├── README.md ├── kervolutionMnist.ipynb └── layer.py /README.md: -------------------------------------------------------------------------------- 1 | ## Kervolutional Neural Networks 2 | A Pytorch implementation for the Kervolutional AKA Kernel Convolutional Layer from Kervolutional Neural Networks [[paper](https://arxiv.org/pdf/1904.03955.pdf)]. 3 | It is doing something very similar to Network in Network but using kernels to add the non-linearity instead. 4 | 5 | ## Dependancies 6 | ``` 7 | pip install 8 | ``` 9 | 10 | To use this layer: 11 | ``` 12 | from layer import KernelConv2d, GaussianKernel, PolynomialKernel 13 | ``` 14 | -------------------------------------------------------------------------------- /kervolutionMnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.optim as optim\n", 13 | "from torchvision import datasets,transforms\n", 14 | "import numpy as np\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "import os\n", 17 | "import pandas as pd\n", 18 | "from layer import KernelConv2d, GaussianKernel, PolynomialKernel\n", 19 | "from functools import partial # To invoke Kernel objects with input parameters when creating KernelConv2d object (e.g. partial(GaussianKernel, 0.05) for Gaussian OR partial(PolynomialKernel,2,3) for Polynomial)\n", 20 | "%matplotlib inline\n", 21 | "def mkdirs(path):\n", 22 | " if not os.path.exists(path):\n", 23 | " os.makedirs(path)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Kervolution LeNet" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "class KNNet(nn.Module):\n", 40 | " def __init__(self):\n", 41 | " super(KNNet,self).__init__()\n", 42 | " self.conv1=KernelConv2d(1,10,5,partial(GaussianKernel, 0.05)) # self.conv1=KernelConv2d(1,10,5) for default/Ploynomial kernel with default parameters\n", 43 | " print(self.conv1)\n", 44 | " self.bn1=nn.BatchNorm2d(10)\n", 45 | " self.conv2=KernelConv2d(10,20,5)\n", 46 | " self.bn2=nn.BatchNorm2d(20)\n", 47 | " self.conv2_drop=nn.Dropout2d()\n", 48 | " self.fc1=nn.Linear(320,50)\n", 49 | " self.fc2=nn.Linear(50,10)\n", 50 | " def forward(self,x):\n", 51 | " x=F.relu(F.max_pool2d(self.conv1(x),2))\n", 52 | " x=self.bn1(x)\n", 53 | " x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))\n", 54 | " x=self.bn2(x)\n", 55 | " x=x.view(-1,320)\n", 56 | " x=F.relu(self.fc1(x))\n", 57 | " x=F.dropout(x,training=self.training)\n", 58 | " x=F.relu(self.fc2(x))\n", 59 | " return F.log_softmax(x,dim=1)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": { 66 | "scrolled": true 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "train_loader=torch.utils.data.DataLoader(\n", 71 | " datasets.MNIST(\"data\",train=True,download=True,transform=transforms.Compose([\n", 72 | " transforms.ToTensor(),\n", 73 | " ])),batch_size=128,shuffle=True)\n", 74 | "test_loader=torch.utils.data.DataLoader(\n", 75 | " datasets.MNIST(\"data\",train=False,download=True,transform=transforms.Compose([\n", 76 | " transforms.ToTensor(),\n", 77 | " ])),batch_size=128,shuffle=False\n", 78 | ")\n", 79 | "attack_test_loader=torch.utils.data.DataLoader(\n", 80 | " datasets.MNIST(\"data\",train=False,download=True,transform=transforms.Compose([\n", 81 | " transforms.ToTensor(),\n", 82 | " ])),batch_size=1,shuffle=False\n", 83 | ")\n", 84 | "print(len(train_loader))\n", 85 | "print(len(test_loader))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 95 | "\n", 96 | "knn=KNNet().to(device)\n", 97 | "knn.train(mode=True)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "criterion=torch.nn.NLLLoss()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "def compute_accuray(pred,true):\n", 116 | " pred_idx=pred.argmax(dim=1).detach().cpu().numpy()\n", 117 | " tmp=pred_idx==true.cpu().numpy()\n", 118 | " return sum(tmp)/len(pred_idx)\n", 119 | "def train(m,out_dir):\n", 120 | " iter_loss=[]\n", 121 | " train_losses=[]\n", 122 | " test_losses=[]\n", 123 | " iter_loss_path=os.path.join(out_dir,\"iter_loss.csv\")\n", 124 | " epoch_loss_path=os.path.join(out_dir,\"epoch_loss.csv\")\n", 125 | " nb_epochs=20\n", 126 | " last_loss=99999\n", 127 | " mkdirs(os.path.join(out_dir,\"models\"))\n", 128 | " optimizer=optim.SGD(m.parameters(),lr=0.003,momentum=0.9)\n", 129 | " for epoch in range(nb_epochs):\n", 130 | " train_loss=0.\n", 131 | " train_acc=0.\n", 132 | " m.train(mode=True)\n", 133 | " for data,target in train_loader:\n", 134 | " data,target=data.to(device),target.to(device)\n", 135 | " optimizer.zero_grad()\n", 136 | " output=m(data)\n", 137 | " loss=criterion(output,target)\n", 138 | " loss_value=loss.item()\n", 139 | " iter_loss.append(loss_value)\n", 140 | " train_loss+=loss_value\n", 141 | " loss.backward()\n", 142 | " optimizer.step()\n", 143 | " acc=compute_accuray(output,target)\n", 144 | " train_acc+=acc\n", 145 | " train_losses.append(train_loss/len(train_loader))\n", 146 | " \n", 147 | " test_loss=0.\n", 148 | " test_acc=0.\n", 149 | " m.train(mode=False)\n", 150 | " for data,target in test_loader:\n", 151 | " data,target=data.to(device),target.to(device)\n", 152 | " output=m(data)\n", 153 | " loss=criterion(output,target)\n", 154 | " loss_value=loss.item()\n", 155 | " iter_loss.append(loss_value)\n", 156 | " test_loss+=loss_value\n", 157 | " acc=compute_accuray(output,target)\n", 158 | " test_acc+=acc\n", 159 | " test_losses.append(test_loss/len(test_loader))\n", 160 | " print(\"Epoch {}: train loss is {}, train accuracy is {}; test loss is {}, test accuracy is {}\".\n", 161 | " format(epoch,round(train_loss/len(train_loader),2),\n", 162 | " round(train_acc/len(train_loader),2),\n", 163 | " round(test_loss/len(test_loader),2),\n", 164 | " round(test_acc/len(test_loader),2))) \n", 165 | " if test_loss/len(test_loader)