├── AIM ├── modules │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── bread-checkpoint.py │ │ ├── burger-checkpoint.py │ │ └── ham-checkpoint.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── bread.cpython-311.pyc │ │ ├── bread.cpython-37.pyc │ │ ├── bread.cpython-38.pyc │ │ ├── burger.cpython-311.pyc │ │ ├── burger.cpython-37.pyc │ │ ├── burger.cpython-38.pyc │ │ ├── ham.cpython-311.pyc │ │ ├── ham.cpython-37.pyc │ │ └── ham.cpython-38.pyc │ ├── bread.py │ ├── burger.py │ └── ham.py ├── settings.py └── sync_bn │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-38.pyc │ ├── metric.py │ ├── nn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-38.pyc │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── batchnorm.cpython-311.pyc │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ ├── batchnorm.cpython-38.pyc │ │ │ ├── comm.cpython-311.pyc │ │ │ ├── comm.cpython-37.pyc │ │ │ ├── comm.cpython-38.pyc │ │ │ ├── replicate.cpython-311.pyc │ │ │ ├── replicate.cpython-37.pyc │ │ │ └── replicate.cpython-38.pyc │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ ├── tests │ │ │ ├── test_numeric_batchnorm.py │ │ │ └── test_sync_batchnorm.py │ │ └── unittest.py │ └── parallel │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-311.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── data_parallel.cpython-311.pyc │ │ ├── data_parallel.cpython-37.pyc │ │ └── data_parallel.cpython-38.pyc │ │ └── data_parallel.py │ └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── dataloader.py │ ├── dataset.py │ ├── distributed.py │ └── sampler.py │ └── th.py ├── README.md ├── bpe_simple_vocab_16e6.txt.gz ├── datatest.py ├── datatrain.py ├── docs ├── README.md └── WSMA-appendix.pdf ├── images ├── README.md ├── pipelline.png └── pipelline1.pdf ├── models ├── dino │ ├── __pycache__ │ │ ├── utils.cpython-311.pyc │ │ ├── utils.cpython-37.pyc │ │ ├── vision_transformer.cpython-311.pyc │ │ └── vision_transformer.cpython-37.pyc │ ├── utils.py │ └── vision_transformer.py └── model.py ├── preprocessing.py ├── save_models └── README.md ├── simple_tokenizer.py ├── test.py ├── train.py └── utils ├── accuracy.py ├── evaluation.py ├── gtransforms.py ├── transform.py └── util.py /AIM/modules/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | -------------------------------------------------------------------------------- /AIM/modules/.ipynb_checkpoints/bread-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Hamburger for Pytorch 4 | 5 | @author: Gsunshine 6 | """ 7 | 8 | from functools import partial 9 | 10 | import numpy as np 11 | import AIM.settings as settings 12 | import torch 13 | from AIM.sync_bn.nn.modules import SynchronizedBatchNorm2d 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | 18 | norm_layer = partial(SynchronizedBatchNorm2d, momentum=settings.BN_MOM) 19 | 20 | 21 | class ConvBNReLU(nn.Module): 22 | @classmethod 23 | def _same_paddings(cls, kernel_size): 24 | if kernel_size == 1: 25 | return 0 26 | elif kernel_size == 3: 27 | return 1 28 | 29 | def __init__(self, in_c, out_c, 30 | kernel_size=1, stride=1, padding='same', 31 | dilation=1, groups=1): 32 | super().__init__() 33 | 34 | if padding == 'same': 35 | padding = self._same_paddings(kernel_size) 36 | 37 | self.conv = nn.Conv2d(in_c, out_c, 38 | kernel_size=kernel_size, stride=stride, 39 | padding=padding, dilation=dilation, 40 | groups=groups, 41 | bias=False) 42 | self.bn = norm_layer(out_c) 43 | self.act = nn.ReLU(inplace=True) 44 | 45 | def forward(self, x): 46 | x = self.conv(x) 47 | x = self.bn(x) 48 | x = self.act(x) 49 | 50 | return x 51 | 52 | -------------------------------------------------------------------------------- /AIM/modules/.ipynb_checkpoints/burger-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | from .bread import ConvBNReLU, norm_layer 10 | from .ham import get_hams 11 | 12 | 13 | class HamburgerV1(nn.Module): 14 | def __init__(self, in_c, n=3, D=512, args=None): 15 | super().__init__() 16 | 17 | ham_type = 'NMF' 18 | self.n = n 19 | 20 | D = getattr(args, 'MD_D', D) 21 | 22 | self.lower_bread = nn.Sequential(nn.Conv2d(in_c, D, 1), 23 | nn.ReLU(inplace=True)) 24 | 25 | HAM = get_hams(ham_type) 26 | 27 | self.ham = HAM(args, D=D) 28 | 29 | self.upper_bread = nn.Sequential(nn.Conv2d(D, in_c, 1, bias=False), 30 | norm_layer(in_c)) 31 | self.shortcut = nn.Sequential() 32 | 33 | self._init_weight() 34 | 35 | print('ham', HAM) 36 | 37 | def _init_weight(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, np.sqrt(2. / N)) 42 | elif isinstance(m, _BatchNorm): 43 | m.weight.data.fill_(1) 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | 47 | def forward(self, x): 48 | _, c, h, w = x.size() 49 | shortcut = self.shortcut(x) # 存储一个备份用于后面相加 50 | x = self.lower_bread(x) 51 | x_c = x.size(1) 52 | x = x.view(-1, self.n, x_c, h, w) 53 | x = self.ham(x) 54 | x = x.contiguous().view(-1, x_c, h, w) 55 | x = self.upper_bread(x) 56 | x = F.relu(x + shortcut, inplace=True) 57 | 58 | return x 59 | 60 | def online_update(self, bases): 61 | if hasattr(self.ham, 'online_update'): 62 | self.ham.online_update(bases) 63 | -------------------------------------------------------------------------------- /AIM/modules/.ipynb_checkpoints/ham-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | 7 | 8 | class _MatrixDecomposition2DBase(nn.Module): 9 | def __init__(self, args, D): 10 | super().__init__() 11 | 12 | self.spatial = getattr(args, 'SPATIAL', True) 13 | self.S = getattr(args, 'MD_S', 1) 14 | self.D = D 15 | # self.D = getattr(args, 'MD_D', 512) 16 | self.R = getattr(args, 'MD_R', 64) 17 | 18 | self.train_steps = getattr(args, 'TRAIN_STEPS', 6) 19 | self.eval_steps = getattr(args, 'EVAL_STEPS', 7) 20 | 21 | self.inv_t = getattr(args, 'INV_T', 1) 22 | self.eta = getattr(args, 'ETA', 0.9) 23 | 24 | print('spatial', self.spatial) 25 | print('S', self.S) 26 | print('D', self.D) 27 | print('R', self.R) 28 | print('train_steps', self.train_steps) 29 | print('eval_steps', self.eval_steps) 30 | print('inv_t', self.inv_t) 31 | print('eta', self.eta) 32 | 33 | self.bases = self._build_bases(1, self.S, self.D, self.R, cuda=True) 34 | 35 | def _build_bases(self, B, S, D, R, cuda=False): 36 | raise NotImplementedError 37 | 38 | def local_step(self, x, bases, coef): 39 | raise NotImplementedError 40 | 41 | @torch.no_grad() 42 | def local_inference(self, x, bases): 43 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 44 | coef = torch.bmm(x.transpose(1, 2), bases) 45 | coef = F.softmax(self.inv_t * coef, dim=-1) 46 | 47 | steps = self.train_steps if self.training else self.eval_steps 48 | for _ in range(steps): 49 | bases, coef = self.local_step(x, bases, coef) 50 | 51 | return bases, coef 52 | 53 | def compute_coef(self, x, bases, coef): 54 | raise NotImplementedError 55 | 56 | def forward(self, x, return_bases=False): 57 | B, Num, C, H, W = x.shape 58 | 59 | # (B, C, H, W) -> (B * S, D, N) 60 | 61 | D = C // self.S 62 | N = Num * H * W 63 | x = x.permute(0, 2, 1, 3, 4) # [B,C,Num,H,W] 64 | 65 | x = x.contiguous().view(B * self.S, D, N) #### [B,C,Num*H*W] 66 | bases = self.bases.repeat(B, 1, 1) ### [B*S,D,R] 67 | bases, coef = self.local_inference(x, bases) 68 | # (B * S, N, R) 69 | coef = self.compute_coef(x, bases, coef) 70 | 71 | # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) 72 | x = torch.bmm(bases, coef.transpose(1, 2)) 73 | 74 | # (B * S, D, N) -> (B, C, H, W) 75 | x = x.view(B, C, Num, H, W) 76 | x = x.permute(0, 2, 1, 3, 4) 77 | # (B * H, D, R) -> (B, H, N, D) 78 | bases = bases.view(B, self.S, D, self.R) 79 | self.online_update(bases) 80 | return x 81 | 82 | @torch.no_grad() 83 | def online_update(self, bases): 84 | # (B, S, D, R) -> (S, D, R) 85 | bases = bases 86 | update = bases.mean(dim=0) 87 | self.bases = self.bases 88 | self.bases += self.eta * (update - self.bases) 89 | self.bases = F.normalize(self.bases, dim=1) 90 | 91 | 92 | class NMF2D(_MatrixDecomposition2DBase): 93 | def __init__(self, args, D): 94 | super().__init__(args, D) 95 | self.inv_t = 1 96 | self.D = D 97 | 98 | def _build_bases(self, B, S, D, R, cuda=True): 99 | if cuda: 100 | bases = torch.rand((B * S, D, R)).cuda() 101 | 102 | else: 103 | bases = torch.rand((B * S, D, R)) 104 | 105 | bases = F.normalize(bases, dim=1) 106 | 107 | return bases 108 | 109 | @torch.no_grad() 110 | def local_step(self, x, bases, coef): 111 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 112 | 113 | numerator = torch.bmm(x.transpose(1, 2), bases) 114 | # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) 115 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) 116 | # Multiplicative Update 117 | coef = coef * numerator / (denominator + 1e-6) 118 | 119 | # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) 120 | numerator = torch.bmm(x, coef) 121 | # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) 122 | denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) 123 | # Multiplicative Update 124 | bases = bases * numerator / (denominator + 1e-6) 125 | 126 | return bases, coef 127 | 128 | def compute_coef(self, x, bases, coef): 129 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 130 | numerator = torch.bmm(x.transpose(1, 2), bases) 131 | # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) 132 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) 133 | # multiplication update 134 | coef = coef * numerator / (denominator + 1e-6) 135 | 136 | return coef 137 | 138 | 139 | def get_hams(key): 140 | hams = {'NMF': NMF2D} 141 | 142 | assert key in hams 143 | return hams[key] 144 | -------------------------------------------------------------------------------- /AIM/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | -------------------------------------------------------------------------------- /AIM/modules/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/bread.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/bread.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/bread.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/bread.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/bread.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/bread.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/burger.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/burger.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/burger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/burger.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/burger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/burger.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/ham.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/ham.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/ham.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/ham.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/modules/__pycache__/ham.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/ham.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/modules/bread.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Hamburger for Pytorch 4 | 5 | @author: Gsunshine 6 | """ 7 | 8 | from functools import partial 9 | 10 | import numpy as np 11 | import AIM.settings as settings 12 | import torch 13 | from AIM.sync_bn.nn.modules import SynchronizedBatchNorm2d 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | 18 | norm_layer = partial(SynchronizedBatchNorm2d, momentum=settings.BN_MOM) 19 | 20 | 21 | class ConvBNReLU(nn.Module): 22 | @classmethod 23 | def _same_paddings(cls, kernel_size): 24 | if kernel_size == 1: 25 | return 0 26 | elif kernel_size == 3: 27 | return 1 28 | 29 | def __init__(self, in_c, out_c, 30 | kernel_size=1, stride=1, padding='same', 31 | dilation=1, groups=1): 32 | super().__init__() 33 | 34 | if padding == 'same': 35 | padding = self._same_paddings(kernel_size) 36 | 37 | self.conv = nn.Conv2d(in_c, out_c, 38 | kernel_size=kernel_size, stride=stride, 39 | padding=padding, dilation=dilation, 40 | groups=groups, 41 | bias=False) 42 | self.bn = norm_layer(out_c) 43 | self.act = nn.ReLU(inplace=True) 44 | 45 | def forward(self, x): 46 | x = self.conv(x) 47 | x = self.bn(x) 48 | x = self.act(x) 49 | 50 | return x 51 | 52 | -------------------------------------------------------------------------------- /AIM/modules/burger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | from .bread import ConvBNReLU, norm_layer 10 | from .ham import get_hams 11 | 12 | 13 | class HamburgerV1(nn.Module): 14 | def __init__(self, in_c, n=3, D=512, args=None): 15 | super().__init__() 16 | 17 | ham_type = 'NMF' 18 | self.n = n 19 | 20 | D = getattr(args, 'MD_D', D) 21 | 22 | self.lower_bread = nn.Sequential(nn.Conv2d(in_c, D, 1), 23 | nn.ReLU(inplace=True)) 24 | 25 | HAM = get_hams(ham_type) 26 | 27 | self.ham = HAM(args, D=D) 28 | 29 | self.upper_bread = nn.Sequential(nn.Conv2d(D, in_c, 1, bias=False), 30 | norm_layer(in_c)) 31 | self.shortcut = nn.Sequential() 32 | 33 | self._init_weight() 34 | 35 | print('ham', HAM) 36 | 37 | def _init_weight(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, np.sqrt(2. / N)) 42 | elif isinstance(m, _BatchNorm): 43 | m.weight.data.fill_(1) 44 | if m.bias is not None: 45 | m.bias.data.zero_() 46 | 47 | def forward(self, x): 48 | _, c, h, w = x.size() 49 | shortcut = self.shortcut(x) # 存储一个备份用于后面相加 50 | x = self.lower_bread(x) 51 | x_c = x.size(1) 52 | x = x.view(-1, self.n, x_c, h, w) 53 | x = self.ham(x) 54 | x = x.contiguous().view(-1, x_c, h, w) 55 | x = self.upper_bread(x) 56 | x = F.relu(x + shortcut, inplace=True) 57 | 58 | return x 59 | 60 | def online_update(self, bases): 61 | if hasattr(self.ham, 'online_update'): 62 | self.ham.online_update(bases) 63 | -------------------------------------------------------------------------------- /AIM/modules/ham.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | 7 | 8 | class _MatrixDecomposition2DBase(nn.Module): 9 | def __init__(self, args, D): 10 | super().__init__() 11 | 12 | self.spatial = getattr(args, 'SPATIAL', True) 13 | self.S = getattr(args, 'MD_S', 1) 14 | self.D = D 15 | # self.D = getattr(args, 'MD_D', 512) 16 | self.R = getattr(args, 'MD_R', 64) 17 | 18 | self.train_steps = getattr(args, 'TRAIN_STEPS', 6) 19 | self.eval_steps = getattr(args, 'EVAL_STEPS', 7) 20 | 21 | self.inv_t = getattr(args, 'INV_T', 1) 22 | self.eta = getattr(args, 'ETA', 0.9) 23 | 24 | print('spatial', self.spatial) 25 | print('S', self.S) 26 | print('D', self.D) 27 | print('R', self.R) 28 | print('train_steps', self.train_steps) 29 | print('eval_steps', self.eval_steps) 30 | print('inv_t', self.inv_t) 31 | print('eta', self.eta) 32 | 33 | self.bases = self._build_bases(1, self.S, self.D, self.R, cuda=True) 34 | 35 | def _build_bases(self, B, S, D, R, cuda=False): 36 | raise NotImplementedError 37 | 38 | def local_step(self, x, bases, coef): 39 | raise NotImplementedError 40 | 41 | @torch.no_grad() 42 | def local_inference(self, x, bases): 43 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 44 | coef = torch.bmm(x.transpose(1, 2), bases) 45 | coef = F.softmax(self.inv_t * coef, dim=-1) 46 | 47 | steps = self.train_steps if self.training else self.eval_steps 48 | for _ in range(steps): 49 | bases, coef = self.local_step(x, bases, coef) 50 | 51 | return bases, coef 52 | 53 | def compute_coef(self, x, bases, coef): 54 | raise NotImplementedError 55 | 56 | def forward(self, x, return_bases=False): 57 | B, Num, C, H, W = x.shape 58 | 59 | # (B, C, H, W) -> (B * S, D, N) 60 | 61 | D = C // self.S 62 | N = Num * H * W 63 | x = x.permute(0, 2, 1, 3, 4) # [B,C,Num,H,W] 64 | 65 | x = x.contiguous().view(B * self.S, D, N) #### [B,C,Num*H*W] 66 | bases = self.bases.repeat(B, 1, 1) ### [B*S,D,R] 67 | bases, coef = self.local_inference(x, bases) 68 | # (B * S, N, R) 69 | coef = self.compute_coef(x, bases, coef) 70 | 71 | # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) 72 | x = torch.bmm(bases, coef.transpose(1, 2)) 73 | 74 | # (B * S, D, N) -> (B, C, H, W) 75 | x = x.view(B, C, Num, H, W) 76 | x = x.permute(0, 2, 1, 3, 4) 77 | # (B * H, D, R) -> (B, H, N, D) 78 | bases = bases.view(B, self.S, D, self.R) 79 | self.online_update(bases) 80 | return x 81 | 82 | @torch.no_grad() 83 | def online_update(self, bases): 84 | # (B, S, D, R) -> (S, D, R) 85 | bases = bases 86 | update = bases.mean(dim=0) 87 | self.bases = self.bases 88 | self.bases += self.eta * (update - self.bases) 89 | self.bases = F.normalize(self.bases, dim=1) 90 | 91 | 92 | class NMF2D(_MatrixDecomposition2DBase): 93 | def __init__(self, args, D): 94 | super().__init__(args, D) 95 | self.inv_t = 1 96 | self.D = D 97 | 98 | def _build_bases(self, B, S, D, R, cuda=True): 99 | if cuda: 100 | bases = torch.rand((B * S, D, R)).cuda() 101 | 102 | else: 103 | bases = torch.rand((B * S, D, R)) 104 | 105 | bases = F.normalize(bases, dim=1) 106 | 107 | return bases 108 | 109 | @torch.no_grad() 110 | def local_step(self, x, bases, coef): 111 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 112 | 113 | numerator = torch.bmm(x.transpose(1, 2), bases) 114 | # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) 115 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) 116 | # Multiplicative Update 117 | coef = coef * numerator / (denominator + 1e-6) 118 | 119 | # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) 120 | numerator = torch.bmm(x, coef) 121 | # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) 122 | denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) 123 | # Multiplicative Update 124 | bases = bases * numerator / (denominator + 1e-6) 125 | 126 | return bases, coef 127 | 128 | def compute_coef(self, x, bases, coef): 129 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 130 | numerator = torch.bmm(x.transpose(1, 2), bases) 131 | # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) 132 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) 133 | # multiplication update 134 | coef = coef * numerator / (denominator + 1e-6) 135 | 136 | return coef 137 | 138 | 139 | def get_hams(key): 140 | hams = {'NMF': NMF2D} 141 | 142 | assert key in hams 143 | return hams[key] 144 | -------------------------------------------------------------------------------- /AIM/settings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | 8 | # Data settings 9 | MEAN = torch.Tensor(np.array([0.485, 0.456, 0.406])) 10 | STD = torch.Tensor(np.array([0.229, 0.224, 0.225])) 11 | SCALES = (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0) 12 | TEST_SCALES = (0.5, 0.75, 1.0, 1.25, 1.5, 1.75) 13 | CROP_SIZE = 513 14 | IGNORE_LABEL = 255 15 | NUM_WORKERS = 64 16 | 17 | 18 | # Training settings 19 | RUN_FOR_TEST = True 20 | 21 | TRAIN_BATCH_SIZE = 16 22 | VAL_BATCH_SIZE = 1 23 | 24 | ITER_MAX = 60000 25 | ITER_SAVE = 1000 26 | ITER_VAL = 1000 27 | 28 | TEST_ITER_MAX = 20000 29 | TEST_ITER_SAVE = 1000 30 | TEST_ITER_VAL = 1000 31 | 32 | LR_DECAY = 10 33 | LR = 9e-3 34 | LR_MOM = 0.9 35 | POLY_POWER = 0.9 36 | WEIGHT_DECAY = 1e-4 37 | 38 | 39 | # Network 40 | N_CLASSES = 21 41 | N_LAYERS = 101 42 | STRIDE = 8 43 | BN_MOM = 3e-4 44 | 45 | 46 | # Hamburger 47 | CHANNELS = 512 48 | VERSION = 'V2' 49 | 50 | HAM_TYPE = 'NMF' 51 | DUAL = False 52 | SPATIAL = True 53 | RAND_INIT = True 54 | 55 | MD_S = 1 56 | MD_D = 512 57 | MD_R = 64 58 | 59 | CHEESE_FACTOR = 1 60 | TRAIN_STEPS = 6 61 | EVAL_STEPS = 6 62 | 63 | INV_T = 1 64 | BETA = 0.1 65 | ETA = 0.9 66 | 67 | 68 | # Tensorboard 69 | logger = logging.getLogger('train') 70 | logger.setLevel(logging.INFO) 71 | ch = logging.StreamHandler() 72 | ch.setLevel(logging.INFO) 73 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 74 | ch.setFormatter(formatter) 75 | logger.addHandler(ch) 76 | -------------------------------------------------------------------------------- /AIM/sync_bn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__init__.py -------------------------------------------------------------------------------- /AIM/sync_bn/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import settings 4 | 5 | 6 | def fast_hist(label_true, label_pred): 7 | n_class = settings.N_CLASSES 8 | mask = (label_true >= 0) & (label_true < n_class) 9 | hist = torch.bincount( 10 | n_class * label_true[mask].int() + label_pred[mask].int(), 11 | minlength=n_class ** 2, 12 | ).reshape(n_class, n_class) 13 | return hist 14 | 15 | 16 | label_names = [ 17 | 'background', 18 | 'aeroplane', 19 | 'bicycle', 20 | 'bird', 21 | 'boat', 22 | 'bottle', 23 | 'bus', 24 | 'car', 25 | 'cat', 26 | 'chair', 27 | 'cow', 28 | 'diningtable', 29 | 'dog', 30 | 'horse', 31 | 'motorbike', 32 | 'person', 33 | 'pottedplant', 34 | 'sheep', 35 | 'sofa', 36 | 'train', 37 | 'tvmonitor', 38 | ] 39 | 40 | 41 | def cal_scores(hist): 42 | n_class = settings.N_CLASSES 43 | #acc = np.diag(hist).sum() / hist.sum() 44 | #acc_cls = np.diag(hist) / hist.sum(axis=1) 45 | #acc_cls = np.nanmean(acc_cls) 46 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 47 | mean_iu = np.nanmean(iu) 48 | freq = hist.sum(axis=1) / hist.sum() 49 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 50 | cls_iu = dict(zip(label_names, iu)) 51 | 52 | return { 53 | #"pAcc": acc, 54 | #"mAcc": acc_cls, 55 | "fIoU": fwavacc, 56 | "mIoU": mean_iu, 57 | }#, cls_iu 58 | 59 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/comm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=3e-4, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | # customed batch norm statistics 49 | self.momentum = momentum 50 | 51 | def forward(self, input, weight=None, bias=None): 52 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 53 | if not (self._is_parallel and self.training): 54 | return F.batch_norm( 55 | input, self.running_mean, self.running_var, self.weight, self.bias, 56 | self.training, self.momentum, self.eps) 57 | 58 | # Resize the input to (B, C, -1). 59 | input_shape = input.size() 60 | input = input.view(input.size(0), self.num_features, -1) 61 | 62 | # Compute the sum and square-sum. 63 | sum_size = input.size(0) * input.size(2) 64 | input_sum = _sum_ft(input) 65 | input_ssum = _sum_ft(input ** 2) 66 | 67 | # Reduce-and-broadcast the statistics. 68 | if self._parallel_id == 0: 69 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 70 | else: 71 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 72 | 73 | # Compute the output. 74 | if self.affine: 75 | if weight is None or bias is None: 76 | weight = self.weight 77 | bias = self.bias 78 | 79 | # MJY:: Fuse the multiplication for speed. 80 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * weight) + _unsqueeze_ft(bias) 81 | else: 82 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 83 | 84 | # Reshape it. 85 | return output.view(input_shape) 86 | 87 | def __data_parallel_replicate__(self, ctx, copy_id): 88 | self._is_parallel = True 89 | self._parallel_id = copy_id 90 | 91 | # parallel_id == 0 means master device. 92 | if self._parallel_id == 0: 93 | ctx.sync_master = self._sync_master 94 | else: 95 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 96 | 97 | def _data_parallel_master(self, intermediates): 98 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 99 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 100 | 101 | to_reduce = [i[1][:2] for i in intermediates] 102 | to_reduce = [j for i in to_reduce for j in i] # flatten 103 | target_gpus = [i[1].sum.get_device() for i in intermediates] 104 | 105 | sum_size = sum([i[1].sum_size for i in intermediates]) 106 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 107 | 108 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 109 | 110 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 111 | 112 | outputs = [] 113 | for i, rec in enumerate(intermediates): 114 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 115 | 116 | return outputs 117 | 118 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): 119 | """return *dest* by `dest := dest*alpha + delta*beta + bias`""" 120 | return dest * alpha + delta * beta + bias 121 | 122 | def _compute_mean_std(self, sum_, ssum, size): 123 | """Compute the mean and standard-deviation with sum and square-sum. This method 124 | also maintains the moving average on the master device.""" 125 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 126 | mean = sum_ / size 127 | sumvar = ssum - sum_ * mean 128 | unbias_var = sumvar / (size - 1) 129 | bias_var = sumvar / size 130 | 131 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 132 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 133 | 134 | return mean, (bias_var + self.eps) ** -0.5 135 | 136 | 137 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 138 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 139 | mini-batch. 140 | 141 | .. math:: 142 | 143 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 144 | 145 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 146 | standard-deviation are reduced across all devices during training. 147 | 148 | For example, when one uses `nn.DataParallel` to wrap the network during 149 | training, PyTorch's implementation normalize the tensor on each device using 150 | the statistics only on that device, which accelerated the computation and 151 | is also easy to implement, but the statistics might be inaccurate. 152 | Instead, in this synchronized version, the statistics will be computed 153 | over all training samples distributed on multiple devices. 154 | 155 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 156 | as the built-in PyTorch implementation. 157 | 158 | The mean and standard-deviation are calculated per-dimension over 159 | the mini-batches and gamma and beta are learnable parameter vectors 160 | of size C (where C is the input size). 161 | 162 | During training, this layer keeps a running estimate of its computed mean 163 | and variance. The running sum is kept with a default momentum of 0.1. 164 | 165 | During evaluation, this running mean/variance is used for normalization. 166 | 167 | Because the BatchNorm is done over the `C` dimension, computing statistics 168 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 169 | 170 | Args: 171 | num_features: num_features from an expected input of size 172 | `batch_size x num_features [x width]` 173 | eps: a value added to the denominator for numerical stability. 174 | Default: 1e-5 175 | momentum: the value used for the running_mean and running_var 176 | computation. Default: 0.1 177 | affine: a boolean value that when set to ``True``, gives the layer learnable 178 | affine parameters. Default: ``True`` 179 | 180 | Shape: 181 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 182 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 183 | 184 | Examples: 185 | >>> # With Learnable Parameters 186 | >>> m = SynchronizedBatchNorm1d(100) 187 | >>> # Without Learnable Parameters 188 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 189 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 190 | >>> output = m(input) 191 | """ 192 | 193 | def _check_input_dim(self, input): 194 | if input.dim() != 2 and input.dim() != 3: 195 | raise ValueError('expected 2D or 3D input (got {}D input)' 196 | .format(input.dim())) 197 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 198 | 199 | 200 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 201 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 202 | of 3d inputs 203 | 204 | .. math:: 205 | 206 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 207 | 208 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 209 | standard-deviation are reduced across all devices during training. 210 | 211 | For example, when one uses `nn.DataParallel` to wrap the network during 212 | training, PyTorch's implementation normalize the tensor on each device using 213 | the statistics only on that device, which accelerated the computation and 214 | is also easy to implement, but the statistics might be inaccurate. 215 | Instead, in this synchronized version, the statistics will be computed 216 | over all training samples distributed on multiple devices. 217 | 218 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 219 | as the built-in PyTorch implementation. 220 | 221 | The mean and standard-deviation are calculated per-dimension over 222 | the mini-batches and gamma and beta are learnable parameter vectors 223 | of size C (where C is the input size). 224 | 225 | During training, this layer keeps a running estimate of its computed mean 226 | and variance. The running sum is kept with a default momentum of 0.1. 227 | 228 | During evaluation, this running mean/variance is used for normalization. 229 | 230 | Because the BatchNorm is done over the `C` dimension, computing statistics 231 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 232 | 233 | Args: 234 | num_features: num_features from an expected input of 235 | size batch_size x num_features x height x width 236 | eps: a value added to the denominator for numerical stability. 237 | Default: 1e-5 238 | momentum: the value used for the running_mean and running_var 239 | computation. Default: 0.1 240 | affine: a boolean value that when set to ``True``, gives the layer learnable 241 | affine parameters. Default: ``True`` 242 | 243 | Shape: 244 | - Input: :math:`(N, C, H, W)` 245 | - Output: :math:`(N, C, H, W)` (same shape as input) 246 | 247 | Examples: 248 | >>> # With Learnable Parameters 249 | >>> m = SynchronizedBatchNorm2d(100) 250 | >>> # Without Learnable Parameters 251 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 252 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 253 | >>> output = m(input) 254 | """ 255 | 256 | def _check_input_dim(self, input): 257 | if input.dim() != 4: 258 | raise ValueError('expected 4D input (got {}D input)' 259 | .format(input.dim())) 260 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 261 | 262 | 263 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 264 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 265 | of 4d inputs 266 | 267 | .. math:: 268 | 269 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 270 | 271 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 272 | standard-deviation are reduced across all devices during training. 273 | 274 | For example, when one uses `nn.DataParallel` to wrap the network during 275 | training, PyTorch's implementation normalize the tensor on each device using 276 | the statistics only on that device, which accelerated the computation and 277 | is also easy to implement, but the statistics might be inaccurate. 278 | Instead, in this synchronized version, the statistics will be computed 279 | over all training samples distributed on multiple devices. 280 | 281 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 282 | as the built-in PyTorch implementation. 283 | 284 | The mean and standard-deviation are calculated per-dimension over 285 | the mini-batches and gamma and beta are learnable parameter vectors 286 | of size C (where C is the input size). 287 | 288 | During training, this layer keeps a running estimate of its computed mean 289 | and variance. The running sum is kept with a default momentum of 0.1. 290 | 291 | During evaluation, this running mean/variance is used for normalization. 292 | 293 | Because the BatchNorm is done over the `C` dimension, computing statistics 294 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 295 | or Spatio-temporal BatchNorm 296 | 297 | Args: 298 | num_features: num_features from an expected input of 299 | size batch_size x num_features x depth x height x width 300 | eps: a value added to the denominator for numerical stability. 301 | Default: 1e-5 302 | momentum: the value used for the running_mean and running_var 303 | computation. Default: 0.1 304 | affine: a boolean value that when set to ``True``, gives the layer learnable 305 | affine parameters. Default: ``True`` 306 | 307 | Shape: 308 | - Input: :math:`(N, C, D, H, W)` 309 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 310 | 311 | Examples: 312 | >>> # With Learnable Parameters 313 | >>> m = SynchronizedBatchNorm3d(100) 314 | >>> # Without Learnable Parameters 315 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 316 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 317 | >>> output = m(input) 318 | """ 319 | 320 | def _check_input_dim(self, input): 321 | if input.dim() != 5: 322 | raise ValueError('expected 5D input (got {}D input)' 323 | .format(input.dim())) 324 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 325 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-311.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-37.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-38.pyc -------------------------------------------------------------------------------- /AIM/sync_bn/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | from torch.autograd import Variable 7 | import collections 8 | from torch.nn.parallel._functions import Gather 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | def async_copy_to(obj, dev, main_stream=None): 13 | if torch.is_tensor(obj): 14 | obj = Variable(obj) 15 | if isinstance(obj, Variable): 16 | v = obj.cuda(dev) 17 | if main_stream is not None: 18 | v.data.record_stream(main_stream) 19 | return v 20 | elif isinstance(obj, collections.Mapping): 21 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 22 | elif isinstance(obj, collections.Sequence): 23 | return [async_copy_to(o, dev, main_stream) for o in obj] 24 | else: 25 | return obj 26 | 27 | 28 | def dict_gather(outputs, target_device, dim=0): 29 | """ 30 | Gathers variables from different GPUs on a specified device 31 | (-1 means the CPU), with dictionary support. 32 | """ 33 | def gather_map(outputs): 34 | out = outputs[0] 35 | if isinstance(out, Variable): 36 | # MJY(20180330) HACK:: force nr_dims > 0 37 | if out.dim() == 0: 38 | outputs = [o.unsqueeze(0) for o in outputs] 39 | return Gather.apply(target_device, dim, *outputs) 40 | elif out is None: 41 | return None 42 | elif isinstance(out, collections.Mapping): 43 | return {k: gather_map([o[k] for o in outputs]) for k in out} 44 | elif isinstance(out, collections.Sequence): 45 | return type(out)(map(gather_map, zip(*outputs))) 46 | return gather_map(outputs) 47 | 48 | 49 | class DictGatherDataParallel(nn.DataParallel): 50 | def gather(self, outputs, output_device): 51 | return dict_gather(outputs, output_device, dim=self.dim) 52 | 53 | 54 | class UserScatteredDataParallel(DictGatherDataParallel): 55 | def scatter(self, inputs, kwargs, device_ids): 56 | assert len(inputs) == 1 57 | inputs = inputs[0] 58 | inputs = _async_copy_stream(inputs, device_ids) 59 | inputs = [[i] for i in inputs] 60 | assert len(kwargs) == 0 61 | kwargs = [{} for _ in range(len(inputs))] 62 | 63 | return inputs, kwargs 64 | 65 | 66 | def user_scattered_collate(batch): 67 | return batch 68 | 69 | 70 | def _async_copy(inputs, device_ids): 71 | nr_devs = len(device_ids) 72 | assert type(inputs) in (tuple, list) 73 | assert len(inputs) == nr_devs 74 | 75 | outputs = [] 76 | for i, dev in zip(inputs, device_ids): 77 | with cuda.device(dev): 78 | outputs.append(async_copy_to(i, dev)) 79 | 80 | return tuple(outputs) 81 | 82 | 83 | def _async_copy_stream(inputs, device_ids): 84 | nr_devs = len(device_ids) 85 | assert type(inputs) in (tuple, list) 86 | assert len(inputs) == nr_devs 87 | 88 | outputs = [] 89 | streams = [_get_stream(d) for d in device_ids] 90 | for i, dev, stream in zip(inputs, device_ids, streams): 91 | with cuda.device(dev): 92 | main_stream = cuda.current_stream() 93 | with cuda.stream(stream): 94 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 95 | main_stream.wait_stream(stream) 96 | 97 | return outputs 98 | 99 | 100 | """Adapted from: torch/nn/parallel/_functions.py""" 101 | # background streams used for copying 102 | _streams = None 103 | 104 | 105 | def _get_stream(device): 106 | """Gets a background stream for copying between CPU and GPU""" 107 | global _streams 108 | if device == -1: 109 | return None 110 | if _streams is None: 111 | _streams = [None] * cuda.device_count() 112 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 113 | return _streams[device] 114 | -------------------------------------------------------------------------------- /AIM/sync_bn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /AIM/sync_bn/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /AIM/sync_bn/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 4 | _remove_worker_pids, _error_if_any_worker_fails 5 | from .sampler import SequentialSampler, RandomSampler, BatchSampler 6 | import signal 7 | import functools 8 | import collections 9 | import re 10 | import sys 11 | import threading 12 | import traceback 13 | from torch._six import string_classes, int_classes 14 | import numpy as np 15 | 16 | if sys.version_info[0] == 2: 17 | import Queue as queue 18 | else: 19 | import queue 20 | 21 | 22 | class ExceptionWrapper(object): 23 | r"Wraps an exception plus traceback to communicate across threads" 24 | 25 | def __init__(self, exc_info): 26 | self.exc_type = exc_info[0] 27 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 28 | 29 | 30 | _use_shared_memory = False 31 | """Whether to use shared memory in default_collate""" 32 | 33 | 34 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): 35 | global _use_shared_memory 36 | _use_shared_memory = True 37 | 38 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal 39 | # module's handlers are executed after Python returns from C low-level 40 | # handlers, likely when the same fatal signal happened again already. 41 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 42 | _set_worker_signal_handlers() 43 | 44 | torch.set_num_threads(1) 45 | torch.manual_seed(seed) 46 | np.random.seed(seed) 47 | 48 | if init_fn is not None: 49 | init_fn(worker_id) 50 | 51 | while True: 52 | r = index_queue.get() 53 | if r is None: 54 | break 55 | idx, batch_indices = r 56 | try: 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | except Exception: 59 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 60 | else: 61 | data_queue.put((idx, samples)) 62 | 63 | 64 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): 65 | if pin_memory: 66 | torch.cuda.set_device(device_id) 67 | 68 | while True: 69 | try: 70 | r = in_queue.get() 71 | except Exception: 72 | if done_event.is_set(): 73 | return 74 | raise 75 | if r is None: 76 | break 77 | if isinstance(r[1], ExceptionWrapper): 78 | out_queue.put(r) 79 | continue 80 | idx, batch = r 81 | try: 82 | if pin_memory: 83 | batch = pin_memory_batch(batch) 84 | except Exception: 85 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 86 | else: 87 | out_queue.put((idx, batch)) 88 | 89 | numpy_type_map = { 90 | 'float64': torch.DoubleTensor, 91 | 'float32': torch.FloatTensor, 92 | 'float16': torch.HalfTensor, 93 | 'int64': torch.LongTensor, 94 | 'int32': torch.IntTensor, 95 | 'int16': torch.ShortTensor, 96 | 'int8': torch.CharTensor, 97 | 'uint8': torch.ByteTensor, 98 | } 99 | 100 | 101 | def default_collate(batch): 102 | "Puts each data field into a tensor with outer dimension batch size" 103 | 104 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 105 | elem_type = type(batch[0]) 106 | if torch.is_tensor(batch[0]): 107 | out = None 108 | if _use_shared_memory: 109 | # If we're in a background process, concatenate directly into a 110 | # shared memory tensor to avoid an extra copy 111 | numel = sum([x.numel() for x in batch]) 112 | storage = batch[0].storage()._new_shared(numel) 113 | out = batch[0].new(storage) 114 | return torch.stack(batch, 0, out=out) 115 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 116 | and elem_type.__name__ != 'string_': 117 | elem = batch[0] 118 | if elem_type.__name__ == 'ndarray': 119 | # array of string classes and object 120 | if re.search('[SaUO]', elem.dtype.str) is not None: 121 | raise TypeError(error_msg.format(elem.dtype)) 122 | 123 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 124 | if elem.shape == (): # scalars 125 | py_type = float if elem.dtype.name.startswith('float') else int 126 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 127 | elif isinstance(batch[0], int_classes): 128 | return torch.LongTensor(batch) 129 | elif isinstance(batch[0], float): 130 | return torch.DoubleTensor(batch) 131 | elif isinstance(batch[0], string_classes): 132 | return batch 133 | elif isinstance(batch[0], collections.Mapping): 134 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 135 | elif isinstance(batch[0], collections.Sequence): 136 | transposed = zip(*batch) 137 | return [default_collate(samples) for samples in transposed] 138 | 139 | raise TypeError((error_msg.format(type(batch[0])))) 140 | 141 | 142 | def pin_memory_batch(batch): 143 | if torch.is_tensor(batch): 144 | return batch.pin_memory() 145 | elif isinstance(batch, string_classes): 146 | return batch 147 | elif isinstance(batch, collections.Mapping): 148 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 149 | elif isinstance(batch, collections.Sequence): 150 | return [pin_memory_batch(sample) for sample in batch] 151 | else: 152 | return batch 153 | 154 | 155 | _SIGCHLD_handler_set = False 156 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one 157 | handler needs to be set for all DataLoaders in a process.""" 158 | 159 | 160 | def _set_SIGCHLD_handler(): 161 | # Windows doesn't support SIGCHLD handler 162 | if sys.platform == 'win32': 163 | return 164 | # can't set signal in child threads 165 | if not isinstance(threading.current_thread(), threading._MainThread): 166 | return 167 | global _SIGCHLD_handler_set 168 | if _SIGCHLD_handler_set: 169 | return 170 | previous_handler = signal.getsignal(signal.SIGCHLD) 171 | if not callable(previous_handler): 172 | previous_handler = None 173 | 174 | def handler(signum, frame): 175 | # This following call uses `waitid` with WNOHANG from C side. Therefore, 176 | # Python can still get and update the process status successfully. 177 | _error_if_any_worker_fails() 178 | if previous_handler is not None: 179 | previous_handler(signum, frame) 180 | 181 | signal.signal(signal.SIGCHLD, handler) 182 | _SIGCHLD_handler_set = True 183 | 184 | 185 | class DataLoaderIter(object): 186 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 187 | 188 | def __init__(self, loader): 189 | self.dataset = loader.dataset 190 | self.collate_fn = loader.collate_fn 191 | self.batch_sampler = loader.batch_sampler 192 | self.num_workers = loader.num_workers 193 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 194 | self.timeout = loader.timeout 195 | self.done_event = threading.Event() 196 | 197 | self.sample_iter = iter(self.batch_sampler) 198 | 199 | if self.num_workers > 0: 200 | self.worker_init_fn = loader.worker_init_fn 201 | self.index_queue = multiprocessing.SimpleQueue() 202 | self.worker_result_queue = multiprocessing.SimpleQueue() 203 | self.batches_outstanding = 0 204 | self.worker_pids_set = False 205 | self.shutdown = False 206 | self.send_idx = 0 207 | self.rcvd_idx = 0 208 | self.reorder_dict = {} 209 | 210 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] 211 | self.workers = [ 212 | multiprocessing.Process( 213 | target=_worker_loop, 214 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, 215 | base_seed + i, self.worker_init_fn, i)) 216 | for i in range(self.num_workers)] 217 | 218 | if self.pin_memory or self.timeout > 0: 219 | self.data_queue = queue.Queue() 220 | if self.pin_memory: 221 | maybe_device_id = torch.cuda.current_device() 222 | else: 223 | # do not initialize cuda context if not necessary 224 | maybe_device_id = None 225 | self.worker_manager_thread = threading.Thread( 226 | target=_worker_manager_loop, 227 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 228 | maybe_device_id)) 229 | self.worker_manager_thread.daemon = True 230 | self.worker_manager_thread.start() 231 | else: 232 | self.data_queue = self.worker_result_queue 233 | 234 | for w in self.workers: 235 | w.daemon = True # ensure that the worker exits on process exit 236 | w.start() 237 | 238 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 239 | _set_SIGCHLD_handler() 240 | self.worker_pids_set = True 241 | 242 | # prime the prefetch loop 243 | for _ in range(2 * self.num_workers): 244 | self._put_indices() 245 | 246 | def __len__(self): 247 | return len(self.batch_sampler) 248 | 249 | def _get_batch(self): 250 | if self.timeout > 0: 251 | try: 252 | return self.data_queue.get(timeout=self.timeout) 253 | except queue.Empty: 254 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) 255 | else: 256 | return self.data_queue.get() 257 | 258 | def __next__(self): 259 | if self.num_workers == 0: # same-process loading 260 | indices = next(self.sample_iter) # may raise StopIteration 261 | batch = self.collate_fn([self.dataset[i] for i in indices]) 262 | if self.pin_memory: 263 | batch = pin_memory_batch(batch) 264 | return batch 265 | 266 | # check if the next sample has already been generated 267 | if self.rcvd_idx in self.reorder_dict: 268 | batch = self.reorder_dict.pop(self.rcvd_idx) 269 | return self._process_next_batch(batch) 270 | 271 | if self.batches_outstanding == 0: 272 | self._shutdown_workers() 273 | raise StopIteration 274 | 275 | while True: 276 | assert (not self.shutdown and self.batches_outstanding > 0) 277 | idx, batch = self._get_batch() 278 | self.batches_outstanding -= 1 279 | if idx != self.rcvd_idx: 280 | # store out-of-order samples 281 | self.reorder_dict[idx] = batch 282 | continue 283 | return self._process_next_batch(batch) 284 | 285 | next = __next__ # Python 2 compatibility 286 | 287 | def __iter__(self): 288 | return self 289 | 290 | def _put_indices(self): 291 | assert self.batches_outstanding < 2 * self.num_workers 292 | indices = next(self.sample_iter, None) 293 | if indices is None: 294 | return 295 | self.index_queue.put((self.send_idx, indices)) 296 | self.batches_outstanding += 1 297 | self.send_idx += 1 298 | 299 | def _process_next_batch(self, batch): 300 | self.rcvd_idx += 1 301 | self._put_indices() 302 | if isinstance(batch, ExceptionWrapper): 303 | raise batch.exc_type(batch.exc_msg) 304 | return batch 305 | 306 | def __getstate__(self): 307 | # TODO: add limited pickling support for sharing an iterator 308 | # across multiple threads for HOGWILD. 309 | # Probably the best way to do this is by moving the sample pushing 310 | # to a separate thread and then just sharing the data queue 311 | # but signalling the end is tricky without a non-blocking API 312 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 313 | 314 | def _shutdown_workers(self): 315 | try: 316 | if not self.shutdown: 317 | self.shutdown = True 318 | self.done_event.set() 319 | # if worker_manager_thread is waiting to put 320 | while not self.data_queue.empty(): 321 | self.data_queue.get() 322 | for _ in self.workers: 323 | self.index_queue.put(None) 324 | # done_event should be sufficient to exit worker_manager_thread, 325 | # but be safe here and put another None 326 | self.worker_result_queue.put(None) 327 | finally: 328 | # removes pids no matter what 329 | if self.worker_pids_set: 330 | _remove_worker_pids(id(self)) 331 | self.worker_pids_set = False 332 | 333 | def __del__(self): 334 | if self.num_workers > 0: 335 | self._shutdown_workers() 336 | 337 | 338 | class DataLoader(object): 339 | """ 340 | Data loader. Combines a dataset and a sampler, and provides 341 | single- or multi-process iterators over the dataset. 342 | 343 | Arguments: 344 | dataset (Dataset): dataset from which to load the data. 345 | batch_size (int, optional): how many samples per batch to load 346 | (default: 1). 347 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 348 | at every epoch (default: False). 349 | sampler (Sampler, optional): defines the strategy to draw samples from 350 | the dataset. If specified, ``shuffle`` must be False. 351 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 352 | indices at a time. Mutually exclusive with batch_size, shuffle, 353 | sampler, and drop_last. 354 | num_workers (int, optional): how many subprocesses to use for data 355 | loading. 0 means that the data will be loaded in the main process. 356 | (default: 0) 357 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 358 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 359 | into CUDA pinned memory before returning them. 360 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 361 | if the dataset size is not divisible by the batch size. If ``False`` and 362 | the size of dataset is not divisible by the batch size, then the last batch 363 | will be smaller. (default: False) 364 | timeout (numeric, optional): if positive, the timeout value for collecting a batch 365 | from workers. Should always be non-negative. (default: 0) 366 | worker_init_fn (callable, optional): If not None, this will be called on each 367 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 368 | input, after seeding and before data loading. (default: None) 369 | 370 | .. note:: By default, each worker will have its PyTorch seed set to 371 | ``base_seed + worker_id``, where ``base_seed`` is a long generated 372 | by main process using its RNG. You may use ``torch.initial_seed()`` to access 373 | this value in :attr:`worker_init_fn`, which can be used to set other seeds 374 | (e.g. NumPy) before data loading. 375 | 376 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an 377 | unpicklable object, e.g., a lambda function. 378 | """ 379 | 380 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 381 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 382 | timeout=0, worker_init_fn=None): 383 | self.dataset = dataset 384 | self.batch_size = batch_size 385 | self.num_workers = num_workers 386 | self.collate_fn = collate_fn 387 | self.pin_memory = pin_memory 388 | self.drop_last = drop_last 389 | self.timeout = timeout 390 | self.worker_init_fn = worker_init_fn 391 | 392 | if timeout < 0: 393 | raise ValueError('timeout option should be non-negative') 394 | 395 | if batch_sampler is not None: 396 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 397 | raise ValueError('batch_sampler is mutually exclusive with ' 398 | 'batch_size, shuffle, sampler, and drop_last') 399 | 400 | if sampler is not None and shuffle: 401 | raise ValueError('sampler is mutually exclusive with shuffle') 402 | 403 | if self.num_workers < 0: 404 | raise ValueError('num_workers cannot be negative; ' 405 | 'use num_workers=0 to disable multiprocessing.') 406 | 407 | if batch_sampler is None: 408 | if sampler is None: 409 | if shuffle: 410 | sampler = RandomSampler(dataset) 411 | else: 412 | sampler = SequentialSampler(dataset) 413 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 414 | 415 | self.sampler = sampler 416 | self.batch_sampler = batch_sampler 417 | 418 | def __iter__(self): 419 | return DataLoaderIter(self) 420 | 421 | def __len__(self): 422 | return len(self.batch_sampler) 423 | -------------------------------------------------------------------------------- /AIM/sync_bn/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /AIM/sync_bn/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /AIM/sync_bn/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /AIM/sync_bn/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI 2024]Weakly Supervised Multimodal Affordance Grounding for Egocentric Images 2 | ## Paper 3 | >Weakly Supervised Multimodal Affordance Grounding for Egocentric Images(AAAI 2024) 4 | 5 | Link: https://doi.org/10.1609/aaai.v38i6.28451 6 | 7 | >Appendix 8 | 9 | Link: [Appendix.pdf](/docs) 10 | 11 |


