├── models ├── cnn_core │ ├── __init__.py │ ├── mobilenet_v3.py │ └── regnet.py ├── perceiver_core │ ├── __init__.py │ ├── utils.py │ ├── config.py │ └── modules.py ├── __init__.py ├── registry.py └── dyn_perceiver_resnet.py ├── figs ├── fig6_ee.png ├── fig1_idea.png ├── fig9_video.png ├── fig2_overview.png ├── fig3_x2z_ca.png ├── fig5_z2x_ca.png ├── fig3_components.png ├── fig4_token_mixer.png ├── fig7_main_results.png ├── fig8_mob_results.png └── fig10_speed_4subfig.png ├── LICENSE ├── .gitignore ├── losses.py ├── datasets.py ├── README.md ├── adaptive_inference.py ├── optim_factory.py ├── engine.py ├── utils.py └── main_baseline.py /models/cnn_core/__init__.py: -------------------------------------------------------------------------------- 1 | from .regnet import * 2 | from .mobilenet_v3 import * -------------------------------------------------------------------------------- /models/perceiver_core/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /figs/fig6_ee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig6_ee.png -------------------------------------------------------------------------------- /figs/fig1_idea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig1_idea.png -------------------------------------------------------------------------------- /figs/fig9_video.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig9_video.png -------------------------------------------------------------------------------- /figs/fig2_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig2_overview.png -------------------------------------------------------------------------------- /figs/fig3_x2z_ca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig3_x2z_ca.png -------------------------------------------------------------------------------- /figs/fig5_z2x_ca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig5_z2x_ca.png -------------------------------------------------------------------------------- /figs/fig3_components.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig3_components.png -------------------------------------------------------------------------------- /figs/fig4_token_mixer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig4_token_mixer.png -------------------------------------------------------------------------------- /figs/fig7_main_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig7_main_results.png -------------------------------------------------------------------------------- /figs/fig8_mob_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig8_mob_results.png -------------------------------------------------------------------------------- /figs/fig10_speed_4subfig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig10_speed_4subfig.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .dyn_perceiver_regnet import * 2 | from .dyn_perceiver_mobilenet_v3 import * 3 | from .dyn_perceiver_resnet import * 4 | -------------------------------------------------------------------------------- /models/perceiver_core/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Sequential(nn.Sequential): 5 | def forward(self, *x): 6 | for module in self: 7 | if type(x) == tuple: 8 | x = module(*x) 9 | else: 10 | x = module(x) 11 | return x 12 | 13 | 14 | def freeze(module: nn.Module): 15 | for param in module.parameters(): 16 | param.requires_grad = False 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/perceiver_core/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, fields 2 | from typing import Generic, Optional, TypeVar 3 | 4 | 5 | @dataclass 6 | class EncoderConfig: 7 | num_cross_attention_heads: int = 8 8 | num_cross_attention_qk_channels: Optional[int] = None 9 | num_cross_attention_v_channels: Optional[int] = None 10 | num_cross_attention_layers: int = 1 11 | first_cross_attention_layer_shared: bool = False 12 | cross_attention_widening_factor: int = 1 13 | num_self_attention_heads: int = 8 14 | num_self_attention_qk_channels: Optional[int] = None 15 | num_self_attention_v_channels: Optional[int] = None 16 | num_self_attention_layers_per_block: int = 8 17 | num_self_attention_blocks: int = 1 18 | first_self_attention_block_shared: bool = True 19 | self_attention_widening_factor: int = 1 20 | dropout: float = 0.0 21 | freeze: bool = False 22 | 23 | def base_kwargs(self, exclude=("freeze",)): 24 | return _base_kwargs(self, EncoderConfig, exclude) 25 | 26 | 27 | @dataclass 28 | class DecoderConfig: 29 | num_cross_attention_heads: int = 8 30 | num_cross_attention_qk_channels: Optional[int] = None 31 | num_cross_attention_v_channels: Optional[int] = None 32 | cross_attention_widening_factor: int = 1 33 | dropout: float = 0.0 34 | freeze: bool = False 35 | 36 | def base_kwargs(self, exclude=("freeze",)): 37 | return _base_kwargs(self, DecoderConfig, exclude) 38 | 39 | 40 | E = TypeVar("E", bound=EncoderConfig) 41 | D = TypeVar("D", bound=DecoderConfig) 42 | 43 | 44 | @dataclass 45 | class PerceiverConfig(Generic[E, D]): 46 | encoder: E 47 | decoder: D 48 | num_latents: int 49 | num_latent_channels: int 50 | activation_checkpointing: bool 51 | 52 | 53 | @dataclass 54 | class ClassificationDecoderConfig(DecoderConfig): 55 | num_output_queries: int = 1 56 | num_output_query_channels: int = 256 57 | num_classes: int = 100 58 | 59 | 60 | def _base_kwargs(config, base_class, exclude): 61 | base_field_names = [field.name for field in fields(base_class) if field.name not in exclude] 62 | return {k: v for k, v in asdict(config).items() if k in base_field_names} 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | cluster/ 3 | results/ 4 | results_best/ 5 | ckpt_cswin/cswin_tiny_224.pth 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | *.tar.gz 11 | *.pth 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | exp/ 38 | models/flops.txt 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | 16 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 17 | distillation_type: str, alpha: float, tau: float): 18 | super().__init__() 19 | self.base_criterion = base_criterion 20 | self.teacher_model = teacher_model 21 | self.teacher_model.eval() 22 | assert distillation_type in ['none', 'soft', 'hard'] 23 | self.distillation_type = distillation_type 24 | self.alpha = alpha 25 | self.tau = tau 26 | 27 | def forward(self, inputs, outputs, labels): 28 | """ 29 | Args: 30 | inputs: The original inputs that are feed to the teacher model 31 | outputs: the outputs of the model to be trained. It is expected to be 32 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 33 | in the first position and the distillation predictions as the second output 34 | labels: the labels for the base criterion 35 | """ 36 | outputs_kd = None 37 | if not isinstance(outputs, torch.Tensor): 38 | # assume that the model outputs a tuple of [outputs, outputs_kd] 39 | outputs, outputs_kd = outputs 40 | else: 41 | outputs_kd = outputs 42 | base_loss = self.base_criterion(outputs, labels) 43 | if self.distillation_type == 'none': 44 | return base_loss 45 | 46 | if outputs_kd is None: 47 | raise ValueError("When knowledge distillation is enabled, the model is " 48 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 49 | "class_token and the dist_token") 50 | # don't backprop throught the teacher 51 | with torch.no_grad(): 52 | teacher_outputs = self.teacher_model(inputs) 53 | 54 | if self.distillation_type == 'soft': 55 | T = self.tau 56 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 57 | # with slight modifications 58 | distillation_loss = F.kl_div( 59 | F.log_softmax(outputs_kd / T, dim=1), 60 | F.log_softmax(teacher_outputs / T, dim=1), 61 | reduction='sum', 62 | log_target=True 63 | ) * (T * T) / outputs_kd.numel() 64 | elif self.distillation_type == 'hard': 65 | distillation_loss = F.cross_entropy( 66 | outputs_kd, teacher_outputs.argmax(dim=1)) 67 | 68 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 69 | return loss 70 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | from torchvision import datasets, transforms 11 | # from dataloader_hf import FireFlyerImageNet 12 | from timm.data.constants import \ 13 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 14 | from timm.data import create_transform 15 | 16 | def build_dataset(is_train, args): 17 | transform = build_transform(is_train, args) 18 | 19 | print("Transform = ") 20 | if isinstance(transform, tuple): 21 | for trans in transform: 22 | print(" - - - - - - - - - - ") 23 | for t in trans.transforms: 24 | print(t) 25 | else: 26 | for t in transform.transforms: 27 | print(t) 28 | print("---------------------------") 29 | 30 | if args.data_set == 'CIFAR': 31 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 32 | nb_classes = 100 33 | elif args.data_set == 'IMNET': 34 | print("reading from datapath", args.data_path) 35 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 36 | dataset = datasets.ImageFolder(root, transform=transform) 37 | nb_classes = 1000 38 | elif args.data_set == "image_folder": 39 | root = args.data_path if is_train else args.eval_data_path 40 | dataset = datasets.ImageFolder(root, transform=transform) 41 | nb_classes = args.nb_classes 42 | assert len(dataset.class_to_idx) == nb_classes 43 | else: 44 | raise NotImplementedError() 45 | print("Number of the class = %d" % nb_classes) 46 | 47 | return dataset, nb_classes 48 | 49 | def build_transform(is_train, args): 50 | resize_im = args.input_size > 32 51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 54 | 55 | if is_train: 56 | # this should always dispatch to transforms_imagenet_train 57 | 58 | transform = create_transform( 59 | input_size=args.input_size, 60 | is_training=True, 61 | color_jitter=args.color_jitter, 62 | auto_augment=args.aa, 63 | interpolation=args.train_interpolation, 64 | re_prob=args.reprob, 65 | re_mode=args.remode, 66 | re_count=args.recount, 67 | mean=mean, 68 | std=std, 69 | ) 70 | if not resize_im: 71 | transform.transforms[0] = transforms.RandomCrop( 72 | args.input_size, padding=4) 73 | return transform 74 | 75 | t = [] 76 | if resize_im: 77 | # warping (no cropping) when evaluated at 384 or larger 78 | if args.input_size >= 384: 79 | t.append( 80 | transforms.Resize((args.input_size, args.input_size), 81 | interpolation=transforms.InterpolationMode.BICUBIC), 82 | ) 83 | print(f"Warping {args.input_size} size input images...") 84 | else: 85 | if args.crop_pct is None: 86 | args.crop_pct = 224 / 256 87 | size = int(args.input_size / args.crop_pct) 88 | t.append( 89 | # to maintain same ratio w.r.t. 224 images 90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 91 | ) 92 | t.append(transforms.CenterCrop(args.input_size)) 93 | 94 | t.append(transforms.ToTensor()) 95 | t.append(transforms.Normalize(mean, std)) 96 | return transforms.Compose(t) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Perceiver for Efficient Visual Recognition 2 | 3 | 4 | 5 | ## Introduction 6 | 7 | This repository contains the implementation of the ICCV 2023 paper, [*Dynamic Perceiver for Efficient Visual Recognition*](https://arxiv.org/abs/2306.11248). The proposed **Dynamic Perceiver (Dyn-Perceiver)** decouples the feature extraction procedure and the early classification task with a novel two-branch architecture, which significantly improves model performance in the dynamic early exiting scenario. 8 | 9 | ### Overall idea 10 | 11 | fig1 12 | 13 | 14 | 15 | ### Model overview 16 | 17 | fig2 18 | 19 | 20 | 21 | ### The inference procedure 22 | 23 | fig3 24 | 25 | ## Usage 26 | 27 | ### Dependencies 28 | 29 | - Python: 3.8 30 | - Pytorch: 1.12.1 31 | - Torchvision: 0.13.1 32 | 33 | ### Scripts 34 | 35 | - Train a **RegNetY-based Dynamic Perceiver** model on ImageNet: 36 | 37 | ```bash 38 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main_earlyExit.py \ 39 | --model reg800m_perceiver_t128 --depth_factor 1 1 1 2 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \ 40 | --num_workers 4 \ 41 | --model_ema true --model_ema_eval true --epochs 300 \ 42 | --batch_size 128 --lr 1e-3 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \ 43 | --data_path YOUR_DATA_PATH \ 44 | --output_dir YOUR_SAVE_PATH &\ 45 | ``` 46 | 47 | - Train a **ResNet-based Dynamic Perceiver** model on ImageNet: 48 | 49 | ```bash 50 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main_earlyExit.py \ 51 | --model resnet50_0375_perceiver_t128 --depth_factor 1 1 1 1 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \ 52 | --num_workers 4 \ 53 | --model_ema true --model_ema_eval true --epochs 300 \ 54 | --batch_size 128 --lr 6e-4 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \ 55 | --data_path YOUR_DATA_PATH \ 56 | --output_dir YOUR_SAVE_PATH &\ 57 | ``` 58 | 59 | - Train a **MobileNet-based Dynamic Perceiver** model on ImageNet: 60 | 61 | ```bash 62 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main_earlyExit.py \ 63 | --model mobilenetV3_0x75_perceiver_t128 --depth_factor 1 1 1 3 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \ 64 | --num_workers 4 \ 65 | --model_ema true --model_ema_eval true --epochs 600 \ 66 | --batch_size 128 --lr 1e-3 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \ 67 | --data_path YOUR_DATA_PATH \ 68 | --output_dir YOUR_SAVE_PATH &\ 69 | ``` 70 | 71 | - Evaluate (dynamic): 72 | 73 | ```bash 74 | CUDA_VISIBLE_DEVICES=0 python main_earlyExit.py --eval true \ 75 | --resume YOUR_CHECKPOINT_PATH \ 76 | --model reg800m_perceiver_t128 --depth_factor 1 1 1 2 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \ 77 | --num_workers 4 \ 78 | --batch_size 128 --lr 1e-3 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \ 79 | --data_path YOUR_DATA_PATH \ 80 | --output_dir YOUR_SAVE_PATH &\ 81 | ``` 82 | 83 | 84 | 85 | ### Results 86 | 87 | - : ImageNet results of Dyn-Perceiver built on top of MobileNet-v3. 88 | 89 | fig4 90 | 91 | - Speed test results of Dyn-Perceiver. 92 | 93 | fig5 94 | 95 | ### Pre-trained Models on ImageNet 96 | |model|acc_exit1|acc_exit2|acc_exit3|acc_exit4|Checkpoint Link| 97 | |:-:|:-:|:-:|:-:|:-:|:-:| 98 | | reg800m_perceiver_t128 | 68.62 | 78.32 | 79.15 | 79.86 |[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ca3c8a808d504fdfa2c8/?dl=1) | 99 | | resnet50_0375_perceiver_t128 | 72.93 | 77.52 | 74.32 | 77.70 |[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/844ac164a62246eb9ce8/?dl=1) | 100 | | mobilenetV3_0x75_perceiver_t128 | 53.13 | 71.65 | 71.89 | 74.59 |[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/a86c56e3076146ad8748/?dl=1) | 101 | 102 | ### Contact 103 | 104 | If you have any questions, please feel free to contact the authors. 105 | 106 | Yizeng Han: [hanyz18@mails.tsinghua.edu.cn](mailto:hanyz18@mails.tsinghua.edu.cn), [yizeng38@gmail.com](mailto:yizeng38@gmail.com). 107 | 108 | Dongchen Han: [hdc19@mails.tsinghua.edu.cn](mailto:hdc19@mails.tsinghua.edu.cn), [tianqing1.10000@gmail.com](mailto:tianqing1.10000@gmail.com) 109 | 110 | 111 | 112 | 115 | -------------------------------------------------------------------------------- /models/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | 11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 12 | 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] 13 | 14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 15 | _model_to_module = {} # mapping of model names to module names 16 | _model_entrypoints = {} # mapping of model names to entrypoint fns 17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 18 | _model_default_cfgs = dict() # central repo for model default_cfgs 19 | 20 | 21 | def register_model(fn): 22 | # lookup containing module 23 | mod = sys.modules[fn.__module__] 24 | module_name_split = fn.__module__.split('.') 25 | module_name = module_name_split[-1] if len(module_name_split) else '' 26 | 27 | # add model to __all__ in module 28 | model_name = fn.__name__ 29 | if hasattr(mod, '__all__'): 30 | mod.__all__.append(model_name) 31 | else: 32 | mod.__all__ = [model_name] 33 | 34 | # add entries to registry dict/sets 35 | _model_entrypoints[model_name] = fn 36 | _model_to_module[model_name] = module_name 37 | _module_to_models[module_name].add(model_name) 38 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 41 | # entrypoints or non-matching combos 42 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 43 | _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) 44 | if has_pretrained: 45 | _model_has_pretrained.add(model_name) 46 | return fn 47 | 48 | 49 | def _natural_key(string_): 50 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 51 | 52 | 53 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 54 | """ Return list of available model names, sorted alphabetically 55 | Args: 56 | filter (str) - Wildcard filter string that works with fnmatch 57 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 58 | pretrained (bool) - Include only models with pretrained weights if True 59 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 60 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 61 | Example: 62 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 63 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 64 | """ 65 | if module: 66 | models = list(_module_to_models[module]) 67 | else: 68 | models = _model_entrypoints.keys() 69 | if filter: 70 | models = fnmatch.filter(models, filter) # include these models 71 | if exclude_filters: 72 | if not isinstance(exclude_filters, (tuple, list)): 73 | exclude_filters = [exclude_filters] 74 | for xf in exclude_filters: 75 | exclude_models = fnmatch.filter(models, xf) # exclude these models 76 | if len(exclude_models): 77 | models = set(models).difference(exclude_models) 78 | if pretrained: 79 | models = _model_has_pretrained.intersection(models) 80 | if name_matches_cfg: 81 | models = set(_model_default_cfgs).intersection(models) 82 | return list(sorted(models, key=_natural_key)) 83 | 84 | 85 | def is_model(model_name): 86 | """ Check if a model name exists 87 | """ 88 | return model_name in _model_entrypoints 89 | 90 | 91 | def model_entrypoint(model_name): 92 | """Fetch a model entrypoint for specified model name 93 | """ 94 | return _model_entrypoints[model_name] 95 | 96 | 97 | def list_modules(): 98 | """ Return list of module names that contain models / model entrypoints 99 | """ 100 | modules = _module_to_models.keys() 101 | return list(sorted(modules)) 102 | 103 | 104 | def is_model_in_modules(model_name, module_names): 105 | """Check if a model exists within a subset of modules 106 | Args: 107 | model_name (str) - name of model to check 108 | module_names (tuple, list, set) - names of modules to search in 109 | """ 110 | assert isinstance(module_names, (tuple, list, set)) 111 | return any(model_name in _module_to_models[n] for n in module_names) 112 | 113 | 114 | def has_model_default_key(model_name, cfg_key): 115 | """ Query model default_cfgs for existence of a specific key. 116 | """ 117 | if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: 118 | return True 119 | return False 120 | 121 | 122 | def is_model_default_key(model_name, cfg_key): 123 | """ Return truthy value for specified model default_cfg key, False if does not exist. 124 | """ 125 | if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): 126 | return True 127 | return False 128 | 129 | 130 | def get_model_default_value(model_name, cfg_key): 131 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 132 | """ 133 | if model_name in _model_default_cfgs: 134 | return _model_default_cfgs[model_name].get(cfg_key, None) 135 | else: 136 | return None 137 | 138 | 139 | def is_model_pretrained(model_name): 140 | return model_name in _model_has_pretrained -------------------------------------------------------------------------------- /adaptive_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import math 5 | import numpy as np 6 | 7 | 8 | def dynamic_evaluate(model, test_loader, val_loader, filename, args): 9 | tester = Tester(model) 10 | # if os.path.exists(os.path.join(args.output_dir, 'logits_single.pth')): 11 | # val_pred, val_target, test_pred, test_target = \ 12 | # torch.load(os.path.join(args.output_dir, 'logits_single.pth')) 13 | # else: 14 | val_pred, val_target = tester.calc_logit(val_loader, early_break=True) 15 | test_pred, test_target = tester.calc_logit(test_loader, early_break=False) 16 | # torch.save((val_pred, val_target, test_pred, test_target), 17 | # os.path.join(args.output_dir, 'logits_single.pth')) 18 | 19 | # flops = torch.load(os.path.join(args.output_dir, 'flops.pth')) 20 | flops = np.loadtxt(f'{args.output_dir}/flops.txt') 21 | flops = [flops[i] for i in [0,1,2,3]] 22 | each_exit = False 23 | with open(os.path.join(args.output_dir, filename), 'w') as fout: 24 | probs_list = generate_distribution(each_exit=each_exit) 25 | for probs in probs_list: 26 | print('\n*****************') 27 | print(probs) 28 | acc_val, _, T = tester.dynamic_eval_find_threshold(val_pred, val_target, probs, flops) 29 | print(T) 30 | acc_test, exp_flops, acc_each_stage = tester.dynamic_eval_with_threshold(test_pred, test_target, flops, T) 31 | print('valid acc: {:.3f}, test acc: {:.3f}, test flops: {:.2f}M'.format(acc_val, acc_test, exp_flops)) 32 | print('acc of each exit: {}'.format(acc_each_stage)) 33 | fout.write('{}\t{}\n'.format(acc_test, exp_flops.item())) 34 | print('----------ALL DONE-----------') 35 | 36 | 37 | def generate_distribution(each_exit=False): 38 | probs_list = [] 39 | if each_exit: 40 | for i in range(4): 41 | probs = torch.zeros(4, dtype=torch.float) 42 | probs[i] = 1 43 | probs_list.append(probs) 44 | else: 45 | p_list = torch.zeros(34) 46 | for i in range(17): 47 | p_list[i] = (i + 4) / 20 48 | p_list[33 - i] = 20 / (i + 4) 49 | 50 | # y_early3, y_att, y_cnn, y_merge 51 | # 对应放缩比例 52 | k = [0.85, 1, 0.5, 1] 53 | for i in range(33): 54 | probs = torch.exp(torch.log(p_list[i]) * torch.range(1, 4)) 55 | probs /= probs.sum() 56 | for j in range(3): 57 | probs[j] *= k[j] 58 | probs[j+1:4] = (1 - probs[0:j+1].sum()) * probs[j+1:4] / probs[j+1:4].sum() 59 | probs_list.append(probs) 60 | return probs_list 61 | 62 | 63 | class Tester(object): 64 | def __init__(self, model): 65 | # self.args = args 66 | self.model = model 67 | self.softmax = nn.Softmax(dim=1).cuda() 68 | 69 | def calc_logit(self, dataloader, early_break=False): 70 | self.model.eval() 71 | n_stage = 4 72 | logits = [[] for _ in range(n_stage)] 73 | targets = [] 74 | # print('xxxxxxxxxxx111111') 75 | # print(len(dataloader)) 76 | for i, (input, target) in enumerate(dataloader): 77 | # print(input.shape, target.shape) 78 | if early_break and i > 100: 79 | break 80 | targets.append(target) 81 | input = input.cuda() 82 | with torch.no_grad(): 83 | y_early3, y_att, y_cnn, y_merge = self.model(input) 84 | output = [y_early3, y_att, y_cnn, y_merge] 85 | for b in range(n_stage): 86 | _t = self.softmax(output[b]) 87 | 88 | logits[b].append(_t) 89 | if i % 50 == 0: 90 | print('Generate Logit: [{0}/{1}]'.format(i, len(dataloader))) 91 | 92 | for b in range(n_stage): 93 | logits[b] = torch.cat(logits[b], dim=0) 94 | 95 | size = (n_stage, logits[0].size(0), logits[0].size(1)) 96 | ts_logits = torch.Tensor().resize_(size).zero_() 97 | for b in range(n_stage): 98 | ts_logits[b].copy_(logits[b]) 99 | 100 | targets = torch.cat(targets, dim=0) 101 | ts_targets = torch.Tensor().resize_(size[1]).copy_(targets) 102 | 103 | return ts_logits, ts_targets 104 | 105 | def dynamic_eval_find_threshold(self, logits, targets, p, flops): 106 | """ 107 | logits: m * n * c 108 | m: Stages 109 | n: Samples 110 | c: Classes 111 | """ 112 | n_stage, n_sample, c = logits.size() 113 | 114 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False) 115 | 116 | _, sorted_idx = max_preds.sort(dim=1, descending=True) 117 | 118 | filtered = torch.zeros(n_sample) 119 | T = torch.Tensor(n_stage).fill_(1e8) 120 | 121 | for k in range(n_stage - 1): 122 | acc, count = 0.0, 0 123 | out_n = math.floor(n_sample * p[k]) 124 | for i in range(n_sample): 125 | ori_idx = sorted_idx[k][i] 126 | if filtered[ori_idx] == 0: 127 | count += 1 128 | if count == out_n: 129 | T[k] = max_preds[k][ori_idx] 130 | break 131 | filtered.add_(max_preds[k].ge(T[k]).type_as(filtered)) 132 | 133 | T[n_stage -1] = -1e8 # accept all of the samples at the last stage 134 | 135 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) 136 | acc, expected_flops = 0, 0 137 | for i in range(n_sample): 138 | gold_label = targets[i] 139 | for k in range(n_stage): 140 | if max_preds[k][i].item() >= T[k]: # force the sample to exit at k 141 | if int(gold_label.item()) == int(argmax_preds[k][i].item()): 142 | acc += 1 143 | acc_rec[k] += 1 144 | exp[k] += 1 145 | break 146 | acc_all = 0 147 | for k in range(n_stage): 148 | _t = 1.0 * exp[k] / n_sample 149 | expected_flops += _t * flops[k] 150 | acc_all += acc_rec[k] 151 | 152 | return acc * 100.0 / n_sample, expected_flops, T 153 | 154 | def dynamic_eval_with_threshold(self, logits, targets, flops, T): 155 | n_stage, n_sample, _ = logits.size() 156 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False) # take the max logits as confidence 157 | 158 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) 159 | acc, expected_flops = 0, 0 160 | for i in range(n_sample): 161 | gold_label = targets[i] 162 | for k in range(n_stage): 163 | if max_preds[k][i].item() >= T[k]: # force to exit at k 164 | _g = int(gold_label.item()) 165 | _pred = int(argmax_preds[k][i].item()) 166 | if _g == _pred: 167 | acc += 1 168 | acc_rec[k] += 1 169 | exp[k] += 1 170 | break 171 | acc_all, sample_all = 0, 0 172 | for k in range(n_stage): 173 | _t = exp[k] * 1.0 / n_sample 174 | sample_all += exp[k] 175 | expected_flops += _t * flops[k] 176 | acc_all += acc_rec[k] 177 | 178 | return acc * 100.0 / n_sample, expected_flops, acc_rec / exp 179 | 180 | 181 | if __name__ == '__main__': 182 | print(generate_distribution(each_exit=False)) 183 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | from torch import optim as optim 11 | 12 | from timm.optim.adafactor import Adafactor 13 | from timm.optim.adahessian import Adahessian 14 | from timm.optim.adamp import AdamP 15 | from timm.optim.lookahead import Lookahead 16 | from timm.optim.nadam import Nadam 17 | from timm.optim.novograd import NovoGrad 18 | from timm.optim.nvnovograd import NvNovoGrad 19 | from timm.optim.radam import RAdam 20 | from timm.optim.rmsprop_tf import RMSpropTF 21 | from timm.optim.sgdp import SGDP 22 | 23 | import json 24 | 25 | try: 26 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 27 | has_apex = True 28 | except ImportError: 29 | has_apex = False 30 | 31 | 32 | def get_num_layer_for_convnext(var_name): 33 | """ 34 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 35 | consecutive blocks, including possible neighboring downsample layers; 36 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 37 | """ 38 | num_max_layer = 12 39 | if var_name.startswith("downsample_layers"): 40 | stage_id = int(var_name.split('.')[1]) 41 | if stage_id == 0: 42 | layer_id = 0 43 | elif stage_id == 1 or stage_id == 2: 44 | layer_id = stage_id + 1 45 | elif stage_id == 3: 46 | layer_id = 12 47 | return layer_id 48 | 49 | elif var_name.startswith("stages"): 50 | stage_id = int(var_name.split('.')[1]) 51 | block_id = int(var_name.split('.')[2]) 52 | if stage_id == 0 or stage_id == 1: 53 | layer_id = stage_id + 1 54 | elif stage_id == 2: 55 | layer_id = 3 + block_id // 3 56 | elif stage_id == 3: 57 | layer_id = 12 58 | return layer_id 59 | else: 60 | return num_max_layer + 1 61 | 62 | class LayerDecayValueAssigner(object): 63 | def __init__(self, values): 64 | self.values = values 65 | 66 | def get_scale(self, layer_id): 67 | return self.values[layer_id] 68 | 69 | def get_layer_id(self, var_name): 70 | return get_num_layer_for_convnext(var_name) 71 | 72 | 73 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 74 | parameter_group_names = {} 75 | parameter_group_vars = {} 76 | 77 | for name, param in model.named_parameters(): 78 | if not param.requires_grad: 79 | continue # frozen weights 80 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 81 | group_name = "no_decay" 82 | this_weight_decay = 0. 83 | else: 84 | group_name = "decay" 85 | this_weight_decay = weight_decay 86 | if get_num_layer is not None: 87 | layer_id = get_num_layer(name) 88 | group_name = "layer_%d_%s" % (layer_id, group_name) 89 | else: 90 | layer_id = None 91 | 92 | if group_name not in parameter_group_names: 93 | if get_layer_scale is not None: 94 | scale = get_layer_scale(layer_id) 95 | else: 96 | scale = 1. 97 | 98 | parameter_group_names[group_name] = { 99 | "weight_decay": this_weight_decay, 100 | "params": [], 101 | "lr_scale": scale 102 | } 103 | parameter_group_vars[group_name] = { 104 | "weight_decay": this_weight_decay, 105 | "params": [], 106 | "lr_scale": scale 107 | } 108 | 109 | parameter_group_vars[group_name]["params"].append(param) 110 | parameter_group_names[group_name]["params"].append(name) 111 | # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 112 | return list(parameter_group_vars.values()) 113 | 114 | 115 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 116 | opt_lower = args.opt.lower() 117 | weight_decay = args.weight_decay 118 | # if weight_decay and filter_bias_and_bn: 119 | if filter_bias_and_bn: 120 | skip = {} 121 | if skip_list is not None: 122 | skip = skip_list 123 | elif hasattr(model, 'no_weight_decay'): 124 | skip = model.no_weight_decay() 125 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 126 | weight_decay = 0. 127 | else: 128 | parameters = model.parameters() 129 | 130 | if 'fused' in opt_lower: 131 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 132 | 133 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 134 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 135 | opt_args['eps'] = args.opt_eps 136 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 137 | opt_args['betas'] = args.opt_betas 138 | 139 | opt_split = opt_lower.split('_') 140 | opt_lower = opt_split[-1] 141 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 142 | opt_args.pop('eps', None) 143 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 144 | elif opt_lower == 'momentum': 145 | opt_args.pop('eps', None) 146 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 147 | elif opt_lower == 'adam': 148 | optimizer = optim.Adam(parameters, **opt_args) 149 | elif opt_lower == 'adamw': 150 | optimizer = optim.AdamW(parameters, **opt_args) 151 | elif opt_lower == 'nadam': 152 | optimizer = Nadam(parameters, **opt_args) 153 | elif opt_lower == 'radam': 154 | optimizer = RAdam(parameters, **opt_args) 155 | elif opt_lower == 'adamp': 156 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 157 | elif opt_lower == 'sgdp': 158 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 159 | elif opt_lower == 'adadelta': 160 | optimizer = optim.Adadelta(parameters, **opt_args) 161 | elif opt_lower == 'adafactor': 162 | if not args.lr: 163 | opt_args['lr'] = None 164 | optimizer = Adafactor(parameters, **opt_args) 165 | elif opt_lower == 'adahessian': 166 | optimizer = Adahessian(parameters, **opt_args) 167 | elif opt_lower == 'rmsprop': 168 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 169 | elif opt_lower == 'rmsproptf': 170 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 171 | elif opt_lower == 'novograd': 172 | optimizer = NovoGrad(parameters, **opt_args) 173 | elif opt_lower == 'nvnovograd': 174 | optimizer = NvNovoGrad(parameters, **opt_args) 175 | elif opt_lower == 'fusedsgd': 176 | opt_args.pop('eps', None) 177 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 178 | elif opt_lower == 'fusedmomentum': 179 | opt_args.pop('eps', None) 180 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 181 | elif opt_lower == 'fusedadam': 182 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 183 | elif opt_lower == 'fusedadamw': 184 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 185 | elif opt_lower == 'fusedlamb': 186 | optimizer = FusedLAMB(parameters, **opt_args) 187 | elif opt_lower == 'fusednovograd': 188 | opt_args.setdefault('betas', (0.95, 0.98)) 189 | optimizer = FusedNovoGrad(parameters, **opt_args) 190 | else: 191 | assert False and "Invalid optimizer" 192 | 193 | if len(opt_split) > 1: 194 | if opt_split[0] == 'lookahead': 195 | optimizer = Lookahead(optimizer) 196 | 197 | return optimizer 198 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import math 10 | from typing import Iterable, Optional 11 | import torch 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | import utils 16 | 17 | import torch.nn.functional as F 18 | 19 | def train_one_epoch_earlyExit(model: torch.nn.Module, criterion: torch.nn.Module, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 23 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 24 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False, 25 | loss_cnn_factor=0.25, 26 | loss_att_factor=0.25, 27 | loss_merge_factor=0.5, 28 | 29 | with_kd=False, 30 | T_kd=4.0, 31 | alpha_kd=0.5, 32 | criterion_distill=None): 33 | model.train(True) 34 | metric_logger = utils.MetricLogger(delimiter=" ") 35 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 36 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 37 | 38 | header = 'Epoch: [{}]'.format(epoch) 39 | print_freq = 10 40 | 41 | optimizer.zero_grad() 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # if data_iter_step >= 50: 46 | # break 47 | step = data_iter_step // update_freq 48 | if step >= num_training_steps_per_epoch: 49 | continue 50 | it = start_steps + step # global training iteration 51 | # Update LR & WD for the first acc 52 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 53 | for i, param_group in enumerate(optimizer.param_groups): 54 | if lr_schedule_values is not None: 55 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 56 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 57 | param_group["weight_decay"] = wd_schedule_values[it] 58 | 59 | samples = samples.to(device, non_blocking=True) 60 | targets = targets.to(device, non_blocking=True) 61 | 62 | if mixup_fn is not None: 63 | samples, targets = mixup_fn(samples, targets) 64 | 65 | if use_amp: 66 | with torch.cuda.amp.autocast(): 67 | y_early3, y_att, y_cnn, y_merge = model(samples) 68 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, targets), criterion(y_att, targets), criterion(y_merge, targets) if criterion_distill is None else criterion_distill(samples, y_merge, targets), criterion(y_early3, targets) 69 | else: # full precision 70 | y_early3, y_att, y_cnn, y_merge, = model(samples) 71 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, targets), criterion(y_att, targets), criterion(y_merge, targets) if criterion_distill is None else criterion_distill(samples, y_merge, targets), criterion(y_early3, targets) 72 | 73 | loss = loss_cnn_factor*loss_cnn + loss_att_factor*(loss_att+loss_early3) + loss_merge_factor*loss_merge 74 | 75 | if with_kd: 76 | out_teacher = y_merge.detach() 77 | 78 | kd_loss = F.kl_div(F.log_softmax(y_early3/T_kd, dim=1),F.softmax(out_teacher/T_kd, dim=1), reduction='batchmean') * T_kd**2 + \ 79 | F.kl_div(F.log_softmax(y_att/T_kd, dim=1),F.softmax(out_teacher/T_kd, dim=1), reduction='batchmean') * T_kd**2 + \ 80 | F.kl_div(F.log_softmax(y_cnn/T_kd, dim=1),F.softmax(out_teacher/T_kd, dim=1), reduction='batchmean') * T_kd**2 81 | 82 | loss += alpha_kd * kd_loss 83 | 84 | loss_value = loss.item() 85 | 86 | if not math.isfinite(loss_value): # this could trigger if using AMP 87 | print("Loss is {}, stopping training".format(loss_value)) 88 | assert math.isfinite(loss_value) 89 | 90 | if use_amp: 91 | # this attribute is added by timm on one optimizer (adahessian) 92 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 93 | loss /= update_freq 94 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 95 | parameters=model.parameters(), create_graph=is_second_order, 96 | update_grad=(data_iter_step + 1) % update_freq == 0) 97 | if (data_iter_step + 1) % update_freq == 0: 98 | optimizer.zero_grad() 99 | if model_ema is not None: 100 | model_ema.update(model) 101 | else: # full precision 102 | loss /= update_freq 103 | loss.backward() 104 | if (data_iter_step + 1) % update_freq == 0: 105 | optimizer.step() 106 | optimizer.zero_grad() 107 | if model_ema is not None: 108 | model_ema.update(model) 109 | 110 | torch.cuda.synchronize() 111 | 112 | if mixup_fn is None: 113 | class_acc_cnn = (y_cnn.max(-1)[-1] == targets).float().mean() 114 | class_acc_att = (y_att.max(-1)[-1] == targets).float().mean() 115 | class_acc_merge = (y_merge.max(-1)[-1] == targets).float().mean() 116 | else: 117 | class_acc_cnn, class_acc_att, class_acc_merge = None, None, None 118 | metric_logger.update(loss=loss_value) 119 | metric_logger.update(class_acc=class_acc_merge) 120 | min_lr = 10. 121 | max_lr = 0. 122 | for group in optimizer.param_groups: 123 | min_lr = min(min_lr, group["lr"]) 124 | max_lr = max(max_lr, group["lr"]) 125 | 126 | metric_logger.update(lr=max_lr) 127 | metric_logger.update(min_lr=min_lr) 128 | weight_decay_value = None 129 | for group in optimizer.param_groups: 130 | if group["weight_decay"] > 0: 131 | weight_decay_value = group["weight_decay"] 132 | metric_logger.update(weight_decay=weight_decay_value) 133 | 134 | 135 | 136 | if use_amp: 137 | metric_logger.update(grad_norm=grad_norm) 138 | 139 | # print(str(metric_logger)) 140 | # assert(0==1) 141 | 142 | if log_writer is not None: 143 | log_writer.update(loss=loss_value, head="loss") 144 | # log_writer.update(class_acc_cnn=class_acc_cnn, head="loss") 145 | # log_writer.update(class_acc_att=class_acc_att, head="loss") 146 | log_writer.update(class_acc=class_acc_merge, head="loss") 147 | log_writer.update(lr=max_lr, head="opt") 148 | log_writer.update(min_lr=min_lr, head="opt") 149 | log_writer.update(weight_decay=weight_decay_value, head="opt") 150 | if use_amp: 151 | log_writer.update(grad_norm=grad_norm, head="opt") 152 | log_writer.set_step() 153 | 154 | if wandb_logger: 155 | wandb_logger._wandb.log({ 156 | 'Rank-0 Batch Wise/train_loss': loss_value, 157 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 158 | 'Rank-0 Batch Wise/train_min_lr': min_lr 159 | }, commit=False) 160 | if class_acc_cnn: 161 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc_cnn': class_acc_cnn}, commit=False) 162 | if class_acc_att: 163 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc_att': class_acc_att}, commit=False) 164 | if class_acc_merge: 165 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc_merge': class_acc_merge}, commit=False) 166 | if use_amp: 167 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 168 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 169 | 170 | 171 | # gather the stats from all processes 172 | metric_logger.synchronize_between_processes() 173 | print("Averaged stats:", metric_logger) 174 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 175 | 176 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 177 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 178 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 179 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 180 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 181 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 182 | model.train(True) 183 | metric_logger = utils.MetricLogger(delimiter=" ") 184 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 185 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 186 | header = 'Epoch: [{}]'.format(epoch) 187 | print_freq = 10 188 | 189 | optimizer.zero_grad() 190 | 191 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 192 | step = data_iter_step // update_freq 193 | if step >= num_training_steps_per_epoch: 194 | continue 195 | it = start_steps + step # global training iteration 196 | # Update LR & WD for the first acc 197 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 198 | for i, param_group in enumerate(optimizer.param_groups): 199 | if lr_schedule_values is not None: 200 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 201 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 202 | param_group["weight_decay"] = wd_schedule_values[it] 203 | 204 | samples = samples.to(device, non_blocking=True) 205 | targets = targets.to(device, non_blocking=True) 206 | 207 | if mixup_fn is not None: 208 | samples, targets = mixup_fn(samples, targets) 209 | 210 | if use_amp: 211 | with torch.cuda.amp.autocast(): 212 | output = model(samples) 213 | loss = criterion(output, targets) 214 | else: # full precision 215 | output = model(samples) 216 | loss = criterion(output, targets) 217 | 218 | loss_value = loss.item() 219 | 220 | if not math.isfinite(loss_value): # this could trigger if using AMP 221 | print("Loss is {}, stopping training".format(loss_value)) 222 | assert math.isfinite(loss_value) 223 | 224 | if use_amp: 225 | # this attribute is added by timm on one optimizer (adahessian) 226 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 227 | loss /= update_freq 228 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 229 | parameters=model.parameters(), create_graph=is_second_order, 230 | update_grad=(data_iter_step + 1) % update_freq == 0) 231 | if (data_iter_step + 1) % update_freq == 0: 232 | optimizer.zero_grad() 233 | if model_ema is not None: 234 | model_ema.update(model) 235 | else: # full precision 236 | loss /= update_freq 237 | loss.backward() 238 | if (data_iter_step + 1) % update_freq == 0: 239 | optimizer.step() 240 | optimizer.zero_grad() 241 | if model_ema is not None: 242 | model_ema.update(model) 243 | 244 | torch.cuda.synchronize() 245 | 246 | if mixup_fn is None: 247 | class_acc = (output.max(-1)[-1] == targets).float().mean() 248 | else: 249 | class_acc = None 250 | metric_logger.update(loss=loss_value) 251 | metric_logger.update(class_acc=class_acc) 252 | min_lr = 10. 253 | max_lr = 0. 254 | for group in optimizer.param_groups: 255 | min_lr = min(min_lr, group["lr"]) 256 | max_lr = max(max_lr, group["lr"]) 257 | 258 | metric_logger.update(lr=max_lr) 259 | metric_logger.update(min_lr=min_lr) 260 | weight_decay_value = None 261 | for group in optimizer.param_groups: 262 | if group["weight_decay"] > 0: 263 | weight_decay_value = group["weight_decay"] 264 | metric_logger.update(weight_decay=weight_decay_value) 265 | if use_amp: 266 | metric_logger.update(grad_norm=grad_norm) 267 | 268 | if log_writer is not None: 269 | log_writer.update(loss=loss_value, head="loss") 270 | log_writer.update(class_acc=class_acc, head="loss") 271 | log_writer.update(lr=max_lr, head="opt") 272 | log_writer.update(min_lr=min_lr, head="opt") 273 | log_writer.update(weight_decay=weight_decay_value, head="opt") 274 | if use_amp: 275 | log_writer.update(grad_norm=grad_norm, head="opt") 276 | log_writer.set_step() 277 | 278 | if wandb_logger: 279 | wandb_logger._wandb.log({ 280 | 'Rank-0 Batch Wise/train_loss': loss_value, 281 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 282 | 'Rank-0 Batch Wise/train_min_lr': min_lr 283 | }, commit=False) 284 | if class_acc: 285 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 286 | if use_amp: 287 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 288 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 289 | 290 | 291 | # gather the stats from all processes 292 | metric_logger.synchronize_between_processes() 293 | print("Averaged stats:", metric_logger) 294 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 295 | 296 | @torch.no_grad() 297 | def evaluate(data_loader, model, device, use_amp=False): 298 | criterion = torch.nn.CrossEntropyLoss() 299 | 300 | metric_logger = utils.MetricLogger(delimiter=" ") 301 | header = 'Test:' 302 | 303 | # switch to evaluation mode 304 | model.eval() 305 | for batch in metric_logger.log_every(data_loader, 10, header): 306 | images = batch[0] 307 | target = batch[-1] 308 | 309 | images = images.to(device, non_blocking=True) 310 | target = target.to(device, non_blocking=True) 311 | 312 | # compute output 313 | if use_amp: 314 | with torch.cuda.amp.autocast(): 315 | output = model(images) 316 | loss = criterion(output, target) 317 | else: 318 | output = model(images) 319 | loss = criterion(output, target) 320 | 321 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 322 | 323 | batch_size = images.shape[0] 324 | metric_logger.update(loss=loss.item()) 325 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 326 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 327 | # gather the stats from all processes 328 | metric_logger.synchronize_between_processes() 329 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 330 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 331 | 332 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 333 | 334 | @torch.no_grad() 335 | def evaluate_earlyExit(data_loader, model, device, use_amp=False, 336 | loss_cnn_factor=0.25, 337 | loss_att_factor=0.25, 338 | loss_merge_factor=0.5): 339 | criterion = torch.nn.CrossEntropyLoss() 340 | 341 | metric_logger = utils.MetricLogger(delimiter=" ") 342 | header = 'Test:' 343 | 344 | # switch to evaluation mode 345 | model.eval() 346 | for batch in metric_logger.log_every(data_loader, 10, header): 347 | images = batch[0] 348 | target = batch[-1] 349 | 350 | images = images.to(device, non_blocking=True) 351 | target = target.to(device, non_blocking=True) 352 | 353 | # compute output 354 | if use_amp: 355 | with torch.cuda.amp.autocast(): 356 | y_early3, y_att, y_cnn, y_merge, = model(images) 357 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, target), criterion(y_att, target), criterion(y_merge, target), criterion(y_early3, target) 358 | else: 359 | y_early3, y_att, y_cnn, y_merge, = model(images) 360 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, target), criterion(y_att, target), criterion(y_merge, target), criterion(y_early3, target) 361 | 362 | loss = loss_cnn_factor*loss_cnn + loss_att_factor*(loss_att+loss_early3) + loss_merge_factor*loss_merge 363 | 364 | acc1_cnn, acc5_cnn = accuracy(y_cnn, target, topk=(1, 5)) 365 | acc1_att, acc5_att = accuracy(y_att, target, topk=(1, 5)) 366 | acc1_early3, acc5_early3 = accuracy(y_early3, target, topk=(1, 5)) 367 | acc1_merge, acc5_merge = accuracy(y_merge, target, topk=(1, 5)) 368 | 369 | batch_size = images.shape[0] 370 | metric_logger.update(loss=loss.item()) 371 | metric_logger.meters['acc1_cnn'].update(acc1_cnn.item(), n=batch_size) 372 | metric_logger.meters['acc1_att'].update(acc1_att.item(), n=batch_size) 373 | metric_logger.meters['acc1_early3'].update(acc1_early3.item(), n=batch_size) 374 | metric_logger.meters['acc1_merge'].update(acc1_merge.item(), n=batch_size) 375 | # gather the stats from all processes 376 | metric_logger.synchronize_between_processes() 377 | print('* Acc@1_early3 {top1_early3.global_avg:.3f} \ 378 | Acc@1_merge {top1_merge.global_avg:.3f} \ 379 | Acc@1_att {top1_att.global_avg:.3f} \ 380 | Acc@1_cnn {top1_cnn.global_avg:.3f} \ 381 | loss {losses.global_avg:.3f}' 382 | .format(top1_cnn=metric_logger.acc1_cnn, 383 | top1_att=metric_logger.acc1_att, 384 | top1_early3=metric_logger.acc1_early3, 385 | top1_merge=metric_logger.acc1_merge, 386 | losses=metric_logger.loss)) 387 | 388 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 389 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | import math 11 | import time 12 | from collections import defaultdict, deque 13 | import datetime 14 | import numpy as np 15 | from timm.utils import get_state_dict 16 | 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | from tensorboardX import SummaryWriter 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | class MetricLogger(object): 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError("'{}' object has no attribute '{}'".format( 107 | type(self).__name__, attr)) 108 | 109 | def __str__(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def synchronize_between_processes(self): 118 | for meter in self.meters.values(): 119 | meter.synchronize_between_processes() 120 | 121 | def add_meter(self, name, meter): 122 | self.meters[name] = meter 123 | 124 | def log_every(self, iterable, print_freq, header=None): 125 | i = 0 126 | if not header: 127 | header = '' 128 | start_time = time.time() 129 | end = time.time() 130 | iter_time = SmoothedValue(fmt='{avg:.4f}') 131 | data_time = SmoothedValue(fmt='{avg:.4f}') 132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 133 | log_msg = [ 134 | header, 135 | '[{0' + space_fmt + '}/{1}]', 136 | 'eta: {eta}', 137 | '{meters}', 138 | 'time: {time}', 139 | 'data: {data}' 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append('max mem: {memory:.0f}') 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=torch.cuda.max_memory_allocated() / MB)) 158 | else: 159 | print(log_msg.format( 160 | i, len(iterable), eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), data=str(data_time))) 163 | i += 1 164 | end = time.time() 165 | total_time = time.time() - start_time 166 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 167 | print('{} Total time: {} ({:.4f} s / it)'.format( 168 | header, total_time_str, total_time / len(iterable))) 169 | 170 | 171 | class TensorboardLogger(object): 172 | def __init__(self, log_dir): 173 | self.writer = SummaryWriter(logdir=log_dir) 174 | self.step = 0 175 | 176 | def set_step(self, step=None): 177 | if step is not None: 178 | self.step = step 179 | else: 180 | self.step += 1 181 | 182 | def update(self, head='scalar', step=None, **kwargs): 183 | for k, v in kwargs.items(): 184 | if v is None: 185 | continue 186 | if isinstance(v, torch.Tensor): 187 | v = v.item() 188 | assert isinstance(v, (float, int)) 189 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 190 | 191 | def flush(self): 192 | self.writer.flush() 193 | 194 | 195 | class WandbLogger(object): 196 | def __init__(self, args): 197 | self.args = args 198 | 199 | try: 200 | import wandb 201 | self._wandb = wandb 202 | except ImportError: 203 | raise ImportError( 204 | "To use the Weights and Biases Logger please install wandb." 205 | "Run `pip install wandb` to install it." 206 | ) 207 | 208 | # Initialize a W&B run 209 | if self._wandb.run is None: 210 | self._wandb.init( 211 | project=args.project, 212 | config=args 213 | ) 214 | 215 | def log_epoch_metrics(self, metrics, commit=True): 216 | """ 217 | Log train/test metrics onto W&B. 218 | """ 219 | # Log number of model parameters as W&B summary 220 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 221 | metrics.pop('n_parameters', None) 222 | 223 | # Log current epoch 224 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 225 | metrics.pop('epoch') 226 | 227 | for k, v in metrics.items(): 228 | if 'train' in k: 229 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 230 | elif 'test' in k: 231 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 232 | 233 | self._wandb.log({}) 234 | 235 | def log_checkpoints(self): 236 | output_dir = self.args.output_dir 237 | model_artifact = self._wandb.Artifact( 238 | self._wandb.run.id + "_model", type="model" 239 | ) 240 | 241 | model_artifact.add_dir(output_dir) 242 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 243 | 244 | def set_steps(self): 245 | # Set global training step 246 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 247 | # Set epoch-wise step 248 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 249 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 250 | 251 | 252 | def setup_for_distributed(is_master): 253 | """ 254 | This function disables printing when not in master process 255 | """ 256 | import builtins as __builtin__ 257 | builtin_print = __builtin__.print 258 | 259 | def print(*args, **kwargs): 260 | force = kwargs.pop('force', False) 261 | if is_master or force: 262 | builtin_print(*args, **kwargs) 263 | 264 | __builtin__.print = print 265 | 266 | 267 | def is_dist_avail_and_initialized(): 268 | if not dist.is_available(): 269 | return False 270 | if not dist.is_initialized(): 271 | return False 272 | return True 273 | 274 | 275 | def get_world_size(): 276 | if not is_dist_avail_and_initialized(): 277 | return 1 278 | return dist.get_world_size() 279 | 280 | 281 | def get_rank(): 282 | if not is_dist_avail_and_initialized(): 283 | return 0 284 | return dist.get_rank() 285 | 286 | 287 | def is_main_process(): 288 | return get_rank() == 0 289 | 290 | 291 | def save_on_master(*args, **kwargs): 292 | if is_main_process(): 293 | torch.save(*args, **kwargs) 294 | 295 | 296 | def init_distributed_mode(args): 297 | 298 | if args.dist_on_itp: 299 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 300 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 301 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 302 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 303 | os.environ['LOCAL_RANK'] = str(args.gpu) 304 | os.environ['RANK'] = str(args.rank) 305 | os.environ['WORLD_SIZE'] = str(args.world_size) 306 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 307 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 308 | args.rank = int(os.environ["RANK"]) 309 | args.world_size = int(os.environ['WORLD_SIZE']) 310 | args.gpu = int(os.environ['LOCAL_RANK']) 311 | elif 'SLURM_PROCID' in os.environ: 312 | args.rank = int(os.environ['SLURM_PROCID']) 313 | args.gpu = args.rank % torch.cuda.device_count() 314 | 315 | os.environ['RANK'] = str(args.rank) 316 | os.environ['LOCAL_RANK'] = str(args.gpu) 317 | os.environ['WORLD_SIZE'] = str(args.world_size) 318 | else: 319 | print('Not using distributed mode') 320 | args.distributed = False 321 | return 322 | 323 | args.distributed = True 324 | 325 | torch.cuda.set_device(args.gpu) 326 | args.dist_backend = 'nccl' 327 | print('| distributed init (rank {}): {}, gpu {}'.format( 328 | args.rank, args.dist_url, args.gpu), flush=True) 329 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 330 | world_size=args.world_size, rank=args.rank) 331 | torch.distributed.barrier() 332 | setup_for_distributed(args.rank == 0) 333 | 334 | 335 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 336 | missing_keys = [] 337 | unexpected_keys = [] 338 | error_msgs = [] 339 | # copy state_dict so _load_from_state_dict can modify it 340 | metadata = getattr(state_dict, '_metadata', None) 341 | state_dict = state_dict.copy() 342 | if metadata is not None: 343 | state_dict._metadata = metadata 344 | 345 | def load(module, prefix=''): 346 | local_metadata = {} if metadata is None else metadata.get( 347 | prefix[:-1], {}) 348 | module._load_from_state_dict( 349 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 350 | for name, child in module._modules.items(): 351 | if child is not None: 352 | load(child, prefix + name + '.') 353 | 354 | load(model, prefix=prefix) 355 | 356 | warn_missing_keys = [] 357 | ignore_missing_keys = [] 358 | for key in missing_keys: 359 | keep_flag = True 360 | for ignore_key in ignore_missing.split('|'): 361 | if ignore_key in key: 362 | keep_flag = False 363 | break 364 | if keep_flag: 365 | warn_missing_keys.append(key) 366 | else: 367 | ignore_missing_keys.append(key) 368 | 369 | missing_keys = warn_missing_keys 370 | 371 | if len(missing_keys) > 0: 372 | print("Weights of {} not initialized from pretrained model: {}".format( 373 | model.__class__.__name__, missing_keys)) 374 | if len(unexpected_keys) > 0: 375 | print("Weights from pretrained model not used in {}: {}".format( 376 | model.__class__.__name__, unexpected_keys)) 377 | if len(ignore_missing_keys) > 0: 378 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 379 | model.__class__.__name__, ignore_missing_keys)) 380 | if len(error_msgs) > 0: 381 | print('\n'.join(error_msgs)) 382 | 383 | 384 | class NativeScalerWithGradNormCount: 385 | state_dict_key = "amp_scaler" 386 | 387 | def __init__(self): 388 | self._scaler = torch.cuda.amp.GradScaler() 389 | 390 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 391 | self._scaler.scale(loss).backward(create_graph=create_graph) 392 | if update_grad: 393 | if clip_grad is not None: 394 | assert parameters is not None 395 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 396 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 397 | else: 398 | self._scaler.unscale_(optimizer) 399 | norm = get_grad_norm_(parameters) 400 | self._scaler.step(optimizer) 401 | self._scaler.update() 402 | else: 403 | norm = None 404 | return norm 405 | 406 | def state_dict(self): 407 | return self._scaler.state_dict() 408 | 409 | def load_state_dict(self, state_dict): 410 | self._scaler.load_state_dict(state_dict) 411 | 412 | 413 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 414 | if isinstance(parameters, torch.Tensor): 415 | parameters = [parameters] 416 | parameters = [p for p in parameters if p.grad is not None] 417 | norm_type = float(norm_type) 418 | if len(parameters) == 0: 419 | return torch.tensor(0.) 420 | device = parameters[0].grad.device 421 | if norm_type == inf: 422 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 423 | else: 424 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 425 | return total_norm 426 | 427 | 428 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 429 | start_warmup_value=0, warmup_steps=-1): 430 | warmup_schedule = np.array([]) 431 | warmup_iters = warmup_epochs * niter_per_ep 432 | if warmup_steps > 0: 433 | warmup_iters = warmup_steps 434 | print("Set warmup steps = %d" % warmup_iters) 435 | if warmup_epochs > 0: 436 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 437 | 438 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 439 | schedule = np.array( 440 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 441 | 442 | schedule = np.concatenate((warmup_schedule, schedule)) 443 | 444 | assert len(schedule) == epochs * niter_per_ep 445 | return schedule 446 | 447 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 448 | output_dir = Path(args.output_dir) 449 | epoch_name = str(epoch) 450 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 451 | for checkpoint_path in checkpoint_paths: 452 | to_save = { 453 | 'model': model_without_ddp.state_dict(), 454 | 'optimizer': optimizer.state_dict(), 455 | 'epoch': epoch, 456 | 'scaler': loss_scaler.state_dict(), 457 | 'args': args, 458 | } 459 | 460 | if model_ema is not None: 461 | to_save['model_ema'] = get_state_dict(model_ema) 462 | 463 | save_on_master(to_save, checkpoint_path) 464 | 465 | if is_main_process() and isinstance(epoch, int): 466 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 467 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 468 | if os.path.exists(old_ckpt): 469 | os.remove(old_ckpt) 470 | 471 | 472 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 473 | output_dir = Path(args.output_dir) 474 | if args.auto_resume and len(args.resume) == 0: 475 | import glob 476 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 477 | latest_ckpt = -1 478 | for ckpt in all_checkpoints: 479 | t = ckpt.split('-')[-1].split('.')[0] 480 | if t.isdigit(): 481 | latest_ckpt = max(int(t), latest_ckpt) 482 | if latest_ckpt >= 0: 483 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 484 | print("Auto resume checkpoint: %s" % args.resume) 485 | 486 | if args.resume: 487 | if args.resume.startswith('https'): 488 | checkpoint = torch.hub.load_state_dict_from_url( 489 | args.resume, map_location='cpu', check_hash=True) 490 | else: 491 | checkpoint = torch.load(args.resume, map_location='cpu') 492 | if 'model' in checkpoint: 493 | model_without_ddp.load_state_dict(checkpoint['model']) 494 | elif 'state_dict' in checkpoint: 495 | model_without_ddp.load_state_dict(checkpoint['state_dict']) 496 | else: 497 | model_without_ddp.load_state_dict(checkpoint) 498 | print("Resume checkpoint %s" % args.resume) 499 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 500 | optimizer.load_state_dict(checkpoint['optimizer']) 501 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 502 | args.start_epoch = checkpoint['epoch'] + 1 503 | else: 504 | assert args.eval, 'Does not support resuming with checkpoint-best' 505 | if hasattr(args, 'model_ema') and args.model_ema: 506 | if 'model_ema' in checkpoint.keys(): 507 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 508 | else: 509 | model_ema.ema.load_state_dict(checkpoint['model']) 510 | if 'scaler' in checkpoint: 511 | loss_scaler.load_state_dict(checkpoint['scaler']) 512 | print("With optim & sched!") 513 | -------------------------------------------------------------------------------- /models/cnn_core/mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from functools import partial 3 | from typing import Any, Callable, List, Optional, Sequence 4 | 5 | import torch 6 | from torch import nn, Tensor 7 | 8 | from torchvision._internally_replaced_utils import load_state_dict_from_url 9 | from torchvision.ops.misc import ConvNormActivation, SqueezeExcitation as SElayer 10 | # from torchvision.utils import _log_api_usage_once 11 | from torchvision.models._utils import _make_divisible 12 | 13 | 14 | # __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] 15 | 16 | 17 | model_urls = { 18 | "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", 19 | "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", 20 | } 21 | 22 | 23 | class SqueezeExcitation(SElayer): 24 | """DEPRECATED""" 25 | 26 | def __init__(self, input_channels: int, squeeze_factor: int = 4): 27 | squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) 28 | super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid) 29 | self.relu = self.activation 30 | delattr(self, "activation") 31 | warnings.warn( 32 | "This SqueezeExcitation class is deprecated since 0.12 and will be removed in 0.14. " 33 | "Use torchvision.ops.SqueezeExcitation instead.", 34 | FutureWarning, 35 | ) 36 | 37 | 38 | class InvertedResidualConfig: 39 | # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper 40 | def __init__( 41 | self, 42 | input_channels: int, 43 | kernel: int, 44 | expanded_channels: int, 45 | out_channels: int, 46 | use_se: bool, 47 | activation: str, 48 | stride: int, 49 | dilation: int, 50 | width_mult: float, 51 | ): 52 | self.input_channels = self.adjust_channels(input_channels, width_mult) 53 | self.kernel = kernel 54 | self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) 55 | self.out_channels = self.adjust_channels(out_channels, width_mult) 56 | self.use_se = use_se 57 | self.use_hs = activation == "HS" 58 | self.stride = stride 59 | self.dilation = dilation 60 | 61 | @staticmethod 62 | def adjust_channels(channels: int, width_mult: float): 63 | return _make_divisible(channels * width_mult, 8) 64 | 65 | 66 | class InvertedResidual(nn.Module): 67 | # Implemented as described at section 5 of MobileNetV3 paper 68 | def __init__( 69 | self, 70 | cnf: InvertedResidualConfig, 71 | norm_layer: Callable[..., nn.Module], 72 | se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid), 73 | ): 74 | super().__init__() 75 | if not (1 <= cnf.stride <= 2): 76 | raise ValueError("illegal stride value") 77 | 78 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels 79 | 80 | layers: List[nn.Module] = [] 81 | activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU 82 | 83 | self.c_in = cnf.input_channels 84 | self.c_out = cnf.out_channels 85 | # expand 86 | self.has_conv1 = cnf.expanded_channels != cnf.input_channels 87 | if cnf.expanded_channels != cnf.input_channels: 88 | layers.append( 89 | ConvNormActivation( 90 | cnf.input_channels, 91 | cnf.expanded_channels, 92 | kernel_size=1, 93 | norm_layer=norm_layer, 94 | activation_layer=activation_layer, 95 | ) 96 | ) 97 | self.conv1_flops_per_pixel = cnf.input_channels * cnf.expanded_channels 98 | 99 | # depthwise 100 | stride = 1 if cnf.dilation > 1 else cnf.stride 101 | layers.append( 102 | ConvNormActivation( 103 | cnf.expanded_channels, 104 | cnf.expanded_channels, 105 | kernel_size=cnf.kernel, 106 | stride=stride, 107 | dilation=cnf.dilation, 108 | groups=cnf.expanded_channels, 109 | norm_layer=norm_layer, 110 | activation_layer=activation_layer, 111 | ) 112 | ) 113 | self.conv2_flops_per_pixel = cnf.expanded_channels * cnf.kernel**2 114 | if cnf.use_se: 115 | squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) 116 | layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) 117 | self.se_flops = cnf.expanded_channels * squeeze_channels * 2 118 | 119 | # project 120 | layers.append( 121 | ConvNormActivation( 122 | cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None 123 | ) 124 | ) 125 | self.conv3_flops_per_pixel = cnf.expanded_channels * cnf.out_channels 126 | 127 | self.block = nn.ModuleList(layers) 128 | self.out_channels = cnf.out_channels 129 | self._is_cn = cnf.stride > 1 130 | 131 | def forward(self, input: Tensor) -> Tensor: 132 | residual = input 133 | result = input 134 | for i in range(len(self.block)): 135 | result = self.block[i](result) 136 | if self.use_res_connect: 137 | result += residual 138 | return result 139 | 140 | def forward_calc_flops(self, input: Tensor) -> Tensor: 141 | 142 | # print(len(self.block)) 143 | # if len(self.block) == 4: 144 | # print(self.block) 145 | input, flops = input 146 | if len(self.block) == 2: 147 | result = self.block[0](input) 148 | flops += self.conv2_flops_per_pixel * result.shape[2] * result.shape[3] 149 | 150 | result = self.block[1](result) 151 | flops += self.conv3_flops_per_pixel * result.shape[2] * result.shape[3] 152 | 153 | elif len(self.block) == 3: 154 | result = self.block[0](input) 155 | flops += self.conv1_flops_per_pixel * result.shape[2] * result.shape[3] 156 | 157 | result = self.block[1](result) 158 | flops += self.conv2_flops_per_pixel * result.shape[2] * result.shape[3] 159 | 160 | result = self.block[2](result) 161 | flops += self.conv3_flops_per_pixel * result.shape[2] * result.shape[3] 162 | 163 | else: 164 | result = self.block[0](input) 165 | flops += self.conv1_flops_per_pixel * result.shape[2] * result.shape[3] 166 | 167 | result = self.block[1](result) 168 | flops += self.conv2_flops_per_pixel * result.shape[2] * result.shape[3] 169 | 170 | # se layer 171 | flops += result.shape[1] * result.shape[2] * result.shape[3] # global pooling 172 | result = self.block[2](result) 173 | flops += self.se_flops 174 | 175 | result = self.block[3](result) 176 | flops += self.conv3_flops_per_pixel * result.shape[2] * result.shape[3] 177 | 178 | # result = self.block(input) 179 | if self.use_res_connect: 180 | result += input 181 | return result, flops 182 | 183 | 184 | class MobileNetV3(nn.Module): 185 | def __init__( 186 | self, 187 | inverted_residual_setting: List[InvertedResidualConfig], 188 | last_channel: int, 189 | num_classes: int = 1000, 190 | block: Optional[Callable[..., nn.Module]] = None, 191 | norm_layer: Optional[Callable[..., nn.Module]] = None, 192 | dropout: float = 0.2, 193 | **kwargs: Any, 194 | ) -> None: 195 | """ 196 | MobileNet V3 main class 197 | 198 | Args: 199 | inverted_residual_setting (List[InvertedResidualConfig]): Network structure 200 | last_channel (int): The number of channels on the penultimate layer 201 | num_classes (int): Number of classes 202 | block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet 203 | norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use 204 | dropout (float): The droupout probability 205 | """ 206 | super().__init__() 207 | # _log_api_usage_once(self) 208 | 209 | if not inverted_residual_setting: 210 | raise ValueError("The inverted_residual_setting should not be empty") 211 | elif not ( 212 | isinstance(inverted_residual_setting, Sequence) 213 | and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting]) 214 | ): 215 | raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") 216 | 217 | if block is None: 218 | block = InvertedResidual 219 | 220 | if norm_layer is None: 221 | norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) 222 | 223 | layers: List[nn.Module] = [] 224 | 225 | # building first layer 226 | firstconv_output_channels = inverted_residual_setting[0].input_channels 227 | layers.append( 228 | ConvNormActivation( 229 | 3, 230 | firstconv_output_channels, 231 | kernel_size=3, 232 | stride=2, 233 | norm_layer=norm_layer, 234 | activation_layer=nn.Hardswish, 235 | ) 236 | ) 237 | self.stem_flops_per_pixel = 3 * firstconv_output_channels * 9 238 | 239 | # building inverted residual blocks 240 | for cnf in inverted_residual_setting: 241 | layers.append(block(cnf, norm_layer)) 242 | 243 | # building last several layers 244 | lastconv_input_channels = inverted_residual_setting[-1].out_channels 245 | self.lastconv_input_channels = lastconv_input_channels 246 | lastconv_output_channels = 6 * lastconv_input_channels 247 | layers.append( 248 | ConvNormActivation( 249 | lastconv_input_channels, 250 | lastconv_output_channels, 251 | kernel_size=1, 252 | norm_layer=norm_layer, 253 | activation_layer=nn.Hardswish, 254 | ) 255 | ) 256 | self.tail_flops_per_pixel = lastconv_input_channels * lastconv_output_channels 257 | self.lastconv_output_channels = lastconv_output_channels 258 | self.last_channel = last_channel 259 | 260 | self.features = nn.ModuleList(layers) 261 | self.avgpool = nn.AdaptiveAvgPool2d(1) 262 | self.classifier = nn.Sequential( 263 | nn.Linear(lastconv_output_channels, last_channel), 264 | nn.Hardswish(inplace=True), 265 | nn.Dropout(p=dropout, inplace=True), 266 | nn.Linear(last_channel, num_classes), 267 | ) 268 | self.classifier_flops = lastconv_output_channels * last_channel + last_channel * num_classes 269 | for m in self.modules(): 270 | if isinstance(m, nn.Conv2d): 271 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 272 | if m.bias is not None: 273 | nn.init.zeros_(m.bias) 274 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 275 | nn.init.ones_(m.weight) 276 | nn.init.zeros_(m.bias) 277 | elif isinstance(m, nn.Linear): 278 | nn.init.normal_(m.weight, 0, 0.01) 279 | nn.init.zeros_(m.bias) 280 | 281 | def forward(self, x: Tensor) -> Tensor: 282 | print('input', x.shape) 283 | x = self.features[0](x) 284 | flops = self.stem_flops_per_pixel * x.shape[2] * x.shape[3] 285 | print('after features[0]', x.shape) 286 | x = (x, flops) 287 | for i in range(1, len(self.features)-1): 288 | print(f'after features[{i}], {x[0].shape}') 289 | x = self.features[i].forward_calc_flops(x) 290 | x, flops = x 291 | 292 | print('before the last block in features', x.shape) 293 | x = self.features[-1](x) 294 | flops += self.tail_flops_per_pixel * x.shape[2] * x.shape[3] 295 | 296 | flops += x.shape[1] * x.shape[2] * x.shape[3] 297 | x = self.avgpool(x) 298 | x = torch.flatten(x, 1) 299 | 300 | x = self.classifier(x) 301 | flops += self.classifier_flops 302 | print(flops/1e9) 303 | return x 304 | 305 | # def forward(self, x: Tensor) -> Tensor: 306 | # return self._forward_impl(x) 307 | 308 | 309 | def _mobilenet_v3_conf( 310 | arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any 311 | ): 312 | reduce_divider = 2 if reduced_tail else 1 313 | dilation = 2 if dilated else 1 314 | 315 | bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) 316 | adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) 317 | 318 | if arch == "mobilenet_v3_large": 319 | inverted_residual_setting = [ 320 | bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), 321 | bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 322 | bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), 323 | bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 324 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), 325 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), 326 | bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 327 | bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), 328 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), 329 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), 330 | bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), 331 | bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), 332 | bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4 333 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), 334 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), 335 | ] 336 | last_channel = adjust_channels(1280 // reduce_divider) # C5 337 | elif arch == "mobilenet_v3_small": 338 | inverted_residual_setting = [ 339 | bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 340 | bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 341 | bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), 342 | bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 343 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), 344 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), 345 | bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), 346 | bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), 347 | bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 348 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), 349 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), 350 | ] 351 | last_channel = adjust_channels(1024 // reduce_divider) # C5 352 | else: 353 | raise ValueError(f"Unsupported model type {arch}") 354 | 355 | return inverted_residual_setting, last_channel 356 | 357 | 358 | def _mobilenet_v3( 359 | arch: str, 360 | inverted_residual_setting: List[InvertedResidualConfig], 361 | last_channel: int, 362 | pretrained: bool, 363 | progress: bool, 364 | **kwargs: Any, 365 | ): 366 | model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) 367 | if pretrained: 368 | if model_urls.get(arch, None) is None: 369 | raise ValueError(f"No checkpoint is available for model type {arch}") 370 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 371 | model.load_state_dict(state_dict) 372 | return model 373 | 374 | 375 | def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 376 | """ 377 | Constructs a large MobileNetV3 architecture from 378 | `"Searching for MobileNetV3" `_. 379 | 380 | Args: 381 | pretrained (bool): If True, returns a model pre-trained on ImageNet 382 | progress (bool): If True, displays a progress bar of the download to stderr 383 | """ 384 | arch = "mobilenet_v3_large" 385 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) 386 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 387 | 388 | def mobilenet_v3_large_1x25(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 389 | """ 390 | Constructs a large MobileNetV3 architecture from 391 | `"Searching for MobileNetV3" `_. 392 | 393 | Args: 394 | pretrained (bool): If True, returns a model pre-trained on ImageNet 395 | progress (bool): If True, displays a progress bar of the download to stderr 396 | """ 397 | arch = "mobilenet_v3_large" 398 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=1.25, **kwargs) 399 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 400 | 401 | def mobilenet_v3_large_1x5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 402 | """ 403 | Constructs a large MobileNetV3 architecture from 404 | `"Searching for MobileNetV3" `_. 405 | 406 | Args: 407 | pretrained (bool): If True, returns a model pre-trained on ImageNet 408 | progress (bool): If True, displays a progress bar of the download to stderr 409 | """ 410 | arch = "mobilenet_v3_large" 411 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=1.5, **kwargs) 412 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 413 | 414 | def mobilenet_v3_large_2x0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 415 | """ 416 | Constructs a large MobileNetV3 architecture from 417 | `"Searching for MobileNetV3" `_. 418 | 419 | Args: 420 | pretrained (bool): If True, returns a model pre-trained on ImageNet 421 | progress (bool): If True, displays a progress bar of the download to stderr 422 | """ 423 | arch = "mobilenet_v3_large" 424 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=2.0, **kwargs) 425 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 426 | 427 | def mobilenet_v3_large_0x75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 428 | """ 429 | Constructs a large MobileNetV3 architecture from 430 | `"Searching for MobileNetV3" `_. 431 | 432 | Args: 433 | pretrained (bool): If True, returns a model pre-trained on ImageNet 434 | progress (bool): If True, displays a progress bar of the download to stderr 435 | """ 436 | arch = "mobilenet_v3_large" 437 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=0.75, **kwargs) 438 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 439 | 440 | def mobilenet_v3_large_0x5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 441 | """ 442 | Constructs a large MobileNetV3 architecture from 443 | `"Searching for MobileNetV3" `_. 444 | 445 | Args: 446 | pretrained (bool): If True, returns a model pre-trained on ImageNet 447 | progress (bool): If True, displays a progress bar of the download to stderr 448 | """ 449 | arch = "mobilenet_v3_large" 450 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=0.5, **kwargs) 451 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 452 | 453 | 454 | def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: 455 | """ 456 | Constructs a small MobileNetV3 architecture from 457 | `"Searching for MobileNetV3" `_. 458 | 459 | Args: 460 | pretrained (bool): If True, returns a model pre-trained on ImageNet 461 | progress (bool): If True, displays a progress bar of the download to stderr 462 | """ 463 | arch = "mobilenet_v3_small" 464 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) 465 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) 466 | 467 | if __name__ == '__main__': 468 | model = mobilenet_v3_large_2x0() 469 | x = torch.rand(1,3,224,224) 470 | 471 | y = model(x) 472 | 473 | # from op_counter import measure_model 474 | # cls_ops, cls_params = measure_model(model, 224, 224) 475 | # print(cls_ops[-1]/1e9) -------------------------------------------------------------------------------- /models/cnn_core/regnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | from functools import partial 4 | from typing import Any, Callable, List, Optional, Tuple 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | 9 | from torchvision._internally_replaced_utils import load_state_dict_from_url 10 | from torchvision.ops.misc import ConvNormActivation, SqueezeExcitation 11 | # from torchvision.utils import _log_api_usage_once 12 | from torchvision.models._utils import _make_divisible 13 | 14 | model_urls = { 15 | "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", 16 | "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", 17 | "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", 18 | "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", 19 | "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", 20 | "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", 21 | "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", 22 | "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", 23 | "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", 24 | "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", 25 | "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", 26 | "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", 27 | "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", 28 | "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", 29 | } 30 | 31 | class SimpleStemIN(ConvNormActivation): 32 | """Simple stem for ImageNet: 3x3, BN, ReLU.""" 33 | 34 | def __init__( 35 | self, 36 | width_in: int, 37 | width_out: int, 38 | norm_layer: Callable[..., nn.Module], 39 | activation_layer: Callable[..., nn.Module], 40 | ) -> None: 41 | self.width_out = width_out 42 | super().__init__( 43 | width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer 44 | ) 45 | 46 | class BottleneckTransform(nn.Sequential): 47 | """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1.""" 48 | 49 | def __init__( 50 | self, 51 | width_in: int, 52 | width_out: int, 53 | stride: int, 54 | norm_layer: Callable[..., nn.Module], 55 | activation_layer: Callable[..., nn.Module], 56 | group_width: int, 57 | bottleneck_multiplier: float, 58 | se_ratio: Optional[float], 59 | ) -> None: 60 | self.width_in = width_in 61 | self.width_out = width_out 62 | layers: OrderedDict[str, nn.Module] = OrderedDict() 63 | w_b = int(round(width_out * bottleneck_multiplier)) 64 | g = w_b // group_width 65 | 66 | layers["a"] = ConvNormActivation( 67 | width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer 68 | ) 69 | layers["b"] = ConvNormActivation( 70 | w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer 71 | ) 72 | 73 | if se_ratio: 74 | # The SE reduction ratio is defined with respect to the 75 | # beginning of the block 76 | width_se_out = int(round(se_ratio * width_in)) 77 | layers["se"] = SqueezeExcitation( 78 | input_channels=w_b, 79 | squeeze_channels=width_se_out, 80 | activation=activation_layer, 81 | ) 82 | 83 | layers["c"] = ConvNormActivation( 84 | w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None 85 | ) 86 | super().__init__(layers) 87 | 88 | 89 | class ResBottleneckBlock(nn.Module): 90 | """Residual bottleneck block: x + F(x), F = bottleneck transform.""" 91 | 92 | def __init__( 93 | self, 94 | width_in: int, 95 | width_out: int, 96 | stride: int, 97 | norm_layer: Callable[..., nn.Module], 98 | activation_layer: Callable[..., nn.Module], 99 | group_width: int = 1, 100 | bottleneck_multiplier: float = 1.0, 101 | se_ratio: Optional[float] = None, 102 | ) -> None: 103 | super().__init__() 104 | 105 | # Use skip connection with projection if shape changes 106 | self.proj = None 107 | should_proj = (width_in != width_out) or (stride != 1) 108 | if should_proj: 109 | self.proj = ConvNormActivation( 110 | width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None 111 | ) 112 | self.f = BottleneckTransform( 113 | width_in, 114 | width_out, 115 | stride, 116 | norm_layer, 117 | activation_layer, 118 | group_width, 119 | bottleneck_multiplier, 120 | se_ratio, 121 | ) 122 | self.activation = activation_layer(inplace=True) 123 | w_b = int(round(width_out * bottleneck_multiplier)) 124 | g = w_b // group_width 125 | self.conv1_flops_per_pixel = width_in*w_b 126 | self.conv2_flops_per_pixel = w_b*w_b*9 // g 127 | self.conv3_flops_per_pixel = w_b*width_out 128 | self.stride=stride 129 | 130 | if self.proj is not None: 131 | self.downsample_flops = width_in * width_out 132 | 133 | def forward(self, x: Tensor) -> Tensor: 134 | identity = self.proj(x) if self.proj is not None else x 135 | 136 | x = self.f(x) 137 | x = self.activation(x + identity) 138 | return x 139 | 140 | 141 | def forward_calc_flops(self, x: Tensor) -> Tensor: 142 | x, flops = x 143 | if self.proj is not None: 144 | identity = self.proj(x) 145 | flops += self.downsample_flops * identity.shape[2] * identity.shape[3] 146 | else: 147 | identity = x 148 | 149 | x = self.f(x) 150 | 151 | b,c,h,w = x.shape 152 | h_conv1 = h * self.stride 153 | 154 | flops += (self.conv1_flops_per_pixel * h_conv1 * h_conv1 + self.conv2_flops_per_pixel * h * w + self.conv3_flops_per_pixel * h * w) 155 | 156 | x = self.activation(x + identity) 157 | return x, flops 158 | 159 | 160 | class AnyStage(nn.ModuleList): 161 | """AnyNet stage (sequence of blocks w/ the same output shape).""" 162 | 163 | def __init__( 164 | self, 165 | width_in: int, 166 | width_out: int, 167 | stride: int, 168 | depth: int, 169 | block_constructor: Callable[..., nn.Module], 170 | norm_layer: Callable[..., nn.Module], 171 | activation_layer: Callable[..., nn.Module], 172 | group_width: int, 173 | bottleneck_multiplier: float, 174 | se_ratio: Optional[float] = None, 175 | stage_index: int = 0, 176 | ) -> None: 177 | super().__init__() 178 | 179 | self.c_in = width_in 180 | self.c_out = width_out 181 | 182 | for i in range(depth): 183 | block = block_constructor( 184 | width_in if i == 0 else width_out, 185 | width_out, 186 | stride if i == 0 else 1, 187 | norm_layer, 188 | activation_layer, 189 | group_width, 190 | bottleneck_multiplier, 191 | se_ratio, 192 | ) 193 | 194 | self.add_module(f"block{stage_index}-{i}", block) 195 | 196 | def forward(self, x): 197 | for block in self: 198 | x = block(x) 199 | return x 200 | 201 | def forward_calc_flops(self, x): 202 | flops = 0 203 | x = (x, flops) 204 | for block in self: 205 | x = block.forward_calc_flops(x) 206 | 207 | # x, flops = x 208 | return x # should be a tuple of (x, flops) 209 | 210 | 211 | class BlockParams: 212 | def __init__( 213 | self, 214 | depths: List[int], 215 | widths: List[int], 216 | group_widths: List[int], 217 | bottleneck_multipliers: List[float], 218 | strides: List[int], 219 | se_ratio: Optional[float] = None, 220 | ) -> None: 221 | self.depths = depths 222 | self.widths = widths 223 | self.group_widths = group_widths 224 | self.bottleneck_multipliers = bottleneck_multipliers 225 | self.strides = strides 226 | self.se_ratio = se_ratio 227 | 228 | @classmethod 229 | def from_init_params( 230 | cls, 231 | depth: int, 232 | w_0: int, 233 | w_a: float, 234 | w_m: float, 235 | group_width: int, 236 | bottleneck_multiplier: float = 1.0, 237 | se_ratio: Optional[float] = None, 238 | **kwargs: Any, 239 | ) -> "BlockParams": 240 | """ 241 | Programatically compute all the per-block settings, 242 | given the RegNet parameters. 243 | 244 | The first step is to compute the quantized linear block parameters, 245 | in log space. Key parameters are: 246 | - `w_a` is the width progression slope 247 | - `w_0` is the initial width 248 | - `w_m` is the width stepping in the log space 249 | 250 | In other terms 251 | `log(block_width) = log(w_0) + w_m * block_capacity`, 252 | with `bock_capacity` ramping up following the w_0 and w_a params. 253 | This block width is finally quantized to multiples of 8. 254 | 255 | The second step is to compute the parameters per stage, 256 | taking into account the skip connection and the final 1x1 convolutions. 257 | We use the fact that the output width is constant within a stage. 258 | """ 259 | 260 | QUANT = 8 261 | STRIDE = 2 262 | 263 | if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0: 264 | raise ValueError("Invalid RegNet settings") 265 | # Compute the block widths. Each stage has one unique block width 266 | widths_cont = torch.arange(depth) * w_a + w_0 267 | block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m)) 268 | block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist() 269 | num_stages = len(set(block_widths)) 270 | 271 | # Convert to per stage parameters 272 | split_helper = zip( 273 | block_widths + [0], 274 | [0] + block_widths, 275 | block_widths + [0], 276 | [0] + block_widths, 277 | ) 278 | splits = [w != wp or r != rp for w, wp, r, rp in split_helper] 279 | 280 | stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t] 281 | stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist() 282 | 283 | strides = [STRIDE] * num_stages 284 | bottleneck_multipliers = [bottleneck_multiplier] * num_stages 285 | group_widths = [group_width] * num_stages 286 | 287 | # Adjust the compatibility of stage widths and group widths 288 | stage_widths, group_widths = cls._adjust_widths_groups_compatibilty( 289 | stage_widths, bottleneck_multipliers, group_widths 290 | ) 291 | 292 | return cls( 293 | depths=stage_depths, 294 | widths=stage_widths, 295 | group_widths=group_widths, 296 | bottleneck_multipliers=bottleneck_multipliers, 297 | strides=strides, 298 | se_ratio=se_ratio, 299 | ) 300 | 301 | def _get_expanded_params(self): 302 | return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers) 303 | 304 | @staticmethod 305 | def _adjust_widths_groups_compatibilty( 306 | stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int] 307 | ) -> Tuple[List[int], List[int]]: 308 | """ 309 | Adjusts the compatibility of widths and groups, 310 | depending on the bottleneck ratio. 311 | """ 312 | # Compute all widths for the current settings 313 | widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)] 314 | group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)] 315 | 316 | # Compute the adjusted widths so that stage and group widths fit 317 | ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)] 318 | stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)] 319 | return stage_widths, group_widths_min 320 | 321 | class RegNet(nn.Module): 322 | def __init__( 323 | self, 324 | block_params: BlockParams, 325 | num_classes: int = 1000, 326 | stem_width: int = 32, 327 | stem_type: Optional[Callable[..., nn.Module]] = None, 328 | block_type: Optional[Callable[..., nn.Module]] = None, 329 | norm_layer: Optional[Callable[..., nn.Module]] = None, 330 | activation: Optional[Callable[..., nn.Module]] = None, 331 | ) -> None: 332 | super().__init__() 333 | # _log_api_usage_once(self) 334 | 335 | if stem_type is None: 336 | stem_type = SimpleStemIN 337 | if norm_layer is None: 338 | norm_layer = nn.BatchNorm2d 339 | if block_type is None: 340 | block_type = ResBottleneckBlock 341 | if activation is None: 342 | activation = nn.ReLU 343 | 344 | # Ad hoc stem 345 | self.stem = stem_type( 346 | 3, # width_in 347 | stem_width, 348 | norm_layer, 349 | activation, 350 | ) 351 | 352 | current_width = stem_width 353 | 354 | blocks = [] 355 | 356 | # print(block_params._get_expanded_params()) 357 | 358 | for i, ( 359 | width_out, 360 | stride, 361 | depth, 362 | group_width, 363 | bottleneck_multiplier, 364 | ) in enumerate(block_params._get_expanded_params()): 365 | # print(stride, depth) 366 | blocks.append( 367 | ( 368 | f"block{i+1}", 369 | AnyStage( 370 | current_width, 371 | width_out, 372 | stride, 373 | depth, 374 | block_type, 375 | norm_layer, 376 | activation, 377 | group_width, 378 | bottleneck_multiplier, 379 | block_params.se_ratio, 380 | stage_index=i + 1, 381 | ), 382 | ) 383 | ) 384 | 385 | current_width = width_out 386 | 387 | self.trunk_output = nn.Sequential(OrderedDict(blocks)) 388 | 389 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 390 | self.fc = nn.Linear(in_features=current_width, out_features=num_classes) 391 | 392 | # Performs ResNet-style weight initialization 393 | for m in self.modules(): 394 | if isinstance(m, nn.Conv2d): 395 | # Note that there is no bias due to BN 396 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 397 | nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out)) 398 | elif isinstance(m, nn.BatchNorm2d): 399 | nn.init.ones_(m.weight) 400 | nn.init.zeros_(m.bias) 401 | elif isinstance(m, nn.Linear): 402 | nn.init.normal_(m.weight, mean=0.0, std=0.01) 403 | nn.init.zeros_(m.bias) 404 | 405 | def forward(self, x: Tensor) -> Tensor: 406 | x = self.stem(x) 407 | x = self.trunk_output(x) 408 | 409 | x = self.avgpool(x) 410 | x = x.flatten(start_dim=1) 411 | x = self.fc(x) 412 | 413 | return x 414 | 415 | 416 | def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet: 417 | norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1)) 418 | model = RegNet(block_params, norm_layer=norm_layer, **kwargs) 419 | if pretrained: 420 | if arch not in model_urls: 421 | raise ValueError(f"No checkpoint is available for model type {arch}") 422 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 423 | model.load_state_dict(state_dict) 424 | return model 425 | 426 | def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 427 | """ 428 | Constructs a RegNetY_400MF architecture from 429 | `"Designing Network Design Spaces" `_. 430 | 431 | Args: 432 | pretrained (bool): If True, returns a model pre-trained on ImageNet 433 | progress (bool): If True, displays a progress bar of the download to stderr 434 | """ 435 | params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) 436 | return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) 437 | 438 | 439 | 440 | 441 | def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 442 | """ 443 | Constructs a RegNetY_800MF architecture from 444 | `"Designing Network Design Spaces" `_. 445 | 446 | Args: 447 | pretrained (bool): If True, returns a model pre-trained on ImageNet 448 | progress (bool): If True, displays a progress bar of the download to stderr 449 | """ 450 | params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) 451 | return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) 452 | 453 | 454 | 455 | def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 456 | """ 457 | Constructs a RegNetY_1.6GF architecture from 458 | `"Designing Network Design Spaces" `_. 459 | 460 | Args: 461 | pretrained (bool): If True, returns a model pre-trained on ImageNet 462 | progress (bool): If True, displays a progress bar of the download to stderr 463 | """ 464 | params = BlockParams.from_init_params( 465 | depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs 466 | ) 467 | return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) 468 | 469 | 470 | 471 | def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 472 | """ 473 | Constructs a RegNetY_3.2GF architecture from 474 | `"Designing Network Design Spaces" `_. 475 | 476 | Args: 477 | pretrained (bool): If True, returns a model pre-trained on ImageNet 478 | progress (bool): If True, displays a progress bar of the download to stderr 479 | """ 480 | params = BlockParams.from_init_params( 481 | depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs 482 | ) 483 | return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) 484 | 485 | 486 | 487 | 488 | def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 489 | """ 490 | Constructs a RegNetY_8GF architecture from 491 | `"Designing Network Design Spaces" `_. 492 | 493 | Args: 494 | pretrained (bool): If True, returns a model pre-trained on ImageNet 495 | progress (bool): If True, displays a progress bar of the download to stderr 496 | """ 497 | params = BlockParams.from_init_params( 498 | depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs 499 | ) 500 | return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) 501 | 502 | 503 | 504 | def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 505 | """ 506 | Constructs a RegNetY_16GF architecture from 507 | `"Designing Network Design Spaces" `_. 508 | 509 | Args: 510 | pretrained (bool): If True, returns a model pre-trained on ImageNet 511 | progress (bool): If True, displays a progress bar of the download to stderr 512 | """ 513 | params = BlockParams.from_init_params( 514 | depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs 515 | ) 516 | return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) 517 | 518 | 519 | 520 | def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 521 | """ 522 | Constructs a RegNetY_32GF architecture from 523 | `"Designing Network Design Spaces" `_. 524 | 525 | Args: 526 | pretrained (bool): If True, returns a model pre-trained on ImageNet 527 | progress (bool): If True, displays a progress bar of the download to stderr 528 | """ 529 | params = BlockParams.from_init_params( 530 | depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs 531 | ) 532 | return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) 533 | 534 | 535 | 536 | def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet: 537 | """ 538 | Constructs a RegNetY_128GF architecture from 539 | `"Designing Network Design Spaces" `_. 540 | NOTE: Pretrained weights are not available for this model. 541 | """ 542 | params = BlockParams.from_init_params( 543 | depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs 544 | ) 545 | return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs) 546 | 547 | 548 | if __name__ == '__main__': 549 | net = regnet_y_400mf(pretrained=False) 550 | print(net.fc.weight.shape) -------------------------------------------------------------------------------- /main_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | import time 15 | import torch 16 | import torch.nn as nn 17 | import torch.backends.cudnn as cudnn 18 | import json 19 | import os 20 | 21 | from pathlib import Path 22 | 23 | from timm.data.mixup import Mixup 24 | from timm.models import create_model 25 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 26 | from timm.utils import ModelEma 27 | from optim_factory import create_optimizer, LayerDecayValueAssigner 28 | 29 | from datasets import build_dataset 30 | from engine import train_one_epoch, evaluate 31 | 32 | from utils import NativeScalerWithGradNormCount as NativeScaler 33 | import utils 34 | 35 | import models 36 | 37 | 38 | def str2bool(v): 39 | """ 40 | Converts string to bool type; enables command line 41 | arguments in the format of '--arg1 true --arg2 false' 42 | """ 43 | if isinstance(v, bool): 44 | return v 45 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 46 | return True 47 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 48 | return False 49 | else: 50 | raise argparse.ArgumentTypeError('Boolean value expected.') 51 | 52 | def get_args_parser(): 53 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 54 | parser.add_argument('--batch_size', default=64, type=int, 55 | help='Per GPU batch size') 56 | parser.add_argument('--epochs', default=300, type=int) 57 | parser.add_argument('--update_freq', default=1, type=int, 58 | help='gradient accumulation steps') 59 | 60 | # Model parameters 61 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL', 62 | help='Name of model to train') 63 | parser.add_argument('--classifier_type', type=str, default='cnn', choices=['cnn', 'latent', 'merge'],) 64 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', 65 | help='Drop path rate (default: 0.0)') 66 | parser.add_argument('--dropout', type=float, default=0) 67 | parser.add_argument('--input_size', default=224, type=int, 68 | help='image input size') 69 | parser.add_argument('--num_latent_channels', type=int, default=None) 70 | parser.add_argument('--SA_widening_factor', type=int, default=1) 71 | parser.add_argument('--num_SA_heads', type=list, default=None) 72 | parser.add_argument('--spatial_reduction', type=str2bool, default=False) 73 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 74 | help="Layer scale initial values") 75 | 76 | # EMA related parameters 77 | parser.add_argument('--model_ema', type=str2bool, default=False) 78 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 79 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 80 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 81 | 82 | # Optimization parameters 83 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 84 | help='Optimizer (default: "adamw"') 85 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 86 | help='Optimizer Epsilon (default: 1e-8)') 87 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 88 | help='Optimizer Betas (default: None, use opt default)') 89 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 90 | help='Clip gradient norm (default: None, no clipping)') 91 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 92 | help='SGD momentum (default: 0.9)') 93 | parser.add_argument('--weight_decay', type=float, default=0.05, 94 | help='weight decay (default: 0.05)') 95 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 96 | weight decay. We use a cosine schedule for WD and using a larger decay by 97 | the end of training improves performance for ViTs.""") 98 | 99 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR', 100 | help='learning rate (default: 4e-3), with total batch size 4096') 101 | parser.add_argument('--layer_decay', type=float, default=1.0) 102 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 103 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 104 | parser.add_argument('--warmup_epochs', type=int, default=20, metavar='N', 105 | help='epochs to warmup LR, if scheduler supports') 106 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 107 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 108 | 109 | # Augmentation parameters 110 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 111 | help='Color jitter factor (default: 0.4)') 112 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 113 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 114 | parser.add_argument('--smoothing', type=float, default=0.1, 115 | help='Label smoothing (default: 0.1)') 116 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 117 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 118 | 119 | # Evaluation parameters 120 | parser.add_argument('--crop_pct', type=float, default=None) 121 | 122 | # * Random Erase params 123 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 124 | help='Random erase prob (default: 0.25)') 125 | parser.add_argument('--remode', type=str, default='pixel', 126 | help='Random erase mode (default: "pixel")') 127 | parser.add_argument('--recount', type=int, default=1, 128 | help='Random erase count (default: 1)') 129 | parser.add_argument('--resplit', type=str2bool, default=False, 130 | help='Do not random erase first (clean) augmentation split') 131 | 132 | # * Mixup params 133 | parser.add_argument('--mixup', type=float, default=0.8, 134 | help='mixup alpha, mixup enabled if > 0.') 135 | parser.add_argument('--cutmix', type=float, default=1.0, 136 | help='cutmix alpha, cutmix enabled if > 0.') 137 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 138 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 139 | parser.add_argument('--mixup_prob', type=float, default=1.0, 140 | help='Probability of performing mixup or cutmix when either/both is enabled') 141 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 142 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 143 | parser.add_argument('--mixup_mode', type=str, default='batch', 144 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 145 | 146 | # * Finetuning params 147 | parser.add_argument('--finetune', default='', 148 | help='finetune from checkpoint') 149 | parser.add_argument('--head_init_scale', default=1.0, type=float, 150 | help='classifier head initial scale, typically adjusted in fine-tuning') 151 | parser.add_argument('--model_key', default='model|module|state_dict_ema', type=str, 152 | help='which key to load from saved state dict, usually model or model_ema') 153 | parser.add_argument('--model_prefix', default='', type=str) 154 | 155 | # Dataset parameters 156 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 157 | help='dataset path') 158 | parser.add_argument('--eval_data_path', default=None, type=str, 159 | help='dataset path for evaluation') 160 | parser.add_argument('--nb_classes', default=1000, type=int, 161 | help='number of the classification types') 162 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 163 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'], 164 | type=str, help='ImageNet dataset path') 165 | parser.add_argument('--output_dir', default='', 166 | help='path where to save, empty for no saving') 167 | parser.add_argument('--log_dir', default=None, 168 | help='path where to tensorboard log') 169 | parser.add_argument('--device', default='cuda', 170 | help='device to use for training / testing') 171 | parser.add_argument('--seed', default=0, type=int) 172 | 173 | parser.add_argument('--resume', default='', 174 | help='resume from checkpoint') 175 | parser.add_argument('--auto_resume', type=str2bool, default=True) 176 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 177 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 178 | parser.add_argument('--save_ckpt_num', default=3, type=int) 179 | 180 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 181 | help='start epoch') 182 | parser.add_argument('--eval', type=str2bool, default=False, 183 | help='Perform evaluation only') 184 | parser.add_argument('--dist_eval', type=str2bool, default=True, 185 | help='Enabling distributed evaluation') 186 | parser.add_argument('--disable_eval', type=str2bool, default=False, 187 | help='Disabling evaluation during training') 188 | parser.add_argument('--num_workers', default=10, type=int) 189 | parser.add_argument('--pin_mem', type=str2bool, default=True, 190 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 191 | 192 | # distributed training parameters 193 | parser.add_argument('--world_size', default=1, type=int, 194 | help='number of distributed processes') 195 | parser.add_argument('--local_rank', default=-1, type=int) 196 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 197 | parser.add_argument('--dist_url', default='env://', 198 | help='url used to set up distributed training') 199 | 200 | parser.add_argument('--use_amp', type=str2bool, default=False, 201 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 202 | 203 | # Weights and Biases arguments 204 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 205 | help="enable logging to Weights and Biases") 206 | parser.add_argument('--project', default='convnext', type=str, 207 | help="The name of the W&B project where you're sending the new run.") 208 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 209 | help="Save model checkpoints as W&B Artifacts.") 210 | 211 | return parser 212 | 213 | def main(args): 214 | utils.init_distributed_mode(args) 215 | print(args) 216 | device = torch.device(args.device) 217 | 218 | # fix the seed for reproducibility 219 | seed = args.seed + utils.get_rank() 220 | torch.manual_seed(seed) 221 | np.random.seed(seed) 222 | cudnn.benchmark = True 223 | 224 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 225 | if args.disable_eval: 226 | args.dist_eval = False 227 | dataset_val = None 228 | else: 229 | dataset_val, _ = build_dataset(is_train=False, args=args) 230 | 231 | num_tasks = utils.get_world_size() 232 | global_rank = utils.get_rank() 233 | 234 | sampler_train = torch.utils.data.DistributedSampler( 235 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 236 | ) 237 | print("Sampler_train = %s" % str(sampler_train)) 238 | if args.dist_eval: 239 | if len(dataset_val) % num_tasks != 0: 240 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 241 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 242 | 'equal num of samples per-process.') 243 | sampler_val = torch.utils.data.DistributedSampler( 244 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 245 | else: 246 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 247 | 248 | if global_rank == 0 and args.log_dir is not None: 249 | os.makedirs(args.log_dir, exist_ok=True) 250 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 251 | else: 252 | log_writer = None 253 | 254 | if global_rank == 0 and args.enable_wandb: 255 | wandb_logger = utils.WandbLogger(args) 256 | else: 257 | wandb_logger = None 258 | 259 | data_loader_train = torch.utils.data.DataLoader( 260 | dataset_train, sampler=sampler_train, 261 | batch_size=args.batch_size, 262 | num_workers=args.num_workers, 263 | pin_memory=args.pin_mem, 264 | drop_last=True, 265 | ) 266 | 267 | if dataset_val is not None: 268 | data_loader_val = torch.utils.data.DataLoader( 269 | dataset_val, sampler=sampler_val, 270 | batch_size=int(1.5 * args.batch_size), 271 | num_workers=args.num_workers, 272 | pin_memory=args.pin_mem, 273 | drop_last=False 274 | ) 275 | else: 276 | data_loader_val = None 277 | 278 | mixup_fn = None 279 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 280 | if mixup_active: 281 | print("Mixup is activated!") 282 | mixup_fn = Mixup( 283 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 284 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 285 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 286 | 287 | 288 | model = eval(f'models.{args.model}')() 289 | 290 | if args.finetune: 291 | if args.finetune.startswith('https'): 292 | checkpoint = torch.hub.load_state_dict_from_url( 293 | args.finetune, map_location='cpu', check_hash=True) 294 | else: 295 | checkpoint = torch.load(args.finetune, map_location='cpu') 296 | 297 | print("Load ckpt from %s" % args.finetune) 298 | checkpoint_model = None 299 | for model_key in args.model_key.split('|'): 300 | if model_key in checkpoint: 301 | checkpoint_model = checkpoint[model_key] 302 | print("Load state_dict by model_key = %s" % model_key) 303 | break 304 | if checkpoint_model is None: 305 | checkpoint_model = checkpoint 306 | state_dict = model.state_dict() 307 | for k in ['head.weight', 'head.bias']: 308 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 309 | print(f"Removing key {k} from pretrained checkpoint") 310 | del checkpoint_model[k] 311 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 312 | model.to(device) 313 | 314 | model_ema = None 315 | if args.model_ema: 316 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 317 | model_ema = ModelEma( 318 | model, 319 | decay=args.model_ema_decay, 320 | device='cpu' if args.model_ema_force_cpu else '', 321 | resume='') 322 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 323 | 324 | model_without_ddp = model 325 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 326 | 327 | # print("Model = %s" % str(model_without_ddp)) 328 | print('number of params:', n_parameters) 329 | 330 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 331 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 332 | print("LR = %.8f" % args.lr) 333 | print("Batch size = %d" % total_batch_size) 334 | print("Update frequent = %d" % args.update_freq) 335 | print("Number of training examples = %d" % len(dataset_train)) 336 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 337 | 338 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 339 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value. 340 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \ 341 | "Layer Decay impl only supports convnext_small/base/large/xlarge" 342 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 343 | else: 344 | assigner = None 345 | 346 | if assigner is not None: 347 | print("Assigned values = %s" % str(assigner.values)) 348 | 349 | if args.distributed: 350 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 351 | model_without_ddp = model.module 352 | 353 | optimizer = create_optimizer( 354 | args, model_without_ddp, skip_list=None, 355 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 356 | get_layer_scale=assigner.get_scale if assigner is not None else None) 357 | 358 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 359 | 360 | print("Use Cosine LR scheduler") 361 | lr_schedule_values = utils.cosine_scheduler( 362 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 363 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 364 | ) 365 | 366 | if args.weight_decay_end is None: 367 | args.weight_decay_end = args.weight_decay 368 | wd_schedule_values = utils.cosine_scheduler( 369 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 370 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 371 | 372 | if mixup_fn is not None: 373 | # smoothing is handled with mixup label transform 374 | criterion = SoftTargetCrossEntropy() 375 | elif args.smoothing > 0.: 376 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 377 | else: 378 | criterion = torch.nn.CrossEntropyLoss() 379 | 380 | print("criterion = %s" % str(criterion)) 381 | 382 | utils.auto_load_model( 383 | args=args, model=model, model_without_ddp=model_without_ddp, 384 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 385 | 386 | if args.eval: 387 | print(f"Eval only mode") 388 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 389 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 390 | return 391 | 392 | max_accuracy = 0.0 393 | if args.model_ema and args.model_ema_eval: 394 | max_accuracy_ema = 0.0 395 | 396 | print("Start training for %d epochs" % args.epochs) 397 | start_time = time.time() 398 | for epoch in range(args.start_epoch, args.epochs): 399 | if args.distributed: 400 | data_loader_train.sampler.set_epoch(epoch) 401 | if log_writer is not None: 402 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 403 | if wandb_logger: 404 | wandb_logger.set_steps() 405 | train_stats = train_one_epoch( 406 | model, criterion, data_loader_train, optimizer, 407 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 408 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 409 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 410 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 411 | use_amp=args.use_amp 412 | ) 413 | if args.output_dir and args.save_ckpt: 414 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 415 | utils.save_model( 416 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 417 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 418 | if data_loader_val is not None: 419 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 420 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 421 | if max_accuracy < test_stats["acc1"]: 422 | max_accuracy = test_stats["acc1"] 423 | if args.output_dir and args.save_ckpt: 424 | utils.save_model( 425 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 426 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 427 | print(f'Max accuracy: {max_accuracy:.2f}%') 428 | 429 | if log_writer is not None: 430 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 431 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 432 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 433 | 434 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 435 | **{f'test_{k}': v for k, v in test_stats.items()}, 436 | 'epoch': epoch, 437 | 'n_parameters': n_parameters} 438 | 439 | # repeat testing routines for EMA, if ema eval is turned on 440 | if args.model_ema and args.model_ema_eval: 441 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 442 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 443 | if max_accuracy_ema < test_stats_ema["acc1"]: 444 | max_accuracy_ema = test_stats_ema["acc1"] 445 | if args.output_dir and args.save_ckpt: 446 | utils.save_model( 447 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 448 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 449 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 450 | if log_writer is not None: 451 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 452 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 453 | else: 454 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 455 | 'epoch': epoch, 456 | 'n_parameters': n_parameters} 457 | 458 | if args.output_dir and utils.is_main_process(): 459 | if log_writer is not None: 460 | log_writer.flush() 461 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 462 | f.write(json.dumps(log_stats) + "\n") 463 | 464 | if wandb_logger: 465 | wandb_logger.log_epoch_metrics(log_stats) 466 | 467 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 468 | wandb_logger.log_checkpoints() 469 | 470 | 471 | total_time = time.time() - start_time 472 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 473 | print('Training time {}'.format(total_time_str)) 474 | 475 | if __name__ == '__main__': 476 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 477 | args = parser.parse_args() 478 | if args.output_dir: 479 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 480 | main(args) 481 | -------------------------------------------------------------------------------- /models/dyn_perceiver_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Optional, Tuple 5 | 6 | import torch 7 | from einops import rearrange, repeat 8 | from torch import Tensor 9 | import torch.nn as nn 10 | 11 | import torch.nn.functional as F 12 | import numpy as np 13 | # from .registry import register_model 14 | from timm.models.registry import register_model 15 | # from fairscale.nn import checkpoint_wrapper 16 | # from fvcore.nn import FlopCountAnalysis, parameter_count_table 17 | 18 | from .perceiver_core import ( 19 | CrossAttentionLayer, 20 | SelfAttentionBlock, 21 | ) 22 | from .cnn_core.resnet import * 23 | 24 | 25 | class DynPerceiver(nn.Module): 26 | def __init__(self, 27 | input_size: int=224, 28 | num_classes:int=1000, 29 | cnn_arch: str="resnet18", 30 | num_SA_heads: list=[1,2,4,8], 31 | num_latents: int=32, 32 | num_latent_channels: int=None, 33 | dropout: float = 0.0, 34 | SA_widening_factor: int=4, 35 | activation_checkpointing: bool = False, 36 | spatial_reduction: bool=True, 37 | depth_factor: list=[1,1,1,3], 38 | output_dir: str='./', 39 | with_x2z=True, 40 | with_z2x=True, 41 | with_dwc=True, 42 | with_last_CA=True, 43 | with_isc=True, 44 | 45 | drop_rate=0.2, 46 | drop_path_rate=0.2, 47 | 48 | exit=-1, 49 | 50 | **kwargs): 51 | super().__init__() 52 | self.exit = exit 53 | if num_SA_heads is None: 54 | num_SA_heads = [1,2,4,8] 55 | self.num_classes = num_classes 56 | cnn = eval(f'{cnn_arch}')(drop_rate=drop_rate, 57 | drop_path_rate=drop_path_rate) 58 | self.cnn_stem = nn.Sequential( 59 | cnn.conv1, 60 | cnn.bn1, 61 | cnn.act1, 62 | cnn.maxpool 63 | ) 64 | 65 | self.cnn_body_stage1 = cnn.layer1 66 | self.cnn_body_stage2 = cnn.layer2 67 | self.cnn_body_stage3 = cnn.layer3 68 | self.cnn_body_stage4 = cnn.layer4 69 | 70 | # self.cnn_body_last_conv1x1 = cnn.blocks[6] 71 | 72 | num_blocks_per_stage = [3*depth_factor[0], 3*depth_factor[1], 9*depth_factor[2], 3*depth_factor[3]] 73 | self.avgpool = cnn.global_pool 74 | # self.cnn_head_before_cls = nn.Sequential( 75 | # cnn.conv_head, 76 | # cnn.act2 77 | # ) 78 | # self.flatten_cnn = cnn.flatten 79 | self.drop_rate_cnn = cnn.drop_rate 80 | self.classifier_cnn = cnn.fc 81 | 82 | print(cnn.drop_rate) 83 | 84 | self.spatial_reduction = spatial_reduction 85 | if spatial_reduction: 86 | self.ca_pooling = nn.AdaptiveAvgPool2d((7,7)) 87 | def cross_attn(num_cross_attention_heads, q_input_channels, kv_input_channels, num_cross_attention_qk_channels, num_cross_attention_v_channels, cross_attention_widening_factor, 88 | rpb=False, 89 | feat_w=112, 90 | feat_h=112): 91 | layer = CrossAttentionLayer( 92 | num_heads=num_cross_attention_heads, 93 | num_q_input_channels=q_input_channels, 94 | num_kv_input_channels=kv_input_channels, 95 | num_qk_channels=num_cross_attention_qk_channels, 96 | num_v_channels=num_cross_attention_v_channels, 97 | widening_factor=cross_attention_widening_factor, 98 | dropout=dropout, 99 | 100 | rpb=rpb, 101 | feat_w=feat_w, 102 | feat_h=feat_h, 103 | ) 104 | return layer 105 | 106 | def self_attn(num_self_attention_layers_per_block, num_self_attention_heads, num_channels, num_self_attention_qk_channels, num_self_attention_v_channels, self_attention_widening_factor): 107 | return SelfAttentionBlock( 108 | num_layers=num_self_attention_layers_per_block, 109 | num_heads=num_self_attention_heads, 110 | num_channels=num_channels, 111 | num_qk_channels=num_self_attention_qk_channels, 112 | num_v_channels=num_self_attention_v_channels, 113 | widening_factor=self_attention_widening_factor, 114 | dropout=dropout, 115 | activation_checkpointing=activation_checkpointing, 116 | ) 117 | 118 | 119 | # stage1 120 | x_channels_stage1in = cnn.layer1[0].conv1.in_channels 121 | x_channels_stage2in = cnn.layer2[0].conv1.in_channels 122 | x_channels_stage3in = cnn.layer3[0].conv1.in_channels 123 | x_channels_stage4in = cnn.layer4[0].conv1.in_channels 124 | x_channels_stage4out = cnn.layer4[-1].conv3.out_channels if isinstance(cnn.layer4[-1], Bottleneck) else cnn.layer4[-1].conv2.out_channels 125 | z_channels = [x_channels_stage1in, x_channels_stage2in, x_channels_stage3in, x_channels_stage4in] 126 | # print(z_channels) 127 | # assert(0==1) 128 | if num_latent_channels is None: 129 | num_latent_channels = x_channels_stage1in 130 | self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels)) 131 | 132 | self.with_x2z = with_x2z 133 | self.with_z2x = with_z2x 134 | self.with_dwc = with_dwc 135 | 136 | if with_dwc: 137 | self.dwc1_x2z = nn.Conv2d(in_channels=x_channels_stage1in, out_channels=x_channels_stage1in, kernel_size=7, 138 | groups=x_channels_stage1in, stride=1, padding=3) 139 | feat_hw = 7 if spatial_reduction else input_size//2 140 | 141 | # essential 142 | self.cross_att1_x2z = cross_attn(num_cross_attention_heads=1, 143 | q_input_channels=x_channels_stage1in, 144 | kv_input_channels=x_channels_stage1in, 145 | num_cross_attention_qk_channels=None, 146 | num_cross_attention_v_channels=None, 147 | cross_attention_widening_factor=1, 148 | 149 | rpb=True, 150 | feat_w=feat_hw, 151 | feat_h=feat_hw, 152 | ) 153 | self.self_att1 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[0], 154 | num_self_attention_heads=num_SA_heads[0], 155 | num_channels=x_channels_stage1in, 156 | num_self_attention_qk_channels=None, 157 | num_self_attention_v_channels=None, 158 | self_attention_widening_factor=SA_widening_factor 159 | ) 160 | 161 | # stage2 162 | if with_x2z: 163 | if with_dwc: 164 | self.dwc2_x2z = nn.Conv2d(in_channels=x_channels_stage2in, out_channels=x_channels_stage2in, kernel_size=7, groups=x_channels_stage2in, stride=1, padding=3) 165 | feat_hw = 7 if spatial_reduction else input_size//4 166 | self.cross_att2_x2z = cross_attn(num_cross_attention_heads=1, 167 | q_input_channels=x_channels_stage2in, 168 | kv_input_channels=x_channels_stage2in, 169 | num_cross_attention_qk_channels=None, 170 | num_cross_attention_v_channels=None, 171 | cross_attention_widening_factor=1, 172 | 173 | rpb=True, 174 | feat_w=feat_hw, 175 | feat_h=feat_hw, 176 | ) 177 | 178 | if with_z2x: 179 | self.cross_att2_z2x = cross_attn(num_cross_attention_heads=1, 180 | q_input_channels=x_channels_stage2in, 181 | kv_input_channels=x_channels_stage2in, 182 | num_cross_attention_qk_channels=x_channels_stage2in//8, 183 | num_cross_attention_v_channels=x_channels_stage2in//8, 184 | cross_attention_widening_factor=1 185 | ) 186 | self.self_att2 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[1], 187 | num_self_attention_heads=num_SA_heads[1], 188 | num_channels=x_channels_stage2in, 189 | num_self_attention_qk_channels=None, 190 | num_self_attention_v_channels=None, 191 | self_attention_widening_factor=SA_widening_factor 192 | ) 193 | 194 | # stage3 195 | if with_x2z: 196 | if with_dwc: 197 | self.dwc3_x2z = nn.Conv2d(in_channels=x_channels_stage3in, out_channels=x_channels_stage3in, kernel_size=7, groups=x_channels_stage3in, stride=1, padding=3) 198 | feat_hw = 7 if spatial_reduction else input_size//8 199 | self.cross_att3_x2z = cross_attn(num_cross_attention_heads=1, 200 | q_input_channels=x_channels_stage3in, 201 | kv_input_channels=x_channels_stage3in, 202 | num_cross_attention_qk_channels=None, 203 | num_cross_attention_v_channels=None, 204 | cross_attention_widening_factor=1, 205 | 206 | rpb=True, 207 | feat_w=feat_hw, 208 | feat_h=feat_hw 209 | ) 210 | 211 | if with_z2x: 212 | self.cross_att3_z2x = cross_attn(num_cross_attention_heads=1, 213 | q_input_channels=x_channels_stage3in, 214 | kv_input_channels=x_channels_stage3in, 215 | num_cross_attention_qk_channels=x_channels_stage3in//8, 216 | num_cross_attention_v_channels=x_channels_stage3in//8, 217 | cross_attention_widening_factor=1 218 | ) 219 | self.self_att3 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[2], 220 | num_self_attention_heads=num_SA_heads[2], 221 | num_channels=x_channels_stage3in, 222 | num_self_attention_qk_channels=None, 223 | num_self_attention_v_channels=None, 224 | self_attention_widening_factor=SA_widening_factor 225 | ) 226 | 227 | # stage4 228 | if with_x2z: 229 | if with_dwc: 230 | self.dwc4_x2z = nn.Conv2d(in_channels=x_channels_stage4in, out_channels=x_channels_stage4in, kernel_size=7, groups=x_channels_stage4in, stride=1, padding=3) 231 | feat_hw = 7 if spatial_reduction else input_size//16 232 | self.cross_att4_x2z = cross_attn(num_cross_attention_heads=1, 233 | q_input_channels=x_channels_stage4in, 234 | kv_input_channels=x_channels_stage4in, 235 | num_cross_attention_qk_channels=None, 236 | num_cross_attention_v_channels=None, 237 | cross_attention_widening_factor=1, 238 | 239 | rpb=True, 240 | feat_w=feat_hw, 241 | feat_h=feat_hw 242 | ) 243 | 244 | if with_z2x: 245 | # print(x_channels_stage4in) 246 | self.cross_att4_z2x = cross_attn(num_cross_attention_heads=1, 247 | q_input_channels=x_channels_stage4in, 248 | kv_input_channels=x_channels_stage4in, 249 | num_cross_attention_qk_channels=x_channels_stage4in//8, 250 | num_cross_attention_v_channels=x_channels_stage4in//8, 251 | cross_attention_widening_factor=1 252 | ) 253 | # print(num_blocks_per_stage[3]) 254 | self.self_att4 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[3], 255 | num_self_attention_heads=num_SA_heads[3], 256 | num_channels=x_channels_stage4in, 257 | num_self_attention_qk_channels=None, 258 | num_self_attention_v_channels=None, 259 | self_attention_widening_factor=SA_widening_factor 260 | ) 261 | 262 | # last cross attention 263 | # print(x_channels_stage4out//8) 264 | # print(x_channels_stage4out, x_channels_stage4in) 265 | self.last_cross_att_z2x = cross_attn(num_cross_attention_heads=1, 266 | q_input_channels=x_channels_stage4out, 267 | kv_input_channels=x_channels_stage4in, 268 | num_cross_attention_qk_channels=x_channels_stage4out//8, 269 | num_cross_attention_v_channels=x_channels_stage4out//8, 270 | cross_attention_widening_factor=1 271 | ) if with_last_CA else None 272 | 273 | # self.classifier_cnn = cnn.classifier 274 | # self.classifier_flops = cnn.classifier_flops 275 | # print(x_channels_stage1in, x_channels_stage2in, x_channels_stage3in, x_channels_stage4in) 276 | # self.early_classifier1 = nn.Linear(x_channels_stage1in, num_classes) 277 | # self.early_classifier2 = nn.Linear(x_channels_stage2in, num_classes) 278 | 279 | self.early_classifier3 = nn.Linear(x_channels_stage3in, num_classes) 280 | self.with_isc = with_isc 281 | 282 | if not with_isc: 283 | self.classifier_att = nn.Linear(x_channels_stage4in, num_classes) 284 | 285 | self.classifier_merge = nn.Sequential( 286 | nn.BatchNorm1d(cnn.num_features + x_channels_stage4in), 287 | nn.Linear(cnn.num_features + x_channels_stage4in, num_classes) 288 | ) 289 | else: 290 | 291 | self.isc3 = nn.Sequential(nn.Linear(num_classes, x_channels_stage4in), 292 | nn.BatchNorm1d(x_channels_stage4in), 293 | nn.ReLU(inplace=True) 294 | ) 295 | 296 | self.classifier_att = nn.Linear(2*x_channels_stage4in, num_classes) 297 | 298 | self.isc4 = nn.Sequential(nn.Linear(num_classes, x_channels_stage4in), 299 | nn.BatchNorm1d(x_channels_stage4in), 300 | nn.ReLU(inplace=True) 301 | ) 302 | self.classifier_merge = nn.Sequential( 303 | nn.BatchNorm1d(cnn.num_features + 2*x_channels_stage4in), 304 | nn.Linear(cnn.num_features + 2*x_channels_stage4in, num_classes) 305 | ) 306 | 307 | expander = [] 308 | token_mixer = [] 309 | 310 | num_latents_list = [num_latents, num_latents//2, num_latents//4, num_latents//8] 311 | for i in range(3): 312 | c_in = z_channels[i] 313 | c_out = z_channels[i+1] 314 | expander.append(nn.Sequential( 315 | nn.LayerNorm(c_in), 316 | nn.Linear(c_in, c_out) 317 | )) 318 | n_z_in = num_latents_list[i] 319 | n_z_out = num_latents_list[i+1] 320 | token_mixer.append(nn.Sequential( 321 | nn.LayerNorm(n_z_in), 322 | nn.Linear(n_z_in, n_z_out) 323 | )) 324 | 325 | self.token_expander = nn.ModuleList(expander) 326 | self.token_mixer = nn.ModuleList(token_mixer) 327 | self.output_dir = output_dir 328 | self._init_parameters() 329 | 330 | 331 | def _init_parameters(self): 332 | with torch.no_grad(): 333 | self.latent.normal_(0.0, 0.02).clamp_(-2.0, 2.0) 334 | 335 | def forward(self, x, pad_mask=None): 336 | exit = self.exit 337 | #TODO 338 | b, c_in, _, _ = x.shape 339 | x_latent = repeat(self.latent, "... -> b ...", b=b) 340 | 341 | x = self.cnn_stem(x) 342 | # print(x.shape) 343 | 344 | if self.with_dwc: 345 | x_kv = self.dwc1_x2z(x) + x 346 | x_kv = self.ca_pooling(x_kv) 347 | else: 348 | x_kv = self.ca_pooling(x) 349 | x_kv = rearrange(x_kv, "b c ... -> b (...) c") 350 | # print(x_latent.shape, x_kv.shape) 351 | x_latent = self.cross_att1_x2z(x_latent, x_kv, pad_mask) 352 | 353 | # stage1, conv and self attention 354 | x_latent = self.self_att1(x_latent) 355 | 356 | # y_early1 = torch.mean(x_latent, dim=1).squeeze(1) 357 | # y_early1 = self.early_classifier1(y_early1) 358 | 359 | 360 | x = self.cnn_body_stage1(x) 361 | 362 | # between stage1 and stage2 363 | _, n_tokens, c_in = x_latent.shape 364 | x_latent = x_latent.permute(0,2,1) 365 | x_latent = self.token_mixer[0](x_latent) 366 | x_latent = x_latent.permute(0,2,1) 367 | x_latent = self.token_expander[0](x_latent) 368 | 369 | # transformer to conv 370 | if self.with_z2x: 371 | _,_,h,w = x.shape 372 | x = rearrange(x, "b c ... -> b (...) c") 373 | x = self.cross_att2_z2x(x, x_latent, pad_mask) 374 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 375 | 376 | # conv to transformer 377 | # print(x.shape) 378 | if self.with_x2z: 379 | if self.with_dwc: 380 | x_kv = self.dwc2_x2z(x) + x 381 | x_kv = self.ca_pooling(x_kv) 382 | else: 383 | x_kv = self.ca_pooling(x) 384 | x_kv = rearrange(x_kv, "b c ... -> b (...) c") 385 | x_latent = self.cross_att2_x2z(x_latent, x_kv, pad_mask) 386 | 387 | # stage2 388 | x_latent = self.self_att2(x_latent) 389 | # y_early2 = torch.mean(x_latent, dim=1).squeeze(1) 390 | # y_early2 = self.early_classifier2(y_early2) 391 | 392 | x = self.cnn_body_stage2(x) 393 | 394 | # between stage2 and stage3 395 | _, n_tokens, c_in = x_latent.shape 396 | x_latent = x_latent.permute(0,2,1) 397 | x_latent = self.token_mixer[1](x_latent) 398 | x_latent = x_latent.permute(0,2,1) 399 | x_latent = self.token_expander[1](x_latent) 400 | 401 | # transformer to conv 402 | if self.with_z2x: 403 | _,_,h,w = x.shape 404 | x = rearrange(x, "b c ... -> b (...) c") 405 | x = self.cross_att3_z2x(x, x_latent, pad_mask) 406 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 407 | 408 | # conv to transformer 409 | # print(x.shape) 410 | if self.with_x2z: 411 | if self.with_dwc: 412 | x_kv = self.dwc3_x2z(x) + x 413 | x_kv = self.ca_pooling(x_kv) 414 | else: 415 | x_kv = self.ca_pooling(x) 416 | x_kv = rearrange(x_kv, "b c ... -> b (...) c") 417 | x_latent = self.cross_att3_x2z(x_latent, x_kv, pad_mask) 418 | 419 | # stage3 420 | x_latent = self.self_att3(x_latent) 421 | y_early3 = torch.mean(x_latent, dim=1).squeeze(1) 422 | y_early3 = self.early_classifier3(y_early3) 423 | if exit == 0: 424 | return y_early3 425 | 426 | x = self.cnn_body_stage3(x) 427 | 428 | # between stage3 and stage4 429 | _, n_tokens, c_in = x_latent.shape 430 | x_latent = x_latent.permute(0,2,1) 431 | x_latent = self.token_mixer[2](x_latent) 432 | x_latent = x_latent.permute(0,2,1) 433 | x_latent = self.token_expander[2](x_latent) 434 | 435 | # transformer to conv 436 | if self.with_z2x: 437 | _,_,h,w = x.shape 438 | x = rearrange(x, "b c ... -> b (...) c") 439 | # print(x_latent.shape, x.shape) 440 | x = self.cross_att4_z2x(x, x_latent, pad_mask) 441 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 442 | 443 | 444 | # conv to transformer 445 | # print(x.shape) 446 | if self.with_x2z: 447 | 448 | if self.with_dwc: 449 | x_kv = self.dwc4_x2z(x) + x 450 | x_kv = self.ca_pooling(x_kv) 451 | else: 452 | x_kv = self.ca_pooling(x) 453 | x_kv = rearrange(x_kv, "b c ... -> b (...) c") 454 | x_latent = self.cross_att4_x2z(x_latent, x_kv, pad_mask) 455 | 456 | # stage4 457 | x_latent = self.self_att4(x_latent) 458 | 459 | x_latent_mean = torch.mean(x_latent, dim=1).squeeze(1) 460 | if self.with_isc: 461 | y3_ = self.isc3(y_early3) 462 | y_att = torch.cat((x_latent_mean, y3_), dim=1) 463 | y_att = self.classifier_att(y_att) 464 | else: 465 | y_att = self.classifier_att(x_latent_mean) 466 | if exit == 1: 467 | return y_att 468 | 469 | x = self.cnn_body_stage4(x) 470 | # print(x.shape) 471 | 472 | # x = self.cnn_body_last_conv1x1(x) 473 | x_mean = self.avgpool(x) 474 | # x_mean = self.cnn_head_before_cls(x_mean) 475 | # x_mean = self.flatten_cnn(x_mean) 476 | 477 | if self.drop_rate_cnn > 0.: 478 | x_mean = F.dropout(x_mean, p=self.drop_rate_cnn, training=self.training) 479 | 480 | # print(x_mean.shape) 481 | y_cnn = self.classifier_cnn(x_mean) 482 | if exit == 2: 483 | return y_cnn 484 | 485 | if self.last_cross_att_z2x is not None: 486 | _,_,h,w = x.shape 487 | x = rearrange(x, "b c ... -> b (...) c") 488 | x = self.last_cross_att_z2x(x, x_latent, pad_mask) 489 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 490 | x_mean = self.avgpool(x) 491 | # x_mean = self.cnn_head_before_cls(x_mean) 492 | # x_mean = self.flatten_cnn(x_mean) 493 | 494 | if self.with_isc: 495 | y4_ = self.isc4(y_att) 496 | x_merge = torch.cat((x_mean, x_latent_mean, y4_), dim=1) 497 | y_merge = self.classifier_merge(x_merge) 498 | else: 499 | x_merge = torch.cat((x_mean, x_latent_mean), dim=1) 500 | y_merge = self.classifier_merge(x_merge) 501 | 502 | return y_early3, y_att, y_cnn, y_merge 503 | 504 | 505 | @register_model 506 | def resnet18_perceiver_t128(**kwargs): 507 | model = DynPerceiver(num_latents=128, cnn_arch='resnet18', **kwargs) 508 | return model 509 | 510 | 511 | @register_model 512 | def resnet18_perceiver_t256(**kwargs): 513 | model = DynPerceiver(num_latents=256, cnn_arch='resnet18', **kwargs) 514 | return model 515 | 516 | 517 | @register_model 518 | def resnet18_perceiver_t512(**kwargs): 519 | model = DynPerceiver(num_latents=512, cnn_arch='resnet18', **kwargs) 520 | return model 521 | 522 | 523 | @register_model 524 | def resnet34_perceiver_t128(**kwargs): 525 | model = DynPerceiver(num_latents=128, cnn_arch='resnet34', **kwargs) 526 | return model 527 | 528 | 529 | @register_model 530 | def resnet34_perceiver_t256(**kwargs): 531 | model = DynPerceiver(num_latents=256, cnn_arch='resnet34', **kwargs) 532 | return model 533 | 534 | 535 | @register_model 536 | def resnet34_perceiver_t512(**kwargs): 537 | model = DynPerceiver(num_latents=512, cnn_arch='resnet34', **kwargs) 538 | return model 539 | 540 | 541 | @register_model 542 | def resnet50_perceiver_t64(**kwargs): 543 | model = DynPerceiver(num_latents=64, cnn_arch='resnet50', **kwargs) 544 | return model 545 | 546 | 547 | @register_model 548 | def resnet50_perceiver_t128(**kwargs): 549 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50', **kwargs) 550 | return model 551 | 552 | 553 | @register_model 554 | def resnet50_perceiver_t256(**kwargs): 555 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50', **kwargs) 556 | return model 557 | 558 | 559 | @register_model 560 | def resnet50_perceiver_t512(**kwargs): 561 | model = DynPerceiver(num_latents=512, cnn_arch='resnet50', **kwargs) 562 | return model 563 | 564 | 565 | @register_model 566 | def resnet50_050_perceiver_t64(**kwargs): 567 | model = DynPerceiver(num_latents=64, cnn_arch='resnet50_050', **kwargs) 568 | return model 569 | 570 | 571 | @register_model 572 | def resnet50_050_perceiver_t128(**kwargs): 573 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_050', **kwargs) 574 | return model 575 | 576 | 577 | @register_model 578 | def resnet50_050_perceiver_t256(**kwargs): 579 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_050', **kwargs) 580 | return model 581 | 582 | 583 | @register_model 584 | def resnet50_0375_perceiver_t64(**kwargs): 585 | model = DynPerceiver(num_latents=64, cnn_arch='resnet50_0375', **kwargs) 586 | return model 587 | 588 | 589 | @register_model 590 | def resnet50_0375_perceiver_t128(**kwargs): 591 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_0375', **kwargs) 592 | return model 593 | 594 | 595 | @register_model 596 | def resnet50_0375_perceiver_t256(**kwargs): 597 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_0375', **kwargs) 598 | return model 599 | 600 | 601 | @register_model 602 | def resnet50_0625_perceiver_t128(**kwargs): 603 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_0625', **kwargs) 604 | return model 605 | 606 | 607 | @register_model 608 | def resnet50_0625_perceiver_t160(**kwargs): 609 | model = DynPerceiver(num_latents=160, cnn_arch='resnet50_0625', **kwargs) 610 | return model 611 | 612 | 613 | @register_model 614 | def resnet50_0625_perceiver_t192(**kwargs): 615 | model = DynPerceiver(num_latents=192, cnn_arch='resnet50_0625', **kwargs) 616 | return model 617 | 618 | 619 | @register_model 620 | def resnet50_0625_perceiver_t256(**kwargs): 621 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_0625', **kwargs) 622 | return model 623 | 624 | 625 | @register_model 626 | def resnet50_075_perceiver_t128(**kwargs): 627 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_075', **kwargs) 628 | return model 629 | 630 | 631 | @register_model 632 | def resnet50_075_perceiver_t256(**kwargs): 633 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_075', **kwargs) 634 | return model 635 | 636 | 637 | def dyn_flops(): 638 | x = torch.rand(1, 3, 224, 224) 639 | result = [] 640 | for i in range(4,5): 641 | model = resnet50_0375_perceiver_t256( 642 | depth_factor=[1,1,1,1], SA_widening_factor=4, spatial_reduction=True, 643 | with_last_CA=True, with_x2z=True, with_dwc=True, with_z2x=True, exit=i) 644 | model.eval() 645 | from fvcore.nn import FlopCountAnalysis 646 | flops = FlopCountAnalysis(model, x) 647 | result.append(flops.total() / 1e9) 648 | print('***************************') 649 | for flop in result: 650 | print(flop) 651 | print('***************************') 652 | 653 | 654 | if __name__ == '__main__': 655 | 656 | dyn_flops() 657 | 658 | # # Fourier-encodes pixel positions and flatten along spatial dimensions 659 | # input_adapter = ImageInputAdapter( 660 | # image_shape=(224, 224, 3), # M = 224 * 224 661 | # num_frequency_bands=64, 662 | # ) 663 | 664 | # # Projects generic Perceiver decoder output to specified number of classes 665 | # output_adapter = ClassificationOutputAdapter( 666 | # num_classes=1000, 667 | # num_output_query_channels=1024, # F 668 | # ) 669 | 670 | # # Generic Perceiver encoder 671 | # encoder = PerceiverEncoder( 672 | # input_adapter=input_adapter, 673 | # num_latents=512, # N 674 | # num_latent_channels=1024, # D 675 | # num_cross_attention_qk_channels=input_adapter.num_input_channels, # C 676 | # num_cross_attention_heads=1, 677 | # num_self_attention_heads=4, 678 | # num_self_attention_layers_per_block=6, 679 | # num_self_attention_blocks=8, 680 | # dropout=0.0, 681 | # ) 682 | 683 | # # Generic Perceiver decoder 684 | # decoder = PerceiverDecoder( 685 | # output_adapter=output_adapter, 686 | # num_latent_channels=1024, # D 687 | # num_cross_attention_heads=1, 688 | # dropout=0.0, 689 | # ) 690 | 691 | # # Perceiver IO image classifier 692 | # model = PerceiverIO(encoder, decoder) 693 | # model.eval() 694 | # print(model) 695 | # x = torch.rand(4,224,224,3) 696 | # with torch.no_grad(): 697 | # y = model(x) 698 | 699 | # print(y.shape) 700 | 701 | 702 | 703 | 704 | 705 | # regnet = regnet_y_400mf() 706 | # print(regnet) 707 | # print(regnet.trunk_output.block1) 708 | 709 | def count_parameters(model): 710 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 711 | 712 | model = resnet50_perceiver_t128(depth_factor=[1,1,1,3], SA_widening_factor=4, spatial_reduction=True, with_last_CA=True, 713 | with_x2z=True, with_dwc=True, with_z2x=True) 714 | model.eval() 715 | print(count_parameters(model)/1e6) 716 | x = torch.rand(1,3,224,224) 717 | with torch.no_grad(): 718 | y = model(x) 719 | # print(y.shape) 720 | # print() 721 | # flops = FlopCountAnalysis(model, x) 722 | # print("FLOPs: ", flops.total()/1e9) 723 | # from fvcore.nn import flop_count_str 724 | # print(flop_count_str(flops)) 725 | # # 分析parameters 726 | # print(parameter_count_table(model)) -------------------------------------------------------------------------------- /models/perceiver_core/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange, repeat 6 | # from fairscale.nn import checkpoint_wrapper 7 | from torch import Tensor 8 | 9 | from .utils import Sequential 10 | from timm.models.layers import trunc_normal_ 11 | 12 | class MultiHeadAttention(nn.Module): 13 | def __init__( 14 | self, 15 | num_heads: int, 16 | num_q_input_channels: int, 17 | num_kv_input_channels: int, 18 | num_qk_channels: Optional[int] = None, 19 | num_v_channels: Optional[int] = None, 20 | num_output_channels: Optional[int] = None, 21 | dropout: float = 0.0, 22 | rpb=False, 23 | feat_w=32, 24 | feat_h=32 25 | ): 26 | """Multi-head attention as described in https://arxiv.org/abs/2107.14795 Appendix E. 27 | 28 | :param num_heads: Number of attention heads. 29 | :param num_q_input_channels: Number of query input channels. 30 | :param num_kv_input_channels: Number of key/value input channels. 31 | :param num_qk_channels: Number of channels query and key input channels are projected to, 32 | for computing the attention matrix. Defaults to number `num_q_input_channels` 33 | :param num_v_channels: Number of channels value input channels are projected to. 34 | Defaults to `num_qk_channels`. 35 | :param num_output_channels: Number of output channels attention result channels are projected to. 36 | Defaults to `num_q_input_channels` 37 | :param dropout: Dropout probability for attention matrix values. Defaults to `0.0` 38 | """ 39 | super().__init__() 40 | 41 | if num_qk_channels is None: 42 | num_qk_channels = num_q_input_channels 43 | 44 | if num_v_channels is None: 45 | num_v_channels = num_qk_channels 46 | 47 | if num_output_channels is None: 48 | num_output_channels = num_q_input_channels 49 | 50 | if num_qk_channels % num_heads != 0: 51 | raise ValueError("num_qk_channels must be divisible by num_heads") 52 | 53 | if num_v_channels % num_heads != 0: 54 | raise ValueError("num_v_channels must be divisible by num_heads") 55 | 56 | 57 | # print(f'attention, qk channels {num_qk_channels}, v channels {num_v_channels}') 58 | 59 | 60 | self.rpb=rpb 61 | 62 | if rpb: 63 | self.feat_w=feat_w 64 | self.feat_h=feat_h 65 | self.relative_position_bias = nn.Parameter( 66 | torch.zeros((1, self.feat_w*self.feat_h, num_qk_channels)) 67 | ) # same size as k (feature map) 68 | 69 | num_qk_channels_per_head = num_qk_channels // num_heads 70 | 71 | self.dp_scale = num_qk_channels_per_head ** -0.5 72 | self.num_heads = num_heads 73 | 74 | self.q_proj = nn.Linear(num_q_input_channels, num_qk_channels, bias=False) 75 | self.k_proj = nn.Linear(num_kv_input_channels, num_qk_channels, bias=False) 76 | self.v_proj = nn.Linear(num_kv_input_channels, num_v_channels, bias=False) 77 | self.o_proj = nn.Linear(num_v_channels, num_output_channels) 78 | self.dropout = nn.Dropout(dropout) 79 | 80 | def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None): 81 | """ 82 | :param x_q: Query input of shape (B, N, D) where B is the batch size, N the query sequence length 83 | and D the number of query input channels (= `num_q_input_channels`) 84 | :param x_kv: Key/value input of shape (B, L, C) where B is the batch size, L the key/value sequence 85 | length and C are the number of key/value input channels (= `num_kv_input_channels`) 86 | :param pad_mask: Boolean key padding mask. `True` values indicate padding tokens. 87 | :param attn_mask: Boolean attention mask. Not needed/supported yet. 88 | :return: attention result of shape (B, N, F) where B is the batch size, N the query sequence length 89 | and F the number of output channels (= `num_output_channels`) 90 | """ 91 | if attn_mask is not None: 92 | raise NotImplementedError("attention masks not supported yet") 93 | 94 | # flops = 0 95 | b = x_q.shape[0] 96 | q_in_channels = x_q.shape[-1] 97 | kv_in_channels = x_kv.shape[-1] 98 | 99 | q = self.q_proj(x_q) 100 | k = self.k_proj(x_kv) 101 | v = self.v_proj(x_kv) 102 | # print("x_q.shape:", x_q.shape) 103 | # flops += q_in_channels * q.shape[-1] * q.shape[-2] + 2*kv_in_channels * k.shape[-1] * k.shape[-2] 104 | # print(k.shape) 105 | if self.rpb: 106 | k += self.relative_position_bias 107 | 108 | q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v]) 109 | attn = torch.einsum("b i c, b j c -> b i j", q, k) * self.dp_scale 110 | # attn = q @ k.transpose(-2, -1) * self.dp_scale 111 | # print(attn.shape) 112 | # flops += q.shape[-1] * q.shape[-2] * k.shape[-2] * self.num_heads 113 | 114 | if pad_mask is not None: 115 | pad_mask = repeat(pad_mask, "b j -> (b h) () j", h=self.num_heads) 116 | attn_max_neg = -torch.finfo(attn.dtype).max 117 | attn.masked_fill_(pad_mask, attn_max_neg) 118 | 119 | attn = attn.softmax(dim=-1) 120 | attn = self.dropout(attn) 121 | 122 | o = torch.einsum("b i j, b j c -> b i c", attn, v) 123 | # o = attn @ v 124 | # print(attn.shape, v.shape, o.shape) 125 | # flops += attn.shape[-2] * v.shape[-2] * v.shape[-1] * self.num_heads 126 | o = rearrange(o, "(b h) n c -> b n (h c)", h=self.num_heads) 127 | in_channels = o.shape[-1] 128 | o = self.o_proj(o) 129 | # print(o.shape) 130 | # flops += in_channels * o.shape[-1] * o.shape[-2] 131 | return o 132 | 133 | def forward_calc_flops(self, x_q, x_kv, pad_mask=None, attn_mask=None): 134 | """ 135 | :param x_q: Query input of shape (B, N, D) where B is the batch size, N the query sequence length 136 | and D the number of query input channels (= `num_q_input_channels`) 137 | :param x_kv: Key/value input of shape (B, L, C) where B is the batch size, L the key/value sequence 138 | length and C are the number of key/value input channels (= `num_kv_input_channels`) 139 | :param pad_mask: Boolean key padding mask. `True` values indicate padding tokens. 140 | :param attn_mask: Boolean attention mask. Not needed/supported yet. 141 | :return: attention result of shape (B, N, F) where B is the batch size, N the query sequence length 142 | and F the number of output channels (= `num_output_channels`) 143 | """ 144 | if attn_mask is not None: 145 | raise NotImplementedError("attention masks not supported yet") 146 | 147 | flops = 0 148 | b = x_q.shape[0] 149 | q_in_channels = x_q.shape[-1] 150 | kv_in_channels = x_kv.shape[-1] 151 | 152 | q = self.q_proj(x_q) 153 | k = self.k_proj(x_kv) 154 | v = self.v_proj(x_kv) 155 | # print("x_q.shape:", x_q.shape) 156 | flops += q_in_channels * q.shape[-1] * q.shape[-2] + 2*kv_in_channels * k.shape[-1] * k.shape[-2] 157 | # print(k.shape) 158 | if self.rpb: 159 | k += self.relative_position_bias 160 | 161 | q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v]) 162 | attn = torch.einsum("b i c, b j c -> b i j", q, k) * self.dp_scale 163 | # attn = q @ k.transpose(-2, -1) * self.dp_scale 164 | # print(attn.shape) 165 | flops += q.shape[-1] * q.shape[-2] * k.shape[-2] * self.num_heads 166 | 167 | if pad_mask is not None: 168 | pad_mask = repeat(pad_mask, "b j -> (b h) () j", h=self.num_heads) 169 | attn_max_neg = -torch.finfo(attn.dtype).max 170 | attn.masked_fill_(pad_mask, attn_max_neg) 171 | 172 | attn = attn.softmax(dim=-1) 173 | attn = self.dropout(attn) 174 | 175 | o = torch.einsum("b i j, b j c -> b i c", attn, v) 176 | # o = attn @ v 177 | # print(attn.shape, v.shape, o.shape) 178 | flops += attn.shape[-2] * v.shape[-2] * v.shape[-1] * self.num_heads 179 | o = rearrange(o, "(b h) n c -> b n (h c)", h=self.num_heads) 180 | in_channels = o.shape[-1] 181 | o = self.o_proj(o) 182 | # print(o.shape) 183 | flops += in_channels * o.shape[-1] * o.shape[-2] 184 | return o, flops 185 | 186 | class CrossAttention(nn.Module): 187 | def __init__( 188 | self, 189 | num_heads: int, 190 | num_q_input_channels: int, 191 | num_kv_input_channels: int, 192 | num_qk_channels: Optional[int] = None, 193 | num_v_channels: Optional[int] = None, 194 | dropout: float = 0.0, 195 | 196 | rpb=False, 197 | feat_w=32, 198 | feat_h=32 199 | ): 200 | """Multi-head cross-attention (see `MultiHeadAttention` for details).""" 201 | super().__init__() 202 | 203 | self.q_norm = nn.LayerNorm(num_q_input_channels) 204 | self.kv_norm = nn.LayerNorm(num_kv_input_channels) 205 | self.attention = MultiHeadAttention( 206 | num_heads=num_heads, 207 | num_q_input_channels=num_q_input_channels, 208 | num_kv_input_channels=num_kv_input_channels, 209 | num_qk_channels=num_qk_channels, 210 | num_v_channels=num_v_channels, 211 | dropout=dropout, 212 | 213 | rpb=rpb, 214 | feat_w=feat_w, 215 | feat_h=feat_h 216 | ) 217 | 218 | def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None): 219 | """Multi-head attention of query input `x_q` to key/value input (`x_kv`) after (separately) applying layer 220 | normalization to these inputs.""" 221 | x_q = self.q_norm(x_q) 222 | x_kv = self.kv_norm(x_kv) 223 | return self.attention(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask) 224 | 225 | def forward_calc_flops(self, x_q, x_kv, pad_mask=None, attn_mask=None): 226 | """Multi-head attention of query input `x_q` to key/value input (`x_kv`) after (separately) applying layer 227 | normalization to these inputs.""" 228 | x_q = self.q_norm(x_q) 229 | x_kv = self.kv_norm(x_kv) 230 | return self.attention.forward_calc_flops(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask) 231 | 232 | 233 | class SelfAttention(nn.Module): 234 | def __init__( 235 | self, 236 | num_heads: int, 237 | num_channels: int, 238 | num_qk_channels: Optional[int] = None, 239 | num_v_channels: Optional[int] = None, 240 | dropout: float = 0.0, 241 | ): 242 | """Multi-head self-attention (see `MultiHeadAttention` and for details).""" 243 | super().__init__() 244 | self.norm = nn.LayerNorm(num_channels) 245 | self.attention = MultiHeadAttention( 246 | num_heads=num_heads, 247 | num_q_input_channels=num_channels, 248 | num_kv_input_channels=num_channels, 249 | num_qk_channels=num_qk_channels, 250 | num_v_channels=num_v_channels, 251 | dropout=dropout, 252 | ) 253 | 254 | def forward(self, x, pad_mask=None, attn_mask=None): 255 | """Multi-head attention of input `x` to itself after applying layer normalization to the input.""" 256 | x = self.norm(x) 257 | return self.attention(x, x, pad_mask=pad_mask, attn_mask=attn_mask) 258 | 259 | def forward_calc_flops(self, x, pad_mask=None, attn_mask=None): 260 | """Multi-head attention of input `x` to itself after applying layer normalization to the input.""" 261 | x = self.norm(x) 262 | return self.attention.forward_calc_flops(x, x, pad_mask=pad_mask, attn_mask=attn_mask) 263 | 264 | 265 | class CrossAttentionLayer(nn.Module): 266 | def __init__( 267 | self, 268 | num_heads: int, 269 | num_q_input_channels: int, 270 | num_kv_input_channels: int, 271 | num_qk_channels: Optional[int] = None, 272 | num_v_channels: Optional[int] = None, 273 | widening_factor: int = 1, 274 | dropout: float = 0.0, 275 | attention_residual: bool = True, 276 | rpb=False, 277 | feat_w=32, 278 | feat_h=32 279 | ): 280 | # print(attention_residual) 281 | super().__init__( 282 | # Residual(cross_attn, dropout) if attention_residual else cross_attn, 283 | # Residual(MLP(num_q_input_channels, widening_factor), dropout), 284 | ) 285 | 286 | self.cross_attn = CrossAttention( 287 | num_heads=num_heads, 288 | num_q_input_channels=num_q_input_channels, 289 | num_kv_input_channels=num_kv_input_channels, 290 | num_qk_channels=num_qk_channels, 291 | num_v_channels=num_v_channels, 292 | dropout=dropout, 293 | rpb=rpb, 294 | feat_w=feat_w, 295 | feat_h=feat_h 296 | ) 297 | self.attention_residual = attention_residual 298 | self.mlp = MLP(num_q_input_channels, widening_factor) 299 | self.dropout = nn.Dropout(p=dropout) 300 | 301 | def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None): 302 | att_out = self.cross_attn(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask) 303 | 304 | out = x_q + self.dropout(att_out) if self.attention_residual else att_out 305 | 306 | mlp_out = self.mlp(out) 307 | out = out + self.dropout(mlp_out) 308 | 309 | # flops = att_flops + mlp_flops 310 | 311 | return out 312 | 313 | def forward_calc_flops(self, x_q, x_kv, pad_mask=None, attn_mask=None): 314 | att_out, att_flops = self.cross_attn.forward_calc_flops(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask) 315 | 316 | out = x_q + self.dropout(att_out) if self.attention_residual else att_out 317 | 318 | mlp_out, mlp_flops = self.mlp.forward_calc_flops(out) 319 | out = out + self.dropout(mlp_out) 320 | # print(att_flops/1e8, mlp_flops/1e8) 321 | flops = att_flops + mlp_flops 322 | 323 | return out, flops 324 | 325 | class SelfAttentionLayer(Sequential): 326 | def __init__( 327 | self, 328 | num_heads: int, 329 | num_channels: int, 330 | num_qk_channels: Optional[int] = None, 331 | num_v_channels: Optional[int] = None, 332 | widening_factor: int = 1, 333 | dropout: float = 0.0, 334 | ): 335 | 336 | super().__init__( 337 | # Residual(self_attn, dropout), 338 | # Residual(MLP(num_channels, widening_factor), dropout), 339 | ) 340 | self.self_attn = SelfAttention( 341 | num_heads=num_heads, 342 | num_channels=num_channels, 343 | num_qk_channels=num_qk_channels, 344 | num_v_channels=num_v_channels, 345 | dropout=dropout, 346 | ) 347 | self.mlp = MLP(num_channels, widening_factor) 348 | self.dropout = nn.Dropout(p=dropout) 349 | 350 | def forward(self, x, pad_mask=None, attn_mask=None): 351 | att_out = self.self_attn(x, pad_mask=pad_mask, attn_mask=attn_mask) 352 | out = x + self.dropout(att_out) 353 | mlp_out = self.mlp(out) 354 | out = out + self.dropout(mlp_out) 355 | return out 356 | 357 | def forward_calc_flops(self, x, pad_mask=None, attn_mask=None): 358 | att_out, att_flops = self.self_attn.forward_calc_flops(x, pad_mask=pad_mask, attn_mask=attn_mask) 359 | out = x + self.dropout(att_out) 360 | mlp_out, mlp_flops = self.mlp.forward_calc_flops(out) 361 | out = out + self.dropout(mlp_out) 362 | flops = att_flops + mlp_flops 363 | return out, flops 364 | 365 | class SelfAttentionBlock(nn.ModuleList): 366 | def __init__( 367 | self, 368 | num_layers: int, 369 | num_heads: int, 370 | num_channels: int, 371 | num_qk_channels: Optional[int] = None, 372 | num_v_channels: Optional[int] = None, 373 | widening_factor: int = 1, 374 | dropout: float = 0.0, 375 | activation_checkpointing: bool = False, 376 | ): 377 | layers = [ 378 | SelfAttentionLayer( 379 | num_heads=num_heads, 380 | num_channels=num_channels, 381 | num_qk_channels=num_qk_channels, 382 | num_v_channels=num_v_channels, 383 | widening_factor=widening_factor, 384 | dropout=dropout, 385 | ) 386 | for _ in range(num_layers) 387 | ] 388 | 389 | # if activation_checkpointing: 390 | # layers = [checkpoint_wrapper(layer) for layer in layers] 391 | 392 | super().__init__(layers) 393 | 394 | def forward(self, x, pad_mask=None, attn_mask=None): 395 | out = x 396 | for layer in self: 397 | out = layer(out, pad_mask=pad_mask, attn_mask=attn_mask) 398 | return out 399 | 400 | def forward_calc_flops(self, x, pad_mask=None, attn_mask=None): 401 | out = x 402 | flops = 0 403 | for layer in self: 404 | out, flops_layer = layer.forward_calc_flops(out, pad_mask=pad_mask, attn_mask=attn_mask) 405 | flops += flops_layer 406 | return out, flops 407 | 408 | 409 | class MLP(nn.Module): 410 | def __init__(self, num_channels: int, widening_factor: int): 411 | super().__init__( 412 | 413 | ) 414 | self.layernorm = nn.LayerNorm(num_channels) 415 | self.fc1 = nn.Linear(num_channels, widening_factor * num_channels) 416 | self.gelu = nn.GELU() 417 | self.fc2 = nn.Linear(widening_factor * num_channels, num_channels) 418 | 419 | def forward(self, x): 420 | x = self.layernorm(x) 421 | _, L, c = x.shape 422 | x = self.fc1(x) 423 | 424 | x = self.gelu(x) 425 | c = x.shape[-1] 426 | x = self.fc2(x) 427 | return x 428 | 429 | 430 | def forward_calc_flops(self, x): 431 | x = self.layernorm(x) 432 | _, L, c = x.shape 433 | x = self.fc1(x) 434 | flops = L * c * x.shape[-1] 435 | 436 | x = self.gelu(x) 437 | c = x.shape[-1] 438 | x = self.fc2(x) 439 | flops += L * c * x.shape[-1] 440 | return x, flops 441 | 442 | 443 | class Residual(nn.Module): 444 | def __init__(self, module: nn.Module, dropout: float): 445 | super().__init__() 446 | self.module = module 447 | self.dropout = nn.Dropout(p=dropout) 448 | self.dropout_p = dropout 449 | 450 | def forward(self, *args, **kwargs): 451 | # print(*args, **kwargs) 452 | # print(f'args:') 453 | # print(*args) 454 | # print(f'kwargs:') 455 | # print(**kwargs) 456 | x = self.module(*args, **kwargs) 457 | return self.dropout(x) + args[0] 458 | 459 | 460 | class InputAdapter(nn.Module): 461 | def __init__(self, num_input_channels: int): 462 | """Transforms and position-encodes task-specific input to generic encoder input. 463 | 464 | :param num_input_channels: Number of channels of the generic encoder input produced by this adapter. 465 | """ 466 | super().__init__() 467 | self._num_input_channels = num_input_channels 468 | 469 | @property 470 | def num_input_channels(self): 471 | return self._num_input_channels 472 | 473 | def forward(self, x): 474 | raise NotImplementedError() 475 | 476 | 477 | class OutputAdapter(nn.Module): 478 | def __init__(self, output_query: Tensor): 479 | """Transforms generic decoder cross-attention output to task-specific output. 480 | 481 | :param output_query: output query prototype (does not include batch dimension) used as query input to 482 | generic decoder cross-attention. 483 | """ 484 | super().__init__() 485 | self._output_query = nn.Parameter(output_query) 486 | self._init_parameters() 487 | 488 | def _init_parameters(self): 489 | with torch.no_grad(): 490 | self._output_query.normal_(0.0, 0.02).clamp_(-2.0, 2.0) 491 | 492 | @property 493 | def num_output_query_channels(self): 494 | return self._output_query.shape[-1] 495 | 496 | def output_query(self, x): 497 | return repeat(self._output_query, "... -> b ...", b=x.shape[0]) 498 | 499 | def forward(self, x): 500 | raise NotImplementedError() 501 | 502 | 503 | class ClassificationOutputAdapter(OutputAdapter): 504 | def __init__(self, num_classes: int, num_output_queries: int = 1, num_output_query_channels: Optional[int] = None): 505 | 506 | if num_output_query_channels is None: 507 | num_output_query_channels = num_classes 508 | 509 | super().__init__(output_query=torch.empty(num_output_queries, num_output_query_channels)) 510 | self.linear = nn.Linear(num_output_query_channels, num_classes) 511 | 512 | def forward(self, x): 513 | return self.linear(x).squeeze(dim=1) 514 | 515 | 516 | class PerceiverEncoder(nn.Module): 517 | def __init__( 518 | self, 519 | input_adapter: InputAdapter, 520 | num_latents: int, 521 | num_latent_channels: int, 522 | num_cross_attention_heads: int = 4, 523 | num_cross_attention_qk_channels: Optional[int] = None, 524 | num_cross_attention_v_channels: Optional[int] = None, 525 | num_cross_attention_layers: int = 1, 526 | first_cross_attention_layer_shared: bool = False, 527 | cross_attention_widening_factor: int = 1, 528 | num_self_attention_heads: int = 4, 529 | num_self_attention_qk_channels: Optional[int] = None, 530 | num_self_attention_v_channels: Optional[int] = None, 531 | num_self_attention_layers_per_block: int = 6, 532 | num_self_attention_blocks: int = 1, 533 | first_self_attention_block_shared: bool = True, 534 | self_attention_widening_factor: int = 1, 535 | dropout: float = 0.0, 536 | activation_checkpointing: bool = False, 537 | ): 538 | """Generic Perceiver IO encoder. 539 | 540 | :param input_adapter: Transforms and position-encodes task-specific input to generic encoder input 541 | of shape (B, M, C) where B is the batch size, M the input sequence length and C the number of 542 | key/value input channels. C is determined by the `num_input_channels` property of the 543 | `input_adapter`. 544 | :param num_latents: Number of latent variables (N). 545 | :param num_latent_channels: Number of latent channels (D). 546 | :param num_cross_attention_heads: Number of cross-attention heads. 547 | :param num_cross_attention_qk_channels: Number of query and key channels for cross-attention 548 | (see `MultiHeadAttention.num_qk_channels` for details). 549 | :param num_cross_attention_v_channels: Number of value channels for cross-attention 550 | (see `MultiHeadAttention.num_v_channels` for details). 551 | :param num_cross_attention_layers: Number of cross-attention layers (alternating with self-attention blocks). 552 | :param first_cross_attention_layer_shared: Whether the first cross-attention layer should share its weights 553 | with subsequent cross-attention layers (if any). 554 | :param num_self_attention_heads: Number of self-attention heads. 555 | :param num_self_attention_qk_channels: Number of query and key channels for self-attention 556 | (see `MultiHeadAttention.num_qk_channels` for details). 557 | :param num_self_attention_v_channels: Number of value channels for self-attention 558 | (see `MultiHeadAttention.num_v_channels` for details). 559 | :param num_self_attention_layers_per_block: Number of self-attention layers per self-attention block. 560 | :param num_self_attention_blocks: Number of self-attention blocks sharing weights between corresponding 561 | self-attention layers. 562 | :param first_self_attention_block_shared: Whether the first self-attention block should share its weights 563 | with subsequent self-attention blocks (if any). 564 | :param dropout: Dropout probability for self- and cross-attention layers and residuals. 565 | :param activation_checkpointing: If True, implements an activation checkpoint for each self-attention 566 | layer and cross-attention layer. 567 | """ 568 | super().__init__() 569 | 570 | self.input_adapter = input_adapter 571 | 572 | if num_cross_attention_layers <= 0: 573 | raise ValueError("num_cross_attention_layers must be > 0") 574 | 575 | if num_self_attention_blocks <= 0: 576 | raise ValueError("num_self_attention_blocks must be > 0") 577 | 578 | if num_cross_attention_layers > num_self_attention_blocks: 579 | raise ValueError("num_cross_attention_layers must be <= num_self_attention_blocks") 580 | 581 | self.num_cross_attention_layers = num_cross_attention_layers 582 | self.num_self_attention_blocks = num_self_attention_blocks 583 | 584 | self.first_cross_attention_layer_shared = first_cross_attention_layer_shared 585 | self.first_self_attention_block_shared = first_self_attention_block_shared 586 | 587 | def cross_attn(): 588 | # print(input_adapter.num_input_channels) 589 | # assert(0==1) 590 | layer = CrossAttentionLayer( 591 | num_heads=num_cross_attention_heads, 592 | num_q_input_channels=num_latent_channels, 593 | num_kv_input_channels=input_adapter.num_input_channels, 594 | num_qk_channels=num_cross_attention_qk_channels, 595 | num_v_channels=num_cross_attention_v_channels, 596 | widening_factor=cross_attention_widening_factor, 597 | dropout=dropout, 598 | ) 599 | # return checkpoint_wrapper(layer) if activation_checkpointing else layer 600 | return layer 601 | 602 | def self_attn(): 603 | return SelfAttentionBlock( 604 | num_layers=num_self_attention_layers_per_block, 605 | num_heads=num_self_attention_heads, 606 | num_channels=num_latent_channels, 607 | num_qk_channels=num_self_attention_qk_channels, 608 | num_v_channels=num_self_attention_v_channels, 609 | widening_factor=self_attention_widening_factor, 610 | dropout=dropout, 611 | activation_checkpointing=activation_checkpointing, 612 | ) 613 | 614 | self.cross_attn_n = cross_attn() 615 | self.self_attn_n = self_attn() 616 | 617 | if self.first_cross_attention_layer_shared or num_cross_attention_layers == 1: 618 | self.cross_attn_1 = self.cross_attn_n 619 | else: 620 | self.cross_attn_1 = cross_attn() 621 | 622 | if self.first_self_attention_block_shared or num_self_attention_blocks == 1: 623 | self.self_attn_1 = self.self_attn_n 624 | else: 625 | self.self_attn_1 = self_attn() 626 | 627 | # learnable initial latent vectors 628 | self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels)) 629 | self._init_parameters() 630 | 631 | def _init_parameters(self): 632 | with torch.no_grad(): 633 | self.latent.normal_(0.0, 0.02).clamp_(-2.0, 2.0) 634 | 635 | def forward(self, x, pad_mask=None): 636 | b, *_ = x.shape 637 | 638 | # encode task-specific input 639 | print(x.shape) 640 | x = self.input_adapter(x) 641 | print(x.shape) 642 | assert(0==1) 643 | # repeat initial latent vector along batch dimension 644 | x_latent = repeat(self.latent, "... -> b ...", b=b) 645 | 646 | x_latent = self.cross_attn_1(x_latent, x, pad_mask) 647 | x_latent = self.self_attn_1(x_latent) 648 | 649 | for i in range(1, self.num_self_attention_blocks): 650 | if i < self.num_cross_attention_layers: 651 | x_latent = self.cross_attn_n(x_latent, x, pad_mask) 652 | x_latent = self.self_attn_n(x_latent) 653 | 654 | print(x_latent.shape) 655 | return x_latent 656 | 657 | 658 | class PerceiverDecoder(nn.Module): 659 | def __init__( 660 | self, 661 | output_adapter: OutputAdapter, 662 | num_latent_channels: int, 663 | num_cross_attention_heads: int = 4, 664 | num_cross_attention_qk_channels: Optional[int] = None, 665 | num_cross_attention_v_channels: Optional[int] = None, 666 | cross_attention_widening_factor: int = 1, 667 | dropout: float = 0.0, 668 | activation_checkpointing: bool = False, 669 | ): 670 | """Generic Perceiver IO decoder. 671 | 672 | :param output_adapter: Transforms generic decoder cross-attention output of shape (B, O, F) to task-specific 673 | output. B is the batch size, O the output sequence length and F the number of cross-attention output 674 | channels. F is determined by the `num_output_query_channels` property of the `output_adapter`. 675 | :param num_latent_channels: Number of latent channels (C_latent) as produced by a Perceiver IO encoder. 676 | :param num_cross_attention_heads: Number of cross-attention heads. 677 | :param num_cross_attention_qk_channels: Number of query and key channels for cross-attention 678 | (see `MultiHeadAttention.num_qk_channels` for details). 679 | :param num_cross_attention_v_channels: Number of value channels for cross-attention 680 | (see `MultiHeadAttention.num_v_channels` for details). 681 | :param dropout: Dropout probability for cross-attention layers and residuals. 682 | :param activation_checkpointing: If True, implements an activation checkpoint for the decoder's 683 | cross-attention layer. 684 | """ 685 | super().__init__() 686 | 687 | cross_attn = CrossAttentionLayer( 688 | num_heads=num_cross_attention_heads, 689 | num_q_input_channels=output_adapter.num_output_query_channels, 690 | num_kv_input_channels=num_latent_channels, 691 | num_qk_channels=num_cross_attention_qk_channels, 692 | num_v_channels=num_cross_attention_v_channels, 693 | widening_factor=cross_attention_widening_factor, 694 | dropout=dropout, 695 | ) 696 | 697 | # if activation_checkpointing: 698 | # cross_attn = checkpoint_wrapper(cross_attn) 699 | 700 | self.cross_attn = cross_attn 701 | self.output_adapter = output_adapter 702 | 703 | def forward(self, x): 704 | output_query = self.output_adapter.output_query(x) 705 | output = self.cross_attn(output_query, x) 706 | return self.output_adapter(output) 707 | 708 | 709 | class PerceiverIO(Sequential): 710 | def __init__(self, encoder: PerceiverEncoder, decoder: PerceiverDecoder): 711 | super().__init__(encoder, decoder) 712 | 713 | @property 714 | def encoder(self): 715 | return self[0] 716 | 717 | @property 718 | def decoder(self): 719 | return self[1] 720 | --------------------------------------------------------------------------------