├── src ├── .gitkeep ├── data │ ├── __init__.py │ ├── rasampler.py │ └── myloader.py ├── gcn_lib │ ├── __init__.py │ ├── torch_local.py │ ├── pos_embed.py │ ├── torch_nn.py │ ├── torch_edge.py │ └── torch_vertex.py ├── dataloaders │ └── celeba_hq.py ├── utils │ └── metrics.py ├── model │ ├── wignn_256.py │ ├── transfer_models.py │ ├── wignn.py │ ├── pyramid_vig.py │ ├── mobilevig.py │ └── greedyvig.py ├── opt_transfer.py ├── opt.py ├── train_trasnfer_learning.py ├── utils.py └── train.py ├── imgs ├── tab_inet.png ├── complexity.png ├── res_celebahq.png └── teaser_wignet.png ├── LICENSE ├── README.md ├── .gitignore └── env_wignet.yaml /src/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .myloader import create_loader 2 | -------------------------------------------------------------------------------- /imgs/tab_inet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/tab_inet.png -------------------------------------------------------------------------------- /imgs/complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/complexity.png -------------------------------------------------------------------------------- /imgs/res_celebahq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/res_celebahq.png -------------------------------------------------------------------------------- /imgs/teaser_wignet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/teaser_wignet.png -------------------------------------------------------------------------------- /src/gcn_lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .torch_nn import * 3 | from .torch_edge import * 4 | from .torch_vertex import * 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MXM License 2 | 3 | The copyright in this software is being made available under the BSD License, included below. This software may be subject to other third party and contributor rights, including patent rights, and no such rights are granted under this license. 4 | 5 | OWNER = University of Turin, IMT, Sisvel Technology. 6 | 7 | ORGANIZATION = University of Turin, IMT, Sisvel Technology. 8 | 9 | YEAR = 2024 10 | 11 | Copyright (c) 2024, University of Turin, IMT, Sisvel Technology. 12 | 13 | All rights reserved. 14 | 15 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 16 | 17 | - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 18 | - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 19 | - Neither the name of the University of Turin, IMT, Sisvel Technology nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 22 | 23 | -------------------------------------------------------------------------------- /src/dataloaders/celeba_hq.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | import torchvision 7 | from torchvision import datasets, models, transforms 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | import time 13 | import os 14 | 15 | 16 | 17 | def get_celeba(args, get_train_sampler = False, transform_train = True, crop_size = 224, drop_last=False): 18 | 19 | transforms_augs = transforms.Compose([ 20 | transforms.Resize((crop_size, crop_size)), 21 | transforms.RandomHorizontalFlip(), # data augmentation 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # normalization 24 | ]) 25 | 26 | transforms_no_augs = transforms.Compose([ 27 | transforms.Resize((crop_size, crop_size)), 28 | transforms.ToTensor(), 29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 30 | ]) 31 | 32 | if transform_train: 33 | train_transforms = transforms_augs 34 | else: 35 | train_transforms = transforms_no_augs 36 | 37 | test_transforms = transforms_no_augs 38 | 39 | data_dir = f'{args.root}/CelebA_HQ_facial_identity_dataset' 40 | train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms) 41 | test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), test_transforms) 42 | 43 | dataset_labels = train_dataset.classes 44 | num_classes = len(dataset_labels) 45 | 46 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle= transform_train, 47 | num_workers=args.workers, pin_memory=True, sampler=None, 48 | persistent_workers=args.workers > 0,drop_last=drop_last) 49 | 50 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, 51 | num_workers=args.workers, pin_memory=True, sampler=None, 52 | persistent_workers=args.workers > 0,drop_last=drop_last) 53 | 54 | if(get_train_sampler): 55 | return train_dataloader, None, test_dataloader, None, num_classes, dataset_labels 56 | 57 | return train_dataloader, None, test_dataloader, num_classes, dataset_labels 58 | 59 | 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | pass -------------------------------------------------------------------------------- /src/data/rasampler.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/facebookresearch/deit/blob/main/samplers.py 2 | # Copyright (c) 2015-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the CC-by-NC license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | import torch 9 | import torch.distributed as dist 10 | import math 11 | 12 | 13 | class RASampler(torch.utils.data.Sampler): 14 | """Sampler that restricts data loading to a subset of the dataset for distributed, 15 | with repeated augmentation. 16 | It ensures that different each augmented version of a sample will be visible to a 17 | different process (GPU) 18 | Heavily based on torch.utils.data.DistributedSampler 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 22 | if num_replicas is None: 23 | if not dist.is_available(): 24 | raise RuntimeError("Requires distributed package to be available") 25 | num_replicas = dist.get_world_size() 26 | if rank is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | rank = dist.get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 37 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 38 | self.shuffle = shuffle 39 | 40 | def __iter__(self): 41 | # deterministically shuffle based on epoch 42 | g = torch.Generator() 43 | g.manual_seed(self.epoch) 44 | if self.shuffle: 45 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 46 | else: 47 | indices = list(range(len(self.dataset))) 48 | 49 | # add extra samples to make it evenly divisible 50 | indices = [ele for ele in indices for i in range(3)] 51 | indices += indices[:(self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch -------------------------------------------------------------------------------- /src/gcn_lib/torch_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.models.layers import to_2tuple 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def window_partition_channel_last(x, window_size=8): 7 | """ 8 | Args: 9 | x: (B, H, W, C) 10 | window_size (int): window size 11 | Returns: 12 | windows: (num_windows*B, window_size, window_size, C) 13 | """ 14 | B, H, W, C = x.shape 15 | windows = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 16 | windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 17 | return windows 18 | 19 | def window_partition(x, window_size=7): 20 | """ 21 | Args: 22 | x: (B, C, H, W) 23 | window_size (int): window size 24 | Returns: 25 | windows: (num_windows*B, window_size, window_size, C) 26 | """ 27 | x = x.transpose(1,2).transpose(2,3) 28 | B, H, W, C = x.shape 29 | windows = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 30 | windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 31 | 32 | windows = windows.transpose(2,3).transpose(1,2) 33 | return windows 34 | 35 | def window_reverse(windows, window_size, H, W): 36 | """ 37 | Args: 38 | windows: (num_windows*B, C, window_size, window_size) 39 | window_size (int): Window size 40 | H (int): Height of image 41 | W (int): Width of image 42 | Returns: 43 | x: (B, C, H, W) 44 | """ 45 | windows = windows.transpose(1,2).transpose(2,3) 46 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 47 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 48 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 49 | x = x.transpose(2,3).transpose(1,2) 50 | return x 51 | 52 | 53 | class PatchEmbed(nn.Module): 54 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96):#, norm_layer=None): 55 | super().__init__() 56 | patch_size = to_2tuple(patch_size) 57 | self.patch_size = patch_size 58 | 59 | self.in_chans = in_chans 60 | self.embed_dim = embed_dim 61 | 62 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 63 | self.norm = nn.BatchNorm2d(embed_dim) 64 | 65 | def forward(self, x): 66 | """Forward function.""" 67 | # padding 68 | _, _, H, W = x.size() 69 | if W % self.patch_size[1] != 0: 70 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 71 | if H % self.patch_size[0] != 0: 72 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 73 | 74 | x = self.proj(x) # B C Wh Ww 75 | x = self.norm(x) 76 | 77 | return x -------------------------------------------------------------------------------- /src/gcn_lib/pos_embed.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | 6 | # -------------------------------------------------------- 7 | # relative position embedding 8 | # References: https://arxiv.org/abs/2009.13658 9 | # -------------------------------------------------------- 10 | def get_2d_relative_pos_embed(embed_dim, grid_size): 11 | """ 12 | grid_size: int of the grid height and width 13 | return: 14 | pos_embed: [grid_size*grid_size, grid_size*grid_size] 15 | """ 16 | pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size) 17 | relative_pos = 2 * np.matmul(pos_embed, pos_embed.transpose()) / pos_embed.shape[1] 18 | return relative_pos 19 | 20 | 21 | # -------------------------------------------------------- 22 | # 2D sine-cosine position embedding 23 | # References: 24 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 25 | # MoCo v3: https://github.com/facebookresearch/moco-v3 26 | # -------------------------------------------------------- 27 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 28 | """ 29 | grid_size: int of the grid height and width 30 | return: 31 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 32 | """ 33 | grid_h = np.arange(grid_size, dtype=np.float32) 34 | grid_w = np.arange(grid_size, dtype=np.float32) 35 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 36 | grid = np.stack(grid, axis=0) 37 | 38 | grid = grid.reshape([2, 1, grid_size, grid_size]) 39 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 40 | if cls_token: 41 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 42 | return pos_embed 43 | 44 | 45 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 46 | assert embed_dim % 2 == 0 47 | 48 | # use half of dimensions to encode grid_h 49 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 50 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 51 | 52 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 53 | return emb 54 | 55 | 56 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 57 | """ 58 | embed_dim: output dimension for each position 59 | pos: a list of positions to be encoded: size (M,) 60 | out: (M, D) 61 | """ 62 | assert embed_dim % 2 == 0 63 | omega = np.arange(embed_dim // 2, dtype=np.float) 64 | omega /= embed_dim / 2. 65 | omega = 1. / 10000**omega # (D/2,) 66 | 67 | pos = pos.reshape(-1) # (M,) 68 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 69 | 70 | emb_sin = np.sin(out) # (M, D/2) 71 | emb_cos = np.cos(out) # (M, D/2) 72 | 73 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 74 | return emb 75 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from enum import Enum 4 | 5 | class Summary(Enum): 6 | NONE = 0 7 | AVERAGE = 1 8 | SUM = 2 9 | COUNT = 3 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 14 | self.name = name 15 | self.fmt = fmt 16 | self.summary_type = summary_type 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | def all_reduce(self): 32 | if torch.cuda.is_available(): 33 | device = torch.device("cuda") 34 | elif torch.backends.mps.is_available(): 35 | device = torch.device("mps") 36 | else: 37 | device = torch.device("cpu") 38 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) 39 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 40 | self.sum, self.count = total.tolist() 41 | self.avg = self.sum / self.count 42 | 43 | def __str__(self): 44 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 45 | return fmtstr.format(**self.__dict__) 46 | 47 | def summary(self): 48 | fmtstr = '' 49 | if self.summary_type is Summary.NONE: 50 | fmtstr = '' 51 | elif self.summary_type is Summary.AVERAGE: 52 | fmtstr = '{name} {avg:.3f}' 53 | elif self.summary_type is Summary.SUM: 54 | fmtstr = '{name} {sum:.3f}' 55 | elif self.summary_type is Summary.COUNT: 56 | fmtstr = '{name} {count:.3f}' 57 | else: 58 | raise ValueError('invalid summary type %r' % self.summary_type) 59 | 60 | return fmtstr.format(**self.__dict__) 61 | 62 | 63 | class ProgressMeter(object): 64 | def __init__(self, num_batches, meters, prefix=""): 65 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 66 | self.meters = meters 67 | self.prefix = prefix 68 | 69 | def display(self, batch): 70 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 71 | entries += [str(meter) for meter in self.meters] 72 | print('\t'.join(entries)) 73 | 74 | def display_summary(self): 75 | entries = [" *"] 76 | entries += [meter.summary() for meter in self.meters] 77 | print(' '.join(entries)) 78 | 79 | def _get_batch_fmtstr(self, num_batches): 80 | num_digits = len(str(num_batches // 1)) 81 | fmt = '{:' + str(num_digits) + 'd}' 82 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 83 | 84 | def accuracy(output, target, topk=(1,)): 85 | """Computes the accuracy over the k top predictions for the specified values of k""" 86 | with torch.no_grad(): 87 | maxk = max(topk) 88 | batch_size = target.size(0) 89 | 90 | _, pred = output.topk(maxk, 1, True, True) 91 | pred = pred.t() 92 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 93 | 94 | res = [] 95 | for k in topk: 96 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 97 | res.append(correct_k.mul_(100.0 / batch_size)) 98 | return res -------------------------------------------------------------------------------- /src/model/wignn_256.py: -------------------------------------------------------------------------------- 1 | from timm.models.registry import register_model 2 | from model.wignn import DeepGCN 3 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 4 | from timm.models import create_model 5 | from torchprofile import profile_macs 6 | 7 | import torch 8 | 9 | def _cfg(url='', **kwargs): 10 | return { 11 | 'url': url, 12 | 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 13 | 'crop_pct': .9, 'interpolation': 'bicubic', 14 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 15 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 16 | **kwargs 17 | } 18 | 19 | 20 | default_cfgs = { 21 | 'wignn_256_gelu': _cfg( 22 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 23 | ), 24 | 'wignn_b_256_gelu': _cfg( 25 | crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 26 | ), 27 | } 28 | 29 | 30 | class OptInit: 31 | def __init__( 32 | self, 33 | num_classes=1000, 34 | drop_path_rate=0.0, 35 | knn = 9, 36 | use_shifts = True, 37 | use_reduce_ratios = False, 38 | img_size = 256, 39 | adapt_knn = True, 40 | 41 | channels = None, 42 | blocks = None, 43 | **kwargs): 44 | 45 | self.k = knn # neighbor num (default:9) 46 | self.conv = 'mr' # graph conv layer {edge, mr} 47 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish} 48 | self.norm = 'batch' # batch or instance normalization {batch, instance} 49 | self.bias = True # bias of conv layer True or False 50 | self.dropout = 0.0 # dropout rate 51 | self.use_dilation = True # use dilated knn or not 52 | self.epsilon = 0.2 # stochastic epsilon for gcn 53 | self.use_stochastic = False # stochastic for gcn, True or False 54 | self.drop_path = drop_path_rate 55 | self.blocks = blocks # number of basic blocks in the backbone 56 | self.channels = channels # number of channels of deep features 57 | self.n_classes = num_classes # Dimension of out_channels 58 | self.emb_dims = 1024 # Dimension of embeddings 59 | self.windows_size = 8 60 | self.use_shifts = use_shifts 61 | self.img_size = img_size 62 | self.use_reduce_ratios = use_reduce_ratios 63 | self.adapt_knn = adapt_knn 64 | 65 | @register_model 66 | def wignn_ti_256_gelu(pretrained=False, **kwargs): 67 | 68 | opt = OptInit(**kwargs, channels = [48, 96, 240, 384], blocks= [2,2,6,2]) 69 | model = DeepGCN(opt) 70 | model.default_cfg = default_cfgs['wignn_256_gelu'] 71 | return model 72 | 73 | 74 | @register_model 75 | def wignn_s_256_gelu(pretrained=False, **kwargs): 76 | 77 | opt = OptInit(**kwargs, channels = [80, 160, 400, 640], blocks= [2,2,6,2]) 78 | model = DeepGCN(opt) 79 | model.default_cfg = default_cfgs['wignn_256_gelu'] 80 | return model 81 | 82 | 83 | @register_model 84 | def wignn_m_256_gelu(pretrained=False, **kwargs): 85 | 86 | opt = OptInit(**kwargs, channels = [96, 192, 384, 768], blocks= [2,2,16,2]) 87 | 88 | model = DeepGCN(opt) 89 | model.default_cfg = default_cfgs['wignn_256_gelu'] 90 | return model 91 | 92 | 93 | @register_model 94 | def wignn_b_256_gelu(pretrained=False, **kwargs): 95 | 96 | opt = OptInit(**kwargs, channels = [128, 256, 512, 1024], blocks= [2,2,18,2]) 97 | 98 | model = DeepGCN(opt) 99 | model.default_cfg = default_cfgs['wignn_b_256_gelu'] 100 | return model 101 | 102 | 103 | 104 | if __name__ == '__main__': 105 | 106 | pass 107 | 108 | 109 | -------------------------------------------------------------------------------- /src/gcn_lib/torch_nn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from torch.nn import Sequential as Seq, Linear as Lin, Conv2d 5 | from einops import rearrange 6 | 7 | 8 | ############################## 9 | # Basic layers 10 | ############################## 11 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 12 | # activation layer 13 | 14 | act = act.lower() 15 | if act == 'relu': 16 | layer = nn.ReLU(inplace) 17 | elif act == 'leakyrelu': 18 | layer = nn.LeakyReLU(neg_slope, inplace) 19 | elif act == 'prelu': 20 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 21 | elif act == 'gelu': 22 | layer = nn.GELU() 23 | elif act == 'hswish': 24 | layer = nn.Hardswish(inplace) 25 | else: 26 | raise NotImplementedError('activation layer [%s] is not found' % act) 27 | return layer 28 | 29 | 30 | def norm_layer(norm, nc): 31 | # normalization layer 2d 32 | norm = norm.lower() 33 | if norm == 'batch': 34 | layer = nn.BatchNorm2d(nc, affine=True) 35 | elif norm == 'instance': 36 | layer = nn.InstanceNorm2d(nc, affine=False) 37 | else: 38 | raise NotImplementedError('normalization layer [%s] is not found' % norm) 39 | return layer 40 | 41 | 42 | class MLP(Seq): 43 | def __init__(self, channels, act='relu', norm=None, bias=True): 44 | m = [] 45 | for i in range(1, len(channels)): 46 | m.append(Lin(channels[i - 1], channels[i], bias)) 47 | if act is not None and act.lower() != 'none': 48 | m.append(act_layer(act)) 49 | if norm is not None and norm.lower() != 'none': 50 | m.append(norm_layer(norm, channels[-1])) 51 | super(MLP, self).__init__(*m) 52 | 53 | 54 | class BasicConv(Seq): 55 | def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): 56 | m = [] 57 | for i in range(1, len(channels)): 58 | m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias, groups=4)) 59 | if norm is not None and norm.lower() != 'none': 60 | m.append(norm_layer(norm, channels[-1])) 61 | if act is not None and act.lower() != 'none': 62 | m.append(act_layer(act)) 63 | if drop > 0: 64 | m.append(nn.Dropout2d(drop)) 65 | 66 | super(BasicConv, self).__init__(*m) 67 | 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight) 74 | if m.bias is not None: 75 | nn.init.zeros_(m.bias) 76 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() 79 | 80 | 81 | def batched_index_select(x, idx): 82 | r"""fetches neighbors features from a given neighbor idx 83 | 84 | Args: 85 | x (Tensor): input feature Tensor 86 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`. 87 | idx (Tensor): edge_idx 88 | :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`. 89 | Returns: 90 | Tensor: output neighbors features 91 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`. 92 | """ 93 | batch_size, num_dims, num_vertices_reduced = x.shape[:3] 94 | _, num_vertices, k = idx.shape 95 | idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced 96 | idx = idx + idx_base 97 | idx = idx.contiguous().view(-1) 98 | 99 | x = x.transpose(2, 1) 100 | feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :] 101 | feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() 102 | return feature 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WiGNet: Windowed Vision Graph Neural Network 2 | 3 | Pytorch implementation of the paper "**WiGNet: Windowed Vision Graph Neural Network**", published at WACV 2025. This repository is based on [VisionGNN](https://github.com/jichengyuan/Vision_GNN). 4 | 5 | [ArXiv](https://arxiv.org/abs/2410.00807) 6 | 7 |
8 | teaser 9 |
10 | 11 | ## Abstract 12 | In recent years, Graph Neural Networks (GNNs) have demonstrated strong adaptability to various real-world challenges, with architectures such as Vision GNN (ViG) achieving state-of-the-art performance in several computer vision tasks. However, their practical applicability is hindered by the computational complexity of constructing the graph, which scales quadratically with the image size. In this paper, we introduce a novel Windowed vision Graph neural Network (WiGNet) model for efficient image processing. WiGNet explores a different strategy from previous works by partitioning the image into windows and constructing a graph within each window. Therefore, our model uses graph convolutions instead of the typical 2D convolution or self-attention mechanism. WiGNet effectively manages computational and memory complexity for large image sizes. We evaluate our method in the ImageNet-1k benchmark dataset and test the adaptability of WiGNet using the CelebA-HQ dataset as a downstream task with higher-resolution images. In both of these scenarios, our method achieves competitive results compared to previous vision GNNs while keeping memory and computational complexity at bay. WiGNet offers a promising solution toward the deployment of vision GNNs in real-world applications. 13 | 14 | 15 |
16 | complexity 17 |
18 | 19 | 20 | ## Usage 21 | 22 | Download our pretrained model from [here](https://drive.google.com/file/d/11bDJaiYxCIwG2OxapIJSkQZys38wDI4S/view?usp=sharing). 23 | 24 | ### Environment 25 | - conda env create -f env_wignet.yml 26 | - conda activate wignet 27 | 28 | ### ImageNet Classification 29 | 30 | 31 | - Evaluation 32 | ``` 33 | python train.py --model wignn_ti_256_gelu \ 34 | --img-size 256 \ 35 | --knn 9 \ 36 | --use-shift 1 \ 37 | --adapt-knn 1 \ 38 | --data /path/to/imagenet \ 39 | -b 128 \ 40 | --resume /path/to/checkpoint.pth.tar \ 41 | --evaluate 42 | ``` 43 | 44 | - Training WiGNet-Ti on 8 GPUs 45 | ``` 46 | python -m torch.distributed.launch \ 47 | --nproc_per_node=8 train.py \ 48 | --model wignn_ti_256_gelu \ 49 | --img-size 256 \ 50 | --knn 9 \ 51 | --use-shift 1 \ 52 | --adapt-knn 1 \ 53 | --use-reduce-ratios 0 \ 54 | --data /path/to/imagenet \ 55 | --sched cosine \ 56 | --epochs 300 \ 57 | --opt adamw -j 8 \ 58 | --warmup-lr 1e-6 \ 59 | --mixup .8 \ 60 | --cutmix 1.0 \ 61 | --model-ema \ 62 | --model-ema-decay 0.99996 \ 63 | --aa rand-m9-mstd0.5-inc1 \ 64 | --color-jitter 0.4 \ 65 | --warmup-epochs 20 \ 66 | --opt-eps 1e-8 \ 67 | --remode pixel \ 68 | --reprob 0.25 \ 69 | --amp \ 70 | --lr 2e-3 \ 71 | --weight-decay .05 \ 72 | --drop 0 \ 73 | --drop-path .1 \ 74 | -b 128 \ 75 | --output /path/to/save 76 | ``` 77 | 78 | 79 | 80 | ### Complexity Evaluation 81 | 82 | **Memory & MACs** 83 | - WiGNet 84 | ``` 85 | python -m model.wignn 86 | ``` 87 | 88 | - ViG 89 | ``` 90 | python -m model.pyramid_vig 91 | ``` 92 | 93 | - GreedyViG 94 | ``` 95 | python -m model.greedyvig 96 | ``` 97 | 98 | - MobileViG 99 | ``` 100 | python -m model.mobilevig 101 | ``` 102 | 103 | 104 | 105 | 106 | ### Transfer Learning 107 | - WiGNet 108 | ``` 109 | python train_trasnfer_learning.py 110 | --model-type wignn_ti_256_gelu \ 111 | --use-shift 1 \ 112 | --adapt-knn 1 \ 113 | --batch-size 64 \ 114 | --checkpoint /path/to/checkpoint.pth.tar \ 115 | --crop-size 512 \ 116 | --dataset CelebA \ 117 | --epochs 30 \ 118 | --freeze 1 \ 119 | --loss cross_entropy \ 120 | --lr 0.001 \ 121 | --lr-scheduler constant \ 122 | --opt adam \ 123 | --root /path/to/save/dataset \ 124 | --save-dir /path/to/save/outputs_tl_high_res/ \ 125 | --seed 1 126 | ``` 127 | 128 | For ViG include `--num-gpu 8` 129 | 130 | ## Results 131 | 132 | ### ImageNet-1k 133 |
134 | inet 135 |
136 | 137 | ### CelebaHq 138 |
139 | celebahq 140 |
141 | 142 | 143 | # Citation 144 | If you use our code, please cite 145 | 146 | ``` 147 | @inproceedings{spadaro2024wignet, 148 | title={{W}i{GN}et: {W}indowed {V}ision {G}raph {N}eural {N}etwork}, 149 | author={Spadaro, Gabriele and Grangetto, Marco and Fiandrotti, Attilio and Tartaglione, Enzo and Giraldo, Jhony H}, 150 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 151 | year={2025} 152 | } 153 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom ignores 2 | wandb 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | ### Windows template 141 | # Windows thumbnail cache files 142 | Thumbs.db 143 | Thumbs.db:encryptable 144 | ehthumbs.db 145 | ehthumbs_vista.db 146 | 147 | # Dump file 148 | *.stackdump 149 | 150 | # Folder config file 151 | [Dd]esktop.ini 152 | 153 | # Recycle Bin used on file shares 154 | $RECYCLE.BIN/ 155 | 156 | # Windows Installer files 157 | *.cab 158 | *.msi 159 | *.msix 160 | *.msm 161 | *.msp 162 | 163 | # Windows shortcuts 164 | *.lnk 165 | 166 | ### JetBrains template 167 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 168 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 169 | 170 | # User-specific stuff 171 | .idea/**/workspace.xml 172 | .idea/**/tasks.xml 173 | .idea/**/usage.statistics.xml 174 | .idea/**/dictionaries 175 | .idea/**/shelf 176 | 177 | # Generated files 178 | .idea/**/contentModel.xml 179 | 180 | # Sensitive or high-churn files 181 | .idea/**/dataSources/ 182 | .idea/**/dataSources.ids 183 | .idea/**/dataSources.local.xml 184 | .idea/**/sqlDataSources.xml 185 | .idea/**/dynamic.xml 186 | .idea/**/uiDesigner.xml 187 | .idea/**/dbnavigator.xml 188 | 189 | # Gradle 190 | .idea/**/gradle.xml 191 | .idea/**/libraries 192 | 193 | # Gradle and Maven with auto-import 194 | # When using Gradle or Maven with auto-import, you should exclude module files, 195 | # since they will be recreated, and may cause churn. Uncomment if using 196 | # auto-import. 197 | # .idea/artifacts 198 | # .idea/compiler.xml 199 | # .idea/jarRepositories.xml 200 | # .idea/modules.xml 201 | # .idea/*.iml 202 | # .idea/modules 203 | # *.iml 204 | # *.ipr 205 | 206 | # CMake 207 | cmake-build-*/ 208 | 209 | # Mongo Explorer plugin 210 | .idea/**/mongoSettings.xml 211 | 212 | # File-based project format 213 | *.iws 214 | 215 | # IntelliJ 216 | out/ 217 | 218 | # mpeltonen/sbt-idea plugin 219 | .idea_modules/ 220 | 221 | # JIRA plugin 222 | atlassian-ide-plugin.xml 223 | 224 | # Cursive Clojure plugin 225 | .idea/replstate.xml 226 | 227 | # Crashlytics plugin (for Android Studio and IntelliJ) 228 | com_crashlytics_export_strings.xml 229 | crashlytics.properties 230 | crashlytics-build.properties 231 | fabric.properties 232 | 233 | # Editor-based Rest Client 234 | .idea/httpRequests 235 | 236 | # Android studio 3.1+ serialized cache file 237 | .idea/caches/build_file_checksums.ser 238 | 239 | ### VisualStudioCode template 240 | .vscode/* 241 | !.vscode/settings.json 242 | !.vscode/tasks.json 243 | !.vscode/launch.json 244 | !.vscode/extensions.json 245 | *.code-workspace 246 | 247 | # Local History for Visual Studio Code 248 | .history/ 249 | 250 | .idea/ 251 | 252 | -------------------------------------------------------------------------------- /src/data/myloader.py: -------------------------------------------------------------------------------- 1 | 2 | """ Loader Factory, Fast Collate, CUDA Prefetcher 3 | 4 | Prefetcher and Fast Collate inspired by NVIDIA APEX example at 5 | https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | 10 | import torch.utils.data 11 | import torch.distributed as dist 12 | import numpy as np 13 | 14 | from timm.data.transforms_factory import create_transform 15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from timm.data.distributed_sampler import OrderedDistributedSampler 17 | from timm.data.random_erasing import RandomErasing 18 | from timm.data.mixup import FastCollateMixup 19 | from timm.data.loader import fast_collate, PrefetchLoader, MultiEpochsDataLoader 20 | 21 | from .rasampler import RASampler 22 | 23 | 24 | def is_dist_avail_and_initialized(): 25 | if not dist.is_available(): 26 | return False 27 | if not dist.is_initialized(): 28 | return False 29 | return True 30 | 31 | 32 | def get_world_size(): 33 | if not is_dist_avail_and_initialized(): 34 | return 1 35 | return dist.get_world_size() 36 | 37 | 38 | def get_rank(): 39 | if not is_dist_avail_and_initialized(): 40 | return 0 41 | return dist.get_rank() 42 | 43 | 44 | def create_loader( 45 | dataset, 46 | input_size, 47 | batch_size, 48 | is_training=False, 49 | use_prefetcher=True, 50 | no_aug=False, 51 | re_prob=0., 52 | re_mode='const', 53 | re_count=1, 54 | re_split=False, 55 | scale=None, 56 | ratio=None, 57 | hflip=0.5, 58 | vflip=0., 59 | color_jitter=0.4, 60 | auto_augment=None, 61 | num_aug_splits=0, 62 | interpolation='bilinear', 63 | mean=IMAGENET_DEFAULT_MEAN, 64 | std=IMAGENET_DEFAULT_STD, 65 | num_workers=1, 66 | distributed=False, 67 | crop_pct=None, 68 | collate_fn=None, 69 | pin_memory=False, 70 | fp16=False, 71 | tf_preprocessing=False, 72 | use_multi_epochs_loader=False, 73 | repeated_aug=False 74 | ): 75 | re_num_splits = 0 76 | if re_split: 77 | # apply RE to second half of batch if no aug split otherwise line up with aug split 78 | re_num_splits = num_aug_splits or 2 79 | dataset.transform = create_transform( 80 | input_size, 81 | is_training=is_training, 82 | use_prefetcher=use_prefetcher, 83 | no_aug=no_aug, 84 | scale=scale, 85 | ratio=ratio, 86 | hflip=hflip, 87 | vflip=vflip, 88 | color_jitter=color_jitter, 89 | auto_augment=auto_augment, 90 | interpolation=interpolation, 91 | mean=mean, 92 | std=std, 93 | crop_pct=crop_pct, 94 | tf_preprocessing=tf_preprocessing, 95 | re_prob=re_prob, 96 | re_mode=re_mode, 97 | re_count=re_count, 98 | re_num_splits=re_num_splits, 99 | separate=num_aug_splits > 0, 100 | ) 101 | 102 | sampler = None 103 | if distributed: 104 | if is_training: 105 | if repeated_aug: 106 | print('using repeated_aug') 107 | num_tasks = get_world_size() 108 | global_rank = get_rank() 109 | sampler = RASampler( 110 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True 111 | ) 112 | else: 113 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 114 | else: 115 | # This will add extra duplicate entries to result in equal num 116 | # of samples per-process, will slightly alter validation results 117 | sampler = OrderedDistributedSampler(dataset) 118 | else: 119 | if is_training and repeated_aug: 120 | print('using repeated_aug') 121 | num_tasks = get_world_size() 122 | global_rank = get_rank() 123 | sampler = RASampler( 124 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True 125 | ) 126 | 127 | if collate_fn is None: 128 | collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate 129 | 130 | loader_class = torch.utils.data.DataLoader 131 | 132 | if use_multi_epochs_loader: 133 | loader_class = MultiEpochsDataLoader 134 | 135 | loader = loader_class( 136 | dataset, 137 | batch_size=batch_size, 138 | shuffle=sampler is None and is_training, 139 | num_workers=num_workers, 140 | sampler=sampler, 141 | collate_fn=collate_fn, 142 | pin_memory=pin_memory, 143 | drop_last=is_training, 144 | ) 145 | if use_prefetcher: 146 | prefetch_re_prob = re_prob if is_training and not no_aug else 0. 147 | loader = PrefetchLoader( 148 | loader, 149 | mean=mean, 150 | std=std, 151 | fp16=fp16, 152 | re_prob=prefetch_re_prob, 153 | re_mode=re_mode, 154 | re_count=re_count, 155 | re_num_splits=re_num_splits 156 | ) 157 | 158 | return loader -------------------------------------------------------------------------------- /src/gcn_lib/torch_edge.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import sys 7 | 8 | def pairwise_distance(x): 9 | """ 10 | Compute pairwise distance of a point cloud. 11 | Args: 12 | x: tensor (batch_size, num_points, num_dims) 13 | Returns: 14 | pairwise distance: (batch_size, num_points, num_points) 15 | """ 16 | with torch.no_grad(): 17 | x_inner = -2*torch.matmul(x, x.transpose(2, 1)) 18 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 19 | return x_square + x_inner + x_square.transpose(2, 1) 20 | 21 | 22 | def part_pairwise_distance(x, start_idx=0, end_idx=1): 23 | """ 24 | Compute pairwise distance of a point cloud. 25 | Args: 26 | x: tensor (batch_size, num_points, num_dims) 27 | Returns: 28 | pairwise distance: (batch_size, num_points, num_points) 29 | """ 30 | with torch.no_grad(): 31 | x_part = x[:, start_idx:end_idx] 32 | x_square_part = torch.sum(torch.mul(x_part, x_part), dim=-1, keepdim=True) 33 | x_inner = -2*torch.matmul(x_part, x.transpose(2, 1)) 34 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 35 | return x_square_part + x_inner + x_square.transpose(2, 1) 36 | 37 | 38 | def xy_pairwise_distance(x, y): 39 | """ 40 | Compute pairwise distance of a point cloud. 41 | Args: 42 | x: tensor (batch_size, num_points, num_dims) 43 | Returns: 44 | pairwise distance: (batch_size, num_points, num_points) 45 | """ 46 | with torch.no_grad(): 47 | xy_inner = -2*torch.matmul(x, y.transpose(2, 1)) 48 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) 49 | y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True) 50 | return x_square + xy_inner + y_square.transpose(2, 1) 51 | 52 | 53 | def dense_knn_matrix(x, k=16, relative_pos=None): 54 | """Get KNN based on the pairwise distance. 55 | Args: 56 | x: (batch_size, num_dims, num_points, 1) 57 | k: int 58 | Returns: 59 | nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k) 60 | """ 61 | with torch.no_grad(): 62 | x = x.transpose(2, 1).squeeze(-1) 63 | batch_size, n_points, n_dims = x.shape 64 | ### memory efficient implementation ### 65 | n_part = 10000 66 | if n_points > n_part: 67 | nn_idx_list = [] 68 | groups = math.ceil(n_points / n_part) 69 | for i in range(groups): 70 | start_idx = n_part * i 71 | end_idx = min(n_points, n_part * (i + 1)) 72 | dist = part_pairwise_distance(x.detach(), start_idx, end_idx) 73 | if relative_pos is not None: 74 | dist += relative_pos[:, start_idx:end_idx] 75 | _, nn_idx_part = torch.topk(-dist, k=k) 76 | nn_idx_list += [nn_idx_part] 77 | nn_idx = torch.cat(nn_idx_list, dim=1) 78 | else: 79 | dist = pairwise_distance(x.detach()) 80 | 81 | if relative_pos is not None: 82 | 83 | dist += relative_pos 84 | 85 | # nn_idx = torch.randint(0, n_points-1, (batch_size,n_points,k)).to(dist.device) 86 | _, nn_idx = torch.topk(-dist, k=k) # b, n, k 87 | ###### 88 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) 89 | return torch.stack((nn_idx, center_idx), dim=0) 90 | 91 | 92 | def xy_dense_knn_matrix(x, y, k=16, relative_pos=None): 93 | """Get KNN based on the pairwise distance. 94 | Args: 95 | x: (batch_size, num_dims, num_points, 1) 96 | k: int 97 | Returns: 98 | nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k) 99 | """ 100 | with torch.no_grad(): 101 | x = x.transpose(2, 1).squeeze(-1) 102 | y = y.transpose(2, 1).squeeze(-1) 103 | batch_size, n_points, n_dims = x.shape 104 | dist = xy_pairwise_distance(x.detach(), y.detach()) 105 | if relative_pos is not None: 106 | dist += relative_pos 107 | 108 | # nn_idx = torch.randint(1, 195, (batch_size,n_points,k)).to(dist.device) 109 | _, nn_idx = torch.topk(-dist, k=k) 110 | 111 | # print('Campling values') 112 | # nn_idx = torch.clamp(nn_idx, min=0, max=n_points-1) 113 | 114 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) 115 | return torch.stack((nn_idx, center_idx), dim=0) 116 | 117 | 118 | class DenseDilated(nn.Module): 119 | """ 120 | Find dilated neighbor from neighbor list 121 | 122 | edge_index: (2, batch_size, num_points, k) 123 | """ 124 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 125 | super(DenseDilated, self).__init__() 126 | self.dilation = dilation 127 | self.stochastic = stochastic 128 | self.epsilon = epsilon 129 | self.k = k 130 | 131 | def forward(self, edge_index): 132 | if self.stochastic: 133 | if torch.rand(1) < self.epsilon and self.training: 134 | num = self.k * self.dilation 135 | randnum = torch.randperm(num)[:self.k] 136 | edge_index = edge_index[:, :, :, randnum] 137 | else: 138 | edge_index = edge_index[:, :, :, ::self.dilation] 139 | else: 140 | edge_index = edge_index[:, :, :, ::self.dilation] 141 | return edge_index 142 | 143 | 144 | class DenseDilatedKnnGraph(nn.Module): 145 | """ 146 | Find the neighbors' indices based on dilated knn 147 | """ 148 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): 149 | super(DenseDilatedKnnGraph, self).__init__() 150 | self.dilation = dilation 151 | self.stochastic = stochastic 152 | self.epsilon = epsilon 153 | self.k = k 154 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon) 155 | 156 | def forward(self, x, y=None, relative_pos=None): 157 | if y is not None: 158 | #### normalize 159 | x = F.normalize(x, p=2.0, dim=1) 160 | y = F.normalize(y, p=2.0, dim=1) 161 | #### 162 | edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos) 163 | else: 164 | #### normalize 165 | x = F.normalize(x, p=2.0, dim=1) 166 | #### 167 | edge_index = dense_knn_matrix(x, self.k * self.dilation, relative_pos) 168 | return self._dilated(edge_index) 169 | -------------------------------------------------------------------------------- /src/opt_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def int2bool(i): 4 | i = int(i) 5 | assert i == 0 or i == 1 6 | return i == 1 7 | 8 | 9 | def get_args_parser(add_help=True): 10 | 11 | parser = argparse.ArgumentParser(description="Transfer Learning Training", add_help=add_help) 12 | 13 | parser.add_argument("--model-type", type=str, help="model name") 14 | parser.add_argument( 15 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" 16 | ) 17 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") 18 | parser.add_argument( 19 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" 20 | ) 21 | parser.add_argument("--opt", default="sgd", type=str, help="optimizer") 22 | parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") 23 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 24 | parser.add_argument( 25 | "--wd", 26 | "--weight-decay", 27 | default=1e-4, 28 | type=float, 29 | metavar="W", 30 | help="weight decay (default: 1e-4)", 31 | dest="weight_decay", 32 | ) 33 | parser.add_argument( 34 | "--norm-weight-decay", 35 | default=None, 36 | type=float, 37 | help="weight decay for Normalization layers (default: None, same value as --wd)", 38 | ) 39 | parser.add_argument( 40 | "--bias-weight-decay", 41 | default=None, 42 | type=float, 43 | help="weight decay for bias parameters of all layers (default: None, same value as --wd)", 44 | ) 45 | parser.add_argument( 46 | "--transformer-embedding-decay", 47 | default=None, 48 | type=float, 49 | help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)", 50 | ) 51 | parser.add_argument( 52 | "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" 53 | ) 54 | parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") 55 | parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") 56 | parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)") 57 | parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") 58 | parser.add_argument( 59 | "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)" 60 | ) 61 | parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") 62 | parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") 63 | parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") 64 | parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)") 65 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency") 66 | parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") 67 | parser.add_argument("--resume", default="", type=str, help="path of checkpoint") 68 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") 69 | parser.add_argument( 70 | "--cache-dataset", 71 | dest="cache_dataset", 72 | help="Cache the datasets for quicker initialization. It also serializes the transforms", 73 | action="store_true", 74 | ) 75 | parser.add_argument( 76 | "--sync-bn", 77 | dest="sync_bn", 78 | help="Use sync batch norm", 79 | action="store_true", 80 | ) 81 | parser.add_argument( 82 | "--test-only", 83 | dest="test_only", 84 | help="Only test the model", 85 | action="store_true", 86 | ) 87 | parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") 88 | parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy") 89 | parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy") 90 | parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") 91 | 92 | # Mixed precision training parameters 93 | parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") 94 | 95 | # distributed training parameters 96 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") 97 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") 98 | parser.add_argument( 99 | "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" 100 | ) 101 | parser.add_argument( 102 | "--model-ema-steps", 103 | type=int, 104 | default=32, 105 | help="the number of iterations that controls how often to update the EMA model (default: 32)", 106 | ) 107 | parser.add_argument( 108 | "--model-ema-decay", 109 | type=float, 110 | default=0.99998, 111 | help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", 112 | ) 113 | parser.add_argument( 114 | "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." 115 | ) 116 | parser.add_argument( 117 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" 118 | ) 119 | 120 | parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") 121 | parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") 122 | parser.add_argument( 123 | "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" 124 | ) 125 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") 126 | parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") 127 | parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") 128 | 129 | 130 | parser.add_argument("--seed", default=None, type=int) 131 | parser.add_argument("--save-dir", default=None, type=str) 132 | 133 | parser.add_argument("--dataset", type=str, 134 | help="Source dataset.") 135 | 136 | parser.add_argument("--checkpoint", default="", type=str, help="path of checkpoint") 137 | parser.add_argument("--root", type=str, default="/scratch/dataset", 138 | help="Dataset root folder.") 139 | parser.add_argument("--loss", default="cross_entropy", type=str, choices=["cross_entropy", "nll"]) 140 | 141 | 142 | parser.add_argument('--use-shift', type=int2bool, default=0) # 1 == True 143 | parser.add_argument('--adapt-knn', type=int2bool, default=0) # 1 == True 144 | parser.add_argument('--freezed', type=int2bool, default=1) # 1 == True 145 | parser.add_argument("--crop-size", default=None, type=int) 146 | 147 | 148 | parser.add_argument("--num-gpu", default=1, type=int) 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | return parser -------------------------------------------------------------------------------- /env_wignet.yaml: -------------------------------------------------------------------------------- 1 | name: wignet 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - asttokens=2.0.5=pyhd3eb1b0_0 10 | - autopep8=1.6.0=pyhd3eb1b0_1 11 | - backcall=0.2.0=pyhd3eb1b0_0 12 | - beautifulsoup4=4.11.1=py310h06a4308_0 13 | - blas=1.0=mkl 14 | - brotlipy=0.7.0=py310h7f8727e_1002 15 | - bzip2=1.0.8=h7b6447c_0 16 | - c-ares=1.18.1=h7f8727e_0 17 | - ca-certificates=2023.08.22=h06a4308_0 18 | - certifi=2023.11.17=py310h06a4308_0 19 | - cffi=1.15.1=py310h74dc2b5_0 20 | - chardet=4.0.0=py310h06a4308_1003 21 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 22 | - cmake=3.22.1=h1fce559_0 23 | - conda=22.11.1=py310h06a4308_4 24 | - conda-build=3.23.3=py310h06a4308_0 25 | - conda-package-handling=1.9.0=py310h5eee18b_1 26 | - cryptography=38.0.1=py310h9ce1e76_0 27 | - cuda=11.6.1=0 28 | - cuda-cccl=11.6.55=hf6102b2_0 29 | - cuda-command-line-tools=11.6.2=0 30 | - cuda-compiler=11.6.2=0 31 | - cuda-cudart=11.6.55=he381448_0 32 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 33 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 34 | - cuda-cupti=11.6.124=h86345e5_0 35 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 36 | - cuda-driver-dev=11.6.55=0 37 | - cuda-gdb=12.0.90=0 38 | - cuda-libraries=11.6.1=0 39 | - cuda-libraries-dev=11.6.1=0 40 | - cuda-memcheck=11.8.86=0 41 | - cuda-nsight=12.0.78=0 42 | - cuda-nsight-compute=12.0.0=0 43 | - cuda-nvcc=11.6.124=hbba6d2d_0 44 | - cuda-nvdisasm=12.0.76=0 45 | - cuda-nvml-dev=11.6.55=haa9ef22_0 46 | - cuda-nvprof=12.0.90=0 47 | - cuda-nvprune=11.6.124=he22ec0a_0 48 | - cuda-nvrtc=11.6.124=h020bade_0 49 | - cuda-nvrtc-dev=11.6.124=h249d397_0 50 | - cuda-nvtx=11.6.124=h0630a44_0 51 | - cuda-nvvp=12.0.90=0 52 | - cuda-runtime=11.6.1=0 53 | - cuda-samples=11.6.101=h8efea70_0 54 | - cuda-sanitizer-api=12.0.90=0 55 | - cuda-toolkit=11.6.1=0 56 | - cuda-tools=11.6.1=0 57 | - cuda-visual-tools=11.6.1=0 58 | - decorator=5.1.1=pyhd3eb1b0_0 59 | - executing=0.8.3=pyhd3eb1b0_0 60 | - expat=2.4.9=h6a678d5_0 61 | - ffmpeg=4.3=hf484d3e_0 62 | - filelock=3.6.0=pyhd3eb1b0_0 63 | - flit-core=3.6.0=pyhd3eb1b0_0 64 | - freetype=2.12.1=h4a9f257_0 65 | - gds-tools=1.5.0.59=0 66 | - giflib=5.2.1=h7b6447c_0 67 | - glob2=0.7=pyhd3eb1b0_0 68 | - gmp=6.2.1=h295c915_3 69 | - gnutls=3.6.15=he1e5248_0 70 | - icu=58.2=he6710b0_3 71 | - idna=3.4=py310h06a4308_0 72 | - intel-openmp=2021.4.0=h06a4308_3561 73 | - ipython=8.7.0=py310h06a4308_0 74 | - jedi=0.18.1=py310h06a4308_1 75 | - jinja2=2.11.3=pyhd3eb1b0_0 76 | - jpeg=9e=h7f8727e_0 77 | - krb5=1.19.2=hac12032_0 78 | - lame=3.100=h7b6447c_0 79 | - lcms2=2.12=h3be6417_0 80 | - ld_impl_linux-64=2.38=h1181459_1 81 | - lerc=3.0=h295c915_0 82 | - libarchive=3.6.1=hab531cd_0 83 | - libcublas=11.9.2.110=h5e84587_0 84 | - libcublas-dev=11.9.2.110=h5c901ab_0 85 | - libcufft=10.7.1.112=hf425ae0_0 86 | - libcufft-dev=10.7.1.112=ha5ce4c0_0 87 | - libcufile=1.5.0.59=0 88 | - libcufile-dev=1.5.0.59=0 89 | - libcurand=10.3.1.50=0 90 | - libcurand-dev=10.3.1.50=0 91 | - libcurl=7.86.0=h91b91d3_0 92 | - libcusolver=11.3.4.124=h33c3c4e_0 93 | - libcusparse=11.7.2.124=h7538f96_0 94 | - libcusparse-dev=11.7.2.124=hbbe9722_0 95 | - libdeflate=1.8=h7f8727e_5 96 | - libedit=3.1.20221030=h5eee18b_0 97 | - libev=4.33=h7f8727e_1 98 | - libffi=3.3=he6710b0_2 99 | - libgcc-ng=11.2.0=h1234567_1 100 | - libgomp=11.2.0=h1234567_1 101 | - libiconv=1.16=h7f8727e_2 102 | - libidn2=2.3.2=h7f8727e_0 103 | - liblief=0.12.3=h6a678d5_0 104 | - libnghttp2=1.46.0=hce63b2e_0 105 | - libnpp=11.6.3.124=hd2722f0_0 106 | - libnpp-dev=11.6.3.124=h3c42840_0 107 | - libnvjpeg=11.6.2.124=hd473ad6_0 108 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0 109 | - libpng=1.6.37=hbc83047_0 110 | - libssh2=1.10.0=h8f2d780_0 111 | - libstdcxx-ng=11.2.0=h1234567_1 112 | - libtasn1=4.16.0=h27cfd23_0 113 | - libtiff=4.4.0=hecacb30_2 114 | - libunistring=0.9.10=h27cfd23_0 115 | - libuuid=1.41.5=h5eee18b_0 116 | - libuv=1.40.0=h7b6447c_0 117 | - libwebp=1.2.4=h11a3e52_0 118 | - libwebp-base=1.2.4=h5eee18b_0 119 | - libxml2=2.9.14=h74e7548_0 120 | - lz4-c=1.9.4=h6a678d5_0 121 | - markupsafe=2.0.1=py310h7f8727e_0 122 | - matplotlib-inline=0.1.6=py310h06a4308_0 123 | - mkl=2021.4.0=h06a4308_640 124 | - mkl-service=2.4.0=py310h7f8727e_0 125 | - mkl_fft=1.3.1=py310hd6ae3a3_0 126 | - mkl_random=1.2.2=py310h00e6091_0 127 | - ncurses=6.3=h5eee18b_3 128 | - nettle=3.7.3=hbbd107a_1 129 | - nsight-compute=2022.4.0.15=0 130 | - numpy=1.22.3=py310hfa59a62_0 131 | - numpy-base=1.22.3=py310h9585f30_0 132 | - openh264=2.1.1=h4ff587b_0 133 | - openssl=1.1.1w=h7f8727e_0 134 | - parso=0.8.3=pyhd3eb1b0_0 135 | - patch=2.7.6=h7b6447c_1001 136 | - patchelf=0.15.0=h6a678d5_0 137 | - pexpect=4.8.0=pyhd3eb1b0_3 138 | - pickleshare=0.7.5=pyhd3eb1b0_1003 139 | - pillow=9.3.0=py310hace64e9_0 140 | - pip=22.3.1=py310h06a4308_0 141 | - pkginfo=1.8.3=py310h06a4308_0 142 | - pluggy=1.0.0=py310h06a4308_1 143 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 144 | - psutil=5.9.0=py310h5eee18b_0 145 | - ptyprocess=0.7.0=pyhd3eb1b0_2 146 | - pure_eval=0.2.2=pyhd3eb1b0_0 147 | - py-lief=0.12.3=py310h6a678d5_0 148 | - pycodestyle=2.11.1=py310h06a4308_0 149 | - pycosat=0.6.4=py310h5eee18b_0 150 | - pycparser=2.21=pyhd3eb1b0_0 151 | - pyopenssl=22.0.0=pyhd3eb1b0_0 152 | - pysocks=1.7.1=py310h06a4308_0 153 | - python=3.10.8=haa1d7c7_0 154 | - python-libarchive-c=2.9=pyhd3eb1b0_1 155 | - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0 156 | - pytorch-cuda=11.6=h867d48c_1 157 | - pytorch-mutex=1.0=cuda 158 | - pyyaml=6.0=py310h5eee18b_1 159 | - readline=8.2=h5eee18b_0 160 | - rhash=1.4.1=h3c74f83_1 161 | - ripgrep=13.0.0=hbdeaff8_0 162 | - ruamel.yaml=0.17.21=py310h5eee18b_0 163 | - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 164 | - six=1.16.0=pyhd3eb1b0_1 165 | - soupsieve=2.3.2.post1=py310h06a4308_0 166 | - sqlite=3.40.0=h5082296_0 167 | - stack_data=0.2.0=pyhd3eb1b0_0 168 | - tk=8.6.12=h1ccaba5_0 169 | - toml=0.10.2=pyhd3eb1b0_0 170 | - toolz=0.12.0=py310h06a4308_0 171 | - torchtext=0.14.1=py310 172 | - torchvision=0.14.1=py310_cu116 173 | - traitlets=5.7.1=py310h06a4308_0 174 | - typing_extensions=4.4.0=py310h06a4308_0 175 | - urllib3=1.26.13=py310h06a4308_0 176 | - wcwidth=0.2.5=pyhd3eb1b0_0 177 | - wheel=0.37.1=pyhd3eb1b0_0 178 | - xz=5.2.8=h5eee18b_0 179 | - yaml=0.2.5=h7b6447c_0 180 | - zlib=1.2.13=h5eee18b_0 181 | - zstd=1.5.2=ha4553b6_0 182 | - pip: 183 | - addict==2.4.0 184 | - aliyun-python-sdk-core==2.14.0 185 | - aliyun-python-sdk-kms==2.16.2 186 | - appdirs==1.4.4 187 | - astunparse==1.6.3 188 | - attrs==22.1.0 189 | - click==8.1.3 190 | - colorama==0.4.6 191 | - comm==0.1.3 192 | - compressai==1.2.4 193 | - contourpy==1.0.7 194 | - crcmod==1.7 195 | - cycler==0.11.0 196 | - debugpy==1.6.6 197 | - dnspython==2.2.1 198 | - docker-pycreds==0.4.0 199 | - einops==0.6.1 200 | - exceptiongroup==1.0.4 201 | - expecttest==0.1.4 202 | - fonttools==4.39.3 203 | - future==0.18.2 204 | - gdown==4.7.1 205 | - gitdb==4.0.10 206 | - gitpython==3.1.31 207 | - huggingface-hub==0.13.3 208 | - hypothesis==6.61.0 209 | - importlib-metadata==7.0.1 210 | - iniconfig==2.0.0 211 | - ipykernel==6.22.0 212 | - ipywidgets==8.0.6 213 | - jmespath==0.10.0 214 | - joblib==1.2.0 215 | - jupyter-client==8.1.0 216 | - jupyter-core==5.3.0 217 | - jupyterlab-widgets==3.0.7 218 | - kiwisolver==1.4.4 219 | - kmedoids==0.5.0 220 | - llvmlite==0.41.1 221 | - markdown==3.5.2 222 | - markdown-it-py==3.0.0 223 | - matplotlib==3.7.1 224 | - mdurl==0.1.2 225 | - mmcv==2.1.0 226 | - mmdet==3.3.0 227 | - mmengine==0.10.3 228 | - model-index==0.1.11 229 | - mpmath==1.2.1 230 | - nest-asyncio==1.5.6 231 | - ninja==1.11.1 232 | - numba==0.58.1 233 | - opencv-python==4.8.0.74 234 | - opendatalab==0.0.10 235 | - openmim==0.3.9 236 | - openxlab==0.0.34 237 | - ordered-set==4.1.0 238 | - oss2==2.17.0 239 | - packaging==23.0 240 | - pandas==2.0.3 241 | - pathtools==0.1.2 242 | - platformdirs==4.2.0 243 | - protobuf==4.22.1 244 | - py-cpuinfo==9.0.0 245 | - pycocotools==2.0.7 246 | - pycryptodome==3.20.0 247 | - pygments==2.17.2 248 | - pyparsing==3.0.9 249 | - pytest==7.2.2 250 | - pytest-gc==0.0.1 251 | - python-dateutil==2.8.2 252 | - python-etcd==0.4.5 253 | - python-papi==5.5.1.5 254 | - pytorch-msssim==0.2.1 255 | - pytz==2023.4 256 | - pyzmq==25.0.2 257 | - requests==2.28.2 258 | - rich==13.4.2 259 | - scikit-learn==1.2.2 260 | - scipy==1.10.1 261 | - seaborn==0.12.2 262 | - sentry-sdk==1.19.0 263 | - setproctitle==1.3.2 264 | - setuptools==60.2.0 265 | - shapely==2.0.3 266 | - smmap==5.0.0 267 | - sortedcontainers==2.4.0 268 | - sympy==1.11.1 269 | - tabulate==0.9.0 270 | - termcolor==2.4.0 271 | - terminaltables==3.1.10 272 | - threadpoolctl==3.1.0 273 | - timm==0.6.13 274 | - tomli==2.0.1 275 | - torch-geometric==2.3.1 276 | - torchac==0.9.3 277 | - torchelastic==0.2.2 278 | - torchprofile==0.0.4 279 | - tornado==6.2 280 | - tqdm==4.65.2 281 | - types-dataclasses==0.6.6 282 | - tzdata==2023.3 283 | - ultralytics==8.0.184 284 | - wandb==0.14.0 285 | - widgetsnbextension==4.0.7 286 | - yapf==0.40.2 287 | - zipp==3.17.0 288 | -------------------------------------------------------------------------------- /src/model/transfer_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.models import create_model, resume_checkpoint 3 | from timm.models.helpers import clean_state_dict 4 | import torch.nn as nn 5 | from torch.nn import Sequential as Seq 6 | from gcn_lib import act_layer 7 | import numpy as np 8 | from model import pyramid_vig 9 | from model import wignn 10 | from model import wignn_256 11 | from model import greedyvig 12 | from model import mobilevig 13 | import torchvision 14 | 15 | from torchprofile import profile_macs 16 | 17 | from collections import OrderedDict 18 | import sys 19 | 20 | 21 | def remove_pos(state_dict): 22 | cleaned_state_dict = OrderedDict() 23 | for k, v in state_dict.items(): 24 | if 'pos_embed' not in k and 'attn_mask' not in k and 'adj_mask' not in k: 25 | cleaned_state_dict[k] = v 26 | return cleaned_state_dict 27 | 28 | def remove_relative_pos(state_dict): 29 | cleaned_state_dict = OrderedDict() 30 | for k, v in state_dict.items(): 31 | if 'relative_pos' not in k and 'pos_embed' not in k: 32 | cleaned_state_dict[k] = v 33 | return cleaned_state_dict 34 | 35 | def get_model(model_type, use_shift = False, adapt_knn = False, checkpoint = None, pretrained = True, freezed = True, dataset = 'PET', crop_size = None): 36 | 37 | if dataset == 'CelebA': 38 | n_classes = 307 39 | else: 40 | raise NotImplementedError(f'Dataset: {dataset} not yet implemented') 41 | 42 | 43 | pretrained_creation = False 44 | 45 | if('wignn' in model_type): 46 | if crop_size is not None: 47 | model = create_model( 48 | model_type, 49 | pretrained=pretrained_creation, 50 | use_shifts = use_shift, 51 | adapt_knn = adapt_knn, 52 | img_size = crop_size 53 | ) 54 | else: 55 | model = create_model( 56 | model_type, 57 | pretrained=pretrained_creation, 58 | use_shifts = use_shift, 59 | adapt_knn = adapt_knn 60 | ) 61 | elif('pvig' in model_type): 62 | if crop_size is not None: 63 | model = create_model( 64 | model_type, 65 | pretrained=pretrained_creation, 66 | img_size = crop_size 67 | ) 68 | else: 69 | model = create_model( 70 | model_type, 71 | pretrained=pretrained_creation 72 | ) 73 | elif('GreedyViG' in model_type): 74 | model = create_model( 75 | model_type, 76 | num_classes=1000, 77 | distillation=False, 78 | pretrained=pretrained_creation 79 | ) 80 | elif('mobilevig' in model_type): 81 | model = create_model( 82 | model_type, 83 | ) 84 | 85 | else: 86 | raise NotImplementedError(f'Model: {model_type} not yet implemented') 87 | 88 | 89 | 90 | 91 | 92 | 93 | # load checkpoint for our models and ViG 94 | if pretrained and checkpoint != '': 95 | if 'wignn' in model_type: 96 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints' 97 | 98 | checkpoint = torch.load(checkpoint, map_location='cpu') 99 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 100 | print('Restoring model state from checkpoint...') 101 | state_dict = clean_state_dict(checkpoint['state_dict']) 102 | 103 | if crop_size is not None: 104 | state_dict = remove_pos(state_dict) 105 | 106 | model.load_state_dict(state_dict, strict = False) 107 | 108 | print(f'Pretrain weights for {model_type} loaded.') 109 | 110 | elif 'pvig' in model_type: 111 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints' 112 | 113 | state_dict = torch.load(checkpoint) 114 | if crop_size is not None: 115 | state_dict = remove_relative_pos(state_dict) 116 | model.load_state_dict(state_dict, strict=False) 117 | print('Pretrain weights for vig loaded.') 118 | 119 | elif 'GreedyViG' in model_type: 120 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints' 121 | 122 | checkpoint = torch.load(checkpoint, map_location='cpu') 123 | checkpoint_model = checkpoint['state_dict'] 124 | 125 | state_dict = model.state_dict() 126 | for k in ['dist_head.weight', 'dist_head.bias']: 127 | # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 128 | # print(f"Removing key {k} from pretrained checkpoint") 129 | del checkpoint_model[k] 130 | 131 | model.load_state_dict(checkpoint_model, strict=True) 132 | print('Pretrain weights for GreedyViG loaded.') 133 | 134 | elif 'mobilevig' in model_type: 135 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints' 136 | 137 | checkpoint = torch.load(checkpoint, map_location='cpu') 138 | checkpoint_model = checkpoint['state_dict'] 139 | 140 | state_dict = model.state_dict() 141 | for k in ['dist_head.weight', 'dist_head.bias']: 142 | # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 143 | # print(f"Removing key {k} from pretrained checkpoint") 144 | del checkpoint_model[k] 145 | 146 | model.load_state_dict(checkpoint_model, strict=True) 147 | print('Pretrain weights for MobileViG loaded.') 148 | 149 | 150 | 151 | 152 | 153 | # freeze the model 154 | for param in model.parameters(): 155 | if freezed: 156 | param.requires_grad = False 157 | else: 158 | param.requires_grad = True 159 | 160 | 161 | if 'pvig' in model_type or 'wignn' in model_type: 162 | model.prediction = Seq(nn.Conv2d(model.prediction[0].in_channels, 1024, 1, bias=True), 163 | nn.BatchNorm2d(1024), 164 | act_layer('gelu'), 165 | nn.Dropout(0.0), 166 | nn.Conv2d(1024, n_classes, 1, bias=True)) 167 | model.prediction.requires_grad = True 168 | elif 'GreedyViG' in model_type: 169 | model.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1), 170 | nn.Conv2d(model.prediction[1].in_channels, 768, kernel_size=1, bias=True), 171 | nn.BatchNorm2d(768), 172 | nn.GELU(), 173 | nn.Dropout(0.0)) 174 | 175 | model.head = nn.Conv2d(768, n_classes, kernel_size=1, bias=True) 176 | model.dist_head = nn.Conv2d(768, n_classes, 1, bias=True) 177 | model.prediction.requires_grad = True 178 | model.head.requires_grad = True 179 | model.dist_head.requires_grad = True 180 | 181 | elif 'mobilevig' in model_type: 182 | model.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1), 183 | nn.Conv2d(256, 512, 1, bias=True), 184 | nn.BatchNorm2d(512), 185 | nn.GELU(), 186 | nn.Dropout(0.)) 187 | 188 | model.head = nn.Conv2d(512, n_classes, 1, bias=True) 189 | model.dist_head = nn.Conv2d(512, n_classes, 1, bias=True) 190 | model.prediction.requires_grad = True 191 | model.head.requires_grad = True 192 | model.dist_head.requires_grad = True 193 | 194 | else: 195 | raise NotImplementedError(f'Model {model_type} not yet implemented\n{model}') 196 | 197 | 198 | params = sum(p.numel() for p in model.parameters()) 199 | trainable_parameters = filter(lambda p: p.requires_grad, model.parameters()) 200 | trainable_parameters = sum([np.prod(p.size()) for p in trainable_parameters]) 201 | 202 | 203 | 204 | return model, params, trainable_parameters, n_classes 205 | 206 | 207 | 208 | 209 | 210 | if __name__ == '__main__': 211 | checkpoint = 'path/to/checkpoint' 212 | model_type = 'wignn_ti_256_gelu' 213 | 214 | 215 | dataset = 'CelebA' 216 | model, params, trainable_parameters, n_classes = get_model(model_type = model_type, 217 | use_shift=True, 218 | adapt_knn=True, 219 | checkpoint = checkpoint, 220 | freezed=True, 221 | dataset = dataset, 222 | crop_size = 512) 223 | model.eval() 224 | model.cuda() 225 | x = torch.rand((1,3,512,512)).cuda() 226 | # print(model) 227 | print(f"Parameters: {params}") 228 | print(f"Trainable Parameters: {trainable_parameters}") 229 | 230 | # out = model(x) 231 | # print(out.shape) 232 | macs = profile_macs(model, x) 233 | print(f'\n\n!!!!! macs : {macs*10**-9}\n\n') -------------------------------------------------------------------------------- /src/model/wignn.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential as Seq 7 | 8 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | from timm.models.helpers import load_pretrained 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from timm.models.registry import register_model 12 | 13 | from gcn_lib import act_layer, WindowGrapher 14 | from timm.models import create_model 15 | import time 16 | 17 | from torchprofile import profile_macs 18 | 19 | import sys 20 | 21 | def _cfg(url='', **kwargs): 22 | return { 23 | 'url': url, 24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 25 | 'crop_pct': .9, 'interpolation': 'bicubic', 26 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 27 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 28 | **kwargs 29 | } 30 | 31 | 32 | default_cfgs = { 33 | 'wignn_224_gelu': _cfg( 34 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 35 | ), 36 | 'wignn_b_224_gelu': _cfg( 37 | crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 38 | ), 39 | } 40 | 41 | 42 | class FFN(nn.Module): 43 | def __init__(self, in_features, hidden_features=None, out_features=None, act='relu', drop_path=0.0): 44 | super().__init__() 45 | out_features = out_features or in_features 46 | hidden_features = hidden_features or in_features 47 | self.fc1 = nn.Sequential( 48 | nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0), 49 | nn.BatchNorm2d(hidden_features), 50 | ) 51 | self.act = act_layer(act) 52 | self.fc2 = nn.Sequential( 53 | nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0), 54 | nn.BatchNorm2d(out_features), 55 | ) 56 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 57 | 58 | def forward(self, x): 59 | shortcut = x 60 | x = self.fc1(x) 61 | x = self.act(x) 62 | x = self.fc2(x) 63 | x = self.drop_path(x) + shortcut 64 | return x#.reshape(B, C, N, 1) 65 | 66 | 67 | class Stem(nn.Module): 68 | """ Image to Visual Embedding 69 | Overlap: https://arxiv.org/pdf/2106.13797.pdf 70 | """ 71 | def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'): 72 | super().__init__() 73 | self.convs = nn.Sequential( 74 | nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1), 75 | nn.BatchNorm2d(out_dim//2), 76 | act_layer(act), 77 | nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1), 78 | nn.BatchNorm2d(out_dim), 79 | act_layer(act), 80 | nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1), 81 | nn.BatchNorm2d(out_dim), 82 | ) 83 | 84 | def forward(self, x): 85 | x = self.convs(x) 86 | return x 87 | 88 | 89 | class Downsample(nn.Module): 90 | """ Convolution-based downsample 91 | """ 92 | def __init__(self, in_dim=3, out_dim=768): 93 | super().__init__() 94 | self.conv = nn.Sequential( 95 | nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1), 96 | nn.BatchNorm2d(out_dim), 97 | ) 98 | 99 | def forward(self, x): 100 | x = self.conv(x) 101 | return x 102 | 103 | 104 | class DeepGCN(torch.nn.Module): 105 | def __init__(self, opt): 106 | super(DeepGCN, self).__init__() 107 | print(opt) 108 | self.k = opt.k # knn 109 | self.act = opt.act # activation layer {relu, prelu, leakyrelu, gelu, hswish} 110 | self.norm = opt.norm # batch or instance normalization {batch, instance} 111 | self.bias = opt.bias # bias of conv layer True or False 112 | self.epsilon = opt.epsilon # stochastic epsilon for gcn 113 | self.stochastic = opt.use_stochastic # stochastic for gcn, True or False 114 | self.conv = opt.conv # graph conv layer {edge, mr} 115 | self.emb_dims = opt.emb_dims # # Dimension of embeddings 116 | self.drop_path = opt.drop_path 117 | 118 | self.blocks = opt.blocks # [2,2,6,2] # number of basic blocks in the backbone 119 | self.channels = opt.channels # [80, 160, 400, 640] # number of channels of deep features 120 | 121 | self.img_size = opt.img_size 122 | self.use_shifts = opt.use_shifts 123 | 124 | 125 | self.n_blocks = sum(self.blocks) 126 | self.window_size = [opt.windows_size for _ in range(len(self.blocks))] 127 | 128 | if opt.use_reduce_ratios: 129 | self.reduce_ratios = [2, 2, 1, 1] 130 | else: 131 | self.reduce_ratios = [1, 1, 1, 1] 132 | 133 | adapt_knn = opt.adapt_knn 134 | print(f'Created Model wignn ({self.img_size}) Window: {self.window_size}') 135 | print(f'Use shifting windows: {self.use_shifts} adapt knn: {adapt_knn}') 136 | 137 | print(f'Knn: {self.k}') 138 | print(f'Reduce ratios: {self.reduce_ratios}') 139 | 140 | 141 | print(f'Channel: {self.channels}') 142 | print(f'Blocks: {self.blocks}') 143 | 144 | 145 | 146 | self.dpr = [x.item() for x in torch.linspace(0, self.drop_path, self.n_blocks)] # stochastic depth decay rule 147 | self.num_knn = [int(x.item()) for x in torch.linspace(self.k, self.k, self.n_blocks)] # number of knn's k 148 | # max_dilation = 49 // max(num_knn) 149 | 150 | self.stem = Stem(out_dim=self.channels[0], act=self.act) 151 | self.pos_embed = nn.Parameter(torch.zeros(1, self.channels[0], self.img_size//4, self.img_size//4)) 152 | 153 | self.backbone = nn.ModuleList([]) 154 | idx = 0 155 | for i in range(len(self.blocks)): 156 | if i > 0: 157 | self.backbone.append(Downsample(self.channels[i-1], self.channels[i])) 158 | 159 | for j in range(self.blocks[i]): 160 | shift_size = 0 161 | if j % 2 != 0 and self.use_shifts: 162 | shift_size = self.window_size[i] // 2 163 | 164 | self.backbone += [ 165 | Seq( 166 | WindowGrapher( 167 | in_channels = self.channels[i], 168 | kernel_size = self.num_knn[idx], 169 | windows_size = self.window_size[i], 170 | dilation = 1, 171 | conv = self.conv, 172 | act = self.act, 173 | norm = self.norm, 174 | bias = self.bias, 175 | stochastic = self.stochastic, 176 | epsilon = self.epsilon, 177 | drop_path = self.dpr[idx], 178 | relative_pos = True, 179 | shift_size = shift_size, 180 | r = self.reduce_ratios[i], 181 | input_resolution = ( 182 | (self.img_size//4) // (2 ** i), 183 | (self.img_size//4) // (2 ** i)), 184 | adapt_knn=adapt_knn 185 | ), 186 | FFN(self.channels[i], self.channels[i] * 4, act=self.act, drop_path=self.dpr[idx]) 187 | )] 188 | idx += 1 189 | self.backbone = Seq(*self.backbone) 190 | 191 | self.prediction = Seq(nn.Conv2d(self.channels[-1], self.emb_dims, 1, bias=True), 192 | nn.BatchNorm2d(self.emb_dims), 193 | act_layer(self.act), 194 | nn.Dropout(opt.dropout), 195 | nn.Conv2d(self.emb_dims, opt.n_classes, 1, bias=True)) 196 | self.model_init() 197 | 198 | def model_init(self): 199 | for m in self.modules(): 200 | if isinstance(m, torch.nn.Conv2d): 201 | torch.nn.init.kaiming_normal_(m.weight) 202 | m.weight.requires_grad = True 203 | if m.bias is not None: 204 | m.bias.data.zero_() 205 | m.bias.requires_grad = True 206 | 207 | def forward(self, inputs): 208 | x = self.stem(inputs) + self.pos_embed 209 | B, C, H, W = x.shape 210 | for i in range(len(self.backbone)): 211 | x = self.backbone[i](x) 212 | 213 | x = F.adaptive_avg_pool2d(x, 1) 214 | return self.prediction(x).squeeze(-1).squeeze(-1) 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | class OptInit: 228 | def __init__(self, 229 | num_classes=1000, 230 | drop_path_rate=0.0, 231 | knn = 9, 232 | use_shifts = True, 233 | use_reduce_ratios = False, 234 | img_size = 224, 235 | adapt_knn = False, 236 | 237 | channels = None, 238 | blocks = None, 239 | **kwargs): 240 | 241 | self.k = knn # neighbor num (default:9) 242 | self.conv = 'mr' # graph conv layer {edge, mr} 243 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish} 244 | self.norm = 'batch' # batch or instance normalization {batch, instance} 245 | self.bias = True # bias of conv layer True or False 246 | self.dropout = 0.0 # dropout rate 247 | self.use_dilation = True # use dilated knn or not 248 | self.epsilon = 0.2 # stochastic epsilon for gcn 249 | self.use_stochastic = False # stochastic for gcn, True or False 250 | self.drop_path = drop_path_rate 251 | self.blocks = blocks # number of basic blocks in the backbone 252 | self.channels = channels # number of channels of deep features 253 | self.n_classes = num_classes # Dimension of out_channels 254 | self.emb_dims = 1024 # Dimension of embeddings 255 | self.windows_size = 7 256 | 257 | self.use_shifts = use_shifts 258 | self.img_size = img_size 259 | self.use_reduce_ratios = False 260 | self.adapt_knn = adapt_knn 261 | 262 | @register_model 263 | def wignn_ti_224_gelu(pretrained=False, **kwargs): 264 | 265 | 266 | opt = OptInit(**kwargs, channels = [48, 96, 240, 384], blocks= [2,2,6,2]) 267 | 268 | model = DeepGCN(opt) 269 | model.default_cfg = default_cfgs['wignn_224_gelu'] 270 | return model 271 | 272 | 273 | @register_model 274 | def wignn_s_224_gelu(pretrained=False, **kwargs): 275 | 276 | opt = OptInit(**kwargs, channels = [80, 160, 400, 640], blocks= [2,2,6,2]) 277 | 278 | model = DeepGCN(opt) 279 | model.default_cfg = default_cfgs['wignn_224_gelu'] 280 | return model 281 | 282 | 283 | @register_model 284 | def wignn_m_224_gelu(pretrained=False, **kwargs): 285 | 286 | opt = OptInit(**kwargs, channels = [96, 192, 384, 768], blocks= [2,2,16,2]) 287 | 288 | model = DeepGCN(opt) 289 | model.default_cfg = default_cfgs['wignn_224_gelu'] 290 | return model 291 | 292 | 293 | @register_model 294 | def wignn_b_224_gelu(pretrained=False, **kwargs): 295 | 296 | opt = OptInit(**kwargs, channels = [128, 256, 512, 1024], blocks= [2,2,18,2]) 297 | 298 | model = DeepGCN(opt) 299 | model.default_cfg = default_cfgs['wignn_b_224_gelu'] 300 | return model 301 | 302 | 303 | 304 | if __name__ == '__main__': 305 | 306 | 307 | # for img_size in [224,224*2,224*3,224*4,224*5]: 308 | 309 | # model = create_model( 310 | # 'wignn_ti_224_gelu', 311 | # knn = 9, 312 | # use_shifts = True, 313 | # img_size = img_size, 314 | # adapt_knn = True 315 | # ) 316 | 317 | # model = model.cuda() 318 | # model.eval() 319 | 320 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 321 | 322 | # # print(model(x).shape) 323 | 324 | # macs = profile_macs(model, x) 325 | # print(f'\n\n!!!!! WiGNet macs ({img_size}): {macs}\n\n') 326 | 327 | for img_size in [224,224*2,224*3,224*4,224*5]: 328 | 329 | 330 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof: 331 | 332 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 333 | 334 | model = create_model( 335 | 'wignn_ti_224_gelu', 336 | knn = 9, 337 | use_shifts = True, 338 | img_size = img_size, 339 | adapt_knn = True 340 | ) 341 | 342 | model = model.cuda() 343 | model.eval() 344 | 345 | _ = model(x) 346 | 347 | f = open("memory_WiGNet_model.txt", "a") 348 | 349 | f.write(prof.key_averages().table()) 350 | f.write('\n\n') 351 | f.close() 352 | 353 | print('\n\n---------------\nResults saved in memory_WiGNet_model.txt') 354 | -------------------------------------------------------------------------------- /src/model/pyramid_vig.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import Sequential as Seq 7 | 8 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | from timm.models.helpers import load_pretrained 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from timm.models.registry import register_model 12 | 13 | from gcn_lib import Grapher, act_layer 14 | from timm.models import create_model 15 | 16 | from torchprofile import profile_macs 17 | import time 18 | 19 | def _cfg(url='', **kwargs): 20 | return { 21 | 'url': url, 22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 23 | 'crop_pct': .9, 'interpolation': 'bicubic', 24 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 25 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 26 | **kwargs 27 | } 28 | 29 | 30 | default_cfgs = { 31 | 'vig_224_gelu': _cfg( 32 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 33 | ), 34 | 'vig_b_224_gelu': _cfg( 35 | crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 36 | ), 37 | } 38 | 39 | 40 | class FFN(nn.Module): 41 | def __init__(self, in_features, hidden_features=None, out_features=None, act='relu', drop_path=0.0): 42 | super().__init__() 43 | out_features = out_features or in_features 44 | hidden_features = hidden_features or in_features 45 | self.fc1 = nn.Sequential( 46 | nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0), 47 | nn.BatchNorm2d(hidden_features), 48 | ) 49 | self.act = act_layer(act) 50 | self.fc2 = nn.Sequential( 51 | nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0), 52 | nn.BatchNorm2d(out_features), 53 | ) 54 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 55 | 56 | def forward(self, x): 57 | shortcut = x 58 | x = self.fc1(x) 59 | x = self.act(x) 60 | x = self.fc2(x) 61 | x = self.drop_path(x) + shortcut 62 | return x#.reshape(B, C, N, 1) 63 | 64 | 65 | class Stem(nn.Module): 66 | """ Image to Visual Embedding 67 | Overlap: https://arxiv.org/pdf/2106.13797.pdf 68 | """ 69 | def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'): 70 | super().__init__() 71 | self.convs = nn.Sequential( 72 | nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1), 73 | nn.BatchNorm2d(out_dim//2), 74 | act_layer(act), 75 | nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1), 76 | nn.BatchNorm2d(out_dim), 77 | act_layer(act), 78 | nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1), 79 | nn.BatchNorm2d(out_dim), 80 | ) 81 | 82 | def forward(self, x): 83 | x = self.convs(x) 84 | return x 85 | 86 | 87 | class Downsample(nn.Module): 88 | """ Convolution-based downsample 89 | """ 90 | def __init__(self, in_dim=3, out_dim=768): 91 | super().__init__() 92 | self.conv = nn.Sequential( 93 | nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1), 94 | nn.BatchNorm2d(out_dim), 95 | ) 96 | 97 | def forward(self, x): 98 | x = self.conv(x) 99 | return x 100 | 101 | 102 | class DeepGCN(torch.nn.Module): 103 | def __init__(self, opt): 104 | super(DeepGCN, self).__init__() 105 | print(opt) 106 | k = opt.k 107 | act = opt.act 108 | norm = opt.norm 109 | bias = opt.bias 110 | epsilon = opt.epsilon 111 | stochastic = opt.use_stochastic 112 | conv = opt.conv 113 | emb_dims = opt.emb_dims 114 | drop_path = opt.drop_path 115 | 116 | blocks = opt.blocks 117 | self.n_blocks = sum(blocks) 118 | channels = opt.channels 119 | reduce_ratios = [4, 2, 1, 1] 120 | dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)] # stochastic depth decay rule 121 | num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)] # number of knn's k 122 | max_dilation = 49 // max(num_knn) 123 | 124 | img_size = opt.img_size 125 | 126 | 127 | 128 | self.stem = Stem(out_dim=channels[0], act=act) 129 | self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], img_size//4, img_size//4)) 130 | HW = img_size // 4 * img_size // 4 131 | 132 | self.backbone = nn.ModuleList([]) 133 | idx = 0 134 | for i in range(len(blocks)): 135 | if i > 0: 136 | self.backbone.append(Downsample(channels[i-1], channels[i])) 137 | HW = HW // 4 138 | for j in range(blocks[i]): 139 | self.backbone += [ 140 | Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act, norm, 141 | bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx], 142 | relative_pos=True), 143 | FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx]) 144 | )] 145 | idx += 1 146 | self.backbone = Seq(*self.backbone) 147 | 148 | self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True), 149 | nn.BatchNorm2d(1024), 150 | act_layer(act), 151 | nn.Dropout(opt.dropout), 152 | nn.Conv2d(1024, opt.n_classes, 1, bias=True)) 153 | self.model_init() 154 | 155 | def model_init(self): 156 | for m in self.modules(): 157 | if isinstance(m, torch.nn.Conv2d): 158 | torch.nn.init.kaiming_normal_(m.weight) 159 | m.weight.requires_grad = True 160 | if m.bias is not None: 161 | m.bias.data.zero_() 162 | m.bias.requires_grad = True 163 | 164 | def forward(self, inputs): 165 | x = self.stem(inputs) + self.pos_embed 166 | B, C, H, W = x.shape 167 | for i in range(len(self.backbone)): 168 | x = self.backbone[i](x) 169 | 170 | x = F.adaptive_avg_pool2d(x, 1) 171 | 172 | 173 | return self.prediction(x).squeeze(-1).squeeze(-1) 174 | 175 | 176 | @register_model 177 | def pvig_ti_224_gelu(pretrained=False, **kwargs): 178 | class OptInit: 179 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs): 180 | self.k = 9 # neighbor num (default:9) 181 | self.conv = 'mr' # graph conv layer {edge, mr} 182 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish} 183 | self.norm = 'batch' # batch or instance normalization {batch, instance} 184 | self.bias = True # bias of conv layer True or False 185 | self.dropout = 0.0 # dropout rate 186 | self.use_dilation = True # use dilated knn or not 187 | self.epsilon = 0.2 # stochastic epsilon for gcn 188 | self.use_stochastic = False # stochastic for gcn, True or False 189 | self.drop_path = drop_path_rate 190 | self.blocks = [2,2,6,2] # number of basic blocks in the backbone 191 | self.channels = [48, 96, 240, 384] # number of channels of deep features 192 | self.n_classes = num_classes # Dimension of out_channels 193 | self.emb_dims = 1024 # Dimension of embeddings 194 | 195 | self.img_size = img_size 196 | 197 | 198 | opt = OptInit(**kwargs) 199 | model = DeepGCN(opt) 200 | model.default_cfg = default_cfgs['vig_224_gelu'] 201 | return model 202 | 203 | 204 | @register_model 205 | def pvig_s_224_gelu(pretrained=False, **kwargs): 206 | class OptInit: 207 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs): 208 | self.k = 9 # neighbor num (default:9) 209 | self.conv = 'mr' # graph conv layer {edge, mr} 210 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish} 211 | self.norm = 'batch' # batch or instance normalization {batch, instance} 212 | self.bias = True # bias of conv layer True or False 213 | self.dropout = 0.0 # dropout rate 214 | self.use_dilation = True # use dilated knn or not 215 | self.epsilon = 0.2 # stochastic epsilon for gcn 216 | self.use_stochastic = False # stochastic for gcn, True or False 217 | self.drop_path = drop_path_rate 218 | self.blocks = [2,2,6,2] # number of basic blocks in the backbone 219 | self.channels = [80, 160, 400, 640] # number of channels of deep features 220 | self.n_classes = num_classes # Dimension of out_channels 221 | self.emb_dims = 1024 # Dimension of embeddings 222 | 223 | self.img_size = img_size 224 | 225 | 226 | 227 | opt = OptInit(**kwargs) 228 | model = DeepGCN(opt) 229 | model.default_cfg = default_cfgs['vig_224_gelu'] 230 | return model 231 | 232 | 233 | @register_model 234 | def pvig_m_224_gelu(pretrained=False, **kwargs): 235 | class OptInit: 236 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs): 237 | self.k = 9 # neighbor num (default:9) 238 | self.conv = 'mr' # graph conv layer {edge, mr} 239 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish} 240 | self.norm = 'batch' # batch or instance normalization {batch, instance} 241 | self.bias = True # bias of conv layer True or False 242 | self.dropout = 0.0 # dropout rate 243 | self.use_dilation = True # use dilated knn or not 244 | self.epsilon = 0.2 # stochastic epsilon for gcn 245 | self.use_stochastic = False # stochastic for gcn, True or False 246 | self.drop_path = drop_path_rate 247 | self.blocks = [2,2,16,2] # number of basic blocks in the backbone 248 | self.channels = [96, 192, 384, 768] # number of channels of deep features 249 | self.n_classes = num_classes # Dimension of out_channels 250 | self.emb_dims = 1024 # Dimension of embeddings 251 | 252 | self.img_size = img_size 253 | 254 | 255 | opt = OptInit(**kwargs) 256 | model = DeepGCN(opt) 257 | model.default_cfg = default_cfgs['vig_224_gelu'] 258 | return model 259 | 260 | 261 | @register_model 262 | def pvig_b_224_gelu(pretrained=False, **kwargs): 263 | class OptInit: 264 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs): 265 | self.k = 9 # neighbor num (default:9) 266 | self.conv = 'mr' # graph conv layer {edge, mr} 267 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish} 268 | self.norm = 'batch' # batch or instance normalization {batch, instance} 269 | self.bias = True # bias of conv layer True or False 270 | self.dropout = 0.0 # dropout rate 271 | self.use_dilation = True # use dilated knn or not 272 | self.epsilon = 0.2 # stochastic epsilon for gcn 273 | self.use_stochastic = False # stochastic for gcn, True or False 274 | self.drop_path = drop_path_rate 275 | self.blocks = [2,2,18,2] # number of basic blocks in the backbone 276 | self.channels = [128, 256, 512, 1024] # number of channels of deep features 277 | self.n_classes = num_classes # Dimension of out_channels 278 | self.emb_dims = 1024 # Dimension of embeddings 279 | 280 | self.img_size = img_size 281 | 282 | 283 | opt = OptInit(**kwargs) 284 | model = DeepGCN(opt) 285 | model.default_cfg = default_cfgs['vig_b_224_gelu'] 286 | return model 287 | 288 | 289 | 290 | if __name__ == '__main__': 291 | 292 | # for img_size in [224*5]: 293 | 294 | # model = create_model( 295 | # 'pvig_ti_224_gelu', 296 | # img_size = img_size 297 | # ) 298 | 299 | # model = model.cuda() 300 | # model.eval() 301 | 302 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 303 | 304 | # # print(model(x).shape) 305 | 306 | # macs = profile_macs(model, x) 307 | # print(f'\n\n!!!!! P-ViG macs ({img_size}): {macs*10**-9}\n\n') 308 | 309 | for img_size in [224,224*2,224*3,224*4,224*5]: 310 | 311 | 312 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof: 313 | 314 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 315 | 316 | model = create_model( 317 | 'pvig_ti_224_gelu', 318 | img_size = img_size 319 | ) 320 | 321 | model = model.cuda() 322 | model.eval() 323 | 324 | _ = model(x) 325 | 326 | f = open("memory_ViG_model.txt", "a") 327 | 328 | f.write(prof.key_averages().table()) 329 | f.write('\n\n') 330 | f.close() 331 | 332 | print('\n\n---------------\nResults saved in memory_ViG_model.txt') 333 | 334 | -------------------------------------------------------------------------------- /src/model/mobilevig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import Sequential as Seq 5 | 6 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 7 | from timm.models.layers import DropPath 8 | from timm.models.registry import register_model 9 | 10 | from timm.models import create_model 11 | from torchprofile import profile_macs 12 | 13 | 14 | 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | 27 | default_cfgs = { 28 | 'mobilevig': _cfg(crop_pct=0.9, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) 29 | } 30 | 31 | 32 | class Stem(nn.Module): 33 | def __init__(self, input_dim, output_dim, activation=nn.GELU): 34 | super(Stem, self).__init__() 35 | self.stem = nn.Sequential( 36 | nn.Conv2d(input_dim, output_dim // 2, kernel_size=3, stride=2, padding=1), 37 | nn.BatchNorm2d(output_dim // 2), 38 | nn.GELU(), 39 | nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, stride=2, padding=1), 40 | nn.BatchNorm2d(output_dim), 41 | nn.GELU() 42 | ) 43 | 44 | def forward(self, x): 45 | return self.stem(x) 46 | 47 | 48 | class MLP(nn.Module): 49 | """ 50 | Implementation of MLP with 1*1 convolutions. 51 | Input: tensor with shape [B, C, H, W] 52 | """ 53 | 54 | def __init__(self, in_features, hidden_features=None, 55 | out_features=None, drop=0., mid_conv=False): 56 | super().__init__() 57 | out_features = out_features or in_features 58 | hidden_features = hidden_features or in_features 59 | self.mid_conv = mid_conv 60 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 61 | self.act = nn.GELU() 62 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 63 | self.drop = nn.Dropout(drop) 64 | 65 | if self.mid_conv: 66 | self.mid = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, 67 | groups=hidden_features) 68 | self.mid_norm = nn.BatchNorm2d(hidden_features) 69 | 70 | self.norm1 = nn.BatchNorm2d(hidden_features) 71 | self.norm2 = nn.BatchNorm2d(out_features) 72 | 73 | def forward(self, x): 74 | x = self.fc1(x) 75 | x = self.norm1(x) 76 | x = self.act(x) 77 | 78 | if self.mid_conv: 79 | x_mid = self.mid(x) 80 | x_mid = self.mid_norm(x_mid) 81 | x = self.act(x_mid) 82 | x = self.drop(x) 83 | 84 | x = self.fc2(x) 85 | x = self.norm2(x) 86 | 87 | x = self.drop(x) 88 | return x 89 | 90 | 91 | class InvertedResidual(nn.Module): 92 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): 93 | super().__init__() 94 | 95 | mlp_hidden_dim = int(dim * mlp_ratio) 96 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop, mid_conv=True) 97 | 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. \ 99 | else nn.Identity() 100 | self.use_layer_scale = use_layer_scale 101 | if use_layer_scale: 102 | self.layer_scale_2 = nn.Parameter( 103 | layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) 104 | 105 | def forward(self, x): 106 | if self.use_layer_scale: 107 | x = x + self.drop_path(self.layer_scale_2 * self.mlp(x)) 108 | else: 109 | x = x + self.drop_path(self.mlp(x)) 110 | return x 111 | 112 | class MRConv4d(nn.Module): 113 | """ 114 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type 115 | 116 | K is the number of superpatches, therefore hops equals res // K. 117 | """ 118 | def __init__(self, in_channels, out_channels, K=2): 119 | super(MRConv4d, self).__init__() 120 | self.nn = nn.Sequential( 121 | nn.Conv2d(in_channels * 2, out_channels, 1), 122 | nn.BatchNorm2d(in_channels * 2), 123 | nn.GELU() 124 | ) 125 | self.K = K 126 | 127 | def forward(self, x): 128 | B, C, H, W = x.shape 129 | 130 | x_j = x - x 131 | for i in range(self.K, H, self.K): 132 | x_c = x - torch.cat([x[:, :, -i:, :], x[:, :, :-i, :]], dim=2) 133 | x_j = torch.max(x_j, x_c) 134 | for i in range(self.K, W, self.K): 135 | x_r = x - torch.cat([x[:, :, :, -i:], x[:, :, :, :-i]], dim=3) 136 | x_j = torch.max(x_j, x_r) 137 | 138 | x = torch.cat([x, x_j], dim=1) 139 | return self.nn(x) 140 | 141 | 142 | class Grapher(nn.Module): 143 | """ 144 | Grapher module with graph convolution and fc layers 145 | """ 146 | def __init__(self, in_channels, drop_path=0.0, K=2): 147 | super(Grapher, self).__init__() 148 | self.channels = in_channels 149 | self.K = K 150 | 151 | self.fc1 = nn.Sequential( 152 | nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), 153 | nn.BatchNorm2d(in_channels), 154 | ) 155 | self.graph_conv = MRConv4d(in_channels, in_channels * 2, K=self.K) 156 | self.fc2 = nn.Sequential( 157 | nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), 158 | nn.BatchNorm2d(in_channels), 159 | ) # out_channels back to 1x} 160 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 161 | 162 | 163 | def forward(self, x): 164 | _tmp = x 165 | x = self.fc1(x) 166 | x = self.graph_conv(x) 167 | x = self.fc2(x) 168 | x = self.drop_path(x) + _tmp 169 | 170 | return x 171 | 172 | 173 | class Downsample(nn.Module): 174 | """ Convolution-based downsample 175 | """ 176 | def __init__(self, in_dim, out_dim): 177 | super().__init__() 178 | self.conv = nn.Sequential( 179 | nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1), 180 | nn.BatchNorm2d(out_dim), 181 | ) 182 | 183 | def forward(self, x): 184 | x = self.conv(x) 185 | return x 186 | 187 | 188 | class FFN(nn.Module): 189 | def __init__(self, in_features, hidden_features=None, out_features=None, drop_path=0.0): 190 | super().__init__() 191 | out_features = out_features or in_features # same as input 192 | hidden_features = hidden_features or in_features # x4 193 | self.fc1 = nn.Sequential( 194 | nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0), 195 | nn.BatchNorm2d(hidden_features), 196 | ) 197 | self.act = nn.GELU() 198 | self.fc2 = nn.Sequential( 199 | nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0), 200 | nn.BatchNorm2d(out_features), 201 | ) 202 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 203 | 204 | def forward(self, x): 205 | shortcut = x 206 | x = self.fc1(x) 207 | x = self.act(x) 208 | x = self.fc2(x) 209 | x = self.drop_path(x) + shortcut 210 | return x 211 | 212 | 213 | class MobileViG(torch.nn.Module): 214 | def __init__(self, local_blocks, local_channels, 215 | global_blocks, global_channels, 216 | dropout=0., drop_path=0., emb_dims=512, 217 | K=2, distillation=False, num_classes=1000): 218 | super(MobileViG, self).__init__() 219 | 220 | self.distillation = distillation 221 | 222 | n_blocks = sum(global_blocks) + sum(local_blocks) 223 | dpr = [x.item() for x in torch.linspace(0, drop_path, n_blocks)] # stochastic depth decay rule 224 | dpr_idx = 0 225 | 226 | self.stem = Stem(input_dim=3, output_dim=local_channels[0]) 227 | 228 | # local processing with inverted residuals 229 | self.local_backbone = nn.ModuleList([]) 230 | for i in range(len(local_blocks)): 231 | if i > 0: 232 | self.local_backbone.append(Downsample(local_channels[i-1], local_channels[i])) 233 | for _ in range(local_blocks[i]): 234 | self.local_backbone.append(InvertedResidual(dim=local_channels[i], mlp_ratio=4, drop_path=dpr[dpr_idx])) 235 | dpr_idx += 1 236 | self.local_backbone.append(Downsample(local_channels[-1], global_channels[0])) # transition from local to global 237 | 238 | # global processing with svga 239 | self.backbone = nn.ModuleList([]) 240 | for i in range(len(global_blocks)): 241 | if i > 0: 242 | self.backbone.append(Downsample(global_channels[i-1], global_channels[i])) 243 | for j in range(global_blocks[i]): 244 | self.backbone += [nn.Sequential( 245 | Grapher(global_channels[i], drop_path=dpr[dpr_idx], K=K), 246 | FFN(global_channels[i], global_channels[i] * 4, drop_path=dpr[dpr_idx])) 247 | ] 248 | dpr_idx += 1 249 | 250 | self.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1), 251 | nn.Conv2d(global_channels[-1], emb_dims, 1, bias=True), 252 | nn.BatchNorm2d(emb_dims), 253 | nn.GELU(), 254 | nn.Dropout(dropout)) 255 | 256 | self.head = nn.Conv2d(emb_dims, num_classes, 1, bias=True) 257 | 258 | if self.distillation: 259 | self.dist_head = nn.Conv2d(emb_dims, num_classes, 1, bias=True) 260 | 261 | self.model_init() 262 | 263 | def model_init(self): 264 | for m in self.modules(): 265 | if isinstance(m, torch.nn.Conv2d): 266 | torch.nn.init.kaiming_normal_(m.weight) 267 | m.weight.requires_grad = True 268 | if m.bias is not None: 269 | m.bias.data.zero_() 270 | m.bias.requires_grad = True 271 | 272 | def forward(self, inputs): 273 | x = self.stem(inputs) 274 | B, C, H, W = x.shape 275 | for i in range(len(self.local_backbone)): 276 | x = self.local_backbone[i](x) 277 | for i in range(len(self.backbone)): 278 | x = self.backbone[i](x) 279 | 280 | x = self.prediction(x) 281 | 282 | if self.distillation: 283 | x = self.head(x).squeeze(-1).squeeze(-1), self.dist_head(x).squeeze(-1).squeeze(-1) 284 | if not self.training: 285 | x = (x[0] + x[1]) / 2 286 | else: 287 | x = self.head(x).squeeze(-1).squeeze(-1) 288 | return x 289 | 290 | 291 | @register_model 292 | def mobilevig_ti(pretrained=False, **kwargs): 293 | model = MobileViG(local_blocks=[2, 2, 6], 294 | local_channels=[42, 84, 168], 295 | global_blocks=[2], 296 | global_channels=[256], 297 | dropout=0., 298 | drop_path=0.1, 299 | emb_dims=512, 300 | K=2, 301 | distillation=False, 302 | num_classes=1000) 303 | model.default_cfg = default_cfgs['mobilevig'] 304 | return model 305 | 306 | 307 | @register_model 308 | def mobilevig_s(pretrained=False, **kwargs): 309 | model = MobileViG(local_blocks=[3, 3, 9], 310 | local_channels=[42, 84, 176], 311 | global_blocks=[3], 312 | global_channels=[256], 313 | dropout=0., 314 | drop_path=0.1, 315 | emb_dims=512, 316 | K=2, 317 | distillation=False, 318 | num_classes=1000) 319 | model.default_cfg = default_cfgs['mobilevig'] 320 | return model 321 | 322 | 323 | @register_model 324 | def mobilevig_m(pretrained=False, **kwargs): 325 | model = MobileViG(local_blocks=[3, 3, 9], 326 | local_channels=[42, 84, 224], 327 | global_blocks=[3], 328 | global_channels=[400], 329 | dropout=0., 330 | drop_path=0.1, 331 | emb_dims=768, 332 | K=2, 333 | distillation=False, 334 | num_classes=1000) 335 | model.default_cfg = default_cfgs['mobilevig'] 336 | return model 337 | 338 | 339 | @register_model 340 | def mobilevig_b(pretrained=False, **kwargs): 341 | model = MobileViG(local_blocks=[5, 5, 15], 342 | local_channels=[42, 84, 240], 343 | global_blocks=[5], 344 | global_channels=[464], 345 | dropout=0., 346 | drop_path=0.1, 347 | emb_dims=768, 348 | K=2, 349 | distillation=False, 350 | num_classes=1000) 351 | model.default_cfg = default_cfgs['mobilevig'] 352 | return model 353 | 354 | 355 | 356 | 357 | if __name__ == '__main__': 358 | 359 | # for img_size in [224,224*2,224*3,224*4,224*5]: 360 | # model = create_model( 361 | # 'mobilevig_s' 362 | # ) 363 | # model = model.cuda() 364 | # model.eval() 365 | 366 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 367 | 368 | 369 | # macs = profile_macs(model, x) 370 | # print(f'\n\n!!!!! MobileVig macs ({img_size}): {macs*10**(-9)}\n\n') 371 | 372 | 373 | for img_size in [224,224*2,224*3,224*4,224*5]: 374 | 375 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof: 376 | 377 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 378 | 379 | model = create_model( 380 | 'mobilevig_s' 381 | ) 382 | 383 | model = model.cuda() 384 | model.eval() 385 | 386 | _ = model(x) 387 | 388 | f = open("memory_MobileVig_model.txt", "a") 389 | 390 | 391 | f.write(prof.key_averages().table()) 392 | f.write('\n\n') 393 | f.close() 394 | print('\n\n---------------\nResults saved in memory_MobileVig_model.txt') 395 | 396 | -------------------------------------------------------------------------------- /src/model/greedyvig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | from torch.nn import Sequential as Seq 6 | 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.models.layers import DropPath 9 | from timm.models.registry import register_model 10 | 11 | import random 12 | import warnings 13 | from timm.models import create_model 14 | from torchprofile import profile_macs 15 | 16 | warnings.filterwarnings('ignore') 17 | 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | import sys 20 | 21 | # IMAGENET 22 | def _cfg(url='', **kwargs): 23 | return { 24 | 'url': url, 25 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 26 | 'crop_pct': .9, 'interpolation': 'bicubic', 27 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 28 | 'classifier': 'head', 29 | **kwargs 30 | } 31 | 32 | 33 | default_cfgs = { 34 | 'greedyvig': _cfg(crop_pct=0.9, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) 35 | } 36 | 37 | 38 | class Stem(nn.Module): 39 | def __init__(self, input_dim, output_dim): 40 | super(Stem, self).__init__() 41 | self.stem = nn.Sequential( 42 | nn.Conv2d(input_dim, output_dim // 2, kernel_size=3, stride=2, padding=1), 43 | nn.BatchNorm2d(output_dim // 2), 44 | nn.GELU(), 45 | nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, stride=2, padding=1), 46 | nn.BatchNorm2d(output_dim), 47 | nn.GELU(), 48 | ) 49 | 50 | def forward(self, x): 51 | return self.stem(x) 52 | 53 | 54 | class DepthWiseSeparable(nn.Module): 55 | def __init__(self, in_dim, kernel, expansion=4): 56 | super().__init__() 57 | 58 | self.pw1 = nn.Conv2d(in_dim, in_dim * 4, 1) # kernel size = 1 59 | self.norm1 = nn.BatchNorm2d(in_dim * 4) 60 | self.act1 = nn.GELU() 61 | 62 | self.dw = nn.Conv2d(in_dim * 4, in_dim * 4, kernel_size=kernel, stride=1, padding=1, groups=in_dim * 4) # kernel size = 3 63 | self.norm2 = nn.BatchNorm2d(in_dim * 4) 64 | self.act2 = nn.GELU() 65 | 66 | self.pw2 = nn.Conv2d(in_dim * 4, in_dim, 1) 67 | self.norm3 = nn.BatchNorm2d(in_dim) 68 | 69 | def forward(self, x): 70 | x = self.pw1(x) 71 | x = self.norm1(x) 72 | x = self.act1(x) 73 | 74 | x = self.dw(x) 75 | x = self.norm2(x) 76 | x = self.act2(x) 77 | 78 | x = self.pw2(x) 79 | x = self.norm3(x) 80 | return x 81 | 82 | 83 | class InvertedResidual(nn.Module): 84 | def __init__(self, dim, kernel, expansion_ratio=4., drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): 85 | super().__init__() 86 | 87 | self.dws = DepthWiseSeparable(in_dim=dim, kernel=kernel, expansion=expansion_ratio) 88 | 89 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 90 | self.use_layer_scale = use_layer_scale 91 | if use_layer_scale: 92 | self.layer_scale_2 = nn.Parameter( 93 | layer_scale_init_value * torch.ones(dim), requires_grad=True) 94 | 95 | def forward(self, x): 96 | if self.use_layer_scale: 97 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.dws(x)) 98 | else: 99 | x = x + self.drop_path(self.dws(x)) 100 | return x 101 | 102 | 103 | 104 | class DynamicMRConv4d(nn.Module): 105 | def __init__(self, in_channels, out_channels, K): 106 | super().__init__() 107 | self.nn = nn.Sequential( 108 | nn.Conv2d(in_channels, out_channels, 1), 109 | nn.BatchNorm2d(out_channels), 110 | nn.GELU() 111 | ) 112 | self.K = K 113 | self.mean = 0 114 | self.std = 0 115 | 116 | def forward(self, x): 117 | B, C, H, W = x.shape 118 | x_j = x - x 119 | 120 | # get an estimate of the mean distance by computing the distance of points b/w quadrants. This is for efficiency to minimize computations. 121 | x_rolled = torch.cat([x[:, :, -H//2:, :], x[:, :, :-H//2, :]], dim=2) 122 | x_rolled = torch.cat([x_rolled[:, :, :, -W//2:], x_rolled[:, :, :, :-W//2]], dim=3) 123 | 124 | # Norm, Euclidean Distance 125 | norm = torch.norm((x - x_rolled), p=1, dim=1, keepdim=True) 126 | 127 | self.mean = torch.mean(norm, dim=[2,3], keepdim=True) 128 | self.std = torch.std(norm, dim=[2,3], keepdim=True) 129 | 130 | 131 | for i in range(0, H, self.K): 132 | x_rolled = torch.cat([x[:, :, -i:, :], x[:, :, :-i, :]], dim=2) 133 | 134 | dist = torch.norm((x - x_rolled), p=1, dim=1, keepdim=True) 135 | 136 | # Got 83.86% 137 | mask = torch.where(dist < self.mean - self.std, 1, 0) 138 | 139 | x_rolled_and_masked = (x_rolled - x) * mask 140 | x_j = torch.max(x_j, x_rolled_and_masked) 141 | 142 | for j in range(0, W, self.K): 143 | x_rolled = torch.cat([x[:, :, :, -j:], x[:, :, :, :-j]], dim=3) 144 | 145 | dist = torch.norm((x - x_rolled), p=1, dim=1, keepdim=True) 146 | 147 | mask = torch.where(dist < self.mean - self.std, 1, 0) 148 | 149 | x_rolled_and_masked = (x_rolled - x) * mask 150 | x_j = torch.max(x_j, x_rolled_and_masked) 151 | 152 | x = torch.cat([x, x_j], dim=1) 153 | return self.nn(x) 154 | 155 | 156 | 157 | class ConditionalPositionEncoding(nn.Module): 158 | """ 159 | Implementation of conditional positional encoding. For more details refer to paper: 160 | `Conditional Positional Encodings for Vision Transformers `_ 161 | """ 162 | def __init__(self, in_channels, kernel_size): 163 | super().__init__() 164 | self.pe = nn.Conv2d( 165 | in_channels=in_channels, 166 | out_channels=in_channels, 167 | kernel_size=kernel_size, 168 | stride=1, 169 | padding=kernel_size // 2, 170 | bias=True, 171 | groups=in_channels 172 | ) 173 | 174 | def forward(self, x): 175 | x = self.pe(x) + x 176 | return x 177 | 178 | 179 | class Grapher(nn.Module): 180 | """ 181 | Grapher module with graph convolution and fc layers 182 | """ 183 | def __init__(self, in_channels, K): 184 | super(Grapher, self).__init__() 185 | self.channels = in_channels 186 | self.K = K 187 | 188 | self.cpe = ConditionalPositionEncoding(in_channels, kernel_size=7) 189 | self.fc1 = nn.Sequential( 190 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0), 191 | nn.BatchNorm2d(in_channels), 192 | ) 193 | self.graph_conv = DynamicMRConv4d(in_channels * 2, in_channels, K=self.K) 194 | self.fc2 = nn.Sequential( 195 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0), 196 | nn.BatchNorm2d(in_channels), 197 | ) # out_channels back to 1x} 198 | 199 | 200 | def forward(self, x): 201 | x = self.cpe(x) 202 | x = self.fc1(x) 203 | x = self.graph_conv(x) 204 | x = self.fc2(x) 205 | 206 | return x 207 | 208 | 209 | class DynamicGraphConvBlock(nn.Module): 210 | def __init__(self, in_dim, drop_path=0., K=2, use_layer_scale=True, layer_scale_init_value=1e-5): 211 | super().__init__() 212 | 213 | self.mixer = Grapher(in_dim, K) 214 | self.ffn = nn.Sequential( 215 | nn.Conv2d(in_dim, in_dim * 4, kernel_size=1, stride=1, padding=0), 216 | nn.BatchNorm2d(in_dim * 4), 217 | nn.GELU(), 218 | nn.Conv2d(in_dim * 4, in_dim, kernel_size=1, stride=1, padding=0), 219 | nn.BatchNorm2d(in_dim), 220 | ) 221 | 222 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 223 | self.use_layer_scale = use_layer_scale 224 | if use_layer_scale: 225 | self.layer_scale_1 = nn.Parameter( 226 | layer_scale_init_value * torch.ones(in_dim), requires_grad=True) 227 | self.layer_scale_2 = nn.Parameter( 228 | layer_scale_init_value * torch.ones(in_dim), requires_grad=True) 229 | 230 | def forward(self, x): 231 | if self.use_layer_scale: 232 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.mixer(x)) 233 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.ffn(x)) 234 | else: 235 | x = x + self.drop_path(self.mixer(x)) 236 | x = x + self.drop_path(self.ffn(x)) 237 | return x 238 | 239 | 240 | class Downsample(nn.Module): 241 | """ 242 | Convolution-based downsample 243 | """ 244 | def __init__(self, in_dim, out_dim): 245 | super().__init__() 246 | self.conv = nn.Sequential( 247 | nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1), 248 | nn.BatchNorm2d(out_dim), 249 | ) 250 | def forward(self, x): 251 | x = self.conv(x) 252 | return x 253 | 254 | 255 | class GreedyViG(torch.nn.Module): 256 | def __init__(self, blocks, channels, kernels, stride, 257 | act_func, dropout=0., drop_path=0., emb_dims=512, 258 | K=2, distillation=False, num_classes=1000): 259 | super(GreedyViG, self).__init__() 260 | 261 | self.distillation = distillation 262 | self.stage_names = ['stem', 'local_1', 'local_2', 'local_3', 'global'] 263 | 264 | n_blocks = sum([sum(x) for x in blocks]) 265 | dpr = [x.item() for x in torch.linspace(0, drop_path, n_blocks)] # stochastic depth decay rule 266 | dpr_idx = 0 267 | 268 | self.stem = Stem(input_dim=3, output_dim=channels[0]) 269 | 270 | self.backbone = [] 271 | for i in range(len(blocks)): 272 | stage = [] 273 | local_stages = blocks[i][0] 274 | global_stages = blocks[i][1] 275 | if i > 0: 276 | stage.append(Downsample(channels[i-1], channels[i])) 277 | for _ in range(local_stages): 278 | stage.append(InvertedResidual(dim=channels[i], kernel=3, expansion_ratio=4, drop_path=dpr[dpr_idx])) 279 | dpr_idx += 1 280 | for _ in range(global_stages): 281 | stage.append(DynamicGraphConvBlock(channels[i], drop_path=dpr[dpr_idx], K=K[i])) 282 | dpr_idx += 1 283 | self.backbone.append(nn.Sequential(*stage)) 284 | 285 | self.backbone = nn.Sequential(*self.backbone) 286 | 287 | self.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1), 288 | nn.Conv2d(channels[-1], emb_dims, kernel_size=1, bias=True), 289 | nn.BatchNorm2d(emb_dims), 290 | nn.GELU(), 291 | nn.Dropout(dropout)) 292 | 293 | self.head = nn.Conv2d(emb_dims, num_classes, kernel_size=1, bias=True) 294 | 295 | if self.distillation: 296 | self.dist_head = nn.Conv2d(emb_dims, num_classes, 1, bias=True) 297 | 298 | self.model_init() 299 | 300 | def model_init(self): 301 | for m in self.modules(): 302 | if isinstance(m, torch.nn.Conv2d): 303 | torch.nn.init.kaiming_normal_(m.weight) 304 | m.weight.requires_grad = True 305 | if m.bias is not None: 306 | m.bias.data.zero_() 307 | m.bias.requires_grad = True 308 | 309 | def forward(self, inputs): 310 | x = self.stem(inputs) 311 | B, C, H, W = x.shape 312 | x = self.backbone(x) 313 | 314 | x = self.prediction(x) 315 | 316 | 317 | if self.distillation: 318 | 319 | x = self.head(x).squeeze(-1).squeeze(-1), self.dist_head(x).squeeze(-1).squeeze(-1) 320 | if not self.training: 321 | x = (x[0] + x[1]) / 2 322 | else: 323 | x = self.head(x).squeeze(-1).squeeze(-1) 324 | return x 325 | 326 | @register_model 327 | def GreedyViG_S(pretrained=False, **kwargs): ## 12.0 M, 1.6 GMACs 328 | model = GreedyViG(blocks=[[2,2], [2,2], [6,2], [2,2]], 329 | channels=[48, 96, 192, 384], 330 | kernels=3, 331 | stride=1, 332 | act_func='gelu', 333 | dropout=0., 334 | drop_path=0.1, 335 | emb_dims=768, 336 | K=[8, 4, 2, 1], 337 | distillation=False, 338 | num_classes=1000) 339 | model.default_cfg = default_cfgs['greedyvig'] 340 | return model 341 | 342 | @register_model 343 | def GreedyViG_M(pretrained=False, **kwargs): # 21.9 M, 3.2 GMACs 344 | model = GreedyViG(blocks=[[3,3], [3,3], [9,3], [3,3]], 345 | channels=[56, 112, 224, 448], 346 | kernels=3, 347 | stride=1, 348 | act_func='gelu', 349 | dropout=0., 350 | drop_path=0.1, 351 | emb_dims=768, 352 | K=[8, 4, 2, 1], 353 | distillation=False, 354 | num_classes=1000) 355 | model.default_cfg = default_cfgs['greedyvig'] 356 | return model 357 | 358 | @register_model 359 | def GreedyViG_B(pretrained=False, **kwargs): # 30.9 M, 5.2 GMACs 360 | model = GreedyViG(blocks=[[4,4], [4,4], [12,4], [3,3]], 361 | channels=[64, 128, 256, 512], 362 | kernels=3, 363 | stride=1, 364 | act_func='gelu', 365 | dropout=0., 366 | drop_path=0.1, 367 | emb_dims=768, 368 | K=[8, 4, 2, 1], 369 | distillation=False, 370 | num_classes=1000) 371 | model.default_cfg = default_cfgs['greedyvig'] 372 | return model 373 | 374 | 375 | 376 | if __name__ == '__main__': 377 | 378 | # for img_size in [224,224*2,224*3,224*4,224*5]: 379 | # model = create_model( 380 | # 'GreedyViG_S', 381 | # num_classes=1000, 382 | # distillation=False, 383 | # pretrained=False, 384 | # fuse=False 385 | # ) 386 | # model = model.cuda() 387 | # model.eval() 388 | 389 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 390 | 391 | 392 | # macs = profile_macs(model, x) 393 | # print(f'\n\n!!!!! GreedyViG macs ({img_size}): {macs*10**(-9)}\n\n') 394 | 395 | 396 | for img_size in [224,224*2,224*3,224*4,224*5]: 397 | 398 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof: 399 | 400 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda') 401 | 402 | model = create_model( 403 | 'GreedyViG_S', 404 | num_classes=1000, 405 | distillation=False, 406 | pretrained=False, 407 | fuse=False 408 | ) 409 | 410 | model = model.cuda() 411 | model.eval() 412 | 413 | _ = model(x) 414 | 415 | f = open("memory_GreedyViG_model.txt", "a") 416 | 417 | 418 | f.write(prof.key_averages().table()) 419 | f.write('\n\n') 420 | f.close() 421 | print('\n\n---------------\nResults saved in memory_GreedyViG_model.txt') -------------------------------------------------------------------------------- /src/opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | 4 | def int2bool(i): 5 | i = int(i) 6 | assert i == 0 or i == 1 7 | return i == 1 8 | 9 | 10 | # The first arg parser parses out only the --config argument, this argument is used to 11 | # load a yaml file containing key-values that override the defaults for the main parser below 12 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 13 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 14 | help='YAML config file specifying default arguments') 15 | 16 | 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 19 | 20 | # Dataset / Model parameters 21 | parser.add_argument('--data', default='/scratch/dataset/inet', type=str, help='path to dataset (default: imagenet)') 22 | parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', 23 | help='Name of model to train (default: "countception"') 24 | parser.add_argument('--pretrained', action='store_true', default=False, 25 | help='Start with pretrained version of specified network (if avail)') 26 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 27 | help='Initialize model from this checkpoint (default: none)') 28 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 29 | help='Resume full model and optimizer state from checkpoint (default: none)') 30 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 31 | help='prevent resume of optimizer state when resuming model') 32 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N', 33 | help='number of label classes (default: 1000)') 34 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 35 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 36 | """ parser.add_argument('--img-size', type=int, default=None, metavar='N', 37 | help='Image patch size (default: None => model default)') """ 38 | parser.add_argument('--crop-pct', default=None, type=float, 39 | metavar='N', help='Input image center crop percent (for validation only)') 40 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 41 | help='Override mean pixel value of dataset') 42 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 43 | help='Override std deviation of of dataset') 44 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 45 | help='Image resize interpolation type (overrides model)') 46 | parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', 47 | help='input batch size for training (default: 32)') 48 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', 49 | help='ratio of validation batch size to training batch size (default: 1)') 50 | 51 | # Optimizer parameters 52 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', 53 | help='Optimizer (default: "sgd"') 54 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 55 | help='Optimizer Epsilon (default: None, use opt default)') 56 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 57 | help='Optimizer Betas (default: None, use opt default)') 58 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 59 | help='Optimizer momentum (default: 0.9)') 60 | parser.add_argument('--weight-decay', type=float, default=0.0001, 61 | help='weight decay (default: 0.0001)') 62 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 63 | help='Clip gradient norm (default: None, no clipping)') 64 | 65 | 66 | 67 | # Learning rate schedule parameters 68 | parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', 69 | help='LR scheduler (default: "step"') 70 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 71 | help='learning rate (default: 0.01)') 72 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 73 | help='learning rate noise on/off epoch percentages') 74 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 75 | help='learning rate noise limit percent (default: 0.67)') 76 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 77 | help='learning rate noise std-dev (default: 1.0)') 78 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 79 | help='learning rate cycle len multiplier (default: 1.0)') 80 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 81 | help='learning rate cycle limit') 82 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', 83 | help='warmup learning rate (default: 0.0001)') 84 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 85 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 86 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 87 | help='number of epochs to train (default: 2)') 88 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 89 | help='manual epoch number (useful on restarts)') 90 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 91 | help='epoch interval to decay LR') 92 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', 93 | help='epochs to warmup LR, if scheduler supports') 94 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 95 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 96 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 97 | help='patience epochs for Plateau LR scheduler (default: 10') 98 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 99 | help='LR decay rate (default: 0.1)') 100 | 101 | # Augmentation & regularization parameters 102 | parser.add_argument('--no-aug', action='store_true', default=False, 103 | help='Disable all training augmentation, override other train aug args') 104 | parser.add_argument('--repeated-aug', action='store_true') 105 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 106 | help='Random resize scale (default: 0.08 1.0)') 107 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', 108 | help='Random resize aspect ratio (default: 0.75 1.33)') 109 | parser.add_argument('--hflip', type=float, default=0.5, 110 | help='Horizontal flip training aug probability') 111 | parser.add_argument('--vflip', type=float, default=0., 112 | help='Vertical flip training aug probability') 113 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 114 | help='Color jitter factor (default: 0.4)') 115 | parser.add_argument('--aa', type=str, default=None, metavar='NAME', 116 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 117 | parser.add_argument('--aug-splits', type=int, default=0, 118 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 119 | parser.add_argument('--jsd', action='store_true', default=False, 120 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 121 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT', 122 | help='Random erase prob (default: 0.)') 123 | parser.add_argument('--remode', type=str, default='const', 124 | help='Random erase mode (default: "const")') 125 | parser.add_argument('--recount', type=int, default=1, 126 | help='Random erase count (default: 1)') 127 | parser.add_argument('--resplit', action='store_true', default=False, 128 | help='Do not random erase first (clean) augmentation split') 129 | parser.add_argument('--mixup', type=float, default=0.0, 130 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 131 | parser.add_argument('--cutmix', type=float, default=0.0, 132 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 133 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 134 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 135 | parser.add_argument('--mixup-prob', type=float, default=1.0, 136 | help='Probability of performing mixup or cutmix when either/both is enabled') 137 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 138 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 139 | parser.add_argument('--mixup-mode', type=str, default='batch', 140 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 141 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 142 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 143 | parser.add_argument('--smoothing', type=float, default=0.1, 144 | help='Label smoothing (default: 0.1)') 145 | parser.add_argument('--train-interpolation', type=str, default='random', 146 | help='Training interpolation (random, bilinear, bicubic default: "random")') 147 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 148 | help='Dropout rate (default: 0.)') 149 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 150 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 151 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', 152 | help='Drop path rate (default: None)') 153 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 154 | help='Drop block rate (default: None)') 155 | 156 | # Batch norm parameters (only works with gen_efficientnet based models currently) 157 | parser.add_argument('--bn-tf', action='store_true', default=False, 158 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') 159 | parser.add_argument('--bn-momentum', type=float, default=None, 160 | help='BatchNorm momentum override (if not None)') 161 | parser.add_argument('--bn-eps', type=float, default=None, 162 | help='BatchNorm epsilon override (if not None)') 163 | parser.add_argument('--sync-bn', action='store_true', 164 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 165 | parser.add_argument('--dist-bn', type=str, default='', 166 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 167 | parser.add_argument('--split-bn', action='store_true', 168 | help='Enable separate BN layers per augmentation split.') 169 | 170 | # Model Exponential Moving Average 171 | parser.add_argument('--model-ema', action='store_true', default=False, 172 | help='Enable tracking moving average of model weights') 173 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 174 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 175 | parser.add_argument('--model-ema-decay', type=float, default=0.9998, 176 | help='decay factor for model weights moving average (default: 0.9998)') 177 | 178 | # Misc 179 | parser.add_argument('--seed', type=int, default=42, metavar='S', 180 | help='random seed (default: 42)') 181 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 182 | help='how many batches to wait before logging training status') 183 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 184 | help='how many batches to wait before writing recovery checkpoint') 185 | parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', 186 | help='how many training processes to use (default: 1)') 187 | parser.add_argument('--num-gpu', type=int, default=1, 188 | help='Number of GPUS to use') 189 | parser.add_argument('--save-images', action='store_true', default=False, 190 | help='save images of input bathes every log interval for debugging') 191 | parser.add_argument('--amp', action='store_true', default=False, 192 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 193 | parser.add_argument('--apex-amp', action='store_true', default=False, 194 | help='Use NVIDIA Apex AMP mixed precision') 195 | parser.add_argument('--native-amp', action='store_true', default=False, 196 | help='Use Native Torch AMP mixed precision') 197 | parser.add_argument('--channels-last', action='store_true', default=False, 198 | help='Use channels_last memory layout') 199 | parser.add_argument('--pin-mem', action='store_true', default=False, 200 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 201 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 202 | help='disable fast prefetcher') 203 | parser.add_argument('--output', default='', type=str, metavar='PATH', 204 | help='path to output folder (default: none, current dir)') 205 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 206 | help='Best metric (default: "top1"') 207 | parser.add_argument('--tta', type=int, default=0, metavar='N', 208 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 209 | parser.add_argument("--local-rank","--local_rank", default=0, type=int) 210 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 211 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 212 | 213 | parser.add_argument("--init-method", default='env://', type=str) 214 | parser.add_argument("--train-url", type=str) 215 | # newly added 216 | parser.add_argument('--attn-ratio', type=float, default=1., 217 | help='attention ratio') 218 | parser.add_argument("--pretrain-path", default=None, type=str) 219 | parser.add_argument("--evaluate", action='store_true', default=False, 220 | help='whether evaluate the model') 221 | 222 | 223 | 224 | # ours 225 | parser.add_argument('--knn', default=9, type=int) 226 | parser.add_argument('--use-reduce-ratios', type=int2bool, default=0) # 1 == True 227 | parser.add_argument('--img-size', default=224, type=int) 228 | 229 | parser.add_argument('--use-shift', type=int2bool, default=0) # 1 == True 230 | parser.add_argument('--adapt-knn', type=int2bool, default=0) # 1 == True 231 | 232 | 233 | 234 | def parse_args(): 235 | # Do we have a config file to parse? 236 | args_config, remaining = config_parser.parse_known_args() 237 | if args_config.config: 238 | with open(args_config.config, 'r') as f: 239 | cfg = yaml.safe_load(f) 240 | parser.set_defaults(**cfg) 241 | 242 | # The main arg parser parses the rest of the args, the usual 243 | # defaults will have been overridden if config file specified. 244 | args = parser.parse_args(remaining) 245 | 246 | # Cache the args as a text string to save them in the output dir later 247 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 248 | return args, args_text -------------------------------------------------------------------------------- /src/gcn_lib/torch_vertex.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from .torch_nn import BasicConv, batched_index_select, act_layer 6 | from .torch_edge import DenseDilatedKnnGraph 7 | from .pos_embed import get_2d_relative_pos_embed 8 | import torch.nn.functional as F 9 | from timm.models.layers import DropPath 10 | import sys 11 | from .torch_local import window_partition, window_reverse, PatchEmbed, window_partition_channel_last 12 | import time 13 | 14 | from einops import rearrange, repeat 15 | 16 | 17 | class MRConv2d(nn.Module): 18 | """ 19 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type 20 | """ 21 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 22 | super(MRConv2d, self).__init__() 23 | self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias) 24 | 25 | def forward(self, x, edge_index, y=None): 26 | x_i = batched_index_select(x, edge_index[1]) 27 | if y is not None: 28 | x_j = batched_index_select(y, edge_index[0]) 29 | else: 30 | x_j = batched_index_select(x, edge_index[0]) 31 | x_j, _ = torch.max(x_j - x_i, -1, keepdim=True) 32 | b, c, n, _ = x.shape 33 | x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _) 34 | return self.nn(x) 35 | 36 | 37 | class EdgeConv2d(nn.Module): 38 | """ 39 | Edge convolution layer (with activation, batch normalization) for dense data type 40 | """ 41 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 42 | super(EdgeConv2d, self).__init__() 43 | self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias) 44 | 45 | def forward(self, x, edge_index, y=None): 46 | x_i = batched_index_select(x, edge_index[1]) 47 | if y is not None: 48 | x_j = batched_index_select(y, edge_index[0]) 49 | else: 50 | x_j = batched_index_select(x, edge_index[0]) 51 | max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) 52 | return max_value 53 | 54 | 55 | class GraphSAGE(nn.Module): 56 | """ 57 | GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216) for dense data type 58 | """ 59 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 60 | super(GraphSAGE, self).__init__() 61 | self.nn1 = BasicConv([in_channels, in_channels], act, norm, bias) 62 | self.nn2 = BasicConv([in_channels*2, out_channels], act, norm, bias) 63 | 64 | def forward(self, x, edge_index, y=None): 65 | if y is not None: 66 | x_j = batched_index_select(y, edge_index[0]) 67 | else: 68 | x_j = batched_index_select(x, edge_index[0]) 69 | x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True) 70 | return self.nn2(torch.cat([x, x_j], dim=1)) 71 | 72 | 73 | class GINConv2d(nn.Module): 74 | """ 75 | GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for dense data type 76 | """ 77 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): 78 | super(GINConv2d, self).__init__() 79 | self.nn = BasicConv([in_channels, out_channels], act, norm, bias) 80 | eps_init = 0.0 81 | self.eps = nn.Parameter(torch.Tensor([eps_init])) 82 | 83 | def forward(self, x, edge_index, y=None): 84 | if y is not None: 85 | x_j = batched_index_select(y, edge_index[0]) 86 | else: 87 | x_j = batched_index_select(x, edge_index[0]) 88 | x_j = torch.sum(x_j, -1, keepdim=True) 89 | return self.nn((1 + self.eps) * x + x_j) 90 | 91 | 92 | class GraphConv2d(nn.Module): 93 | """ 94 | Static graph convolution layer 95 | """ 96 | def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True): 97 | super(GraphConv2d, self).__init__() 98 | if conv == 'edge': 99 | self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias) 100 | elif conv == 'mr': 101 | self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias) 102 | 103 | elif conv == 'sage': 104 | self.gconv = GraphSAGE(in_channels, out_channels, act, norm, bias) 105 | elif conv == 'gin': 106 | self.gconv = GINConv2d(in_channels, out_channels, act, norm, bias) 107 | else: 108 | raise NotImplementedError('conv:{} is not supported'.format(conv)) 109 | 110 | def forward(self, x, edge_index, y=None): 111 | return self.gconv(x, edge_index, y) 112 | 113 | 114 | class DyGraphConv2d(GraphConv2d): 115 | """ 116 | Dynamic graph convolution layer 117 | """ 118 | def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu', 119 | norm=None, bias=True, stochastic=False, epsilon=0.0, r=1): 120 | super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias) 121 | self.k = kernel_size 122 | self.d = dilation 123 | self.r = r 124 | self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) 125 | self.debug = True 126 | 127 | def _mask_edge_index(self, edge_index, adj_mask): 128 | # edge_index: [2, B, N, k] 129 | # adj_mask: [B, N, k] (=1 -> keep / =0 -> masked) 130 | 131 | if adj_mask is None: 132 | return edge_index 133 | 134 | 135 | # edge_index_j: [B, N, k] 136 | edge_index_j = edge_index[0] 137 | edge_index_i = edge_index[1] 138 | adj_mask_inv = torch.ones_like(adj_mask) - adj_mask 139 | # adj_mask: [[1,1,0,0], 140 | # [1,1,1,0]] 141 | # 142 | # edge_index_j: [[14,43,20,21], 143 | # [18,12,32,24]] 144 | # 145 | # edge_index_i: [[0,0,0,0], 146 | # [1,1,1,1]] 147 | # 148 | # output: [[14,43,0,0], 149 | # [18,12,32,1]] 150 | edge_index_j = (edge_index_j * adj_mask) + (edge_index_i * adj_mask_inv) 151 | 152 | return torch.stack((edge_index_j, edge_index_i), dim=0).long() 153 | 154 | 155 | 156 | def forward(self, x, relative_pos=None, adj_mask = None): 157 | # print('Doing gnn') 158 | B, C, H, W = x.shape 159 | y = None 160 | if self.r > 1: 161 | y = F.avg_pool2d(x, self.r, self.r) 162 | # print(f'y: {y.shape}') 163 | y = y.reshape(B, C, -1, 1).contiguous() 164 | x = x.reshape(B, C, -1, 1).contiguous() 165 | 166 | edge_index = self.dilated_knn_graph(x, y, relative_pos) 167 | 168 | edge_index = self._mask_edge_index(edge_index, adj_mask) 169 | 170 | x = super(DyGraphConv2d, self).forward(x, edge_index, y) 171 | 172 | 173 | if self.debug: 174 | return x.reshape(x.shape[0], -1, H, W).contiguous(), edge_index 175 | return x.reshape(x.shape[0], -1, H, W).contiguous() 176 | 177 | 178 | 179 | 180 | 181 | 182 | class Grapher(nn.Module): 183 | """ 184 | Grapher module with graph convolution and fc layers 185 | """ 186 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, 187 | bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False): 188 | super(Grapher, self).__init__() 189 | self.channels = in_channels 190 | self.n = n 191 | self.r = r 192 | self.fc1 = nn.Sequential( 193 | nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), 194 | nn.BatchNorm2d(in_channels), 195 | ) 196 | self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv, 197 | act, norm, bias, stochastic, epsilon, r) 198 | self.fc2 = nn.Sequential( 199 | nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), 200 | nn.BatchNorm2d(in_channels), 201 | ) 202 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 203 | self.relative_pos = None 204 | if relative_pos: 205 | print('using relative_pos') 206 | relative_pos_tensor = torch.from_numpy(np.float32(get_2d_relative_pos_embed(in_channels, 207 | int(n**0.5)))).unsqueeze(0).unsqueeze(1) 208 | relative_pos_tensor = F.interpolate( 209 | relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False) 210 | self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False) 211 | 212 | def _get_relative_pos(self, relative_pos, H, W): 213 | if relative_pos is None or H * W == self.n: 214 | return relative_pos 215 | else: 216 | N = H * W 217 | N_reduced = N // (self.r * self.r) 218 | return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0) 219 | 220 | def forward(self, x): 221 | _tmp = x 222 | x = self.fc1(x) 223 | B, C, H, W = x.shape 224 | relative_pos = self._get_relative_pos(self.relative_pos, H, W) 225 | 226 | x, edge_index = self.graph_conv(x, relative_pos) 227 | x = self.fc2(x) 228 | x = self.drop_path(x) + _tmp 229 | return x 230 | 231 | 232 | class WindowGrapher(nn.Module): 233 | """ 234 | Local Grapher module with graph convolution and fc layers 235 | """ 236 | def __init__( 237 | self, 238 | in_channels, 239 | kernel_size=9, 240 | windows_size = 7, 241 | dilation=1, 242 | conv='edge', 243 | act='relu', 244 | norm=None, 245 | bias=True, 246 | stochastic=False, 247 | epsilon=0.0, 248 | drop_path=0.0, 249 | relative_pos=False, 250 | shift_size = 0, 251 | r = 1, 252 | input_resolution = (224//4,224//4), 253 | adapt_knn = False): 254 | super(WindowGrapher, self).__init__() 255 | 256 | if min(input_resolution) <= windows_size: 257 | # if window size is larger than input resolution, we don't partition windows 258 | shift_size = 0 259 | windows_size = min(input_resolution) 260 | assert 0 <= shift_size < windows_size, "shift_size must in 0-window_size" 261 | 262 | 263 | max_connection_allowed = (windows_size // r)**2 264 | if shift_size > 0: 265 | assert shift_size % r == 0 266 | max_connection_allowed = (shift_size // r)**2 267 | 268 | assert kernel_size <= max_connection_allowed, f'trying k = {kernel_size} while the max can be: {max_connection_allowed}' 269 | 270 | 271 | self.windows_size = windows_size 272 | self.shift_size = shift_size 273 | self.r = r 274 | 275 | 276 | n_nodes = self.windows_size * self.windows_size 277 | 278 | self.fc1 = nn.Sequential( 279 | nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0), 280 | nn.BatchNorm2d(in_channels), 281 | ) 282 | 283 | self.graph_conv = DyGraphConv2d(in_channels, (in_channels * 2), kernel_size, dilation, conv, 284 | act, norm, bias, stochastic, epsilon, r = r) 285 | 286 | self.fc2 = nn.Sequential( 287 | nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0), 288 | nn.BatchNorm2d(in_channels), 289 | ) 290 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 291 | 292 | 293 | 294 | 295 | 296 | self.relative_pos = None 297 | if relative_pos: 298 | print('using relative_pos') 299 | relative_pos_tensor = torch.from_numpy(np.float32(get_2d_relative_pos_embed(in_channels, 300 | int(n_nodes**0.5)))).unsqueeze(0).unsqueeze(1) 301 | relative_pos_tensor = F.interpolate( 302 | relative_pos_tensor, size=(n_nodes, n_nodes//(r*r)), mode='bicubic', align_corners=False) 303 | self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False) 304 | 305 | 306 | attn_mask = None 307 | adj_mask = None 308 | if self.shift_size > 0: 309 | print(f'Shifting windows!') 310 | H, W = input_resolution 311 | # calculate attention mask for SW-MSA 312 | img_mask = torch.zeros((1, 1, H, W)) 313 | h_slices = (slice(0, -self.windows_size), 314 | slice(-self.windows_size, -self.shift_size), 315 | slice(-self.shift_size, None)) 316 | w_slices = (slice(0, -self.windows_size), 317 | slice(-self.windows_size, -self.shift_size), 318 | slice(-self.shift_size, None)) 319 | 320 | 321 | cnt = 0 322 | for h in h_slices: 323 | for w in w_slices: 324 | img_mask[:, :, h, w] = cnt 325 | cnt += 1 326 | 327 | # print(img_mask) 328 | # print('\n\n') 329 | 330 | mask_windows_unf = window_partition(img_mask, self.windows_size) # nW, 1, windows_size, windows_size, 331 | 332 | 333 | mask_windows = mask_windows_unf.view(-1, self.windows_size * self.windows_size) 334 | 335 | if self.r > 1: 336 | mask_windows_y = F.max_pool2d(mask_windows_unf, self.r, self.r) 337 | mask_windows_y = mask_windows_y.view(-1, (self.windows_size // self.r) * (self.windows_size // self.r)) 338 | else: 339 | mask_windows_y = mask_windows 340 | 341 | attn_mask = mask_windows_y.unsqueeze(1) - mask_windows.unsqueeze(2) # nW x N x (N // r) 342 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(1000000.0)).masked_fill(attn_mask == 0, float(0.0)) 343 | 344 | # Get n_connections_allowed for each node in each windows 345 | if adapt_knn: 346 | print('Adapting knn!') 347 | adj_mask = torch.empty((attn_mask.shape[0], attn_mask.shape[1], kernel_size)) # nW x N x k 348 | for w in range(attn_mask.shape[0]): 349 | for i in range(attn_mask.shape[1]): 350 | all_connection = torch.sum(attn_mask[w,i] == 0) 351 | scaled_knn = (kernel_size * all_connection) // (self.windows_size * (self.windows_size // r)) 352 | n_connections_allowed = int(max(scaled_knn, 3.0)) 353 | # print(f'Window: {w} node {i} - allowed_connection = {all_connection} (k = {n_connections_allowed})') 354 | masked = torch.zeros(kernel_size - n_connections_allowed) 355 | un_masked = torch.ones(n_connections_allowed) 356 | adj_mask[w,i] = torch.cat([un_masked,masked],dim=0) 357 | 358 | self.register_buffer("attn_mask", attn_mask) 359 | self.register_buffer("adj_mask", adj_mask) 360 | 361 | 362 | def _merge_pos_attn(self, batch_size): 363 | if self.attn_mask is None: 364 | return self.relative_pos 365 | 366 | if self.relative_pos is None: 367 | print('Should not be here..') 368 | self.relative_pos = torch.zeros((1, self.attn_mask.shape[1], self.attn_mask.shape[2])).to(self.attn_mask.device) 369 | 370 | nW_nGH = self.attn_mask.shape[0] 371 | #print(f'Attention mask: {self.attn_mask.shape}') 372 | return self.relative_pos.repeat(nW_nGH*batch_size,1,1) + self.attn_mask.repeat(batch_size, 1, 1) # B, N, N 373 | 374 | 375 | 376 | 377 | 378 | def forward(self, x): 379 | 380 | 381 | _tmp = x 382 | x = self.fc1(x) 383 | B, C, H, W = x.shape 384 | 385 | # cyclic shift 386 | if self.shift_size > 0: 387 | x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3)) 388 | 389 | 390 | 391 | x = window_partition(x, window_size = self.windows_size) 392 | 393 | # merge relative_pos w/ attn_mask TODO 394 | pos_att = self._merge_pos_attn(batch_size=B) 395 | # pos_att: b*nW, N, N 396 | 397 | adj_mask = None 398 | if self.adj_mask is not None: 399 | adj_mask = self.adj_mask.repeat(B,1,1) 400 | x, edge_index = self.graph_conv(x, pos_att, adj_mask) 401 | 402 | 403 | x = window_reverse(x, self.windows_size, H=H, W=W) 404 | 405 | # reverse cyclic shift 406 | if self.shift_size > 0: 407 | x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3)) 408 | 409 | x = self.fc2(x) 410 | 411 | x = self.drop_path(x) + _tmp 412 | 413 | return x 414 | 415 | -------------------------------------------------------------------------------- /src/train_trasnfer_learning.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import datetime 4 | import os 5 | import time 6 | 7 | import torch 8 | import torch.utils.data 9 | import torchvision 10 | import torchvision.models.detection 11 | import torchvision.models.detection.mask_rcnn 12 | import utils 13 | 14 | 15 | import torch.nn as nn 16 | import sys 17 | import numpy as np 18 | from opt_transfer import get_args_parser 19 | import errno 20 | 21 | import wandb 22 | import random 23 | 24 | import torch.backends.cudnn as cudnn 25 | import warnings 26 | 27 | from model.transfer_models import get_model 28 | from dataloaders.celeba_hq import get_celeba 29 | 30 | from sklearn.metrics import confusion_matrix 31 | import seaborn as sn 32 | import pandas as pd 33 | import matplotlib.pyplot as plt 34 | import utils 35 | 36 | def set_seed(seed): 37 | random.seed(seed) 38 | torch.manual_seed(seed) 39 | cudnn.deterministic = True 40 | cudnn.benchmark = False 41 | np.random.seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | 44 | os.environ["PYTHONHASHSEED"] = str(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | 47 | 48 | def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): 49 | model.train() 50 | 51 | 52 | metric_logger = utils.MetricLogger(delimiter=" ") 53 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) 54 | metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) 55 | 56 | header = f"Epoch: [{epoch}]" 57 | for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 58 | 59 | if image.shape[0] <= 1: 60 | continue 61 | 62 | start_time = time.time() 63 | image, target = image.to(device), target.to(device) 64 | 65 | model.zero_grad() 66 | with torch.cuda.amp.autocast(enabled=scaler is not None): 67 | output = model(image) 68 | loss = criterion(output, target) 69 | 70 | optimizer.zero_grad() 71 | if scaler is not None: 72 | scaler.scale(loss).backward() 73 | if args.clip_grad_norm is not None: 74 | # we should unscale the gradients of optimizer's assigned params if do gradient clipping 75 | scaler.unscale_(optimizer) 76 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 77 | scaler.step(optimizer) 78 | scaler.update() 79 | else: 80 | loss.backward() 81 | if args.clip_grad_norm is not None: 82 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) 83 | optimizer.step() 84 | 85 | if model_ema and i % args.model_ema_steps == 0: 86 | model_ema.update_parameters(model) 87 | if epoch < args.lr_warmup_epochs: 88 | # Reset ema buffer to keep copying weights during warmup period 89 | model_ema.n_averaged.fill_(0) 90 | 91 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 92 | batch_size = image.shape[0] 93 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 94 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 95 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 96 | metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) 97 | 98 | 99 | 100 | return metric_logger 101 | 102 | 103 | def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): 104 | model.eval() 105 | n_threads = torch.get_num_threads() 106 | # FIXME remove this and make paste_masks_in_image run on the GPU 107 | torch.set_num_threads(1) 108 | metric_logger = utils.MetricLogger(delimiter=" ") 109 | header = f"Test: {log_suffix}" 110 | 111 | num_processed_samples = 0 112 | 113 | with torch.inference_mode(): 114 | for i,(image, target) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 115 | image = image.to(device) 116 | target = target.to(device) 117 | 118 | if torch.cuda.is_available(): 119 | torch.cuda.synchronize() 120 | 121 | output = model(image) 122 | loss = criterion(output, target) 123 | 124 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 125 | # FIXME need to take into account that the datasets 126 | # could have been padded in distributed setup 127 | batch_size = image.shape[0] 128 | metric_logger.update(loss=loss.item()) 129 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 130 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 131 | num_processed_samples += batch_size 132 | 133 | 134 | 135 | # gather the stats from all processes 136 | 137 | num_processed_samples = utils.reduce_across_processes(num_processed_samples) 138 | 139 | metric_logger.synchronize_between_processes() 140 | 141 | print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") 142 | 143 | torch.set_num_threads(n_threads) 144 | 145 | 146 | return metric_logger.acc1.global_avg, metric_logger.acc5.global_avg, metric_logger.loss.global_avg 147 | 148 | 149 | def main(args): 150 | 151 | 152 | if args.seed is not None: 153 | set_seed(args.seed) 154 | 155 | model, params, trainable_parameters, n_classes = get_model( 156 | model_type = args.model_type, 157 | use_shift = args.use_shift, 158 | adapt_knn= args.adapt_knn, 159 | checkpoint = args.checkpoint, 160 | freezed = args.freezed, 161 | dataset = args.dataset, 162 | crop_size=args.crop_size) 163 | 164 | 165 | if args.num_gpu > 1: 166 | print('Using DataParallel') 167 | # model = torch.nn.DataParallel(model) 168 | model = nn.DataParallel(model) 169 | 170 | print("Let's use", torch.cuda.device_count(), "GPUs!") 171 | 172 | print('\n\n--------------------------------------') 173 | 174 | print(f"Parameters: {params}") 175 | print(f"Trainable Parameters: {trainable_parameters}") 176 | 177 | print('--------------------------------------\n\n') 178 | 179 | if not args.test_only: 180 | wandb.log({ 181 | f'params/parameters':params, 182 | f'params/trainable_params':trainable_parameters 183 | }) 184 | 185 | 186 | if args.save_dir and not args.test_only: 187 | print(f'# Results will be saved in {args.save_dir}') 188 | try: 189 | os.makedirs(args.save_dir) 190 | except OSError as e: 191 | if e.errno != errno.EEXIST: 192 | raise 193 | 194 | 195 | print(args) 196 | 197 | assert torch.cuda.is_available(), '# --- Cuda Not Available!!' 198 | device = 'cuda' 199 | model.to(device) 200 | 201 | 202 | if args.use_deterministic_algorithms: 203 | torch.use_deterministic_algorithms(True) 204 | 205 | # Data loading code 206 | print("Loading data") 207 | 208 | if args.crop_size == None: 209 | crop_size = 224 210 | if '256' in args.model_type: 211 | crop_size = 256 212 | 213 | if hasattr(model, 'default_cfg'): 214 | default_cfg = model.default_cfg 215 | input_sizes = list(default_cfg['input_size']) 216 | assert input_sizes[1] == input_sizes[2] and len(input_sizes) == 3, f'input sizes for model {args.model_type} is {input_sizes}' 217 | crop_size = int(input_sizes[1]) 218 | else: 219 | crop_size = args.crop_size 220 | 221 | 222 | 223 | print(f'\n\n!!!Dataloaders with crop size: {crop_size}\n\n') 224 | 225 | args.distributed = False 226 | if args.dataset == 'CelebA': 227 | drop_last=False 228 | if 'pvig' in args.model_type: 229 | drop_last=True 230 | 231 | train_loader, valid_loader, test_loader, train_sampler, num_classes, _ = get_celeba(args, get_train_sampler=True, crop_size=crop_size, drop_last=drop_last) 232 | else: 233 | raise NotImplementedError(f'Dataset {args.dataset} not yet implemented ') 234 | 235 | if valid_loader is not None: 236 | valid_loader = None if len(valid_loader) == 0 else valid_loader 237 | 238 | 239 | print(f'Train Dataloader: {len(train_loader)}') 240 | if valid_loader is not None: 241 | print(f'Valid Dataloader: {len(valid_loader)}') 242 | print(f'Test Dataloader: {len(test_loader)}') 243 | print(f'Founded classes: {num_classes}') 244 | 245 | assert num_classes == n_classes, f'Founded {num_classes} classes in the dataset, while you specified {n_classes}' 246 | 247 | 248 | if args.loss == 'cross_entropy': 249 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) 250 | elif args.loss == 'nll': 251 | criterion = nn.NLLLoss() 252 | 253 | 254 | custom_keys_weight_decay = [] 255 | if args.bias_weight_decay is not None: 256 | custom_keys_weight_decay.append(("bias", args.bias_weight_decay)) 257 | if args.transformer_embedding_decay is not None: 258 | for key in ["class_token", "position_embedding", "relative_position_bias_table"]: 259 | custom_keys_weight_decay.append((key, args.transformer_embedding_decay)) 260 | 261 | parameters = utils.set_weight_decay( 262 | model, 263 | args.weight_decay, 264 | norm_weight_decay=args.norm_weight_decay, 265 | custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None, 266 | ) 267 | 268 | opt_name = args.opt.lower() 269 | if opt_name.startswith("sgd"): 270 | optimizer = torch.optim.SGD( 271 | parameters, 272 | lr=args.lr, 273 | momentum=args.momentum, 274 | weight_decay=args.weight_decay, 275 | nesterov="nesterov" in opt_name, 276 | ) 277 | elif opt_name == "rmsprop": 278 | optimizer = torch.optim.RMSprop( 279 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 280 | ) 281 | elif opt_name == "adamw": 282 | optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) 283 | elif opt_name == "adam": 284 | optimizer = torch.optim.Adam(parameters, lr=args.lr) 285 | else: 286 | raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") 287 | 288 | scaler = torch.cuda.amp.GradScaler() if args.amp else None 289 | 290 | args.lr_scheduler = args.lr_scheduler.lower() 291 | if args.lr_scheduler == "steplr": 292 | main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 293 | elif args.lr_scheduler == "cosineannealinglr": 294 | main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 295 | optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min 296 | ) 297 | elif args.lr_scheduler == "exponentiallr": 298 | main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) 299 | elif args.lr_scheduler == "constant": 300 | main_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=args.epochs - args.lr_warmup_epochs) 301 | else: 302 | raise RuntimeError( 303 | f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR " 304 | "are supported." 305 | ) 306 | 307 | if args.lr_warmup_epochs > 0: 308 | if args.lr_warmup_method == "linear": 309 | warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( 310 | optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs 311 | ) 312 | elif args.lr_warmup_method == "constant": 313 | warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( 314 | optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs 315 | ) 316 | else: 317 | raise RuntimeError( 318 | f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported." 319 | ) 320 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR( 321 | optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] 322 | ) 323 | else: 324 | lr_scheduler = main_lr_scheduler 325 | 326 | 327 | 328 | if args.resume: 329 | checkpoint = torch.load(args.resume, map_location="cpu") 330 | model.load_state_dict(checkpoint["model"]) 331 | if not args.test_only: 332 | optimizer.load_state_dict(checkpoint["optimizer"]) 333 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 334 | args.start_epoch = checkpoint["epoch"] + 1 335 | if scaler and "scaler" in checkpoint: 336 | scaler.load_state_dict(checkpoint["scaler"]) 337 | 338 | if args.test_only: 339 | # torch.backends.cudnn.deterministic = True 340 | test_acc, test_acc_5, test_loss = evaluate(model, criterion, test_loader, device=device) 341 | print(f'Test acc: {test_acc}') 342 | 343 | return test_acc, test_acc_5, test_loss 344 | 345 | 346 | 347 | print("Start training") 348 | start_time = time.time() 349 | for epoch in range(args.start_epoch, args.epochs): 350 | print(f'\n\n--------\nStarting epoch: {epoch}') 351 | metric_logger = train_one_epoch(model, criterion, optimizer, train_loader, device, epoch, args, None, scaler) 352 | print(f'train acc: {metric_logger.acc1.global_avg}') 353 | 354 | wandb.log({ 355 | f'train/train_loss':metric_logger.loss.global_avg, 356 | f'lr':metric_logger.lr.global_avg, 357 | f'train/train_acc@1':metric_logger.acc1.global_avg, 358 | f'train/train_acc@5':metric_logger.acc5.global_avg 359 | }, step = epoch) 360 | lr_scheduler.step() 361 | 362 | 363 | # validate after every epoch 364 | val_acc = 0.0 365 | if valid_loader is not None: 366 | val_acc, val_acc_5, val_loss = evaluate(model, criterion, valid_loader, device=device) 367 | 368 | print(f'valid acc: {val_acc}') 369 | wandb.log({ 370 | f'val/val_loss':val_loss, 371 | f'val/val_acc@1':val_acc, 372 | f'val/val_acc@5':val_acc_5 373 | }, step = epoch) 374 | 375 | 376 | # evaluate after every epoch 377 | test_acc, test_acc_5, test_loss = evaluate(model, criterion, test_loader, device=device) 378 | 379 | 380 | print(f'test acc: {test_acc}') 381 | wandb.log({ 382 | f'test/test_loss':test_loss, 383 | f'test/test_acc@1':test_acc, 384 | f'test/test_acc@5':test_acc_5 385 | }, step = epoch) 386 | 387 | if args.save_dir: 388 | checkpoint = { 389 | "model": model.state_dict(), 390 | "optimizer": optimizer.state_dict(), 391 | "lr_scheduler": lr_scheduler.state_dict(), 392 | "args": args, 393 | "epoch": epoch, 394 | "train_acc": metric_logger.acc1.global_avg, 395 | "val_acc": val_acc, 396 | "test_acc": test_acc, 397 | "params":params 398 | } 399 | if scaler: 400 | checkpoint["scaler"] = scaler.state_dict() 401 | 402 | utils.save_on_master(checkpoint, os.path.join(args.save_dir, "checkpoint.pth")) 403 | 404 | 405 | total_time = time.time() - start_time 406 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 407 | print(f'\n\n--------------\nEnd training') 408 | print(f"Training time {total_time_str}") 409 | print(f'Train acc@1: {metric_logger.acc1.global_avg}') 410 | print(f'Test acc@1: {test_acc}') 411 | 412 | 413 | wandb.log({ 414 | f'end/train_acc@1': metric_logger.acc1.global_avg, 415 | f'end/val_acc@1': val_acc, 416 | f'end/test_acc@1': test_acc 417 | }) 418 | 419 | 420 | 421 | 422 | 423 | 424 | if __name__ == "__main__": 425 | 426 | # torch.autograd.set_detect_anomaly(True) 427 | 428 | args = get_args_parser().parse_args() 429 | 430 | print('Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') 431 | print(args.save_dir) 432 | 433 | if not args.test_only: 434 | wandb.login() 435 | wandb_name = f'{args.dataset}/{args.model_type}/train_seed{args.seed}' 436 | 437 | wandb.init( 438 | # Set the project where this run will be logged 439 | project='WACV_WiGNet_transfer_learning_high_res', 440 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10) 441 | name=wandb_name, 442 | config=args) 443 | 444 | args.save_dir += wandb_name 445 | else: 446 | args.save_dir = None 447 | 448 | 449 | main(args=args) 450 | 451 | if not args.test_only: 452 | wandb.run.finish() 453 | 454 | 455 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import errno 4 | import hashlib 5 | import os 6 | import time 7 | from collections import defaultdict, deque, OrderedDict 8 | from typing import List, Optional, Tuple 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | 14 | class SmoothedValue: 15 | """Track a series of values and provide access to smoothed values over a 16 | window or the global series average. 17 | """ 18 | 19 | def __init__(self, window_size=20, fmt=None): 20 | if fmt is None: 21 | fmt = "{median:.4f} ({global_avg:.4f})" 22 | self.deque = deque(maxlen=window_size) 23 | self.total = 0.0 24 | self.count = 0 25 | self.fmt = fmt 26 | 27 | def update(self, value, n=1): 28 | self.deque.append(value) 29 | self.count += n 30 | self.total += value * n 31 | 32 | def synchronize_between_processes(self): 33 | """ 34 | Warning: does not synchronize the deque! 35 | """ 36 | t = reduce_across_processes([self.count, self.total]) 37 | t = t.tolist() 38 | self.count = int(t[0]) 39 | self.total = t[1] 40 | 41 | @property 42 | def median(self): 43 | d = torch.tensor(list(self.deque)) 44 | return d.median().item() 45 | 46 | @property 47 | def avg(self): 48 | d = torch.tensor(list(self.deque), dtype=torch.float32) 49 | return d.mean().item() 50 | 51 | @property 52 | def global_avg(self): 53 | return self.total / self.count 54 | 55 | @property 56 | def max(self): 57 | return max(self.deque) 58 | 59 | @property 60 | def value(self): 61 | return self.deque[-1] 62 | 63 | def __str__(self): 64 | return self.fmt.format( 65 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value 66 | ) 67 | 68 | 69 | class MetricLogger: 70 | def __init__(self, delimiter="\t"): 71 | self.meters = defaultdict(SmoothedValue) 72 | self.delimiter = delimiter 73 | 74 | def update(self, **kwargs): 75 | for k, v in kwargs.items(): 76 | if isinstance(v, torch.Tensor): 77 | v = v.item() 78 | assert isinstance(v, (float, int)) 79 | self.meters[k].update(v) 80 | 81 | def __getattr__(self, attr): 82 | if attr in self.meters: 83 | return self.meters[attr] 84 | if attr in self.__dict__: 85 | return self.__dict__[attr] 86 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") 87 | 88 | def __str__(self): 89 | loss_str = [] 90 | for name, meter in self.meters.items(): 91 | loss_str.append(f"{name}: {str(meter)}") 92 | return self.delimiter.join(loss_str) 93 | 94 | def synchronize_between_processes(self): 95 | for meter in self.meters.values(): 96 | meter.synchronize_between_processes() 97 | 98 | def add_meter(self, name, meter): 99 | self.meters[name] = meter 100 | 101 | def log_every(self, iterable, print_freq, header=None): 102 | i = 0 103 | if not header: 104 | header = "" 105 | start_time = time.time() 106 | end = time.time() 107 | iter_time = SmoothedValue(fmt="{avg:.4f}") 108 | data_time = SmoothedValue(fmt="{avg:.4f}") 109 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 110 | if torch.cuda.is_available(): 111 | log_msg = self.delimiter.join( 112 | [ 113 | header, 114 | "[{0" + space_fmt + "}/{1}]", 115 | "eta: {eta}", 116 | "{meters}", 117 | "time: {time}", 118 | "data: {data}", 119 | "max mem: {memory:.0f}", 120 | ] 121 | ) 122 | else: 123 | log_msg = self.delimiter.join( 124 | [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] 125 | ) 126 | MB = 1024.0 * 1024.0 127 | for obj in iterable: 128 | data_time.update(time.time() - end) 129 | yield obj 130 | iter_time.update(time.time() - end) 131 | if i % print_freq == 0: 132 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 133 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 134 | if torch.cuda.is_available(): 135 | print( 136 | log_msg.format( 137 | i, 138 | len(iterable), 139 | eta=eta_string, 140 | meters=str(self), 141 | time=str(iter_time), 142 | data=str(data_time), 143 | memory=torch.cuda.max_memory_allocated() / MB, 144 | ) 145 | ) 146 | else: 147 | print( 148 | log_msg.format( 149 | i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) 150 | ) 151 | ) 152 | i += 1 153 | end = time.time() 154 | total_time = time.time() - start_time 155 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 156 | print(f"{header} Total time: {total_time_str}") 157 | 158 | 159 | class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): 160 | """Maintains moving averages of model parameters using an exponential decay. 161 | ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` 162 | `torch.optim.swa_utils.AveragedModel `_ 163 | is used to compute the EMA. 164 | """ 165 | 166 | def __init__(self, model, decay, device="cpu"): 167 | def ema_avg(avg_model_param, model_param, num_averaged): 168 | return decay * avg_model_param + (1 - decay) * model_param 169 | 170 | super().__init__(model, device, ema_avg, use_buffers=True) 171 | 172 | 173 | def accuracy(output, target, topk=(1,)): 174 | """Computes the accuracy over the k top predictions for the specified values of k""" 175 | with torch.inference_mode(): 176 | maxk = max(topk) 177 | batch_size = target.size(0) 178 | if target.ndim == 2: 179 | target = target.max(dim=1)[1] 180 | 181 | _, pred = output.topk(maxk, 1, True, True) 182 | pred = pred.t() 183 | correct = pred.eq(target[None]) 184 | 185 | res = [] 186 | for k in topk: 187 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 188 | res.append(correct_k * (100.0 / batch_size)) 189 | return res 190 | 191 | 192 | def mkdir(path): 193 | try: 194 | os.makedirs(path) 195 | except OSError as e: 196 | if e.errno != errno.EEXIST: 197 | raise 198 | 199 | 200 | def setup_for_distributed(is_master): 201 | """ 202 | This function disables printing when not in master process 203 | """ 204 | import builtins as __builtin__ 205 | 206 | builtin_print = __builtin__.print 207 | 208 | def print(*args, **kwargs): 209 | force = kwargs.pop("force", False) 210 | if is_master or force: 211 | builtin_print(*args, **kwargs) 212 | 213 | __builtin__.print = print 214 | 215 | 216 | def is_dist_avail_and_initialized(): 217 | if not dist.is_available(): 218 | return False 219 | if not dist.is_initialized(): 220 | return False 221 | return True 222 | 223 | 224 | def get_world_size(): 225 | if not is_dist_avail_and_initialized(): 226 | return 1 227 | return dist.get_world_size() 228 | 229 | 230 | def get_rank(): 231 | if not is_dist_avail_and_initialized(): 232 | return 0 233 | return dist.get_rank() 234 | 235 | 236 | def is_main_process(distributed = False): 237 | return get_rank() == 0 or not distributed 238 | 239 | 240 | def save_on_master(*args, **kwargs): 241 | if is_main_process(): 242 | torch.save(*args, **kwargs) 243 | 244 | 245 | def init_distributed_mode(args): 246 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 247 | args.rank = int(os.environ["RANK"]) 248 | args.world_size = int(os.environ["WORLD_SIZE"]) 249 | args.gpu = int(os.environ["LOCAL_RANK"]) 250 | # elif "SLURM_PROCID" in os.environ: 251 | # args.rank = int(os.environ["SLURM_PROCID"]) 252 | # args.gpu = args.rank % torch.cuda.device_count() 253 | elif hasattr(args, "rank"): 254 | pass 255 | else: 256 | print("Not using distributed mode") 257 | args.distributed = False 258 | return 259 | 260 | args.distributed = True 261 | 262 | torch.cuda.set_device(args.gpu) 263 | args.dist_backend = "nccl" 264 | print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) 265 | torch.distributed.init_process_group( 266 | backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank 267 | ) 268 | torch.distributed.barrier() 269 | setup_for_distributed(args.rank == 0) 270 | 271 | 272 | def average_checkpoints(inputs): 273 | """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from: 274 | https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16 275 | 276 | Args: 277 | inputs (List[str]): An iterable of string paths of checkpoints to load from. 278 | Returns: 279 | A dict of string keys mapping to various values. The 'model' key 280 | from the returned dict should correspond to an OrderedDict mapping 281 | string parameter names to torch Tensors. 282 | """ 283 | params_dict = OrderedDict() 284 | params_keys = None 285 | new_state = None 286 | num_models = len(inputs) 287 | for fpath in inputs: 288 | with open(fpath, "rb") as f: 289 | state = torch.load( 290 | f, 291 | map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), 292 | ) 293 | # Copies over the settings from the first checkpoint 294 | if new_state is None: 295 | new_state = state 296 | model_params = state["model"] 297 | model_params_keys = list(model_params.keys()) 298 | if params_keys is None: 299 | params_keys = model_params_keys 300 | elif params_keys != model_params_keys: 301 | raise KeyError( 302 | f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}" 303 | ) 304 | for k in params_keys: 305 | p = model_params[k] 306 | if isinstance(p, torch.HalfTensor): 307 | p = p.float() 308 | if k not in params_dict: 309 | params_dict[k] = p.clone() 310 | # NOTE: clone() is needed in case of p is a shared parameter 311 | else: 312 | params_dict[k] += p 313 | averaged_params = OrderedDict() 314 | for k, v in params_dict.items(): 315 | averaged_params[k] = v 316 | if averaged_params[k].is_floating_point(): 317 | averaged_params[k].div_(num_models) 318 | else: 319 | averaged_params[k] //= num_models 320 | new_state["model"] = averaged_params 321 | return new_state 322 | 323 | 324 | def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True): 325 | """ 326 | This method can be used to prepare weights files for new models. It receives as 327 | input a model architecture and a checkpoint from the training script and produces 328 | a file with the weights ready for release. 329 | 330 | Examples: 331 | from torchvision import models as M 332 | 333 | # Classification 334 | model = M.mobilenet_v3_large(weights=None) 335 | print(store_model_weights(model, './class.pth')) 336 | 337 | # Quantized Classification 338 | model = M.quantization.mobilenet_v3_large(weights=None, quantize=False) 339 | model.fuse_model(is_qat=True) 340 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack') 341 | _ = torch.ao.quantization.prepare_qat(model, inplace=True) 342 | print(store_model_weights(model, './qat.pth')) 343 | 344 | # Object Detection 345 | model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None) 346 | print(store_model_weights(model, './obj.pth')) 347 | 348 | # Segmentation 349 | model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True) 350 | print(store_model_weights(model, './segm.pth', strict=False)) 351 | 352 | Args: 353 | model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes. 354 | checkpoint_path (str): The path of the checkpoint we will load. 355 | checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored. 356 | Default: "model". 357 | strict (bool): whether to strictly enforce that the keys 358 | in :attr:`state_dict` match the keys returned by this module's 359 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 360 | 361 | Returns: 362 | output_path (str): The location where the weights are saved. 363 | """ 364 | # Store the new model next to the checkpoint_path 365 | checkpoint_path = os.path.abspath(checkpoint_path) 366 | output_dir = os.path.dirname(checkpoint_path) 367 | 368 | # Deep copy to avoid side effects on the model object. 369 | model = copy.deepcopy(model) 370 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 371 | 372 | # Load the weights to the model to validate that everything works 373 | # and remove unnecessary weights (such as auxiliaries, etc.) 374 | if checkpoint_key == "model_ema": 375 | del checkpoint[checkpoint_key]["n_averaged"] 376 | torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.") 377 | model.load_state_dict(checkpoint[checkpoint_key], strict=strict) 378 | 379 | tmp_path = os.path.join(output_dir, str(model.__hash__())) 380 | torch.save(model.state_dict(), tmp_path) 381 | 382 | sha256_hash = hashlib.sha256() 383 | with open(tmp_path, "rb") as f: 384 | # Read and update hash string value in blocks of 4K 385 | for byte_block in iter(lambda: f.read(4096), b""): 386 | sha256_hash.update(byte_block) 387 | hh = sha256_hash.hexdigest() 388 | 389 | output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth") 390 | os.replace(tmp_path, output_path) 391 | 392 | return output_path 393 | 394 | 395 | def reduce_across_processes(val): 396 | if not is_dist_avail_and_initialized(): 397 | # nothing to sync, but we still convert to tensor for consistency with the distributed case. 398 | return torch.tensor(val) 399 | 400 | t = torch.tensor(val, device="cuda") 401 | dist.barrier() 402 | dist.all_reduce(t) 403 | return t 404 | 405 | 406 | def set_weight_decay( 407 | model: torch.nn.Module, 408 | weight_decay: float, 409 | norm_weight_decay: Optional[float] = None, 410 | norm_classes: Optional[List[type]] = None, 411 | custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None, 412 | ): 413 | if not norm_classes: 414 | norm_classes = [ 415 | torch.nn.modules.batchnorm._BatchNorm, 416 | torch.nn.LayerNorm, 417 | torch.nn.GroupNorm, 418 | torch.nn.modules.instancenorm._InstanceNorm, 419 | torch.nn.LocalResponseNorm, 420 | ] 421 | norm_classes = tuple(norm_classes) 422 | 423 | params = { 424 | "other": [], 425 | "norm": [], 426 | } 427 | params_weight_decay = { 428 | "other": weight_decay, 429 | "norm": norm_weight_decay, 430 | } 431 | custom_keys = [] 432 | if custom_keys_weight_decay is not None: 433 | for key, weight_decay in custom_keys_weight_decay: 434 | params[key] = [] 435 | params_weight_decay[key] = weight_decay 436 | custom_keys.append(key) 437 | 438 | def _add_params(module, prefix=""): 439 | for name, p in module.named_parameters(recurse=False): 440 | if not p.requires_grad: 441 | continue 442 | is_custom_key = False 443 | for key in custom_keys: 444 | target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name 445 | if key == target_name: 446 | params[key].append(p) 447 | is_custom_key = True 448 | break 449 | if not is_custom_key: 450 | if norm_weight_decay is not None and isinstance(module, norm_classes): 451 | params["norm"].append(p) 452 | else: 453 | params["other"].append(p) 454 | 455 | for child_name, child_module in module.named_children(): 456 | child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name 457 | _add_params(child_module, prefix=child_prefix) 458 | 459 | _add_params(model) 460 | 461 | param_groups = [] 462 | for key in params: 463 | if len(params[key]) > 0: 464 | param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]}) 465 | return param_groups 466 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | warnings.filterwarnings('ignore') 4 | import argparse 5 | import time 6 | import os 7 | import logging 8 | from collections import OrderedDict 9 | from contextlib import suppress 10 | from datetime import datetime 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torchvision.utils 15 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 16 | 17 | from timm.data import ImageDataset as Dataset, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset #, create_loader 18 | from timm.models import create_model, resume_checkpoint #, convert_splitbn_model 19 | from timm.utils import * 20 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy 21 | from timm.optim import create_optimizer 22 | from timm.scheduler import create_scheduler 23 | from timm.utils import ApexScaler, NativeScaler 24 | 25 | from data import create_loader 26 | from model import pyramid_vig 27 | from model import wignn 28 | from model import wignn_256 29 | 30 | import sys 31 | 32 | from opt import parse_args 33 | import wandb 34 | 35 | import torch.backends.cudnn as cudnn 36 | import numpy as np 37 | import random 38 | 39 | try: 40 | from apex import amp 41 | from apex.parallel import DistributedDataParallel as ApexDDP 42 | from apex.parallel import convert_syncbn_model 43 | has_apex = True 44 | except ImportError: 45 | has_apex = False 46 | 47 | has_native_amp = False 48 | try: 49 | if getattr(torch.cuda.amp, 'autocast') is not None: 50 | has_native_amp = True 51 | except AttributeError: 52 | pass 53 | 54 | torch.backends.cudnn.benchmark = True 55 | _logger = logging.getLogger('train') 56 | 57 | 58 | def main(): 59 | setup_default_logging() 60 | args, args_text = parse_args() 61 | 62 | if args.evaluate: 63 | random.seed(42) 64 | torch.manual_seed(42) 65 | cudnn.deterministic = True 66 | cudnn.benchmark = False 67 | np.random.seed(42) 68 | torch.cuda.manual_seed(42) 69 | 70 | os.environ["PYTHONHASHSEED"] = str(42) 71 | torch.cuda.manual_seed_all(42) 72 | 73 | args.prefetcher = not args.no_prefetcher 74 | args.distributed = False 75 | if 'WORLD_SIZE' in os.environ: 76 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 77 | if args.distributed and args.num_gpu > 1: 78 | _logger.warning( 79 | 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') 80 | args.num_gpu = 1 81 | 82 | args.device = 'cuda:0' 83 | args.world_size = 1 84 | args.rank = 0 # global rank 85 | if args.distributed: 86 | args.num_gpu = 1 87 | args.device = 'cuda:%d' % args.local_rank 88 | torch.cuda.set_device(args.local_rank) 89 | args.world_size = int(os.environ['WORLD_SIZE']) 90 | args.rank = int(os.environ['RANK']) 91 | torch.distributed.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size) 92 | args.world_size = torch.distributed.get_world_size() 93 | args.rank = torch.distributed.get_rank() 94 | assert args.rank >= 0 95 | 96 | if args.local_rank == 0 and not args.evaluate: 97 | wandb.init( 98 | # Set the project where this run will be logged 99 | project='WACV_WiGNet-inet', 100 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10) 101 | name=f'{args.model}_shift{args.use_shift}_k{args.knn}_adapt{args.adapt_knn}', 102 | config=args) 103 | 104 | if args.distributed: 105 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 106 | % (args.rank, args.world_size)) 107 | else: 108 | _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) 109 | 110 | torch.manual_seed(args.seed + args.rank) 111 | 112 | if('wignn' in args.model): 113 | model = create_model( 114 | args.model, 115 | num_classes=args.num_classes, 116 | drop_path_rate=args.drop_path, 117 | knn = args.knn, 118 | use_shifts = args.use_shift, 119 | adapt_knn = args.adapt_knn 120 | ) 121 | else: 122 | model = create_model( 123 | args.model, 124 | pretrained=args.pretrained, 125 | num_classes=args.num_classes, 126 | drop_rate=args.drop, 127 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 128 | drop_path_rate=args.drop_path, 129 | drop_block_rate=args.drop_block, 130 | global_pool=args.gp, 131 | bn_tf=args.bn_tf, 132 | bn_momentum=args.bn_momentum, 133 | bn_eps=args.bn_eps, 134 | checkpoint_path=args.initial_checkpoint 135 | ) 136 | 137 | ################## pretrain ############ 138 | if args.pretrain_path is not None: 139 | print('Loading:', args.pretrain_path) 140 | state_dict = torch.load(args.pretrain_path) 141 | model.load_state_dict(state_dict, strict=False) 142 | print('Pretrain weights loaded.') 143 | ################### flops ################# 144 | # print(model) 145 | if hasattr(model, 'default_cfg'): 146 | default_cfg = model.default_cfg 147 | input_size = [1] + list(default_cfg['input_size']) 148 | else: 149 | input_size = [1, 3, 224, 224] 150 | print(f'\n\n!!!!!!! Using input size: {input_size}\n\n') 151 | input = torch.randn(input_size)#.cuda() 152 | 153 | from torchprofile import profile_macs 154 | model.eval() 155 | macs = profile_macs(model, input) 156 | model.train() 157 | print('model flops:', macs, 'input_size:', input_size) 158 | ########################################## 159 | 160 | if args.local_rank == 0: 161 | _logger.info('Model %s created, param count: %d' % 162 | (args.model, sum([m.numel() for m in model.parameters()]))) 163 | 164 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) 165 | 166 | num_aug_splits = 0 167 | if args.aug_splits > 0: 168 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 169 | num_aug_splits = args.aug_splits 170 | 171 | """ if args.split_bn: 172 | assert num_aug_splits > 1 or args.resplit 173 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) """ 174 | 175 | use_amp = None 176 | if args.amp: 177 | # for backwards compat, `--amp` arg tries apex before native amp 178 | if has_apex: 179 | args.apex_amp = True 180 | elif has_native_amp: 181 | args.native_amp = True 182 | if args.apex_amp and has_apex: 183 | use_amp = 'apex' 184 | elif args.native_amp and has_native_amp: 185 | use_amp = 'native' 186 | elif args.apex_amp or args.native_amp: 187 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 188 | "Install NVIDA apex or upgrade to PyTorch 1.6") 189 | 190 | if args.num_gpu > 1: 191 | if use_amp == 'apex': 192 | _logger.warning( 193 | 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') 194 | use_amp = None 195 | model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() 196 | assert not args.channels_last, "Channels last not supported with DP, use DDP." 197 | else: 198 | model.cuda() 199 | if args.channels_last: 200 | model = model.to(memory_format=torch.channels_last) 201 | 202 | optimizer = create_optimizer(args, model) 203 | 204 | amp_autocast = suppress # do nothing 205 | loss_scaler = None 206 | if use_amp == 'apex': 207 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 208 | loss_scaler = ApexScaler() 209 | if args.local_rank == 0: 210 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 211 | elif use_amp == 'native': 212 | amp_autocast = torch.cuda.amp.autocast 213 | loss_scaler = NativeScaler() 214 | if args.local_rank == 0: 215 | _logger.info('Using native Torch AMP. Training in mixed precision.') 216 | else: 217 | if args.local_rank == 0: 218 | _logger.info('AMP not enabled. Training in float32.') 219 | 220 | # optionally resume from a checkpoint 221 | resume_epoch = None 222 | if args.resume: 223 | resume_epoch = resume_checkpoint( 224 | model, args.resume, 225 | optimizer=None if args.no_resume_opt else optimizer, 226 | loss_scaler=None if args.no_resume_opt else loss_scaler, 227 | log_info=args.local_rank == 0) 228 | 229 | model_ema = None 230 | if args.model_ema: 231 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 232 | model_ema = ModelEma( 233 | model, 234 | decay=args.model_ema_decay, 235 | device='cpu' if args.model_ema_force_cpu else '', 236 | resume=args.resume) 237 | 238 | if args.distributed: 239 | if args.sync_bn: 240 | assert not args.split_bn 241 | try: 242 | if has_apex and use_amp != 'native': 243 | # Apex SyncBN preferred unless native amp is activated 244 | model = convert_syncbn_model(model) 245 | else: 246 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 247 | if args.local_rank == 0: 248 | _logger.info( 249 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 250 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 251 | except Exception as e: 252 | _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') 253 | if has_apex and use_amp != 'native': 254 | # Apex DDP preferred unless native amp is activated 255 | if args.local_rank == 0: 256 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 257 | model = ApexDDP(model, delay_allreduce=True) 258 | else: 259 | if args.local_rank == 0: 260 | _logger.info("Using native Torch DistributedDataParallel.") 261 | model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 262 | # NOTE: EMA model does not need to be wrapped by DDP 263 | 264 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 265 | start_epoch = 0 266 | if args.start_epoch is not None: 267 | # a specified start_epoch will always override the resume epoch 268 | start_epoch = args.start_epoch 269 | elif resume_epoch is not None: 270 | start_epoch = resume_epoch 271 | if lr_scheduler is not None and start_epoch > 0: 272 | lr_scheduler.step(start_epoch) 273 | 274 | if args.local_rank == 0: 275 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 276 | 277 | train_dir = os.path.join(args.data, 'train') 278 | if not os.path.exists(train_dir): 279 | _logger.error('Training folder does not exist at: {}'.format(train_dir)) 280 | exit(1) 281 | dataset_train = Dataset(train_dir) 282 | 283 | collate_fn = None 284 | mixup_fn = None 285 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 286 | if mixup_active: 287 | mixup_args = dict( 288 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 289 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 290 | label_smoothing=args.smoothing, num_classes=args.num_classes) 291 | if args.prefetcher: 292 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 293 | collate_fn = FastCollateMixup(**mixup_args) 294 | else: 295 | mixup_fn = Mixup(**mixup_args) 296 | 297 | if num_aug_splits > 1: 298 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 299 | 300 | print(f"\n\n!!! input size dataloader: {data_config['input_size']}\n\n") 301 | train_interpolation = args.train_interpolation 302 | if args.no_aug or not train_interpolation: 303 | train_interpolation = data_config['interpolation'] 304 | loader_train = create_loader( 305 | dataset_train, 306 | input_size=data_config['input_size'], 307 | batch_size=args.batch_size, 308 | is_training=True, 309 | use_prefetcher=args.prefetcher, 310 | no_aug=args.no_aug, 311 | re_prob=args.reprob, 312 | re_mode=args.remode, 313 | re_count=args.recount, 314 | re_split=args.resplit, 315 | scale=args.scale, 316 | ratio=args.ratio, 317 | hflip=args.hflip, 318 | vflip=args.vflip, 319 | color_jitter=args.color_jitter, 320 | auto_augment=args.aa, 321 | num_aug_splits=num_aug_splits, 322 | interpolation=train_interpolation, 323 | mean=data_config['mean'], 324 | std=data_config['std'], 325 | num_workers=args.workers, 326 | distributed=args.distributed, 327 | collate_fn=collate_fn, 328 | pin_memory=args.pin_mem, 329 | use_multi_epochs_loader=args.use_multi_epochs_loader, 330 | repeated_aug=args.repeated_aug 331 | ) 332 | 333 | eval_dir = os.path.join(args.data, 'val') 334 | if not os.path.isdir(eval_dir): 335 | eval_dir = os.path.join(args.data, 'validation') 336 | if not os.path.isdir(eval_dir): 337 | _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) 338 | exit(1) 339 | dataset_eval = Dataset(eval_dir) 340 | 341 | loader_eval = create_loader( 342 | dataset_eval, 343 | input_size=data_config['input_size'], 344 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 345 | is_training=False, 346 | use_prefetcher=args.prefetcher, 347 | interpolation=data_config['interpolation'], 348 | mean=data_config['mean'], 349 | std=data_config['std'], 350 | num_workers=args.workers, 351 | distributed=args.distributed, 352 | crop_pct=data_config['crop_pct'], 353 | pin_memory=args.pin_mem, 354 | ) 355 | 356 | if args.jsd: 357 | assert num_aug_splits > 1 # JSD only valid with aug splits set 358 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() 359 | elif mixup_active: 360 | # smoothing is handled with mixup target transform 361 | train_loss_fn = SoftTargetCrossEntropy().cuda() 362 | elif args.smoothing: 363 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() 364 | else: 365 | train_loss_fn = nn.CrossEntropyLoss().cuda() 366 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 367 | 368 | if args.evaluate: 369 | print('Evaluating model..') 370 | if model_ema is not None: 371 | eval_metrics_test = validate(model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA) ') 372 | else: 373 | eval_metrics_test = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) 374 | 375 | print(eval_metrics_test) 376 | return 377 | 378 | eval_metric = args.eval_metric 379 | best_metric_no_ema = None 380 | best_epoch_no_ema = None 381 | 382 | best_metric_ema = None 383 | best_epoch_ema = None 384 | 385 | saver_no_ema = None 386 | output_dir_no_ema = '' 387 | saver_ema = None 388 | output_dir_ema = '' 389 | if args.local_rank == 0: 390 | output_base = args.output if args.output else './output' 391 | exp_name = '-'.join([ 392 | datetime.now().strftime("%Y%m%d-%H%M%S"), 393 | args.model, 394 | str(data_config['input_size'][-1]) 395 | ]) 396 | output_dir_no_ema = get_outdir(output_base, 'train', f'{exp_name}_NO_EMA') 397 | decreasing = True if eval_metric == 'loss' else False 398 | saver_no_ema = CheckpointSaver( 399 | model=model, optimizer=optimizer, args=args, model_ema=None, amp_scaler=loss_scaler, 400 | checkpoint_dir=output_dir_no_ema, recovery_dir=output_dir_no_ema, decreasing=decreasing) 401 | with open(os.path.join(output_dir_no_ema, 'args.yaml'), 'w') as f: 402 | f.write(args_text) 403 | 404 | if args.model_ema: 405 | output_dir_ema = get_outdir(output_base, 'train', f'{exp_name}_EMA') 406 | saver_ema = CheckpointSaver( 407 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 408 | checkpoint_dir=output_dir_ema, recovery_dir=output_dir_ema, decreasing=decreasing) 409 | with open(os.path.join(output_dir_ema, 'args.yaml'), 'w') as f: 410 | f.write(args_text) 411 | 412 | 413 | 414 | # try: 415 | for epoch in range(start_epoch, num_epochs): 416 | if args.distributed: 417 | loader_train.sampler.set_epoch(epoch) 418 | 419 | train_metrics = train_epoch( 420 | epoch, model, loader_train, optimizer, train_loss_fn, args, 421 | lr_scheduler=lr_scheduler, saver=saver_ema if args.model_ema else saver_no_ema, output_dir=output_dir_ema if args.model_ema else output_dir_no_ema, 422 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) 423 | 424 | if args.local_rank == 0: 425 | wandb.log({ 426 | 'train/loss':train_metrics['loss'], 427 | 'train/lr': optimizer.param_groups[0]["lr"] 428 | },step = epoch) 429 | 430 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 431 | if args.local_rank == 0: 432 | _logger.info("Distributing BatchNorm running means and vars") 433 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 434 | 435 | eval_metrics_no_ema = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) 436 | 437 | if args.local_rank == 0: 438 | wandb.log({ 439 | 'val/loss':eval_metrics_no_ema['loss'], 440 | 'val/acc@1':eval_metrics_no_ema['top1'], 441 | 'val/acc@5':eval_metrics_no_ema['top5'] 442 | },step = epoch) 443 | 444 | if model_ema is not None and not args.model_ema_force_cpu: 445 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 446 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 447 | eval_metrics_ema = validate( 448 | model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') 449 | # eval_metrics = ema_eval_metrics 450 | 451 | if args.local_rank == 0: 452 | wandb.log({ 453 | 'val_ema/loss':eval_metrics_ema['loss'], 454 | 'val_ema/acc@1':eval_metrics_ema['top1'], 455 | 'val_ema/acc@5':eval_metrics_ema['top5'] 456 | },step = epoch) 457 | 458 | if lr_scheduler is not None: 459 | # step LR for next epoch 460 | if args.model_ema: 461 | lr_scheduler.step(epoch + 1, eval_metrics_ema[eval_metric]) 462 | else: 463 | lr_scheduler.step(epoch + 1, eval_metrics_no_ema[eval_metric]) 464 | 465 | update_summary( 466 | epoch, train_metrics, eval_metrics_no_ema, os.path.join(output_dir_no_ema, 'summary.csv'), 467 | write_header=best_metric_no_ema is None) 468 | 469 | if args.model_ema: 470 | update_summary( 471 | epoch, train_metrics, eval_metrics_ema, os.path.join(output_dir_ema, 'summary.csv'), 472 | write_header=best_metric_ema is None) 473 | 474 | if saver_no_ema is not None: 475 | # save proper checkpoint with eval metric 476 | save_metric = eval_metrics_no_ema[eval_metric] 477 | best_metric_no_ema, best_epoch_no_ema = saver_no_ema.save_checkpoint(epoch, metric=save_metric) 478 | 479 | if saver_ema is not None: 480 | # save proper checkpoint with eval metric 481 | save_metric = eval_metrics_ema[eval_metric] 482 | best_metric_ema, best_epoch_ema = saver_ema.save_checkpoint(epoch, metric=save_metric) 483 | 484 | # except KeyboardInterrupt: 485 | # pass 486 | 487 | if best_metric_no_ema is not None: 488 | _logger.info('*** Best metric (NO EMA): {0} (epoch {1})'.format(best_metric_no_ema, best_epoch_no_ema)) 489 | 490 | if args.local_rank == 0: 491 | wandb.log({ 492 | f'best_no_ema/{eval_metric}':best_metric_no_ema, 493 | 'best_no_ema/epoch':best_epoch_no_ema 494 | },step = epoch) 495 | 496 | if best_metric_ema is not None: 497 | _logger.info('*** Best metric (EMA): {0} (epoch {1})'.format(best_metric_ema, best_epoch_ema)) 498 | 499 | if args.local_rank == 0: 500 | wandb.log({ 501 | f'best_ema/{eval_metric}':best_metric_ema, 502 | 'best_ema/epoch':best_epoch_ema 503 | },step = epoch) 504 | 505 | wandb.finish() 506 | 507 | 508 | def train_epoch( 509 | epoch, model, loader, optimizer, loss_fn, args, 510 | lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, 511 | loss_scaler=None, model_ema=None, mixup_fn=None): 512 | 513 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 514 | if args.prefetcher and loader.mixup_enabled: 515 | loader.mixup_enabled = False 516 | elif mixup_fn is not None: 517 | mixup_fn.mixup_enabled = False 518 | 519 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 520 | batch_time_m = AverageMeter() 521 | data_time_m = AverageMeter() 522 | losses_m = AverageMeter() 523 | 524 | model.train() 525 | 526 | end = time.time() 527 | last_idx = len(loader) - 1 528 | num_updates = epoch * len(loader) 529 | for batch_idx, (input, target) in enumerate(loader): 530 | # if batch_idx > 2: # TODO remove it 531 | # break 532 | last_batch = batch_idx == last_idx 533 | data_time_m.update(time.time() - end) 534 | if not args.prefetcher: 535 | input, target = input.cuda(), target.cuda() 536 | if mixup_fn is not None: 537 | input, target = mixup_fn(input, target) 538 | if args.channels_last: 539 | input = input.contiguous(memory_format=torch.channels_last) 540 | 541 | with amp_autocast(): 542 | output = model(input) 543 | loss = loss_fn(output, target) 544 | 545 | # sys.exit(1) 546 | 547 | if not args.distributed: 548 | losses_m.update(loss.item(), input.size(0)) 549 | 550 | optimizer.zero_grad() 551 | if loss_scaler is not None: 552 | loss_scaler( 553 | loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) 554 | else: 555 | loss.backward(create_graph=second_order) 556 | if args.clip_grad is not None: 557 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) 558 | optimizer.step() 559 | 560 | torch.cuda.synchronize() 561 | if model_ema is not None: 562 | model_ema.update(model) 563 | num_updates += 1 564 | 565 | batch_time_m.update(time.time() - end) 566 | if last_batch or batch_idx % args.log_interval == 0: 567 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 568 | lr = sum(lrl) / len(lrl) 569 | 570 | if args.distributed: 571 | reduced_loss = reduce_tensor(loss.data, args.world_size) 572 | losses_m.update(reduced_loss.item(), input.size(0)) 573 | 574 | if args.local_rank == 0: 575 | _logger.info( 576 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 577 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 578 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 579 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 580 | 'LR: {lr:.3e} ' 581 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 582 | epoch, 583 | batch_idx, len(loader), 584 | 100. * batch_idx / last_idx, 585 | loss=losses_m, 586 | batch_time=batch_time_m, 587 | rate=input.size(0) * args.world_size / batch_time_m.val, 588 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg, 589 | lr=lr, 590 | data_time=data_time_m)) 591 | 592 | if args.save_images and output_dir: 593 | torchvision.utils.save_image( 594 | input, 595 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 596 | padding=0, 597 | normalize=True) 598 | 599 | if saver is not None and args.recovery_interval and ( 600 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 601 | saver.save_recovery(epoch, batch_idx=batch_idx) 602 | 603 | if lr_scheduler is not None: 604 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 605 | 606 | end = time.time() 607 | # end for 608 | 609 | if hasattr(optimizer, 'sync_lookahead'): 610 | optimizer.sync_lookahead() 611 | 612 | return OrderedDict([('loss', losses_m.avg)]) 613 | 614 | 615 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): 616 | batch_time_m = AverageMeter() 617 | losses_m = AverageMeter() 618 | top1_m = AverageMeter() 619 | top5_m = AverageMeter() 620 | 621 | model.eval() 622 | 623 | end = time.time() 624 | last_idx = len(loader) - 1 625 | with torch.no_grad(): 626 | for batch_idx, (input, target) in enumerate(loader): 627 | 628 | # if batch_idx > 20: # TODO remove it 629 | # break 630 | last_batch = batch_idx == last_idx 631 | if not args.prefetcher: 632 | input = input.cuda() 633 | target = target.cuda() 634 | if args.channels_last: 635 | input = input.contiguous(memory_format=torch.channels_last) 636 | 637 | with amp_autocast(): 638 | output = model(input) 639 | 640 | if isinstance(output, (tuple, list)): 641 | output = output[0] 642 | 643 | # augmentation reduction 644 | reduce_factor = args.tta 645 | if reduce_factor > 1: 646 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 647 | target = target[0:target.size(0):reduce_factor] 648 | 649 | loss = loss_fn(output, target) 650 | 651 | 652 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 653 | 654 | if args.distributed: 655 | reduced_loss = reduce_tensor(loss.data, args.world_size) 656 | acc1 = reduce_tensor(acc1, args.world_size) 657 | acc5 = reduce_tensor(acc5, args.world_size) 658 | else: 659 | reduced_loss = loss.data 660 | 661 | torch.cuda.synchronize() 662 | 663 | losses_m.update(reduced_loss.item(), input.size(0)) 664 | top1_m.update(acc1.item(), output.size(0)) 665 | top5_m.update(acc5.item(), output.size(0)) 666 | 667 | batch_time_m.update(time.time() - end) 668 | end = time.time() 669 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 670 | log_name = 'Test' + log_suffix 671 | _logger.info( 672 | '{0}: [{1:>4d}/{2}] ' 673 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 674 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 675 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 676 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 677 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 678 | loss=losses_m, top1=top1_m, top5=top5_m)) 679 | 680 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 681 | 682 | return metrics 683 | 684 | 685 | if __name__ == '__main__': 686 | main() --------------------------------------------------------------------------------