12 | 13 | **Abstract:** 14 | 15 | To enhance the interaction between intelligent systems and the environment, locating the affordance regions of objects is crucial. These regions correspond to specific areas that provide distinct functionalities. Humans often acquire the ability to identify these regions through action demonstrations and verbal instructions. In this paper, we present a novel multimodal framework that extracts affordance knowledge from exocentric images, which depict human-object interactions, as well as from accompanying textual descriptions that describe the performed actions. The extracted knowledge is then transferred to egocentric images. To achieve this goal, we propose the HOI-Transfer Module, which utilizes local perception to disentangle individual actions within exocentric images. This module effectively captures localized features and correlations between actions, leading to valuable affordance knowledge. Additionally, we introduce the Pixel-Text Fusion Module, which fuses affordance knowledge by identifying regions in egocentric images that bear resemblances to the textual features defining affordances. We employ a Weakly Supervised Multimodal Affordance (WSMA) learning approach, utilizing image-level labels for training. Through extensive experiments, we demonstrate the superiority of our proposed method in terms of evaluation metrics and visual results when compared to existing affordance grounding models. Furthermore, ablation experiments confirm the effectiveness of our approach. 16 | 17 | ## Requirements 18 | We run in the following environment: 19 | - A NVIDIA GeForce RTX 3090 20 | - Python(3.8) 21 | - Pytorch(1.10.0) 22 | 23 | ## Required pre-trained models 24 | - model for Dino_vit(No need to download separately, the code is already included) 25 | - model for text_enconder(clip): You can find it [here](https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt) 26 | 27 | ## Start 28 | ```bash 29 | git clone https://github.com/xulingjing88/WSMA.git 30 | cd WSMA 31 | ``` 32 | Before training, you need to preprocess the data 33 | - Seen, Unseen(from AGD20K): You can find it [here](https://github.com/lhc1224/Cross-View-AG/tree/main/code/cvpr) 34 | - HICO-IIF: You can find it [here](https://pan.baidu.com/s/1imzN-mRaWLIyDLZ80NibxQ?pwd=c878) or google cloud [here](https://drive.google.com/file/d/1InBbjM6Uo9HK8OKAuK7TZyvCFjSVKkfY/view?usp=drive_link) 35 | ```bash 36 | python preprocessing.py 37 | ``` 38 | Set **'data_root'** to the path of the dataset, **'divide'** to the dataset name (Seen or Unseen or HICO-IIF), and then you can start training by running train.py. 39 | ```bash 40 | python train.py 41 | ``` 42 | 43 | ## Acknowledgements 44 | We would like to express our gratitude to the following repositories for their contributions and inspirations: [Cross-View-AG](https://github.com/lhc1224/Cross-View-AG), [LOCATE](https://github.com/Reagan1311/LOCATE), [Dino](https://github.com/facebookresearch/dino), [CLIP](https://github.com/openai/CLIP). 45 | 46 | -------------------------------------------------------------------------------- /bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /datatest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | import collections 4 | import torch 5 | import numpy as np 6 | import h5py 7 | from joblib import Parallel, delayed 8 | import random 9 | from utils import util 10 | import cv2 11 | from utils.transform import Normalize, cv_random_crop_flip, load_image 12 | from torchvision import transforms 13 | from PIL import Image 14 | 15 | mean = [0.485, 0.456, 0.406] 16 | std = [0.229, 0.224, 0.225] 17 | 18 | 19 | class TrainData(data.Dataset): 20 | def __init__(self, egocentric_root, crop_size=224, divide="Unseen", mask_root=None): 21 | self.egocentric_root = egocentric_root 22 | self.image_list = [] 23 | self.crop_size = crop_size 24 | self.mask_root = mask_root 25 | if divide == "Seen": 26 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch", 27 | "cut", "cut_with", "drag", 'drink_with', "eat", 28 | "hit", "hold", "jump", "kick", "lie_on", "lift", 29 | "look_out", "open", "pack", "peel", "pick_up", 30 | "pour", "push", "ride", "sip", "sit_on", "stick", 31 | "stir", "swing", "take_photo", "talk_on", "text_on", 32 | "throw", "type_on", "wash", "write"] 33 | elif divide=="Unseen": 34 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with', 35 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 36 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 37 | "swing", "take_photo", "throw", "type_on", "wash"] 38 | else: # HICO-IIF 39 | self.aff_list = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on'] 40 | 41 | self.transform = transforms.Compose([ 42 | transforms.Resize((crop_size, crop_size)), 43 | transforms.ToTensor(), 44 | transforms.Normalize(mean, std) 45 | ]) 46 | 47 | files = os.listdir(self.egocentric_root) 48 | for file in files: 49 | file_path = os.path.join(self.egocentric_root, file) 50 | obj_files = os.listdir(file_path) 51 | for obj_file in obj_files: 52 | obj_file_path = os.path.join(file_path, obj_file) 53 | images = os.listdir(obj_file_path) 54 | for img in images: 55 | if 'json' not in img: 56 | img_path = os.path.join(obj_file_path, img) 57 | cur = img_path.split("/") 58 | if os.path.exists(os.path.join(self.mask_root, file, obj_file, img[:-3] + "png")): 59 | self.image_list.append(img_path) 60 | 61 | def __getitem__(self, item): 62 | egocentric_image_path = self.image_list[item] 63 | names = egocentric_image_path.split("/") 64 | aff_name, object = names[-3], names[-2] 65 | label = self.aff_list.index(aff_name) 66 | egocentric_image = self.load_static_image(egocentric_image_path) # At this time, the graph of individual items has been converted into tensor 67 | 68 | mask_path = os.path.join(self.mask_root, names[-3], names[-2], names[-1][:-3] + "png") 69 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 70 | mask = cv2.resize(mask, (224, 224)) 71 | 72 | return egocentric_image, label, mask_path, aff_name 73 | 74 | def load_static_image(self, path): 75 | 76 | img = util.load_img(path) 77 | img = self.transform(img) 78 | return img 79 | 80 | def __len__(self): 81 | 82 | return len(self.image_list) 83 | -------------------------------------------------------------------------------- /datatrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | import collections 4 | import torch 5 | import numpy as np 6 | import h5py 7 | from joblib import Parallel, delayed 8 | import random 9 | from utils import util 10 | import cv2 11 | from utils.transform import Normalize, cv_random_crop_flip, load_image 12 | from torchvision import transforms 13 | from PIL import Image 14 | 15 | mean = [0.485, 0.456, 0.406] 16 | std = [0.229, 0.224, 0.225] 17 | 18 | 19 | class TrainData(data.Dataset): 20 | def __init__(self, exocentric_root, egocentric_root, resize_size=256, crop_size=224, divide="Unseen"): 21 | self.exocentric_root = exocentric_root 22 | self.egocentric_root = egocentric_root 23 | self.resize_size = resize_size 24 | self.image_list = [] 25 | self.crop_size = crop_size 26 | if divide == "Seen": 27 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch", 28 | "cut", "cut_with", "drag", 'drink_with', "eat", 29 | "hit", "hold", "jump", "kick", "lie_on", "lift", 30 | "look_out", "open", "pack", "peel", "pick_up", 31 | "pour", "push", "ride", "sip", "sit_on", "stick", 32 | "stir", "swing", "take_photo", "talk_on", "text_on", 33 | "throw", "type_on", "wash", "write"] 34 | elif divide=="Unseen": 35 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with', 36 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 37 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 38 | "swing", "take_photo", "throw", "type_on", "wash"] 39 | else: # HICO-IIF 40 | self.aff_list = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on'] 41 | 42 | self.transform = transforms.Compose([ 43 | transforms.Resize(resize_size), 44 | transforms.RandomCrop(crop_size), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean, std) 48 | ]) 49 | 50 | files = os.listdir(self.exocentric_root) 51 | for file in files: 52 | file_path = os.path.join(self.exocentric_root, file) 53 | obj_files = os.listdir(file_path) 54 | for obj_file in obj_files: 55 | obj_file_path = os.path.join(file_path, obj_file) 56 | images = os.listdir(obj_file_path) 57 | for img in images: 58 | img_path = os.path.join(obj_file_path, img) 59 | cur = img_path.split("/") 60 | if os.path.exists(os.path.join(self.egocentric_root, cur[-3], cur[-2])): 61 | self.image_list.append(img_path) 62 | 63 | def __getitem__(self, item): 64 | exocentric_image_path = self.image_list[item] 65 | names = exocentric_image_path.split("/") 66 | aff_name, object = names[-3], names[-2] 67 | object_file = os.path.join(self.egocentric_root, aff_name, object) 68 | obj_images = os.listdir(object_file) 69 | label = self.aff_list.index(aff_name) 70 | idx = random.randint(0, len(obj_images) - 1) # Randomly select a picture 71 | egocentric_image_path = os.path.join(object_file, obj_images[idx]) 72 | names = egocentric_image_path.split("/") 73 | egocentric_image = self.load_static_image(egocentric_image_path) 74 | 75 | exocentric_file = os.path.join(self.exocentric_root, aff_name, object) 76 | exocentrics = os.listdir(exocentric_file) 77 | exocentric_images = [] 78 | if len(exocentrics) <= 3: # If it is less than 3, select all of them and fill them up 79 | start = 0 80 | for i in range(start, len(exocentrics)): 81 | tmp_exo = self.load_static_image(os.path.join(self.exocentric_root, aff_name, object, exocentrics[i])) 82 | exocentric_images.append(tmp_exo) 83 | for i in range(len(exocentrics), 3): 84 | exocentric_images.append(tmp_exo) 85 | else: 86 | start = random.randint(0, len(exocentrics) - 4) # If more than 3, choose 3 randomly 87 | for idx in range(start, start + 3): 88 | tmp_exo = self.load_static_image(os.path.join(self.exocentric_root, aff_name, object, exocentrics[idx])) 89 | exocentric_images.append(tmp_exo) 90 | 91 | return exocentric_images, egocentric_image, label, aff_name 92 | 93 | def load_static_image(self, path): 94 | 95 | img = util.load_img(path) 96 | img = self.transform(img) 97 | return img 98 | 99 | def __len__(self): 100 | 101 | return len(self.image_list) 102 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Some document here 2 | -------------------------------------------------------------------------------- /docs/WSMA-appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/docs/WSMA-appendix.pdf -------------------------------------------------------------------------------- /images/README.md: -------------------------------------------------------------------------------- 1 | Some images 2 | -------------------------------------------------------------------------------- /images/pipelline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/images/pipelline.png -------------------------------------------------------------------------------- /images/pipelline1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/images/pipelline1.pdf -------------------------------------------------------------------------------- /models/dino/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /models/dino/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /models/dino/__pycache__/vision_transformer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/vision_transformer.cpython-311.pyc -------------------------------------------------------------------------------- /models/dino/__pycache__/vision_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/vision_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/dino/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Misc functions. 16 | 17 | Mostly copy-paste from torchvision references or other public repos like DETR: 18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 19 | """ 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import random 25 | import datetime 26 | import subprocess 27 | from collections import defaultdict, deque 28 | 29 | import numpy as np 30 | import torch 31 | from torch import nn 32 | import torch.distributed as dist 33 | from PIL import ImageFilter, ImageOps 34 | 35 | 36 | class GaussianBlur(object): 37 | """ 38 | Apply Gaussian Blur to the PIL image. 39 | """ 40 | 41 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 42 | self.prob = p 43 | self.radius_min = radius_min 44 | self.radius_max = radius_max 45 | 46 | def __call__(self, img): 47 | do_it = random.random() <= self.prob 48 | if not do_it: 49 | return img 50 | 51 | return img.filter( 52 | ImageFilter.GaussianBlur( 53 | radius=random.uniform(self.radius_min, self.radius_max) 54 | ) 55 | ) 56 | 57 | 58 | class Solarization(object): 59 | """ 60 | Apply Solarization to the PIL image. 61 | """ 62 | 63 | def __init__(self, p): 64 | self.p = p 65 | 66 | def __call__(self, img): 67 | if random.random() < self.p: 68 | return ImageOps.solarize(img) 69 | else: 70 | return img 71 | 72 | 73 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): 74 | if os.path.isfile(pretrained_weights): 75 | state_dict = torch.load(pretrained_weights, map_location="cpu") 76 | if checkpoint_key is not None and checkpoint_key in state_dict: 77 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 78 | state_dict = state_dict[checkpoint_key] 79 | # remove `module.` prefix 80 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 81 | # remove `backbone.` prefix induced by multicrop wrapper 82 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 83 | msg = model.load_state_dict(state_dict, strict=False) 84 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 85 | else: 86 | # print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") 87 | url = None 88 | if model_name == "vit_small" and patch_size == 16: 89 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 90 | elif model_name == "vit_small" and patch_size == 8: 91 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" 92 | elif model_name == "vit_base" and patch_size == 16: 93 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 94 | elif model_name == "vit_base" and patch_size == 8: 95 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 96 | if url is not None: 97 | # print("Since no pretrained weights have been provided, we load the reference pretrained dino weights.") 98 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 99 | model.load_state_dict(state_dict, strict=True) 100 | else: 101 | print("There is no reference weights available for this model => We use random weights.") 102 | 103 | 104 | def clip_gradients(model, clip): 105 | norms = [] 106 | for name, p in model.named_parameters(): 107 | if p.grad is not None: 108 | param_norm = p.grad.data.norm(2) 109 | norms.append(param_norm.item()) 110 | clip_coef = clip / (param_norm + 1e-6) 111 | if clip_coef < 1: 112 | p.grad.data.mul_(clip_coef) 113 | return norms 114 | 115 | 116 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 117 | if epoch >= freeze_last_layer: 118 | return 119 | for n, p in model.named_parameters(): 120 | if "last_layer" in n: 121 | p.grad = None 122 | 123 | 124 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 125 | """ 126 | Re-start from checkpoint 127 | """ 128 | if not os.path.isfile(ckp_path): 129 | return 130 | print("Found checkpoint at {}".format(ckp_path)) 131 | 132 | # open checkpoint file 133 | checkpoint = torch.load(ckp_path, map_location="cpu") 134 | 135 | # key is what to look for in the checkpoint file 136 | # value is the object to load 137 | # example: {'state_dict': model} 138 | for key, value in kwargs.items(): 139 | if key in checkpoint and value is not None: 140 | try: 141 | msg = value.load_state_dict(checkpoint[key], strict=False) 142 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) 143 | except TypeError: 144 | try: 145 | msg = value.load_state_dict(checkpoint[key]) 146 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path)) 147 | except ValueError: 148 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path)) 149 | else: 150 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) 151 | 152 | # re load variable important for the run 153 | if run_variables is not None: 154 | for var_name in run_variables: 155 | if var_name in checkpoint: 156 | run_variables[var_name] = checkpoint[var_name] 157 | 158 | 159 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 160 | warmup_schedule = np.array([]) 161 | warmup_iters = warmup_epochs * niter_per_ep 162 | if warmup_epochs > 0: 163 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 164 | 165 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 166 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 167 | 168 | schedule = np.concatenate((warmup_schedule, schedule)) 169 | assert len(schedule) == epochs * niter_per_ep 170 | return schedule 171 | 172 | 173 | def bool_flag(s): 174 | """ 175 | Parse boolean arguments from the command line. 176 | """ 177 | FALSY_STRINGS = {"off", "false", "0"} 178 | TRUTHY_STRINGS = {"on", "true", "1"} 179 | if s.lower() in FALSY_STRINGS: 180 | return False 181 | elif s.lower() in TRUTHY_STRINGS: 182 | return True 183 | else: 184 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 185 | 186 | 187 | def fix_random_seeds(seed=31): 188 | """ 189 | Fix random seeds. 190 | """ 191 | torch.manual_seed(seed) 192 | torch.cuda.manual_seed_all(seed) 193 | np.random.seed(seed) 194 | 195 | 196 | class SmoothedValue(object): 197 | """Track a series of values and provide access to smoothed values over a 198 | window or the global series average. 199 | """ 200 | 201 | def __init__(self, window_size=20, fmt=None): 202 | if fmt is None: 203 | fmt = "{median:.6f} ({global_avg:.6f})" 204 | self.deque = deque(maxlen=window_size) 205 | self.total = 0.0 206 | self.count = 0 207 | self.fmt = fmt 208 | 209 | def update(self, value, n=1): 210 | self.deque.append(value) 211 | self.count += n 212 | self.total += value * n 213 | 214 | def synchronize_between_processes(self): 215 | """ 216 | Warning: does not synchronize the deque! 217 | """ 218 | if not is_dist_avail_and_initialized(): 219 | return 220 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 221 | dist.barrier() 222 | dist.all_reduce(t) 223 | t = t.tolist() 224 | self.count = int(t[0]) 225 | self.total = t[1] 226 | 227 | @property 228 | def median(self): 229 | d = torch.tensor(list(self.deque)) 230 | return d.median().item() 231 | 232 | @property 233 | def avg(self): 234 | d = torch.tensor(list(self.deque), dtype=torch.float32) 235 | return d.mean().item() 236 | 237 | @property 238 | def global_avg(self): 239 | return self.total / self.count 240 | 241 | @property 242 | def max(self): 243 | return max(self.deque) 244 | 245 | @property 246 | def value(self): 247 | return self.deque[-1] 248 | 249 | def __str__(self): 250 | return self.fmt.format( 251 | median=self.median, 252 | avg=self.avg, 253 | global_avg=self.global_avg, 254 | max=self.max, 255 | value=self.value) 256 | 257 | 258 | def reduce_dict(input_dict, average=True): 259 | """ 260 | Args: 261 | input_dict (dict): all the values will be reduced 262 | average (bool): whether to do average or sum 263 | Reduce the values in the dictionary from all processes so that all processes 264 | have the averaged results. Returns a dict with the same fields as 265 | input_dict, after reduction. 266 | """ 267 | world_size = get_world_size() 268 | if world_size < 2: 269 | return input_dict 270 | with torch.no_grad(): 271 | names = [] 272 | values = [] 273 | # sort the keys so that they are consistent across processes 274 | for k in sorted(input_dict.keys()): 275 | names.append(k) 276 | values.append(input_dict[k]) 277 | values = torch.stack(values, dim=0) 278 | dist.all_reduce(values) 279 | if average: 280 | values /= world_size 281 | reduced_dict = {k: v for k, v in zip(names, values)} 282 | return reduced_dict 283 | 284 | 285 | class MetricLogger(object): 286 | def __init__(self, delimiter="\t"): 287 | self.meters = defaultdict(SmoothedValue) 288 | self.delimiter = delimiter 289 | 290 | def update(self, **kwargs): 291 | for k, v in kwargs.items(): 292 | if isinstance(v, torch.Tensor): 293 | v = v.item() 294 | assert isinstance(v, (float, int)) 295 | self.meters[k].update(v) 296 | 297 | def __getattr__(self, attr): 298 | if attr in self.meters: 299 | return self.meters[attr] 300 | if attr in self.__dict__: 301 | return self.__dict__[attr] 302 | raise AttributeError("'{}' object has no attribute '{}'".format( 303 | type(self).__name__, attr)) 304 | 305 | def __str__(self): 306 | loss_str = [] 307 | for name, meter in self.meters.items(): 308 | loss_str.append( 309 | "{}: {}".format(name, str(meter)) 310 | ) 311 | return self.delimiter.join(loss_str) 312 | 313 | def synchronize_between_processes(self): 314 | for meter in self.meters.values(): 315 | meter.synchronize_between_processes() 316 | 317 | def add_meter(self, name, meter): 318 | self.meters[name] = meter 319 | 320 | def log_every(self, iterable, print_freq, header=None): 321 | i = 0 322 | if not header: 323 | header = '' 324 | start_time = time.time() 325 | end = time.time() 326 | iter_time = SmoothedValue(fmt='{avg:.6f}') 327 | data_time = SmoothedValue(fmt='{avg:.6f}') 328 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 329 | if torch.cuda.is_available(): 330 | log_msg = self.delimiter.join([ 331 | header, 332 | '[{0' + space_fmt + '}/{1}]', 333 | 'eta: {eta}', 334 | '{meters}', 335 | 'time: {time}', 336 | 'data: {data}', 337 | 'max mem: {memory:.0f}' 338 | ]) 339 | else: 340 | log_msg = self.delimiter.join([ 341 | header, 342 | '[{0' + space_fmt + '}/{1}]', 343 | 'eta: {eta}', 344 | '{meters}', 345 | 'time: {time}', 346 | 'data: {data}' 347 | ]) 348 | MB = 1024.0 * 1024.0 349 | for obj in iterable: 350 | data_time.update(time.time() - end) 351 | yield obj 352 | iter_time.update(time.time() - end) 353 | if i % print_freq == 0 or i == len(iterable) - 1: 354 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 355 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 356 | if torch.cuda.is_available(): 357 | print(log_msg.format( 358 | i, len(iterable), eta=eta_string, 359 | meters=str(self), 360 | time=str(iter_time), data=str(data_time), 361 | memory=torch.cuda.max_memory_allocated() / MB)) 362 | else: 363 | print(log_msg.format( 364 | i, len(iterable), eta=eta_string, 365 | meters=str(self), 366 | time=str(iter_time), data=str(data_time))) 367 | i += 1 368 | end = time.time() 369 | total_time = time.time() - start_time 370 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 371 | print('{} Total time: {} ({:.6f} s / it)'.format( 372 | header, total_time_str, total_time / len(iterable))) 373 | 374 | 375 | def get_sha(): 376 | cwd = os.path.dirname(os.path.abspath(__file__)) 377 | 378 | def _run(command): 379 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 380 | 381 | sha = 'N/A' 382 | diff = "clean" 383 | branch = 'N/A' 384 | try: 385 | sha = _run(['git', 'rev-parse', 'HEAD']) 386 | subprocess.check_output(['git', 'diff'], cwd=cwd) 387 | diff = _run(['git', 'diff-index', 'HEAD']) 388 | diff = "has uncommited changes" if diff else "clean" 389 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 390 | except Exception: 391 | pass 392 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 393 | return message 394 | 395 | 396 | def is_dist_avail_and_initialized(): 397 | if not dist.is_available(): 398 | return False 399 | if not dist.is_initialized(): 400 | return False 401 | return True 402 | 403 | 404 | def get_world_size(): 405 | if not is_dist_avail_and_initialized(): 406 | return 1 407 | return dist.get_world_size() 408 | 409 | 410 | def get_rank(): 411 | if not is_dist_avail_and_initialized(): 412 | return 0 413 | return dist.get_rank() 414 | 415 | 416 | def is_main_process(): 417 | return get_rank() == 0 418 | 419 | 420 | def save_on_master(*args, **kwargs): 421 | if is_main_process(): 422 | torch.save(*args, **kwargs) 423 | 424 | 425 | def setup_for_distributed(is_master): 426 | """ 427 | This function disables printing when not in master process 428 | """ 429 | import builtins as __builtin__ 430 | builtin_print = __builtin__.print 431 | 432 | def print(*args, **kwargs): 433 | force = kwargs.pop('force', False) 434 | if is_master or force: 435 | builtin_print(*args, **kwargs) 436 | 437 | __builtin__.print = print 438 | 439 | 440 | def init_distributed_mode(args): 441 | # launched with torch.distributed.launch 442 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 443 | args.rank = int(os.environ["RANK"]) 444 | args.world_size = int(os.environ['WORLD_SIZE']) 445 | args.gpu = int(os.environ['LOCAL_RANK']) 446 | # launched with submitit on a slurm cluster 447 | elif 'SLURM_PROCID' in os.environ: 448 | args.rank = int(os.environ['SLURM_PROCID']) 449 | args.gpu = args.rank % torch.cuda.device_count() 450 | # launched naively with `python main_dino.py` 451 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 452 | elif torch.cuda.is_available(): 453 | print('Will run the code on one GPU.') 454 | args.rank, args.gpu, args.world_size = 0, 0, 1 455 | os.environ['MASTER_ADDR'] = '127.0.0.1' 456 | os.environ['MASTER_PORT'] = '29500' 457 | else: 458 | print('Does not support training without GPU.') 459 | sys.exit(1) 460 | 461 | dist.init_process_group( 462 | backend="nccl", 463 | init_method=args.dist_url, 464 | world_size=args.world_size, 465 | rank=args.rank, 466 | ) 467 | 468 | torch.cuda.set_device(args.gpu) 469 | print('| distributed init (rank {}): {}'.format( 470 | args.rank, args.dist_url), flush=True) 471 | dist.barrier() 472 | setup_for_distributed(args.rank == 0) 473 | 474 | 475 | def accuracy(output, target, topk=(1,)): 476 | """Computes the accuracy over the k top predictions for the specified values of k""" 477 | maxk = max(topk) 478 | batch_size = target.size(0) 479 | _, pred = output.topk(maxk, 1, True, True) 480 | pred = pred.t() 481 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 482 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 483 | 484 | 485 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 486 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 487 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 488 | def norm_cdf(x): 489 | # Computes standard normal cumulative distribution function 490 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 491 | 492 | if (mean < a - 2 * std) or (mean > b + 2 * std): 493 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 494 | "The distribution of values may be incorrect.", 495 | stacklevel=2) 496 | 497 | with torch.no_grad(): 498 | # Values are generated by using a truncated uniform distribution and 499 | # then using the inverse CDF for the normal distribution. 500 | # Get upper and lower cdf values 501 | l = norm_cdf((a - mean) / std) 502 | u = norm_cdf((b - mean) / std) 503 | 504 | # Uniformly fill tensor with values from [l, u], then translate to 505 | # [2l-1, 2u-1]. 506 | tensor.uniform_(2 * l - 1, 2 * u - 1) 507 | 508 | # Use inverse cdf transform for normal distribution to get truncated 509 | # standard normal 510 | tensor.erfinv_() 511 | 512 | # Transform to proper mean, std 513 | tensor.mul_(std * math.sqrt(2.)) 514 | tensor.add_(mean) 515 | 516 | # Clamp to ensure it's in the proper range 517 | tensor.clamp_(min=a, max=b) 518 | return tensor 519 | 520 | 521 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 522 | # type: (Tensor, float, float, float, float) -> Tensor 523 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 524 | 525 | 526 | class LARS(torch.optim.Optimizer): 527 | """ 528 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 529 | """ 530 | 531 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, 532 | weight_decay_filter=None, lars_adaptation_filter=None): 533 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 534 | eta=eta, weight_decay_filter=weight_decay_filter, 535 | lars_adaptation_filter=lars_adaptation_filter) 536 | super().__init__(params, defaults) 537 | 538 | @torch.no_grad() 539 | def step(self): 540 | for g in self.param_groups: 541 | for p in g['params']: 542 | dp = p.grad 543 | 544 | if dp is None: 545 | continue 546 | 547 | if p.ndim != 1: 548 | dp = dp.add(p, alpha=g['weight_decay']) 549 | 550 | if p.ndim != 1: 551 | param_norm = torch.norm(p) 552 | update_norm = torch.norm(dp) 553 | one = torch.ones_like(param_norm) 554 | q = torch.where(param_norm > 0., 555 | torch.where(update_norm > 0, 556 | (g['eta'] * param_norm / update_norm), one), one) 557 | dp = dp.mul(q) 558 | 559 | param_state = self.state[p] 560 | if 'mu' not in param_state: 561 | param_state['mu'] = torch.zeros_like(p) 562 | mu = param_state['mu'] 563 | mu.mul_(g['momentum']).add_(dp) 564 | 565 | p.add_(mu, alpha=-g['lr']) 566 | 567 | 568 | class MultiCropWrapper(nn.Module): 569 | """ 570 | Perform forward pass separately on each resolution input. 571 | The inputs corresponding to a single resolution are clubbed and single 572 | forward is run on the same resolution inputs. Hence we do several 573 | forward passes = number of different resolutions used. We then 574 | concatenate all the output features and run the head forward on these 575 | concatenated features. 576 | """ 577 | 578 | def __init__(self, backbone, head): 579 | super(MultiCropWrapper, self).__init__() 580 | # disable layers dedicated to ImageNet labels classification 581 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 582 | self.backbone = backbone 583 | self.head = head 584 | 585 | def forward(self, x): 586 | # convert to list 587 | if not isinstance(x, list): 588 | x = [x] 589 | idx_crops = torch.cumsum(torch.unique_consecutive( 590 | torch.tensor([inp.shape[-1] for inp in x]), 591 | return_counts=True, 592 | )[1], 0) 593 | start_idx = 0 594 | for end_idx in idx_crops: 595 | _out = self.backbone(torch.cat(x[start_idx: end_idx])) 596 | if start_idx == 0: 597 | output = _out 598 | else: 599 | output = torch.cat((output, _out)) 600 | start_idx = end_idx 601 | # Run the head forward on the concatenated features. 602 | return self.head(output) 603 | 604 | 605 | def get_params_groups(model): 606 | regularized = [] 607 | not_regularized = [] 608 | for name, param in model.named_parameters(): 609 | if not param.requires_grad: 610 | continue 611 | # we do not regularize biases nor Norm parameters 612 | if name.endswith(".bias") or len(param.shape) == 1: 613 | not_regularized.append(param) 614 | else: 615 | regularized.append(param) 616 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 617 | 618 | 619 | def has_batchnorms(model): 620 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 621 | for name, module in model.named_modules(): 622 | if isinstance(module, bn_types): 623 | return True 624 | return False 625 | -------------------------------------------------------------------------------- /models/dino/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from torch import Tensor 20 | from typing import Union 21 | from functools import partial 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | from models.dino.utils import trunc_normal_ 27 | 28 | 29 | def drop_path(x, drop_prob: float = 0., training: bool = False): 30 | if drop_prob == 0. or not training: 31 | return x 32 | keep_prob = 1 - drop_prob 33 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 34 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 35 | random_tensor.floor_() # binarize 36 | output = x.div(keep_prob) * random_tensor 37 | return output 38 | 39 | 40 | class DropPath(nn.Module): 41 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 42 | """ 43 | 44 | def __init__(self, drop_prob=None): 45 | super(DropPath, self).__init__() 46 | self.drop_prob = drop_prob 47 | 48 | def forward(self, x): 49 | return drop_path(x, self.drop_prob, self.training) 50 | 51 | 52 | class Mlp(nn.Module): 53 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 54 | super().__init__() 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | self.fc1 = nn.Linear(in_features, hidden_features) 58 | self.act = act_layer() 59 | self.fc2 = nn.Linear(hidden_features, out_features) 60 | self.drop = nn.Dropout(drop) 61 | 62 | def forward(self, x): 63 | x = self.fc1(x) 64 | x = self.act(x) 65 | x = self.drop(x) 66 | x = self.fc2(x) 67 | x = self.drop(x) 68 | return x 69 | 70 | 71 | class Attention(nn.Module): 72 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 73 | super().__init__() 74 | self.num_heads = num_heads 75 | head_dim = dim // num_heads 76 | self.scale = qk_scale or head_dim ** -0.5 77 | 78 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 79 | self.attn_drop = nn.Dropout(attn_drop) 80 | self.proj = nn.Linear(dim, dim) 81 | self.proj_drop = nn.Dropout(proj_drop) 82 | 83 | def forward(self, x, return_key=False): 84 | B, N, C = x.shape 85 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 86 | q, k, v = qkv[0], qkv[1], qkv[2] 87 | 88 | attn = (q @ k.transpose(-2, -1)) * self.scale 89 | attn = attn.softmax(dim=-1) 90 | attn = self.attn_drop(attn) 91 | 92 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 93 | x = self.proj(x) 94 | x = self.proj_drop(x) 95 | if not return_key: 96 | return x, attn 97 | else: 98 | return x, attn, k 99 | 100 | class Block(nn.Module): 101 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 102 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 103 | super().__init__() 104 | 105 | self.norm1 = norm_layer(dim) 106 | self.attn = Attention( 107 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 108 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 109 | self.norm2 = norm_layer(dim) 110 | mlp_hidden_dim = int(dim * mlp_ratio) 111 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 112 | 113 | def forward(self, x, return_attention=False, return_key=False): 114 | if return_key: 115 | y, attn, key = self.attn(self.norm1(x), return_key) 116 | else: 117 | y, attn = self.attn(self.norm1(x)) 118 | x = x + self.drop_path(y) 119 | x = x + self.drop_path(self.mlp(self.norm2(x))) 120 | if return_attention: 121 | return x, attn 122 | elif return_key: 123 | return x, key, attn 124 | else: 125 | return x 126 | 127 | 128 | class PatchEmbed(nn.Module): 129 | """ Image to Patch Embedding 130 | """ 131 | 132 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 133 | super().__init__() 134 | num_patches = (img_size // patch_size) * (img_size // patch_size) 135 | self.img_size = img_size 136 | self.patch_size = patch_size 137 | self.num_patches = num_patches 138 | 139 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 140 | 141 | def forward(self, x): 142 | B, C, H, W = x.shape 143 | x = self.proj(x).flatten(2).transpose(1, 2) 144 | return x 145 | 146 | 147 | class VisionTransformer(nn.Module): 148 | """ Vision Transformer """ 149 | 150 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 151 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 152 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 153 | super().__init__() 154 | self.num_features = self.embed_dim = embed_dim 155 | 156 | self.patch_embed = PatchEmbed( 157 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 158 | num_patches = self.patch_embed.num_patches 159 | 160 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 161 | 162 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 163 | self.pos_drop = nn.Dropout(p=drop_rate) 164 | 165 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 166 | self.blocks = nn.ModuleList([ 167 | Block( 168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 169 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 170 | for i in range(depth)]) 171 | self.norm = norm_layer(embed_dim) 172 | 173 | # Classifier head 174 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 175 | 176 | trunc_normal_(self.pos_embed, std=.02) 177 | trunc_normal_(self.cls_token, std=.02) 178 | self.apply(self._init_weights) 179 | 180 | def _init_weights(self, m): 181 | if isinstance(m, nn.Linear): 182 | trunc_normal_(m.weight, std=.02) 183 | if isinstance(m, nn.Linear) and m.bias is not None: 184 | nn.init.constant_(m.bias, 0) 185 | elif isinstance(m, nn.LayerNorm): 186 | nn.init.constant_(m.bias, 0) 187 | nn.init.constant_(m.weight, 1.0) 188 | 189 | def interpolate_pos_encoding(self, x, w, h): 190 | npatch = x.shape[1] - 1 191 | N = self.pos_embed.shape[1] - 1 192 | if npatch == N and w == h: 193 | return self.pos_embed 194 | class_pos_embed = self.pos_embed[:, 0] 195 | patch_pos_embed = self.pos_embed[:, 1:] 196 | dim = x.shape[-1] 197 | w0 = w // self.patch_embed.patch_size 198 | h0 = h // self.patch_embed.patch_size 199 | # we add a small number to avoid floating point error in the interpolation 200 | # see discussion at https://github.com/facebookresearch/dino/issues/8 201 | w0, h0 = w0 + 0.1, h0 + 0.1 202 | patch_pos_embed = nn.functional.interpolate( 203 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 204 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 205 | mode='bicubic', 206 | ) 207 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 208 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 209 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 210 | 211 | def prepare_tokens(self, x): 212 | B, nc, w, h = x.shape 213 | x = self.patch_embed(x) # patch linear embedding 214 | 215 | # add the [CLS] token to the embed patch tokens 216 | cls_tokens = self.cls_token.expand(B, -1, -1) 217 | x = torch.cat((cls_tokens, x), dim=1) 218 | 219 | # add positional encoding to each token 220 | x = x + self.interpolate_pos_encoding(x, w, h) 221 | 222 | return self.pos_drop(x) 223 | 224 | def forward(self, x, return_attention=False): 225 | if return_attention: 226 | atten_weights = [] 227 | x = self.prepare_tokens(x) 228 | for blk_ in self.blocks: 229 | x, weights = blk_(x, return_attention) 230 | atten_weights.append(weights) 231 | x = self.norm(x) 232 | return x, atten_weights 233 | 234 | else: 235 | x = self.prepare_tokens(x) 236 | for blk_ in self.blocks: 237 | x = blk_(x) 238 | x = self.norm(x) 239 | return x 240 | 241 | def get_last_selfattention(self, x): 242 | x = self.prepare_tokens(x) 243 | for i, blk in enumerate(self.blocks): 244 | if i < len(self.blocks) - 1: 245 | x = blk(x) 246 | else: 247 | # return attention of the last block 248 | return blk(x, return_attention=True) 249 | 250 | def get_intermediate_layers(self, x, n=1): 251 | x = self.prepare_tokens(x) 252 | # we return the output tokens from the `n` last blocks 253 | output = [] 254 | for i, blk in enumerate(self.blocks): 255 | x = blk(x) 256 | if len(self.blocks) - i <= n: 257 | output.append(self.norm(x)) 258 | return output 259 | 260 | def get_last_key(self, x, extra_layer=None): 261 | x = self.prepare_tokens(x) 262 | key_mid = 0 263 | for i, blk in enumerate(self.blocks): 264 | if extra_layer != None and i == extra_layer: 265 | x, key, attn = blk(x, return_key=True) 266 | key_mid = key 267 | elif i < len(self.blocks) - 1: 268 | x = blk(x) 269 | else: 270 | # return attention of the last block 271 | x, key, attn = blk(x, return_key=True) 272 | if extra_layer == None: 273 | return x, key, attn 274 | else: 275 | return key_mid, x, key, attn 276 | 277 | def get_all_key(self, x): 278 | x = self.prepare_tokens(x) 279 | key_mid = 0 280 | keys=[] 281 | xs=[] 282 | for i, blk in enumerate(self.blocks): 283 | if i < len(self.blocks) - 1: 284 | x, key, attn = blk(x, return_key=True) 285 | keys.append(key) 286 | xs.append(x) 287 | else: 288 | # return attention of the last block 289 | x, key, attn = blk(x, return_key=True) 290 | keys.append(key) 291 | xs.append(x) 292 | return xs, keys, attn 293 | 294 | 295 | def vit_tiny(patch_size=16, **kwargs): 296 | model = VisionTransformer( 297 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 298 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 299 | return model 300 | 301 | 302 | def vit_small(patch_size=16, **kwargs): 303 | model = VisionTransformer( 304 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 305 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 306 | return model 307 | 308 | 309 | def vit_base(patch_size=16, **kwargs): 310 | model = VisionTransformer( 311 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 312 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 313 | return model 314 | 315 | 316 | class DINOHead(nn.Module): 317 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, 318 | bottleneck_dim=256): 319 | super().__init__() 320 | nlayers = max(nlayers, 1) 321 | if nlayers == 1: 322 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 323 | else: 324 | layers = [nn.Linear(in_dim, hidden_dim)] 325 | if use_bn: 326 | layers.append(nn.BatchNorm1d(hidden_dim)) 327 | layers.append(nn.GELU()) 328 | for _ in range(nlayers - 2): 329 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 330 | if use_bn: 331 | layers.append(nn.BatchNorm1d(hidden_dim)) 332 | layers.append(nn.GELU()) 333 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 334 | self.mlp = nn.Sequential(*layers) 335 | self.apply(self._init_weights) 336 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 337 | self.last_layer.weight_g.data.fill_(1) 338 | if norm_last_layer: 339 | self.last_layer.weight_g.requires_grad = False 340 | 341 | def _init_weights(self, m): 342 | if isinstance(m, nn.Linear): 343 | trunc_normal_(m.weight, std=.02) 344 | if isinstance(m, nn.Linear) and m.bias is not None: 345 | nn.init.constant_(m.bias, 0) 346 | 347 | def forward(self, x): 348 | x = self.mlp(x) 349 | x = nn.functional.normalize(x, dim=-1, p=2) 350 | x = self.last_layer(x) 351 | return x 352 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.dino import vision_transformer as vits 4 | from models.dino.utils import load_pretrained_weights 5 | import numpy as np 6 | from torch.nn import functional as F 7 | from AIM.modules.burger import HamburgerV1 8 | from collections import OrderedDict 9 | from pkg_resources import packaging 10 | from simple_tokenizer import SimpleTokenizer as _Tokenizer 11 | from typing import Any, Union, List 12 | 13 | _tokenizer = _Tokenizer() 14 | 15 | class QuickGELU(nn.Module): 16 | def forward(self, x: torch.Tensor): 17 | return x * torch.sigmoid(1.702 * x) 18 | 19 | class ResidualAttentionBlock(nn.Module): 20 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 21 | super().__init__() 22 | 23 | self.attn = nn.MultiheadAttention(d_model, n_head) 24 | self.ln_1 = nn.LayerNorm(d_model) 25 | self.mlp = nn.Sequential(OrderedDict([ 26 | ("c_fc", nn.Linear(d_model, d_model * 4)), 27 | ("gelu", QuickGELU()), 28 | ("c_proj", nn.Linear(d_model * 4, d_model)) 29 | ])) 30 | self.ln_2 = nn.LayerNorm(d_model) 31 | self.attn_mask = attn_mask 32 | 33 | def attention(self, x: torch.Tensor): 34 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 35 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 36 | 37 | def forward(self, x: torch.Tensor): 38 | x = x + self.attention(self.ln_1(x)) 39 | x = x + self.mlp(self.ln_2(x)) 40 | return x 41 | 42 | class Transformer(nn.Module): 43 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 44 | super().__init__() 45 | self.width = width 46 | self.layers = layers 47 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 48 | 49 | def forward(self, x: torch.Tensor): 50 | return self.resblocks(x) 51 | 52 | 53 | class Mlp(nn.Module): 54 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 55 | super(Mlp, self).__init__() 56 | out_features = out_features or in_features 57 | hidden_features = hidden_features or in_features 58 | self.norm = nn.LayerNorm(in_features) 59 | self.fc1 = nn.Linear(in_features, hidden_features) 60 | self.act = act_layer() 61 | self.fc2 = nn.Linear(hidden_features, out_features) 62 | self.drop = nn.Dropout(drop) 63 | 64 | def forward(self, x): 65 | x = self.norm(x) 66 | x = self.fc1(x) 67 | x = self.act(x) 68 | x = self.drop(x) 69 | x = self.fc2(x) 70 | x = self.drop(x) 71 | return x 72 | 73 | class PromptLearner(nn.Module): 74 | def __init__(self, classnames, ln_final, token_embedding): 75 | super().__init__() 76 | n_cls = len(classnames) 77 | n_ctx = 16 78 | ctx_init = "" 79 | dtype = ln_final.weight.dtype 80 | ctx_dim = ln_final.weight.shape[0] 81 | 82 | if ctx_init: 83 | # use given words to initialize context vectors 84 | ctx_init = ctx_init.replace("_", " ") 85 | n_ctx = len(ctx_init.split(" ")) 86 | prompt = tokenize(ctx_init) 87 | with torch.no_grad(): 88 | embedding = token_embedding(prompt).type(dtype) 89 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] 90 | prompt_prefix = ctx_init 91 | 92 | else: 93 | print("Initializing a generic context") 94 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 95 | nn.init.normal_(ctx_vectors, std=0.02) 96 | prompt_prefix = " ".join(["X"] * n_ctx) 97 | 98 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized 99 | 100 | classnames = [name.replace("_", " ") for name in classnames] 101 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 102 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 103 | 104 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts]) 105 | with torch.no_grad(): 106 | embedding = token_embedding(tokenized_prompts).type(dtype) 107 | 108 | # These token vectors will be saved when in save_model(), 109 | # but they should be ignored in load_model() as we want to use 110 | # those computed using the current class names 111 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 112 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS 113 | 114 | self.n_cls = n_cls 115 | self.n_ctx = n_ctx 116 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 117 | self.name_lens = name_lens 118 | self.class_token_position = 'end' 119 | 120 | def forward(self): 121 | ctx = self.ctx 122 | if ctx.dim() == 2: 123 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 124 | 125 | prefix = self.token_prefix 126 | suffix = self.token_suffix 127 | 128 | if self.class_token_position == "end": 129 | prompts = torch.cat( 130 | [ 131 | prefix, # (n_cls, 1, dim) 132 | ctx, # (n_cls, n_ctx, dim) 133 | suffix, # (n_cls, *, dim) 134 | ], 135 | dim=1, 136 | ) 137 | 138 | elif self.class_token_position == "middle": 139 | half_n_ctx = self.n_ctx // 2 140 | prompts = [] 141 | for i in range(self.n_cls): 142 | name_len = self.name_lens[i] 143 | prefix_i = prefix[i : i + 1, :, :] 144 | class_i = suffix[i : i + 1, :name_len, :] 145 | suffix_i = suffix[i : i + 1, name_len:, :] 146 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] 147 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] 148 | prompt = torch.cat( 149 | [ 150 | prefix_i, # (1, 1, dim) 151 | ctx_i_half1, # (1, n_ctx//2, dim) 152 | class_i, # (1, name_len, dim) 153 | ctx_i_half2, # (1, n_ctx//2, dim) 154 | suffix_i, # (1, *, dim) 155 | ], 156 | dim=1, 157 | ) 158 | prompts.append(prompt) 159 | prompts = torch.cat(prompts, dim=0) 160 | 161 | elif self.class_token_position == "front": 162 | prompts = [] 163 | for i in range(self.n_cls): 164 | name_len = self.name_lens[i] 165 | prefix_i = prefix[i : i + 1, :, :] 166 | class_i = suffix[i : i + 1, :name_len, :] 167 | suffix_i = suffix[i : i + 1, name_len:, :] 168 | ctx_i = ctx[i : i + 1, :, :] 169 | prompt = torch.cat( 170 | [ 171 | prefix_i, # (1, 1, dim) 172 | class_i, # (1, name_len, dim) 173 | ctx_i, # (1, n_ctx, dim) 174 | suffix_i, # (1, *, dim) 175 | ], 176 | dim=1, 177 | ) 178 | prompts.append(prompt) 179 | prompts = torch.cat(prompts, dim=0) 180 | 181 | else: 182 | raise ValueError 183 | 184 | return prompts 185 | 186 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 187 | if isinstance(texts, str): 188 | texts = [texts] 189 | 190 | sot_token = _tokenizer.encoder["<|startoftext|>"] 191 | eot_token = _tokenizer.encoder["<|endoftext|>"] 192 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 193 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 194 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 195 | else: 196 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 197 | 198 | for i, tokens in enumerate(all_tokens): 199 | if len(tokens) > context_length: 200 | if truncate: 201 | tokens = tokens[:context_length] 202 | tokens[-1] = eot_token 203 | else: 204 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 205 | result[i, :len(tokens)] = torch.tensor(tokens) 206 | 207 | return result 208 | 209 | 210 | class AttentionPool2d(nn.Module): 211 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 212 | super().__init__() 213 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 214 | self.k_proj = nn.Linear(embed_dim, embed_dim) 215 | self.q_proj = nn.Linear(embed_dim, embed_dim) 216 | self.v_proj = nn.Linear(embed_dim, embed_dim) 217 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 218 | self.num_heads = num_heads 219 | 220 | def forward(self, x): 221 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 222 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 223 | x_1 = x + self.positional_embedding[:, None, :].to(x.dtype) 224 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 225 | x, att = F.multi_head_attention_forward( 226 | query=x, key=x, value=x, 227 | embed_dim_to_check=x.shape[-1], 228 | num_heads=self.num_heads, 229 | q_proj_weight=self.q_proj.weight, 230 | k_proj_weight=self.k_proj.weight, 231 | v_proj_weight=self.v_proj.weight, 232 | in_proj_weight=None, 233 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 234 | bias_k=None, 235 | bias_v=None, 236 | add_zero_attn=False, 237 | dropout_p=0, 238 | out_proj_weight=self.c_proj.weight, 239 | out_proj_bias=self.c_proj.bias, 240 | use_separate_proj_weight=True, 241 | training=self.training, 242 | need_weights=True 243 | ) 244 | return x[0], x[1:], att[:, 1:, 1:] 245 | 246 | class Cross_Attention(nn.Module): 247 | def __init__(self, in_dim, out_dim): 248 | super().__init__() 249 | self.in_dim = in_dim 250 | self.out_dim = out_dim 251 | self.proj_q = nn.Linear(in_dim, out_dim) 252 | self.proj_k = nn.Linear(in_dim, out_dim) 253 | self.proj_v = nn.Linear(in_dim, out_dim) 254 | self.scale = self.out_dim ** (-0.5) 255 | 256 | self.norm = nn.LayerNorm(self.in_dim) 257 | def forward(self, ego, exo): 258 | 259 | B, hw, C = ego.size() 260 | query = self.proj_q(ego) 261 | key = self.proj_k(exo) 262 | value = self.proj_v(exo) 263 | 264 | att = torch.bmm(query, key.transpose(1, 2))*self.scale 265 | att = att.softmax(dim=-1) 266 | out = torch.bmm(att, value) 267 | 268 | out = self.norm(out+ego) 269 | 270 | return out.transpose(1, 2).view(B, C, 14, 14) 271 | 272 | class Model(nn.Module): 273 | def __init__(self, args, embed_dim:int, context_length: int, vocab_size: int, 274 | transformer_width: int, transformer_heads: int, 275 | transformer_layers: int, num_classes=36, pretrained=True, n=3, D=512): 276 | super(Model, self).__init__() 277 | self.num_classes = num_classes 278 | self.criterion = nn.CrossEntropyLoss() 279 | self.pretrained = pretrained 280 | self.n = n 281 | self.D = D 282 | self.context_length = context_length 283 | if args.divide == "Seen": 284 | self.classnames = ['beat', "boxing", "brush_with", "carry", "catch", 285 | "cut", "cut_with", "drag", 'drink_with', "eat", 286 | "hit", "hold", "jump", "kick", "lie_on", "lift", 287 | "look_out", "open", "pack", "peel", "pick_up", 288 | "pour", "push", "ride", "sip", "sit_on", "stick", 289 | "stir", "swing", "take_photo", "talk_on", "text_on", 290 | "throw", "type_on", "wash", "write"] 291 | elif args.divide=="Unseen": 292 | self.classnames = ["carry", "catch", "cut", "cut_with", 'drink_with', 293 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 294 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 295 | "swing", "take_photo", "throw", "type_on", "wash"] 296 | else: # HICO-IIF 297 | self.classnames = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on'] 298 | 299 | # dino-vit 300 | self.vit_feat_dim = 384 301 | self.cluster_num = 3 302 | self.stride = 16 303 | self.patch = 16 304 | self.Hamburger = HamburgerV1(in_c=self.vit_feat_dim, n=self.n, D=self.D) 305 | self.vit_model = vits.__dict__['vit_small'](patch_size=self.patch, num_classes=0) 306 | load_pretrained_weights(self.vit_model, '', None, 'vit_small', self.patch) 307 | 308 | self.aff_proj = Mlp(in_features=int(self.vit_feat_dim*2), hidden_features=int(self.vit_feat_dim*2), out_features=self.vit_feat_dim, 309 | act_layer=nn.GELU, drop=0.) 310 | self.aff_ego_proj = nn.Sequential( 311 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1), 312 | nn.BatchNorm2d(self.vit_feat_dim), 313 | nn.ReLU(True), 314 | ) 315 | self.aff_exo_proj = nn.Sequential( 316 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1), 317 | nn.BatchNorm2d(self.vit_feat_dim), 318 | nn.ReLU(True), 319 | ) 320 | 321 | # clip 322 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 323 | self.attnpool = AttentionPool2d(14, self.vit_feat_dim, 64, embed_dim) 324 | self.transformer = Transformer( 325 | width=transformer_width, 326 | layers=transformer_layers, 327 | heads=transformer_heads, 328 | attn_mask=self.build_attention_mask() 329 | ) 330 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 331 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 332 | self.ln_final = nn.LayerNorm(transformer_width) 333 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 334 | 335 | self.prompt_learner = PromptLearner(self.classnames, self.ln_final, self.token_embedding) 336 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 337 | 338 | # fc 339 | self.fc = nn.Linear(self.vit_feat_dim, self.num_classes) 340 | self.avgpool = nn.AdaptiveAvgPool2d(1) 341 | 342 | def encode_text(self, per, text): 343 | x = per + self.token_embedding(text.cuda()).float() # [batch_size, n_ctx, d_model] 344 | 345 | x = x + self.positional_embedding.float() 346 | x = x.permute(1, 0, 2) # NLD -> LND 347 | x = self.transformer(x) 348 | x = x.permute(1, 0, 2) # LND -> NLD 349 | x = self.ln_final(x).float() 350 | 351 | # x.shape = [batch_size, n_ctx, transformer.width] 352 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 353 | return x 354 | 355 | def forward(self, exocentric, egocentric_image, label, text): 356 | target = label.long().squeeze() 357 | b, n, c, h, w = exocentric.size() 358 | exocentrin_input = exocentric.view(b * n, c, h, w) 359 | 360 | # dino_vit 361 | with torch.no_grad(): 362 | _, ego_key, ego_attn = self.vit_model.get_all_key(egocentric_image) 363 | _, exo_key, exo_attn = self.vit_model.get_all_key(exocentrin_input) 364 | ego_desc = ego_key[len(ego_key)-2].permute(0, 2, 3, 1).flatten(-2, -1).detach() 365 | exo_desc = exo_key[len(ego_key)-2].permute(0, 2, 3, 1).flatten(-2, -1).detach() 366 | for i in range(len(ego_key)-1, len(ego_key)): 367 | ego_desc = torch.cat((ego_desc, ego_key[i].permute(0, 2, 3, 1).flatten(-2, -1).detach()), dim=2) 368 | exo_desc = torch.cat((exo_desc, exo_key[i].permute(0, 2, 3, 1).flatten(-2, -1).detach()), dim=2) 369 | 370 | ego_proj = self.aff_proj(ego_desc[:, 1:]) 371 | exo_proj = self.aff_proj(exo_desc[:, 1:]) 372 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride) 373 | exo_proj = self._reshape_transform(exo_proj, self.patch, self.stride) 374 | 375 | exo_proj = self.Hamburger(exo_proj) 376 | 377 | # text branch 378 | prompts = self.prompt_learner() 379 | tokenized_prompts = self.tokenized_prompts 380 | text_features = self.encode_text(prompts, tokenized_prompts) 381 | 382 | e_b, e_c, e_h, e_w = ego_proj.shape 383 | pre_ego = ego_proj 384 | image_features, ego_proj, mu_att = self.attnpool(ego_proj) 385 | 386 | image_features = F.normalize(image_features, dim=1, p=2) 387 | 388 | text_features = F.normalize(text_features, dim=1, p=2) 389 | 390 | 391 | # Pixel-Text Fusion 392 | logit_scale = self.logit_scale.exp() 393 | self.logits_per_image = logit_scale * image_features @ text_features.t() 394 | self.logits_per_text = self.logits_per_image.t() 395 | 396 | text_f = torch.ones((e_b, 1024)).cuda() 397 | for i in range(e_b): 398 | text_f[i] = text_features[label[i]] 399 | att_egoproj = F.normalize(ego_proj, dim=1, p=2) 400 | attego = logit_scale *att_egoproj.permute(1, 0, 2)@text_f.unsqueeze(2) 401 | attego = torch.sigmoid(F.normalize(attego, dim=1, p=2)).permute(1, 0, 2).repeat(1, 1, e_c) 402 | ego_proj = attego.permute(1, 2, 0).view(e_b, e_c, e_h, e_w)*pre_ego + pre_ego 403 | 404 | exocentric_branch =self.aff_exo_proj(exo_proj) 405 | egocentric_branch =self.aff_ego_proj(ego_proj) 406 | 407 | # cls 408 | exo_pool = self.avgpool(exocentric_branch) 409 | exo_pool = exo_pool.view(exo_pool.size(0), -1) 410 | self.exo_score = self.fc(exo_pool) 411 | 412 | batch, channel, h, w = exocentric_branch.shape 413 | exocentric_branch = exocentric_branch.view(batch//3, 3, channel, h, w).mean(1) 414 | batch = batch//3 415 | 416 | exo_weight = self.fc.weight[target] 417 | exo_weight = exo_weight.view(batch, channel, 1, 1).expand_as(exocentric_branch) 418 | self.exo_feature = (exo_weight * exocentric_branch) 419 | 420 | self.exo_features = torch.ones(batch, self.num_classes, e_h, e_w).cuda() 421 | label_sum = torch.ones_like(label).cuda() 422 | for m in range(0,self.num_classes): 423 | weight = self.fc.weight[label_sum.long()*m] 424 | weight = weight.view(batch, channel, 1, 1).expand_as(exocentric_branch) 425 | self.exo_features[:,m] = (weight * exocentric_branch).mean(1) 426 | 427 | ego_pool = self.avgpool(egocentric_branch) 428 | ego_pool = ego_pool.view(ego_pool.size(0), -1) 429 | self.ego_score = self.fc(ego_pool) 430 | 431 | ego_weight = self.fc.weight[target] 432 | ego_weight = ego_weight.view(batch, channel, 1, 1).expand_as(egocentric_branch) 433 | self.ego_feature = (ego_weight * egocentric_branch) 434 | 435 | self.ego_features = torch.ones(batch, self.num_classes, e_h, e_w).cuda() 436 | label_sum = torch.ones_like(label).cuda() 437 | for m in range(0,self.num_classes): 438 | gweight = self.fc.weight[label_sum.long()*m] 439 | gweight = gweight.view(batch, channel, 1, 1).expand_as(egocentric_branch) 440 | self.ego_features[:,m] = (gweight * egocentric_branch).mean(1) 441 | 442 | # l_rela 443 | self.exo_att = self.exo_features.view(batch, self.num_classes, -1).transpose(1, 2) 444 | self.exo_att = torch.matmul(self.exo_features.view(batch, self.num_classes, -1), self.exo_att) 445 | self.ego_att = self.ego_features.view(batch, self.num_classes, -1).transpose(1, 2) 446 | self.ego_att = torch.matmul(self.ego_features.view(batch, self.num_classes, -1), self.ego_att) 447 | 448 | return self.exo_score, self.ego_score, self.logits_per_text, self.logits_per_image 449 | 450 | @torch.no_grad() 451 | def get(self, egocentric_image, label, text): 452 | 453 | # dino_vit 454 | _, ego_key, ego_attn = self.vit_model.get_all_key(egocentric_image) 455 | ego_desc = ego_key[len(ego_key)-2].permute(0, 2, 3, 1).flatten(-2, -1).detach() 456 | for i in range(len(ego_key)-1, len(ego_key)): 457 | ego_desc = torch.cat((ego_desc, ego_key[i].permute(0, 2, 3, 1).flatten(-2, -1).detach()), dim=2) 458 | ego_proj = self.aff_proj(ego_desc[:, 1:]) 459 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride) 460 | 461 | # text branch 462 | prompts = self.prompt_learner() 463 | tokenized_prompts = self.tokenized_prompts 464 | text_features = self.encode_text(prompts, tokenized_prompts) 465 | 466 | e_b, e_c, e_h, e_w = ego_proj.shape 467 | pre_ego = ego_proj 468 | image_features, ego_proj, mu_att = self.attnpool(ego_proj) 469 | 470 | image_features = F.normalize(image_features, dim=1, p=2) 471 | text_features = F.normalize(text_features, dim=1, p=2) 472 | logit_scale = self.logit_scale.exp() 473 | logits_per_image = logit_scale * image_features @ text_features.t() 474 | logits_per_text = logits_per_image.t() 475 | 476 | text_f = torch.ones((e_b, 1024)).cuda() 477 | for i in range(e_b): 478 | text_f[i] = text_features[label[i]] 479 | att_egoproj = F.normalize(ego_proj, dim=1, p=2) 480 | attego = att_egoproj.permute(1, 0, 2)@text_f.unsqueeze(2) 481 | attego = torch.sigmoid(F.normalize(attego, dim=1, p=2)).permute(1, 0, 2).repeat(1, 1, e_c) 482 | ego_proj = attego.permute(1, 2, 0).view(e_b, e_c, e_h, e_w)*pre_ego + pre_ego 483 | 484 | mu_att = mu_att / torch.sum(mu_att, dim=1, keepdim=True) 485 | mu_att = mu_att / torch.sum(mu_att, dim=2, keepdim=True) 486 | for _ in range(2): 487 | mu_att = mu_att / torch.sum(mu_att, dim=1, keepdim=True) 488 | mu_att = mu_att / torch.sum(mu_att, dim=2, keepdim=True) 489 | mu_att = (mu_att + mu_att.permute(0, 2, 1)) / 2 490 | mu_att = torch.matmul(mu_att, mu_att) 491 | 492 | egocentric_branch =self.aff_ego_proj(ego_proj) 493 | 494 | ego_pool = self.avgpool(egocentric_branch) 495 | ego_pool = ego_pool.view(ego_pool.size(0), -1) 496 | self.ego_score = self.fc(ego_pool) 497 | 498 | target = label.long().squeeze() 499 | batch, channel,_,_ = egocentric_branch.shape 500 | 501 | cam_weight = self.fc.weight[target] 502 | cam_weight = cam_weight.view(batch, channel, 1, 1).expand_as(egocentric_branch) 503 | cam = (cam_weight * egocentric_branch).mean(1) 504 | 505 | cam1 = mu_att@(cam.view(batch, -1, 1)) 506 | cam1 = cam1.view(batch, e_h, e_w) 507 | 508 | return cam, cam1 509 | 510 | def get_loss(self, gt_label, separate=False): 511 | # loss L_cls 512 | b, h = self.exo_score.shape 513 | self.exo_score = self.exo_score.view(b//3, -1, h) 514 | loss_cls = (self.criterion(self.ego_score, gt_label)+ (self.criterion(self.exo_score[:, 0], gt_label) 515 | + self.criterion(self.exo_score[:, 1], gt_label) 516 | + self.criterion(self.exo_score[:, 2], gt_label))/3) 517 | # loss L_d 518 | exo_branch, ego_branch = self.exo_feature,self.ego_feature 519 | exo_pool = F.adaptive_avg_pool2d(exo_branch, 1).view(exo_branch.size(0), -1) 520 | exo_pool = F.normalize(exo_pool, 2, dim=1) 521 | ego_pool = F.adaptive_avg_pool2d(ego_branch, 1).view(ego_branch.size(0), -1) 522 | ego_pool = F.normalize(ego_pool, 2, dim=1) 523 | loss_dist =0.5 * (((exo_pool - ego_pool) ** 2).sum(1)).mean(0) 524 | 525 | # loss L_lrela 526 | exo_att, ego_att = self.exo_att,self.ego_att 527 | attb, c, _ = exo_att.size() 528 | exo_att = F.normalize(exo_att, 2, dim=2).view(attb, -1) 529 | ego_att = F.normalize(ego_att, 2, dim=2).view(attb, -1) 530 | loss_att = 0.5*(1 - F.cosine_similarity(exo_att, ego_att, dim=1).mean(0)) 531 | 532 | return loss_cls, loss_dist, loss_att 533 | 534 | def _reshape_transform(self, tensor, patch_size, stride): 535 | height = (224 - patch_size) // stride + 1 536 | width = (224 - patch_size) // stride + 1 537 | result = tensor.reshape(tensor.size(0), height, width, tensor.size(-1)) 538 | result = result.transpose(2, 3).transpose(1, 2).contiguous() 539 | return result 540 | 541 | def build_attention_mask(self): 542 | # lazily create causal attention mask, with full attention between the vision tokens 543 | 544 | mask = torch.empty(self.context_length, self.context_length) 545 | mask.fill_(float("-inf")) 546 | mask.triu_(1) # zero out the lower diagonal 547 | return mask 548 | 549 | def MODEL(args, num_classes=36, 550 | pretrained=True, n=3, D=512): 551 | dict = 'RN50.pt' # clip's pre-trained model 552 | state_dict = torch.jit.load(dict) 553 | state_dict = state_dict.state_dict() 554 | 555 | embed_dim = state_dict["text_projection"].shape[1] 556 | context_length = state_dict["positional_embedding"].shape[0] 557 | vocab_size = state_dict["token_embedding.weight"].shape[0] 558 | transformer_width = state_dict["ln_final.weight"].shape[0] 559 | transformer_heads = transformer_width // 64 560 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 561 | 562 | model = Model(args, embed_dim=embed_dim, context_length=context_length, vocab_size=vocab_size, transformer_width=transformer_width, 563 | transformer_heads=transformer_heads,transformer_layers=transformer_layers, num_classes=num_classes, pretrained=pretrained, n=n, D=D) 564 | 565 | model_dict = model.state_dict() 566 | par = [] 567 | pretrained_dict = {} 568 | for para in model.named_parameters(): 569 | k = para[0] 570 | if k in state_dict: 571 | par.append(para[0]) 572 | for k, v in state_dict.items(): 573 | if k in model_dict: 574 | pretrained_dict[k] = v 575 | 576 | model_dict.update(pretrained_dict) 577 | model.load_state_dict(model_dict) 578 | return model, par 579 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torch 5 | gt_root="your_path here" 6 | files=os.listdir(gt_root) 7 | dict_1={} 8 | for file in files: 9 | file_path=os.path.join(gt_root,file) 10 | objs=os.listdir(file_path) 11 | for obj in objs: 12 | obj_path=os.path.join(file_path,obj) 13 | images=os.listdir(obj_path) 14 | for img in images: 15 | img_path=os.path.join(obj_path,img) 16 | mask = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 17 | key=file+"_"+obj+"_"+img 18 | dict_1[key]=mask 19 | torch.save(dict_1,"filename.t7") 20 | -------------------------------------------------------------------------------- /save_models/README.md: -------------------------------------------------------------------------------- 1 | # Weakly Supervised Multimodal Affordance Grounding for Egocentric Images 2 | ## Paper 3 | >Weakly Supervised Multimodal Affordance Grounding for Egocentric Images(AAAI 2024) 4 | 5 | Link: https://doi.org/10.1609/aaai.v38i6.28451 6 | 7 | >Appendix 8 | 9 | Link: 10 | 11 | >Video 12 | 13 | Link: 14 | 15 |


