├── LICENSE ├── README.md ├── figure ├── accuracy.png ├── loss.png └── model_arch.png ├── main.py ├── main_ddp.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── module.cpython-37.pyc ├── model.py └── module.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hong-Jia Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobileViT 2 | 3 | Unofficial PyTorch implementation of MobileViT based on paper [MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER](https://arxiv.org/abs/2110.02178). 4 | 5 | --- 6 | 7 | ## Table of Contents 8 | * [Model Architecture](#model-architecture) 9 | * [Usage](#usage) 10 | * [Experiment Result](#experiment) 11 | * [Citation](#citation) 12 | 13 | 14 | --- 15 | ## Model Architecture 16 |
17 | Trulli 18 |
MobileViT Architecture
19 |
20 | 21 | --- 22 | 23 | ## Usage 24 | 25 | ```bash= 26 | import torch 27 | import models 28 | 29 | img = torch.randn(1, 3, 256, 256) 30 | net = models.MobileViT_S() 31 | 32 | # XXS: 1.3M 、 XS: 2.3M 、 S: 5.6M 33 | print("MobileViT-S params: ", sum(p.numel() for p in net.parameters())) 34 | print(f"Output shape: {net(img).shape}") 35 | ``` 36 | 37 | ### Training 38 | - Single node with one GPU 39 | ```bash= 40 | python main.py 41 | ``` 42 | 43 | - Single node with multi GPU 44 | ```bash= 45 | CUDA_VISIBLE_DEVICES=3,4 python -m torch.distributed.launch --nproc_per_node=2 --master_port=6666 main_ddp.py 46 | ``` 47 | 48 | ```bash= 49 | optional arguments: 50 | -h, --help show this help message and exit 51 | --gpu_device GPU_DEVICE 52 | Select specific GPU to run the model 53 | --batch-size N Input batch size for training (default: 64) 54 | --epochs N Number of epochs to train (default: 20) 55 | --num-class N Number of classes to classify (default: 10) 56 | --lr LR Learning rate (default: 0.01) 57 | --weight-decay WD Weight decay (default: 1e-5) 58 | --model-path PATH Path to save the model 59 | ``` 60 | 61 | --- 62 | 63 | ## Experiment 64 | 65 | ![Accuracy of ImageNet](./figure/accuracy.png) 66 | 67 | ![Loss of ImageNet](./figure/loss.png) 68 | 69 | ### MobileVit-S Pretrained Weights: [weight](https://drive.google.com/file/d/1ZQt1vACHTN98QJYaT2JW3kPF-wziHyPX/view?usp=sharing) 70 | ### MobileVit-XXS Pretrained Weights: [weight](https://drive.google.com/file/d/1PZGq1hVNokS1r5R3cCJr9IC75CyjL6a8/view?usp=sharing) 71 | 72 | ### How to load pretrained weight(training with DataParrael) 73 | Solution by the **[@Sehaba95](https://github.com/wilile26811249/MobileViT/issues/7)**: 74 | ```python 75 | def load_mobilevit_weights(model_path): 76 | # Create an instance of the MobileViT model 77 | net = MobileViT_S() 78 | 79 | # Load the PyTorch state_dict 80 | state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict'] 81 | 82 | # Since there is a problem in the names of layers, we will change the keys to meet the MobileViT model architecture 83 | for key in list(state_dict.keys()): 84 | state_dict[key.replace('module.', '')] = state_dict.pop(key) 85 | 86 | # Once the keys are fixed, we can modify the parameters of MobileViT 87 | net.load_state_dict(state_dict) 88 | 89 | return net 90 | 91 | net = load_mobilevit_weights("MobileViT_S_model_best.pth.tar") 92 | ``` 93 | 94 | --- 95 | 96 | |Model | Dataset | Learning Rate | LR Scheduler | Optimizer | Weight decay | Acc@1/Val | Acc@5/Val | 97 | |-------|:--------:|:------:|:----:|:--------:|:-------:|:--------:|:-------:| 98 | |MobileViT | ImageNet-1k | 0.05 | Cosine LR| SGDM | 1e-5 | 61.918% | 83.05% | 99 | 100 | --- 101 | 102 | ## Citation 103 | ``` 104 | @InProceedings{Sachin2021, 105 | title = {MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TRANSFORMER}, 106 | author = {Sachin Mehta and Mohammad Rastegari}, 107 | booktitle = {}, 108 | year = {2021} 109 | } 110 | ``` 111 | 112 | 113 | ### If this implement have any problem please let me know, thank you. 114 | -------------------------------------------------------------------------------- /figure/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilile26811249/MobileViT/5e61fb72bf77b3445c4bc81008efaf2c68822594/figure/accuracy.png -------------------------------------------------------------------------------- /figure/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilile26811249/MobileViT/5e61fb72bf77b3445c4bc81008efaf2c68822594/figure/loss.png -------------------------------------------------------------------------------- /figure/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilile26811249/MobileViT/5e61fb72bf77b3445c4bc81008efaf2c68822594/figure/model_arch.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import utils 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms as T 9 | import torchvision.datasets as datasets 10 | from tqdm import tqdm 11 | import wandb 12 | 13 | import models 14 | 15 | def save_checkpoint(state, is_best, filename = 'checkpoint.pth.tar'): 16 | torch.save(state, filename) 17 | if is_best: 18 | shutil.copyfile(filename, 'model_best.pth.tar') 19 | 20 | 21 | def train_epoch(epoch, net, train_loader, val_loader , criterion, optimizer, scheduler, device): 22 | """ 23 | Training logic for an epoch 24 | """ 25 | global best_acc1 26 | train_loss = utils.AverageMeter("Epoch losses", ":.4e") 27 | train_acc1 = utils.AverageMeter("Train Acc@1", ":6.2f") 28 | train_acc5 = utils.AverageMeter("Train Acc@5", ":6.2f") 29 | progress_train = utils.ProgressMeter( 30 | num_batches = len(val_loader), 31 | meters = [train_loss, train_acc1, train_acc5], 32 | prefix = 'Epoch: {} '.format(epoch + 1), 33 | batch_info = " Iter" 34 | ) 35 | net.train() 36 | 37 | for it, (inputs, targets) in enumerate(tqdm(train_loader)): 38 | inputs = inputs.to(device) 39 | targets = targets.to(device) 40 | 41 | # Forward pass 42 | outputs = net(inputs) 43 | loss = criterion(outputs, targets) 44 | acc1, acc5 = utils.accuracy(outputs, targets, topk = (1, 5)) 45 | 46 | train_loss.update(loss.item(), inputs.size(0)) 47 | train_acc1.update(acc1.item(), inputs.size(0)) 48 | train_acc5.update(acc5.item(), inputs.size(0)) 49 | if it % args.print_freq == 0: 50 | progress_train.display(it) 51 | 52 | # Backward and optimize 53 | optimizer.zero_grad() 54 | loss.backward() 55 | optimizer.step() 56 | 57 | # Log on Wandb 58 | wandb.log({ 59 | "Loss/train" : train_loss.avg, 60 | "Acc@1/train" : train_acc1.avg, 61 | "Acc@5/train" : train_acc5.avg, 62 | }) 63 | scheduler.step() 64 | 65 | # Validation model 66 | val_loss = utils.AverageMeter("Val losses", ":.4e") 67 | val_acc1 = utils.AverageMeter("Val Acc@1", ":6.2f") 68 | val_acc5 = utils.AverageMeter("Val Acc@5", ":6.2f") 69 | progress_val = utils.ProgressMeter( 70 | num_batches = len(val_loader), 71 | meters = [val_loss, val_acc1, val_acc5], 72 | prefix = 'Epoch: {} '.format(epoch + 1), 73 | batch_info = " Iter" 74 | ) 75 | net.eval() 76 | 77 | for it, (inputs, targets) in enumerate(val_loader): 78 | inputs = inputs.to(device) 79 | targets = targets.to(device) 80 | 81 | # Forward pass 82 | with torch.no_grad(): 83 | outputs = net(inputs) 84 | loss = criterion(outputs, targets) 85 | acc1, acc5 = utils.accuracy(outputs, targets, topk=(1, 5)) 86 | val_loss.update(loss.item(), inputs.size(0)) 87 | val_acc1.update(acc1.item(), inputs.size(0)) 88 | val_acc5.update(acc5.item(), inputs.size(0)) 89 | acc1 = val_acc1.avg 90 | 91 | if it % args.print_freq == 0: 92 | progress_val.display(it) 93 | 94 | # Log on Wandb 95 | wandb.log({ 96 | "Loss/val" : val_loss.avg, 97 | "Acc@1/val" : val_acc1.avg, 98 | "Acc@5/val" : val_acc5.avg 99 | }) 100 | 101 | is_best = acc1 > best_acc1 102 | best_acc1 = max(acc1, best_acc1) 103 | save_checkpoint({ 104 | 'epoch': epoch + 1, 105 | 'state_dict': net.state_dict(), 106 | 'best_acc1': best_acc1, 107 | 'optimizer' : optimizer.state_dict(), 108 | }, is_best) 109 | return val_loss.avg, val_acc1.avg, val_acc5.avg 110 | 111 | 112 | if __name__ == "__main__": 113 | best_acc1 = 0.0 114 | parser = argparse.ArgumentParser(description = "Train classification of CMT model") 115 | parser.add_argument('--data', metavar = 'DIR', default = '../imagenet_data', 116 | help = 'path to dataset') 117 | parser.add_argument("--gpu_device", type = int, default = 2, 118 | help = "Select specific GPU to run the model") 119 | parser.add_argument('--batch-size', type = int, default = 256, metavar = 'N', 120 | help = 'Input batch size for training (default: 64)') 121 | parser.add_argument('--epochs', type = int, default = 90, metavar = 'N', 122 | help = 'Number of epochs to train (default: 90)') 123 | parser.add_argument('--num-class', type = int, default = 1000, metavar = 'N', 124 | help = 'Number of classes to classify (default: 10)') 125 | parser.add_argument('--lr', type = float, default = 0.05, metavar='LR', 126 | help = 'Learning rate (default: 6e-5)') 127 | parser.add_argument('--weight-decay', type = float, default = 5e-5, metavar = 'WD', 128 | help = 'Weight decay (default: 1e-5)') 129 | parser.add_argument('-p', '--print-freq', default = 10, type = int, 130 | metavar='N', help='print frequency (default: 10)') 131 | args = parser.parse_args() 132 | 133 | # autotune cudnn kernel choice 134 | torch.backends.cudnn.benchmark = True 135 | 136 | # Create folder to save model 137 | WEIGHTS_PATH = "./weights" 138 | if not os.path.exists(WEIGHTS_PATH): 139 | os.makedirs(WEIGHTS_PATH) 140 | 141 | # Set device 142 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_device) 143 | os.environ["CUDA_VISIBLE_DEVICES"] = '1,2' 144 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 145 | 146 | # Data loading 147 | traindir = os.path.join(args.data, 'train') 148 | valdir = os.path.join(args.data, 'val') 149 | normalize = T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) 150 | 151 | train_dataset = datasets.ImageFolder( 152 | traindir, 153 | T.Compose([ 154 | T.RandomResizedCrop(256), 155 | T.RandomHorizontalFlip(), 156 | T.ToTensor(), 157 | normalize, 158 | ])) 159 | val_dataset = datasets.ImageFolder( 160 | valdir, 161 | T.Compose([ 162 | T.Resize(256), 163 | T.ToTensor(), 164 | normalize, 165 | ])) 166 | 167 | train_loader = torch.utils.data.DataLoader( 168 | train_dataset, batch_size = args.batch_size, shuffle = True, 169 | num_workers = 4, pin_memory = True 170 | ) 171 | val_loader = torch.utils.data.DataLoader( 172 | val_dataset, batch_size = args.batch_size, shuffle = False, 173 | num_workers = 4, pin_memory = True 174 | ) 175 | 176 | # Create model 177 | net = models.MobileViT_S() 178 | # net.to(device) 179 | net = torch.nn.DataParallel(net).to(device) 180 | 181 | # Set loss function and optimizer 182 | # criterion = nn.CrossEntropyLoss() 183 | criterion = nn.CrossEntropyLoss().to(device) 184 | optimizer = torch.optim.SGD(net.parameters(), args.lr, 185 | momentum = 0.9, 186 | weight_decay = args.weight_decay) 187 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 188 | 189 | # Using wandb for logging 190 | wandb.init() 191 | wandb.config.update(args) 192 | wandb.watch(net) 193 | 194 | # Train the model 195 | for epoch in tqdm(range(args.epochs)): 196 | loss, acc1, acc5 = train_epoch(epoch, net, train_loader, 197 | val_loader, criterion, optimizer, scheduler, device 198 | ) 199 | print(f"Epoch {epoch} -> Acc@1: {acc1}, Acc@5: {acc5}") 200 | 201 | print("Training is done") 202 | -------------------------------------------------------------------------------- /main_ddp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import utils 4 | import shutil 5 | import wandb 6 | import random 7 | import numpy as np 8 | from tqdm import tqdm 9 | import models 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import transforms as T 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.utils.data.distributed import DistributedSampler 17 | import torchvision.datasets as datasets 18 | 19 | 20 | def same_seeds(seed=18): 21 | np.random.seed(seed) # Numpy module. 22 | random.seed(seed) # Python random module. 23 | torch.manual_seed(seed) 24 | if torch.cuda.is_available(): 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 27 | torch.backends.cudnn.benchmark = False 28 | torch.backends.cudnn.deterministic = True 29 | 30 | 31 | def save_checkpoint(state, is_best, filename = 'checkpoint.pth.tar'): 32 | torch.save(state, filename) 33 | if is_best: 34 | shutil.copyfile(filename, 'model_best.pth.tar') 35 | 36 | 37 | def train_epoch(epoch, net, train_loader, val_loader , criterion, optimizer, scheduler, device): 38 | """ 39 | Training logic for an epoch 40 | """ 41 | if args.local_rank == 0: 42 | global best_acc1 43 | same_seeds(args.seed_num) 44 | 45 | train_loss = utils.AverageMeter("Epoch losses", ":.4e") 46 | train_acc1 = utils.AverageMeter("Train Acc@1", ":6.2f") 47 | train_acc5 = utils.AverageMeter("Train Acc@5", ":6.2f") 48 | progress_train = utils.ProgressMeter( 49 | num_batches = len(train_loader), 50 | meters = [train_loss, train_acc1, train_acc5], 51 | prefix = 'Epoch: {} '.format(epoch + 1), 52 | batch_info = " Iter" 53 | ) 54 | net.train() 55 | 56 | # Callbacks 57 | # warm_up_cos = lambda epoch: epoch / args.warmup_epochs if epoch <= args.warmup_epochs else 0.5 * (math.cos((epoch - args.warmup_epochs) /(args.epochs - args.warmup_epochs) * math.pi) + 1) 58 | # scheduler_wucos = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=warm_up_cos) 59 | 60 | for it, (inputs, targets) in enumerate(tqdm(train_loader)): 61 | inputs = inputs.to(device) 62 | targets = targets.to(device) 63 | 64 | # Forward pass 65 | outputs = net(inputs) 66 | loss = criterion(outputs, targets) 67 | acc1, acc5 = utils.accuracy(outputs, targets, topk = (1, 5)) 68 | 69 | train_loss.update(loss.item(), inputs.size(0)) 70 | train_acc1.update(acc1.item(), inputs.size(0)) 71 | train_acc5.update(acc5.item(), inputs.size(0)) 72 | if it % args.print_freq == 0: 73 | progress_train.display(it) 74 | 75 | # Backward and optimize 76 | optimizer.zero_grad() 77 | loss.backward() 78 | optimizer.step() 79 | 80 | # Log on Wandb 81 | if args.local_rank == 0: 82 | wandb.log({ 83 | "Loss/train" : train_loss.avg, 84 | "Acc@1/train" : train_acc1.avg, 85 | "Acc@5/train" : train_acc5.avg, 86 | }) 87 | 88 | # Validation model 89 | val_loss = utils.AverageMeter("Val losses", ":.4e") 90 | val_acc1 = utils.AverageMeter("Val Acc@1", ":6.2f") 91 | val_acc5 = utils.AverageMeter("Val Acc@5", ":6.2f") 92 | progress_val = utils.ProgressMeter( 93 | num_batches = len(val_loader), 94 | meters = [val_loss, val_acc1, val_acc5], 95 | prefix = 'Epoch: {} '.format(epoch + 1), 96 | batch_info = " Iter" 97 | ) 98 | net.eval() 99 | 100 | for it, (inputs, targets) in enumerate(val_loader): 101 | inputs = inputs.to(device) 102 | targets = targets.to(device) 103 | 104 | # Forward pass 105 | with torch.no_grad(): 106 | outputs = net(inputs) 107 | loss = criterion(outputs, targets) 108 | acc1, acc5 = utils.accuracy(outputs, targets, topk=(1, 5)) 109 | val_loss.update(loss.item(), inputs.size(0)) 110 | val_acc1.update(acc1.item(), inputs.size(0)) 111 | val_acc5.update(acc5.item(), inputs.size(0)) 112 | acc1 = val_acc1.avg 113 | 114 | if it % args.print_freq == 0: 115 | progress_val.display(it) 116 | 117 | # Log on Wandb 118 | if args.local_rank == 0: 119 | wandb.log({ 120 | "Loss/val" : val_loss.avg, 121 | "Acc@1/val" : val_acc1.avg, 122 | "Acc@5/val" : val_acc5.avg 123 | }) 124 | # Learning_rate callbacks 125 | scheduler.step() 126 | 127 | if args.local_rank == 0: 128 | is_best = acc1 > best_acc1 129 | best_acc1 = max(acc1, best_acc1) 130 | save_checkpoint({ 131 | 'epoch': epoch + 1, 132 | 'state_dict': net.state_dict(), 133 | 'best_acc1': best_acc1, 134 | 'optimizer' : optimizer.state_dict(), 135 | }, is_best) 136 | return val_loss.avg, val_acc1.avg, val_acc5.avg 137 | 138 | 139 | if __name__ == "__main__": 140 | best_acc1 = 0.0 141 | parser = argparse.ArgumentParser(description = "Train classification of CMT model") 142 | parser.add_argument('--data', metavar = 'DIR', default = '../imagenet_data', 143 | help = 'path to dataset') 144 | parser.add_argument("--gpu_device", type = int, default = 2, 145 | help = "Select specific GPU to run the model") 146 | parser.add_argument('--batch-size', type = int, default = 64, metavar = 'N', 147 | help = 'Input batch size for training (default: 64)') 148 | parser.add_argument('--epochs', type = int, default = 90, metavar = 'N', 149 | help = 'Number of epochs to train (default: 90)') 150 | parser.add_argument('-we', '--warmup_epochs', default=5, type=int, help='epochs for warmup') 151 | parser.add_argument('--num-class', type = int, default = 1000, metavar = 'N', 152 | help = 'Number of classes to classify (default: 10)') 153 | parser.add_argument('--lr', type = float, default = 0.05, metavar='LR', 154 | help = 'Learning rate (default: 6e-5)') 155 | parser.add_argument('--weight-decay', type = float, default = 5e-5, metavar = 'WD', 156 | help = 'Weight decay (default: 1e-5)') 157 | parser.add_argument('-p', '--print-freq', default = 10, type = int, metavar='N', 158 | help='print frequency (default: 10)') 159 | parser.add_argument('--local_rank', type=int, 160 | help='local rank for DistributedDataParallel') 161 | parser.add_argument('-seed', '--seed_num', default=42, type=int, 162 | help='number of random seed') 163 | args = parser.parse_args() 164 | 165 | # Multi GPU 166 | print(f'Running DDP on rank: {args.local_rank}') 167 | torch.cuda.set_device(args.local_rank) 168 | dist.init_process_group(backend = 'nccl', init_method = 'env://') 169 | 170 | # Create folder to save model 171 | WEIGHTS_PATH = "./weights" 172 | if not os.path.exists(WEIGHTS_PATH): 173 | os.makedirs(WEIGHTS_PATH) 174 | 175 | # Set device 176 | # os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" 177 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 178 | 179 | # Data loading 180 | traindir = os.path.join(args.data, 'train') 181 | valdir = os.path.join(args.data, 'val') 182 | normalize = T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) 183 | 184 | train_dataset = datasets.ImageFolder( 185 | traindir, 186 | T.Compose([ 187 | T.RandomResizedCrop(256), 188 | T.RandomHorizontalFlip(), 189 | T.ToTensor(), 190 | normalize, 191 | ])) 192 | train_sampler = DistributedSampler(train_dataset) 193 | val_dataset = datasets.ImageFolder( 194 | valdir, 195 | T.Compose([ 196 | T.Resize(256), 197 | T.ToTensor(), 198 | normalize, 199 | ])) 200 | 201 | train_loader = torch.utils.data.DataLoader( 202 | train_dataset, batch_size = args.batch_size, shuffle = (train_sampler is None), 203 | num_workers = 8, pin_memory = True, sampler=train_sampler 204 | ) 205 | val_loader = torch.utils.data.DataLoader( 206 | val_dataset, batch_size = args.batch_size, shuffle = False, 207 | num_workers = 8, pin_memory = True 208 | ) 209 | 210 | # Create model 211 | net = models.MobileViT_S() 212 | # net.to(device) 213 | # net = torch.nn.DataParallel(net).to(device) 214 | net = net.to(args.local_rank) 215 | net = DDP(net, device_ids=[args.local_rank], output_device=args.local_rank) 216 | 217 | # Set loss function and optimizer 218 | criterion = nn.CrossEntropyLoss().to(device) 219 | optimizer = torch.optim.SGD(net.parameters(), args.lr, 220 | momentum = 0.9, 221 | weight_decay = args.weight_decay) 222 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 223 | 224 | # Using wandb for logging 225 | if args.local_rank == 0: 226 | wandb.init() 227 | wandb.config.update(args) 228 | wandb.watch(net) 229 | 230 | # Train the model 231 | for epoch in tqdm(range(args.epochs)): 232 | train_sampler.set_epoch(epoch) 233 | loss, acc1, acc5 = train_epoch(epoch, net, train_loader, 234 | val_loader, criterion, optimizer, scheduler, device 235 | ) 236 | print(f"Epoch {epoch} -> Acc@1: {acc1}, Acc@5: {acc5}") 237 | 238 | print("Training is done") 239 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilile26811249/MobileViT/5e61fb72bf77b3445c4bc81008efaf2c68822594/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilile26811249/MobileViT/5e61fb72bf77b3445c4bc81008efaf2c68822594/models/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wilile26811249/MobileViT/5e61fb72bf77b3445c4bc81008efaf2c68822594/models/__pycache__/module.cpython-37.pyc -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .module import InvertedResidual, MobileVitBlock 5 | 6 | model_cfg = { 7 | "xxs":{ 8 | "features": [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320], 9 | "d": [64, 80, 96], 10 | "expansion_ratio": 2, 11 | "layers": [2, 4, 3] 12 | }, 13 | "xs":{ 14 | "features": [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384], 15 | "d": [96, 120, 144], 16 | "expansion_ratio": 4, 17 | "layers": [2, 4, 3] 18 | }, 19 | "s":{ 20 | "features": [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640], 21 | "d": [144, 192, 240], 22 | "expansion_ratio": 4, 23 | "layers": [2, 4, 3] 24 | }, 25 | } 26 | 27 | class MobileViT(nn.Module): 28 | def __init__(self, img_size, features_list, d_list, transformer_depth, expansion, num_classes = 1000): 29 | super(MobileViT, self).__init__() 30 | 31 | self.stem = nn.Sequential( 32 | nn.Conv2d(in_channels = 3, out_channels = features_list[0], kernel_size = 3, stride = 2, padding = 1), 33 | InvertedResidual(in_channels = features_list[0], out_channels = features_list[1], stride = 1, expand_ratio = expansion), 34 | ) 35 | 36 | self.stage1 = nn.Sequential( 37 | InvertedResidual(in_channels = features_list[1], out_channels = features_list[2], stride = 2, expand_ratio = expansion), 38 | InvertedResidual(in_channels = features_list[2], out_channels = features_list[2], stride = 1, expand_ratio = expansion), 39 | InvertedResidual(in_channels = features_list[2], out_channels = features_list[3], stride = 1, expand_ratio = expansion) 40 | ) 41 | 42 | self.stage2 = nn.Sequential( 43 | InvertedResidual(in_channels = features_list[3], out_channels = features_list[4], stride = 2, expand_ratio = expansion), 44 | MobileVitBlock(in_channels = features_list[4], out_channels = features_list[5], d_model = d_list[0], 45 | layers = transformer_depth[0], mlp_dim = d_list[0] * 2) 46 | ) 47 | 48 | self.stage3 = nn.Sequential( 49 | InvertedResidual(in_channels = features_list[5], out_channels = features_list[6], stride = 2, expand_ratio = expansion), 50 | MobileVitBlock(in_channels = features_list[6], out_channels = features_list[7], d_model = d_list[1], 51 | layers = transformer_depth[1], mlp_dim = d_list[1] * 4) 52 | ) 53 | 54 | self.stage4 = nn.Sequential( 55 | InvertedResidual(in_channels = features_list[7], out_channels = features_list[8], stride = 2, expand_ratio = expansion), 56 | MobileVitBlock(in_channels = features_list[8], out_channels = features_list[9], d_model = d_list[2], 57 | layers = transformer_depth[2], mlp_dim = d_list[2] * 4), 58 | nn.Conv2d(in_channels = features_list[9], out_channels = features_list[10], kernel_size = 1, stride = 1, padding = 0) 59 | ) 60 | 61 | self.avgpool = nn.AvgPool2d(kernel_size = img_size // 32) 62 | self.fc = nn.Linear(features_list[10], num_classes) 63 | 64 | 65 | def forward(self, x): 66 | # Stem 67 | x = self.stem(x) 68 | # Body 69 | x = self.stage1(x) 70 | x = self.stage2(x) 71 | x = self.stage3(x) 72 | x = self.stage4(x) 73 | # Head 74 | x = self.avgpool(x) 75 | x = x.view(x.size(0), -1) 76 | x = self.fc(x) 77 | return x 78 | 79 | 80 | def MobileViT_XXS(img_size = 256, num_classes = 1000): 81 | cfg_xxs = model_cfg["xxs"] 82 | model_xxs = MobileViT(img_size, cfg_xxs["features"], cfg_xxs["d"], cfg_xxs["layers"], cfg_xxs["expansion_ratio"], num_classes) 83 | return model_xxs 84 | 85 | def MobileViT_XS(img_size = 256, num_classes = 1000): 86 | cfg_xs = model_cfg["xs"] 87 | model_xs = MobileViT(img_size, cfg_xs["features"], cfg_xs["d"], cfg_xs["layers"], cfg_xs["expansion_ratio"], num_classes) 88 | return model_xs 89 | 90 | def MobileViT_S(img_size = 256, num_classes = 1000): 91 | cfg_s = model_cfg["s"] 92 | model_s = MobileViT(img_size, cfg_s["features"], cfg_s["d"], cfg_s["layers"], cfg_s["expansion_ratio"], num_classes) 93 | return model_s 94 | 95 | 96 | if __name__ == "__main__": 97 | img = torch.randn(1, 3, 256, 256) 98 | 99 | cfg_xxs = model_cfg["xxs"] 100 | model_xxs = MobileViT(256, cfg_xxs["features"], cfg_xxs["d"], cfg_xxs["layers"], cfg_xxs["expansion_ratio"]) 101 | 102 | cfg_xs = model_cfg["xs"] 103 | model_xs = MobileViT(256, cfg_xs["features"], cfg_xs["d"], cfg_xs["layers"], cfg_xs["expansion_ratio"]) 104 | 105 | cfg_s = model_cfg["s"] 106 | model_s = MobileViT(256, cfg_s["features"], cfg_s["d"], cfg_s["layers"], cfg_s["expansion_ratio"]) 107 | 108 | print(model_s) 109 | 110 | # XXS: 1.3M 、 XS: 2.3M 、 S: 5.6M 111 | print("XXS params: ", sum(p.numel() for p in model_xxs.parameters())) 112 | print(" XS params: ", sum(p.numel() for p in model_xs.parameters())) 113 | print(" S params: ", sum(p.numel() for p in model_s.parameters())) -------------------------------------------------------------------------------- /models/module.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any, Optional, List 2 | from einops import rearrange 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ConvNormAct(nn.Module): 9 | def __init__(self, 10 | in_channels: int, 11 | out_channels: int, 12 | kernel_size: int = 3, 13 | stride = 1, 14 | padding: Optional[int] = None, 15 | groups: int = 1, 16 | norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, 17 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.SiLU, 18 | dilation: int = 1 19 | ): 20 | super(ConvNormAct, self).__init__() 21 | if padding is None: 22 | padding = (kernel_size - 1) // 2 * dilation 23 | self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 24 | dilation = dilation, groups = groups, bias = norm_layer is None) 25 | 26 | self.norm_layer = nn.BatchNorm2d(out_channels) if norm_layer is None else norm_layer(out_channels) 27 | self.act = activation_layer() if activation_layer is not None else activation_layer 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | x = self.conv(x) 31 | if self.norm_layer is not None: 32 | x = self.norm_layer(x) 33 | if self.act is not None: 34 | x = self.act(x) 35 | return x 36 | 37 | 38 | class PreNorm(nn.Module): 39 | def __init__(self, dim, fn): 40 | super().__init__() 41 | self.norm = nn.LayerNorm(dim) 42 | self.fn = fn 43 | 44 | def forward(self, x, **kwargs): 45 | return self.fn(self.norm(x), **kwargs) 46 | 47 | 48 | class FFN(nn.Module): 49 | def __init__(self, dim, hidden_dim, dropout=0.): 50 | super().__init__() 51 | self.net = nn.Sequential( 52 | nn.Linear(dim, hidden_dim), 53 | nn.SiLU(), 54 | nn.Dropout(dropout), 55 | nn.Linear(hidden_dim, dim), 56 | nn.Dropout(dropout) 57 | ) 58 | 59 | def forward(self, x): 60 | return self.net(x) 61 | 62 | 63 | class MultiHeadSelfAttention(nn.Module): 64 | """ 65 | Implement multi head self attention layer using the "Einstein summation convention". 66 | Paper: https://arxiv.org/abs/1706.03762 67 | Blog: https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a 68 | Parameters 69 | ---------- 70 | dim: 71 | Token's dimension, EX: word embedding vector size 72 | num_heads: 73 | The number of distinct representations to learn 74 | dim_head: 75 | The dimension of the each head 76 | """ 77 | def __init__(self, dim, num_heads = 8, dim_head = None): 78 | super(MultiHeadSelfAttention, self).__init__() 79 | self.num_heads = num_heads 80 | self.dim_head = int(dim / num_heads) if dim_head is None else dim_head 81 | _weight_dim = self.num_heads * self.dim_head 82 | self.to_qvk = nn.Linear(dim, _weight_dim * 3, bias = False) 83 | self.scale_factor = dim ** -0.5 84 | 85 | self.scale_factor = dim ** -0.5 86 | 87 | # Weight matrix for output, Size: num_heads*dim_head X dim 88 | # Final linear transformation layer 89 | self.w_out = nn.Linear(_weight_dim, dim, bias = False) 90 | 91 | def forward(self, x): 92 | qkv = self.to_qvk(x).chunk(3, dim = -1) 93 | q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.num_heads), qkv) 94 | 95 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale_factor 96 | attn = torch.softmax(dots, dim = -1) 97 | out = torch.matmul(attn, v) 98 | out = rearrange(out, 'b p h n d -> b p n (h d)') 99 | return self.w_out(out) 100 | 101 | 102 | class Transformer(nn.Module): 103 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.1): 104 | super().__init__() 105 | self.layers = nn.ModuleList([]) 106 | for _ in range(depth): 107 | self.layers.append(nn.ModuleList([ 108 | PreNorm(dim, MultiHeadSelfAttention(dim, heads, dim_head)), 109 | PreNorm(dim, FFN(dim, mlp_dim, dropout)) 110 | ])) 111 | 112 | def forward(self, x): 113 | for attn, ff in self.layers: 114 | x = attn(x) + x 115 | x = ff(x) + x 116 | return x 117 | 118 | 119 | class InvertedResidual(nn.Module): 120 | """ 121 | MobileNetv2 InvertedResidual block 122 | """ 123 | def __init__(self, in_channels, out_channels, stride = 1, expand_ratio = 2, act_layer = nn.SiLU): 124 | super(InvertedResidual, self).__init__() 125 | self.stride = stride 126 | self.use_res_connect = self.stride == 1 and in_channels == out_channels 127 | hidden_dim = int(round(in_channels * expand_ratio)) 128 | 129 | layers = [] 130 | if expand_ratio != 1: 131 | layers.append(ConvNormAct(in_channels, hidden_dim, kernel_size = 1, activation_layer = None)) 132 | 133 | # Depth-wise convolution 134 | layers.append( 135 | ConvNormAct(hidden_dim, hidden_dim, kernel_size = 3, stride = stride, 136 | padding = 1, groups = hidden_dim, activation_layer = act_layer) 137 | ) 138 | # Point-wise convolution 139 | layers.append( 140 | nn.Conv2d(hidden_dim, out_channels, kernel_size = 1, stride = 1, bias = False) 141 | ) 142 | layers.append(nn.BatchNorm2d(out_channels)) 143 | 144 | self.conv = nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | if self.use_res_connect: 148 | return x + self.conv(x) 149 | else: 150 | return self.conv(x) 151 | 152 | 153 | class MobileVitBlock(nn.Module): 154 | def __init__(self, in_channels, out_channels, d_model, layers, mlp_dim): 155 | super(MobileVitBlock, self).__init__() 156 | # Local representation 157 | self.local_representation = nn.Sequential( 158 | # Encode local spatial information 159 | ConvNormAct(in_channels, in_channels, 3), 160 | # Projects the tensor to a high-diementional space 161 | ConvNormAct(in_channels, d_model, 1) 162 | ) 163 | 164 | self.transformer = Transformer(d_model, layers, 1, 32, mlp_dim, 0.1) 165 | 166 | # Fusion block 167 | self.fusion_block1 = nn.Conv2d(d_model, in_channels, kernel_size = 1) 168 | self.fusion_block2 = nn.Conv2d(in_channels * 2, out_channels, 3, padding = 1) 169 | 170 | def forward(self, x): 171 | local_repr = self.local_representation(x) 172 | # global_repr = self.global_representation(local_repr) 173 | _, _, h, w = local_repr.shape 174 | global_repr = rearrange(local_repr, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=2, pw=2) 175 | global_repr = self.transformer(global_repr) 176 | global_repr = rearrange(global_repr, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//2, w=w//2, ph=2, pw=2) 177 | 178 | # Fuse the local and gloval features in the concatenation tensor 179 | fuse_repr = self.fusion_block1(global_repr) 180 | result = self.fusion_block2(torch.cat([x, fuse_repr], dim = 1)) 181 | return result 182 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | from torchvision import datasets 6 | from torchvision import transforms as T 7 | 8 | class AverageMeter(object): 9 | def __init__(self, 10 | name: str, 11 | fmt: Optional[str] = ':f', 12 | ) -> None: 13 | self.name = name 14 | self.fmt = fmt 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, 24 | val: float, 25 | n: Optional[int] = 1 26 | ) -> None: 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | def __str__(self): 33 | fmtstr = '{name}:{val' + self.fmt + '}({avg' + self.fmt + '})' 34 | return fmtstr.format(**self.__dict__) 35 | 36 | 37 | class ProgressMeter(object): 38 | def __init__(self, 39 | num_batches: int, 40 | meters: List[AverageMeter], 41 | prefix: Optional[str] = "", 42 | batch_info: Optional[str] = "" 43 | ) -> None: 44 | self.batch_fmster = self._get_batch_fmster(num_batches) 45 | self.meters = meters 46 | self.prefix = prefix 47 | self.batch_info = batch_info 48 | 49 | def display(self, batch): 50 | self.info = [self.prefix + self.batch_info + self.batch_fmster.format(batch)] 51 | self.info += [str(meter) for meter in self.meters] 52 | print('\t'.join(self.info)) 53 | 54 | def _get_batch_fmster(self, num_batches): 55 | num_digits = len(str(num_batches // 1)) 56 | fmt = '{:' + str(num_digits) + 'd}' 57 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 58 | 59 | 60 | class EarlyStopping(object): 61 | """ 62 | Arg 63 | """ 64 | def __init__(self, 65 | patience: int = 7, 66 | verbose: Optional[bool] = False, 67 | delta: Optional[float] = 0.0, 68 | path: Optional[str] = "checkpoint.pt" 69 | ) -> None: 70 | self.patience = patience 71 | self.verbose = verbose 72 | self.counter = 0 73 | self.best_score = None 74 | self.early_stop_flag = False 75 | self.val_loss_min = np.Inf 76 | self.delta = delta 77 | self.verbose = verbose 78 | self.path = path 79 | 80 | def __call__(self, val_loss, model): 81 | score = abs(val_loss) 82 | if self.best_score is None: 83 | self.best_score = score 84 | self.save_model(val_loss, model) 85 | elif val_loss > self.val_loss_min + self.delta: 86 | self.counter += 1 87 | if self.verbose: 88 | print(f"EarlyStopping Counter: {self.counter} out of {self.patience}") 89 | print(f"Best val loss: {self.val_loss_min} Current val loss: {score}") 90 | if self.counter >= self.patience: 91 | self.early_stop_flag = True 92 | else: 93 | self.best_score = score 94 | self.save_model(val_loss, model) 95 | self.counter = 0 96 | 97 | def save_model(self, val_loss, model): 98 | if self.verbose: 99 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 100 | torch.save(model.state_dict(), self.path) 101 | self.val_loss_min = val_loss 102 | 103 | 104 | def accuracy(output, target, topk = (1,)): 105 | """ 106 | Computes the accuracy over the top k predictions 107 | """ 108 | with torch.no_grad(): 109 | max_k = max(topk) 110 | batch_size = output.size(0) 111 | 112 | _, pred = output.topk(max_k, 113 | dim = 1, 114 | largest = True, 115 | sorted = True 116 | ) 117 | pred = pred.t() 118 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 119 | 120 | result = [] 121 | for k in topk: 122 | correct_k = correct[: k].contiguous().view(-1).float().sum(0, keepdim = True) 123 | result.append(correct_k.mul_(100.0 / batch_size)) 124 | return result 125 | 126 | 127 | def get_lr(optimizer): 128 | for param_group in optimizer.param_groups: 129 | return param_group['lr'] 130 | 131 | 132 | def get_cifar10_dataset(train_transform = None, test_transform = None): 133 | train_dataset = datasets.CIFAR10( 134 | root = './data', 135 | train = True, 136 | transform = train_transform, 137 | download = True 138 | ) 139 | test_dataset = datasets.CIFAR10( 140 | root = './data', 141 | train = False, 142 | transform = test_transform, 143 | download = True 144 | ) 145 | return train_dataset, test_dataset 146 | 147 | 148 | def get_dataloader( 149 | train_transform, 150 | test_transform, 151 | img_size = 224, 152 | split = (0.8, 0.2), 153 | **kwargs 154 | ): 155 | assert len(split) == 2 156 | assert sum(split) == 1 157 | assert split[0] + split[1] == 1 158 | 159 | train_dataset, test_dataset = get_cifar10_dataset(train_transform, test_transform) 160 | train_size = int(len(train_dataset) * split[0]) 161 | test_size = int(len(train_dataset) * split[1]) 162 | train_dataset, val_dataset = torch.utils.data.random_split( 163 | train_dataset, 164 | (train_size, test_size) 165 | ) 166 | 167 | train_loader = torch.utils.data.DataLoader( 168 | train_dataset, 169 | batch_size = kwargs['batch_size'], 170 | shuffle = True, 171 | num_workers = kwargs['num_workers'], 172 | pin_memory = True, 173 | drop_last = True, 174 | sampler = None 175 | ) 176 | val_loader = torch.utils.data.DataLoader( 177 | val_dataset, 178 | batch_size = kwargs['batch_size'], 179 | shuffle = False, 180 | num_workers = kwargs['num_workers'], 181 | pin_memory = True, 182 | drop_last = False, 183 | sampler = None 184 | ) 185 | test_loader = torch.utils.data.DataLoader( 186 | test_dataset, 187 | batch_size = kwargs['batch_size'], 188 | shuffle = False, 189 | num_workers = kwargs['num_workers'], 190 | pin_memory = True, 191 | drop_last = False, 192 | sampler = None 193 | ) 194 | return train_loader, val_loader, test_loader --------------------------------------------------------------------------------