├── train.py ├── dataset └── transform.py ├── LICENSE ├── config.py ├── model ├── custom_blocks.py ├── cifar100_model.py └── patch_conv_net.py ├── README.md └── .gitignore /train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities.seed import seed_everything 2 | from pytorch_lightning import Trainer 3 | 4 | from model.cifar100_model import Cifar100Model 5 | from config import Config 6 | 7 | 8 | seed_everything(Config.seed) 9 | model = Cifar100Model(Config) 10 | trainer = Trainer(gpus=1, max_epochs=Config.epochs, auto_lr_find=True) 11 | # trainer.tune(model) 12 | trainer.fit(model) -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from timm.data.auto_augment import rand_augment_transform 3 | from config import Config 4 | 5 | resize = [transforms.Resize(Config.image_size)] 6 | randaug = [] 7 | preprocess = [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 8 | if Config.use_randaug: 9 | randaug.append(rand_augment_transform( 10 | config_str='rand-m7-mstd0.5', 11 | hparams={'translate_const': 117, 'img_mean': (124, 116, 104)} 12 | )) 13 | 14 | train_transform = transforms.Compose(resize + randaug + preprocess) 15 | 16 | test_transform = transforms.Compose([ 17 | transforms.Resize(Config.image_size), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 20 | ]) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dongkyun Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Config: 5 | # Directories 6 | data_dir = 'data' 7 | 8 | # Data 9 | image_size = 384 10 | input_dim = 3 11 | num_of_classes = 100 12 | 13 | # Trainer 14 | train_batch_size = 96 15 | val_batch_size = 96 16 | num_workers = 8 17 | epochs = 40 18 | lr = 1e-4 19 | seed = 42 20 | use_randaug = True 21 | use_mixup = True 22 | use_stochastic_depth = True 23 | mixup_args = { 24 | 'mixup_alpha': 0.8, 25 | 'cutmix_alpha': 1.0, 26 | 'cutmix_minmax': None, 27 | 'prob': 1.0, 28 | 'switch_prob': 0., 29 | 'mode': 'batch', 30 | 'label_smoothing': 0, 31 | 'num_classes': num_of_classes 32 | } 33 | 34 | 35 | # Model 36 | model_width = 'S' 37 | model_depth = 60 38 | model_name = model_width + str(model_depth) 39 | 40 | if model_width == 'S': 41 | patch_dim = 384 42 | elif model_width == 'B': 43 | patch_dim = 768 44 | elif model_width == 'L': 45 | patch_dim = 1024 46 | 47 | conv_stem_hidden_dims = [32, 64, 128] 48 | conv_stem_layers = 4 49 | assert len(conv_stem_hidden_dims) + 1 == conv_stem_layers 50 | 51 | column_hidden_dim = patch_dim//3 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /model/custom_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_activation(act): 6 | if act == 'relu': 7 | return nn.ReLU() 8 | elif act == 'gelu': 9 | return nn.GELU() 10 | elif act is None: 11 | return nn.Identity() 12 | 13 | 14 | class Conv2dBlock(nn.Module): 15 | def __init__(self, in_dim, out_dim, kernel_size, stride, padding, act=None): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding) 18 | self.act = get_activation(act) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | x = self.act(x) 23 | return x 24 | 25 | 26 | class SEBlock(nn.Module): 27 | def __init__(self, in_dim, reduction=16): 28 | super().__init__() 29 | self.layers = nn.Sequential( 30 | nn.AdaptiveAvgPool2d(1), 31 | nn.Flatten(), 32 | nn.Linear(in_dim, in_dim//reduction, bias=False), 33 | nn.ReLU(), 34 | nn.Linear(in_dim//reduction, in_dim, bias=False), 35 | nn.Sigmoid() 36 | ) 37 | 38 | def forward(self, x): 39 | weights = self.layers(x) 40 | weights = weights.unsqueeze(-1).unsqueeze(-1) 41 | return x * weights.expand_as(x) 42 | 43 | 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of Augmenting Convolutional networks with attention-based aggregation 2 | 3 | This is the unofficial PyTorch Implementation of "Augmenting Convolutional networks with attention-based aggregation" 4 | 5 | reference: https://arxiv.org/pdf/2112.13692.pdf 6 | 7 | ## Prerequisites 8 | 9 | + PyTorch 10 | + PyTorch Lightning 11 | + timm 12 | + torchmetrics 13 | + torchvision 14 | + python3 15 | + CUDA 16 | 17 | ## Comments 18 | - Due to computation limits, CIFAR100 dataset was used in contrast to ImageNet in the original paper. 19 | - Since the official code is not released yet, there may be differences in structures and hyperparameters. 20 | - Most of the hidden dimensions were chosen based on guesswork. 21 | - MADGRAD was used instead of LAMB optimizer. 22 | - (I thought it would be inefficient to use LAMB for small batchsizes in my local machine) 23 | - LayerScale will be added soon 24 | 25 | 26 | ## Citations 27 | 28 | ```bibtex 29 | @misc{touvron2021augmenting, 30 | title={Augmenting Convolutional networks with attention-based aggregation}, 31 | author={Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Piotr Bojanowski and Armand Joulin and Gabriel Synnaeve and Hervé Jégou}, 32 | year={2021}, 33 | eprint={2112.13692}, 34 | archivePrefix={arXiv}, 35 | primaryClass={cs.CV} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | lightning_logs 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /model/cifar100_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | import torch.nn as nn 4 | import pytorch_lightning as pl 5 | import torchvision.datasets as datasets 6 | from timm.data.mixup import Mixup 7 | from torch_optimizer import MADGRAD 8 | from model.patch_conv_net import PatchConvNet 9 | from dataset.transform import train_transform, test_transform 10 | 11 | 12 | class Cifar100Model(pl.LightningModule): 13 | def __init__(self, cfg): 14 | super().__init__() 15 | self.cfg = cfg 16 | self.patch_conv_net = PatchConvNet(cfg) 17 | self.criterion = nn.CrossEntropyLoss() 18 | self.val_acc = torchmetrics.Accuracy() 19 | self.mixup_fn = Mixup(**cfg.mixup_args) 20 | self.lr = cfg.lr 21 | 22 | def forward(self, x, training): 23 | return self.patch_conv_net(x, training) 24 | 25 | def shared_step(self, batch, batch_idx, training): 26 | images, target = batch 27 | if training and self.cfg.use_mixup: 28 | images, target = self.mixup_fn(images, target) 29 | training = training and self.cfg.use_stochastic_depth 30 | output = self(images, training) 31 | loss = self.criterion(output, target) 32 | return loss, output, target 33 | 34 | def training_step(self, batch, batch_idx): 35 | loss, _, _ = self.shared_step(batch, batch_idx, True) 36 | return loss 37 | 38 | def validation_step(self, batch, batch_idx): 39 | loss, output, target = self.shared_step(batch, batch_idx, False) 40 | self.val_acc(output, target) 41 | logs = {'val_loss': loss, 'val_acc': self.val_acc} 42 | self.log_dict(logs, prog_bar=True) 43 | 44 | def configure_optimizers(self): 45 | optimizer = MADGRAD(self.parameters(), lr=self.lr) 46 | return optimizer 47 | 48 | def train_dataloader(self): 49 | train_dataset = datasets.CIFAR100(self.cfg.data_dir, train=True, download=True, transform=train_transform) 50 | train_loader = torch.utils.data.DataLoader( 51 | dataset=train_dataset, batch_size=self.cfg.train_batch_size, shuffle=True, num_workers=self.cfg.num_workers 52 | ) 53 | return train_loader 54 | 55 | def val_dataloader(self): 56 | val_dataset = datasets.CIFAR100(self.cfg.data_dir, train=False, download=True, transform=test_transform) 57 | val_loader = torch.utils.data.DataLoader( 58 | dataset=val_dataset, batch_size=self.cfg.train_batch_size, shuffle=False, num_workers=self.cfg.num_workers 59 | ) 60 | return val_loader -------------------------------------------------------------------------------- /model/patch_conv_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules import Conv2d 4 | from model.custom_blocks import Conv2dBlock, SEBlock 5 | from torchvision.ops import stochastic_depth 6 | 7 | 8 | class ConvolutionalStem(nn.Module): 9 | def __init__(self, in_dim, out_dim, hidden_dims, num_layers): 10 | super().__init__() 11 | assert len(hidden_dims) + 1 == num_layers 12 | dims = [in_dim] + hidden_dims + [out_dim] 13 | 14 | self.layers = [Conv2dBlock(dims[i], dims[i+1], 3, 2, 1, 'gelu') for i in range(num_layers)] 15 | self.layers = nn.Sequential(*self.layers) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | 21 | class ColumnBlock(nn.Module): 22 | def __init__(self, in_dim, hidden_dim=None): 23 | super().__init__() 24 | if not hidden_dim: 25 | hidden_dim = in_dim // 4 26 | 27 | self.layers = nn.Sequential( 28 | nn.BatchNorm2d(in_dim), 29 | Conv2dBlock(in_dim, hidden_dim, 1, 1, 0, 'gelu'), 30 | Conv2dBlock(hidden_dim, hidden_dim, 3, 1, 1, 'gelu'), 31 | SEBlock(hidden_dim), 32 | Conv2d(hidden_dim, in_dim, 1, 1, 0), 33 | ) 34 | 35 | def forward(self, x, training): 36 | x_skip = x 37 | x = self.layers(x) 38 | x = stochastic_depth(x, p=0.5, mode="row", training=training) 39 | x = x + x_skip 40 | return x 41 | 42 | 43 | class Column(nn.Module): 44 | def __init__(self, in_dim, hidden_dim=None, num_layers=60): 45 | super().__init__() 46 | self.layers = [ColumnBlock(in_dim, hidden_dim) for _ in range(num_layers)] 47 | self.layers = nn.ModuleList(self.layers) 48 | 49 | def forward(self, x, training): 50 | for layer in self.layers: 51 | x = layer(x, training) 52 | return x 53 | 54 | 55 | class AttentionPooling(nn.Module): 56 | def __init__(self, in_dim): 57 | super().__init__() 58 | self.cls_vec = nn.Parameter(torch.randn(in_dim)) 59 | self.fc = nn.Linear(in_dim, in_dim) 60 | self.softmax = nn.Softmax(-1) 61 | 62 | def forward(self, x): 63 | weights = torch.matmul(x.view(-1, x.shape[1]), self.cls_vec) 64 | weights = self.softmax(weights.view(x.shape[0], -1)) 65 | x = torch.bmm(x.view(x.shape[0], x.shape[1], -1), weights.unsqueeze(-1)).squeeze() 66 | x = x + self.cls_vec 67 | x = self.fc(x) 68 | x = x + self.cls_vec 69 | return x 70 | 71 | 72 | class PatchConvNet(nn.Module): 73 | def __init__(self, cfg): 74 | super().__init__() 75 | self.conv_stem = ConvolutionalStem( 76 | in_dim=cfg.input_dim, 77 | out_dim=cfg.patch_dim, 78 | hidden_dims=cfg.conv_stem_hidden_dims, 79 | num_layers=cfg.conv_stem_layers 80 | ) 81 | self.column = Column( 82 | in_dim=cfg.patch_dim, 83 | hidden_dim=cfg.column_hidden_dim, 84 | num_layers=cfg.model_depth 85 | ) 86 | self.pooling = AttentionPooling( 87 | in_dim=cfg.patch_dim, 88 | ) 89 | self.head = nn.Linear(cfg.patch_dim, cfg.num_of_classes) 90 | 91 | def forward(self, x, training): 92 | x = self.conv_stem(x) 93 | x = self.column(x, training) 94 | x = self.pooling(x) 95 | x = self.head(x) 96 | return x 97 | 98 | --------------------------------------------------------------------------------