├── Old Trials ├── README.md ├── bnn_mnist.ipynb ├── pbnn_mnist_v1.ipynb ├── bnn_cifar.ipynb ├── pbnn_mnist_v2.ipynb ├── pbnn_mnist_v7.ipynb ├── pbnn_mnist_v3.ipynb ├── pbnn_cifar_v1.ipynb └── pbnn_mnist_v5.ipynb └── README.md /Old Trials/README.md: -------------------------------------------------------------------------------- 1 | # pbnn 2 | reproduce paper of pbnn 3 | 4 | ## read jupyter notebook online 5 | visit **https://nbviewer.jupyter.org/** 6 | 7 | copy and paste the url of github, e.g. **https://github.com/YanghaoZYH/pbnn/blob/master/pbnn_mnist_v1.ipynb** 8 | 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reproduction of Probabilistic binary neural networks 2 | 3 | ### Team 4 | 5 | - Yanghao Zhang yz16n18@soton.ac.uk 6 | 7 | - Zhiruo Zhang zz1a18@soton.ac.uk 8 | 9 | - Yigang Zhou yz13u18@soton.ac.uk 10 | 11 | 12 | 13 | ### Brief 14 | 15 | This project presents a probabilistic binary neural network according to the paper published on ICLR. We implemented all functions in the paper including the embracement of stochasticity in training process and stochastic versions of Batch Normalization, as well as sampling in binary activations. A similar result to the original paper was obtained after experiment. 16 | 17 | 18 | 19 | ### Useful links 20 | 21 | Jorn W.T. Peters, Tim Genewein, and Max Welling. Probabilistic binary neural networks, 2019. 22 | 23 | Shayer, O., Levi, D. and Fetaya, E., 2017. Learning discrete weights using the local reparameterization trick. arXiv preprint arXiv:1710.07739. 24 | https://openreview.net/pdf?id=BySRH6CpW 25 | 26 | Peters, J.W. and Welling, M., 2018. Probabilistic Binary Neural Networks. arXiv preprint arXiv:1809.03368. 27 | https://arxiv.org/pdf/1809.03368.pdf 28 | 29 | Kingma, D.P., Salimans, T. and Welling, M., 2015. Variational dropout and the local reparameterization trick. In Advances in Neural Information Processing Systems (pp. 2575-2583). 30 | 31 | Binarized Neural Network (BNN) for pytorch 32 | -------------------------------------------------------------------------------- /Old Trials/bnn_mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Copy of binaryNN_mnist_save1.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "\n", 33 | "import numpy as np\n", 34 | "\n", 35 | "\n", 36 | "def Binarize(tensor,quant_mode='det'):\n", 37 | " if quant_mode=='det':\n", 38 | " return tensor.sign()\n", 39 | " else:\n", 40 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 41 | "\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "class HingeLoss(nn.Module):\n", 46 | " def __init__(self):\n", 47 | " super(HingeLoss,self).__init__()\n", 48 | " self.margin=1.0\n", 49 | "\n", 50 | " def hinge_loss(self,input,target):\n", 51 | " #import pdb; pdb.set_trace()\n", 52 | " output=self.margin-input.mul(target)\n", 53 | " output[output.le(0)]=0\n", 54 | " return output.mean()\n", 55 | "\n", 56 | " def forward(self, input, target):\n", 57 | " return self.hinge_loss(input,target)\n", 58 | "\n", 59 | "class SqrtHingeLossFunction(Function):\n", 60 | " def __init__(self):\n", 61 | " super(SqrtHingeLossFunction,self).__init__()\n", 62 | " self.margin=1.0\n", 63 | "\n", 64 | " def forward(self, input, target):\n", 65 | " output=self.margin-input.mul(target)\n", 66 | " output[output.le(0)]=0\n", 67 | " self.save_for_backward(input, target)\n", 68 | " loss=output.mul(output).sum(0).sum(1).div(target.numel())\n", 69 | " return loss\n", 70 | "\n", 71 | " def backward(self,grad_output):\n", 72 | " input, target = self.saved_tensors\n", 73 | " output=self.margin-input.mul(target)\n", 74 | " output[output.le(0)]=0\n", 75 | " import pdb; pdb.set_trace()\n", 76 | " grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)\n", 77 | " grad_output.mul_(output.ne(0).float())\n", 78 | " grad_output.div_(input.numel())\n", 79 | " return grad_output,grad_output\n", 80 | "\n", 81 | "def Quantize(tensor,quant_mode='det', params=None, numBits=8):\n", 82 | " tensor.clamp_(-2**(numBits-1),2**(numBits-1))\n", 83 | " if quant_mode=='det':\n", 84 | " tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))\n", 85 | " else:\n", 86 | " tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))\n", 87 | " quant_fixed(tensor, params)\n", 88 | " return tensor\n", 89 | "\n", 90 | "import torch.nn._functions as tnnf\n", 91 | "\n", 92 | "\n", 93 | "class BinarizeLinear(nn.Linear):\n", 94 | "\n", 95 | " def __init__(self, *kargs, **kwargs):\n", 96 | " super(BinarizeLinear, self).__init__(*kargs, **kwargs)\n", 97 | "\n", 98 | " def forward(self, input):\n", 99 | "\n", 100 | " if input.size(1) != 784:\n", 101 | " input.data=Binarize(input.data)\n", 102 | " if not hasattr(self.weight,'org'):\n", 103 | " self.weight.org=self.weight.data.clone()\n", 104 | " self.weight.data=Binarize(self.weight.org)\n", 105 | " out = nn.functional.linear(input, self.weight)\n", 106 | " if not self.bias is None:\n", 107 | " self.bias.org=self.bias.data.clone()\n", 108 | " out += self.bias.view(1, -1).expand_as(out)\n", 109 | "\n", 110 | " return out\n", 111 | "\n", 112 | "class BinarizeConv2d(nn.Conv2d):\n", 113 | "\n", 114 | " def __init__(self, *kargs, **kwargs):\n", 115 | " super(BinarizeConv2d, self).__init__(*kargs, **kwargs)\n", 116 | "\n", 117 | "\n", 118 | " def forward(self, input):\n", 119 | " if input.size(1) != 3:\n", 120 | " input.data = Binarize(input.data)\n", 121 | " if not hasattr(self.weight,'org'):\n", 122 | " self.weight.org=self.weight.data.clone()\n", 123 | " self.weight.data=Binarize(self.weight.org)\n", 124 | "\n", 125 | " out = nn.functional.conv2d(input, self.weight, None, self.stride,\n", 126 | " self.padding, self.dilation, self.groups)\n", 127 | "\n", 128 | " if not self.bias is None:\n", 129 | " self.bias.org=self.bias.data.clone()\n", 130 | " out += self.bias.view(1, -1, 1, 1).expand_as(out)\n", 131 | "\n", 132 | " return out\n" 133 | ], 134 | "execution_count": 0, 135 | "outputs": [] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "metadata": { 140 | "id": "ggHwTvf4MRqC", 141 | "colab_type": "code", 142 | "outputId": "daef3c23-afca-41d7-dee4-b6d924fd8854", 143 | "colab": { 144 | "base_uri": "https://localhost:8080/", 145 | "height": 185 146 | } 147 | }, 148 | "source": [ 149 | "from __future__ import print_function\n", 150 | "import argparse\n", 151 | "import torch\n", 152 | "import torch.nn as nn\n", 153 | "import torch.nn.functional as F\n", 154 | "import torch.optim as optim\n", 155 | "from torchvision import datasets, transforms\n", 156 | "from torch.autograd import Variable\n", 157 | "from tqdm import tqdm\n", 158 | "\n", 159 | "\n", 160 | "torch.manual_seed(1)\n", 161 | "\n", 162 | "\n", 163 | "\n", 164 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 165 | "train_loader = torch.utils.data.DataLoader(\n", 166 | " datasets.MNIST('../data', train=True, download=True,\n", 167 | " transform=transforms.Compose([\n", 168 | " transforms.ToTensor(),\n", 169 | " transforms.Normalize((0.1307,), (0.3081,))\n", 170 | " ])),\n", 171 | " batch_size=64, shuffle=True)\n", 172 | "test_loader = torch.utils.data.DataLoader(\n", 173 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 174 | " transforms.ToTensor(),\n", 175 | " transforms.Normalize((0.1307,), (0.3081,))\n", 176 | " ])),\n", 177 | " batch_size=64, shuffle=True)\n", 178 | "\n", 179 | "\n", 180 | "class Net(nn.Module):\n", 181 | " def __init__(self):\n", 182 | " super(Net, self).__init__()\n", 183 | " self.infl_ratio=3\n", 184 | " self.fc1 = BinarizeLinear(784, 2048*self.infl_ratio)\n", 185 | " self.htanh1 = nn.Hardtanh()\n", 186 | " self.bn1 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 187 | " self.fc2 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 188 | " self.htanh2 = nn.Hardtanh()\n", 189 | " self.bn2 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 190 | " self.fc3 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 191 | " self.htanh3 = nn.Hardtanh()\n", 192 | " self.bn3 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 193 | " self.fc4 = nn.Linear(2048*self.infl_ratio, 10)\n", 194 | " self.logsoftmax = nn.LogSoftmax(dim=1)\n", 195 | " self.drop=nn.Dropout(0.5)\n", 196 | "\n", 197 | " def forward(self, x):\n", 198 | " x = x.view(-1, 28*28)\n", 199 | " \n", 200 | " x = self.fc1(x)\n", 201 | " x = self.bn1(x)\n", 202 | " x = self.htanh1(x)\n", 203 | " \n", 204 | "# x = self.fc2(x)\n", 205 | "# x = self.bn2(x)\n", 206 | "# x = self.htanh2(x)\n", 207 | "# x = self.fc3(x)\n", 208 | "# x = self.drop(x)\n", 209 | "# x = self.bn3(x)\n", 210 | "# x = self.htanh3(x)\n", 211 | " x = self.fc4(x)\n", 212 | " return self.logsoftmax(x)\n", 213 | "\n", 214 | "model = Net()\n", 215 | "torch.cuda.device('cuda')\n", 216 | "model.cuda()\n", 217 | "\n", 218 | "\n", 219 | "criterion = nn.CrossEntropyLoss()\n", 220 | "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 221 | "\n", 222 | "\n", 223 | "def train(epoch):\n", 224 | " model.train()\n", 225 | " \n", 226 | " losses = []\n", 227 | " trainloader = tqdm(train_loader)\n", 228 | " \n", 229 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 230 | " \n", 231 | " data, target = data.cuda(), target.cuda()\n", 232 | " data, target = Variable(data), Variable(target)\n", 233 | " optimizer.zero_grad()\n", 234 | " output = model(data)\n", 235 | " loss = criterion(output, target)\n", 236 | "\n", 237 | "# if epoch%40==0:\n", 238 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 239 | "\n", 240 | "# optimizer.zero_grad()\n", 241 | " \n", 242 | " loss.backward()\n", 243 | " \n", 244 | " for p in list(model.parameters()):\n", 245 | " if hasattr(p,'org'):\n", 246 | " p.data.copy_(p.org)\n", 247 | " optimizer.step()\n", 248 | " \n", 249 | " for p in list(model.parameters()):\n", 250 | " if hasattr(p,'org'):\n", 251 | " p.org.copy_(p.data.clamp_(-1,1))\n", 252 | " \n", 253 | " losses.append(loss.item())\n", 254 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 255 | "# if batch_idx % 10000 == 0:\n", 256 | "# print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 257 | "# epoch, batch_idx * len(data), len(train_loader.dataset),\n", 258 | "# 100. * batch_idx / len(train_loader), loss.item()))\n", 259 | "\n", 260 | "def test():\n", 261 | " model.eval()\n", 262 | " test_loss = 0\n", 263 | " correct = 0\n", 264 | " testloader = tqdm(test_loader)\n", 265 | " for data, target in testloader:\n", 266 | " data, target = data.cuda(), target.cuda()\n", 267 | " with torch.no_grad():\n", 268 | " data = Variable(data)\n", 269 | " target = Variable(target)\n", 270 | " output = model(data)\n", 271 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 272 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 273 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 274 | " \n", 275 | " \n", 276 | "\n", 277 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 278 | " \n", 279 | " test_loss /= len(test_loader.dataset)\n", 280 | " \n", 281 | " \n", 282 | "# print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 283 | "# test_loss, correct, len(test_loader.dataset),\n", 284 | "# 100. * correct / len(test_loader.dataset)))\n", 285 | "\n", 286 | "\n", 287 | "for epoch in range(5):\n", 288 | " train(epoch)\n", 289 | " test()" 290 | ], 291 | "execution_count": 0, 292 | "outputs": [ 293 | { 294 | "output_type": "stream", 295 | "text": [ 296 | "100%|██████████| 938/938 [00:14<00:00, 65.07it/s, epoch=0, loss=0.347]\n", 297 | "100%|██████████| 157/157 [00:01<00:00, 83.37it/s, acc=93%, loss=0.00361]\n", 298 | "100%|██████████| 938/938 [00:16<00:00, 58.43it/s, epoch=1, loss=0.206]\n", 299 | "100%|██████████| 157/157 [00:02<00:00, 67.73it/s, acc=95%, loss=0.00248]\n", 300 | "100%|██████████| 938/938 [00:16<00:00, 57.91it/s, epoch=2, loss=0.141]\n", 301 | "100%|██████████| 157/157 [00:01<00:00, 85.35it/s, acc=94%, loss=0.00269]\n", 302 | "100%|██████████| 938/938 [00:14<00:00, 66.19it/s, epoch=3, loss=0.11]\n", 303 | "100%|██████████| 157/157 [00:01<00:00, 84.63it/s, acc=96%, loss=0.00216]\n", 304 | "100%|██████████| 938/938 [00:14<00:00, 66.64it/s, epoch=4, loss=0.0814]\n", 305 | "100%|██████████| 157/157 [00:01<00:00, 85.82it/s, acc=96%, loss=0.0019]\n" 306 | ], 307 | "name": "stderr" 308 | } 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "metadata": { 314 | "id": "zauLH4y0zg_G", 315 | "colab_type": "code", 316 | "colab": {} 317 | }, 318 | "source": [ 319 | "" 320 | ], 321 | "execution_count": 0, 322 | "outputs": [] 323 | } 324 | ] 325 | } 326 | -------------------------------------------------------------------------------- /Old Trials/pbnn_mnist_v1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pbnn_mnist_v1.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "import time\n", 33 | "\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "\n", 37 | "def Binarize(tensor,quant_mode='det'):\n", 38 | " if quant_mode=='det':\n", 39 | " return tensor.sign()\n", 40 | " else:\n", 41 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "class HingeLoss(nn.Module):\n", 47 | " def __init__(self):\n", 48 | " super(HingeLoss,self).__init__()\n", 49 | " self.margin=1.0\n", 50 | "\n", 51 | " def hinge_loss(self,input,target):\n", 52 | " #import pdb; pdb.set_trace()\n", 53 | " output=self.margin-input.mul(target)\n", 54 | " output[output.le(0)]=0\n", 55 | " return output.mean()\n", 56 | "\n", 57 | " def forward(self, input, target):\n", 58 | " return self.hinge_loss(input,target)\n", 59 | "\n", 60 | "class SqrtHingeLossFunction(Function):\n", 61 | " def __init__(self):\n", 62 | " super(SqrtHingeLossFunction,self).__init__()\n", 63 | " self.margin=1.0\n", 64 | "\n", 65 | " def forward(self, input, target):\n", 66 | " output=self.margin-input.mul(target)\n", 67 | " output[output.le(0)]=0\n", 68 | " self.save_for_backward(input, target)\n", 69 | " loss=output.mul(output).sum(0).sum(1).div(target.numel())\n", 70 | " return loss\n", 71 | "\n", 72 | " def backward(self,grad_output):\n", 73 | " input, target = self.saved_tensors\n", 74 | " output=self.margin-input.mul(target)\n", 75 | " output[output.le(0)]=0\n", 76 | " import pdb; pdb.set_trace()\n", 77 | " grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)\n", 78 | " grad_output.mul_(output.ne(0).float())\n", 79 | " grad_output.div_(input.numel())\n", 80 | " return grad_output,grad_output\n", 81 | "\n", 82 | "def Quantize(tensor,quant_mode='det', params=None, numBits=8):\n", 83 | " tensor.clamp_(-2**(numBits-1),2**(numBits-1))\n", 84 | " if quant_mode=='det':\n", 85 | " tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))\n", 86 | " else:\n", 87 | " tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))\n", 88 | " quant_fixed(tensor, params)\n", 89 | " return tensor\n", 90 | "\n", 91 | "import torch.nn._functions as tnnf\n", 92 | "\n", 93 | "\n", 94 | "class BinarizeLinear(nn.Linear):\n", 95 | "\n", 96 | " def __init__(self, *kargs, **kwargs):\n", 97 | " super(BinarizeLinear, self).__init__(*kargs, **kwargs)\n", 98 | "\n", 99 | " def forward(self, input):\n", 100 | "\n", 101 | " if input.size(1) != 784:\n", 102 | " input.data=Binarize(input.data)\n", 103 | " if not hasattr(self.weight,'org'):\n", 104 | " self.weight.org=self.weight.data.clone()\n", 105 | " self.weight.data=Binarize(self.weight.org)\n", 106 | " out = nn.functional.linear(input, self.weight)\n", 107 | " if not self.bias is None:\n", 108 | " self.bias.org=self.bias.data.clone()\n", 109 | " out += self.bias.view(1, -1).expand_as(out)\n", 110 | "\n", 111 | " return out\n", 112 | "\n", 113 | "class BinarizeConv2d(nn.Conv2d):\n", 114 | "\n", 115 | " def __init__(self, *kargs, **kwargs):\n", 116 | " super(BinarizeConv2d, self).__init__(*kargs, **kwargs)\n", 117 | "\n", 118 | "\n", 119 | " def forward(self, input):\n", 120 | " if input.size(1) != 3:\n", 121 | " input.data = Binarize(input.data)\n", 122 | " if not hasattr(self.weight,'org'):\n", 123 | " self.weight.org=self.weight.data.clone()\n", 124 | " self.weight.data=Binarize(self.weight.org)\n", 125 | "\n", 126 | " out = nn.functional.conv2d(input, self.weight, None, self.stride,\n", 127 | " self.padding, self.dilation, self.groups)\n", 128 | "\n", 129 | " if not self.bias is None:\n", 130 | " self.bias.org=self.bias.data.clone()\n", 131 | " out += self.bias.view(1, -1, 1, 1).expand_as(out)\n", 132 | "\n", 133 | " return out\n", 134 | " \n", 135 | "\n" 136 | ], 137 | "execution_count": 0, 138 | "outputs": [] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "metadata": { 143 | "id": "pCIRllsuAtAD", 144 | "colab_type": "code", 145 | "outputId": "7fd606b7-fa4f-41ab-f207-a188a0de5159", 146 | "colab": { 147 | "base_uri": "https://localhost:8080/", 148 | "height": 185 149 | } 150 | }, 151 | "source": [ 152 | "from __future__ import print_function\n", 153 | "import argparse\n", 154 | "import torch\n", 155 | "import torch.nn as nn\n", 156 | "import torch.nn.functional as F\n", 157 | "import torch.optim as optim\n", 158 | "from torchvision import datasets, transforms\n", 159 | "from torch.autograd import Variable\n", 160 | "from tqdm import tqdm\n", 161 | "# from models.binarized_modules import BinarizeLinear,BinarizeConv2d\n", 162 | "# from models.binarized_modules import Binarize,Ternarize,Ternarize2,Ternarize3,Ternarize4,HingeLoss\n", 163 | "# Training settings\n", 164 | "# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", 165 | "# parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", 166 | "# help='input batch size for training (default: 256)')\n", 167 | "# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", 168 | "# help='input batch size for testing (default: 1000)')\n", 169 | "# parser.add_argument('--epochs', type=int, default=100, metavar='N',\n", 170 | "# help='number of epochs to train (default: 10)')\n", 171 | "# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',\n", 172 | "# help='learning rate (default: 0.001)')\n", 173 | "# parser.add_argument('--momentum', type=float, default=0.5, metavar='M',\n", 174 | "# help='SGD momentum (default: 0.5)')\n", 175 | "# parser.add_argument('--no-cuda', action='store_true', default=False,\n", 176 | "# help='disables CUDA training')\n", 177 | "# parser.add_argument('--seed', type=int, default=1, metavar='S',\n", 178 | "# help='random seed (default: 1)')\n", 179 | "# parser.add_argument('--gpus', default=3,\n", 180 | "# help='gpus used for training - e.g 0,1,3')\n", 181 | "# parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n", 182 | "# help='how many batches to wait before logging training status')\n", 183 | "# args = parser.parse_args()\n", 184 | "# args.cuda = not args.no_cuda and torch.cuda.is_available()\n", 185 | "\n", 186 | "torch.manual_seed(1)\n", 187 | "# if args.cuda:\n", 188 | "# torch.cuda.manual_seed(args.seed)\n", 189 | "\n", 190 | "\n", 191 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 192 | "train_loader = torch.utils.data.DataLoader(\n", 193 | " datasets.MNIST('../data', train=True, download=True,\n", 194 | " transform=transforms.Compose([\n", 195 | " transforms.ToTensor(),\n", 196 | " transforms.Normalize((0.1307,), (0.3081,))\n", 197 | " ])),\n", 198 | " batch_size=128, shuffle=True)\n", 199 | "test_loader = torch.utils.data.DataLoader(\n", 200 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 201 | " transforms.ToTensor(),\n", 202 | " transforms.Normalize((0.1307,), (0.3081,))\n", 203 | " ])),\n", 204 | " batch_size=128, shuffle=True)\n", 205 | "\n", 206 | "\n", 207 | "# 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10\n", 208 | "\n", 209 | "class Net(nn.Module):\n", 210 | " def __init__(self):\n", 211 | " super(Net, self).__init__()\n", 212 | " \n", 213 | "# self.fc1 = BinarizeLinear(784, 2048*self.infl_ratio)\n", 214 | " \n", 215 | " \n", 216 | "# self.infl_ratio=3\n", 217 | "# self.fc1 = BinarizeLinear(784, 2048*self.infl_ratio)\n", 218 | "# self.htanh1 = nn.Hardtanh()\n", 219 | "# self.bn1 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 220 | "# self.fc2 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 221 | "# self.htanh2 = nn.Hardtanh()\n", 222 | "# self.bn2 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 223 | "# self.fc3 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 224 | "# self.htanh3 = nn.Hardtanh()\n", 225 | "# self.bn3 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 226 | "# self.fc4 = nn.Linear(2048*self.infl_ratio, 10)\n", 227 | "# self.logsoftmax = nn.LogSoftmax(dim=1)\n", 228 | "# self.drop=nn.Dropout(0.5)\n", 229 | " \n", 230 | " self.conv1 = BinarizeConv2d(1, 32, kernel_size=3)\n", 231 | " self.mp1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 232 | " self.bn1 = nn.BatchNorm2d(32)\n", 233 | " self.htanh1 = nn.Hardtanh()\n", 234 | " \n", 235 | " self.conv2 = BinarizeConv2d(32, 64, kernel_size=3)\n", 236 | " self.mp2 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 237 | " self.bn2 = nn.BatchNorm2d(64)\n", 238 | " self.htanh2 = nn.Hardtanh()\n", 239 | " \n", 240 | " self.fc1 = nn.Linear(64*5*5, 512)\n", 241 | " self.bn3 = nn.BatchNorm1d(512)\n", 242 | " self.htanh3 = nn.Hardtanh()\n", 243 | " \n", 244 | " self.fc2 = nn.Linear(512, 10)\n", 245 | " self.sm = nn.Softmax(dim=1)\n", 246 | " \n", 247 | " \n", 248 | "\n", 249 | " def forward(self, x):\n", 250 | "# x = x.view(-1, 28*28)\n", 251 | " x = Binarize(x)\n", 252 | " x = self.conv1(x)\n", 253 | " x = self.mp1(x)\n", 254 | " x = self.bn1(x)\n", 255 | " x = self.htanh1(x)\n", 256 | " x = Binarize(x)\n", 257 | " \n", 258 | " x = self.conv2(x)\n", 259 | " x = self.mp2(x)\n", 260 | " x = self.bn2(x)\n", 261 | " x = self.htanh2(x)\n", 262 | " x = Binarize(x)\n", 263 | " \n", 264 | "# print(x.shape)\n", 265 | "# x = x.view(-1, 64*5*5)\n", 266 | " x = x.view(x.size(0), -1)\n", 267 | " \n", 268 | " \n", 269 | " x = self.fc1(x)\n", 270 | " x = self.bn3(x)\n", 271 | " x = self.htanh3(x)\n", 272 | " x = Binarize(x)\n", 273 | " \n", 274 | " x = self.fc2(x)\n", 275 | " \n", 276 | "\n", 277 | " return self.sm(x)\n", 278 | "\n", 279 | "model = Net()\n", 280 | "torch.cuda.device('cuda')\n", 281 | "model.cuda()\n", 282 | "\n", 283 | "\n", 284 | "criterion = nn.CrossEntropyLoss()\n", 285 | "optimizer = optim.Adam(model.parameters(), lr=0.01)\n", 286 | "\n", 287 | "\n", 288 | "def train(epoch):\n", 289 | " model.train()\n", 290 | " \n", 291 | " losses = []\n", 292 | " trainloader = tqdm(train_loader)\n", 293 | " \n", 294 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 295 | " \n", 296 | " data, target = data.cuda(), target.cuda()\n", 297 | " data, target = Variable(data), Variable(target)\n", 298 | " optimizer.zero_grad()\n", 299 | " output = model(data)\n", 300 | " loss = criterion(output, target)\n", 301 | "\n", 302 | "# if epoch%40==0:\n", 303 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 304 | "\n", 305 | "# optimizer.zero_grad()\n", 306 | " \n", 307 | " loss.backward()\n", 308 | " \n", 309 | " for p in list(model.parameters()):\n", 310 | " if hasattr(p,'org'):\n", 311 | " p.data.copy_(p.org)\n", 312 | " optimizer.step()\n", 313 | " \n", 314 | " for p in list(model.parameters()):\n", 315 | " if hasattr(p,'org'):\n", 316 | " p.org.copy_(p.data.clamp_(-1,1))\n", 317 | " \n", 318 | " losses.append(loss.item())\n", 319 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 320 | "# if batch_idx % 10000 == 0:\n", 321 | "# print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 322 | "# epoch, batch_idx * len(data), len(train_loader.dataset),\n", 323 | "# 100. * batch_idx / len(train_loader), loss.item()))\n", 324 | "\n", 325 | "def test():\n", 326 | " model.eval()\n", 327 | " test_loss = 0\n", 328 | " correct = 0\n", 329 | " testloader = tqdm(test_loader)\n", 330 | " for data, target in testloader:\n", 331 | " data, target = data.cuda(), target.cuda()\n", 332 | " with torch.no_grad():\n", 333 | " data = Variable(data)\n", 334 | " target = Variable(target)\n", 335 | " output = model(data)\n", 336 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 337 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 338 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 339 | " \n", 340 | " \n", 341 | "\n", 342 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 343 | " \n", 344 | " test_loss /= len(test_loader.dataset)\n", 345 | " \n", 346 | " \n", 347 | "# print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 348 | "# test_loss, correct, len(test_loader.dataset),\n", 349 | "# 100. * correct / len(test_loader.dataset)))\n", 350 | "\n", 351 | "\n", 352 | "for epoch in range(5):\n", 353 | " train(epoch)\n", 354 | " test()" 355 | ], 356 | "execution_count": 2, 357 | "outputs": [ 358 | { 359 | "output_type": "stream", 360 | "text": [ 361 | "100%|██████████| 469/469 [00:13<00:00, 34.80it/s, epoch=0, loss=1.6]\n", 362 | "100%|██████████| 79/79 [00:01<00:00, 47.50it/s, acc=89%, loss=0.0124]\n", 363 | "100%|██████████| 469/469 [00:11<00:00, 39.50it/s, epoch=1, loss=1.57]\n", 364 | "100%|██████████| 79/79 [00:01<00:00, 47.74it/s, acc=89%, loss=0.0124]\n", 365 | "100%|██████████| 469/469 [00:11<00:00, 39.42it/s, epoch=2, loss=1.57]\n", 366 | "100%|██████████| 79/79 [00:01<00:00, 47.80it/s, acc=89%, loss=0.0124]\n", 367 | "100%|██████████| 469/469 [00:11<00:00, 39.37it/s, epoch=3, loss=1.57]\n", 368 | "100%|██████████| 79/79 [00:01<00:00, 47.96it/s, acc=89%, loss=0.0124]\n", 369 | "100%|██████████| 469/469 [00:11<00:00, 39.30it/s, epoch=4, loss=1.57]\n", 370 | "100%|██████████| 79/79 [00:01<00:00, 48.08it/s, acc=89%, loss=0.0123]\n" 371 | ], 372 | "name": "stderr" 373 | } 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "metadata": { 379 | "id": "NWCZnU4KiEhH", 380 | "colab_type": "code", 381 | "colab": {} 382 | }, 383 | "source": [ 384 | "" 385 | ], 386 | "execution_count": 0, 387 | "outputs": [] 388 | } 389 | ] 390 | } -------------------------------------------------------------------------------- /Old Trials/bnn_cifar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pbinaryNN_cifar_v1.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "import time\n", 33 | "\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "\n", 37 | "def Binarize(tensor,quant_mode='det'):\n", 38 | " if quant_mode=='det':\n", 39 | " return tensor.sign()\n", 40 | " else:\n", 41 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "class HingeLoss(nn.Module):\n", 47 | " def __init__(self):\n", 48 | " super(HingeLoss,self).__init__()\n", 49 | " self.margin=1.0\n", 50 | "\n", 51 | " def hinge_loss(self,input,target):\n", 52 | " #import pdb; pdb.set_trace()\n", 53 | " output=self.margin-input.mul(target)\n", 54 | " output[output.le(0)]=0\n", 55 | " return output.mean()\n", 56 | "\n", 57 | " def forward(self, input, target):\n", 58 | " return self.hinge_loss(input,target)\n", 59 | "\n", 60 | "class SqrtHingeLossFunction(Function):\n", 61 | " def __init__(self):\n", 62 | " super(SqrtHingeLossFunction,self).__init__()\n", 63 | " self.margin=1.0\n", 64 | "\n", 65 | " def forward(self, input, target):\n", 66 | " output=self.margin-input.mul(target)\n", 67 | " output[output.le(0)]=0\n", 68 | " self.save_for_backward(input, target)\n", 69 | " loss=output.mul(output).sum(0).sum(1).div(target.numel())\n", 70 | " return loss\n", 71 | "\n", 72 | " def backward(self,grad_output):\n", 73 | " input, target = self.saved_tensors\n", 74 | " output=self.margin-input.mul(target)\n", 75 | " output[output.le(0)]=0\n", 76 | " import pdb; pdb.set_trace()\n", 77 | " grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)\n", 78 | " grad_output.mul_(output.ne(0).float())\n", 79 | " grad_output.div_(input.numel())\n", 80 | " return grad_output,grad_output\n", 81 | "\n", 82 | "def Quantize(tensor,quant_mode='det', params=None, numBits=8):\n", 83 | " tensor.clamp_(-2**(numBits-1),2**(numBits-1))\n", 84 | " if quant_mode=='det':\n", 85 | " tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))\n", 86 | " else:\n", 87 | " tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))\n", 88 | " quant_fixed(tensor, params)\n", 89 | " return tensor\n", 90 | "\n", 91 | "import torch.nn._functions as tnnf\n", 92 | "\n", 93 | "\n", 94 | "class BinarizeLinear(nn.Linear):\n", 95 | "\n", 96 | " def __init__(self, *kargs, **kwargs):\n", 97 | " super(BinarizeLinear, self).__init__(*kargs, **kwargs)\n", 98 | "\n", 99 | " def forward(self, input):\n", 100 | "\n", 101 | " if input.size(1) != 784:\n", 102 | " input.data=Binarize(input.data)\n", 103 | " if not hasattr(self.weight,'org'):\n", 104 | " self.weight.org=self.weight.data.clone()\n", 105 | " self.weight.data=Binarize(self.weight.org)\n", 106 | " out = nn.functional.linear(input, self.weight)\n", 107 | " if not self.bias is None:\n", 108 | " self.bias.org=self.bias.data.clone()\n", 109 | " out += self.bias.view(1, -1).expand_as(out)\n", 110 | "\n", 111 | " return out\n", 112 | "\n", 113 | "class BinarizeConv2d(nn.Conv2d):\n", 114 | "\n", 115 | " def __init__(self, *kargs, **kwargs):\n", 116 | " super(BinarizeConv2d, self).__init__(*kargs, **kwargs)\n", 117 | "\n", 118 | "\n", 119 | " def forward(self, input):\n", 120 | " if input.size(1) != 3:\n", 121 | " input.data = Binarize(input.data)\n", 122 | " if not hasattr(self.weight,'org'):\n", 123 | " self.weight.org=self.weight.data.clone()\n", 124 | " self.weight.data=Binarize(self.weight.org)\n", 125 | "\n", 126 | " out = nn.functional.conv2d(input, self.weight, None, self.stride,\n", 127 | " self.padding, self.dilation, self.groups)\n", 128 | "\n", 129 | " if not self.bias is None:\n", 130 | " self.bias.org=self.bias.data.clone()\n", 131 | " out += self.bias.view(1, -1, 1, 1).expand_as(out)\n", 132 | "\n", 133 | " return out\n" 134 | ], 135 | "execution_count": 0, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "ggHwTvf4MRqC", 142 | "colab_type": "code", 143 | "outputId": "94f11deb-6521-4233-ef37-501cb0ba8799", 144 | "colab": { 145 | "base_uri": "https://localhost:8080/", 146 | "height": 238 147 | } 148 | }, 149 | "source": [ 150 | "\n", 151 | "%%time \n", 152 | "from __future__ import print_function\n", 153 | "import argparse\n", 154 | "import torch\n", 155 | "import torch.nn as nn\n", 156 | "import torch.nn.functional as F\n", 157 | "import torch.optim as optim\n", 158 | "from torchvision import datasets, transforms\n", 159 | "from torch.autograd import Variable\n", 160 | "from tqdm import tqdm\n", 161 | "\n", 162 | "# from models.binarized_modules import BinarizeLinear,BinarizeConv2d\n", 163 | "# from models.binarized_modules import Binarize,Ternarize,Ternarize2,Ternarize3,Ternarize4,HingeLoss\n", 164 | "# Training settings\n", 165 | "# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", 166 | "# parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", 167 | "# help='input batch size for training (default: 256)')\n", 168 | "# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", 169 | "# help='input batch size for testing (default: 1000)')\n", 170 | "# parser.add_argument('--epochs', type=int, default=100, metavar='N',\n", 171 | "# help='number of epochs to train (default: 10)')\n", 172 | "# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',\n", 173 | "# help='learning rate (default: 0.001)')\n", 174 | "# parser.add_argument('--momentum', type=float, default=0.5, metavar='M',\n", 175 | "# help='SGD momentum (default: 0.5)')\n", 176 | "# parser.add_argument('--no-cuda', action='store_true', default=False,\n", 177 | "# help='disables CUDA training')\n", 178 | "# parser.add_argument('--seed', type=int, default=1, metavar='S',\n", 179 | "# help='random seed (default: 1)')\n", 180 | "# parser.add_argument('--gpus', default=3,\n", 181 | "# help='gpus used for training - e.g 0,1,3')\n", 182 | "# parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n", 183 | "# help='how many batches to wait before logging training status')\n", 184 | "# args = parser.parse_args()\n", 185 | "# args.cuda = not args.no_cuda and torch.cuda.is_available()\n", 186 | "\n", 187 | "torch.manual_seed(1)\n", 188 | "# if args.cuda:\n", 189 | "# torch.cuda.manual_seed(args.seed)\n", 190 | "\n", 191 | "\n", 192 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 193 | "train_loader = torch.utils.data.DataLoader(\n", 194 | " datasets.CIFAR10('../data', train=True, download=True,\n", 195 | " transform=transforms.Compose([\n", 196 | "# transforms.RandomHorizontalFlip(),\n", 197 | " transforms.ToTensor(),\n", 198 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 199 | " \n", 200 | " ])),\n", 201 | " batch_size=64, shuffle=True)\n", 202 | "test_loader = torch.utils.data.DataLoader(\n", 203 | " datasets.CIFAR10('../data', train=False, transform=transforms.Compose([\n", 204 | "# transforms.RandomHorizontalFlip(),\n", 205 | "# transforms.RandomCrop(32,padding=4),\n", 206 | " transforms.ToTensor(),\n", 207 | " transforms.Normalize((0.1307,), (0.3081,))\n", 208 | " ])),\n", 209 | " batch_size=64, shuffle=True)\n", 210 | "\n", 211 | "\n", 212 | "class Net(nn.Module):\n", 213 | " def __init__(self):\n", 214 | " super(Net, self).__init__()\n", 215 | " self.infl_ratio=3\n", 216 | " self.fc1 = BinarizeLinear(1024, 2048*self.infl_ratio)\n", 217 | " self.htanh1 = nn.Hardtanh()\n", 218 | " self.bn1 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 219 | " \n", 220 | " self.fc2 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 221 | " self.htanh2 = nn.Hardtanh()\n", 222 | " self.bn2 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 223 | " self.fc3 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 224 | " self.htanh3 = nn.Hardtanh()\n", 225 | " self.bn3 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 226 | " self.fc4 = nn.Linear(2048*self.infl_ratio, 10)\n", 227 | " \n", 228 | " \n", 229 | " self.logsoftmax = nn.LogSoftmax(dim=1)\n", 230 | " self.drop=nn.Dropout(0.5)\n", 231 | " \n", 232 | " \n", 233 | " self.b1 = BinarizeConv2d(3, 128*self.infl_ratio, kernel_size=3, stride=1, padding=1,\n", 234 | " bias=True)\n", 235 | " self.b2 = nn.BatchNorm2d(128*self.infl_ratio)\n", 236 | " self.b3 = nn.Hardtanh(inplace=True)\n", 237 | " \n", 238 | " self.b4 = BinarizeConv2d(128*self.infl_ratio, 512, kernel_size=3, padding=1, bias=True)\n", 239 | " self.b5 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 240 | " self.b6 = nn.BatchNorm2d(512)\n", 241 | " self.b7 = nn.Hardtanh(inplace=True)\n", 242 | " \n", 243 | " \n", 244 | " self.bb1 = BinarizeLinear(512 * 16 * 16, 1024, bias=True)\n", 245 | " self.bb2 = nn.BatchNorm1d(1024)\n", 246 | " self.bb3 = nn.Hardtanh(inplace=True)\n", 247 | " self.bb4 = BinarizeLinear(1024, 10, bias=True)\n", 248 | " self.bb5 = nn.BatchNorm1d(10, affine=False)\n", 249 | " self.bb6 = nn.LogSoftmax(dim=1)\n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | "\n", 254 | " def forward(self, x):\n", 255 | " \n", 256 | " x = self.b1(x)\n", 257 | " x = self.b2(x)\n", 258 | " x = self.b3(x)\n", 259 | " x = self.b4(x)\n", 260 | " x = self.b5(x)\n", 261 | " x = self.b6(x)\n", 262 | " x = self.b7(x)\n", 263 | " \n", 264 | "# print(x.shape)\n", 265 | " x = x.view(-1, 512 * 16 * 16)\n", 266 | " \n", 267 | " x = self.bb1(x)\n", 268 | " x = self.bb2(x)\n", 269 | " x = self.bb3(x)\n", 270 | " x = self.bb4(x)\n", 271 | " x = self.bb5(x)\n", 272 | " return self.bb6(x)\n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | "# x = self.fc1(x)\n", 277 | "# x = self.bn1(x)\n", 278 | "# x = self.htanh1(x)\n", 279 | " \n", 280 | "# x = self.fc2(x)\n", 281 | "# x = self.bn2(x)\n", 282 | "# x = self.htanh2(x)\n", 283 | "# x = self.fc3(x)\n", 284 | "# x = self.drop(x)\n", 285 | "# x = self.bn3(x)\n", 286 | "# x = self.htanh3(x)\n", 287 | "# x = self.fc4(x)\n", 288 | "# return self.logsoftmax(x)\n", 289 | "\n", 290 | "model = Net()\n", 291 | "torch.cuda.device('gpu')\n", 292 | "model.cuda()\n", 293 | "\n", 294 | "\n", 295 | "criterion = nn.CrossEntropyLoss()\n", 296 | "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", 297 | "\n", 298 | "\n", 299 | "def train(epoch):\n", 300 | " model.train()\n", 301 | " \n", 302 | " losses = []\n", 303 | " trainloader = tqdm(train_loader)\n", 304 | " \n", 305 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 306 | " \n", 307 | " data, target = data.cuda(), target.cuda()\n", 308 | " data, target = Variable(data), Variable(target)\n", 309 | " optimizer.zero_grad()\n", 310 | " output = model(data)\n", 311 | " loss = criterion(output, target)\n", 312 | "\n", 313 | "# if epoch%40==0:\n", 314 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 315 | "\n", 316 | "# optimizer.zero_grad()\n", 317 | " \n", 318 | " loss.backward()\n", 319 | " \n", 320 | " for p in list(model.parameters()):\n", 321 | " if hasattr(p,'org'):\n", 322 | " p.data.copy_(p.org)\n", 323 | " optimizer.step()\n", 324 | " \n", 325 | " for p in list(model.parameters()):\n", 326 | " if hasattr(p,'org'):\n", 327 | " p.org.copy_(p.data.clamp_(-1,1))\n", 328 | " \n", 329 | " losses.append(loss.item())\n", 330 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 331 | "# if batch_idx % 10000 == 0:\n", 332 | "# print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 333 | "# epoch, batch_idx * len(data), len(train_loader.dataset),\n", 334 | "# 100. * batch_idx / len(train_loader), loss.item()))\n", 335 | "\n", 336 | "def test():\n", 337 | " model.eval()\n", 338 | " test_loss = 0\n", 339 | " correct = 0\n", 340 | " testloader = tqdm(test_loader)\n", 341 | " for data, target in testloader:\n", 342 | " data, target = data.cuda(), target.cuda()\n", 343 | " with torch.no_grad():\n", 344 | " data = Variable(data)\n", 345 | " target = Variable(target)\n", 346 | " output = model(data)\n", 347 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 348 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 349 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 350 | " \n", 351 | " \n", 352 | "\n", 353 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 354 | " \n", 355 | " test_loss /= len(test_loader.dataset)\n", 356 | " \n", 357 | " \n", 358 | "# print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 359 | "# test_loss, correct, len(test_loader.dataset),\n", 360 | "# 100. * correct / len(test_loader.dataset)))\n", 361 | "\n", 362 | "\n", 363 | "for epoch in range(5):\n", 364 | " train(epoch)\n", 365 | " test()" 366 | ], 367 | "execution_count": 0, 368 | "outputs": [ 369 | { 370 | "output_type": "stream", 371 | "text": [ 372 | "Files already downloaded and verified\n" 373 | ], 374 | "name": "stdout" 375 | }, 376 | { 377 | "output_type": "stream", 378 | "text": [ 379 | "100%|██████████| 782/782 [03:43<00:00, 4.30it/s, epoch=0, loss=1.49]\n", 380 | "100%|██████████| 157/157 [00:16<00:00, 9.32it/s, acc=59%, loss=0.0208]\n", 381 | "100%|██████████| 782/782 [03:43<00:00, 4.28it/s, epoch=1, loss=1.18]\n", 382 | "100%|██████████| 157/157 [00:16<00:00, 9.36it/s, acc=64%, loss=0.0186]\n", 383 | "100%|██████████| 782/782 [03:44<00:00, 4.29it/s, epoch=2, loss=1.06]\n", 384 | "100%|██████████| 157/157 [00:16<00:00, 9.34it/s, acc=63%, loss=0.0186]\n", 385 | "100%|██████████| 782/782 [03:55<00:00, 4.28it/s, epoch=3, loss=0.961]\n", 386 | "100%|██████████| 157/157 [00:16<00:00, 9.36it/s, acc=66%, loss=0.0183]\n", 387 | "100%|██████████| 782/782 [03:43<00:00, 4.28it/s, epoch=4, loss=0.851]\n", 388 | "100%|██████████| 157/157 [00:16<00:00, 9.36it/s, acc=64%, loss=0.0187]" 389 | ], 390 | "name": "stderr" 391 | }, 392 | { 393 | "output_type": "stream", 394 | "text": [ 395 | "CPU times: user 11min 30s, sys: 8min 26s, total: 19min 57s\n", 396 | "Wall time: 20min 19s\n" 397 | ], 398 | "name": "stdout" 399 | }, 400 | { 401 | "output_type": "stream", 402 | "text": [ 403 | "\n" 404 | ], 405 | "name": "stderr" 406 | } 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "metadata": { 412 | "id": "-d-bkux8OYrg", 413 | "colab_type": "code", 414 | "colab": {} 415 | }, 416 | "source": [ 417 | "" 418 | ], 419 | "execution_count": 0, 420 | "outputs": [] 421 | } 422 | ] 423 | } 424 | -------------------------------------------------------------------------------- /Old Trials/pbnn_mnist_v2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pbnn_mnist_v2.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "import time\n", 33 | "\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "\n", 37 | "def Binarize(tensor,quant_mode='det'):\n", 38 | " if quant_mode=='det':\n", 39 | " return tensor.sign()\n", 40 | " else:\n", 41 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "class HingeLoss(nn.Module):\n", 47 | " def __init__(self):\n", 48 | " super(HingeLoss,self).__init__()\n", 49 | " self.margin=1.0\n", 50 | "\n", 51 | " def hinge_loss(self,input,target):\n", 52 | " #import pdb; pdb.set_trace()\n", 53 | " output=self.margin-input.mul(target)\n", 54 | " output[output.le(0)]=0\n", 55 | " return output.mean()\n", 56 | "\n", 57 | " def forward(self, input, target):\n", 58 | " return self.hinge_loss(input,target)\n", 59 | "\n", 60 | "class SqrtHingeLossFunction(Function):\n", 61 | " def __init__(self):\n", 62 | " super(SqrtHingeLossFunction,self).__init__()\n", 63 | " self.margin=1.0\n", 64 | "\n", 65 | " def forward(self, input, target):\n", 66 | " output=self.margin-input.mul(target)\n", 67 | " output[output.le(0)]=0\n", 68 | " self.save_for_backward(input, target)\n", 69 | " loss=output.mul(output).sum(0).sum(1).div(target.numel())\n", 70 | " return loss\n", 71 | "\n", 72 | " def backward(self,grad_output):\n", 73 | " input, target = self.saved_tensors\n", 74 | " output=self.margin-input.mul(target)\n", 75 | " output[output.le(0)]=0\n", 76 | " import pdb; pdb.set_trace()\n", 77 | " grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)\n", 78 | " grad_output.mul_(output.ne(0).float())\n", 79 | " grad_output.div_(input.numel())\n", 80 | " return grad_output,grad_output\n", 81 | "\n", 82 | "def Quantize(tensor,quant_mode='det', params=None, numBits=8):\n", 83 | " tensor.clamp_(-2**(numBits-1),2**(numBits-1))\n", 84 | " if quant_mode=='det':\n", 85 | " tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))\n", 86 | " else:\n", 87 | " tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))\n", 88 | " quant_fixed(tensor, params)\n", 89 | " return tensor\n", 90 | "\n", 91 | "import torch.nn._functions as tnnf\n", 92 | "\n", 93 | "\n", 94 | "class BinarizeLinear(nn.Linear):\n", 95 | "\n", 96 | " def __init__(self, *kargs, **kwargs):\n", 97 | " super(BinarizeLinear, self).__init__(*kargs, **kwargs)\n", 98 | "\n", 99 | " def forward(self, input):\n", 100 | "\n", 101 | " if input.size(1) != 784:\n", 102 | " input.data=Binarize(input.data)\n", 103 | " if not hasattr(self.weight,'org'):\n", 104 | " self.weight.org=self.weight.data.clone()\n", 105 | " self.weight.data=Binarize(self.weight.org)\n", 106 | " out = nn.functional.linear(input, self.weight)\n", 107 | " if not self.bias is None:\n", 108 | " self.bias.org=self.bias.data.clone()\n", 109 | " out += self.bias.view(1, -1).expand_as(out)\n", 110 | "\n", 111 | " return out\n", 112 | "\n", 113 | "class BinarizeConv2d(nn.Conv2d):\n", 114 | "\n", 115 | " def __init__(self, *kargs, **kwargs):\n", 116 | " super(BinarizeConv2d, self).__init__(*kargs, **kwargs)\n", 117 | "\n", 118 | "\n", 119 | " def forward(self, input):\n", 120 | " if input.size(1) != 3:\n", 121 | " input.data = Binarize(input.data)\n", 122 | " if not hasattr(self.weight,'org'):\n", 123 | " self.weight.org=self.weight.data.clone()\n", 124 | " self.weight.data=Binarize(self.weight.org)\n", 125 | "\n", 126 | " out = nn.functional.conv2d(input, self.weight, None, self.stride,\n", 127 | " self.padding, self.dilation, self.groups)\n", 128 | "\n", 129 | " if not self.bias is None:\n", 130 | " self.bias.org=self.bias.data.clone()\n", 131 | " out += self.bias.view(1, -1, 1, 1).expand_as(out)\n", 132 | "\n", 133 | " return out\n", 134 | " \n", 135 | "\n", 136 | " \n", 137 | "class Preactivation1(nn.Linear):\n", 138 | "\n", 139 | " def __init__(self, *kargs, **kwargs):\n", 140 | " super(Preactivation1, self).__init__(*kargs, **kwargs)\n", 141 | "\n", 142 | " def forward(self, input):\n", 143 | "\n", 144 | " if input.size(1) != 784:\n", 145 | " input.data=Binarize(input.data)\n", 146 | " if not hasattr(self.weight,'org'):\n", 147 | " self.weight.org=self.weight.data.clone()\n", 148 | " self.weight.data=Binarize(self.weight.org)\n", 149 | " \n", 150 | " p = self.weight.sigmoid()\n", 151 | " mu = self.weight*p\n", 152 | " out1 = nn.functional.linear(input,mu)\n", 153 | " \n", 154 | "# if not self.bias is None:\n", 155 | "# self.bias.org=self.bias.data.clone()\n", 156 | "# out += self.bias.view(1, -1).expand_as(out)\n", 157 | "\n", 158 | " return out1 \n", 159 | " \n", 160 | "class Preactivation2(nn.Linear):\n", 161 | "\n", 162 | " def __init__(self, *kargs, **kwargs):\n", 163 | " super(Preactivation2, self).__init__(*kargs, **kwargs)\n", 164 | "\n", 165 | " def forward(self, input):\n", 166 | "\n", 167 | " if input.size(1) != 784:\n", 168 | " input.data=Binarize(input.data)\n", 169 | " if not hasattr(self.weight,'org'):\n", 170 | " self.weight.org=self.weight.data.clone()\n", 171 | " self.weight.data=Binarize(self.weight.org)\n", 172 | " \n", 173 | " p = self.weight.sigmoid()\n", 174 | " mu = self.weight*p\n", 175 | " var = (self.weight - mu)**2*mu\n", 176 | " out2 = nn.functional.linear(input**2,var)\n", 177 | "# if not self.bias is None:\n", 178 | "# self.bias.org=self.bias.data.clone()\n", 179 | "# out += self.bias.view(1, -1).expand_as(out)\n", 180 | "\n", 181 | " return out2 \n", 182 | " \n", 183 | "\n", 184 | "\n", 185 | "# if not self.bias is None:\n", 186 | "# self.bias.org=self.bias.data.clone()\n", 187 | "# out += self.bias.view(1, -1, 1, 1).expand_as(out)\n", 188 | "\n", 189 | "# return input\n", 190 | " \n", 191 | "\n", 192 | "class BinarizeConv1d(nn.Conv1d):\n", 193 | "\n", 194 | " def __init__(self, *kargs, **kwargs):\n", 195 | " super(BinarizeConv1d, self).__init__(*kargs, **kwargs)\n", 196 | "\n", 197 | "\n", 198 | " def forward(self, input):\n", 199 | " if input.size(1) != 3:\n", 200 | " input.data = Binarize(input.data)\n", 201 | " if not hasattr(self.weight,'org'):\n", 202 | " self.weight.org=self.weight.data.clone()\n", 203 | " self.weight.data=Binarize(self.weight.org)\n", 204 | "\n", 205 | " out = nn.functional.conv1d(input, self.weight, None, self.stride,\n", 206 | " self.padding, self.dilation, self.groups)\n", 207 | " if not self.bias is None:\n", 208 | " self.bias.org=self.bias.data.clone()\n", 209 | " out += self.bias.view(1, -1).expand_as(out)\n", 210 | "\n", 211 | " return out\n" 212 | ], 213 | "execution_count": 0, 214 | "outputs": [] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "metadata": { 219 | "id": "ggHwTvf4MRqC", 220 | "colab_type": "code", 221 | "outputId": "bf28c7fe-ab8d-4462-f31c-85b9ae9e42a0", 222 | "colab": { 223 | "base_uri": "https://localhost:8080/", 224 | "height": 191 225 | } 226 | }, 227 | "source": [ 228 | "from __future__ import print_function\n", 229 | "import argparse\n", 230 | "import torch\n", 231 | "import torch.nn as nn\n", 232 | "import torch.nn.functional as F\n", 233 | "import torch.optim as optim\n", 234 | "from torchvision import datasets, transforms\n", 235 | "from torch.autograd import Variable\n", 236 | "from tqdm import tqdm\n", 237 | "from torch.distributions.normal import Normal\n", 238 | "from torch.distributions.relaxed_bernoulli import RelaxedBernoulli\n", 239 | "\n", 240 | "torch.manual_seed(1)\n", 241 | "# if args.cuda:\n", 242 | "# torch.cuda.manual_seed(args.seed)\n", 243 | "\n", 244 | "\n", 245 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 246 | "train_loader = torch.utils.data.DataLoader(\n", 247 | " datasets.MNIST('../data', train=True, download=True,\n", 248 | " transform=transforms.Compose([\n", 249 | " transforms.ToTensor(),\n", 250 | " transforms.Normalize((0.1307,), (0.3081,))\n", 251 | " ])),\n", 252 | " batch_size=128, shuffle=True)\n", 253 | "test_loader = torch.utils.data.DataLoader(\n", 254 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 255 | " transforms.ToTensor(),\n", 256 | " transforms.Normalize((0.1307,), (0.3081,))\n", 257 | " ])),\n", 258 | " batch_size=128, shuffle=True)\n", 259 | "\n", 260 | "\n", 261 | "# 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10\n", 262 | "\n", 263 | "class Net(nn.Module):\n", 264 | " def __init__(self):\n", 265 | " super(Net, self).__init__()\n", 266 | " \n", 267 | "# self.pre1 = BinarizeLinear(784, 2)\n", 268 | "# self.pre2 = BinarizeLinear(784, 2)\n", 269 | "# self.pre3 = nn.BatchNorm1d(2)\n", 270 | "# self.pre4 = BinarizeLinear(2, 784)\n", 271 | " \n", 272 | " \n", 273 | " self.infl_ratio=3\n", 274 | " self.pre1 = Preactivation1(512,512)\n", 275 | " self.pre2 = Preactivation2(512,512)\n", 276 | " \n", 277 | " self.bn1 = nn.BatchNorm1d(784)\n", 278 | " \n", 279 | " \n", 280 | "# self.fc1 = BinarizeLinear(784, 2048*self.infl_ratio)\n", 281 | " \n", 282 | " \n", 283 | " self.conv1 = BinarizeConv2d(1, 32, kernel_size=3)\n", 284 | " self.mp1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 285 | " self.bn1 = nn.BatchNorm2d(32)\n", 286 | " self.htanh1 = nn.Hardtanh()\n", 287 | " \n", 288 | " self.conv2 = BinarizeConv2d(32, 64, kernel_size=3)\n", 289 | " self.mp2 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 290 | " self.bn2 = nn.BatchNorm2d(64)\n", 291 | " self.htanh2 = nn.Hardtanh()\n", 292 | " \n", 293 | " self.fc1 = BinarizeLinear(64*5*5, 512)\n", 294 | " self.bn3 = nn.BatchNorm1d(512)\n", 295 | " self.htanh3 = nn.Hardtanh()\n", 296 | " \n", 297 | " self.fc2 = nn.Linear(512, 10)\n", 298 | " self.pre3 = Preactivation1(10,10)\n", 299 | " self.pre4 = Preactivation2(10,10)\n", 300 | " self.sm = nn.Softmax(dim=1)\n", 301 | " \n", 302 | "\n", 303 | " \n", 304 | " def sample(mu, log_sigma2):\n", 305 | " eps = torch.randn(mu.shape[0], mu.shape[1])\n", 306 | " return mu + torch.exp(log_sigma2 / 2) * eps\n", 307 | " \n", 308 | " \n", 309 | " def forward(self, x):\n", 310 | "# x = x.view(-1, 28*28) #input[128,784]\n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " x = x.view(-1, 1,28,28)\n", 316 | " \n", 317 | " \n", 318 | " x = Binarize(x)\n", 319 | " x = self.conv1(x)\n", 320 | " \n", 321 | " \n", 322 | " x = self.mp1(x)\n", 323 | " x = self.bn1(x)\n", 324 | " \n", 325 | "# x = self.htanh1(x)\n", 326 | "# x = Binarize(x)\n", 327 | " \n", 328 | " x = self.conv2(x)\n", 329 | " x = self.mp2(x)\n", 330 | " x = self.bn2(x)\n", 331 | "# x = self.htanh2(x)\n", 332 | "# x = Binarize(x)\n", 333 | " \n", 334 | " x = x.view(x.size(0), -1)\n", 335 | " \n", 336 | " \n", 337 | " x = self.fc1(x)\n", 338 | " \n", 339 | " mu = self.pre1(x) \n", 340 | " log_sigma2 = self.pre2(x)\n", 341 | " \n", 342 | "# eps = torch.randn(mu.shape[0], mu.shape[1])\n", 343 | " \n", 344 | "# z = mu + torch.exp(log_sigma2 / 2) * eps.cuda()\n", 345 | " \n", 346 | " \n", 347 | " m = mu.mean(1).unsqueeze(1)\n", 348 | " \n", 349 | "# print(m.shape)\n", 350 | " v = log_sigma2.var(1).unsqueeze(1)\n", 351 | " \n", 352 | " \n", 353 | " mu = 0.5*(mu-m)/(v+0.5).sqrt()+0.5\n", 354 | " log_sigma2 = 0.5**2*log_sigma2/(v+0.5)\n", 355 | " \n", 356 | " \n", 357 | " x = Normal(mu,log_sigma2)\n", 358 | " \n", 359 | " p = 1 - x.cdf(0)\n", 360 | "# print(x[1])\n", 361 | "\n", 362 | "# x = Binarize(x)\n", 363 | " m = RelaxedBernoulli(torch.tensor([1.]).cuda(),p)\n", 364 | " x = m.sample()\n", 365 | " \n", 366 | "# print(x[1])\n", 367 | " \n", 368 | "# print(x.shape)\n", 369 | "# x = Binarize(x)\n", 370 | "# print(x)\n", 371 | " \n", 372 | "# x = self.bn3(x)\n", 373 | " \n", 374 | " \n", 375 | "# x = self.htanh3(x)\n", 376 | " \n", 377 | "# print(x[0])\n", 378 | "# x = Binarize(x)\n", 379 | " \n", 380 | " x = self.fc2(x)\n", 381 | " \n", 382 | "# fm2 = self.pre3(x)\n", 383 | " \n", 384 | "# fv2 = self.pre4(x)\n", 385 | " \n", 386 | "# m2 = fm2.mean(1).unsqueeze(1)\n", 387 | " \n", 388 | "\n", 389 | "# v2 = fv2.var(1).unsqueeze(1)\n", 390 | " \n", 391 | " \n", 392 | "# x = 0.5*(fm2-m2)/(v2+0.5).sqrt()+0.5\n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | "\n", 397 | " return self.sm(x)\n", 398 | " \n", 399 | "\n", 400 | "\n", 401 | "model = Net()\n", 402 | "torch.cuda.device('cuda')\n", 403 | "model.cuda()\n", 404 | "\n", 405 | "\n", 406 | "criterion = nn.CrossEntropyLoss()\n", 407 | "optimizer = optim.Adam(model.parameters(), lr=0.01)\n", 408 | "\n", 409 | "\n", 410 | "def train(epoch):\n", 411 | " model.train()\n", 412 | " \n", 413 | " losses = []\n", 414 | " trainloader = tqdm(train_loader)\n", 415 | " \n", 416 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 417 | " \n", 418 | " data, target = data.cuda(), target.cuda()\n", 419 | " data, target = Variable(data), Variable(target)\n", 420 | " optimizer.zero_grad()\n", 421 | " output = model(data)\n", 422 | " loss = criterion(output, target)\n", 423 | "\n", 424 | "# if epoch%40==0:\n", 425 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 426 | "\n", 427 | "# optimizer.zero_grad()\n", 428 | " \n", 429 | " loss.backward()\n", 430 | " \n", 431 | " for p in list(model.parameters()):\n", 432 | " if hasattr(p,'org'):\n", 433 | " p.data.copy_(p.org)\n", 434 | " optimizer.step()\n", 435 | " \n", 436 | " for p in list(model.parameters()):\n", 437 | " if hasattr(p,'org'):\n", 438 | " p.org.copy_(p.data.clamp_(-1,1))\n", 439 | " \n", 440 | " losses.append(loss.item())\n", 441 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 442 | "\n", 443 | "\n", 444 | "\n", 445 | "def test():\n", 446 | " model.eval()\n", 447 | " test_loss = 0\n", 448 | " correct = 0\n", 449 | " testloader = tqdm(test_loader)\n", 450 | " for data, target in testloader:\n", 451 | " data, target = data.cuda(), target.cuda()\n", 452 | " with torch.no_grad():\n", 453 | " data = Variable(data)\n", 454 | " target = Variable(target)\n", 455 | " output = model(data)\n", 456 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 457 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 458 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 459 | " \n", 460 | " \n", 461 | "\n", 462 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 463 | " \n", 464 | " test_loss /= len(test_loader.dataset)\n", 465 | " \n", 466 | " \n", 467 | "\n", 468 | "\n", 469 | "for epoch in range(5):\n", 470 | " train(epoch)\n", 471 | " test()" 472 | ], 473 | "execution_count": 0, 474 | "outputs": [ 475 | { 476 | "output_type": "stream", 477 | "text": [ 478 | "100%|██████████| 469/469 [00:10<00:00, 43.58it/s, epoch=0, loss=1.81]\n", 479 | "100%|██████████| 79/79 [00:01<00:00, 47.26it/s, acc=71%, loss=0.0139]\n", 480 | "100%|██████████| 469/469 [00:10<00:00, 44.15it/s, epoch=1, loss=1.75]\n", 481 | "100%|██████████| 79/79 [00:01<00:00, 46.77it/s, acc=72%, loss=0.0138]\n", 482 | "100%|██████████| 469/469 [00:12<00:00, 38.95it/s, epoch=2, loss=1.75]\n", 483 | "100%|██████████| 79/79 [00:02<00:00, 36.25it/s, acc=71%, loss=0.0138]\n", 484 | "100%|██████████| 469/469 [00:10<00:00, 43.25it/s, epoch=3, loss=1.75]\n", 485 | "100%|██████████| 79/79 [00:01<00:00, 46.26it/s, acc=72%, loss=0.0137]\n", 486 | "100%|██████████| 469/469 [00:10<00:00, 44.24it/s, epoch=4, loss=1.75]\n", 487 | "100%|██████████| 79/79 [00:01<00:00, 47.30it/s, acc=72%, loss=0.0137]\n" 488 | ], 489 | "name": "stderr" 490 | } 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "metadata": { 496 | "id": "bxgRGAoMbkCf", 497 | "colab_type": "code", 498 | "colab": {} 499 | }, 500 | "source": [ 501 | "" 502 | ], 503 | "execution_count": 0, 504 | "outputs": [] 505 | } 506 | ] 507 | } -------------------------------------------------------------------------------- /Old Trials/pbnn_mnist_v7.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pbnn_mnist_v7.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "import time\n", 33 | "from torch.distributions.relaxed_bernoulli import RelaxedBernoulli,LogitRelaxedBernoulli\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "\n", 37 | "def Binarize(tensor,quant_mode='det'):\n", 38 | " if quant_mode=='det':\n", 39 | " tensor = tensor.sign()\n", 40 | " tensor[tensor==0] = 1\n", 41 | " return tensor\n", 42 | " else:\n", 43 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 44 | "\n", 45 | "def sample2(mu, log_sigma2):\n", 46 | " eps = torch.randn_like(mu)\n", 47 | " s = mu + torch.exp(log_sigma2 / 2) * eps\n", 48 | " return s\n", 49 | " \n", 50 | " \n", 51 | "def sample_gumbel(shape, eps=1e-20):\n", 52 | " unif = torch.rand(*shape).cuda()\n", 53 | " g = -torch.log(-torch.log(unif + eps))\n", 54 | " return g\n", 55 | "\n", 56 | "def sample_gumbel_softmax(logits, temperature):\n", 57 | " \"\"\"\n", 58 | " Input:\n", 59 | " logits: Tensor of log probs, shape = BS x k\n", 60 | " temperature = scalar\n", 61 | " \n", 62 | " Output: Tensor of values sampled from Gumbel softmax.\n", 63 | " These will tend towards a one-hot representation in the limit of temp -> 0\n", 64 | " shape = BS x k\n", 65 | " \"\"\"\n", 66 | " g = sample_gumbel(logits.shape)\n", 67 | " h = (g + logits)/temperature\n", 68 | " h_max = h.max(dim=-1, keepdim=True)[0]\n", 69 | " h = h - h_max\n", 70 | " cache = torch.exp(h)\n", 71 | " y = cache / cache.sum(dim=-1, keepdim=True)\n", 72 | " return y\n", 73 | " \n", 74 | "def sampling(mu,sig):\n", 75 | " x = Normal(mu,sig)\n", 76 | "# x = x.sample(torch.tensor([out_features]))\n", 77 | "# print(x.cdf)\n", 78 | " p = 1 - x.cdf(0)\n", 79 | "# print((x.cdf(0))[0])\n", 80 | "# p = Binarize(p)\n", 81 | "# print(p[0])\n", 82 | "# a = ((p+1)/2).bernoulli()\n", 83 | "# a = a*2-1\n", 84 | "# # print(a[0])\n", 85 | "# a = torch.nn.functional.gumbel_softmax(p, tau=1, hard=True, eps=1e-10, dim=-1)\n", 86 | "# \n", 87 | "# l = LogitRelaxedBernoulli(torch.tensor([1.]).cuda(),p)\n", 88 | "# l = l.sample()\n", 89 | "# a = sample_gumbel_softmax(p,1.0)\n", 90 | "# print(x[0]) \n", 91 | " return p\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | "import torch.nn._functions as tnnf\n", 96 | "\n", 97 | "\n", 98 | "\n", 99 | " \n", 100 | "class PBinarizeLinear(nn.Linear):\n", 101 | "\n", 102 | " def __init__(self, *kargs, **kwargs):\n", 103 | " super(PBinarizeLinear, self).__init__(*kargs, **kwargs)\n", 104 | "# w = torch.empty_like(self.weight)\n", 105 | "# self.weight.data = nn.init.uniform_(w,-1,1)\n", 106 | "# theta.requires_grad_\n", 107 | "# self.weight.data = ((theta+1)/2).bernoulli()\n", 108 | "# self.weight.data = Binarize(self.weight.data-0.5)\n", 109 | "# self.weight.data = Binarize(theta)\n", 110 | " \n", 111 | "\n", 112 | " def forward(self, input):\n", 113 | "# print(input.data[0])\n", 114 | " \n", 115 | " \n", 116 | " if not hasattr(self.weight,'org'):\n", 117 | " self.weight.org=self.weight.data.clone() \n", 118 | " \n", 119 | " self.weight.data=Binarize(self.weight.org)\n", 120 | "# print(self.weight.data)\n", 121 | "# print(self.weight.org)\n", 122 | "# theta = self.weight\n", 123 | " theta = torch.tanh(self.weight)\n", 124 | "# print(theta)\n", 125 | "# print(input[0])\n", 126 | " \n", 127 | "\n", 128 | "# print(input[0])\n", 129 | " if input.size(1) != 784:\n", 130 | " mu = nn.functional.linear(input,theta)\n", 131 | " left = input**2 - (1- input**2)\n", 132 | " right = theta**2 - (1-theta**2)\n", 133 | " sigma = 1 - nn.functional.linear(left,right)\n", 134 | " else:\n", 135 | "# print((input**2)[0])\n", 136 | "# print((1-(theta**2))[0])\n", 137 | " mu = nn.functional.linear(input,theta) \n", 138 | " sigma = nn.functional.linear(input**2,1-(theta**2))\n", 139 | " \n", 140 | "# \n", 141 | "# print(mu.shape)\n", 142 | " m = mu.mean(0,True)\n", 143 | " \n", 144 | " v = sigma.var(0,True)\n", 145 | " \n", 146 | " mu = 0.5*(mu-m)/((v+(0.0001)).sqrt()+0.5)\n", 147 | " sigma = 0.5**2*sigma/(v+0.0001)\n", 148 | "\n", 149 | " \n", 150 | " \n", 151 | " out1 = sampling(mu,sigma)\n", 152 | "\n", 153 | " if self.out_features==10:\n", 154 | " return mu\n", 155 | " else:\n", 156 | " return out1\n", 157 | "\n", 158 | "\n", 159 | "\n", 160 | "\n", 161 | "class PBinarizeConv2d(nn.Conv2d):\n", 162 | "\n", 163 | " def __init__(self, *kargs, **kwargs):\n", 164 | " super(PBinarizeConv2d, self).__init__(*kargs, **kwargs)\n", 165 | " \n", 166 | "\n", 167 | " def forward(self, input):\n", 168 | " \n", 169 | " if not hasattr(self.weight,'org'):\n", 170 | " self.weight.org=self.weight.data.clone() \n", 171 | " \n", 172 | " self.weight.data=Binarize(self.weight.org)\n", 173 | " \n", 174 | " theta = torch.tanh(self.weight)\n", 175 | " \n", 176 | "\n", 177 | " if input.size(1) != 3:\n", 178 | " mu = nn.functional.conv2d(input, theta, None, self.stride,\n", 179 | " self.padding, self.dilation, self.groups)\n", 180 | " left = input**2 - (1- input**2)\n", 181 | " right = theta**2 - (1-theta**2)\n", 182 | " sigma = 1 - nn.functional.conv2d(left, right, None, self.stride,\n", 183 | " self.padding, self.dilation, self.groups)\n", 184 | " else:\n", 185 | " mu = nn.functional.conv2d(input, theta, None, self.stride,\n", 186 | " self.padding, self.dilation, self.groups)\n", 187 | " sigma = nn.functional.conv2d(input**2, 1-(theta**2), None, self.stride,\n", 188 | " self.padding, self.dilation, self.groups)\n", 189 | " \n", 190 | "# print(mu.shape)\n", 191 | " m = mu.mean((0,2,3),True)\n", 192 | " \n", 193 | " v = sigma.var((0,2,3)).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)\n", 194 | " \n", 195 | " mu = 0.5*(mu-m)/((v+(0.0001)).sqrt()+0.5)\n", 196 | " sigma = 0.5**2*sigma/(v+0.0001)\n", 197 | " \n", 198 | " out1 = sampling(mu,sigma)\n", 199 | "\n", 200 | "\n", 201 | "# if not self.bias is None:\n", 202 | "# self.bias.org=self.bias.data.clone()\n", 203 | "# out += self.bias.view(1, -1, 1, 1).expand_as(out)\n", 204 | " \n", 205 | " \n", 206 | " return out1\n" 207 | ], 208 | "execution_count": 0, 209 | "outputs": [] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "metadata": { 214 | "id": "ggHwTvf4MRqC", 215 | "colab_type": "code", 216 | "outputId": "c2a52211-c643-4d72-86c3-38350e5505ad", 217 | "colab": { 218 | "base_uri": "https://localhost:8080/", 219 | "height": 155 220 | } 221 | }, 222 | "source": [ 223 | "from __future__ import print_function\n", 224 | "import argparse\n", 225 | "import torch\n", 226 | "import torch.nn as nn\n", 227 | "import torch.nn.functional as F\n", 228 | "import torch.optim as optim\n", 229 | "from torchvision import datasets, transforms\n", 230 | "from torch.autograd import Variable\n", 231 | "from tqdm import tqdm\n", 232 | "from torch.distributions.normal import Normal\n", 233 | "from torch.distributions.relaxed_bernoulli import RelaxedBernoulli\n", 234 | "from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical\n", 235 | "\n", 236 | "from torch.distributions.categorical import Categorical\n", 237 | "\n", 238 | "torch.manual_seed(1)\n", 239 | "# if args.cuda:\n", 240 | "# torch.cuda.manual_seed(args.seed)\n", 241 | "\n", 242 | "\n", 243 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 244 | "train_loader = torch.utils.data.DataLoader(\n", 245 | " datasets.MNIST('../data', train=True, download=True,\n", 246 | " transform=transforms.Compose([\n", 247 | " transforms.ToTensor(),\n", 248 | " transforms.Normalize((0.1307,), (0.3081,))\n", 249 | " ])),\n", 250 | " batch_size=128, shuffle=True)\n", 251 | "test_loader = torch.utils.data.DataLoader(\n", 252 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 253 | " transforms.ToTensor(),\n", 254 | " transforms.Normalize((0.1307,), (0.3081,))\n", 255 | " ])),\n", 256 | " batch_size=128, shuffle=True)\n", 257 | "\n", 258 | "\n", 259 | "\n", 260 | "\n", 261 | "\n", 262 | "# 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10c\n", 263 | "class Net(nn.Module):\n", 264 | " def __init__(self):\n", 265 | " super(Net, self).__init__()\n", 266 | " \n", 267 | " self.conv1 = PBinarizeConv2d(1, 32, kernel_size=3)\n", 268 | " self.mp1= nn.MaxPool2d(kernel_size=2, stride=2)\n", 269 | " \n", 270 | " self.conv2 = PBinarizeConv2d(32, 64, kernel_size=3)\n", 271 | " self.mp2= nn.MaxPool2d(kernel_size=2, stride=2)\n", 272 | " \n", 273 | " self.fc1 = PBinarizeLinear(36864, 512)\n", 274 | "\n", 275 | " \n", 276 | " self.fc2 = PBinarizeLinear(512, 10)\n", 277 | "\n", 278 | "\n", 279 | " # 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10c\n", 280 | " \n", 281 | " def forward(self, x):\n", 282 | " \n", 283 | "# print(x.shape)\n", 284 | " \n", 285 | " x = self.conv1(x)\n", 286 | " x = self.conv2(x)\n", 287 | "\n", 288 | " x = x.view(x.size(0), -1)\n", 289 | "# print(x.size())\n", 290 | "\n", 291 | " x = self.fc1(x)\n", 292 | "\n", 293 | " x = self.fc2(x)\n", 294 | "\n", 295 | "\n", 296 | " return x\n", 297 | " \n", 298 | "\n", 299 | "model = Net()\n", 300 | "\n", 301 | "print(model)\n", 302 | "\n", 303 | "torch.cuda.device('cuda')\n", 304 | "model.cuda()\n", 305 | "\n", 306 | "\n", 307 | "\n", 308 | "criterion = nn.CrossEntropyLoss()\n", 309 | "optimizer = optim.Adam(model.parameters(), lr=0.01)\n", 310 | "\n", 311 | "\n", 312 | "def train(epoch):\n", 313 | " model.train()\n", 314 | " \n", 315 | " losses = []\n", 316 | " trainloader = tqdm(train_loader)\n", 317 | " \n", 318 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 319 | " \n", 320 | " data, target = data.cuda(), target.cuda()\n", 321 | " data, target = Variable(data), Variable(target)\n", 322 | " optimizer.zero_grad()\n", 323 | " output = model(data)\n", 324 | "# print(output)\n", 325 | "# output = output+1e-10\n", 326 | " loss = criterion(output, target)\n", 327 | "\n", 328 | "# print(loss)\n", 329 | "\n", 330 | "# if epoch%40==0:\n", 331 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 332 | "\n", 333 | "# optimizer.zero_grad()\n", 334 | "# \n", 335 | " loss.backward()\n", 336 | " \n", 337 | " for p in list(model.parameters()):\n", 338 | " if hasattr(p,'org'):\n", 339 | " p.data.copy_(p.org)\n", 340 | " optimizer.step()\n", 341 | " \n", 342 | " for p in list(model.parameters()):\n", 343 | " if hasattr(p,'org'):\n", 344 | " p.org.copy_(p.data.clamp_(-0.9,0.9))\n", 345 | " \n", 346 | " losses.append(loss.item())\n", 347 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 348 | "\n", 349 | "\n", 350 | "\n", 351 | "def test():\n", 352 | " model.eval()\n", 353 | " test_loss = 0\n", 354 | " correct = 0\n", 355 | " testloader = tqdm(test_loader)\n", 356 | " for data, target in testloader:\n", 357 | " data, target = data.cuda(), target.cuda()\n", 358 | " with torch.no_grad():\n", 359 | " data = Variable(data)\n", 360 | " target = Variable(target)\n", 361 | " output = model(data)\n", 362 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 363 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 364 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 365 | " \n", 366 | " \n", 367 | "\n", 368 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 369 | " \n", 370 | " test_loss /= len(test_loader.dataset)\n", 371 | " \n", 372 | " \n", 373 | "\n", 374 | "\n" 375 | ], 376 | "execution_count": 2, 377 | "outputs": [ 378 | { 379 | "output_type": "stream", 380 | "text": [ 381 | "Net(\n", 382 | " (conv1): PBinarizeConv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", 383 | " (mp1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 384 | " (conv2): PBinarizeConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", 385 | " (mp2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 386 | " (fc1): PBinarizeLinear(in_features=36864, out_features=512, bias=True)\n", 387 | " (fc2): PBinarizeLinear(in_features=512, out_features=10, bias=True)\n", 388 | ")\n" 389 | ], 390 | "name": "stdout" 391 | } 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "metadata": { 397 | "id": "bxgRGAoMbkCf", 398 | "colab_type": "code", 399 | "outputId": "2189e180-c534-43ff-c4ca-34372f785713", 400 | "colab": { 401 | "base_uri": "https://localhost:8080/", 402 | "height": 763 403 | } 404 | }, 405 | "source": [ 406 | "%%%time\n", 407 | "for epoch in range(20):\n", 408 | " train(epoch)\n", 409 | " test()" 410 | ], 411 | "execution_count": 3, 412 | "outputs": [ 413 | { 414 | "output_type": "stream", 415 | "text": [ 416 | "100%|██████████| 469/469 [00:38<00:00, 12.34it/s, epoch=0, loss=0.426]\n", 417 | "100%|██████████| 79/79 [00:02<00:00, 27.93it/s, acc=93%, loss=0.00174]\n", 418 | "100%|██████████| 469/469 [00:37<00:00, 12.64it/s, epoch=1, loss=0.135]\n", 419 | "100%|██████████| 79/79 [00:02<00:00, 27.86it/s, acc=95%, loss=0.00128]\n", 420 | "100%|██████████| 469/469 [00:37<00:00, 12.42it/s, epoch=2, loss=0.133]\n", 421 | "100%|██████████| 79/79 [00:02<00:00, 28.05it/s, acc=96%, loss=0.000873]\n", 422 | "100%|██████████| 469/469 [00:37<00:00, 12.65it/s, epoch=3, loss=0.115]\n", 423 | "100%|██████████| 79/79 [00:02<00:00, 28.18it/s, acc=95%, loss=0.00109]\n", 424 | "100%|██████████| 469/469 [00:37<00:00, 12.43it/s, epoch=4, loss=0.0959]\n", 425 | "100%|██████████| 79/79 [00:02<00:00, 28.01it/s, acc=96%, loss=0.000876]\n", 426 | "100%|██████████| 469/469 [00:36<00:00, 12.69it/s, epoch=5, loss=0.0714]\n", 427 | "100%|██████████| 79/79 [00:02<00:00, 28.11it/s, acc=96%, loss=0.000863]\n", 428 | "100%|██████████| 469/469 [00:37<00:00, 12.43it/s, epoch=6, loss=0.078]\n", 429 | "100%|██████████| 79/79 [00:02<00:00, 28.13it/s, acc=96%, loss=0.00077]\n", 430 | "100%|██████████| 469/469 [00:38<00:00, 12.28it/s, epoch=7, loss=0.0745]\n", 431 | "100%|██████████| 79/79 [00:02<00:00, 28.10it/s, acc=97%, loss=0.00084]\n", 432 | "100%|██████████| 469/469 [00:37<00:00, 12.44it/s, epoch=8, loss=0.0662]\n", 433 | "100%|██████████| 79/79 [00:02<00:00, 28.01it/s, acc=96%, loss=0.000852]\n", 434 | "100%|██████████| 469/469 [00:36<00:00, 12.70it/s, epoch=9, loss=0.0547]\n", 435 | "100%|██████████| 79/79 [00:02<00:00, 27.96it/s, acc=96%, loss=0.00081]\n", 436 | "100%|██████████| 469/469 [00:37<00:00, 12.45it/s, epoch=10, loss=0.0594]\n", 437 | "100%|██████████| 79/79 [00:02<00:00, 28.12it/s, acc=96%, loss=0.000806]\n", 438 | "100%|██████████| 469/469 [00:36<00:00, 12.68it/s, epoch=11, loss=0.072]\n", 439 | "100%|██████████| 79/79 [00:02<00:00, 28.04it/s, acc=96%, loss=0.00102]\n", 440 | "100%|██████████| 469/469 [00:37<00:00, 12.43it/s, epoch=12, loss=0.055]\n", 441 | "100%|██████████| 79/79 [00:02<00:00, 27.83it/s, acc=97%, loss=0.000707]\n", 442 | "100%|██████████| 469/469 [00:36<00:00, 12.69it/s, epoch=13, loss=0.0443]\n", 443 | "100%|██████████| 79/79 [00:02<00:00, 28.10it/s, acc=97%, loss=0.00064]\n", 444 | "100%|██████████| 469/469 [00:38<00:00, 12.24it/s, epoch=14, loss=0.0389]\n", 445 | "100%|██████████| 79/79 [00:03<00:00, 23.99it/s, acc=97%, loss=0.000602]\n", 446 | "100%|██████████| 469/469 [00:37<00:00, 12.55it/s, epoch=15, loss=0.0461]\n", 447 | "100%|██████████| 79/79 [00:02<00:00, 27.98it/s, acc=96%, loss=0.000925]\n", 448 | "100%|██████████| 469/469 [00:37<00:00, 12.42it/s, epoch=16, loss=0.0436]\n", 449 | "100%|██████████| 79/79 [00:02<00:00, 27.94it/s, acc=97%, loss=0.000662]\n", 450 | "100%|██████████| 469/469 [00:37<00:00, 12.67it/s, epoch=17, loss=0.0324]\n", 451 | "100%|██████████| 79/79 [00:02<00:00, 27.90it/s, acc=97%, loss=0.000677]\n", 452 | "100%|██████████| 469/469 [00:37<00:00, 12.42it/s, epoch=18, loss=0.0361]\n", 453 | "100%|██████████| 79/79 [00:02<00:00, 28.01it/s, acc=97%, loss=0.000744]\n", 454 | "100%|██████████| 469/469 [00:37<00:00, 12.66it/s, epoch=19, loss=0.0435]\n", 455 | "100%|██████████| 79/79 [00:02<00:00, 27.76it/s, acc=97%, loss=0.000696]" 456 | ], 457 | "name": "stderr" 458 | }, 459 | { 460 | "output_type": "stream", 461 | "text": [ 462 | "CPU times: user 9min 14s, sys: 3min 49s, total: 13min 3s\n", 463 | "Wall time: 13min 26s\n" 464 | ], 465 | "name": "stdout" 466 | }, 467 | { 468 | "output_type": "stream", 469 | "text": [ 470 | "\n" 471 | ], 472 | "name": "stderr" 473 | } 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "metadata": { 479 | "id": "DiDjIjU6Mf_y", 480 | "colab_type": "code", 481 | "colab": {} 482 | }, 483 | "source": [ 484 | "a = torch.rand(3,4)\n", 485 | "# a" 486 | ], 487 | "execution_count": 0, 488 | "outputs": [] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "metadata": { 493 | "id": "41XggYVRr8az", 494 | "colab_type": "code", 495 | "outputId": "80c30fd8-92d1-4d08-a1b0-cc27aab724d7", 496 | "colab": { 497 | "base_uri": "https://localhost:8080/", 498 | "height": 165 499 | } 500 | }, 501 | "source": [ 502 | "(a.var((0,2,3),True)).size()" 503 | ], 504 | "execution_count": 5, 505 | "outputs": [ 506 | { 507 | "output_type": "error", 508 | "ename": "IndexError", 509 | "evalue": "ignored", 510 | "traceback": [ 511 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 512 | "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", 513 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 514 | "\u001b[0;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-2, 1], but got 2)" 515 | ] 516 | } 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "metadata": { 522 | "id": "IwxNnUxmQkqj", 523 | "colab_type": "code", 524 | "colab": {} 525 | }, 526 | "source": [ 527 | "(a.var((0)).unsqueeze(-1)).size()" 528 | ], 529 | "execution_count": 0, 530 | "outputs": [] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "metadata": { 535 | "id": "sphyNMNrfwQz", 536 | "colab_type": "code", 537 | "colab": {} 538 | }, 539 | "source": [ 540 | "torch.ones_like(a)" 541 | ], 542 | "execution_count": 0, 543 | "outputs": [] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "metadata": { 548 | "id": "3GP1UM1ufwHF", 549 | "colab_type": "code", 550 | "colab": {} 551 | }, 552 | "source": [ 553 | "" 554 | ], 555 | "execution_count": 0, 556 | "outputs": [] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "metadata": { 561 | "id": "GaDZ92vonUN0", 562 | "colab_type": "code", 563 | "colab": {} 564 | }, 565 | "source": [ 566 | "a.bernoulli()" 567 | ], 568 | "execution_count": 0, 569 | "outputs": [] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "metadata": { 574 | "id": "8YWe1Ti0MyMt", 575 | "colab_type": "code", 576 | "colab": {} 577 | }, 578 | "source": [ 579 | "a=Binarize(a)\n", 580 | "a" 581 | ], 582 | "execution_count": 0, 583 | "outputs": [] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "metadata": { 588 | "id": "tqOWLfDtr-H2", 589 | "colab_type": "code", 590 | "colab": {} 591 | }, 592 | "source": [ 593 | "torch.nn.functional.gumbel_softmax(a, tau=1, hard=True, eps=1e-10, dim=-1)" 594 | ], 595 | "execution_count": 0, 596 | "outputs": [] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "metadata": { 601 | "id": "Qu8HiOmqjwdx", 602 | "colab_type": "code", 603 | "colab": {} 604 | }, 605 | "source": [ 606 | "(a+1)/2" 607 | ], 608 | "execution_count": 0, 609 | "outputs": [] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "metadata": { 614 | "id": "0hkIV7ztMVyZ", 615 | "colab_type": "code", 616 | "colab": {} 617 | }, 618 | "source": [ 619 | "a.tanh()" 620 | ], 621 | "execution_count": 0, 622 | "outputs": [] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "metadata": { 627 | "id": "9idSm857otNw", 628 | "colab_type": "code", 629 | "colab": {} 630 | }, 631 | "source": [ 632 | "a.mean(0)" 633 | ], 634 | "execution_count": 0, 635 | "outputs": [] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "metadata": { 640 | "id": "cUGDVCa1pZes", 641 | "colab_type": "code", 642 | "colab": {} 643 | }, 644 | "source": [ 645 | "p= np.count_nonzero((a+1)/2,axis=0)/np.count_nonzero(a,axis=0)" 646 | ], 647 | "execution_count": 0, 648 | "outputs": [] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "metadata": { 653 | "id": "NJMpKyyqIkdh", 654 | "colab_type": "code", 655 | "colab": {} 656 | }, 657 | "source": [ 658 | "1-(p)**2" 659 | ], 660 | "execution_count": 0, 661 | "outputs": [] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "metadata": { 666 | "id": "1m-AodLmu8D3", 667 | "colab_type": "code", 668 | "colab": {} 669 | }, 670 | "source": [ 671 | "np.count_nonzero((a+1)/2)" 672 | ], 673 | "execution_count": 0, 674 | "outputs": [] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "metadata": { 679 | "id": "xwOCJG610Shm", 680 | "colab_type": "code", 681 | "colab": {} 682 | }, 683 | "source": [ 684 | "mu = torch.randn(5)\n", 685 | "sig = torch.randn(5)\n" 686 | ], 687 | "execution_count": 0, 688 | "outputs": [] 689 | }, 690 | { 691 | "cell_type": "code", 692 | "metadata": { 693 | "id": "9_oUI0eo0XgN", 694 | "colab_type": "code", 695 | "colab": {} 696 | }, 697 | "source": [ 698 | "x=Normal(mu,sig)\n", 699 | "1 - x.cdf(0)" 700 | ], 701 | "execution_count": 0, 702 | "outputs": [] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "metadata": { 707 | "id": "0jw0mT8duoik", 708 | "colab_type": "code", 709 | "colab": {} 710 | }, 711 | "source": [ 712 | "m = -4.6\n", 713 | "v = 25936\n", 714 | "x = Normal(m,v)\n", 715 | "p = 1 - x.cdf(0)\n", 716 | "# s = sample_gumbel_softmax(p,1.0)\n", 717 | "\n" 718 | ], 719 | "execution_count": 0, 720 | "outputs": [] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "metadata": { 725 | "id": "amjSvoApwIVS", 726 | "colab_type": "code", 727 | "colab": {} 728 | }, 729 | "source": [ 730 | "w = torch.empty(3, 5)\n", 731 | "nn.init.uniform_(w,-1,1)" 732 | ], 733 | "execution_count": 0, 734 | "outputs": [] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "metadata": { 739 | "id": "2cQjhRs_wIPo", 740 | "colab_type": "code", 741 | "colab": {} 742 | }, 743 | "source": [ 744 | "p.sample()" 745 | ], 746 | "execution_count": 0, 747 | "outputs": [] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "metadata": { 752 | "id": "0AZEDswQu1TZ", 753 | "colab_type": "code", 754 | "colab": {} 755 | }, 756 | "source": [ 757 | "aa = Normal(m,v)\n", 758 | "# aa.sample(torch.tensor([20]))" 759 | ], 760 | "execution_count": 0, 761 | "outputs": [] 762 | }, 763 | { 764 | "cell_type": "code", 765 | "metadata": { 766 | "id": "zJwtW9ld5_tf", 767 | "colab_type": "code", 768 | "colab": {} 769 | }, 770 | "source": [ 771 | "" 772 | ], 773 | "execution_count": 0, 774 | "outputs": [] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "metadata": { 779 | "id": "b5baFgYuu39G", 780 | "colab_type": "code", 781 | "colab": {} 782 | }, 783 | "source": [ 784 | "1 - x.cdf(0)" 785 | ], 786 | "execution_count": 0, 787 | "outputs": [] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "metadata": { 792 | "id": "xAMjkieNIlS-", 793 | "colab_type": "code", 794 | "colab": {} 795 | }, 796 | "source": [ 797 | "m = RelaxedOneHotCategorical(torch.tensor([1.]),a)\n", 798 | "m.sample()" 799 | ], 800 | "execution_count": 0, 801 | "outputs": [] 802 | }, 803 | { 804 | "cell_type": "code", 805 | "metadata": { 806 | "id": "kEYyrZzMZkuJ", 807 | "colab_type": "code", 808 | "colab": {} 809 | }, 810 | "source": [ 811 | "\"\"# # Execute this code block to install dependencies when running on colab\n", 812 | "# try:\n", 813 | "# import torch\n", 814 | "# except:\n", 815 | "# from os.path import exists\n", 816 | "# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", 817 | "# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", 818 | "# cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\\.\\([0-9]*\\)\\.\\([0-9]*\\)$/cu\\1\\2/'\n", 819 | "# accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'\n", 820 | "\n", 821 | "# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision\n", 822 | "\n", 823 | "# try: \n", 824 | "# import torchbearer\n", 825 | "# except:\n", 826 | "# !pip install torchbearer\n", 827 | " \n", 828 | "# from torchbearer import Trial\n", 829 | "# torchbearer_trial = Trial(model, optimizer, criterion, metrics=['loss', 'accuracy']).to('cuda:0')\n", 830 | "# torchbearer_trial.with_generators(train_loader, test_generator=test_loader)\n", 831 | "# torchbearer_trial.run(epochs=5)" 832 | ], 833 | "execution_count": 0, 834 | "outputs": [] 835 | } 836 | ] 837 | } -------------------------------------------------------------------------------- /Old Trials/pbnn_mnist_v3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pbnn_mnist_v3.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "import time\n", 33 | "from torch.distributions.relaxed_bernoulli import RelaxedBernoulli\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "\n", 37 | "def Binarize(tensor,quant_mode='det'):\n", 38 | " if quant_mode=='det':\n", 39 | " return tensor.sign()\n", 40 | " else:\n", 41 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 42 | "\n", 43 | "def sample2(mu, log_sigma2):\n", 44 | " eps = torch.randn_like(mu)\n", 45 | " s = mu + torch.exp(log_sigma2 / 2) * eps\n", 46 | " return s\n", 47 | " \n", 48 | " \n", 49 | "\n", 50 | "def sampling(mu,sig):\n", 51 | " x = Normal(mu,sig)\n", 52 | "# x = x.sample(torch.tensor([out_features]))\n", 53 | " p = 1 - x.cdf(0)\n", 54 | " p = Binarize(p)\n", 55 | " m = RelaxedBernoulli(torch.tensor([1.]).cuda(),p)\n", 56 | " x = m.sample()\n", 57 | " return x\n", 58 | "\n", 59 | "\n", 60 | "\n", 61 | "import torch.nn._functions as tnnf\n", 62 | "\n", 63 | "\n", 64 | "\n", 65 | " \n", 66 | "# def pre(mu,sigma):\n", 67 | " \n", 68 | "# m = mu.mean(1).unsqueeze(1)\n", 69 | "# v = sigma.var(1).unsqueeze(1)\n", 70 | "# mu = 0.5*(mu-m)/((v+(1e-6)).sqrt()+0.5)\n", 71 | "# sigma = 0.5**2*sigma/(v+0.5)\n", 72 | " \n", 73 | "# return mu,sigma\n", 74 | "\n", 75 | " \n", 76 | "class PBinarizeLinear(nn.Linear):\n", 77 | "\n", 78 | " def __init__(self, *kargs, **kwargs):\n", 79 | " super(PBinarizeLinear, self).__init__(*kargs, **kwargs)\n", 80 | " \n", 81 | "\n", 82 | " def forward(self, input):\n", 83 | " \n", 84 | "\n", 85 | " if input.size(1) != 784:\n", 86 | " input.data=Binarize(input.data)\n", 87 | " \n", 88 | " if not hasattr(self.weight,'org'):\n", 89 | " self.weight.org=self.weight.data.clone() \n", 90 | " \n", 91 | " self.weight.data=Binarize(self.weight.org)\n", 92 | "# print(self.weight.shape)\n", 93 | " \n", 94 | "# p = np.count_nonzero((self.weight.data.cpu()+1)/2)/np.count_nonzero(self.weight.data.cpu())\n", 95 | "# print(p)\n", 96 | " \n", 97 | " mu = self.weight.data.std()/self.weight.data\n", 98 | " var = (self.weight.data - mu)**2*mu\n", 99 | "\n", 100 | " \n", 101 | " mu = nn.functional.linear(input,mu) \n", 102 | " sigma = nn.functional.linear(input**2,var)\n", 103 | " \n", 104 | "\n", 105 | " m = mu.mean()\n", 106 | " v = sigma.var()\n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " mu = 0.5*(mu-m)/((v+(0.5)).sqrt()+0.5)\n", 111 | " sigma = 0.5**2*sigma/(v+0.5)\n", 112 | " \n", 113 | "\n", 114 | " \n", 115 | " out1 = sampling(mu,sigma)\n", 116 | "# out = nn.functional.linear(input, self.weight)\n", 117 | "\n", 118 | "# if not self.bias is None:\n", 119 | "# self.bias.org=self.bias.data.clone()\n", 120 | "# out += self.bias.view(1, -1).expand_as(out)\n", 121 | "\n", 122 | " return out1\n", 123 | "\n", 124 | "\n", 125 | "\n", 126 | "\n" 127 | ], 128 | "execution_count": 0, 129 | "outputs": [] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "metadata": { 134 | "id": "ggHwTvf4MRqC", 135 | "colab_type": "code", 136 | "outputId": "41b02be1-bffe-45c6-fce5-60c7af62b9fa", 137 | "colab": { 138 | "base_uri": "https://localhost:8080/", 139 | "height": 86 140 | } 141 | }, 142 | "source": [ 143 | "from __future__ import print_function\n", 144 | "import argparse\n", 145 | "import torch\n", 146 | "import torch.nn as nn\n", 147 | "import torch.nn.functional as F\n", 148 | "import torch.optim as optim\n", 149 | "from torchvision import datasets, transforms\n", 150 | "from torch.autograd import Variable\n", 151 | "from tqdm import tqdm\n", 152 | "from torch.distributions.normal import Normal\n", 153 | "from torch.distributions.relaxed_bernoulli import RelaxedBernoulli\n", 154 | "from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical\n", 155 | "\n", 156 | "from torch.distributions.categorical import Categorical\n", 157 | "\n", 158 | "torch.manual_seed(1)\n", 159 | "# if args.cuda:\n", 160 | "# torch.cuda.manual_seed(args.seed)\n", 161 | "\n", 162 | "\n", 163 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 164 | "train_loader = torch.utils.data.DataLoader(\n", 165 | " datasets.MNIST('../data', train=True, download=True,\n", 166 | " transform=transforms.Compose([\n", 167 | " transforms.ToTensor(),\n", 168 | " transforms.Normalize((0.1307,), (0.3081,))\n", 169 | " ])),\n", 170 | " batch_size=128, shuffle=True)\n", 171 | "test_loader = torch.utils.data.DataLoader(\n", 172 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 173 | " transforms.ToTensor(),\n", 174 | " transforms.Normalize((0.1307,), (0.3081,))\n", 175 | " ])),\n", 176 | " batch_size=128, shuffle=True)\n", 177 | "\n", 178 | "\n", 179 | "\n", 180 | "\n", 181 | "\n", 182 | "# 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10c\n", 183 | "class Net(nn.Module):\n", 184 | " def __init__(self):\n", 185 | " super(Net, self).__init__()\n", 186 | " \n", 187 | "# self.conv1 = PBinarizeConv2d(1, 32, kernel_size=3)\n", 188 | " \n", 189 | "# self.mp1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 190 | "# self.htanh1 = nn.Hardtanh()\n", 191 | " \n", 192 | "# self.conv2 = PBinarizeConv2d(32, 64, kernel_size=3)\n", 193 | "# self.mp2 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 194 | "# self.htanh2 = nn.Hardtanh()\n", 195 | " \n", 196 | " self.fc1 = PBinarizeLinear(784, 512)\n", 197 | "# self.htanh3 = nn.Hardtanh()\n", 198 | " \n", 199 | " self.fc2 = nn.Linear(512, 10)\n", 200 | "\n", 201 | "\n", 202 | " # 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10c\n", 203 | " \n", 204 | " def forward(self, x):\n", 205 | " \n", 206 | " \n", 207 | " x = x.view(x.size(0), -1)\n", 208 | "# print(x.size())\n", 209 | " \n", 210 | " x = self.fc1(x)\n", 211 | " \n", 212 | " x = self.fc2(x)\n", 213 | "\n", 214 | " \n", 215 | " return x\n", 216 | " \n", 217 | "\n", 218 | "\n", 219 | "model = Net()\n", 220 | "\n", 221 | "print(model)\n", 222 | "\n", 223 | "torch.cuda.device('cuda')\n", 224 | "model.cuda()\n", 225 | "\n", 226 | "\n", 227 | "criterion = nn.CrossEntropyLoss()\n", 228 | "optimizer = optim.Adam(model.parameters(), lr=0.01)\n", 229 | "\n", 230 | "\n", 231 | "def train(epoch):\n", 232 | " model.train()\n", 233 | " \n", 234 | " losses = []\n", 235 | " trainloader = tqdm(train_loader)\n", 236 | " \n", 237 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 238 | " \n", 239 | " data, target = data.cuda(), target.cuda()\n", 240 | " data, target = Variable(data), Variable(target)\n", 241 | " optimizer.zero_grad()\n", 242 | " output = model(data)\n", 243 | " loss = criterion(output, target)\n", 244 | "# print()\n", 245 | "# print(loss)\n", 246 | "\n", 247 | "# if epoch%40==0:\n", 248 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 249 | "\n", 250 | "# optimizer.zero_grad()\n", 251 | "# \n", 252 | " loss.backward()\n", 253 | " \n", 254 | " for p in list(model.parameters()):\n", 255 | " if hasattr(p,'org'):\n", 256 | " p.data.copy_(p.org)\n", 257 | " optimizer.step()\n", 258 | " \n", 259 | " for p in list(model.parameters()):\n", 260 | " if hasattr(p,'org'):\n", 261 | " p.org.copy_(p.data.clamp_(-0.9,0.9))\n", 262 | " \n", 263 | " losses.append(loss.item())\n", 264 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 265 | "\n", 266 | "\n", 267 | "\n", 268 | "def test():\n", 269 | " model.eval()\n", 270 | " test_loss = 0\n", 271 | " correct = 0\n", 272 | " testloader = tqdm(test_loader)\n", 273 | " for data, target in testloader:\n", 274 | " data, target = data.cuda(), target.cuda()\n", 275 | " with torch.no_grad():\n", 276 | " data = Variable(data)\n", 277 | " target = Variable(target)\n", 278 | " output = model(data)\n", 279 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 280 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 281 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 282 | " \n", 283 | " \n", 284 | "\n", 285 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 286 | " \n", 287 | " test_loss /= len(test_loader.dataset)\n", 288 | " \n", 289 | " \n", 290 | "\n", 291 | "\n" 292 | ], 293 | "execution_count": 2, 294 | "outputs": [ 295 | { 296 | "output_type": "stream", 297 | "text": [ 298 | "Net(\n", 299 | " (fc1): PBinarizeLinear(in_features=784, out_features=512, bias=True)\n", 300 | " (fc2): Linear(in_features=512, out_features=10, bias=True)\n", 301 | ")\n" 302 | ], 303 | "name": "stdout" 304 | } 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "metadata": { 310 | "id": "bxgRGAoMbkCf", 311 | "colab_type": "code", 312 | "outputId": "910e8b0e-0e7c-442b-8480-46e905368219", 313 | "colab": { 314 | "base_uri": "https://localhost:8080/", 315 | "height": 500 316 | } 317 | }, 318 | "source": [ 319 | "%%%time\n", 320 | "for epoch in range(5):\n", 321 | " train(epoch)\n", 322 | " test()" 323 | ], 324 | "execution_count": 3, 325 | "outputs": [ 326 | { 327 | "output_type": "stream", 328 | "text": [ 329 | "100%|██████████| 469/469 [00:17<00:00, 26.70it/s, epoch=0, loss=1.16]\n", 330 | "100%|██████████| 79/79 [00:02<00:00, 27.19it/s, acc=71%, loss=0.00672]\n", 331 | "100%|██████████| 469/469 [00:17<00:00, 26.44it/s, epoch=1, loss=0.795]\n", 332 | "100%|██████████| 79/79 [00:02<00:00, 27.40it/s, acc=75%, loss=0.00606]\n", 333 | "100%|██████████| 469/469 [00:17<00:00, 26.71it/s, epoch=2, loss=0.77]\n", 334 | "100%|██████████| 79/79 [00:02<00:00, 27.62it/s, acc=72%, loss=0.0067]\n", 335 | " 26%|██▌ | 120/469 [00:04<00:13, 26.81it/s, epoch=3, loss=0.739]" 336 | ], 337 | "name": "stderr" 338 | }, 339 | { 340 | "output_type": "error", 341 | "ename": "KeyboardInterrupt", 342 | "evalue": "ignored", 343 | "traceback": [ 344 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 345 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 346 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'time'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'for epoch in range(5):\\n train(epoch)\\n test()'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 347 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell_magic\u001b[0;34m(self, magic_name, line, cell)\u001b[0m\n\u001b[1;32m 2115\u001b[0m \u001b[0mmagic_arg_s\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvar_expand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2116\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2117\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmagic_arg_s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 348 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtime\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n", 349 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 188\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 350 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtime\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n\u001b[1;32m 1191\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mst\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclock2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1193\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglob\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocal_ns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1194\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclock2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1195\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 351 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n", 352 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0mtrainloader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 95\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 353 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 977\u001b[0m \"\"\", fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001b[1;32m 978\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 979\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 980\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 981\u001b[0m \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 354 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 561\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 355 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 561\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 356 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 95\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 96\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 357 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 358 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, tensor)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mNormalized\u001b[0m \u001b[0mTensor\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \"\"\"\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 359 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py\u001b[0m in \u001b[0;36mnormalize\u001b[0;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 204\u001b[0;31m \u001b[0mtensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 205\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 360 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 361 | ] 362 | } 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": { 368 | "id": "DiDjIjU6Mf_y", 369 | "colab_type": "code", 370 | "colab": {} 371 | }, 372 | "source": [ 373 | "a = torch.randn(5,4)\n", 374 | "a" 375 | ], 376 | "execution_count": 0, 377 | "outputs": [] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "metadata": { 382 | "id": "8YWe1Ti0MyMt", 383 | "colab_type": "code", 384 | "colab": {} 385 | }, 386 | "source": [ 387 | "a=Binarize(a)\n", 388 | "a" 389 | ], 390 | "execution_count": 0, 391 | "outputs": [] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "metadata": { 396 | "id": "Qu8HiOmqjwdx", 397 | "colab_type": "code", 398 | "colab": {} 399 | }, 400 | "source": [ 401 | "(a+1)/2" 402 | ], 403 | "execution_count": 0, 404 | "outputs": [] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "metadata": { 409 | "id": "9idSm857otNw", 410 | "colab_type": "code", 411 | "colab": {} 412 | }, 413 | "source": [ 414 | "torch.mean(a, (0), True)" 415 | ], 416 | "execution_count": 0, 417 | "outputs": [] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "cUGDVCa1pZes", 423 | "colab_type": "code", 424 | "colab": {} 425 | }, 426 | "source": [ 427 | "p= np.count_nonzero((a+1)/2,axis=0)/np.count_nonzero(a,axis=0)" 428 | ], 429 | "execution_count": 0, 430 | "outputs": [] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "metadata": { 435 | "id": "NJMpKyyqIkdh", 436 | "colab_type": "code", 437 | "colab": {} 438 | }, 439 | "source": [ 440 | "1-(p)**2" 441 | ], 442 | "execution_count": 0, 443 | "outputs": [] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "metadata": { 448 | "id": "1m-AodLmu8D3", 449 | "colab_type": "code", 450 | "colab": {} 451 | }, 452 | "source": [ 453 | "np.count_nonzero((a+1)/2)" 454 | ], 455 | "execution_count": 0, 456 | "outputs": [] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "metadata": { 461 | "id": "xwOCJG610Shm", 462 | "colab_type": "code", 463 | "colab": {} 464 | }, 465 | "source": [ 466 | "mu = torch.randn(5)\n", 467 | "sig = torch.randn(5)\n" 468 | ], 469 | "execution_count": 0, 470 | "outputs": [] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "metadata": { 475 | "id": "9_oUI0eo0XgN", 476 | "colab_type": "code", 477 | "colab": {} 478 | }, 479 | "source": [ 480 | "x=Normal(mu,sig)\n", 481 | "1 - x.cdf(0)" 482 | ], 483 | "execution_count": 0, 484 | "outputs": [] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "metadata": { 489 | "id": "0jw0mT8duoik", 490 | "colab_type": "code", 491 | "colab": {} 492 | }, 493 | "source": [ 494 | "m = -4.6\n", 495 | "v = 25936\n", 496 | "x = Normal(m,v)\n", 497 | "p = 1 - x.cdf(0)\n", 498 | "# s = sample_gumbel_softmax(p,1.0)\n", 499 | "\n" 500 | ], 501 | "execution_count": 0, 502 | "outputs": [] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "metadata": { 507 | "id": "amjSvoApwIVS", 508 | "colab_type": "code", 509 | "colab": {} 510 | }, 511 | "source": [ 512 | "eps = torch.randn_like(mu)\n", 513 | " s = mu + torch.exp(log_sigma2 / 2) * eps" 514 | ], 515 | "execution_count": 0, 516 | "outputs": [] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "metadata": { 521 | "id": "2cQjhRs_wIPo", 522 | "colab_type": "code", 523 | "colab": {} 524 | }, 525 | "source": [ 526 | "p.sample()" 527 | ], 528 | "execution_count": 0, 529 | "outputs": [] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "metadata": { 534 | "id": "0AZEDswQu1TZ", 535 | "colab_type": "code", 536 | "colab": {} 537 | }, 538 | "source": [ 539 | "aa = Normal(m,v)\n", 540 | "# aa.sample(torch.tensor([20]))" 541 | ], 542 | "execution_count": 0, 543 | "outputs": [] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "metadata": { 548 | "id": "b5baFgYuu39G", 549 | "colab_type": "code", 550 | "colab": {} 551 | }, 552 | "source": [ 553 | "1 - x.cdf(0)" 554 | ], 555 | "execution_count": 0, 556 | "outputs": [] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "metadata": { 561 | "id": "xAMjkieNIlS-", 562 | "colab_type": "code", 563 | "colab": {} 564 | }, 565 | "source": [ 566 | "m = RelaxedOneHotCategorical(torch.tensor([1.]),a)\n", 567 | "m.sample()" 568 | ], 569 | "execution_count": 0, 570 | "outputs": [] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "metadata": { 575 | "id": "kEYyrZzMZkuJ", 576 | "colab_type": "code", 577 | "colab": {} 578 | }, 579 | "source": [ 580 | "\"\"# # Execute this code block to install dependencies when running on colab\n", 581 | "# try:\n", 582 | "# import torch\n", 583 | "# except:\n", 584 | "# from os.path import exists\n", 585 | "# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", 586 | "# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", 587 | "# cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\\.\\([0-9]*\\)\\.\\([0-9]*\\)$/cu\\1\\2/'\n", 588 | "# accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'\n", 589 | "\n", 590 | "# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision\n", 591 | "\n", 592 | "# try: \n", 593 | "# import torchbearer\n", 594 | "# except:\n", 595 | "# !pip install torchbearer\n", 596 | " \n", 597 | "# from torchbearer import Trial\n", 598 | "# torchbearer_trial = Trial(model, optimizer, criterion, metrics=['loss', 'accuracy']).to('cuda:0')\n", 599 | "# torchbearer_trial.with_generators(train_loader, test_generator=test_loader)\n", 600 | "# torchbearer_trial.run(epochs=5)" 601 | ], 602 | "execution_count": 0, 603 | "outputs": [] 604 | } 605 | ] 606 | } -------------------------------------------------------------------------------- /Old Trials/pbnn_cifar_v1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pBNN_cifar.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "import time\n", 33 | "\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "\n", 37 | "def Binarize(tensor,quant_mode='det'):\n", 38 | " if quant_mode=='det':\n", 39 | " return tensor.sign()\n", 40 | " else:\n", 41 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 42 | "\n", 43 | "\n", 44 | "\n", 45 | "\n", 46 | "class HingeLoss(nn.Module):\n", 47 | " def __init__(self):\n", 48 | " super(HingeLoss,self).__init__()\n", 49 | " self.margin=1.0\n", 50 | "\n", 51 | " def hinge_loss(self,input,target):\n", 52 | " #import pdb; pdb.set_trace()\n", 53 | " output=self.margin-input.mul(target)\n", 54 | " output[output.le(0)]=0\n", 55 | " return output.mean()\n", 56 | "\n", 57 | " def forward(self, input, target):\n", 58 | " return self.hinge_loss(input,target)\n", 59 | "\n", 60 | "class SqrtHingeLossFunction(Function):\n", 61 | " def __init__(self):\n", 62 | " super(SqrtHingeLossFunction,self).__init__()\n", 63 | " self.margin=1.0\n", 64 | "\n", 65 | " def forward(self, input, target):\n", 66 | " output=self.margin-input.mul(target)\n", 67 | " output[output.le(0)]=0\n", 68 | " self.save_for_backward(input, target)\n", 69 | " loss=output.mul(output).sum(0).sum(1).div(target.numel())\n", 70 | " return loss\n", 71 | "\n", 72 | " def backward(self,grad_output):\n", 73 | " input, target = self.saved_tensors\n", 74 | " output=self.margin-input.mul(target)\n", 75 | " output[output.le(0)]=0\n", 76 | " import pdb; pdb.set_trace()\n", 77 | " grad_output.resize_as_(input).copy_(target).mul_(-2).mul_(output)\n", 78 | " grad_output.mul_(output.ne(0).float())\n", 79 | " grad_output.div_(input.numel())\n", 80 | " return grad_output,grad_output\n", 81 | "\n", 82 | "def Quantize(tensor,quant_mode='det', params=None, numBits=8):\n", 83 | " tensor.clamp_(-2**(numBits-1),2**(numBits-1))\n", 84 | " if quant_mode=='det':\n", 85 | " tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))\n", 86 | " else:\n", 87 | " tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))\n", 88 | " quant_fixed(tensor, params)\n", 89 | " return tensor\n", 90 | "\n", 91 | "import torch.nn._functions as tnnf\n", 92 | "\n", 93 | "\n", 94 | "class BinarizeLinear(nn.Linear):\n", 95 | "\n", 96 | " def __init__(self, *kargs, **kwargs):\n", 97 | " super(BinarizeLinear, self).__init__(*kargs, **kwargs)\n", 98 | "\n", 99 | " def forward(self, input):\n", 100 | "\n", 101 | " if input.size(1) != 784:\n", 102 | " input.data=Binarize(input.data)\n", 103 | " if not hasattr(self.weight,'org'):\n", 104 | " self.weight.org=self.weight.data.clone()\n", 105 | " self.weight.data=Binarize(self.weight.org)\n", 106 | " out = nn.functional.linear(input, self.weight)\n", 107 | " if not self.bias is None:\n", 108 | " self.bias.org=self.bias.data.clone()\n", 109 | " out += self.bias.view(1, -1).expand_as(out)\n", 110 | "\n", 111 | " return out\n", 112 | "\n", 113 | "class BinarizeConv2d(nn.Conv2d):\n", 114 | "\n", 115 | " def __init__(self, *kargs, **kwargs):\n", 116 | " super(BinarizeConv2d, self).__init__(*kargs, **kwargs)\n", 117 | "\n", 118 | "\n", 119 | " def forward(self, input):\n", 120 | " if input.size(1) != 3:\n", 121 | " input.data = Binarize(input.data)\n", 122 | " if not hasattr(self.weight,'org'):\n", 123 | " self.weight.org=self.weight.data.clone()\n", 124 | " self.weight.data=Binarize(self.weight.org)\n", 125 | "\n", 126 | " out = nn.functional.conv2d(input, self.weight, None, self.stride,\n", 127 | " self.padding, self.dilation, self.groups)\n", 128 | "\n", 129 | " if not self.bias is None:\n", 130 | " self.bias.org=self.bias.data.clone()\n", 131 | " out += self.bias.view(1, -1, 1, 1).expand_as(out)\n", 132 | " \n", 133 | " \n", 134 | " return out\n" 135 | ], 136 | "execution_count": 0, 137 | "outputs": [] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "metadata": { 142 | "id": "ggHwTvf4MRqC", 143 | "colab_type": "code", 144 | "outputId": "e03312cf-8fc0-48dd-a1d5-897271d3bc87", 145 | "colab": { 146 | "base_uri": "https://localhost:8080/", 147 | "height": 548 148 | } 149 | }, 150 | "source": [ 151 | "\n", 152 | "%%time \n", 153 | "from __future__ import print_function\n", 154 | "import argparse\n", 155 | "import torch\n", 156 | "import torch.nn as nn\n", 157 | "import torch.nn.functional as F\n", 158 | "import torch.optim as optim\n", 159 | "from torchvision import datasets, transforms\n", 160 | "from torch.autograd import Variable\n", 161 | "from tqdm import tqdm\n", 162 | "\n", 163 | "# from models.binarized_modules import BinarizeLinear,BinarizeConv2d\n", 164 | "# from models.binarized_modules import Binarize,Ternarize,Ternarize2,Ternarize3,Ternarize4,HingeLoss\n", 165 | "# Training settings\n", 166 | "# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n", 167 | "# parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n", 168 | "# help='input batch size for training (default: 256)')\n", 169 | "# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n", 170 | "# help='input batch size for testing (default: 1000)')\n", 171 | "# parser.add_argument('--epochs', type=int, default=100, metavar='N',\n", 172 | "# help='number of epochs to train (default: 10)')\n", 173 | "# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',\n", 174 | "# help='learning rate (default: 0.001)')\n", 175 | "# parser.add_argument('--momentum', type=float, default=0.5, metavar='M',\n", 176 | "# help='SGD momentum (default: 0.5)')\n", 177 | "# parser.add_argument('--no-cuda', action='store_true', default=False,\n", 178 | "# help='disables CUDA training')\n", 179 | "# parser.add_argument('--seed', type=int, default=1, metavar='S',\n", 180 | "# help='random seed (default: 1)')\n", 181 | "# parser.add_argument('--gpus', default=3,\n", 182 | "# help='gpus used for training - e.g 0,1,3')\n", 183 | "# parser.add_argument('--log-interval', type=int, default=10, metavar='N',\n", 184 | "# help='how many batches to wait before logging training status')\n", 185 | "# args = parser.parse_args()\n", 186 | "# args.cuda = not args.no_cuda and torch.cuda.is_available()\n", 187 | "\n", 188 | "torch.manual_seed(1)\n", 189 | "# if args.cuda:\n", 190 | "# torch.cuda.manual_seed(args.seed)\n", 191 | "\n", 192 | "\n", 193 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 194 | "train_loader = torch.utils.data.DataLoader(\n", 195 | " datasets.CIFAR10('../data', train=True, download=True,\n", 196 | " transform=transforms.Compose([\n", 197 | "# transforms.RandomHorizontalFlip(),\n", 198 | " transforms.ToTensor(),\n", 199 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 200 | " \n", 201 | " ])),\n", 202 | " batch_size=128, shuffle=True)\n", 203 | "test_loader = torch.utils.data.DataLoader(\n", 204 | " datasets.CIFAR10('../data', train=False, transform=transforms.Compose([\n", 205 | "# transforms.RandomHorizontalFlip(),\n", 206 | "# transforms.RandomCrop(32,padding=4),\n", 207 | " transforms.ToTensor(),\n", 208 | " transforms.Normalize((0.1307,), (0.3081,))\n", 209 | " ])),\n", 210 | " batch_size=128, shuffle=True)\n", 211 | "\n", 212 | "\n", 213 | "class Net(nn.Module):\n", 214 | " def __init__(self):\n", 215 | " super(Net, self).__init__()\n", 216 | "# self.infl_ratio=3\n", 217 | "# self.fc1 = BinarizeLinear(1024, 2048*self.infl_ratio)\n", 218 | "# self.htanh1 = nn.Hardtanh()\n", 219 | "# self.bn1 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 220 | " \n", 221 | "# self.fc2 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 222 | "# self.htanh2 = nn.Hardtanh()\n", 223 | "# self.bn2 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 224 | "# self.fc3 = BinarizeLinear(2048*self.infl_ratio, 2048*self.infl_ratio)\n", 225 | "# self.htanh3 = nn.Hardtanh()\n", 226 | "# self.bn3 = nn.BatchNorm1d(2048*self.infl_ratio)\n", 227 | "# self.fc4 = nn.Linear(2048*self.infl_ratio, 10)\n", 228 | " \n", 229 | " \n", 230 | "# self.logsoftmax = nn.LogSoftmax(dim=1)\n", 231 | "# self.drop=nn.Dropout(0.5)\n", 232 | " \n", 233 | " \n", 234 | "# self.b1 = BinarizeConv2d(3, 128*self.infl_ratio, kernel_size=3, stride=1, padding=1,\n", 235 | "# bias=True)\n", 236 | "# self.b2 = nn.BatchNorm2d(128*self.infl_ratio)\n", 237 | "# self.b3 = nn.Hardtanh(inplace=True)\n", 238 | " \n", 239 | " \n", 240 | "# self.b4 = BinarizeConv2d(128*self.infl_ratio, 512, kernel_size=3, padding=1, bias=True)\n", 241 | "# self.b5 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 242 | "# self.b6 = nn.BatchNorm2d(512)\n", 243 | "# self.b7 = nn.Hardtanh(inplace=True)\n", 244 | " \n", 245 | " \n", 246 | "# self.bb1 = BinarizeLinear(384 * 16 * 16, 1024, bias=True)\n", 247 | "# self.bb2 = nn.BatchNorm1d(1024)\n", 248 | "# self.bb3 = nn.Hardtanh(inplace=True)\n", 249 | "# self.bb4 = BinarizeLinear(1024, 10, bias=True)\n", 250 | "# self.bb5 = nn.BatchNorm1d(10, affine=False)\n", 251 | " \n", 252 | " self.conv1 = BinarizeConv2d(3, 128, kernel_size=3)\n", 253 | " self.conv2 = BinarizeConv2d(128, 128, kernel_size=3)\n", 254 | " self.mp1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 255 | " \n", 256 | " self.bn1 = nn.BatchNorm2d(128)\n", 257 | " self.htanh1 = nn.Hardtanh()\n", 258 | " \n", 259 | " self.conv3 = BinarizeConv2d(128, 256, kernel_size=3)\n", 260 | " self.conv4 = BinarizeConv2d(256, 256, kernel_size=3)\n", 261 | " self.mp2 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 262 | " \n", 263 | " self.bn2 = nn.BatchNorm2d(256)\n", 264 | " self.htanh2 = nn.Hardtanh()\n", 265 | " \n", 266 | " self.conv5 = BinarizeConv2d(256, 512, kernel_size=3)\n", 267 | " self.conv6 = BinarizeConv2d(512, 512, kernel_size=3)\n", 268 | " self.mp3 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 269 | " \n", 270 | " self.bn3 = nn.BatchNorm2d(512)\n", 271 | " self.htanh3 = nn.Hardtanh()\n", 272 | " \n", 273 | " self.fc1 = BinarizeLinear(512, 1024)\n", 274 | " self.bn4 = nn.BatchNorm1d(1024)\n", 275 | " self.htanh4 = nn.Hardtanh()\n", 276 | " \n", 277 | " self.fc2 = nn.Linear(1024, 10)\n", 278 | " self.sm = nn.Softmax(dim=1)\n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | "\n", 286 | " def forward(self, x):\n", 287 | " \n", 288 | " x = self.conv1(x)\n", 289 | " \n", 290 | " x = self.conv2(x)\n", 291 | " x = self.mp1(x)\n", 292 | " x = self.bn1(x)\n", 293 | " x = Binarize(x)\n", 294 | " x = self.htanh1(x)\n", 295 | " x = Binarize(x)\n", 296 | " \n", 297 | " x = self.conv3(x)\n", 298 | " x = self.conv4(x)\n", 299 | " x = self.mp2(x)\n", 300 | " x = self.bn2(x)\n", 301 | " x = Binarize(x)\n", 302 | " x = self.htanh2(x)\n", 303 | " x = Binarize(x)\n", 304 | " \n", 305 | " x = self.conv5(x)\n", 306 | " x = self.conv6(x)\n", 307 | " x = self.mp3(x)\n", 308 | " x = self.bn3(x)\n", 309 | " x = Binarize(x)\n", 310 | " x = self.htanh3(x)\n", 311 | " x = Binarize(x)\n", 312 | " \n", 313 | " x = x.view(x.size(0), -1)\n", 314 | "# print(x.shape)\n", 315 | " \n", 316 | " x = self.fc1(x)\n", 317 | " x = self.bn4(x)\n", 318 | " x = Binarize(x)\n", 319 | " x = self.htanh4(x)\n", 320 | " x = Binarize(x)\n", 321 | " \n", 322 | " x = self.fc2(x)\n", 323 | "\n", 324 | " \n", 325 | " return self.sm(x)\n", 326 | " \n", 327 | "\n", 328 | "# return self.logsoftmax(x)\n", 329 | "\n", 330 | "model = Net()\n", 331 | "torch.cuda.device('cuda')\n", 332 | "model.cuda()\n", 333 | "\n", 334 | "\n", 335 | "criterion = nn.CrossEntropyLoss()\n", 336 | "optimizer = optim.Adam(model.parameters(), lr=0.01)\n", 337 | "\n", 338 | "\n", 339 | "def train(epoch):\n", 340 | " model.train()\n", 341 | " \n", 342 | " losses = []\n", 343 | " trainloader = tqdm(train_loader)\n", 344 | " \n", 345 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 346 | " \n", 347 | " data, target = data.cuda(), target.cuda()\n", 348 | " data, target = Variable(data), Variable(target)\n", 349 | " optimizer.zero_grad()\n", 350 | " output = model(data)\n", 351 | " loss = criterion(output, target)\n", 352 | "\n", 353 | "# if epoch%40==0:\n", 354 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 355 | "\n", 356 | "# optimizer.zero_grad()\n", 357 | " \n", 358 | " loss.backward()\n", 359 | " \n", 360 | " for p in list(model.parameters()):\n", 361 | " if hasattr(p,'org'):\n", 362 | " p.data.copy_(p.org)\n", 363 | " optimizer.step()\n", 364 | " \n", 365 | " for p in list(model.parameters()):\n", 366 | " if hasattr(p,'org'):\n", 367 | " p.org.copy_(p.data.clamp_(-1,1))\n", 368 | " \n", 369 | " losses.append(loss.item())\n", 370 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 371 | "# if batch_idx % 10000 == 0:\n", 372 | "# print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 373 | "# epoch, batch_idx * len(data), len(train_loader.dataset),\n", 374 | "# 100. * batch_idx / len(train_loader), loss.item()))\n", 375 | "\n", 376 | "def test():\n", 377 | " model.eval()\n", 378 | " test_loss = 0\n", 379 | " correct = 0\n", 380 | " testloader = tqdm(test_loader)\n", 381 | " for data, target in testloader:\n", 382 | " data, target = data.cuda(), target.cuda()\n", 383 | " with torch.no_grad():\n", 384 | " data = Variable(data)\n", 385 | " target = Variable(target)\n", 386 | " output = model(data)\n", 387 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 388 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 389 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 390 | " \n", 391 | " \n", 392 | "\n", 393 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 394 | " \n", 395 | " test_loss /= len(test_loader.dataset)\n", 396 | " \n", 397 | " \n", 398 | "# print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 399 | "# test_loss, correct, len(test_loader.dataset),\n", 400 | "# 100. * correct / len(test_loader.dataset)))\n", 401 | "\n", 402 | "\n", 403 | "for epoch in range(5):\n", 404 | " train(epoch)\n", 405 | " test()" 406 | ], 407 | "execution_count": 0, 408 | "outputs": [ 409 | { 410 | "output_type": "stream", 411 | "text": [ 412 | "Files already downloaded and verified\n" 413 | ], 414 | "name": "stdout" 415 | }, 416 | { 417 | "output_type": "stream", 418 | "text": [ 419 | "100%|██████████| 391/391 [00:25<00:00, 15.51it/s, epoch=0, loss=2.2]\n", 420 | "100%|██████████| 79/79 [00:03<00:00, 26.22it/s, acc=26%, loss=0.0173]\n", 421 | "100%|██████████| 391/391 [00:25<00:00, 15.50it/s, epoch=1, loss=2.18]\n", 422 | "100%|██████████| 79/79 [00:03<00:00, 25.94it/s, acc=26%, loss=0.0173]\n", 423 | " 29%|██▉ | 114/391 [00:07<00:17, 15.61it/s, epoch=2, loss=2.18]" 424 | ], 425 | "name": "stderr" 426 | }, 427 | { 428 | "output_type": "error", 429 | "ename": "KeyboardInterrupt", 430 | "evalue": "ignored", 431 | "traceback": [ 432 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 433 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 434 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'time'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"from __future__ import print_function\\nimport argparse\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport torch.optim as optim\\nfrom torchvision import datasets, transforms\\nfrom torch.autograd import Variable\\nfrom tqdm import tqdm\\n\\n# from models.binarized_modules import BinarizeLinear,BinarizeConv2d\\n# from models.binarized_modules import Binarize,Ternarize,Ternarize2,Ternarize3,Ternarize4,HingeLoss\\n# Training settings\\n# parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\\n# parser.add_argument('--batch-size', type=int, default=64, metavar='N',\\n# help='input batch size for training (default: 256)')\\n# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\\n# help='input batch size for testing (default: 1000)')\\n# parser.add_argument('--epochs', type=int, default=100, metavar='N',\\n# help='number of epochs to train (default: 10)')\\n# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',\\n# help='learning rate (default: 0.001)')\\n# parser.add_argument('--momentum', type=float, default=0.5, metavar='M',\\n# help='SGD momentum (default: 0.5)')\\n# parser.add_argument('--no-cuda', action='store_true', default=False,\\n# ...\n\u001b[0m", 435 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell_magic\u001b[0;34m(self, magic_name, line, cell)\u001b[0m\n\u001b[1;32m 2115\u001b[0m \u001b[0mmagic_arg_s\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvar_expand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2116\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2117\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmagic_arg_s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 436 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtime\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n", 437 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 188\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 438 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtime\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n\u001b[1;32m 1191\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[0mst\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclock2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1193\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglob\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocal_ns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1194\u001b[0m \u001b[0mend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclock2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1195\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 439 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n", 440 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch)\u001b[0m\n", 441 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 977\u001b[0m \"\"\", fp_write=getattr(self.fp, 'write', sys.stderr.write))\n\u001b[1;32m 978\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 979\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 980\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 981\u001b[0m \u001b[0;31m# Update and possibly print the progressbar.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 442 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 561\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 443 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 561\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 444 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/datasets/cifar.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 445 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, img)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 61\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 446 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, pic)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mConverted\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \"\"\"\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 447 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py\u001b[0m in \u001b[0;36mto_tensor\u001b[0;34m(pic)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;31m# put it from HWC to CHW format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# yikes, this transpose takes 80% of the loading time/CPU\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mByteTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m255\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 448 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 449 | ] 450 | } 451 | ] 452 | } 453 | ] 454 | } -------------------------------------------------------------------------------- /Old Trials/pbnn_mnist_v5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "pbnn_mnist_v5.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "3YM3-zGxLa63", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import pdb\n", 28 | "import torch.nn as nn\n", 29 | "import math\n", 30 | "from torch.autograd import Variable\n", 31 | "from torch.autograd import Function\n", 32 | "import time\n", 33 | "from torch.distributions.relaxed_bernoulli import RelaxedBernoulli,LogitRelaxedBernoulli\n", 34 | "import numpy as np\n", 35 | "\n", 36 | "\n", 37 | "def Binarize(tensor,quant_mode='det'):\n", 38 | " if quant_mode=='det':\n", 39 | " tensor = tensor.sign()\n", 40 | " tensor[tensor==0] = 1\n", 41 | " return tensor\n", 42 | " else:\n", 43 | " return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)\n", 44 | "\n", 45 | "def sample2(mu, log_sigma2):\n", 46 | " eps = torch.randn_like(mu)\n", 47 | " s = mu + torch.exp(log_sigma2 / 2) * eps\n", 48 | " return s\n", 49 | " \n", 50 | " \n", 51 | "def sample_gumbel(shape, eps=1e-20):\n", 52 | " unif = torch.rand(*shape).cuda()\n", 53 | " g = -torch.log(-torch.log(unif + eps))\n", 54 | " return g\n", 55 | "\n", 56 | "def sample_gumbel_softmax(logits, temperature):\n", 57 | " \"\"\"\n", 58 | " Input:\n", 59 | " logits: Tensor of log probs, shape = BS x k\n", 60 | " temperature = scalar\n", 61 | " \n", 62 | " Output: Tensor of values sampled from Gumbel softmax.\n", 63 | " These will tend towards a one-hot representation in the limit of temp -> 0\n", 64 | " shape = BS x k\n", 65 | " \"\"\"\n", 66 | " g = sample_gumbel(logits.shape)\n", 67 | " h = (g + logits)/temperature\n", 68 | " h_max = h.max(dim=-1, keepdim=True)[0]\n", 69 | " h = h - h_max\n", 70 | " cache = torch.exp(h)\n", 71 | " y = cache / cache.sum(dim=-1, keepdim=True)\n", 72 | " return y\n", 73 | " \n", 74 | "def sampling(mu,sig):\n", 75 | " x = Normal(mu,sig)\n", 76 | "# x = x.sample(torch.tensor([out_features]))\n", 77 | "# print(x.cdf)\n", 78 | " p = 1 - x.cdf(0)\n", 79 | "# print((x.cdf(0))[0])\n", 80 | "# p = Binarize(p)\n", 81 | "# print(p[0])\n", 82 | " a = ((p+1)/2).bernoulli()\n", 83 | " a = a*2-1\n", 84 | "# print(a[0])\n", 85 | " a = torch.nn.functional.gumbel_softmax(p, tau=1, hard=True, eps=1e-10, dim=-1)\n", 86 | "# \n", 87 | "# l = LogitRelaxedBernoulli(torch.tensor([1.]).cuda(),p)\n", 88 | "# l = l.sample()\n", 89 | "# a = sample_gumbel_softmax(p,1.0)\n", 90 | "# print(x[0]) \n", 91 | " return p\n", 92 | "\n", 93 | "\n", 94 | "\n", 95 | "import torch.nn._functions as tnnf\n", 96 | "\n", 97 | "\n", 98 | "\n", 99 | " \n", 100 | "class PBinarizeLinear(nn.Linear):\n", 101 | "\n", 102 | " def __init__(self, *kargs, **kwargs):\n", 103 | " super(PBinarizeLinear, self).__init__(*kargs, **kwargs)\n", 104 | "# w = torch.empty_like(self.weight)\n", 105 | "# self.weight.data = nn.init.uniform_(w,-1,1)\n", 106 | "# theta.requires_grad_\n", 107 | "# self.weight.data = ((theta+1)/2).bernoulli()\n", 108 | "# self.weight.data = Binarize(self.weight.data-0.5)\n", 109 | "# self.weight.data = Binarize(theta)\n", 110 | " \n", 111 | "\n", 112 | " def forward(self, input):\n", 113 | "# print(input.data[0])\n", 114 | " \n", 115 | " \n", 116 | " if not hasattr(self.weight,'org'):\n", 117 | " self.weight.org=self.weight.data.clone() \n", 118 | " \n", 119 | "# self.weight.data=Binarize(self.weight.org)\n", 120 | "# print(self.weight.data)\n", 121 | "# print(self.weight.org)\n", 122 | "# theta = self.weight\n", 123 | " theta = torch.tanh(self.weight)\n", 124 | "# print(theta)\n", 125 | "# print(input[0])\n", 126 | " \n", 127 | "\n", 128 | "# print(input[0])\n", 129 | " if input.size(1) != 784:\n", 130 | " mu = nn.functional.linear(input,theta)\n", 131 | " left = input**2 - (1- input**2)\n", 132 | " right = theta**2 - (1-theta**2)\n", 133 | " sigma = 1 - nn.functional.linear(left,right)\n", 134 | " else:\n", 135 | "# print((input**2)[0])\n", 136 | "# print((1-(theta**2))[0])\n", 137 | " mu = nn.functional.linear(input,theta) \n", 138 | " sigma = nn.functional.linear(input**2,1-(theta**2))\n", 139 | " \n", 140 | "# if input.size(1) == 784:\n", 141 | "# input.data = Binarize(input.data)\n", 142 | "# # print(input[0])\n", 143 | "# input2 = input\n", 144 | "# mu = nn.functional.linear(input2,theta)\n", 145 | "# left = input2**2 - (1- input2**2)\n", 146 | "# right = theta**2 - (1-theta**2)\n", 147 | "# sigma = torch.ones_like(mu) - nn.functional.linear(left,right)\n", 148 | " \n", 149 | "# print(left[0])\n", 150 | "# print(right[0])\n", 151 | " \n", 152 | " \n", 153 | "# ss= (input**2)@(1-(theta**2).t())\n", 154 | "# print(ss[0])\n", 155 | "# print(sigma[0])\n", 156 | "\n", 157 | "# m = mu.mean(0,True)\n", 158 | " \n", 159 | "# v = sigma.var(0,True)\n", 160 | " \n", 161 | "# mu = 0.5*(mu-m)/((v+(0.0001)).sqrt()+0.5)\n", 162 | "# sigma = 0.5**2*sigma/(v+0.0001)\n", 163 | "\n", 164 | " \n", 165 | " \n", 166 | " out1 = sampling(mu,sigma)\n", 167 | "\n", 168 | " if self.out_features==10:\n", 169 | " return mu\n", 170 | " else:\n", 171 | " return out1\n", 172 | "\n", 173 | "\n", 174 | "\n", 175 | "\n", 176 | "class BinarizeLinear(nn.Linear):\n", 177 | "\n", 178 | " def __init__(self, *kargs, **kwargs):\n", 179 | " super(BinarizeLinear, self).__init__(*kargs, **kwargs)\n", 180 | "\n", 181 | " def forward(self, input):\n", 182 | "\n", 183 | " if input.size(1) != 784:\n", 184 | " input.data=Binarize(input.data)\n", 185 | " if not hasattr(self.weight,'org'):\n", 186 | " self.weight.org=self.weight.data.clone()\n", 187 | " self.weight.data=Binarize(self.weight.org)\n", 188 | " out = nn.functional.linear(input, self.weight)\n", 189 | " if not self.bias is None:\n", 190 | " self.bias.org=self.bias.data.clone()\n", 191 | " out += self.bias.view(1, -1).expand_as(out)\n", 192 | "# print(self.weight)\n", 193 | " p = np.count_nonzero((self.weight.data.cpu()+1)/2)/np.count_nonzero((self.weight.data.cpu()))\n", 194 | " print(self.weight.data)\n", 195 | " return out\n", 196 | "\n" 197 | ], 198 | "execution_count": 0, 199 | "outputs": [] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "metadata": { 204 | "id": "ggHwTvf4MRqC", 205 | "colab_type": "code", 206 | "outputId": "dd576616-4f60-49fe-e097-99261d99a451", 207 | "colab": { 208 | "base_uri": "https://localhost:8080/", 209 | "height": 86 210 | } 211 | }, 212 | "source": [ 213 | "from __future__ import print_function\n", 214 | "import argparse\n", 215 | "import torch\n", 216 | "import torch.nn as nn\n", 217 | "import torch.nn.functional as F\n", 218 | "import torch.optim as optim\n", 219 | "from torchvision import datasets, transforms\n", 220 | "from torch.autograd import Variable\n", 221 | "from tqdm import tqdm\n", 222 | "from torch.distributions.normal import Normal\n", 223 | "from torch.distributions.relaxed_bernoulli import RelaxedBernoulli\n", 224 | "from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical\n", 225 | "\n", 226 | "from torch.distributions.categorical import Categorical\n", 227 | "\n", 228 | "torch.manual_seed(1)\n", 229 | "# if args.cuda:\n", 230 | "# torch.cuda.manual_seed(args.seed)\n", 231 | "\n", 232 | "\n", 233 | "# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n", 234 | "train_loader = torch.utils.data.DataLoader(\n", 235 | " datasets.MNIST('../data', train=True, download=True,\n", 236 | " transform=transforms.Compose([\n", 237 | " transforms.ToTensor(),\n", 238 | " transforms.Normalize((0.1307,), (0.3081,))\n", 239 | " ])),\n", 240 | " batch_size=128, shuffle=True)\n", 241 | "test_loader = torch.utils.data.DataLoader(\n", 242 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 243 | " transforms.ToTensor(),\n", 244 | " transforms.Normalize((0.1307,), (0.3081,))\n", 245 | " ])),\n", 246 | " batch_size=128, shuffle=True)\n", 247 | "\n", 248 | "\n", 249 | "\n", 250 | "\n", 251 | "\n", 252 | "# 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10c\n", 253 | "class Net(nn.Module):\n", 254 | " def __init__(self):\n", 255 | " super(Net, self).__init__()\n", 256 | " \n", 257 | "# self.conv1 = PBinarizeConv2d(1, 32, kernel_size=3)\n", 258 | " \n", 259 | "# self.mp1 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 260 | "# self.htanh1 = nn.Hardtanh()\n", 261 | " \n", 262 | "# self.conv2 = PBinarizeConv2d(32, 64, kernel_size=3)\n", 263 | "# self.mp2 = nn.MaxPool2d(kernel_size=2, stride=2)\n", 264 | "# self.htanh2 = nn.Hardtanh()\n", 265 | " \n", 266 | " self.fc1 = PBinarizeLinear(784, 512)\n", 267 | "# self.htanh3 = nn.Hardtanh()\n", 268 | " \n", 269 | " self.fc2 = PBinarizeLinear(512, 10)\n", 270 | "\n", 271 | "\n", 272 | " # 32C3 - MP2 - 64C3 - Mp2 - 512FC - SM10c\n", 273 | " \n", 274 | " def forward(self, x):\n", 275 | " \n", 276 | " \n", 277 | " x = x.view(x.size(0), -1)\n", 278 | "# print(x.size())\n", 279 | " \n", 280 | " x = self.fc1(x)\n", 281 | " \n", 282 | " x = self.fc2(x)\n", 283 | "# print(x)\n", 284 | "\n", 285 | " \n", 286 | " return x\n", 287 | " \n", 288 | "\n", 289 | "model = Net()\n", 290 | "\n", 291 | "print(model)\n", 292 | "\n", 293 | "torch.cuda.device('cuda')\n", 294 | "model.cuda()\n", 295 | "\n", 296 | "\n", 297 | "criterion = nn.CrossEntropyLoss()\n", 298 | "optimizer = optim.Adam(model.parameters(), lr=0.01)\n", 299 | "\n", 300 | "\n", 301 | "def train(epoch):\n", 302 | " model.train()\n", 303 | " \n", 304 | " losses = []\n", 305 | " trainloader = tqdm(train_loader)\n", 306 | " \n", 307 | " for batch_idx, (data, target) in enumerate(trainloader):\n", 308 | " \n", 309 | " data, target = data.cuda(), target.cuda()\n", 310 | " data, target = Variable(data), Variable(target)\n", 311 | " optimizer.zero_grad()\n", 312 | " output = model(data)\n", 313 | " loss = criterion(output, target)\n", 314 | "\n", 315 | "# print(loss)\n", 316 | "\n", 317 | "# if epoch%40==0:\n", 318 | "# optimizer.param_groups[0]['lr']=optimizer.param_groups[0]['lr']*0.1\n", 319 | "\n", 320 | "# optimizer.zero_grad()\n", 321 | "# \n", 322 | " loss.backward()\n", 323 | " \n", 324 | "# for p in list(model.parameters()):\n", 325 | "# if hasattr(p,'org'):\n", 326 | "# p.data.copy_(p.org)\n", 327 | " optimizer.step()\n", 328 | " \n", 329 | "# for p in list(model.parameters()):\n", 330 | "# if hasattr(p,'org'):\n", 331 | "# p.org.copy_(p.data.clamp_(-0.9,0.9))\n", 332 | " \n", 333 | " losses.append(loss.item())\n", 334 | " trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)\n", 335 | "\n", 336 | "\n", 337 | "\n", 338 | "def test():\n", 339 | " model.eval()\n", 340 | " test_loss = 0\n", 341 | " correct = 0\n", 342 | " testloader = tqdm(test_loader)\n", 343 | " for data, target in testloader:\n", 344 | " data, target = data.cuda(), target.cuda()\n", 345 | " with torch.no_grad():\n", 346 | " data = Variable(data)\n", 347 | " target = Variable(target)\n", 348 | " output = model(data)\n", 349 | " test_loss += criterion(output, target).item() # sum up batch loss\n", 350 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 351 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 352 | " \n", 353 | " \n", 354 | "\n", 355 | " testloader.set_postfix(loss=test_loss / len(test_loader.dataset),acc=str((100. *correct / len(test_loader.dataset)).numpy())+'%')\n", 356 | " \n", 357 | " test_loss /= len(test_loader.dataset)\n", 358 | " \n", 359 | " \n", 360 | "\n", 361 | "\n" 362 | ], 363 | "execution_count": 2, 364 | "outputs": [ 365 | { 366 | "output_type": "stream", 367 | "text": [ 368 | "Net(\n", 369 | " (fc1): PBinarizeLinear(in_features=784, out_features=512, bias=True)\n", 370 | " (fc2): PBinarizeLinear(in_features=512, out_features=10, bias=True)\n", 371 | ")\n" 372 | ], 373 | "name": "stdout" 374 | } 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "metadata": { 380 | "id": "bxgRGAoMbkCf", 381 | "colab_type": "code", 382 | "outputId": "2e24e9c1-274c-4eb1-989b-242a83e4e289", 383 | "colab": { 384 | "base_uri": "https://localhost:8080/", 385 | "height": 397 386 | } 387 | }, 388 | "source": [ 389 | "%%%time\n", 390 | "for epoch in range(10):\n", 391 | " train(epoch)\n", 392 | " test()" 393 | ], 394 | "execution_count": 3, 395 | "outputs": [ 396 | { 397 | "output_type": "stream", 398 | "text": [ 399 | "100%|██████████| 469/469 [00:13<00:00, 35.24it/s, epoch=0, loss=0.945]\n", 400 | "100%|██████████| 79/79 [00:01<00:00, 44.29it/s, acc=88%, loss=0.00286]\n", 401 | "100%|██████████| 469/469 [00:12<00:00, 36.39it/s, epoch=1, loss=0.311]\n", 402 | "100%|██████████| 79/79 [00:01<00:00, 43.47it/s, acc=91%, loss=0.00216]\n", 403 | "100%|██████████| 469/469 [00:11<00:00, 40.58it/s, epoch=2, loss=0.247]\n", 404 | "100%|██████████| 79/79 [00:01<00:00, 43.32it/s, acc=93%, loss=0.00179]\n", 405 | "100%|██████████| 469/469 [00:11<00:00, 41.57it/s, epoch=3, loss=0.209]\n", 406 | "100%|██████████| 79/79 [00:01<00:00, 43.29it/s, acc=94%, loss=0.00154]\n", 407 | "100%|██████████| 469/469 [00:11<00:00, 40.66it/s, epoch=4, loss=0.185]\n", 408 | "100%|██████████| 79/79 [00:01<00:00, 43.47it/s, acc=94%, loss=0.00144]\n", 409 | "100%|██████████| 469/469 [00:11<00:00, 40.69it/s, epoch=5, loss=0.166]\n", 410 | "100%|██████████| 79/79 [00:01<00:00, 43.42it/s, acc=94%, loss=0.00135]\n", 411 | "100%|██████████| 469/469 [00:11<00:00, 40.57it/s, epoch=6, loss=0.148]\n", 412 | "100%|██████████| 79/79 [00:01<00:00, 43.05it/s, acc=95%, loss=0.00121]\n", 413 | "100%|██████████| 469/469 [00:12<00:00, 36.52it/s, epoch=7, loss=0.134]\n", 414 | "100%|██████████| 79/79 [00:01<00:00, 43.68it/s, acc=95%, loss=0.0012]\n", 415 | "100%|██████████| 469/469 [00:11<00:00, 40.86it/s, epoch=8, loss=0.124]\n", 416 | "100%|██████████| 79/79 [00:01<00:00, 45.09it/s, acc=95%, loss=0.00115]\n", 417 | "100%|██████████| 469/469 [00:11<00:00, 40.57it/s, epoch=9, loss=0.112]\n", 418 | "100%|██████████| 79/79 [00:01<00:00, 42.62it/s, acc=95%, loss=0.000996]" 419 | ], 420 | "name": "stderr" 421 | }, 422 | { 423 | "output_type": "stream", 424 | "text": [ 425 | "CPU times: user 2min 4s, sys: 5 s, total: 2min 9s\n", 426 | "Wall time: 2min 18s\n" 427 | ], 428 | "name": "stdout" 429 | }, 430 | { 431 | "output_type": "stream", 432 | "text": [ 433 | "\n" 434 | ], 435 | "name": "stderr" 436 | } 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "metadata": { 442 | "id": "DiDjIjU6Mf_y", 443 | "colab_type": "code", 444 | "colab": { 445 | "base_uri": "https://localhost:8080/", 446 | "height": 104 447 | }, 448 | "outputId": "6a0ae9a4-30e3-4502-b1e8-e398ed17b4cd" 449 | }, 450 | "source": [ 451 | "a = torch.rand(5,4)\n", 452 | "a" 453 | ], 454 | "execution_count": 4, 455 | "outputs": [ 456 | { 457 | "output_type": "execute_result", 458 | "data": { 459 | "text/plain": [ 460 | "tensor([[0.2111, 0.3660, 0.8533, 0.7804],\n", 461 | " [0.6190, 0.2480, 0.4822, 0.3310],\n", 462 | " [0.3542, 0.7472, 0.1691, 0.9444],\n", 463 | " [0.2635, 0.7556, 0.0211, 0.7558],\n", 464 | " [0.5158, 0.6163, 0.2596, 0.7621]])" 465 | ] 466 | }, 467 | "metadata": { 468 | "tags": [] 469 | }, 470 | "execution_count": 4 471 | } 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "id": "41XggYVRr8az", 478 | "colab_type": "code", 479 | "colab": {} 480 | }, 481 | "source": [ 482 | "" 483 | ], 484 | "execution_count": 0, 485 | "outputs": [] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "metadata": { 490 | "id": "sphyNMNrfwQz", 491 | "colab_type": "code", 492 | "colab": { 493 | "base_uri": "https://localhost:8080/", 494 | "height": 104 495 | }, 496 | "outputId": "702701a3-6b2b-45aa-dacd-7d5f5aef6f6b" 497 | }, 498 | "source": [ 499 | "torch.ones_like(a)" 500 | ], 501 | "execution_count": 5, 502 | "outputs": [ 503 | { 504 | "output_type": "execute_result", 505 | "data": { 506 | "text/plain": [ 507 | "tensor([[1., 1., 1., 1.],\n", 508 | " [1., 1., 1., 1.],\n", 509 | " [1., 1., 1., 1.],\n", 510 | " [1., 1., 1., 1.],\n", 511 | " [1., 1., 1., 1.]])" 512 | ] 513 | }, 514 | "metadata": { 515 | "tags": [] 516 | }, 517 | "execution_count": 5 518 | } 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "metadata": { 524 | "id": "3GP1UM1ufwHF", 525 | "colab_type": "code", 526 | "colab": {} 527 | }, 528 | "source": [ 529 | "" 530 | ], 531 | "execution_count": 0, 532 | "outputs": [] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "metadata": { 537 | "id": "GaDZ92vonUN0", 538 | "colab_type": "code", 539 | "colab": { 540 | "base_uri": "https://localhost:8080/", 541 | "height": 104 542 | }, 543 | "outputId": "60c20ca0-96b0-4093-ef8e-5a362d132b1a" 544 | }, 545 | "source": [ 546 | "a.bernoulli()" 547 | ], 548 | "execution_count": 6, 549 | "outputs": [ 550 | { 551 | "output_type": "execute_result", 552 | "data": { 553 | "text/plain": [ 554 | "tensor([[0., 1., 1., 1.],\n", 555 | " [0., 0., 0., 0.],\n", 556 | " [1., 0., 0., 1.],\n", 557 | " [0., 1., 0., 0.],\n", 558 | " [0., 0., 0., 1.]])" 559 | ] 560 | }, 561 | "metadata": { 562 | "tags": [] 563 | }, 564 | "execution_count": 6 565 | } 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "metadata": { 571 | "id": "8YWe1Ti0MyMt", 572 | "colab_type": "code", 573 | "colab": { 574 | "base_uri": "https://localhost:8080/", 575 | "height": 104 576 | }, 577 | "outputId": "80b86d19-6aaa-41fb-8093-ddb57966b22d" 578 | }, 579 | "source": [ 580 | "a=Binarize(a)\n", 581 | "a" 582 | ], 583 | "execution_count": 7, 584 | "outputs": [ 585 | { 586 | "output_type": "execute_result", 587 | "data": { 588 | "text/plain": [ 589 | "tensor([[1., 1., 1., 1.],\n", 590 | " [1., 1., 1., 1.],\n", 591 | " [1., 1., 1., 1.],\n", 592 | " [1., 1., 1., 1.],\n", 593 | " [1., 1., 1., 1.]])" 594 | ] 595 | }, 596 | "metadata": { 597 | "tags": [] 598 | }, 599 | "execution_count": 7 600 | } 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "metadata": { 606 | "id": "tqOWLfDtr-H2", 607 | "colab_type": "code", 608 | "colab": { 609 | "base_uri": "https://localhost:8080/", 610 | "height": 104 611 | }, 612 | "outputId": "2f7fc180-2971-4c82-9076-c430d0ec06ae" 613 | }, 614 | "source": [ 615 | "torch.nn.functional.gumbel_softmax(a, tau=1, hard=True, eps=1e-10, dim=-1)" 616 | ], 617 | "execution_count": 8, 618 | "outputs": [ 619 | { 620 | "output_type": "execute_result", 621 | "data": { 622 | "text/plain": [ 623 | "tensor([[0., 0., 0., 1.],\n", 624 | " [0., 1., 0., 0.],\n", 625 | " [1., 0., 0., 0.],\n", 626 | " [0., 0., 0., 1.],\n", 627 | " [0., 0., 0., 1.]])" 628 | ] 629 | }, 630 | "metadata": { 631 | "tags": [] 632 | }, 633 | "execution_count": 8 634 | } 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "metadata": { 640 | "id": "Qu8HiOmqjwdx", 641 | "colab_type": "code", 642 | "colab": { 643 | "base_uri": "https://localhost:8080/", 644 | "height": 104 645 | }, 646 | "outputId": "78423f00-aa52-44e9-ba18-90f2bfbf09f8" 647 | }, 648 | "source": [ 649 | "(a+1)/2" 650 | ], 651 | "execution_count": 9, 652 | "outputs": [ 653 | { 654 | "output_type": "execute_result", 655 | "data": { 656 | "text/plain": [ 657 | "tensor([[1., 1., 1., 1.],\n", 658 | " [1., 1., 1., 1.],\n", 659 | " [1., 1., 1., 1.],\n", 660 | " [1., 1., 1., 1.],\n", 661 | " [1., 1., 1., 1.]])" 662 | ] 663 | }, 664 | "metadata": { 665 | "tags": [] 666 | }, 667 | "execution_count": 9 668 | } 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "metadata": { 674 | "id": "0hkIV7ztMVyZ", 675 | "colab_type": "code", 676 | "colab": { 677 | "base_uri": "https://localhost:8080/", 678 | "height": 104 679 | }, 680 | "outputId": "abf001e9-a205-47f4-ef49-530deffd15eb" 681 | }, 682 | "source": [ 683 | "a.tanh()" 684 | ], 685 | "execution_count": 10, 686 | "outputs": [ 687 | { 688 | "output_type": "execute_result", 689 | "data": { 690 | "text/plain": [ 691 | "tensor([[0.7616, 0.7616, 0.7616, 0.7616],\n", 692 | " [0.7616, 0.7616, 0.7616, 0.7616],\n", 693 | " [0.7616, 0.7616, 0.7616, 0.7616],\n", 694 | " [0.7616, 0.7616, 0.7616, 0.7616],\n", 695 | " [0.7616, 0.7616, 0.7616, 0.7616]])" 696 | ] 697 | }, 698 | "metadata": { 699 | "tags": [] 700 | }, 701 | "execution_count": 10 702 | } 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "metadata": { 708 | "id": "9idSm857otNw", 709 | "colab_type": "code", 710 | "colab": { 711 | "base_uri": "https://localhost:8080/", 712 | "height": 35 713 | }, 714 | "outputId": "dd7fc7e4-1e6b-4f5f-fe64-031272e58262" 715 | }, 716 | "source": [ 717 | "a.mean(0)" 718 | ], 719 | "execution_count": 11, 720 | "outputs": [ 721 | { 722 | "output_type": "execute_result", 723 | "data": { 724 | "text/plain": [ 725 | "tensor([1., 1., 1., 1.])" 726 | ] 727 | }, 728 | "metadata": { 729 | "tags": [] 730 | }, 731 | "execution_count": 11 732 | } 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "metadata": { 738 | "id": "cUGDVCa1pZes", 739 | "colab_type": "code", 740 | "colab": {} 741 | }, 742 | "source": [ 743 | "p= np.count_nonzero((a+1)/2,axis=0)/np.count_nonzero(a,axis=0)" 744 | ], 745 | "execution_count": 0, 746 | "outputs": [] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "metadata": { 751 | "id": "NJMpKyyqIkdh", 752 | "colab_type": "code", 753 | "colab": { 754 | "base_uri": "https://localhost:8080/", 755 | "height": 35 756 | }, 757 | "outputId": "5ace1b64-ef81-4696-c378-b215d94282f6" 758 | }, 759 | "source": [ 760 | "1-(p)**2" 761 | ], 762 | "execution_count": 13, 763 | "outputs": [ 764 | { 765 | "output_type": "execute_result", 766 | "data": { 767 | "text/plain": [ 768 | "array([0., 0., 0., 0.])" 769 | ] 770 | }, 771 | "metadata": { 772 | "tags": [] 773 | }, 774 | "execution_count": 13 775 | } 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "metadata": { 781 | "id": "1m-AodLmu8D3", 782 | "colab_type": "code", 783 | "colab": { 784 | "base_uri": "https://localhost:8080/", 785 | "height": 35 786 | }, 787 | "outputId": "78a57934-b025-4000-d029-f43d75f6542d" 788 | }, 789 | "source": [ 790 | "np.count_nonzero((a+1)/2)" 791 | ], 792 | "execution_count": 14, 793 | "outputs": [ 794 | { 795 | "output_type": "execute_result", 796 | "data": { 797 | "text/plain": [ 798 | "20" 799 | ] 800 | }, 801 | "metadata": { 802 | "tags": [] 803 | }, 804 | "execution_count": 14 805 | } 806 | ] 807 | }, 808 | { 809 | "cell_type": "code", 810 | "metadata": { 811 | "id": "xwOCJG610Shm", 812 | "colab_type": "code", 813 | "colab": {} 814 | }, 815 | "source": [ 816 | "mu = torch.randn(5)\n", 817 | "sig = torch.randn(5)\n" 818 | ], 819 | "execution_count": 0, 820 | "outputs": [] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "metadata": { 825 | "id": "9_oUI0eo0XgN", 826 | "colab_type": "code", 827 | "colab": { 828 | "base_uri": "https://localhost:8080/", 829 | "height": 35 830 | }, 831 | "outputId": "d3454a76-3475-442a-fa05-886ddb256d79" 832 | }, 833 | "source": [ 834 | "x=Normal(mu,sig)\n", 835 | "1 - x.cdf(0)" 836 | ], 837 | "execution_count": 16, 838 | "outputs": [ 839 | { 840 | "output_type": "execute_result", 841 | "data": { 842 | "text/plain": [ 843 | "tensor([1.0000, 0.0000, 0.0000, 0.9960, 0.2989])" 844 | ] 845 | }, 846 | "metadata": { 847 | "tags": [] 848 | }, 849 | "execution_count": 16 850 | } 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "metadata": { 856 | "id": "0jw0mT8duoik", 857 | "colab_type": "code", 858 | "colab": {} 859 | }, 860 | "source": [ 861 | "m = -4.6\n", 862 | "v = 25936\n", 863 | "x = Normal(m,v)\n", 864 | "p = 1 - x.cdf(0)\n", 865 | "# s = sample_gumbel_softmax(p,1.0)\n", 866 | "\n" 867 | ], 868 | "execution_count": 0, 869 | "outputs": [] 870 | }, 871 | { 872 | "cell_type": "code", 873 | "metadata": { 874 | "id": "amjSvoApwIVS", 875 | "colab_type": "code", 876 | "colab": { 877 | "base_uri": "https://localhost:8080/", 878 | "height": 69 879 | }, 880 | "outputId": "3b4bece8-4c82-4174-c0f2-3758df32b339" 881 | }, 882 | "source": [ 883 | "w = torch.empty(3, 5)\n", 884 | "nn.init.uniform_(w,-1,1)" 885 | ], 886 | "execution_count": 18, 887 | "outputs": [ 888 | { 889 | "output_type": "execute_result", 890 | "data": { 891 | "text/plain": [ 892 | "tensor([[-0.9226, -0.8018, -0.3588, 0.8282, -0.8950],\n", 893 | " [-0.5052, -0.7240, -0.3942, -0.4255, -0.2595],\n", 894 | " [-0.8417, 0.1117, -0.6369, -0.6278, -0.7887]])" 895 | ] 896 | }, 897 | "metadata": { 898 | "tags": [] 899 | }, 900 | "execution_count": 18 901 | } 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "metadata": { 907 | "id": "2cQjhRs_wIPo", 908 | "colab_type": "code", 909 | "colab": { 910 | "base_uri": "https://localhost:8080/", 911 | "height": 165 912 | }, 913 | "outputId": "b053bf4a-e9b9-457b-a231-6af4a91a7116" 914 | }, 915 | "source": [ 916 | "p.sample()" 917 | ], 918 | "execution_count": 19, 919 | "outputs": [ 920 | { 921 | "output_type": "error", 922 | "ename": "AttributeError", 923 | "evalue": "ignored", 924 | "traceback": [ 925 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 926 | "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", 927 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 928 | "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'sample'" 929 | ] 930 | } 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "metadata": { 936 | "id": "0AZEDswQu1TZ", 937 | "colab_type": "code", 938 | "colab": {} 939 | }, 940 | "source": [ 941 | "aa = Normal(m,v)\n", 942 | "# aa.sample(torch.tensor([20]))" 943 | ], 944 | "execution_count": 0, 945 | "outputs": [] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "metadata": { 950 | "id": "zJwtW9ld5_tf", 951 | "colab_type": "code", 952 | "colab": {} 953 | }, 954 | "source": [ 955 | "" 956 | ], 957 | "execution_count": 0, 958 | "outputs": [] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "metadata": { 963 | "id": "b5baFgYuu39G", 964 | "colab_type": "code", 965 | "colab": {} 966 | }, 967 | "source": [ 968 | "1 - x.cdf(0)" 969 | ], 970 | "execution_count": 0, 971 | "outputs": [] 972 | }, 973 | { 974 | "cell_type": "code", 975 | "metadata": { 976 | "id": "xAMjkieNIlS-", 977 | "colab_type": "code", 978 | "colab": {} 979 | }, 980 | "source": [ 981 | "m = RelaxedOneHotCategorical(torch.tensor([1.]),a)\n", 982 | "m.sample()" 983 | ], 984 | "execution_count": 0, 985 | "outputs": [] 986 | }, 987 | { 988 | "cell_type": "code", 989 | "metadata": { 990 | "id": "kEYyrZzMZkuJ", 991 | "colab_type": "code", 992 | "colab": {} 993 | }, 994 | "source": [ 995 | "\"\"# # Execute this code block to install dependencies when running on colab\n", 996 | "# try:\n", 997 | "# import torch\n", 998 | "# except:\n", 999 | "# from os.path import exists\n", 1000 | "# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", 1001 | "# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", 1002 | "# cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\\.\\([0-9]*\\)\\.\\([0-9]*\\)$/cu\\1\\2/'\n", 1003 | "# accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'\n", 1004 | "\n", 1005 | "# !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision\n", 1006 | "\n", 1007 | "# try: \n", 1008 | "# import torchbearer\n", 1009 | "# except:\n", 1010 | "# !pip install torchbearer\n", 1011 | " \n", 1012 | "# from torchbearer import Trial\n", 1013 | "# torchbearer_trial = Trial(model, optimizer, criterion, metrics=['loss', 'accuracy']).to('cuda:0')\n", 1014 | "# torchbearer_trial.with_generators(train_loader, test_generator=test_loader)\n", 1015 | "# torchbearer_trial.run(epochs=5)" 1016 | ], 1017 | "execution_count": 0, 1018 | "outputs": [] 1019 | } 1020 | ] 1021 | } --------------------------------------------------------------------------------