├── models.py └── MNIST_example.ipynb /models.py: -------------------------------------------------------------------------------- 1 | """ Pytorch models used in topograd project """ 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | class Segmenter_Unet(nn.Module): 8 | def __init__(self, img_dim, num_filters=16): 9 | super(Segmenter_Unet, self).__init__() 10 | self.img_dim = img_dim 11 | self.num_filters = num_filters 12 | 13 | # have series of convs down, using maxpools 14 | # and then convs back up again 15 | # just upsample on the way back up 16 | self.conv1_1 = nn.Conv2d(1, self.num_filters, 3, stride=1, padding=1) 17 | self.conv1_2 = nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1, padding=1) 18 | self.conv1_3 = nn.Conv2d(self.num_filters, self.num_filters*2, 3, stride=2, padding=1) 19 | # downsampled 20 | self.conv2_1 = nn.Conv2d(self.num_filters*2, self.num_filters*2, 3, stride=1, padding=1) 21 | self.conv2_2 = nn.Conv2d(self.num_filters*2, self.num_filters*2, 3, stride=1, padding=1) 22 | self.conv2_3 = nn.Conv2d(self.num_filters*2, self.num_filters*4, 3, stride=2, padding=1) 23 | # downsampled 24 | self.conv3_1 = nn.Conv2d(self.num_filters*4, self.num_filters*4, 3, stride=1, padding=1) 25 | self.conv3_2 = nn.Conv2d(self.num_filters*4, self.num_filters*4, 3, stride=1, padding=1) 26 | self.conv3_3 = nn.Conv2d(self.num_filters*4, self.num_filters*4, 3, stride=1, padding=1) 27 | 28 | # upsampled - will have a num_filters * 2 concatted in 29 | self.conv4_1 = nn.Conv2d(self.num_filters*6, self.num_filters*2, 3, stride=1, padding=1) 30 | self.conv4_2 = nn.Conv2d(self.num_filters*2, self.num_filters*2, 3, stride=1, padding=1) 31 | self.conv4_3 = nn.Conv2d(self.num_filters*2, self.num_filters*2, 3, stride=1, padding=1) 32 | 33 | # upsampled - will have num_filters * 1 concatted in 34 | self.conv5_1 = nn.Conv2d(self.num_filters*3, self.num_filters*1, 3, stride=1, padding=1) 35 | self.conv5_2 = nn.Conv2d(self.num_filters*1, self.num_filters*1, 3, stride=1, padding=1) 36 | self.conv5_3 = nn.Conv2d(self.num_filters*1, self.num_filters*1, 3, stride=1, padding=1) 37 | 38 | # finish with 1x1 conv 39 | self.conv_final = nn.Conv2d(self.num_filters*1, 1, 1, stride=1) 40 | 41 | def forward(self, x): 42 | x = F.relu(self.conv1_1(x)) 43 | x_1 = F.relu(self.conv1_2(x)) 44 | x = F.relu(self.conv1_3(x_1)) 45 | 46 | x = F.relu(self.conv2_1(x)) 47 | x_2 = F.relu(self.conv2_2(x)) 48 | x = F.relu(self.conv2_3(x_2)) 49 | 50 | x = F.relu(self.conv3_1(x)) 51 | x = F.relu(self.conv3_2(x)) 52 | x = F.relu(self.conv3_3(x)) 53 | 54 | x = torch.cat([F.interpolate(x, scale_factor=2), x_2], dim=1) 55 | x = F.relu(self.conv4_1(x)) 56 | x = F.relu(self.conv4_2(x)) 57 | x = F.relu(self.conv4_3(x)) 58 | 59 | x = torch.cat([F.interpolate(x, scale_factor=2), x_1], dim=1) 60 | x = F.relu(self.conv5_1(x)) 61 | x = F.relu(self.conv5_2(x)) 62 | x = F.relu(self.conv5_3(x)) 63 | 64 | x = torch.sigmoid(self.conv_final(x)) 65 | return x 66 | 67 | class MNIST_classifier(nn.Module): 68 | def __init__(self, img_dim, num_filters=16, num_classes=10): 69 | super(MNIST_classifier, self).__init__() 70 | self.img_dim = img_dim 71 | self.num_filters = num_filters 72 | self.num_classes = num_classes 73 | 74 | # have series of convs down, using maxpools 75 | # and then convs back up again 76 | # just upsample on the way back up 77 | self.conv1_1 = nn.Conv2d(1, self.num_filters, 3, stride=1, padding=1) 78 | self.conv1_2 = nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1, padding=1) 79 | self.conv1_3 = nn.Conv2d(self.num_filters, self.num_filters*2, 3, stride=2, padding=1) 80 | # downsampled 81 | self.conv2_1 = nn.Conv2d(self.num_filters*2, self.num_filters*2, 3, stride=1, padding=1) 82 | self.conv2_2 = nn.Conv2d(self.num_filters*2, self.num_filters*2, 3, stride=1, padding=1) 83 | self.conv2_3 = nn.Conv2d(self.num_filters*2, self.num_filters*4, 3, stride=2, padding=1) 84 | # downsampled 85 | self.conv3_1 = nn.Conv2d(self.num_filters*4, self.num_filters*4, 3, stride=1, padding=1) 86 | self.conv3_2 = nn.Conv2d(self.num_filters*4, self.num_filters*4, 3, stride=1, padding=1) 87 | self.conv3_3 = nn.Conv2d(self.num_filters*4, self.num_filters*4, 3, stride=1, padding=1) 88 | 89 | self.low_res_img_dim = self.img_dim // 4 90 | self.final_conv_num_filters = self.num_filters*4 91 | self.fc_1 = nn.Linear(self.low_res_img_dim**2 * self.final_conv_num_filters, self.final_conv_num_filters) 92 | self.fc_2 = nn.Linear(self.final_conv_num_filters, self.final_conv_num_filters) 93 | self.fc_3 = nn.Linear(self.final_conv_num_filters, self.final_conv_num_filters) 94 | 95 | self.fc_final = nn.Linear(self.final_conv_num_filters, self.num_classes) 96 | 97 | def forward(self, x): 98 | x = F.relu(self.conv1_1(x)) 99 | x = F.relu(self.conv1_2(x)) 100 | x = F.relu(self.conv1_3(x)) 101 | x = F.relu(self.conv2_1(x)) 102 | x = F.relu(self.conv2_2(x)) 103 | x = F.relu(self.conv2_3(x)) 104 | x = F.relu(self.conv3_1(x)) 105 | x = F.relu(self.conv3_2(x)) 106 | x = F.relu(self.conv3_3(x)) 107 | 108 | x = x.view(-1, self.low_res_img_dim**2 * self.final_conv_num_filters) 109 | x = F.relu(self.fc_1(x)) 110 | x = F.relu(self.fc_2(x)) 111 | x = F.relu(self.fc_3(x)) 112 | x = self.fc_final(x) 113 | return x 114 | -------------------------------------------------------------------------------- /MNIST_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\"\"\" Experiment on MNIST data - \n", 10 | "The task is to reconstruct ground truth pixels using pixels+noise\n", 11 | "If prior knowledge of the required topology is provided then we will observe the same data\n", 12 | "being predicted as a 0, 9, 8 etc. as different priors are enforced. \n", 13 | "\"\"\"\n", 14 | "\n", 15 | "import copy\n", 16 | "import gudhi as gd\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import numpy as np\n", 19 | "import numpy.fft as ft\n", 20 | "import os\n", 21 | "\n", 22 | "%matplotlib inline\n", 23 | "\n", 24 | "import torch\n", 25 | "import torchvision.datasets as datasets\n", 26 | "import torchsummary\n", 27 | "from torch import nn, optim\n", 28 | "from torch.nn import functional as F\n", 29 | "from models import Segmenter_Unet, MNIST_classifier\n", 30 | "\n", 31 | "from topologylayer.nn import LevelSetLayer2D, TopKBarcodeLengths\n", 32 | "device = torch.device(\"cuda\")" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "mnist_trainset = datasets.MNIST(root='./MNIST_data', train=True, download=True, transform=None)\n", 42 | "mnist_testset = datasets.MNIST(root='./MNIST_data', train=False, download=True, transform=None)\n", 43 | "img_dim = 28\n", 44 | "print(\"Size of training set is {}\".format(len(mnist_trainset)))\n", 45 | "print(\"Size of test set is {}\".format(len(mnist_testset)))" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "\"\"\" Split training set \"\"\"\n", 55 | "\n", 56 | "X_train = np.array([np.array(x[0]) for x in mnist_trainset])\n", 57 | "Y_train = np.array([x[1] for x in mnist_trainset])\n", 58 | "X_test = np.array([np.array(x[0]) for x in mnist_testset])\n", 59 | "Y_test = np.array([x[1] for x in mnist_testset])\n", 60 | "\n", 61 | "# use some of the training set to train denoising network\n", 62 | "# use some of the training set to train digit classifier (that can measure how well digits are denoised)\n", 63 | "N_denoise = 10000\n", 64 | "N_classifier = 50000\n", 65 | "\n", 66 | "X_denoise = X_train[:N_denoise]\n", 67 | "Y_denoise = Y_train[:N_denoise]\n", 68 | "X_classifier = X_train[N_denoise:N_denoise+N_classifier]\n", 69 | "Y_classifier = Y_train[N_denoise:N_denoise+N_classifier]\n", 70 | "\n", 71 | "def norm(X):\n", 72 | " return X.astype(np.float) / np.max(X, axis=(1,2)).reshape((-1, 1, 1))\n", 73 | "\n", 74 | "X_denoise = norm(X_denoise)\n", 75 | "X_classifier = norm(X_classifier)\n", 76 | "\n", 77 | "print(X_denoise.shape)\n", 78 | "print(Y_denoise.shape)\n", 79 | "print(X_classifier.shape)\n", 80 | "print(Y_classifier.shape)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "scrolled": true 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "classifier_net = MNIST_classifier(img_dim).to(device)\n", 92 | "classifier_optimizer = optim.Adam(classifier_net.parameters(), lr=1e-4)\n", 93 | "\n", 94 | "X_classifier_torch = torch.tensor(X_classifier.reshape(-1, 1, img_dim, img_dim)).float().to(device)\n", 95 | "Y_classifier_torch = torch.tensor(Y_classifier).to(device)\n", 96 | "\n", 97 | "N_classifier_val = 10000\n", 98 | "N_classifier_train = N_classifier - N_classifier_val\n", 99 | "\n", 100 | "def train_classifier(model, optimizer, X, Y, X_v, Y_v, batch_size=50, num_epochs=1, verbose=False):\n", 101 | " \"\"\" Train the classification model\n", 102 | "\n", 103 | " \"\"\"\n", 104 | " model.train()\n", 105 | " N = X.shape[0]\n", 106 | " if Y.shape[0] != N:\n", 107 | " raise ValueError('ERROR: Number of labels ({}) != Number of images ({})!'.format(Y.shape[0], N))\n", 108 | "\n", 109 | " num_batches = N // batch_size\n", 110 | " for e in range(num_epochs):\n", 111 | " train_loss = 0.\n", 112 | " batch_indices = np.arange(N, dtype=np.int)\n", 113 | " np.random.shuffle(batch_indices)\n", 114 | "\n", 115 | " for b in range(num_batches):\n", 116 | " this_batch_indices = batch_indices[b*batch_size:(b+1)*batch_size]\n", 117 | " X_batch = X[this_batch_indices]\n", 118 | " Y_batch = Y[this_batch_indices]\n", 119 | "\n", 120 | " optimizer.zero_grad()\n", 121 | "\n", 122 | " predict_batch = model(X_batch)\n", 123 | " ce_loss = torch.nn.CrossEntropyLoss()(predict_batch, Y_batch)\n", 124 | " train_loss += ce_loss.item()\n", 125 | " ce_loss.backward()\n", 126 | " \n", 127 | " optimizer.step()\n", 128 | "\n", 129 | " if ((e+1) % 5) == 0:\n", 130 | " # check validation loss as well\n", 131 | " model.eval()\n", 132 | " predict_val = model(X_v)\n", 133 | " validation_loss = torch.nn.CrossEntropyLoss()(predict_val, Y_v)\n", 134 | " validation_accuracy = torch.mean((Y_v == torch.argmax(predict_val, dim=1)).float()) * 100.\n", 135 | " \n", 136 | " if verbose:\n", 137 | " print('Epoch: {0:5d} \\t Training Loss: {1:5g} \\t Val Loss: {2:5g} \\t Val Acc: {3:4g}%'.format(e+1,\n", 138 | " train_loss / num_batches,\n", 139 | " validation_loss,\n", 140 | " validation_accuracy))\n", 141 | " # set model back into training mode\n", 142 | " model.train()\n", 143 | "\n", 144 | " return model\n", 145 | "\n", 146 | "torchsummary.summary(classifier_net, (1, img_dim, img_dim))\n", 147 | "\n", 148 | "batch_size = 1000\n", 149 | "num_epochs = 250\n", 150 | "\n", 151 | "try:\n", 152 | " classifier_net = torch.load('./MNIST_classifier.pt')\n", 153 | "except:\n", 154 | " classifier_net = train_classifier(classifier_net,\n", 155 | " classifier_optimizer,\n", 156 | " X_classifier_torch[:N_classifier_train],\n", 157 | " Y_classifier_torch[:N_classifier_train],\n", 158 | " X_classifier_torch[N_classifier_train:N_classifier_train+N_classifier_val],\n", 159 | " Y_classifier_torch[N_classifier_train:N_classifier_train+N_classifier_val],\n", 160 | " batch_size,\n", 161 | " num_epochs,\n", 162 | " verbose=True)\n", 163 | " torch.cuda.empty_cache()\n", 164 | " torch.save(classifier_net, './MNIST_classifier.pt')\n" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "\"\"\" Add noise to MNIST digits in the Fourier domain \"\"\"\n", 174 | "def add_noise(X, num_lines_removed):\n", 175 | " N = X.shape[0]\n", 176 | " K = ft.fftshift(ft.fft2(X), axes=(1,2))\n", 177 | " num_img = K.shape[0]\n", 178 | " img_dim = K.shape[-1]\n", 179 | " K_degraded = K.copy()\n", 180 | " \n", 181 | " for n in range(num_img):\n", 182 | " lines = np.arange(img_dim)\n", 183 | " \n", 184 | " np.random.shuffle(lines)\n", 185 | " for l in lines[:num_lines_removed]:\n", 186 | " K_degraded[n, l] = 0\n", 187 | " \n", 188 | " np.random.shuffle(lines)\n", 189 | " for l in lines[:num_lines_removed]:\n", 190 | " K_degraded[n, :, l] = 0\n", 191 | " \n", 192 | " X_recon = np.abs(ft.ifft2(K_degraded))\n", 193 | " # min already 0 due to np.abs\n", 194 | " X_recon = X_recon / np.max(X_recon, axis=(1,2)).reshape((N, 1, 1))\n", 195 | " return X_recon" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "\"\"\" Train network to get original images back - then train one with digit-specific topological priors \"\"\"\n", 205 | "\n", 206 | "def train_model_supervised(model, optimizer, X, Y, X_v, Y_v, batch_size=50, num_epochs=1, verbose=False):\n", 207 | " \"\"\" Train the segmentation model\n", 208 | "\n", 209 | " Parameters\n", 210 | " ----------\n", 211 | " model - Pytorch model\n", 212 | " optimizer - Pytorch optimizer\n", 213 | " X - training images\n", 214 | " Y - training labels\n", 215 | " X_v - validation images\n", 216 | " Y_v - validation labels\n", 217 | " batch_size - int - batch size for training\n", 218 | " num_epochs - int - number of full epochs to train for\n", 219 | " verbose - bool - if True, print training information\n", 220 | "\n", 221 | " Returns\n", 222 | " -------\n", 223 | " model - trained Pytorch model\n", 224 | "\n", 225 | " Notes\n", 226 | " -----\n", 227 | "\n", 228 | " \"\"\"\n", 229 | " model.train()\n", 230 | " N = X.shape[0]\n", 231 | " if Y.shape[0] != N:\n", 232 | " raise ValueError('ERROR: Number of labels ({}) != Number of images ({})!'.format(Y.shape[0], N))\n", 233 | "\n", 234 | " num_batches = N // batch_size\n", 235 | " for e in range(num_epochs):\n", 236 | " train_loss = 0.\n", 237 | " batch_indices = np.arange(N, dtype=np.int)\n", 238 | " np.random.shuffle(batch_indices)\n", 239 | "\n", 240 | " for b in range(num_batches):\n", 241 | " this_batch_indices = batch_indices[b*batch_size:(b+1)*batch_size]\n", 242 | " X_batch = X[this_batch_indices]\n", 243 | " Y_batch = Y[this_batch_indices]\n", 244 | "\n", 245 | " optimizer.zero_grad()\n", 246 | "\n", 247 | " predict_batch = model(X_batch)\n", 248 | " bce_loss = F.binary_cross_entropy(predict_batch, Y_batch)\n", 249 | " train_loss += bce_loss.item()\n", 250 | " bce_loss.backward()\n", 251 | "\n", 252 | " optimizer.step()\n", 253 | "\n", 254 | " if ((e+1) % 10) == 0:\n", 255 | " # check validation loss as well\n", 256 | " model.eval()\n", 257 | " predict_val = model(X_v)\n", 258 | " validation_loss = nn.MSELoss()(predict_val, Y_v)\n", 259 | " \n", 260 | " if verbose:\n", 261 | " print('Epoch: {0:5d} \\t Training Loss: {1:5g} \\t Validation Loss: {2:5g}'.format(e+1,\n", 262 | " train_loss / num_batches,\n", 263 | " validation_loss))\n", 264 | " model.train()\n", 265 | "\n", 266 | " return model" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "\"\"\" Create noisy images\"\"\"\n", 276 | "l = 8\n", 277 | "X_noise = add_noise(X_denoise, l) \n", 278 | "X_noise_torch = torch.tensor(X_noise.reshape(-1, 1, img_dim, img_dim)).float().to(device)\n", 279 | "X_denoise_torch = torch.tensor(X_denoise.reshape(-1, 1, img_dim, img_dim)).float().to(device)\n", 280 | "\n", 281 | "X_test_noise = add_noise(X_test, l)\n", 282 | "X_test_noise_torch = torch.tensor(X_test_noise.reshape(-1, 1, img_dim, img_dim)).float().to(device)\n", 283 | "X_test_torch = torch.tensor(X_test.reshape(-1, 1, img_dim, img_dim)).float().to(device)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": { 290 | "scrolled": true 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "\"\"\" Train U-net to denoise the MNIST images \"\"\"\n", 295 | "N_denoise_train = 100\n", 296 | "N_denoise_val = 100\n", 297 | "\n", 298 | "model = Segmenter_Unet(img_dim=img_dim,\n", 299 | " num_filters=16)\n", 300 | "\n", 301 | "model = model.to(device)\n", 302 | "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n", 303 | "\n", 304 | "model = train_model_supervised(model, optimizer,\n", 305 | " X_noise_torch[:N_denoise_train],\n", 306 | " X_denoise_torch[:N_denoise_train],\n", 307 | " X_noise_torch[N_denoise_train:N_denoise_train+N_denoise_val],\n", 308 | " X_denoise_torch[N_denoise_train:N_denoise_train+N_denoise_val],\n", 309 | " batch_size=min(1000, N_denoise_train), num_epochs=1000, verbose=True)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "\"\"\" Predict on test set \"\"\"\n", 319 | "model.eval()\n", 320 | "with torch.no_grad():\n", 321 | " Z_predicted = model(X_test_noise_torch)\n", 322 | "Z_predicted_np = Z_predicted.cpu().numpy()[:,0]\n", 323 | "print(Z_predicted_np.shape)\n", 324 | "print(X_test_noise_torch.shape)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "\"\"\" Assess quality of reconstructed images by passing them through pre-trained MNIST classifier\n", 334 | "if the classifier can correctly tell what they are then they look good, since it is ~98% accurate on real digits \"\"\"\n", 335 | "Y_test_torch = torch.tensor(Y_test)\n", 336 | "Z_digit_prediction = classifier_net(Z_predicted)\n", 337 | "X_noise_digit_prediction = classifier_net(X_test_noise_torch)\n", 338 | "X_digit_prediction = classifier_net(X_test_torch)\n", 339 | "print(Z_digit_prediction.shape)\n", 340 | "print(X_noise_digit_prediction.shape)\n", 341 | "print(X_digit_prediction.shape)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "print(torch.mean((torch.argmax(X_noise_digit_prediction, dim=1).cpu() == Y_test_torch).float()))\n", 351 | "print(torch.mean((torch.argmax(X_digit_prediction, dim=1).cpu() == Y_test_torch).float()))\n", 352 | "print(torch.mean((torch.argmax(Z_digit_prediction, dim=1).cpu() == Y_test_torch).float()))" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "\"\"\"Optimise topological loss on a single case to get some nice pictures for the paper\n", 362 | "train network with some specific set of parameters, then observe change in output reconstruction\n", 363 | "when topological priors are applied to the output and the network's weights adjusted \"\"\" \n", 364 | "\n", 365 | "H_1 = {0:1, 1:0} # 1, 2, 3, 4, 5, 7\n", 366 | "H_0 = {0:1, 1:1} # 0, 6, 9\n", 367 | "H_8 = {0:1, 1:2} # 8\n", 368 | "\n", 369 | "# correct topology for each digit\n", 370 | "H_dict = {0:H_0,\n", 371 | " 1:H_1,\n", 372 | " 2:H_1, # note this will close the loop on the 2 - some interesting cases here?\n", 373 | " 3:H_1,\n", 374 | " 4:H_1,\n", 375 | " 5:H_1,\n", 376 | " 6:H_0,\n", 377 | " 7:H_1,\n", 378 | " 8:H_8,\n", 379 | " 9:H_0}\n", 380 | "\n", 381 | "dgminfo = LevelSetLayer2D(size=(28,28), sublevel=False, maxdim=1)\n", 382 | "l2_loss = nn.MSELoss()" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "original_network_correct = torch.argmax(Z_digit_prediction, dim=1).cpu() == Y_test_torch\n", 392 | "print(original_network_correct[:20])" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [ 401 | "i = 3\n", 402 | "f = plt.figure(figsize=(15,5))\n", 403 | "(ax1, ax2, ax3) = f.subplots(1,3)\n", 404 | "ax1.imshow(X_test_noise[i])\n", 405 | "ax1.set_xticks([])\n", 406 | "ax1.set_yticks([])\n", 407 | "ax2.imshow(X_test[i])\n", 408 | "ax2.set_xticks([])\n", 409 | "ax2.set_yticks([])\n", 410 | "ax3.imshow(Z_predicted_np[i])\n", 411 | "ax3.set_xticks([])\n", 412 | "ax3.set_yticks([])\n", 413 | "print('Ground truth: {}'.format(Y_test[i]))\n", 414 | "print('Predicted as: {}'.format(torch.argmax(Z_digit_prediction[i]).item()))\n", 415 | "print('Logits:')\n", 416 | "print(Z_digit_prediction[i])" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "model_topo = copy.deepcopy(model)\n", 426 | "optimizer = torch.optim.Adam(model_topo.parameters(), lr=1e-5)\n", 427 | "num_iter_topo = 100\n", 428 | "digit_i = Y_test[i]\n", 429 | "H_i = H_dict[digit_i]\n", 430 | "\n", 431 | "print(digit_i)\n", 432 | "print(H_i)\n", 433 | "\n", 434 | "original_model_output = model(X_test_noise_torch[i:i+1]).cpu().detach() # detach to avoid second pass error\n", 435 | "\n", 436 | "L_sqdiff_weight = 10 # hyper-parameter\n", 437 | "max_k = 20 # only consider this many bars - most will be 0-length anyway\n", 438 | "\n", 439 | "L_list = []\n", 440 | "for t in range(num_iter_topo):\n", 441 | " optimizer.zero_grad()\n", 442 | " Z_cuda = model_topo(X_test_noise_torch[i:i+1])\n", 443 | " Z_cpu = Z_cuda.cpu()\n", 444 | " a = dgminfo(Z_cpu)\n", 445 | "\n", 446 | " L0 = (TopKBarcodeLengths(dim=0, k=max_k)(a)**2).sum()\n", 447 | " dim_1_sq_bars = TopKBarcodeLengths(dim=1, k=max_k)(a)**2\n", 448 | " bar_signs = torch.ones(max_k)\n", 449 | " bar_signs[:H_i[1]] = -1\n", 450 | " L1 = (dim_1_sq_bars * bar_signs).sum()\n", 451 | "\n", 452 | " L_sqdiff = l2_loss(original_model_output, Z_cpu) * L_sqdiff_weight\n", 453 | " L = L0 + L1 + L_sqdiff\n", 454 | " L.backward()\n", 455 | " L_list.append(L.item())\n", 456 | " optimizer.step()\n", 457 | "\n", 458 | " ground_truth_mask = X_test_torch[i:i+1][0,0].cpu().detach()\n", 459 | " original_predicted_mask = original_model_output[0,0]\n", 460 | " topo_predicted_mask = Z_cpu[0,0].detach()" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": null, 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "f = plt.figure(figsize=(20,5))\n", 470 | "(ax1, ax2, ax3, ax4) = f.subplots(1,4)\n", 471 | "ax1.imshow(X_test_noise[i])\n", 472 | "ax1.set_xticks([])\n", 473 | "ax1.set_yticks([])\n", 474 | "ax2.imshow(X_test[i])\n", 475 | "ax2.set_xticks([])\n", 476 | "ax2.set_yticks([])\n", 477 | "ax3.imshow(Z_predicted_np[i])\n", 478 | "ax3.set_xticks([])\n", 479 | "ax3.set_yticks([])\n", 480 | "ax4.imshow(Z_cpu[0,0].detach().numpy(), cmap='gray')\n", 481 | "ax4.set_xticks([])\n", 482 | "ax4.set_yticks([])\n", 483 | "\n", 484 | "Z_topo_digit_prediction_i = classifier_net(Z_cuda)\n", 485 | "\n", 486 | "print(Y_test[i])\n", 487 | "print(torch.argmax(Z_topo_digit_prediction_i).item())\n", 488 | "print(Z_topo_digit_prediction_i)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [ 497 | "def diag_tidy(diag, eps=1e-1):\n", 498 | " new_diag = []\n", 499 | " for _, x in diag:\n", 500 | " if np.abs(x[0] - x[1]) > eps:\n", 501 | " new_diag.append((_, x))\n", 502 | " return new_diag\n", 503 | "\n", 504 | "plt.figure(figsize=(3,3))\n", 505 | "plt.imshow(Z_predicted_np[i], cmap='gray')\n", 506 | "plt.xticks([])\n", 507 | "plt.yticks([])\n", 508 | "plt.colorbar()\n", 509 | "plt.show()\n", 510 | "\n", 511 | "cc = gd.CubicalComplex(dimensions=(img_dim, img_dim),\n", 512 | " top_dimensional_cells=1-Z_predicted_np[i].flatten())\n", 513 | "\n", 514 | "diag = cc.persistence()\n", 515 | "plt.figure(figsize=(3,3))\n", 516 | "diag_clean = diag_tidy(diag, 1e-3)\n", 517 | "gd.plot_persistence_barcode(diag_clean)\n", 518 | "plt.ylim(-1, len(diag_clean))\n", 519 | "plt.xticks(ticks=np.linspace(0, 1, 6), labels=np.round(np.linspace(1, 0, 6), 2))\n", 520 | "plt.yticks([])\n", 521 | "plt.show()" 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [ 530 | "plt.figure(figsize=(3,3))\n", 531 | "plt.imshow(Z_cpu[0,0].detach().numpy(), cmap='gray')\n", 532 | "plt.xticks([])\n", 533 | "plt.yticks([])\n", 534 | "plt.colorbar()\n", 535 | "plt.show()\n", 536 | "\n", 537 | "cc = gd.CubicalComplex(dimensions=(img_dim, img_dim),\n", 538 | " top_dimensional_cells=1-Z_cpu[0,0].detach().numpy().flatten())\n", 539 | "diag = cc.persistence()\n", 540 | "\n", 541 | "plt.figure(figsize=(3,3))\n", 542 | "diag_clean = diag_tidy(diag, 1e-3)\n", 543 | "gd.plot_persistence_barcode(diag_clean)\n", 544 | "plt.ylim(-1, len(diag_clean))\n", 545 | "plt.xticks(ticks=np.linspace(0, 1, 6), labels=np.round(np.linspace(1, 0, 6), 2))\n", 546 | "plt.yticks([])" 547 | ] 548 | } 549 | ], 550 | "metadata": { 551 | "kernelspec": { 552 | "display_name": "Python 3", 553 | "language": "python", 554 | "name": "python3" 555 | }, 556 | "language_info": { 557 | "codemirror_mode": { 558 | "name": "ipython", 559 | "version": 3 560 | }, 561 | "file_extension": ".py", 562 | "mimetype": "text/x-python", 563 | "name": "python", 564 | "nbconvert_exporter": "python", 565 | "pygments_lexer": "ipython3", 566 | "version": "3.7.3" 567 | } 568 | }, 569 | "nbformat": 4, 570 | "nbformat_minor": 2 571 | } 572 | --------------------------------------------------------------------------------