├── models
├── cnn_core
│ ├── __init__.py
│ ├── mobilenet_v3.py
│ └── regnet.py
├── perceiver_core
│ ├── __init__.py
│ ├── utils.py
│ ├── config.py
│ └── modules.py
├── __init__.py
├── registry.py
└── dyn_perceiver_resnet.py
├── figs
├── fig6_ee.png
├── fig1_idea.png
├── fig9_video.png
├── fig2_overview.png
├── fig3_x2z_ca.png
├── fig5_z2x_ca.png
├── fig3_components.png
├── fig4_token_mixer.png
├── fig7_main_results.png
├── fig8_mob_results.png
└── fig10_speed_4subfig.png
├── LICENSE
├── .gitignore
├── losses.py
├── datasets.py
├── README.md
├── adaptive_inference.py
├── optim_factory.py
├── engine.py
├── utils.py
└── main_baseline.py
/models/cnn_core/__init__.py:
--------------------------------------------------------------------------------
1 | from .regnet import *
2 | from .mobilenet_v3 import *
--------------------------------------------------------------------------------
/models/perceiver_core/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/figs/fig6_ee.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig6_ee.png
--------------------------------------------------------------------------------
/figs/fig1_idea.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig1_idea.png
--------------------------------------------------------------------------------
/figs/fig9_video.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig9_video.png
--------------------------------------------------------------------------------
/figs/fig2_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig2_overview.png
--------------------------------------------------------------------------------
/figs/fig3_x2z_ca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig3_x2z_ca.png
--------------------------------------------------------------------------------
/figs/fig5_z2x_ca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig5_z2x_ca.png
--------------------------------------------------------------------------------
/figs/fig3_components.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig3_components.png
--------------------------------------------------------------------------------
/figs/fig4_token_mixer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig4_token_mixer.png
--------------------------------------------------------------------------------
/figs/fig7_main_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig7_main_results.png
--------------------------------------------------------------------------------
/figs/fig8_mob_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig8_mob_results.png
--------------------------------------------------------------------------------
/figs/fig10_speed_4subfig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeapLabTHU/Dynamic_Perceiver/HEAD/figs/fig10_speed_4subfig.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .dyn_perceiver_regnet import *
2 | from .dyn_perceiver_mobilenet_v3 import *
3 | from .dyn_perceiver_resnet import *
4 |
--------------------------------------------------------------------------------
/models/perceiver_core/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Sequential(nn.Sequential):
5 | def forward(self, *x):
6 | for module in self:
7 | if type(x) == tuple:
8 | x = module(*x)
9 | else:
10 | x = module(x)
11 | return x
12 |
13 |
14 | def freeze(module: nn.Module):
15 | for param in module.parameters():
16 | param.requires_grad = False
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/perceiver_core/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, dataclass, fields
2 | from typing import Generic, Optional, TypeVar
3 |
4 |
5 | @dataclass
6 | class EncoderConfig:
7 | num_cross_attention_heads: int = 8
8 | num_cross_attention_qk_channels: Optional[int] = None
9 | num_cross_attention_v_channels: Optional[int] = None
10 | num_cross_attention_layers: int = 1
11 | first_cross_attention_layer_shared: bool = False
12 | cross_attention_widening_factor: int = 1
13 | num_self_attention_heads: int = 8
14 | num_self_attention_qk_channels: Optional[int] = None
15 | num_self_attention_v_channels: Optional[int] = None
16 | num_self_attention_layers_per_block: int = 8
17 | num_self_attention_blocks: int = 1
18 | first_self_attention_block_shared: bool = True
19 | self_attention_widening_factor: int = 1
20 | dropout: float = 0.0
21 | freeze: bool = False
22 |
23 | def base_kwargs(self, exclude=("freeze",)):
24 | return _base_kwargs(self, EncoderConfig, exclude)
25 |
26 |
27 | @dataclass
28 | class DecoderConfig:
29 | num_cross_attention_heads: int = 8
30 | num_cross_attention_qk_channels: Optional[int] = None
31 | num_cross_attention_v_channels: Optional[int] = None
32 | cross_attention_widening_factor: int = 1
33 | dropout: float = 0.0
34 | freeze: bool = False
35 |
36 | def base_kwargs(self, exclude=("freeze",)):
37 | return _base_kwargs(self, DecoderConfig, exclude)
38 |
39 |
40 | E = TypeVar("E", bound=EncoderConfig)
41 | D = TypeVar("D", bound=DecoderConfig)
42 |
43 |
44 | @dataclass
45 | class PerceiverConfig(Generic[E, D]):
46 | encoder: E
47 | decoder: D
48 | num_latents: int
49 | num_latent_channels: int
50 | activation_checkpointing: bool
51 |
52 |
53 | @dataclass
54 | class ClassificationDecoderConfig(DecoderConfig):
55 | num_output_queries: int = 1
56 | num_output_query_channels: int = 256
57 | num_classes: int = 100
58 |
59 |
60 | def _base_kwargs(config, base_class, exclude):
61 | base_field_names = [field.name for field in fields(base_class) if field.name not in exclude]
62 | return {k: v for k, v in asdict(config).items() if k in base_field_names}
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | cluster/
3 | results/
4 | results_best/
5 | ckpt_cswin/cswin_tiny_224.pth
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 | *.tar.gz
11 | *.pth
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | exp/
38 | models/flops.txt
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Implements the knowledge distillation loss
5 | """
6 | import torch
7 | from torch.nn import functional as F
8 |
9 |
10 | class DistillationLoss(torch.nn.Module):
11 | """
12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
13 | taking a teacher model prediction and using it as additional supervision.
14 | """
15 |
16 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
17 | distillation_type: str, alpha: float, tau: float):
18 | super().__init__()
19 | self.base_criterion = base_criterion
20 | self.teacher_model = teacher_model
21 | self.teacher_model.eval()
22 | assert distillation_type in ['none', 'soft', 'hard']
23 | self.distillation_type = distillation_type
24 | self.alpha = alpha
25 | self.tau = tau
26 |
27 | def forward(self, inputs, outputs, labels):
28 | """
29 | Args:
30 | inputs: The original inputs that are feed to the teacher model
31 | outputs: the outputs of the model to be trained. It is expected to be
32 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
33 | in the first position and the distillation predictions as the second output
34 | labels: the labels for the base criterion
35 | """
36 | outputs_kd = None
37 | if not isinstance(outputs, torch.Tensor):
38 | # assume that the model outputs a tuple of [outputs, outputs_kd]
39 | outputs, outputs_kd = outputs
40 | else:
41 | outputs_kd = outputs
42 | base_loss = self.base_criterion(outputs, labels)
43 | if self.distillation_type == 'none':
44 | return base_loss
45 |
46 | if outputs_kd is None:
47 | raise ValueError("When knowledge distillation is enabled, the model is "
48 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
49 | "class_token and the dist_token")
50 | # don't backprop throught the teacher
51 | with torch.no_grad():
52 | teacher_outputs = self.teacher_model(inputs)
53 |
54 | if self.distillation_type == 'soft':
55 | T = self.tau
56 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
57 | # with slight modifications
58 | distillation_loss = F.kl_div(
59 | F.log_softmax(outputs_kd / T, dim=1),
60 | F.log_softmax(teacher_outputs / T, dim=1),
61 | reduction='sum',
62 | log_target=True
63 | ) * (T * T) / outputs_kd.numel()
64 | elif self.distillation_type == 'hard':
65 | distillation_loss = F.cross_entropy(
66 | outputs_kd, teacher_outputs.argmax(dim=1))
67 |
68 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
69 | return loss
70 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import os
10 | from torchvision import datasets, transforms
11 | # from dataloader_hf import FireFlyerImageNet
12 | from timm.data.constants import \
13 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
14 | from timm.data import create_transform
15 |
16 | def build_dataset(is_train, args):
17 | transform = build_transform(is_train, args)
18 |
19 | print("Transform = ")
20 | if isinstance(transform, tuple):
21 | for trans in transform:
22 | print(" - - - - - - - - - - ")
23 | for t in trans.transforms:
24 | print(t)
25 | else:
26 | for t in transform.transforms:
27 | print(t)
28 | print("---------------------------")
29 |
30 | if args.data_set == 'CIFAR':
31 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
32 | nb_classes = 100
33 | elif args.data_set == 'IMNET':
34 | print("reading from datapath", args.data_path)
35 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
36 | dataset = datasets.ImageFolder(root, transform=transform)
37 | nb_classes = 1000
38 | elif args.data_set == "image_folder":
39 | root = args.data_path if is_train else args.eval_data_path
40 | dataset = datasets.ImageFolder(root, transform=transform)
41 | nb_classes = args.nb_classes
42 | assert len(dataset.class_to_idx) == nb_classes
43 | else:
44 | raise NotImplementedError()
45 | print("Number of the class = %d" % nb_classes)
46 |
47 | return dataset, nb_classes
48 |
49 | def build_transform(is_train, args):
50 | resize_im = args.input_size > 32
51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
54 |
55 | if is_train:
56 | # this should always dispatch to transforms_imagenet_train
57 |
58 | transform = create_transform(
59 | input_size=args.input_size,
60 | is_training=True,
61 | color_jitter=args.color_jitter,
62 | auto_augment=args.aa,
63 | interpolation=args.train_interpolation,
64 | re_prob=args.reprob,
65 | re_mode=args.remode,
66 | re_count=args.recount,
67 | mean=mean,
68 | std=std,
69 | )
70 | if not resize_im:
71 | transform.transforms[0] = transforms.RandomCrop(
72 | args.input_size, padding=4)
73 | return transform
74 |
75 | t = []
76 | if resize_im:
77 | # warping (no cropping) when evaluated at 384 or larger
78 | if args.input_size >= 384:
79 | t.append(
80 | transforms.Resize((args.input_size, args.input_size),
81 | interpolation=transforms.InterpolationMode.BICUBIC),
82 | )
83 | print(f"Warping {args.input_size} size input images...")
84 | else:
85 | if args.crop_pct is None:
86 | args.crop_pct = 224 / 256
87 | size = int(args.input_size / args.crop_pct)
88 | t.append(
89 | # to maintain same ratio w.r.t. 224 images
90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
91 | )
92 | t.append(transforms.CenterCrop(args.input_size))
93 |
94 | t.append(transforms.ToTensor())
95 | t.append(transforms.Normalize(mean, std))
96 | return transforms.Compose(t)
97 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dynamic Perceiver for Efficient Visual Recognition
2 |
3 |
4 |
5 | ## Introduction
6 |
7 | This repository contains the implementation of the ICCV 2023 paper, [*Dynamic Perceiver for Efficient Visual Recognition*](https://arxiv.org/abs/2306.11248). The proposed **Dynamic Perceiver (Dyn-Perceiver)** decouples the feature extraction procedure and the early classification task with a novel two-branch architecture, which significantly improves model performance in the dynamic early exiting scenario.
8 |
9 | ### Overall idea
10 |
11 |
12 |
13 |
14 |
15 | ### Model overview
16 |
17 |
18 |
19 |
20 |
21 | ### The inference procedure
22 |
23 |
24 |
25 | ## Usage
26 |
27 | ### Dependencies
28 |
29 | - Python: 3.8
30 | - Pytorch: 1.12.1
31 | - Torchvision: 0.13.1
32 |
33 | ### Scripts
34 |
35 | - Train a **RegNetY-based Dynamic Perceiver** model on ImageNet:
36 |
37 | ```bash
38 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main_earlyExit.py \
39 | --model reg800m_perceiver_t128 --depth_factor 1 1 1 2 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \
40 | --num_workers 4 \
41 | --model_ema true --model_ema_eval true --epochs 300 \
42 | --batch_size 128 --lr 1e-3 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \
43 | --data_path YOUR_DATA_PATH \
44 | --output_dir YOUR_SAVE_PATH &\
45 | ```
46 |
47 | - Train a **ResNet-based Dynamic Perceiver** model on ImageNet:
48 |
49 | ```bash
50 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main_earlyExit.py \
51 | --model resnet50_0375_perceiver_t128 --depth_factor 1 1 1 1 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \
52 | --num_workers 4 \
53 | --model_ema true --model_ema_eval true --epochs 300 \
54 | --batch_size 128 --lr 6e-4 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \
55 | --data_path YOUR_DATA_PATH \
56 | --output_dir YOUR_SAVE_PATH &\
57 | ```
58 |
59 | - Train a **MobileNet-based Dynamic Perceiver** model on ImageNet:
60 |
61 | ```bash
62 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 main_earlyExit.py \
63 | --model mobilenetV3_0x75_perceiver_t128 --depth_factor 1 1 1 3 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \
64 | --num_workers 4 \
65 | --model_ema true --model_ema_eval true --epochs 600 \
66 | --batch_size 128 --lr 1e-3 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \
67 | --data_path YOUR_DATA_PATH \
68 | --output_dir YOUR_SAVE_PATH &\
69 | ```
70 |
71 | - Evaluate (dynamic):
72 |
73 | ```bash
74 | CUDA_VISIBLE_DEVICES=0 python main_earlyExit.py --eval true \
75 | --resume YOUR_CHECKPOINT_PATH \
76 | --model reg800m_perceiver_t128 --depth_factor 1 1 1 2 --spatial_reduction true --with_last_CA true --SA_widening_factor 4 --with_x2z true --with_dwc true --with_z2x true --with_isc true \
77 | --num_workers 4 \
78 | --batch_size 128 --lr 1e-3 --loss_cnn_factor 1.0 --loss_att_factor 0.5 --loss_merge_factor 1.0 --update_freq 1 --use_amp false --with_kd true --T_kd 1.0 --alpha_kd 0.5 \
79 | --data_path YOUR_DATA_PATH \
80 | --output_dir YOUR_SAVE_PATH &\
81 | ```
82 |
83 |
84 |
85 | ### Results
86 |
87 | - : ImageNet results of Dyn-Perceiver built on top of MobileNet-v3.
88 |
89 |
90 |
91 | - Speed test results of Dyn-Perceiver.
92 |
93 |
94 |
95 | ### Pre-trained Models on ImageNet
96 | |model|acc_exit1|acc_exit2|acc_exit3|acc_exit4|Checkpoint Link|
97 | |:-:|:-:|:-:|:-:|:-:|:-:|
98 | | reg800m_perceiver_t128 | 68.62 | 78.32 | 79.15 | 79.86 |[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ca3c8a808d504fdfa2c8/?dl=1) |
99 | | resnet50_0375_perceiver_t128 | 72.93 | 77.52 | 74.32 | 77.70 |[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/844ac164a62246eb9ce8/?dl=1) |
100 | | mobilenetV3_0x75_perceiver_t128 | 53.13 | 71.65 | 71.89 | 74.59 |[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/a86c56e3076146ad8748/?dl=1) |
101 |
102 | ### Contact
103 |
104 | If you have any questions, please feel free to contact the authors.
105 |
106 | Yizeng Han: [hanyz18@mails.tsinghua.edu.cn](mailto:hanyz18@mails.tsinghua.edu.cn), [yizeng38@gmail.com](mailto:yizeng38@gmail.com).
107 |
108 | Dongchen Han: [hdc19@mails.tsinghua.edu.cn](mailto:hdc19@mails.tsinghua.edu.cn), [tianqing1.10000@gmail.com](mailto:tianqing1.10000@gmail.com)
109 |
110 |
111 |
112 |
115 |
--------------------------------------------------------------------------------
/models/registry.py:
--------------------------------------------------------------------------------
1 | """ Model Registry
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 |
5 | import sys
6 | import re
7 | import fnmatch
8 | from collections import defaultdict
9 | from copy import deepcopy
10 |
11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
12 | 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
13 |
14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module
15 | _model_to_module = {} # mapping of model names to module names
16 | _model_entrypoints = {} # mapping of model names to entrypoint fns
17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present
18 | _model_default_cfgs = dict() # central repo for model default_cfgs
19 |
20 |
21 | def register_model(fn):
22 | # lookup containing module
23 | mod = sys.modules[fn.__module__]
24 | module_name_split = fn.__module__.split('.')
25 | module_name = module_name_split[-1] if len(module_name_split) else ''
26 |
27 | # add model to __all__ in module
28 | model_name = fn.__name__
29 | if hasattr(mod, '__all__'):
30 | mod.__all__.append(model_name)
31 | else:
32 | mod.__all__ = [model_name]
33 |
34 | # add entries to registry dict/sets
35 | _model_entrypoints[model_name] = fn
36 | _model_to_module[model_name] = module_name
37 | _module_to_models[module_name].add(model_name)
38 | has_pretrained = False # check if model has a pretrained url to allow filtering on this
39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
41 | # entrypoints or non-matching combos
42 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
43 | _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
44 | if has_pretrained:
45 | _model_has_pretrained.add(model_name)
46 | return fn
47 |
48 |
49 | def _natural_key(string_):
50 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
51 |
52 |
53 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
54 | """ Return list of available model names, sorted alphabetically
55 | Args:
56 | filter (str) - Wildcard filter string that works with fnmatch
57 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
58 | pretrained (bool) - Include only models with pretrained weights if True
59 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
60 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
61 | Example:
62 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
63 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
64 | """
65 | if module:
66 | models = list(_module_to_models[module])
67 | else:
68 | models = _model_entrypoints.keys()
69 | if filter:
70 | models = fnmatch.filter(models, filter) # include these models
71 | if exclude_filters:
72 | if not isinstance(exclude_filters, (tuple, list)):
73 | exclude_filters = [exclude_filters]
74 | for xf in exclude_filters:
75 | exclude_models = fnmatch.filter(models, xf) # exclude these models
76 | if len(exclude_models):
77 | models = set(models).difference(exclude_models)
78 | if pretrained:
79 | models = _model_has_pretrained.intersection(models)
80 | if name_matches_cfg:
81 | models = set(_model_default_cfgs).intersection(models)
82 | return list(sorted(models, key=_natural_key))
83 |
84 |
85 | def is_model(model_name):
86 | """ Check if a model name exists
87 | """
88 | return model_name in _model_entrypoints
89 |
90 |
91 | def model_entrypoint(model_name):
92 | """Fetch a model entrypoint for specified model name
93 | """
94 | return _model_entrypoints[model_name]
95 |
96 |
97 | def list_modules():
98 | """ Return list of module names that contain models / model entrypoints
99 | """
100 | modules = _module_to_models.keys()
101 | return list(sorted(modules))
102 |
103 |
104 | def is_model_in_modules(model_name, module_names):
105 | """Check if a model exists within a subset of modules
106 | Args:
107 | model_name (str) - name of model to check
108 | module_names (tuple, list, set) - names of modules to search in
109 | """
110 | assert isinstance(module_names, (tuple, list, set))
111 | return any(model_name in _module_to_models[n] for n in module_names)
112 |
113 |
114 | def has_model_default_key(model_name, cfg_key):
115 | """ Query model default_cfgs for existence of a specific key.
116 | """
117 | if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
118 | return True
119 | return False
120 |
121 |
122 | def is_model_default_key(model_name, cfg_key):
123 | """ Return truthy value for specified model default_cfg key, False if does not exist.
124 | """
125 | if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
126 | return True
127 | return False
128 |
129 |
130 | def get_model_default_value(model_name, cfg_key):
131 | """ Get a specific model default_cfg value by key. None if it doesn't exist.
132 | """
133 | if model_name in _model_default_cfgs:
134 | return _model_default_cfgs[model_name].get(cfg_key, None)
135 | else:
136 | return None
137 |
138 |
139 | def is_model_pretrained(model_name):
140 | return model_name in _model_has_pretrained
--------------------------------------------------------------------------------
/adaptive_inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import math
5 | import numpy as np
6 |
7 |
8 | def dynamic_evaluate(model, test_loader, val_loader, filename, args):
9 | tester = Tester(model)
10 | # if os.path.exists(os.path.join(args.output_dir, 'logits_single.pth')):
11 | # val_pred, val_target, test_pred, test_target = \
12 | # torch.load(os.path.join(args.output_dir, 'logits_single.pth'))
13 | # else:
14 | val_pred, val_target = tester.calc_logit(val_loader, early_break=True)
15 | test_pred, test_target = tester.calc_logit(test_loader, early_break=False)
16 | # torch.save((val_pred, val_target, test_pred, test_target),
17 | # os.path.join(args.output_dir, 'logits_single.pth'))
18 |
19 | # flops = torch.load(os.path.join(args.output_dir, 'flops.pth'))
20 | flops = np.loadtxt(f'{args.output_dir}/flops.txt')
21 | flops = [flops[i] for i in [0,1,2,3]]
22 | each_exit = False
23 | with open(os.path.join(args.output_dir, filename), 'w') as fout:
24 | probs_list = generate_distribution(each_exit=each_exit)
25 | for probs in probs_list:
26 | print('\n*****************')
27 | print(probs)
28 | acc_val, _, T = tester.dynamic_eval_find_threshold(val_pred, val_target, probs, flops)
29 | print(T)
30 | acc_test, exp_flops, acc_each_stage = tester.dynamic_eval_with_threshold(test_pred, test_target, flops, T)
31 | print('valid acc: {:.3f}, test acc: {:.3f}, test flops: {:.2f}M'.format(acc_val, acc_test, exp_flops))
32 | print('acc of each exit: {}'.format(acc_each_stage))
33 | fout.write('{}\t{}\n'.format(acc_test, exp_flops.item()))
34 | print('----------ALL DONE-----------')
35 |
36 |
37 | def generate_distribution(each_exit=False):
38 | probs_list = []
39 | if each_exit:
40 | for i in range(4):
41 | probs = torch.zeros(4, dtype=torch.float)
42 | probs[i] = 1
43 | probs_list.append(probs)
44 | else:
45 | p_list = torch.zeros(34)
46 | for i in range(17):
47 | p_list[i] = (i + 4) / 20
48 | p_list[33 - i] = 20 / (i + 4)
49 |
50 | # y_early3, y_att, y_cnn, y_merge
51 | # 对应放缩比例
52 | k = [0.85, 1, 0.5, 1]
53 | for i in range(33):
54 | probs = torch.exp(torch.log(p_list[i]) * torch.range(1, 4))
55 | probs /= probs.sum()
56 | for j in range(3):
57 | probs[j] *= k[j]
58 | probs[j+1:4] = (1 - probs[0:j+1].sum()) * probs[j+1:4] / probs[j+1:4].sum()
59 | probs_list.append(probs)
60 | return probs_list
61 |
62 |
63 | class Tester(object):
64 | def __init__(self, model):
65 | # self.args = args
66 | self.model = model
67 | self.softmax = nn.Softmax(dim=1).cuda()
68 |
69 | def calc_logit(self, dataloader, early_break=False):
70 | self.model.eval()
71 | n_stage = 4
72 | logits = [[] for _ in range(n_stage)]
73 | targets = []
74 | # print('xxxxxxxxxxx111111')
75 | # print(len(dataloader))
76 | for i, (input, target) in enumerate(dataloader):
77 | # print(input.shape, target.shape)
78 | if early_break and i > 100:
79 | break
80 | targets.append(target)
81 | input = input.cuda()
82 | with torch.no_grad():
83 | y_early3, y_att, y_cnn, y_merge = self.model(input)
84 | output = [y_early3, y_att, y_cnn, y_merge]
85 | for b in range(n_stage):
86 | _t = self.softmax(output[b])
87 |
88 | logits[b].append(_t)
89 | if i % 50 == 0:
90 | print('Generate Logit: [{0}/{1}]'.format(i, len(dataloader)))
91 |
92 | for b in range(n_stage):
93 | logits[b] = torch.cat(logits[b], dim=0)
94 |
95 | size = (n_stage, logits[0].size(0), logits[0].size(1))
96 | ts_logits = torch.Tensor().resize_(size).zero_()
97 | for b in range(n_stage):
98 | ts_logits[b].copy_(logits[b])
99 |
100 | targets = torch.cat(targets, dim=0)
101 | ts_targets = torch.Tensor().resize_(size[1]).copy_(targets)
102 |
103 | return ts_logits, ts_targets
104 |
105 | def dynamic_eval_find_threshold(self, logits, targets, p, flops):
106 | """
107 | logits: m * n * c
108 | m: Stages
109 | n: Samples
110 | c: Classes
111 | """
112 | n_stage, n_sample, c = logits.size()
113 |
114 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False)
115 |
116 | _, sorted_idx = max_preds.sort(dim=1, descending=True)
117 |
118 | filtered = torch.zeros(n_sample)
119 | T = torch.Tensor(n_stage).fill_(1e8)
120 |
121 | for k in range(n_stage - 1):
122 | acc, count = 0.0, 0
123 | out_n = math.floor(n_sample * p[k])
124 | for i in range(n_sample):
125 | ori_idx = sorted_idx[k][i]
126 | if filtered[ori_idx] == 0:
127 | count += 1
128 | if count == out_n:
129 | T[k] = max_preds[k][ori_idx]
130 | break
131 | filtered.add_(max_preds[k].ge(T[k]).type_as(filtered))
132 |
133 | T[n_stage -1] = -1e8 # accept all of the samples at the last stage
134 |
135 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage)
136 | acc, expected_flops = 0, 0
137 | for i in range(n_sample):
138 | gold_label = targets[i]
139 | for k in range(n_stage):
140 | if max_preds[k][i].item() >= T[k]: # force the sample to exit at k
141 | if int(gold_label.item()) == int(argmax_preds[k][i].item()):
142 | acc += 1
143 | acc_rec[k] += 1
144 | exp[k] += 1
145 | break
146 | acc_all = 0
147 | for k in range(n_stage):
148 | _t = 1.0 * exp[k] / n_sample
149 | expected_flops += _t * flops[k]
150 | acc_all += acc_rec[k]
151 |
152 | return acc * 100.0 / n_sample, expected_flops, T
153 |
154 | def dynamic_eval_with_threshold(self, logits, targets, flops, T):
155 | n_stage, n_sample, _ = logits.size()
156 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False) # take the max logits as confidence
157 |
158 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage)
159 | acc, expected_flops = 0, 0
160 | for i in range(n_sample):
161 | gold_label = targets[i]
162 | for k in range(n_stage):
163 | if max_preds[k][i].item() >= T[k]: # force to exit at k
164 | _g = int(gold_label.item())
165 | _pred = int(argmax_preds[k][i].item())
166 | if _g == _pred:
167 | acc += 1
168 | acc_rec[k] += 1
169 | exp[k] += 1
170 | break
171 | acc_all, sample_all = 0, 0
172 | for k in range(n_stage):
173 | _t = exp[k] * 1.0 / n_sample
174 | sample_all += exp[k]
175 | expected_flops += _t * flops[k]
176 | acc_all += acc_rec[k]
177 |
178 | return acc * 100.0 / n_sample, expected_flops, acc_rec / exp
179 |
180 |
181 | if __name__ == '__main__':
182 | print(generate_distribution(each_exit=False))
183 |
--------------------------------------------------------------------------------
/optim_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import torch
10 | from torch import optim as optim
11 |
12 | from timm.optim.adafactor import Adafactor
13 | from timm.optim.adahessian import Adahessian
14 | from timm.optim.adamp import AdamP
15 | from timm.optim.lookahead import Lookahead
16 | from timm.optim.nadam import Nadam
17 | from timm.optim.novograd import NovoGrad
18 | from timm.optim.nvnovograd import NvNovoGrad
19 | from timm.optim.radam import RAdam
20 | from timm.optim.rmsprop_tf import RMSpropTF
21 | from timm.optim.sgdp import SGDP
22 |
23 | import json
24 |
25 | try:
26 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
27 | has_apex = True
28 | except ImportError:
29 | has_apex = False
30 |
31 |
32 | def get_num_layer_for_convnext(var_name):
33 | """
34 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three
35 | consecutive blocks, including possible neighboring downsample layers;
36 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
37 | """
38 | num_max_layer = 12
39 | if var_name.startswith("downsample_layers"):
40 | stage_id = int(var_name.split('.')[1])
41 | if stage_id == 0:
42 | layer_id = 0
43 | elif stage_id == 1 or stage_id == 2:
44 | layer_id = stage_id + 1
45 | elif stage_id == 3:
46 | layer_id = 12
47 | return layer_id
48 |
49 | elif var_name.startswith("stages"):
50 | stage_id = int(var_name.split('.')[1])
51 | block_id = int(var_name.split('.')[2])
52 | if stage_id == 0 or stage_id == 1:
53 | layer_id = stage_id + 1
54 | elif stage_id == 2:
55 | layer_id = 3 + block_id // 3
56 | elif stage_id == 3:
57 | layer_id = 12
58 | return layer_id
59 | else:
60 | return num_max_layer + 1
61 |
62 | class LayerDecayValueAssigner(object):
63 | def __init__(self, values):
64 | self.values = values
65 |
66 | def get_scale(self, layer_id):
67 | return self.values[layer_id]
68 |
69 | def get_layer_id(self, var_name):
70 | return get_num_layer_for_convnext(var_name)
71 |
72 |
73 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
74 | parameter_group_names = {}
75 | parameter_group_vars = {}
76 |
77 | for name, param in model.named_parameters():
78 | if not param.requires_grad:
79 | continue # frozen weights
80 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
81 | group_name = "no_decay"
82 | this_weight_decay = 0.
83 | else:
84 | group_name = "decay"
85 | this_weight_decay = weight_decay
86 | if get_num_layer is not None:
87 | layer_id = get_num_layer(name)
88 | group_name = "layer_%d_%s" % (layer_id, group_name)
89 | else:
90 | layer_id = None
91 |
92 | if group_name not in parameter_group_names:
93 | if get_layer_scale is not None:
94 | scale = get_layer_scale(layer_id)
95 | else:
96 | scale = 1.
97 |
98 | parameter_group_names[group_name] = {
99 | "weight_decay": this_weight_decay,
100 | "params": [],
101 | "lr_scale": scale
102 | }
103 | parameter_group_vars[group_name] = {
104 | "weight_decay": this_weight_decay,
105 | "params": [],
106 | "lr_scale": scale
107 | }
108 |
109 | parameter_group_vars[group_name]["params"].append(param)
110 | parameter_group_names[group_name]["params"].append(name)
111 | # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
112 | return list(parameter_group_vars.values())
113 |
114 |
115 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
116 | opt_lower = args.opt.lower()
117 | weight_decay = args.weight_decay
118 | # if weight_decay and filter_bias_and_bn:
119 | if filter_bias_and_bn:
120 | skip = {}
121 | if skip_list is not None:
122 | skip = skip_list
123 | elif hasattr(model, 'no_weight_decay'):
124 | skip = model.no_weight_decay()
125 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
126 | weight_decay = 0.
127 | else:
128 | parameters = model.parameters()
129 |
130 | if 'fused' in opt_lower:
131 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
132 |
133 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
134 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
135 | opt_args['eps'] = args.opt_eps
136 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
137 | opt_args['betas'] = args.opt_betas
138 |
139 | opt_split = opt_lower.split('_')
140 | opt_lower = opt_split[-1]
141 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
142 | opt_args.pop('eps', None)
143 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
144 | elif opt_lower == 'momentum':
145 | opt_args.pop('eps', None)
146 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
147 | elif opt_lower == 'adam':
148 | optimizer = optim.Adam(parameters, **opt_args)
149 | elif opt_lower == 'adamw':
150 | optimizer = optim.AdamW(parameters, **opt_args)
151 | elif opt_lower == 'nadam':
152 | optimizer = Nadam(parameters, **opt_args)
153 | elif opt_lower == 'radam':
154 | optimizer = RAdam(parameters, **opt_args)
155 | elif opt_lower == 'adamp':
156 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
157 | elif opt_lower == 'sgdp':
158 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
159 | elif opt_lower == 'adadelta':
160 | optimizer = optim.Adadelta(parameters, **opt_args)
161 | elif opt_lower == 'adafactor':
162 | if not args.lr:
163 | opt_args['lr'] = None
164 | optimizer = Adafactor(parameters, **opt_args)
165 | elif opt_lower == 'adahessian':
166 | optimizer = Adahessian(parameters, **opt_args)
167 | elif opt_lower == 'rmsprop':
168 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
169 | elif opt_lower == 'rmsproptf':
170 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
171 | elif opt_lower == 'novograd':
172 | optimizer = NovoGrad(parameters, **opt_args)
173 | elif opt_lower == 'nvnovograd':
174 | optimizer = NvNovoGrad(parameters, **opt_args)
175 | elif opt_lower == 'fusedsgd':
176 | opt_args.pop('eps', None)
177 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
178 | elif opt_lower == 'fusedmomentum':
179 | opt_args.pop('eps', None)
180 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
181 | elif opt_lower == 'fusedadam':
182 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
183 | elif opt_lower == 'fusedadamw':
184 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
185 | elif opt_lower == 'fusedlamb':
186 | optimizer = FusedLAMB(parameters, **opt_args)
187 | elif opt_lower == 'fusednovograd':
188 | opt_args.setdefault('betas', (0.95, 0.98))
189 | optimizer = FusedNovoGrad(parameters, **opt_args)
190 | else:
191 | assert False and "Invalid optimizer"
192 |
193 | if len(opt_split) > 1:
194 | if opt_split[0] == 'lookahead':
195 | optimizer = Lookahead(optimizer)
196 |
197 | return optimizer
198 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import math
10 | from typing import Iterable, Optional
11 | import torch
12 | from timm.data import Mixup
13 | from timm.utils import accuracy, ModelEma
14 |
15 | import utils
16 |
17 | import torch.nn.functional as F
18 |
19 | def train_one_epoch_earlyExit(model: torch.nn.Module, criterion: torch.nn.Module,
20 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
23 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
24 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False,
25 | loss_cnn_factor=0.25,
26 | loss_att_factor=0.25,
27 | loss_merge_factor=0.5,
28 |
29 | with_kd=False,
30 | T_kd=4.0,
31 | alpha_kd=0.5,
32 | criterion_distill=None):
33 | model.train(True)
34 | metric_logger = utils.MetricLogger(delimiter=" ")
35 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
36 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
37 |
38 | header = 'Epoch: [{}]'.format(epoch)
39 | print_freq = 10
40 |
41 | optimizer.zero_grad()
42 |
43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
44 |
45 | # if data_iter_step >= 50:
46 | # break
47 | step = data_iter_step // update_freq
48 | if step >= num_training_steps_per_epoch:
49 | continue
50 | it = start_steps + step # global training iteration
51 | # Update LR & WD for the first acc
52 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
53 | for i, param_group in enumerate(optimizer.param_groups):
54 | if lr_schedule_values is not None:
55 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
56 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
57 | param_group["weight_decay"] = wd_schedule_values[it]
58 |
59 | samples = samples.to(device, non_blocking=True)
60 | targets = targets.to(device, non_blocking=True)
61 |
62 | if mixup_fn is not None:
63 | samples, targets = mixup_fn(samples, targets)
64 |
65 | if use_amp:
66 | with torch.cuda.amp.autocast():
67 | y_early3, y_att, y_cnn, y_merge = model(samples)
68 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, targets), criterion(y_att, targets), criterion(y_merge, targets) if criterion_distill is None else criterion_distill(samples, y_merge, targets), criterion(y_early3, targets)
69 | else: # full precision
70 | y_early3, y_att, y_cnn, y_merge, = model(samples)
71 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, targets), criterion(y_att, targets), criterion(y_merge, targets) if criterion_distill is None else criterion_distill(samples, y_merge, targets), criterion(y_early3, targets)
72 |
73 | loss = loss_cnn_factor*loss_cnn + loss_att_factor*(loss_att+loss_early3) + loss_merge_factor*loss_merge
74 |
75 | if with_kd:
76 | out_teacher = y_merge.detach()
77 |
78 | kd_loss = F.kl_div(F.log_softmax(y_early3/T_kd, dim=1),F.softmax(out_teacher/T_kd, dim=1), reduction='batchmean') * T_kd**2 + \
79 | F.kl_div(F.log_softmax(y_att/T_kd, dim=1),F.softmax(out_teacher/T_kd, dim=1), reduction='batchmean') * T_kd**2 + \
80 | F.kl_div(F.log_softmax(y_cnn/T_kd, dim=1),F.softmax(out_teacher/T_kd, dim=1), reduction='batchmean') * T_kd**2
81 |
82 | loss += alpha_kd * kd_loss
83 |
84 | loss_value = loss.item()
85 |
86 | if not math.isfinite(loss_value): # this could trigger if using AMP
87 | print("Loss is {}, stopping training".format(loss_value))
88 | assert math.isfinite(loss_value)
89 |
90 | if use_amp:
91 | # this attribute is added by timm on one optimizer (adahessian)
92 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
93 | loss /= update_freq
94 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
95 | parameters=model.parameters(), create_graph=is_second_order,
96 | update_grad=(data_iter_step + 1) % update_freq == 0)
97 | if (data_iter_step + 1) % update_freq == 0:
98 | optimizer.zero_grad()
99 | if model_ema is not None:
100 | model_ema.update(model)
101 | else: # full precision
102 | loss /= update_freq
103 | loss.backward()
104 | if (data_iter_step + 1) % update_freq == 0:
105 | optimizer.step()
106 | optimizer.zero_grad()
107 | if model_ema is not None:
108 | model_ema.update(model)
109 |
110 | torch.cuda.synchronize()
111 |
112 | if mixup_fn is None:
113 | class_acc_cnn = (y_cnn.max(-1)[-1] == targets).float().mean()
114 | class_acc_att = (y_att.max(-1)[-1] == targets).float().mean()
115 | class_acc_merge = (y_merge.max(-1)[-1] == targets).float().mean()
116 | else:
117 | class_acc_cnn, class_acc_att, class_acc_merge = None, None, None
118 | metric_logger.update(loss=loss_value)
119 | metric_logger.update(class_acc=class_acc_merge)
120 | min_lr = 10.
121 | max_lr = 0.
122 | for group in optimizer.param_groups:
123 | min_lr = min(min_lr, group["lr"])
124 | max_lr = max(max_lr, group["lr"])
125 |
126 | metric_logger.update(lr=max_lr)
127 | metric_logger.update(min_lr=min_lr)
128 | weight_decay_value = None
129 | for group in optimizer.param_groups:
130 | if group["weight_decay"] > 0:
131 | weight_decay_value = group["weight_decay"]
132 | metric_logger.update(weight_decay=weight_decay_value)
133 |
134 |
135 |
136 | if use_amp:
137 | metric_logger.update(grad_norm=grad_norm)
138 |
139 | # print(str(metric_logger))
140 | # assert(0==1)
141 |
142 | if log_writer is not None:
143 | log_writer.update(loss=loss_value, head="loss")
144 | # log_writer.update(class_acc_cnn=class_acc_cnn, head="loss")
145 | # log_writer.update(class_acc_att=class_acc_att, head="loss")
146 | log_writer.update(class_acc=class_acc_merge, head="loss")
147 | log_writer.update(lr=max_lr, head="opt")
148 | log_writer.update(min_lr=min_lr, head="opt")
149 | log_writer.update(weight_decay=weight_decay_value, head="opt")
150 | if use_amp:
151 | log_writer.update(grad_norm=grad_norm, head="opt")
152 | log_writer.set_step()
153 |
154 | if wandb_logger:
155 | wandb_logger._wandb.log({
156 | 'Rank-0 Batch Wise/train_loss': loss_value,
157 | 'Rank-0 Batch Wise/train_max_lr': max_lr,
158 | 'Rank-0 Batch Wise/train_min_lr': min_lr
159 | }, commit=False)
160 | if class_acc_cnn:
161 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc_cnn': class_acc_cnn}, commit=False)
162 | if class_acc_att:
163 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc_att': class_acc_att}, commit=False)
164 | if class_acc_merge:
165 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc_merge': class_acc_merge}, commit=False)
166 | if use_amp:
167 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
168 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
169 |
170 |
171 | # gather the stats from all processes
172 | metric_logger.synchronize_between_processes()
173 | print("Averaged stats:", metric_logger)
174 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
175 |
176 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
177 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
178 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
179 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
180 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
181 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
182 | model.train(True)
183 | metric_logger = utils.MetricLogger(delimiter=" ")
184 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
185 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
186 | header = 'Epoch: [{}]'.format(epoch)
187 | print_freq = 10
188 |
189 | optimizer.zero_grad()
190 |
191 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
192 | step = data_iter_step // update_freq
193 | if step >= num_training_steps_per_epoch:
194 | continue
195 | it = start_steps + step # global training iteration
196 | # Update LR & WD for the first acc
197 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
198 | for i, param_group in enumerate(optimizer.param_groups):
199 | if lr_schedule_values is not None:
200 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
201 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
202 | param_group["weight_decay"] = wd_schedule_values[it]
203 |
204 | samples = samples.to(device, non_blocking=True)
205 | targets = targets.to(device, non_blocking=True)
206 |
207 | if mixup_fn is not None:
208 | samples, targets = mixup_fn(samples, targets)
209 |
210 | if use_amp:
211 | with torch.cuda.amp.autocast():
212 | output = model(samples)
213 | loss = criterion(output, targets)
214 | else: # full precision
215 | output = model(samples)
216 | loss = criterion(output, targets)
217 |
218 | loss_value = loss.item()
219 |
220 | if not math.isfinite(loss_value): # this could trigger if using AMP
221 | print("Loss is {}, stopping training".format(loss_value))
222 | assert math.isfinite(loss_value)
223 |
224 | if use_amp:
225 | # this attribute is added by timm on one optimizer (adahessian)
226 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
227 | loss /= update_freq
228 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
229 | parameters=model.parameters(), create_graph=is_second_order,
230 | update_grad=(data_iter_step + 1) % update_freq == 0)
231 | if (data_iter_step + 1) % update_freq == 0:
232 | optimizer.zero_grad()
233 | if model_ema is not None:
234 | model_ema.update(model)
235 | else: # full precision
236 | loss /= update_freq
237 | loss.backward()
238 | if (data_iter_step + 1) % update_freq == 0:
239 | optimizer.step()
240 | optimizer.zero_grad()
241 | if model_ema is not None:
242 | model_ema.update(model)
243 |
244 | torch.cuda.synchronize()
245 |
246 | if mixup_fn is None:
247 | class_acc = (output.max(-1)[-1] == targets).float().mean()
248 | else:
249 | class_acc = None
250 | metric_logger.update(loss=loss_value)
251 | metric_logger.update(class_acc=class_acc)
252 | min_lr = 10.
253 | max_lr = 0.
254 | for group in optimizer.param_groups:
255 | min_lr = min(min_lr, group["lr"])
256 | max_lr = max(max_lr, group["lr"])
257 |
258 | metric_logger.update(lr=max_lr)
259 | metric_logger.update(min_lr=min_lr)
260 | weight_decay_value = None
261 | for group in optimizer.param_groups:
262 | if group["weight_decay"] > 0:
263 | weight_decay_value = group["weight_decay"]
264 | metric_logger.update(weight_decay=weight_decay_value)
265 | if use_amp:
266 | metric_logger.update(grad_norm=grad_norm)
267 |
268 | if log_writer is not None:
269 | log_writer.update(loss=loss_value, head="loss")
270 | log_writer.update(class_acc=class_acc, head="loss")
271 | log_writer.update(lr=max_lr, head="opt")
272 | log_writer.update(min_lr=min_lr, head="opt")
273 | log_writer.update(weight_decay=weight_decay_value, head="opt")
274 | if use_amp:
275 | log_writer.update(grad_norm=grad_norm, head="opt")
276 | log_writer.set_step()
277 |
278 | if wandb_logger:
279 | wandb_logger._wandb.log({
280 | 'Rank-0 Batch Wise/train_loss': loss_value,
281 | 'Rank-0 Batch Wise/train_max_lr': max_lr,
282 | 'Rank-0 Batch Wise/train_min_lr': min_lr
283 | }, commit=False)
284 | if class_acc:
285 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False)
286 | if use_amp:
287 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
288 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
289 |
290 |
291 | # gather the stats from all processes
292 | metric_logger.synchronize_between_processes()
293 | print("Averaged stats:", metric_logger)
294 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
295 |
296 | @torch.no_grad()
297 | def evaluate(data_loader, model, device, use_amp=False):
298 | criterion = torch.nn.CrossEntropyLoss()
299 |
300 | metric_logger = utils.MetricLogger(delimiter=" ")
301 | header = 'Test:'
302 |
303 | # switch to evaluation mode
304 | model.eval()
305 | for batch in metric_logger.log_every(data_loader, 10, header):
306 | images = batch[0]
307 | target = batch[-1]
308 |
309 | images = images.to(device, non_blocking=True)
310 | target = target.to(device, non_blocking=True)
311 |
312 | # compute output
313 | if use_amp:
314 | with torch.cuda.amp.autocast():
315 | output = model(images)
316 | loss = criterion(output, target)
317 | else:
318 | output = model(images)
319 | loss = criterion(output, target)
320 |
321 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
322 |
323 | batch_size = images.shape[0]
324 | metric_logger.update(loss=loss.item())
325 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
326 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
327 | # gather the stats from all processes
328 | metric_logger.synchronize_between_processes()
329 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
330 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
331 |
332 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
333 |
334 | @torch.no_grad()
335 | def evaluate_earlyExit(data_loader, model, device, use_amp=False,
336 | loss_cnn_factor=0.25,
337 | loss_att_factor=0.25,
338 | loss_merge_factor=0.5):
339 | criterion = torch.nn.CrossEntropyLoss()
340 |
341 | metric_logger = utils.MetricLogger(delimiter=" ")
342 | header = 'Test:'
343 |
344 | # switch to evaluation mode
345 | model.eval()
346 | for batch in metric_logger.log_every(data_loader, 10, header):
347 | images = batch[0]
348 | target = batch[-1]
349 |
350 | images = images.to(device, non_blocking=True)
351 | target = target.to(device, non_blocking=True)
352 |
353 | # compute output
354 | if use_amp:
355 | with torch.cuda.amp.autocast():
356 | y_early3, y_att, y_cnn, y_merge, = model(images)
357 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, target), criterion(y_att, target), criterion(y_merge, target), criterion(y_early3, target)
358 | else:
359 | y_early3, y_att, y_cnn, y_merge, = model(images)
360 | loss_cnn, loss_att, loss_merge, loss_early3 = criterion(y_cnn, target), criterion(y_att, target), criterion(y_merge, target), criterion(y_early3, target)
361 |
362 | loss = loss_cnn_factor*loss_cnn + loss_att_factor*(loss_att+loss_early3) + loss_merge_factor*loss_merge
363 |
364 | acc1_cnn, acc5_cnn = accuracy(y_cnn, target, topk=(1, 5))
365 | acc1_att, acc5_att = accuracy(y_att, target, topk=(1, 5))
366 | acc1_early3, acc5_early3 = accuracy(y_early3, target, topk=(1, 5))
367 | acc1_merge, acc5_merge = accuracy(y_merge, target, topk=(1, 5))
368 |
369 | batch_size = images.shape[0]
370 | metric_logger.update(loss=loss.item())
371 | metric_logger.meters['acc1_cnn'].update(acc1_cnn.item(), n=batch_size)
372 | metric_logger.meters['acc1_att'].update(acc1_att.item(), n=batch_size)
373 | metric_logger.meters['acc1_early3'].update(acc1_early3.item(), n=batch_size)
374 | metric_logger.meters['acc1_merge'].update(acc1_merge.item(), n=batch_size)
375 | # gather the stats from all processes
376 | metric_logger.synchronize_between_processes()
377 | print('* Acc@1_early3 {top1_early3.global_avg:.3f} \
378 | Acc@1_merge {top1_merge.global_avg:.3f} \
379 | Acc@1_att {top1_att.global_avg:.3f} \
380 | Acc@1_cnn {top1_cnn.global_avg:.3f} \
381 | loss {losses.global_avg:.3f}'
382 | .format(top1_cnn=metric_logger.acc1_cnn,
383 | top1_att=metric_logger.acc1_att,
384 | top1_early3=metric_logger.acc1_early3,
385 | top1_merge=metric_logger.acc1_merge,
386 | losses=metric_logger.loss))
387 |
388 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
389 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 |
9 | import os
10 | import math
11 | import time
12 | from collections import defaultdict, deque
13 | import datetime
14 | import numpy as np
15 | from timm.utils import get_state_dict
16 |
17 | from pathlib import Path
18 |
19 | import torch
20 | import torch.distributed as dist
21 | from torch._six import inf
22 |
23 | from tensorboardX import SummaryWriter
24 |
25 | class SmoothedValue(object):
26 | """Track a series of values and provide access to smoothed values over a
27 | window or the global series average.
28 | """
29 |
30 | def __init__(self, window_size=20, fmt=None):
31 | if fmt is None:
32 | fmt = "{median:.4f} ({global_avg:.4f})"
33 | self.deque = deque(maxlen=window_size)
34 | self.total = 0.0
35 | self.count = 0
36 | self.fmt = fmt
37 |
38 | def update(self, value, n=1):
39 | self.deque.append(value)
40 | self.count += n
41 | self.total += value * n
42 |
43 | def synchronize_between_processes(self):
44 | """
45 | Warning: does not synchronize the deque!
46 | """
47 | if not is_dist_avail_and_initialized():
48 | return
49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
50 | dist.barrier()
51 | dist.all_reduce(t)
52 | t = t.tolist()
53 | self.count = int(t[0])
54 | self.total = t[1]
55 |
56 | @property
57 | def median(self):
58 | d = torch.tensor(list(self.deque))
59 | return d.median().item()
60 |
61 | @property
62 | def avg(self):
63 | d = torch.tensor(list(self.deque), dtype=torch.float32)
64 | return d.mean().item()
65 |
66 | @property
67 | def global_avg(self):
68 | return self.total / self.count
69 |
70 | @property
71 | def max(self):
72 | return max(self.deque)
73 |
74 | @property
75 | def value(self):
76 | return self.deque[-1]
77 |
78 | def __str__(self):
79 | return self.fmt.format(
80 | median=self.median,
81 | avg=self.avg,
82 | global_avg=self.global_avg,
83 | max=self.max,
84 | value=self.value)
85 |
86 |
87 | class MetricLogger(object):
88 | def __init__(self, delimiter="\t"):
89 | self.meters = defaultdict(SmoothedValue)
90 | self.delimiter = delimiter
91 |
92 | def update(self, **kwargs):
93 | for k, v in kwargs.items():
94 | if v is None:
95 | continue
96 | if isinstance(v, torch.Tensor):
97 | v = v.item()
98 | assert isinstance(v, (float, int))
99 | self.meters[k].update(v)
100 |
101 | def __getattr__(self, attr):
102 | if attr in self.meters:
103 | return self.meters[attr]
104 | if attr in self.__dict__:
105 | return self.__dict__[attr]
106 | raise AttributeError("'{}' object has no attribute '{}'".format(
107 | type(self).__name__, attr))
108 |
109 | def __str__(self):
110 | loss_str = []
111 | for name, meter in self.meters.items():
112 | loss_str.append(
113 | "{}: {}".format(name, str(meter))
114 | )
115 | return self.delimiter.join(loss_str)
116 |
117 | def synchronize_between_processes(self):
118 | for meter in self.meters.values():
119 | meter.synchronize_between_processes()
120 |
121 | def add_meter(self, name, meter):
122 | self.meters[name] = meter
123 |
124 | def log_every(self, iterable, print_freq, header=None):
125 | i = 0
126 | if not header:
127 | header = ''
128 | start_time = time.time()
129 | end = time.time()
130 | iter_time = SmoothedValue(fmt='{avg:.4f}')
131 | data_time = SmoothedValue(fmt='{avg:.4f}')
132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
133 | log_msg = [
134 | header,
135 | '[{0' + space_fmt + '}/{1}]',
136 | 'eta: {eta}',
137 | '{meters}',
138 | 'time: {time}',
139 | 'data: {data}'
140 | ]
141 | if torch.cuda.is_available():
142 | log_msg.append('max mem: {memory:.0f}')
143 | log_msg = self.delimiter.join(log_msg)
144 | MB = 1024.0 * 1024.0
145 | for obj in iterable:
146 | data_time.update(time.time() - end)
147 | yield obj
148 | iter_time.update(time.time() - end)
149 | if i % print_freq == 0 or i == len(iterable) - 1:
150 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
152 | if torch.cuda.is_available():
153 | print(log_msg.format(
154 | i, len(iterable), eta=eta_string,
155 | meters=str(self),
156 | time=str(iter_time), data=str(data_time),
157 | memory=torch.cuda.max_memory_allocated() / MB))
158 | else:
159 | print(log_msg.format(
160 | i, len(iterable), eta=eta_string,
161 | meters=str(self),
162 | time=str(iter_time), data=str(data_time)))
163 | i += 1
164 | end = time.time()
165 | total_time = time.time() - start_time
166 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
167 | print('{} Total time: {} ({:.4f} s / it)'.format(
168 | header, total_time_str, total_time / len(iterable)))
169 |
170 |
171 | class TensorboardLogger(object):
172 | def __init__(self, log_dir):
173 | self.writer = SummaryWriter(logdir=log_dir)
174 | self.step = 0
175 |
176 | def set_step(self, step=None):
177 | if step is not None:
178 | self.step = step
179 | else:
180 | self.step += 1
181 |
182 | def update(self, head='scalar', step=None, **kwargs):
183 | for k, v in kwargs.items():
184 | if v is None:
185 | continue
186 | if isinstance(v, torch.Tensor):
187 | v = v.item()
188 | assert isinstance(v, (float, int))
189 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
190 |
191 | def flush(self):
192 | self.writer.flush()
193 |
194 |
195 | class WandbLogger(object):
196 | def __init__(self, args):
197 | self.args = args
198 |
199 | try:
200 | import wandb
201 | self._wandb = wandb
202 | except ImportError:
203 | raise ImportError(
204 | "To use the Weights and Biases Logger please install wandb."
205 | "Run `pip install wandb` to install it."
206 | )
207 |
208 | # Initialize a W&B run
209 | if self._wandb.run is None:
210 | self._wandb.init(
211 | project=args.project,
212 | config=args
213 | )
214 |
215 | def log_epoch_metrics(self, metrics, commit=True):
216 | """
217 | Log train/test metrics onto W&B.
218 | """
219 | # Log number of model parameters as W&B summary
220 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
221 | metrics.pop('n_parameters', None)
222 |
223 | # Log current epoch
224 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
225 | metrics.pop('epoch')
226 |
227 | for k, v in metrics.items():
228 | if 'train' in k:
229 | self._wandb.log({f'Global Train/{k}': v}, commit=False)
230 | elif 'test' in k:
231 | self._wandb.log({f'Global Test/{k}': v}, commit=False)
232 |
233 | self._wandb.log({})
234 |
235 | def log_checkpoints(self):
236 | output_dir = self.args.output_dir
237 | model_artifact = self._wandb.Artifact(
238 | self._wandb.run.id + "_model", type="model"
239 | )
240 |
241 | model_artifact.add_dir(output_dir)
242 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
243 |
244 | def set_steps(self):
245 | # Set global training step
246 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
247 | # Set epoch-wise step
248 | self._wandb.define_metric('Global Train/*', step_metric='epoch')
249 | self._wandb.define_metric('Global Test/*', step_metric='epoch')
250 |
251 |
252 | def setup_for_distributed(is_master):
253 | """
254 | This function disables printing when not in master process
255 | """
256 | import builtins as __builtin__
257 | builtin_print = __builtin__.print
258 |
259 | def print(*args, **kwargs):
260 | force = kwargs.pop('force', False)
261 | if is_master or force:
262 | builtin_print(*args, **kwargs)
263 |
264 | __builtin__.print = print
265 |
266 |
267 | def is_dist_avail_and_initialized():
268 | if not dist.is_available():
269 | return False
270 | if not dist.is_initialized():
271 | return False
272 | return True
273 |
274 |
275 | def get_world_size():
276 | if not is_dist_avail_and_initialized():
277 | return 1
278 | return dist.get_world_size()
279 |
280 |
281 | def get_rank():
282 | if not is_dist_avail_and_initialized():
283 | return 0
284 | return dist.get_rank()
285 |
286 |
287 | def is_main_process():
288 | return get_rank() == 0
289 |
290 |
291 | def save_on_master(*args, **kwargs):
292 | if is_main_process():
293 | torch.save(*args, **kwargs)
294 |
295 |
296 | def init_distributed_mode(args):
297 |
298 | if args.dist_on_itp:
299 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
300 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
301 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
302 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
303 | os.environ['LOCAL_RANK'] = str(args.gpu)
304 | os.environ['RANK'] = str(args.rank)
305 | os.environ['WORLD_SIZE'] = str(args.world_size)
306 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
307 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
308 | args.rank = int(os.environ["RANK"])
309 | args.world_size = int(os.environ['WORLD_SIZE'])
310 | args.gpu = int(os.environ['LOCAL_RANK'])
311 | elif 'SLURM_PROCID' in os.environ:
312 | args.rank = int(os.environ['SLURM_PROCID'])
313 | args.gpu = args.rank % torch.cuda.device_count()
314 |
315 | os.environ['RANK'] = str(args.rank)
316 | os.environ['LOCAL_RANK'] = str(args.gpu)
317 | os.environ['WORLD_SIZE'] = str(args.world_size)
318 | else:
319 | print('Not using distributed mode')
320 | args.distributed = False
321 | return
322 |
323 | args.distributed = True
324 |
325 | torch.cuda.set_device(args.gpu)
326 | args.dist_backend = 'nccl'
327 | print('| distributed init (rank {}): {}, gpu {}'.format(
328 | args.rank, args.dist_url, args.gpu), flush=True)
329 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
330 | world_size=args.world_size, rank=args.rank)
331 | torch.distributed.barrier()
332 | setup_for_distributed(args.rank == 0)
333 |
334 |
335 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
336 | missing_keys = []
337 | unexpected_keys = []
338 | error_msgs = []
339 | # copy state_dict so _load_from_state_dict can modify it
340 | metadata = getattr(state_dict, '_metadata', None)
341 | state_dict = state_dict.copy()
342 | if metadata is not None:
343 | state_dict._metadata = metadata
344 |
345 | def load(module, prefix=''):
346 | local_metadata = {} if metadata is None else metadata.get(
347 | prefix[:-1], {})
348 | module._load_from_state_dict(
349 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
350 | for name, child in module._modules.items():
351 | if child is not None:
352 | load(child, prefix + name + '.')
353 |
354 | load(model, prefix=prefix)
355 |
356 | warn_missing_keys = []
357 | ignore_missing_keys = []
358 | for key in missing_keys:
359 | keep_flag = True
360 | for ignore_key in ignore_missing.split('|'):
361 | if ignore_key in key:
362 | keep_flag = False
363 | break
364 | if keep_flag:
365 | warn_missing_keys.append(key)
366 | else:
367 | ignore_missing_keys.append(key)
368 |
369 | missing_keys = warn_missing_keys
370 |
371 | if len(missing_keys) > 0:
372 | print("Weights of {} not initialized from pretrained model: {}".format(
373 | model.__class__.__name__, missing_keys))
374 | if len(unexpected_keys) > 0:
375 | print("Weights from pretrained model not used in {}: {}".format(
376 | model.__class__.__name__, unexpected_keys))
377 | if len(ignore_missing_keys) > 0:
378 | print("Ignored weights of {} not initialized from pretrained model: {}".format(
379 | model.__class__.__name__, ignore_missing_keys))
380 | if len(error_msgs) > 0:
381 | print('\n'.join(error_msgs))
382 |
383 |
384 | class NativeScalerWithGradNormCount:
385 | state_dict_key = "amp_scaler"
386 |
387 | def __init__(self):
388 | self._scaler = torch.cuda.amp.GradScaler()
389 |
390 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
391 | self._scaler.scale(loss).backward(create_graph=create_graph)
392 | if update_grad:
393 | if clip_grad is not None:
394 | assert parameters is not None
395 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
396 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
397 | else:
398 | self._scaler.unscale_(optimizer)
399 | norm = get_grad_norm_(parameters)
400 | self._scaler.step(optimizer)
401 | self._scaler.update()
402 | else:
403 | norm = None
404 | return norm
405 |
406 | def state_dict(self):
407 | return self._scaler.state_dict()
408 |
409 | def load_state_dict(self, state_dict):
410 | self._scaler.load_state_dict(state_dict)
411 |
412 |
413 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
414 | if isinstance(parameters, torch.Tensor):
415 | parameters = [parameters]
416 | parameters = [p for p in parameters if p.grad is not None]
417 | norm_type = float(norm_type)
418 | if len(parameters) == 0:
419 | return torch.tensor(0.)
420 | device = parameters[0].grad.device
421 | if norm_type == inf:
422 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
423 | else:
424 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
425 | return total_norm
426 |
427 |
428 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
429 | start_warmup_value=0, warmup_steps=-1):
430 | warmup_schedule = np.array([])
431 | warmup_iters = warmup_epochs * niter_per_ep
432 | if warmup_steps > 0:
433 | warmup_iters = warmup_steps
434 | print("Set warmup steps = %d" % warmup_iters)
435 | if warmup_epochs > 0:
436 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
437 |
438 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
439 | schedule = np.array(
440 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
441 |
442 | schedule = np.concatenate((warmup_schedule, schedule))
443 |
444 | assert len(schedule) == epochs * niter_per_ep
445 | return schedule
446 |
447 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
448 | output_dir = Path(args.output_dir)
449 | epoch_name = str(epoch)
450 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
451 | for checkpoint_path in checkpoint_paths:
452 | to_save = {
453 | 'model': model_without_ddp.state_dict(),
454 | 'optimizer': optimizer.state_dict(),
455 | 'epoch': epoch,
456 | 'scaler': loss_scaler.state_dict(),
457 | 'args': args,
458 | }
459 |
460 | if model_ema is not None:
461 | to_save['model_ema'] = get_state_dict(model_ema)
462 |
463 | save_on_master(to_save, checkpoint_path)
464 |
465 | if is_main_process() and isinstance(epoch, int):
466 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
467 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
468 | if os.path.exists(old_ckpt):
469 | os.remove(old_ckpt)
470 |
471 |
472 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
473 | output_dir = Path(args.output_dir)
474 | if args.auto_resume and len(args.resume) == 0:
475 | import glob
476 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
477 | latest_ckpt = -1
478 | for ckpt in all_checkpoints:
479 | t = ckpt.split('-')[-1].split('.')[0]
480 | if t.isdigit():
481 | latest_ckpt = max(int(t), latest_ckpt)
482 | if latest_ckpt >= 0:
483 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
484 | print("Auto resume checkpoint: %s" % args.resume)
485 |
486 | if args.resume:
487 | if args.resume.startswith('https'):
488 | checkpoint = torch.hub.load_state_dict_from_url(
489 | args.resume, map_location='cpu', check_hash=True)
490 | else:
491 | checkpoint = torch.load(args.resume, map_location='cpu')
492 | if 'model' in checkpoint:
493 | model_without_ddp.load_state_dict(checkpoint['model'])
494 | elif 'state_dict' in checkpoint:
495 | model_without_ddp.load_state_dict(checkpoint['state_dict'])
496 | else:
497 | model_without_ddp.load_state_dict(checkpoint)
498 | print("Resume checkpoint %s" % args.resume)
499 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
500 | optimizer.load_state_dict(checkpoint['optimizer'])
501 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
502 | args.start_epoch = checkpoint['epoch'] + 1
503 | else:
504 | assert args.eval, 'Does not support resuming with checkpoint-best'
505 | if hasattr(args, 'model_ema') and args.model_ema:
506 | if 'model_ema' in checkpoint.keys():
507 | model_ema.ema.load_state_dict(checkpoint['model_ema'])
508 | else:
509 | model_ema.ema.load_state_dict(checkpoint['model'])
510 | if 'scaler' in checkpoint:
511 | loss_scaler.load_state_dict(checkpoint['scaler'])
512 | print("With optim & sched!")
513 |
--------------------------------------------------------------------------------
/models/cnn_core/mobilenet_v3.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from functools import partial
3 | from typing import Any, Callable, List, Optional, Sequence
4 |
5 | import torch
6 | from torch import nn, Tensor
7 |
8 | from torchvision._internally_replaced_utils import load_state_dict_from_url
9 | from torchvision.ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
10 | # from torchvision.utils import _log_api_usage_once
11 | from torchvision.models._utils import _make_divisible
12 |
13 |
14 | # __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"]
15 |
16 |
17 | model_urls = {
18 | "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
19 | "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
20 | }
21 |
22 |
23 | class SqueezeExcitation(SElayer):
24 | """DEPRECATED"""
25 |
26 | def __init__(self, input_channels: int, squeeze_factor: int = 4):
27 | squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
28 | super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
29 | self.relu = self.activation
30 | delattr(self, "activation")
31 | warnings.warn(
32 | "This SqueezeExcitation class is deprecated since 0.12 and will be removed in 0.14. "
33 | "Use torchvision.ops.SqueezeExcitation instead.",
34 | FutureWarning,
35 | )
36 |
37 |
38 | class InvertedResidualConfig:
39 | # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
40 | def __init__(
41 | self,
42 | input_channels: int,
43 | kernel: int,
44 | expanded_channels: int,
45 | out_channels: int,
46 | use_se: bool,
47 | activation: str,
48 | stride: int,
49 | dilation: int,
50 | width_mult: float,
51 | ):
52 | self.input_channels = self.adjust_channels(input_channels, width_mult)
53 | self.kernel = kernel
54 | self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
55 | self.out_channels = self.adjust_channels(out_channels, width_mult)
56 | self.use_se = use_se
57 | self.use_hs = activation == "HS"
58 | self.stride = stride
59 | self.dilation = dilation
60 |
61 | @staticmethod
62 | def adjust_channels(channels: int, width_mult: float):
63 | return _make_divisible(channels * width_mult, 8)
64 |
65 |
66 | class InvertedResidual(nn.Module):
67 | # Implemented as described at section 5 of MobileNetV3 paper
68 | def __init__(
69 | self,
70 | cnf: InvertedResidualConfig,
71 | norm_layer: Callable[..., nn.Module],
72 | se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid),
73 | ):
74 | super().__init__()
75 | if not (1 <= cnf.stride <= 2):
76 | raise ValueError("illegal stride value")
77 |
78 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
79 |
80 | layers: List[nn.Module] = []
81 | activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
82 |
83 | self.c_in = cnf.input_channels
84 | self.c_out = cnf.out_channels
85 | # expand
86 | self.has_conv1 = cnf.expanded_channels != cnf.input_channels
87 | if cnf.expanded_channels != cnf.input_channels:
88 | layers.append(
89 | ConvNormActivation(
90 | cnf.input_channels,
91 | cnf.expanded_channels,
92 | kernel_size=1,
93 | norm_layer=norm_layer,
94 | activation_layer=activation_layer,
95 | )
96 | )
97 | self.conv1_flops_per_pixel = cnf.input_channels * cnf.expanded_channels
98 |
99 | # depthwise
100 | stride = 1 if cnf.dilation > 1 else cnf.stride
101 | layers.append(
102 | ConvNormActivation(
103 | cnf.expanded_channels,
104 | cnf.expanded_channels,
105 | kernel_size=cnf.kernel,
106 | stride=stride,
107 | dilation=cnf.dilation,
108 | groups=cnf.expanded_channels,
109 | norm_layer=norm_layer,
110 | activation_layer=activation_layer,
111 | )
112 | )
113 | self.conv2_flops_per_pixel = cnf.expanded_channels * cnf.kernel**2
114 | if cnf.use_se:
115 | squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
116 | layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
117 | self.se_flops = cnf.expanded_channels * squeeze_channels * 2
118 |
119 | # project
120 | layers.append(
121 | ConvNormActivation(
122 | cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
123 | )
124 | )
125 | self.conv3_flops_per_pixel = cnf.expanded_channels * cnf.out_channels
126 |
127 | self.block = nn.ModuleList(layers)
128 | self.out_channels = cnf.out_channels
129 | self._is_cn = cnf.stride > 1
130 |
131 | def forward(self, input: Tensor) -> Tensor:
132 | residual = input
133 | result = input
134 | for i in range(len(self.block)):
135 | result = self.block[i](result)
136 | if self.use_res_connect:
137 | result += residual
138 | return result
139 |
140 | def forward_calc_flops(self, input: Tensor) -> Tensor:
141 |
142 | # print(len(self.block))
143 | # if len(self.block) == 4:
144 | # print(self.block)
145 | input, flops = input
146 | if len(self.block) == 2:
147 | result = self.block[0](input)
148 | flops += self.conv2_flops_per_pixel * result.shape[2] * result.shape[3]
149 |
150 | result = self.block[1](result)
151 | flops += self.conv3_flops_per_pixel * result.shape[2] * result.shape[3]
152 |
153 | elif len(self.block) == 3:
154 | result = self.block[0](input)
155 | flops += self.conv1_flops_per_pixel * result.shape[2] * result.shape[3]
156 |
157 | result = self.block[1](result)
158 | flops += self.conv2_flops_per_pixel * result.shape[2] * result.shape[3]
159 |
160 | result = self.block[2](result)
161 | flops += self.conv3_flops_per_pixel * result.shape[2] * result.shape[3]
162 |
163 | else:
164 | result = self.block[0](input)
165 | flops += self.conv1_flops_per_pixel * result.shape[2] * result.shape[3]
166 |
167 | result = self.block[1](result)
168 | flops += self.conv2_flops_per_pixel * result.shape[2] * result.shape[3]
169 |
170 | # se layer
171 | flops += result.shape[1] * result.shape[2] * result.shape[3] # global pooling
172 | result = self.block[2](result)
173 | flops += self.se_flops
174 |
175 | result = self.block[3](result)
176 | flops += self.conv3_flops_per_pixel * result.shape[2] * result.shape[3]
177 |
178 | # result = self.block(input)
179 | if self.use_res_connect:
180 | result += input
181 | return result, flops
182 |
183 |
184 | class MobileNetV3(nn.Module):
185 | def __init__(
186 | self,
187 | inverted_residual_setting: List[InvertedResidualConfig],
188 | last_channel: int,
189 | num_classes: int = 1000,
190 | block: Optional[Callable[..., nn.Module]] = None,
191 | norm_layer: Optional[Callable[..., nn.Module]] = None,
192 | dropout: float = 0.2,
193 | **kwargs: Any,
194 | ) -> None:
195 | """
196 | MobileNet V3 main class
197 |
198 | Args:
199 | inverted_residual_setting (List[InvertedResidualConfig]): Network structure
200 | last_channel (int): The number of channels on the penultimate layer
201 | num_classes (int): Number of classes
202 | block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
203 | norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
204 | dropout (float): The droupout probability
205 | """
206 | super().__init__()
207 | # _log_api_usage_once(self)
208 |
209 | if not inverted_residual_setting:
210 | raise ValueError("The inverted_residual_setting should not be empty")
211 | elif not (
212 | isinstance(inverted_residual_setting, Sequence)
213 | and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
214 | ):
215 | raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
216 |
217 | if block is None:
218 | block = InvertedResidual
219 |
220 | if norm_layer is None:
221 | norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
222 |
223 | layers: List[nn.Module] = []
224 |
225 | # building first layer
226 | firstconv_output_channels = inverted_residual_setting[0].input_channels
227 | layers.append(
228 | ConvNormActivation(
229 | 3,
230 | firstconv_output_channels,
231 | kernel_size=3,
232 | stride=2,
233 | norm_layer=norm_layer,
234 | activation_layer=nn.Hardswish,
235 | )
236 | )
237 | self.stem_flops_per_pixel = 3 * firstconv_output_channels * 9
238 |
239 | # building inverted residual blocks
240 | for cnf in inverted_residual_setting:
241 | layers.append(block(cnf, norm_layer))
242 |
243 | # building last several layers
244 | lastconv_input_channels = inverted_residual_setting[-1].out_channels
245 | self.lastconv_input_channels = lastconv_input_channels
246 | lastconv_output_channels = 6 * lastconv_input_channels
247 | layers.append(
248 | ConvNormActivation(
249 | lastconv_input_channels,
250 | lastconv_output_channels,
251 | kernel_size=1,
252 | norm_layer=norm_layer,
253 | activation_layer=nn.Hardswish,
254 | )
255 | )
256 | self.tail_flops_per_pixel = lastconv_input_channels * lastconv_output_channels
257 | self.lastconv_output_channels = lastconv_output_channels
258 | self.last_channel = last_channel
259 |
260 | self.features = nn.ModuleList(layers)
261 | self.avgpool = nn.AdaptiveAvgPool2d(1)
262 | self.classifier = nn.Sequential(
263 | nn.Linear(lastconv_output_channels, last_channel),
264 | nn.Hardswish(inplace=True),
265 | nn.Dropout(p=dropout, inplace=True),
266 | nn.Linear(last_channel, num_classes),
267 | )
268 | self.classifier_flops = lastconv_output_channels * last_channel + last_channel * num_classes
269 | for m in self.modules():
270 | if isinstance(m, nn.Conv2d):
271 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
272 | if m.bias is not None:
273 | nn.init.zeros_(m.bias)
274 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
275 | nn.init.ones_(m.weight)
276 | nn.init.zeros_(m.bias)
277 | elif isinstance(m, nn.Linear):
278 | nn.init.normal_(m.weight, 0, 0.01)
279 | nn.init.zeros_(m.bias)
280 |
281 | def forward(self, x: Tensor) -> Tensor:
282 | print('input', x.shape)
283 | x = self.features[0](x)
284 | flops = self.stem_flops_per_pixel * x.shape[2] * x.shape[3]
285 | print('after features[0]', x.shape)
286 | x = (x, flops)
287 | for i in range(1, len(self.features)-1):
288 | print(f'after features[{i}], {x[0].shape}')
289 | x = self.features[i].forward_calc_flops(x)
290 | x, flops = x
291 |
292 | print('before the last block in features', x.shape)
293 | x = self.features[-1](x)
294 | flops += self.tail_flops_per_pixel * x.shape[2] * x.shape[3]
295 |
296 | flops += x.shape[1] * x.shape[2] * x.shape[3]
297 | x = self.avgpool(x)
298 | x = torch.flatten(x, 1)
299 |
300 | x = self.classifier(x)
301 | flops += self.classifier_flops
302 | print(flops/1e9)
303 | return x
304 |
305 | # def forward(self, x: Tensor) -> Tensor:
306 | # return self._forward_impl(x)
307 |
308 |
309 | def _mobilenet_v3_conf(
310 | arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any
311 | ):
312 | reduce_divider = 2 if reduced_tail else 1
313 | dilation = 2 if dilated else 1
314 |
315 | bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
316 | adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
317 |
318 | if arch == "mobilenet_v3_large":
319 | inverted_residual_setting = [
320 | bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
321 | bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
322 | bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
323 | bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
324 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
325 | bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
326 | bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
327 | bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
328 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
329 | bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
330 | bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
331 | bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
332 | bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
333 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
334 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
335 | ]
336 | last_channel = adjust_channels(1280 // reduce_divider) # C5
337 | elif arch == "mobilenet_v3_small":
338 | inverted_residual_setting = [
339 | bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
340 | bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
341 | bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
342 | bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
343 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
344 | bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
345 | bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
346 | bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
347 | bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
348 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
349 | bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
350 | ]
351 | last_channel = adjust_channels(1024 // reduce_divider) # C5
352 | else:
353 | raise ValueError(f"Unsupported model type {arch}")
354 |
355 | return inverted_residual_setting, last_channel
356 |
357 |
358 | def _mobilenet_v3(
359 | arch: str,
360 | inverted_residual_setting: List[InvertedResidualConfig],
361 | last_channel: int,
362 | pretrained: bool,
363 | progress: bool,
364 | **kwargs: Any,
365 | ):
366 | model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
367 | if pretrained:
368 | if model_urls.get(arch, None) is None:
369 | raise ValueError(f"No checkpoint is available for model type {arch}")
370 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
371 | model.load_state_dict(state_dict)
372 | return model
373 |
374 |
375 | def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
376 | """
377 | Constructs a large MobileNetV3 architecture from
378 | `"Searching for MobileNetV3" `_.
379 |
380 | Args:
381 | pretrained (bool): If True, returns a model pre-trained on ImageNet
382 | progress (bool): If True, displays a progress bar of the download to stderr
383 | """
384 | arch = "mobilenet_v3_large"
385 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
386 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
387 |
388 | def mobilenet_v3_large_1x25(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
389 | """
390 | Constructs a large MobileNetV3 architecture from
391 | `"Searching for MobileNetV3" `_.
392 |
393 | Args:
394 | pretrained (bool): If True, returns a model pre-trained on ImageNet
395 | progress (bool): If True, displays a progress bar of the download to stderr
396 | """
397 | arch = "mobilenet_v3_large"
398 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=1.25, **kwargs)
399 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
400 |
401 | def mobilenet_v3_large_1x5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
402 | """
403 | Constructs a large MobileNetV3 architecture from
404 | `"Searching for MobileNetV3" `_.
405 |
406 | Args:
407 | pretrained (bool): If True, returns a model pre-trained on ImageNet
408 | progress (bool): If True, displays a progress bar of the download to stderr
409 | """
410 | arch = "mobilenet_v3_large"
411 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=1.5, **kwargs)
412 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
413 |
414 | def mobilenet_v3_large_2x0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
415 | """
416 | Constructs a large MobileNetV3 architecture from
417 | `"Searching for MobileNetV3" `_.
418 |
419 | Args:
420 | pretrained (bool): If True, returns a model pre-trained on ImageNet
421 | progress (bool): If True, displays a progress bar of the download to stderr
422 | """
423 | arch = "mobilenet_v3_large"
424 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=2.0, **kwargs)
425 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
426 |
427 | def mobilenet_v3_large_0x75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
428 | """
429 | Constructs a large MobileNetV3 architecture from
430 | `"Searching for MobileNetV3" `_.
431 |
432 | Args:
433 | pretrained (bool): If True, returns a model pre-trained on ImageNet
434 | progress (bool): If True, displays a progress bar of the download to stderr
435 | """
436 | arch = "mobilenet_v3_large"
437 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=0.75, **kwargs)
438 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
439 |
440 | def mobilenet_v3_large_0x5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
441 | """
442 | Constructs a large MobileNetV3 architecture from
443 | `"Searching for MobileNetV3" `_.
444 |
445 | Args:
446 | pretrained (bool): If True, returns a model pre-trained on ImageNet
447 | progress (bool): If True, displays a progress bar of the download to stderr
448 | """
449 | arch = "mobilenet_v3_large"
450 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, width_mult=0.5, **kwargs)
451 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
452 |
453 |
454 | def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
455 | """
456 | Constructs a small MobileNetV3 architecture from
457 | `"Searching for MobileNetV3" `_.
458 |
459 | Args:
460 | pretrained (bool): If True, returns a model pre-trained on ImageNet
461 | progress (bool): If True, displays a progress bar of the download to stderr
462 | """
463 | arch = "mobilenet_v3_small"
464 | inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
465 | return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
466 |
467 | if __name__ == '__main__':
468 | model = mobilenet_v3_large_2x0()
469 | x = torch.rand(1,3,224,224)
470 |
471 | y = model(x)
472 |
473 | # from op_counter import measure_model
474 | # cls_ops, cls_params = measure_model(model, 224, 224)
475 | # print(cls_ops[-1]/1e9)
--------------------------------------------------------------------------------
/models/cnn_core/regnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 | from functools import partial
4 | from typing import Any, Callable, List, Optional, Tuple
5 |
6 | import torch
7 | from torch import nn, Tensor
8 |
9 | from torchvision._internally_replaced_utils import load_state_dict_from_url
10 | from torchvision.ops.misc import ConvNormActivation, SqueezeExcitation
11 | # from torchvision.utils import _log_api_usage_once
12 | from torchvision.models._utils import _make_divisible
13 |
14 | model_urls = {
15 | "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
16 | "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
17 | "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
18 | "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
19 | "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
20 | "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
21 | "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
22 | "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
23 | "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
24 | "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
25 | "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
26 | "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
27 | "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
28 | "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
29 | }
30 |
31 | class SimpleStemIN(ConvNormActivation):
32 | """Simple stem for ImageNet: 3x3, BN, ReLU."""
33 |
34 | def __init__(
35 | self,
36 | width_in: int,
37 | width_out: int,
38 | norm_layer: Callable[..., nn.Module],
39 | activation_layer: Callable[..., nn.Module],
40 | ) -> None:
41 | self.width_out = width_out
42 | super().__init__(
43 | width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer
44 | )
45 |
46 | class BottleneckTransform(nn.Sequential):
47 | """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""
48 |
49 | def __init__(
50 | self,
51 | width_in: int,
52 | width_out: int,
53 | stride: int,
54 | norm_layer: Callable[..., nn.Module],
55 | activation_layer: Callable[..., nn.Module],
56 | group_width: int,
57 | bottleneck_multiplier: float,
58 | se_ratio: Optional[float],
59 | ) -> None:
60 | self.width_in = width_in
61 | self.width_out = width_out
62 | layers: OrderedDict[str, nn.Module] = OrderedDict()
63 | w_b = int(round(width_out * bottleneck_multiplier))
64 | g = w_b // group_width
65 |
66 | layers["a"] = ConvNormActivation(
67 | width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
68 | )
69 | layers["b"] = ConvNormActivation(
70 | w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
71 | )
72 |
73 | if se_ratio:
74 | # The SE reduction ratio is defined with respect to the
75 | # beginning of the block
76 | width_se_out = int(round(se_ratio * width_in))
77 | layers["se"] = SqueezeExcitation(
78 | input_channels=w_b,
79 | squeeze_channels=width_se_out,
80 | activation=activation_layer,
81 | )
82 |
83 | layers["c"] = ConvNormActivation(
84 | w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
85 | )
86 | super().__init__(layers)
87 |
88 |
89 | class ResBottleneckBlock(nn.Module):
90 | """Residual bottleneck block: x + F(x), F = bottleneck transform."""
91 |
92 | def __init__(
93 | self,
94 | width_in: int,
95 | width_out: int,
96 | stride: int,
97 | norm_layer: Callable[..., nn.Module],
98 | activation_layer: Callable[..., nn.Module],
99 | group_width: int = 1,
100 | bottleneck_multiplier: float = 1.0,
101 | se_ratio: Optional[float] = None,
102 | ) -> None:
103 | super().__init__()
104 |
105 | # Use skip connection with projection if shape changes
106 | self.proj = None
107 | should_proj = (width_in != width_out) or (stride != 1)
108 | if should_proj:
109 | self.proj = ConvNormActivation(
110 | width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
111 | )
112 | self.f = BottleneckTransform(
113 | width_in,
114 | width_out,
115 | stride,
116 | norm_layer,
117 | activation_layer,
118 | group_width,
119 | bottleneck_multiplier,
120 | se_ratio,
121 | )
122 | self.activation = activation_layer(inplace=True)
123 | w_b = int(round(width_out * bottleneck_multiplier))
124 | g = w_b // group_width
125 | self.conv1_flops_per_pixel = width_in*w_b
126 | self.conv2_flops_per_pixel = w_b*w_b*9 // g
127 | self.conv3_flops_per_pixel = w_b*width_out
128 | self.stride=stride
129 |
130 | if self.proj is not None:
131 | self.downsample_flops = width_in * width_out
132 |
133 | def forward(self, x: Tensor) -> Tensor:
134 | identity = self.proj(x) if self.proj is not None else x
135 |
136 | x = self.f(x)
137 | x = self.activation(x + identity)
138 | return x
139 |
140 |
141 | def forward_calc_flops(self, x: Tensor) -> Tensor:
142 | x, flops = x
143 | if self.proj is not None:
144 | identity = self.proj(x)
145 | flops += self.downsample_flops * identity.shape[2] * identity.shape[3]
146 | else:
147 | identity = x
148 |
149 | x = self.f(x)
150 |
151 | b,c,h,w = x.shape
152 | h_conv1 = h * self.stride
153 |
154 | flops += (self.conv1_flops_per_pixel * h_conv1 * h_conv1 + self.conv2_flops_per_pixel * h * w + self.conv3_flops_per_pixel * h * w)
155 |
156 | x = self.activation(x + identity)
157 | return x, flops
158 |
159 |
160 | class AnyStage(nn.ModuleList):
161 | """AnyNet stage (sequence of blocks w/ the same output shape)."""
162 |
163 | def __init__(
164 | self,
165 | width_in: int,
166 | width_out: int,
167 | stride: int,
168 | depth: int,
169 | block_constructor: Callable[..., nn.Module],
170 | norm_layer: Callable[..., nn.Module],
171 | activation_layer: Callable[..., nn.Module],
172 | group_width: int,
173 | bottleneck_multiplier: float,
174 | se_ratio: Optional[float] = None,
175 | stage_index: int = 0,
176 | ) -> None:
177 | super().__init__()
178 |
179 | self.c_in = width_in
180 | self.c_out = width_out
181 |
182 | for i in range(depth):
183 | block = block_constructor(
184 | width_in if i == 0 else width_out,
185 | width_out,
186 | stride if i == 0 else 1,
187 | norm_layer,
188 | activation_layer,
189 | group_width,
190 | bottleneck_multiplier,
191 | se_ratio,
192 | )
193 |
194 | self.add_module(f"block{stage_index}-{i}", block)
195 |
196 | def forward(self, x):
197 | for block in self:
198 | x = block(x)
199 | return x
200 |
201 | def forward_calc_flops(self, x):
202 | flops = 0
203 | x = (x, flops)
204 | for block in self:
205 | x = block.forward_calc_flops(x)
206 |
207 | # x, flops = x
208 | return x # should be a tuple of (x, flops)
209 |
210 |
211 | class BlockParams:
212 | def __init__(
213 | self,
214 | depths: List[int],
215 | widths: List[int],
216 | group_widths: List[int],
217 | bottleneck_multipliers: List[float],
218 | strides: List[int],
219 | se_ratio: Optional[float] = None,
220 | ) -> None:
221 | self.depths = depths
222 | self.widths = widths
223 | self.group_widths = group_widths
224 | self.bottleneck_multipliers = bottleneck_multipliers
225 | self.strides = strides
226 | self.se_ratio = se_ratio
227 |
228 | @classmethod
229 | def from_init_params(
230 | cls,
231 | depth: int,
232 | w_0: int,
233 | w_a: float,
234 | w_m: float,
235 | group_width: int,
236 | bottleneck_multiplier: float = 1.0,
237 | se_ratio: Optional[float] = None,
238 | **kwargs: Any,
239 | ) -> "BlockParams":
240 | """
241 | Programatically compute all the per-block settings,
242 | given the RegNet parameters.
243 |
244 | The first step is to compute the quantized linear block parameters,
245 | in log space. Key parameters are:
246 | - `w_a` is the width progression slope
247 | - `w_0` is the initial width
248 | - `w_m` is the width stepping in the log space
249 |
250 | In other terms
251 | `log(block_width) = log(w_0) + w_m * block_capacity`,
252 | with `bock_capacity` ramping up following the w_0 and w_a params.
253 | This block width is finally quantized to multiples of 8.
254 |
255 | The second step is to compute the parameters per stage,
256 | taking into account the skip connection and the final 1x1 convolutions.
257 | We use the fact that the output width is constant within a stage.
258 | """
259 |
260 | QUANT = 8
261 | STRIDE = 2
262 |
263 | if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
264 | raise ValueError("Invalid RegNet settings")
265 | # Compute the block widths. Each stage has one unique block width
266 | widths_cont = torch.arange(depth) * w_a + w_0
267 | block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m))
268 | block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist()
269 | num_stages = len(set(block_widths))
270 |
271 | # Convert to per stage parameters
272 | split_helper = zip(
273 | block_widths + [0],
274 | [0] + block_widths,
275 | block_widths + [0],
276 | [0] + block_widths,
277 | )
278 | splits = [w != wp or r != rp for w, wp, r, rp in split_helper]
279 |
280 | stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
281 | stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist()
282 |
283 | strides = [STRIDE] * num_stages
284 | bottleneck_multipliers = [bottleneck_multiplier] * num_stages
285 | group_widths = [group_width] * num_stages
286 |
287 | # Adjust the compatibility of stage widths and group widths
288 | stage_widths, group_widths = cls._adjust_widths_groups_compatibilty(
289 | stage_widths, bottleneck_multipliers, group_widths
290 | )
291 |
292 | return cls(
293 | depths=stage_depths,
294 | widths=stage_widths,
295 | group_widths=group_widths,
296 | bottleneck_multipliers=bottleneck_multipliers,
297 | strides=strides,
298 | se_ratio=se_ratio,
299 | )
300 |
301 | def _get_expanded_params(self):
302 | return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers)
303 |
304 | @staticmethod
305 | def _adjust_widths_groups_compatibilty(
306 | stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
307 | ) -> Tuple[List[int], List[int]]:
308 | """
309 | Adjusts the compatibility of widths and groups,
310 | depending on the bottleneck ratio.
311 | """
312 | # Compute all widths for the current settings
313 | widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)]
314 | group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)]
315 |
316 | # Compute the adjusted widths so that stage and group widths fit
317 | ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)]
318 | stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)]
319 | return stage_widths, group_widths_min
320 |
321 | class RegNet(nn.Module):
322 | def __init__(
323 | self,
324 | block_params: BlockParams,
325 | num_classes: int = 1000,
326 | stem_width: int = 32,
327 | stem_type: Optional[Callable[..., nn.Module]] = None,
328 | block_type: Optional[Callable[..., nn.Module]] = None,
329 | norm_layer: Optional[Callable[..., nn.Module]] = None,
330 | activation: Optional[Callable[..., nn.Module]] = None,
331 | ) -> None:
332 | super().__init__()
333 | # _log_api_usage_once(self)
334 |
335 | if stem_type is None:
336 | stem_type = SimpleStemIN
337 | if norm_layer is None:
338 | norm_layer = nn.BatchNorm2d
339 | if block_type is None:
340 | block_type = ResBottleneckBlock
341 | if activation is None:
342 | activation = nn.ReLU
343 |
344 | # Ad hoc stem
345 | self.stem = stem_type(
346 | 3, # width_in
347 | stem_width,
348 | norm_layer,
349 | activation,
350 | )
351 |
352 | current_width = stem_width
353 |
354 | blocks = []
355 |
356 | # print(block_params._get_expanded_params())
357 |
358 | for i, (
359 | width_out,
360 | stride,
361 | depth,
362 | group_width,
363 | bottleneck_multiplier,
364 | ) in enumerate(block_params._get_expanded_params()):
365 | # print(stride, depth)
366 | blocks.append(
367 | (
368 | f"block{i+1}",
369 | AnyStage(
370 | current_width,
371 | width_out,
372 | stride,
373 | depth,
374 | block_type,
375 | norm_layer,
376 | activation,
377 | group_width,
378 | bottleneck_multiplier,
379 | block_params.se_ratio,
380 | stage_index=i + 1,
381 | ),
382 | )
383 | )
384 |
385 | current_width = width_out
386 |
387 | self.trunk_output = nn.Sequential(OrderedDict(blocks))
388 |
389 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
390 | self.fc = nn.Linear(in_features=current_width, out_features=num_classes)
391 |
392 | # Performs ResNet-style weight initialization
393 | for m in self.modules():
394 | if isinstance(m, nn.Conv2d):
395 | # Note that there is no bias due to BN
396 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
397 | nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
398 | elif isinstance(m, nn.BatchNorm2d):
399 | nn.init.ones_(m.weight)
400 | nn.init.zeros_(m.bias)
401 | elif isinstance(m, nn.Linear):
402 | nn.init.normal_(m.weight, mean=0.0, std=0.01)
403 | nn.init.zeros_(m.bias)
404 |
405 | def forward(self, x: Tensor) -> Tensor:
406 | x = self.stem(x)
407 | x = self.trunk_output(x)
408 |
409 | x = self.avgpool(x)
410 | x = x.flatten(start_dim=1)
411 | x = self.fc(x)
412 |
413 | return x
414 |
415 |
416 | def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet:
417 | norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
418 | model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
419 | if pretrained:
420 | if arch not in model_urls:
421 | raise ValueError(f"No checkpoint is available for model type {arch}")
422 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
423 | model.load_state_dict(state_dict)
424 | return model
425 |
426 | def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
427 | """
428 | Constructs a RegNetY_400MF architecture from
429 | `"Designing Network Design Spaces" `_.
430 |
431 | Args:
432 | pretrained (bool): If True, returns a model pre-trained on ImageNet
433 | progress (bool): If True, displays a progress bar of the download to stderr
434 | """
435 | params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
436 | return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs)
437 |
438 |
439 |
440 |
441 | def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
442 | """
443 | Constructs a RegNetY_800MF architecture from
444 | `"Designing Network Design Spaces" `_.
445 |
446 | Args:
447 | pretrained (bool): If True, returns a model pre-trained on ImageNet
448 | progress (bool): If True, displays a progress bar of the download to stderr
449 | """
450 | params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
451 | return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs)
452 |
453 |
454 |
455 | def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
456 | """
457 | Constructs a RegNetY_1.6GF architecture from
458 | `"Designing Network Design Spaces" `_.
459 |
460 | Args:
461 | pretrained (bool): If True, returns a model pre-trained on ImageNet
462 | progress (bool): If True, displays a progress bar of the download to stderr
463 | """
464 | params = BlockParams.from_init_params(
465 | depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
466 | )
467 | return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs)
468 |
469 |
470 |
471 | def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
472 | """
473 | Constructs a RegNetY_3.2GF architecture from
474 | `"Designing Network Design Spaces" `_.
475 |
476 | Args:
477 | pretrained (bool): If True, returns a model pre-trained on ImageNet
478 | progress (bool): If True, displays a progress bar of the download to stderr
479 | """
480 | params = BlockParams.from_init_params(
481 | depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
482 | )
483 | return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs)
484 |
485 |
486 |
487 |
488 | def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
489 | """
490 | Constructs a RegNetY_8GF architecture from
491 | `"Designing Network Design Spaces" `_.
492 |
493 | Args:
494 | pretrained (bool): If True, returns a model pre-trained on ImageNet
495 | progress (bool): If True, displays a progress bar of the download to stderr
496 | """
497 | params = BlockParams.from_init_params(
498 | depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
499 | )
500 | return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs)
501 |
502 |
503 |
504 | def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
505 | """
506 | Constructs a RegNetY_16GF architecture from
507 | `"Designing Network Design Spaces" `_.
508 |
509 | Args:
510 | pretrained (bool): If True, returns a model pre-trained on ImageNet
511 | progress (bool): If True, displays a progress bar of the download to stderr
512 | """
513 | params = BlockParams.from_init_params(
514 | depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
515 | )
516 | return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs)
517 |
518 |
519 |
520 | def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
521 | """
522 | Constructs a RegNetY_32GF architecture from
523 | `"Designing Network Design Spaces" `_.
524 |
525 | Args:
526 | pretrained (bool): If True, returns a model pre-trained on ImageNet
527 | progress (bool): If True, displays a progress bar of the download to stderr
528 | """
529 | params = BlockParams.from_init_params(
530 | depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
531 | )
532 | return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs)
533 |
534 |
535 |
536 | def regnet_y_128gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RegNet:
537 | """
538 | Constructs a RegNetY_128GF architecture from
539 | `"Designing Network Design Spaces" `_.
540 | NOTE: Pretrained weights are not available for this model.
541 | """
542 | params = BlockParams.from_init_params(
543 | depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
544 | )
545 | return _regnet("regnet_y_128gf", params, pretrained, progress, **kwargs)
546 |
547 |
548 | if __name__ == '__main__':
549 | net = regnet_y_400mf(pretrained=False)
550 | print(net.fc.weight.shape)
--------------------------------------------------------------------------------
/main_baseline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import warnings
9 | warnings.filterwarnings('ignore')
10 |
11 | import argparse
12 | import datetime
13 | import numpy as np
14 | import time
15 | import torch
16 | import torch.nn as nn
17 | import torch.backends.cudnn as cudnn
18 | import json
19 | import os
20 |
21 | from pathlib import Path
22 |
23 | from timm.data.mixup import Mixup
24 | from timm.models import create_model
25 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
26 | from timm.utils import ModelEma
27 | from optim_factory import create_optimizer, LayerDecayValueAssigner
28 |
29 | from datasets import build_dataset
30 | from engine import train_one_epoch, evaluate
31 |
32 | from utils import NativeScalerWithGradNormCount as NativeScaler
33 | import utils
34 |
35 | import models
36 |
37 |
38 | def str2bool(v):
39 | """
40 | Converts string to bool type; enables command line
41 | arguments in the format of '--arg1 true --arg2 false'
42 | """
43 | if isinstance(v, bool):
44 | return v
45 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
46 | return True
47 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
48 | return False
49 | else:
50 | raise argparse.ArgumentTypeError('Boolean value expected.')
51 |
52 | def get_args_parser():
53 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False)
54 | parser.add_argument('--batch_size', default=64, type=int,
55 | help='Per GPU batch size')
56 | parser.add_argument('--epochs', default=300, type=int)
57 | parser.add_argument('--update_freq', default=1, type=int,
58 | help='gradient accumulation steps')
59 |
60 | # Model parameters
61 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL',
62 | help='Name of model to train')
63 | parser.add_argument('--classifier_type', type=str, default='cnn', choices=['cnn', 'latent', 'merge'],)
64 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT',
65 | help='Drop path rate (default: 0.0)')
66 | parser.add_argument('--dropout', type=float, default=0)
67 | parser.add_argument('--input_size', default=224, type=int,
68 | help='image input size')
69 | parser.add_argument('--num_latent_channels', type=int, default=None)
70 | parser.add_argument('--SA_widening_factor', type=int, default=1)
71 | parser.add_argument('--num_SA_heads', type=list, default=None)
72 | parser.add_argument('--spatial_reduction', type=str2bool, default=False)
73 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float,
74 | help="Layer scale initial values")
75 |
76 | # EMA related parameters
77 | parser.add_argument('--model_ema', type=str2bool, default=False)
78 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
79 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
80 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
81 |
82 | # Optimization parameters
83 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
84 | help='Optimizer (default: "adamw"')
85 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
86 | help='Optimizer Epsilon (default: 1e-8)')
87 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
88 | help='Optimizer Betas (default: None, use opt default)')
89 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
90 | help='Clip gradient norm (default: None, no clipping)')
91 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
92 | help='SGD momentum (default: 0.9)')
93 | parser.add_argument('--weight_decay', type=float, default=0.05,
94 | help='weight decay (default: 0.05)')
95 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
96 | weight decay. We use a cosine schedule for WD and using a larger decay by
97 | the end of training improves performance for ViTs.""")
98 |
99 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR',
100 | help='learning rate (default: 4e-3), with total batch size 4096')
101 | parser.add_argument('--layer_decay', type=float, default=1.0)
102 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
103 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
104 | parser.add_argument('--warmup_epochs', type=int, default=20, metavar='N',
105 | help='epochs to warmup LR, if scheduler supports')
106 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
107 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
108 |
109 | # Augmentation parameters
110 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
111 | help='Color jitter factor (default: 0.4)')
112 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
113 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
114 | parser.add_argument('--smoothing', type=float, default=0.1,
115 | help='Label smoothing (default: 0.1)')
116 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
117 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
118 |
119 | # Evaluation parameters
120 | parser.add_argument('--crop_pct', type=float, default=None)
121 |
122 | # * Random Erase params
123 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
124 | help='Random erase prob (default: 0.25)')
125 | parser.add_argument('--remode', type=str, default='pixel',
126 | help='Random erase mode (default: "pixel")')
127 | parser.add_argument('--recount', type=int, default=1,
128 | help='Random erase count (default: 1)')
129 | parser.add_argument('--resplit', type=str2bool, default=False,
130 | help='Do not random erase first (clean) augmentation split')
131 |
132 | # * Mixup params
133 | parser.add_argument('--mixup', type=float, default=0.8,
134 | help='mixup alpha, mixup enabled if > 0.')
135 | parser.add_argument('--cutmix', type=float, default=1.0,
136 | help='cutmix alpha, cutmix enabled if > 0.')
137 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
138 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
139 | parser.add_argument('--mixup_prob', type=float, default=1.0,
140 | help='Probability of performing mixup or cutmix when either/both is enabled')
141 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
142 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
143 | parser.add_argument('--mixup_mode', type=str, default='batch',
144 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
145 |
146 | # * Finetuning params
147 | parser.add_argument('--finetune', default='',
148 | help='finetune from checkpoint')
149 | parser.add_argument('--head_init_scale', default=1.0, type=float,
150 | help='classifier head initial scale, typically adjusted in fine-tuning')
151 | parser.add_argument('--model_key', default='model|module|state_dict_ema', type=str,
152 | help='which key to load from saved state dict, usually model or model_ema')
153 | parser.add_argument('--model_prefix', default='', type=str)
154 |
155 | # Dataset parameters
156 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
157 | help='dataset path')
158 | parser.add_argument('--eval_data_path', default=None, type=str,
159 | help='dataset path for evaluation')
160 | parser.add_argument('--nb_classes', default=1000, type=int,
161 | help='number of the classification types')
162 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
163 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'],
164 | type=str, help='ImageNet dataset path')
165 | parser.add_argument('--output_dir', default='',
166 | help='path where to save, empty for no saving')
167 | parser.add_argument('--log_dir', default=None,
168 | help='path where to tensorboard log')
169 | parser.add_argument('--device', default='cuda',
170 | help='device to use for training / testing')
171 | parser.add_argument('--seed', default=0, type=int)
172 |
173 | parser.add_argument('--resume', default='',
174 | help='resume from checkpoint')
175 | parser.add_argument('--auto_resume', type=str2bool, default=True)
176 | parser.add_argument('--save_ckpt', type=str2bool, default=True)
177 | parser.add_argument('--save_ckpt_freq', default=1, type=int)
178 | parser.add_argument('--save_ckpt_num', default=3, type=int)
179 |
180 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
181 | help='start epoch')
182 | parser.add_argument('--eval', type=str2bool, default=False,
183 | help='Perform evaluation only')
184 | parser.add_argument('--dist_eval', type=str2bool, default=True,
185 | help='Enabling distributed evaluation')
186 | parser.add_argument('--disable_eval', type=str2bool, default=False,
187 | help='Disabling evaluation during training')
188 | parser.add_argument('--num_workers', default=10, type=int)
189 | parser.add_argument('--pin_mem', type=str2bool, default=True,
190 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
191 |
192 | # distributed training parameters
193 | parser.add_argument('--world_size', default=1, type=int,
194 | help='number of distributed processes')
195 | parser.add_argument('--local_rank', default=-1, type=int)
196 | parser.add_argument('--dist_on_itp', type=str2bool, default=False)
197 | parser.add_argument('--dist_url', default='env://',
198 | help='url used to set up distributed training')
199 |
200 | parser.add_argument('--use_amp', type=str2bool, default=False,
201 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
202 |
203 | # Weights and Biases arguments
204 | parser.add_argument('--enable_wandb', type=str2bool, default=False,
205 | help="enable logging to Weights and Biases")
206 | parser.add_argument('--project', default='convnext', type=str,
207 | help="The name of the W&B project where you're sending the new run.")
208 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
209 | help="Save model checkpoints as W&B Artifacts.")
210 |
211 | return parser
212 |
213 | def main(args):
214 | utils.init_distributed_mode(args)
215 | print(args)
216 | device = torch.device(args.device)
217 |
218 | # fix the seed for reproducibility
219 | seed = args.seed + utils.get_rank()
220 | torch.manual_seed(seed)
221 | np.random.seed(seed)
222 | cudnn.benchmark = True
223 |
224 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
225 | if args.disable_eval:
226 | args.dist_eval = False
227 | dataset_val = None
228 | else:
229 | dataset_val, _ = build_dataset(is_train=False, args=args)
230 |
231 | num_tasks = utils.get_world_size()
232 | global_rank = utils.get_rank()
233 |
234 | sampler_train = torch.utils.data.DistributedSampler(
235 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
236 | )
237 | print("Sampler_train = %s" % str(sampler_train))
238 | if args.dist_eval:
239 | if len(dataset_val) % num_tasks != 0:
240 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
241 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
242 | 'equal num of samples per-process.')
243 | sampler_val = torch.utils.data.DistributedSampler(
244 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
245 | else:
246 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
247 |
248 | if global_rank == 0 and args.log_dir is not None:
249 | os.makedirs(args.log_dir, exist_ok=True)
250 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
251 | else:
252 | log_writer = None
253 |
254 | if global_rank == 0 and args.enable_wandb:
255 | wandb_logger = utils.WandbLogger(args)
256 | else:
257 | wandb_logger = None
258 |
259 | data_loader_train = torch.utils.data.DataLoader(
260 | dataset_train, sampler=sampler_train,
261 | batch_size=args.batch_size,
262 | num_workers=args.num_workers,
263 | pin_memory=args.pin_mem,
264 | drop_last=True,
265 | )
266 |
267 | if dataset_val is not None:
268 | data_loader_val = torch.utils.data.DataLoader(
269 | dataset_val, sampler=sampler_val,
270 | batch_size=int(1.5 * args.batch_size),
271 | num_workers=args.num_workers,
272 | pin_memory=args.pin_mem,
273 | drop_last=False
274 | )
275 | else:
276 | data_loader_val = None
277 |
278 | mixup_fn = None
279 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
280 | if mixup_active:
281 | print("Mixup is activated!")
282 | mixup_fn = Mixup(
283 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
284 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
285 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
286 |
287 |
288 | model = eval(f'models.{args.model}')()
289 |
290 | if args.finetune:
291 | if args.finetune.startswith('https'):
292 | checkpoint = torch.hub.load_state_dict_from_url(
293 | args.finetune, map_location='cpu', check_hash=True)
294 | else:
295 | checkpoint = torch.load(args.finetune, map_location='cpu')
296 |
297 | print("Load ckpt from %s" % args.finetune)
298 | checkpoint_model = None
299 | for model_key in args.model_key.split('|'):
300 | if model_key in checkpoint:
301 | checkpoint_model = checkpoint[model_key]
302 | print("Load state_dict by model_key = %s" % model_key)
303 | break
304 | if checkpoint_model is None:
305 | checkpoint_model = checkpoint
306 | state_dict = model.state_dict()
307 | for k in ['head.weight', 'head.bias']:
308 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
309 | print(f"Removing key {k} from pretrained checkpoint")
310 | del checkpoint_model[k]
311 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
312 | model.to(device)
313 |
314 | model_ema = None
315 | if args.model_ema:
316 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
317 | model_ema = ModelEma(
318 | model,
319 | decay=args.model_ema_decay,
320 | device='cpu' if args.model_ema_force_cpu else '',
321 | resume='')
322 | print("Using EMA with decay = %.8f" % args.model_ema_decay)
323 |
324 | model_without_ddp = model
325 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
326 |
327 | # print("Model = %s" % str(model_without_ddp))
328 | print('number of params:', n_parameters)
329 |
330 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
331 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
332 | print("LR = %.8f" % args.lr)
333 | print("Batch size = %d" % total_batch_size)
334 | print("Update frequent = %d" % args.update_freq)
335 | print("Number of training examples = %d" % len(dataset_train))
336 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
337 |
338 | if args.layer_decay < 1.0 or args.layer_decay > 1.0:
339 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value.
340 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \
341 | "Layer Decay impl only supports convnext_small/base/large/xlarge"
342 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
343 | else:
344 | assigner = None
345 |
346 | if assigner is not None:
347 | print("Assigned values = %s" % str(assigner.values))
348 |
349 | if args.distributed:
350 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
351 | model_without_ddp = model.module
352 |
353 | optimizer = create_optimizer(
354 | args, model_without_ddp, skip_list=None,
355 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
356 | get_layer_scale=assigner.get_scale if assigner is not None else None)
357 |
358 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used
359 |
360 | print("Use Cosine LR scheduler")
361 | lr_schedule_values = utils.cosine_scheduler(
362 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
363 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
364 | )
365 |
366 | if args.weight_decay_end is None:
367 | args.weight_decay_end = args.weight_decay
368 | wd_schedule_values = utils.cosine_scheduler(
369 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
370 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
371 |
372 | if mixup_fn is not None:
373 | # smoothing is handled with mixup label transform
374 | criterion = SoftTargetCrossEntropy()
375 | elif args.smoothing > 0.:
376 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
377 | else:
378 | criterion = torch.nn.CrossEntropyLoss()
379 |
380 | print("criterion = %s" % str(criterion))
381 |
382 | utils.auto_load_model(
383 | args=args, model=model, model_without_ddp=model_without_ddp,
384 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
385 |
386 | if args.eval:
387 | print(f"Eval only mode")
388 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
389 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
390 | return
391 |
392 | max_accuracy = 0.0
393 | if args.model_ema and args.model_ema_eval:
394 | max_accuracy_ema = 0.0
395 |
396 | print("Start training for %d epochs" % args.epochs)
397 | start_time = time.time()
398 | for epoch in range(args.start_epoch, args.epochs):
399 | if args.distributed:
400 | data_loader_train.sampler.set_epoch(epoch)
401 | if log_writer is not None:
402 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
403 | if wandb_logger:
404 | wandb_logger.set_steps()
405 | train_stats = train_one_epoch(
406 | model, criterion, data_loader_train, optimizer,
407 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
408 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch,
409 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
410 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
411 | use_amp=args.use_amp
412 | )
413 | if args.output_dir and args.save_ckpt:
414 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
415 | utils.save_model(
416 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
417 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
418 | if data_loader_val is not None:
419 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
420 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
421 | if max_accuracy < test_stats["acc1"]:
422 | max_accuracy = test_stats["acc1"]
423 | if args.output_dir and args.save_ckpt:
424 | utils.save_model(
425 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
426 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
427 | print(f'Max accuracy: {max_accuracy:.2f}%')
428 |
429 | if log_writer is not None:
430 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
431 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
432 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
433 |
434 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
435 | **{f'test_{k}': v for k, v in test_stats.items()},
436 | 'epoch': epoch,
437 | 'n_parameters': n_parameters}
438 |
439 | # repeat testing routines for EMA, if ema eval is turned on
440 | if args.model_ema and args.model_ema_eval:
441 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp)
442 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%")
443 | if max_accuracy_ema < test_stats_ema["acc1"]:
444 | max_accuracy_ema = test_stats_ema["acc1"]
445 | if args.output_dir and args.save_ckpt:
446 | utils.save_model(
447 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
448 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema)
449 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%')
450 | if log_writer is not None:
451 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch)
452 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}})
453 | else:
454 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
455 | 'epoch': epoch,
456 | 'n_parameters': n_parameters}
457 |
458 | if args.output_dir and utils.is_main_process():
459 | if log_writer is not None:
460 | log_writer.flush()
461 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
462 | f.write(json.dumps(log_stats) + "\n")
463 |
464 | if wandb_logger:
465 | wandb_logger.log_epoch_metrics(log_stats)
466 |
467 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir:
468 | wandb_logger.log_checkpoints()
469 |
470 |
471 | total_time = time.time() - start_time
472 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
473 | print('Training time {}'.format(total_time_str))
474 |
475 | if __name__ == '__main__':
476 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()])
477 | args = parser.parse_args()
478 | if args.output_dir:
479 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
480 | main(args)
481 |
--------------------------------------------------------------------------------
/models/dyn_perceiver_resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from dataclasses import dataclass
4 | from typing import Any, Optional, Tuple
5 |
6 | import torch
7 | from einops import rearrange, repeat
8 | from torch import Tensor
9 | import torch.nn as nn
10 |
11 | import torch.nn.functional as F
12 | import numpy as np
13 | # from .registry import register_model
14 | from timm.models.registry import register_model
15 | # from fairscale.nn import checkpoint_wrapper
16 | # from fvcore.nn import FlopCountAnalysis, parameter_count_table
17 |
18 | from .perceiver_core import (
19 | CrossAttentionLayer,
20 | SelfAttentionBlock,
21 | )
22 | from .cnn_core.resnet import *
23 |
24 |
25 | class DynPerceiver(nn.Module):
26 | def __init__(self,
27 | input_size: int=224,
28 | num_classes:int=1000,
29 | cnn_arch: str="resnet18",
30 | num_SA_heads: list=[1,2,4,8],
31 | num_latents: int=32,
32 | num_latent_channels: int=None,
33 | dropout: float = 0.0,
34 | SA_widening_factor: int=4,
35 | activation_checkpointing: bool = False,
36 | spatial_reduction: bool=True,
37 | depth_factor: list=[1,1,1,3],
38 | output_dir: str='./',
39 | with_x2z=True,
40 | with_z2x=True,
41 | with_dwc=True,
42 | with_last_CA=True,
43 | with_isc=True,
44 |
45 | drop_rate=0.2,
46 | drop_path_rate=0.2,
47 |
48 | exit=-1,
49 |
50 | **kwargs):
51 | super().__init__()
52 | self.exit = exit
53 | if num_SA_heads is None:
54 | num_SA_heads = [1,2,4,8]
55 | self.num_classes = num_classes
56 | cnn = eval(f'{cnn_arch}')(drop_rate=drop_rate,
57 | drop_path_rate=drop_path_rate)
58 | self.cnn_stem = nn.Sequential(
59 | cnn.conv1,
60 | cnn.bn1,
61 | cnn.act1,
62 | cnn.maxpool
63 | )
64 |
65 | self.cnn_body_stage1 = cnn.layer1
66 | self.cnn_body_stage2 = cnn.layer2
67 | self.cnn_body_stage3 = cnn.layer3
68 | self.cnn_body_stage4 = cnn.layer4
69 |
70 | # self.cnn_body_last_conv1x1 = cnn.blocks[6]
71 |
72 | num_blocks_per_stage = [3*depth_factor[0], 3*depth_factor[1], 9*depth_factor[2], 3*depth_factor[3]]
73 | self.avgpool = cnn.global_pool
74 | # self.cnn_head_before_cls = nn.Sequential(
75 | # cnn.conv_head,
76 | # cnn.act2
77 | # )
78 | # self.flatten_cnn = cnn.flatten
79 | self.drop_rate_cnn = cnn.drop_rate
80 | self.classifier_cnn = cnn.fc
81 |
82 | print(cnn.drop_rate)
83 |
84 | self.spatial_reduction = spatial_reduction
85 | if spatial_reduction:
86 | self.ca_pooling = nn.AdaptiveAvgPool2d((7,7))
87 | def cross_attn(num_cross_attention_heads, q_input_channels, kv_input_channels, num_cross_attention_qk_channels, num_cross_attention_v_channels, cross_attention_widening_factor,
88 | rpb=False,
89 | feat_w=112,
90 | feat_h=112):
91 | layer = CrossAttentionLayer(
92 | num_heads=num_cross_attention_heads,
93 | num_q_input_channels=q_input_channels,
94 | num_kv_input_channels=kv_input_channels,
95 | num_qk_channels=num_cross_attention_qk_channels,
96 | num_v_channels=num_cross_attention_v_channels,
97 | widening_factor=cross_attention_widening_factor,
98 | dropout=dropout,
99 |
100 | rpb=rpb,
101 | feat_w=feat_w,
102 | feat_h=feat_h,
103 | )
104 | return layer
105 |
106 | def self_attn(num_self_attention_layers_per_block, num_self_attention_heads, num_channels, num_self_attention_qk_channels, num_self_attention_v_channels, self_attention_widening_factor):
107 | return SelfAttentionBlock(
108 | num_layers=num_self_attention_layers_per_block,
109 | num_heads=num_self_attention_heads,
110 | num_channels=num_channels,
111 | num_qk_channels=num_self_attention_qk_channels,
112 | num_v_channels=num_self_attention_v_channels,
113 | widening_factor=self_attention_widening_factor,
114 | dropout=dropout,
115 | activation_checkpointing=activation_checkpointing,
116 | )
117 |
118 |
119 | # stage1
120 | x_channels_stage1in = cnn.layer1[0].conv1.in_channels
121 | x_channels_stage2in = cnn.layer2[0].conv1.in_channels
122 | x_channels_stage3in = cnn.layer3[0].conv1.in_channels
123 | x_channels_stage4in = cnn.layer4[0].conv1.in_channels
124 | x_channels_stage4out = cnn.layer4[-1].conv3.out_channels if isinstance(cnn.layer4[-1], Bottleneck) else cnn.layer4[-1].conv2.out_channels
125 | z_channels = [x_channels_stage1in, x_channels_stage2in, x_channels_stage3in, x_channels_stage4in]
126 | # print(z_channels)
127 | # assert(0==1)
128 | if num_latent_channels is None:
129 | num_latent_channels = x_channels_stage1in
130 | self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels))
131 |
132 | self.with_x2z = with_x2z
133 | self.with_z2x = with_z2x
134 | self.with_dwc = with_dwc
135 |
136 | if with_dwc:
137 | self.dwc1_x2z = nn.Conv2d(in_channels=x_channels_stage1in, out_channels=x_channels_stage1in, kernel_size=7,
138 | groups=x_channels_stage1in, stride=1, padding=3)
139 | feat_hw = 7 if spatial_reduction else input_size//2
140 |
141 | # essential
142 | self.cross_att1_x2z = cross_attn(num_cross_attention_heads=1,
143 | q_input_channels=x_channels_stage1in,
144 | kv_input_channels=x_channels_stage1in,
145 | num_cross_attention_qk_channels=None,
146 | num_cross_attention_v_channels=None,
147 | cross_attention_widening_factor=1,
148 |
149 | rpb=True,
150 | feat_w=feat_hw,
151 | feat_h=feat_hw,
152 | )
153 | self.self_att1 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[0],
154 | num_self_attention_heads=num_SA_heads[0],
155 | num_channels=x_channels_stage1in,
156 | num_self_attention_qk_channels=None,
157 | num_self_attention_v_channels=None,
158 | self_attention_widening_factor=SA_widening_factor
159 | )
160 |
161 | # stage2
162 | if with_x2z:
163 | if with_dwc:
164 | self.dwc2_x2z = nn.Conv2d(in_channels=x_channels_stage2in, out_channels=x_channels_stage2in, kernel_size=7, groups=x_channels_stage2in, stride=1, padding=3)
165 | feat_hw = 7 if spatial_reduction else input_size//4
166 | self.cross_att2_x2z = cross_attn(num_cross_attention_heads=1,
167 | q_input_channels=x_channels_stage2in,
168 | kv_input_channels=x_channels_stage2in,
169 | num_cross_attention_qk_channels=None,
170 | num_cross_attention_v_channels=None,
171 | cross_attention_widening_factor=1,
172 |
173 | rpb=True,
174 | feat_w=feat_hw,
175 | feat_h=feat_hw,
176 | )
177 |
178 | if with_z2x:
179 | self.cross_att2_z2x = cross_attn(num_cross_attention_heads=1,
180 | q_input_channels=x_channels_stage2in,
181 | kv_input_channels=x_channels_stage2in,
182 | num_cross_attention_qk_channels=x_channels_stage2in//8,
183 | num_cross_attention_v_channels=x_channels_stage2in//8,
184 | cross_attention_widening_factor=1
185 | )
186 | self.self_att2 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[1],
187 | num_self_attention_heads=num_SA_heads[1],
188 | num_channels=x_channels_stage2in,
189 | num_self_attention_qk_channels=None,
190 | num_self_attention_v_channels=None,
191 | self_attention_widening_factor=SA_widening_factor
192 | )
193 |
194 | # stage3
195 | if with_x2z:
196 | if with_dwc:
197 | self.dwc3_x2z = nn.Conv2d(in_channels=x_channels_stage3in, out_channels=x_channels_stage3in, kernel_size=7, groups=x_channels_stage3in, stride=1, padding=3)
198 | feat_hw = 7 if spatial_reduction else input_size//8
199 | self.cross_att3_x2z = cross_attn(num_cross_attention_heads=1,
200 | q_input_channels=x_channels_stage3in,
201 | kv_input_channels=x_channels_stage3in,
202 | num_cross_attention_qk_channels=None,
203 | num_cross_attention_v_channels=None,
204 | cross_attention_widening_factor=1,
205 |
206 | rpb=True,
207 | feat_w=feat_hw,
208 | feat_h=feat_hw
209 | )
210 |
211 | if with_z2x:
212 | self.cross_att3_z2x = cross_attn(num_cross_attention_heads=1,
213 | q_input_channels=x_channels_stage3in,
214 | kv_input_channels=x_channels_stage3in,
215 | num_cross_attention_qk_channels=x_channels_stage3in//8,
216 | num_cross_attention_v_channels=x_channels_stage3in//8,
217 | cross_attention_widening_factor=1
218 | )
219 | self.self_att3 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[2],
220 | num_self_attention_heads=num_SA_heads[2],
221 | num_channels=x_channels_stage3in,
222 | num_self_attention_qk_channels=None,
223 | num_self_attention_v_channels=None,
224 | self_attention_widening_factor=SA_widening_factor
225 | )
226 |
227 | # stage4
228 | if with_x2z:
229 | if with_dwc:
230 | self.dwc4_x2z = nn.Conv2d(in_channels=x_channels_stage4in, out_channels=x_channels_stage4in, kernel_size=7, groups=x_channels_stage4in, stride=1, padding=3)
231 | feat_hw = 7 if spatial_reduction else input_size//16
232 | self.cross_att4_x2z = cross_attn(num_cross_attention_heads=1,
233 | q_input_channels=x_channels_stage4in,
234 | kv_input_channels=x_channels_stage4in,
235 | num_cross_attention_qk_channels=None,
236 | num_cross_attention_v_channels=None,
237 | cross_attention_widening_factor=1,
238 |
239 | rpb=True,
240 | feat_w=feat_hw,
241 | feat_h=feat_hw
242 | )
243 |
244 | if with_z2x:
245 | # print(x_channels_stage4in)
246 | self.cross_att4_z2x = cross_attn(num_cross_attention_heads=1,
247 | q_input_channels=x_channels_stage4in,
248 | kv_input_channels=x_channels_stage4in,
249 | num_cross_attention_qk_channels=x_channels_stage4in//8,
250 | num_cross_attention_v_channels=x_channels_stage4in//8,
251 | cross_attention_widening_factor=1
252 | )
253 | # print(num_blocks_per_stage[3])
254 | self.self_att4 = self_attn(num_self_attention_layers_per_block=num_blocks_per_stage[3],
255 | num_self_attention_heads=num_SA_heads[3],
256 | num_channels=x_channels_stage4in,
257 | num_self_attention_qk_channels=None,
258 | num_self_attention_v_channels=None,
259 | self_attention_widening_factor=SA_widening_factor
260 | )
261 |
262 | # last cross attention
263 | # print(x_channels_stage4out//8)
264 | # print(x_channels_stage4out, x_channels_stage4in)
265 | self.last_cross_att_z2x = cross_attn(num_cross_attention_heads=1,
266 | q_input_channels=x_channels_stage4out,
267 | kv_input_channels=x_channels_stage4in,
268 | num_cross_attention_qk_channels=x_channels_stage4out//8,
269 | num_cross_attention_v_channels=x_channels_stage4out//8,
270 | cross_attention_widening_factor=1
271 | ) if with_last_CA else None
272 |
273 | # self.classifier_cnn = cnn.classifier
274 | # self.classifier_flops = cnn.classifier_flops
275 | # print(x_channels_stage1in, x_channels_stage2in, x_channels_stage3in, x_channels_stage4in)
276 | # self.early_classifier1 = nn.Linear(x_channels_stage1in, num_classes)
277 | # self.early_classifier2 = nn.Linear(x_channels_stage2in, num_classes)
278 |
279 | self.early_classifier3 = nn.Linear(x_channels_stage3in, num_classes)
280 | self.with_isc = with_isc
281 |
282 | if not with_isc:
283 | self.classifier_att = nn.Linear(x_channels_stage4in, num_classes)
284 |
285 | self.classifier_merge = nn.Sequential(
286 | nn.BatchNorm1d(cnn.num_features + x_channels_stage4in),
287 | nn.Linear(cnn.num_features + x_channels_stage4in, num_classes)
288 | )
289 | else:
290 |
291 | self.isc3 = nn.Sequential(nn.Linear(num_classes, x_channels_stage4in),
292 | nn.BatchNorm1d(x_channels_stage4in),
293 | nn.ReLU(inplace=True)
294 | )
295 |
296 | self.classifier_att = nn.Linear(2*x_channels_stage4in, num_classes)
297 |
298 | self.isc4 = nn.Sequential(nn.Linear(num_classes, x_channels_stage4in),
299 | nn.BatchNorm1d(x_channels_stage4in),
300 | nn.ReLU(inplace=True)
301 | )
302 | self.classifier_merge = nn.Sequential(
303 | nn.BatchNorm1d(cnn.num_features + 2*x_channels_stage4in),
304 | nn.Linear(cnn.num_features + 2*x_channels_stage4in, num_classes)
305 | )
306 |
307 | expander = []
308 | token_mixer = []
309 |
310 | num_latents_list = [num_latents, num_latents//2, num_latents//4, num_latents//8]
311 | for i in range(3):
312 | c_in = z_channels[i]
313 | c_out = z_channels[i+1]
314 | expander.append(nn.Sequential(
315 | nn.LayerNorm(c_in),
316 | nn.Linear(c_in, c_out)
317 | ))
318 | n_z_in = num_latents_list[i]
319 | n_z_out = num_latents_list[i+1]
320 | token_mixer.append(nn.Sequential(
321 | nn.LayerNorm(n_z_in),
322 | nn.Linear(n_z_in, n_z_out)
323 | ))
324 |
325 | self.token_expander = nn.ModuleList(expander)
326 | self.token_mixer = nn.ModuleList(token_mixer)
327 | self.output_dir = output_dir
328 | self._init_parameters()
329 |
330 |
331 | def _init_parameters(self):
332 | with torch.no_grad():
333 | self.latent.normal_(0.0, 0.02).clamp_(-2.0, 2.0)
334 |
335 | def forward(self, x, pad_mask=None):
336 | exit = self.exit
337 | #TODO
338 | b, c_in, _, _ = x.shape
339 | x_latent = repeat(self.latent, "... -> b ...", b=b)
340 |
341 | x = self.cnn_stem(x)
342 | # print(x.shape)
343 |
344 | if self.with_dwc:
345 | x_kv = self.dwc1_x2z(x) + x
346 | x_kv = self.ca_pooling(x_kv)
347 | else:
348 | x_kv = self.ca_pooling(x)
349 | x_kv = rearrange(x_kv, "b c ... -> b (...) c")
350 | # print(x_latent.shape, x_kv.shape)
351 | x_latent = self.cross_att1_x2z(x_latent, x_kv, pad_mask)
352 |
353 | # stage1, conv and self attention
354 | x_latent = self.self_att1(x_latent)
355 |
356 | # y_early1 = torch.mean(x_latent, dim=1).squeeze(1)
357 | # y_early1 = self.early_classifier1(y_early1)
358 |
359 |
360 | x = self.cnn_body_stage1(x)
361 |
362 | # between stage1 and stage2
363 | _, n_tokens, c_in = x_latent.shape
364 | x_latent = x_latent.permute(0,2,1)
365 | x_latent = self.token_mixer[0](x_latent)
366 | x_latent = x_latent.permute(0,2,1)
367 | x_latent = self.token_expander[0](x_latent)
368 |
369 | # transformer to conv
370 | if self.with_z2x:
371 | _,_,h,w = x.shape
372 | x = rearrange(x, "b c ... -> b (...) c")
373 | x = self.cross_att2_z2x(x, x_latent, pad_mask)
374 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
375 |
376 | # conv to transformer
377 | # print(x.shape)
378 | if self.with_x2z:
379 | if self.with_dwc:
380 | x_kv = self.dwc2_x2z(x) + x
381 | x_kv = self.ca_pooling(x_kv)
382 | else:
383 | x_kv = self.ca_pooling(x)
384 | x_kv = rearrange(x_kv, "b c ... -> b (...) c")
385 | x_latent = self.cross_att2_x2z(x_latent, x_kv, pad_mask)
386 |
387 | # stage2
388 | x_latent = self.self_att2(x_latent)
389 | # y_early2 = torch.mean(x_latent, dim=1).squeeze(1)
390 | # y_early2 = self.early_classifier2(y_early2)
391 |
392 | x = self.cnn_body_stage2(x)
393 |
394 | # between stage2 and stage3
395 | _, n_tokens, c_in = x_latent.shape
396 | x_latent = x_latent.permute(0,2,1)
397 | x_latent = self.token_mixer[1](x_latent)
398 | x_latent = x_latent.permute(0,2,1)
399 | x_latent = self.token_expander[1](x_latent)
400 |
401 | # transformer to conv
402 | if self.with_z2x:
403 | _,_,h,w = x.shape
404 | x = rearrange(x, "b c ... -> b (...) c")
405 | x = self.cross_att3_z2x(x, x_latent, pad_mask)
406 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
407 |
408 | # conv to transformer
409 | # print(x.shape)
410 | if self.with_x2z:
411 | if self.with_dwc:
412 | x_kv = self.dwc3_x2z(x) + x
413 | x_kv = self.ca_pooling(x_kv)
414 | else:
415 | x_kv = self.ca_pooling(x)
416 | x_kv = rearrange(x_kv, "b c ... -> b (...) c")
417 | x_latent = self.cross_att3_x2z(x_latent, x_kv, pad_mask)
418 |
419 | # stage3
420 | x_latent = self.self_att3(x_latent)
421 | y_early3 = torch.mean(x_latent, dim=1).squeeze(1)
422 | y_early3 = self.early_classifier3(y_early3)
423 | if exit == 0:
424 | return y_early3
425 |
426 | x = self.cnn_body_stage3(x)
427 |
428 | # between stage3 and stage4
429 | _, n_tokens, c_in = x_latent.shape
430 | x_latent = x_latent.permute(0,2,1)
431 | x_latent = self.token_mixer[2](x_latent)
432 | x_latent = x_latent.permute(0,2,1)
433 | x_latent = self.token_expander[2](x_latent)
434 |
435 | # transformer to conv
436 | if self.with_z2x:
437 | _,_,h,w = x.shape
438 | x = rearrange(x, "b c ... -> b (...) c")
439 | # print(x_latent.shape, x.shape)
440 | x = self.cross_att4_z2x(x, x_latent, pad_mask)
441 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
442 |
443 |
444 | # conv to transformer
445 | # print(x.shape)
446 | if self.with_x2z:
447 |
448 | if self.with_dwc:
449 | x_kv = self.dwc4_x2z(x) + x
450 | x_kv = self.ca_pooling(x_kv)
451 | else:
452 | x_kv = self.ca_pooling(x)
453 | x_kv = rearrange(x_kv, "b c ... -> b (...) c")
454 | x_latent = self.cross_att4_x2z(x_latent, x_kv, pad_mask)
455 |
456 | # stage4
457 | x_latent = self.self_att4(x_latent)
458 |
459 | x_latent_mean = torch.mean(x_latent, dim=1).squeeze(1)
460 | if self.with_isc:
461 | y3_ = self.isc3(y_early3)
462 | y_att = torch.cat((x_latent_mean, y3_), dim=1)
463 | y_att = self.classifier_att(y_att)
464 | else:
465 | y_att = self.classifier_att(x_latent_mean)
466 | if exit == 1:
467 | return y_att
468 |
469 | x = self.cnn_body_stage4(x)
470 | # print(x.shape)
471 |
472 | # x = self.cnn_body_last_conv1x1(x)
473 | x_mean = self.avgpool(x)
474 | # x_mean = self.cnn_head_before_cls(x_mean)
475 | # x_mean = self.flatten_cnn(x_mean)
476 |
477 | if self.drop_rate_cnn > 0.:
478 | x_mean = F.dropout(x_mean, p=self.drop_rate_cnn, training=self.training)
479 |
480 | # print(x_mean.shape)
481 | y_cnn = self.classifier_cnn(x_mean)
482 | if exit == 2:
483 | return y_cnn
484 |
485 | if self.last_cross_att_z2x is not None:
486 | _,_,h,w = x.shape
487 | x = rearrange(x, "b c ... -> b (...) c")
488 | x = self.last_cross_att_z2x(x, x_latent, pad_mask)
489 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
490 | x_mean = self.avgpool(x)
491 | # x_mean = self.cnn_head_before_cls(x_mean)
492 | # x_mean = self.flatten_cnn(x_mean)
493 |
494 | if self.with_isc:
495 | y4_ = self.isc4(y_att)
496 | x_merge = torch.cat((x_mean, x_latent_mean, y4_), dim=1)
497 | y_merge = self.classifier_merge(x_merge)
498 | else:
499 | x_merge = torch.cat((x_mean, x_latent_mean), dim=1)
500 | y_merge = self.classifier_merge(x_merge)
501 |
502 | return y_early3, y_att, y_cnn, y_merge
503 |
504 |
505 | @register_model
506 | def resnet18_perceiver_t128(**kwargs):
507 | model = DynPerceiver(num_latents=128, cnn_arch='resnet18', **kwargs)
508 | return model
509 |
510 |
511 | @register_model
512 | def resnet18_perceiver_t256(**kwargs):
513 | model = DynPerceiver(num_latents=256, cnn_arch='resnet18', **kwargs)
514 | return model
515 |
516 |
517 | @register_model
518 | def resnet18_perceiver_t512(**kwargs):
519 | model = DynPerceiver(num_latents=512, cnn_arch='resnet18', **kwargs)
520 | return model
521 |
522 |
523 | @register_model
524 | def resnet34_perceiver_t128(**kwargs):
525 | model = DynPerceiver(num_latents=128, cnn_arch='resnet34', **kwargs)
526 | return model
527 |
528 |
529 | @register_model
530 | def resnet34_perceiver_t256(**kwargs):
531 | model = DynPerceiver(num_latents=256, cnn_arch='resnet34', **kwargs)
532 | return model
533 |
534 |
535 | @register_model
536 | def resnet34_perceiver_t512(**kwargs):
537 | model = DynPerceiver(num_latents=512, cnn_arch='resnet34', **kwargs)
538 | return model
539 |
540 |
541 | @register_model
542 | def resnet50_perceiver_t64(**kwargs):
543 | model = DynPerceiver(num_latents=64, cnn_arch='resnet50', **kwargs)
544 | return model
545 |
546 |
547 | @register_model
548 | def resnet50_perceiver_t128(**kwargs):
549 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50', **kwargs)
550 | return model
551 |
552 |
553 | @register_model
554 | def resnet50_perceiver_t256(**kwargs):
555 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50', **kwargs)
556 | return model
557 |
558 |
559 | @register_model
560 | def resnet50_perceiver_t512(**kwargs):
561 | model = DynPerceiver(num_latents=512, cnn_arch='resnet50', **kwargs)
562 | return model
563 |
564 |
565 | @register_model
566 | def resnet50_050_perceiver_t64(**kwargs):
567 | model = DynPerceiver(num_latents=64, cnn_arch='resnet50_050', **kwargs)
568 | return model
569 |
570 |
571 | @register_model
572 | def resnet50_050_perceiver_t128(**kwargs):
573 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_050', **kwargs)
574 | return model
575 |
576 |
577 | @register_model
578 | def resnet50_050_perceiver_t256(**kwargs):
579 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_050', **kwargs)
580 | return model
581 |
582 |
583 | @register_model
584 | def resnet50_0375_perceiver_t64(**kwargs):
585 | model = DynPerceiver(num_latents=64, cnn_arch='resnet50_0375', **kwargs)
586 | return model
587 |
588 |
589 | @register_model
590 | def resnet50_0375_perceiver_t128(**kwargs):
591 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_0375', **kwargs)
592 | return model
593 |
594 |
595 | @register_model
596 | def resnet50_0375_perceiver_t256(**kwargs):
597 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_0375', **kwargs)
598 | return model
599 |
600 |
601 | @register_model
602 | def resnet50_0625_perceiver_t128(**kwargs):
603 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_0625', **kwargs)
604 | return model
605 |
606 |
607 | @register_model
608 | def resnet50_0625_perceiver_t160(**kwargs):
609 | model = DynPerceiver(num_latents=160, cnn_arch='resnet50_0625', **kwargs)
610 | return model
611 |
612 |
613 | @register_model
614 | def resnet50_0625_perceiver_t192(**kwargs):
615 | model = DynPerceiver(num_latents=192, cnn_arch='resnet50_0625', **kwargs)
616 | return model
617 |
618 |
619 | @register_model
620 | def resnet50_0625_perceiver_t256(**kwargs):
621 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_0625', **kwargs)
622 | return model
623 |
624 |
625 | @register_model
626 | def resnet50_075_perceiver_t128(**kwargs):
627 | model = DynPerceiver(num_latents=128, cnn_arch='resnet50_075', **kwargs)
628 | return model
629 |
630 |
631 | @register_model
632 | def resnet50_075_perceiver_t256(**kwargs):
633 | model = DynPerceiver(num_latents=256, cnn_arch='resnet50_075', **kwargs)
634 | return model
635 |
636 |
637 | def dyn_flops():
638 | x = torch.rand(1, 3, 224, 224)
639 | result = []
640 | for i in range(4,5):
641 | model = resnet50_0375_perceiver_t256(
642 | depth_factor=[1,1,1,1], SA_widening_factor=4, spatial_reduction=True,
643 | with_last_CA=True, with_x2z=True, with_dwc=True, with_z2x=True, exit=i)
644 | model.eval()
645 | from fvcore.nn import FlopCountAnalysis
646 | flops = FlopCountAnalysis(model, x)
647 | result.append(flops.total() / 1e9)
648 | print('***************************')
649 | for flop in result:
650 | print(flop)
651 | print('***************************')
652 |
653 |
654 | if __name__ == '__main__':
655 |
656 | dyn_flops()
657 |
658 | # # Fourier-encodes pixel positions and flatten along spatial dimensions
659 | # input_adapter = ImageInputAdapter(
660 | # image_shape=(224, 224, 3), # M = 224 * 224
661 | # num_frequency_bands=64,
662 | # )
663 |
664 | # # Projects generic Perceiver decoder output to specified number of classes
665 | # output_adapter = ClassificationOutputAdapter(
666 | # num_classes=1000,
667 | # num_output_query_channels=1024, # F
668 | # )
669 |
670 | # # Generic Perceiver encoder
671 | # encoder = PerceiverEncoder(
672 | # input_adapter=input_adapter,
673 | # num_latents=512, # N
674 | # num_latent_channels=1024, # D
675 | # num_cross_attention_qk_channels=input_adapter.num_input_channels, # C
676 | # num_cross_attention_heads=1,
677 | # num_self_attention_heads=4,
678 | # num_self_attention_layers_per_block=6,
679 | # num_self_attention_blocks=8,
680 | # dropout=0.0,
681 | # )
682 |
683 | # # Generic Perceiver decoder
684 | # decoder = PerceiverDecoder(
685 | # output_adapter=output_adapter,
686 | # num_latent_channels=1024, # D
687 | # num_cross_attention_heads=1,
688 | # dropout=0.0,
689 | # )
690 |
691 | # # Perceiver IO image classifier
692 | # model = PerceiverIO(encoder, decoder)
693 | # model.eval()
694 | # print(model)
695 | # x = torch.rand(4,224,224,3)
696 | # with torch.no_grad():
697 | # y = model(x)
698 |
699 | # print(y.shape)
700 |
701 |
702 |
703 |
704 |
705 | # regnet = regnet_y_400mf()
706 | # print(regnet)
707 | # print(regnet.trunk_output.block1)
708 |
709 | def count_parameters(model):
710 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
711 |
712 | model = resnet50_perceiver_t128(depth_factor=[1,1,1,3], SA_widening_factor=4, spatial_reduction=True, with_last_CA=True,
713 | with_x2z=True, with_dwc=True, with_z2x=True)
714 | model.eval()
715 | print(count_parameters(model)/1e6)
716 | x = torch.rand(1,3,224,224)
717 | with torch.no_grad():
718 | y = model(x)
719 | # print(y.shape)
720 | # print()
721 | # flops = FlopCountAnalysis(model, x)
722 | # print("FLOPs: ", flops.total()/1e9)
723 | # from fvcore.nn import flop_count_str
724 | # print(flop_count_str(flops))
725 | # # 分析parameters
726 | # print(parameter_count_table(model))
--------------------------------------------------------------------------------
/models/perceiver_core/modules.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | from einops import rearrange, repeat
6 | # from fairscale.nn import checkpoint_wrapper
7 | from torch import Tensor
8 |
9 | from .utils import Sequential
10 | from timm.models.layers import trunc_normal_
11 |
12 | class MultiHeadAttention(nn.Module):
13 | def __init__(
14 | self,
15 | num_heads: int,
16 | num_q_input_channels: int,
17 | num_kv_input_channels: int,
18 | num_qk_channels: Optional[int] = None,
19 | num_v_channels: Optional[int] = None,
20 | num_output_channels: Optional[int] = None,
21 | dropout: float = 0.0,
22 | rpb=False,
23 | feat_w=32,
24 | feat_h=32
25 | ):
26 | """Multi-head attention as described in https://arxiv.org/abs/2107.14795 Appendix E.
27 |
28 | :param num_heads: Number of attention heads.
29 | :param num_q_input_channels: Number of query input channels.
30 | :param num_kv_input_channels: Number of key/value input channels.
31 | :param num_qk_channels: Number of channels query and key input channels are projected to,
32 | for computing the attention matrix. Defaults to number `num_q_input_channels`
33 | :param num_v_channels: Number of channels value input channels are projected to.
34 | Defaults to `num_qk_channels`.
35 | :param num_output_channels: Number of output channels attention result channels are projected to.
36 | Defaults to `num_q_input_channels`
37 | :param dropout: Dropout probability for attention matrix values. Defaults to `0.0`
38 | """
39 | super().__init__()
40 |
41 | if num_qk_channels is None:
42 | num_qk_channels = num_q_input_channels
43 |
44 | if num_v_channels is None:
45 | num_v_channels = num_qk_channels
46 |
47 | if num_output_channels is None:
48 | num_output_channels = num_q_input_channels
49 |
50 | if num_qk_channels % num_heads != 0:
51 | raise ValueError("num_qk_channels must be divisible by num_heads")
52 |
53 | if num_v_channels % num_heads != 0:
54 | raise ValueError("num_v_channels must be divisible by num_heads")
55 |
56 |
57 | # print(f'attention, qk channels {num_qk_channels}, v channels {num_v_channels}')
58 |
59 |
60 | self.rpb=rpb
61 |
62 | if rpb:
63 | self.feat_w=feat_w
64 | self.feat_h=feat_h
65 | self.relative_position_bias = nn.Parameter(
66 | torch.zeros((1, self.feat_w*self.feat_h, num_qk_channels))
67 | ) # same size as k (feature map)
68 |
69 | num_qk_channels_per_head = num_qk_channels // num_heads
70 |
71 | self.dp_scale = num_qk_channels_per_head ** -0.5
72 | self.num_heads = num_heads
73 |
74 | self.q_proj = nn.Linear(num_q_input_channels, num_qk_channels, bias=False)
75 | self.k_proj = nn.Linear(num_kv_input_channels, num_qk_channels, bias=False)
76 | self.v_proj = nn.Linear(num_kv_input_channels, num_v_channels, bias=False)
77 | self.o_proj = nn.Linear(num_v_channels, num_output_channels)
78 | self.dropout = nn.Dropout(dropout)
79 |
80 | def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None):
81 | """
82 | :param x_q: Query input of shape (B, N, D) where B is the batch size, N the query sequence length
83 | and D the number of query input channels (= `num_q_input_channels`)
84 | :param x_kv: Key/value input of shape (B, L, C) where B is the batch size, L the key/value sequence
85 | length and C are the number of key/value input channels (= `num_kv_input_channels`)
86 | :param pad_mask: Boolean key padding mask. `True` values indicate padding tokens.
87 | :param attn_mask: Boolean attention mask. Not needed/supported yet.
88 | :return: attention result of shape (B, N, F) where B is the batch size, N the query sequence length
89 | and F the number of output channels (= `num_output_channels`)
90 | """
91 | if attn_mask is not None:
92 | raise NotImplementedError("attention masks not supported yet")
93 |
94 | # flops = 0
95 | b = x_q.shape[0]
96 | q_in_channels = x_q.shape[-1]
97 | kv_in_channels = x_kv.shape[-1]
98 |
99 | q = self.q_proj(x_q)
100 | k = self.k_proj(x_kv)
101 | v = self.v_proj(x_kv)
102 | # print("x_q.shape:", x_q.shape)
103 | # flops += q_in_channels * q.shape[-1] * q.shape[-2] + 2*kv_in_channels * k.shape[-1] * k.shape[-2]
104 | # print(k.shape)
105 | if self.rpb:
106 | k += self.relative_position_bias
107 |
108 | q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
109 | attn = torch.einsum("b i c, b j c -> b i j", q, k) * self.dp_scale
110 | # attn = q @ k.transpose(-2, -1) * self.dp_scale
111 | # print(attn.shape)
112 | # flops += q.shape[-1] * q.shape[-2] * k.shape[-2] * self.num_heads
113 |
114 | if pad_mask is not None:
115 | pad_mask = repeat(pad_mask, "b j -> (b h) () j", h=self.num_heads)
116 | attn_max_neg = -torch.finfo(attn.dtype).max
117 | attn.masked_fill_(pad_mask, attn_max_neg)
118 |
119 | attn = attn.softmax(dim=-1)
120 | attn = self.dropout(attn)
121 |
122 | o = torch.einsum("b i j, b j c -> b i c", attn, v)
123 | # o = attn @ v
124 | # print(attn.shape, v.shape, o.shape)
125 | # flops += attn.shape[-2] * v.shape[-2] * v.shape[-1] * self.num_heads
126 | o = rearrange(o, "(b h) n c -> b n (h c)", h=self.num_heads)
127 | in_channels = o.shape[-1]
128 | o = self.o_proj(o)
129 | # print(o.shape)
130 | # flops += in_channels * o.shape[-1] * o.shape[-2]
131 | return o
132 |
133 | def forward_calc_flops(self, x_q, x_kv, pad_mask=None, attn_mask=None):
134 | """
135 | :param x_q: Query input of shape (B, N, D) where B is the batch size, N the query sequence length
136 | and D the number of query input channels (= `num_q_input_channels`)
137 | :param x_kv: Key/value input of shape (B, L, C) where B is the batch size, L the key/value sequence
138 | length and C are the number of key/value input channels (= `num_kv_input_channels`)
139 | :param pad_mask: Boolean key padding mask. `True` values indicate padding tokens.
140 | :param attn_mask: Boolean attention mask. Not needed/supported yet.
141 | :return: attention result of shape (B, N, F) where B is the batch size, N the query sequence length
142 | and F the number of output channels (= `num_output_channels`)
143 | """
144 | if attn_mask is not None:
145 | raise NotImplementedError("attention masks not supported yet")
146 |
147 | flops = 0
148 | b = x_q.shape[0]
149 | q_in_channels = x_q.shape[-1]
150 | kv_in_channels = x_kv.shape[-1]
151 |
152 | q = self.q_proj(x_q)
153 | k = self.k_proj(x_kv)
154 | v = self.v_proj(x_kv)
155 | # print("x_q.shape:", x_q.shape)
156 | flops += q_in_channels * q.shape[-1] * q.shape[-2] + 2*kv_in_channels * k.shape[-1] * k.shape[-2]
157 | # print(k.shape)
158 | if self.rpb:
159 | k += self.relative_position_bias
160 |
161 | q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
162 | attn = torch.einsum("b i c, b j c -> b i j", q, k) * self.dp_scale
163 | # attn = q @ k.transpose(-2, -1) * self.dp_scale
164 | # print(attn.shape)
165 | flops += q.shape[-1] * q.shape[-2] * k.shape[-2] * self.num_heads
166 |
167 | if pad_mask is not None:
168 | pad_mask = repeat(pad_mask, "b j -> (b h) () j", h=self.num_heads)
169 | attn_max_neg = -torch.finfo(attn.dtype).max
170 | attn.masked_fill_(pad_mask, attn_max_neg)
171 |
172 | attn = attn.softmax(dim=-1)
173 | attn = self.dropout(attn)
174 |
175 | o = torch.einsum("b i j, b j c -> b i c", attn, v)
176 | # o = attn @ v
177 | # print(attn.shape, v.shape, o.shape)
178 | flops += attn.shape[-2] * v.shape[-2] * v.shape[-1] * self.num_heads
179 | o = rearrange(o, "(b h) n c -> b n (h c)", h=self.num_heads)
180 | in_channels = o.shape[-1]
181 | o = self.o_proj(o)
182 | # print(o.shape)
183 | flops += in_channels * o.shape[-1] * o.shape[-2]
184 | return o, flops
185 |
186 | class CrossAttention(nn.Module):
187 | def __init__(
188 | self,
189 | num_heads: int,
190 | num_q_input_channels: int,
191 | num_kv_input_channels: int,
192 | num_qk_channels: Optional[int] = None,
193 | num_v_channels: Optional[int] = None,
194 | dropout: float = 0.0,
195 |
196 | rpb=False,
197 | feat_w=32,
198 | feat_h=32
199 | ):
200 | """Multi-head cross-attention (see `MultiHeadAttention` for details)."""
201 | super().__init__()
202 |
203 | self.q_norm = nn.LayerNorm(num_q_input_channels)
204 | self.kv_norm = nn.LayerNorm(num_kv_input_channels)
205 | self.attention = MultiHeadAttention(
206 | num_heads=num_heads,
207 | num_q_input_channels=num_q_input_channels,
208 | num_kv_input_channels=num_kv_input_channels,
209 | num_qk_channels=num_qk_channels,
210 | num_v_channels=num_v_channels,
211 | dropout=dropout,
212 |
213 | rpb=rpb,
214 | feat_w=feat_w,
215 | feat_h=feat_h
216 | )
217 |
218 | def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None):
219 | """Multi-head attention of query input `x_q` to key/value input (`x_kv`) after (separately) applying layer
220 | normalization to these inputs."""
221 | x_q = self.q_norm(x_q)
222 | x_kv = self.kv_norm(x_kv)
223 | return self.attention(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask)
224 |
225 | def forward_calc_flops(self, x_q, x_kv, pad_mask=None, attn_mask=None):
226 | """Multi-head attention of query input `x_q` to key/value input (`x_kv`) after (separately) applying layer
227 | normalization to these inputs."""
228 | x_q = self.q_norm(x_q)
229 | x_kv = self.kv_norm(x_kv)
230 | return self.attention.forward_calc_flops(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask)
231 |
232 |
233 | class SelfAttention(nn.Module):
234 | def __init__(
235 | self,
236 | num_heads: int,
237 | num_channels: int,
238 | num_qk_channels: Optional[int] = None,
239 | num_v_channels: Optional[int] = None,
240 | dropout: float = 0.0,
241 | ):
242 | """Multi-head self-attention (see `MultiHeadAttention` and for details)."""
243 | super().__init__()
244 | self.norm = nn.LayerNorm(num_channels)
245 | self.attention = MultiHeadAttention(
246 | num_heads=num_heads,
247 | num_q_input_channels=num_channels,
248 | num_kv_input_channels=num_channels,
249 | num_qk_channels=num_qk_channels,
250 | num_v_channels=num_v_channels,
251 | dropout=dropout,
252 | )
253 |
254 | def forward(self, x, pad_mask=None, attn_mask=None):
255 | """Multi-head attention of input `x` to itself after applying layer normalization to the input."""
256 | x = self.norm(x)
257 | return self.attention(x, x, pad_mask=pad_mask, attn_mask=attn_mask)
258 |
259 | def forward_calc_flops(self, x, pad_mask=None, attn_mask=None):
260 | """Multi-head attention of input `x` to itself after applying layer normalization to the input."""
261 | x = self.norm(x)
262 | return self.attention.forward_calc_flops(x, x, pad_mask=pad_mask, attn_mask=attn_mask)
263 |
264 |
265 | class CrossAttentionLayer(nn.Module):
266 | def __init__(
267 | self,
268 | num_heads: int,
269 | num_q_input_channels: int,
270 | num_kv_input_channels: int,
271 | num_qk_channels: Optional[int] = None,
272 | num_v_channels: Optional[int] = None,
273 | widening_factor: int = 1,
274 | dropout: float = 0.0,
275 | attention_residual: bool = True,
276 | rpb=False,
277 | feat_w=32,
278 | feat_h=32
279 | ):
280 | # print(attention_residual)
281 | super().__init__(
282 | # Residual(cross_attn, dropout) if attention_residual else cross_attn,
283 | # Residual(MLP(num_q_input_channels, widening_factor), dropout),
284 | )
285 |
286 | self.cross_attn = CrossAttention(
287 | num_heads=num_heads,
288 | num_q_input_channels=num_q_input_channels,
289 | num_kv_input_channels=num_kv_input_channels,
290 | num_qk_channels=num_qk_channels,
291 | num_v_channels=num_v_channels,
292 | dropout=dropout,
293 | rpb=rpb,
294 | feat_w=feat_w,
295 | feat_h=feat_h
296 | )
297 | self.attention_residual = attention_residual
298 | self.mlp = MLP(num_q_input_channels, widening_factor)
299 | self.dropout = nn.Dropout(p=dropout)
300 |
301 | def forward(self, x_q, x_kv, pad_mask=None, attn_mask=None):
302 | att_out = self.cross_attn(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask)
303 |
304 | out = x_q + self.dropout(att_out) if self.attention_residual else att_out
305 |
306 | mlp_out = self.mlp(out)
307 | out = out + self.dropout(mlp_out)
308 |
309 | # flops = att_flops + mlp_flops
310 |
311 | return out
312 |
313 | def forward_calc_flops(self, x_q, x_kv, pad_mask=None, attn_mask=None):
314 | att_out, att_flops = self.cross_attn.forward_calc_flops(x_q, x_kv, pad_mask=pad_mask, attn_mask=attn_mask)
315 |
316 | out = x_q + self.dropout(att_out) if self.attention_residual else att_out
317 |
318 | mlp_out, mlp_flops = self.mlp.forward_calc_flops(out)
319 | out = out + self.dropout(mlp_out)
320 | # print(att_flops/1e8, mlp_flops/1e8)
321 | flops = att_flops + mlp_flops
322 |
323 | return out, flops
324 |
325 | class SelfAttentionLayer(Sequential):
326 | def __init__(
327 | self,
328 | num_heads: int,
329 | num_channels: int,
330 | num_qk_channels: Optional[int] = None,
331 | num_v_channels: Optional[int] = None,
332 | widening_factor: int = 1,
333 | dropout: float = 0.0,
334 | ):
335 |
336 | super().__init__(
337 | # Residual(self_attn, dropout),
338 | # Residual(MLP(num_channels, widening_factor), dropout),
339 | )
340 | self.self_attn = SelfAttention(
341 | num_heads=num_heads,
342 | num_channels=num_channels,
343 | num_qk_channels=num_qk_channels,
344 | num_v_channels=num_v_channels,
345 | dropout=dropout,
346 | )
347 | self.mlp = MLP(num_channels, widening_factor)
348 | self.dropout = nn.Dropout(p=dropout)
349 |
350 | def forward(self, x, pad_mask=None, attn_mask=None):
351 | att_out = self.self_attn(x, pad_mask=pad_mask, attn_mask=attn_mask)
352 | out = x + self.dropout(att_out)
353 | mlp_out = self.mlp(out)
354 | out = out + self.dropout(mlp_out)
355 | return out
356 |
357 | def forward_calc_flops(self, x, pad_mask=None, attn_mask=None):
358 | att_out, att_flops = self.self_attn.forward_calc_flops(x, pad_mask=pad_mask, attn_mask=attn_mask)
359 | out = x + self.dropout(att_out)
360 | mlp_out, mlp_flops = self.mlp.forward_calc_flops(out)
361 | out = out + self.dropout(mlp_out)
362 | flops = att_flops + mlp_flops
363 | return out, flops
364 |
365 | class SelfAttentionBlock(nn.ModuleList):
366 | def __init__(
367 | self,
368 | num_layers: int,
369 | num_heads: int,
370 | num_channels: int,
371 | num_qk_channels: Optional[int] = None,
372 | num_v_channels: Optional[int] = None,
373 | widening_factor: int = 1,
374 | dropout: float = 0.0,
375 | activation_checkpointing: bool = False,
376 | ):
377 | layers = [
378 | SelfAttentionLayer(
379 | num_heads=num_heads,
380 | num_channels=num_channels,
381 | num_qk_channels=num_qk_channels,
382 | num_v_channels=num_v_channels,
383 | widening_factor=widening_factor,
384 | dropout=dropout,
385 | )
386 | for _ in range(num_layers)
387 | ]
388 |
389 | # if activation_checkpointing:
390 | # layers = [checkpoint_wrapper(layer) for layer in layers]
391 |
392 | super().__init__(layers)
393 |
394 | def forward(self, x, pad_mask=None, attn_mask=None):
395 | out = x
396 | for layer in self:
397 | out = layer(out, pad_mask=pad_mask, attn_mask=attn_mask)
398 | return out
399 |
400 | def forward_calc_flops(self, x, pad_mask=None, attn_mask=None):
401 | out = x
402 | flops = 0
403 | for layer in self:
404 | out, flops_layer = layer.forward_calc_flops(out, pad_mask=pad_mask, attn_mask=attn_mask)
405 | flops += flops_layer
406 | return out, flops
407 |
408 |
409 | class MLP(nn.Module):
410 | def __init__(self, num_channels: int, widening_factor: int):
411 | super().__init__(
412 |
413 | )
414 | self.layernorm = nn.LayerNorm(num_channels)
415 | self.fc1 = nn.Linear(num_channels, widening_factor * num_channels)
416 | self.gelu = nn.GELU()
417 | self.fc2 = nn.Linear(widening_factor * num_channels, num_channels)
418 |
419 | def forward(self, x):
420 | x = self.layernorm(x)
421 | _, L, c = x.shape
422 | x = self.fc1(x)
423 |
424 | x = self.gelu(x)
425 | c = x.shape[-1]
426 | x = self.fc2(x)
427 | return x
428 |
429 |
430 | def forward_calc_flops(self, x):
431 | x = self.layernorm(x)
432 | _, L, c = x.shape
433 | x = self.fc1(x)
434 | flops = L * c * x.shape[-1]
435 |
436 | x = self.gelu(x)
437 | c = x.shape[-1]
438 | x = self.fc2(x)
439 | flops += L * c * x.shape[-1]
440 | return x, flops
441 |
442 |
443 | class Residual(nn.Module):
444 | def __init__(self, module: nn.Module, dropout: float):
445 | super().__init__()
446 | self.module = module
447 | self.dropout = nn.Dropout(p=dropout)
448 | self.dropout_p = dropout
449 |
450 | def forward(self, *args, **kwargs):
451 | # print(*args, **kwargs)
452 | # print(f'args:')
453 | # print(*args)
454 | # print(f'kwargs:')
455 | # print(**kwargs)
456 | x = self.module(*args, **kwargs)
457 | return self.dropout(x) + args[0]
458 |
459 |
460 | class InputAdapter(nn.Module):
461 | def __init__(self, num_input_channels: int):
462 | """Transforms and position-encodes task-specific input to generic encoder input.
463 |
464 | :param num_input_channels: Number of channels of the generic encoder input produced by this adapter.
465 | """
466 | super().__init__()
467 | self._num_input_channels = num_input_channels
468 |
469 | @property
470 | def num_input_channels(self):
471 | return self._num_input_channels
472 |
473 | def forward(self, x):
474 | raise NotImplementedError()
475 |
476 |
477 | class OutputAdapter(nn.Module):
478 | def __init__(self, output_query: Tensor):
479 | """Transforms generic decoder cross-attention output to task-specific output.
480 |
481 | :param output_query: output query prototype (does not include batch dimension) used as query input to
482 | generic decoder cross-attention.
483 | """
484 | super().__init__()
485 | self._output_query = nn.Parameter(output_query)
486 | self._init_parameters()
487 |
488 | def _init_parameters(self):
489 | with torch.no_grad():
490 | self._output_query.normal_(0.0, 0.02).clamp_(-2.0, 2.0)
491 |
492 | @property
493 | def num_output_query_channels(self):
494 | return self._output_query.shape[-1]
495 |
496 | def output_query(self, x):
497 | return repeat(self._output_query, "... -> b ...", b=x.shape[0])
498 |
499 | def forward(self, x):
500 | raise NotImplementedError()
501 |
502 |
503 | class ClassificationOutputAdapter(OutputAdapter):
504 | def __init__(self, num_classes: int, num_output_queries: int = 1, num_output_query_channels: Optional[int] = None):
505 |
506 | if num_output_query_channels is None:
507 | num_output_query_channels = num_classes
508 |
509 | super().__init__(output_query=torch.empty(num_output_queries, num_output_query_channels))
510 | self.linear = nn.Linear(num_output_query_channels, num_classes)
511 |
512 | def forward(self, x):
513 | return self.linear(x).squeeze(dim=1)
514 |
515 |
516 | class PerceiverEncoder(nn.Module):
517 | def __init__(
518 | self,
519 | input_adapter: InputAdapter,
520 | num_latents: int,
521 | num_latent_channels: int,
522 | num_cross_attention_heads: int = 4,
523 | num_cross_attention_qk_channels: Optional[int] = None,
524 | num_cross_attention_v_channels: Optional[int] = None,
525 | num_cross_attention_layers: int = 1,
526 | first_cross_attention_layer_shared: bool = False,
527 | cross_attention_widening_factor: int = 1,
528 | num_self_attention_heads: int = 4,
529 | num_self_attention_qk_channels: Optional[int] = None,
530 | num_self_attention_v_channels: Optional[int] = None,
531 | num_self_attention_layers_per_block: int = 6,
532 | num_self_attention_blocks: int = 1,
533 | first_self_attention_block_shared: bool = True,
534 | self_attention_widening_factor: int = 1,
535 | dropout: float = 0.0,
536 | activation_checkpointing: bool = False,
537 | ):
538 | """Generic Perceiver IO encoder.
539 |
540 | :param input_adapter: Transforms and position-encodes task-specific input to generic encoder input
541 | of shape (B, M, C) where B is the batch size, M the input sequence length and C the number of
542 | key/value input channels. C is determined by the `num_input_channels` property of the
543 | `input_adapter`.
544 | :param num_latents: Number of latent variables (N).
545 | :param num_latent_channels: Number of latent channels (D).
546 | :param num_cross_attention_heads: Number of cross-attention heads.
547 | :param num_cross_attention_qk_channels: Number of query and key channels for cross-attention
548 | (see `MultiHeadAttention.num_qk_channels` for details).
549 | :param num_cross_attention_v_channels: Number of value channels for cross-attention
550 | (see `MultiHeadAttention.num_v_channels` for details).
551 | :param num_cross_attention_layers: Number of cross-attention layers (alternating with self-attention blocks).
552 | :param first_cross_attention_layer_shared: Whether the first cross-attention layer should share its weights
553 | with subsequent cross-attention layers (if any).
554 | :param num_self_attention_heads: Number of self-attention heads.
555 | :param num_self_attention_qk_channels: Number of query and key channels for self-attention
556 | (see `MultiHeadAttention.num_qk_channels` for details).
557 | :param num_self_attention_v_channels: Number of value channels for self-attention
558 | (see `MultiHeadAttention.num_v_channels` for details).
559 | :param num_self_attention_layers_per_block: Number of self-attention layers per self-attention block.
560 | :param num_self_attention_blocks: Number of self-attention blocks sharing weights between corresponding
561 | self-attention layers.
562 | :param first_self_attention_block_shared: Whether the first self-attention block should share its weights
563 | with subsequent self-attention blocks (if any).
564 | :param dropout: Dropout probability for self- and cross-attention layers and residuals.
565 | :param activation_checkpointing: If True, implements an activation checkpoint for each self-attention
566 | layer and cross-attention layer.
567 | """
568 | super().__init__()
569 |
570 | self.input_adapter = input_adapter
571 |
572 | if num_cross_attention_layers <= 0:
573 | raise ValueError("num_cross_attention_layers must be > 0")
574 |
575 | if num_self_attention_blocks <= 0:
576 | raise ValueError("num_self_attention_blocks must be > 0")
577 |
578 | if num_cross_attention_layers > num_self_attention_blocks:
579 | raise ValueError("num_cross_attention_layers must be <= num_self_attention_blocks")
580 |
581 | self.num_cross_attention_layers = num_cross_attention_layers
582 | self.num_self_attention_blocks = num_self_attention_blocks
583 |
584 | self.first_cross_attention_layer_shared = first_cross_attention_layer_shared
585 | self.first_self_attention_block_shared = first_self_attention_block_shared
586 |
587 | def cross_attn():
588 | # print(input_adapter.num_input_channels)
589 | # assert(0==1)
590 | layer = CrossAttentionLayer(
591 | num_heads=num_cross_attention_heads,
592 | num_q_input_channels=num_latent_channels,
593 | num_kv_input_channels=input_adapter.num_input_channels,
594 | num_qk_channels=num_cross_attention_qk_channels,
595 | num_v_channels=num_cross_attention_v_channels,
596 | widening_factor=cross_attention_widening_factor,
597 | dropout=dropout,
598 | )
599 | # return checkpoint_wrapper(layer) if activation_checkpointing else layer
600 | return layer
601 |
602 | def self_attn():
603 | return SelfAttentionBlock(
604 | num_layers=num_self_attention_layers_per_block,
605 | num_heads=num_self_attention_heads,
606 | num_channels=num_latent_channels,
607 | num_qk_channels=num_self_attention_qk_channels,
608 | num_v_channels=num_self_attention_v_channels,
609 | widening_factor=self_attention_widening_factor,
610 | dropout=dropout,
611 | activation_checkpointing=activation_checkpointing,
612 | )
613 |
614 | self.cross_attn_n = cross_attn()
615 | self.self_attn_n = self_attn()
616 |
617 | if self.first_cross_attention_layer_shared or num_cross_attention_layers == 1:
618 | self.cross_attn_1 = self.cross_attn_n
619 | else:
620 | self.cross_attn_1 = cross_attn()
621 |
622 | if self.first_self_attention_block_shared or num_self_attention_blocks == 1:
623 | self.self_attn_1 = self.self_attn_n
624 | else:
625 | self.self_attn_1 = self_attn()
626 |
627 | # learnable initial latent vectors
628 | self.latent = nn.Parameter(torch.empty(num_latents, num_latent_channels))
629 | self._init_parameters()
630 |
631 | def _init_parameters(self):
632 | with torch.no_grad():
633 | self.latent.normal_(0.0, 0.02).clamp_(-2.0, 2.0)
634 |
635 | def forward(self, x, pad_mask=None):
636 | b, *_ = x.shape
637 |
638 | # encode task-specific input
639 | print(x.shape)
640 | x = self.input_adapter(x)
641 | print(x.shape)
642 | assert(0==1)
643 | # repeat initial latent vector along batch dimension
644 | x_latent = repeat(self.latent, "... -> b ...", b=b)
645 |
646 | x_latent = self.cross_attn_1(x_latent, x, pad_mask)
647 | x_latent = self.self_attn_1(x_latent)
648 |
649 | for i in range(1, self.num_self_attention_blocks):
650 | if i < self.num_cross_attention_layers:
651 | x_latent = self.cross_attn_n(x_latent, x, pad_mask)
652 | x_latent = self.self_attn_n(x_latent)
653 |
654 | print(x_latent.shape)
655 | return x_latent
656 |
657 |
658 | class PerceiverDecoder(nn.Module):
659 | def __init__(
660 | self,
661 | output_adapter: OutputAdapter,
662 | num_latent_channels: int,
663 | num_cross_attention_heads: int = 4,
664 | num_cross_attention_qk_channels: Optional[int] = None,
665 | num_cross_attention_v_channels: Optional[int] = None,
666 | cross_attention_widening_factor: int = 1,
667 | dropout: float = 0.0,
668 | activation_checkpointing: bool = False,
669 | ):
670 | """Generic Perceiver IO decoder.
671 |
672 | :param output_adapter: Transforms generic decoder cross-attention output of shape (B, O, F) to task-specific
673 | output. B is the batch size, O the output sequence length and F the number of cross-attention output
674 | channels. F is determined by the `num_output_query_channels` property of the `output_adapter`.
675 | :param num_latent_channels: Number of latent channels (C_latent) as produced by a Perceiver IO encoder.
676 | :param num_cross_attention_heads: Number of cross-attention heads.
677 | :param num_cross_attention_qk_channels: Number of query and key channels for cross-attention
678 | (see `MultiHeadAttention.num_qk_channels` for details).
679 | :param num_cross_attention_v_channels: Number of value channels for cross-attention
680 | (see `MultiHeadAttention.num_v_channels` for details).
681 | :param dropout: Dropout probability for cross-attention layers and residuals.
682 | :param activation_checkpointing: If True, implements an activation checkpoint for the decoder's
683 | cross-attention layer.
684 | """
685 | super().__init__()
686 |
687 | cross_attn = CrossAttentionLayer(
688 | num_heads=num_cross_attention_heads,
689 | num_q_input_channels=output_adapter.num_output_query_channels,
690 | num_kv_input_channels=num_latent_channels,
691 | num_qk_channels=num_cross_attention_qk_channels,
692 | num_v_channels=num_cross_attention_v_channels,
693 | widening_factor=cross_attention_widening_factor,
694 | dropout=dropout,
695 | )
696 |
697 | # if activation_checkpointing:
698 | # cross_attn = checkpoint_wrapper(cross_attn)
699 |
700 | self.cross_attn = cross_attn
701 | self.output_adapter = output_adapter
702 |
703 | def forward(self, x):
704 | output_query = self.output_adapter.output_query(x)
705 | output = self.cross_attn(output_query, x)
706 | return self.output_adapter(output)
707 |
708 |
709 | class PerceiverIO(Sequential):
710 | def __init__(self, encoder: PerceiverEncoder, decoder: PerceiverDecoder):
711 | super().__init__(encoder, decoder)
712 |
713 | @property
714 | def encoder(self):
715 | return self[0]
716 |
717 | @property
718 | def decoder(self):
719 | return self[1]
720 |
--------------------------------------------------------------------------------