├── README.md └── deepcaps.py /README.md: -------------------------------------------------------------------------------- 1 | # DeepCaps PyTorch 2 | PyTorch Implementation of "DeepCaps: Going Deeper with Capsule Networks" by J. Rajasegaran et al. [CVPR 2019] 3 | 4 | 5 | [![Say Thanks!](https://img.shields.io/badge/Say%20Thanks-!-1EAEDB.svg)](https://saythanks.io/to/HopefulRational) 6 | -------------------------------------------------------------------------------- /deepcaps.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | ''' 8 | Authors: HopefulRational and team 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as func 14 | from torch.autograd import Variable 15 | # import torch.autograd as grad 16 | from torchvision import datasets, transforms 17 | import pandas as pd 18 | import numpy as np 19 | import math 20 | #import skimage.transform 21 | #import matplotlib.pyplot as plt 22 | # from tqdm import tqdm_notebook as tqdm 23 | #get_ipython().magic('matplotlib inline') 24 | 25 | eps = 1e-8 26 | cf = 1 27 | # ONLY cuda runnable 28 | device = torch.device("cuda:0") 29 | 30 | # norm_squared = torch.sum(s**2, dim=dim, keepdim=True) 31 | # return ((norm_squared /(1 + norm_squared + eps)) * (s / (torch.sqrt(norm_squared) + eps))) 32 | 33 | 34 | # In[2]: 35 | 36 | 37 | """ 38 | From github: https://gist.github.com/ncullen93/425ca642955f73452ebc097b3b46c493 39 | """ 40 | """ 41 | Affine transforms implemented on torch tensors, and 42 | only requiring one interpolation 43 | Included: 44 | - Affine() 45 | - AffineCompose() 46 | - Rotation() 47 | - Translation() 48 | - Shear() 49 | - Zoom() 50 | - Flip() 51 | """ 52 | 53 | import math 54 | import random 55 | import torch 56 | 57 | # necessary now, but should eventually not be 58 | import scipy.ndimage as ndi 59 | import numpy as np 60 | 61 | 62 | def transform_matrix_offset_center(matrix, x, y): 63 | """Apply offset to a transform matrix so that the image is 64 | transformed about the center of the image. 65 | NOTE: This is a fairly simple operaion, so can easily be 66 | moved to full torch. 67 | Arguments 68 | --------- 69 | matrix : 3x3 matrix/array 70 | x : integer 71 | height dimension of image to be transformed 72 | y : integer 73 | width dimension of image to be transformed 74 | """ 75 | o_x = float(x) / 2 + 0.5 76 | o_y = float(y) / 2 + 0.5 77 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 78 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 79 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix) 80 | return transform_matrix 81 | 82 | def apply_transform(x, transform, fill_mode='nearest', fill_value=0.): 83 | """Applies an affine transform to a 2D array, or to each channel of a 3D array. 84 | NOTE: this can and certainly should be moved to full torch operations. 85 | Arguments 86 | --------- 87 | x : np.ndarray 88 | array to transform. NOTE: array should be ordered CHW 89 | 90 | transform : 3x3 affine transform matrix 91 | matrix to apply 92 | """ 93 | x = x.astype('float32') 94 | transform = transform_matrix_offset_center(transform, x.shape[1], x.shape[2]) 95 | final_affine_matrix = transform[:2, :2] 96 | final_offset = transform[:2, 2] 97 | channel_images = [ndi.interpolation.affine_transform(x_channel, final_affine_matrix, 98 | final_offset, order=0, mode=fill_mode, cval=fill_value) for x_channel in x] 99 | x = np.stack(channel_images, axis=0) 100 | return x 101 | 102 | class Affine(object): 103 | 104 | def __init__(self, 105 | rotation_range=None, 106 | translation_range=None, 107 | shear_range=None, 108 | zoom_range=None, 109 | fill_mode='constant', 110 | fill_value=0., 111 | target_fill_mode='nearest', 112 | target_fill_value=0.): 113 | """Perform an affine transforms with various sub-transforms, using 114 | only one interpolation and without having to instantiate each 115 | sub-transform individually. 116 | Arguments 117 | --------- 118 | rotation_range : one integer or float 119 | image will be rotated between (-degrees, degrees) degrees 120 | translation_range : a float or a tuple/list w/ 2 floats between [0, 1) 121 | first value: 122 | image will be horizontally shifted between 123 | (-height_range * height_dimension, height_range * height_dimension) 124 | second value: 125 | Image will be vertically shifted between 126 | (-width_range * width_dimension, width_range * width_dimension) 127 | shear_range : float 128 | radian bounds on the shear transform 129 | zoom_range : list/tuple with two floats between [0, infinity). 130 | first float should be less than the second 131 | lower and upper bounds on percent zoom. 132 | Anything less than 1.0 will zoom in on the image, 133 | anything greater than 1.0 will zoom out on the image. 134 | e.g. (0.7, 1.0) will only zoom in, 135 | (1.0, 1.4) will only zoom out, 136 | (0.7, 1.4) will randomly zoom in or out 137 | fill_mode : string in {'constant', 'nearest'} 138 | how to fill the empty space caused by the transform 139 | ProTip : use 'nearest' for discrete images (e.g. segmentations) 140 | and use 'constant' for continuous images 141 | fill_value : float 142 | the value to fill the empty space with if fill_mode='constant' 143 | target_fill_mode : same as fill_mode, but for target image 144 | target_fill_value : same as fill_value, but for target image 145 | """ 146 | self.transforms = [] 147 | if translation_range: 148 | translation_tform = Translation(translation_range, lazy=True) 149 | self.transforms.append(translation_tform) 150 | 151 | if rotation_range: 152 | rotation_tform = Rotation(rotation_range, lazy=True) 153 | self.transforms.append(rotation_tform) 154 | 155 | if shear_range: 156 | shear_tform = Shear(shear_range, lazy=True) 157 | self.transforms.append(shear_tform) 158 | 159 | if zoom_range: 160 | zoom_tform = Translation(zoom_range, lazy=True) 161 | self.transforms.append(zoom_tform) 162 | 163 | self.fill_mode = fill_mode 164 | self.fill_value = fill_value 165 | self.target_fill_mode = target_fill_mode 166 | self.target_fill_value = target_fill_value 167 | 168 | def __call__(self, x, y=None): 169 | # collect all of the lazily returned tform matrices 170 | tform_matrix = self.transforms[0](x) 171 | for tform in self.transforms[1:]: 172 | tform_matrix = np.dot(tform_matrix, tform(x)) 173 | 174 | x = torch.from_numpy(apply_transform(x.numpy(), tform_matrix, 175 | fill_mode=self.fill_mode, fill_value=self.fill_value)) 176 | 177 | if y: 178 | y = torch.from_numpy(apply_transform(y.numpy(), tform_matrix, 179 | fill_mode=self.target_fill_mode, fill_value=self.target_fill_value)) 180 | return x, y 181 | else: 182 | return x 183 | 184 | class AffineCompose(object): 185 | 186 | def __init__(self, 187 | transforms, 188 | fill_mode='constant', 189 | fill_value=0., 190 | target_fill_mode='nearest', 191 | target_fill_value=0.): 192 | """Apply a collection of explicit affine transforms to an input image, 193 | and to a target image if necessary 194 | Arguments 195 | --------- 196 | transforms : list or tuple 197 | each element in the list/tuple should be an affine transform. 198 | currently supported transforms: 199 | - Rotation() 200 | - Translation() 201 | - Shear() 202 | - Zoom() 203 | fill_mode : string in {'constant', 'nearest'} 204 | how to fill the empty space caused by the transform 205 | fill_value : float 206 | the value to fill the empty space with if fill_mode='constant' 207 | """ 208 | self.transforms = transforms 209 | # set transforms to lazy so they only return the tform matrix 210 | for t in self.transforms: 211 | t.lazy = True 212 | self.fill_mode = fill_mode 213 | self.fill_value = fill_value 214 | self.target_fill_mode = target_fill_mode 215 | self.target_fill_value = target_fill_value 216 | 217 | def __call__(self, x, y=None): 218 | # collect all of the lazily returned tform matrices 219 | tform_matrix = self.transforms[0](x) 220 | for tform in self.transforms[1:]: 221 | tform_matrix = np.dot(tform_matrix, tform(x)) 222 | 223 | x = torch.from_numpy(apply_transform(x.numpy(), tform_matrix, 224 | fill_mode=self.fill_mode, fill_value=self.fill_value)) 225 | 226 | if y: 227 | y = torch.from_numpy(apply_transform(y.numpy(), tform_matrix, 228 | fill_mode=self.target_fill_mode, fill_value=self.target_fill_value)) 229 | return x, y 230 | else: 231 | return x 232 | 233 | 234 | class Rotation(object): 235 | 236 | def __init__(self, 237 | rotation_range, 238 | fill_mode='constant', 239 | fill_value=0., 240 | target_fill_mode='nearest', 241 | target_fill_value=0., 242 | lazy=False): 243 | """Randomly rotate an image between (-degrees, degrees). If the image 244 | has multiple channels, the same rotation will be applied to each channel. 245 | Arguments 246 | --------- 247 | rotation_range : integer or float 248 | image will be rotated between (-degrees, degrees) degrees 249 | fill_mode : string in {'constant', 'nearest'} 250 | how to fill the empty space caused by the transform 251 | fill_value : float 252 | the value to fill the empty space with if fill_mode='constant' 253 | lazy : boolean 254 | if true, perform the transform on the tensor and return the tensor 255 | if false, only create the affine transform matrix and return that 256 | """ 257 | self.rotation_range = rotation_range 258 | self.fill_mode = fill_mode 259 | self.fill_value = fill_value 260 | self.target_fill_mode = target_fill_mode 261 | self.target_fill_value = target_fill_value 262 | self.lazy = lazy 263 | 264 | def __call__(self, x, y=None): 265 | degree = random.uniform(-self.rotation_range, self.rotation_range) 266 | theta = math.pi / 180 * degree 267 | rotation_matrix = np.array([[math.cos(theta), -math.sin(theta), 0], 268 | [math.sin(theta), math.cos(theta), 0], 269 | [0, 0, 1]]) 270 | if self.lazy: 271 | return rotation_matrix 272 | else: 273 | x_transformed = torch.from_numpy(apply_transform(x.numpy(), rotation_matrix, 274 | fill_mode=self.fill_mode, fill_value=self.fill_value)) 275 | if y: 276 | y_transformed = torch.from_numpy(apply_transform(y.numpy(), rotation_matrix, 277 | fill_mode=self.target_fill_mode, fill_value=self.target_fill_value)) 278 | return x_transformed, y_transformed 279 | else: 280 | return x_transformed 281 | 282 | 283 | class Translation(object): 284 | 285 | def __init__(self, 286 | translation_range, 287 | fill_mode='constant', 288 | fill_value=0., 289 | target_fill_mode='nearest', 290 | target_fill_value=0., 291 | lazy=False): 292 | """Randomly translate an image some fraction of total height and/or 293 | some fraction of total width. If the image has multiple channels, 294 | the same translation will be applied to each channel. 295 | Arguments 296 | --------- 297 | translation_range : two floats between [0, 1) 298 | first value: 299 | fractional bounds of total height to shift image 300 | image will be horizontally shifted between 301 | (-height_range * height_dimension, height_range * height_dimension) 302 | second value: 303 | fractional bounds of total width to shift image 304 | Image will be vertically shifted between 305 | (-width_range * width_dimension, width_range * width_dimension) 306 | fill_mode : string in {'constant', 'nearest'} 307 | how to fill the empty space caused by the transform 308 | fill_value : float 309 | the value to fill the empty space with if fill_mode='constant' 310 | lazy : boolean 311 | if true, perform the transform on the tensor and return the tensor 312 | if false, only create the affine transform matrix and return that 313 | """ 314 | if isinstance(translation_range, float): 315 | translation_range = (translation_range, translation_range) 316 | self.height_range = translation_range[0] 317 | self.width_range = translation_range[1] 318 | self.fill_mode = fill_mode 319 | self.fill_value = fill_value 320 | self.target_fill_mode = target_fill_mode 321 | self.target_fill_value = target_fill_value 322 | self.lazy = lazy 323 | 324 | def __call__(self, x, y=None): 325 | # height shift 326 | if self.height_range > 0: 327 | tx = random.uniform(-self.height_range, self.height_range) * x.size(1) 328 | else: 329 | tx = 0 330 | # width shift 331 | if self.width_range > 0: 332 | ty = random.uniform(-self.width_range, self.width_range) * x.size(2) 333 | else: 334 | ty = 0 335 | 336 | translation_matrix = np.array([[1, 0, tx], 337 | [0, 1, ty], 338 | [0, 0, 1]]) 339 | if self.lazy: 340 | return translation_matrix 341 | else: 342 | x_transformed = torch.from_numpy(apply_transform(x.numpy(), 343 | translation_matrix, fill_mode=self.fill_mode, fill_value=self.fill_value)) 344 | if y: 345 | y_transformed = torch.from_numpy(apply_transform(y.numpy(), translation_matrix, 346 | fill_mode=self.target_fill_mode, fill_value=self.target_fill_value)) 347 | return x_transformed, y_transformed 348 | else: 349 | return x_transformed 350 | 351 | 352 | class Shear(object): 353 | 354 | def __init__(self, 355 | shear_range, 356 | fill_mode='constant', 357 | fill_value=0., 358 | target_fill_mode='nearest', 359 | target_fill_value=0., 360 | lazy=False): 361 | """Randomly shear an image with radians (-shear_range, shear_range) 362 | Arguments 363 | --------- 364 | shear_range : float 365 | radian bounds on the shear transform 366 | 367 | fill_mode : string in {'constant', 'nearest'} 368 | how to fill the empty space caused by the transform 369 | fill_value : float 370 | the value to fill the empty space with if fill_mode='constant' 371 | lazy : boolean 372 | if true, perform the transform on the tensor and return the tensor 373 | if false, only create the affine transform matrix and return that 374 | """ 375 | self.shear_range = shear_range 376 | self.fill_mode = fill_mode 377 | self.fill_value = fill_value 378 | self.target_fill_mode = target_fill_mode 379 | self.target_fill_value = target_fill_value 380 | self.lazy = lazy 381 | 382 | def __call__(self, x, y=None): 383 | shear = random.uniform(-self.shear_range, self.shear_range) 384 | shear_matrix = np.array([[1, -math.sin(shear), 0], 385 | [0, math.cos(shear), 0], 386 | [0, 0, 1]]) 387 | if self.lazy: 388 | return shear_matrix 389 | else: 390 | x_transformed = torch.from_numpy(apply_transform(x.numpy(), 391 | shear_matrix, fill_mode=self.fill_mode, fill_value=self.fill_value)) 392 | if y: 393 | y_transformed = torch.from_numpy(apply_transform(y.numpy(), shear_matrix, 394 | fill_mode=self.target_fill_mode, fill_value=self.target_fill_value)) 395 | return x_transformed, y_transformed 396 | else: 397 | return x_transformed 398 | 399 | 400 | class Zoom(object): 401 | 402 | def __init__(self, 403 | zoom_range, 404 | fill_mode='constant', 405 | fill_value=0, 406 | target_fill_mode='nearest', 407 | target_fill_value=0., 408 | lazy=False): 409 | """Randomly zoom in and/or out on an image 410 | Arguments 411 | --------- 412 | zoom_range : tuple or list with 2 values, both between (0, infinity) 413 | lower and upper bounds on percent zoom. 414 | Anything less than 1.0 will zoom in on the image, 415 | anything greater than 1.0 will zoom out on the image. 416 | e.g. (0.7, 1.0) will only zoom in, 417 | (1.0, 1.4) will only zoom out, 418 | (0.7, 1.4) will randomly zoom in or out 419 | fill_mode : string in {'constant', 'nearest'} 420 | how to fill the empty space caused by the transform 421 | fill_value : float 422 | the value to fill the empty space with if fill_mode='constant' 423 | lazy : boolean 424 | if true, perform the transform on the tensor and return the tensor 425 | if false, only create the affine transform matrix and return that 426 | """ 427 | if not isinstance(zoom_range, list) and not isinstance(zoom_range, tuple): 428 | raise ValueError('zoom_range must be tuple or list with 2 values') 429 | self.zoom_range = zoom_range 430 | self.fill_mode = fill_mode 431 | self.fill_value = fill_value 432 | self.target_fill_mode = target_fill_mode 433 | self.target_fill_value = target_fill_value 434 | self.lazy = lazy 435 | 436 | def __call__(self, x, y=None): 437 | zx = random.uniform(self.zoom_range[0], self.zoom_range[1]) 438 | zy = random.uniform(self.zoom_range[0], self.zoom_range[1]) 439 | zoom_matrix = np.array([[zx, 0, 0], 440 | [0, zy, 0], 441 | [0, 0, 1]]) 442 | if self.lazy: 443 | return zoom_matrix 444 | else: 445 | x_transformed = torch.from_numpy(apply_transform(x.numpy(), 446 | zoom_matrix, fill_mode=self.fill_mode, fill_value=self.fill_value)) 447 | if y: 448 | y_transformed = torch.from_numpy(apply_transform(y.numpy(), zoom_matrix, 449 | fill_mode=self.target_fill_mode, fill_value=self.target_fill_value)) 450 | return x_transformed, y_transformed 451 | else: 452 | return x_transformed 453 | 454 | 455 | 456 | 457 | # In[3]: 458 | 459 | 460 | print("\nclass trans") 461 | class trans(object): 462 | def __init__(self, 463 | rotation_range=None, 464 | translation_range=None, 465 | shear_range=None, 466 | zoom_range=None, 467 | fill_mode='constant', 468 | fill_value=0., 469 | target_fill_mode='nearest', 470 | target_fill_value=0. 471 | ): 472 | self.affine = Affine(rotation_range, translation_range, shear_range, zoom_range) 473 | 474 | def __call__(self, data): 475 | data = transforms.ToTensor()(data) 476 | return self.affine(data) 477 | 478 | 479 | # In[4]: 480 | 481 | 482 | print("\nsquash -> Tensor") 483 | print("softmax_3d -> Tensor") 484 | print("one_hot -> numpy.array") 485 | 486 | def squash(s, dim=-1): 487 | norm = torch.norm(s, dim=dim, keepdim=True) 488 | return (norm /(1 + norm**2 + eps)) * s 489 | 490 | # not being used anymore. instead using nn.functional.softmax 491 | def softmax_3d(x, dim): 492 | return (torch.exp(x) / torch.sum(torch.sum(torch.sum(torch.exp(x), dim=dim[0], keepdim=True), dim=dim[1], keepdim=True), dim=dim[2], keepdim=True)) 493 | 494 | def one_hot(tensor, num_classes=10): 495 | return torch.eye(num_classes).cuda().index_select(dim=0, index=tensor.cuda()) # One-hot encode 496 | # return torch.eye(num_classes).index_select(dim=0, index=tensor).numpy() # One-hot encode 497 | 498 | 499 | # In[5]: 500 | 501 | 502 | print("class ConvertToCaps") 503 | 504 | class ConvertToCaps(nn.Module): 505 | def __init__(self): 506 | super().__init__() 507 | 508 | def forward(self, inputs): 509 | # channels first 510 | return torch.unsqueeze(inputs, 2) 511 | 512 | 513 | # In[6]: 514 | 515 | 516 | print("class FlattenCaps") 517 | 518 | class FlattenCaps(nn.Module): 519 | def __init__(self): 520 | super().__init__() 521 | 522 | def forward(self, inputs): 523 | # inputs.shape = (batch, channels, dimensions, height, width) 524 | batch, channels, dimensions, height, width = inputs.shape 525 | inputs = inputs.permute(0, 3, 4, 1, 2).contiguous() 526 | output_shape = (batch, channels * height * width, dimensions) 527 | return inputs.view(*output_shape) 528 | 529 | 530 | # In[7]: 531 | 532 | 533 | print("class CapsToScalars") 534 | 535 | class CapsToScalars(nn.Module): 536 | def __init__(self): 537 | super().__init__() 538 | 539 | def forward(self, inputs): 540 | # inputs.shape = (batch, num_capsules, dimensions) 541 | return torch.norm(inputs, dim=2) 542 | 543 | 544 | # In[8]: 545 | 546 | 547 | print("class Conv2DCaps") 548 | 549 | # padding should be 'SAME' 550 | # LATER correct: DONT PASS h, w FOR conv2d OPERATION 551 | 552 | class Conv2DCaps(nn.Module): 553 | def __init__(self, h, w, ch_i, n_i, ch_j, n_j, kernel_size=3, stride=1, r_num=1): 554 | super().__init__() 555 | self.ch_i = ch_i 556 | self.n_i = n_i 557 | self.ch_j = ch_j 558 | self.n_j = n_j 559 | self.kernel_size = kernel_size 560 | self.stride = stride 561 | self.r_num = r_num 562 | in_channels = self.ch_i * self.n_i 563 | out_channels = self.ch_j * self.n_j 564 | self.pad = 1 565 | 566 | # self.w = nn.Parameter(torch.randn(ch_j, n_j, ch_i, n_i, kernel_size, kernel_size) * 0.01).cuda() 567 | 568 | # self.w_reshaped = self.w.view(ch_j*n_j, ch_i*n_i, kernel_size, kernel_size) 569 | 570 | self.conv1 = nn.Conv2d(in_channels=in_channels, 571 | out_channels=out_channels, 572 | kernel_size=self.kernel_size, 573 | stride=self.stride, 574 | padding=self.pad).cuda() 575 | 576 | 577 | def forward(self, inputs): 578 | # check if happened properly 579 | # inputs.shape: (batch, channels, dimension, hight, width) 580 | 581 | self.batch, self.ch_i, self.n_i, self.h_i, self.w_i = inputs.shape 582 | in_size = self.h_i 583 | x = inputs.view(self.batch, self.ch_i * self.n_i, self.h_i, self.w_i) 584 | 585 | x = self.conv1(x) 586 | width = x.shape[2] 587 | x = x.view(inputs.shape[0], self.ch_j, self.n_j, width, width) 588 | return squash(x,dim=2)# squash(x).shape: (batch, channels, dimension, ht, wdth) 589 | 590 | 591 | # In[9]: 592 | 593 | 594 | print("class ConvCapsLayer3D") 595 | 596 | # SEE kernel_initializer, 597 | 598 | class ConvCapsLayer3D(nn.Module): 599 | def __init__(self, ch_i, n_i, ch_j=32, n_j=4, kernel_size=3, r_num=3): 600 | 601 | super().__init__() 602 | self.ch_i = ch_i 603 | self.n_i = n_i 604 | self.ch_j = ch_j 605 | self.n_j = n_j 606 | self.kernel_size = kernel_size 607 | self.r_num = r_num 608 | in_channels = 1 609 | out_channels = self.ch_j * self.n_j 610 | stride = (n_i, 1, 1) 611 | pad = (0, 1, 1) 612 | 613 | # self.w = nn.Parameter(torch.randn(ch_j*n_j, 1, n_i, 3, 3)).cuda() 614 | 615 | 616 | self.conv1 = nn.Conv3d(in_channels=in_channels, 617 | out_channels=out_channels, 618 | kernel_size=self.kernel_size, 619 | stride=stride, 620 | padding=pad).cuda() 621 | 622 | 623 | def forward(self, inputs): 624 | # x.shape = (batch, channels, dimension, height, width) 625 | self.batch, self.ch_i, self.n_i, self.h_i, self.w_i = inputs.shape 626 | in_size = self.h_i 627 | out_size = self.h_i 628 | 629 | x = inputs.view(self.batch, self.ch_i * self.n_i, self.h_i, self.w_i) 630 | x = x.unsqueeze(1) 631 | x = self.conv1(x) 632 | self.width = x.shape[-1] 633 | 634 | x = x.permute(0,2,1,3,4) 635 | x = x.view(self.batch, self.ch_i, self.ch_j, self.n_j, self.width, self.width) 636 | x = x.permute(0, 4, 5, 3, 2, 1).contiguous() 637 | self.B = x.new(x.shape[0], self.width, self.width, 1, self.ch_j, self.ch_i).zero_() 638 | x = self.update_routing(x, self.r_num) 639 | return x 640 | 641 | def update_routing(self, x, itr=3): 642 | # x.shape = (batch, width, width, n_j, ch_j, ch_i) 643 | for i in range(itr): 644 | # softmax of self.B along (1,2,4) 645 | tmp = self.B.permute(0,5,3,1,2,4).contiguous().reshape(x.shape[0],self.ch_i,1,self.width*self.width*self.ch_j) 646 | #k = softmax_3d(self.B, (1,2,4)) # (batch, width, width, 1, ch_j, ch_i) 647 | #k = func.softmax(self.B, dim=4) 648 | k = func.softmax(tmp,dim=-1) 649 | k = k.reshape(x.shape[0],self.ch_i,1,self.width,self.width,self.ch_j).permute(0,3,4,2,5,1).contiguous() 650 | S_tmp = k * x 651 | S = torch.sum(S_tmp, dim=-1, keepdim=True) 652 | S_hat = squash(S) 653 | 654 | if i < (itr-1): 655 | agrements = (S_hat * x).sum(dim=3, keepdim=True) # sum over n_j dimension 656 | self.B = self.B + agrements 657 | 658 | S_hat = S_hat.squeeze(-1) 659 | #batch, h_j, w_j, n_j, ch_j = S_hat.shape 660 | return S_hat.permute(0, 4, 3, 1, 2).contiguous() 661 | 662 | 663 | # In[10]: 664 | 665 | 666 | print("class Mask_CID") 667 | 668 | class Mask_CID(nn.Module): 669 | def __init__(self): 670 | super().__init__() 671 | 672 | def forward(self, x, target=None): 673 | # x.shape = (batch, classes, dim) 674 | # one-hot required 675 | if target is None: 676 | classes = torch.norm(x, dim=2) 677 | max_len_indices = classes.max(dim=1)[1].squeeze() 678 | else: 679 | max_len_indices = target.max(dim=1)[1] 680 | 681 | # print("max_len_indices: ", max_len_indices) 682 | increasing = torch.arange(start=0, end=x.shape[0]).cuda() 683 | m = torch.stack([increasing, max_len_indices], dim=1) 684 | 685 | masked = torch.zeros((x.shape[0], 1) + x.shape[2:]) 686 | for i in increasing: 687 | masked[i] = x[m[i][0], m[i][1], :].unsqueeze(0) 688 | 689 | return masked.squeeze(-1), max_len_indices # dim: (batch, 1, capsule_dim) 690 | 691 | 692 | # In[11]: 693 | 694 | 695 | print("class CapsuleLayer") 696 | 697 | class CapsuleLayer(nn.Module): 698 | def __init__(self, num_capsules=10, num_routes=640, in_channels=8, out_channels=16, routing_iters=3): 699 | # in_channels: input_dim; out_channels: output_dim. 700 | super().__init__() 701 | 702 | self.num_capsules = num_capsules 703 | self.num_routes = num_routes 704 | self.routing_iters = routing_iters 705 | 706 | self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels) * 0.01) 707 | self.bias = nn.Parameter(torch.rand(1, 1, num_capsules, out_channels) * 0.01) 708 | 709 | def forward(self, x): 710 | # x: [batch_size, 32, 16] -> [batch_size, 32, 1, 16] 711 | # -> [batch_size, 32, 1, 16, 1] 712 | # print("CapsuleLayer_x.shape: ", x.shape) 713 | x = x.unsqueeze(2).unsqueeze(dim=4) 714 | 715 | u_hat = torch.matmul(self.W, x).squeeze() # u_hat -> [batch_size, 32, 10, 32] 716 | 717 | # b_ij = torch.zeros((batch_size, self.num_routes, self.num_capsules, 1)) 718 | b_ij = x.new(x.shape[0], self.num_routes, self.num_capsules, 1).zero_() 719 | 720 | for itr in range(self.routing_iters): 721 | c_ij = func.softmax(b_ij, dim=2) 722 | s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + self.bias 723 | v_j = squash(s_j, dim=-1) 724 | 725 | if itr < self.routing_iters-1: 726 | a_ij = (u_hat * v_j).sum(dim=-1, keepdim=True) 727 | b_ij = b_ij + a_ij 728 | v_j = v_j.squeeze() #.unsqueeze(-1) 729 | 730 | return v_j # dim: (batch, num_capsules, out_channels or dim_capsules) 731 | 732 | 733 | # In[12]: 734 | 735 | 736 | print("class Decoder_mnist") 737 | 738 | class Decoder_mnist(nn.Module): 739 | def __init__(self, caps_size=16, num_caps=1, img_size=28, img_channels=1): 740 | super().__init__() 741 | 742 | self.num_caps = num_caps 743 | self.img_channels = img_channels 744 | self.img_size = img_size 745 | 746 | self.dense = torch.nn.Linear(caps_size*num_caps, 7*7*16).cuda(device) 747 | self.relu = nn.ReLU(inplace=True) 748 | 749 | 750 | self.reconst_layers1 = nn.Sequential(nn.BatchNorm2d(num_features=16, momentum=0.8), 751 | 752 | nn.ConvTranspose2d(in_channels=16, out_channels=64, 753 | kernel_size=3, stride=1, padding=1 754 | ) 755 | ) 756 | 757 | self.reconst_layers2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, 758 | kernel_size=3, stride=2, padding=1 759 | ) 760 | 761 | self.reconst_layers3 = nn.ConvTranspose2d(in_channels=32, out_channels=16, 762 | kernel_size=3, stride=2, padding=1 763 | ) 764 | 765 | self.reconst_layers4 = nn.ConvTranspose2d(in_channels=16, out_channels=1, 766 | kernel_size=3, stride=1, padding=1 767 | ) 768 | 769 | self.reconst_layers5 = nn.ReLU() 770 | 771 | 772 | 773 | def forward(self, x): 774 | # x.shape = (batch, 1, capsule_dim(=32 for MNIST)) 775 | batch = x.shape[0] 776 | 777 | x = x.type(torch.FloatTensor) 778 | 779 | x = x.cuda() 780 | 781 | x = self.dense(x) 782 | x = self.relu(x) 783 | x = x.reshape(-1, 16, 7, 7) 784 | 785 | x = self.reconst_layers1(x) 786 | 787 | x = self.reconst_layers2(x) 788 | 789 | # padding 790 | p2d = (1, 0, 1, 0) 791 | x = func.pad(x, p2d, "constant", 0) 792 | x = self.reconst_layers3(x) 793 | 794 | # padding 795 | p2d = (1, 0, 1, 0) 796 | x = func.pad(x, p2d, "constant", 0) 797 | x = self.reconst_layers4(x) 798 | 799 | x = self.reconst_layers5(x) 800 | x = x.reshape(-1, 1, self.img_size, self.img_size) 801 | return x # dim: (batch, 1, imsize, imsize) 802 | 803 | 804 | class Decoder_mnist32x32(nn.Module): 805 | def __init__(self, caps_size=16, num_caps=1, img_size=28, img_channels=1): 806 | super().__init__() 807 | 808 | self.num_caps = num_caps 809 | self.img_channels = img_channels 810 | self.img_size = img_size 811 | 812 | self.dense = torch.nn.Linear(caps_size*num_caps, 8*8*16).cuda(device) 813 | self.relu = nn.ReLU(inplace=True) 814 | 815 | 816 | self.reconst_layers1 = nn.Sequential(nn.BatchNorm2d(num_features=16, momentum=0.8), 817 | 818 | nn.ConvTranspose2d(in_channels=16, out_channels=64, 819 | kernel_size=3, stride=1, padding=1 820 | ) 821 | ) 822 | 823 | self.reconst_layers2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, 824 | kernel_size=3, stride=2, padding=1 825 | ) 826 | 827 | self.reconst_layers3 = nn.ConvTranspose2d(in_channels=32, out_channels=16, 828 | kernel_size=3, stride=2, padding=1 829 | ) 830 | 831 | self.reconst_layers4 = nn.ConvTranspose2d(in_channels=16, out_channels=3, 832 | kernel_size=3, stride=1, padding=1 833 | ) 834 | 835 | # self.reconst_layers4 = nn.ConvTranspose2d(in_channels=8, out_channels=3, 836 | # kernel_size=3, stride=1, padding=1 837 | # ) 838 | 839 | self.reconst_layers5 = nn.ReLU() 840 | 841 | 842 | 843 | def forward(self, x): 844 | # x.shape = (batch, 1, capsule_dim(=32 for MNIST)) 845 | batch = x.shape[0] 846 | 847 | x = x.type(torch.FloatTensor) 848 | 849 | x = x.cuda() 850 | 851 | x = self.dense(x) 852 | x = self.relu(x) 853 | x = x.reshape(-1, 16, 8, 8) 854 | 855 | x = self.reconst_layers1(x) 856 | 857 | x = self.reconst_layers2(x) 858 | 859 | # padding 860 | p2d = (1, 0, 1, 0) 861 | x = func.pad(x, p2d, "constant", 0) 862 | x = self.reconst_layers3(x) 863 | 864 | # padding 865 | p2d = (1, 0, 1, 0) 866 | x = func.pad(x, p2d, "constant", 0) 867 | x = self.reconst_layers4(x) 868 | 869 | # x = self.reconst_layers5(x) 870 | x = x.reshape(-1, self.img_channels, self.img_size, self.img_size) 871 | return x # dim: (batch, 1, imsize, imsize) 872 | 873 | # In[13]: 874 | 875 | 876 | print("class Model") 877 | 878 | class Model(nn.Module): 879 | def __init__(self): 880 | super().__init__() 881 | self.conv2d = nn.Conv2d(in_channels=1, out_channels=128, 882 | kernel_size=3, stride=1, padding=1) 883 | self.batchNorm = torch.nn.BatchNorm2d(num_features=128, eps=1e-08, momentum=0.99) 884 | self.toCaps = ConvertToCaps() 885 | 886 | self.conv2dCaps1_nj_4_strd_2 = Conv2DCaps(h=28, w=28, ch_i=128, n_i=1, ch_j=32, n_j=4, kernel_size=3, stride=2, r_num=1) 887 | self.conv2dCaps1_nj_4_strd_1_1 = Conv2DCaps(h=14, w=14, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1) 888 | self.conv2dCaps1_nj_4_strd_1_2 = Conv2DCaps(h=14, w=14, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1) 889 | self.conv2dCaps1_nj_4_strd_1_3 = Conv2DCaps(h=14, w=14, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1) 890 | 891 | self.conv2dCaps2_nj_8_strd_2 = Conv2DCaps(h=14, w=14, ch_i=32, n_i=4, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1) 892 | self.conv2dCaps2_nj_8_strd_1_1 = Conv2DCaps(h=7, w=7, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 893 | self.conv2dCaps2_nj_8_strd_1_2 = Conv2DCaps(h=7, w=7, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 894 | self.conv2dCaps2_nj_8_strd_1_3 = Conv2DCaps(h=7, w=7, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 895 | 896 | self.conv2dCaps3_nj_8_strd_2 = Conv2DCaps(h=7, w=7, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1) 897 | self.conv2dCaps3_nj_8_strd_1_1 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 898 | self.conv2dCaps3_nj_8_strd_1_2 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 899 | self.conv2dCaps3_nj_8_strd_1_3 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 900 | 901 | self.conv2dCaps4_nj_8_strd_2 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1) 902 | self.conv3dCaps4_nj_8 = ConvCapsLayer3D(ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, r_num=3) 903 | self.conv2dCaps4_nj_8_strd_1_1 = Conv2DCaps(h=2, w=2, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 904 | self.conv2dCaps4_nj_8_strd_1_2 = Conv2DCaps(h=2, w=2, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 905 | 906 | self.decoder = Decoder_mnist(caps_size=16, num_caps=1, img_size=28, img_channels=1) 907 | self.flatCaps = FlattenCaps() 908 | self.digCaps = CapsuleLayer(num_capsules=10, num_routes=640, in_channels=8, out_channels=16, routing_iters=3) 909 | self.capsToScalars = CapsToScalars() 910 | self.mask = Mask_CID() 911 | self.mse_loss = nn.MSELoss(reduction="none") 912 | 913 | def forward(self, x, target=None): 914 | x = self.conv2d(x) 915 | x = self.batchNorm(x) 916 | x = self.toCaps(x) 917 | 918 | x = self.conv2dCaps1_nj_4_strd_2(x) 919 | x_skip = self.conv2dCaps1_nj_4_strd_1_1(x) 920 | x = self.conv2dCaps1_nj_4_strd_1_2(x) 921 | x = self.conv2dCaps1_nj_4_strd_1_3(x) 922 | x = x + x_skip 923 | 924 | x = self.conv2dCaps2_nj_8_strd_2(x) 925 | x_skip = self.conv2dCaps2_nj_8_strd_1_1(x) 926 | x = self.conv2dCaps2_nj_8_strd_1_2(x) 927 | x = self.conv2dCaps2_nj_8_strd_1_3(x) 928 | x = x + x_skip 929 | 930 | x = self.conv2dCaps3_nj_8_strd_2(x) 931 | x_skip = self.conv2dCaps3_nj_8_strd_1_1(x) 932 | x = self.conv2dCaps3_nj_8_strd_1_2(x) 933 | x = self.conv2dCaps3_nj_8_strd_1_3(x) 934 | x = x + x_skip 935 | x1 = x 936 | 937 | x = self.conv2dCaps4_nj_8_strd_2(x) 938 | x_skip = self.conv3dCaps4_nj_8(x) 939 | x = self.conv2dCaps4_nj_8_strd_1_1(x) 940 | x = self.conv2dCaps4_nj_8_strd_1_2(x) 941 | x = x + x_skip 942 | x2 = x 943 | 944 | xa = self.flatCaps(x1) 945 | xb = self.flatCaps(x2) 946 | x = torch.cat((xa, xb), dim=-2) 947 | dig_caps = self.digCaps(x) 948 | 949 | x = self.capsToScalars(dig_caps) 950 | masked, indices = self.mask(dig_caps, target) 951 | decoded = self.decoder(masked) 952 | 953 | return dig_caps, masked, decoded, indices 954 | 955 | def margin_loss(self, x, labels, lamda, m_plus, m_minus): 956 | v_c = torch.norm(x, dim=2, keepdim=True) 957 | tmp1 = func.relu(m_plus - v_c).view(x.shape[0], -1) ** 2 958 | tmp2 = func.relu(v_c - m_minus).view(x.shape[0], -1) ** 2 959 | loss = labels*tmp1 + lamda*(1-labels)*tmp2 960 | loss = loss.sum(dim=1) 961 | return loss 962 | 963 | def reconst_loss(self, recnstrcted, data): 964 | loss = self.mse_loss(recnstrcted.view(recnstrcted.shape[0], -1), data.view(recnstrcted.shape[0], -1)) 965 | return 0.4 * loss.sum(dim=1) 966 | 967 | def loss(self, x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1): 968 | loss = self.margin_loss(x, labels, lamda, m_plus, m_minus) + self.reconst_loss(recnstrcted, data) 969 | return loss.mean() 970 | 971 | #################################################################################################################################################### 972 | #################################################################################################################################################### 973 | class Model32x32(nn.Module): 974 | def __init__(self): 975 | super().__init__() 976 | self.conv2d = nn.Conv2d(in_channels=3, out_channels=128, 977 | kernel_size=3, stride=1, padding=1) 978 | self.batchNorm = torch.nn.BatchNorm2d(num_features=128, eps=1e-08, momentum=0.99) 979 | self.toCaps = ConvertToCaps() 980 | 981 | self.conv2dCaps1_nj_4_strd_2 = Conv2DCaps(h=32, w=32, ch_i=128, n_i=1, ch_j=32, n_j=4, kernel_size=3, stride=2, r_num=1) 982 | self.conv2dCaps1_nj_4_strd_1_1 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1) 983 | self.conv2dCaps1_nj_4_strd_1_2 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1) 984 | self.conv2dCaps1_nj_4_strd_1_3 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1) 985 | 986 | self.conv2dCaps2_nj_8_strd_2 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1) 987 | self.conv2dCaps2_nj_8_strd_1_1 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 988 | self.conv2dCaps2_nj_8_strd_1_2 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 989 | self.conv2dCaps2_nj_8_strd_1_3 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 990 | 991 | self.conv2dCaps3_nj_8_strd_2 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1) 992 | self.conv2dCaps3_nj_8_strd_1_1 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 993 | self.conv2dCaps3_nj_8_strd_1_2 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 994 | self.conv2dCaps3_nj_8_strd_1_3 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 995 | 996 | self.conv2dCaps4_nj_8_strd_2 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1) 997 | self.conv3dCaps4_nj_8 = ConvCapsLayer3D(ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, r_num=3) 998 | self.conv2dCaps4_nj_8_strd_1_1 = Conv2DCaps(h=2, w=2, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 999 | self.conv2dCaps4_nj_8_strd_1_2 = Conv2DCaps(h=2, w=2, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1) 1000 | 1001 | self.decoder = Decoder_mnist32x32(caps_size=32, num_caps=1, img_size=32, img_channels=3) 1002 | self.flatCaps = FlattenCaps() 1003 | self.digCaps = CapsuleLayer(num_capsules=10, num_routes=64*10, in_channels=8, out_channels=32, routing_iters=3) 1004 | self.capsToScalars = CapsToScalars() 1005 | self.mask = Mask_CID() 1006 | self.mse_loss = nn.MSELoss(reduction="none") 1007 | 1008 | def forward(self, x, target=None): 1009 | x = self.conv2d(x) 1010 | x = self.batchNorm(x) 1011 | x = self.toCaps(x) 1012 | 1013 | x = self.conv2dCaps1_nj_4_strd_2(x) 1014 | x_skip = self.conv2dCaps1_nj_4_strd_1_1(x) 1015 | x = self.conv2dCaps1_nj_4_strd_1_2(x) 1016 | x = self.conv2dCaps1_nj_4_strd_1_3(x) 1017 | x = x + x_skip 1018 | 1019 | x = self.conv2dCaps2_nj_8_strd_2(x) 1020 | x_skip = self.conv2dCaps2_nj_8_strd_1_1(x) 1021 | x = self.conv2dCaps2_nj_8_strd_1_2(x) 1022 | x = self.conv2dCaps2_nj_8_strd_1_3(x) 1023 | x = x + x_skip 1024 | 1025 | x = self.conv2dCaps3_nj_8_strd_2(x) 1026 | x_skip = self.conv2dCaps3_nj_8_strd_1_1(x) 1027 | x = self.conv2dCaps3_nj_8_strd_1_2(x) 1028 | x = self.conv2dCaps3_nj_8_strd_1_3(x) 1029 | x = x + x_skip 1030 | x1 = x 1031 | 1032 | x = self.conv2dCaps4_nj_8_strd_2(x) 1033 | x_skip = self.conv3dCaps4_nj_8(x) 1034 | x = self.conv2dCaps4_nj_8_strd_1_1(x) 1035 | x = self.conv2dCaps4_nj_8_strd_1_2(x) 1036 | x = x + x_skip 1037 | x2 = x 1038 | 1039 | # x1.shape : torch.Size([64, 32, 8, 4, 4]) | x2.shape : torch.Size([64, 32, 8, 2, 2]) (for CIFAR10) 1040 | xa = self.flatCaps(x1) 1041 | xb = self.flatCaps(x2) 1042 | x = torch.cat((xa, xb), dim=-2) 1043 | dig_caps = self.digCaps(x) 1044 | 1045 | x = self.capsToScalars(dig_caps) 1046 | masked, indices = self.mask(dig_caps, target) 1047 | decoded = self.decoder(masked) 1048 | 1049 | return dig_caps, masked, decoded, indices 1050 | 1051 | def margin_loss(self, x, labels, lamda, m_plus, m_minus): 1052 | v_c = torch.norm(x, dim=2, keepdim=True) 1053 | tmp1 = func.relu(m_plus - v_c).view(x.shape[0], -1) ** 2 1054 | tmp2 = func.relu(v_c - m_minus).view(x.shape[0], -1) ** 2 1055 | loss = labels*tmp1 + lamda*(1-labels)*tmp2 1056 | loss = loss.sum(dim=1) 1057 | return loss 1058 | 1059 | def reconst_loss(self, recnstrcted, data): 1060 | loss = self.mse_loss(recnstrcted.view(recnstrcted.shape[0], -1), data.view(recnstrcted.shape[0], -1)) 1061 | return 0.4 * loss.sum(dim=1) 1062 | 1063 | def loss(self, x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1): 1064 | loss = self.margin_loss(x, labels, lamda, m_plus, m_minus) + self.reconst_loss(recnstrcted, data) 1065 | return loss.mean() 1066 | 1067 | 1068 | #################################################################################################################################################### 1069 | #################################################################################################################################################### 1070 | 1071 | 1072 | # In[14]: 1073 | 1074 | 1075 | # loss 1076 | mse_loss = nn.MSELoss(reduction='none') 1077 | 1078 | def margin_loss(x, labels, lamda=0.5, m_plus=0.9, m_minus=0.1): 1079 | v_c = torch.norm(x, dim=2, keepdim=True) 1080 | tmp1 = func.relu(m_plus - v_c).view(x.shape[0], -1) ** 2 1081 | tmp2 = func.relu(v_c - m_minus).view(x.shape[0], -1) ** 2 1082 | loss_ = labels*tmp1 + lamda*(1-labels)*tmp2 1083 | loss_ = loss_.sum(dim=1) 1084 | return loss_ 1085 | 1086 | def reconst_loss(recnstrcted, data): 1087 | loss = mse_loss(recnstrcted.view(recnstrcted.shape[0], -1), data.view(recnstrcted.shape[0], -1)) 1088 | return 0.4 * loss.sum(dim=1) 1089 | 1090 | def loss(x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1): 1091 | loss_ = margin_loss(x, labels, lamda, m_plus, m_minus) + reconst_loss(recnstrcted, data) 1092 | return loss_.mean() 1093 | 1094 | 1095 | # In[15]: 1096 | 1097 | 1098 | model = Model().cuda() 1099 | # Uncomment below line for CIFAR10 1100 | # model = Model32x32().cuda() 1101 | 1102 | 1103 | # lr = 0.001 1104 | # optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 1105 | # # torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) 1106 | # lambda1 = lambda: epoch: lr * 0.5**(epoch // 10) 1107 | # lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 1108 | batch_size = 64 1109 | num_epochs = 100 1110 | lamda = 0.5 1111 | m_plus = 0.9 1112 | m_minus = 0.1 1113 | 1114 | 1115 | # In[16]: 1116 | 1117 | 1118 | train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(root='/home/mtech3/CODES/ankit/data/FashionMNIST/FashionMNIST/',train=True,download=True,transform=trans(rotation_range=0.1, translation_range=0.1, zoom_range=(0.1, 0.2))),batch_size=batch_size,shuffle=True) 1119 | test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(root='/home/mtech3/CODES/ankit/data/FashionMNIST/FashionMNIST/',train=False,download=True,transform=transforms.ToTensor()),batch_size=batch_size,shuffle=True) 1120 | 1121 | ###################### 1122 | # Uncomment these for CIFAR10 1123 | # train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./CIFAR10',train=True,download=True,transform=trans(rotation_range=0.1, translation_range=0.1, zoom_range=(0.1, 0.2))),batch_size=batch_size,shuffle=True) 1124 | # test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./CIFAR10',train=False,download=True,transform=transforms.ToTensor()),batch_size=batch_size,shuffle=True) 1125 | ###################### 1126 | 1127 | # In[17]: 1128 | 1129 | 1130 | def accuracy(indices, labels): 1131 | correct = 0.0 1132 | for i in range(indices.shape[0]): 1133 | if float(indices[i]) == labels[i]: 1134 | correct += 1 1135 | return correct 1136 | 1137 | 1138 | # In[18]: 1139 | 1140 | 1141 | print("def test") 1142 | 1143 | def test(model, test_loader, loss, batch_size, lamda=0.5, m_plus=0.9, m_minus=0.1): 1144 | test_loss = 0.0 1145 | correct = 0.0 1146 | for batch_idx, (data, label) in enumerate(test_loader): 1147 | data, labels = data.cuda(), one_hot(label.cuda()) 1148 | outputs, masked_output, recnstrcted, indices = model(data) 1149 | # if batch_idx == 9: 1150 | # print("test: ", indices) 1151 | loss_test = model.loss(outputs, recnstrcted, data, labels, lamda, m_plus, m_minus) 1152 | test_loss += loss_test.data 1153 | indices_cpu, labels_cpu = indices.cpu(), label.cpu() 1154 | # for i in range(indices_cpu.shape[0]): 1155 | # if float(indices_cpu[i]) == labels_cpu[i]: 1156 | # correct += 1 1157 | correct += accuracy(indices_cpu, labels_cpu) 1158 | # if batch_idx % 100 == 0: 1159 | # print("batch: ", batch_idx, "accuracy: ", correct/len(test_loader.dataset)) 1160 | # print(indices_cpu) 1161 | print("\nTest Loss: ", test_loss/len(test_loader.dataset), "; Test Accuracy: " , correct/len(test_loader.dataset) * 100,'\n') 1162 | 1163 | 1164 | # In[ ]: 1165 | 1166 | 1167 | def train(train_loader, model, num_epochs, lr=0.001, batch_size=64, lamda=0.5, m_plus=0.9, m_minus=0.1): 1168 | optimizer = torch.optim.Adam(model.parameters(), lr) 1169 | lambda1 = lambda epoch: 0.5**(epoch // 10) 1170 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 1171 | #lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.96) 1172 | for epoch in range(num_epochs): 1173 | for batch_idx, (data, label_) in enumerate(train_loader): 1174 | data, label = data.cuda(), label_.cuda() 1175 | labels = one_hot(label) 1176 | optimizer.zero_grad() 1177 | outputs, masked, recnstrcted, indices = model(data, labels) 1178 | loss_val = model.loss(outputs, recnstrcted, data, labels, lamda, m_plus, m_minus) 1179 | loss_val.backward() 1180 | optimizer.step() 1181 | if batch_idx%100 == 0: 1182 | outputs, masked, recnstrcted, indices = model(data) 1183 | loss_val = model.loss(outputs, recnstrcted, data, labels, lamda, m_plus, m_minus) 1184 | print("epoch: ", epoch, "batch_idx: ", batch_idx, "loss: ", loss_val, "accuracy: ", accuracy(indices, label_.cpu())/indices.shape[0]) 1185 | test(model, test_loader, loss, batch_size, lamda, m_plus, m_minus) 1186 | lr_scheduler.step() 1187 | 1188 | 1189 | # In[ ]: 1190 | 1191 | 1192 | # soft-training 1193 | train(train_loader, model, num_epochs=100, lr=0.001, batch_size=256, lamda=0.5, m_plus=0.9, m_minus=0.1) 1194 | 1195 | # Hard-Training 1196 | print("\n\n\n\nHard-Training\n") 1197 | train(train_loader, model, num_epochs=100, lr=0.001, batch_size=256, lamda=0.8, m_plus=0.95, m_minus=0.05) 1198 | 1199 | --------------------------------------------------------------------------------