├── AIM
├── modules
│ ├── .ipynb_checkpoints
│ │ ├── __init__-checkpoint.py
│ │ ├── bread-checkpoint.py
│ │ ├── burger-checkpoint.py
│ │ └── ham-checkpoint.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── bread.cpython-311.pyc
│ │ ├── bread.cpython-37.pyc
│ │ ├── bread.cpython-38.pyc
│ │ ├── burger.cpython-311.pyc
│ │ ├── burger.cpython-37.pyc
│ │ ├── burger.cpython-38.pyc
│ │ ├── ham.cpython-311.pyc
│ │ ├── ham.cpython-37.pyc
│ │ └── ham.cpython-38.pyc
│ ├── bread.py
│ ├── burger.py
│ └── ham.py
├── settings.py
└── sync_bn
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-311.pyc
│ ├── __init__.cpython-37.pyc
│ └── __init__.cpython-38.pyc
│ ├── metric.py
│ ├── nn
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ └── __init__.cpython-38.pyc
│ ├── modules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-311.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── batchnorm.cpython-311.pyc
│ │ │ ├── batchnorm.cpython-37.pyc
│ │ │ ├── batchnorm.cpython-38.pyc
│ │ │ ├── comm.cpython-311.pyc
│ │ │ ├── comm.cpython-37.pyc
│ │ │ ├── comm.cpython-38.pyc
│ │ │ ├── replicate.cpython-311.pyc
│ │ │ ├── replicate.cpython-37.pyc
│ │ │ └── replicate.cpython-38.pyc
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ ├── tests
│ │ │ ├── test_numeric_batchnorm.py
│ │ │ └── test_sync_batchnorm.py
│ │ └── unittest.py
│ └── parallel
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-311.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── data_parallel.cpython-311.pyc
│ │ ├── data_parallel.cpython-37.pyc
│ │ └── data_parallel.cpython-38.pyc
│ │ └── data_parallel.py
│ └── utils
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ ├── dataloader.py
│ ├── dataset.py
│ ├── distributed.py
│ └── sampler.py
│ └── th.py
├── README.md
├── bpe_simple_vocab_16e6.txt.gz
├── datatest.py
├── datatrain.py
├── docs
├── README.md
└── WSMA-appendix.pdf
├── images
├── README.md
├── pipelline.png
└── pipelline1.pdf
├── models
├── dino
│ ├── __pycache__
│ │ ├── utils.cpython-311.pyc
│ │ ├── utils.cpython-37.pyc
│ │ ├── vision_transformer.cpython-311.pyc
│ │ └── vision_transformer.cpython-37.pyc
│ ├── utils.py
│ └── vision_transformer.py
└── model.py
├── preprocessing.py
├── save_models
└── README.md
├── simple_tokenizer.py
├── test.py
├── train.py
└── utils
├── accuracy.py
├── evaluation.py
├── gtransforms.py
├── transform.py
└── util.py
/AIM/modules/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 |
--------------------------------------------------------------------------------
/AIM/modules/.ipynb_checkpoints/bread-checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Hamburger for Pytorch
4 |
5 | @author: Gsunshine
6 | """
7 |
8 | from functools import partial
9 |
10 | import numpy as np
11 | import AIM.settings as settings
12 | import torch
13 | from AIM.sync_bn.nn.modules import SynchronizedBatchNorm2d
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 |
18 | norm_layer = partial(SynchronizedBatchNorm2d, momentum=settings.BN_MOM)
19 |
20 |
21 | class ConvBNReLU(nn.Module):
22 | @classmethod
23 | def _same_paddings(cls, kernel_size):
24 | if kernel_size == 1:
25 | return 0
26 | elif kernel_size == 3:
27 | return 1
28 |
29 | def __init__(self, in_c, out_c,
30 | kernel_size=1, stride=1, padding='same',
31 | dilation=1, groups=1):
32 | super().__init__()
33 |
34 | if padding == 'same':
35 | padding = self._same_paddings(kernel_size)
36 |
37 | self.conv = nn.Conv2d(in_c, out_c,
38 | kernel_size=kernel_size, stride=stride,
39 | padding=padding, dilation=dilation,
40 | groups=groups,
41 | bias=False)
42 | self.bn = norm_layer(out_c)
43 | self.act = nn.ReLU(inplace=True)
44 |
45 | def forward(self, x):
46 | x = self.conv(x)
47 | x = self.bn(x)
48 | x = self.act(x)
49 |
50 | return x
51 |
52 |
--------------------------------------------------------------------------------
/AIM/modules/.ipynb_checkpoints/burger-checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 | from .bread import ConvBNReLU, norm_layer
10 | from .ham import get_hams
11 |
12 |
13 | class HamburgerV1(nn.Module):
14 | def __init__(self, in_c, n=3, D=512, args=None):
15 | super().__init__()
16 |
17 | ham_type = 'NMF'
18 | self.n = n
19 |
20 | D = getattr(args, 'MD_D', D)
21 |
22 | self.lower_bread = nn.Sequential(nn.Conv2d(in_c, D, 1),
23 | nn.ReLU(inplace=True))
24 |
25 | HAM = get_hams(ham_type)
26 |
27 | self.ham = HAM(args, D=D)
28 |
29 | self.upper_bread = nn.Sequential(nn.Conv2d(D, in_c, 1, bias=False),
30 | norm_layer(in_c))
31 | self.shortcut = nn.Sequential()
32 |
33 | self._init_weight()
34 |
35 | print('ham', HAM)
36 |
37 | def _init_weight(self):
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | m.weight.data.normal_(0, np.sqrt(2. / N))
42 | elif isinstance(m, _BatchNorm):
43 | m.weight.data.fill_(1)
44 | if m.bias is not None:
45 | m.bias.data.zero_()
46 |
47 | def forward(self, x):
48 | _, c, h, w = x.size()
49 | shortcut = self.shortcut(x) # 存储一个备份用于后面相加
50 | x = self.lower_bread(x)
51 | x_c = x.size(1)
52 | x = x.view(-1, self.n, x_c, h, w)
53 | x = self.ham(x)
54 | x = x.contiguous().view(-1, x_c, h, w)
55 | x = self.upper_bread(x)
56 | x = F.relu(x + shortcut, inplace=True)
57 |
58 | return x
59 |
60 | def online_update(self, bases):
61 | if hasattr(self.ham, 'online_update'):
62 | self.ham.online_update(bases)
63 |
--------------------------------------------------------------------------------
/AIM/modules/.ipynb_checkpoints/ham-checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn.modules.batchnorm import _BatchNorm
6 |
7 |
8 | class _MatrixDecomposition2DBase(nn.Module):
9 | def __init__(self, args, D):
10 | super().__init__()
11 |
12 | self.spatial = getattr(args, 'SPATIAL', True)
13 | self.S = getattr(args, 'MD_S', 1)
14 | self.D = D
15 | # self.D = getattr(args, 'MD_D', 512)
16 | self.R = getattr(args, 'MD_R', 64)
17 |
18 | self.train_steps = getattr(args, 'TRAIN_STEPS', 6)
19 | self.eval_steps = getattr(args, 'EVAL_STEPS', 7)
20 |
21 | self.inv_t = getattr(args, 'INV_T', 1)
22 | self.eta = getattr(args, 'ETA', 0.9)
23 |
24 | print('spatial', self.spatial)
25 | print('S', self.S)
26 | print('D', self.D)
27 | print('R', self.R)
28 | print('train_steps', self.train_steps)
29 | print('eval_steps', self.eval_steps)
30 | print('inv_t', self.inv_t)
31 | print('eta', self.eta)
32 |
33 | self.bases = self._build_bases(1, self.S, self.D, self.R, cuda=True)
34 |
35 | def _build_bases(self, B, S, D, R, cuda=False):
36 | raise NotImplementedError
37 |
38 | def local_step(self, x, bases, coef):
39 | raise NotImplementedError
40 |
41 | @torch.no_grad()
42 | def local_inference(self, x, bases):
43 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
44 | coef = torch.bmm(x.transpose(1, 2), bases)
45 | coef = F.softmax(self.inv_t * coef, dim=-1)
46 |
47 | steps = self.train_steps if self.training else self.eval_steps
48 | for _ in range(steps):
49 | bases, coef = self.local_step(x, bases, coef)
50 |
51 | return bases, coef
52 |
53 | def compute_coef(self, x, bases, coef):
54 | raise NotImplementedError
55 |
56 | def forward(self, x, return_bases=False):
57 | B, Num, C, H, W = x.shape
58 |
59 | # (B, C, H, W) -> (B * S, D, N)
60 |
61 | D = C // self.S
62 | N = Num * H * W
63 | x = x.permute(0, 2, 1, 3, 4) # [B,C,Num,H,W]
64 |
65 | x = x.contiguous().view(B * self.S, D, N) #### [B,C,Num*H*W]
66 | bases = self.bases.repeat(B, 1, 1) ### [B*S,D,R]
67 | bases, coef = self.local_inference(x, bases)
68 | # (B * S, N, R)
69 | coef = self.compute_coef(x, bases, coef)
70 |
71 | # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
72 | x = torch.bmm(bases, coef.transpose(1, 2))
73 |
74 | # (B * S, D, N) -> (B, C, H, W)
75 | x = x.view(B, C, Num, H, W)
76 | x = x.permute(0, 2, 1, 3, 4)
77 | # (B * H, D, R) -> (B, H, N, D)
78 | bases = bases.view(B, self.S, D, self.R)
79 | self.online_update(bases)
80 | return x
81 |
82 | @torch.no_grad()
83 | def online_update(self, bases):
84 | # (B, S, D, R) -> (S, D, R)
85 | bases = bases
86 | update = bases.mean(dim=0)
87 | self.bases = self.bases
88 | self.bases += self.eta * (update - self.bases)
89 | self.bases = F.normalize(self.bases, dim=1)
90 |
91 |
92 | class NMF2D(_MatrixDecomposition2DBase):
93 | def __init__(self, args, D):
94 | super().__init__(args, D)
95 | self.inv_t = 1
96 | self.D = D
97 |
98 | def _build_bases(self, B, S, D, R, cuda=True):
99 | if cuda:
100 | bases = torch.rand((B * S, D, R)).cuda()
101 |
102 | else:
103 | bases = torch.rand((B * S, D, R))
104 |
105 | bases = F.normalize(bases, dim=1)
106 |
107 | return bases
108 |
109 | @torch.no_grad()
110 | def local_step(self, x, bases, coef):
111 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
112 |
113 | numerator = torch.bmm(x.transpose(1, 2), bases)
114 | # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
115 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
116 | # Multiplicative Update
117 | coef = coef * numerator / (denominator + 1e-6)
118 |
119 | # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
120 | numerator = torch.bmm(x, coef)
121 | # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
122 | denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
123 | # Multiplicative Update
124 | bases = bases * numerator / (denominator + 1e-6)
125 |
126 | return bases, coef
127 |
128 | def compute_coef(self, x, bases, coef):
129 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
130 | numerator = torch.bmm(x.transpose(1, 2), bases)
131 | # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
132 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
133 | # multiplication update
134 | coef = coef * numerator / (denominator + 1e-6)
135 |
136 | return coef
137 |
138 |
139 | def get_hams(key):
140 | hams = {'NMF': NMF2D}
141 |
142 | assert key in hams
143 | return hams[key]
144 |
--------------------------------------------------------------------------------
/AIM/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 |
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/bread.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/bread.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/bread.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/bread.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/bread.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/bread.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/burger.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/burger.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/burger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/burger.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/burger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/burger.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/ham.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/ham.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/ham.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/ham.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/modules/__pycache__/ham.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/modules/__pycache__/ham.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/modules/bread.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Hamburger for Pytorch
4 |
5 | @author: Gsunshine
6 | """
7 |
8 | from functools import partial
9 |
10 | import numpy as np
11 | import AIM.settings as settings
12 | import torch
13 | from AIM.sync_bn.nn.modules import SynchronizedBatchNorm2d
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 |
18 | norm_layer = partial(SynchronizedBatchNorm2d, momentum=settings.BN_MOM)
19 |
20 |
21 | class ConvBNReLU(nn.Module):
22 | @classmethod
23 | def _same_paddings(cls, kernel_size):
24 | if kernel_size == 1:
25 | return 0
26 | elif kernel_size == 3:
27 | return 1
28 |
29 | def __init__(self, in_c, out_c,
30 | kernel_size=1, stride=1, padding='same',
31 | dilation=1, groups=1):
32 | super().__init__()
33 |
34 | if padding == 'same':
35 | padding = self._same_paddings(kernel_size)
36 |
37 | self.conv = nn.Conv2d(in_c, out_c,
38 | kernel_size=kernel_size, stride=stride,
39 | padding=padding, dilation=dilation,
40 | groups=groups,
41 | bias=False)
42 | self.bn = norm_layer(out_c)
43 | self.act = nn.ReLU(inplace=True)
44 |
45 | def forward(self, x):
46 | x = self.conv(x)
47 | x = self.bn(x)
48 | x = self.act(x)
49 |
50 | return x
51 |
52 |
--------------------------------------------------------------------------------
/AIM/modules/burger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 | from .bread import ConvBNReLU, norm_layer
10 | from .ham import get_hams
11 |
12 |
13 | class HamburgerV1(nn.Module):
14 | def __init__(self, in_c, n=3, D=512, args=None):
15 | super().__init__()
16 |
17 | ham_type = 'NMF'
18 | self.n = n
19 |
20 | D = getattr(args, 'MD_D', D)
21 |
22 | self.lower_bread = nn.Sequential(nn.Conv2d(in_c, D, 1),
23 | nn.ReLU(inplace=True))
24 |
25 | HAM = get_hams(ham_type)
26 |
27 | self.ham = HAM(args, D=D)
28 |
29 | self.upper_bread = nn.Sequential(nn.Conv2d(D, in_c, 1, bias=False),
30 | norm_layer(in_c))
31 | self.shortcut = nn.Sequential()
32 |
33 | self._init_weight()
34 |
35 | print('ham', HAM)
36 |
37 | def _init_weight(self):
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | N = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | m.weight.data.normal_(0, np.sqrt(2. / N))
42 | elif isinstance(m, _BatchNorm):
43 | m.weight.data.fill_(1)
44 | if m.bias is not None:
45 | m.bias.data.zero_()
46 |
47 | def forward(self, x):
48 | _, c, h, w = x.size()
49 | shortcut = self.shortcut(x) # 存储一个备份用于后面相加
50 | x = self.lower_bread(x)
51 | x_c = x.size(1)
52 | x = x.view(-1, self.n, x_c, h, w)
53 | x = self.ham(x)
54 | x = x.contiguous().view(-1, x_c, h, w)
55 | x = self.upper_bread(x)
56 | x = F.relu(x + shortcut, inplace=True)
57 |
58 | return x
59 |
60 | def online_update(self, bases):
61 | if hasattr(self.ham, 'online_update'):
62 | self.ham.online_update(bases)
63 |
--------------------------------------------------------------------------------
/AIM/modules/ham.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn.modules.batchnorm import _BatchNorm
6 |
7 |
8 | class _MatrixDecomposition2DBase(nn.Module):
9 | def __init__(self, args, D):
10 | super().__init__()
11 |
12 | self.spatial = getattr(args, 'SPATIAL', True)
13 | self.S = getattr(args, 'MD_S', 1)
14 | self.D = D
15 | # self.D = getattr(args, 'MD_D', 512)
16 | self.R = getattr(args, 'MD_R', 64)
17 |
18 | self.train_steps = getattr(args, 'TRAIN_STEPS', 6)
19 | self.eval_steps = getattr(args, 'EVAL_STEPS', 7)
20 |
21 | self.inv_t = getattr(args, 'INV_T', 1)
22 | self.eta = getattr(args, 'ETA', 0.9)
23 |
24 | print('spatial', self.spatial)
25 | print('S', self.S)
26 | print('D', self.D)
27 | print('R', self.R)
28 | print('train_steps', self.train_steps)
29 | print('eval_steps', self.eval_steps)
30 | print('inv_t', self.inv_t)
31 | print('eta', self.eta)
32 |
33 | self.bases = self._build_bases(1, self.S, self.D, self.R, cuda=True)
34 |
35 | def _build_bases(self, B, S, D, R, cuda=False):
36 | raise NotImplementedError
37 |
38 | def local_step(self, x, bases, coef):
39 | raise NotImplementedError
40 |
41 | @torch.no_grad()
42 | def local_inference(self, x, bases):
43 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
44 | coef = torch.bmm(x.transpose(1, 2), bases)
45 | coef = F.softmax(self.inv_t * coef, dim=-1)
46 |
47 | steps = self.train_steps if self.training else self.eval_steps
48 | for _ in range(steps):
49 | bases, coef = self.local_step(x, bases, coef)
50 |
51 | return bases, coef
52 |
53 | def compute_coef(self, x, bases, coef):
54 | raise NotImplementedError
55 |
56 | def forward(self, x, return_bases=False):
57 | B, Num, C, H, W = x.shape
58 |
59 | # (B, C, H, W) -> (B * S, D, N)
60 |
61 | D = C // self.S
62 | N = Num * H * W
63 | x = x.permute(0, 2, 1, 3, 4) # [B,C,Num,H,W]
64 |
65 | x = x.contiguous().view(B * self.S, D, N) #### [B,C,Num*H*W]
66 | bases = self.bases.repeat(B, 1, 1) ### [B*S,D,R]
67 | bases, coef = self.local_inference(x, bases)
68 | # (B * S, N, R)
69 | coef = self.compute_coef(x, bases, coef)
70 |
71 | # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
72 | x = torch.bmm(bases, coef.transpose(1, 2))
73 |
74 | # (B * S, D, N) -> (B, C, H, W)
75 | x = x.view(B, C, Num, H, W)
76 | x = x.permute(0, 2, 1, 3, 4)
77 | # (B * H, D, R) -> (B, H, N, D)
78 | bases = bases.view(B, self.S, D, self.R)
79 | self.online_update(bases)
80 | return x
81 |
82 | @torch.no_grad()
83 | def online_update(self, bases):
84 | # (B, S, D, R) -> (S, D, R)
85 | bases = bases
86 | update = bases.mean(dim=0)
87 | self.bases = self.bases
88 | self.bases += self.eta * (update - self.bases)
89 | self.bases = F.normalize(self.bases, dim=1)
90 |
91 |
92 | class NMF2D(_MatrixDecomposition2DBase):
93 | def __init__(self, args, D):
94 | super().__init__(args, D)
95 | self.inv_t = 1
96 | self.D = D
97 |
98 | def _build_bases(self, B, S, D, R, cuda=True):
99 | if cuda:
100 | bases = torch.rand((B * S, D, R)).cuda()
101 |
102 | else:
103 | bases = torch.rand((B * S, D, R))
104 |
105 | bases = F.normalize(bases, dim=1)
106 |
107 | return bases
108 |
109 | @torch.no_grad()
110 | def local_step(self, x, bases, coef):
111 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
112 |
113 | numerator = torch.bmm(x.transpose(1, 2), bases)
114 | # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
115 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
116 | # Multiplicative Update
117 | coef = coef * numerator / (denominator + 1e-6)
118 |
119 | # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
120 | numerator = torch.bmm(x, coef)
121 | # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
122 | denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
123 | # Multiplicative Update
124 | bases = bases * numerator / (denominator + 1e-6)
125 |
126 | return bases, coef
127 |
128 | def compute_coef(self, x, bases, coef):
129 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
130 | numerator = torch.bmm(x.transpose(1, 2), bases)
131 | # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
132 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
133 | # multiplication update
134 | coef = coef * numerator / (denominator + 1e-6)
135 |
136 | return coef
137 |
138 |
139 | def get_hams(key):
140 | hams = {'NMF': NMF2D}
141 |
142 | assert key in hams
143 | return hams[key]
144 |
--------------------------------------------------------------------------------
/AIM/settings.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 |
8 | # Data settings
9 | MEAN = torch.Tensor(np.array([0.485, 0.456, 0.406]))
10 | STD = torch.Tensor(np.array([0.229, 0.224, 0.225]))
11 | SCALES = (0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0)
12 | TEST_SCALES = (0.5, 0.75, 1.0, 1.25, 1.5, 1.75)
13 | CROP_SIZE = 513
14 | IGNORE_LABEL = 255
15 | NUM_WORKERS = 64
16 |
17 |
18 | # Training settings
19 | RUN_FOR_TEST = True
20 |
21 | TRAIN_BATCH_SIZE = 16
22 | VAL_BATCH_SIZE = 1
23 |
24 | ITER_MAX = 60000
25 | ITER_SAVE = 1000
26 | ITER_VAL = 1000
27 |
28 | TEST_ITER_MAX = 20000
29 | TEST_ITER_SAVE = 1000
30 | TEST_ITER_VAL = 1000
31 |
32 | LR_DECAY = 10
33 | LR = 9e-3
34 | LR_MOM = 0.9
35 | POLY_POWER = 0.9
36 | WEIGHT_DECAY = 1e-4
37 |
38 |
39 | # Network
40 | N_CLASSES = 21
41 | N_LAYERS = 101
42 | STRIDE = 8
43 | BN_MOM = 3e-4
44 |
45 |
46 | # Hamburger
47 | CHANNELS = 512
48 | VERSION = 'V2'
49 |
50 | HAM_TYPE = 'NMF'
51 | DUAL = False
52 | SPATIAL = True
53 | RAND_INIT = True
54 |
55 | MD_S = 1
56 | MD_D = 512
57 | MD_R = 64
58 |
59 | CHEESE_FACTOR = 1
60 | TRAIN_STEPS = 6
61 | EVAL_STEPS = 6
62 |
63 | INV_T = 1
64 | BETA = 0.1
65 | ETA = 0.9
66 |
67 |
68 | # Tensorboard
69 | logger = logging.getLogger('train')
70 | logger.setLevel(logging.INFO)
71 | ch = logging.StreamHandler()
72 | ch.setLevel(logging.INFO)
73 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
74 | ch.setFormatter(formatter)
75 | logger.addHandler(ch)
76 |
--------------------------------------------------------------------------------
/AIM/sync_bn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__init__.py
--------------------------------------------------------------------------------
/AIM/sync_bn/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import settings
4 |
5 |
6 | def fast_hist(label_true, label_pred):
7 | n_class = settings.N_CLASSES
8 | mask = (label_true >= 0) & (label_true < n_class)
9 | hist = torch.bincount(
10 | n_class * label_true[mask].int() + label_pred[mask].int(),
11 | minlength=n_class ** 2,
12 | ).reshape(n_class, n_class)
13 | return hist
14 |
15 |
16 | label_names = [
17 | 'background',
18 | 'aeroplane',
19 | 'bicycle',
20 | 'bird',
21 | 'boat',
22 | 'bottle',
23 | 'bus',
24 | 'car',
25 | 'cat',
26 | 'chair',
27 | 'cow',
28 | 'diningtable',
29 | 'dog',
30 | 'horse',
31 | 'motorbike',
32 | 'person',
33 | 'pottedplant',
34 | 'sheep',
35 | 'sofa',
36 | 'train',
37 | 'tvmonitor',
38 | ]
39 |
40 |
41 | def cal_scores(hist):
42 | n_class = settings.N_CLASSES
43 | #acc = np.diag(hist).sum() / hist.sum()
44 | #acc_cls = np.diag(hist) / hist.sum(axis=1)
45 | #acc_cls = np.nanmean(acc_cls)
46 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
47 | mean_iu = np.nanmean(iu)
48 | freq = hist.sum(axis=1) / hist.sum()
49 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
50 | cls_iu = dict(zip(label_names, iu))
51 |
52 | return {
53 | #"pAcc": acc,
54 | #"mAcc": acc_cls,
55 | "fIoU": fwavacc,
56 | "mIoU": mean_iu,
57 | }#, cls_iu
58 |
59 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
3 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/batchnorm.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/comm.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/modules/__pycache__/replicate.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=3e-4, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | # customed batch norm statistics
49 | self.momentum = momentum
50 |
51 | def forward(self, input, weight=None, bias=None):
52 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
53 | if not (self._is_parallel and self.training):
54 | return F.batch_norm(
55 | input, self.running_mean, self.running_var, self.weight, self.bias,
56 | self.training, self.momentum, self.eps)
57 |
58 | # Resize the input to (B, C, -1).
59 | input_shape = input.size()
60 | input = input.view(input.size(0), self.num_features, -1)
61 |
62 | # Compute the sum and square-sum.
63 | sum_size = input.size(0) * input.size(2)
64 | input_sum = _sum_ft(input)
65 | input_ssum = _sum_ft(input ** 2)
66 |
67 | # Reduce-and-broadcast the statistics.
68 | if self._parallel_id == 0:
69 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
70 | else:
71 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
72 |
73 | # Compute the output.
74 | if self.affine:
75 | if weight is None or bias is None:
76 | weight = self.weight
77 | bias = self.bias
78 |
79 | # MJY:: Fuse the multiplication for speed.
80 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * weight) + _unsqueeze_ft(bias)
81 | else:
82 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
83 |
84 | # Reshape it.
85 | return output.view(input_shape)
86 |
87 | def __data_parallel_replicate__(self, ctx, copy_id):
88 | self._is_parallel = True
89 | self._parallel_id = copy_id
90 |
91 | # parallel_id == 0 means master device.
92 | if self._parallel_id == 0:
93 | ctx.sync_master = self._sync_master
94 | else:
95 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
96 |
97 | def _data_parallel_master(self, intermediates):
98 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
99 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
100 |
101 | to_reduce = [i[1][:2] for i in intermediates]
102 | to_reduce = [j for i in to_reduce for j in i] # flatten
103 | target_gpus = [i[1].sum.get_device() for i in intermediates]
104 |
105 | sum_size = sum([i[1].sum_size for i in intermediates])
106 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
107 |
108 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
109 |
110 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
111 |
112 | outputs = []
113 | for i, rec in enumerate(intermediates):
114 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
115 |
116 | return outputs
117 |
118 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
119 | """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
120 | return dest * alpha + delta * beta + bias
121 |
122 | def _compute_mean_std(self, sum_, ssum, size):
123 | """Compute the mean and standard-deviation with sum and square-sum. This method
124 | also maintains the moving average on the master device."""
125 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
126 | mean = sum_ / size
127 | sumvar = ssum - sum_ * mean
128 | unbias_var = sumvar / (size - 1)
129 | bias_var = sumvar / size
130 |
131 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
132 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
133 |
134 | return mean, (bias_var + self.eps) ** -0.5
135 |
136 |
137 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
138 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
139 | mini-batch.
140 |
141 | .. math::
142 |
143 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
144 |
145 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
146 | standard-deviation are reduced across all devices during training.
147 |
148 | For example, when one uses `nn.DataParallel` to wrap the network during
149 | training, PyTorch's implementation normalize the tensor on each device using
150 | the statistics only on that device, which accelerated the computation and
151 | is also easy to implement, but the statistics might be inaccurate.
152 | Instead, in this synchronized version, the statistics will be computed
153 | over all training samples distributed on multiple devices.
154 |
155 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
156 | as the built-in PyTorch implementation.
157 |
158 | The mean and standard-deviation are calculated per-dimension over
159 | the mini-batches and gamma and beta are learnable parameter vectors
160 | of size C (where C is the input size).
161 |
162 | During training, this layer keeps a running estimate of its computed mean
163 | and variance. The running sum is kept with a default momentum of 0.1.
164 |
165 | During evaluation, this running mean/variance is used for normalization.
166 |
167 | Because the BatchNorm is done over the `C` dimension, computing statistics
168 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
169 |
170 | Args:
171 | num_features: num_features from an expected input of size
172 | `batch_size x num_features [x width]`
173 | eps: a value added to the denominator for numerical stability.
174 | Default: 1e-5
175 | momentum: the value used for the running_mean and running_var
176 | computation. Default: 0.1
177 | affine: a boolean value that when set to ``True``, gives the layer learnable
178 | affine parameters. Default: ``True``
179 |
180 | Shape:
181 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
182 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
183 |
184 | Examples:
185 | >>> # With Learnable Parameters
186 | >>> m = SynchronizedBatchNorm1d(100)
187 | >>> # Without Learnable Parameters
188 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
189 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
190 | >>> output = m(input)
191 | """
192 |
193 | def _check_input_dim(self, input):
194 | if input.dim() != 2 and input.dim() != 3:
195 | raise ValueError('expected 2D or 3D input (got {}D input)'
196 | .format(input.dim()))
197 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
198 |
199 |
200 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
201 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
202 | of 3d inputs
203 |
204 | .. math::
205 |
206 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
207 |
208 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
209 | standard-deviation are reduced across all devices during training.
210 |
211 | For example, when one uses `nn.DataParallel` to wrap the network during
212 | training, PyTorch's implementation normalize the tensor on each device using
213 | the statistics only on that device, which accelerated the computation and
214 | is also easy to implement, but the statistics might be inaccurate.
215 | Instead, in this synchronized version, the statistics will be computed
216 | over all training samples distributed on multiple devices.
217 |
218 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
219 | as the built-in PyTorch implementation.
220 |
221 | The mean and standard-deviation are calculated per-dimension over
222 | the mini-batches and gamma and beta are learnable parameter vectors
223 | of size C (where C is the input size).
224 |
225 | During training, this layer keeps a running estimate of its computed mean
226 | and variance. The running sum is kept with a default momentum of 0.1.
227 |
228 | During evaluation, this running mean/variance is used for normalization.
229 |
230 | Because the BatchNorm is done over the `C` dimension, computing statistics
231 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
232 |
233 | Args:
234 | num_features: num_features from an expected input of
235 | size batch_size x num_features x height x width
236 | eps: a value added to the denominator for numerical stability.
237 | Default: 1e-5
238 | momentum: the value used for the running_mean and running_var
239 | computation. Default: 0.1
240 | affine: a boolean value that when set to ``True``, gives the layer learnable
241 | affine parameters. Default: ``True``
242 |
243 | Shape:
244 | - Input: :math:`(N, C, H, W)`
245 | - Output: :math:`(N, C, H, W)` (same shape as input)
246 |
247 | Examples:
248 | >>> # With Learnable Parameters
249 | >>> m = SynchronizedBatchNorm2d(100)
250 | >>> # Without Learnable Parameters
251 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
252 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
253 | >>> output = m(input)
254 | """
255 |
256 | def _check_input_dim(self, input):
257 | if input.dim() != 4:
258 | raise ValueError('expected 4D input (got {}D input)'
259 | .format(input.dim()))
260 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
261 |
262 |
263 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
264 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
265 | of 4d inputs
266 |
267 | .. math::
268 |
269 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
270 |
271 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
272 | standard-deviation are reduced across all devices during training.
273 |
274 | For example, when one uses `nn.DataParallel` to wrap the network during
275 | training, PyTorch's implementation normalize the tensor on each device using
276 | the statistics only on that device, which accelerated the computation and
277 | is also easy to implement, but the statistics might be inaccurate.
278 | Instead, in this synchronized version, the statistics will be computed
279 | over all training samples distributed on multiple devices.
280 |
281 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
282 | as the built-in PyTorch implementation.
283 |
284 | The mean and standard-deviation are calculated per-dimension over
285 | the mini-batches and gamma and beta are learnable parameter vectors
286 | of size C (where C is the input size).
287 |
288 | During training, this layer keeps a running estimate of its computed mean
289 | and variance. The running sum is kept with a default momentum of 0.1.
290 |
291 | During evaluation, this running mean/variance is used for normalization.
292 |
293 | Because the BatchNorm is done over the `C` dimension, computing statistics
294 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
295 | or Spatio-temporal BatchNorm
296 |
297 | Args:
298 | num_features: num_features from an expected input of
299 | size batch_size x num_features x depth x height x width
300 | eps: a value added to the denominator for numerical stability.
301 | Default: 1e-5
302 | momentum: the value used for the running_mean and running_var
303 | computation. Default: 0.1
304 | affine: a boolean value that when set to ``True``, gives the layer learnable
305 | affine parameters. Default: ``True``
306 |
307 | Shape:
308 | - Input: :math:`(N, C, D, H, W)`
309 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
310 |
311 | Examples:
312 | >>> # With Learnable Parameters
313 | >>> m = SynchronizedBatchNorm3d(100)
314 | >>> # Without Learnable Parameters
315 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
316 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
317 | >>> output = m(input)
318 | """
319 |
320 | def _check_input_dim(self, input):
321 | if input.dim() != 5:
322 | raise ValueError('expected 5D input (got {}D input)'
323 | .format(input.dim()))
324 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
325 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/tests/test_numeric_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_numeric_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm.unittest import TorchTestCase
16 |
17 |
18 | def handy_var(a, unbias=True):
19 | n = a.size(0)
20 | asum = a.sum(dim=0)
21 | as_sum = (a ** 2).sum(dim=0) # a square sum
22 | sumvar = as_sum - asum * asum / n
23 | if unbias:
24 | return sumvar / (n - 1)
25 | else:
26 | return sumvar / n
27 |
28 |
29 | class NumericTestCase(TorchTestCase):
30 | def testNumericBatchNorm(self):
31 | a = torch.rand(16, 10)
32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33 | bn.train()
34 |
35 | a_var1 = Variable(a, requires_grad=True)
36 | b_var1 = bn(a_var1)
37 | loss1 = b_var1.sum()
38 | loss1.backward()
39 |
40 | a_var2 = Variable(a, requires_grad=True)
41 | a_mean2 = a_var2.mean(dim=0, keepdim=True)
42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44 | b_var2 = (a_var2 - a_mean2) / a_std2
45 | loss2 = b_var2.sum()
46 | loss2.backward()
47 |
48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49 | self.assertTensorClose(bn.running_var, handy_var(a))
50 | self.assertTensorClose(a_var1.data, a_var2.data)
51 | self.assertTensorClose(b_var1.data, b_var2.data)
52 | self.assertTensorClose(a_var1.grad, a_var2.grad)
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/tests/test_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_sync_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16 | from sync_batchnorm.unittest import TorchTestCase
17 |
18 |
19 | def handy_var(a, unbias=True):
20 | n = a.size(0)
21 | asum = a.sum(dim=0)
22 | as_sum = (a ** 2).sum(dim=0) # a square sum
23 | sumvar = as_sum - asum * asum / n
24 | if unbias:
25 | return sumvar / (n - 1)
26 | else:
27 | return sumvar / n
28 |
29 |
30 | def _find_bn(module):
31 | for m in module.modules():
32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33 | return m
34 |
35 |
36 | class SyncTestCase(TorchTestCase):
37 | def _syncParameters(self, bn1, bn2):
38 | bn1.reset_parameters()
39 | bn2.reset_parameters()
40 | if bn1.affine and bn2.affine:
41 | bn2.weight.data.copy_(bn1.weight.data)
42 | bn2.bias.data.copy_(bn1.bias.data)
43 |
44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45 | """Check the forward and backward for the customized batch normalization."""
46 | bn1.train(mode=is_train)
47 | bn2.train(mode=is_train)
48 |
49 | if cuda:
50 | input = input.cuda()
51 |
52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53 |
54 | input1 = Variable(input, requires_grad=True)
55 | output1 = bn1(input1)
56 | output1.sum().backward()
57 | input2 = Variable(input, requires_grad=True)
58 | output2 = bn2(input2)
59 | output2.sum().backward()
60 |
61 | self.assertTensorClose(input1.data, input2.data)
62 | self.assertTensorClose(output1.data, output2.data)
63 | self.assertTensorClose(input1.grad, input2.grad)
64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66 |
67 | def testSyncBatchNormNormalTrain(self):
68 | bn = nn.BatchNorm1d(10)
69 | sync_bn = SynchronizedBatchNorm1d(10)
70 |
71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72 |
73 | def testSyncBatchNormNormalEval(self):
74 | bn = nn.BatchNorm1d(10)
75 | sync_bn = SynchronizedBatchNorm1d(10)
76 |
77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78 |
79 | def testSyncBatchNormSyncTrain(self):
80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83 |
84 | bn.cuda()
85 | sync_bn.cuda()
86 |
87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88 |
89 | def testSyncBatchNormSyncEval(self):
90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93 |
94 | bn.cuda()
95 | sync_bn.cuda()
96 |
97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98 |
99 | def testSyncBatchNorm2DSyncTrain(self):
100 | bn = nn.BatchNorm2d(10)
101 | sync_bn = SynchronizedBatchNorm2d(10)
102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103 |
104 | bn.cuda()
105 | sync_bn.cuda()
106 |
107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/modules/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
2 |
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-311.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-37.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/AIM/sync_bn/nn/parallel/__pycache__/data_parallel.cpython-38.pyc
--------------------------------------------------------------------------------
/AIM/sync_bn/nn/parallel/data_parallel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf8 -*-
2 |
3 | import torch.cuda as cuda
4 | import torch.nn as nn
5 | import torch
6 | from torch.autograd import Variable
7 | import collections
8 | from torch.nn.parallel._functions import Gather
9 |
10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11 |
12 | def async_copy_to(obj, dev, main_stream=None):
13 | if torch.is_tensor(obj):
14 | obj = Variable(obj)
15 | if isinstance(obj, Variable):
16 | v = obj.cuda(dev)
17 | if main_stream is not None:
18 | v.data.record_stream(main_stream)
19 | return v
20 | elif isinstance(obj, collections.Mapping):
21 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
22 | elif isinstance(obj, collections.Sequence):
23 | return [async_copy_to(o, dev, main_stream) for o in obj]
24 | else:
25 | return obj
26 |
27 |
28 | def dict_gather(outputs, target_device, dim=0):
29 | """
30 | Gathers variables from different GPUs on a specified device
31 | (-1 means the CPU), with dictionary support.
32 | """
33 | def gather_map(outputs):
34 | out = outputs[0]
35 | if isinstance(out, Variable):
36 | # MJY(20180330) HACK:: force nr_dims > 0
37 | if out.dim() == 0:
38 | outputs = [o.unsqueeze(0) for o in outputs]
39 | return Gather.apply(target_device, dim, *outputs)
40 | elif out is None:
41 | return None
42 | elif isinstance(out, collections.Mapping):
43 | return {k: gather_map([o[k] for o in outputs]) for k in out}
44 | elif isinstance(out, collections.Sequence):
45 | return type(out)(map(gather_map, zip(*outputs)))
46 | return gather_map(outputs)
47 |
48 |
49 | class DictGatherDataParallel(nn.DataParallel):
50 | def gather(self, outputs, output_device):
51 | return dict_gather(outputs, output_device, dim=self.dim)
52 |
53 |
54 | class UserScatteredDataParallel(DictGatherDataParallel):
55 | def scatter(self, inputs, kwargs, device_ids):
56 | assert len(inputs) == 1
57 | inputs = inputs[0]
58 | inputs = _async_copy_stream(inputs, device_ids)
59 | inputs = [[i] for i in inputs]
60 | assert len(kwargs) == 0
61 | kwargs = [{} for _ in range(len(inputs))]
62 |
63 | return inputs, kwargs
64 |
65 |
66 | def user_scattered_collate(batch):
67 | return batch
68 |
69 |
70 | def _async_copy(inputs, device_ids):
71 | nr_devs = len(device_ids)
72 | assert type(inputs) in (tuple, list)
73 | assert len(inputs) == nr_devs
74 |
75 | outputs = []
76 | for i, dev in zip(inputs, device_ids):
77 | with cuda.device(dev):
78 | outputs.append(async_copy_to(i, dev))
79 |
80 | return tuple(outputs)
81 |
82 |
83 | def _async_copy_stream(inputs, device_ids):
84 | nr_devs = len(device_ids)
85 | assert type(inputs) in (tuple, list)
86 | assert len(inputs) == nr_devs
87 |
88 | outputs = []
89 | streams = [_get_stream(d) for d in device_ids]
90 | for i, dev, stream in zip(inputs, device_ids, streams):
91 | with cuda.device(dev):
92 | main_stream = cuda.current_stream()
93 | with cuda.stream(stream):
94 | outputs.append(async_copy_to(i, dev, main_stream=main_stream))
95 | main_stream.wait_stream(stream)
96 |
97 | return outputs
98 |
99 |
100 | """Adapted from: torch/nn/parallel/_functions.py"""
101 | # background streams used for copying
102 | _streams = None
103 |
104 |
105 | def _get_stream(device):
106 | """Gets a background stream for copying between CPU and GPU"""
107 | global _streams
108 | if device == -1:
109 | return None
110 | if _streams is None:
111 | _streams = [None] * cuda.device_count()
112 | if _streams[device] is None: _streams[device] = cuda.Stream(device)
113 | return _streams[device]
114 |
--------------------------------------------------------------------------------
/AIM/sync_bn/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .th import *
2 |
--------------------------------------------------------------------------------
/AIM/sync_bn/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .dataset import Dataset, TensorDataset, ConcatDataset
3 | from .dataloader import DataLoader
4 |
--------------------------------------------------------------------------------
/AIM/sync_bn/utils/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.multiprocessing as multiprocessing
3 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
4 | _remove_worker_pids, _error_if_any_worker_fails
5 | from .sampler import SequentialSampler, RandomSampler, BatchSampler
6 | import signal
7 | import functools
8 | import collections
9 | import re
10 | import sys
11 | import threading
12 | import traceback
13 | from torch._six import string_classes, int_classes
14 | import numpy as np
15 |
16 | if sys.version_info[0] == 2:
17 | import Queue as queue
18 | else:
19 | import queue
20 |
21 |
22 | class ExceptionWrapper(object):
23 | r"Wraps an exception plus traceback to communicate across threads"
24 |
25 | def __init__(self, exc_info):
26 | self.exc_type = exc_info[0]
27 | self.exc_msg = "".join(traceback.format_exception(*exc_info))
28 |
29 |
30 | _use_shared_memory = False
31 | """Whether to use shared memory in default_collate"""
32 |
33 |
34 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
35 | global _use_shared_memory
36 | _use_shared_memory = True
37 |
38 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
39 | # module's handlers are executed after Python returns from C low-level
40 | # handlers, likely when the same fatal signal happened again already.
41 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
42 | _set_worker_signal_handlers()
43 |
44 | torch.set_num_threads(1)
45 | torch.manual_seed(seed)
46 | np.random.seed(seed)
47 |
48 | if init_fn is not None:
49 | init_fn(worker_id)
50 |
51 | while True:
52 | r = index_queue.get()
53 | if r is None:
54 | break
55 | idx, batch_indices = r
56 | try:
57 | samples = collate_fn([dataset[i] for i in batch_indices])
58 | except Exception:
59 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
60 | else:
61 | data_queue.put((idx, samples))
62 |
63 |
64 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
65 | if pin_memory:
66 | torch.cuda.set_device(device_id)
67 |
68 | while True:
69 | try:
70 | r = in_queue.get()
71 | except Exception:
72 | if done_event.is_set():
73 | return
74 | raise
75 | if r is None:
76 | break
77 | if isinstance(r[1], ExceptionWrapper):
78 | out_queue.put(r)
79 | continue
80 | idx, batch = r
81 | try:
82 | if pin_memory:
83 | batch = pin_memory_batch(batch)
84 | except Exception:
85 | out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
86 | else:
87 | out_queue.put((idx, batch))
88 |
89 | numpy_type_map = {
90 | 'float64': torch.DoubleTensor,
91 | 'float32': torch.FloatTensor,
92 | 'float16': torch.HalfTensor,
93 | 'int64': torch.LongTensor,
94 | 'int32': torch.IntTensor,
95 | 'int16': torch.ShortTensor,
96 | 'int8': torch.CharTensor,
97 | 'uint8': torch.ByteTensor,
98 | }
99 |
100 |
101 | def default_collate(batch):
102 | "Puts each data field into a tensor with outer dimension batch size"
103 |
104 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
105 | elem_type = type(batch[0])
106 | if torch.is_tensor(batch[0]):
107 | out = None
108 | if _use_shared_memory:
109 | # If we're in a background process, concatenate directly into a
110 | # shared memory tensor to avoid an extra copy
111 | numel = sum([x.numel() for x in batch])
112 | storage = batch[0].storage()._new_shared(numel)
113 | out = batch[0].new(storage)
114 | return torch.stack(batch, 0, out=out)
115 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
116 | and elem_type.__name__ != 'string_':
117 | elem = batch[0]
118 | if elem_type.__name__ == 'ndarray':
119 | # array of string classes and object
120 | if re.search('[SaUO]', elem.dtype.str) is not None:
121 | raise TypeError(error_msg.format(elem.dtype))
122 |
123 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
124 | if elem.shape == (): # scalars
125 | py_type = float if elem.dtype.name.startswith('float') else int
126 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
127 | elif isinstance(batch[0], int_classes):
128 | return torch.LongTensor(batch)
129 | elif isinstance(batch[0], float):
130 | return torch.DoubleTensor(batch)
131 | elif isinstance(batch[0], string_classes):
132 | return batch
133 | elif isinstance(batch[0], collections.Mapping):
134 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
135 | elif isinstance(batch[0], collections.Sequence):
136 | transposed = zip(*batch)
137 | return [default_collate(samples) for samples in transposed]
138 |
139 | raise TypeError((error_msg.format(type(batch[0]))))
140 |
141 |
142 | def pin_memory_batch(batch):
143 | if torch.is_tensor(batch):
144 | return batch.pin_memory()
145 | elif isinstance(batch, string_classes):
146 | return batch
147 | elif isinstance(batch, collections.Mapping):
148 | return {k: pin_memory_batch(sample) for k, sample in batch.items()}
149 | elif isinstance(batch, collections.Sequence):
150 | return [pin_memory_batch(sample) for sample in batch]
151 | else:
152 | return batch
153 |
154 |
155 | _SIGCHLD_handler_set = False
156 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one
157 | handler needs to be set for all DataLoaders in a process."""
158 |
159 |
160 | def _set_SIGCHLD_handler():
161 | # Windows doesn't support SIGCHLD handler
162 | if sys.platform == 'win32':
163 | return
164 | # can't set signal in child threads
165 | if not isinstance(threading.current_thread(), threading._MainThread):
166 | return
167 | global _SIGCHLD_handler_set
168 | if _SIGCHLD_handler_set:
169 | return
170 | previous_handler = signal.getsignal(signal.SIGCHLD)
171 | if not callable(previous_handler):
172 | previous_handler = None
173 |
174 | def handler(signum, frame):
175 | # This following call uses `waitid` with WNOHANG from C side. Therefore,
176 | # Python can still get and update the process status successfully.
177 | _error_if_any_worker_fails()
178 | if previous_handler is not None:
179 | previous_handler(signum, frame)
180 |
181 | signal.signal(signal.SIGCHLD, handler)
182 | _SIGCHLD_handler_set = True
183 |
184 |
185 | class DataLoaderIter(object):
186 | "Iterates once over the DataLoader's dataset, as specified by the sampler"
187 |
188 | def __init__(self, loader):
189 | self.dataset = loader.dataset
190 | self.collate_fn = loader.collate_fn
191 | self.batch_sampler = loader.batch_sampler
192 | self.num_workers = loader.num_workers
193 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
194 | self.timeout = loader.timeout
195 | self.done_event = threading.Event()
196 |
197 | self.sample_iter = iter(self.batch_sampler)
198 |
199 | if self.num_workers > 0:
200 | self.worker_init_fn = loader.worker_init_fn
201 | self.index_queue = multiprocessing.SimpleQueue()
202 | self.worker_result_queue = multiprocessing.SimpleQueue()
203 | self.batches_outstanding = 0
204 | self.worker_pids_set = False
205 | self.shutdown = False
206 | self.send_idx = 0
207 | self.rcvd_idx = 0
208 | self.reorder_dict = {}
209 |
210 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
211 | self.workers = [
212 | multiprocessing.Process(
213 | target=_worker_loop,
214 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
215 | base_seed + i, self.worker_init_fn, i))
216 | for i in range(self.num_workers)]
217 |
218 | if self.pin_memory or self.timeout > 0:
219 | self.data_queue = queue.Queue()
220 | if self.pin_memory:
221 | maybe_device_id = torch.cuda.current_device()
222 | else:
223 | # do not initialize cuda context if not necessary
224 | maybe_device_id = None
225 | self.worker_manager_thread = threading.Thread(
226 | target=_worker_manager_loop,
227 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
228 | maybe_device_id))
229 | self.worker_manager_thread.daemon = True
230 | self.worker_manager_thread.start()
231 | else:
232 | self.data_queue = self.worker_result_queue
233 |
234 | for w in self.workers:
235 | w.daemon = True # ensure that the worker exits on process exit
236 | w.start()
237 |
238 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
239 | _set_SIGCHLD_handler()
240 | self.worker_pids_set = True
241 |
242 | # prime the prefetch loop
243 | for _ in range(2 * self.num_workers):
244 | self._put_indices()
245 |
246 | def __len__(self):
247 | return len(self.batch_sampler)
248 |
249 | def _get_batch(self):
250 | if self.timeout > 0:
251 | try:
252 | return self.data_queue.get(timeout=self.timeout)
253 | except queue.Empty:
254 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
255 | else:
256 | return self.data_queue.get()
257 |
258 | def __next__(self):
259 | if self.num_workers == 0: # same-process loading
260 | indices = next(self.sample_iter) # may raise StopIteration
261 | batch = self.collate_fn([self.dataset[i] for i in indices])
262 | if self.pin_memory:
263 | batch = pin_memory_batch(batch)
264 | return batch
265 |
266 | # check if the next sample has already been generated
267 | if self.rcvd_idx in self.reorder_dict:
268 | batch = self.reorder_dict.pop(self.rcvd_idx)
269 | return self._process_next_batch(batch)
270 |
271 | if self.batches_outstanding == 0:
272 | self._shutdown_workers()
273 | raise StopIteration
274 |
275 | while True:
276 | assert (not self.shutdown and self.batches_outstanding > 0)
277 | idx, batch = self._get_batch()
278 | self.batches_outstanding -= 1
279 | if idx != self.rcvd_idx:
280 | # store out-of-order samples
281 | self.reorder_dict[idx] = batch
282 | continue
283 | return self._process_next_batch(batch)
284 |
285 | next = __next__ # Python 2 compatibility
286 |
287 | def __iter__(self):
288 | return self
289 |
290 | def _put_indices(self):
291 | assert self.batches_outstanding < 2 * self.num_workers
292 | indices = next(self.sample_iter, None)
293 | if indices is None:
294 | return
295 | self.index_queue.put((self.send_idx, indices))
296 | self.batches_outstanding += 1
297 | self.send_idx += 1
298 |
299 | def _process_next_batch(self, batch):
300 | self.rcvd_idx += 1
301 | self._put_indices()
302 | if isinstance(batch, ExceptionWrapper):
303 | raise batch.exc_type(batch.exc_msg)
304 | return batch
305 |
306 | def __getstate__(self):
307 | # TODO: add limited pickling support for sharing an iterator
308 | # across multiple threads for HOGWILD.
309 | # Probably the best way to do this is by moving the sample pushing
310 | # to a separate thread and then just sharing the data queue
311 | # but signalling the end is tricky without a non-blocking API
312 | raise NotImplementedError("DataLoaderIterator cannot be pickled")
313 |
314 | def _shutdown_workers(self):
315 | try:
316 | if not self.shutdown:
317 | self.shutdown = True
318 | self.done_event.set()
319 | # if worker_manager_thread is waiting to put
320 | while not self.data_queue.empty():
321 | self.data_queue.get()
322 | for _ in self.workers:
323 | self.index_queue.put(None)
324 | # done_event should be sufficient to exit worker_manager_thread,
325 | # but be safe here and put another None
326 | self.worker_result_queue.put(None)
327 | finally:
328 | # removes pids no matter what
329 | if self.worker_pids_set:
330 | _remove_worker_pids(id(self))
331 | self.worker_pids_set = False
332 |
333 | def __del__(self):
334 | if self.num_workers > 0:
335 | self._shutdown_workers()
336 |
337 |
338 | class DataLoader(object):
339 | """
340 | Data loader. Combines a dataset and a sampler, and provides
341 | single- or multi-process iterators over the dataset.
342 |
343 | Arguments:
344 | dataset (Dataset): dataset from which to load the data.
345 | batch_size (int, optional): how many samples per batch to load
346 | (default: 1).
347 | shuffle (bool, optional): set to ``True`` to have the data reshuffled
348 | at every epoch (default: False).
349 | sampler (Sampler, optional): defines the strategy to draw samples from
350 | the dataset. If specified, ``shuffle`` must be False.
351 | batch_sampler (Sampler, optional): like sampler, but returns a batch of
352 | indices at a time. Mutually exclusive with batch_size, shuffle,
353 | sampler, and drop_last.
354 | num_workers (int, optional): how many subprocesses to use for data
355 | loading. 0 means that the data will be loaded in the main process.
356 | (default: 0)
357 | collate_fn (callable, optional): merges a list of samples to form a mini-batch.
358 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors
359 | into CUDA pinned memory before returning them.
360 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
361 | if the dataset size is not divisible by the batch size. If ``False`` and
362 | the size of dataset is not divisible by the batch size, then the last batch
363 | will be smaller. (default: False)
364 | timeout (numeric, optional): if positive, the timeout value for collecting a batch
365 | from workers. Should always be non-negative. (default: 0)
366 | worker_init_fn (callable, optional): If not None, this will be called on each
367 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
368 | input, after seeding and before data loading. (default: None)
369 |
370 | .. note:: By default, each worker will have its PyTorch seed set to
371 | ``base_seed + worker_id``, where ``base_seed`` is a long generated
372 | by main process using its RNG. You may use ``torch.initial_seed()`` to access
373 | this value in :attr:`worker_init_fn`, which can be used to set other seeds
374 | (e.g. NumPy) before data loading.
375 |
376 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
377 | unpicklable object, e.g., a lambda function.
378 | """
379 |
380 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
381 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
382 | timeout=0, worker_init_fn=None):
383 | self.dataset = dataset
384 | self.batch_size = batch_size
385 | self.num_workers = num_workers
386 | self.collate_fn = collate_fn
387 | self.pin_memory = pin_memory
388 | self.drop_last = drop_last
389 | self.timeout = timeout
390 | self.worker_init_fn = worker_init_fn
391 |
392 | if timeout < 0:
393 | raise ValueError('timeout option should be non-negative')
394 |
395 | if batch_sampler is not None:
396 | if batch_size > 1 or shuffle or sampler is not None or drop_last:
397 | raise ValueError('batch_sampler is mutually exclusive with '
398 | 'batch_size, shuffle, sampler, and drop_last')
399 |
400 | if sampler is not None and shuffle:
401 | raise ValueError('sampler is mutually exclusive with shuffle')
402 |
403 | if self.num_workers < 0:
404 | raise ValueError('num_workers cannot be negative; '
405 | 'use num_workers=0 to disable multiprocessing.')
406 |
407 | if batch_sampler is None:
408 | if sampler is None:
409 | if shuffle:
410 | sampler = RandomSampler(dataset)
411 | else:
412 | sampler = SequentialSampler(dataset)
413 | batch_sampler = BatchSampler(sampler, batch_size, drop_last)
414 |
415 | self.sampler = sampler
416 | self.batch_sampler = batch_sampler
417 |
418 | def __iter__(self):
419 | return DataLoaderIter(self)
420 |
421 | def __len__(self):
422 | return len(self.batch_sampler)
423 |
--------------------------------------------------------------------------------
/AIM/sync_bn/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import warnings
3 |
4 | from torch._utils import _accumulate
5 | from torch import randperm
6 |
7 |
8 | class Dataset(object):
9 | """An abstract class representing a Dataset.
10 |
11 | All other datasets should subclass it. All subclasses should override
12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13 | supporting integer indexing in range from 0 to len(self) exclusive.
14 | """
15 |
16 | def __getitem__(self, index):
17 | raise NotImplementedError
18 |
19 | def __len__(self):
20 | raise NotImplementedError
21 |
22 | def __add__(self, other):
23 | return ConcatDataset([self, other])
24 |
25 |
26 | class TensorDataset(Dataset):
27 | """Dataset wrapping data and target tensors.
28 |
29 | Each sample will be retrieved by indexing both tensors along the first
30 | dimension.
31 |
32 | Arguments:
33 | data_tensor (Tensor): contains sample data.
34 | target_tensor (Tensor): contains sample targets (labels).
35 | """
36 |
37 | def __init__(self, data_tensor, target_tensor):
38 | assert data_tensor.size(0) == target_tensor.size(0)
39 | self.data_tensor = data_tensor
40 | self.target_tensor = target_tensor
41 |
42 | def __getitem__(self, index):
43 | return self.data_tensor[index], self.target_tensor[index]
44 |
45 | def __len__(self):
46 | return self.data_tensor.size(0)
47 |
48 |
49 | class ConcatDataset(Dataset):
50 | """
51 | Dataset to concatenate multiple datasets.
52 | Purpose: useful to assemble different existing datasets, possibly
53 | large-scale datasets as the concatenation operation is done in an
54 | on-the-fly manner.
55 |
56 | Arguments:
57 | datasets (iterable): List of datasets to be concatenated
58 | """
59 |
60 | @staticmethod
61 | def cumsum(sequence):
62 | r, s = [], 0
63 | for e in sequence:
64 | l = len(e)
65 | r.append(l + s)
66 | s += l
67 | return r
68 |
69 | def __init__(self, datasets):
70 | super(ConcatDataset, self).__init__()
71 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
72 | self.datasets = list(datasets)
73 | self.cumulative_sizes = self.cumsum(self.datasets)
74 |
75 | def __len__(self):
76 | return self.cumulative_sizes[-1]
77 |
78 | def __getitem__(self, idx):
79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80 | if dataset_idx == 0:
81 | sample_idx = idx
82 | else:
83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84 | return self.datasets[dataset_idx][sample_idx]
85 |
86 | @property
87 | def cummulative_sizes(self):
88 | warnings.warn("cummulative_sizes attribute is renamed to "
89 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
90 | return self.cumulative_sizes
91 |
92 |
93 | class Subset(Dataset):
94 | def __init__(self, dataset, indices):
95 | self.dataset = dataset
96 | self.indices = indices
97 |
98 | def __getitem__(self, idx):
99 | return self.dataset[self.indices[idx]]
100 |
101 | def __len__(self):
102 | return len(self.indices)
103 |
104 |
105 | def random_split(dataset, lengths):
106 | """
107 | Randomly split a dataset into non-overlapping new datasets of given lengths
108 | ds
109 |
110 | Arguments:
111 | dataset (Dataset): Dataset to be split
112 | lengths (iterable): lengths of splits to be produced
113 | """
114 | if sum(lengths) != len(dataset):
115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116 |
117 | indices = randperm(sum(lengths))
118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
119 |
--------------------------------------------------------------------------------
/AIM/sync_bn/utils/data/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from .sampler import Sampler
4 | from torch.distributed import get_world_size, get_rank
5 |
6 |
7 | class DistributedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 |
15 | .. note::
16 | Dataset is assumed to be of constant size.
17 |
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | """
24 |
25 | def __init__(self, dataset, num_replicas=None, rank=None):
26 | if num_replicas is None:
27 | num_replicas = get_world_size()
28 | if rank is None:
29 | rank = get_rank()
30 | self.dataset = dataset
31 | self.num_replicas = num_replicas
32 | self.rank = rank
33 | self.epoch = 0
34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 |
37 | def __iter__(self):
38 | # deterministically shuffle based on epoch
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch)
41 | indices = list(torch.randperm(len(self.dataset), generator=g))
42 |
43 | # add extra samples to make it evenly divisible
44 | indices += indices[:(self.total_size - len(indices))]
45 | assert len(indices) == self.total_size
46 |
47 | # subsample
48 | offset = self.num_samples * self.rank
49 | indices = indices[offset:offset + self.num_samples]
50 | assert len(indices) == self.num_samples
51 |
52 | return iter(indices)
53 |
54 | def __len__(self):
55 | return self.num_samples
56 |
57 | def set_epoch(self, epoch):
58 | self.epoch = epoch
59 |
--------------------------------------------------------------------------------
/AIM/sync_bn/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler(object):
5 | """Base class for all Samplers.
6 |
7 | Every Sampler subclass has to provide an __iter__ method, providing a way
8 | to iterate over indices of dataset elements, and a __len__ method that
9 | returns the length of the returned iterators.
10 | """
11 |
12 | def __init__(self, data_source):
13 | pass
14 |
15 | def __iter__(self):
16 | raise NotImplementedError
17 |
18 | def __len__(self):
19 | raise NotImplementedError
20 |
21 |
22 | class SequentialSampler(Sampler):
23 | """Samples elements sequentially, always in the same order.
24 |
25 | Arguments:
26 | data_source (Dataset): dataset to sample from
27 | """
28 |
29 | def __init__(self, data_source):
30 | self.data_source = data_source
31 |
32 | def __iter__(self):
33 | return iter(range(len(self.data_source)))
34 |
35 | def __len__(self):
36 | return len(self.data_source)
37 |
38 |
39 | class RandomSampler(Sampler):
40 | """Samples elements randomly, without replacement.
41 |
42 | Arguments:
43 | data_source (Dataset): dataset to sample from
44 | """
45 |
46 | def __init__(self, data_source):
47 | self.data_source = data_source
48 |
49 | def __iter__(self):
50 | return iter(torch.randperm(len(self.data_source)).long())
51 |
52 | def __len__(self):
53 | return len(self.data_source)
54 |
55 |
56 | class SubsetRandomSampler(Sampler):
57 | """Samples elements randomly from a given list of indices, without replacement.
58 |
59 | Arguments:
60 | indices (list): a list of indices
61 | """
62 |
63 | def __init__(self, indices):
64 | self.indices = indices
65 |
66 | def __iter__(self):
67 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
68 |
69 | def __len__(self):
70 | return len(self.indices)
71 |
72 |
73 | class WeightedRandomSampler(Sampler):
74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75 |
76 | Arguments:
77 | weights (list) : a list of weights, not necessary summing up to one
78 | num_samples (int): number of samples to draw
79 | replacement (bool): if ``True``, samples are drawn with replacement.
80 | If not, they are drawn without replacement, which means that when a
81 | sample index is drawn for a row, it cannot be drawn again for that row.
82 | """
83 |
84 | def __init__(self, weights, num_samples, replacement=True):
85 | self.weights = torch.DoubleTensor(weights)
86 | self.num_samples = num_samples
87 | self.replacement = replacement
88 |
89 | def __iter__(self):
90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91 |
92 | def __len__(self):
93 | return self.num_samples
94 |
95 |
96 | class BatchSampler(object):
97 | """Wraps another sampler to yield a mini-batch of indices.
98 |
99 | Args:
100 | sampler (Sampler): Base sampler.
101 | batch_size (int): Size of mini-batch.
102 | drop_last (bool): If ``True``, the sampler will drop the last batch if
103 | its size would be less than ``batch_size``
104 |
105 | Example:
106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110 | """
111 |
112 | def __init__(self, sampler, batch_size, drop_last):
113 | self.sampler = sampler
114 | self.batch_size = batch_size
115 | self.drop_last = drop_last
116 |
117 | def __iter__(self):
118 | batch = []
119 | for idx in self.sampler:
120 | batch.append(idx)
121 | if len(batch) == self.batch_size:
122 | yield batch
123 | batch = []
124 | if len(batch) > 0 and not self.drop_last:
125 | yield batch
126 |
127 | def __len__(self):
128 | if self.drop_last:
129 | return len(self.sampler) // self.batch_size
130 | else:
131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
132 |
--------------------------------------------------------------------------------
/AIM/sync_bn/utils/th.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import collections
5 |
6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
7 |
8 | def as_variable(obj):
9 | if isinstance(obj, Variable):
10 | return obj
11 | if isinstance(obj, collections.Sequence):
12 | return [as_variable(v) for v in obj]
13 | elif isinstance(obj, collections.Mapping):
14 | return {k: as_variable(v) for k, v in obj.items()}
15 | else:
16 | return Variable(obj)
17 |
18 | def as_numpy(obj):
19 | if isinstance(obj, collections.Sequence):
20 | return [as_numpy(v) for v in obj]
21 | elif isinstance(obj, collections.Mapping):
22 | return {k: as_numpy(v) for k, v in obj.items()}
23 | elif isinstance(obj, Variable):
24 | return obj.data.cpu().numpy()
25 | elif torch.is_tensor(obj):
26 | return obj.cpu().numpy()
27 | else:
28 | return np.array(obj)
29 |
30 | def mark_volatile(obj):
31 | if torch.is_tensor(obj):
32 | obj = Variable(obj)
33 | if isinstance(obj, Variable):
34 | obj.no_grad = True
35 | return obj
36 | elif isinstance(obj, collections.Mapping):
37 | return {k: mark_volatile(o) for k, o in obj.items()}
38 | elif isinstance(obj, collections.Sequence):
39 | return [mark_volatile(o) for o in obj]
40 | else:
41 | return obj
42 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [AAAI 2024]Weakly Supervised Multimodal Affordance Grounding for Egocentric Images
2 | ## Paper
3 | >Weakly Supervised Multimodal Affordance Grounding for Egocentric Images(AAAI 2024)
4 |
5 | Link: https://doi.org/10.1609/aaai.v38i6.28451
6 |
7 | >Appendix
8 |
9 | Link: [Appendix.pdf](/docs)
10 |
11 |
12 |
13 | **Abstract:**
14 |
15 | To enhance the interaction between intelligent systems and the environment, locating the affordance regions of objects is crucial. These regions correspond to specific areas that provide distinct functionalities. Humans often acquire the ability to identify these regions through action demonstrations and verbal instructions. In this paper, we present a novel multimodal framework that extracts affordance knowledge from exocentric images, which depict human-object interactions, as well as from accompanying textual descriptions that describe the performed actions. The extracted knowledge is then transferred to egocentric images. To achieve this goal, we propose the HOI-Transfer Module, which utilizes local perception to disentangle individual actions within exocentric images. This module effectively captures localized features and correlations between actions, leading to valuable affordance knowledge. Additionally, we introduce the Pixel-Text Fusion Module, which fuses affordance knowledge by identifying regions in egocentric images that bear resemblances to the textual features defining affordances. We employ a Weakly Supervised Multimodal Affordance (WSMA) learning approach, utilizing image-level labels for training. Through extensive experiments, we demonstrate the superiority of our proposed method in terms of evaluation metrics and visual results when compared to existing affordance grounding models. Furthermore, ablation experiments confirm the effectiveness of our approach.
16 |
17 | ## Requirements
18 | We run in the following environment:
19 | - A NVIDIA GeForce RTX 3090
20 | - Python(3.8)
21 | - Pytorch(1.10.0)
22 |
23 | ## Required pre-trained models
24 | - model for Dino_vit(No need to download separately, the code is already included)
25 | - model for text_enconder(clip): You can find it [here](https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt)
26 |
27 | ## Start
28 | ```bash
29 | git clone https://github.com/xulingjing88/WSMA.git
30 | cd WSMA
31 | ```
32 | Before training, you need to preprocess the data
33 | - Seen, Unseen(from AGD20K): You can find it [here](https://github.com/lhc1224/Cross-View-AG/tree/main/code/cvpr)
34 | - HICO-IIF: You can find it [here](https://pan.baidu.com/s/1imzN-mRaWLIyDLZ80NibxQ?pwd=c878) or google cloud [here](https://drive.google.com/file/d/1InBbjM6Uo9HK8OKAuK7TZyvCFjSVKkfY/view?usp=drive_link)
35 | ```bash
36 | python preprocessing.py
37 | ```
38 | Set **'data_root'** to the path of the dataset, **'divide'** to the dataset name (Seen or Unseen or HICO-IIF), and then you can start training by running train.py.
39 | ```bash
40 | python train.py
41 | ```
42 |
43 | ## Acknowledgements
44 | We would like to express our gratitude to the following repositories for their contributions and inspirations: [Cross-View-AG](https://github.com/lhc1224/Cross-View-AG), [LOCATE](https://github.com/Reagan1311/LOCATE), [Dino](https://github.com/facebookresearch/dino), [CLIP](https://github.com/openai/CLIP).
45 |
46 |
--------------------------------------------------------------------------------
/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/datatest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | import collections
4 | import torch
5 | import numpy as np
6 | import h5py
7 | from joblib import Parallel, delayed
8 | import random
9 | from utils import util
10 | import cv2
11 | from utils.transform import Normalize, cv_random_crop_flip, load_image
12 | from torchvision import transforms
13 | from PIL import Image
14 |
15 | mean = [0.485, 0.456, 0.406]
16 | std = [0.229, 0.224, 0.225]
17 |
18 |
19 | class TrainData(data.Dataset):
20 | def __init__(self, egocentric_root, crop_size=224, divide="Unseen", mask_root=None):
21 | self.egocentric_root = egocentric_root
22 | self.image_list = []
23 | self.crop_size = crop_size
24 | self.mask_root = mask_root
25 | if divide == "Seen":
26 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch",
27 | "cut", "cut_with", "drag", 'drink_with', "eat",
28 | "hit", "hold", "jump", "kick", "lie_on", "lift",
29 | "look_out", "open", "pack", "peel", "pick_up",
30 | "pour", "push", "ride", "sip", "sit_on", "stick",
31 | "stir", "swing", "take_photo", "talk_on", "text_on",
32 | "throw", "type_on", "wash", "write"]
33 | elif divide=="Unseen":
34 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with',
35 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
36 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
37 | "swing", "take_photo", "throw", "type_on", "wash"]
38 | else: # HICO-IIF
39 | self.aff_list = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on']
40 |
41 | self.transform = transforms.Compose([
42 | transforms.Resize((crop_size, crop_size)),
43 | transforms.ToTensor(),
44 | transforms.Normalize(mean, std)
45 | ])
46 |
47 | files = os.listdir(self.egocentric_root)
48 | for file in files:
49 | file_path = os.path.join(self.egocentric_root, file)
50 | obj_files = os.listdir(file_path)
51 | for obj_file in obj_files:
52 | obj_file_path = os.path.join(file_path, obj_file)
53 | images = os.listdir(obj_file_path)
54 | for img in images:
55 | if 'json' not in img:
56 | img_path = os.path.join(obj_file_path, img)
57 | cur = img_path.split("/")
58 | if os.path.exists(os.path.join(self.mask_root, file, obj_file, img[:-3] + "png")):
59 | self.image_list.append(img_path)
60 |
61 | def __getitem__(self, item):
62 | egocentric_image_path = self.image_list[item]
63 | names = egocentric_image_path.split("/")
64 | aff_name, object = names[-3], names[-2]
65 | label = self.aff_list.index(aff_name)
66 | egocentric_image = self.load_static_image(egocentric_image_path) # At this time, the graph of individual items has been converted into tensor
67 |
68 | mask_path = os.path.join(self.mask_root, names[-3], names[-2], names[-1][:-3] + "png")
69 | mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
70 | mask = cv2.resize(mask, (224, 224))
71 |
72 | return egocentric_image, label, mask_path, aff_name
73 |
74 | def load_static_image(self, path):
75 |
76 | img = util.load_img(path)
77 | img = self.transform(img)
78 | return img
79 |
80 | def __len__(self):
81 |
82 | return len(self.image_list)
83 |
--------------------------------------------------------------------------------
/datatrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | import collections
4 | import torch
5 | import numpy as np
6 | import h5py
7 | from joblib import Parallel, delayed
8 | import random
9 | from utils import util
10 | import cv2
11 | from utils.transform import Normalize, cv_random_crop_flip, load_image
12 | from torchvision import transforms
13 | from PIL import Image
14 |
15 | mean = [0.485, 0.456, 0.406]
16 | std = [0.229, 0.224, 0.225]
17 |
18 |
19 | class TrainData(data.Dataset):
20 | def __init__(self, exocentric_root, egocentric_root, resize_size=256, crop_size=224, divide="Unseen"):
21 | self.exocentric_root = exocentric_root
22 | self.egocentric_root = egocentric_root
23 | self.resize_size = resize_size
24 | self.image_list = []
25 | self.crop_size = crop_size
26 | if divide == "Seen":
27 | self.aff_list = ['beat', "boxing", "brush_with", "carry", "catch",
28 | "cut", "cut_with", "drag", 'drink_with', "eat",
29 | "hit", "hold", "jump", "kick", "lie_on", "lift",
30 | "look_out", "open", "pack", "peel", "pick_up",
31 | "pour", "push", "ride", "sip", "sit_on", "stick",
32 | "stir", "swing", "take_photo", "talk_on", "text_on",
33 | "throw", "type_on", "wash", "write"]
34 | elif divide=="Unseen":
35 | self.aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with',
36 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
37 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
38 | "swing", "take_photo", "throw", "type_on", "wash"]
39 | else: # HICO-IIF
40 | self.aff_list = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on']
41 |
42 | self.transform = transforms.Compose([
43 | transforms.Resize(resize_size),
44 | transforms.RandomCrop(crop_size),
45 | transforms.RandomHorizontalFlip(),
46 | transforms.ToTensor(),
47 | transforms.Normalize(mean, std)
48 | ])
49 |
50 | files = os.listdir(self.exocentric_root)
51 | for file in files:
52 | file_path = os.path.join(self.exocentric_root, file)
53 | obj_files = os.listdir(file_path)
54 | for obj_file in obj_files:
55 | obj_file_path = os.path.join(file_path, obj_file)
56 | images = os.listdir(obj_file_path)
57 | for img in images:
58 | img_path = os.path.join(obj_file_path, img)
59 | cur = img_path.split("/")
60 | if os.path.exists(os.path.join(self.egocentric_root, cur[-3], cur[-2])):
61 | self.image_list.append(img_path)
62 |
63 | def __getitem__(self, item):
64 | exocentric_image_path = self.image_list[item]
65 | names = exocentric_image_path.split("/")
66 | aff_name, object = names[-3], names[-2]
67 | object_file = os.path.join(self.egocentric_root, aff_name, object)
68 | obj_images = os.listdir(object_file)
69 | label = self.aff_list.index(aff_name)
70 | idx = random.randint(0, len(obj_images) - 1) # Randomly select a picture
71 | egocentric_image_path = os.path.join(object_file, obj_images[idx])
72 | names = egocentric_image_path.split("/")
73 | egocentric_image = self.load_static_image(egocentric_image_path)
74 |
75 | exocentric_file = os.path.join(self.exocentric_root, aff_name, object)
76 | exocentrics = os.listdir(exocentric_file)
77 | exocentric_images = []
78 | if len(exocentrics) <= 3: # If it is less than 3, select all of them and fill them up
79 | start = 0
80 | for i in range(start, len(exocentrics)):
81 | tmp_exo = self.load_static_image(os.path.join(self.exocentric_root, aff_name, object, exocentrics[i]))
82 | exocentric_images.append(tmp_exo)
83 | for i in range(len(exocentrics), 3):
84 | exocentric_images.append(tmp_exo)
85 | else:
86 | start = random.randint(0, len(exocentrics) - 4) # If more than 3, choose 3 randomly
87 | for idx in range(start, start + 3):
88 | tmp_exo = self.load_static_image(os.path.join(self.exocentric_root, aff_name, object, exocentrics[idx]))
89 | exocentric_images.append(tmp_exo)
90 |
91 | return exocentric_images, egocentric_image, label, aff_name
92 |
93 | def load_static_image(self, path):
94 |
95 | img = util.load_img(path)
96 | img = self.transform(img)
97 | return img
98 |
99 | def __len__(self):
100 |
101 | return len(self.image_list)
102 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | ## Some document here
2 |
--------------------------------------------------------------------------------
/docs/WSMA-appendix.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/docs/WSMA-appendix.pdf
--------------------------------------------------------------------------------
/images/README.md:
--------------------------------------------------------------------------------
1 | Some images
2 |
--------------------------------------------------------------------------------
/images/pipelline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/images/pipelline.png
--------------------------------------------------------------------------------
/images/pipelline1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/images/pipelline1.pdf
--------------------------------------------------------------------------------
/models/dino/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/models/dino/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dino/__pycache__/vision_transformer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/vision_transformer.cpython-311.pyc
--------------------------------------------------------------------------------
/models/dino/__pycache__/vision_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xulingjing88/WSMA/b4f7d3b77b536a80381acdf28c489b2393e177d9/models/dino/__pycache__/vision_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dino/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Misc functions.
16 |
17 | Mostly copy-paste from torchvision references or other public repos like DETR:
18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py
19 | """
20 | import os
21 | import sys
22 | import time
23 | import math
24 | import random
25 | import datetime
26 | import subprocess
27 | from collections import defaultdict, deque
28 |
29 | import numpy as np
30 | import torch
31 | from torch import nn
32 | import torch.distributed as dist
33 | from PIL import ImageFilter, ImageOps
34 |
35 |
36 | class GaussianBlur(object):
37 | """
38 | Apply Gaussian Blur to the PIL image.
39 | """
40 |
41 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
42 | self.prob = p
43 | self.radius_min = radius_min
44 | self.radius_max = radius_max
45 |
46 | def __call__(self, img):
47 | do_it = random.random() <= self.prob
48 | if not do_it:
49 | return img
50 |
51 | return img.filter(
52 | ImageFilter.GaussianBlur(
53 | radius=random.uniform(self.radius_min, self.radius_max)
54 | )
55 | )
56 |
57 |
58 | class Solarization(object):
59 | """
60 | Apply Solarization to the PIL image.
61 | """
62 |
63 | def __init__(self, p):
64 | self.p = p
65 |
66 | def __call__(self, img):
67 | if random.random() < self.p:
68 | return ImageOps.solarize(img)
69 | else:
70 | return img
71 |
72 |
73 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
74 | if os.path.isfile(pretrained_weights):
75 | state_dict = torch.load(pretrained_weights, map_location="cpu")
76 | if checkpoint_key is not None and checkpoint_key in state_dict:
77 | print(f"Take key {checkpoint_key} in provided checkpoint dict")
78 | state_dict = state_dict[checkpoint_key]
79 | # remove `module.` prefix
80 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
81 | # remove `backbone.` prefix induced by multicrop wrapper
82 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
83 | msg = model.load_state_dict(state_dict, strict=False)
84 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
85 | else:
86 | # print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
87 | url = None
88 | if model_name == "vit_small" and patch_size == 16:
89 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
90 | elif model_name == "vit_small" and patch_size == 8:
91 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
92 | elif model_name == "vit_base" and patch_size == 16:
93 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
94 | elif model_name == "vit_base" and patch_size == 8:
95 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
96 | if url is not None:
97 | # print("Since no pretrained weights have been provided, we load the reference pretrained dino weights.")
98 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
99 | model.load_state_dict(state_dict, strict=True)
100 | else:
101 | print("There is no reference weights available for this model => We use random weights.")
102 |
103 |
104 | def clip_gradients(model, clip):
105 | norms = []
106 | for name, p in model.named_parameters():
107 | if p.grad is not None:
108 | param_norm = p.grad.data.norm(2)
109 | norms.append(param_norm.item())
110 | clip_coef = clip / (param_norm + 1e-6)
111 | if clip_coef < 1:
112 | p.grad.data.mul_(clip_coef)
113 | return norms
114 |
115 |
116 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
117 | if epoch >= freeze_last_layer:
118 | return
119 | for n, p in model.named_parameters():
120 | if "last_layer" in n:
121 | p.grad = None
122 |
123 |
124 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
125 | """
126 | Re-start from checkpoint
127 | """
128 | if not os.path.isfile(ckp_path):
129 | return
130 | print("Found checkpoint at {}".format(ckp_path))
131 |
132 | # open checkpoint file
133 | checkpoint = torch.load(ckp_path, map_location="cpu")
134 |
135 | # key is what to look for in the checkpoint file
136 | # value is the object to load
137 | # example: {'state_dict': model}
138 | for key, value in kwargs.items():
139 | if key in checkpoint and value is not None:
140 | try:
141 | msg = value.load_state_dict(checkpoint[key], strict=False)
142 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
143 | except TypeError:
144 | try:
145 | msg = value.load_state_dict(checkpoint[key])
146 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
147 | except ValueError:
148 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
149 | else:
150 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
151 |
152 | # re load variable important for the run
153 | if run_variables is not None:
154 | for var_name in run_variables:
155 | if var_name in checkpoint:
156 | run_variables[var_name] = checkpoint[var_name]
157 |
158 |
159 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
160 | warmup_schedule = np.array([])
161 | warmup_iters = warmup_epochs * niter_per_ep
162 | if warmup_epochs > 0:
163 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
164 |
165 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
166 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
167 |
168 | schedule = np.concatenate((warmup_schedule, schedule))
169 | assert len(schedule) == epochs * niter_per_ep
170 | return schedule
171 |
172 |
173 | def bool_flag(s):
174 | """
175 | Parse boolean arguments from the command line.
176 | """
177 | FALSY_STRINGS = {"off", "false", "0"}
178 | TRUTHY_STRINGS = {"on", "true", "1"}
179 | if s.lower() in FALSY_STRINGS:
180 | return False
181 | elif s.lower() in TRUTHY_STRINGS:
182 | return True
183 | else:
184 | raise argparse.ArgumentTypeError("invalid value for a boolean flag")
185 |
186 |
187 | def fix_random_seeds(seed=31):
188 | """
189 | Fix random seeds.
190 | """
191 | torch.manual_seed(seed)
192 | torch.cuda.manual_seed_all(seed)
193 | np.random.seed(seed)
194 |
195 |
196 | class SmoothedValue(object):
197 | """Track a series of values and provide access to smoothed values over a
198 | window or the global series average.
199 | """
200 |
201 | def __init__(self, window_size=20, fmt=None):
202 | if fmt is None:
203 | fmt = "{median:.6f} ({global_avg:.6f})"
204 | self.deque = deque(maxlen=window_size)
205 | self.total = 0.0
206 | self.count = 0
207 | self.fmt = fmt
208 |
209 | def update(self, value, n=1):
210 | self.deque.append(value)
211 | self.count += n
212 | self.total += value * n
213 |
214 | def synchronize_between_processes(self):
215 | """
216 | Warning: does not synchronize the deque!
217 | """
218 | if not is_dist_avail_and_initialized():
219 | return
220 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
221 | dist.barrier()
222 | dist.all_reduce(t)
223 | t = t.tolist()
224 | self.count = int(t[0])
225 | self.total = t[1]
226 |
227 | @property
228 | def median(self):
229 | d = torch.tensor(list(self.deque))
230 | return d.median().item()
231 |
232 | @property
233 | def avg(self):
234 | d = torch.tensor(list(self.deque), dtype=torch.float32)
235 | return d.mean().item()
236 |
237 | @property
238 | def global_avg(self):
239 | return self.total / self.count
240 |
241 | @property
242 | def max(self):
243 | return max(self.deque)
244 |
245 | @property
246 | def value(self):
247 | return self.deque[-1]
248 |
249 | def __str__(self):
250 | return self.fmt.format(
251 | median=self.median,
252 | avg=self.avg,
253 | global_avg=self.global_avg,
254 | max=self.max,
255 | value=self.value)
256 |
257 |
258 | def reduce_dict(input_dict, average=True):
259 | """
260 | Args:
261 | input_dict (dict): all the values will be reduced
262 | average (bool): whether to do average or sum
263 | Reduce the values in the dictionary from all processes so that all processes
264 | have the averaged results. Returns a dict with the same fields as
265 | input_dict, after reduction.
266 | """
267 | world_size = get_world_size()
268 | if world_size < 2:
269 | return input_dict
270 | with torch.no_grad():
271 | names = []
272 | values = []
273 | # sort the keys so that they are consistent across processes
274 | for k in sorted(input_dict.keys()):
275 | names.append(k)
276 | values.append(input_dict[k])
277 | values = torch.stack(values, dim=0)
278 | dist.all_reduce(values)
279 | if average:
280 | values /= world_size
281 | reduced_dict = {k: v for k, v in zip(names, values)}
282 | return reduced_dict
283 |
284 |
285 | class MetricLogger(object):
286 | def __init__(self, delimiter="\t"):
287 | self.meters = defaultdict(SmoothedValue)
288 | self.delimiter = delimiter
289 |
290 | def update(self, **kwargs):
291 | for k, v in kwargs.items():
292 | if isinstance(v, torch.Tensor):
293 | v = v.item()
294 | assert isinstance(v, (float, int))
295 | self.meters[k].update(v)
296 |
297 | def __getattr__(self, attr):
298 | if attr in self.meters:
299 | return self.meters[attr]
300 | if attr in self.__dict__:
301 | return self.__dict__[attr]
302 | raise AttributeError("'{}' object has no attribute '{}'".format(
303 | type(self).__name__, attr))
304 |
305 | def __str__(self):
306 | loss_str = []
307 | for name, meter in self.meters.items():
308 | loss_str.append(
309 | "{}: {}".format(name, str(meter))
310 | )
311 | return self.delimiter.join(loss_str)
312 |
313 | def synchronize_between_processes(self):
314 | for meter in self.meters.values():
315 | meter.synchronize_between_processes()
316 |
317 | def add_meter(self, name, meter):
318 | self.meters[name] = meter
319 |
320 | def log_every(self, iterable, print_freq, header=None):
321 | i = 0
322 | if not header:
323 | header = ''
324 | start_time = time.time()
325 | end = time.time()
326 | iter_time = SmoothedValue(fmt='{avg:.6f}')
327 | data_time = SmoothedValue(fmt='{avg:.6f}')
328 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
329 | if torch.cuda.is_available():
330 | log_msg = self.delimiter.join([
331 | header,
332 | '[{0' + space_fmt + '}/{1}]',
333 | 'eta: {eta}',
334 | '{meters}',
335 | 'time: {time}',
336 | 'data: {data}',
337 | 'max mem: {memory:.0f}'
338 | ])
339 | else:
340 | log_msg = self.delimiter.join([
341 | header,
342 | '[{0' + space_fmt + '}/{1}]',
343 | 'eta: {eta}',
344 | '{meters}',
345 | 'time: {time}',
346 | 'data: {data}'
347 | ])
348 | MB = 1024.0 * 1024.0
349 | for obj in iterable:
350 | data_time.update(time.time() - end)
351 | yield obj
352 | iter_time.update(time.time() - end)
353 | if i % print_freq == 0 or i == len(iterable) - 1:
354 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
355 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
356 | if torch.cuda.is_available():
357 | print(log_msg.format(
358 | i, len(iterable), eta=eta_string,
359 | meters=str(self),
360 | time=str(iter_time), data=str(data_time),
361 | memory=torch.cuda.max_memory_allocated() / MB))
362 | else:
363 | print(log_msg.format(
364 | i, len(iterable), eta=eta_string,
365 | meters=str(self),
366 | time=str(iter_time), data=str(data_time)))
367 | i += 1
368 | end = time.time()
369 | total_time = time.time() - start_time
370 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
371 | print('{} Total time: {} ({:.6f} s / it)'.format(
372 | header, total_time_str, total_time / len(iterable)))
373 |
374 |
375 | def get_sha():
376 | cwd = os.path.dirname(os.path.abspath(__file__))
377 |
378 | def _run(command):
379 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
380 |
381 | sha = 'N/A'
382 | diff = "clean"
383 | branch = 'N/A'
384 | try:
385 | sha = _run(['git', 'rev-parse', 'HEAD'])
386 | subprocess.check_output(['git', 'diff'], cwd=cwd)
387 | diff = _run(['git', 'diff-index', 'HEAD'])
388 | diff = "has uncommited changes" if diff else "clean"
389 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
390 | except Exception:
391 | pass
392 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
393 | return message
394 |
395 |
396 | def is_dist_avail_and_initialized():
397 | if not dist.is_available():
398 | return False
399 | if not dist.is_initialized():
400 | return False
401 | return True
402 |
403 |
404 | def get_world_size():
405 | if not is_dist_avail_and_initialized():
406 | return 1
407 | return dist.get_world_size()
408 |
409 |
410 | def get_rank():
411 | if not is_dist_avail_and_initialized():
412 | return 0
413 | return dist.get_rank()
414 |
415 |
416 | def is_main_process():
417 | return get_rank() == 0
418 |
419 |
420 | def save_on_master(*args, **kwargs):
421 | if is_main_process():
422 | torch.save(*args, **kwargs)
423 |
424 |
425 | def setup_for_distributed(is_master):
426 | """
427 | This function disables printing when not in master process
428 | """
429 | import builtins as __builtin__
430 | builtin_print = __builtin__.print
431 |
432 | def print(*args, **kwargs):
433 | force = kwargs.pop('force', False)
434 | if is_master or force:
435 | builtin_print(*args, **kwargs)
436 |
437 | __builtin__.print = print
438 |
439 |
440 | def init_distributed_mode(args):
441 | # launched with torch.distributed.launch
442 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
443 | args.rank = int(os.environ["RANK"])
444 | args.world_size = int(os.environ['WORLD_SIZE'])
445 | args.gpu = int(os.environ['LOCAL_RANK'])
446 | # launched with submitit on a slurm cluster
447 | elif 'SLURM_PROCID' in os.environ:
448 | args.rank = int(os.environ['SLURM_PROCID'])
449 | args.gpu = args.rank % torch.cuda.device_count()
450 | # launched naively with `python main_dino.py`
451 | # we manually add MASTER_ADDR and MASTER_PORT to env variables
452 | elif torch.cuda.is_available():
453 | print('Will run the code on one GPU.')
454 | args.rank, args.gpu, args.world_size = 0, 0, 1
455 | os.environ['MASTER_ADDR'] = '127.0.0.1'
456 | os.environ['MASTER_PORT'] = '29500'
457 | else:
458 | print('Does not support training without GPU.')
459 | sys.exit(1)
460 |
461 | dist.init_process_group(
462 | backend="nccl",
463 | init_method=args.dist_url,
464 | world_size=args.world_size,
465 | rank=args.rank,
466 | )
467 |
468 | torch.cuda.set_device(args.gpu)
469 | print('| distributed init (rank {}): {}'.format(
470 | args.rank, args.dist_url), flush=True)
471 | dist.barrier()
472 | setup_for_distributed(args.rank == 0)
473 |
474 |
475 | def accuracy(output, target, topk=(1,)):
476 | """Computes the accuracy over the k top predictions for the specified values of k"""
477 | maxk = max(topk)
478 | batch_size = target.size(0)
479 | _, pred = output.topk(maxk, 1, True, True)
480 | pred = pred.t()
481 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
482 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
483 |
484 |
485 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
486 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
487 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
488 | def norm_cdf(x):
489 | # Computes standard normal cumulative distribution function
490 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
491 |
492 | if (mean < a - 2 * std) or (mean > b + 2 * std):
493 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
494 | "The distribution of values may be incorrect.",
495 | stacklevel=2)
496 |
497 | with torch.no_grad():
498 | # Values are generated by using a truncated uniform distribution and
499 | # then using the inverse CDF for the normal distribution.
500 | # Get upper and lower cdf values
501 | l = norm_cdf((a - mean) / std)
502 | u = norm_cdf((b - mean) / std)
503 |
504 | # Uniformly fill tensor with values from [l, u], then translate to
505 | # [2l-1, 2u-1].
506 | tensor.uniform_(2 * l - 1, 2 * u - 1)
507 |
508 | # Use inverse cdf transform for normal distribution to get truncated
509 | # standard normal
510 | tensor.erfinv_()
511 |
512 | # Transform to proper mean, std
513 | tensor.mul_(std * math.sqrt(2.))
514 | tensor.add_(mean)
515 |
516 | # Clamp to ensure it's in the proper range
517 | tensor.clamp_(min=a, max=b)
518 | return tensor
519 |
520 |
521 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
522 | # type: (Tensor, float, float, float, float) -> Tensor
523 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
524 |
525 |
526 | class LARS(torch.optim.Optimizer):
527 | """
528 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
529 | """
530 |
531 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
532 | weight_decay_filter=None, lars_adaptation_filter=None):
533 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
534 | eta=eta, weight_decay_filter=weight_decay_filter,
535 | lars_adaptation_filter=lars_adaptation_filter)
536 | super().__init__(params, defaults)
537 |
538 | @torch.no_grad()
539 | def step(self):
540 | for g in self.param_groups:
541 | for p in g['params']:
542 | dp = p.grad
543 |
544 | if dp is None:
545 | continue
546 |
547 | if p.ndim != 1:
548 | dp = dp.add(p, alpha=g['weight_decay'])
549 |
550 | if p.ndim != 1:
551 | param_norm = torch.norm(p)
552 | update_norm = torch.norm(dp)
553 | one = torch.ones_like(param_norm)
554 | q = torch.where(param_norm > 0.,
555 | torch.where(update_norm > 0,
556 | (g['eta'] * param_norm / update_norm), one), one)
557 | dp = dp.mul(q)
558 |
559 | param_state = self.state[p]
560 | if 'mu' not in param_state:
561 | param_state['mu'] = torch.zeros_like(p)
562 | mu = param_state['mu']
563 | mu.mul_(g['momentum']).add_(dp)
564 |
565 | p.add_(mu, alpha=-g['lr'])
566 |
567 |
568 | class MultiCropWrapper(nn.Module):
569 | """
570 | Perform forward pass separately on each resolution input.
571 | The inputs corresponding to a single resolution are clubbed and single
572 | forward is run on the same resolution inputs. Hence we do several
573 | forward passes = number of different resolutions used. We then
574 | concatenate all the output features and run the head forward on these
575 | concatenated features.
576 | """
577 |
578 | def __init__(self, backbone, head):
579 | super(MultiCropWrapper, self).__init__()
580 | # disable layers dedicated to ImageNet labels classification
581 | backbone.fc, backbone.head = nn.Identity(), nn.Identity()
582 | self.backbone = backbone
583 | self.head = head
584 |
585 | def forward(self, x):
586 | # convert to list
587 | if not isinstance(x, list):
588 | x = [x]
589 | idx_crops = torch.cumsum(torch.unique_consecutive(
590 | torch.tensor([inp.shape[-1] for inp in x]),
591 | return_counts=True,
592 | )[1], 0)
593 | start_idx = 0
594 | for end_idx in idx_crops:
595 | _out = self.backbone(torch.cat(x[start_idx: end_idx]))
596 | if start_idx == 0:
597 | output = _out
598 | else:
599 | output = torch.cat((output, _out))
600 | start_idx = end_idx
601 | # Run the head forward on the concatenated features.
602 | return self.head(output)
603 |
604 |
605 | def get_params_groups(model):
606 | regularized = []
607 | not_regularized = []
608 | for name, param in model.named_parameters():
609 | if not param.requires_grad:
610 | continue
611 | # we do not regularize biases nor Norm parameters
612 | if name.endswith(".bias") or len(param.shape) == 1:
613 | not_regularized.append(param)
614 | else:
615 | regularized.append(param)
616 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
617 |
618 |
619 | def has_batchnorms(model):
620 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
621 | for name, module in model.named_modules():
622 | if isinstance(module, bn_types):
623 | return True
624 | return False
625 |
--------------------------------------------------------------------------------
/models/dino/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from torch import Tensor
20 | from typing import Union
21 | from functools import partial
22 |
23 | import torch
24 | import torch.nn as nn
25 |
26 | from models.dino.utils import trunc_normal_
27 |
28 |
29 | def drop_path(x, drop_prob: float = 0., training: bool = False):
30 | if drop_prob == 0. or not training:
31 | return x
32 | keep_prob = 1 - drop_prob
33 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
34 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
35 | random_tensor.floor_() # binarize
36 | output = x.div(keep_prob) * random_tensor
37 | return output
38 |
39 |
40 | class DropPath(nn.Module):
41 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
42 | """
43 |
44 | def __init__(self, drop_prob=None):
45 | super(DropPath, self).__init__()
46 | self.drop_prob = drop_prob
47 |
48 | def forward(self, x):
49 | return drop_path(x, self.drop_prob, self.training)
50 |
51 |
52 | class Mlp(nn.Module):
53 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
54 | super().__init__()
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | self.fc1 = nn.Linear(in_features, hidden_features)
58 | self.act = act_layer()
59 | self.fc2 = nn.Linear(hidden_features, out_features)
60 | self.drop = nn.Dropout(drop)
61 |
62 | def forward(self, x):
63 | x = self.fc1(x)
64 | x = self.act(x)
65 | x = self.drop(x)
66 | x = self.fc2(x)
67 | x = self.drop(x)
68 | return x
69 |
70 |
71 | class Attention(nn.Module):
72 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
73 | super().__init__()
74 | self.num_heads = num_heads
75 | head_dim = dim // num_heads
76 | self.scale = qk_scale or head_dim ** -0.5
77 |
78 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
79 | self.attn_drop = nn.Dropout(attn_drop)
80 | self.proj = nn.Linear(dim, dim)
81 | self.proj_drop = nn.Dropout(proj_drop)
82 |
83 | def forward(self, x, return_key=False):
84 | B, N, C = x.shape
85 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
86 | q, k, v = qkv[0], qkv[1], qkv[2]
87 |
88 | attn = (q @ k.transpose(-2, -1)) * self.scale
89 | attn = attn.softmax(dim=-1)
90 | attn = self.attn_drop(attn)
91 |
92 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
93 | x = self.proj(x)
94 | x = self.proj_drop(x)
95 | if not return_key:
96 | return x, attn
97 | else:
98 | return x, attn, k
99 |
100 | class Block(nn.Module):
101 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
102 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
103 | super().__init__()
104 |
105 | self.norm1 = norm_layer(dim)
106 | self.attn = Attention(
107 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
108 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
109 | self.norm2 = norm_layer(dim)
110 | mlp_hidden_dim = int(dim * mlp_ratio)
111 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
112 |
113 | def forward(self, x, return_attention=False, return_key=False):
114 | if return_key:
115 | y, attn, key = self.attn(self.norm1(x), return_key)
116 | else:
117 | y, attn = self.attn(self.norm1(x))
118 | x = x + self.drop_path(y)
119 | x = x + self.drop_path(self.mlp(self.norm2(x)))
120 | if return_attention:
121 | return x, attn
122 | elif return_key:
123 | return x, key, attn
124 | else:
125 | return x
126 |
127 |
128 | class PatchEmbed(nn.Module):
129 | """ Image to Patch Embedding
130 | """
131 |
132 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
133 | super().__init__()
134 | num_patches = (img_size // patch_size) * (img_size // patch_size)
135 | self.img_size = img_size
136 | self.patch_size = patch_size
137 | self.num_patches = num_patches
138 |
139 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
140 |
141 | def forward(self, x):
142 | B, C, H, W = x.shape
143 | x = self.proj(x).flatten(2).transpose(1, 2)
144 | return x
145 |
146 |
147 | class VisionTransformer(nn.Module):
148 | """ Vision Transformer """
149 |
150 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
151 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
152 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
153 | super().__init__()
154 | self.num_features = self.embed_dim = embed_dim
155 |
156 | self.patch_embed = PatchEmbed(
157 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
158 | num_patches = self.patch_embed.num_patches
159 |
160 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
161 |
162 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
163 | self.pos_drop = nn.Dropout(p=drop_rate)
164 |
165 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
166 | self.blocks = nn.ModuleList([
167 | Block(
168 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
169 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
170 | for i in range(depth)])
171 | self.norm = norm_layer(embed_dim)
172 |
173 | # Classifier head
174 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
175 |
176 | trunc_normal_(self.pos_embed, std=.02)
177 | trunc_normal_(self.cls_token, std=.02)
178 | self.apply(self._init_weights)
179 |
180 | def _init_weights(self, m):
181 | if isinstance(m, nn.Linear):
182 | trunc_normal_(m.weight, std=.02)
183 | if isinstance(m, nn.Linear) and m.bias is not None:
184 | nn.init.constant_(m.bias, 0)
185 | elif isinstance(m, nn.LayerNorm):
186 | nn.init.constant_(m.bias, 0)
187 | nn.init.constant_(m.weight, 1.0)
188 |
189 | def interpolate_pos_encoding(self, x, w, h):
190 | npatch = x.shape[1] - 1
191 | N = self.pos_embed.shape[1] - 1
192 | if npatch == N and w == h:
193 | return self.pos_embed
194 | class_pos_embed = self.pos_embed[:, 0]
195 | patch_pos_embed = self.pos_embed[:, 1:]
196 | dim = x.shape[-1]
197 | w0 = w // self.patch_embed.patch_size
198 | h0 = h // self.patch_embed.patch_size
199 | # we add a small number to avoid floating point error in the interpolation
200 | # see discussion at https://github.com/facebookresearch/dino/issues/8
201 | w0, h0 = w0 + 0.1, h0 + 0.1
202 | patch_pos_embed = nn.functional.interpolate(
203 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
204 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
205 | mode='bicubic',
206 | )
207 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
208 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
209 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
210 |
211 | def prepare_tokens(self, x):
212 | B, nc, w, h = x.shape
213 | x = self.patch_embed(x) # patch linear embedding
214 |
215 | # add the [CLS] token to the embed patch tokens
216 | cls_tokens = self.cls_token.expand(B, -1, -1)
217 | x = torch.cat((cls_tokens, x), dim=1)
218 |
219 | # add positional encoding to each token
220 | x = x + self.interpolate_pos_encoding(x, w, h)
221 |
222 | return self.pos_drop(x)
223 |
224 | def forward(self, x, return_attention=False):
225 | if return_attention:
226 | atten_weights = []
227 | x = self.prepare_tokens(x)
228 | for blk_ in self.blocks:
229 | x, weights = blk_(x, return_attention)
230 | atten_weights.append(weights)
231 | x = self.norm(x)
232 | return x, atten_weights
233 |
234 | else:
235 | x = self.prepare_tokens(x)
236 | for blk_ in self.blocks:
237 | x = blk_(x)
238 | x = self.norm(x)
239 | return x
240 |
241 | def get_last_selfattention(self, x):
242 | x = self.prepare_tokens(x)
243 | for i, blk in enumerate(self.blocks):
244 | if i < len(self.blocks) - 1:
245 | x = blk(x)
246 | else:
247 | # return attention of the last block
248 | return blk(x, return_attention=True)
249 |
250 | def get_intermediate_layers(self, x, n=1):
251 | x = self.prepare_tokens(x)
252 | # we return the output tokens from the `n` last blocks
253 | output = []
254 | for i, blk in enumerate(self.blocks):
255 | x = blk(x)
256 | if len(self.blocks) - i <= n:
257 | output.append(self.norm(x))
258 | return output
259 |
260 | def get_last_key(self, x, extra_layer=None):
261 | x = self.prepare_tokens(x)
262 | key_mid = 0
263 | for i, blk in enumerate(self.blocks):
264 | if extra_layer != None and i == extra_layer:
265 | x, key, attn = blk(x, return_key=True)
266 | key_mid = key
267 | elif i < len(self.blocks) - 1:
268 | x = blk(x)
269 | else:
270 | # return attention of the last block
271 | x, key, attn = blk(x, return_key=True)
272 | if extra_layer == None:
273 | return x, key, attn
274 | else:
275 | return key_mid, x, key, attn
276 |
277 | def get_all_key(self, x):
278 | x = self.prepare_tokens(x)
279 | key_mid = 0
280 | keys=[]
281 | xs=[]
282 | for i, blk in enumerate(self.blocks):
283 | if i < len(self.blocks) - 1:
284 | x, key, attn = blk(x, return_key=True)
285 | keys.append(key)
286 | xs.append(x)
287 | else:
288 | # return attention of the last block
289 | x, key, attn = blk(x, return_key=True)
290 | keys.append(key)
291 | xs.append(x)
292 | return xs, keys, attn
293 |
294 |
295 | def vit_tiny(patch_size=16, **kwargs):
296 | model = VisionTransformer(
297 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
298 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
299 | return model
300 |
301 |
302 | def vit_small(patch_size=16, **kwargs):
303 | model = VisionTransformer(
304 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
305 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
306 | return model
307 |
308 |
309 | def vit_base(patch_size=16, **kwargs):
310 | model = VisionTransformer(
311 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
312 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
313 | return model
314 |
315 |
316 | class DINOHead(nn.Module):
317 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
318 | bottleneck_dim=256):
319 | super().__init__()
320 | nlayers = max(nlayers, 1)
321 | if nlayers == 1:
322 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
323 | else:
324 | layers = [nn.Linear(in_dim, hidden_dim)]
325 | if use_bn:
326 | layers.append(nn.BatchNorm1d(hidden_dim))
327 | layers.append(nn.GELU())
328 | for _ in range(nlayers - 2):
329 | layers.append(nn.Linear(hidden_dim, hidden_dim))
330 | if use_bn:
331 | layers.append(nn.BatchNorm1d(hidden_dim))
332 | layers.append(nn.GELU())
333 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
334 | self.mlp = nn.Sequential(*layers)
335 | self.apply(self._init_weights)
336 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
337 | self.last_layer.weight_g.data.fill_(1)
338 | if norm_last_layer:
339 | self.last_layer.weight_g.requires_grad = False
340 |
341 | def _init_weights(self, m):
342 | if isinstance(m, nn.Linear):
343 | trunc_normal_(m.weight, std=.02)
344 | if isinstance(m, nn.Linear) and m.bias is not None:
345 | nn.init.constant_(m.bias, 0)
346 |
347 | def forward(self, x):
348 | x = self.mlp(x)
349 | x = nn.functional.normalize(x, dim=-1, p=2)
350 | x = self.last_layer(x)
351 | return x
352 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.dino import vision_transformer as vits
4 | from models.dino.utils import load_pretrained_weights
5 | import numpy as np
6 | from torch.nn import functional as F
7 | from AIM.modules.burger import HamburgerV1
8 | from collections import OrderedDict
9 | from pkg_resources import packaging
10 | from simple_tokenizer import SimpleTokenizer as _Tokenizer
11 | from typing import Any, Union, List
12 |
13 | _tokenizer = _Tokenizer()
14 |
15 | class QuickGELU(nn.Module):
16 | def forward(self, x: torch.Tensor):
17 | return x * torch.sigmoid(1.702 * x)
18 |
19 | class ResidualAttentionBlock(nn.Module):
20 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
21 | super().__init__()
22 |
23 | self.attn = nn.MultiheadAttention(d_model, n_head)
24 | self.ln_1 = nn.LayerNorm(d_model)
25 | self.mlp = nn.Sequential(OrderedDict([
26 | ("c_fc", nn.Linear(d_model, d_model * 4)),
27 | ("gelu", QuickGELU()),
28 | ("c_proj", nn.Linear(d_model * 4, d_model))
29 | ]))
30 | self.ln_2 = nn.LayerNorm(d_model)
31 | self.attn_mask = attn_mask
32 |
33 | def attention(self, x: torch.Tensor):
34 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
35 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
36 |
37 | def forward(self, x: torch.Tensor):
38 | x = x + self.attention(self.ln_1(x))
39 | x = x + self.mlp(self.ln_2(x))
40 | return x
41 |
42 | class Transformer(nn.Module):
43 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
44 | super().__init__()
45 | self.width = width
46 | self.layers = layers
47 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
48 |
49 | def forward(self, x: torch.Tensor):
50 | return self.resblocks(x)
51 |
52 |
53 | class Mlp(nn.Module):
54 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
55 | super(Mlp, self).__init__()
56 | out_features = out_features or in_features
57 | hidden_features = hidden_features or in_features
58 | self.norm = nn.LayerNorm(in_features)
59 | self.fc1 = nn.Linear(in_features, hidden_features)
60 | self.act = act_layer()
61 | self.fc2 = nn.Linear(hidden_features, out_features)
62 | self.drop = nn.Dropout(drop)
63 |
64 | def forward(self, x):
65 | x = self.norm(x)
66 | x = self.fc1(x)
67 | x = self.act(x)
68 | x = self.drop(x)
69 | x = self.fc2(x)
70 | x = self.drop(x)
71 | return x
72 |
73 | class PromptLearner(nn.Module):
74 | def __init__(self, classnames, ln_final, token_embedding):
75 | super().__init__()
76 | n_cls = len(classnames)
77 | n_ctx = 16
78 | ctx_init = ""
79 | dtype = ln_final.weight.dtype
80 | ctx_dim = ln_final.weight.shape[0]
81 |
82 | if ctx_init:
83 | # use given words to initialize context vectors
84 | ctx_init = ctx_init.replace("_", " ")
85 | n_ctx = len(ctx_init.split(" "))
86 | prompt = tokenize(ctx_init)
87 | with torch.no_grad():
88 | embedding = token_embedding(prompt).type(dtype)
89 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
90 | prompt_prefix = ctx_init
91 |
92 | else:
93 | print("Initializing a generic context")
94 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
95 | nn.init.normal_(ctx_vectors, std=0.02)
96 | prompt_prefix = " ".join(["X"] * n_ctx)
97 |
98 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized
99 |
100 | classnames = [name.replace("_", " ") for name in classnames]
101 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
102 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
103 |
104 | tokenized_prompts = torch.cat([tokenize(p) for p in prompts])
105 | with torch.no_grad():
106 | embedding = token_embedding(tokenized_prompts).type(dtype)
107 |
108 | # These token vectors will be saved when in save_model(),
109 | # but they should be ignored in load_model() as we want to use
110 | # those computed using the current class names
111 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
112 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
113 |
114 | self.n_cls = n_cls
115 | self.n_ctx = n_ctx
116 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
117 | self.name_lens = name_lens
118 | self.class_token_position = 'end'
119 |
120 | def forward(self):
121 | ctx = self.ctx
122 | if ctx.dim() == 2:
123 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
124 |
125 | prefix = self.token_prefix
126 | suffix = self.token_suffix
127 |
128 | if self.class_token_position == "end":
129 | prompts = torch.cat(
130 | [
131 | prefix, # (n_cls, 1, dim)
132 | ctx, # (n_cls, n_ctx, dim)
133 | suffix, # (n_cls, *, dim)
134 | ],
135 | dim=1,
136 | )
137 |
138 | elif self.class_token_position == "middle":
139 | half_n_ctx = self.n_ctx // 2
140 | prompts = []
141 | for i in range(self.n_cls):
142 | name_len = self.name_lens[i]
143 | prefix_i = prefix[i : i + 1, :, :]
144 | class_i = suffix[i : i + 1, :name_len, :]
145 | suffix_i = suffix[i : i + 1, name_len:, :]
146 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
147 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
148 | prompt = torch.cat(
149 | [
150 | prefix_i, # (1, 1, dim)
151 | ctx_i_half1, # (1, n_ctx//2, dim)
152 | class_i, # (1, name_len, dim)
153 | ctx_i_half2, # (1, n_ctx//2, dim)
154 | suffix_i, # (1, *, dim)
155 | ],
156 | dim=1,
157 | )
158 | prompts.append(prompt)
159 | prompts = torch.cat(prompts, dim=0)
160 |
161 | elif self.class_token_position == "front":
162 | prompts = []
163 | for i in range(self.n_cls):
164 | name_len = self.name_lens[i]
165 | prefix_i = prefix[i : i + 1, :, :]
166 | class_i = suffix[i : i + 1, :name_len, :]
167 | suffix_i = suffix[i : i + 1, name_len:, :]
168 | ctx_i = ctx[i : i + 1, :, :]
169 | prompt = torch.cat(
170 | [
171 | prefix_i, # (1, 1, dim)
172 | class_i, # (1, name_len, dim)
173 | ctx_i, # (1, n_ctx, dim)
174 | suffix_i, # (1, *, dim)
175 | ],
176 | dim=1,
177 | )
178 | prompts.append(prompt)
179 | prompts = torch.cat(prompts, dim=0)
180 |
181 | else:
182 | raise ValueError
183 |
184 | return prompts
185 |
186 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
187 | if isinstance(texts, str):
188 | texts = [texts]
189 |
190 | sot_token = _tokenizer.encoder["<|startoftext|>"]
191 | eot_token = _tokenizer.encoder["<|endoftext|>"]
192 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
193 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
194 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
195 | else:
196 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
197 |
198 | for i, tokens in enumerate(all_tokens):
199 | if len(tokens) > context_length:
200 | if truncate:
201 | tokens = tokens[:context_length]
202 | tokens[-1] = eot_token
203 | else:
204 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
205 | result[i, :len(tokens)] = torch.tensor(tokens)
206 |
207 | return result
208 |
209 |
210 | class AttentionPool2d(nn.Module):
211 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
212 | super().__init__()
213 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
214 | self.k_proj = nn.Linear(embed_dim, embed_dim)
215 | self.q_proj = nn.Linear(embed_dim, embed_dim)
216 | self.v_proj = nn.Linear(embed_dim, embed_dim)
217 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
218 | self.num_heads = num_heads
219 |
220 | def forward(self, x):
221 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
222 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
223 | x_1 = x + self.positional_embedding[:, None, :].to(x.dtype)
224 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
225 | x, att = F.multi_head_attention_forward(
226 | query=x, key=x, value=x,
227 | embed_dim_to_check=x.shape[-1],
228 | num_heads=self.num_heads,
229 | q_proj_weight=self.q_proj.weight,
230 | k_proj_weight=self.k_proj.weight,
231 | v_proj_weight=self.v_proj.weight,
232 | in_proj_weight=None,
233 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
234 | bias_k=None,
235 | bias_v=None,
236 | add_zero_attn=False,
237 | dropout_p=0,
238 | out_proj_weight=self.c_proj.weight,
239 | out_proj_bias=self.c_proj.bias,
240 | use_separate_proj_weight=True,
241 | training=self.training,
242 | need_weights=True
243 | )
244 | return x[0], x[1:], att[:, 1:, 1:]
245 |
246 | class Cross_Attention(nn.Module):
247 | def __init__(self, in_dim, out_dim):
248 | super().__init__()
249 | self.in_dim = in_dim
250 | self.out_dim = out_dim
251 | self.proj_q = nn.Linear(in_dim, out_dim)
252 | self.proj_k = nn.Linear(in_dim, out_dim)
253 | self.proj_v = nn.Linear(in_dim, out_dim)
254 | self.scale = self.out_dim ** (-0.5)
255 |
256 | self.norm = nn.LayerNorm(self.in_dim)
257 | def forward(self, ego, exo):
258 |
259 | B, hw, C = ego.size()
260 | query = self.proj_q(ego)
261 | key = self.proj_k(exo)
262 | value = self.proj_v(exo)
263 |
264 | att = torch.bmm(query, key.transpose(1, 2))*self.scale
265 | att = att.softmax(dim=-1)
266 | out = torch.bmm(att, value)
267 |
268 | out = self.norm(out+ego)
269 |
270 | return out.transpose(1, 2).view(B, C, 14, 14)
271 |
272 | class Model(nn.Module):
273 | def __init__(self, args, embed_dim:int, context_length: int, vocab_size: int,
274 | transformer_width: int, transformer_heads: int,
275 | transformer_layers: int, num_classes=36, pretrained=True, n=3, D=512):
276 | super(Model, self).__init__()
277 | self.num_classes = num_classes
278 | self.criterion = nn.CrossEntropyLoss()
279 | self.pretrained = pretrained
280 | self.n = n
281 | self.D = D
282 | self.context_length = context_length
283 | if args.divide == "Seen":
284 | self.classnames = ['beat', "boxing", "brush_with", "carry", "catch",
285 | "cut", "cut_with", "drag", 'drink_with', "eat",
286 | "hit", "hold", "jump", "kick", "lie_on", "lift",
287 | "look_out", "open", "pack", "peel", "pick_up",
288 | "pour", "push", "ride", "sip", "sit_on", "stick",
289 | "stir", "swing", "take_photo", "talk_on", "text_on",
290 | "throw", "type_on", "wash", "write"]
291 | elif args.divide=="Unseen":
292 | self.classnames = ["carry", "catch", "cut", "cut_with", 'drink_with',
293 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
294 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
295 | "swing", "take_photo", "throw", "type_on", "wash"]
296 | else: # HICO-IIF
297 | self.classnames = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on']
298 |
299 | # dino-vit
300 | self.vit_feat_dim = 384
301 | self.cluster_num = 3
302 | self.stride = 16
303 | self.patch = 16
304 | self.Hamburger = HamburgerV1(in_c=self.vit_feat_dim, n=self.n, D=self.D)
305 | self.vit_model = vits.__dict__['vit_small'](patch_size=self.patch, num_classes=0)
306 | load_pretrained_weights(self.vit_model, '', None, 'vit_small', self.patch)
307 |
308 | self.aff_proj = Mlp(in_features=int(self.vit_feat_dim*2), hidden_features=int(self.vit_feat_dim*2), out_features=self.vit_feat_dim,
309 | act_layer=nn.GELU, drop=0.)
310 | self.aff_ego_proj = nn.Sequential(
311 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1),
312 | nn.BatchNorm2d(self.vit_feat_dim),
313 | nn.ReLU(True),
314 | )
315 | self.aff_exo_proj = nn.Sequential(
316 | nn.Conv2d(self.vit_feat_dim, self.vit_feat_dim, kernel_size=3, stride=1, padding=1),
317 | nn.BatchNorm2d(self.vit_feat_dim),
318 | nn.ReLU(True),
319 | )
320 |
321 | # clip
322 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
323 | self.attnpool = AttentionPool2d(14, self.vit_feat_dim, 64, embed_dim)
324 | self.transformer = Transformer(
325 | width=transformer_width,
326 | layers=transformer_layers,
327 | heads=transformer_heads,
328 | attn_mask=self.build_attention_mask()
329 | )
330 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
331 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
332 | self.ln_final = nn.LayerNorm(transformer_width)
333 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
334 |
335 | self.prompt_learner = PromptLearner(self.classnames, self.ln_final, self.token_embedding)
336 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
337 |
338 | # fc
339 | self.fc = nn.Linear(self.vit_feat_dim, self.num_classes)
340 | self.avgpool = nn.AdaptiveAvgPool2d(1)
341 |
342 | def encode_text(self, per, text):
343 | x = per + self.token_embedding(text.cuda()).float() # [batch_size, n_ctx, d_model]
344 |
345 | x = x + self.positional_embedding.float()
346 | x = x.permute(1, 0, 2) # NLD -> LND
347 | x = self.transformer(x)
348 | x = x.permute(1, 0, 2) # LND -> NLD
349 | x = self.ln_final(x).float()
350 |
351 | # x.shape = [batch_size, n_ctx, transformer.width]
352 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
353 | return x
354 |
355 | def forward(self, exocentric, egocentric_image, label, text):
356 | target = label.long().squeeze()
357 | b, n, c, h, w = exocentric.size()
358 | exocentrin_input = exocentric.view(b * n, c, h, w)
359 |
360 | # dino_vit
361 | with torch.no_grad():
362 | _, ego_key, ego_attn = self.vit_model.get_all_key(egocentric_image)
363 | _, exo_key, exo_attn = self.vit_model.get_all_key(exocentrin_input)
364 | ego_desc = ego_key[len(ego_key)-2].permute(0, 2, 3, 1).flatten(-2, -1).detach()
365 | exo_desc = exo_key[len(ego_key)-2].permute(0, 2, 3, 1).flatten(-2, -1).detach()
366 | for i in range(len(ego_key)-1, len(ego_key)):
367 | ego_desc = torch.cat((ego_desc, ego_key[i].permute(0, 2, 3, 1).flatten(-2, -1).detach()), dim=2)
368 | exo_desc = torch.cat((exo_desc, exo_key[i].permute(0, 2, 3, 1).flatten(-2, -1).detach()), dim=2)
369 |
370 | ego_proj = self.aff_proj(ego_desc[:, 1:])
371 | exo_proj = self.aff_proj(exo_desc[:, 1:])
372 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride)
373 | exo_proj = self._reshape_transform(exo_proj, self.patch, self.stride)
374 |
375 | exo_proj = self.Hamburger(exo_proj)
376 |
377 | # text branch
378 | prompts = self.prompt_learner()
379 | tokenized_prompts = self.tokenized_prompts
380 | text_features = self.encode_text(prompts, tokenized_prompts)
381 |
382 | e_b, e_c, e_h, e_w = ego_proj.shape
383 | pre_ego = ego_proj
384 | image_features, ego_proj, mu_att = self.attnpool(ego_proj)
385 |
386 | image_features = F.normalize(image_features, dim=1, p=2)
387 |
388 | text_features = F.normalize(text_features, dim=1, p=2)
389 |
390 |
391 | # Pixel-Text Fusion
392 | logit_scale = self.logit_scale.exp()
393 | self.logits_per_image = logit_scale * image_features @ text_features.t()
394 | self.logits_per_text = self.logits_per_image.t()
395 |
396 | text_f = torch.ones((e_b, 1024)).cuda()
397 | for i in range(e_b):
398 | text_f[i] = text_features[label[i]]
399 | att_egoproj = F.normalize(ego_proj, dim=1, p=2)
400 | attego = logit_scale *att_egoproj.permute(1, 0, 2)@text_f.unsqueeze(2)
401 | attego = torch.sigmoid(F.normalize(attego, dim=1, p=2)).permute(1, 0, 2).repeat(1, 1, e_c)
402 | ego_proj = attego.permute(1, 2, 0).view(e_b, e_c, e_h, e_w)*pre_ego + pre_ego
403 |
404 | exocentric_branch =self.aff_exo_proj(exo_proj)
405 | egocentric_branch =self.aff_ego_proj(ego_proj)
406 |
407 | # cls
408 | exo_pool = self.avgpool(exocentric_branch)
409 | exo_pool = exo_pool.view(exo_pool.size(0), -1)
410 | self.exo_score = self.fc(exo_pool)
411 |
412 | batch, channel, h, w = exocentric_branch.shape
413 | exocentric_branch = exocentric_branch.view(batch//3, 3, channel, h, w).mean(1)
414 | batch = batch//3
415 |
416 | exo_weight = self.fc.weight[target]
417 | exo_weight = exo_weight.view(batch, channel, 1, 1).expand_as(exocentric_branch)
418 | self.exo_feature = (exo_weight * exocentric_branch)
419 |
420 | self.exo_features = torch.ones(batch, self.num_classes, e_h, e_w).cuda()
421 | label_sum = torch.ones_like(label).cuda()
422 | for m in range(0,self.num_classes):
423 | weight = self.fc.weight[label_sum.long()*m]
424 | weight = weight.view(batch, channel, 1, 1).expand_as(exocentric_branch)
425 | self.exo_features[:,m] = (weight * exocentric_branch).mean(1)
426 |
427 | ego_pool = self.avgpool(egocentric_branch)
428 | ego_pool = ego_pool.view(ego_pool.size(0), -1)
429 | self.ego_score = self.fc(ego_pool)
430 |
431 | ego_weight = self.fc.weight[target]
432 | ego_weight = ego_weight.view(batch, channel, 1, 1).expand_as(egocentric_branch)
433 | self.ego_feature = (ego_weight * egocentric_branch)
434 |
435 | self.ego_features = torch.ones(batch, self.num_classes, e_h, e_w).cuda()
436 | label_sum = torch.ones_like(label).cuda()
437 | for m in range(0,self.num_classes):
438 | gweight = self.fc.weight[label_sum.long()*m]
439 | gweight = gweight.view(batch, channel, 1, 1).expand_as(egocentric_branch)
440 | self.ego_features[:,m] = (gweight * egocentric_branch).mean(1)
441 |
442 | # l_rela
443 | self.exo_att = self.exo_features.view(batch, self.num_classes, -1).transpose(1, 2)
444 | self.exo_att = torch.matmul(self.exo_features.view(batch, self.num_classes, -1), self.exo_att)
445 | self.ego_att = self.ego_features.view(batch, self.num_classes, -1).transpose(1, 2)
446 | self.ego_att = torch.matmul(self.ego_features.view(batch, self.num_classes, -1), self.ego_att)
447 |
448 | return self.exo_score, self.ego_score, self.logits_per_text, self.logits_per_image
449 |
450 | @torch.no_grad()
451 | def get(self, egocentric_image, label, text):
452 |
453 | # dino_vit
454 | _, ego_key, ego_attn = self.vit_model.get_all_key(egocentric_image)
455 | ego_desc = ego_key[len(ego_key)-2].permute(0, 2, 3, 1).flatten(-2, -1).detach()
456 | for i in range(len(ego_key)-1, len(ego_key)):
457 | ego_desc = torch.cat((ego_desc, ego_key[i].permute(0, 2, 3, 1).flatten(-2, -1).detach()), dim=2)
458 | ego_proj = self.aff_proj(ego_desc[:, 1:])
459 | ego_proj = self._reshape_transform(ego_proj, self.patch, self.stride)
460 |
461 | # text branch
462 | prompts = self.prompt_learner()
463 | tokenized_prompts = self.tokenized_prompts
464 | text_features = self.encode_text(prompts, tokenized_prompts)
465 |
466 | e_b, e_c, e_h, e_w = ego_proj.shape
467 | pre_ego = ego_proj
468 | image_features, ego_proj, mu_att = self.attnpool(ego_proj)
469 |
470 | image_features = F.normalize(image_features, dim=1, p=2)
471 | text_features = F.normalize(text_features, dim=1, p=2)
472 | logit_scale = self.logit_scale.exp()
473 | logits_per_image = logit_scale * image_features @ text_features.t()
474 | logits_per_text = logits_per_image.t()
475 |
476 | text_f = torch.ones((e_b, 1024)).cuda()
477 | for i in range(e_b):
478 | text_f[i] = text_features[label[i]]
479 | att_egoproj = F.normalize(ego_proj, dim=1, p=2)
480 | attego = att_egoproj.permute(1, 0, 2)@text_f.unsqueeze(2)
481 | attego = torch.sigmoid(F.normalize(attego, dim=1, p=2)).permute(1, 0, 2).repeat(1, 1, e_c)
482 | ego_proj = attego.permute(1, 2, 0).view(e_b, e_c, e_h, e_w)*pre_ego + pre_ego
483 |
484 | mu_att = mu_att / torch.sum(mu_att, dim=1, keepdim=True)
485 | mu_att = mu_att / torch.sum(mu_att, dim=2, keepdim=True)
486 | for _ in range(2):
487 | mu_att = mu_att / torch.sum(mu_att, dim=1, keepdim=True)
488 | mu_att = mu_att / torch.sum(mu_att, dim=2, keepdim=True)
489 | mu_att = (mu_att + mu_att.permute(0, 2, 1)) / 2
490 | mu_att = torch.matmul(mu_att, mu_att)
491 |
492 | egocentric_branch =self.aff_ego_proj(ego_proj)
493 |
494 | ego_pool = self.avgpool(egocentric_branch)
495 | ego_pool = ego_pool.view(ego_pool.size(0), -1)
496 | self.ego_score = self.fc(ego_pool)
497 |
498 | target = label.long().squeeze()
499 | batch, channel,_,_ = egocentric_branch.shape
500 |
501 | cam_weight = self.fc.weight[target]
502 | cam_weight = cam_weight.view(batch, channel, 1, 1).expand_as(egocentric_branch)
503 | cam = (cam_weight * egocentric_branch).mean(1)
504 |
505 | cam1 = mu_att@(cam.view(batch, -1, 1))
506 | cam1 = cam1.view(batch, e_h, e_w)
507 |
508 | return cam, cam1
509 |
510 | def get_loss(self, gt_label, separate=False):
511 | # loss L_cls
512 | b, h = self.exo_score.shape
513 | self.exo_score = self.exo_score.view(b//3, -1, h)
514 | loss_cls = (self.criterion(self.ego_score, gt_label)+ (self.criterion(self.exo_score[:, 0], gt_label)
515 | + self.criterion(self.exo_score[:, 1], gt_label)
516 | + self.criterion(self.exo_score[:, 2], gt_label))/3)
517 | # loss L_d
518 | exo_branch, ego_branch = self.exo_feature,self.ego_feature
519 | exo_pool = F.adaptive_avg_pool2d(exo_branch, 1).view(exo_branch.size(0), -1)
520 | exo_pool = F.normalize(exo_pool, 2, dim=1)
521 | ego_pool = F.adaptive_avg_pool2d(ego_branch, 1).view(ego_branch.size(0), -1)
522 | ego_pool = F.normalize(ego_pool, 2, dim=1)
523 | loss_dist =0.5 * (((exo_pool - ego_pool) ** 2).sum(1)).mean(0)
524 |
525 | # loss L_lrela
526 | exo_att, ego_att = self.exo_att,self.ego_att
527 | attb, c, _ = exo_att.size()
528 | exo_att = F.normalize(exo_att, 2, dim=2).view(attb, -1)
529 | ego_att = F.normalize(ego_att, 2, dim=2).view(attb, -1)
530 | loss_att = 0.5*(1 - F.cosine_similarity(exo_att, ego_att, dim=1).mean(0))
531 |
532 | return loss_cls, loss_dist, loss_att
533 |
534 | def _reshape_transform(self, tensor, patch_size, stride):
535 | height = (224 - patch_size) // stride + 1
536 | width = (224 - patch_size) // stride + 1
537 | result = tensor.reshape(tensor.size(0), height, width, tensor.size(-1))
538 | result = result.transpose(2, 3).transpose(1, 2).contiguous()
539 | return result
540 |
541 | def build_attention_mask(self):
542 | # lazily create causal attention mask, with full attention between the vision tokens
543 |
544 | mask = torch.empty(self.context_length, self.context_length)
545 | mask.fill_(float("-inf"))
546 | mask.triu_(1) # zero out the lower diagonal
547 | return mask
548 |
549 | def MODEL(args, num_classes=36,
550 | pretrained=True, n=3, D=512):
551 | dict = 'RN50.pt' # clip's pre-trained model
552 | state_dict = torch.jit.load(dict)
553 | state_dict = state_dict.state_dict()
554 |
555 | embed_dim = state_dict["text_projection"].shape[1]
556 | context_length = state_dict["positional_embedding"].shape[0]
557 | vocab_size = state_dict["token_embedding.weight"].shape[0]
558 | transformer_width = state_dict["ln_final.weight"].shape[0]
559 | transformer_heads = transformer_width // 64
560 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
561 |
562 | model = Model(args, embed_dim=embed_dim, context_length=context_length, vocab_size=vocab_size, transformer_width=transformer_width,
563 | transformer_heads=transformer_heads,transformer_layers=transformer_layers, num_classes=num_classes, pretrained=pretrained, n=n, D=D)
564 |
565 | model_dict = model.state_dict()
566 | par = []
567 | pretrained_dict = {}
568 | for para in model.named_parameters():
569 | k = para[0]
570 | if k in state_dict:
571 | par.append(para[0])
572 | for k, v in state_dict.items():
573 | if k in model_dict:
574 | pretrained_dict[k] = v
575 |
576 | model_dict.update(pretrained_dict)
577 | model.load_state_dict(model_dict)
578 | return model, par
579 |
--------------------------------------------------------------------------------
/preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import torch
5 | gt_root="your_path here"
6 | files=os.listdir(gt_root)
7 | dict_1={}
8 | for file in files:
9 | file_path=os.path.join(gt_root,file)
10 | objs=os.listdir(file_path)
11 | for obj in objs:
12 | obj_path=os.path.join(file_path,obj)
13 | images=os.listdir(obj_path)
14 | for img in images:
15 | img_path=os.path.join(obj_path,img)
16 | mask = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
17 | key=file+"_"+obj+"_"+img
18 | dict_1[key]=mask
19 | torch.save(dict_1,"filename.t7")
20 |
--------------------------------------------------------------------------------
/save_models/README.md:
--------------------------------------------------------------------------------
1 | # Weakly Supervised Multimodal Affordance Grounding for Egocentric Images
2 | ## Paper
3 | >Weakly Supervised Multimodal Affordance Grounding for Egocentric Images(AAAI 2024)
4 |
5 | Link: https://doi.org/10.1609/aaai.v38i6.28451
6 |
7 | >Appendix
8 |
9 | Link:
10 |
11 | >Video
12 |
13 | Link:
14 |
15 |
16 |
17 | **Abstract:**
18 |
19 | To enhance the interaction between intelligent systems and the environment, locating the affordance regions of objects is crucial. These regions correspond to specific areas that provide distinct functionalities. Humans often acquire the ability to identify these regions through action demonstrations and verbal instructions. In this paper, we present a novel multimodal framework that extracts affordance knowledge from exocentric images, which depict human-object interactions, as well as from accompanying textual descriptions that describe the performed actions. The extracted knowledge is then transferred to egocentric images. To achieve this goal, we propose the HOI-Transfer Module, which utilizes local perception to disentangle individual actions within exocentric images. This module effectively captures localized features and correlations between actions, leading to valuable affordance knowledge. Additionally, we introduce the Pixel-Text Fusion Module, which fuses affordance knowledge by identifying regions in egocentric images that bear resemblances to the textual features defining affordances. We employ a Weakly Supervised Multimodal Affordance (WSMA) learning approach, utilizing image-level labels for training. Through extensive experiments, we demonstrate the superiority of our proposed method in terms of evaluation metrics and visual results when compared to existing affordance grounding models. Furthermore, ablation experiments confirm the effectiveness of our approach.
20 |
21 | ## Requirements
22 | We run in the following environment:
23 | - A GeForce RTX 3090
24 | - Python(3.8)
25 | - Pytorch(1.10.0)
26 |
27 |
--------------------------------------------------------------------------------
/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from torch.autograd import Variable
4 | from PIL import Image
5 | import torchvision
6 | import torchvision.transforms as transforms
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import cv2
10 | import os
11 | from models.model import MODEL
12 | import argparse
13 | from utils.evaluation import cal_kl, cal_sim, cal_nss
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--data_root', type=str, default='data_path')
17 | parser.add_argument('--phase', type=str, default='test')
18 | parser.add_argument("--divide", type=str, default="Unseen") #"Seen" or "Unseen" or "HICO-IIF"
19 | parser.add_argument("--model_path", type=str, default="save_models_path") # the model weight path
20 | parser.add_argument("--crop_size", type=int, default=224)
21 | parser.add_argument("--batch_size", type=int, default=1)
22 | parser.add_argument('--threshold', type=float, default='0.2')
23 | # parser.add_argument("--init_weights", type=bool, default=False)
24 | parser.add_argument('--num_workers', type=int, default=1)
25 | args = parser.parse_args()
26 |
27 |
28 |
29 | def normalize_map(atten_map):
30 | min_val = np.min(atten_map)
31 | max_val = np.max(atten_map)
32 | atten_norm = (atten_map - min_val) / (max_val - min_val + 1e-10)
33 |
34 | return atten_norm
35 |
36 |
37 | if args.divide == "Seen":
38 | args.num_classes = 36
39 | aff_list = ['beat', "boxing", "brush_with", "carry", "catch", "cut", "cut_with", "drag", 'drink_with',
40 | "eat", "hit", "hold", "jump", "kick", "lie_on", "lift", "look_out", "open", "pack", "peel",
41 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick", "stir", "swing", "take_photo",
42 | "talk_on", "text_on", "throw", "type_on", "wash", "write"]
43 | elif args.divide=="Unseen":
44 | aff_list = ["carry", "catch", "cut", "cut_with", 'drink_with',
45 | "eat", "hit", "hold", "jump", "kick", "lie_on", "open", "peel",
46 | "pick_up", "pour", "push", "ride", "sip", "sit_on", "stick",
47 | "swing", "take_photo", "throw", "type_on", "wash"]
48 | else: # HICO-IIF
49 | aff_list = ['cut_with', 'drink_with', 'hold', 'open', 'pour', 'sip', 'stick', 'stir', 'swing', 'type_on']
50 |
51 | args.test_root = os.path.join(args.data_root, args.divide, "testset", "egocentric")
52 | args.mask_root = os.path.join(args.data_root, args.divide, "testset", "GT")
53 |
54 | model, par = MODEL(args, num_classes=len(aff_list), pretrained=False)
55 | model.load_state_dict(torch.load(args.model_path))
56 | model.eval()
57 | model.cuda()
58 |
59 | import datatest
60 |
61 | testset = datatest.TrainData(egocentric_root=args.test_root, crop_size=args.crop_size, divide=args.divide, mask_root=args.mask_root)
62 | MyDataLoader = torch.utils.data.DataLoader(dataset=testset,
63 | batch_size=args.batch_size,
64 | shuffle=False,
65 | num_workers=args.num_workers,
66 | pin_memory=True)
67 |
68 | dict_1 = {}
69 | for step, (image, label, mask_path, name) in enumerate(MyDataLoader):
70 | label = label.cuda(non_blocking=True)
71 | image = image.cuda(
72 | non_blocking=True)
73 | cam, cam1 = model.get(image, label, name)
74 | cam = cam[0].cpu().detach().numpy()
75 | cam1 = cam1[0].cpu().detach().numpy()
76 | cam = normalize_map(cam)
77 | cam1 = normalize_map(cam1)
78 | cam1[cam np.ndarray:
6 | map1, map2 = pred / (pred.sum() + eps), gt / (gt.sum() + eps)
7 | kld = np.sum(map2 * np.log(map2 / (map1 + eps) + eps))
8 | return kld
9 |
10 |
11 | def cal_sim(pred: np.ndarray, gt: np.ndarray, eps=1e-12) -> np.ndarray:
12 | map1, map2 = pred / (pred.sum() + eps), gt / (gt.sum() + eps)
13 | intersection = np.minimum(map1, map2)
14 |
15 | return np.sum(intersection)
16 |
17 |
18 | def image_binary(image, threshold):
19 | output = np.zeros(image.size).reshape(image.shape)
20 | for xx in range(image.shape[0]):
21 | for yy in range(image.shape[1]):
22 | if (image[xx][yy] > threshold):
23 | output[xx][yy] = 1
24 | return output
25 |
26 |
27 | def cal_nss(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
28 | pred = pred / 255.0
29 | gt = gt / 255.0
30 | std = np.std(pred)
31 | u = np.mean(pred)
32 |
33 | smap = (pred - u) / std
34 | fixation_map = (gt - np.min(gt)) / (np.max(gt) - np.min(gt) + 1e-12)
35 | fixation_map = image_binary(fixation_map, 0.1)
36 |
37 | nss = smap * fixation_map
38 |
39 | nss = np.sum(nss) / np.sum(fixation_map + 1e-12)
40 |
41 | return nss
42 |
--------------------------------------------------------------------------------
/utils/gtransforms.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/yjxiong/tsn-pytorch/blob/master/transforms.py
2 |
3 | import torchvision
4 | import random
5 | from PIL import Image
6 | import numbers
7 | import torch
8 | import torchvision.transforms.functional as F
9 |
10 | class GroupResize(object):
11 | def __init__(self, size, interpolation=Image.BILINEAR):
12 | self.worker = torchvision.transforms.Resize(size, interpolation)
13 |
14 | def __call__(self, img_group):
15 | return [self.worker(img) for img in img_group]
16 |
17 | class GroupRandomCrop(object):
18 | def __init__(self, size):
19 | if isinstance(size, numbers.Number):
20 | self.size = (int(size), int(size))
21 | else:
22 | self.size = size
23 |
24 | def __call__(self, img_group):
25 |
26 | w, h = img_group[0].size
27 | th, tw = self.size
28 |
29 | out_images = list()
30 |
31 | x1 = random.randint(0, w - tw)
32 | y1 = random.randint(0, h - th)
33 |
34 | for img in img_group:
35 | assert(img.size[0] == w and img.size[1] == h)
36 | if w == tw and h == th:
37 | out_images.append(img)
38 | else:
39 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
40 |
41 | return out_images
42 |
43 | class GroupCenterCrop(object):
44 | def __init__(self, size):
45 | self.worker = torchvision.transforms.CenterCrop(size)
46 |
47 | def __call__(self, img_group):
48 | return [self.worker(img) for img in img_group]
49 |
50 | class GroupRandomHorizontalFlip(object):
51 | def __call__(self, img_group):
52 | if random.random() < 0.5:
53 | img_group = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
54 | return img_group
55 |
56 | class ToTensor(object):
57 | def __init__(self):
58 | self.worker = torchvision.transforms.ToTensor()
59 |
60 | def __call__(self, img_group):
61 | img_group = [self.worker(img) for img in img_group]
62 | return torch.stack(img_group, 0)
63 |
64 | class GroupNormalize(object):
65 | def __init__(self, mean, std):
66 | self.mean = mean
67 | self.std = std
68 |
69 | def __call__(self, tensor): # (T, 3, 224, 224)
70 | for b in range(tensor.size(0)):
71 | for t, m, s in zip(tensor[b], self.mean, self.std):
72 | t.sub_(m).div_(s)
73 | return tensor
74 |
75 | class ZeroPad(object):
76 | def __init__(self, max_len):
77 | self.max_len = max_len
78 |
79 | def __call__(self, tensor):
80 | if tensor.size(0)==self.max_len:
81 | return tensor
82 |
83 | n_pad = self.max_len - tensor.size(0)
84 | pad = torch.zeros(n_pad, tensor.size(1), tensor.size(2), tensor.size(3))
85 | tensor = torch.cat([tensor, pad], 0) # (T, 3, 224, 224)
86 | return tensor
--------------------------------------------------------------------------------
/utils/transform.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import cv2
4 | import torch
5 | from torch.utils import data
6 | from torchvision import transforms
7 | from torchvision.transforms import functional as F
8 | import numbers
9 | import numpy as np
10 | import random
11 | import json
12 |
13 | def load_image(pah,image_size=320,if_resize=False):
14 |
15 | if not os.path.exists(pah):
16 | print(pah)
17 | print('File Not Exists')
18 |
19 | im = cv2.imread(pah)[:,:,::-1]
20 | if if_resize:
21 | im = cv2.resize(im, (image_size,image_size))
22 | #im = randomGaussianBlur(im)
23 | in_ = np.array(im, dtype=np.float32)
24 | in_ = in_.transpose((2, 0, 1))
25 | return in_
26 |
27 | def Normalize(image,mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225]):
28 | image /= 255.0
29 | image=image.transpose((1,2,0))
30 | image-=np.array((mean[0],mean[1],mean[2]))
31 | image/=np.array((std[0],std[1],std[2]))
32 |
33 | image=image.transpose((2,0,1))
34 |
35 | return image
36 |
37 | def load_image_test(pah,if_resize=False,image_size=320):
38 | if not os.path.exists(pah):
39 | print(pah)
40 | print('File Not Exists')
41 |
42 | im = cv2.imread(pah)
43 | if if_resize:
44 | im = cv2.resize(im, (image_size,image_size))
45 | in_ = np.array(im, dtype=np.float32)
46 | im_size = tuple(in_.shape[:2])
47 |
48 | in_ = in_.transpose((2, 0, 1))
49 | return in_, im_size
50 |
51 | def load_label(pah,image_size=320):
52 | """
53 | Load label image as 1 x height x width integer array of label indices.
54 | The leading singleton dimension is required by the loss.
55 | """
56 | if not os.path.exists(pah):
57 | print('File Not Exists')
58 | im = Image.open(pah)
59 | #im = im.resize((image_size,image_size))
60 | label = np.array(im, dtype=np.float32)
61 | if len(label.shape) == 3:
62 | label = label[:, :, 0]
63 | # label = cv2.resize(label, im_sz, interpolation=cv2.INTER_NEAREST)
64 | label = label / 255.
65 | label = label[np.newaxis, ...]
66 | return label
67 |
68 | def cv_random_flip(img, label, edge):
69 | flip_flag = random.randint(0, 1)
70 | if flip_flag == 1:
71 | img = img[:, :, ::-1].copy()
72 | label = label[:, :, ::-1].copy()
73 | edge = edge[:, :, ::-1].copy()
74 | return img, label, edge
75 |
76 | def cv_random_crop_flip(img, depth,object_image,resize_size, crop_size, random_flip=True):
77 | def get_params(img_size, output_size):
78 | h, w = img_size
79 | th, tw = output_size
80 | if w == tw and h == th:
81 | return 0, 0, h, w
82 | i = random.randint(0, h - th)
83 | j = random.randint(0, w - tw)
84 | return i, j, th, tw
85 |
86 | if random_flip:
87 | flip_flag = random.randint(0, 1)
88 | img = img.transpose((1, 2, 0)) # H, W, C
89 | depth = depth.transpose((1, 2, 0))
90 | object_image=object_image.transpose((1,2,0))
91 |
92 | img = cv2.resize(img, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_LINEAR)
93 | depth = cv2.resize(depth, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST)
94 | object_image = cv2.resize(object_image, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST)
95 |
96 |
97 | i, j, h, w = get_params(resize_size, crop_size)
98 | img = img[i:i + h, j:j + w, :].transpose((2, 0, 1)) # C, H, W
99 | depth = depth[i:i + h, j:j + w, :].transpose((2, 0, 1)) # C, H, W
100 | object_image = object_image[i:i + h, j:j + w, :].transpose((2, 0, 1)) # C, H, W
101 |
102 | if flip_flag == 1:
103 | img = img[:, :, ::-1].copy()
104 | depth = depth[:, :, ::-1].copy()
105 | object_image = object_image[:, :, ::-1].copy()
106 |
107 | return img, depth,object_image
108 |
109 | def cv_random_crop_flip_ref(img, obj_mask,per_mask,resize_size, crop_size, random_flip=True):
110 | def get_params(img_size, output_size):
111 | h, w = img_size
112 | th, tw = output_size
113 | if w == tw and h == th:
114 | return 0, 0, h, w
115 | i = random.randint(0, h - th)
116 | j = random.randint(0, w - tw)
117 | return i, j, th, tw
118 |
119 | if random_flip:
120 | flip_flag = random.randint(0, 1)
121 | img = img.transpose((1, 2, 0)) # H, W, C
122 |
123 | img = cv2.resize(img, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_LINEAR)
124 | obj_mask = cv2.resize(obj_mask, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST)
125 | per_mask = cv2.resize(per_mask, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST)
126 |
127 | i, j, h, w = get_params(resize_size, crop_size)
128 | img = img[i:i + h, j:j + w, :].transpose((2, 0, 1)) # C, H, W
129 | obj_mask = obj_mask[i:i + h, j:j + w][np.newaxis, ...] # 1, H, W
130 | per_mask = per_mask[i:i + h, j:j + w][np.newaxis, ...]
131 |
132 | if flip_flag == 1:
133 | img = img[:, :, ::-1].copy()
134 | obj_mask = obj_mask[:, :, ::-1].copy()
135 | per_mask=per_mask[:, :, ::-1].copy()
136 | obj_mask = obj_mask[0, :, :] # H, W
137 | per_mask=per_mask[0,:,:]
138 | return img, obj_mask,per_mask
139 |
140 | def cv_center_crop(img, label, edge_label,resize_size, crop_size, random_flip=True):
141 | def get_params(img_size, output_size):
142 | h, w = img_size
143 | th, tw = output_size
144 | if w == tw and h == th:
145 | return 0, 0, h, w
146 | i = (h - th) / 2
147 | j = (w - tw) / 2
148 | return i, j, th, tw
149 |
150 | img = img.transpose((1, 2, 0)) # H, W, C
151 | label = label[0, :, :] # H, W
152 | img = cv2.resize(img, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_LINEAR)
153 | label = cv2.resize(label, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST)
154 | edge_label=cv2.resize(edge_label,(resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST)
155 | i, j, h, w = get_params(resize_size, crop_size)
156 | img = img[i:i + h, j:j + w, :].transpose((2, 0, 1)) # C, H, W
157 | label = label[i:i + h, j:j + w][np.newaxis, ...] # 1, H, W
158 | edge_label=edge_label[i:i + h, j:j + w][np.newaxis, ...]
159 |
160 | return img, label,edge_label
161 |
162 | def random_crop(img, label, edge_label, size, padding=None, pad_if_needed=True, fill_img=(123, 116, 103), fill_label=0,
163 | padding_mode='constant'):
164 | def get_params(img, output_size):
165 | w, h = img.size
166 | th, tw = output_size
167 | if w == tw and h == th:
168 | return 0, 0, h, w
169 |
170 | i = random.randint(0, h - th)
171 | j = random.randint(0, w - tw)
172 | return i, j, th, tw
173 |
174 | if isinstance(size, numbers.Number):
175 | size = (int(size), int(size))
176 | if padding is not None:
177 | img = F.pad(img, padding, fill_img, padding_mode)
178 | label = F.pad(label, padding, fill_label, padding_mode)
179 | edge_label=F.pad(edge_label,padding,fill_label, padding_mode)
180 |
181 | # pad the width if needed
182 | if pad_if_needed and img.size[0] < size[1]:
183 | img = F.pad(img, (int((1 + size[1] - img.size[0]) / 2), 0), fill_img, padding_mode)
184 | label = F.pad(label, (int((1 + size[1] - label.size[0]) / 2), 0), fill_label, padding_mode)
185 | edge_label=F.pad(edge_label,(int((1 + size[1] - label.size[0]) / 2), 0), fill_label, padding_mode)
186 |
187 | # pad the height if needed
188 | if pad_if_needed and img.size[1] < size[0]:
189 | img = F.pad(img, (0, int((1 + size[0] - img.size[1]) / 2)), fill_img, padding_mode)
190 | label = F.pad(label, (0, int((1 + size[0] - label.size[1]) / 2)), fill_label, padding_mode)
191 | edge_label=F.pad(edge_label,(0, int((1 + size[0] - label.size[1]) / 2)), fill_label, padding_mode)
192 |
193 | i, j, h, w = get_params(img, size)
194 | return [F.crop(img, i, j, h, w), F.crop(label, i, j, h, w)],[F.crop(edge_label,i,j,h,w)]
195 |
196 | def random_crop_ref(img, obj_mask, per_mask, size, padding=None, pad_if_needed=True, fill_img=(123, 116, 103), fill_label=0,
197 | padding_mode='constant'):
198 | def get_params(img, output_size):
199 | w, h = img.size
200 | th, tw = output_size
201 | if w == tw and h == th:
202 | return 0, 0, h, w
203 |
204 | i = random.randint(0, h - th)
205 | j = random.randint(0, w - tw)
206 | return i, j, th, tw
207 |
208 | if isinstance(size, numbers.Number):
209 | size = (int(size), int(size))
210 | if padding is not None:
211 | img = F.pad(img, padding, fill_img, padding_mode)
212 | obj_mask = F.pad(obj_mask, padding, fill_label, padding_mode)
213 | per_mask=F.pad(per_mask,padding,fill_label, padding_mode)
214 |
215 | # pad the width if needed
216 | if pad_if_needed and img.size[0] < size[1]:
217 | img = F.pad(img, (int((1 + size[1] - img.size[0]) / 2), 0), fill_img, padding_mode)
218 | obj_mask = F.pad(obj_mask, (int((1 + size[1] - obj_mask.size[0]) / 2), 0), fill_label, padding_mode)
219 | per_mask = F.pad(per_mask,(int((1 + size[1] - per_mask.size[0]) / 2), 0), fill_label, padding_mode)
220 |
221 | # pad the height if needed
222 | if pad_if_needed and img.size[1] < size[0]:
223 | img = F.pad(img, (0, int((1 + size[0] - img.size[1]) / 2)), fill_img, padding_mode)
224 | obj_mask = F.pad(obj_mask, (0, int((1 + size[0] - obj_mask.size[1]) / 2)), fill_label, padding_mode)
225 | per_mask = F.pad(per_mask, (0, int((1 + size[0] - per_mask.size[1]) / 2)), fill_label, padding_mode)
226 |
227 | i, j, h, w = get_params(img, size)
228 | return [F.crop(img, i, j, h, w), F.crop(obj_mask, i, j, h, w)],[F.crop(per_mask,i,j,h,w)]
229 |
230 |
231 | def center_crop(img, label, edge_label, size, padding=None, pad_if_needed=True, fill_img=(123, 116, 103), fill_label=0,
232 | padding_mode='constant'):
233 | def get_params(img, output_size):
234 | w, h = img.size
235 | th, tw = output_size
236 | if w == tw and h == th:
237 | return 0, 0, h, w
238 |
239 | i = (h-th)/2
240 | j = (w-tw)/2
241 | return i, j, th, tw
242 |
243 | if isinstance(size, numbers.Number):
244 | size = (int(size), int(size))
245 | if padding is not None:
246 | img = F.pad(img, padding, fill_img, padding_mode)
247 | label = F.pad(label, padding, fill_label, padding_mode)
248 | edge_label=F.pad(edge_label,padding,fill_label, padding_mode)
249 |
250 | # pad the width if needed
251 | if pad_if_needed and img.size[0] < size[1]:
252 | img = F.pad(img, (int((1 + size[1] - img.size[0]) / 2), 0), fill_img, padding_mode)
253 | label = F.pad(label, (int((1 + size[1] - label.size[0]) / 2), 0), fill_label, padding_mode)
254 | edge_label=F.pad(edge_label,(int((1 + size[1] - label.size[0]) / 2), 0), fill_label, padding_mode)
255 |
256 | # pad the height if needed
257 | if pad_if_needed and img.size[1] < size[0]:
258 | img = F.pad(img, (0, int((1 + size[0] - img.size[1]) / 2)), fill_img, padding_mode)
259 | label = F.pad(label, (0, int((1 + size[0] - label.size[1]) / 2)), fill_label, padding_mode)
260 | edge_label=F.pad(edge_label,(0, int((1 + size[0] - label.size[1]) / 2)), fill_label, padding_mode)
261 |
262 | i, j, h, w = get_params(img, size)
263 | return [F.crop(img, i, j, h, w), F.crop(label, i, j, h, w)],[F.crop(edge_label,i,j,h,w)]
264 |
265 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 |
8 |
9 | def default_mean_std():
10 | mean = [0.485, 0.456, 0.406]
11 | std = [0.229, 0.224, 0.225]
12 | return mean, std
13 |
14 |
15 | def load_img(fl):
16 | return Image.open(fl).convert('RGB')
17 |
18 |
19 | def batch_cuda(batch):
20 | return {k: v.cuda() if type(v) == torch.Tensor else v for k, v in batch.items()}
21 |
22 |
23 | def unnormalize(tensor):
24 | mean, std = default_mean_std()
25 | u_tensor = tensor.clone()
26 |
27 | def _unnorm(t):
28 | for c in range(3):
29 | t[c].mul_(std[c]).add_(mean[c])
30 |
31 | if u_tensor.dim() == 4:
32 | [_unnorm(t) for t in u_tensor]
33 | else:
34 | _unnorm(u_tensor)
35 |
36 | return u_tensor
37 |
38 |
39 | def default_transform(split):
40 | mean, std = default_mean_std()
41 |
42 | if split == 'train':
43 | transform = transforms.Compose([
44 | transforms.Resize(256),
45 | transforms.RandomCrop(224),
46 | transforms.RandomHorizontalFlip(),
47 | transforms.ToTensor(),
48 | transforms.Normalize(mean, std)
49 | ])
50 | else:
51 | transform = transforms.Compose([
52 | # transforms.Resize(256),
53 | transforms.CenterCrop(224),
54 | transforms.ToTensor(),
55 | transforms.Normalize(mean, std)
56 | ])
57 |
58 | return transform
59 |
60 |
61 | import utils.gtransforms as gtransforms
62 |
63 |
64 | def clip_transform(split, max_len):
65 | mean, std = default_mean_std()
66 |
67 | if split == 'train':
68 | transform = transforms.Compose([
69 | gtransforms.GroupResize(256),
70 | gtransforms.GroupRandomCrop(224),
71 | gtransforms.GroupRandomHorizontalFlip(),
72 | gtransforms.ToTensor(),
73 | gtransforms.GroupNormalize(mean, std),
74 | gtransforms.ZeroPad(max_len),
75 | ])
76 |
77 | elif split == 'val':
78 | transform = transforms.Compose([
79 | gtransforms.GroupResize(256),
80 | gtransforms.GroupCenterCrop(224),
81 | gtransforms.ToTensor(),
82 | gtransforms.GroupNormalize(mean, std),
83 | gtransforms.ZeroPad(max_len),
84 | ])
85 |
86 | return transform
87 |
88 |
89 | import torchvision.transforms.functional as TF
90 | import torch.nn.functional as F
91 | import numbers
92 |
93 |
94 | class PairedTransform:
95 |
96 | def __init__(self, split, std_norm=True):
97 | self.split = split
98 | self.mean, self.std = default_mean_std()
99 | self.std_norm = std_norm
100 |
101 | def train_transform(self, image, heatmap):
102 |
103 | heatmap = torch.from_numpy(heatmap)
104 |
105 | # PIL and Tensor have inverted shapes
106 | # image.size == (606, 479) while heatmap.shape == (479, 606)
107 | assert (heatmap.shape[1], heatmap.shape[0]) == image.size, 'image and heatmap sizes mismatch (%s vs. %s)' % (
108 | image.size, heatmap.shape)
109 |
110 | # resize
111 | image = TF.resize(image, size=(256, 256))
112 | heatmap = \
113 | F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), size=image.size, mode='bilinear', align_corners=False)[0][0]
114 |
115 | # random crop
116 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(224, 224))
117 | image = TF.crop(image, i, j, h, w)
118 | heatmap = heatmap[i:i + w, j:j + h]
119 |
120 | # horizontal flip
121 | if np.random.rand() < 0.5:
122 | image = TF.hflip(image)
123 | heatmap = heatmap.flip(1)
124 |
125 | # to tensor + normalize
126 | image = TF.to_tensor(image)
127 |
128 | if self.std_norm:
129 | image = TF.normalize(image, self.mean, self.std)
130 |
131 | if heatmap.sum().item() != 0:
132 | heatmap = heatmap / heatmap.sum()
133 |
134 | return image, heatmap
135 |
136 | def val_transform(self, image, heatmap):
137 |
138 | # PIL and Tensor have inverted shapes
139 | # image.size == (606, 479) while heatmap.shape == (479, 606)
140 | assert (heatmap.shape[1], heatmap.shape[0]) == image.size, 'image and heatmap sizes mismatch (%s vs. %s)' % (
141 | image.size, heatmap.shape)
142 |
143 | heatmap = torch.from_numpy(heatmap)
144 |
145 | # resize
146 | image = TF.resize(image, size=(224, 224))
147 | heatmap = \
148 | F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), size=image.size, mode='bilinear', align_corners=False)[0][0]
149 |
150 | # to tensor + normalize
151 | image = TF.to_tensor(image)
152 |
153 | if self.std_norm:
154 | image = TF.normalize(image, self.mean, self.std)
155 |
156 | if heatmap.sum().item() != 0:
157 | heatmap = heatmap / heatmap.sum()
158 |
159 | return image, heatmap
160 |
161 | def __call__(self, image, heatmap):
162 |
163 | if self.split == 'train':
164 | image, heatmap = self.train_transform(image, heatmap)
165 | elif self.split == 'val':
166 | image, heatmap = self.val_transform(image, heatmap)
167 |
168 | return image, heatmap
169 |
--------------------------------------------------------------------------------