├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── common.py ├── functools.py ├── models ├── __init__.py ├── defaults.py ├── modules.py └── thundernet.py └── transforms ├── __init__.py ├── detection ├── __init__.py └── functional.py ├── ext.py └── segmentation └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 qixuxiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_Lightweight_Network 2 | Lightweight Networks such as MobileNet, ShuffleNet and ThunderNet implemented in Pytorch 3 | 4 | ## introduce 5 | intend to reproducing the original result of the papers of MobileNet, ShuffleNet and ThunderNet in PyTorch. 6 | 7 | ## Prerequisites 8 | 9 | Python>=3.6 10 | PyTorch>=1.0 11 | torchsummary 12 | 13 | 14 | I have no GPU resources to finish my train and test code, I really appreciate it if you can provide me with GPU device, please write at qixuxiang@outlook.com 15 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from common import cuda, one_hot, cpu, detach, Args 2 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Sequence, Mapping 3 | 4 | import torch 5 | 6 | 7 | class Args(tuple): 8 | def __new__(cls, *args): 9 | return super().__new__(cls, tuple(args)) 10 | 11 | def __repr__(self): 12 | return "Args" + super().__repr__() 13 | 14 | 15 | def one_hot(tensor, C=None, dtype=torch.float): 16 | d = tensor.dim() 17 | C = C or tensor.max() + 1 18 | t = tensor.new_zeros(*tensor.size(), C, dtype=dtype) 19 | return t.scatter_(d, tensor.unsqueeze(d), 1) 20 | 21 | 22 | CUDA = torch.cuda.is_available() 23 | 24 | 25 | def detach(t, clone=True): 26 | if torch.is_tensor(t): 27 | if clone: 28 | return t.clone().detach() 29 | else: 30 | return t.detach() 31 | elif isinstance(t, Args): 32 | return t 33 | elif isinstance(t, Sequence): 34 | return t.__class__(detach(x, clone) for x in t) 35 | elif isinstance(t, Mapping): 36 | return t.__class__((k, detach(v, clone)) for k, v in t.items()) 37 | else: 38 | return t 39 | 40 | 41 | def cuda(t): 42 | if torch.is_tensor(t): 43 | return t.cuda() if CUDA else t 44 | elif isinstance(t, Sequence): 45 | return t.__class__(cuda(x) for x in t) 46 | elif isinstance(t, Mapping): 47 | return t.__class__((k, cuda(v)) for k, v in t.items()) 48 | else: 49 | return t 50 | 51 | 52 | def cpu(t): 53 | if torch.is_tensor(t): 54 | return t.cpu() 55 | elif isinstance(t, Sequence): 56 | return t.__class__(cpu(x) for x in t) 57 | elif isinstance(t, Mapping): 58 | return t.__class__((k, cpu(v)) for k, v in t.items()) 59 | else: 60 | return t 61 | 62 | 63 | def _tuple(x, n=-1): 64 | if x is None: 65 | return () 66 | elif torch.is_tensor(x): 67 | return (x,) 68 | elif not isinstance(x, Sequence): 69 | assert n > 0, "Length must be positive, but got %d" % n 70 | return (x,) * n 71 | else: 72 | if n == -1: 73 | n = len(x) 74 | else: 75 | assert len(x) == n, "The length of x is %d, not equal to the expected length %d" % (len(x), n) 76 | return tuple(x) 77 | 78 | 79 | def select0(t, indices): 80 | arange = torch.arange(t.size(1), device=t.device) 81 | return t[indices, arange] 82 | 83 | 84 | def select1(t, indices): 85 | arange = torch.arange(t.size(0), device=t.device) 86 | return t[arange, indices] 87 | 88 | 89 | def select(t, dim, indices): 90 | if dim == 0: 91 | return select0(t, indices) 92 | elif dim == 1: 93 | return select1(t, indices) 94 | else: 95 | raise ValueError("dim could be only 0 or 1, not %d" % dim) 96 | 97 | 98 | def sample(t, n): 99 | if len(t) >= n: 100 | indices = torch.randperm(len(t), device=t.device)[:n] 101 | else: 102 | indices = torch.randint(len(t), size=(n,), device=t.device) 103 | return t[indices] 104 | 105 | 106 | def _concat(xs, dim=1): 107 | if torch.is_tensor(xs): 108 | return xs 109 | elif len(xs) == 1: 110 | return xs[0] 111 | else: 112 | return torch.cat(xs, dim=dim) 113 | 114 | 115 | def inverse_sigmoid(x, eps=1e-6, inplace=False): 116 | if not torch.is_tensor(x): 117 | if eps != 0: 118 | x = min(max(x, eps), 1-eps) 119 | return math.log(x / (1 - x)) 120 | if inplace: 121 | return inverse_sigmoid_(x, eps) 122 | if eps != 0: 123 | x = torch.clamp(x, eps, 1-eps) 124 | return (x / (1 - x)).log() 125 | 126 | 127 | def inverse_sigmoid_(x, eps=1e-6): 128 | if eps != 0: 129 | x = torch.clamp_(x, eps, 1 - eps) 130 | return x.div_(1 - x).log_() 131 | 132 | 133 | def expand_last_dim(t, *size): 134 | return t.view(*t.size()[:-1], *size) -------------------------------------------------------------------------------- /functools.py: -------------------------------------------------------------------------------- 1 | from toolz.curried import curry, isiterable, map, filter 2 | 3 | 4 | def lmap(f, *iterables): 5 | return list(map(f, *iterables)) 6 | 7 | 8 | @curry 9 | def recursive_lmap(f, iterable): 10 | if isiterable(next(iter(iterable))): 11 | return lmap(recursive_lmap(f), iterable) 12 | else: 13 | return lmap(f, iterable) 14 | 15 | 16 | @curry 17 | def find(f, seq): 18 | try: 19 | return next(filter(lambda x: f(x[1]), enumerate(seq)))[0] 20 | except StopIteration: 21 | return None 22 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.defaults import set_default_activation 2 | from models.defaults import set_default_norm_layer -------------------------------------------------------------------------------- /models/defaults.py: -------------------------------------------------------------------------------- 1 | __ACTIVATION__ = 'relu' 2 | __NORM_LAYER__ = 'bn' 3 | 4 | 5 | def get_default_activation(): 6 | global __ACTIVATION__ 7 | return __ACTIVATION__ 8 | 9 | 10 | def set_default_activation(name): 11 | global __ACTIVATION__ 12 | __ACTIVATION__ = name 13 | 14 | 15 | def get_default_norm_layer(): 16 | global __NORM_LAYER__ 17 | return __NORM_LAYER__ 18 | 19 | 20 | def set_default_norm_layer(name): 21 | global __NORM_LAYER__ 22 | __NORM_LAYER__ = name 23 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from toolz import curry 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from common import _tuple 9 | from models.defaults import get_default_activation, get_default_norm_layer 10 | 11 | 12 | def hardsigmoid(x, inplace=False): 13 | return F.relu6(x + 3, inplace=inplace) / 6 14 | 15 | 16 | def hardswish(x, inplace=False): 17 | return x * (F.relu6(x + 3, inplace=inplace) / 6) 18 | 19 | 20 | def swish(x): 21 | return x * torch.sigmoid(x) 22 | 23 | 24 | class Swish(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | def forward(self, x): 29 | return swish(x) 30 | 31 | 32 | class HardSigmoid(nn.Module): 33 | def __init__(self, inplace=False): 34 | super().__init__() 35 | self.inplace = inplace 36 | 37 | def forward(self, x): 38 | return hardsigmoid(x, self.inplace) 39 | 40 | def extra_repr(self): 41 | inplace_str = 'inplace' if self.inplace else '' 42 | return inplace_str 43 | 44 | 45 | class HardSwish(nn.Module): 46 | def __init__(self, inplace=False): 47 | super().__init__() 48 | self.inplace = inplace 49 | 50 | def forward(self, x): 51 | return hardswish(x, self.inplace) 52 | 53 | def extra_repr(self): 54 | inplace_str = 'inplace' if self.inplace else '' 55 | return inplace_str 56 | 57 | 58 | def upsample_add(x, y): 59 | r""" 60 | Upsample x and add it to y. 61 | 62 | Parameters 63 | ---------- 64 | x : torch.Tensor 65 | tensor to upsample 66 | y : torch.Tensor 67 | tensor to be added 68 | """ 69 | h, w = y.size()[2:4] 70 | return F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False) + y 71 | 72 | 73 | def upsample_concat(x, y): 74 | h, w = y.size()[2:4] 75 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False) 76 | return torch.cat((x, y), dim=1) 77 | 78 | 79 | def get_groups(channels, ref=32): 80 | xs = filter(lambda x: channels % x == 0, range(2, channels + 1)) 81 | c = min(xs, key=lambda x: abs(x - ref)) 82 | if c < 8: 83 | c = max(c, channels // c) 84 | return channels // c 85 | 86 | 87 | def get_norm_layer(name, channels): 88 | if isinstance(name, nn.Module): 89 | return name 90 | elif hasattr(name, '__call__'): 91 | return name(channels) 92 | elif name == 'default': 93 | return get_norm_layer(get_default_norm_layer(), channels) 94 | elif name == 'bn': 95 | return nn.BatchNorm2d(channels) 96 | elif name == 'gn': 97 | num_groups = get_groups(channels, 32) 98 | return nn.GroupNorm(num_groups, channels) 99 | else: 100 | raise NotImplementedError("No normalization named %s" % name) 101 | 102 | 103 | def get_attention(name, **kwargs): 104 | if not name: 105 | return Identity() 106 | name = name.lower() 107 | if name == 'se': 108 | return SEModule(**kwargs) 109 | elif name == 'sem': 110 | return SELayerM(**kwargs) 111 | elif name == 'cbam': 112 | return CBAM(**kwargs) 113 | else: 114 | raise NotImplementedError("No attention module named %s" % name) 115 | 116 | 117 | def get_activation(name): 118 | if isinstance(name, nn.Module): 119 | return name 120 | if name == 'default': 121 | return get_activation(get_default_activation()) 122 | elif name == 'relu': 123 | return nn.ReLU(inplace=True) 124 | elif name == 'relu6': 125 | return nn.ReLU6(inplace=True) 126 | elif name == 'leaky_relu': 127 | return nn.LeakyReLU(negative_slope=0.1, inplace=True) 128 | elif name == 'sigmoid': 129 | return nn.Sigmoid() 130 | elif name == 'hswish': 131 | return HardSwish(inplace=True) 132 | elif name == 'swish': 133 | return Swish() 134 | else: 135 | raise NotImplementedError("No activation named %s" % name) 136 | 137 | 138 | def Conv2d(in_channels, out_channels, 139 | kernel_size, stride=1, 140 | padding='same', dilation=1, groups=1, bias=None, 141 | norm_layer=None, activation=None, depthwise_separable=False, mid_norm_layer=None, transposed=False): 142 | if depthwise_separable: 143 | assert kernel_size != 1, "No need to use depthwise separable convolution in 1x1" 144 | # if norm_layer is None: 145 | # assert mid_norm_layer is not None, "`mid_norm_layer` must be provided when `norm_layer` is None" 146 | # else: 147 | if mid_norm_layer is None: 148 | mid_norm_layer = norm_layer 149 | return DWConv2d(in_channels, out_channels, kernel_size, stride, padding, bias, mid_norm_layer, norm_layer, activation, transposed) 150 | if padding == 'same': 151 | if isinstance(kernel_size, tuple): 152 | kh, kw = kernel_size 153 | ph = (kh - 1) // 2 154 | pw = (kw - 1) // 2 155 | padding = (ph, pw) 156 | else: 157 | padding = (kernel_size - 1) // 2 158 | layers = [] 159 | if bias is None: 160 | bias = norm_layer is None 161 | if transposed: 162 | conv = nn.ConvTranspose2d( 163 | in_channels, out_channels, 164 | kernel_size, stride, padding, dilation=dilation, groups=groups, bias=bias) 165 | else: 166 | conv = nn.Conv2d( 167 | in_channels, out_channels, 168 | kernel_size, stride, padding, dilation=dilation, groups=groups, bias=bias) 169 | if activation is not None: 170 | if activation == 'sigmoid': 171 | nn.init.xavier_normal_(conv.weight) 172 | elif activation == 'leaky_relu': 173 | nn.init.kaiming_normal_(conv.weight, a=0.1, nonlinearity='leaky_relu') 174 | else: 175 | try: 176 | nn.init.kaiming_normal_(conv.weight, nonlinearity=activation) 177 | except ValueError: 178 | nn.init.kaiming_normal_(conv.weight, nonlinearity='relu') 179 | else: 180 | nn.init.kaiming_normal_(conv.weight, nonlinearity='relu') 181 | if bias: 182 | nn.init.zeros_(conv.bias) 183 | 184 | if norm_layer is not None: 185 | if norm_layer == 'default': 186 | norm_layer = get_default_norm_layer() 187 | layers.append(get_norm_layer(norm_layer, out_channels)) 188 | if activation is not None: 189 | layers.append(get_activation(activation)) 190 | layers = [conv] + layers 191 | if len(layers) == 1: 192 | return layers[0] 193 | else: 194 | return nn.Sequential(*layers) 195 | 196 | 197 | def Linear(in_channels, out_channels, bias=None, norm_layer=None, activation=None): 198 | layers = [] 199 | if bias is None: 200 | bias = norm_layer is None 201 | fc = nn.Linear( 202 | in_channels, out_channels, bias=bias) 203 | if activation is not None: 204 | if activation == 'sigmoid': 205 | nn.init.xavier_normal_(fc.weight) 206 | elif activation == 'leaky_relu': 207 | nn.init.kaiming_normal_(fc.weight, a=0.1, nonlinearity='leaky_relu') 208 | else: 209 | try: 210 | nn.init.kaiming_normal_(fc.weight, nonlinearity=activation) 211 | except ValueError: 212 | nn.init.kaiming_normal_(fc.weight, nonlinearity='relu') 213 | else: 214 | nn.init.kaiming_normal_(fc.weight, nonlinearity='relu') 215 | if bias: 216 | nn.init.zeros_(fc.bias) 217 | 218 | if norm_layer == 'default' or norm_layer == 'bn': 219 | layers.append(nn.BatchNorm1d(out_channels)) 220 | if activation is not None: 221 | layers.append(get_activation(activation)) 222 | layers = [fc] + layers 223 | if len(layers) == 1: 224 | return layers[0] 225 | else: 226 | return nn.Sequential(*layers) 227 | 228 | 229 | def Pool(name, kernel_size, stride=1, padding='same', ceil_mode=False): 230 | if padding == 'same': 231 | if isinstance(kernel_size, tuple): 232 | kh, kw = kernel_size 233 | ph = (kh - 1) // 2 234 | pw = (kw - 1) // 2 235 | padding = (ph, pw) 236 | else: 237 | padding = (kernel_size - 1) // 2 238 | if name == 'avg': 239 | return nn.AvgPool2d(kernel_size, stride, padding, ceil_mode=ceil_mode, count_include_pad=False) 240 | elif name == 'max': 241 | return nn.MaxPool2d(kernel_size, stride, padding, ceil_mode=ceil_mode) 242 | else: 243 | raise NotImplementedError("No activation named %s" % name) 244 | 245 | 246 | class SEModule(nn.Module): 247 | def __init__(self, in_channels, reduction=8): 248 | super().__init__() 249 | channels = in_channels // reduction 250 | self.pool = nn.AdaptiveAvgPool2d(1) 251 | self.layers = nn.Sequential( 252 | nn.Linear(in_channels, channels), 253 | nn.ReLU(True), 254 | nn.Linear(channels, in_channels), 255 | nn.Sigmoid(), 256 | ) 257 | 258 | def forward(self, x): 259 | b, c = x.size()[:2] 260 | s = self.pool(x).view(b, c) 261 | s = self.layers(s).view(b, c, 1, 1) 262 | return x * s 263 | 264 | 265 | class CBAMChannelAttention(nn.Module): 266 | def __init__(self, in_channels, reduction=8): 267 | super().__init__() 268 | channels = in_channels // reduction 269 | self.mlp = nn.Sequential( 270 | nn.Linear(in_channels, channels), 271 | nn.ReLU(True), 272 | nn.Linear(channels, in_channels), 273 | ) 274 | 275 | def forward(self, x): 276 | b, c = x.size()[:2] 277 | aa = F.adaptive_avg_pool2d(x, 1).view(b, c) 278 | aa = self.mlp(aa) 279 | am = F.adaptive_max_pool2d(x, 1).view(b, c) 280 | am = self.mlp(am) 281 | a = torch.sigmoid(aa + am).view(b, c, 1, 1) 282 | return x * a 283 | 284 | 285 | class CBAMSpatialAttention(nn.Module): 286 | def __init__(self): 287 | super().__init__() 288 | self.conv = Conv2d(2, 1, kernel_size=7, norm_layer='bn') 289 | 290 | def forward(self, x): 291 | aa = x.mean(dim=1, keepdim=True) 292 | am = x.max(dim=1, keepdim=True)[0] 293 | a = torch.cat([aa, am], dim=1) 294 | a = torch.sigmoid(self.conv(a)) 295 | return x * a 296 | 297 | 298 | class CBAM(nn.Module): 299 | def __init__(self, in_channels, reduction=4): 300 | super().__init__() 301 | self.channel = CBAMChannelAttention(in_channels, reduction) 302 | self.spatial = CBAMSpatialAttention() 303 | 304 | def forward(self, x): 305 | x = self.channel(x) 306 | x = self.spatial(x) 307 | return x 308 | 309 | 310 | class SELayerM(nn.Module): 311 | def __init__(self, in_channels, reduction=4): 312 | super().__init__() 313 | channels = in_channels // reduction 314 | self.avgpool = nn.AdaptiveAvgPool2d(1) 315 | self.layers = nn.Sequential( 316 | nn.Linear(in_channels, channels), 317 | nn.ReLU6(True), 318 | nn.Linear(channels, in_channels), 319 | HardSigmoid(True), 320 | ) 321 | 322 | def forward(self, x): 323 | b, c = x.size()[:2] 324 | s = self.avgpool(x).view(b, c) 325 | s = self.layers(s).view(b, c, 1, 1) 326 | return x * s 327 | 328 | 329 | @curry 330 | def DWConv2d(in_channels, out_channels, 331 | kernel_size=3, stride=1, 332 | padding='same', bias=True, mid_norm_layer='default', 333 | norm_layer=None, activation=None, transposed=False): 334 | return nn.Sequential( 335 | Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels, 336 | norm_layer=mid_norm_layer, transposed=transposed), 337 | Conv2d(in_channels, out_channels, kernel_size=1, 338 | norm_layer=norm_layer, activation=activation, bias=bias), 339 | ) 340 | 341 | 342 | class Sequential(nn.Sequential): 343 | def __init__(self, *args, **kwargs): 344 | super().__init__(*args) 345 | if 'inference' in kwargs: 346 | self._inference = kwargs['inference'] 347 | 348 | def forward(self, *xs): 349 | for module in self._modules.values(): 350 | xs = module(*_tuple(xs)) 351 | return xs 352 | 353 | def inference(self, *xs): 354 | self.eval() 355 | with torch.no_grad(): 356 | xs = self.forward(*xs) 357 | preds = self._inference(*_tuple(xs)) 358 | self.train() 359 | return preds 360 | 361 | 362 | class Identity(nn.Module): 363 | def __init__(self, *args, **kwargs): 364 | super().__init__() 365 | 366 | def forward(self, x): 367 | return x 368 | 369 | 370 | class DropConnect(nn.Module): 371 | def __init__(self, p=0.2): 372 | super().__init__() 373 | assert 0 <= p <= 1, "drop probability has to be between 0 and 1, but got %f" % p 374 | self.p = p 375 | 376 | def forward(self, x): 377 | if not self.training or self.p == 0: 378 | return x 379 | keep_prob = 1.0 - self.p 380 | batch_size = x.size(0) 381 | t = torch.rand(batch_size, 1, 1, 1, dtype=x.dtype, device=x.device) < keep_prob 382 | x = (x / keep_prob).masked_fill(t, 0) 383 | return x 384 | 385 | def extra_repr(self): 386 | return 'p={}'.format(self.p) 387 | 388 | 389 | class Flatten(nn.Module): 390 | def __init__(self): 391 | super().__init__() 392 | 393 | 394 | def forward(self, x): 395 | return x.view(x.size(0), -1) 396 | -------------------------------------------------------------------------------- /models/thundernet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchsummary import summary 7 | import numpy as np 8 | 9 | from modules import Conv2d, DWConv2d, SEModule 10 | 11 | CEM_FILTER=245 12 | 13 | def channel_shuffle(x, g): 14 | n, c, h, w = x.size() 15 | x = x.view(n, g, c // g, h, w).permute( 16 | 0, 2, 1, 3, 4).contiguous().view(n, c, h, w) 17 | return x 18 | ''' 19 | def channel_shuffle(x, groups): 20 | batchsize, num_channels, height, width = x.data.size() 21 | channels_per_group = num_channels // groups 22 | 23 | # reshape 24 | x = x.view(batchsize, groups, 25 | channels_per_group, height, width) 26 | 27 | x = torch.transpose(x, 1, 2).contiguous() 28 | 29 | # flatten 30 | x = x.view(batchsize, -1, height, width) 31 | 32 | return x 33 | 34 | ''' 35 | 36 | class ShuffleBlock(nn.Module): 37 | def __init__(self, groups): 38 | super().__init__() 39 | self.groups = groups 40 | 41 | def forward(self, x): 42 | return channel_shuffle(x, g=self.groups) 43 | 44 | 45 | 46 | class InvertedResidual(nn.Module): 47 | def __init__(self, inp, oup, stride): 48 | super(InvertedResidual, self).__init__() 49 | 50 | if not (1 <= stride <= 3): 51 | raise ValueError('illegal stride value') 52 | self.stride = stride 53 | 54 | branch_features = oup // 2 55 | assert (self.stride != 1) or (inp == branch_features << 1) 56 | 57 | pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False) 58 | dw_conv33 = functools.partial(self.depthwise_conv, 59 | kernel_size=3, stride=self.stride, padding=1) 60 | 61 | if self.stride > 1: 62 | self.branch1 = nn.Sequential( 63 | dw_conv33(inp, inp), 64 | nn.BatchNorm2d(inp), 65 | pw_conv11(inp, branch_features), 66 | nn.BatchNorm2d(branch_features), 67 | nn.ReLU(inplace=True), 68 | ) 69 | 70 | self.branch2 = nn.Sequential( 71 | pw_conv11(inp if (self.stride > 1) else branch_features, branch_features), 72 | nn.BatchNorm2d(branch_features), 73 | nn.ReLU(inplace=True), 74 | dw_conv33(branch_features, branch_features), 75 | nn.BatchNorm2d(branch_features), 76 | pw_conv11(branch_features, branch_features), 77 | nn.BatchNorm2d(branch_features), 78 | nn.ReLU(inplace=True), 79 | ) 80 | 81 | @staticmethod 82 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 83 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 84 | 85 | def forward(self, x): 86 | if self.stride == 1: 87 | x1, x2 = x.chunk(2, dim=1) 88 | out = torch.cat((x1, self.branch2(x2)), dim=1) 89 | else: 90 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 91 | 92 | out = channel_shuffle(out, 2) 93 | 94 | return out 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__(self, stages_repeats, stages_out_channels, num_classes): 98 | super(ShuffleNetV2, self).__init__() 99 | 100 | if len(stages_repeats) != 3: 101 | raise ValueError('expected stages_repeats as list of 3 positive ints') 102 | if len(stages_out_channels) != 5: 103 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 104 | self._stage_out_channels = stages_out_channels 105 | 106 | input_channels = 3 107 | output_channels = self._stage_out_channels[0] 108 | self.conv1 = nn.Sequential( 109 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 110 | nn.BatchNorm2d(output_channels), 111 | nn.ReLU(inplace=True), 112 | ) 113 | input_channels = output_channels 114 | 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | 117 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 118 | for name, repeats, output_channels in zip( 119 | stage_names, stages_repeats, self._stage_out_channels[1:]): 120 | seq = [InvertedResidual(input_channels, output_channels, 2)] 121 | for i in range(repeats - 1): 122 | seq.append(InvertedResidual(output_channels, output_channels, 1)) 123 | setattr(self, name, nn.Sequential(*seq)) 124 | input_channels = output_channels 125 | 126 | output_channels = self._stage_out_channels[-1] 127 | self.conv5 = nn.Sequential( 128 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 129 | nn.BatchNorm2d(output_channels), 130 | nn.ReLU(inplace=True), 131 | ) 132 | 133 | self.fc = nn.Linear(output_channels, num_classes) 134 | 135 | def forward(self, x): 136 | x = self.conv1(x) 137 | x = self.maxpool(x) 138 | x = self.stage2(x) 139 | x = self.stage3(x) 140 | x = self.stage4(x) 141 | x = self.conv5(x) 142 | x = x.mean([2, 3]) # globalpool 143 | x = self.fc(x) 144 | return x 145 | 146 | class CEM(nn.Module): 147 | """Context Enhancement Module""" 148 | def __init__(self, in_channels, kernel_size=1, stride=1): 149 | super(CEM, self).__init__() 150 | self.conv4 = Conv2d(in_channels, CEM_FILTER, kernel_size, bias=True) 151 | self.conv5 = Conv2d(in_channels, CEM_FILTER, kernel_size, bias=True) 152 | 153 | def forward(self, inputs): 154 | # in keras NHWC 155 | # in torch NCHW 156 | C4_lat = self.conv4(inputs[0]) 157 | C5_lat = self.conv5(inputs[1]) 158 | C5_lat = nn.UpsamplingBilinear2d(scale_factor=2) 159 | Cglb_lat = inputs[2].view(-1, CEM_FILTER, 1, 1) 160 | return C4_lat + C5_lat + Cglb_lat 161 | 162 | class RPN(nn.Module): 163 | """region proposal network""" 164 | def __init__(self, in_channels=245, f_channels=256): 165 | super(RPN, self).__init__() 166 | self.num_anchors = 5*5 167 | self.rpn = DWConv2d( 168 | in_channels, f_channels, kernel_size=6, 169 | mid_norm_layer='default', norm_layer='default', 170 | activation='default') 171 | self.loc_conv = Conv2d(f_channels, 2*self.num_anchors, kernel_size=1, strides=1, 172 | padding='valid', bias=True 173 | ) 174 | self.rpn_cls_pred = Conv2d(2*self.num_anchors, 4*self.num_anchors, kernel_size=1, 175 | strides=1, padding='valid', bias=True 176 | ) 177 | 178 | 179 | 180 | 181 | class SAM(nn.Module): 182 | """spatial attention module""" 183 | def __init__(self, in_channels, kernel_size=1, stride=1): 184 | super(SAM, self).__init__() 185 | self.conv1 = Conv2d( 186 | in_channels, CEM_FILTER, kernel_size, padding="valid", 187 | norm_layer='default' 188 | ) 189 | 190 | def forward(self, inputs): 191 | x = self.conv1(inputs[0]) 192 | x = F.softmax(x, dim=1) 193 | x = x.mul(inputs[1]) 194 | return x 195 | 196 | class BasicBlock(nn.Module): 197 | def __init__(self, in_channels, shuffle_groups=2, with_se=False): 198 | super().__init__() 199 | self.with_se = with_se 200 | channels = in_channels // 2 201 | self.conv1 = Conv2d( 202 | channels, channels, kernel_size=1, 203 | norm_layer='default', activation='default', 204 | ) 205 | self.conv2 = Conv2d( 206 | channels, channels, kernel_size=5, groups=channels, 207 | norm_layer='default', 208 | ) 209 | self.conv3 = Conv2d( 210 | channels, channels, kernel_size=1, 211 | norm_layer='default', activation='default', 212 | ) 213 | if with_se: 214 | self.se = SEModule(channels, reduction=8) 215 | self.shuffle = ShuffleBlock(shuffle_groups) 216 | 217 | def forward(self, x): 218 | x = x.contiguous() 219 | c = x.size(1) // 2 220 | x1 = x[:, :c, :, :] 221 | x2 = x[:, c:, :, :] 222 | x2 = self.conv1(x2) 223 | x2 = self.conv2(x2) 224 | x2 = self.conv3(x2) 225 | if self.with_se: 226 | x2 = self.se(x2) 227 | x = torch.cat((x1, x2), dim=1) 228 | x = self.shuffle(x) 229 | return x 230 | 231 | 232 | class DownBlock(nn.Module): 233 | def __init__(self, in_channels, out_channels, shuffle_groups=2, **kwargs): 234 | super().__init__() 235 | channels = out_channels // 2 236 | self.conv11 = Conv2d( 237 | in_channels, in_channels, kernel_size=5, stride=2, groups=in_channels, 238 | norm_layer='default', 239 | ) 240 | self.conv12 = Conv2d( 241 | in_channels, channels, kernel_size=1, 242 | norm_layer='default', activation='default', 243 | ) 244 | self.conv21 = Conv2d( 245 | in_channels, channels, kernel_size=1, 246 | norm_layer='default', activation='default', 247 | ) 248 | self.conv22 = Conv2d( 249 | channels, channels, kernel_size=5, stride=2, groups=channels, 250 | norm_layer='default', 251 | ) 252 | self.conv23 = Conv2d( 253 | channels, channels, kernel_size=1, 254 | norm_layer='default', activation='default', 255 | ) 256 | self.shuffle = ShuffleBlock(shuffle_groups) 257 | 258 | def forward(self, x): 259 | x1 = self.conv11(x) 260 | 261 | x1 = self.conv12(x1) 262 | 263 | x2 = self.conv21(x) 264 | x2 = self.conv22(x2) 265 | x2 = self.conv23(x2) 266 | 267 | x = torch.cat((x1, x2), dim=1) 268 | x = self.shuffle(x) 269 | return x 270 | 271 | 272 | class SNet(nn.Module): 273 | cfg = { 274 | 49: [24, 60, 120, 240, 512], 275 | 146: [24, 132, 264, 528], 276 | 535: [48, 248, 496, 992], 277 | } 278 | 279 | def __init__(self, num_classes=CEM_FILTER, version=49, **kwargs): 280 | super().__init__() 281 | num_layers = [4, 8, 4] 282 | self.num_layers = num_layers 283 | channels = self.cfg[version] 284 | self.channels = channels 285 | 286 | self.conv1 = Conv2d( 287 | 3, channels[0], kernel_size=3, stride=2, 288 | activation='default', **kwargs 289 | ) 290 | self.maxpool = nn.MaxPool2d( 291 | kernel_size=3, stride=2, padding=1, 292 | ) 293 | self.stage2 = self._make_layer( 294 | num_layers[0], channels[0], channels[1], **kwargs) 295 | self.stage3 = self._make_layer( 296 | num_layers[1], channels[1], channels[2], **kwargs) 297 | self.stage4 = self._make_layer( 298 | num_layers[2], channels[2], channels[3], **kwargs) 299 | if len(self.channels) == 5: 300 | self.conv5 = Conv2d( 301 | channels[3], channels[4], kernel_size=1, **kwargs) 302 | 303 | self.avgpool = nn.AdaptiveAvgPool2d(1) 304 | self.fc = nn.Linear(channels[-1], num_classes) 305 | 306 | def _make_layer(self, num_layers, in_channels, out_channels, **kwargs): 307 | layers = [DownBlock(in_channels, out_channels, **kwargs)] 308 | for i in range(num_layers - 1): 309 | layers.append(BasicBlock(out_channels, **kwargs)) 310 | return nn.Sequential(*layers) 311 | 312 | def forward(self, x): 313 | x = self.conv1(x) 314 | x = self.maxpool(x) 315 | x = self.stage2(x) 316 | x = self.stage3(x) 317 | x = self.stage4(x) 318 | if len(self.channels) == 5: 319 | x = self.conv5(x) 320 | ''' 321 | x = self.avgpool(x) 322 | x = x.view(x.size(0), -1) 323 | x = self.fc(x) 324 | ''' 325 | 326 | return x 327 | 328 | def main(): 329 | #shape = (10, 320, 320, 3) #NHWC in tf/keras 330 | shape = (10, 16, 320, 320) 331 | nx = np.random.rand(*shape).astype(np.float32) 332 | t = torch.Tensor(nx) 333 | ''' 334 | g = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], CEM_FILTER) 335 | x = g(t) 336 | print(x.shape) #torch.Size([10, 245]) 337 | ''' 338 | #senet_49 = SNet() 339 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 340 | model = SNet().to(device) 341 | 342 | summary(model, (3, 320, 320)) 343 | 344 | 345 | if __name__ == "__main__": 346 | main() -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | import torchvision.transforms.functional as VF 4 | 5 | 6 | class Transform: 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def __call__(self, input, target): 12 | pass 13 | 14 | def __repr__(self): 15 | return pprint(self) 16 | 17 | 18 | class JointTransform(Transform): 19 | 20 | def __init__(self, transform=None): 21 | super().__init__() 22 | self.transform = transform 23 | 24 | def __call__(self, input, target): 25 | return self.transform(input, target) 26 | 27 | def __repr__(self): 28 | return self.__class__.__name__ + '()' 29 | 30 | 31 | class InputTransform(Transform): 32 | 33 | def __init__(self, transform): 34 | super().__init__() 35 | self.transform = transform 36 | 37 | def __call__(self, input, target): 38 | return self.transform(input), target 39 | 40 | # def __repr__(self): 41 | # return pprint(self) 42 | # format_string = self.__class__.__name__ + '(' 43 | # format_string += '\n' 44 | # format_string += ' {0}'.format(self.transform) 45 | # format_string += '\n)' 46 | # return format_string 47 | 48 | 49 | class TargetTransform(Transform): 50 | 51 | def __init__(self, transform): 52 | super().__init__() 53 | self.transform = transform 54 | 55 | def __call__(self, input, target): 56 | return input, self.transform(target) 57 | 58 | # def __repr__(self): 59 | # return pprint(self) 60 | 61 | 62 | class Compose(Transform): 63 | """Composes several transforms together. 64 | 65 | Args: 66 | transforms (list of ``Transform`` objects): list of transforms to compose. 67 | 68 | Example: 69 | >>> transforms.Compose([ 70 | >>> transforms.CenterCrop(10), 71 | >>> transforms.ToTensor(), 72 | >>> ]) 73 | """ 74 | 75 | def __init__(self, transforms): 76 | self.transforms = transforms 77 | 78 | def __call__(self, img, target): 79 | for t in self.transforms: 80 | # start = time.time() 81 | img, target = t(img, target) 82 | # print("%.4f" % ((time.time() - start) * 1000)) 83 | return img, target 84 | 85 | # def __repr__(self): 86 | # return pprint(self) 87 | # format_string = self.__class__.__name__ + '(' 88 | # for t in self.transforms: 89 | # format_string += '\n' 90 | # format_string += ' {0}'.format(t) 91 | # format_string += '\n)' 92 | # return format_string 93 | 94 | 95 | class UseOriginal(Transform): 96 | """Use the original image and annotations. 97 | """ 98 | 99 | def __init__(self): 100 | pass 101 | 102 | def __call__(self, img, target): 103 | return img, target 104 | 105 | 106 | class RandomApply(Transform): 107 | 108 | def __init__(self, transforms, p=0.5): 109 | self.transforms = transforms 110 | self.p = p 111 | 112 | def __call__(self, img, target): 113 | if random.random() < self.p: 114 | for t in self.transforms: 115 | img, target = t(img, target) 116 | return img, target 117 | 118 | # def __repr__(self): 119 | # return pprint(self) 120 | # format_string = self.__class__.__name__ + '(' 121 | # for t in self.transforms: 122 | # format_string += '\n' 123 | # format_string += ' {0}'.format(t) 124 | # format_string += '\n)' 125 | # return format_string 126 | 127 | 128 | class RandomChoice(Transform): 129 | """Apply single transformation randomly picked from a list. 130 | 131 | Args: 132 | transforms (list of ``Transform`` objects): list of transforms to compose. 133 | 134 | Example: 135 | >>> transforms.RandomChoice([ 136 | >>> transforms.CenterCrop(10), 137 | >>> transforms.ToTensor(), 138 | >>> ]) 139 | """ 140 | 141 | def __init__(self, transforms): 142 | self.transforms = transforms 143 | 144 | def __call__(self, img, target): 145 | t = random.choice(self.transforms) 146 | img, target = t(img, target) 147 | return img, target 148 | 149 | # def __repr__(self): 150 | # return pprint(self) 151 | # format_string = self.__class__.__name__ + '(' 152 | # for t in self.transforms: 153 | # format_string += '\n' 154 | # format_string += ' {0}'.format(t) 155 | # format_string += '\n)' 156 | # return format_string 157 | 158 | 159 | class ToTensor(JointTransform): 160 | 161 | def __init__(self): 162 | super().__init__() 163 | 164 | def __call__(self, img, anns): 165 | return VF.to_tensor(img), anns 166 | 167 | 168 | def pprint(t, level=0, sep=' '): 169 | pre = sep * level 170 | if not isinstance(t, Transform) or isinstance(t, JointTransform): 171 | return pre + repr(t) 172 | format_string = pre + type(t).__name__ + '(' 173 | if hasattr(t, 'transforms'): 174 | for t in getattr(t, 'transforms'): 175 | format_string += '\n' 176 | format_string += pprint(t, level + 1) 177 | format_string += '\n' 178 | format_string += pre + ')' 179 | elif hasattr(t, 'transform'): 180 | format_string += '\n' 181 | format_string += pprint(getattr(t, 'transform'), level + 1) 182 | format_string += '\n' 183 | format_string += pre + ')' 184 | else: 185 | format_string += ')' 186 | return format_string 187 | 188 | -------------------------------------------------------------------------------- /transforms/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import warnings 4 | 5 | from typing import Sequence, Tuple 6 | 7 | from PIL import Image 8 | import torchvision.transforms.functional as VF 9 | from torchvision.transforms import ColorJitter 10 | 11 | from transforms import JointTransform, Compose, ToTensor, InputTransform, RandomChoice, RandomApply, UseOriginal 12 | from transforms.detection import functional as HF 13 | 14 | 15 | class RandomExpand(JointTransform): 16 | """ 17 | Expand the given PIL Image to random size. 18 | 19 | This is popularly used to train the SSD-like detectors. 20 | 21 | Parameters 22 | ---------- 23 | ratios : ``tuple`` 24 | Range of expand ratio. 25 | """ 26 | 27 | def __init__(self, ratios=(1, 4)): 28 | super().__init__() 29 | self.ratios = ratios 30 | 31 | def __call__(self, img, anns): 32 | width, height = img.size 33 | ratio = random.uniform(*self.ratios) 34 | left = random.uniform(0, width * ratio - width) 35 | top = random.uniform(0, height * ratio - height) 36 | expand_image = Image.new( 37 | img.mode, (int(width * ratio), int(height * ratio))) 38 | expand_image.paste(img, (int(left), int(top))) 39 | 40 | new_anns = HF.move(anns, left, top) 41 | if len(new_anns) == 0: 42 | return img, anns 43 | return expand_image, new_anns 44 | 45 | def __repr__(self): 46 | format_string = self.__class__.__name__ 47 | format_string += '(ratio={0})'.format(tuple(round(r, 4) 48 | for r in self.ratios)) 49 | return format_string 50 | 51 | 52 | 53 | class RandomSampleCrop(JointTransform): 54 | """ 55 | Crop the given PIL Image to random size and aspect ratio. 56 | 57 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 58 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 59 | is finally resized to given size. 60 | This is popularly used to train the Inception networks. 61 | 62 | Parameters 63 | ---------- 64 | min_ious : ``List[float]`` 65 | Range of minimal iou between the objects and the cropped image. 66 | aspect_ratio_constraints : ``tuple`` 67 | Range of cropped aspect ratio. 68 | """ 69 | 70 | def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.9), aspect_ratio_constraints=(0.5, 2)): 71 | super().__init__() 72 | self.min_ious = min_ious 73 | min_ar, max_ar = aspect_ratio_constraints 74 | self.min_ar = min_ar 75 | self.max_ar = max_ar 76 | 77 | def __call__(self, img, anns): 78 | min_iou = random.choice(self.min_ious) 79 | returns = HF.random_sample_crop(anns, img.size, min_iou, self.min_ar, self.max_ar) 80 | if returns is None: 81 | return img, anns 82 | else: 83 | anns, l, t, w, h = returns 84 | new_img = img.crop([l, t, l + w, t + h]) 85 | new_anns = HF.crop(anns, l, t, w, h) 86 | if len(new_anns) == 0: 87 | return img, anns 88 | return new_img, new_anns 89 | 90 | 91 | class RandomResizedCrop(JointTransform): 92 | """ 93 | Crop the given PIL Image to random size and aspect ratio. 94 | 95 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 96 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 97 | is finally resized to given size. 98 | This is popularly used to train the Inception networks. 99 | 100 | Parameters 101 | ---------- 102 | size : ``Union[Number, Sequence[int]]`` 103 | Desired output size of the crop. If size is an int instead of sequence like (w, h), 104 | a square crop (size, size) is made. 105 | scale : ``Tuple[float, float]`` 106 | Range of size of the origin size cropped. 107 | ratio: ``Tuple[float, float]`` 108 | Range of aspect ratio of the origin aspect ratio cropped. 109 | interpolation: 110 | Default: PIL.Image.BILINEAR 111 | min_area_frac: ``float`` 112 | Minimal area fraction requirement of the original bounding box. 113 | """ 114 | 115 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), min_area_frac=0.25, interpolation=Image.BILINEAR): 116 | super().__init__() 117 | if isinstance(size, tuple): 118 | self.size = size 119 | else: 120 | self.size = (size, size) 121 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 122 | warnings.warn("range should be of kind (min, max)") 123 | 124 | self.interpolation = interpolation 125 | self.scale = scale 126 | self.ratio = ratio 127 | self.min_area_frac = min_area_frac 128 | 129 | @staticmethod 130 | def get_params(img, scale, ratio): 131 | """ 132 | Get parameters for ``crop`` for a random sized crop. 133 | 134 | Parameters 135 | ---------- 136 | img : ``Image`` 137 | Image to be cropped. 138 | scale : ``tuple`` 139 | Range of size of the origin size cropped. 140 | ratio : ``tuple`` 141 | Range of aspect ratio of the origin aspect ratio cropped. 142 | 143 | Returns 144 | ------- 145 | tuple 146 | Tarams (i, j, h, w) to be passed to ``crop`` for a random sized crop. 147 | """ 148 | width, height = img.size 149 | area = width * height 150 | 151 | for attempt in range(10): 152 | target_area = random.uniform(*scale) * area 153 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 154 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 155 | 156 | w = int(round(math.sqrt(target_area * aspect_ratio))) 157 | h = int(round(math.sqrt(target_area / aspect_ratio))) 158 | 159 | if w <= width and h <= height: 160 | i = random.randint(0, height - h) 161 | j = random.randint(0, width - w) 162 | return i, j, h, w 163 | 164 | # Fallback to central crop 165 | in_ratio = width / height 166 | if in_ratio < min(ratio): 167 | w = width 168 | h = w / min(ratio) 169 | elif in_ratio > max(ratio): 170 | h = height 171 | w = h * max(ratio) 172 | else: # whole image 173 | w = width 174 | h = height 175 | i = (height - h) // 2 176 | j = (width - w) // 2 177 | return i, j, h, w 178 | 179 | def __call__(self, img, anns): 180 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 181 | new_anns = HF.resized_crop(anns, j, i, w, h, self.size, self.min_area_frac) 182 | if len(new_anns) == 0: 183 | return img, anns 184 | img = VF.resized_crop(img, i, j, h, w, self.size[::-1], self.interpolation) 185 | return img, new_anns 186 | 187 | def __repr__(self): 188 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 189 | format_string += ', scale={0}'.format(tuple(round(s, 4) 190 | for s in self.scale)) 191 | format_string += ', ratio={0}'.format(tuple(round(r, 4) 192 | for r in self.ratio)) 193 | format_string += ', min_area_frac={0})'.format(self.min_area_frac) 194 | return format_string 195 | 196 | 197 | class Resize(JointTransform): 198 | """Resize the image and bounding boxes. 199 | 200 | Parameters 201 | ---------- 202 | size : ``Union[Number, Sequence[int]]`` 203 | Desired output size. If size is a sequence like (w, h), 204 | the output size will be matched to this. If size is an int, 205 | the smaller edge of the image will be matched to this number maintaing 206 | the aspect ratio. i.e, if width > height, then image will be rescaled to 207 | (output_size * width / height, output_size) 208 | """ 209 | 210 | def __init__(self, size): 211 | super().__init__() 212 | self.size = size 213 | 214 | def __call__(self, img, anns): 215 | if img.size == self.size: 216 | return img, anns 217 | 218 | anns = HF.resize(anns, img.size, self.size) 219 | if isinstance(self.size, Tuple): 220 | size = self.size[::-1] 221 | else: 222 | size = self.size 223 | img = VF.resize(img, size) 224 | return img, anns 225 | 226 | def __repr__(self): 227 | return self.__class__.__name__ + "(size=%s)" % (self.size,) 228 | 229 | 230 | class CenterCrop(JointTransform): 231 | """ 232 | Crops the given PIL Image at the center and transform the bounding boxes. 233 | 234 | Parameters 235 | ---------- 236 | size : ``Union[Number, Sequence[int]]`` 237 | Desired output size of the crop. If size is an int instead of sequence like (w, h), 238 | a square crop (size, size) is made. 239 | """ 240 | 241 | def __init__(self, size): 242 | super().__init__() 243 | self.size = size 244 | 245 | def __call__(self, img, anns): 246 | if isinstance(self.size, Tuple): 247 | size = self.size[::-1] 248 | else: 249 | size = self.size 250 | img = VF.center_crop(img, size) 251 | anns = HF.center_crop(anns, self.size) 252 | return img, anns 253 | 254 | def __repr__(self): 255 | return self.__class__.__name__ + "(size=%s)".format(self.size) 256 | 257 | 258 | class ToPercentCoords(JointTransform): 259 | 260 | def __init__(self): 261 | super().__init__() 262 | 263 | def __call__(self, img, anns): 264 | return img, HF.to_percent_coords(anns, img.size) 265 | 266 | def __repr__(self): 267 | return self.__class__.__name__ + "()" 268 | 269 | 270 | class ToAbsoluteCoords(JointTransform): 271 | 272 | def __init__(self): 273 | super().__init__() 274 | 275 | def __call__(self, img, anns): 276 | return img, HF.to_absolute_coords(anns, img.size) 277 | 278 | def __repr__(self): 279 | return self.__class__.__name__ + "()" 280 | 281 | 282 | class RandomHorizontalFlip(JointTransform): 283 | """Horizontally flip the given PIL Image randomly with a given probability. 284 | 285 | Args: 286 | p (float): probability of the image being flipped. Default value is 0.5 287 | """ 288 | 289 | def __init__(self, p=0.5): 290 | super().__init__() 291 | self.p = p 292 | 293 | def __call__(self, img, anns): 294 | if random.random() < self.p: 295 | img = VF.hflip(img) 296 | anns = HF.hflip(anns, img.size) 297 | return img, anns 298 | return img, anns 299 | 300 | def __repr__(self): 301 | return self.__class__.__name__ + '(p={})'.format(self.p) 302 | 303 | 304 | class RandomVerticalFlip(JointTransform): 305 | """Vertically flip the given PIL Image randomly with a given probability. 306 | 307 | Args: 308 | p (float): probability of the image being flipped. Default value is 0.5 309 | """ 310 | 311 | def __init__(self, p=0.5): 312 | super().__init__() 313 | self.p = p 314 | 315 | def __call__(self, img, anns): 316 | if random.random() < self.p: 317 | img = VF.vflip(img) 318 | anns = HF.vflip(anns, img.size) 319 | return img, anns 320 | return img, anns 321 | 322 | def __repr__(self): 323 | return self.__class__.__name__ + '(p={})'.format(self.p) 324 | 325 | 326 | def SSDTransform(size, color_jitter=True, scale=(0.1, 1), expand=(1, 4), min_area_frac=0.25): 327 | transforms = [] 328 | if color_jitter: 329 | transforms.append( 330 | InputTransform( 331 | ColorJitter( 332 | brightness=0.1, contrast=0.5, 333 | saturation=0.5, hue=0.05, 334 | ) 335 | ) 336 | ) 337 | transforms += [ 338 | RandomApply([ 339 | RandomExpand(expand), 340 | ]), 341 | RandomChoice([ 342 | UseOriginal(), 343 | RandomSampleCrop(), 344 | RandomResizedCrop(size, scale=scale, ratio=(1/2, 2/1), min_area_frac=min_area_frac), 345 | ]), 346 | RandomHorizontalFlip(), 347 | Resize(size) 348 | ] 349 | return Compose(transforms) 350 | -------------------------------------------------------------------------------- /transforms/detection/functional.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Sequence, Union, Tuple 2 | from numbers import Number 3 | import random 4 | 5 | import numpy as np 6 | from toolz import curry 7 | from toolz.curried import get 8 | 9 | from common import _tuple 10 | 11 | __all__ = [ 12 | "resize", "resized_crop", "center_crop", "drop_boundary_bboxes", 13 | "to_absolute_coords", "to_percent_coords", "hflip", "hflip2", 14 | "vflip", "vflip2", "random_sample_crop", "move" 15 | ] 16 | 17 | 18 | def iou_1m(box, boxes): 19 | r""" 20 | Calculates one-to-many ious. 21 | 22 | Parameters 23 | ---------- 24 | box : ``Sequences[Number]`` 25 | A bounding box. 26 | boxes : ``array_like`` 27 | Many bounding boxes. 28 | 29 | Returns 30 | ------- 31 | ious : ``array_like`` 32 | IoUs between the box and boxes. 33 | """ 34 | xi1 = np.maximum(boxes[..., 0], box[0]) 35 | yi1 = np.maximum(boxes[..., 1], box[1]) 36 | xi2 = np.minimum(boxes[..., 2], box[2]) 37 | yi2 = np.minimum(boxes[..., 3], box[3]) 38 | xdiff = xi2 - xi1 39 | ydiff = yi2 - yi1 40 | inter_area = xdiff * ydiff 41 | box_area = (box[2] - box[0]) * (box[3] - box[1]) 42 | boxes_area = (boxes[..., 2] - boxes[..., 0]) * \ 43 | (boxes[..., 3] - boxes[..., 1]) 44 | union_area = boxes_area + box_area - inter_area 45 | 46 | iou = inter_area / union_area 47 | iou[xdiff < 0] = 0 48 | iou[ydiff < 0] = 0 49 | return iou 50 | 51 | 52 | def random_sample_crop(anns, size, min_iou, min_ar, max_ar, max_attemps=50): 53 | """ 54 | Crop the given PIL Image to random size and aspect ratio. 55 | 56 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 57 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 58 | is finally resized to given size. 59 | This is popularly used to train the Inception networks. 60 | 61 | Parameters 62 | ---------- 63 | anns : ``List[Dict]`` 64 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 65 | size : ``Sequence[int]`` 66 | Size of the original image. 67 | min_iou : ``float`` 68 | Minimal iou between the objects and the cropped image. 69 | min_ar : ``Number`` 70 | Minimal aspect ratio. 71 | max_ar : ``Number`` 72 | Maximum aspect ratio. 73 | max_attemps: ``int`` 74 | Maximum attemps to try. 75 | """ 76 | width, height = size 77 | bboxes = np.stack([ann['bbox'] for ann in anns]) 78 | bboxes[:, 2:] += bboxes[:, :2] 79 | for _ in range(max_attemps): 80 | w = random.uniform(0.3 * width, width) 81 | h = random.uniform(0.3 * height, height) 82 | 83 | if h / w < min_ar or h / w > max_ar: 84 | continue 85 | 86 | l = random.uniform(0, width - w) 87 | t = random.uniform(0, height - h) 88 | r = l + w 89 | b = t + h 90 | 91 | patch = np.array([l, t, r, b]) 92 | ious = iou_1m(patch, bboxes) 93 | if ious.min() < min_iou: 94 | continue 95 | 96 | centers = (bboxes[:, :2] + bboxes[:, 2:]) / 2.0 97 | mask = (l < centers[:, 0]) & (centers[:, 0] < r) & ( 98 | t < centers[:, 1]) & (centers[:, 1] < b) 99 | 100 | if not mask.any(): 101 | continue 102 | indices = np.nonzero(mask)[0].tolist() 103 | return get(indices, anns), l, t, w, h 104 | return None 105 | 106 | 107 | @curry 108 | def resized_crop(anns, left, upper, width, height, output_size, min_area_frac): 109 | anns = crop(anns, left, upper, width, height, min_area_frac) 110 | size = (width, height) 111 | # if drop: 112 | # anns = drop_boundary_bboxes(anns, size) 113 | anns = resize(anns, size, output_size) 114 | return anns 115 | 116 | 117 | @curry 118 | def drop_boundary_bboxes(anns, size): 119 | r""" 120 | Drop bounding boxes whose centers are out of the image boundary. 121 | 122 | Parameters 123 | ---------- 124 | anns : ``List[Dict]`` 125 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 126 | size : ``Sequence[int]`` 127 | Size of the original image. 128 | """ 129 | width, height = size 130 | new_anns = [] 131 | for ann in anns: 132 | l, t, w, h = ann['bbox'] 133 | x = (l + w) / 2. 134 | y = (t + h) / 2. 135 | if 0 <= x <= width and 0 <= y <= height: 136 | new_anns.append({**ann, "bbox": [l, t, w, h]}) 137 | return new_anns 138 | 139 | 140 | @curry 141 | def center_crop(anns, size, output_size): 142 | r""" 143 | Crops the bounding boxes of the given PIL Image at the center. 144 | 145 | Parameters 146 | ---------- 147 | anns : ``List[Dict]`` 148 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 149 | size : ``Sequence[int]`` 150 | Size of the original image. 151 | output_size : ``Union[Number, Sequence[int]]`` 152 | Desired output size of the crop. If size is an int instead of sequence like (w, h), 153 | a square crop (size, size) is made. 154 | """ 155 | output_size = _tuple(output_size, 2) 156 | output_size = tuple(int(x) for x in output_size) 157 | w, h = size 158 | th, tw = output_size 159 | upper = int(round((h - th) / 2.)) 160 | left = int(round((w - tw) / 2.)) 161 | return crop(anns, left, upper, th, tw) 162 | 163 | 164 | @curry 165 | def crop(anns, left, upper, width, height, minimal_area_fraction=0.25): 166 | r""" 167 | Crop the bounding boxes of the given PIL Image. 168 | 169 | Parameters 170 | ---------- 171 | anns : ``List[Dict]`` 172 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 173 | left: ``int`` 174 | Left pixel coordinate. 175 | upper: ``int`` 176 | Upper pixel coordinate. 177 | width: ``int`` 178 | Width of the cropped image. 179 | height: ``int`` 180 | Height of the cropped image. 181 | minimal_area_fraction : ``int`` 182 | Minimal area fraction requirement. 183 | """ 184 | new_anns = [] 185 | for ann in anns: 186 | l, t, w, h = ann['bbox'] 187 | area = w * h 188 | l -= left 189 | t -= upper 190 | if l + w >= 0 and l <= width and t + h >= 0 and t <= height: 191 | if l < 0: 192 | w += l 193 | l = 0 194 | if t < 0: 195 | h += t 196 | t = 0 197 | w = min(width - l, w) 198 | h = min(height - t, h) 199 | if w * h < area * minimal_area_fraction: 200 | continue 201 | new_anns.append({**ann, "bbox": [l, t, w, h]}) 202 | return new_anns 203 | 204 | 205 | @curry 206 | def resize(anns, size, output_size): 207 | """ 208 | Parameters 209 | ---------- 210 | anns : List[Dict] 211 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 212 | size : Sequence[int] 213 | Size of the original image. 214 | output_size : Union[Number, Sequence[int]] 215 | Desired output size. If size is a sequence like (w, h), the output size will be matched to this. 216 | If size is an int, the smaller edge of the image will be matched to this number maintaing 217 | the aspect ratio. i.e, if width > height, then image will be rescaled to 218 | (output_size * width / height, output_size) 219 | """ 220 | w, h = size 221 | if isinstance(output_size, int): 222 | if (w <= h and w == output_size) or (h <= w and h == output_size): 223 | return anns 224 | if w < h: 225 | ow = output_size 226 | sw = sh = ow / w 227 | else: 228 | oh = output_size 229 | sw = sh = oh / h 230 | else: 231 | ow, oh = output_size 232 | sw = ow / w 233 | sh = oh / h 234 | new_anns = [] 235 | for ann in anns: 236 | bbox = list(ann['bbox']) 237 | bbox[0] *= sw 238 | bbox[1] *= sh 239 | bbox[2] *= sw 240 | bbox[3] *= sh 241 | new_anns.append({**ann, "bbox": bbox}) 242 | return new_anns 243 | 244 | 245 | @curry 246 | def to_percent_coords(anns, size): 247 | r""" 248 | Convert absolute coordinates of the bounding boxes to percent cocoordinates. 249 | 250 | Parameters 251 | ---------- 252 | anns : ``List[Dict]`` 253 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 254 | size : ``Sequence[int]`` 255 | Size of the original image. 256 | """ 257 | w, h = size 258 | new_anns = [] 259 | for ann in anns: 260 | bbox = list(ann['bbox']) 261 | bbox[0] /= w 262 | bbox[1] /= h 263 | bbox[2] /= w 264 | bbox[3] /= h 265 | new_anns.append({**ann, "bbox": bbox}) 266 | return new_anns 267 | 268 | 269 | @curry 270 | def to_absolute_coords(anns, size): 271 | r""" 272 | Convert percent coordinates of the bounding boxes to absolute cocoordinates. 273 | 274 | Parameters 275 | ---------- 276 | anns : ``List[Dict]`` 277 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 278 | size : ``Sequence[int]`` 279 | Size of the original image. 280 | """ 281 | w, h = size 282 | new_anns = [] 283 | for ann in anns: 284 | bbox = list(ann['bbox']) 285 | bbox[0] *= w 286 | bbox[1] *= h 287 | bbox[2] *= w 288 | bbox[3] *= h 289 | new_anns.append({**ann, "bbox": bbox}) 290 | return new_anns 291 | 292 | 293 | @curry 294 | def hflip(anns, size): 295 | """ 296 | Horizontally flip the bounding boxes of the given PIL Image. 297 | 298 | Parameters 299 | ---------- 300 | anns : ``List[Dict]`` 301 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 302 | size : ``Sequence[int]`` 303 | Size of the original image. 304 | """ 305 | w, h = size 306 | new_anns = [] 307 | for ann in anns: 308 | bbox = list(ann['bbox']) 309 | bbox[0] = w - (bbox[0] + bbox[2]) 310 | new_anns.append({**ann, "bbox": bbox}) 311 | return new_anns 312 | 313 | 314 | @curry 315 | def hflip2(anns, size): 316 | """ 317 | Horizontally flip the bounding boxes of the given PIL Image. 318 | 319 | Parameters 320 | ---------- 321 | anns : ``List[Dict]`` 322 | Sequences of annotation of objects, containing `bbox` of [l, t, r, b]. 323 | size : ``Sequence[int]`` 324 | Size of the original image. 325 | """ 326 | w, h = size 327 | new_anns = [] 328 | for ann in anns: 329 | bbox = list(ann['bbox']) 330 | l = bbox[0] 331 | bbox[0] = w - bbox[2] 332 | bbox[2] = w - l 333 | new_anns.append({**ann, "bbox": bbox}) 334 | return new_anns 335 | 336 | 337 | @curry 338 | def vflip(anns, size): 339 | """ 340 | Vertically flip the bounding boxes of the given PIL Image. 341 | 342 | Parameters 343 | ---------- 344 | anns : ``List[Dict]`` 345 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 346 | size : ``Sequence[int]`` 347 | Size of the original image. 348 | """ 349 | w, h = size 350 | new_anns = [] 351 | for ann in anns: 352 | bbox = list(ann['bbox']) 353 | bbox[1] = h - (bbox[1] + bbox[3]) 354 | new_anns.append({**ann, "bbox": bbox}) 355 | return new_anns 356 | 357 | 358 | @curry 359 | def vflip2(anns, size): 360 | r""" 361 | Vertically flip the bounding boxes of the given PIL Image. 362 | 363 | Parameters 364 | ---------- 365 | anns : ``List[Dict]`` 366 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 367 | size : ``Sequence[int]`` 368 | Size of the original image. 369 | """ 370 | w, h = size 371 | 372 | new_anns = [] 373 | for ann in anns: 374 | bbox = list(ann['bbox']) 375 | t = bbox[1] 376 | bbox[1] = h - bbox[3] 377 | bbox[3] = h - t 378 | new_anns.append({**ann, "bbox": bbox}) 379 | return new_anns 380 | 381 | 382 | @curry 383 | def move(anns, x, y): 384 | r""" 385 | Move the bounding boxes by x and y. 386 | 387 | Parameters 388 | ---------- 389 | anns : ``List[Dict]`` 390 | Sequences of annotation of objects, containing `bbox` of [l, t, w, h]. 391 | x : ``Number`` 392 | How many to move along the horizontal axis. 393 | y : ``Number`` 394 | How many to move along the vertical axis. 395 | """ 396 | 397 | new_anns = [] 398 | for ann in anns: 399 | bbox = list(ann['bbox']) 400 | bbox[0] += x 401 | bbox[1] += y 402 | new_anns.append({**ann, "bbox": bbox}) 403 | return new_anns 404 | 405 | -------------------------------------------------------------------------------- /transforms/ext.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from PIL import Image, ImageEnhance, ImageOps 5 | 6 | import torch 7 | 8 | 9 | class Cutout: 10 | """Randomly mask out one or more patches from an image. 11 | 12 | Optimal length: 16 for CIFAR10, 8 for CIFAR100, 20 for SVHN, 24 or 32 for STL10 13 | 14 | Note: 15 | It should be put after ToTensor(). 16 | 17 | Args: 18 | n_holes (int): Number of patches to cut out of each image. 19 | length (int): The length (in pixels) of each square patch. 20 | """ 21 | 22 | def __init__(self, n_holes, length): 23 | self.n_holes = n_holes 24 | self.length = length 25 | 26 | def __call__(self, img): 27 | """ 28 | Args: 29 | img (Tensor): Tensor image of size (C, H, W). 30 | Returns: 31 | Tensor: Image with n_holes of dimension length x length cut out of it. 32 | """ 33 | h = img.size(1) 34 | w = img.size(2) 35 | 36 | mask = np.ones((h, w), np.float32) 37 | 38 | for n in range(self.n_holes): 39 | y = np.random.randint(h) 40 | x = np.random.randint(w) 41 | 42 | y1 = np.clip(y - self.length // 2, 0, h) 43 | y2 = np.clip(y + self.length // 2, 0, h) 44 | x1 = np.clip(x - self.length // 2, 0, w) 45 | x2 = np.clip(x + self.length // 2, 0, w) 46 | 47 | mask[y1: y2, x1: x2] = 0. 48 | 49 | mask = torch.from_numpy(mask) 50 | mask = mask.expand_as(img) 51 | img = img * mask 52 | 53 | return img 54 | 55 | 56 | class ImageNetPolicy(object): 57 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 58 | 59 | Example: 60 | >>> policy = ImageNetPolicy() 61 | >>> transformed = policy(image) 62 | 63 | Example as a PyTorch Transform: 64 | >>> transform=transforms.Compose([ 65 | >>> transforms.Resize(256), 66 | >>> ImageNetPolicy(), 67 | >>> transforms.ToTensor()]) 68 | """ 69 | 70 | def __init__(self, fillcolor=(128, 128, 128)): 71 | self.policies = [ 72 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 73 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 74 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 75 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 76 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 77 | 78 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 79 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 80 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 81 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 82 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 83 | 84 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 85 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 86 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 87 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 88 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 89 | 90 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 91 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 92 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 93 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 94 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 95 | 96 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 97 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 98 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 99 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 100 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 101 | ] 102 | 103 | def __call__(self, img): 104 | policy_idx = random.randint(0, len(self.policies) - 1) 105 | return self.policies[policy_idx](img) 106 | 107 | def __repr__(self): 108 | return "AutoAugment ImageNet Policy" 109 | 110 | 111 | class CIFAR10Policy(object): 112 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 113 | 114 | Example: 115 | >>> policy = CIFAR10Policy() 116 | >>> transformed = policy(image) 117 | 118 | Example as a PyTorch Transform: 119 | >>> transform=transforms.Compose([ 120 | >>> transforms.Resize(256), 121 | >>> CIFAR10Policy(), 122 | >>> transforms.ToTensor()]) 123 | """ 124 | 125 | def __init__(self, fillcolor=(128, 128, 128)): 126 | self.policies = [ 127 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 128 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 129 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 130 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 131 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 132 | 133 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 134 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 135 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 136 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 137 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 138 | 139 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 140 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 141 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 142 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 143 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 144 | 145 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 146 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 147 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 148 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 149 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 150 | 151 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 152 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 153 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 154 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 155 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 156 | ] 157 | 158 | def __call__(self, img): 159 | policy_idx = random.randint(0, len(self.policies) - 1) 160 | return self.policies[policy_idx](img) 161 | 162 | def __repr__(self): 163 | return "AutoAugment CIFAR10 Policy" 164 | 165 | 166 | class SVHNPolicy(object): 167 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 168 | 169 | Example: 170 | >>> policy = SVHNPolicy() 171 | >>> transformed = policy(image) 172 | 173 | Example as a PyTorch Transform: 174 | >>> transform=transforms.Compose([ 175 | >>> transforms.Resize(256), 176 | >>> SVHNPolicy(), 177 | >>> transforms.ToTensor()]) 178 | """ 179 | 180 | def __init__(self, fillcolor=(128, 128, 128)): 181 | self.policies = [ 182 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 183 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 184 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 185 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 186 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 187 | 188 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 189 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 190 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 191 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 192 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 193 | 194 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 195 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 196 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 197 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 198 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 199 | 200 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 201 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 202 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 203 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 204 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 205 | 206 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 207 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 208 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 209 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 210 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 211 | ] 212 | 213 | def __call__(self, img): 214 | policy_idx = random.randint(0, len(self.policies) - 1) 215 | return self.policies[policy_idx](img) 216 | 217 | def __repr__(self): 218 | return "AutoAugment SVHN Policy" 219 | 220 | 221 | class SubPolicy(object): 222 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 223 | ranges = { 224 | "shearX": np.linspace(0, 0.3, 10), 225 | "shearY": np.linspace(0, 0.3, 10), 226 | "translateX": np.linspace(0, 150 / 331, 10), 227 | "translateY": np.linspace(0, 150 / 331, 10), 228 | "rotate": np.linspace(0, 30, 10), 229 | "color": np.linspace(0.0, 0.9, 10), 230 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 231 | "solarize": np.linspace(256, 0, 10), 232 | "contrast": np.linspace(0.0, 0.9, 10), 233 | "sharpness": np.linspace(0.0, 0.9, 10), 234 | "brightness": np.linspace(0.0, 0.9, 10), 235 | "autocontrast": [0] * 10, 236 | "equalize": [0] * 10, 237 | "invert": [0] * 10 238 | } 239 | 240 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 241 | def rotate_with_fill(img, magnitude): 242 | rot = img.convert("RGBA").rotate(magnitude) 243 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 244 | 245 | func = { 246 | "shearX": lambda img, magnitude: img.transform( 247 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 248 | Image.BICUBIC, fillcolor=fillcolor), 249 | "shearY": lambda img, magnitude: img.transform( 250 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 251 | Image.BICUBIC, fillcolor=fillcolor), 252 | "translateX": lambda img, magnitude: img.transform( 253 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 254 | fillcolor=fillcolor), 255 | "translateY": lambda img, magnitude: img.transform( 256 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 257 | fillcolor=fillcolor), 258 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 259 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 260 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 261 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 262 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 263 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 264 | 1 + magnitude * random.choice([-1, 1])), 265 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 266 | 1 + magnitude * random.choice([-1, 1])), 267 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 268 | 1 + magnitude * random.choice([-1, 1])), 269 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 270 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 271 | "invert": lambda img, magnitude: ImageOps.invert(img) 272 | } 273 | 274 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 275 | # operation1, ranges[operation1][magnitude_idx1], 276 | # operation2, ranges[operation2][magnitude_idx2]) 277 | self.p1 = p1 278 | self.operation1 = func[operation1] 279 | self.magnitude1 = ranges[operation1][magnitude_idx1] 280 | self.p2 = p2 281 | self.operation2 = func[operation2] 282 | self.magnitude2 = ranges[operation2][magnitude_idx2] 283 | 284 | def __call__(self, img): 285 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 286 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 287 | return img 288 | -------------------------------------------------------------------------------- /transforms/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import torchvision.transforms.functional as TF 8 | from PIL import Image 9 | from transforms import JointTransform 10 | from typing import Iterable 11 | 12 | 13 | class SameTransform(JointTransform): 14 | 15 | def __init__(self, t): 16 | super().__init__() 17 | self.t = t 18 | 19 | def __call__(self, img, mask): 20 | return self.t(img), self.t(mask) 21 | 22 | 23 | class ToTensor(JointTransform): 24 | """Convert the input ``PIL Image`` to tensor and the target segmentation image to labels. 25 | """ 26 | 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def __call__(self, img, mask): 31 | input = TF.to_tensor(img) 32 | target = np.array(mask) 33 | target = torch.from_numpy(target).long() 34 | return input, target 35 | 36 | 37 | class Resize(JointTransform): 38 | """Resize the input PIL Image to the given size. 39 | 40 | Args: 41 | size (sequence or int): Desired output size. If size is a sequence like 42 | (h, w), output size will be matched to this. If size is an int, 43 | smaller edge of the image will be matched to this number. 44 | i.e, if height > width, then image will be rescaled to 45 | (size * height / width, size) 46 | interpolation (int, optional): Desired interpolation. Default is 47 | ``PIL.Image.BILINEAR`` 48 | """ 49 | 50 | def __init__(self, size): 51 | super().__init__() 52 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 53 | self.size = size 54 | 55 | def __call__(self, img, mask): 56 | """ 57 | Args: 58 | img (PIL Image): Image to be scaled. 59 | 60 | Returns: 61 | PIL Image: Rescaled image. 62 | """ 63 | img = TF.resize(img, self.size, Image.BILINEAR) 64 | mask = TF.resize(mask, self.size, Image.NEAREST) 65 | return img, mask 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + '(size={0})'.format(self.size) 69 | 70 | 71 | class RandomCrop(JointTransform): 72 | """Crop the given PIL Image at a random location. 73 | 74 | Args: 75 | size (sequence or int): Desired output size of the crop. If size is an 76 | int instead of sequence like (h, w), a square crop (size, size) is 77 | made. 78 | padding (int or sequence, optional): Optional padding on each border 79 | of the image. Default is None, i.e no padding. If a sequence of length 80 | 4 is provided, it is used to pad left, top, right, bottom borders 81 | respectively. If a sequence of length 2 is provided, it is used to 82 | pad left/right, top/bottom borders, respectively. 83 | pad_if_needed (boolean): It will pad the image if smaller than the 84 | desired size to avoid raising an exception. 85 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 86 | length 3, it is used to fill R, G, B channels respectively. 87 | This value is only used when the padding_mode is constant 88 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 89 | 90 | - constant: pads with a constant value, this value is specified with fill 91 | 92 | - edge: pads with the last value on the edge of the image 93 | 94 | - reflect: pads with reflection of image (without repeating the last value on the edge) 95 | 96 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 97 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 98 | 99 | - symmetric: pads with reflection of image (repeating the last value on the edge) 100 | 101 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 102 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 103 | 104 | """ 105 | 106 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 107 | super().__init__() 108 | if isinstance(size, numbers.Number): 109 | self.size = (int(size), int(size)) 110 | else: 111 | self.size = size 112 | self.padding = padding 113 | self.pad_if_needed = pad_if_needed 114 | self.fill = fill 115 | self.padding_mode = padding_mode 116 | 117 | @staticmethod 118 | def get_params(img, output_size): 119 | """Get parameters for ``crop`` for a random crop. 120 | 121 | Args: 122 | img (PIL Image): Image to be cropped. 123 | output_size (tuple): Expected output size of the crop. 124 | 125 | Returns: 126 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 127 | """ 128 | w, h = img.size 129 | th, tw = output_size 130 | if w == tw and h == th: 131 | return 0, 0, h, w 132 | 133 | i = random.randint(0, h - th) 134 | j = random.randint(0, w - tw) 135 | return i, j, th, tw 136 | 137 | def __call__(self, img, mask): 138 | """ 139 | Args: 140 | img (PIL Image): Image to be cropped. 141 | 142 | Returns: 143 | PIL Image: Cropped image. 144 | """ 145 | if self.padding is not None: 146 | img = TF.pad(img, self.padding, self.fill, self.padding_mode) 147 | mask = TF.pad(mask, self.padding, self.fill, self.padding_mode) 148 | 149 | # pad the width if needed 150 | if self.pad_if_needed and img.size[0] < self.size[1]: 151 | img = TF.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 152 | mask = TF.pad(mask, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 153 | # pad the height if needed 154 | if self.pad_if_needed and img.size[1] < self.size[0]: 155 | img = TF.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 156 | mask = TF.pad(mask, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 157 | 158 | i, j, h, w = self.get_params(img, self.size) 159 | 160 | img = TF.crop(img, i, j, h, w) 161 | mask = TF.crop(mask, i, j, h, w) 162 | return img, mask 163 | 164 | def __repr__(self): 165 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 166 | 167 | 168 | class CenterCrop(JointTransform): 169 | """Crops the given PIL Image at the center. 170 | 171 | Args: 172 | size (sequence or int): Desired output size of the crop. If size is an 173 | int instead of sequence like (h, w), a square crop (size, size) is 174 | made. 175 | """ 176 | 177 | def __init__(self, size): 178 | super().__init__() 179 | if isinstance(size, numbers.Number): 180 | self.size = (int(size), int(size)) 181 | else: 182 | self.size = size 183 | 184 | def __call__(self, img, mask): 185 | """ 186 | Args: 187 | img (PIL Image): Image to be cropped. 188 | 189 | Returns: 190 | PIL Image: Cropped image. 191 | """ 192 | img = TF.center_crop(img, self.size) 193 | mask = TF.center_crop(mask, self.size) 194 | return img, mask 195 | 196 | def __repr__(self): 197 | return self.__class__.__name__ + '(size={0})'.format(self.size) 198 | 199 | 200 | class RandomHorizontalFlip(JointTransform): 201 | """Horizontally flip the given PIL Image randomly with a given probability. 202 | 203 | Args: 204 | p (float): probability of the image being flipped. Default value is 0.5 205 | """ 206 | 207 | def __init__(self, p=0.5): 208 | super().__init__() 209 | self.p = p 210 | 211 | def __call__(self, img, mask): 212 | """ 213 | Args: 214 | img (PIL Image): Image to be flipped. 215 | 216 | Returns: 217 | PIL Image: Randomly flipped image. 218 | """ 219 | if random.random() < self.p: 220 | return TF.hflip(img), TF.hflip(mask) 221 | return img, mask 222 | 223 | def __repr__(self): 224 | return self.__class__.__name__ + '(p={})'.format(self.p) --------------------------------------------------------------------------------