├── Lecture.ipynb ├── README.md ├── main.py └── save ├── loss.pdf └── result.png /Lecture.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Lecture.ipynb","provenance":[],"collapsed_sections":[],"mount_file_id":"16YeIonTTR3PifruBx4ubCRGsrY0lvksL","authorship_tag":"ABX9TyNNKlol00SM+BKOnUGOyIg+"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":568},"id":"YbVKAnEsyRor","executionInfo":{"status":"ok","timestamp":1608630334006,"user_tz":-540,"elapsed":30076,"user":{"displayName":"gakusyu umelab","photoUrl":"","userId":"09554957494174248921"}},"outputId":"40bcfd0f-5db1-418d-bd8f-77306a1387f3"},"source":["import os\r\n","import numpy as np\r\n","import torch\r\n","from torch import nn\r\n","from torch.autograd import Variable\r\n","from torch.utils.data import DataLoader,Dataset\r\n","from torchvision import transforms\r\n","from torchvision.datasets import MNIST\r\n","import pylab\r\n","import matplotlib.pyplot as plt\r\n","\r\n","mount_dir = './drive/MyDrive/Colab Notebooks/anomaly-detection-using-autoencoder-PyTorch-master'\r\n","\r\n","class Mnisttox(Dataset):\r\n"," def __init__(self, datasets ,labels:list):\r\n"," self.dataset = [datasets[i][0] for i in range(len(datasets))\r\n"," if datasets[i][1] in labels ]\r\n"," self.labels = labels\r\n"," self.len_oneclass = int(len(self.dataset)/10)\r\n","\r\n"," def __len__(self):\r\n"," return int(len(self.dataset))\r\n","\r\n"," def __getitem__(self, index):\r\n"," img = self.dataset[index]\r\n"," return img,[]\r\n","\r\n","class Autoencoder(nn.Module):\r\n"," def __init__(self,z_dim):\r\n"," super(Autoencoder, self).__init__()\r\n"," self.encoder = nn.Sequential(\r\n"," nn.Linear(28 * 28, 256),\r\n"," nn.ReLU(True),\r\n"," nn.Linear(256, 128),\r\n"," nn.ReLU(True),\r\n"," nn.Linear(128, z_dim))\r\n","\r\n"," self.decoder = nn.Sequential(\r\n"," nn.Linear(z_dim, 128),\r\n"," nn.ReLU(True),\r\n"," nn.Linear(128, 256),\r\n"," nn.ReLU(True),\r\n"," nn.Linear(256, 28 * 28),\r\n"," nn.Tanh()\r\n"," )\r\n","\r\n"," def forward(self, x):\r\n"," z = self.encoder(x)\r\n"," xhat = self.decoder(z)\r\n"," return xhat\r\n","\r\n","z_dim = 64\r\n","batch_size = 16\r\n","num_epochs = 10\r\n","learning_rate = 3.0e-4\r\n","n = 6 #number of test sample\r\n","cuda = True\r\n","model = Autoencoder(z_dim)\r\n","mse_loss = nn.MSELoss()\r\n","optimizer = torch.optim.Adam(model.parameters(),\r\n"," lr=learning_rate,\r\n"," weight_decay=1e-5)\r\n","\r\n","if cuda:\r\n"," model.cuda()\r\n","\r\n","img_transform = transforms.Compose([\r\n"," transforms.ToTensor(),\r\n"," transforms.Normalize((0.5, ), (0.5, )) # [0,1] => [-1,1]\r\n","])\r\n","train_dataset = MNIST('./data', download=True,train=True, transform=img_transform)\r\n","train_1 = Mnisttox(train_dataset,[1])\r\n","train_loader = DataLoader(train_1, batch_size=batch_size, shuffle=True)\r\n","losses = np.zeros(num_epochs)\r\n","\r\n","for epoch in range(num_epochs):\r\n"," i = 0\r\n"," for img,_ in train_loader:\r\n"," # print(\"now\")\r\n","\r\n"," x = img.view(img.size(0), -1)\r\n","\r\n"," if cuda:\r\n"," x = Variable(x).cuda()\r\n"," else:\r\n"," x = Variable(x)\r\n","\r\n"," xhat = model(x)\r\n","\r\n"," # 出力画像(再構成画像)と入力画像の間でlossを計算\r\n"," loss = mse_loss(xhat, x)\r\n"," losses[epoch] = losses[epoch] * (i / (i + 1.)) + loss * (1. / (i + 1.))\r\n"," optimizer.zero_grad()\r\n"," loss.backward()\r\n"," optimizer.step()\r\n"," i += 1\r\n","\r\n"," plt.figure()\r\n"," pylab.xlim(0, num_epochs)\r\n"," plt.plot(range(0, num_epochs), losses, label='loss')\r\n"," plt.legend()\r\n"," plt.savefig(os.path.join(mount_dir+\"/save/\", 'loss.pdf'))\r\n"," plt.close()\r\n","\r\n"," print('epoch [{}/{}], loss: {:.4f}'.format(\r\n"," epoch + 1,\r\n"," num_epochs,\r\n"," loss))\r\n","\r\n","test_dataset = MNIST('./data', train=False,download=True, transform=img_transform)\r\n","test_1_9 = Mnisttox(test_dataset,[1,9])\r\n","test_loader = DataLoader(test_1_9, batch_size=len(test_dataset), shuffle=True)\r\n","\r\n","for img,_ in test_loader:\r\n"," x = img.view(img.size(0), -1)\r\n","\r\n"," if cuda:\r\n"," x = Variable(x).cuda()\r\n"," else:\r\n"," x = Variable(x)\r\n","\r\n"," xhat = model(x)\r\n"," x = x.cpu().detach().numpy()\r\n"," xhat = xhat.cpu().detach().numpy()\r\n"," x = x/2 + 0.5\r\n"," xhat = xhat/2 + 0.5\r\n","\r\n","# サンプル画像表示\r\n","plt.figure(figsize=(12, 6))\r\n","for i in range(n):\r\n"," # テスト画像を表示\r\n"," ax = plt.subplot(3, n, i + 1)\r\n"," plt.imshow(x[i].reshape(28, 28))\r\n"," plt.gray()\r\n"," ax.get_xaxis().set_visible(False)\r\n"," ax.get_yaxis().set_visible(False)\r\n","\r\n"," # 出力画像を表示\r\n"," ax = plt.subplot(3, n, i + 1 + n)\r\n"," plt.imshow(xhat[i].reshape(28, 28))\r\n"," plt.gray()\r\n"," ax.get_xaxis().set_visible(False)\r\n"," ax.get_yaxis().set_visible(False)\r\n","\r\n"," # 入出力の差分画像を計算\r\n"," diff_img = np.abs(x[i] - xhat[i])\r\n","\r\n"," # 入出力の差分数値を計算\r\n"," diff = np.sum(diff_img)\r\n","\r\n"," # 差分画像と差分数値の表示\r\n"," ax = plt.subplot(3, n, i + 1 + n * 2)\r\n"," plt.imshow(diff_img.reshape(28, 28),cmap=\"jet\")\r\n"," #plt.gray()\r\n"," ax.get_xaxis().set_visible(True)\r\n"," ax.get_yaxis().set_visible(True)\r\n"," ax.set_xlabel('score = ' + str(diff))\r\n","\r\n","plt.savefig(mount_dir+\"/save/result.png\")\r\n","plt.show()\r\n","plt.close()\r\n","\r\n"],"execution_count":2,"outputs":[{"output_type":"stream","text":["epoch [1/10], loss: 0.0754\n","epoch [2/10], loss: 0.0461\n","epoch [3/10], loss: 0.0380\n","epoch [4/10], loss: 0.0418\n","epoch [5/10], loss: 0.0262\n","epoch [6/10], loss: 0.0172\n","epoch [7/10], loss: 0.0111\n","epoch [8/10], loss: 0.0133\n","epoch [9/10], loss: 0.0122\n","epoch [10/10], loss: 0.0300\n"],"name":"stdout"},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}}]}]} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-Autoencoder using MNIST 2 | https://qiita.com/satolab/items/8efa513e7fd6cb41fdc5 3 | 4 | Powered by [satolab](https://qiita.com/satolab) 5 | 6 | ## Overview 7 | PyTorchを用いた,AutoencoderによるMNISTの異常検知プログラムです. 8 | 9 | 10 | 11 | ## Model 12 | 13 | 14 | ## Results 15 | - 10 epochs(input,output,difference) 16 | ![result.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/583727/c2921e0a-dee3-2964-42cd-f9cf97ebc320.png) 17 | 18 | 19 | ## Usage 20 | - main.pyで学習.save dirにサンプルが保存されます. 21 | Learn with main.py. The sample is saved in save dir. 22 | 23 | ## References 24 | 差分画像の計算と表示部分 25 | http://cedro3.com/ai/keras-autoencoder-anomaly/ 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader,Dataset 7 | from torchvision import transforms 8 | from torchvision.datasets import MNIST 9 | import pylab 10 | import matplotlib.pyplot as plt 11 | 12 | class Mnisttox(Dataset): 13 | def __init__(self, datasets ,labels:list): 14 | self.dataset = [datasets[i][0] for i in range(len(datasets)) 15 | if datasets[i][1] in labels ] 16 | self.labels = labels 17 | self.len_oneclass = int(len(self.dataset)/10) 18 | 19 | def __len__(self): 20 | return int(len(self.dataset)) 21 | 22 | def __getitem__(self, index): 23 | img = self.dataset[index] 24 | return img,[] 25 | 26 | class Autoencoder(nn.Module): 27 | def __init__(self,z_dim): 28 | super(Autoencoder, self).__init__() 29 | self.encoder = nn.Sequential( 30 | nn.Linear(28 * 28, 256), 31 | nn.ReLU(True), 32 | nn.Linear(256, 128), 33 | nn.ReLU(True), 34 | nn.Linear(128, z_dim)) 35 | 36 | self.decoder = nn.Sequential( 37 | nn.Linear(z_dim, 128), 38 | nn.ReLU(True), 39 | nn.Linear(128, 256), 40 | nn.ReLU(True), 41 | nn.Linear(256, 28 * 28), 42 | nn.Tanh() 43 | ) 44 | 45 | def forward(self, x): 46 | z = self.encoder(x) 47 | xhat = self.decoder(z) 48 | return xhat 49 | 50 | z_dim = 64 51 | batch_size = 16 52 | num_epochs = 10 53 | learning_rate = 3.0e-4 54 | n = 6 #number of test sample 55 | cuda = True 56 | model = Autoencoder(z_dim) 57 | mse_loss = nn.MSELoss() 58 | optimizer = torch.optim.Adam(model.parameters(), 59 | lr=learning_rate, 60 | weight_decay=1e-5) 61 | 62 | if cuda: 63 | model.cuda() 64 | 65 | img_transform = transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Normalize((0.5, ), (0.5, )) # [0,1] => [-1,1] 68 | ]) 69 | train_dataset = MNIST('./data', download=True,train=True, transform=img_transform) 70 | train_1 = Mnisttox(train_dataset,[1]) 71 | train_loader = DataLoader(train_1, batch_size=batch_size, shuffle=True) 72 | losses = np.zeros(num_epochs) 73 | 74 | for epoch in range(num_epochs): 75 | i = 0 76 | for img,_ in train_loader: 77 | 78 | x = img.view(img.size(0), -1) 79 | 80 | if cuda: 81 | x = Variable(x).cuda() 82 | else: 83 | x = Variable(x) 84 | 85 | xhat = model(x) 86 | 87 | # 出力画像(再構成画像)と入力画像の間でlossを計算 88 | loss = mse_loss(xhat, x) 89 | losses[epoch] = losses[epoch] * (i / (i + 1.)) + loss * (1. / (i + 1.)) 90 | optimizer.zero_grad() 91 | loss.backward() 92 | optimizer.step() 93 | i += 1 94 | 95 | plt.figure() 96 | pylab.xlim(0, num_epochs) 97 | plt.plot(range(0, num_epochs), losses, label='loss') 98 | plt.legend() 99 | plt.savefig(os.path.join("./save/", 'loss.pdf')) 100 | plt.close() 101 | 102 | print('epoch [{}/{}], loss: {:.4f}'.format( 103 | epoch + 1, 104 | num_epochs, 105 | loss)) 106 | 107 | test_dataset = MNIST('./data', train=False,download=True, transform=img_transform) 108 | test_1_9 = Mnisttox(test_dataset,[1,9]) 109 | test_loader = DataLoader(test_1_9, batch_size=len(test_dataset), shuffle=True) 110 | 111 | for img,_ in test_loader: 112 | x = img.view(img.size(0), -1) 113 | 114 | if cuda: 115 | x = Variable(x).cuda() 116 | else: 117 | x = Variable(x) 118 | 119 | xhat = model(x) 120 | x = x.cpu().detach().numpy() 121 | xhat = xhat.cpu().detach().numpy() 122 | x = x/2 + 0.5 123 | xhat = xhat/2 + 0.5 124 | 125 | # サンプル画像表示 126 | plt.figure(figsize=(12, 6)) 127 | for i in range(n): 128 | # テスト画像を表示 129 | ax = plt.subplot(3, n, i + 1) 130 | plt.imshow(x[i].reshape(28, 28)) 131 | plt.gray() 132 | ax.get_xaxis().set_visible(False) 133 | ax.get_yaxis().set_visible(False) 134 | 135 | # 出力画像を表示 136 | ax = plt.subplot(3, n, i + 1 + n) 137 | plt.imshow(xhat[i].reshape(28, 28)) 138 | plt.gray() 139 | ax.get_xaxis().set_visible(False) 140 | ax.get_yaxis().set_visible(False) 141 | 142 | # 入出力の差分画像を計算 143 | diff_img = np.abs(x[i] - xhat[i]) 144 | 145 | # 入出力の差分数値を計算 146 | diff = np.sum(diff_img) 147 | 148 | # 差分画像と差分数値の表示 149 | ax = plt.subplot(3, n, i + 1 + n * 2) 150 | plt.imshow(diff_img.reshape(28, 28),cmap="jet") 151 | #plt.gray() 152 | ax.get_xaxis().set_visible(True) 153 | ax.get_yaxis().set_visible(True) 154 | ax.set_xlabel('score = ' + str(diff)) 155 | 156 | plt.savefig("./save/result.png") 157 | plt.show() 158 | plt.close() 159 | 160 | -------------------------------------------------------------------------------- /save/loss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satolab12/anomaly-detection-using-autoencoder-PyTorch/2581f7d8045b8e22cb1234e720dd757f21e3223e/save/loss.pdf -------------------------------------------------------------------------------- /save/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/satolab12/anomaly-detection-using-autoencoder-PyTorch/2581f7d8045b8e22cb1234e720dd757f21e3223e/save/result.png --------------------------------------------------------------------------------