16 | 17 | **Abstract:** 18 | 19 | To enhance the interaction between intelligent systems and the environment, locating the affordance regions of objects is crucial. These regions correspond to specific areas that provide distinct functionalities. Humans often acquire the ability to identify these regions through action demonstrations and verbal instructions. In this paper, we present a novel multimodal framework that extracts affordance knowledge from exocentric images, which depict human-object interactions, as well as from accompanying textual descriptions that describe the performed actions. The extracted knowledge is then transferred to egocentric images. To achieve this goal, we propose the HOI-Transfer Module, which utilizes local perception to disentangle individual actions within exocentric images. This module effectively captures localized features and correlations between actions, leading to valuable affordance knowledge. Additionally, we introduce the Pixel-Text Fusion Module, which fuses affordance knowledge by identifying regions in egocentric images that bear resemblances to the textual features defining affordances. We employ a Weakly Supervised Multimodal Affordance (WSMA) learning approach, utilizing image-level labels for training. Through extensive experiments, we demonstrate the superiority of our proposed method in terms of evaluation metrics and visual results when compared to existing affordance grounding models. Furthermore, ablation experiments confirm the effectiveness of our approach. 20 | 21 | ## Requirements 22 | We run in the following environment: 23 | - A GeForce RTX 3090 24 | - Python(3.8) 25 | - Pytorch(1.10.0) 26 | 27 | -------------------------------------------------------------------------------- /simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.autograd import Variable 4 | from PIL import Image 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import cv2 10 | import os 11 | from models.model import MODEL 12 | import argparse 13 | from utils.evaluation import cal_kl, cal_sim, cal_nss 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data_root', type=str, default='data_path') 17 | parser.add_argument('--phase', type=str, default='test') 18 | parser.add_argument("--divide", type=str, default="Unseen") #"Seen" or "Unseen" or "HICO-IIF" 19 | parser.add_argument("--model_path", type=str, default="save_models_path") # the model weight path 20 | parser.add_argument("--crop_size", type=int, default=224) 21 | parser.add_argument("--batch_size", type=int, default=1) 22 | parser.add_argument('--threshold', type=float, default='0.2') 23 | # parser.add_argument("--init_weights", type=bool, default=False) 24 | parser.add_argument('--num_workers', type=int, default=1) 25 | args = parser.parse_args() 26 | 27 | 28 | 29 | def normalize_map(atten_map): 30 | min_val = np.min(atten_map) 31 | max_val = np.max(atten_map) 32 | atten_norm = (atten_map - min_val) / (max_val - min_val + 1e-10) 33 | 34 | return atten_norm 35 | 36 | 37 | if args.divide == "Seen": 38 | args.num_classes = 36 39 | aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with', 40 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel", 41 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo", 42 | "talk_on", "text_on", "throw", "type_on", "wash", "write"] 43 | elif args.divide=="Unseen": 44 | aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with', 45 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel", 46 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", 47 | "swing", "take_photo", "throw", "type_on", "wash"] 48 | else: # HICO-IIF 49 | aff_list = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on'] 50 | 51 | args.test_root = os.path.join(args.data_root, args.divide, "testset", "egocentric") 52 | args.mask_root = os.path.join(args.data_root, args.divide, "testset", "GT") 53 | 54 | model, par = MODEL(args, num_classes=len(aff_list), pretrained=False) 55 | model.load_state_dict(torch.load(args.model_path)) 56 | model.eval() 57 | model.cuda() 58 | 59 | import datatest 60 | 61 | testset = datatest.TrainData(egocentric_root=args.test_root, crop_size=args.crop_size, divide=args.divide, mask_root=args.mask_root) 62 | MyDataLoader = torch.utils.data.DataLoader(dataset=testset, 63 | batch_size=args.batch_size, 64 | shuffle=False, 65 | num_workers=args.num_workers, 66 | pin_memory=True) 67 | 68 | dict_1 = {} 69 | for step, (image, label, mask_path, name) in enumerate(MyDataLoader): 70 | label = label.cuda(non_blocking=True) 71 | image = image.cuda( 72 | non_blocking=True) 73 | cam, cam1 = model.get(image, label, name) 74 | cam = cam[0].cpu().detach().numpy() 75 | cam1 = cam1[0].cpu().detach().numpy() 76 | cam = normalize_map(cam) 77 | cam1 = normalize_map(cam1) 78 | cam1[cam