├── src
├── .gitkeep
├── data
│ ├── __init__.py
│ ├── rasampler.py
│ └── myloader.py
├── gcn_lib
│ ├── __init__.py
│ ├── torch_local.py
│ ├── pos_embed.py
│ ├── torch_nn.py
│ ├── torch_edge.py
│ └── torch_vertex.py
├── dataloaders
│ └── celeba_hq.py
├── utils
│ └── metrics.py
├── model
│ ├── wignn_256.py
│ ├── transfer_models.py
│ ├── wignn.py
│ ├── pyramid_vig.py
│ ├── mobilevig.py
│ └── greedyvig.py
├── opt_transfer.py
├── opt.py
├── train_trasnfer_learning.py
├── utils.py
└── train.py
├── imgs
├── tab_inet.png
├── complexity.png
├── res_celebahq.png
└── teaser_wignet.png
├── LICENSE
├── README.md
├── .gitignore
└── env_wignet.yaml
/src/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .myloader import create_loader
2 |
--------------------------------------------------------------------------------
/imgs/tab_inet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/tab_inet.png
--------------------------------------------------------------------------------
/imgs/complexity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/complexity.png
--------------------------------------------------------------------------------
/imgs/res_celebahq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/res_celebahq.png
--------------------------------------------------------------------------------
/imgs/teaser_wignet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIDOSLAB/WiGNet/HEAD/imgs/teaser_wignet.png
--------------------------------------------------------------------------------
/src/gcn_lib/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .torch_nn import *
3 | from .torch_edge import *
4 | from .torch_vertex import *
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MXM License
2 |
3 | The copyright in this software is being made available under the BSD License, included below. This software may be subject to other third party and contributor rights, including patent rights, and no such rights are granted under this license.
4 |
5 | OWNER = University of Turin, IMT, Sisvel Technology.
6 |
7 | ORGANIZATION = University of Turin, IMT, Sisvel Technology.
8 |
9 | YEAR = 2024
10 |
11 | Copyright (c) 2024, University of Turin, IMT, Sisvel Technology.
12 |
13 | All rights reserved.
14 |
15 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
16 |
17 | - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
18 | - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
19 | - Neither the name of the University of Turin, IMT, Sisvel Technology nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
22 |
23 |
--------------------------------------------------------------------------------
/src/dataloaders/celeba_hq.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | import torchvision
7 | from torchvision import datasets, models, transforms
8 |
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | import time
13 | import os
14 |
15 |
16 |
17 | def get_celeba(args, get_train_sampler = False, transform_train = True, crop_size = 224, drop_last=False):
18 |
19 | transforms_augs = transforms.Compose([
20 | transforms.Resize((crop_size, crop_size)),
21 | transforms.RandomHorizontalFlip(), # data augmentation
22 | transforms.ToTensor(),
23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # normalization
24 | ])
25 |
26 | transforms_no_augs = transforms.Compose([
27 | transforms.Resize((crop_size, crop_size)),
28 | transforms.ToTensor(),
29 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30 | ])
31 |
32 | if transform_train:
33 | train_transforms = transforms_augs
34 | else:
35 | train_transforms = transforms_no_augs
36 |
37 | test_transforms = transforms_no_augs
38 |
39 | data_dir = f'{args.root}/CelebA_HQ_facial_identity_dataset'
40 | train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms)
41 | test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), test_transforms)
42 |
43 | dataset_labels = train_dataset.classes
44 | num_classes = len(dataset_labels)
45 |
46 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle= transform_train,
47 | num_workers=args.workers, pin_memory=True, sampler=None,
48 | persistent_workers=args.workers > 0,drop_last=drop_last)
49 |
50 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
51 | num_workers=args.workers, pin_memory=True, sampler=None,
52 | persistent_workers=args.workers > 0,drop_last=drop_last)
53 |
54 | if(get_train_sampler):
55 | return train_dataloader, None, test_dataloader, None, num_classes, dataset_labels
56 |
57 | return train_dataloader, None, test_dataloader, num_classes, dataset_labels
58 |
59 |
60 |
61 |
62 | if __name__ == '__main__':
63 |
64 | pass
--------------------------------------------------------------------------------
/src/data/rasampler.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/facebookresearch/deit/blob/main/samplers.py
2 | # Copyright (c) 2015-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the CC-by-NC license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 | import torch
9 | import torch.distributed as dist
10 | import math
11 |
12 |
13 | class RASampler(torch.utils.data.Sampler):
14 | """Sampler that restricts data loading to a subset of the dataset for distributed,
15 | with repeated augmentation.
16 | It ensures that different each augmented version of a sample will be visible to a
17 | different process (GPU)
18 | Heavily based on torch.utils.data.DistributedSampler
19 | """
20 |
21 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
22 | if num_replicas is None:
23 | if not dist.is_available():
24 | raise RuntimeError("Requires distributed package to be available")
25 | num_replicas = dist.get_world_size()
26 | if rank is None:
27 | if not dist.is_available():
28 | raise RuntimeError("Requires distributed package to be available")
29 | rank = dist.get_rank()
30 | self.dataset = dataset
31 | self.num_replicas = num_replicas
32 | self.rank = rank
33 | self.epoch = 0
34 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
37 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
38 | self.shuffle = shuffle
39 |
40 | def __iter__(self):
41 | # deterministically shuffle based on epoch
42 | g = torch.Generator()
43 | g.manual_seed(self.epoch)
44 | if self.shuffle:
45 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
46 | else:
47 | indices = list(range(len(self.dataset)))
48 |
49 | # add extra samples to make it evenly divisible
50 | indices = [ele for ele in indices for i in range(3)]
51 | indices += indices[:(self.total_size - len(indices))]
52 | assert len(indices) == self.total_size
53 |
54 | # subsample
55 | indices = indices[self.rank:self.total_size:self.num_replicas]
56 | assert len(indices) == self.num_samples
57 |
58 | return iter(indices[:self.num_selected_samples])
59 |
60 | def __len__(self):
61 | return self.num_selected_samples
62 |
63 | def set_epoch(self, epoch):
64 | self.epoch = epoch
--------------------------------------------------------------------------------
/src/gcn_lib/torch_local.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from timm.models.layers import to_2tuple
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | def window_partition_channel_last(x, window_size=8):
7 | """
8 | Args:
9 | x: (B, H, W, C)
10 | window_size (int): window size
11 | Returns:
12 | windows: (num_windows*B, window_size, window_size, C)
13 | """
14 | B, H, W, C = x.shape
15 | windows = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
16 | windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
17 | return windows
18 |
19 | def window_partition(x, window_size=7):
20 | """
21 | Args:
22 | x: (B, C, H, W)
23 | window_size (int): window size
24 | Returns:
25 | windows: (num_windows*B, window_size, window_size, C)
26 | """
27 | x = x.transpose(1,2).transpose(2,3)
28 | B, H, W, C = x.shape
29 | windows = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
30 | windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
31 |
32 | windows = windows.transpose(2,3).transpose(1,2)
33 | return windows
34 |
35 | def window_reverse(windows, window_size, H, W):
36 | """
37 | Args:
38 | windows: (num_windows*B, C, window_size, window_size)
39 | window_size (int): Window size
40 | H (int): Height of image
41 | W (int): Width of image
42 | Returns:
43 | x: (B, C, H, W)
44 | """
45 | windows = windows.transpose(1,2).transpose(2,3)
46 | B = int(windows.shape[0] / (H * W / window_size / window_size))
47 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
48 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
49 | x = x.transpose(2,3).transpose(1,2)
50 | return x
51 |
52 |
53 | class PatchEmbed(nn.Module):
54 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96):#, norm_layer=None):
55 | super().__init__()
56 | patch_size = to_2tuple(patch_size)
57 | self.patch_size = patch_size
58 |
59 | self.in_chans = in_chans
60 | self.embed_dim = embed_dim
61 |
62 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
63 | self.norm = nn.BatchNorm2d(embed_dim)
64 |
65 | def forward(self, x):
66 | """Forward function."""
67 | # padding
68 | _, _, H, W = x.size()
69 | if W % self.patch_size[1] != 0:
70 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
71 | if H % self.patch_size[0] != 0:
72 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
73 |
74 | x = self.proj(x) # B C Wh Ww
75 | x = self.norm(x)
76 |
77 | return x
--------------------------------------------------------------------------------
/src/gcn_lib/pos_embed.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 | import torch
5 |
6 | # --------------------------------------------------------
7 | # relative position embedding
8 | # References: https://arxiv.org/abs/2009.13658
9 | # --------------------------------------------------------
10 | def get_2d_relative_pos_embed(embed_dim, grid_size):
11 | """
12 | grid_size: int of the grid height and width
13 | return:
14 | pos_embed: [grid_size*grid_size, grid_size*grid_size]
15 | """
16 | pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size)
17 | relative_pos = 2 * np.matmul(pos_embed, pos_embed.transpose()) / pos_embed.shape[1]
18 | return relative_pos
19 |
20 |
21 | # --------------------------------------------------------
22 | # 2D sine-cosine position embedding
23 | # References:
24 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
25 | # MoCo v3: https://github.com/facebookresearch/moco-v3
26 | # --------------------------------------------------------
27 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
28 | """
29 | grid_size: int of the grid height and width
30 | return:
31 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
32 | """
33 | grid_h = np.arange(grid_size, dtype=np.float32)
34 | grid_w = np.arange(grid_size, dtype=np.float32)
35 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
36 | grid = np.stack(grid, axis=0)
37 |
38 | grid = grid.reshape([2, 1, grid_size, grid_size])
39 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
40 | if cls_token:
41 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
42 | return pos_embed
43 |
44 |
45 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
46 | assert embed_dim % 2 == 0
47 |
48 | # use half of dimensions to encode grid_h
49 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51 |
52 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53 | return emb
54 |
55 |
56 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57 | """
58 | embed_dim: output dimension for each position
59 | pos: a list of positions to be encoded: size (M,)
60 | out: (M, D)
61 | """
62 | assert embed_dim % 2 == 0
63 | omega = np.arange(embed_dim // 2, dtype=np.float)
64 | omega /= embed_dim / 2.
65 | omega = 1. / 10000**omega # (D/2,)
66 |
67 | pos = pos.reshape(-1) # (M,)
68 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
69 |
70 | emb_sin = np.sin(out) # (M, D/2)
71 | emb_cos = np.cos(out) # (M, D/2)
72 |
73 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74 | return emb
75 |
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from enum import Enum
4 |
5 | class Summary(Enum):
6 | NONE = 0
7 | AVERAGE = 1
8 | SUM = 2
9 | COUNT = 3
10 |
11 | class AverageMeter(object):
12 | """Computes and stores the average and current value"""
13 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
14 | self.name = name
15 | self.fmt = fmt
16 | self.summary_type = summary_type
17 | self.reset()
18 |
19 | def reset(self):
20 | self.val = 0
21 | self.avg = 0
22 | self.sum = 0
23 | self.count = 0
24 |
25 | def update(self, val, n=1):
26 | self.val = val
27 | self.sum += val * n
28 | self.count += n
29 | self.avg = self.sum / self.count
30 |
31 | def all_reduce(self):
32 | if torch.cuda.is_available():
33 | device = torch.device("cuda")
34 | elif torch.backends.mps.is_available():
35 | device = torch.device("mps")
36 | else:
37 | device = torch.device("cpu")
38 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
39 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
40 | self.sum, self.count = total.tolist()
41 | self.avg = self.sum / self.count
42 |
43 | def __str__(self):
44 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
45 | return fmtstr.format(**self.__dict__)
46 |
47 | def summary(self):
48 | fmtstr = ''
49 | if self.summary_type is Summary.NONE:
50 | fmtstr = ''
51 | elif self.summary_type is Summary.AVERAGE:
52 | fmtstr = '{name} {avg:.3f}'
53 | elif self.summary_type is Summary.SUM:
54 | fmtstr = '{name} {sum:.3f}'
55 | elif self.summary_type is Summary.COUNT:
56 | fmtstr = '{name} {count:.3f}'
57 | else:
58 | raise ValueError('invalid summary type %r' % self.summary_type)
59 |
60 | return fmtstr.format(**self.__dict__)
61 |
62 |
63 | class ProgressMeter(object):
64 | def __init__(self, num_batches, meters, prefix=""):
65 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
66 | self.meters = meters
67 | self.prefix = prefix
68 |
69 | def display(self, batch):
70 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
71 | entries += [str(meter) for meter in self.meters]
72 | print('\t'.join(entries))
73 |
74 | def display_summary(self):
75 | entries = [" *"]
76 | entries += [meter.summary() for meter in self.meters]
77 | print(' '.join(entries))
78 |
79 | def _get_batch_fmtstr(self, num_batches):
80 | num_digits = len(str(num_batches // 1))
81 | fmt = '{:' + str(num_digits) + 'd}'
82 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
83 |
84 | def accuracy(output, target, topk=(1,)):
85 | """Computes the accuracy over the k top predictions for the specified values of k"""
86 | with torch.no_grad():
87 | maxk = max(topk)
88 | batch_size = target.size(0)
89 |
90 | _, pred = output.topk(maxk, 1, True, True)
91 | pred = pred.t()
92 | correct = pred.eq(target.view(1, -1).expand_as(pred))
93 |
94 | res = []
95 | for k in topk:
96 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
97 | res.append(correct_k.mul_(100.0 / batch_size))
98 | return res
--------------------------------------------------------------------------------
/src/model/wignn_256.py:
--------------------------------------------------------------------------------
1 | from timm.models.registry import register_model
2 | from model.wignn import DeepGCN
3 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
4 | from timm.models import create_model
5 | from torchprofile import profile_macs
6 |
7 | import torch
8 |
9 | def _cfg(url='', **kwargs):
10 | return {
11 | 'url': url,
12 | 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None,
13 | 'crop_pct': .9, 'interpolation': 'bicubic',
14 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
15 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
16 | **kwargs
17 | }
18 |
19 |
20 | default_cfgs = {
21 | 'wignn_256_gelu': _cfg(
22 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
23 | ),
24 | 'wignn_b_256_gelu': _cfg(
25 | crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
26 | ),
27 | }
28 |
29 |
30 | class OptInit:
31 | def __init__(
32 | self,
33 | num_classes=1000,
34 | drop_path_rate=0.0,
35 | knn = 9,
36 | use_shifts = True,
37 | use_reduce_ratios = False,
38 | img_size = 256,
39 | adapt_knn = True,
40 |
41 | channels = None,
42 | blocks = None,
43 | **kwargs):
44 |
45 | self.k = knn # neighbor num (default:9)
46 | self.conv = 'mr' # graph conv layer {edge, mr}
47 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
48 | self.norm = 'batch' # batch or instance normalization {batch, instance}
49 | self.bias = True # bias of conv layer True or False
50 | self.dropout = 0.0 # dropout rate
51 | self.use_dilation = True # use dilated knn or not
52 | self.epsilon = 0.2 # stochastic epsilon for gcn
53 | self.use_stochastic = False # stochastic for gcn, True or False
54 | self.drop_path = drop_path_rate
55 | self.blocks = blocks # number of basic blocks in the backbone
56 | self.channels = channels # number of channels of deep features
57 | self.n_classes = num_classes # Dimension of out_channels
58 | self.emb_dims = 1024 # Dimension of embeddings
59 | self.windows_size = 8
60 | self.use_shifts = use_shifts
61 | self.img_size = img_size
62 | self.use_reduce_ratios = use_reduce_ratios
63 | self.adapt_knn = adapt_knn
64 |
65 | @register_model
66 | def wignn_ti_256_gelu(pretrained=False, **kwargs):
67 |
68 | opt = OptInit(**kwargs, channels = [48, 96, 240, 384], blocks= [2,2,6,2])
69 | model = DeepGCN(opt)
70 | model.default_cfg = default_cfgs['wignn_256_gelu']
71 | return model
72 |
73 |
74 | @register_model
75 | def wignn_s_256_gelu(pretrained=False, **kwargs):
76 |
77 | opt = OptInit(**kwargs, channels = [80, 160, 400, 640], blocks= [2,2,6,2])
78 | model = DeepGCN(opt)
79 | model.default_cfg = default_cfgs['wignn_256_gelu']
80 | return model
81 |
82 |
83 | @register_model
84 | def wignn_m_256_gelu(pretrained=False, **kwargs):
85 |
86 | opt = OptInit(**kwargs, channels = [96, 192, 384, 768], blocks= [2,2,16,2])
87 |
88 | model = DeepGCN(opt)
89 | model.default_cfg = default_cfgs['wignn_256_gelu']
90 | return model
91 |
92 |
93 | @register_model
94 | def wignn_b_256_gelu(pretrained=False, **kwargs):
95 |
96 | opt = OptInit(**kwargs, channels = [128, 256, 512, 1024], blocks= [2,2,18,2])
97 |
98 | model = DeepGCN(opt)
99 | model.default_cfg = default_cfgs['wignn_b_256_gelu']
100 | return model
101 |
102 |
103 |
104 | if __name__ == '__main__':
105 |
106 | pass
107 |
108 |
109 |
--------------------------------------------------------------------------------
/src/gcn_lib/torch_nn.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 | from torch.nn import Sequential as Seq, Linear as Lin, Conv2d
5 | from einops import rearrange
6 |
7 |
8 | ##############################
9 | # Basic layers
10 | ##############################
11 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
12 | # activation layer
13 |
14 | act = act.lower()
15 | if act == 'relu':
16 | layer = nn.ReLU(inplace)
17 | elif act == 'leakyrelu':
18 | layer = nn.LeakyReLU(neg_slope, inplace)
19 | elif act == 'prelu':
20 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
21 | elif act == 'gelu':
22 | layer = nn.GELU()
23 | elif act == 'hswish':
24 | layer = nn.Hardswish(inplace)
25 | else:
26 | raise NotImplementedError('activation layer [%s] is not found' % act)
27 | return layer
28 |
29 |
30 | def norm_layer(norm, nc):
31 | # normalization layer 2d
32 | norm = norm.lower()
33 | if norm == 'batch':
34 | layer = nn.BatchNorm2d(nc, affine=True)
35 | elif norm == 'instance':
36 | layer = nn.InstanceNorm2d(nc, affine=False)
37 | else:
38 | raise NotImplementedError('normalization layer [%s] is not found' % norm)
39 | return layer
40 |
41 |
42 | class MLP(Seq):
43 | def __init__(self, channels, act='relu', norm=None, bias=True):
44 | m = []
45 | for i in range(1, len(channels)):
46 | m.append(Lin(channels[i - 1], channels[i], bias))
47 | if act is not None and act.lower() != 'none':
48 | m.append(act_layer(act))
49 | if norm is not None and norm.lower() != 'none':
50 | m.append(norm_layer(norm, channels[-1]))
51 | super(MLP, self).__init__(*m)
52 |
53 |
54 | class BasicConv(Seq):
55 | def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.):
56 | m = []
57 | for i in range(1, len(channels)):
58 | m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias, groups=4))
59 | if norm is not None and norm.lower() != 'none':
60 | m.append(norm_layer(norm, channels[-1]))
61 | if act is not None and act.lower() != 'none':
62 | m.append(act_layer(act))
63 | if drop > 0:
64 | m.append(nn.Dropout2d(drop))
65 |
66 | super(BasicConv, self).__init__(*m)
67 |
68 | self.reset_parameters()
69 |
70 | def reset_parameters(self):
71 | for m in self.modules():
72 | if isinstance(m, nn.Conv2d):
73 | nn.init.kaiming_normal_(m.weight)
74 | if m.bias is not None:
75 | nn.init.zeros_(m.bias)
76 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
77 | m.weight.data.fill_(1)
78 | m.bias.data.zero_()
79 |
80 |
81 | def batched_index_select(x, idx):
82 | r"""fetches neighbors features from a given neighbor idx
83 |
84 | Args:
85 | x (Tensor): input feature Tensor
86 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
87 | idx (Tensor): edge_idx
88 | :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
89 | Returns:
90 | Tensor: output neighbors features
91 | :math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
92 | """
93 | batch_size, num_dims, num_vertices_reduced = x.shape[:3]
94 | _, num_vertices, k = idx.shape
95 | idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
96 | idx = idx + idx_base
97 | idx = idx.contiguous().view(-1)
98 |
99 | x = x.transpose(2, 1)
100 | feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :]
101 | feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
102 | return feature
103 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WiGNet: Windowed Vision Graph Neural Network
2 |
3 | Pytorch implementation of the paper "**WiGNet: Windowed Vision Graph Neural Network**", published at WACV 2025. This repository is based on [VisionGNN](https://github.com/jichengyuan/Vision_GNN).
4 |
5 | [ArXiv](https://arxiv.org/abs/2410.00807)
6 |
7 |
8 |

9 |
10 |
11 | ## Abstract
12 | In recent years, Graph Neural Networks (GNNs) have demonstrated strong adaptability to various real-world challenges, with architectures such as Vision GNN (ViG) achieving state-of-the-art performance in several computer vision tasks. However, their practical applicability is hindered by the computational complexity of constructing the graph, which scales quadratically with the image size. In this paper, we introduce a novel Windowed vision Graph neural Network (WiGNet) model for efficient image processing. WiGNet explores a different strategy from previous works by partitioning the image into windows and constructing a graph within each window. Therefore, our model uses graph convolutions instead of the typical 2D convolution or self-attention mechanism. WiGNet effectively manages computational and memory complexity for large image sizes. We evaluate our method in the ImageNet-1k benchmark dataset and test the adaptability of WiGNet using the CelebA-HQ dataset as a downstream task with higher-resolution images. In both of these scenarios, our method achieves competitive results compared to previous vision GNNs while keeping memory and computational complexity at bay. WiGNet offers a promising solution toward the deployment of vision GNNs in real-world applications.
13 |
14 |
15 |
16 |

17 |
18 |
19 |
20 | ## Usage
21 |
22 | Download our pretrained model from [here](https://drive.google.com/file/d/11bDJaiYxCIwG2OxapIJSkQZys38wDI4S/view?usp=sharing).
23 |
24 | ### Environment
25 | - conda env create -f env_wignet.yml
26 | - conda activate wignet
27 |
28 | ### ImageNet Classification
29 |
30 |
31 | - Evaluation
32 | ```
33 | python train.py --model wignn_ti_256_gelu \
34 | --img-size 256 \
35 | --knn 9 \
36 | --use-shift 1 \
37 | --adapt-knn 1 \
38 | --data /path/to/imagenet \
39 | -b 128 \
40 | --resume /path/to/checkpoint.pth.tar \
41 | --evaluate
42 | ```
43 |
44 | - Training WiGNet-Ti on 8 GPUs
45 | ```
46 | python -m torch.distributed.launch \
47 | --nproc_per_node=8 train.py \
48 | --model wignn_ti_256_gelu \
49 | --img-size 256 \
50 | --knn 9 \
51 | --use-shift 1 \
52 | --adapt-knn 1 \
53 | --use-reduce-ratios 0 \
54 | --data /path/to/imagenet \
55 | --sched cosine \
56 | --epochs 300 \
57 | --opt adamw -j 8 \
58 | --warmup-lr 1e-6 \
59 | --mixup .8 \
60 | --cutmix 1.0 \
61 | --model-ema \
62 | --model-ema-decay 0.99996 \
63 | --aa rand-m9-mstd0.5-inc1 \
64 | --color-jitter 0.4 \
65 | --warmup-epochs 20 \
66 | --opt-eps 1e-8 \
67 | --remode pixel \
68 | --reprob 0.25 \
69 | --amp \
70 | --lr 2e-3 \
71 | --weight-decay .05 \
72 | --drop 0 \
73 | --drop-path .1 \
74 | -b 128 \
75 | --output /path/to/save
76 | ```
77 |
78 |
79 |
80 | ### Complexity Evaluation
81 |
82 | **Memory & MACs**
83 | - WiGNet
84 | ```
85 | python -m model.wignn
86 | ```
87 |
88 | - ViG
89 | ```
90 | python -m model.pyramid_vig
91 | ```
92 |
93 | - GreedyViG
94 | ```
95 | python -m model.greedyvig
96 | ```
97 |
98 | - MobileViG
99 | ```
100 | python -m model.mobilevig
101 | ```
102 |
103 |
104 |
105 |
106 | ### Transfer Learning
107 | - WiGNet
108 | ```
109 | python train_trasnfer_learning.py
110 | --model-type wignn_ti_256_gelu \
111 | --use-shift 1 \
112 | --adapt-knn 1 \
113 | --batch-size 64 \
114 | --checkpoint /path/to/checkpoint.pth.tar \
115 | --crop-size 512 \
116 | --dataset CelebA \
117 | --epochs 30 \
118 | --freeze 1 \
119 | --loss cross_entropy \
120 | --lr 0.001 \
121 | --lr-scheduler constant \
122 | --opt adam \
123 | --root /path/to/save/dataset \
124 | --save-dir /path/to/save/outputs_tl_high_res/ \
125 | --seed 1
126 | ```
127 |
128 | For ViG include `--num-gpu 8`
129 |
130 | ## Results
131 |
132 | ### ImageNet-1k
133 |
134 |

135 |
136 |
137 | ### CelebaHq
138 |
139 |

140 |
141 |
142 |
143 | # Citation
144 | If you use our code, please cite
145 |
146 | ```
147 | @inproceedings{spadaro2024wignet,
148 | title={{W}i{GN}et: {W}indowed {V}ision {G}raph {N}eural {N}etwork},
149 | author={Spadaro, Gabriele and Grangetto, Marco and Fiandrotti, Attilio and Tartaglione, Enzo and Giraldo, Jhony H},
150 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
151 | year={2025}
152 | }
153 | ```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom ignores
2 | wandb
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | ### Windows template
141 | # Windows thumbnail cache files
142 | Thumbs.db
143 | Thumbs.db:encryptable
144 | ehthumbs.db
145 | ehthumbs_vista.db
146 |
147 | # Dump file
148 | *.stackdump
149 |
150 | # Folder config file
151 | [Dd]esktop.ini
152 |
153 | # Recycle Bin used on file shares
154 | $RECYCLE.BIN/
155 |
156 | # Windows Installer files
157 | *.cab
158 | *.msi
159 | *.msix
160 | *.msm
161 | *.msp
162 |
163 | # Windows shortcuts
164 | *.lnk
165 |
166 | ### JetBrains template
167 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
168 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
169 |
170 | # User-specific stuff
171 | .idea/**/workspace.xml
172 | .idea/**/tasks.xml
173 | .idea/**/usage.statistics.xml
174 | .idea/**/dictionaries
175 | .idea/**/shelf
176 |
177 | # Generated files
178 | .idea/**/contentModel.xml
179 |
180 | # Sensitive or high-churn files
181 | .idea/**/dataSources/
182 | .idea/**/dataSources.ids
183 | .idea/**/dataSources.local.xml
184 | .idea/**/sqlDataSources.xml
185 | .idea/**/dynamic.xml
186 | .idea/**/uiDesigner.xml
187 | .idea/**/dbnavigator.xml
188 |
189 | # Gradle
190 | .idea/**/gradle.xml
191 | .idea/**/libraries
192 |
193 | # Gradle and Maven with auto-import
194 | # When using Gradle or Maven with auto-import, you should exclude module files,
195 | # since they will be recreated, and may cause churn. Uncomment if using
196 | # auto-import.
197 | # .idea/artifacts
198 | # .idea/compiler.xml
199 | # .idea/jarRepositories.xml
200 | # .idea/modules.xml
201 | # .idea/*.iml
202 | # .idea/modules
203 | # *.iml
204 | # *.ipr
205 |
206 | # CMake
207 | cmake-build-*/
208 |
209 | # Mongo Explorer plugin
210 | .idea/**/mongoSettings.xml
211 |
212 | # File-based project format
213 | *.iws
214 |
215 | # IntelliJ
216 | out/
217 |
218 | # mpeltonen/sbt-idea plugin
219 | .idea_modules/
220 |
221 | # JIRA plugin
222 | atlassian-ide-plugin.xml
223 |
224 | # Cursive Clojure plugin
225 | .idea/replstate.xml
226 |
227 | # Crashlytics plugin (for Android Studio and IntelliJ)
228 | com_crashlytics_export_strings.xml
229 | crashlytics.properties
230 | crashlytics-build.properties
231 | fabric.properties
232 |
233 | # Editor-based Rest Client
234 | .idea/httpRequests
235 |
236 | # Android studio 3.1+ serialized cache file
237 | .idea/caches/build_file_checksums.ser
238 |
239 | ### VisualStudioCode template
240 | .vscode/*
241 | !.vscode/settings.json
242 | !.vscode/tasks.json
243 | !.vscode/launch.json
244 | !.vscode/extensions.json
245 | *.code-workspace
246 |
247 | # Local History for Visual Studio Code
248 | .history/
249 |
250 | .idea/
251 |
252 |
--------------------------------------------------------------------------------
/src/data/myloader.py:
--------------------------------------------------------------------------------
1 |
2 | """ Loader Factory, Fast Collate, CUDA Prefetcher
3 |
4 | Prefetcher and Fast Collate inspired by NVIDIA APEX example at
5 | https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
6 |
7 | Hacked together by / Copyright 2020 Ross Wightman
8 | """
9 |
10 | import torch.utils.data
11 | import torch.distributed as dist
12 | import numpy as np
13 |
14 | from timm.data.transforms_factory import create_transform
15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16 | from timm.data.distributed_sampler import OrderedDistributedSampler
17 | from timm.data.random_erasing import RandomErasing
18 | from timm.data.mixup import FastCollateMixup
19 | from timm.data.loader import fast_collate, PrefetchLoader, MultiEpochsDataLoader
20 |
21 | from .rasampler import RASampler
22 |
23 |
24 | def is_dist_avail_and_initialized():
25 | if not dist.is_available():
26 | return False
27 | if not dist.is_initialized():
28 | return False
29 | return True
30 |
31 |
32 | def get_world_size():
33 | if not is_dist_avail_and_initialized():
34 | return 1
35 | return dist.get_world_size()
36 |
37 |
38 | def get_rank():
39 | if not is_dist_avail_and_initialized():
40 | return 0
41 | return dist.get_rank()
42 |
43 |
44 | def create_loader(
45 | dataset,
46 | input_size,
47 | batch_size,
48 | is_training=False,
49 | use_prefetcher=True,
50 | no_aug=False,
51 | re_prob=0.,
52 | re_mode='const',
53 | re_count=1,
54 | re_split=False,
55 | scale=None,
56 | ratio=None,
57 | hflip=0.5,
58 | vflip=0.,
59 | color_jitter=0.4,
60 | auto_augment=None,
61 | num_aug_splits=0,
62 | interpolation='bilinear',
63 | mean=IMAGENET_DEFAULT_MEAN,
64 | std=IMAGENET_DEFAULT_STD,
65 | num_workers=1,
66 | distributed=False,
67 | crop_pct=None,
68 | collate_fn=None,
69 | pin_memory=False,
70 | fp16=False,
71 | tf_preprocessing=False,
72 | use_multi_epochs_loader=False,
73 | repeated_aug=False
74 | ):
75 | re_num_splits = 0
76 | if re_split:
77 | # apply RE to second half of batch if no aug split otherwise line up with aug split
78 | re_num_splits = num_aug_splits or 2
79 | dataset.transform = create_transform(
80 | input_size,
81 | is_training=is_training,
82 | use_prefetcher=use_prefetcher,
83 | no_aug=no_aug,
84 | scale=scale,
85 | ratio=ratio,
86 | hflip=hflip,
87 | vflip=vflip,
88 | color_jitter=color_jitter,
89 | auto_augment=auto_augment,
90 | interpolation=interpolation,
91 | mean=mean,
92 | std=std,
93 | crop_pct=crop_pct,
94 | tf_preprocessing=tf_preprocessing,
95 | re_prob=re_prob,
96 | re_mode=re_mode,
97 | re_count=re_count,
98 | re_num_splits=re_num_splits,
99 | separate=num_aug_splits > 0,
100 | )
101 |
102 | sampler = None
103 | if distributed:
104 | if is_training:
105 | if repeated_aug:
106 | print('using repeated_aug')
107 | num_tasks = get_world_size()
108 | global_rank = get_rank()
109 | sampler = RASampler(
110 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
111 | )
112 | else:
113 | sampler = torch.utils.data.distributed.DistributedSampler(dataset)
114 | else:
115 | # This will add extra duplicate entries to result in equal num
116 | # of samples per-process, will slightly alter validation results
117 | sampler = OrderedDistributedSampler(dataset)
118 | else:
119 | if is_training and repeated_aug:
120 | print('using repeated_aug')
121 | num_tasks = get_world_size()
122 | global_rank = get_rank()
123 | sampler = RASampler(
124 | dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
125 | )
126 |
127 | if collate_fn is None:
128 | collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
129 |
130 | loader_class = torch.utils.data.DataLoader
131 |
132 | if use_multi_epochs_loader:
133 | loader_class = MultiEpochsDataLoader
134 |
135 | loader = loader_class(
136 | dataset,
137 | batch_size=batch_size,
138 | shuffle=sampler is None and is_training,
139 | num_workers=num_workers,
140 | sampler=sampler,
141 | collate_fn=collate_fn,
142 | pin_memory=pin_memory,
143 | drop_last=is_training,
144 | )
145 | if use_prefetcher:
146 | prefetch_re_prob = re_prob if is_training and not no_aug else 0.
147 | loader = PrefetchLoader(
148 | loader,
149 | mean=mean,
150 | std=std,
151 | fp16=fp16,
152 | re_prob=prefetch_re_prob,
153 | re_mode=re_mode,
154 | re_count=re_count,
155 | re_num_splits=re_num_splits
156 | )
157 |
158 | return loader
--------------------------------------------------------------------------------
/src/gcn_lib/torch_edge.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | import sys
7 |
8 | def pairwise_distance(x):
9 | """
10 | Compute pairwise distance of a point cloud.
11 | Args:
12 | x: tensor (batch_size, num_points, num_dims)
13 | Returns:
14 | pairwise distance: (batch_size, num_points, num_points)
15 | """
16 | with torch.no_grad():
17 | x_inner = -2*torch.matmul(x, x.transpose(2, 1))
18 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
19 | return x_square + x_inner + x_square.transpose(2, 1)
20 |
21 |
22 | def part_pairwise_distance(x, start_idx=0, end_idx=1):
23 | """
24 | Compute pairwise distance of a point cloud.
25 | Args:
26 | x: tensor (batch_size, num_points, num_dims)
27 | Returns:
28 | pairwise distance: (batch_size, num_points, num_points)
29 | """
30 | with torch.no_grad():
31 | x_part = x[:, start_idx:end_idx]
32 | x_square_part = torch.sum(torch.mul(x_part, x_part), dim=-1, keepdim=True)
33 | x_inner = -2*torch.matmul(x_part, x.transpose(2, 1))
34 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
35 | return x_square_part + x_inner + x_square.transpose(2, 1)
36 |
37 |
38 | def xy_pairwise_distance(x, y):
39 | """
40 | Compute pairwise distance of a point cloud.
41 | Args:
42 | x: tensor (batch_size, num_points, num_dims)
43 | Returns:
44 | pairwise distance: (batch_size, num_points, num_points)
45 | """
46 | with torch.no_grad():
47 | xy_inner = -2*torch.matmul(x, y.transpose(2, 1))
48 | x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
49 | y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
50 | return x_square + xy_inner + y_square.transpose(2, 1)
51 |
52 |
53 | def dense_knn_matrix(x, k=16, relative_pos=None):
54 | """Get KNN based on the pairwise distance.
55 | Args:
56 | x: (batch_size, num_dims, num_points, 1)
57 | k: int
58 | Returns:
59 | nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k)
60 | """
61 | with torch.no_grad():
62 | x = x.transpose(2, 1).squeeze(-1)
63 | batch_size, n_points, n_dims = x.shape
64 | ### memory efficient implementation ###
65 | n_part = 10000
66 | if n_points > n_part:
67 | nn_idx_list = []
68 | groups = math.ceil(n_points / n_part)
69 | for i in range(groups):
70 | start_idx = n_part * i
71 | end_idx = min(n_points, n_part * (i + 1))
72 | dist = part_pairwise_distance(x.detach(), start_idx, end_idx)
73 | if relative_pos is not None:
74 | dist += relative_pos[:, start_idx:end_idx]
75 | _, nn_idx_part = torch.topk(-dist, k=k)
76 | nn_idx_list += [nn_idx_part]
77 | nn_idx = torch.cat(nn_idx_list, dim=1)
78 | else:
79 | dist = pairwise_distance(x.detach())
80 |
81 | if relative_pos is not None:
82 |
83 | dist += relative_pos
84 |
85 | # nn_idx = torch.randint(0, n_points-1, (batch_size,n_points,k)).to(dist.device)
86 | _, nn_idx = torch.topk(-dist, k=k) # b, n, k
87 | ######
88 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
89 | return torch.stack((nn_idx, center_idx), dim=0)
90 |
91 |
92 | def xy_dense_knn_matrix(x, y, k=16, relative_pos=None):
93 | """Get KNN based on the pairwise distance.
94 | Args:
95 | x: (batch_size, num_dims, num_points, 1)
96 | k: int
97 | Returns:
98 | nearest neighbors: (batch_size, num_points, k) (batch_size, num_points, k)
99 | """
100 | with torch.no_grad():
101 | x = x.transpose(2, 1).squeeze(-1)
102 | y = y.transpose(2, 1).squeeze(-1)
103 | batch_size, n_points, n_dims = x.shape
104 | dist = xy_pairwise_distance(x.detach(), y.detach())
105 | if relative_pos is not None:
106 | dist += relative_pos
107 |
108 | # nn_idx = torch.randint(1, 195, (batch_size,n_points,k)).to(dist.device)
109 | _, nn_idx = torch.topk(-dist, k=k)
110 |
111 | # print('Campling values')
112 | # nn_idx = torch.clamp(nn_idx, min=0, max=n_points-1)
113 |
114 | center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
115 | return torch.stack((nn_idx, center_idx), dim=0)
116 |
117 |
118 | class DenseDilated(nn.Module):
119 | """
120 | Find dilated neighbor from neighbor list
121 |
122 | edge_index: (2, batch_size, num_points, k)
123 | """
124 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
125 | super(DenseDilated, self).__init__()
126 | self.dilation = dilation
127 | self.stochastic = stochastic
128 | self.epsilon = epsilon
129 | self.k = k
130 |
131 | def forward(self, edge_index):
132 | if self.stochastic:
133 | if torch.rand(1) < self.epsilon and self.training:
134 | num = self.k * self.dilation
135 | randnum = torch.randperm(num)[:self.k]
136 | edge_index = edge_index[:, :, :, randnum]
137 | else:
138 | edge_index = edge_index[:, :, :, ::self.dilation]
139 | else:
140 | edge_index = edge_index[:, :, :, ::self.dilation]
141 | return edge_index
142 |
143 |
144 | class DenseDilatedKnnGraph(nn.Module):
145 | """
146 | Find the neighbors' indices based on dilated knn
147 | """
148 | def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
149 | super(DenseDilatedKnnGraph, self).__init__()
150 | self.dilation = dilation
151 | self.stochastic = stochastic
152 | self.epsilon = epsilon
153 | self.k = k
154 | self._dilated = DenseDilated(k, dilation, stochastic, epsilon)
155 |
156 | def forward(self, x, y=None, relative_pos=None):
157 | if y is not None:
158 | #### normalize
159 | x = F.normalize(x, p=2.0, dim=1)
160 | y = F.normalize(y, p=2.0, dim=1)
161 | ####
162 | edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
163 | else:
164 | #### normalize
165 | x = F.normalize(x, p=2.0, dim=1)
166 | ####
167 | edge_index = dense_knn_matrix(x, self.k * self.dilation, relative_pos)
168 | return self._dilated(edge_index)
169 |
--------------------------------------------------------------------------------
/src/opt_transfer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def int2bool(i):
4 | i = int(i)
5 | assert i == 0 or i == 1
6 | return i == 1
7 |
8 |
9 | def get_args_parser(add_help=True):
10 |
11 | parser = argparse.ArgumentParser(description="Transfer Learning Training", add_help=add_help)
12 |
13 | parser.add_argument("--model-type", type=str, help="model name")
14 | parser.add_argument(
15 | "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
16 | )
17 | parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
18 | parser.add_argument(
19 | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
20 | )
21 | parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
22 | parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
23 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
24 | parser.add_argument(
25 | "--wd",
26 | "--weight-decay",
27 | default=1e-4,
28 | type=float,
29 | metavar="W",
30 | help="weight decay (default: 1e-4)",
31 | dest="weight_decay",
32 | )
33 | parser.add_argument(
34 | "--norm-weight-decay",
35 | default=None,
36 | type=float,
37 | help="weight decay for Normalization layers (default: None, same value as --wd)",
38 | )
39 | parser.add_argument(
40 | "--bias-weight-decay",
41 | default=None,
42 | type=float,
43 | help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
44 | )
45 | parser.add_argument(
46 | "--transformer-embedding-decay",
47 | default=None,
48 | type=float,
49 | help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
50 | )
51 | parser.add_argument(
52 | "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
53 | )
54 | parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)")
55 | parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)")
56 | parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")
57 | parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
58 | parser.add_argument(
59 | "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)"
60 | )
61 | parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
62 | parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
63 | parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
64 | parser.add_argument("--lr-min", default=0.0, type=float, help="minimum lr of lr schedule (default: 0.0)")
65 | parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
66 | parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
67 | parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
68 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
69 | parser.add_argument(
70 | "--cache-dataset",
71 | dest="cache_dataset",
72 | help="Cache the datasets for quicker initialization. It also serializes the transforms",
73 | action="store_true",
74 | )
75 | parser.add_argument(
76 | "--sync-bn",
77 | dest="sync_bn",
78 | help="Use sync batch norm",
79 | action="store_true",
80 | )
81 | parser.add_argument(
82 | "--test-only",
83 | dest="test_only",
84 | help="Only test the model",
85 | action="store_true",
86 | )
87 | parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
88 | parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
89 | parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
90 | parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
91 |
92 | # Mixed precision training parameters
93 | parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
94 |
95 | # distributed training parameters
96 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
97 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
98 | parser.add_argument(
99 | "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
100 | )
101 | parser.add_argument(
102 | "--model-ema-steps",
103 | type=int,
104 | default=32,
105 | help="the number of iterations that controls how often to update the EMA model (default: 32)",
106 | )
107 | parser.add_argument(
108 | "--model-ema-decay",
109 | type=float,
110 | default=0.99998,
111 | help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
112 | )
113 | parser.add_argument(
114 | "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
115 | )
116 | parser.add_argument(
117 | "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
118 | )
119 |
120 | parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
121 | parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
122 | parser.add_argument(
123 | "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
124 | )
125 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
126 | parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
127 | parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
128 |
129 |
130 | parser.add_argument("--seed", default=None, type=int)
131 | parser.add_argument("--save-dir", default=None, type=str)
132 |
133 | parser.add_argument("--dataset", type=str,
134 | help="Source dataset.")
135 |
136 | parser.add_argument("--checkpoint", default="", type=str, help="path of checkpoint")
137 | parser.add_argument("--root", type=str, default="/scratch/dataset",
138 | help="Dataset root folder.")
139 | parser.add_argument("--loss", default="cross_entropy", type=str, choices=["cross_entropy", "nll"])
140 |
141 |
142 | parser.add_argument('--use-shift', type=int2bool, default=0) # 1 == True
143 | parser.add_argument('--adapt-knn', type=int2bool, default=0) # 1 == True
144 | parser.add_argument('--freezed', type=int2bool, default=1) # 1 == True
145 | parser.add_argument("--crop-size", default=None, type=int)
146 |
147 |
148 | parser.add_argument("--num-gpu", default=1, type=int)
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 | return parser
--------------------------------------------------------------------------------
/env_wignet.yaml:
--------------------------------------------------------------------------------
1 | name: wignet
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - asttokens=2.0.5=pyhd3eb1b0_0
10 | - autopep8=1.6.0=pyhd3eb1b0_1
11 | - backcall=0.2.0=pyhd3eb1b0_0
12 | - beautifulsoup4=4.11.1=py310h06a4308_0
13 | - blas=1.0=mkl
14 | - brotlipy=0.7.0=py310h7f8727e_1002
15 | - bzip2=1.0.8=h7b6447c_0
16 | - c-ares=1.18.1=h7f8727e_0
17 | - ca-certificates=2023.08.22=h06a4308_0
18 | - certifi=2023.11.17=py310h06a4308_0
19 | - cffi=1.15.1=py310h74dc2b5_0
20 | - chardet=4.0.0=py310h06a4308_1003
21 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
22 | - cmake=3.22.1=h1fce559_0
23 | - conda=22.11.1=py310h06a4308_4
24 | - conda-build=3.23.3=py310h06a4308_0
25 | - conda-package-handling=1.9.0=py310h5eee18b_1
26 | - cryptography=38.0.1=py310h9ce1e76_0
27 | - cuda=11.6.1=0
28 | - cuda-cccl=11.6.55=hf6102b2_0
29 | - cuda-command-line-tools=11.6.2=0
30 | - cuda-compiler=11.6.2=0
31 | - cuda-cudart=11.6.55=he381448_0
32 | - cuda-cudart-dev=11.6.55=h42ad0f4_0
33 | - cuda-cuobjdump=11.6.124=h2eeebcb_0
34 | - cuda-cupti=11.6.124=h86345e5_0
35 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0
36 | - cuda-driver-dev=11.6.55=0
37 | - cuda-gdb=12.0.90=0
38 | - cuda-libraries=11.6.1=0
39 | - cuda-libraries-dev=11.6.1=0
40 | - cuda-memcheck=11.8.86=0
41 | - cuda-nsight=12.0.78=0
42 | - cuda-nsight-compute=12.0.0=0
43 | - cuda-nvcc=11.6.124=hbba6d2d_0
44 | - cuda-nvdisasm=12.0.76=0
45 | - cuda-nvml-dev=11.6.55=haa9ef22_0
46 | - cuda-nvprof=12.0.90=0
47 | - cuda-nvprune=11.6.124=he22ec0a_0
48 | - cuda-nvrtc=11.6.124=h020bade_0
49 | - cuda-nvrtc-dev=11.6.124=h249d397_0
50 | - cuda-nvtx=11.6.124=h0630a44_0
51 | - cuda-nvvp=12.0.90=0
52 | - cuda-runtime=11.6.1=0
53 | - cuda-samples=11.6.101=h8efea70_0
54 | - cuda-sanitizer-api=12.0.90=0
55 | - cuda-toolkit=11.6.1=0
56 | - cuda-tools=11.6.1=0
57 | - cuda-visual-tools=11.6.1=0
58 | - decorator=5.1.1=pyhd3eb1b0_0
59 | - executing=0.8.3=pyhd3eb1b0_0
60 | - expat=2.4.9=h6a678d5_0
61 | - ffmpeg=4.3=hf484d3e_0
62 | - filelock=3.6.0=pyhd3eb1b0_0
63 | - flit-core=3.6.0=pyhd3eb1b0_0
64 | - freetype=2.12.1=h4a9f257_0
65 | - gds-tools=1.5.0.59=0
66 | - giflib=5.2.1=h7b6447c_0
67 | - glob2=0.7=pyhd3eb1b0_0
68 | - gmp=6.2.1=h295c915_3
69 | - gnutls=3.6.15=he1e5248_0
70 | - icu=58.2=he6710b0_3
71 | - idna=3.4=py310h06a4308_0
72 | - intel-openmp=2021.4.0=h06a4308_3561
73 | - ipython=8.7.0=py310h06a4308_0
74 | - jedi=0.18.1=py310h06a4308_1
75 | - jinja2=2.11.3=pyhd3eb1b0_0
76 | - jpeg=9e=h7f8727e_0
77 | - krb5=1.19.2=hac12032_0
78 | - lame=3.100=h7b6447c_0
79 | - lcms2=2.12=h3be6417_0
80 | - ld_impl_linux-64=2.38=h1181459_1
81 | - lerc=3.0=h295c915_0
82 | - libarchive=3.6.1=hab531cd_0
83 | - libcublas=11.9.2.110=h5e84587_0
84 | - libcublas-dev=11.9.2.110=h5c901ab_0
85 | - libcufft=10.7.1.112=hf425ae0_0
86 | - libcufft-dev=10.7.1.112=ha5ce4c0_0
87 | - libcufile=1.5.0.59=0
88 | - libcufile-dev=1.5.0.59=0
89 | - libcurand=10.3.1.50=0
90 | - libcurand-dev=10.3.1.50=0
91 | - libcurl=7.86.0=h91b91d3_0
92 | - libcusolver=11.3.4.124=h33c3c4e_0
93 | - libcusparse=11.7.2.124=h7538f96_0
94 | - libcusparse-dev=11.7.2.124=hbbe9722_0
95 | - libdeflate=1.8=h7f8727e_5
96 | - libedit=3.1.20221030=h5eee18b_0
97 | - libev=4.33=h7f8727e_1
98 | - libffi=3.3=he6710b0_2
99 | - libgcc-ng=11.2.0=h1234567_1
100 | - libgomp=11.2.0=h1234567_1
101 | - libiconv=1.16=h7f8727e_2
102 | - libidn2=2.3.2=h7f8727e_0
103 | - liblief=0.12.3=h6a678d5_0
104 | - libnghttp2=1.46.0=hce63b2e_0
105 | - libnpp=11.6.3.124=hd2722f0_0
106 | - libnpp-dev=11.6.3.124=h3c42840_0
107 | - libnvjpeg=11.6.2.124=hd473ad6_0
108 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0
109 | - libpng=1.6.37=hbc83047_0
110 | - libssh2=1.10.0=h8f2d780_0
111 | - libstdcxx-ng=11.2.0=h1234567_1
112 | - libtasn1=4.16.0=h27cfd23_0
113 | - libtiff=4.4.0=hecacb30_2
114 | - libunistring=0.9.10=h27cfd23_0
115 | - libuuid=1.41.5=h5eee18b_0
116 | - libuv=1.40.0=h7b6447c_0
117 | - libwebp=1.2.4=h11a3e52_0
118 | - libwebp-base=1.2.4=h5eee18b_0
119 | - libxml2=2.9.14=h74e7548_0
120 | - lz4-c=1.9.4=h6a678d5_0
121 | - markupsafe=2.0.1=py310h7f8727e_0
122 | - matplotlib-inline=0.1.6=py310h06a4308_0
123 | - mkl=2021.4.0=h06a4308_640
124 | - mkl-service=2.4.0=py310h7f8727e_0
125 | - mkl_fft=1.3.1=py310hd6ae3a3_0
126 | - mkl_random=1.2.2=py310h00e6091_0
127 | - ncurses=6.3=h5eee18b_3
128 | - nettle=3.7.3=hbbd107a_1
129 | - nsight-compute=2022.4.0.15=0
130 | - numpy=1.22.3=py310hfa59a62_0
131 | - numpy-base=1.22.3=py310h9585f30_0
132 | - openh264=2.1.1=h4ff587b_0
133 | - openssl=1.1.1w=h7f8727e_0
134 | - parso=0.8.3=pyhd3eb1b0_0
135 | - patch=2.7.6=h7b6447c_1001
136 | - patchelf=0.15.0=h6a678d5_0
137 | - pexpect=4.8.0=pyhd3eb1b0_3
138 | - pickleshare=0.7.5=pyhd3eb1b0_1003
139 | - pillow=9.3.0=py310hace64e9_0
140 | - pip=22.3.1=py310h06a4308_0
141 | - pkginfo=1.8.3=py310h06a4308_0
142 | - pluggy=1.0.0=py310h06a4308_1
143 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
144 | - psutil=5.9.0=py310h5eee18b_0
145 | - ptyprocess=0.7.0=pyhd3eb1b0_2
146 | - pure_eval=0.2.2=pyhd3eb1b0_0
147 | - py-lief=0.12.3=py310h6a678d5_0
148 | - pycodestyle=2.11.1=py310h06a4308_0
149 | - pycosat=0.6.4=py310h5eee18b_0
150 | - pycparser=2.21=pyhd3eb1b0_0
151 | - pyopenssl=22.0.0=pyhd3eb1b0_0
152 | - pysocks=1.7.1=py310h06a4308_0
153 | - python=3.10.8=haa1d7c7_0
154 | - python-libarchive-c=2.9=pyhd3eb1b0_1
155 | - pytorch=1.13.1=py3.10_cuda11.6_cudnn8.3.2_0
156 | - pytorch-cuda=11.6=h867d48c_1
157 | - pytorch-mutex=1.0=cuda
158 | - pyyaml=6.0=py310h5eee18b_1
159 | - readline=8.2=h5eee18b_0
160 | - rhash=1.4.1=h3c74f83_1
161 | - ripgrep=13.0.0=hbdeaff8_0
162 | - ruamel.yaml=0.17.21=py310h5eee18b_0
163 | - ruamel.yaml.clib=0.2.6=py310h5eee18b_1
164 | - six=1.16.0=pyhd3eb1b0_1
165 | - soupsieve=2.3.2.post1=py310h06a4308_0
166 | - sqlite=3.40.0=h5082296_0
167 | - stack_data=0.2.0=pyhd3eb1b0_0
168 | - tk=8.6.12=h1ccaba5_0
169 | - toml=0.10.2=pyhd3eb1b0_0
170 | - toolz=0.12.0=py310h06a4308_0
171 | - torchtext=0.14.1=py310
172 | - torchvision=0.14.1=py310_cu116
173 | - traitlets=5.7.1=py310h06a4308_0
174 | - typing_extensions=4.4.0=py310h06a4308_0
175 | - urllib3=1.26.13=py310h06a4308_0
176 | - wcwidth=0.2.5=pyhd3eb1b0_0
177 | - wheel=0.37.1=pyhd3eb1b0_0
178 | - xz=5.2.8=h5eee18b_0
179 | - yaml=0.2.5=h7b6447c_0
180 | - zlib=1.2.13=h5eee18b_0
181 | - zstd=1.5.2=ha4553b6_0
182 | - pip:
183 | - addict==2.4.0
184 | - aliyun-python-sdk-core==2.14.0
185 | - aliyun-python-sdk-kms==2.16.2
186 | - appdirs==1.4.4
187 | - astunparse==1.6.3
188 | - attrs==22.1.0
189 | - click==8.1.3
190 | - colorama==0.4.6
191 | - comm==0.1.3
192 | - compressai==1.2.4
193 | - contourpy==1.0.7
194 | - crcmod==1.7
195 | - cycler==0.11.0
196 | - debugpy==1.6.6
197 | - dnspython==2.2.1
198 | - docker-pycreds==0.4.0
199 | - einops==0.6.1
200 | - exceptiongroup==1.0.4
201 | - expecttest==0.1.4
202 | - fonttools==4.39.3
203 | - future==0.18.2
204 | - gdown==4.7.1
205 | - gitdb==4.0.10
206 | - gitpython==3.1.31
207 | - huggingface-hub==0.13.3
208 | - hypothesis==6.61.0
209 | - importlib-metadata==7.0.1
210 | - iniconfig==2.0.0
211 | - ipykernel==6.22.0
212 | - ipywidgets==8.0.6
213 | - jmespath==0.10.0
214 | - joblib==1.2.0
215 | - jupyter-client==8.1.0
216 | - jupyter-core==5.3.0
217 | - jupyterlab-widgets==3.0.7
218 | - kiwisolver==1.4.4
219 | - kmedoids==0.5.0
220 | - llvmlite==0.41.1
221 | - markdown==3.5.2
222 | - markdown-it-py==3.0.0
223 | - matplotlib==3.7.1
224 | - mdurl==0.1.2
225 | - mmcv==2.1.0
226 | - mmdet==3.3.0
227 | - mmengine==0.10.3
228 | - model-index==0.1.11
229 | - mpmath==1.2.1
230 | - nest-asyncio==1.5.6
231 | - ninja==1.11.1
232 | - numba==0.58.1
233 | - opencv-python==4.8.0.74
234 | - opendatalab==0.0.10
235 | - openmim==0.3.9
236 | - openxlab==0.0.34
237 | - ordered-set==4.1.0
238 | - oss2==2.17.0
239 | - packaging==23.0
240 | - pandas==2.0.3
241 | - pathtools==0.1.2
242 | - platformdirs==4.2.0
243 | - protobuf==4.22.1
244 | - py-cpuinfo==9.0.0
245 | - pycocotools==2.0.7
246 | - pycryptodome==3.20.0
247 | - pygments==2.17.2
248 | - pyparsing==3.0.9
249 | - pytest==7.2.2
250 | - pytest-gc==0.0.1
251 | - python-dateutil==2.8.2
252 | - python-etcd==0.4.5
253 | - python-papi==5.5.1.5
254 | - pytorch-msssim==0.2.1
255 | - pytz==2023.4
256 | - pyzmq==25.0.2
257 | - requests==2.28.2
258 | - rich==13.4.2
259 | - scikit-learn==1.2.2
260 | - scipy==1.10.1
261 | - seaborn==0.12.2
262 | - sentry-sdk==1.19.0
263 | - setproctitle==1.3.2
264 | - setuptools==60.2.0
265 | - shapely==2.0.3
266 | - smmap==5.0.0
267 | - sortedcontainers==2.4.0
268 | - sympy==1.11.1
269 | - tabulate==0.9.0
270 | - termcolor==2.4.0
271 | - terminaltables==3.1.10
272 | - threadpoolctl==3.1.0
273 | - timm==0.6.13
274 | - tomli==2.0.1
275 | - torch-geometric==2.3.1
276 | - torchac==0.9.3
277 | - torchelastic==0.2.2
278 | - torchprofile==0.0.4
279 | - tornado==6.2
280 | - tqdm==4.65.2
281 | - types-dataclasses==0.6.6
282 | - tzdata==2023.3
283 | - ultralytics==8.0.184
284 | - wandb==0.14.0
285 | - widgetsnbextension==4.0.7
286 | - yapf==0.40.2
287 | - zipp==3.17.0
288 |
--------------------------------------------------------------------------------
/src/model/transfer_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from timm.models import create_model, resume_checkpoint
3 | from timm.models.helpers import clean_state_dict
4 | import torch.nn as nn
5 | from torch.nn import Sequential as Seq
6 | from gcn_lib import act_layer
7 | import numpy as np
8 | from model import pyramid_vig
9 | from model import wignn
10 | from model import wignn_256
11 | from model import greedyvig
12 | from model import mobilevig
13 | import torchvision
14 |
15 | from torchprofile import profile_macs
16 |
17 | from collections import OrderedDict
18 | import sys
19 |
20 |
21 | def remove_pos(state_dict):
22 | cleaned_state_dict = OrderedDict()
23 | for k, v in state_dict.items():
24 | if 'pos_embed' not in k and 'attn_mask' not in k and 'adj_mask' not in k:
25 | cleaned_state_dict[k] = v
26 | return cleaned_state_dict
27 |
28 | def remove_relative_pos(state_dict):
29 | cleaned_state_dict = OrderedDict()
30 | for k, v in state_dict.items():
31 | if 'relative_pos' not in k and 'pos_embed' not in k:
32 | cleaned_state_dict[k] = v
33 | return cleaned_state_dict
34 |
35 | def get_model(model_type, use_shift = False, adapt_knn = False, checkpoint = None, pretrained = True, freezed = True, dataset = 'PET', crop_size = None):
36 |
37 | if dataset == 'CelebA':
38 | n_classes = 307
39 | else:
40 | raise NotImplementedError(f'Dataset: {dataset} not yet implemented')
41 |
42 |
43 | pretrained_creation = False
44 |
45 | if('wignn' in model_type):
46 | if crop_size is not None:
47 | model = create_model(
48 | model_type,
49 | pretrained=pretrained_creation,
50 | use_shifts = use_shift,
51 | adapt_knn = adapt_knn,
52 | img_size = crop_size
53 | )
54 | else:
55 | model = create_model(
56 | model_type,
57 | pretrained=pretrained_creation,
58 | use_shifts = use_shift,
59 | adapt_knn = adapt_knn
60 | )
61 | elif('pvig' in model_type):
62 | if crop_size is not None:
63 | model = create_model(
64 | model_type,
65 | pretrained=pretrained_creation,
66 | img_size = crop_size
67 | )
68 | else:
69 | model = create_model(
70 | model_type,
71 | pretrained=pretrained_creation
72 | )
73 | elif('GreedyViG' in model_type):
74 | model = create_model(
75 | model_type,
76 | num_classes=1000,
77 | distillation=False,
78 | pretrained=pretrained_creation
79 | )
80 | elif('mobilevig' in model_type):
81 | model = create_model(
82 | model_type,
83 | )
84 |
85 | else:
86 | raise NotImplementedError(f'Model: {model_type} not yet implemented')
87 |
88 |
89 |
90 |
91 |
92 |
93 | # load checkpoint for our models and ViG
94 | if pretrained and checkpoint != '':
95 | if 'wignn' in model_type:
96 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints'
97 |
98 | checkpoint = torch.load(checkpoint, map_location='cpu')
99 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
100 | print('Restoring model state from checkpoint...')
101 | state_dict = clean_state_dict(checkpoint['state_dict'])
102 |
103 | if crop_size is not None:
104 | state_dict = remove_pos(state_dict)
105 |
106 | model.load_state_dict(state_dict, strict = False)
107 |
108 | print(f'Pretrain weights for {model_type} loaded.')
109 |
110 | elif 'pvig' in model_type:
111 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints'
112 |
113 | state_dict = torch.load(checkpoint)
114 | if crop_size is not None:
115 | state_dict = remove_relative_pos(state_dict)
116 | model.load_state_dict(state_dict, strict=False)
117 | print('Pretrain weights for vig loaded.')
118 |
119 | elif 'GreedyViG' in model_type:
120 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints'
121 |
122 | checkpoint = torch.load(checkpoint, map_location='cpu')
123 | checkpoint_model = checkpoint['state_dict']
124 |
125 | state_dict = model.state_dict()
126 | for k in ['dist_head.weight', 'dist_head.bias']:
127 | # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
128 | # print(f"Removing key {k} from pretrained checkpoint")
129 | del checkpoint_model[k]
130 |
131 | model.load_state_dict(checkpoint_model, strict=True)
132 | print('Pretrain weights for GreedyViG loaded.')
133 |
134 | elif 'mobilevig' in model_type:
135 | assert checkpoint is not None, f'Cannot start from pretrained {model_type} model without checkpoints'
136 |
137 | checkpoint = torch.load(checkpoint, map_location='cpu')
138 | checkpoint_model = checkpoint['state_dict']
139 |
140 | state_dict = model.state_dict()
141 | for k in ['dist_head.weight', 'dist_head.bias']:
142 | # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
143 | # print(f"Removing key {k} from pretrained checkpoint")
144 | del checkpoint_model[k]
145 |
146 | model.load_state_dict(checkpoint_model, strict=True)
147 | print('Pretrain weights for MobileViG loaded.')
148 |
149 |
150 |
151 |
152 |
153 | # freeze the model
154 | for param in model.parameters():
155 | if freezed:
156 | param.requires_grad = False
157 | else:
158 | param.requires_grad = True
159 |
160 |
161 | if 'pvig' in model_type or 'wignn' in model_type:
162 | model.prediction = Seq(nn.Conv2d(model.prediction[0].in_channels, 1024, 1, bias=True),
163 | nn.BatchNorm2d(1024),
164 | act_layer('gelu'),
165 | nn.Dropout(0.0),
166 | nn.Conv2d(1024, n_classes, 1, bias=True))
167 | model.prediction.requires_grad = True
168 | elif 'GreedyViG' in model_type:
169 | model.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1),
170 | nn.Conv2d(model.prediction[1].in_channels, 768, kernel_size=1, bias=True),
171 | nn.BatchNorm2d(768),
172 | nn.GELU(),
173 | nn.Dropout(0.0))
174 |
175 | model.head = nn.Conv2d(768, n_classes, kernel_size=1, bias=True)
176 | model.dist_head = nn.Conv2d(768, n_classes, 1, bias=True)
177 | model.prediction.requires_grad = True
178 | model.head.requires_grad = True
179 | model.dist_head.requires_grad = True
180 |
181 | elif 'mobilevig' in model_type:
182 | model.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1),
183 | nn.Conv2d(256, 512, 1, bias=True),
184 | nn.BatchNorm2d(512),
185 | nn.GELU(),
186 | nn.Dropout(0.))
187 |
188 | model.head = nn.Conv2d(512, n_classes, 1, bias=True)
189 | model.dist_head = nn.Conv2d(512, n_classes, 1, bias=True)
190 | model.prediction.requires_grad = True
191 | model.head.requires_grad = True
192 | model.dist_head.requires_grad = True
193 |
194 | else:
195 | raise NotImplementedError(f'Model {model_type} not yet implemented\n{model}')
196 |
197 |
198 | params = sum(p.numel() for p in model.parameters())
199 | trainable_parameters = filter(lambda p: p.requires_grad, model.parameters())
200 | trainable_parameters = sum([np.prod(p.size()) for p in trainable_parameters])
201 |
202 |
203 |
204 | return model, params, trainable_parameters, n_classes
205 |
206 |
207 |
208 |
209 |
210 | if __name__ == '__main__':
211 | checkpoint = 'path/to/checkpoint'
212 | model_type = 'wignn_ti_256_gelu'
213 |
214 |
215 | dataset = 'CelebA'
216 | model, params, trainable_parameters, n_classes = get_model(model_type = model_type,
217 | use_shift=True,
218 | adapt_knn=True,
219 | checkpoint = checkpoint,
220 | freezed=True,
221 | dataset = dataset,
222 | crop_size = 512)
223 | model.eval()
224 | model.cuda()
225 | x = torch.rand((1,3,512,512)).cuda()
226 | # print(model)
227 | print(f"Parameters: {params}")
228 | print(f"Trainable Parameters: {trainable_parameters}")
229 |
230 | # out = model(x)
231 | # print(out.shape)
232 | macs = profile_macs(model, x)
233 | print(f'\n\n!!!!! macs : {macs*10**-9}\n\n')
--------------------------------------------------------------------------------
/src/model/wignn.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import Sequential as Seq
7 |
8 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9 | from timm.models.helpers import load_pretrained
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 | from timm.models.registry import register_model
12 |
13 | from gcn_lib import act_layer, WindowGrapher
14 | from timm.models import create_model
15 | import time
16 |
17 | from torchprofile import profile_macs
18 |
19 | import sys
20 |
21 | def _cfg(url='', **kwargs):
22 | return {
23 | 'url': url,
24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
25 | 'crop_pct': .9, 'interpolation': 'bicubic',
26 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
27 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
28 | **kwargs
29 | }
30 |
31 |
32 | default_cfgs = {
33 | 'wignn_224_gelu': _cfg(
34 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
35 | ),
36 | 'wignn_b_224_gelu': _cfg(
37 | crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
38 | ),
39 | }
40 |
41 |
42 | class FFN(nn.Module):
43 | def __init__(self, in_features, hidden_features=None, out_features=None, act='relu', drop_path=0.0):
44 | super().__init__()
45 | out_features = out_features or in_features
46 | hidden_features = hidden_features or in_features
47 | self.fc1 = nn.Sequential(
48 | nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
49 | nn.BatchNorm2d(hidden_features),
50 | )
51 | self.act = act_layer(act)
52 | self.fc2 = nn.Sequential(
53 | nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
54 | nn.BatchNorm2d(out_features),
55 | )
56 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
57 |
58 | def forward(self, x):
59 | shortcut = x
60 | x = self.fc1(x)
61 | x = self.act(x)
62 | x = self.fc2(x)
63 | x = self.drop_path(x) + shortcut
64 | return x#.reshape(B, C, N, 1)
65 |
66 |
67 | class Stem(nn.Module):
68 | """ Image to Visual Embedding
69 | Overlap: https://arxiv.org/pdf/2106.13797.pdf
70 | """
71 | def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
72 | super().__init__()
73 | self.convs = nn.Sequential(
74 | nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1),
75 | nn.BatchNorm2d(out_dim//2),
76 | act_layer(act),
77 | nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
78 | nn.BatchNorm2d(out_dim),
79 | act_layer(act),
80 | nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
81 | nn.BatchNorm2d(out_dim),
82 | )
83 |
84 | def forward(self, x):
85 | x = self.convs(x)
86 | return x
87 |
88 |
89 | class Downsample(nn.Module):
90 | """ Convolution-based downsample
91 | """
92 | def __init__(self, in_dim=3, out_dim=768):
93 | super().__init__()
94 | self.conv = nn.Sequential(
95 | nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1),
96 | nn.BatchNorm2d(out_dim),
97 | )
98 |
99 | def forward(self, x):
100 | x = self.conv(x)
101 | return x
102 |
103 |
104 | class DeepGCN(torch.nn.Module):
105 | def __init__(self, opt):
106 | super(DeepGCN, self).__init__()
107 | print(opt)
108 | self.k = opt.k # knn
109 | self.act = opt.act # activation layer {relu, prelu, leakyrelu, gelu, hswish}
110 | self.norm = opt.norm # batch or instance normalization {batch, instance}
111 | self.bias = opt.bias # bias of conv layer True or False
112 | self.epsilon = opt.epsilon # stochastic epsilon for gcn
113 | self.stochastic = opt.use_stochastic # stochastic for gcn, True or False
114 | self.conv = opt.conv # graph conv layer {edge, mr}
115 | self.emb_dims = opt.emb_dims # # Dimension of embeddings
116 | self.drop_path = opt.drop_path
117 |
118 | self.blocks = opt.blocks # [2,2,6,2] # number of basic blocks in the backbone
119 | self.channels = opt.channels # [80, 160, 400, 640] # number of channels of deep features
120 |
121 | self.img_size = opt.img_size
122 | self.use_shifts = opt.use_shifts
123 |
124 |
125 | self.n_blocks = sum(self.blocks)
126 | self.window_size = [opt.windows_size for _ in range(len(self.blocks))]
127 |
128 | if opt.use_reduce_ratios:
129 | self.reduce_ratios = [2, 2, 1, 1]
130 | else:
131 | self.reduce_ratios = [1, 1, 1, 1]
132 |
133 | adapt_knn = opt.adapt_knn
134 | print(f'Created Model wignn ({self.img_size}) Window: {self.window_size}')
135 | print(f'Use shifting windows: {self.use_shifts} adapt knn: {adapt_knn}')
136 |
137 | print(f'Knn: {self.k}')
138 | print(f'Reduce ratios: {self.reduce_ratios}')
139 |
140 |
141 | print(f'Channel: {self.channels}')
142 | print(f'Blocks: {self.blocks}')
143 |
144 |
145 |
146 | self.dpr = [x.item() for x in torch.linspace(0, self.drop_path, self.n_blocks)] # stochastic depth decay rule
147 | self.num_knn = [int(x.item()) for x in torch.linspace(self.k, self.k, self.n_blocks)] # number of knn's k
148 | # max_dilation = 49 // max(num_knn)
149 |
150 | self.stem = Stem(out_dim=self.channels[0], act=self.act)
151 | self.pos_embed = nn.Parameter(torch.zeros(1, self.channels[0], self.img_size//4, self.img_size//4))
152 |
153 | self.backbone = nn.ModuleList([])
154 | idx = 0
155 | for i in range(len(self.blocks)):
156 | if i > 0:
157 | self.backbone.append(Downsample(self.channels[i-1], self.channels[i]))
158 |
159 | for j in range(self.blocks[i]):
160 | shift_size = 0
161 | if j % 2 != 0 and self.use_shifts:
162 | shift_size = self.window_size[i] // 2
163 |
164 | self.backbone += [
165 | Seq(
166 | WindowGrapher(
167 | in_channels = self.channels[i],
168 | kernel_size = self.num_knn[idx],
169 | windows_size = self.window_size[i],
170 | dilation = 1,
171 | conv = self.conv,
172 | act = self.act,
173 | norm = self.norm,
174 | bias = self.bias,
175 | stochastic = self.stochastic,
176 | epsilon = self.epsilon,
177 | drop_path = self.dpr[idx],
178 | relative_pos = True,
179 | shift_size = shift_size,
180 | r = self.reduce_ratios[i],
181 | input_resolution = (
182 | (self.img_size//4) // (2 ** i),
183 | (self.img_size//4) // (2 ** i)),
184 | adapt_knn=adapt_knn
185 | ),
186 | FFN(self.channels[i], self.channels[i] * 4, act=self.act, drop_path=self.dpr[idx])
187 | )]
188 | idx += 1
189 | self.backbone = Seq(*self.backbone)
190 |
191 | self.prediction = Seq(nn.Conv2d(self.channels[-1], self.emb_dims, 1, bias=True),
192 | nn.BatchNorm2d(self.emb_dims),
193 | act_layer(self.act),
194 | nn.Dropout(opt.dropout),
195 | nn.Conv2d(self.emb_dims, opt.n_classes, 1, bias=True))
196 | self.model_init()
197 |
198 | def model_init(self):
199 | for m in self.modules():
200 | if isinstance(m, torch.nn.Conv2d):
201 | torch.nn.init.kaiming_normal_(m.weight)
202 | m.weight.requires_grad = True
203 | if m.bias is not None:
204 | m.bias.data.zero_()
205 | m.bias.requires_grad = True
206 |
207 | def forward(self, inputs):
208 | x = self.stem(inputs) + self.pos_embed
209 | B, C, H, W = x.shape
210 | for i in range(len(self.backbone)):
211 | x = self.backbone[i](x)
212 |
213 | x = F.adaptive_avg_pool2d(x, 1)
214 | return self.prediction(x).squeeze(-1).squeeze(-1)
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 | class OptInit:
228 | def __init__(self,
229 | num_classes=1000,
230 | drop_path_rate=0.0,
231 | knn = 9,
232 | use_shifts = True,
233 | use_reduce_ratios = False,
234 | img_size = 224,
235 | adapt_knn = False,
236 |
237 | channels = None,
238 | blocks = None,
239 | **kwargs):
240 |
241 | self.k = knn # neighbor num (default:9)
242 | self.conv = 'mr' # graph conv layer {edge, mr}
243 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
244 | self.norm = 'batch' # batch or instance normalization {batch, instance}
245 | self.bias = True # bias of conv layer True or False
246 | self.dropout = 0.0 # dropout rate
247 | self.use_dilation = True # use dilated knn or not
248 | self.epsilon = 0.2 # stochastic epsilon for gcn
249 | self.use_stochastic = False # stochastic for gcn, True or False
250 | self.drop_path = drop_path_rate
251 | self.blocks = blocks # number of basic blocks in the backbone
252 | self.channels = channels # number of channels of deep features
253 | self.n_classes = num_classes # Dimension of out_channels
254 | self.emb_dims = 1024 # Dimension of embeddings
255 | self.windows_size = 7
256 |
257 | self.use_shifts = use_shifts
258 | self.img_size = img_size
259 | self.use_reduce_ratios = False
260 | self.adapt_knn = adapt_knn
261 |
262 | @register_model
263 | def wignn_ti_224_gelu(pretrained=False, **kwargs):
264 |
265 |
266 | opt = OptInit(**kwargs, channels = [48, 96, 240, 384], blocks= [2,2,6,2])
267 |
268 | model = DeepGCN(opt)
269 | model.default_cfg = default_cfgs['wignn_224_gelu']
270 | return model
271 |
272 |
273 | @register_model
274 | def wignn_s_224_gelu(pretrained=False, **kwargs):
275 |
276 | opt = OptInit(**kwargs, channels = [80, 160, 400, 640], blocks= [2,2,6,2])
277 |
278 | model = DeepGCN(opt)
279 | model.default_cfg = default_cfgs['wignn_224_gelu']
280 | return model
281 |
282 |
283 | @register_model
284 | def wignn_m_224_gelu(pretrained=False, **kwargs):
285 |
286 | opt = OptInit(**kwargs, channels = [96, 192, 384, 768], blocks= [2,2,16,2])
287 |
288 | model = DeepGCN(opt)
289 | model.default_cfg = default_cfgs['wignn_224_gelu']
290 | return model
291 |
292 |
293 | @register_model
294 | def wignn_b_224_gelu(pretrained=False, **kwargs):
295 |
296 | opt = OptInit(**kwargs, channels = [128, 256, 512, 1024], blocks= [2,2,18,2])
297 |
298 | model = DeepGCN(opt)
299 | model.default_cfg = default_cfgs['wignn_b_224_gelu']
300 | return model
301 |
302 |
303 |
304 | if __name__ == '__main__':
305 |
306 |
307 | # for img_size in [224,224*2,224*3,224*4,224*5]:
308 |
309 | # model = create_model(
310 | # 'wignn_ti_224_gelu',
311 | # knn = 9,
312 | # use_shifts = True,
313 | # img_size = img_size,
314 | # adapt_knn = True
315 | # )
316 |
317 | # model = model.cuda()
318 | # model.eval()
319 |
320 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
321 |
322 | # # print(model(x).shape)
323 |
324 | # macs = profile_macs(model, x)
325 | # print(f'\n\n!!!!! WiGNet macs ({img_size}): {macs}\n\n')
326 |
327 | for img_size in [224,224*2,224*3,224*4,224*5]:
328 |
329 |
330 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof:
331 |
332 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
333 |
334 | model = create_model(
335 | 'wignn_ti_224_gelu',
336 | knn = 9,
337 | use_shifts = True,
338 | img_size = img_size,
339 | adapt_knn = True
340 | )
341 |
342 | model = model.cuda()
343 | model.eval()
344 |
345 | _ = model(x)
346 |
347 | f = open("memory_WiGNet_model.txt", "a")
348 |
349 | f.write(prof.key_averages().table())
350 | f.write('\n\n')
351 | f.close()
352 |
353 | print('\n\n---------------\nResults saved in memory_WiGNet_model.txt')
354 |
--------------------------------------------------------------------------------
/src/model/pyramid_vig.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import Sequential as Seq
7 |
8 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9 | from timm.models.helpers import load_pretrained
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 | from timm.models.registry import register_model
12 |
13 | from gcn_lib import Grapher, act_layer
14 | from timm.models import create_model
15 |
16 | from torchprofile import profile_macs
17 | import time
18 |
19 | def _cfg(url='', **kwargs):
20 | return {
21 | 'url': url,
22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
23 | 'crop_pct': .9, 'interpolation': 'bicubic',
24 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
25 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
26 | **kwargs
27 | }
28 |
29 |
30 | default_cfgs = {
31 | 'vig_224_gelu': _cfg(
32 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
33 | ),
34 | 'vig_b_224_gelu': _cfg(
35 | crop_pct=0.95, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
36 | ),
37 | }
38 |
39 |
40 | class FFN(nn.Module):
41 | def __init__(self, in_features, hidden_features=None, out_features=None, act='relu', drop_path=0.0):
42 | super().__init__()
43 | out_features = out_features or in_features
44 | hidden_features = hidden_features or in_features
45 | self.fc1 = nn.Sequential(
46 | nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
47 | nn.BatchNorm2d(hidden_features),
48 | )
49 | self.act = act_layer(act)
50 | self.fc2 = nn.Sequential(
51 | nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
52 | nn.BatchNorm2d(out_features),
53 | )
54 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
55 |
56 | def forward(self, x):
57 | shortcut = x
58 | x = self.fc1(x)
59 | x = self.act(x)
60 | x = self.fc2(x)
61 | x = self.drop_path(x) + shortcut
62 | return x#.reshape(B, C, N, 1)
63 |
64 |
65 | class Stem(nn.Module):
66 | """ Image to Visual Embedding
67 | Overlap: https://arxiv.org/pdf/2106.13797.pdf
68 | """
69 | def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):
70 | super().__init__()
71 | self.convs = nn.Sequential(
72 | nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1),
73 | nn.BatchNorm2d(out_dim//2),
74 | act_layer(act),
75 | nn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),
76 | nn.BatchNorm2d(out_dim),
77 | act_layer(act),
78 | nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),
79 | nn.BatchNorm2d(out_dim),
80 | )
81 |
82 | def forward(self, x):
83 | x = self.convs(x)
84 | return x
85 |
86 |
87 | class Downsample(nn.Module):
88 | """ Convolution-based downsample
89 | """
90 | def __init__(self, in_dim=3, out_dim=768):
91 | super().__init__()
92 | self.conv = nn.Sequential(
93 | nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1),
94 | nn.BatchNorm2d(out_dim),
95 | )
96 |
97 | def forward(self, x):
98 | x = self.conv(x)
99 | return x
100 |
101 |
102 | class DeepGCN(torch.nn.Module):
103 | def __init__(self, opt):
104 | super(DeepGCN, self).__init__()
105 | print(opt)
106 | k = opt.k
107 | act = opt.act
108 | norm = opt.norm
109 | bias = opt.bias
110 | epsilon = opt.epsilon
111 | stochastic = opt.use_stochastic
112 | conv = opt.conv
113 | emb_dims = opt.emb_dims
114 | drop_path = opt.drop_path
115 |
116 | blocks = opt.blocks
117 | self.n_blocks = sum(blocks)
118 | channels = opt.channels
119 | reduce_ratios = [4, 2, 1, 1]
120 | dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)] # stochastic depth decay rule
121 | num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)] # number of knn's k
122 | max_dilation = 49 // max(num_knn)
123 |
124 | img_size = opt.img_size
125 |
126 |
127 |
128 | self.stem = Stem(out_dim=channels[0], act=act)
129 | self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], img_size//4, img_size//4))
130 | HW = img_size // 4 * img_size // 4
131 |
132 | self.backbone = nn.ModuleList([])
133 | idx = 0
134 | for i in range(len(blocks)):
135 | if i > 0:
136 | self.backbone.append(Downsample(channels[i-1], channels[i]))
137 | HW = HW // 4
138 | for j in range(blocks[i]):
139 | self.backbone += [
140 | Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act, norm,
141 | bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx],
142 | relative_pos=True),
143 | FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx])
144 | )]
145 | idx += 1
146 | self.backbone = Seq(*self.backbone)
147 |
148 | self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),
149 | nn.BatchNorm2d(1024),
150 | act_layer(act),
151 | nn.Dropout(opt.dropout),
152 | nn.Conv2d(1024, opt.n_classes, 1, bias=True))
153 | self.model_init()
154 |
155 | def model_init(self):
156 | for m in self.modules():
157 | if isinstance(m, torch.nn.Conv2d):
158 | torch.nn.init.kaiming_normal_(m.weight)
159 | m.weight.requires_grad = True
160 | if m.bias is not None:
161 | m.bias.data.zero_()
162 | m.bias.requires_grad = True
163 |
164 | def forward(self, inputs):
165 | x = self.stem(inputs) + self.pos_embed
166 | B, C, H, W = x.shape
167 | for i in range(len(self.backbone)):
168 | x = self.backbone[i](x)
169 |
170 | x = F.adaptive_avg_pool2d(x, 1)
171 |
172 |
173 | return self.prediction(x).squeeze(-1).squeeze(-1)
174 |
175 |
176 | @register_model
177 | def pvig_ti_224_gelu(pretrained=False, **kwargs):
178 | class OptInit:
179 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs):
180 | self.k = 9 # neighbor num (default:9)
181 | self.conv = 'mr' # graph conv layer {edge, mr}
182 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
183 | self.norm = 'batch' # batch or instance normalization {batch, instance}
184 | self.bias = True # bias of conv layer True or False
185 | self.dropout = 0.0 # dropout rate
186 | self.use_dilation = True # use dilated knn or not
187 | self.epsilon = 0.2 # stochastic epsilon for gcn
188 | self.use_stochastic = False # stochastic for gcn, True or False
189 | self.drop_path = drop_path_rate
190 | self.blocks = [2,2,6,2] # number of basic blocks in the backbone
191 | self.channels = [48, 96, 240, 384] # number of channels of deep features
192 | self.n_classes = num_classes # Dimension of out_channels
193 | self.emb_dims = 1024 # Dimension of embeddings
194 |
195 | self.img_size = img_size
196 |
197 |
198 | opt = OptInit(**kwargs)
199 | model = DeepGCN(opt)
200 | model.default_cfg = default_cfgs['vig_224_gelu']
201 | return model
202 |
203 |
204 | @register_model
205 | def pvig_s_224_gelu(pretrained=False, **kwargs):
206 | class OptInit:
207 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs):
208 | self.k = 9 # neighbor num (default:9)
209 | self.conv = 'mr' # graph conv layer {edge, mr}
210 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
211 | self.norm = 'batch' # batch or instance normalization {batch, instance}
212 | self.bias = True # bias of conv layer True or False
213 | self.dropout = 0.0 # dropout rate
214 | self.use_dilation = True # use dilated knn or not
215 | self.epsilon = 0.2 # stochastic epsilon for gcn
216 | self.use_stochastic = False # stochastic for gcn, True or False
217 | self.drop_path = drop_path_rate
218 | self.blocks = [2,2,6,2] # number of basic blocks in the backbone
219 | self.channels = [80, 160, 400, 640] # number of channels of deep features
220 | self.n_classes = num_classes # Dimension of out_channels
221 | self.emb_dims = 1024 # Dimension of embeddings
222 |
223 | self.img_size = img_size
224 |
225 |
226 |
227 | opt = OptInit(**kwargs)
228 | model = DeepGCN(opt)
229 | model.default_cfg = default_cfgs['vig_224_gelu']
230 | return model
231 |
232 |
233 | @register_model
234 | def pvig_m_224_gelu(pretrained=False, **kwargs):
235 | class OptInit:
236 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs):
237 | self.k = 9 # neighbor num (default:9)
238 | self.conv = 'mr' # graph conv layer {edge, mr}
239 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
240 | self.norm = 'batch' # batch or instance normalization {batch, instance}
241 | self.bias = True # bias of conv layer True or False
242 | self.dropout = 0.0 # dropout rate
243 | self.use_dilation = True # use dilated knn or not
244 | self.epsilon = 0.2 # stochastic epsilon for gcn
245 | self.use_stochastic = False # stochastic for gcn, True or False
246 | self.drop_path = drop_path_rate
247 | self.blocks = [2,2,16,2] # number of basic blocks in the backbone
248 | self.channels = [96, 192, 384, 768] # number of channels of deep features
249 | self.n_classes = num_classes # Dimension of out_channels
250 | self.emb_dims = 1024 # Dimension of embeddings
251 |
252 | self.img_size = img_size
253 |
254 |
255 | opt = OptInit(**kwargs)
256 | model = DeepGCN(opt)
257 | model.default_cfg = default_cfgs['vig_224_gelu']
258 | return model
259 |
260 |
261 | @register_model
262 | def pvig_b_224_gelu(pretrained=False, **kwargs):
263 | class OptInit:
264 | def __init__(self, num_classes=1000, drop_path_rate=0.0, img_size = 224, **kwargs):
265 | self.k = 9 # neighbor num (default:9)
266 | self.conv = 'mr' # graph conv layer {edge, mr}
267 | self.act = 'gelu' # activation layer {relu, prelu, leakyrelu, gelu, hswish}
268 | self.norm = 'batch' # batch or instance normalization {batch, instance}
269 | self.bias = True # bias of conv layer True or False
270 | self.dropout = 0.0 # dropout rate
271 | self.use_dilation = True # use dilated knn or not
272 | self.epsilon = 0.2 # stochastic epsilon for gcn
273 | self.use_stochastic = False # stochastic for gcn, True or False
274 | self.drop_path = drop_path_rate
275 | self.blocks = [2,2,18,2] # number of basic blocks in the backbone
276 | self.channels = [128, 256, 512, 1024] # number of channels of deep features
277 | self.n_classes = num_classes # Dimension of out_channels
278 | self.emb_dims = 1024 # Dimension of embeddings
279 |
280 | self.img_size = img_size
281 |
282 |
283 | opt = OptInit(**kwargs)
284 | model = DeepGCN(opt)
285 | model.default_cfg = default_cfgs['vig_b_224_gelu']
286 | return model
287 |
288 |
289 |
290 | if __name__ == '__main__':
291 |
292 | # for img_size in [224*5]:
293 |
294 | # model = create_model(
295 | # 'pvig_ti_224_gelu',
296 | # img_size = img_size
297 | # )
298 |
299 | # model = model.cuda()
300 | # model.eval()
301 |
302 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
303 |
304 | # # print(model(x).shape)
305 |
306 | # macs = profile_macs(model, x)
307 | # print(f'\n\n!!!!! P-ViG macs ({img_size}): {macs*10**-9}\n\n')
308 |
309 | for img_size in [224,224*2,224*3,224*4,224*5]:
310 |
311 |
312 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof:
313 |
314 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
315 |
316 | model = create_model(
317 | 'pvig_ti_224_gelu',
318 | img_size = img_size
319 | )
320 |
321 | model = model.cuda()
322 | model.eval()
323 |
324 | _ = model(x)
325 |
326 | f = open("memory_ViG_model.txt", "a")
327 |
328 | f.write(prof.key_averages().table())
329 | f.write('\n\n')
330 | f.close()
331 |
332 | print('\n\n---------------\nResults saved in memory_ViG_model.txt')
333 |
334 |
--------------------------------------------------------------------------------
/src/model/mobilevig.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import Sequential as Seq
5 |
6 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
7 | from timm.models.layers import DropPath
8 | from timm.models.registry import register_model
9 |
10 | from timm.models import create_model
11 | from torchprofile import profile_macs
12 |
13 |
14 |
15 |
16 | def _cfg(url='', **kwargs):
17 | return {
18 | 'url': url,
19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
20 | 'crop_pct': .9, 'interpolation': 'bicubic',
21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
22 | 'classifier': 'head',
23 | **kwargs
24 | }
25 |
26 |
27 | default_cfgs = {
28 | 'mobilevig': _cfg(crop_pct=0.9, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
29 | }
30 |
31 |
32 | class Stem(nn.Module):
33 | def __init__(self, input_dim, output_dim, activation=nn.GELU):
34 | super(Stem, self).__init__()
35 | self.stem = nn.Sequential(
36 | nn.Conv2d(input_dim, output_dim // 2, kernel_size=3, stride=2, padding=1),
37 | nn.BatchNorm2d(output_dim // 2),
38 | nn.GELU(),
39 | nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, stride=2, padding=1),
40 | nn.BatchNorm2d(output_dim),
41 | nn.GELU()
42 | )
43 |
44 | def forward(self, x):
45 | return self.stem(x)
46 |
47 |
48 | class MLP(nn.Module):
49 | """
50 | Implementation of MLP with 1*1 convolutions.
51 | Input: tensor with shape [B, C, H, W]
52 | """
53 |
54 | def __init__(self, in_features, hidden_features=None,
55 | out_features=None, drop=0., mid_conv=False):
56 | super().__init__()
57 | out_features = out_features or in_features
58 | hidden_features = hidden_features or in_features
59 | self.mid_conv = mid_conv
60 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
61 | self.act = nn.GELU()
62 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
63 | self.drop = nn.Dropout(drop)
64 |
65 | if self.mid_conv:
66 | self.mid = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
67 | groups=hidden_features)
68 | self.mid_norm = nn.BatchNorm2d(hidden_features)
69 |
70 | self.norm1 = nn.BatchNorm2d(hidden_features)
71 | self.norm2 = nn.BatchNorm2d(out_features)
72 |
73 | def forward(self, x):
74 | x = self.fc1(x)
75 | x = self.norm1(x)
76 | x = self.act(x)
77 |
78 | if self.mid_conv:
79 | x_mid = self.mid(x)
80 | x_mid = self.mid_norm(x_mid)
81 | x = self.act(x_mid)
82 | x = self.drop(x)
83 |
84 | x = self.fc2(x)
85 | x = self.norm2(x)
86 |
87 | x = self.drop(x)
88 | return x
89 |
90 |
91 | class InvertedResidual(nn.Module):
92 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5):
93 | super().__init__()
94 |
95 | mlp_hidden_dim = int(dim * mlp_ratio)
96 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop, mid_conv=True)
97 |
98 | self.drop_path = DropPath(drop_path) if drop_path > 0. \
99 | else nn.Identity()
100 | self.use_layer_scale = use_layer_scale
101 | if use_layer_scale:
102 | self.layer_scale_2 = nn.Parameter(
103 | layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
104 |
105 | def forward(self, x):
106 | if self.use_layer_scale:
107 | x = x + self.drop_path(self.layer_scale_2 * self.mlp(x))
108 | else:
109 | x = x + self.drop_path(self.mlp(x))
110 | return x
111 |
112 | class MRConv4d(nn.Module):
113 | """
114 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
115 |
116 | K is the number of superpatches, therefore hops equals res // K.
117 | """
118 | def __init__(self, in_channels, out_channels, K=2):
119 | super(MRConv4d, self).__init__()
120 | self.nn = nn.Sequential(
121 | nn.Conv2d(in_channels * 2, out_channels, 1),
122 | nn.BatchNorm2d(in_channels * 2),
123 | nn.GELU()
124 | )
125 | self.K = K
126 |
127 | def forward(self, x):
128 | B, C, H, W = x.shape
129 |
130 | x_j = x - x
131 | for i in range(self.K, H, self.K):
132 | x_c = x - torch.cat([x[:, :, -i:, :], x[:, :, :-i, :]], dim=2)
133 | x_j = torch.max(x_j, x_c)
134 | for i in range(self.K, W, self.K):
135 | x_r = x - torch.cat([x[:, :, :, -i:], x[:, :, :, :-i]], dim=3)
136 | x_j = torch.max(x_j, x_r)
137 |
138 | x = torch.cat([x, x_j], dim=1)
139 | return self.nn(x)
140 |
141 |
142 | class Grapher(nn.Module):
143 | """
144 | Grapher module with graph convolution and fc layers
145 | """
146 | def __init__(self, in_channels, drop_path=0.0, K=2):
147 | super(Grapher, self).__init__()
148 | self.channels = in_channels
149 | self.K = K
150 |
151 | self.fc1 = nn.Sequential(
152 | nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
153 | nn.BatchNorm2d(in_channels),
154 | )
155 | self.graph_conv = MRConv4d(in_channels, in_channels * 2, K=self.K)
156 | self.fc2 = nn.Sequential(
157 | nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
158 | nn.BatchNorm2d(in_channels),
159 | ) # out_channels back to 1x}
160 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
161 |
162 |
163 | def forward(self, x):
164 | _tmp = x
165 | x = self.fc1(x)
166 | x = self.graph_conv(x)
167 | x = self.fc2(x)
168 | x = self.drop_path(x) + _tmp
169 |
170 | return x
171 |
172 |
173 | class Downsample(nn.Module):
174 | """ Convolution-based downsample
175 | """
176 | def __init__(self, in_dim, out_dim):
177 | super().__init__()
178 | self.conv = nn.Sequential(
179 | nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1),
180 | nn.BatchNorm2d(out_dim),
181 | )
182 |
183 | def forward(self, x):
184 | x = self.conv(x)
185 | return x
186 |
187 |
188 | class FFN(nn.Module):
189 | def __init__(self, in_features, hidden_features=None, out_features=None, drop_path=0.0):
190 | super().__init__()
191 | out_features = out_features or in_features # same as input
192 | hidden_features = hidden_features or in_features # x4
193 | self.fc1 = nn.Sequential(
194 | nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
195 | nn.BatchNorm2d(hidden_features),
196 | )
197 | self.act = nn.GELU()
198 | self.fc2 = nn.Sequential(
199 | nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
200 | nn.BatchNorm2d(out_features),
201 | )
202 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
203 |
204 | def forward(self, x):
205 | shortcut = x
206 | x = self.fc1(x)
207 | x = self.act(x)
208 | x = self.fc2(x)
209 | x = self.drop_path(x) + shortcut
210 | return x
211 |
212 |
213 | class MobileViG(torch.nn.Module):
214 | def __init__(self, local_blocks, local_channels,
215 | global_blocks, global_channels,
216 | dropout=0., drop_path=0., emb_dims=512,
217 | K=2, distillation=False, num_classes=1000):
218 | super(MobileViG, self).__init__()
219 |
220 | self.distillation = distillation
221 |
222 | n_blocks = sum(global_blocks) + sum(local_blocks)
223 | dpr = [x.item() for x in torch.linspace(0, drop_path, n_blocks)] # stochastic depth decay rule
224 | dpr_idx = 0
225 |
226 | self.stem = Stem(input_dim=3, output_dim=local_channels[0])
227 |
228 | # local processing with inverted residuals
229 | self.local_backbone = nn.ModuleList([])
230 | for i in range(len(local_blocks)):
231 | if i > 0:
232 | self.local_backbone.append(Downsample(local_channels[i-1], local_channels[i]))
233 | for _ in range(local_blocks[i]):
234 | self.local_backbone.append(InvertedResidual(dim=local_channels[i], mlp_ratio=4, drop_path=dpr[dpr_idx]))
235 | dpr_idx += 1
236 | self.local_backbone.append(Downsample(local_channels[-1], global_channels[0])) # transition from local to global
237 |
238 | # global processing with svga
239 | self.backbone = nn.ModuleList([])
240 | for i in range(len(global_blocks)):
241 | if i > 0:
242 | self.backbone.append(Downsample(global_channels[i-1], global_channels[i]))
243 | for j in range(global_blocks[i]):
244 | self.backbone += [nn.Sequential(
245 | Grapher(global_channels[i], drop_path=dpr[dpr_idx], K=K),
246 | FFN(global_channels[i], global_channels[i] * 4, drop_path=dpr[dpr_idx]))
247 | ]
248 | dpr_idx += 1
249 |
250 | self.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1),
251 | nn.Conv2d(global_channels[-1], emb_dims, 1, bias=True),
252 | nn.BatchNorm2d(emb_dims),
253 | nn.GELU(),
254 | nn.Dropout(dropout))
255 |
256 | self.head = nn.Conv2d(emb_dims, num_classes, 1, bias=True)
257 |
258 | if self.distillation:
259 | self.dist_head = nn.Conv2d(emb_dims, num_classes, 1, bias=True)
260 |
261 | self.model_init()
262 |
263 | def model_init(self):
264 | for m in self.modules():
265 | if isinstance(m, torch.nn.Conv2d):
266 | torch.nn.init.kaiming_normal_(m.weight)
267 | m.weight.requires_grad = True
268 | if m.bias is not None:
269 | m.bias.data.zero_()
270 | m.bias.requires_grad = True
271 |
272 | def forward(self, inputs):
273 | x = self.stem(inputs)
274 | B, C, H, W = x.shape
275 | for i in range(len(self.local_backbone)):
276 | x = self.local_backbone[i](x)
277 | for i in range(len(self.backbone)):
278 | x = self.backbone[i](x)
279 |
280 | x = self.prediction(x)
281 |
282 | if self.distillation:
283 | x = self.head(x).squeeze(-1).squeeze(-1), self.dist_head(x).squeeze(-1).squeeze(-1)
284 | if not self.training:
285 | x = (x[0] + x[1]) / 2
286 | else:
287 | x = self.head(x).squeeze(-1).squeeze(-1)
288 | return x
289 |
290 |
291 | @register_model
292 | def mobilevig_ti(pretrained=False, **kwargs):
293 | model = MobileViG(local_blocks=[2, 2, 6],
294 | local_channels=[42, 84, 168],
295 | global_blocks=[2],
296 | global_channels=[256],
297 | dropout=0.,
298 | drop_path=0.1,
299 | emb_dims=512,
300 | K=2,
301 | distillation=False,
302 | num_classes=1000)
303 | model.default_cfg = default_cfgs['mobilevig']
304 | return model
305 |
306 |
307 | @register_model
308 | def mobilevig_s(pretrained=False, **kwargs):
309 | model = MobileViG(local_blocks=[3, 3, 9],
310 | local_channels=[42, 84, 176],
311 | global_blocks=[3],
312 | global_channels=[256],
313 | dropout=0.,
314 | drop_path=0.1,
315 | emb_dims=512,
316 | K=2,
317 | distillation=False,
318 | num_classes=1000)
319 | model.default_cfg = default_cfgs['mobilevig']
320 | return model
321 |
322 |
323 | @register_model
324 | def mobilevig_m(pretrained=False, **kwargs):
325 | model = MobileViG(local_blocks=[3, 3, 9],
326 | local_channels=[42, 84, 224],
327 | global_blocks=[3],
328 | global_channels=[400],
329 | dropout=0.,
330 | drop_path=0.1,
331 | emb_dims=768,
332 | K=2,
333 | distillation=False,
334 | num_classes=1000)
335 | model.default_cfg = default_cfgs['mobilevig']
336 | return model
337 |
338 |
339 | @register_model
340 | def mobilevig_b(pretrained=False, **kwargs):
341 | model = MobileViG(local_blocks=[5, 5, 15],
342 | local_channels=[42, 84, 240],
343 | global_blocks=[5],
344 | global_channels=[464],
345 | dropout=0.,
346 | drop_path=0.1,
347 | emb_dims=768,
348 | K=2,
349 | distillation=False,
350 | num_classes=1000)
351 | model.default_cfg = default_cfgs['mobilevig']
352 | return model
353 |
354 |
355 |
356 |
357 | if __name__ == '__main__':
358 |
359 | # for img_size in [224,224*2,224*3,224*4,224*5]:
360 | # model = create_model(
361 | # 'mobilevig_s'
362 | # )
363 | # model = model.cuda()
364 | # model.eval()
365 |
366 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
367 |
368 |
369 | # macs = profile_macs(model, x)
370 | # print(f'\n\n!!!!! MobileVig macs ({img_size}): {macs*10**(-9)}\n\n')
371 |
372 |
373 | for img_size in [224,224*2,224*3,224*4,224*5]:
374 |
375 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof:
376 |
377 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
378 |
379 | model = create_model(
380 | 'mobilevig_s'
381 | )
382 |
383 | model = model.cuda()
384 | model.eval()
385 |
386 | _ = model(x)
387 |
388 | f = open("memory_MobileVig_model.txt", "a")
389 |
390 |
391 | f.write(prof.key_averages().table())
392 | f.write('\n\n')
393 | f.close()
394 | print('\n\n---------------\nResults saved in memory_MobileVig_model.txt')
395 |
396 |
--------------------------------------------------------------------------------
/src/model/greedyvig.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import Tensor
4 | import torch.nn.functional as F
5 | from torch.nn import Sequential as Seq
6 |
7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8 | from timm.models.layers import DropPath
9 | from timm.models.registry import register_model
10 |
11 | import random
12 | import warnings
13 | from timm.models import create_model
14 | from torchprofile import profile_macs
15 |
16 | warnings.filterwarnings('ignore')
17 |
18 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
19 | import sys
20 |
21 | # IMAGENET
22 | def _cfg(url='', **kwargs):
23 | return {
24 | 'url': url,
25 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
26 | 'crop_pct': .9, 'interpolation': 'bicubic',
27 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
28 | 'classifier': 'head',
29 | **kwargs
30 | }
31 |
32 |
33 | default_cfgs = {
34 | 'greedyvig': _cfg(crop_pct=0.9, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
35 | }
36 |
37 |
38 | class Stem(nn.Module):
39 | def __init__(self, input_dim, output_dim):
40 | super(Stem, self).__init__()
41 | self.stem = nn.Sequential(
42 | nn.Conv2d(input_dim, output_dim // 2, kernel_size=3, stride=2, padding=1),
43 | nn.BatchNorm2d(output_dim // 2),
44 | nn.GELU(),
45 | nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, stride=2, padding=1),
46 | nn.BatchNorm2d(output_dim),
47 | nn.GELU(),
48 | )
49 |
50 | def forward(self, x):
51 | return self.stem(x)
52 |
53 |
54 | class DepthWiseSeparable(nn.Module):
55 | def __init__(self, in_dim, kernel, expansion=4):
56 | super().__init__()
57 |
58 | self.pw1 = nn.Conv2d(in_dim, in_dim * 4, 1) # kernel size = 1
59 | self.norm1 = nn.BatchNorm2d(in_dim * 4)
60 | self.act1 = nn.GELU()
61 |
62 | self.dw = nn.Conv2d(in_dim * 4, in_dim * 4, kernel_size=kernel, stride=1, padding=1, groups=in_dim * 4) # kernel size = 3
63 | self.norm2 = nn.BatchNorm2d(in_dim * 4)
64 | self.act2 = nn.GELU()
65 |
66 | self.pw2 = nn.Conv2d(in_dim * 4, in_dim, 1)
67 | self.norm3 = nn.BatchNorm2d(in_dim)
68 |
69 | def forward(self, x):
70 | x = self.pw1(x)
71 | x = self.norm1(x)
72 | x = self.act1(x)
73 |
74 | x = self.dw(x)
75 | x = self.norm2(x)
76 | x = self.act2(x)
77 |
78 | x = self.pw2(x)
79 | x = self.norm3(x)
80 | return x
81 |
82 |
83 | class InvertedResidual(nn.Module):
84 | def __init__(self, dim, kernel, expansion_ratio=4., drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5):
85 | super().__init__()
86 |
87 | self.dws = DepthWiseSeparable(in_dim=dim, kernel=kernel, expansion=expansion_ratio)
88 |
89 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
90 | self.use_layer_scale = use_layer_scale
91 | if use_layer_scale:
92 | self.layer_scale_2 = nn.Parameter(
93 | layer_scale_init_value * torch.ones(dim), requires_grad=True)
94 |
95 | def forward(self, x):
96 | if self.use_layer_scale:
97 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.dws(x))
98 | else:
99 | x = x + self.drop_path(self.dws(x))
100 | return x
101 |
102 |
103 |
104 | class DynamicMRConv4d(nn.Module):
105 | def __init__(self, in_channels, out_channels, K):
106 | super().__init__()
107 | self.nn = nn.Sequential(
108 | nn.Conv2d(in_channels, out_channels, 1),
109 | nn.BatchNorm2d(out_channels),
110 | nn.GELU()
111 | )
112 | self.K = K
113 | self.mean = 0
114 | self.std = 0
115 |
116 | def forward(self, x):
117 | B, C, H, W = x.shape
118 | x_j = x - x
119 |
120 | # get an estimate of the mean distance by computing the distance of points b/w quadrants. This is for efficiency to minimize computations.
121 | x_rolled = torch.cat([x[:, :, -H//2:, :], x[:, :, :-H//2, :]], dim=2)
122 | x_rolled = torch.cat([x_rolled[:, :, :, -W//2:], x_rolled[:, :, :, :-W//2]], dim=3)
123 |
124 | # Norm, Euclidean Distance
125 | norm = torch.norm((x - x_rolled), p=1, dim=1, keepdim=True)
126 |
127 | self.mean = torch.mean(norm, dim=[2,3], keepdim=True)
128 | self.std = torch.std(norm, dim=[2,3], keepdim=True)
129 |
130 |
131 | for i in range(0, H, self.K):
132 | x_rolled = torch.cat([x[:, :, -i:, :], x[:, :, :-i, :]], dim=2)
133 |
134 | dist = torch.norm((x - x_rolled), p=1, dim=1, keepdim=True)
135 |
136 | # Got 83.86%
137 | mask = torch.where(dist < self.mean - self.std, 1, 0)
138 |
139 | x_rolled_and_masked = (x_rolled - x) * mask
140 | x_j = torch.max(x_j, x_rolled_and_masked)
141 |
142 | for j in range(0, W, self.K):
143 | x_rolled = torch.cat([x[:, :, :, -j:], x[:, :, :, :-j]], dim=3)
144 |
145 | dist = torch.norm((x - x_rolled), p=1, dim=1, keepdim=True)
146 |
147 | mask = torch.where(dist < self.mean - self.std, 1, 0)
148 |
149 | x_rolled_and_masked = (x_rolled - x) * mask
150 | x_j = torch.max(x_j, x_rolled_and_masked)
151 |
152 | x = torch.cat([x, x_j], dim=1)
153 | return self.nn(x)
154 |
155 |
156 |
157 | class ConditionalPositionEncoding(nn.Module):
158 | """
159 | Implementation of conditional positional encoding. For more details refer to paper:
160 | `Conditional Positional Encodings for Vision Transformers `_
161 | """
162 | def __init__(self, in_channels, kernel_size):
163 | super().__init__()
164 | self.pe = nn.Conv2d(
165 | in_channels=in_channels,
166 | out_channels=in_channels,
167 | kernel_size=kernel_size,
168 | stride=1,
169 | padding=kernel_size // 2,
170 | bias=True,
171 | groups=in_channels
172 | )
173 |
174 | def forward(self, x):
175 | x = self.pe(x) + x
176 | return x
177 |
178 |
179 | class Grapher(nn.Module):
180 | """
181 | Grapher module with graph convolution and fc layers
182 | """
183 | def __init__(self, in_channels, K):
184 | super(Grapher, self).__init__()
185 | self.channels = in_channels
186 | self.K = K
187 |
188 | self.cpe = ConditionalPositionEncoding(in_channels, kernel_size=7)
189 | self.fc1 = nn.Sequential(
190 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0),
191 | nn.BatchNorm2d(in_channels),
192 | )
193 | self.graph_conv = DynamicMRConv4d(in_channels * 2, in_channels, K=self.K)
194 | self.fc2 = nn.Sequential(
195 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0),
196 | nn.BatchNorm2d(in_channels),
197 | ) # out_channels back to 1x}
198 |
199 |
200 | def forward(self, x):
201 | x = self.cpe(x)
202 | x = self.fc1(x)
203 | x = self.graph_conv(x)
204 | x = self.fc2(x)
205 |
206 | return x
207 |
208 |
209 | class DynamicGraphConvBlock(nn.Module):
210 | def __init__(self, in_dim, drop_path=0., K=2, use_layer_scale=True, layer_scale_init_value=1e-5):
211 | super().__init__()
212 |
213 | self.mixer = Grapher(in_dim, K)
214 | self.ffn = nn.Sequential(
215 | nn.Conv2d(in_dim, in_dim * 4, kernel_size=1, stride=1, padding=0),
216 | nn.BatchNorm2d(in_dim * 4),
217 | nn.GELU(),
218 | nn.Conv2d(in_dim * 4, in_dim, kernel_size=1, stride=1, padding=0),
219 | nn.BatchNorm2d(in_dim),
220 | )
221 |
222 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
223 | self.use_layer_scale = use_layer_scale
224 | if use_layer_scale:
225 | self.layer_scale_1 = nn.Parameter(
226 | layer_scale_init_value * torch.ones(in_dim), requires_grad=True)
227 | self.layer_scale_2 = nn.Parameter(
228 | layer_scale_init_value * torch.ones(in_dim), requires_grad=True)
229 |
230 | def forward(self, x):
231 | if self.use_layer_scale:
232 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.mixer(x))
233 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.ffn(x))
234 | else:
235 | x = x + self.drop_path(self.mixer(x))
236 | x = x + self.drop_path(self.ffn(x))
237 | return x
238 |
239 |
240 | class Downsample(nn.Module):
241 | """
242 | Convolution-based downsample
243 | """
244 | def __init__(self, in_dim, out_dim):
245 | super().__init__()
246 | self.conv = nn.Sequential(
247 | nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1),
248 | nn.BatchNorm2d(out_dim),
249 | )
250 | def forward(self, x):
251 | x = self.conv(x)
252 | return x
253 |
254 |
255 | class GreedyViG(torch.nn.Module):
256 | def __init__(self, blocks, channels, kernels, stride,
257 | act_func, dropout=0., drop_path=0., emb_dims=512,
258 | K=2, distillation=False, num_classes=1000):
259 | super(GreedyViG, self).__init__()
260 |
261 | self.distillation = distillation
262 | self.stage_names = ['stem', 'local_1', 'local_2', 'local_3', 'global']
263 |
264 | n_blocks = sum([sum(x) for x in blocks])
265 | dpr = [x.item() for x in torch.linspace(0, drop_path, n_blocks)] # stochastic depth decay rule
266 | dpr_idx = 0
267 |
268 | self.stem = Stem(input_dim=3, output_dim=channels[0])
269 |
270 | self.backbone = []
271 | for i in range(len(blocks)):
272 | stage = []
273 | local_stages = blocks[i][0]
274 | global_stages = blocks[i][1]
275 | if i > 0:
276 | stage.append(Downsample(channels[i-1], channels[i]))
277 | for _ in range(local_stages):
278 | stage.append(InvertedResidual(dim=channels[i], kernel=3, expansion_ratio=4, drop_path=dpr[dpr_idx]))
279 | dpr_idx += 1
280 | for _ in range(global_stages):
281 | stage.append(DynamicGraphConvBlock(channels[i], drop_path=dpr[dpr_idx], K=K[i]))
282 | dpr_idx += 1
283 | self.backbone.append(nn.Sequential(*stage))
284 |
285 | self.backbone = nn.Sequential(*self.backbone)
286 |
287 | self.prediction = nn.Sequential(nn.AdaptiveAvgPool2d(1),
288 | nn.Conv2d(channels[-1], emb_dims, kernel_size=1, bias=True),
289 | nn.BatchNorm2d(emb_dims),
290 | nn.GELU(),
291 | nn.Dropout(dropout))
292 |
293 | self.head = nn.Conv2d(emb_dims, num_classes, kernel_size=1, bias=True)
294 |
295 | if self.distillation:
296 | self.dist_head = nn.Conv2d(emb_dims, num_classes, 1, bias=True)
297 |
298 | self.model_init()
299 |
300 | def model_init(self):
301 | for m in self.modules():
302 | if isinstance(m, torch.nn.Conv2d):
303 | torch.nn.init.kaiming_normal_(m.weight)
304 | m.weight.requires_grad = True
305 | if m.bias is not None:
306 | m.bias.data.zero_()
307 | m.bias.requires_grad = True
308 |
309 | def forward(self, inputs):
310 | x = self.stem(inputs)
311 | B, C, H, W = x.shape
312 | x = self.backbone(x)
313 |
314 | x = self.prediction(x)
315 |
316 |
317 | if self.distillation:
318 |
319 | x = self.head(x).squeeze(-1).squeeze(-1), self.dist_head(x).squeeze(-1).squeeze(-1)
320 | if not self.training:
321 | x = (x[0] + x[1]) / 2
322 | else:
323 | x = self.head(x).squeeze(-1).squeeze(-1)
324 | return x
325 |
326 | @register_model
327 | def GreedyViG_S(pretrained=False, **kwargs): ## 12.0 M, 1.6 GMACs
328 | model = GreedyViG(blocks=[[2,2], [2,2], [6,2], [2,2]],
329 | channels=[48, 96, 192, 384],
330 | kernels=3,
331 | stride=1,
332 | act_func='gelu',
333 | dropout=0.,
334 | drop_path=0.1,
335 | emb_dims=768,
336 | K=[8, 4, 2, 1],
337 | distillation=False,
338 | num_classes=1000)
339 | model.default_cfg = default_cfgs['greedyvig']
340 | return model
341 |
342 | @register_model
343 | def GreedyViG_M(pretrained=False, **kwargs): # 21.9 M, 3.2 GMACs
344 | model = GreedyViG(blocks=[[3,3], [3,3], [9,3], [3,3]],
345 | channels=[56, 112, 224, 448],
346 | kernels=3,
347 | stride=1,
348 | act_func='gelu',
349 | dropout=0.,
350 | drop_path=0.1,
351 | emb_dims=768,
352 | K=[8, 4, 2, 1],
353 | distillation=False,
354 | num_classes=1000)
355 | model.default_cfg = default_cfgs['greedyvig']
356 | return model
357 |
358 | @register_model
359 | def GreedyViG_B(pretrained=False, **kwargs): # 30.9 M, 5.2 GMACs
360 | model = GreedyViG(blocks=[[4,4], [4,4], [12,4], [3,3]],
361 | channels=[64, 128, 256, 512],
362 | kernels=3,
363 | stride=1,
364 | act_func='gelu',
365 | dropout=0.,
366 | drop_path=0.1,
367 | emb_dims=768,
368 | K=[8, 4, 2, 1],
369 | distillation=False,
370 | num_classes=1000)
371 | model.default_cfg = default_cfgs['greedyvig']
372 | return model
373 |
374 |
375 |
376 | if __name__ == '__main__':
377 |
378 | # for img_size in [224,224*2,224*3,224*4,224*5]:
379 | # model = create_model(
380 | # 'GreedyViG_S',
381 | # num_classes=1000,
382 | # distillation=False,
383 | # pretrained=False,
384 | # fuse=False
385 | # )
386 | # model = model.cuda()
387 | # model.eval()
388 |
389 | # x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
390 |
391 |
392 | # macs = profile_macs(model, x)
393 | # print(f'\n\n!!!!! GreedyViG macs ({img_size}): {macs*10**(-9)}\n\n')
394 |
395 |
396 | for img_size in [224,224*2,224*3,224*4,224*5]:
397 |
398 | with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA],profile_memory=True, record_shapes=True) as prof:
399 |
400 | x = torch.rand((1,3,img_size,img_size)).to(device = 'cuda')
401 |
402 | model = create_model(
403 | 'GreedyViG_S',
404 | num_classes=1000,
405 | distillation=False,
406 | pretrained=False,
407 | fuse=False
408 | )
409 |
410 | model = model.cuda()
411 | model.eval()
412 |
413 | _ = model(x)
414 |
415 | f = open("memory_GreedyViG_model.txt", "a")
416 |
417 |
418 | f.write(prof.key_averages().table())
419 | f.write('\n\n')
420 | f.close()
421 | print('\n\n---------------\nResults saved in memory_GreedyViG_model.txt')
--------------------------------------------------------------------------------
/src/opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 |
4 | def int2bool(i):
5 | i = int(i)
6 | assert i == 0 or i == 1
7 | return i == 1
8 |
9 |
10 | # The first arg parser parses out only the --config argument, this argument is used to
11 | # load a yaml file containing key-values that override the defaults for the main parser below
12 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
13 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
14 | help='YAML config file specifying default arguments')
15 |
16 |
17 |
18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
19 |
20 | # Dataset / Model parameters
21 | parser.add_argument('--data', default='/scratch/dataset/inet', type=str, help='path to dataset (default: imagenet)')
22 | parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
23 | help='Name of model to train (default: "countception"')
24 | parser.add_argument('--pretrained', action='store_true', default=False,
25 | help='Start with pretrained version of specified network (if avail)')
26 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
27 | help='Initialize model from this checkpoint (default: none)')
28 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
29 | help='Resume full model and optimizer state from checkpoint (default: none)')
30 | parser.add_argument('--no-resume-opt', action='store_true', default=False,
31 | help='prevent resume of optimizer state when resuming model')
32 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
33 | help='number of label classes (default: 1000)')
34 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
35 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
36 | """ parser.add_argument('--img-size', type=int, default=None, metavar='N',
37 | help='Image patch size (default: None => model default)') """
38 | parser.add_argument('--crop-pct', default=None, type=float,
39 | metavar='N', help='Input image center crop percent (for validation only)')
40 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
41 | help='Override mean pixel value of dataset')
42 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
43 | help='Override std deviation of of dataset')
44 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
45 | help='Image resize interpolation type (overrides model)')
46 | parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
47 | help='input batch size for training (default: 32)')
48 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
49 | help='ratio of validation batch size to training batch size (default: 1)')
50 |
51 | # Optimizer parameters
52 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
53 | help='Optimizer (default: "sgd"')
54 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
55 | help='Optimizer Epsilon (default: None, use opt default)')
56 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
57 | help='Optimizer Betas (default: None, use opt default)')
58 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
59 | help='Optimizer momentum (default: 0.9)')
60 | parser.add_argument('--weight-decay', type=float, default=0.0001,
61 | help='weight decay (default: 0.0001)')
62 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
63 | help='Clip gradient norm (default: None, no clipping)')
64 |
65 |
66 |
67 | # Learning rate schedule parameters
68 | parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
69 | help='LR scheduler (default: "step"')
70 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
71 | help='learning rate (default: 0.01)')
72 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
73 | help='learning rate noise on/off epoch percentages')
74 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
75 | help='learning rate noise limit percent (default: 0.67)')
76 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
77 | help='learning rate noise std-dev (default: 1.0)')
78 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
79 | help='learning rate cycle len multiplier (default: 1.0)')
80 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
81 | help='learning rate cycle limit')
82 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
83 | help='warmup learning rate (default: 0.0001)')
84 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
85 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
86 | parser.add_argument('--epochs', type=int, default=200, metavar='N',
87 | help='number of epochs to train (default: 2)')
88 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
89 | help='manual epoch number (useful on restarts)')
90 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
91 | help='epoch interval to decay LR')
92 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
93 | help='epochs to warmup LR, if scheduler supports')
94 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
95 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
96 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
97 | help='patience epochs for Plateau LR scheduler (default: 10')
98 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
99 | help='LR decay rate (default: 0.1)')
100 |
101 | # Augmentation & regularization parameters
102 | parser.add_argument('--no-aug', action='store_true', default=False,
103 | help='Disable all training augmentation, override other train aug args')
104 | parser.add_argument('--repeated-aug', action='store_true')
105 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
106 | help='Random resize scale (default: 0.08 1.0)')
107 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
108 | help='Random resize aspect ratio (default: 0.75 1.33)')
109 | parser.add_argument('--hflip', type=float, default=0.5,
110 | help='Horizontal flip training aug probability')
111 | parser.add_argument('--vflip', type=float, default=0.,
112 | help='Vertical flip training aug probability')
113 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
114 | help='Color jitter factor (default: 0.4)')
115 | parser.add_argument('--aa', type=str, default=None, metavar='NAME',
116 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
117 | parser.add_argument('--aug-splits', type=int, default=0,
118 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
119 | parser.add_argument('--jsd', action='store_true', default=False,
120 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
121 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
122 | help='Random erase prob (default: 0.)')
123 | parser.add_argument('--remode', type=str, default='const',
124 | help='Random erase mode (default: "const")')
125 | parser.add_argument('--recount', type=int, default=1,
126 | help='Random erase count (default: 1)')
127 | parser.add_argument('--resplit', action='store_true', default=False,
128 | help='Do not random erase first (clean) augmentation split')
129 | parser.add_argument('--mixup', type=float, default=0.0,
130 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
131 | parser.add_argument('--cutmix', type=float, default=0.0,
132 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
133 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
134 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
135 | parser.add_argument('--mixup-prob', type=float, default=1.0,
136 | help='Probability of performing mixup or cutmix when either/both is enabled')
137 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
138 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
139 | parser.add_argument('--mixup-mode', type=str, default='batch',
140 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
141 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
142 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
143 | parser.add_argument('--smoothing', type=float, default=0.1,
144 | help='Label smoothing (default: 0.1)')
145 | parser.add_argument('--train-interpolation', type=str, default='random',
146 | help='Training interpolation (random, bilinear, bicubic default: "random")')
147 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
148 | help='Dropout rate (default: 0.)')
149 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
150 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
151 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
152 | help='Drop path rate (default: None)')
153 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
154 | help='Drop block rate (default: None)')
155 |
156 | # Batch norm parameters (only works with gen_efficientnet based models currently)
157 | parser.add_argument('--bn-tf', action='store_true', default=False,
158 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
159 | parser.add_argument('--bn-momentum', type=float, default=None,
160 | help='BatchNorm momentum override (if not None)')
161 | parser.add_argument('--bn-eps', type=float, default=None,
162 | help='BatchNorm epsilon override (if not None)')
163 | parser.add_argument('--sync-bn', action='store_true',
164 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
165 | parser.add_argument('--dist-bn', type=str, default='',
166 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
167 | parser.add_argument('--split-bn', action='store_true',
168 | help='Enable separate BN layers per augmentation split.')
169 |
170 | # Model Exponential Moving Average
171 | parser.add_argument('--model-ema', action='store_true', default=False,
172 | help='Enable tracking moving average of model weights')
173 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
174 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
175 | parser.add_argument('--model-ema-decay', type=float, default=0.9998,
176 | help='decay factor for model weights moving average (default: 0.9998)')
177 |
178 | # Misc
179 | parser.add_argument('--seed', type=int, default=42, metavar='S',
180 | help='random seed (default: 42)')
181 | parser.add_argument('--log-interval', type=int, default=50, metavar='N',
182 | help='how many batches to wait before logging training status')
183 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
184 | help='how many batches to wait before writing recovery checkpoint')
185 | parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
186 | help='how many training processes to use (default: 1)')
187 | parser.add_argument('--num-gpu', type=int, default=1,
188 | help='Number of GPUS to use')
189 | parser.add_argument('--save-images', action='store_true', default=False,
190 | help='save images of input bathes every log interval for debugging')
191 | parser.add_argument('--amp', action='store_true', default=False,
192 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
193 | parser.add_argument('--apex-amp', action='store_true', default=False,
194 | help='Use NVIDIA Apex AMP mixed precision')
195 | parser.add_argument('--native-amp', action='store_true', default=False,
196 | help='Use Native Torch AMP mixed precision')
197 | parser.add_argument('--channels-last', action='store_true', default=False,
198 | help='Use channels_last memory layout')
199 | parser.add_argument('--pin-mem', action='store_true', default=False,
200 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
201 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
202 | help='disable fast prefetcher')
203 | parser.add_argument('--output', default='', type=str, metavar='PATH',
204 | help='path to output folder (default: none, current dir)')
205 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
206 | help='Best metric (default: "top1"')
207 | parser.add_argument('--tta', type=int, default=0, metavar='N',
208 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
209 | parser.add_argument("--local-rank","--local_rank", default=0, type=int)
210 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
211 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
212 |
213 | parser.add_argument("--init-method", default='env://', type=str)
214 | parser.add_argument("--train-url", type=str)
215 | # newly added
216 | parser.add_argument('--attn-ratio', type=float, default=1.,
217 | help='attention ratio')
218 | parser.add_argument("--pretrain-path", default=None, type=str)
219 | parser.add_argument("--evaluate", action='store_true', default=False,
220 | help='whether evaluate the model')
221 |
222 |
223 |
224 | # ours
225 | parser.add_argument('--knn', default=9, type=int)
226 | parser.add_argument('--use-reduce-ratios', type=int2bool, default=0) # 1 == True
227 | parser.add_argument('--img-size', default=224, type=int)
228 |
229 | parser.add_argument('--use-shift', type=int2bool, default=0) # 1 == True
230 | parser.add_argument('--adapt-knn', type=int2bool, default=0) # 1 == True
231 |
232 |
233 |
234 | def parse_args():
235 | # Do we have a config file to parse?
236 | args_config, remaining = config_parser.parse_known_args()
237 | if args_config.config:
238 | with open(args_config.config, 'r') as f:
239 | cfg = yaml.safe_load(f)
240 | parser.set_defaults(**cfg)
241 |
242 | # The main arg parser parses the rest of the args, the usual
243 | # defaults will have been overridden if config file specified.
244 | args = parser.parse_args(remaining)
245 |
246 | # Cache the args as a text string to save them in the output dir later
247 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
248 | return args, args_text
--------------------------------------------------------------------------------
/src/gcn_lib/torch_vertex.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from .torch_nn import BasicConv, batched_index_select, act_layer
6 | from .torch_edge import DenseDilatedKnnGraph
7 | from .pos_embed import get_2d_relative_pos_embed
8 | import torch.nn.functional as F
9 | from timm.models.layers import DropPath
10 | import sys
11 | from .torch_local import window_partition, window_reverse, PatchEmbed, window_partition_channel_last
12 | import time
13 |
14 | from einops import rearrange, repeat
15 |
16 |
17 | class MRConv2d(nn.Module):
18 | """
19 | Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
20 | """
21 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
22 | super(MRConv2d, self).__init__()
23 | self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)
24 |
25 | def forward(self, x, edge_index, y=None):
26 | x_i = batched_index_select(x, edge_index[1])
27 | if y is not None:
28 | x_j = batched_index_select(y, edge_index[0])
29 | else:
30 | x_j = batched_index_select(x, edge_index[0])
31 | x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
32 | b, c, n, _ = x.shape
33 | x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
34 | return self.nn(x)
35 |
36 |
37 | class EdgeConv2d(nn.Module):
38 | """
39 | Edge convolution layer (with activation, batch normalization) for dense data type
40 | """
41 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
42 | super(EdgeConv2d, self).__init__()
43 | self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)
44 |
45 | def forward(self, x, edge_index, y=None):
46 | x_i = batched_index_select(x, edge_index[1])
47 | if y is not None:
48 | x_j = batched_index_select(y, edge_index[0])
49 | else:
50 | x_j = batched_index_select(x, edge_index[0])
51 | max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
52 | return max_value
53 |
54 |
55 | class GraphSAGE(nn.Module):
56 | """
57 | GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216) for dense data type
58 | """
59 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
60 | super(GraphSAGE, self).__init__()
61 | self.nn1 = BasicConv([in_channels, in_channels], act, norm, bias)
62 | self.nn2 = BasicConv([in_channels*2, out_channels], act, norm, bias)
63 |
64 | def forward(self, x, edge_index, y=None):
65 | if y is not None:
66 | x_j = batched_index_select(y, edge_index[0])
67 | else:
68 | x_j = batched_index_select(x, edge_index[0])
69 | x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True)
70 | return self.nn2(torch.cat([x, x_j], dim=1))
71 |
72 |
73 | class GINConv2d(nn.Module):
74 | """
75 | GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for dense data type
76 | """
77 | def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
78 | super(GINConv2d, self).__init__()
79 | self.nn = BasicConv([in_channels, out_channels], act, norm, bias)
80 | eps_init = 0.0
81 | self.eps = nn.Parameter(torch.Tensor([eps_init]))
82 |
83 | def forward(self, x, edge_index, y=None):
84 | if y is not None:
85 | x_j = batched_index_select(y, edge_index[0])
86 | else:
87 | x_j = batched_index_select(x, edge_index[0])
88 | x_j = torch.sum(x_j, -1, keepdim=True)
89 | return self.nn((1 + self.eps) * x + x_j)
90 |
91 |
92 | class GraphConv2d(nn.Module):
93 | """
94 | Static graph convolution layer
95 | """
96 | def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True):
97 | super(GraphConv2d, self).__init__()
98 | if conv == 'edge':
99 | self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias)
100 | elif conv == 'mr':
101 | self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias)
102 |
103 | elif conv == 'sage':
104 | self.gconv = GraphSAGE(in_channels, out_channels, act, norm, bias)
105 | elif conv == 'gin':
106 | self.gconv = GINConv2d(in_channels, out_channels, act, norm, bias)
107 | else:
108 | raise NotImplementedError('conv:{} is not supported'.format(conv))
109 |
110 | def forward(self, x, edge_index, y=None):
111 | return self.gconv(x, edge_index, y)
112 |
113 |
114 | class DyGraphConv2d(GraphConv2d):
115 | """
116 | Dynamic graph convolution layer
117 | """
118 | def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
119 | norm=None, bias=True, stochastic=False, epsilon=0.0, r=1):
120 | super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias)
121 | self.k = kernel_size
122 | self.d = dilation
123 | self.r = r
124 | self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
125 | self.debug = True
126 |
127 | def _mask_edge_index(self, edge_index, adj_mask):
128 | # edge_index: [2, B, N, k]
129 | # adj_mask: [B, N, k] (=1 -> keep / =0 -> masked)
130 |
131 | if adj_mask is None:
132 | return edge_index
133 |
134 |
135 | # edge_index_j: [B, N, k]
136 | edge_index_j = edge_index[0]
137 | edge_index_i = edge_index[1]
138 | adj_mask_inv = torch.ones_like(adj_mask) - adj_mask
139 | # adj_mask: [[1,1,0,0],
140 | # [1,1,1,0]]
141 | #
142 | # edge_index_j: [[14,43,20,21],
143 | # [18,12,32,24]]
144 | #
145 | # edge_index_i: [[0,0,0,0],
146 | # [1,1,1,1]]
147 | #
148 | # output: [[14,43,0,0],
149 | # [18,12,32,1]]
150 | edge_index_j = (edge_index_j * adj_mask) + (edge_index_i * adj_mask_inv)
151 |
152 | return torch.stack((edge_index_j, edge_index_i), dim=0).long()
153 |
154 |
155 |
156 | def forward(self, x, relative_pos=None, adj_mask = None):
157 | # print('Doing gnn')
158 | B, C, H, W = x.shape
159 | y = None
160 | if self.r > 1:
161 | y = F.avg_pool2d(x, self.r, self.r)
162 | # print(f'y: {y.shape}')
163 | y = y.reshape(B, C, -1, 1).contiguous()
164 | x = x.reshape(B, C, -1, 1).contiguous()
165 |
166 | edge_index = self.dilated_knn_graph(x, y, relative_pos)
167 |
168 | edge_index = self._mask_edge_index(edge_index, adj_mask)
169 |
170 | x = super(DyGraphConv2d, self).forward(x, edge_index, y)
171 |
172 |
173 | if self.debug:
174 | return x.reshape(x.shape[0], -1, H, W).contiguous(), edge_index
175 | return x.reshape(x.shape[0], -1, H, W).contiguous()
176 |
177 |
178 |
179 |
180 |
181 |
182 | class Grapher(nn.Module):
183 | """
184 | Grapher module with graph convolution and fc layers
185 | """
186 | def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
187 | bias=True, stochastic=False, epsilon=0.0, r=1, n=196, drop_path=0.0, relative_pos=False):
188 | super(Grapher, self).__init__()
189 | self.channels = in_channels
190 | self.n = n
191 | self.r = r
192 | self.fc1 = nn.Sequential(
193 | nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
194 | nn.BatchNorm2d(in_channels),
195 | )
196 | self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size, dilation, conv,
197 | act, norm, bias, stochastic, epsilon, r)
198 | self.fc2 = nn.Sequential(
199 | nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
200 | nn.BatchNorm2d(in_channels),
201 | )
202 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
203 | self.relative_pos = None
204 | if relative_pos:
205 | print('using relative_pos')
206 | relative_pos_tensor = torch.from_numpy(np.float32(get_2d_relative_pos_embed(in_channels,
207 | int(n**0.5)))).unsqueeze(0).unsqueeze(1)
208 | relative_pos_tensor = F.interpolate(
209 | relative_pos_tensor, size=(n, n//(r*r)), mode='bicubic', align_corners=False)
210 | self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False)
211 |
212 | def _get_relative_pos(self, relative_pos, H, W):
213 | if relative_pos is None or H * W == self.n:
214 | return relative_pos
215 | else:
216 | N = H * W
217 | N_reduced = N // (self.r * self.r)
218 | return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)
219 |
220 | def forward(self, x):
221 | _tmp = x
222 | x = self.fc1(x)
223 | B, C, H, W = x.shape
224 | relative_pos = self._get_relative_pos(self.relative_pos, H, W)
225 |
226 | x, edge_index = self.graph_conv(x, relative_pos)
227 | x = self.fc2(x)
228 | x = self.drop_path(x) + _tmp
229 | return x
230 |
231 |
232 | class WindowGrapher(nn.Module):
233 | """
234 | Local Grapher module with graph convolution and fc layers
235 | """
236 | def __init__(
237 | self,
238 | in_channels,
239 | kernel_size=9,
240 | windows_size = 7,
241 | dilation=1,
242 | conv='edge',
243 | act='relu',
244 | norm=None,
245 | bias=True,
246 | stochastic=False,
247 | epsilon=0.0,
248 | drop_path=0.0,
249 | relative_pos=False,
250 | shift_size = 0,
251 | r = 1,
252 | input_resolution = (224//4,224//4),
253 | adapt_knn = False):
254 | super(WindowGrapher, self).__init__()
255 |
256 | if min(input_resolution) <= windows_size:
257 | # if window size is larger than input resolution, we don't partition windows
258 | shift_size = 0
259 | windows_size = min(input_resolution)
260 | assert 0 <= shift_size < windows_size, "shift_size must in 0-window_size"
261 |
262 |
263 | max_connection_allowed = (windows_size // r)**2
264 | if shift_size > 0:
265 | assert shift_size % r == 0
266 | max_connection_allowed = (shift_size // r)**2
267 |
268 | assert kernel_size <= max_connection_allowed, f'trying k = {kernel_size} while the max can be: {max_connection_allowed}'
269 |
270 |
271 | self.windows_size = windows_size
272 | self.shift_size = shift_size
273 | self.r = r
274 |
275 |
276 | n_nodes = self.windows_size * self.windows_size
277 |
278 | self.fc1 = nn.Sequential(
279 | nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
280 | nn.BatchNorm2d(in_channels),
281 | )
282 |
283 | self.graph_conv = DyGraphConv2d(in_channels, (in_channels * 2), kernel_size, dilation, conv,
284 | act, norm, bias, stochastic, epsilon, r = r)
285 |
286 | self.fc2 = nn.Sequential(
287 | nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
288 | nn.BatchNorm2d(in_channels),
289 | )
290 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
291 |
292 |
293 |
294 |
295 |
296 | self.relative_pos = None
297 | if relative_pos:
298 | print('using relative_pos')
299 | relative_pos_tensor = torch.from_numpy(np.float32(get_2d_relative_pos_embed(in_channels,
300 | int(n_nodes**0.5)))).unsqueeze(0).unsqueeze(1)
301 | relative_pos_tensor = F.interpolate(
302 | relative_pos_tensor, size=(n_nodes, n_nodes//(r*r)), mode='bicubic', align_corners=False)
303 | self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False)
304 |
305 |
306 | attn_mask = None
307 | adj_mask = None
308 | if self.shift_size > 0:
309 | print(f'Shifting windows!')
310 | H, W = input_resolution
311 | # calculate attention mask for SW-MSA
312 | img_mask = torch.zeros((1, 1, H, W))
313 | h_slices = (slice(0, -self.windows_size),
314 | slice(-self.windows_size, -self.shift_size),
315 | slice(-self.shift_size, None))
316 | w_slices = (slice(0, -self.windows_size),
317 | slice(-self.windows_size, -self.shift_size),
318 | slice(-self.shift_size, None))
319 |
320 |
321 | cnt = 0
322 | for h in h_slices:
323 | for w in w_slices:
324 | img_mask[:, :, h, w] = cnt
325 | cnt += 1
326 |
327 | # print(img_mask)
328 | # print('\n\n')
329 |
330 | mask_windows_unf = window_partition(img_mask, self.windows_size) # nW, 1, windows_size, windows_size,
331 |
332 |
333 | mask_windows = mask_windows_unf.view(-1, self.windows_size * self.windows_size)
334 |
335 | if self.r > 1:
336 | mask_windows_y = F.max_pool2d(mask_windows_unf, self.r, self.r)
337 | mask_windows_y = mask_windows_y.view(-1, (self.windows_size // self.r) * (self.windows_size // self.r))
338 | else:
339 | mask_windows_y = mask_windows
340 |
341 | attn_mask = mask_windows_y.unsqueeze(1) - mask_windows.unsqueeze(2) # nW x N x (N // r)
342 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(1000000.0)).masked_fill(attn_mask == 0, float(0.0))
343 |
344 | # Get n_connections_allowed for each node in each windows
345 | if adapt_knn:
346 | print('Adapting knn!')
347 | adj_mask = torch.empty((attn_mask.shape[0], attn_mask.shape[1], kernel_size)) # nW x N x k
348 | for w in range(attn_mask.shape[0]):
349 | for i in range(attn_mask.shape[1]):
350 | all_connection = torch.sum(attn_mask[w,i] == 0)
351 | scaled_knn = (kernel_size * all_connection) // (self.windows_size * (self.windows_size // r))
352 | n_connections_allowed = int(max(scaled_knn, 3.0))
353 | # print(f'Window: {w} node {i} - allowed_connection = {all_connection} (k = {n_connections_allowed})')
354 | masked = torch.zeros(kernel_size - n_connections_allowed)
355 | un_masked = torch.ones(n_connections_allowed)
356 | adj_mask[w,i] = torch.cat([un_masked,masked],dim=0)
357 |
358 | self.register_buffer("attn_mask", attn_mask)
359 | self.register_buffer("adj_mask", adj_mask)
360 |
361 |
362 | def _merge_pos_attn(self, batch_size):
363 | if self.attn_mask is None:
364 | return self.relative_pos
365 |
366 | if self.relative_pos is None:
367 | print('Should not be here..')
368 | self.relative_pos = torch.zeros((1, self.attn_mask.shape[1], self.attn_mask.shape[2])).to(self.attn_mask.device)
369 |
370 | nW_nGH = self.attn_mask.shape[0]
371 | #print(f'Attention mask: {self.attn_mask.shape}')
372 | return self.relative_pos.repeat(nW_nGH*batch_size,1,1) + self.attn_mask.repeat(batch_size, 1, 1) # B, N, N
373 |
374 |
375 |
376 |
377 |
378 | def forward(self, x):
379 |
380 |
381 | _tmp = x
382 | x = self.fc1(x)
383 | B, C, H, W = x.shape
384 |
385 | # cyclic shift
386 | if self.shift_size > 0:
387 | x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
388 |
389 |
390 |
391 | x = window_partition(x, window_size = self.windows_size)
392 |
393 | # merge relative_pos w/ attn_mask TODO
394 | pos_att = self._merge_pos_attn(batch_size=B)
395 | # pos_att: b*nW, N, N
396 |
397 | adj_mask = None
398 | if self.adj_mask is not None:
399 | adj_mask = self.adj_mask.repeat(B,1,1)
400 | x, edge_index = self.graph_conv(x, pos_att, adj_mask)
401 |
402 |
403 | x = window_reverse(x, self.windows_size, H=H, W=W)
404 |
405 | # reverse cyclic shift
406 | if self.shift_size > 0:
407 | x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
408 |
409 | x = self.fc2(x)
410 |
411 | x = self.drop_path(x) + _tmp
412 |
413 | return x
414 |
415 |
--------------------------------------------------------------------------------
/src/train_trasnfer_learning.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import datetime
4 | import os
5 | import time
6 |
7 | import torch
8 | import torch.utils.data
9 | import torchvision
10 | import torchvision.models.detection
11 | import torchvision.models.detection.mask_rcnn
12 | import utils
13 |
14 |
15 | import torch.nn as nn
16 | import sys
17 | import numpy as np
18 | from opt_transfer import get_args_parser
19 | import errno
20 |
21 | import wandb
22 | import random
23 |
24 | import torch.backends.cudnn as cudnn
25 | import warnings
26 |
27 | from model.transfer_models import get_model
28 | from dataloaders.celeba_hq import get_celeba
29 |
30 | from sklearn.metrics import confusion_matrix
31 | import seaborn as sn
32 | import pandas as pd
33 | import matplotlib.pyplot as plt
34 | import utils
35 |
36 | def set_seed(seed):
37 | random.seed(seed)
38 | torch.manual_seed(seed)
39 | cudnn.deterministic = True
40 | cudnn.benchmark = False
41 | np.random.seed(seed)
42 | torch.cuda.manual_seed(seed)
43 |
44 | os.environ["PYTHONHASHSEED"] = str(seed)
45 | torch.cuda.manual_seed_all(seed)
46 |
47 |
48 | def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
49 | model.train()
50 |
51 |
52 | metric_logger = utils.MetricLogger(delimiter=" ")
53 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
54 | metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
55 |
56 | header = f"Epoch: [{epoch}]"
57 | for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
58 |
59 | if image.shape[0] <= 1:
60 | continue
61 |
62 | start_time = time.time()
63 | image, target = image.to(device), target.to(device)
64 |
65 | model.zero_grad()
66 | with torch.cuda.amp.autocast(enabled=scaler is not None):
67 | output = model(image)
68 | loss = criterion(output, target)
69 |
70 | optimizer.zero_grad()
71 | if scaler is not None:
72 | scaler.scale(loss).backward()
73 | if args.clip_grad_norm is not None:
74 | # we should unscale the gradients of optimizer's assigned params if do gradient clipping
75 | scaler.unscale_(optimizer)
76 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
77 | scaler.step(optimizer)
78 | scaler.update()
79 | else:
80 | loss.backward()
81 | if args.clip_grad_norm is not None:
82 | nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
83 | optimizer.step()
84 |
85 | if model_ema and i % args.model_ema_steps == 0:
86 | model_ema.update_parameters(model)
87 | if epoch < args.lr_warmup_epochs:
88 | # Reset ema buffer to keep copying weights during warmup period
89 | model_ema.n_averaged.fill_(0)
90 |
91 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
92 | batch_size = image.shape[0]
93 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
94 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
95 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
96 | metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
97 |
98 |
99 |
100 | return metric_logger
101 |
102 |
103 | def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
104 | model.eval()
105 | n_threads = torch.get_num_threads()
106 | # FIXME remove this and make paste_masks_in_image run on the GPU
107 | torch.set_num_threads(1)
108 | metric_logger = utils.MetricLogger(delimiter=" ")
109 | header = f"Test: {log_suffix}"
110 |
111 | num_processed_samples = 0
112 |
113 | with torch.inference_mode():
114 | for i,(image, target) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
115 | image = image.to(device)
116 | target = target.to(device)
117 |
118 | if torch.cuda.is_available():
119 | torch.cuda.synchronize()
120 |
121 | output = model(image)
122 | loss = criterion(output, target)
123 |
124 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
125 | # FIXME need to take into account that the datasets
126 | # could have been padded in distributed setup
127 | batch_size = image.shape[0]
128 | metric_logger.update(loss=loss.item())
129 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
130 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
131 | num_processed_samples += batch_size
132 |
133 |
134 |
135 | # gather the stats from all processes
136 |
137 | num_processed_samples = utils.reduce_across_processes(num_processed_samples)
138 |
139 | metric_logger.synchronize_between_processes()
140 |
141 | print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
142 |
143 | torch.set_num_threads(n_threads)
144 |
145 |
146 | return metric_logger.acc1.global_avg, metric_logger.acc5.global_avg, metric_logger.loss.global_avg
147 |
148 |
149 | def main(args):
150 |
151 |
152 | if args.seed is not None:
153 | set_seed(args.seed)
154 |
155 | model, params, trainable_parameters, n_classes = get_model(
156 | model_type = args.model_type,
157 | use_shift = args.use_shift,
158 | adapt_knn= args.adapt_knn,
159 | checkpoint = args.checkpoint,
160 | freezed = args.freezed,
161 | dataset = args.dataset,
162 | crop_size=args.crop_size)
163 |
164 |
165 | if args.num_gpu > 1:
166 | print('Using DataParallel')
167 | # model = torch.nn.DataParallel(model)
168 | model = nn.DataParallel(model)
169 |
170 | print("Let's use", torch.cuda.device_count(), "GPUs!")
171 |
172 | print('\n\n--------------------------------------')
173 |
174 | print(f"Parameters: {params}")
175 | print(f"Trainable Parameters: {trainable_parameters}")
176 |
177 | print('--------------------------------------\n\n')
178 |
179 | if not args.test_only:
180 | wandb.log({
181 | f'params/parameters':params,
182 | f'params/trainable_params':trainable_parameters
183 | })
184 |
185 |
186 | if args.save_dir and not args.test_only:
187 | print(f'# Results will be saved in {args.save_dir}')
188 | try:
189 | os.makedirs(args.save_dir)
190 | except OSError as e:
191 | if e.errno != errno.EEXIST:
192 | raise
193 |
194 |
195 | print(args)
196 |
197 | assert torch.cuda.is_available(), '# --- Cuda Not Available!!'
198 | device = 'cuda'
199 | model.to(device)
200 |
201 |
202 | if args.use_deterministic_algorithms:
203 | torch.use_deterministic_algorithms(True)
204 |
205 | # Data loading code
206 | print("Loading data")
207 |
208 | if args.crop_size == None:
209 | crop_size = 224
210 | if '256' in args.model_type:
211 | crop_size = 256
212 |
213 | if hasattr(model, 'default_cfg'):
214 | default_cfg = model.default_cfg
215 | input_sizes = list(default_cfg['input_size'])
216 | assert input_sizes[1] == input_sizes[2] and len(input_sizes) == 3, f'input sizes for model {args.model_type} is {input_sizes}'
217 | crop_size = int(input_sizes[1])
218 | else:
219 | crop_size = args.crop_size
220 |
221 |
222 |
223 | print(f'\n\n!!!Dataloaders with crop size: {crop_size}\n\n')
224 |
225 | args.distributed = False
226 | if args.dataset == 'CelebA':
227 | drop_last=False
228 | if 'pvig' in args.model_type:
229 | drop_last=True
230 |
231 | train_loader, valid_loader, test_loader, train_sampler, num_classes, _ = get_celeba(args, get_train_sampler=True, crop_size=crop_size, drop_last=drop_last)
232 | else:
233 | raise NotImplementedError(f'Dataset {args.dataset} not yet implemented ')
234 |
235 | if valid_loader is not None:
236 | valid_loader = None if len(valid_loader) == 0 else valid_loader
237 |
238 |
239 | print(f'Train Dataloader: {len(train_loader)}')
240 | if valid_loader is not None:
241 | print(f'Valid Dataloader: {len(valid_loader)}')
242 | print(f'Test Dataloader: {len(test_loader)}')
243 | print(f'Founded classes: {num_classes}')
244 |
245 | assert num_classes == n_classes, f'Founded {num_classes} classes in the dataset, while you specified {n_classes}'
246 |
247 |
248 | if args.loss == 'cross_entropy':
249 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
250 | elif args.loss == 'nll':
251 | criterion = nn.NLLLoss()
252 |
253 |
254 | custom_keys_weight_decay = []
255 | if args.bias_weight_decay is not None:
256 | custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
257 | if args.transformer_embedding_decay is not None:
258 | for key in ["class_token", "position_embedding", "relative_position_bias_table"]:
259 | custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
260 |
261 | parameters = utils.set_weight_decay(
262 | model,
263 | args.weight_decay,
264 | norm_weight_decay=args.norm_weight_decay,
265 | custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
266 | )
267 |
268 | opt_name = args.opt.lower()
269 | if opt_name.startswith("sgd"):
270 | optimizer = torch.optim.SGD(
271 | parameters,
272 | lr=args.lr,
273 | momentum=args.momentum,
274 | weight_decay=args.weight_decay,
275 | nesterov="nesterov" in opt_name,
276 | )
277 | elif opt_name == "rmsprop":
278 | optimizer = torch.optim.RMSprop(
279 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
280 | )
281 | elif opt_name == "adamw":
282 | optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
283 | elif opt_name == "adam":
284 | optimizer = torch.optim.Adam(parameters, lr=args.lr)
285 | else:
286 | raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
287 |
288 | scaler = torch.cuda.amp.GradScaler() if args.amp else None
289 |
290 | args.lr_scheduler = args.lr_scheduler.lower()
291 | if args.lr_scheduler == "steplr":
292 | main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
293 | elif args.lr_scheduler == "cosineannealinglr":
294 | main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
295 | optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=args.lr_min
296 | )
297 | elif args.lr_scheduler == "exponentiallr":
298 | main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
299 | elif args.lr_scheduler == "constant":
300 | main_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=args.epochs - args.lr_warmup_epochs)
301 | else:
302 | raise RuntimeError(
303 | f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
304 | "are supported."
305 | )
306 |
307 | if args.lr_warmup_epochs > 0:
308 | if args.lr_warmup_method == "linear":
309 | warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
310 | optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
311 | )
312 | elif args.lr_warmup_method == "constant":
313 | warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
314 | optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs
315 | )
316 | else:
317 | raise RuntimeError(
318 | f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
319 | )
320 | lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
321 | optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs]
322 | )
323 | else:
324 | lr_scheduler = main_lr_scheduler
325 |
326 |
327 |
328 | if args.resume:
329 | checkpoint = torch.load(args.resume, map_location="cpu")
330 | model.load_state_dict(checkpoint["model"])
331 | if not args.test_only:
332 | optimizer.load_state_dict(checkpoint["optimizer"])
333 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
334 | args.start_epoch = checkpoint["epoch"] + 1
335 | if scaler and "scaler" in checkpoint:
336 | scaler.load_state_dict(checkpoint["scaler"])
337 |
338 | if args.test_only:
339 | # torch.backends.cudnn.deterministic = True
340 | test_acc, test_acc_5, test_loss = evaluate(model, criterion, test_loader, device=device)
341 | print(f'Test acc: {test_acc}')
342 |
343 | return test_acc, test_acc_5, test_loss
344 |
345 |
346 |
347 | print("Start training")
348 | start_time = time.time()
349 | for epoch in range(args.start_epoch, args.epochs):
350 | print(f'\n\n--------\nStarting epoch: {epoch}')
351 | metric_logger = train_one_epoch(model, criterion, optimizer, train_loader, device, epoch, args, None, scaler)
352 | print(f'train acc: {metric_logger.acc1.global_avg}')
353 |
354 | wandb.log({
355 | f'train/train_loss':metric_logger.loss.global_avg,
356 | f'lr':metric_logger.lr.global_avg,
357 | f'train/train_acc@1':metric_logger.acc1.global_avg,
358 | f'train/train_acc@5':metric_logger.acc5.global_avg
359 | }, step = epoch)
360 | lr_scheduler.step()
361 |
362 |
363 | # validate after every epoch
364 | val_acc = 0.0
365 | if valid_loader is not None:
366 | val_acc, val_acc_5, val_loss = evaluate(model, criterion, valid_loader, device=device)
367 |
368 | print(f'valid acc: {val_acc}')
369 | wandb.log({
370 | f'val/val_loss':val_loss,
371 | f'val/val_acc@1':val_acc,
372 | f'val/val_acc@5':val_acc_5
373 | }, step = epoch)
374 |
375 |
376 | # evaluate after every epoch
377 | test_acc, test_acc_5, test_loss = evaluate(model, criterion, test_loader, device=device)
378 |
379 |
380 | print(f'test acc: {test_acc}')
381 | wandb.log({
382 | f'test/test_loss':test_loss,
383 | f'test/test_acc@1':test_acc,
384 | f'test/test_acc@5':test_acc_5
385 | }, step = epoch)
386 |
387 | if args.save_dir:
388 | checkpoint = {
389 | "model": model.state_dict(),
390 | "optimizer": optimizer.state_dict(),
391 | "lr_scheduler": lr_scheduler.state_dict(),
392 | "args": args,
393 | "epoch": epoch,
394 | "train_acc": metric_logger.acc1.global_avg,
395 | "val_acc": val_acc,
396 | "test_acc": test_acc,
397 | "params":params
398 | }
399 | if scaler:
400 | checkpoint["scaler"] = scaler.state_dict()
401 |
402 | utils.save_on_master(checkpoint, os.path.join(args.save_dir, "checkpoint.pth"))
403 |
404 |
405 | total_time = time.time() - start_time
406 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
407 | print(f'\n\n--------------\nEnd training')
408 | print(f"Training time {total_time_str}")
409 | print(f'Train acc@1: {metric_logger.acc1.global_avg}')
410 | print(f'Test acc@1: {test_acc}')
411 |
412 |
413 | wandb.log({
414 | f'end/train_acc@1': metric_logger.acc1.global_avg,
415 | f'end/val_acc@1': val_acc,
416 | f'end/test_acc@1': test_acc
417 | })
418 |
419 |
420 |
421 |
422 |
423 |
424 | if __name__ == "__main__":
425 |
426 | # torch.autograd.set_detect_anomaly(True)
427 |
428 | args = get_args_parser().parse_args()
429 |
430 | print('Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices')
431 | print(args.save_dir)
432 |
433 | if not args.test_only:
434 | wandb.login()
435 | wandb_name = f'{args.dataset}/{args.model_type}/train_seed{args.seed}'
436 |
437 | wandb.init(
438 | # Set the project where this run will be logged
439 | project='WACV_WiGNet_transfer_learning_high_res',
440 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
441 | name=wandb_name,
442 | config=args)
443 |
444 | args.save_dir += wandb_name
445 | else:
446 | args.save_dir = None
447 |
448 |
449 | main(args=args)
450 |
451 | if not args.test_only:
452 | wandb.run.finish()
453 |
454 |
455 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import datetime
3 | import errno
4 | import hashlib
5 | import os
6 | import time
7 | from collections import defaultdict, deque, OrderedDict
8 | from typing import List, Optional, Tuple
9 |
10 | import torch
11 | import torch.distributed as dist
12 |
13 |
14 | class SmoothedValue:
15 | """Track a series of values and provide access to smoothed values over a
16 | window or the global series average.
17 | """
18 |
19 | def __init__(self, window_size=20, fmt=None):
20 | if fmt is None:
21 | fmt = "{median:.4f} ({global_avg:.4f})"
22 | self.deque = deque(maxlen=window_size)
23 | self.total = 0.0
24 | self.count = 0
25 | self.fmt = fmt
26 |
27 | def update(self, value, n=1):
28 | self.deque.append(value)
29 | self.count += n
30 | self.total += value * n
31 |
32 | def synchronize_between_processes(self):
33 | """
34 | Warning: does not synchronize the deque!
35 | """
36 | t = reduce_across_processes([self.count, self.total])
37 | t = t.tolist()
38 | self.count = int(t[0])
39 | self.total = t[1]
40 |
41 | @property
42 | def median(self):
43 | d = torch.tensor(list(self.deque))
44 | return d.median().item()
45 |
46 | @property
47 | def avg(self):
48 | d = torch.tensor(list(self.deque), dtype=torch.float32)
49 | return d.mean().item()
50 |
51 | @property
52 | def global_avg(self):
53 | return self.total / self.count
54 |
55 | @property
56 | def max(self):
57 | return max(self.deque)
58 |
59 | @property
60 | def value(self):
61 | return self.deque[-1]
62 |
63 | def __str__(self):
64 | return self.fmt.format(
65 | median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
66 | )
67 |
68 |
69 | class MetricLogger:
70 | def __init__(self, delimiter="\t"):
71 | self.meters = defaultdict(SmoothedValue)
72 | self.delimiter = delimiter
73 |
74 | def update(self, **kwargs):
75 | for k, v in kwargs.items():
76 | if isinstance(v, torch.Tensor):
77 | v = v.item()
78 | assert isinstance(v, (float, int))
79 | self.meters[k].update(v)
80 |
81 | def __getattr__(self, attr):
82 | if attr in self.meters:
83 | return self.meters[attr]
84 | if attr in self.__dict__:
85 | return self.__dict__[attr]
86 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
87 |
88 | def __str__(self):
89 | loss_str = []
90 | for name, meter in self.meters.items():
91 | loss_str.append(f"{name}: {str(meter)}")
92 | return self.delimiter.join(loss_str)
93 |
94 | def synchronize_between_processes(self):
95 | for meter in self.meters.values():
96 | meter.synchronize_between_processes()
97 |
98 | def add_meter(self, name, meter):
99 | self.meters[name] = meter
100 |
101 | def log_every(self, iterable, print_freq, header=None):
102 | i = 0
103 | if not header:
104 | header = ""
105 | start_time = time.time()
106 | end = time.time()
107 | iter_time = SmoothedValue(fmt="{avg:.4f}")
108 | data_time = SmoothedValue(fmt="{avg:.4f}")
109 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
110 | if torch.cuda.is_available():
111 | log_msg = self.delimiter.join(
112 | [
113 | header,
114 | "[{0" + space_fmt + "}/{1}]",
115 | "eta: {eta}",
116 | "{meters}",
117 | "time: {time}",
118 | "data: {data}",
119 | "max mem: {memory:.0f}",
120 | ]
121 | )
122 | else:
123 | log_msg = self.delimiter.join(
124 | [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
125 | )
126 | MB = 1024.0 * 1024.0
127 | for obj in iterable:
128 | data_time.update(time.time() - end)
129 | yield obj
130 | iter_time.update(time.time() - end)
131 | if i % print_freq == 0:
132 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
133 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
134 | if torch.cuda.is_available():
135 | print(
136 | log_msg.format(
137 | i,
138 | len(iterable),
139 | eta=eta_string,
140 | meters=str(self),
141 | time=str(iter_time),
142 | data=str(data_time),
143 | memory=torch.cuda.max_memory_allocated() / MB,
144 | )
145 | )
146 | else:
147 | print(
148 | log_msg.format(
149 | i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
150 | )
151 | )
152 | i += 1
153 | end = time.time()
154 | total_time = time.time() - start_time
155 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
156 | print(f"{header} Total time: {total_time_str}")
157 |
158 |
159 | class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
160 | """Maintains moving averages of model parameters using an exponential decay.
161 | ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
162 | `torch.optim.swa_utils.AveragedModel `_
163 | is used to compute the EMA.
164 | """
165 |
166 | def __init__(self, model, decay, device="cpu"):
167 | def ema_avg(avg_model_param, model_param, num_averaged):
168 | return decay * avg_model_param + (1 - decay) * model_param
169 |
170 | super().__init__(model, device, ema_avg, use_buffers=True)
171 |
172 |
173 | def accuracy(output, target, topk=(1,)):
174 | """Computes the accuracy over the k top predictions for the specified values of k"""
175 | with torch.inference_mode():
176 | maxk = max(topk)
177 | batch_size = target.size(0)
178 | if target.ndim == 2:
179 | target = target.max(dim=1)[1]
180 |
181 | _, pred = output.topk(maxk, 1, True, True)
182 | pred = pred.t()
183 | correct = pred.eq(target[None])
184 |
185 | res = []
186 | for k in topk:
187 | correct_k = correct[:k].flatten().sum(dtype=torch.float32)
188 | res.append(correct_k * (100.0 / batch_size))
189 | return res
190 |
191 |
192 | def mkdir(path):
193 | try:
194 | os.makedirs(path)
195 | except OSError as e:
196 | if e.errno != errno.EEXIST:
197 | raise
198 |
199 |
200 | def setup_for_distributed(is_master):
201 | """
202 | This function disables printing when not in master process
203 | """
204 | import builtins as __builtin__
205 |
206 | builtin_print = __builtin__.print
207 |
208 | def print(*args, **kwargs):
209 | force = kwargs.pop("force", False)
210 | if is_master or force:
211 | builtin_print(*args, **kwargs)
212 |
213 | __builtin__.print = print
214 |
215 |
216 | def is_dist_avail_and_initialized():
217 | if not dist.is_available():
218 | return False
219 | if not dist.is_initialized():
220 | return False
221 | return True
222 |
223 |
224 | def get_world_size():
225 | if not is_dist_avail_and_initialized():
226 | return 1
227 | return dist.get_world_size()
228 |
229 |
230 | def get_rank():
231 | if not is_dist_avail_and_initialized():
232 | return 0
233 | return dist.get_rank()
234 |
235 |
236 | def is_main_process(distributed = False):
237 | return get_rank() == 0 or not distributed
238 |
239 |
240 | def save_on_master(*args, **kwargs):
241 | if is_main_process():
242 | torch.save(*args, **kwargs)
243 |
244 |
245 | def init_distributed_mode(args):
246 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
247 | args.rank = int(os.environ["RANK"])
248 | args.world_size = int(os.environ["WORLD_SIZE"])
249 | args.gpu = int(os.environ["LOCAL_RANK"])
250 | # elif "SLURM_PROCID" in os.environ:
251 | # args.rank = int(os.environ["SLURM_PROCID"])
252 | # args.gpu = args.rank % torch.cuda.device_count()
253 | elif hasattr(args, "rank"):
254 | pass
255 | else:
256 | print("Not using distributed mode")
257 | args.distributed = False
258 | return
259 |
260 | args.distributed = True
261 |
262 | torch.cuda.set_device(args.gpu)
263 | args.dist_backend = "nccl"
264 | print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
265 | torch.distributed.init_process_group(
266 | backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
267 | )
268 | torch.distributed.barrier()
269 | setup_for_distributed(args.rank == 0)
270 |
271 |
272 | def average_checkpoints(inputs):
273 | """Loads checkpoints from inputs and returns a model with averaged weights. Original implementation taken from:
274 | https://github.com/pytorch/fairseq/blob/a48f235636557b8d3bc4922a6fa90f3a0fa57955/scripts/average_checkpoints.py#L16
275 |
276 | Args:
277 | inputs (List[str]): An iterable of string paths of checkpoints to load from.
278 | Returns:
279 | A dict of string keys mapping to various values. The 'model' key
280 | from the returned dict should correspond to an OrderedDict mapping
281 | string parameter names to torch Tensors.
282 | """
283 | params_dict = OrderedDict()
284 | params_keys = None
285 | new_state = None
286 | num_models = len(inputs)
287 | for fpath in inputs:
288 | with open(fpath, "rb") as f:
289 | state = torch.load(
290 | f,
291 | map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
292 | )
293 | # Copies over the settings from the first checkpoint
294 | if new_state is None:
295 | new_state = state
296 | model_params = state["model"]
297 | model_params_keys = list(model_params.keys())
298 | if params_keys is None:
299 | params_keys = model_params_keys
300 | elif params_keys != model_params_keys:
301 | raise KeyError(
302 | f"For checkpoint {f}, expected list of params: {params_keys}, but found: {model_params_keys}"
303 | )
304 | for k in params_keys:
305 | p = model_params[k]
306 | if isinstance(p, torch.HalfTensor):
307 | p = p.float()
308 | if k not in params_dict:
309 | params_dict[k] = p.clone()
310 | # NOTE: clone() is needed in case of p is a shared parameter
311 | else:
312 | params_dict[k] += p
313 | averaged_params = OrderedDict()
314 | for k, v in params_dict.items():
315 | averaged_params[k] = v
316 | if averaged_params[k].is_floating_point():
317 | averaged_params[k].div_(num_models)
318 | else:
319 | averaged_params[k] //= num_models
320 | new_state["model"] = averaged_params
321 | return new_state
322 |
323 |
324 | def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True):
325 | """
326 | This method can be used to prepare weights files for new models. It receives as
327 | input a model architecture and a checkpoint from the training script and produces
328 | a file with the weights ready for release.
329 |
330 | Examples:
331 | from torchvision import models as M
332 |
333 | # Classification
334 | model = M.mobilenet_v3_large(weights=None)
335 | print(store_model_weights(model, './class.pth'))
336 |
337 | # Quantized Classification
338 | model = M.quantization.mobilenet_v3_large(weights=None, quantize=False)
339 | model.fuse_model(is_qat=True)
340 | model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
341 | _ = torch.ao.quantization.prepare_qat(model, inplace=True)
342 | print(store_model_weights(model, './qat.pth'))
343 |
344 | # Object Detection
345 | model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(weights=None, weights_backbone=None)
346 | print(store_model_weights(model, './obj.pth'))
347 |
348 | # Segmentation
349 | model = M.segmentation.deeplabv3_mobilenet_v3_large(weights=None, weights_backbone=None, aux_loss=True)
350 | print(store_model_weights(model, './segm.pth', strict=False))
351 |
352 | Args:
353 | model (pytorch.nn.Module): The model on which the weights will be loaded for validation purposes.
354 | checkpoint_path (str): The path of the checkpoint we will load.
355 | checkpoint_key (str, optional): The key of the checkpoint where the model weights are stored.
356 | Default: "model".
357 | strict (bool): whether to strictly enforce that the keys
358 | in :attr:`state_dict` match the keys returned by this module's
359 | :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
360 |
361 | Returns:
362 | output_path (str): The location where the weights are saved.
363 | """
364 | # Store the new model next to the checkpoint_path
365 | checkpoint_path = os.path.abspath(checkpoint_path)
366 | output_dir = os.path.dirname(checkpoint_path)
367 |
368 | # Deep copy to avoid side effects on the model object.
369 | model = copy.deepcopy(model)
370 | checkpoint = torch.load(checkpoint_path, map_location="cpu")
371 |
372 | # Load the weights to the model to validate that everything works
373 | # and remove unnecessary weights (such as auxiliaries, etc.)
374 | if checkpoint_key == "model_ema":
375 | del checkpoint[checkpoint_key]["n_averaged"]
376 | torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
377 | model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
378 |
379 | tmp_path = os.path.join(output_dir, str(model.__hash__()))
380 | torch.save(model.state_dict(), tmp_path)
381 |
382 | sha256_hash = hashlib.sha256()
383 | with open(tmp_path, "rb") as f:
384 | # Read and update hash string value in blocks of 4K
385 | for byte_block in iter(lambda: f.read(4096), b""):
386 | sha256_hash.update(byte_block)
387 | hh = sha256_hash.hexdigest()
388 |
389 | output_path = os.path.join(output_dir, "weights-" + str(hh[:8]) + ".pth")
390 | os.replace(tmp_path, output_path)
391 |
392 | return output_path
393 |
394 |
395 | def reduce_across_processes(val):
396 | if not is_dist_avail_and_initialized():
397 | # nothing to sync, but we still convert to tensor for consistency with the distributed case.
398 | return torch.tensor(val)
399 |
400 | t = torch.tensor(val, device="cuda")
401 | dist.barrier()
402 | dist.all_reduce(t)
403 | return t
404 |
405 |
406 | def set_weight_decay(
407 | model: torch.nn.Module,
408 | weight_decay: float,
409 | norm_weight_decay: Optional[float] = None,
410 | norm_classes: Optional[List[type]] = None,
411 | custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
412 | ):
413 | if not norm_classes:
414 | norm_classes = [
415 | torch.nn.modules.batchnorm._BatchNorm,
416 | torch.nn.LayerNorm,
417 | torch.nn.GroupNorm,
418 | torch.nn.modules.instancenorm._InstanceNorm,
419 | torch.nn.LocalResponseNorm,
420 | ]
421 | norm_classes = tuple(norm_classes)
422 |
423 | params = {
424 | "other": [],
425 | "norm": [],
426 | }
427 | params_weight_decay = {
428 | "other": weight_decay,
429 | "norm": norm_weight_decay,
430 | }
431 | custom_keys = []
432 | if custom_keys_weight_decay is not None:
433 | for key, weight_decay in custom_keys_weight_decay:
434 | params[key] = []
435 | params_weight_decay[key] = weight_decay
436 | custom_keys.append(key)
437 |
438 | def _add_params(module, prefix=""):
439 | for name, p in module.named_parameters(recurse=False):
440 | if not p.requires_grad:
441 | continue
442 | is_custom_key = False
443 | for key in custom_keys:
444 | target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
445 | if key == target_name:
446 | params[key].append(p)
447 | is_custom_key = True
448 | break
449 | if not is_custom_key:
450 | if norm_weight_decay is not None and isinstance(module, norm_classes):
451 | params["norm"].append(p)
452 | else:
453 | params["other"].append(p)
454 |
455 | for child_name, child_module in module.named_children():
456 | child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
457 | _add_params(child_module, prefix=child_prefix)
458 |
459 | _add_params(model)
460 |
461 | param_groups = []
462 | for key in params:
463 | if len(params[key]) > 0:
464 | param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
465 | return param_groups
466 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 |
2 | import warnings
3 | warnings.filterwarnings('ignore')
4 | import argparse
5 | import time
6 | import os
7 | import logging
8 | from collections import OrderedDict
9 | from contextlib import suppress
10 | from datetime import datetime
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torchvision.utils
15 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
16 |
17 | from timm.data import ImageDataset as Dataset, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset #, create_loader
18 | from timm.models import create_model, resume_checkpoint #, convert_splitbn_model
19 | from timm.utils import *
20 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
21 | from timm.optim import create_optimizer
22 | from timm.scheduler import create_scheduler
23 | from timm.utils import ApexScaler, NativeScaler
24 |
25 | from data import create_loader
26 | from model import pyramid_vig
27 | from model import wignn
28 | from model import wignn_256
29 |
30 | import sys
31 |
32 | from opt import parse_args
33 | import wandb
34 |
35 | import torch.backends.cudnn as cudnn
36 | import numpy as np
37 | import random
38 |
39 | try:
40 | from apex import amp
41 | from apex.parallel import DistributedDataParallel as ApexDDP
42 | from apex.parallel import convert_syncbn_model
43 | has_apex = True
44 | except ImportError:
45 | has_apex = False
46 |
47 | has_native_amp = False
48 | try:
49 | if getattr(torch.cuda.amp, 'autocast') is not None:
50 | has_native_amp = True
51 | except AttributeError:
52 | pass
53 |
54 | torch.backends.cudnn.benchmark = True
55 | _logger = logging.getLogger('train')
56 |
57 |
58 | def main():
59 | setup_default_logging()
60 | args, args_text = parse_args()
61 |
62 | if args.evaluate:
63 | random.seed(42)
64 | torch.manual_seed(42)
65 | cudnn.deterministic = True
66 | cudnn.benchmark = False
67 | np.random.seed(42)
68 | torch.cuda.manual_seed(42)
69 |
70 | os.environ["PYTHONHASHSEED"] = str(42)
71 | torch.cuda.manual_seed_all(42)
72 |
73 | args.prefetcher = not args.no_prefetcher
74 | args.distributed = False
75 | if 'WORLD_SIZE' in os.environ:
76 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
77 | if args.distributed and args.num_gpu > 1:
78 | _logger.warning(
79 | 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
80 | args.num_gpu = 1
81 |
82 | args.device = 'cuda:0'
83 | args.world_size = 1
84 | args.rank = 0 # global rank
85 | if args.distributed:
86 | args.num_gpu = 1
87 | args.device = 'cuda:%d' % args.local_rank
88 | torch.cuda.set_device(args.local_rank)
89 | args.world_size = int(os.environ['WORLD_SIZE'])
90 | args.rank = int(os.environ['RANK'])
91 | torch.distributed.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size)
92 | args.world_size = torch.distributed.get_world_size()
93 | args.rank = torch.distributed.get_rank()
94 | assert args.rank >= 0
95 |
96 | if args.local_rank == 0 and not args.evaluate:
97 | wandb.init(
98 | # Set the project where this run will be logged
99 | project='WACV_WiGNet-inet',
100 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
101 | name=f'{args.model}_shift{args.use_shift}_k{args.knn}_adapt{args.adapt_knn}',
102 | config=args)
103 |
104 | if args.distributed:
105 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
106 | % (args.rank, args.world_size))
107 | else:
108 | _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
109 |
110 | torch.manual_seed(args.seed + args.rank)
111 |
112 | if('wignn' in args.model):
113 | model = create_model(
114 | args.model,
115 | num_classes=args.num_classes,
116 | drop_path_rate=args.drop_path,
117 | knn = args.knn,
118 | use_shifts = args.use_shift,
119 | adapt_knn = args.adapt_knn
120 | )
121 | else:
122 | model = create_model(
123 | args.model,
124 | pretrained=args.pretrained,
125 | num_classes=args.num_classes,
126 | drop_rate=args.drop,
127 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
128 | drop_path_rate=args.drop_path,
129 | drop_block_rate=args.drop_block,
130 | global_pool=args.gp,
131 | bn_tf=args.bn_tf,
132 | bn_momentum=args.bn_momentum,
133 | bn_eps=args.bn_eps,
134 | checkpoint_path=args.initial_checkpoint
135 | )
136 |
137 | ################## pretrain ############
138 | if args.pretrain_path is not None:
139 | print('Loading:', args.pretrain_path)
140 | state_dict = torch.load(args.pretrain_path)
141 | model.load_state_dict(state_dict, strict=False)
142 | print('Pretrain weights loaded.')
143 | ################### flops #################
144 | # print(model)
145 | if hasattr(model, 'default_cfg'):
146 | default_cfg = model.default_cfg
147 | input_size = [1] + list(default_cfg['input_size'])
148 | else:
149 | input_size = [1, 3, 224, 224]
150 | print(f'\n\n!!!!!!! Using input size: {input_size}\n\n')
151 | input = torch.randn(input_size)#.cuda()
152 |
153 | from torchprofile import profile_macs
154 | model.eval()
155 | macs = profile_macs(model, input)
156 | model.train()
157 | print('model flops:', macs, 'input_size:', input_size)
158 | ##########################################
159 |
160 | if args.local_rank == 0:
161 | _logger.info('Model %s created, param count: %d' %
162 | (args.model, sum([m.numel() for m in model.parameters()])))
163 |
164 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
165 |
166 | num_aug_splits = 0
167 | if args.aug_splits > 0:
168 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
169 | num_aug_splits = args.aug_splits
170 |
171 | """ if args.split_bn:
172 | assert num_aug_splits > 1 or args.resplit
173 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) """
174 |
175 | use_amp = None
176 | if args.amp:
177 | # for backwards compat, `--amp` arg tries apex before native amp
178 | if has_apex:
179 | args.apex_amp = True
180 | elif has_native_amp:
181 | args.native_amp = True
182 | if args.apex_amp and has_apex:
183 | use_amp = 'apex'
184 | elif args.native_amp and has_native_amp:
185 | use_amp = 'native'
186 | elif args.apex_amp or args.native_amp:
187 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
188 | "Install NVIDA apex or upgrade to PyTorch 1.6")
189 |
190 | if args.num_gpu > 1:
191 | if use_amp == 'apex':
192 | _logger.warning(
193 | 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
194 | use_amp = None
195 | model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
196 | assert not args.channels_last, "Channels last not supported with DP, use DDP."
197 | else:
198 | model.cuda()
199 | if args.channels_last:
200 | model = model.to(memory_format=torch.channels_last)
201 |
202 | optimizer = create_optimizer(args, model)
203 |
204 | amp_autocast = suppress # do nothing
205 | loss_scaler = None
206 | if use_amp == 'apex':
207 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
208 | loss_scaler = ApexScaler()
209 | if args.local_rank == 0:
210 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
211 | elif use_amp == 'native':
212 | amp_autocast = torch.cuda.amp.autocast
213 | loss_scaler = NativeScaler()
214 | if args.local_rank == 0:
215 | _logger.info('Using native Torch AMP. Training in mixed precision.')
216 | else:
217 | if args.local_rank == 0:
218 | _logger.info('AMP not enabled. Training in float32.')
219 |
220 | # optionally resume from a checkpoint
221 | resume_epoch = None
222 | if args.resume:
223 | resume_epoch = resume_checkpoint(
224 | model, args.resume,
225 | optimizer=None if args.no_resume_opt else optimizer,
226 | loss_scaler=None if args.no_resume_opt else loss_scaler,
227 | log_info=args.local_rank == 0)
228 |
229 | model_ema = None
230 | if args.model_ema:
231 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
232 | model_ema = ModelEma(
233 | model,
234 | decay=args.model_ema_decay,
235 | device='cpu' if args.model_ema_force_cpu else '',
236 | resume=args.resume)
237 |
238 | if args.distributed:
239 | if args.sync_bn:
240 | assert not args.split_bn
241 | try:
242 | if has_apex and use_amp != 'native':
243 | # Apex SyncBN preferred unless native amp is activated
244 | model = convert_syncbn_model(model)
245 | else:
246 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
247 | if args.local_rank == 0:
248 | _logger.info(
249 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
250 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
251 | except Exception as e:
252 | _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
253 | if has_apex and use_amp != 'native':
254 | # Apex DDP preferred unless native amp is activated
255 | if args.local_rank == 0:
256 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
257 | model = ApexDDP(model, delay_allreduce=True)
258 | else:
259 | if args.local_rank == 0:
260 | _logger.info("Using native Torch DistributedDataParallel.")
261 | model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
262 | # NOTE: EMA model does not need to be wrapped by DDP
263 |
264 | lr_scheduler, num_epochs = create_scheduler(args, optimizer)
265 | start_epoch = 0
266 | if args.start_epoch is not None:
267 | # a specified start_epoch will always override the resume epoch
268 | start_epoch = args.start_epoch
269 | elif resume_epoch is not None:
270 | start_epoch = resume_epoch
271 | if lr_scheduler is not None and start_epoch > 0:
272 | lr_scheduler.step(start_epoch)
273 |
274 | if args.local_rank == 0:
275 | _logger.info('Scheduled epochs: {}'.format(num_epochs))
276 |
277 | train_dir = os.path.join(args.data, 'train')
278 | if not os.path.exists(train_dir):
279 | _logger.error('Training folder does not exist at: {}'.format(train_dir))
280 | exit(1)
281 | dataset_train = Dataset(train_dir)
282 |
283 | collate_fn = None
284 | mixup_fn = None
285 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
286 | if mixup_active:
287 | mixup_args = dict(
288 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
289 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
290 | label_smoothing=args.smoothing, num_classes=args.num_classes)
291 | if args.prefetcher:
292 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
293 | collate_fn = FastCollateMixup(**mixup_args)
294 | else:
295 | mixup_fn = Mixup(**mixup_args)
296 |
297 | if num_aug_splits > 1:
298 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
299 |
300 | print(f"\n\n!!! input size dataloader: {data_config['input_size']}\n\n")
301 | train_interpolation = args.train_interpolation
302 | if args.no_aug or not train_interpolation:
303 | train_interpolation = data_config['interpolation']
304 | loader_train = create_loader(
305 | dataset_train,
306 | input_size=data_config['input_size'],
307 | batch_size=args.batch_size,
308 | is_training=True,
309 | use_prefetcher=args.prefetcher,
310 | no_aug=args.no_aug,
311 | re_prob=args.reprob,
312 | re_mode=args.remode,
313 | re_count=args.recount,
314 | re_split=args.resplit,
315 | scale=args.scale,
316 | ratio=args.ratio,
317 | hflip=args.hflip,
318 | vflip=args.vflip,
319 | color_jitter=args.color_jitter,
320 | auto_augment=args.aa,
321 | num_aug_splits=num_aug_splits,
322 | interpolation=train_interpolation,
323 | mean=data_config['mean'],
324 | std=data_config['std'],
325 | num_workers=args.workers,
326 | distributed=args.distributed,
327 | collate_fn=collate_fn,
328 | pin_memory=args.pin_mem,
329 | use_multi_epochs_loader=args.use_multi_epochs_loader,
330 | repeated_aug=args.repeated_aug
331 | )
332 |
333 | eval_dir = os.path.join(args.data, 'val')
334 | if not os.path.isdir(eval_dir):
335 | eval_dir = os.path.join(args.data, 'validation')
336 | if not os.path.isdir(eval_dir):
337 | _logger.error('Validation folder does not exist at: {}'.format(eval_dir))
338 | exit(1)
339 | dataset_eval = Dataset(eval_dir)
340 |
341 | loader_eval = create_loader(
342 | dataset_eval,
343 | input_size=data_config['input_size'],
344 | batch_size=args.validation_batch_size_multiplier * args.batch_size,
345 | is_training=False,
346 | use_prefetcher=args.prefetcher,
347 | interpolation=data_config['interpolation'],
348 | mean=data_config['mean'],
349 | std=data_config['std'],
350 | num_workers=args.workers,
351 | distributed=args.distributed,
352 | crop_pct=data_config['crop_pct'],
353 | pin_memory=args.pin_mem,
354 | )
355 |
356 | if args.jsd:
357 | assert num_aug_splits > 1 # JSD only valid with aug splits set
358 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
359 | elif mixup_active:
360 | # smoothing is handled with mixup target transform
361 | train_loss_fn = SoftTargetCrossEntropy().cuda()
362 | elif args.smoothing:
363 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
364 | else:
365 | train_loss_fn = nn.CrossEntropyLoss().cuda()
366 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
367 |
368 | if args.evaluate:
369 | print('Evaluating model..')
370 | if model_ema is not None:
371 | eval_metrics_test = validate(model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA) ')
372 | else:
373 | eval_metrics_test = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
374 |
375 | print(eval_metrics_test)
376 | return
377 |
378 | eval_metric = args.eval_metric
379 | best_metric_no_ema = None
380 | best_epoch_no_ema = None
381 |
382 | best_metric_ema = None
383 | best_epoch_ema = None
384 |
385 | saver_no_ema = None
386 | output_dir_no_ema = ''
387 | saver_ema = None
388 | output_dir_ema = ''
389 | if args.local_rank == 0:
390 | output_base = args.output if args.output else './output'
391 | exp_name = '-'.join([
392 | datetime.now().strftime("%Y%m%d-%H%M%S"),
393 | args.model,
394 | str(data_config['input_size'][-1])
395 | ])
396 | output_dir_no_ema = get_outdir(output_base, 'train', f'{exp_name}_NO_EMA')
397 | decreasing = True if eval_metric == 'loss' else False
398 | saver_no_ema = CheckpointSaver(
399 | model=model, optimizer=optimizer, args=args, model_ema=None, amp_scaler=loss_scaler,
400 | checkpoint_dir=output_dir_no_ema, recovery_dir=output_dir_no_ema, decreasing=decreasing)
401 | with open(os.path.join(output_dir_no_ema, 'args.yaml'), 'w') as f:
402 | f.write(args_text)
403 |
404 | if args.model_ema:
405 | output_dir_ema = get_outdir(output_base, 'train', f'{exp_name}_EMA')
406 | saver_ema = CheckpointSaver(
407 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
408 | checkpoint_dir=output_dir_ema, recovery_dir=output_dir_ema, decreasing=decreasing)
409 | with open(os.path.join(output_dir_ema, 'args.yaml'), 'w') as f:
410 | f.write(args_text)
411 |
412 |
413 |
414 | # try:
415 | for epoch in range(start_epoch, num_epochs):
416 | if args.distributed:
417 | loader_train.sampler.set_epoch(epoch)
418 |
419 | train_metrics = train_epoch(
420 | epoch, model, loader_train, optimizer, train_loss_fn, args,
421 | lr_scheduler=lr_scheduler, saver=saver_ema if args.model_ema else saver_no_ema, output_dir=output_dir_ema if args.model_ema else output_dir_no_ema,
422 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
423 |
424 | if args.local_rank == 0:
425 | wandb.log({
426 | 'train/loss':train_metrics['loss'],
427 | 'train/lr': optimizer.param_groups[0]["lr"]
428 | },step = epoch)
429 |
430 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
431 | if args.local_rank == 0:
432 | _logger.info("Distributing BatchNorm running means and vars")
433 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
434 |
435 | eval_metrics_no_ema = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
436 |
437 | if args.local_rank == 0:
438 | wandb.log({
439 | 'val/loss':eval_metrics_no_ema['loss'],
440 | 'val/acc@1':eval_metrics_no_ema['top1'],
441 | 'val/acc@5':eval_metrics_no_ema['top5']
442 | },step = epoch)
443 |
444 | if model_ema is not None and not args.model_ema_force_cpu:
445 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
446 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
447 | eval_metrics_ema = validate(
448 | model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
449 | # eval_metrics = ema_eval_metrics
450 |
451 | if args.local_rank == 0:
452 | wandb.log({
453 | 'val_ema/loss':eval_metrics_ema['loss'],
454 | 'val_ema/acc@1':eval_metrics_ema['top1'],
455 | 'val_ema/acc@5':eval_metrics_ema['top5']
456 | },step = epoch)
457 |
458 | if lr_scheduler is not None:
459 | # step LR for next epoch
460 | if args.model_ema:
461 | lr_scheduler.step(epoch + 1, eval_metrics_ema[eval_metric])
462 | else:
463 | lr_scheduler.step(epoch + 1, eval_metrics_no_ema[eval_metric])
464 |
465 | update_summary(
466 | epoch, train_metrics, eval_metrics_no_ema, os.path.join(output_dir_no_ema, 'summary.csv'),
467 | write_header=best_metric_no_ema is None)
468 |
469 | if args.model_ema:
470 | update_summary(
471 | epoch, train_metrics, eval_metrics_ema, os.path.join(output_dir_ema, 'summary.csv'),
472 | write_header=best_metric_ema is None)
473 |
474 | if saver_no_ema is not None:
475 | # save proper checkpoint with eval metric
476 | save_metric = eval_metrics_no_ema[eval_metric]
477 | best_metric_no_ema, best_epoch_no_ema = saver_no_ema.save_checkpoint(epoch, metric=save_metric)
478 |
479 | if saver_ema is not None:
480 | # save proper checkpoint with eval metric
481 | save_metric = eval_metrics_ema[eval_metric]
482 | best_metric_ema, best_epoch_ema = saver_ema.save_checkpoint(epoch, metric=save_metric)
483 |
484 | # except KeyboardInterrupt:
485 | # pass
486 |
487 | if best_metric_no_ema is not None:
488 | _logger.info('*** Best metric (NO EMA): {0} (epoch {1})'.format(best_metric_no_ema, best_epoch_no_ema))
489 |
490 | if args.local_rank == 0:
491 | wandb.log({
492 | f'best_no_ema/{eval_metric}':best_metric_no_ema,
493 | 'best_no_ema/epoch':best_epoch_no_ema
494 | },step = epoch)
495 |
496 | if best_metric_ema is not None:
497 | _logger.info('*** Best metric (EMA): {0} (epoch {1})'.format(best_metric_ema, best_epoch_ema))
498 |
499 | if args.local_rank == 0:
500 | wandb.log({
501 | f'best_ema/{eval_metric}':best_metric_ema,
502 | 'best_ema/epoch':best_epoch_ema
503 | },step = epoch)
504 |
505 | wandb.finish()
506 |
507 |
508 | def train_epoch(
509 | epoch, model, loader, optimizer, loss_fn, args,
510 | lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
511 | loss_scaler=None, model_ema=None, mixup_fn=None):
512 |
513 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
514 | if args.prefetcher and loader.mixup_enabled:
515 | loader.mixup_enabled = False
516 | elif mixup_fn is not None:
517 | mixup_fn.mixup_enabled = False
518 |
519 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
520 | batch_time_m = AverageMeter()
521 | data_time_m = AverageMeter()
522 | losses_m = AverageMeter()
523 |
524 | model.train()
525 |
526 | end = time.time()
527 | last_idx = len(loader) - 1
528 | num_updates = epoch * len(loader)
529 | for batch_idx, (input, target) in enumerate(loader):
530 | # if batch_idx > 2: # TODO remove it
531 | # break
532 | last_batch = batch_idx == last_idx
533 | data_time_m.update(time.time() - end)
534 | if not args.prefetcher:
535 | input, target = input.cuda(), target.cuda()
536 | if mixup_fn is not None:
537 | input, target = mixup_fn(input, target)
538 | if args.channels_last:
539 | input = input.contiguous(memory_format=torch.channels_last)
540 |
541 | with amp_autocast():
542 | output = model(input)
543 | loss = loss_fn(output, target)
544 |
545 | # sys.exit(1)
546 |
547 | if not args.distributed:
548 | losses_m.update(loss.item(), input.size(0))
549 |
550 | optimizer.zero_grad()
551 | if loss_scaler is not None:
552 | loss_scaler(
553 | loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
554 | else:
555 | loss.backward(create_graph=second_order)
556 | if args.clip_grad is not None:
557 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
558 | optimizer.step()
559 |
560 | torch.cuda.synchronize()
561 | if model_ema is not None:
562 | model_ema.update(model)
563 | num_updates += 1
564 |
565 | batch_time_m.update(time.time() - end)
566 | if last_batch or batch_idx % args.log_interval == 0:
567 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
568 | lr = sum(lrl) / len(lrl)
569 |
570 | if args.distributed:
571 | reduced_loss = reduce_tensor(loss.data, args.world_size)
572 | losses_m.update(reduced_loss.item(), input.size(0))
573 |
574 | if args.local_rank == 0:
575 | _logger.info(
576 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
577 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
578 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
579 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
580 | 'LR: {lr:.3e} '
581 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
582 | epoch,
583 | batch_idx, len(loader),
584 | 100. * batch_idx / last_idx,
585 | loss=losses_m,
586 | batch_time=batch_time_m,
587 | rate=input.size(0) * args.world_size / batch_time_m.val,
588 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
589 | lr=lr,
590 | data_time=data_time_m))
591 |
592 | if args.save_images and output_dir:
593 | torchvision.utils.save_image(
594 | input,
595 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
596 | padding=0,
597 | normalize=True)
598 |
599 | if saver is not None and args.recovery_interval and (
600 | last_batch or (batch_idx + 1) % args.recovery_interval == 0):
601 | saver.save_recovery(epoch, batch_idx=batch_idx)
602 |
603 | if lr_scheduler is not None:
604 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
605 |
606 | end = time.time()
607 | # end for
608 |
609 | if hasattr(optimizer, 'sync_lookahead'):
610 | optimizer.sync_lookahead()
611 |
612 | return OrderedDict([('loss', losses_m.avg)])
613 |
614 |
615 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
616 | batch_time_m = AverageMeter()
617 | losses_m = AverageMeter()
618 | top1_m = AverageMeter()
619 | top5_m = AverageMeter()
620 |
621 | model.eval()
622 |
623 | end = time.time()
624 | last_idx = len(loader) - 1
625 | with torch.no_grad():
626 | for batch_idx, (input, target) in enumerate(loader):
627 |
628 | # if batch_idx > 20: # TODO remove it
629 | # break
630 | last_batch = batch_idx == last_idx
631 | if not args.prefetcher:
632 | input = input.cuda()
633 | target = target.cuda()
634 | if args.channels_last:
635 | input = input.contiguous(memory_format=torch.channels_last)
636 |
637 | with amp_autocast():
638 | output = model(input)
639 |
640 | if isinstance(output, (tuple, list)):
641 | output = output[0]
642 |
643 | # augmentation reduction
644 | reduce_factor = args.tta
645 | if reduce_factor > 1:
646 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
647 | target = target[0:target.size(0):reduce_factor]
648 |
649 | loss = loss_fn(output, target)
650 |
651 |
652 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
653 |
654 | if args.distributed:
655 | reduced_loss = reduce_tensor(loss.data, args.world_size)
656 | acc1 = reduce_tensor(acc1, args.world_size)
657 | acc5 = reduce_tensor(acc5, args.world_size)
658 | else:
659 | reduced_loss = loss.data
660 |
661 | torch.cuda.synchronize()
662 |
663 | losses_m.update(reduced_loss.item(), input.size(0))
664 | top1_m.update(acc1.item(), output.size(0))
665 | top5_m.update(acc5.item(), output.size(0))
666 |
667 | batch_time_m.update(time.time() - end)
668 | end = time.time()
669 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
670 | log_name = 'Test' + log_suffix
671 | _logger.info(
672 | '{0}: [{1:>4d}/{2}] '
673 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
674 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
675 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
676 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
677 | log_name, batch_idx, last_idx, batch_time=batch_time_m,
678 | loss=losses_m, top1=top1_m, top5=top5_m))
679 |
680 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
681 |
682 | return metrics
683 |
684 |
685 | if __name__ == '__main__':
686 | main()
--------------------------------------------------------------------------------