├── LICENSE ├── .gitignore ├── README.md ├── experiments ├── generate_data.py └── train_and_test.py ├── coordconv.py └── CoordConv.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Walsvid 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Settings 2 | .idea 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoordConv 2 | 3 | ![](https://img.shields.io/badge/pytorch-0.4.0-blue.svg) ![](https://img.shields.io/badge/python-3.6.5-brightgreen.svg) 4 | 5 | Pytorch implementation of CoordConv for N-D ConvLayers, and the experiments. 6 | 7 | Reference from the paper: [An intriguing failing of convolutional neural networks and the CoordConv solution](https://arxiv.org/abs/1807.03247) 8 | 9 | Extends the CoordinateChannel concatenation from 2D to 1D and 3D tensors. 10 | 11 | # Requirements 12 | 13 | - pytorch 0.4.0 14 | - torchvision 0.2.1 15 | - torchsummary 1.3 16 | - sklearn 0.19.1 17 | 18 | # Usage 19 | 20 | ```python 21 | from coordconv import CoordConv1d, CoordConv2d, CoordConv3d 22 | 23 | class Net(nn.Module): 24 | def __init__(self): 25 | super(Net, self).__init__() 26 | self.coordconv = CoordConv2d(2, 32, 1, with_r=True) 27 | self.conv1 = nn.Conv2d(32, 64, 1) 28 | self.conv2 = nn.Conv2d(64, 64, 1) 29 | self.conv3 = nn.Conv2d(64, 1, 1) 30 | self.conv4 = nn.Conv2d( 1, 1, 1) 31 | 32 | def forward(self, x): 33 | x = self.coordconv(x) 34 | x = F.relu(self.conv1(x)) 35 | x = F.relu(self.conv2(x)) 36 | x = F.relu(self.conv3(x)) 37 | x = self.conv4(x) 38 | x = x.view(-1, 64*64) 39 | return x 40 | 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | net = Net().to(device) 43 | ``` 44 | 45 | # Experiments 46 | 47 | Implement experiments from origin paper. 48 | 49 | ## Coordinate Classification 50 | 51 | Use `experiments/generate_data.py` to generate `Uniform` and `Quadrant` datasets for Coordinate Classification task. 52 | 53 | Use `experiments/train_and_test.py` to train and test neural network model. 54 | 55 | ### Uniform Datasets 56 | 57 | |Train|Test|Predictions| 58 | |:---:|:---:|:---:| 59 | |![](https://i.loli.net/2018/07/16/5b4c7db11abf9.png)|![](https://i.loli.net/2018/07/16/5b4c7dbd03169.png)|![](https://i.loli.net/2018/07/16/5b4c8d88a70a2.png)| 60 | 61 | 62 | ### Quadrant Datasets 63 | 64 | |Train|Test|Predictions| 65 | |:---:|:---:|:---:| 66 | |![](https://i.loli.net/2018/07/16/5b4c98bba0fec.png)|![](https://i.loli.net/2018/07/16/5b4c98cbf0293.png)|![](https://i.loli.net/2018/07/16/5b4c98d77096f.png)| -------------------------------------------------------------------------------- /experiments/generate_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from sklearn.model_selection import train_test_split 6 | 7 | datatype = 'uniform' 8 | assert datatype in ['uniform', 'quadrant'] 9 | 10 | if not os.path.exists('data-uniform/'): 11 | os.makedirs('data-uniform/') 12 | 13 | if not os.path.exists('data-quadrant/'): 14 | os.makedirs('data-quadrant/') 15 | 16 | np.random.seed(0) 17 | torch.manual_seed(0) 18 | 19 | onehots = np.pad(np.eye(3136, dtype='float32').reshape((3136, 56, 56, 1)), 20 | ((0, 0), (4, 4), (4, 4), (0, 0)), mode="constant") 21 | onehots = onehots.transpose(0, 3, 1, 2) 22 | 23 | onehots_tensor = torch.from_numpy(onehots) 24 | 25 | conv_layer = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(9, 9), padding=4, stride=1, bias=False) 26 | w = torch.ones(1, 1, 9, 9) 27 | conv_layer.weight.data = w 28 | 29 | images_tensor = conv_layer(onehots_tensor) 30 | images = images_tensor.detach().numpy() 31 | 32 | if datatype == 'uniform': 33 | # Create the uniform datasets 34 | indices = np.arange(0, len(onehots), dtype='int32') 35 | train, test = train_test_split(indices, test_size=0.2, random_state=0) 36 | 37 | train_onehot = onehots[train] 38 | train_images = images[train] 39 | 40 | test_onehot = onehots[test] 41 | test_images = images[test] 42 | 43 | np.save('data-uniform/train_onehot.npy', train_onehot) 44 | np.save('data-uniform/train_images.npy', train_images) 45 | np.save('data-uniform/test_onehot.npy', test_onehot) 46 | np.save('data-uniform/test_images.npy', test_images) 47 | else: 48 | pos_quadrant = np.where(onehots == 1.0) 49 | # print(onehots.shape) 50 | X = pos_quadrant[2] 51 | Y = pos_quadrant[3] 52 | 53 | train_set = [] 54 | test_set = [] 55 | 56 | train_ids = [] 57 | test_ids = [] 58 | 59 | for i, (x, y) in enumerate(zip(X, Y)): 60 | if x > 32 and y > 32: # 4th quadrant 61 | test_ids.append(i) 62 | test_set.append([x, y]) 63 | else: 64 | train_ids.append(i) 65 | train_set.append([x, y]) 66 | 67 | train_set = np.array(train_set) 68 | test_set = np.array(test_set) 69 | 70 | train_set = train_set[:, :, None, None] 71 | test_set = test_set[:, :, None, None] 72 | 73 | print(train_set.shape) 74 | print(test_set.shape) 75 | 76 | train_onehot = onehots[train_ids] 77 | test_onehot = onehots[test_ids] 78 | 79 | train_images = images[train_ids] 80 | test_images = images[test_ids] 81 | 82 | print(train_onehot.shape, test_onehot.shape) 83 | print(train_images.shape, test_images.shape) 84 | 85 | np.save('data-quadrant/train_set.npy', train_set) 86 | np.save('data-quadrant/test_set.npy', test_set) 87 | np.save('data-quadrant/train_onehot.npy', train_onehot) 88 | np.save('data-quadrant/train_images.npy', train_images) 89 | np.save('data-quadrant/test_onehot.npy', test_onehot) 90 | np.save('data-quadrant/test_images.npy', test_images) 91 | -------------------------------------------------------------------------------- /experiments/train_and_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.data as utils 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | # from torchsummary import summary 10 | 11 | datatype = 'uniform' 12 | assert datatype in ['uniform', 'quadrant'] 13 | 14 | if datatype == 'uniform': 15 | # Load the one hot datasets 16 | train_onehot = np.load('data-uniform/train_onehot.npy').astype('float32') 17 | test_onehot = np.load('data-uniform/test_onehot.npy').astype('float32') 18 | 19 | # (N, C, H, W) <=== 数据格式 20 | # make the train and test datasets 21 | # train 22 | pos_train = np.where(train_onehot == 1.0) 23 | X_train = pos_train[2] 24 | Y_train = pos_train[3] 25 | train_set = np.zeros((len(X_train), 2, 1, 1), dtype='float32') 26 | for i, (x, y) in enumerate(zip(X_train, Y_train)): 27 | train_set[i, 0, 0, 0] = x 28 | train_set[i, 1, 0, 0] = y 29 | 30 | # test 31 | pos_test = np.where(test_onehot == 1.0) 32 | X_test = pos_test[2] 33 | Y_test = pos_test[3] 34 | test_set = np.zeros((len(X_test), 2, 1, 1), dtype='float32') 35 | for i, (x, y) in enumerate(zip(X_test, Y_test)): 36 | test_set[i, 0, 0, 0] = x 37 | test_set[i, 1, 0, 0] = y 38 | 39 | train_set = np.tile(train_set, [1, 1, 64, 64]) 40 | test_set = np.tile(test_set, [1, 1, 64, 64]) 41 | 42 | # Normalize the datasets 43 | train_set /= (64. - 1.) # 64x64 grid, 0-based index 44 | test_set /= (64. - 1.) # 64x64 grid, 0-based index 45 | 46 | print('Train set : ', train_set.shape, train_set.max(), train_set.min()) 47 | print('Test set : ', test_set.shape, test_set.max(), test_set.min()) 48 | 49 | # Visualize the datasets 50 | 51 | plt.imshow(np.sum(train_onehot, axis=0)[0, :, :], cmap='gray') 52 | plt.title('Train One-hot dataset') 53 | plt.show() 54 | plt.imshow(np.sum(test_onehot, axis=0)[0, :, :], cmap='gray') 55 | plt.title('Test One-hot dataset') 56 | plt.show() 57 | 58 | else: 59 | # Load the one hot datasets and the train / test set 60 | train_set = np.load('data-quadrant/train_set.npy').astype('float32') 61 | test_set = np.load('data-quadrant/test_set.npy').astype('float32') 62 | 63 | train_onehot = np.load('data-quadrant/train_onehot.npy').astype('float32') 64 | test_onehot = np.load('data-quadrant/test_onehot.npy').astype('float32') 65 | 66 | train_set = np.tile(train_set, [1, 1, 64, 64]) 67 | test_set = np.tile(test_set, [1, 1, 64, 64]) 68 | 69 | # Normalize datasets 70 | train_set /= train_set.max() 71 | test_set /= test_set.max() 72 | 73 | print('Train set : ', train_set.shape, train_set.max(), train_set.min()) 74 | print('Test set : ', test_set.shape, test_set.max(), test_set.min()) 75 | 76 | # Visualize the datasets 77 | 78 | plt.imshow(np.sum(train_onehot, axis=0)[0, :, :], cmap='gray') 79 | plt.title('Train One-hot dataset') 80 | plt.show() 81 | plt.imshow(np.sum(test_onehot, axis=0)[0, :, :], cmap='gray') 82 | plt.title('Test One-hot dataset') 83 | plt.show() 84 | 85 | # flatten the datasets 86 | train_onehot = train_onehot.reshape((-1, 64 * 64)).astype('int64') 87 | test_onehot = test_onehot.reshape((-1, 64 * 64)).astype('int64') 88 | 89 | # model definition 90 | 91 | class Net(nn.Module): 92 | def __init__(self): 93 | super(Net, self).__init__() 94 | self.coordconv = CoordConv2d(2, 32, 1, with_r=True) 95 | self.conv1 = nn.Conv2d(32, 64, 1) 96 | self.conv2 = nn.Conv2d(64, 64, 1) 97 | self.conv3 = nn.Conv2d(64, 1, 1) 98 | self.conv4 = nn.Conv2d( 1, 1, 1) 99 | 100 | def forward(self, x): 101 | x = self.coordconv(x) 102 | x = F.relu(self.conv1(x)) 103 | x = F.relu(self.conv2(x)) 104 | x = F.relu(self.conv3(x)) 105 | x = self.conv4(x) 106 | x = x.view(-1, 64*64) 107 | return x 108 | 109 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 110 | net = Net().to(device) 111 | 112 | # summary(net, input_size=(2, 64, 64)) 113 | # ---------------------------------------------------------------- 114 | # Layer (type) Output Shape Param # 115 | # ================================================================ 116 | # AddCoords-1 [-1, 5, 64, 64] 0 117 | # Conv2d-2 [-1, 32, 64, 64] 192 118 | # CoordConv2d-3 [-1, 32, 64, 64] 96 119 | # Conv2d-4 [-1, 64, 64, 64] 2,112 120 | # Conv2d-5 [-1, 64, 64, 64] 4,160 121 | # Conv2d-6 [-1, 1, 64, 64] 65 122 | # Conv2d-7 [-1, 1, 64, 64] 2 123 | # ================================================================ 124 | # Total params: 6,627 125 | # Trainable params: 6,627 126 | # Non-trainable params: 0 127 | # ---------------------------------------------------------------- 128 | 129 | train_tensor_x = torch.stack([torch.Tensor(i) for i in train_set]) 130 | train_tensor_y = torch.stack([torch.LongTensor(i) for i in train_onehot]) 131 | 132 | train_dataset = utils.TensorDataset(train_tensor_x,train_tensor_y) 133 | train_dataloader = utils.DataLoader(train_dataset, batch_size=32, shuffle=False) 134 | 135 | test_tensor_x = torch.stack([torch.Tensor(i) for i in test_set]) 136 | test_tensor_y = torch.stack([torch.LongTensor(i) for i in test_onehot]) 137 | 138 | test_dataset = utils.TensorDataset(test_tensor_x,test_tensor_y) 139 | test_dataloader = utils.DataLoader(test_dataset, batch_size=32, shuffle=False) 140 | 141 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) 142 | 143 | def cross_entropy_one_hot(input, target): 144 | _, labels = target.max(dim=1) 145 | return nn.CrossEntropyLoss()(input, labels) 146 | 147 | criterion = cross_entropy_one_hot 148 | 149 | epochs = 10 150 | 151 | def train(epoch, net, train_dataloader, optimizer, criterion, device): 152 | net.train() 153 | iters = 0 154 | for batch_idx, (data, target) in enumerate(train_dataloader): 155 | data, target = Variable(data), Variable(target) 156 | data, target = data.to(device), target.to(device) 157 | optimizer.zero_grad() 158 | output = net(data) 159 | loss = criterion(output, target) 160 | loss.backward() 161 | optimizer.step() 162 | iters += len(data) 163 | print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( 164 | epoch, iters, len(train_dataloader.dataset), 165 | 100. * (batch_idx + 1) / len(train_dataloader), loss.data.item()), end='\r', flush=True) 166 | print("") 167 | 168 | 169 | for epoch in range(1, epochs + 1): 170 | train(epoch, net, train_dataloader, optimizer, criterion, device) 171 | 172 | 173 | def test(net, test_loader, optimizer, criterion, device): 174 | net.eval() 175 | test_loss = 0 176 | correct = 0 177 | for data, target in test_loader: 178 | with torch.no_grad(): 179 | data, target = Variable(data), Variable(target) 180 | data, target = data.to(device), target.to(device) 181 | output = net(data) 182 | test_loss += criterion(output, target).item() 183 | _, pred = output.max(1, keepdim=True) 184 | _, label = target.max(dim=1) 185 | correct += pred.eq(label.view_as(pred)).sum().item() 186 | 187 | test_loss = test_loss 188 | test_loss /= len(test_loader) # loss function already averages over batch size 189 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 190 | test_loss, correct, len(test_loader.dataset), 191 | 100. * correct / len(test_loader.dataset))) 192 | 193 | 194 | test(net, test_dataloader, optimizer, criterion, device) 195 | -------------------------------------------------------------------------------- /coordconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.modules.conv as conv 4 | 5 | 6 | class AddCoords(nn.Module): 7 | def __init__(self, rank, with_r=False, use_cuda=True): 8 | super(AddCoords, self).__init__() 9 | self.rank = rank 10 | self.with_r = with_r 11 | self.use_cuda = use_cuda 12 | 13 | def forward(self, input_tensor): 14 | """ 15 | :param input_tensor: shape (N, C_in, H, W) 16 | :return: 17 | """ 18 | if self.rank == 1: 19 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape 20 | xx_range = torch.arange(dim_x, dtype=torch.int32) 21 | xx_channel = xx_range[None, None, :] 22 | 23 | xx_channel = xx_channel.float() / (dim_x - 1) 24 | xx_channel = xx_channel * 2 - 1 25 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) 26 | 27 | if torch.cuda.is_available and self.use_cuda: 28 | input_tensor = input_tensor.cuda() 29 | xx_channel = xx_channel.cuda() 30 | out = torch.cat([input_tensor, xx_channel], dim=1) 31 | 32 | if self.with_r: 33 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) 34 | out = torch.cat([out, rr], dim=1) 35 | 36 | elif self.rank == 2: 37 | batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape 38 | xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) 39 | yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) 40 | 41 | xx_range = torch.arange(dim_y, dtype=torch.int32) 42 | yy_range = torch.arange(dim_x, dtype=torch.int32) 43 | xx_range = xx_range[None, None, :, None] 44 | yy_range = yy_range[None, None, :, None] 45 | 46 | xx_channel = torch.matmul(xx_range, xx_ones) 47 | yy_channel = torch.matmul(yy_range, yy_ones) 48 | 49 | # transpose y 50 | yy_channel = yy_channel.permute(0, 1, 3, 2) 51 | 52 | xx_channel = xx_channel.float() / (dim_y - 1) 53 | yy_channel = yy_channel.float() / (dim_x - 1) 54 | 55 | xx_channel = xx_channel * 2 - 1 56 | yy_channel = yy_channel * 2 - 1 57 | 58 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) 59 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) 60 | 61 | if torch.cuda.is_available and self.use_cuda: 62 | input_tensor = input_tensor.cuda() 63 | xx_channel = xx_channel.cuda() 64 | yy_channel = yy_channel.cuda() 65 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 66 | 67 | if self.with_r: 68 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) 69 | out = torch.cat([out, rr], dim=1) 70 | 71 | elif self.rank == 3: 72 | batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = input_tensor.shape 73 | xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) 74 | yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) 75 | zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) 76 | 77 | xy_range = torch.arange(dim_y, dtype=torch.int32) 78 | xy_range = xy_range[None, None, None, :, None] 79 | 80 | yz_range = torch.arange(dim_z, dtype=torch.int32) 81 | yz_range = yz_range[None, None, None, :, None] 82 | 83 | zx_range = torch.arange(dim_x, dtype=torch.int32) 84 | zx_range = zx_range[None, None, None, :, None] 85 | 86 | xy_channel = torch.matmul(xy_range, xx_ones) 87 | xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) 88 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1, 1) 89 | 90 | yz_channel = torch.matmul(yz_range, yy_ones) 91 | yz_channel = yz_channel.permute(0, 1, 3, 4, 2) 92 | yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) 93 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1, 1) 94 | 95 | zx_channel = torch.matmul(zx_range, zz_ones) 96 | zx_channel = zx_channel.permute(0, 1, 4, 2, 3) 97 | zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) 98 | zz_channel = zz_channel.repeat(batch_size_shape, 1, 1, 1, 1) 99 | 100 | if torch.cuda.is_available and self.use_cuda: 101 | input_tensor = input_tensor.cuda() 102 | xx_channel = xx_channel.cuda() 103 | yy_channel = yy_channel.cuda() 104 | zz_channel = zz_channel.cuda() 105 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1) 106 | 107 | if self.with_r: 108 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + 109 | torch.pow(yy_channel - 0.5, 2) + 110 | torch.pow(zz_channel - 0.5, 2)) 111 | out = torch.cat([out, rr], dim=1) 112 | else: 113 | raise NotImplementedError 114 | 115 | return out 116 | 117 | 118 | class CoordConv1d(conv.Conv1d): 119 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 120 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 121 | super(CoordConv1d, self).__init__(in_channels, out_channels, kernel_size, 122 | stride, padding, dilation, groups, bias) 123 | self.rank = 1 124 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 125 | self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), out_channels, 126 | kernel_size, stride, padding, dilation, groups, bias) 127 | 128 | def forward(self, input_tensor): 129 | """ 130 | input_tensor_shape: (N, C_in,H,W) 131 | output_tensor_shape: N,C_out,H_out,W_out) 132 | :return: CoordConv2d Result 133 | """ 134 | out = self.addcoords(input_tensor) 135 | out = self.conv(out) 136 | 137 | return out 138 | 139 | 140 | class CoordConv2d(conv.Conv2d): 141 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 142 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 143 | super(CoordConv2d, self).__init__(in_channels, out_channels, kernel_size, 144 | stride, padding, dilation, groups, bias) 145 | self.rank = 2 146 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 147 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), out_channels, 148 | kernel_size, stride, padding, dilation, groups, bias) 149 | 150 | def forward(self, input_tensor): 151 | """ 152 | input_tensor_shape: (N, C_in,H,W) 153 | output_tensor_shape: N,C_out,H_out,W_out) 154 | :return: CoordConv2d Result 155 | """ 156 | out = self.addcoords(input_tensor) 157 | out = self.conv(out) 158 | 159 | return out 160 | 161 | 162 | class CoordConv3d(conv.Conv3d): 163 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 164 | padding=0, dilation=1, groups=1, bias=True, with_r=False, use_cuda=True): 165 | super(CoordConv3d, self).__init__(in_channels, out_channels, kernel_size, 166 | stride, padding, dilation, groups, bias) 167 | self.rank = 3 168 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 169 | self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), out_channels, 170 | kernel_size, stride, padding, dilation, groups, bias) 171 | 172 | def forward(self, input_tensor): 173 | """ 174 | input_tensor_shape: (N, C_in,H,W) 175 | output_tensor_shape: N,C_out,H_out,W_out) 176 | :return: CoordConv2d Result 177 | """ 178 | out = self.addcoords(input_tensor) 179 | out = self.conv(out) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /CoordConv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.modules.conv as conv\n", 15 | "import torch.utils.data as utils\n", 16 | "import torch.nn.functional as F\n", 17 | "\n", 18 | "from torch.autograd import Variable\n", 19 | "from sklearn.model_selection import train_test_split\n", 20 | "from torchsummary import summary" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "from coordconv import *" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## Generate Data" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/plain": [ 47 | "" 48 | ] 49 | }, 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "output_type": "execute_result" 53 | } 54 | ], 55 | "source": [ 56 | "datatype = 'uniform'\n", 57 | "assert datatype in ['uniform', 'quadrant']\n", 58 | "\n", 59 | "if not os.path.exists('data-uniform/'):\n", 60 | " os.makedirs('data-uniform/')\n", 61 | "\n", 62 | "if not os.path.exists('data-quadrant/'):\n", 63 | " os.makedirs('data-quadrant/')\n", 64 | "\n", 65 | "np.random.seed(0)\n", 66 | "torch.manual_seed(0)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "onehots = np.pad(np.eye(3136, dtype='float32').reshape((3136, 56, 56, 1)),\n", 76 | " ((0, 0), (4, 4), (4, 4), (0, 0)), mode=\"constant\")\n", 77 | "onehots = onehots.transpose(0, 3, 1, 2)\n", 78 | "\n", 79 | "onehots_tensor = torch.from_numpy(onehots)\n", 80 | "\n", 81 | "conv_layer = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(9,9), padding=4, stride=1)\n", 82 | "w = torch.ones(1, 1, 9, 9)\n", 83 | "conv_layer.weight.data = w\n", 84 | "\n", 85 | "images_tensor = conv_layer(onehots_tensor)\n", 86 | "images = images_tensor.detach().numpy()\n", 87 | "\n", 88 | "if datatype == 'uniform':\n", 89 | " # Create the uniform datasets\n", 90 | " indices = np.arange(0, len(onehots), dtype='int32')\n", 91 | " train, test = train_test_split(indices, test_size=0.2, random_state=0)\n", 92 | "\n", 93 | " train_onehot = onehots[train]\n", 94 | " train_images = images[train]\n", 95 | "\n", 96 | " test_onehot = onehots[test]\n", 97 | " test_images = images[test]\n", 98 | "\n", 99 | " np.save('data-uniform/train_onehot.npy', train_onehot)\n", 100 | " np.save('data-uniform/train_images.npy', train_images)\n", 101 | " np.save('data-uniform/test_onehot.npy', test_onehot)\n", 102 | " np.save('data-uniform/test_images.npy', test_images)\n", 103 | "else:\n", 104 | " pos_quadrant = np.where(onehots == 1.0)\n", 105 | "# print(onehots.shape)\n", 106 | " X = pos_quadrant[2]\n", 107 | " Y = pos_quadrant[3]\n", 108 | " \n", 109 | " train_set = []\n", 110 | " test_set = []\n", 111 | "\n", 112 | " train_ids = []\n", 113 | " test_ids = []\n", 114 | "\n", 115 | " for i, (x, y) in enumerate(zip(X, Y)):\n", 116 | " if x > 32 and y > 32: # 4th quadrant\n", 117 | " test_ids.append(i)\n", 118 | " test_set.append([x, y])\n", 119 | " else:\n", 120 | " train_ids.append(i)\n", 121 | " train_set.append([x, y])\n", 122 | "\n", 123 | " train_set = np.array(train_set)\n", 124 | " test_set = np.array(test_set)\n", 125 | " \n", 126 | " train_set = train_set[:, :, None, None]\n", 127 | " test_set = test_set[:, :, None, None]\n", 128 | "\n", 129 | " print(train_set.shape)\n", 130 | " print(test_set.shape)\n", 131 | "\n", 132 | " train_onehot = onehots[train_ids]\n", 133 | " test_onehot = onehots[test_ids]\n", 134 | "\n", 135 | " train_images = images[train_ids]\n", 136 | " test_images = images[test_ids]\n", 137 | "\n", 138 | " print(train_onehot.shape, test_onehot.shape)\n", 139 | " print(train_images.shape, test_images.shape)\n", 140 | "\n", 141 | " np.save('data-quadrant/train_set.npy', train_set)\n", 142 | " np.save('data-quadrant/test_set.npy', test_set)\n", 143 | " np.save('data-quadrant/train_onehot.npy', train_onehot)\n", 144 | " np.save('data-quadrant/train_images.npy', train_images)\n", 145 | " np.save('data-quadrant/test_onehot.npy', test_onehot)\n", 146 | " np.save('data-quadrant/test_images.npy', test_images)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "## Load data numpy" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 5, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "Train set : (2508, 2, 64, 64) 0.93650794 0.06349207\n", 166 | "Test set : (628, 2, 64, 64) 0.93650794 0.06349207\n" 167 | ] 168 | }, 169 | { 170 | "data": { 171 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAYG0lEQVR4nO3de9AkVXnH8e+PXS7KJbuAkIUFFypUhD/iajYIYiniJYQYwZRELMssCtlUSg2WpriYisEUscSkEKMpdQuQjRIBuQiSBEUElRgXllsEVlzEFTYsrGRZAW+w8OSP7tcMr+/M22/P6e7pPb9P1dQ7l57Tz/S8z/Q5fU6fVkRgZtu+7boOwMza4WQ3y4ST3SwTTnazTDjZzTLhZDfLhJO9BZLmSXpS0v5dx1KFpJMl3dji+j4v6cy21pcrJ/sMysScuj0r6ecDj9821/Ii4pmI2CUiHqgZz06Szpb0QBnL9yW9X5LqlNckSWdJurDB8m+SdGJT5be9njbN7zqASRQRu0zdl7QeODkivjZseUnzI2JrE7GUCX05sAdwNPB94FDgc8C+wPuaWK9tgyLCtxE3YD3w2mnPnQVcAnwBeAI4ETgc+A6wBdgI/BOwfbn8fCCAJeXjz5ev/0f5/v8CDhiy/t8Hfg7sM+35lwPPTL0PuAn4EPDtssxrgd0Hlj9iIL47gFeO+MwnA98APlYufz/w+oHXFwPXAJuBdcA7y+ffADwFPA08Cdw6pPzfLWN4otyGXwTOLF/bA/h34MfAY8CXgX3L184uP/MvyvLPLZ//JLABeBy4BXj5wLoOA24rX3sE+IfZtsmw9fT91nkAk34bkexPAX9E0RR6HvB7wMvKxD6QYg/87nL5mZL9UWAZsD3FD8fnh6z/H4Hrh7z2P8BJ5f2bysQ7CHg+8C3grPK1/YD/LX84tqOoITwK7DGk3JPLhH0nMA94D/DgwOv/CXwC2Al4aVnWqwa2zYUjtueOZWL+ZfnZTyjXdWb5+guAN5XbdDfgCuCygfffBJw4rcy3A7uX2/m0crvsWL52C/DW8v6uwMuqbJOZ1tP3m9vs9d0UEV+OiGcj4ucRcUtErI6IrRFxP7ASeNWI918WEWsi4mngImDpkOX2pKgpzGRj+fqU8yNiXUT8jGJvOVXmnwJXR8RXynivBe6k+Acf5gcRcUFEPAOsAhZL2lPSARTNiNMj4hcRcRvwWYqEq+IIih++T0TE0xFxMXD71IsR8eOIuLLcpo8DH2b0diQiPhcRm6NoSn2U4kfit8qXnwYOkrRHRDwREavH2Ca95mSv78HBB5JeJOnfJD0s6XHg73huIk738MD9nwG7DFnuUWDRkNcWla/PVuYLgbdK2jJ1o6je7iPpyIGDj3eOKIuyvH2ARyPipwOv/4ji+EEV+wAbotx9DrwfAEk7SzqvPBj5OPB1Rm9HJJ0q6XuSfkJR9d954D3vAA4B7pV0s6RjyueHbpOKn6N3fICuvumnC36Gov33loh4UtJfUbRhx/U14F2S9omIh6aelPRy4DeBGyqU8SDw2Yj4iyGvD/uhmclDwJ6Sdh5I+P0pqs7w69tluo0Ubf5B+wN3l/dPBQ4ADo2IhyUto6iKT3lO+ZJeTXGQ8jXAPeXTPwEEEBH3AidI2g44Hrhc0kJm3ybb3Omg3rOnsyvFP9lPJR0M/Hmicr8CfBO4QtIhkuZLOpziaPwnyybDbD4HvEnS68o+/50kvVrSnPdiEfFDYA3wYUk7SlpKsfe8qFzkEWDJiG7Bm4DtJL27/CzHU7T7p+xKUZN4TNIewAenvf8RimMig8tvpajhbA+cSbFnB0DS2yXtGRHPUnw/ATzL7Ntk+np6z8mezvuB5RRHmD9DcdBtbGV19ziKA25fLcv/F+DTwHsrlrGe4qDX31Ac5X6gjLfu9/8WigOBDwOXAR+IiKkaxiXADsBmSTfPEMsvy1j+jKLK/cfAlwYWOQf4DYqDZ9+m6LEYdC7/X/0+h+LI/dcoDk6upzjqPniM4xhgraQnKA52viUinqqwTaavp/f03KaTmW2rvGc3y4ST3SwTTnazTIyV7JKOlnSvpPsknZ4qKDNLr/YBOknzKIaEvo5i+OPUsMR7RrzHRwPNGhYRM3Z7jrNnPxS4LyLuj4ingIuBY8coz8waNE6y78tzh4xuYIYhk5JWSFojac0Y6zKzMY0zXHamqsKvVdMjYiXFSSGuxpt1aJw9+waK0wSnLKYYN21mE2icZL+F4tTBAyTtQHFe8tVpwjKz1GpX4yNiq6R3U5yoMQ+4ICLunuVtZtaRVsfGu81u1rwmut7MrEec7GaZcLKbZcLJbpYJJ7tZJpzsZpnoxeyyg92Dg/MYNt1tOOpSamOcLTh2GVXLH5RiXXW3R+rPnOJ7mV5Gm99n1XWnvpSf9+xmmXCym2WiFyPo+jYDbhPV/2Hl923bjJKqaj2szCaaEE1u/7rVeI+gM8uck90sE704Gp9Cm0epm65aT8oR7aalqIJXPdJdtfxJ2TZ1eM9ulgknu1kmnOxmmehdm71uOy51W6tqHE235+uOxppLzFXeU9WoONpsD6daV5+6Qb1nN8uEk90sE72oxtetMg9qursq9Xq7rB4OO/EoRZNk+nuaPJGkDXVi7ipe79nNMuFkN8uEk90sE71os3c1lLHpyStSlFm1DTxquVFlNt2OHnZ8YC4xjWoPT0Jbebqujjl4z26WiVmTXdIFkjZJumvgud0lXSdpXfl3YbNhmtm4quzZLwSOnvbc6cD1EXEQcH35uHWSnnNLLSKG3qrGNarMUe+ruq65lF/lPeOcXZa6jKrljVpu2Gup/neqllH3s6U0a7JHxDeBzdOePhZYVd5fBRyXOC4zS6zuAbq9I2IjQERslLTXsAUlrQBW1FyPmSXS+NH4iFgJrIT0V3FNcQJKE0dGU/QeNH0UedRR8D6NCuvaJM6TN0zdo/GPSFoEUP7dlC4kM2tC3WS/Glhe3l8OXJUmHDNryqxTSUv6AnAksCfwCPC3wJeAS4H9gQeA4yNi+kG8mcpqbSrpLqvxKXR55ZscqvGT9L0P246pp5LOZt74JidkGLWuSRnt1fY/d51/4KbPiEux7VPsRObyWh3Dkt0j6Mwy4WQ3y0QvToRJoc0ukiZeq6NqtXWUFF10VcuvKkUcqeaNr3riUdV1pa7SD/Ke3SwTTnazTDjZzTLR6zZ73bnQU3eDpIij7nJ13zeX7qU6cQxbNkUZdTU97DiFJs7enOI9u1kmnOxmmehFNb7JkWVtdoXVLbOJLq+mh7q2Ofpt2HvmoumRjXXid9ebmdXiZDfLRC+q8alHaqU40j2s7FHLzabqCKxR8aeYOKPqyR2j3lPnO6tb3U9RRtMnBk3CWYHes5tlwslulgknu1kmejF5xaBJaPuMo2+XKkqt7qQOk6LpGD15hZmNzclulolsut6GqXsiTN2Y2h6xN2nanjwzxbo9gs7MesXJbpYJJ7tZJnrX9TZoUtq4fegymos259ivo267vG8X3qjLXW9mmZs12SXtJ+kGSWsl3S3plPL53SVdJ2ld+Xdh8+GaWV1VrvW2CFgUEbdJ2hW4FTgOOBHYHBEfkXQ6sDAiTpulrKSXf2p6NFbT1cU6c7enUmdO+RSXRUqhie61quvrsuuwqtrV+IjYGBG3lfefANYC+wLHAqvKxVZR/ACY2YSa06AaSUuAlwCrgb0jYiMUPwiS9hrynhXAivHCNLNxVT4aL2kX4BvA30fEFZK2RMSCgdcfi4iR7XZX46uXX3VddbkaX299fa7GV9qzS9oeuBy4KCKuKJ9+RNKicq++CNg0dpRzlGJe9CbmMU993bAur7HW9PDhQVU/56R0AU6X4keh0+GyKtZ4PrA2Is4ZeOlqYHl5fzlwVdLIzCypKkfjXwF8C/gu8Gz59Aco2u2XAvsDDwDHR8TmWcpKWo2vq241u80qXJfneVeNI3VMkzo4qemr+FQpby6GVeN7MYKuyckiR61rUkxKEjQx0WOKMur8MLY9CUWKGKvyCDqzzDnZzTLRi8kr6piUSyZ1OZlC6jZkivnx63ZFVp1Tv+68fqm/pyZiHJf37GaZcLKbZcLJbpaJ3nW9TSuv0nKTqg/xdzkmoc619SZxG86mzriGUdz1ZpY5J7tZJnpdjR8jjkpl962aDd11MW5LuhyenIKr8WaZc7KbZcLJbpaJXrTZB6WY+aVqeaPK77Itm2KWmRRDbuu+1qRJPR049UxJo7jNbpY5J7tZJnpx1luKEXSp5werG0cKTc9B1+S2SrGupifAmIsUE6u0xXt2s0w42c0y0Ytq/DBzOQLc1Si8uu+rO2deijn2q0rRFKhzBLvuqMcmqtmTXnUf5D27WSac7GaZcLKbZaLXbfbp2mwzNd0llWJdbUp95lZddS8JVtWoz9n0cZxxec9ulokq13rbSdLNku6UdLekD5XPHyBptaR1ki6RtEPz4ZpZXVWu9SZg54h4srya603AKcD7gCsi4mJJnwbujIhPzVJW0ss/zWW5Pox0mvQ51+pe0ij1pbcmaVKRSbm+26DaJ8JE4cny4fblLYCjgMvK51cBx40dpZk1plKbXdI8SXdQXIP9OuAHwJaI2FousgHYd8h7V0haI2lNioDNrJ5KyR4Rz0TEUmAxcChw8EyLDXnvyohYFhHL6odpZuOaU9dbRGyRdCNwGLBA0vxy774YeKiB+IDqbdkUr41ab9NnZVUtv83LC1ddb5tdXnW7tVLHlLKcmcpL3Z1Z5Wj8CyQtKO8/D3gtsBa4AXhzudhy4KqkkZlZUlX27IuAVZLmUfw4XBoR10i6B7hY0lnA7cD5DcZpZmPqxRx0VbvNJqVLalATcdTZBpPUXTWo6ctPT0o3a+pJOkbxHHRmmXOym2WidyfCdHkpnjpHqUeV38TJNH2rnjfdA9HV9pjEZpP37GaZcLKbZcLJbpaJ3rXZR0lxiZ2uRqfNtGydmIa1ldvuApyENupcpO6im8TP7z27WSac7GaZ6F01PkV1q81RVXMpv8lLJjVRRp0Rem03a6qUN5cym7g6a1tVfu/ZzTLhZDfLhJPdLBO9a7NvS+q23SZxAowmzlirOpFD6qG5c1Hns3V1tqb37GaZcLKbZaJ31fg6c8g3UUYKo6pzqeNou7tnUrs3pzRxmeom1peS9+xmmXCym2Wid9X4QW2fsND0fGlNVncnaYKNOuuaxKbAKHUvldUk79nNMuFkN8uEk90sE71us89Finm7U5zJlUKbxw7qrGsu6+tqkocujwF09Zm9ZzfLROVkLy/bfLuka8rHB0haLWmdpEsk7dBcmGY2rrns2U+huKDjlLOBj0XEQcBjwEkpAxsk6Ve3qstNv0XEr25VDb6njarXqPirfpY6y6UwalvV/f7qqFrGqOVGbfs+q5TskhYDfwicVz4WcBRwWbnIKuC4JgI0szSq7tnPBU4Fni0f7wFsieLa7AAbgH1neqOkFZLWSFozVqRmNpYq12d/A7ApIm4dfHqGRWes50bEyohYFhHLasZoZglU6Xo7AnijpGOAnYDdKPb0CyTNL/fui4GHmgpyUrpxUszJ3vRc66m7B1N07dUdKppiYog6JnHO9xRm3bNHxBkRsTgilgAnAF+PiLcBNwBvLhdbDlzVWJRmNrZx+tlPA94n6T6KNvz5aUIysyaozSqLpFora2LyhpRlz2W9k3D2U9vanMhiUjQxv3xVETHjGz2CziwTTnazTGyzJ8LUrUaNqnKmPjrcZZW2zues2yTpQ9OlblOjzavmjst7drNMONnNMuFkN8tEll1vLX/moa81MfHEpFzOuUkpujObHsnnrjcz64yT3SwT2XS9DXut6SrypFw5dC5lNNlsSlG9ndSTXVL8zzU5QYb37GaZcLKbZcLJbpaJbLrehmliWG2K9zU9X/soTc9L36Y2Y0y9Lne9mVktTnazTGyzXW9N63Ik1aizq1J0eU3KnH91dDm3+6jtPWy5NnnPbpYJJ7tZJnpXje/biR51j9TXLaPOFVhHxVF3tFebcQxbb9Wyp79vUntoxuU9u1kmnOxmmXCym2WiF2321GdQTepkGHXbm8PKqBtXne1Ytcur6bbydMPKSHG8p66uJt30nt0sE5X27JLWA08AzwBbI2KZpN2BS4AlwHrgTyLisWbCNLNxVToRpkz2ZRHx6MBzHwU2R8RHJJ0OLIyI02YpZ+z6StuXa6qz7kkcPTUp8+jPpfw2dXniUdV1VdXEiTDHAqvK+6uA48Yoy8waVjXZA/iqpFslrSif2zsiNgKUf/ea6Y2SVkhaI2nN+OGaWV1Vj8YfEREPSdoLuE7S96quICJWAishTTXezOqplOwR8VD5d5OkK4FDgUckLYqIjZIWAZuaCrLq3N9V3lN3XaM0PYHEdE0P36y6rj5PBjGJZ+w1bdZqvKSdJe06dR94PXAXcDWwvFxsOXBVU0Ga2fiq7Nn3Bq4sfxXnA/8aEddKugW4VNJJwAPA8c2FaWbjynIOutyrc9NVHdU2iWdyzbauri7dlMIkdb2ZWY842c0y4WQ3y0QvznpLrW9tyFTrriN1l13b7eGuriVXl6/1ZmZjc7KbZaLXXW9NTISQWh9i7FKbl0zq25l57nozs1qc7GaZ6MXR+DrzqTd99dG688BN4ui9Ji5RVbW8NtfVpUmI2Xt2s0w42c0y4WQ3y0Qvut4GTcqECV2d1dXG+po0CW3XlJqcdNNdb2ZWi5PdLBO96Hprs3rX5qWB68bUdFMjxckYKSbAqBpTim1Q9eSlquYSr7vezCwpJ7tZJpzsZpnoRZt90CQONx1lLu3tOpem7sMlhFMf32jiWmyph/ROyvXiBnnPbpYJJ7tZJrIZQde36v8obc5jNyhFk2Qu621zHrum58dPsT2q8gg6s8xVSnZJCyRdJul7ktZKOlzS7pKuk7Su/Luw6WDNrL5K1XhJq4BvRcR5knYAng98ANgcER+RdDqwMCJOm6WcftefzXpgWDV+1mSXtBtwJ3BgDCws6V7gyIFLNt8YEb89S1lOdrOGjdNmPxD4MfBZSbdLOq+8dPPeEbGxLHwjsNdMb5a0QtIaSWtqxm5mCVTZsy8DvgMcERGrJX0ceBx4T0QsGFjusYgY2W73nt2seePs2TcAGyJidfn4MuClwCNl9Z3y76YUgZpZM2ZN9oh4GHhQ0lR7/DXAPcDVwPLyueXAVY1EaGZJVD0avxQ4D9gBuB94B8UPxaXA/sADwPERsXmWclyNN2tY7aPxKTnZzZrnEXRmmXOym2XCyW6WCSe7WSac7GaZcLKbZaLtOegeBX4E7Fne79IkxACOYzrH8VxzjeOFw15otZ/9VyuV1kTEstZXPGExOA7H0WYcrsabZcLJbpaJrpJ9ZUfrHTQJMYDjmM5xPFeyODpps5tZ+1yNN8uEk90sE60mu6SjJd0r6b5yRtq21nuBpE2S7hp4rvWpsCXtJ+mGcjruuyWd0kUsknaSdLOkO8s4PlQ+f4Ck1WUcl5QzCTdO0rxyfsNruopD0npJ35V0x9R8iR39jzQ2bXtryS5pHvDPwB8AhwBvlXRIS6u/EDh62nOnA9dHxEHA9eXjpm0F3h8RBwOHAe8qt0HbsfwSOCoiXgwsBY6WdBhwNvCxMo7HgJMajmPKKcDagcddxfHqiFg60K/dxf/Ix4FrI+JFwIsptkuaOCKilRtwOPCVgcdnAGe0uP4lwF0Dj+8FFpX3FwH3thXLQAxXAa/rMhaKawDcBryMYqTW/Jm+rwbXv7j8Bz4KuAZQR3GsB/ac9lyr3wuwG/BDygPnqeNosxq/L/DgwOMN5XNdqTQVdlMkLQFeAqzuIpay6nwHxUSh1wE/ALZExNZykba+n3OBU4Fny8d7dBRHAF+VdKukFeVzbX8vY03bPps2k32mqXKy7PeTtAtwOfDeiHi8ixgi4pmIWEqxZz0UOHimxZqMQdIbgE0Rcevg023HUToiIl5K0cx8l6RXtrDO6eZTzNz8qYh4CfBTEjYd2kz2DcB+A48XAw+1uP7pOpkKW9L2FIl+UURc0WUsABGxBbiR4hjCAklTJ0e18f0cAbxR0nrgYoqq/LkdxEFEPFT+3QRcSfED2Pb30ui07W0m+y3AQeWR1h2AEyimo+5K61NhSxJwPrA2Is7pKhZJL5C0oLz/POC1FAeCbgDe3FYcEXFGRCyOiCUU/w9fj4i3tR2HpJ0l7Tp1H3g9cBctfy/R9LTtTR/4mHag4Rjg+xTtw79ucb1fADYCT1P8ep5E0Ta8HlhX/t29hTheQVEl/W/gjvJ2TNuxAL8D3F7GcRfwwfL5A4GbgfuALwI7tvgdHQlc00Uc5fruLG93T/1vdvQ/shRYU343XwIWporDw2XNMuERdGaZcLKbZcLJbpYJJ7tZJpzsZplwsptlwslulon/A7SVnPBOmFiLAAAAAElFTkSuQmCC\n", 172 | "text/plain": [ 173 | "
" 174 | ] 175 | }, 176 | "metadata": { 177 | "needs_background": "light" 178 | }, 179 | "output_type": "display_data" 180 | }, 181 | { 182 | "data": { 183 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXyklEQVR4nO3df9AdVX3H8ffHAIKATfgdCBhwaNU/FJ2IOLFoES1FK0xHWh2saUsndWodrM5IUKu1YxXtqNiZVicFNLVUoIiGYagKCP4YSyD8UiBCwPIjJBIoiYA/CXz7x56HuVyfe5999u7u3c35vGbuPPfH3rPfu/d+nz1nz9mzigjMbOf3rGkHYGbtcLKbZcLJbpYJJ7tZJpzsZplwsptlwsluT5P0DklXtri+CyR9sK315c7JPgdJjw/cnpL0i4HHp05Q7rWS3jbHMntI+idJ96f13iHp3ZJUdb1NkXSWpHMaLH/O7dWn9UzDLtMOoOsiYq+Z+5LuAf4yIhrf+6WE/hqwN/B6YCNwDPAl4GDgfU3HYDuZiPCt5A24Bzh+6LkFwN8BPwYeBs4HFqbX9gQuAB4BtgPrgEXAp4AngV8CjwOfmmVdbwB+Dhw09PyxwA7gsPT4WuDD6e+jwOXAooHlfzetdztwI7B8zOd7B3AV8M9p+bsHPy9wWCr/EeBOYEV6/mTg18AT6fNcN6L8o4FbgMeA/wAuAT6YXtsf+G/goVT+WmBxem3W7QV8DtiUPvd1wDED61oO3JRe+wnw8bm2SZnvpc+3qQfQp9uIZF8FfJdib7s78EXgC+m104GLgT0oalEvB/ZMr10LvG3Mus4GvjHitQcHEu1a4A7g+emfy/eBv0+vLQX+Dzieosl2YkqmRSPKfUdK2LdT/BP7W+CegdfXAZ8Bng0sS0k5kyhnAeeM+Ty7A5uBvwZ2BU6l+Kc1k+wHAielbfVbKdkvGHj/b2yvFOeiVN4HgPuBXdNrNwGnpPt7A68os03m+l76fHObfXJ/BayKiM0R8UvgI8CfpGr4ExR7rOdHxI6IuD4iflay3P2ALSNe25Jen/FvEXF3Kvti4Kj0/Argkoi4MiKeiojLgdspmgWj3BER/x4RTwJrgOdJWijpSOAlwPsj4lcRsT69/qclP8+xwK8i4l8j4omIOB/4wcyLEfFgRKyNiF9ExE+BjwOvHldginNbRDwBfAzYFzgivfwE8NuS9o2IxyJi3QTbZKfgZJ9ASuhDgcslbZe0nWKP8iyKH965wLeBiyVtkvQxSQtKFv8wsHjEa4vT6zN+MnD/58DMcYbnAW+biS3Ftww4WNLxAwcabxhTFqm8g4GHIuIXA6/fCxxS8vMcTFHlHnTvzB1Je0s6T9J9kh4Fvskz/6H9BklnpoOWPwW2UdQeZt6zAngxcKekdZJ+Pz0/cpuU/By95QN0E4iIkPQA8EcRccOIxT4EfEjSEcA3gNso2vVznW54JbBS0kER8XQCSjqW4gd9TYkQ76eoWr9rxOt7jXh+NpuB/SXtMZDwhwEPpPtzfZ4twJKh5w4DZrbbqvT6yyPiQUnHAN8bWPYZ5Ut6HfAuiur4BkAUxwIEEBEbKGpYC4C3AJdIWsTc22SnPQ3Ue/bJfR44S9KhAJIOkPSH6f7xkl4k6VkUB4p2UBwAgqLdfcRsBSaXU7S/vyLpBZJ2kbScoup8dkTcO+a9M9YAp0h6raQFqSvvtZIOqvA576Kodn9U0rMlvYxi73n+wOc5fEy34HeA3VNf/i6S3kqx552xN0VNYruk/YDh/vfh7bU3RVX9IWA34B8o9uwASHp7qsI/CfyUIomfYu5tMtf30l/TPmjQpxujj8afQdE19hhFUnw4vbYiPf8ziurxp4BnpddenZbdBnxyxPqek97zAMUR4o3AewENLPOMA0oUB9muHHi8nGIPuQ3YClwKHDxifcPv3Z0iSZakx0spjphvS7H8xcCyBwH/k177/ojyj6H4hzHb0fjDUpyPAz+iOJC3Y+C9z9heFAflvkTxT/QB4N1pG78qLX8RRVPnMeCHwIlltkmZ76WvN6UPaGY7OVfjzTLhZDfLhJPdLBMTJbukE1I/512SVtUVlJnVr/IButR/eSfwOorBEtcDb42I28e8x0cDzRoWEbN2f06yZz8auCsifhwRv6Y44eOkCcozswZNkuyHUIxGmrGJWYZOSlopab2k9ROsy8wmNMlw2dmqCr9RTY+I1cBqcDXebJom2bNvojgJZMYSivHTZtZBkyT79cCRkg6XtBvFyQaX1hOWmdWtcjU+InZI+huKM7kWAOdFxG21RWZmtWp1bLzb7GbNa6Lrzcx6xMlulgknu1kmnOxmmXCym2XCyW6Wid7NLjvYVdj0Jc/GdUtWXXfT8Y+KuY51Vd0edX/mOr6X4TLa/D7rWvd8ec9ulgknu1kmPIKuAU1U/0eV38GrN1fWRPW26SZEF7e/R9CZZc7JbpaJ3h2Nr6rNo9RNV+26ckS7aXVUwUe9r+p31pVtU4X37GaZcLKbZcLJbpaJ3rXZq7bj6m5rlY2j6fZ82fLnE8e0jm+02R6ua1196gb1nt0sE052s0z0ohpftco8qOnuqrrXO83q4eD6BuOoo0ky/J4+nUgymyoxTyte79nNMuFkN8uEk90sEz7rbYymJ68YV2bVNt60jm/U0Y5uYnKMLrSVh+MY1kB3rM96M8vZnMku6TxJWyXdOvDcPpKukLQx/V3UbJhmNqkye/YvAicMPbcKuCoijgSuSo9bFxHPuNVN0shb2bjGlTnufWXXNZ/yy7xnkrPL6i6jbHnjlhv1Wl2/nbJlVP1sdZoz2SPiO8AjQ0+fBKxJ99cAJ9ccl5nVrOqgmgMjYgtARGyRdMCoBSWtBFZWXI+Z1aTxEXQRsRpYDfUfja/jBJQmRmPVMRFC00eRRx35n8/6unKke5q6OE/eKFWPxj8oaTFA+ru1vpDMrAlVk/1SYEW6vwJYW084ZtaUOQfVSPoy8BpgP+BB4MPA14CLgMOA+4BTImL4IN5sZbU2qGaa1fg6TPPKNzlU47v0vTcw3fWshWQzgq7JCRnGrWs+I7qaTJ62f9yjPst8RpLVfUZcHdu+jp1I1RjL8gg6s8w52c0y0YvJK+rQZhdJE69VUbbaOk4dXXRlyy+riZNuqnaXjtrGXfkNDPKe3SwTTnazTDjZzTLR6zZ71bnQq3SDNB1H1eWqvq/KxBbzPZOuqTKqanrYcR2a7H71nt0sE052s0z0ohrfZNWm7W6QaXU1jSuz7VGEZd4z/L46midl193EyMam4y/De3azTDjZzTLR6xNh2j5bq+xR6qrbtOtnm7U5HfJ81l0ljiZ6ULrCJ8KYZc7JbpYJJ7tZJnrRZu9bm2mcPlyqqElVJ3XoiqZj9OQVZjYxJ7tZJnoxgq7J6lzVE2EmucRR3WX2yTS76Joe/VZHGR5BZ2YTc7KbZcLJbpaJ3nW9DZU3UTx16UOX0Xz0bXsP6sqw12leeMNdb2aZmzPZJR0q6WpJGyTdJun09Pw+kq6QtDH9XdR8uGZWVZlrvS0GFkfEjZL2Bm4ATgb+DHgkIs6StApYFBFnzFFWa2e9NTG3eNnyql5KqMq6qqoyp3wdl0WqQ9tn302r+t96NT4itkTEjen+Y8AG4BDgJGBNWmwNxT8AM+uoeQ2qkbQUeCmwDjgwIrZA8Q9B0gEj3rMSWDlZmGY2qdJH4yXtBXwb+MeIuETS9ohYOPD6togY2253Nb58+WXXVZWr8dXW1+dqfKk9u6Rdga8A50fEJenpByUtTnv1xcDWSpFNoI550ZuYx7zu64ZN8xprTQ8fHlT2c3alC3BYHYk61eGyKtZ+LrAhIj498NKlwIp0fwWwtv7wzKwuZY7Gvwr4LvBD4Kn09Psp2u0XAYcB9wGnRMQjc5TV3gieMapWs6dVhetqHHXH1NXBSU1fxaduo6rxvRhBN6juySKHdeUHNqgrSVBHsjdRRpV/jG1OQjFcfgvr9gg6s5w52c0y0YvJK6royiWTpjmZQt1tyDrmx6/aFTn4vvn0wtTRDVpl2zUR46S8ZzfLhJPdLBNOdrNM9K7rbVBXuqSq6kP80xyTUOXael3chnOpO353vZllzsluloleV+Or6upw2Sq60sW4M5nm8OQ61utqvFnmnOxmmXCym2WiF8Nl6575payudvHUMcvMuDJGnaE1nzKmta3qaG9XPZuy6pltbW0f79nNMuFkN8tEL6rxdUxA0OT8YG131TQ9B12bc6lN60zCur6zsk2eLvCe3SwTTnazTOxUI+j6ML/3tOaUn8+caFXKH7e+NqvWZT9nXb+PLlbdPYLOLHNOdrNMONnNMtGLrrey2mwztXl5n660Bcdp89jPOE2PVBv3OZs+jjMp79nNMlHmWm+7S7pO0i2SbpP0kfT84ZLWSdoo6UJJuzUfrplVVeZabwL2jIjHVVzN9XvA6cB7gEsi4gJJnwduiYjPzVFWrZd/mqX8kcv1YaRT1+dcq3pJoza7EdvWwPxxE5dXuestCo+nh7umWwDHARen59cAJ1eKzMxaUarNLmmBpJsprsF+BXA3sD0idqRFNgGHjHjvSknrJa2vI2Azq6ZUskfEkxFxFLAEOBp44WyLjXjv6ohYFhHLqodpZpOaV9dbRGyXdA1wDLBQ0i5p774E2NxAfDPrnfX5qt0sTbcNqw4/LVt+m5cXLrveNru8qnZr1R1TneU0Vd6gMkfj95e0MN3fAzge2ABcDbw5LbYCWNtUkGY2uTJ79sXAGkkLKP45XBQRl0m6HbhA0keBm4BzG4zTzCbUu7Pe+jDv2aAm4qiyDbrUXTWo6ctPd6WbteUzMn3Wm1nOnOxmmejdiTDTuhTPuHW3Xa3swrTEcyn7OZvugZjW9uhis8l7drNMONnNMuFkN8tE77re6tZ026rq3PZVY+rDpJtdtJN9Fne9meXMyW6Wid51vdV0cv/EZZQ1n/KbvGRSE2VUGaHXdrOmTHnzKbPpue2b5D27WSac7GaZcLKbZaJ3bfadSdW2WxcnwGjijLWyk1bWPTR3Pqp8tmmdrek9u1kmnOxmmehdNb6OSw03fbnissZV5+qOo+3unq52b86oOnK0iW6/tnjPbpYJJ7tZJrI/EWY+mp4vrStyOJlmmidAtbBunwhjljMnu1kmnOxmmcimzV6lbdjEmVx1aPPYQZV1TbI+m5zb7GaZK53s6bLNN0m6LD0+XNI6SRslXShpt+bCNLNJzWfPfjrFBR1nfAL4TEQcCWwDTqszsEER8fSt7HLDN0lP38oafE8b1dJx8Zf9LFWWq8O4bVX1+6uibBnjlhu37fusVLJLWgK8ATgnPRZwHHBxWmQNcHITAZpZPcru2c8G3gc8lR7vC2xP12YH2AQcMtsbJa2UtF7S+okiNbOJlLk++xuBrRFxw+DTsyw6ax0nIlZHxLKIWFYxRjOrQZmz3pYDb5J0IrA78FyKPf1CSbukvfsSYHNTQU5rQoZho7qh6prXvY746+4erKNrr+pQ0TomhqhiZ+02nHPPHhFnRsSSiFgKvAX4VkScClwNvDkttgJY21iUZjaxSfrZzwDeI+kuijb8ufWEZGZNyHIE3aC2LwHd5tlPXdH3swCrmOZ36xF0ZplzsptlYqetxletRo2rcu5M1dEqn7OOJklXmy5Vv9su/iZcjTfLnJPdLBNOdrNM7LRt9mFdmdhwUBMTT3Tlcs5NavvYQZMTnzTBbXazzDnZzTLRu8s/lVXHVT/reF9Xrhw6nzLqvtxU3aMGu3qySx2/uSabUN6zm2XCyW6WCSe7WSay6XobpYlhtXW8b5rztXehfVmXNmPsyvZw15tZ5pzsZpnYabvemtbmSKpxZdRxxtqwrsz5V8U053Yft71HLdcm79nNMuFkN8tE76rxfTvRo+qR+qpljKpKNt1jMM04Rq23bNnD7+tqD82kvGc3y4ST3SwTTnazTPSizV73GVRtjjKbj6rtzVFlVI2rynYs2+XVdFt52Kgy6jjeU9W0Jrbwnt0sE6X27JLuAR4DngR2RMQySfsAFwJLgXuAP46Ibc2EaWaTKnUiTEr2ZRHx8MBznwQeiYizJK0CFkXEGXOUU6kONK2TGYY1Wa1sWlfn0e/QySOl4ujqvPeDmjgR5iRgTbq/Bjh5grLMrGFlkz2Ab0q6QdLK9NyBEbEFIP09YLY3Slopab2k9ZOHa2ZVlT0avzwiNks6ALhC0o/KriAiVgOroZvns5vlolSyR8Tm9HerpK8CRwMPSlocEVskLQa2NhVk2bm/y7yn6rrGaXoCiWFND98su66+TQbRleMD0zJnNV7SnpL2nrkPvB64FbgUWJEWWwGsbSpIM5tcmT37gcBX03/CXYD/jIivS7oeuEjSacB9wCnNhWlmk8pyDrrcq3PDyo5q6+KZXHOtq2+XbqqD56Azy5yT3SwTTnazTPTirLe69a0NWde6q6i7y67t9vC0riVXla/1ZmYTc7KbZaLXXW9NTIRQtz7EOE1NjpIb1vcz88py15tZ5pzsZpnoxdH4KvOpN3310arzwHWxStjEJarKltfmuqapCzF7z26WCSe7WSac7GaZ6EXXW1fOoOpCTG2sr0ldaLvWqYuTbrrrzSxzTnazTPSiGt9nTc9P10RTo47fRB0TYJSNqY5tUPbkpXHKXg6rhWafq/FmOXOym2XCyW6WiV4Mlx3UxeGm48xnuGyVS1P34RLCdU/02MS12Ooe0tvF68V5z26WCSe7WSZ60fXWxVFK09TmPHaD6miSzGe9bc5j1/T8+C2PuHTXm1nOSiW7pIWSLpb0I0kbJL1S0j6SrpC0Mf1d1HSwZlZdqWq8pDXAdyPiHEm7Ac8B3g88EhFnSVoFLIqIM+YoJ7sRdGZtG1WNnzPZJT0XuAU4IgYWlnQH8JqBSzZfExG/M0dZTnazhk3SZj8CeAj4gqSbJJ2TLt18YERsSYVvAQ6Y7c2SVkpaL2l9xdjNrAZl9uzLgGuB5RGxTtJngUeBd0XEwoHltkXE2Ha79+xmzZtkz74J2BQR69Lji4GXAQ+m6jvp79Y6AjWzZsyZ7BHxE+B+STPt8dcCtwOXAivScyuAtY1EaGa1KHs0/ijgHGA34MfAn1P8o7gIOAy4DzglIh6ZoxxX480aVvlofJ2c7GbN8wg6s8w52c0y4WQ3y4ST3SwTTnazTDjZzTLR9hx0DwP3Avul+9PUhRjAcQxzHM803zieN+qFVvvZn16ptD4ilrW+4o7F4DgcR5txuBpvlgknu1kmppXsq6e03kFdiAEcxzDH8Uy1xTGVNruZtc/VeLNMONnNMtFqsks6QdIdku5KM9K2td7zJG2VdOvAc61PhS3pUElXp+m4b5N0+jRikbS7pOsk3ZLi+Eh6/nBJ61IcF6aZhBsnaUGa3/CyacUh6R5JP5R088x8iVP6jTQ2bXtryS5pAfAvwB8ALwLeKulFLa3+i8AJQ8+tAq6KiCOBq9Ljpu0A3hsRLwSOAd6ZtkHbsfwKOC4iXgIcBZwg6RjgE8BnUhzbgNMajmPG6cCGgcfTiuP3IuKogX7tafxGPgt8PSJeALyEYrvUE0dEtHIDXgl8Y+DxmcCZLa5/KXDrwOM7gMXp/mLgjrZiGYhhLfC6acZCcQ2AG4FXUIzU2mW276vB9S9JP+DjgMsATSmOe4D9hp5r9XsBngv8L+nAed1xtFmNPwS4f+DxpvTctJSaCrspkpYCLwXWTSOWVHW+mWKi0CuAu4HtEbEjLdLW93M28D7gqfR43ynFEcA3Jd0gaWV6ru3vZaJp2+fSZrLPNlVOlv1+kvYCvgK8OyIenUYMEfFkRBxFsWc9GnjhbIs1GYOkNwJbI+KGwafbjiNZHhEvo2hmvlPSsS2sc9guFDM3fy4iXgr8jBqbDm0m+ybg0IHHS4DNLa5/2FSmwpa0K0Winx8Rl0wzFoCI2A5cQ3EMYaGkmZOj2vh+lgNvknQPcAFFVf7sKcRBRGxOf7cCX6X4B9j299LotO1tJvv1wJHpSOtuwFsopqOeltanwpYk4FxgQ0R8elqxSNpf0sJ0fw/geIoDQVcDb24rjog4MyKWRMRSit/DtyLi1LbjkLSnpL1n7gOvB26l5e8lmp62vekDH0MHGk4E7qRoH36gxfV+GdgCPEHx3/M0irbhVcDG9HefFuJ4FUWV9AfAzel2YtuxAC8Gbkpx3Ap8KD1/BHAdcBfwX8CzW/yOXgNcNo040vpuSbfbZn6bU/qNHAWsT9/N14BFdcXh4bJmmfAIOrNMONnNMuFkN8uEk90sE052s0w42c0y4WQ3y8T/A37trvyU6+cwAAAAAElFTkSuQmCC\n", 184 | "text/plain": [ 185 | "
" 186 | ] 187 | }, 188 | "metadata": { 189 | "needs_background": "light" 190 | }, 191 | "output_type": "display_data" 192 | } 193 | ], 194 | "source": [ 195 | "if datatype == 'uniform':\n", 196 | " # Load the one hot datasets\n", 197 | " train_onehot = np.load('data-uniform/train_onehot.npy').astype('float32')\n", 198 | " test_onehot = np.load('data-uniform/test_onehot.npy').astype('float32')\n", 199 | "\n", 200 | " # (N, C, H, W) <=== 数据格式\n", 201 | " # make the train and test datasets\n", 202 | " # train\n", 203 | " pos_train = np.where(train_onehot == 1.0)\n", 204 | " X_train = pos_train[2]\n", 205 | " Y_train = pos_train[3]\n", 206 | " train_set = np.zeros((len(X_train), 2, 1, 1), dtype='float32')\n", 207 | " for i, (x, y) in enumerate(zip(X_train, Y_train)):\n", 208 | " train_set[i, 0, 0, 0] = x\n", 209 | " train_set[i, 1, 0, 0] = y\n", 210 | "\n", 211 | " # test\n", 212 | " pos_test = np.where(test_onehot == 1.0)\n", 213 | " X_test = pos_test[2]\n", 214 | " Y_test = pos_test[3]\n", 215 | " test_set = np.zeros((len(X_test), 2, 1, 1), dtype='float32')\n", 216 | " for i, (x, y) in enumerate(zip(X_test, Y_test)):\n", 217 | " test_set[i, 0, 0, 0] = x\n", 218 | " test_set[i, 1, 0, 0] = y\n", 219 | "\n", 220 | " train_set = np.tile(train_set, [1, 1, 64, 64])\n", 221 | " test_set = np.tile(test_set, [1, 1, 64, 64])\n", 222 | "\n", 223 | " # Normalize the datasets\n", 224 | " train_set /= (64. - 1.) # 64x64 grid, 0-based index\n", 225 | " test_set /= (64. - 1.) # 64x64 grid, 0-based index\n", 226 | "\n", 227 | " print('Train set : ', train_set.shape, train_set.max(), train_set.min())\n", 228 | " print('Test set : ', test_set.shape, test_set.max(), test_set.min())\n", 229 | "\n", 230 | " # Visualize the datasets\n", 231 | "\n", 232 | " plt.imshow(np.sum(train_onehot, axis=0)[0, :, :], cmap='gray')\n", 233 | " plt.title('Train One-hot dataset')\n", 234 | " plt.show()\n", 235 | " plt.imshow(np.sum(test_onehot, axis=0)[0, :, :], cmap='gray')\n", 236 | " plt.title('Test One-hot dataset')\n", 237 | " plt.show()\n", 238 | "\n", 239 | "else:\n", 240 | " # Load the one hot datasets and the train / test set\n", 241 | " train_set = np.load('data-quadrant/train_set.npy').astype('float32')\n", 242 | " test_set = np.load('data-quadrant/test_set.npy').astype('float32')\n", 243 | "\n", 244 | " train_onehot = np.load('data-quadrant/train_onehot.npy').astype('float32')\n", 245 | " test_onehot = np.load('data-quadrant/test_onehot.npy').astype('float32')\n", 246 | "\n", 247 | " train_set = np.tile(train_set, [1, 1, 64, 64])\n", 248 | " test_set = np.tile(test_set, [1, 1, 64, 64])\n", 249 | "\n", 250 | " # Normalize datasets\n", 251 | " train_set /= train_set.max()\n", 252 | " test_set /= test_set.max()\n", 253 | "\n", 254 | " print('Train set : ', train_set.shape, train_set.max(), train_set.min())\n", 255 | " print('Test set : ', test_set.shape, test_set.max(), test_set.min())\n", 256 | "\n", 257 | " # Visualize the datasets\n", 258 | "\n", 259 | " plt.imshow(np.sum(train_onehot, axis=0)[0, :, :], cmap='gray')\n", 260 | " plt.title('Train One-hot dataset')\n", 261 | " plt.show()\n", 262 | " plt.imshow(np.sum(test_onehot, axis=0)[0, :, :], cmap='gray')\n", 263 | " plt.title('Test One-hot dataset')\n", 264 | " plt.show()" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 6, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "# flatten the datasets\n", 274 | "train_onehot = train_onehot.reshape((-1, 64 * 64)).astype('int64')\n", 275 | "test_onehot = test_onehot.reshape((-1, 64 * 64)).astype('int64')" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 7, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "# model definition\n", 285 | "\n", 286 | "class Net(nn.Module):\n", 287 | " def __init__(self):\n", 288 | " super(Net, self).__init__()\n", 289 | " self.coordconv = CoordConv2d(2, 32, 1, with_r=True, use_cuda=False)\n", 290 | " self.conv1 = nn.Conv2d(32, 64, 1)\n", 291 | " self.conv2 = nn.Conv2d(64, 64, 1)\n", 292 | " self.conv3 = nn.Conv2d(64, 1, 1)\n", 293 | " self.conv4 = nn.Conv2d( 1, 1, 1)\n", 294 | "\n", 295 | " def forward(self, x):\n", 296 | " x = self.coordconv(x)\n", 297 | " x = F.relu(self.conv1(x))\n", 298 | " x = F.relu(self.conv2(x))\n", 299 | " x = F.relu(self.conv3(x))\n", 300 | " x = self.conv4(x)\n", 301 | " x = x.view(-1, 64*64)\n", 302 | " return x\n", 303 | "\n", 304 | "device = torch.device(\"cpu\")\n", 305 | "net = Net().to(device)\n", 306 | "\n", 307 | "#summary(net, input_size=(2, 64, 64))" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "## Make Datasets" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 8, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "train_tensor_x = torch.stack([torch.Tensor(i) for i in train_set])\n", 324 | "train_tensor_y = torch.stack([torch.LongTensor(i) for i in train_onehot])\n", 325 | "\n", 326 | "train_dataset = utils.TensorDataset(train_tensor_x,train_tensor_y)\n", 327 | "train_dataloader = utils.DataLoader(train_dataset, batch_size=32, shuffle=False)\n", 328 | "\n", 329 | "test_tensor_x = torch.stack([torch.Tensor(i) for i in test_set])\n", 330 | "test_tensor_y = torch.stack([torch.LongTensor(i) for i in test_onehot])\n", 331 | "\n", 332 | "test_dataset = utils.TensorDataset(test_tensor_x,test_tensor_y)\n", 333 | "test_dataloader = utils.DataLoader(test_dataset, batch_size=32, shuffle=False)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 9, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)\n", 343 | "def cross_entropy_one_hot(input, target):\n", 344 | " _, labels = target.max(dim=1)\n", 345 | " return nn.CrossEntropyLoss()(input, labels)\n", 346 | "criterion = cross_entropy_one_hot\n", 347 | "epochs = 10" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 10, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "def train(epoch, net, train_dataloader, optimizer, criterion, device):\n", 357 | " net.train()\n", 358 | " iters = 0\n", 359 | " for batch_idx, (data, target) in enumerate(train_dataloader):\n", 360 | " data, target = Variable(data), Variable(target)\n", 361 | " data, target = data.to(device), target.to(device)\n", 362 | " optimizer.zero_grad()\n", 363 | " output = net(data)\n", 364 | " loss = criterion(output, target)\n", 365 | " loss.backward()\n", 366 | " optimizer.step()\n", 367 | " iters += len(data)\n", 368 | " print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(\n", 369 | " epoch, iters, len(train_dataloader.dataset),\n", 370 | " 100. * (batch_idx + 1) / len(train_dataloader), loss.data.item()), end='\\r', flush=True)\n", 371 | " print(\"\")" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 11, 377 | "metadata": {}, 378 | "outputs": [ 379 | { 380 | "name": "stdout", 381 | "output_type": "stream", 382 | "text": [ 383 | "Train Epoch: 1 [2508/2508 (100%)] Loss: 7.498067\n", 384 | "Train Epoch: 2 [2508/2508 (100%)] Loss: 3.979326\n", 385 | "Train Epoch: 3 [2508/2508 (100%)] Loss: 2.014220\n", 386 | "Train Epoch: 4 [2508/2508 (100%)] Loss: 0.979962\n", 387 | "Train Epoch: 5 [2508/2508 (100%)] Loss: 0.474519\n", 388 | "Train Epoch: 6 [2508/2508 (100%)] Loss: 0.241503\n", 389 | "Train Epoch: 7 [2508/2508 (100%)] Loss: 0.138851\n", 390 | "Train Epoch: 8 [2508/2508 (100%)] Loss: 0.089052\n", 391 | "Train Epoch: 9 [2508/2508 (100%)] Loss: 0.064587\n", 392 | "Train Epoch: 10 [2508/2508 (100%)] Loss: 0.047140\n" 393 | ] 394 | } 395 | ], 396 | "source": [ 397 | "for epoch in range(1, epochs + 1):\n", 398 | " train(epoch, net, train_dataloader, optimizer, criterion, device)" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": 13, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "def test(net, test_loader, optimizer, criterion, device):\n", 408 | " net.eval()\n", 409 | " test_loss = 0\n", 410 | " correct = 0\n", 411 | " pred_logits = torch.tensor([])\n", 412 | " for data, target in test_loader:\n", 413 | " with torch.no_grad():\n", 414 | " data, target = data.to(device), target.to(device)\n", 415 | " output = net(data)\n", 416 | " logits = F.softmax(output, dim=1)\n", 417 | " pred_logits = torch.cat((pred_logits, logits.cpu()), dim=0)\n", 418 | " test_loss += criterion(output, target).item()\n", 419 | " _, pred = output.max(1, keepdim=True)\n", 420 | " _, label = target.max(dim=1)\n", 421 | " correct += pred.eq(label.view_as(pred)).sum().item()\n", 422 | "\n", 423 | " test_loss = test_loss\n", 424 | " test_loss /= len(test_loader) # loss function already averages over batch size\n", 425 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 426 | " test_loss, correct, len(test_loader.dataset),\n", 427 | " 100. * correct / len(test_loader.dataset)))\n", 428 | " pred_logits = torch.sum(pred_logits, dim=0)\n", 429 | " plt.imshow(pred_logits.detach().numpy().reshape(-1,64), cmap='gray')\n", 430 | " plt.title('Predictions One-hot dataset')" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 14, 436 | "metadata": {}, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "\n", 443 | "Test set: Average loss: 0.0487, Accuracy: 628/628 (100%)\n", 444 | "\n" 445 | ] 446 | }, 447 | { 448 | "data": { 449 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2dfbAlV3HYfy3tlz5W7K60Uq20wALi22ULEBgCsbEMGLBjVCkgYCfZUIrXTpFCFCZIECcGhwSBEyCJbYiMsJcUXzIgJKuCQYAItmMLVoCMhBASWEhrLdovfa3Qtzp/zLynvs07fc+dN/e+93b6V/Xqzb0zc07PmTl3uk/36SOqSpIkhz9HLLUASZLMhuzsSTIQsrMnyUDIzp4kAyE7e5IMhOzsSTIQsrNPgIhsExEVkVXt58+JyPYO5TxGRA6JyJH9S7l8EJEbReRFM6pr5N4kP8lh19nbB+yetjPdKiJ/IiLHTqMuVX2Zqu6slGn+oVfVm1T1WFV9aBpyBXL8IxH5sojcJSJ3iMifi8jTZilDLW3HPXVKZb9QRHZPo+ylqKeWw66zt/wTVT0WeCbwbOB3/AHScLhe/08gIs8DvgBcDJwMPA64CvhrEXn8UsqWzAhVPaz+gBuBF5nPvw9c2m5/BfjPwF8D9wCnAo8CLgD2AP8AvBM4sj3+SOC/AvuBHwCvBxRYZcr716au3wCuBe4CvkPzY/O/gYfb+g4BbwG2uXJOBi4BDgI3AL9hynw7cCHwkbbca4DTzf5zWrnvAq4DfrHQLn8J/NEC338O+Ei7/UJgN/DbwN62TV5njl3btsdNwK3AB4GjxtyLNwN/B9wBfBJY59rrhva6LwFObr//ats+d7dt9s8WKHvcvXmduRc/AH6z/f6Y9l483JZ9qG3/5wB/A9zeXvcfAGvacwR4X9smd7TX81NRm5TqWdK+sdSdc5qdHXh02zn+k+mcNwFPB1YBq4HPAv+rvTknAl8zD8ZvAd9ty9kEXE6hswOvajvds9uH41TgsV6m9vM2V87/Bf4IWAecBuyj7bQ0nf1e4OXtA/4u4G/bfU8GbjadZBvwhAXa5GjgIeAXFtj3OmCP6ewPAr/Xts3LgR8DG9v976fplJuA9cCfA+8acy++1namTTSd77fafWfQdNRnth3mfwJfNecqcGpQ9rh788vAE9p78fPtdTzTXOduV96zgOe2z8W2VtY3tvt+CbgS2NCW91Rgy7g2WaieJe0bSy1A7xfUPGCHaH6hf9h2oqP0kc75e+bYk4D7MG8n4LXA5e32l+cezvbzSyh39s8DZwcyLdjZ24f1IWC92f8u4E/b7bcDXzT7ngbc026fSvO2eRGwOmiTrW19T1lg30uBB8zDec/c9bXf7W07gdC8aZ9g9j0P+Psx9+Kfm8/vAT7Ybl8AvMfsOxZ4ANjWfh7X2cN7s8Dxn527PzWdEHgjcFG7fQbwvbYdjjDHhG1SU88s/w7XkcszVfWLhX03m+3H0rzB9ojI3HdHmGNOdsf/MKjz0cD3JxeVk4GDqnqXq+d08/lHZvvHwDoRWaWqN4jIG2l+EJ4uIp8H3qSqt7g6bqNRJ7fQvA0tW2jesHMcUNUHXX3HAptpNIQrTVsJjbaBiHwO+Mft97+pqh8tyH6yue5vzO1Q1UMicgA4heZHYhzhvRGRlwG/CzyJ5p4eDXy7VJiIPAl4L027H03zQ3xlK9uXReQPgD8EHiMiF9GYJ+sI2mS5MZgBKoOd5nczzZv9BFXd0P4dp6pPb/fvoenEczwmKPdmGrVxXJ2eW4BNIrLe1fMPwTmPFKz6MVV9Ac0PlwLvXuCYu2ns0VctUMSrgS9VVLWf5q3/dNNWj9JmIBRtPBPHtn8fjYsCmut+7NwHETkGOJ7K6ya4NyKyFvg0jS19kqpuAP4PTUeEhe/HB2h+CJ+oqscBbzPHo6r/Q1WfRWMCPgn4d4xpk0I9S8YQO/s8qrqHZoT6v4nIcSJyhIg8QUR+vj3kQuANIrJVRDYC5wbFfQh4s4g8qx3pP1VE5h7mW4EFR7xV9Wbg/wHvEpF1IvLTwFnA2A4jIk8WkTPah/temgev5M47F9guIm8QkfUislFE3kmjdr5jXF2q+jDwx8D7ROTEtv5TROSXxp1b4GPA60TktFb+/wJcoao3tvuLbdYS3Zs1NOMA+4AH27f8S8z+W4HjReRR5rv1wJ3AIRF5CvBv5naIyLNF5GdFZDWN2n4v8FBFmyxUz5Ix6M7e8i9pHo7v0Ki7n6JRbaG5kZ+ncVF9A/hMqRBV/TOakf6P0YwAf5Zm0AYaG/x3ROR2EXnzAqe/lsaOvwW4CPhdVb2sQva1wHk0b5gf0Qwwvq0g31/RDDT9U5q34g+BZwAvUNXrK+qCZuT/BuBvReRO4Is0g4QTo6pfAv4DzRt4D41W9BpzyNuBnW2bvXqBIor3pjWJ3kDzg3Ab8Gs0g2hz+78LfBz4QVv+yTRq+a/R3Ls/pvEczHFc+91tNO12gEZrgKBNCvUsGdIOJCRJcpiTb/YkGQjZ2ZNkIGRnT5KBsKjOLiIvFZHrROQGEYlGqpMkWWI6D9C10zO/B7yYJp7668BrVfU7wTk5GpgkU0ZVZaHvF/Nmfw5wg6r+QFXvBz4BvGIR5SVJMkUW09lPYTRccXf73QgiskNEdonIrkXUlSTJIllMbPxCqsJPqOmqej5wPqQanyRLyWLe7LsZjU3eShMBliTJMmQxnf3rwBNF5HEisoYm1PGSMeckSbJEdFbjVfVBEfm3NPHJRwIfVtVrepMsSZJemWlsfNrsSTJ9puF6S5JkBZGdPUkGwopLS2XS/zBtE+SII0Z/C219i4g8XHQZteVb+qirVPa48vu+5q5yRGXM8n72Vfek5Js9SQZCdvYkGQjZ2ZNkIKwIm/3IIx/JzHv77bfPb2/atGnkuAceeKDXuu6+++6RfdaGP+aYY4r12uO83X/HHXfMbz/qUaN5CB988EFqsDafL//HP/7xgscdffTRneqy7WHLHle+tUPvuuuRLNnHHXdcUY7Ilu0qx0MPPZJ/07aVlQlg/fr1C57jWbVqtMvceeedVWVY+bvWvVjyzZ4kAyE7e5IMhBUXQdfVjWNVuFoXmlcrS3VbFQ3g0KFDC9YLcOyxj6weff/9948Te8HyreroTQGrBkZtFe0rtUnUHg8//PCC54yry6rFtt1gVL21plLkeovaypoQXt5I/kgFt2V2McOgf9dbRtAlycDJzp4kA+GwVeMj1drus6PqUD+ib9VPO8IOsHHjxvltP7pq1cU+ora8SlhSR/0osvU0eFOj5GmYRvSbrfuee+4pHnvUUUfNb9eqy74Me9+teg/1I+J9RO9Nm1Tjk2TgZGdPkoGQnT1JBsKKsNmtvXnbbbfNbx9//PEjx1lbrussrOi81atXz29bO33Dhg0jx1m72dvz1lUW2Z6RjNb2jNxV9jx/nI9ks5TseWs3Q/34RiRv5LqqHZ/p8gx3dX/5sSDriovs/lnO1kybPUkGTnb2JBkIy3IijFeVSqq7dzNZ9dOXEU1cKWHVdihHxnm1zKqjVrXzRG6oaPJIpAbaMiIXkj3OR/LZySTWHeYnoNjjvApuZbRy+Paw1+mjAe19svJGpoC/tyX1uasq7dvUPle2TO/qtDL7NuhjAlcN+WZPkoGQnT1JBkJ29iQZCCvC9RaFKNaeU2uvRckr7L7a2WtRcgk/rmBtuYMHD85ve5vd2nhRMgV7nm8PK4dvD2uLl2bRwei1+fDT2tlgVn4fLmvlsm4/L4e1+70b1I73WJm87d13wkl/b++9995iXfba+kheka63JBk4Yzu7iHxYRPaKyNXmu00icpmIXN/+3xiVkSTJ0jNWjReRnwMOAR9R1Z9qv3sPcFBVzxORc4GNqnrO2Mp6mPVWykcHo245786oTchgj6tNXuFV9Si5hMW7sko57ryZ0CU60KuV1kTxMtr6JknuUStj3+XVJumIklBYsyxSpfuIoIuuM0qiUUtnNV5VvwocdF+/AtjZbu8EzlyUdEmSTJ2uQTUnqeoeAFXdIyInlg4UkR3Ajo71JEnSE1Wj8SKyDbjUqPG3q+oGs/82VR1rt/e9iqtXh6yK5Seg2LTTVsXyx9nR3K6RTVF0ncXLb0emrYkSpV+uJWqr2pH0SIX1Mto2jkwjyyR58kpEZlmXfHfjyq8liq4rmQJd7nNbV6+j8beKyBaA9v/ejuUkSTIjunb2S4Dt7fZ24OJ+xEmSZFrUuN4+DvwN8GQR2S0iZwHnAS8WkeuBF7efkyRZxqyICLouRLOOrE1tXS4QJ1isdWtZG9gvUWUj46J9fYwd2DbwdrktP4oms66gSfK1WzvUJuD0Li9bl7f7bfRb16WVbN12HMR+D6Nt758JSzTjzrZVJIe9Lhh9Duw+L0dtdF1G0CXJwMnOniQDYVkmr6glcrN4t4Wd3FHr0vHJK0orsPq6IjWwlCPO7+tKSXX3JoNVR73JY9VFey1e/bQqbK056NuqtgwrY6Qi15oa3qypnYwSJd+wbeXXI7DX5vdZ7L4+ouks+WZPkoGQnT1JBkJ29iQZCCvCZi/NXIpCXb29Y20ra78eOHBg5Dg7c87PqrN1R2u9WXveyxHNiCvleY9mpdWGsHpbed26dfPbUfimlT+yQz2lRIx27GQctgx/nSU5ouSfa9asmd+eJOzVHhuNBXWdmVdyHfrxgcUmtsg3e5IMhOzsSTIQlqUa71U2q06feOIjs2l90oha1cyqUZGbzKuctr4oGUGUuz1yNdlybBlRvvZa1c63VSnCDcoun6h9o/K7Rr9Zc8s+A1F+eX/PSqp1FPXo28PWXRvVVpvkwpdRMn/6IN/sSTIQsrMnyUBYERNhrIpot/0ouFW/fN620ui5Tz1sVcIob5uVw6uVVu3zUXiWaCmhKF20lb82aYSnS942n1rbHufV5y4jx5E5ZNNM+/tiR8j9SrNWjshkiFa17WPpptqozT5We82JMEkycLKzJ8lAyM6eJANhWdrskevNumOiZYLHyLHgti+jiyvPEy3PZKP1APbv3z+/HUW/9WHXdaFrXvda95qfVRe57LrIFd3b2tmPfrZcbb75WUbQpc2eJAMnO3uSDIRlqcZ7rJpTyo/WtTw/2WXz5s3FfaXlmSYxJ/owDaZNScY+ZIry13sVv5RD3UfrRdGGJfdg5DKbJD9+Kbd7ZL5F6rm9lq7Pd6rxSTJwsrMnyUDIzp4kA2FZ2uzeZrLhinZflPPdU7tks8XbhqXzopzsPmy3tvwoIUMflEJiYTQ01eJDUa29Ookrq0YmX4a1gUvyQZwcw96nSXKy19rRteG43s4vJdPMvPFJknSiZvmnR4vI5SJyrYhcIyJnt99vEpHLROT69v/YVVyTJFk6xqrx7SqtW1T1GyKyHrgSOBP4V8BBVT1PRM4FNqrqOWPK6qSPWvWutKwxjEak7du3r7gvWuLJqmKRGydKMmDVRe+qscf68m2kllXZvGraZTmoyJ3kc8rbfVaFjWa2eVOmNvrNEqnxVv577723eJ7NrQdlU2OSqMda1bqPCMNohl3tEs6d1XhV3aOq32i37wKuBU4BXgHsbA/bSfMDkCTJMmWitFQisg14BnAFcJKq7oHmB0FETiycswPYsTgxkyRZLNWdXUSOBT4NvFFV75wgEux84Py2jKULC0uSgVPV2UVkNU1H/6iqfqb9+lYR2dK+1bcAe6clZGkNtyhZpM8eY+0wayv746wN5m1NmwnHyuHtsdqkgT4Pu7UN7RiDd+PY666146KsOFHmFyuHz1QT5YMv2elRqGut3b927doFyx5H7ay3ScKdS/a2H6ux4yy+jNL4zMwTTkoj2QXAtar6XrPrEmB7u70duLhXyZIk6ZWaN/vzgX8BfFtEvtV+9zbgPOBCETkLuAl41XRETJKkD5ZlBJ2ntNRulCzSu0+simWXAfLqZm1O9toZSV5tjdTFUlSblR1GVcQoQq9rkouSHD5yzV6bd3mVEnd6d6MtI4qItDJFSzZ7d2wpT79v09IMO0/XZaXts+plrK27loygS5KBk509SQbCslz+KaK0RBKMqos+n1k0EcFi1TRffsmE8BFtUa51u8+PxluVPBo5tmX6ySn2vGjF2No2sKaMr8uqwpEr1pbh1f2Sl8SXaeuOvDA+6syq0/aeeQ/EfffdN7/t28bWHZl29rxoGSo/Um/b0T4Ti1211ZNv9iQZCNnZk2QgZGdPkoGwIlxvltKabTBqk00wQ6i4r+sspqiMSP7SuIIvI5LLugS7ut6sjNEMPkttssjaeqGcpMLb7HZ8xtvstj2icZBoLYG+bWd/nXYMxl5bl9mNkK63JBk82dmTZCCsCDW+pI5611jtEjtWffYuKav2+YkZ0ZLQXeiat62kZkPZJRhF6/l2tK4g2z6RWllrWkxC7czKiNo2rc1f71X60jNXm4gDRpNx2OOi5acjUo1PkoGTnT1JBkJ29iQZCMvSZo+WbLbJFLwtaG0tO7OtrXt+29o+UchqtF6XtWW7Lh1di28PK8cJJ5xQ3GeJ3FV+3KJUnh2z8GX4GYh2X9/LT/eRAGOSpboje9uOb9j2OXjwYJUcC9W3UL2TkDZ7kgyc7OxJMhCWpRq/wHnz2yUXGoyq+F5VKs1Y8y60qD2suhgt+xzNiKtVW6OloaLj7MwxK2O0dJM3V2y7lpI/QLyUsW1/a25FanCUt602AYY3y0pl9JHnHkZz/UfLMveRlKKWVOOTZOBkZ0+SgbAi1PgSkfrpR4ej1M+1lNRAP0ptVccoAYYfBS8tDeWvM1oayu6rzYnm95WSRniTxF5bdC+iSTFWDa5NmR1NLooiCksegoXkKmHl9TLX5gb0dJ2wVCLV+CQZONnZk2QgZGdPkoGw4hJOWjvRu0i8u6NEVxvJ1mftyagML2OULNHantHSSta+9FGE9rxS4gZPNCMu+t7eC5+H3e4rbfu6u9qrdiwhSvRYuyyXJ0ogam14u88n1rTPQeR+7DuHvCXf7EkyEGrWelsnIl8TkatE5BoReUf7/eNE5AoRuV5EPikia8aVlSTJ0jHW9dYu7HiMqh5qV3P9K+Bs4E3AZ1T1EyLyQeAqVf3AmLI66Wkl94xXCa3rY9++fSP7Nm/evOC+TZs2jRwXrbbZZaKGPyeaVGHVQLvPmycHDhyY367NBx8lgoj2RZNMomWXbLtaN5y/5mhSkr2WKPotmmhTuhddTYZoWTF7nybJH1dyHfr7PvXkFdowd0Wr2z8FzgA+1X6/EzizSpIkSZaEKptdRI5sV3DdC1wGfB+4XVXnfkZ3A6cUzt0hIrtEZFcfAidJ0o2qzq6qD6nqacBW4DnAUxc6rHDu+ap6uqqe3l3MJEkWy0SuN1W9XUS+AjwX2CAiq9q3+1bglr6E8nadtdOtfePdWDZEMcpxHu2LZrNZO9TaZD6EMloDzbqGojXFIvegteUit1ntrDofplqygb1NHbkH7b61a9cuKJ//7N2IpSSQUairt6n9DLk5uiZzjO5Z7RoE/rhSG/cdyl4zGr9ZRDa020cBLwKuBS4HXtketh24uFfJkiTplZo3+xZgp4gcSfPjcKGqXioi3wE+ISLvBL4JXDBFOZMkWSQrYtZbyTXh86/t379/ftu7pGpdIVHSiFLO+mhZZm9qWDlq86lHOehswg4YdcvZfX4J6yhZQx/PREltjUw0r57b80ouLii76Px5tvyuSyvVEuUv7OpSqyVnvSXJwMnOniQDYUWo8a6M+W2vstlILZ9QopRMIIqI8iq9HSmNIu0iGe0Iv0+wYdW5KNqrNhqudsS9q0obLSFVUs+7ppLuOnlp2uVboiWkojbom1Tjk2TgZGdPkoGQnT1JBsKKttknsWXtsdZO97nnra0fLesUtVs0Q8uW711vpfGC2tlgMGor1i451NUFGOV8r7VLa6P8+pixFiXRqHWN1c6EnGS56b77YNrsSTJwsrMnyUBYETnoSmqxd11FLqQuZXRVr0q56sZRmgjjVU4bseddjKXVSCfJX19SwaOcf5GrycoUJZ7w7X3PPffMb5dy2Xu5fBmlZaO8mm1NJX+d1uzzufZKCSsiOaIypumWyzd7kgyE7OxJMhCysyfJQFgRNrslCkXtsp5b7dLIXYlst2j9NWv/eTsuSnDg3XQlavPeR1i5fPKK0kzF2vEBGL0Wa+dGM8oiF2BpTMTjn4nombPjJ7b8aCahHyey1xaNHSyWfLMnyUDIzp4kA2HFRdBFbpAo6syqxbVlRIkn+phB5dVRqxLaMqLlnyKiSD6rTnediVaqC0ZNlKguS627Kopwq21v69aD2JyLctuX6o7ubbQEtzUFurrhMoIuSQZOdvYkGQgrTo135RU/165aGq1M6lU7m3gimtBiJ9dMEuVnZelyLVF50Xm1Knit7F3rikbS+04u4c03W3c0kl47Acp7Wmwacp+ivOuyUSVSjU+SgZOdPUkGQnb2JBkIK8Jm77Kkkbe3bd700jJOEC//ZPOwWzvdu7VKkXCTyB99H40J2PKj46Jow9IS2d4Vac+LZr1FrrHSOAXUJ9UolbdQfXP4No3GcbokzojK6LoUeC1psyfJwKnu7O2yzd8UkUvbz48TkStE5HoR+aSIrJmemEmSLJZJJsKcTbOg45x+9m7gfar6CRH5IHAW8IE+hIpyrVsVPMrh5suwKmgUmWRVTu8KilTOEpHa6lVJm1zBRll5V1A0ESbKnW+xEV3+Om27RupnFFlWShpRu8QT1LukavO1W6JJSN4sq02cYenqLp0mVW92EdkK/DLwofazAGcAn2oP2QmcOQ0BkyTph1o1/v3AW4C5n/zjgdvbtdkBdgOnLHSiiOwQkV0ismtRkiZJsihq1mf/FWCvql5pv17g0AV1E1U9X1VPV9XTO8qYJEkP1Njszwd+VUReDqyjsdnfD2wQkVXt230rcEtfQnk7qGR3edvHzg7rIylFFC5rxw6ihBFRogXvDlu3bt38duRujGw+OzZhbVmfH792Jppt0yhxg79HpbGPyE0ZjWFEs8GisYMoVLdUhr0Pnj5cb0vF2B6hqm9V1a2qug14DfBlVf114HLgle1h24GLpyZlkiSLZjGvv3OAN4nIDTQ2/AX9iJQkyTRYERF0XfDqs3W1WHUuSlDhKanWUdIFr7aecMIJ89sHDhwY2TfNZX2nMUOwdulo61LzOehqZ3n1/ZxG1zJJ5F7tEtnRUlx9kxF0STJwsrMnyUA4rNT4SH22k1juv//++e1owszmzZtH9u3bt29+O5pMU5IJuk2qmOQeldRzL4dtH9s2MNoGdrTcfg+jEXp+X8mDEi3/VGu69DHJZJJ01Fbt9s9LKfmGNwW6TOrpSqrxSTJwsrMnyUDIzp4kA+GwstmtvWYTMEA5kcN9991XLCMqv9b1ZqPMIM5tX1raOFoyKUrgGC0FHCX6KNmhvoxat1x0XK0tW5vM0Y8d2DaIkm34NrbYe+Fz1lv5S0te+fO6uHcnIW32JBk42dmTZCCsuFVcI6w66id+rF27dn7bRnT5CDqr1kfuKut2sq48GHXV+KWbokgtW1+0cqi9Nu82K6nM0bX4CTml5BiTJOKwanKti642qUiUKMNj5bL3qeuEmah8a7LVrvbqqXUBdiHf7EkyELKzJ8lAyM6eJAPhsHK9ubqKn62LxOeGj9Y2s+fVrufmwzLteZG7J5opZm23KHGiH0soEY0J2Lr9sxKt01bKFT/tpA5r1owmObY2cG3SUU/kDiu5WSN3aWR7p+stSZJFk509SQbCilPj+1BzSuo4xC61khwe647xkVS2/MhMsPnXolx7Hqta28iy/fv3jxwX5aCz8lv3YLS8dW2evD5m8HlKOeq9jNbdNsnSyLXLYtuoTZ8D35bh3X7R8lhdSDU+SQZOdvYkGQgrQo236q2NwLL53OAnJ7VYSquz+gg0q97Vqm+1x/ljoyirKE+eVbujKLxoMo29zii6LhqNtyqzj6CzEz9se/uJJNGknpJXIDI7IiLzqnYyjW/vkkkYeSei6+wjV12q8UkycLKzJ8lAyM6eJANhWdrskQ1pk0DaBJAwGj1WmyAg2ufbxtp8kd1vbcpJlm4qyeXbw7p1vL1q3XJRYohoX2lpai97lKzBuhjtPn+N1kY9ePDgyL5SNKNvDzt24Mu37WHviz8uuhZrz/t9tv2j2Y5R4oy+1wtImz1JBk7VfHYRuRG4C3gIeFBVTxeRTcAngW3AjcCrVfW2UhlJkiwtVWp829lPV9X95rv3AAdV9TwRORfYqKrnjCln0Wq8VaminGhdcpvBqIrs99nIp8iVYtW0SM2OVLYoOq2Wrqu4ltx3tu1h1EUVufZKk5AgnkwTuRgtkVlWIrpnpZz3/jgo5/nr6o7tg2mo8a8AdrbbO4EzF1FWkiRTprazK/AFEblSRHa0352kqnsA2v8nLnSiiOwQkV0ismvx4iZJ0pXaHHTPV9VbRORE4DIR+W5tBap6PnA+zHY+e5Iko1R1dlW9pf2/V0QuAp4D3CoiW1R1j4hsAfb2JZQPE7ThotbW8jOLrJ0bLcVsbSZfl7WpI9vK1uVdRpErJbI9S4kQvBx27CCyt6MEG/a6o1zupbzrXg5PyWXnr7/UpjB6f7u6p0rjD9Y16OX17RHNSpv22n19MlaNF5FjRGT93DbwEuBq4BJge3vYduDiaQmZJMniqXmznwRc1P7arQI+pqp/ISJfBy4UkbOAm4BXTU/MJEkWy7KMoJugvOLn2ug3706aJKnBQmV7ovbtEk3X9bwo+i3KY2dNlEmWoSrlpY9U8Cgvvd0Xue+8uVJyD0aJQ6Jlpb3Z13f/yRx0SZIsmuzsSTIQsrMnyUA4rNZ6i2wca2tF4ZBd8HaytSH9jLgDBw7Mb3v70ia4rLXXarPMRKHFvi7rHrMuzNr29edFS1hHdnQfSSutHJHb07aPdyn2YUdHlNyDudZbkiSdyM6eJANhRbjerGvLbvvEgFFiC6tO1yaVrJUpWhLIu3Fskkyfy72PJAZdVE7vOqydMRiVUTITvNssUltLbeDriswES+2Mw8i1FyXMtEziFi4tj5WutyRJOpGdPUkGwrIcjY/yqVsVzqtUdjTbq4Ql1XqSvF8b7CEAAAaRSURBVPGlZYaiiSpejmhySq1K2DUqzxKZGtEofqkMrz6Xll2KRu1rJwn5JZ5s+/iR9NqEEpZJIhRLarc3NWqfly4RnLXkmz1JBkJ29iQZCNnZk2QgLEvXm7fZ7Sw1617za7vVXou1mXxd1p3nZ8RZm9K6Y3wihGipZ0vtMsTRrLTaZaWjurx92XfEWJRwxNblc61be7vrzMLa42qX4I7WaYtmx1lyrbckSaZKdvYkGQjLUo33lFS4LpFe46hNGhGpupFby6r8tW4WH9Fl1T5vQtg2iZYQtiqyL9+q2qtXr66S0aucpWWdolx4HnvfvbvNYl1X/l6UIvSi47z5Zvf5tirl6/N5+vt2D0akGp8kAyc7e5IMhOzsSTIQVoTNbu0k6xrzoa59uKH6JrK3o8QWdp/PS29nzvmZf7Y+a096t1bJjQijNmvtuEJtCG/XxJfW/RrZ/d5WLi3jHYW9+rGgKKd86bxoSWh7LdCPu82SNnuSDJzs7EkyEJalGh9F0FkV1qvt9lqiMrqoqX0RmRO1edIj11vJjRPdZ19+aekpP8vQmhfeXVXKNx8tV+XVeHuelb+PvO61ufu8zLXnRa69qPw+SDU+SQZOVWcXkQ0i8ikR+a6IXCsizxORTSJymYhc3/7fOL6kJEmWiio1XkR2An+pqh8SkTXA0cDbgIOqep6InAtsVNVzxpTTyWaoXdJo2mUsF6btWahtqy4myST1ls6bJJnHECmp8WM7u4gcB1wFPF7NwSJyHfBCs2TzV1T1yWPKys7eA9nZs7NHLMZmfzywD/gTEfmmiHyoXbr5JFXd0xa+BzhxoZNFZIeI7BKRXR1lT5KkB2o6+yrgmcAHVPUZwN3AubUVqOr5qnq6qp7eUcYkSXqgprPvBnar6hXt50/RdP5bW/Wd9v/e6YjYqGkL/c26jOXCtOWvbauu+2rrXexxyShjO7uq/gi4WUTm7PFfBL4DXAJsb7/bDlw8FQmTJOmF2tH404APAWuAHwCvo/mhuBB4DHAT8CpVPVgshO4DdEmS1NN5NL5PsrMnyfTJCLokGTjZ2ZNkIGRnT5KBkJ09SQZCdvYkGQjZ2ZNkIMx6yeb9wA+BE9rtpWQ5yAAphyflGGVSOR5b2jFTP/t8pSK7ljpWfjnIkHKkHLOUI9X4JBkI2dmTZCAsVWc/f4nqtSwHGSDl8KQco/Qmx5LY7EmSzJ5U45NkIGRnT5KBMNPOLiIvFZHrROSGNiPtrOr9sIjsFZGrzXczT4UtIo8WkcvbdNzXiMjZSyGLiKwTka+JyFWtHO9ov3+ciFzRyvHJNpPw1BGRI9v8hpculRwicqOIfFtEvjWXL3GJnpGppW2fWWcXkSOBPwReBjwNeK2IPG1G1f8p8FL33bnAl1T1icCXmCCv3iJ4EPhtVX0q8Fzg9W0bzFqW+4AzVPVngNOAl4rIc4F3A+9r5bgNOGvKcsxxNnCt+bxUcvyCqp5m/NpL8Yz8d+AvVPUpwM/QtEs/cpTyjfX9BzwP+Lz5/FbgrTOsfxtwtfl8HbCl3d4CXDcrWYwMFwMvXkpZaNYA+AbwszSRWqsWul9TrH9r+wCfAVwKyBLJcSNwgvtupvcFOA74e9qB877lmKUafwpws/m8u/1uqahKhT0tRGQb8AzgiqWQpVWdv0WTKPQy4PvA7ao6t/DYrO7P+4G3AHPrHR+/RHIo8AURuVJEdrTfzfq+LCpt+zhm2dkXSpUzSL+fiBwLfBp4o6reOe74aaCqD6nqaTRv1ucAT13osGnKICK/AuxV1Svt17OWo+X5qvpMGjPz9SLyczOo07OotO3jmGVn3w082nzeCtwyw/o9M0uFbRGR1TQd/aOq+pmllAVAVW8HvkIzhrBBROYmR83i/jwf+FURuRH4BI0q//4lkANVvaX9vxe4iOYHcNb3Zapp22fZ2b8OPLEdaV0DvIYmHfVSMfNU2NKsW3QBcK2qvnepZBGRzSKyod0+CngRzUDQ5cArZyWHqr5VVbeq6jaa5+HLqvrrs5ZDRI4RkfVz28BLgKuZ8X3Raadtn/bAhxtoeDnwPRr78N/PsN6PA3uAB2h+Pc+isQ2/BFzf/t80AzleQKOS/h3wrfbv5bOWBfhp4JutHFcD/7H9/vHA14AbgD8D1s7wHr0QuHQp5Gjru6r9u2bu2VyiZ+Q0YFd7bz4LbOxLjgyXTZKBkBF0STIQsrMnyUDIzp4kAyE7e5IMhOzsSTIQsrMnyUDIzp4kA+H/AwsfZhXIqGTLAAAAAElFTkSuQmCC\n", 450 | "text/plain": [ 451 | "
" 452 | ] 453 | }, 454 | "metadata": { 455 | "needs_background": "light" 456 | }, 457 | "output_type": "display_data" 458 | } 459 | ], 460 | "source": [ 461 | "test(net, test_dataloader, optimizer, criterion, device)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [] 470 | } 471 | ], 472 | "metadata": { 473 | "kernelspec": { 474 | "display_name": "Python 3", 475 | "language": "python", 476 | "name": "python3" 477 | }, 478 | "language_info": { 479 | "codemirror_mode": { 480 | "name": "ipython", 481 | "version": 3 482 | }, 483 | "file_extension": ".py", 484 | "mimetype": "text/x-python", 485 | "name": "python", 486 | "nbconvert_exporter": "python", 487 | "pygments_lexer": "ipython3", 488 | "version": "3.7.5" 489 | } 490 | }, 491 | "nbformat": 4, 492 | "nbformat_minor": 2 493 | } 494 | --------------------------------------------------------------------------------