├── .gitignore ├── README.md └── notebook ├── cnn.ipynb ├── lstm.ipynb ├── mdn-cls.ipynb ├── mdn.ipynb ├── mha.ipynb ├── mln.ipynb ├── mlp.ipynb └── optm.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yet Another PyTorch Tutorial 2 | #### All notebooks are colab-compatible. 3 | 4 | - MLP [notebook](https://github.com/sjchoi86/yet-another-pytorch-tutorial/blob/main/notebook/mlp.ipynb) 5 | - CNN [notebook](https://github.com/sjchoi86/yet-another-pytorch-tutorial/blob/main/notebook/cnn.ipynb) 6 | - Linear Regression [notebook](https://github.com/sjchoi86/yet-another-pytorch-tutorial/blob/main/notebook/optm.ipynb) 7 | - LSTM [notebook](https://github.com/sjchoi86/yet-another-pytorch-tutorial/blob/main/notebook/lstm.ipynb) 8 | - Multi-Headed Attention [notebook](https://github.com/sjchoi86/yet-another-pytorch-tutorial/blob/main/notebook/mha.ipynb) 9 | - Mixture Density Network [notebook](https://github.com/sjchoi86/yet-another-pytorch-tutorial/blob/main/notebook/mdn.ipynb) 10 | 11 | #### Special Thanks to [Jerry](https://github.com/jjerry-k). 12 | 13 | Contact: Sungjoon (sungjoon-choi@korea.ac.kr) 14 | -------------------------------------------------------------------------------- /notebook/lstm.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"lstm.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyO9gYohIEVn4AOL6FCtGFNe"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"Uyq75dn2Mk6x"},"source":["\n"," \n"," \n","
\n"," Colab\n"," \n"," View Source\n","
"]},{"cell_type":"markdown","metadata":{"id":"PJwtTg9a11du"},"source":["# Classification with LSTM"]},{"cell_type":"code","metadata":{"id":"tepz70nH1wwO","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935205673,"user_tz":-540,"elapsed":1085,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"9e5e04bb-db53-449b-8711-190a72a97895"},"source":["import numpy as np\n","import matplotlib.pyplot as plt\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","%matplotlib inline\n","%config InlineBackend.figure_format='retina'\n","print (\"PyTorch version:[%s].\"%(torch.__version__))\n","device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n","print (\"device:[%s].\"%(device))"],"execution_count":1,"outputs":[{"output_type":"stream","text":["PyTorch version:[1.7.0+cu101].\n","device:[cuda:0].\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"gjSfbrHz2NbN"},"source":["### Dataset and Loader"]},{"cell_type":"code","metadata":{"id":"_apH6GPI2Adq","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935205674,"user_tz":-540,"elapsed":997,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"35ed37dd-88c9-4fed-cbba-baf61280c6bf"},"source":["from torchvision import datasets,transforms\n","mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)\n","mnist_test = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor(),download=True)\n","BATCH_SIZE = 256\n","train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)\n","test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)\n","print (\"Done.\")"],"execution_count":2,"outputs":[{"output_type":"stream","text":["Done.\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"10evD4Jg2bQ4"},"source":["### Define Model"]},{"cell_type":"code","metadata":{"id":"QoISvH_O2OWO","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935210005,"user_tz":-540,"elapsed":5312,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"b14ecad1-40f6-4d69-b237-45fe7abde219"},"source":["class RecurrentNeuralNetworkClass(nn.Module):\n"," def __init__(self,name='rnn',xdim=28,hdim=256,ydim=10,n_layer=3):\n"," super(RecurrentNeuralNetworkClass,self).__init__()\n"," self.name = name\n"," self.xdim = xdim\n"," self.hdim = hdim\n"," self.ydim = ydim\n"," self.n_layer = n_layer # K\n","\n"," self.rnn = nn.LSTM(\n"," input_size=self.xdim,hidden_size=self.hdim,num_layers=self.n_layer,batch_first=True)\n"," self.lin = nn.Linear(self.hdim,self.ydim)\n","\n"," def forward(self,x):\n"," # Set initial hidden and cell states \n"," h0 = torch.zeros(self.n_layer,x.size(0),self.hdim).to(device)\n"," c0 = torch.zeros(self.n_layer,x.size(0),self.hdim).to(device)\n"," # RNN\n"," rnn_out,(hn,cn) = self.rnn(x, (h0,c0)) \n"," # x:[N x L x Q] => rnn_out:[N x L x D]\n"," # Linear\n"," out = self.lin(rnn_out[:,-1 :]).view([-1,self.ydim]) \n"," return out \n","\n","R = RecurrentNeuralNetworkClass(\n"," name='rnn',xdim=28,hdim=256,ydim=10,n_layer=2).to(device)\n","loss = nn.CrossEntropyLoss()\n","optm = optim.Adam(R.parameters(),lr=1e-3)\n","print (\"Done.\")"],"execution_count":3,"outputs":[{"output_type":"stream","text":["Done.\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"liD6DC7KANYR"},"source":["### Check How LSTM Works\n","- `N`: number of batches\n","- `L`: sequence lengh\n","- `Q`: input dim\n","- `K`: number of layers\n","- `D`: LSTM feature dimension\n","\n","` Y,(hn,cn) = LSTM(X) `\n","\n","- `X`: [N x L x Q] - `N` input sequnce of length `L` with `Q` dim. \n","- `Y`: [N x L x D] - `N` output sequnce of length `L` with `D` feature dim.\n","- `hn`: [K x N x D] - `K` (per each layer) of `N` final hidden state with `D` feature dim. \n","- `cn`: [K x N x D] - `K` (per each layer) of `N` final hidden state with `D` cell dim. "]},{"cell_type":"code","metadata":{"id":"byX3ViAwARpt","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935210007,"user_tz":-540,"elapsed":5299,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"a192a6f8-a7b9-417f-9558-8c3d2ae04549"},"source":["np.set_printoptions(precision=3)\n","torch.set_printoptions(precision=3)\n","x_numpy = np.random.rand(2,20,28) # [N x L x Q]\n","x_torch = torch.from_numpy(x_numpy).float().to(device)\n","rnn_out,(hn,cn) = R.rnn(x_torch) # forward path\n","\n","print (\"rnn_out:\",rnn_out.shape) # [N x L x D]\n","print (\"Hidden State hn:\",hn.shape) # [K x N x D]\n","print (\"Cell States cn:\",cn.shape) # [K x N x D]"],"execution_count":4,"outputs":[{"output_type":"stream","text":["rnn_out: torch.Size([2, 20, 256])\n","Hidden State hn: torch.Size([2, 2, 256])\n","Cell States cn: torch.Size([2, 2, 256])\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"XuBUgRKD5vTx"},"source":["### Check parameters"]},{"cell_type":"code","metadata":{"id":"raw5y-vn4rWa","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935210008,"user_tz":-540,"elapsed":5285,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"21a7f2ab-7095-420b-9e84-e2a4c0051d24"},"source":["np.set_printoptions(precision=3)\n","n_param = 0\n","for p_idx,(param_name,param) in enumerate(R.named_parameters()):\n"," if param.requires_grad:\n"," param_numpy = param.detach().cpu().numpy() # to numpy array \n"," n_param += len(param_numpy.reshape(-1))\n"," print (\"[%d] name:[%s] shape:[%s].\"%(p_idx,param_name,param_numpy.shape))\n"," print (\" val:%s\"%(param_numpy.reshape(-1)[:5]))\n","print (\"Total number of parameters:[%s].\"%(format(n_param,',d')))"],"execution_count":5,"outputs":[{"output_type":"stream","text":["[0] name:[rnn.weight_ih_l0] shape:[(1024, 28)].\n"," val:[-0.018 0.033 0.057 -0.054 -0.017]\n","[1] name:[rnn.weight_hh_l0] shape:[(1024, 256)].\n"," val:[-0.034 -0.043 0.004 0.045 0.001]\n","[2] name:[rnn.bias_ih_l0] shape:[(1024,)].\n"," val:[ 0.005 0.038 -0.008 0.018 -0.02 ]\n","[3] name:[rnn.bias_hh_l0] shape:[(1024,)].\n"," val:[ 0.044 -0.003 -0.032 -0.055 -0.052]\n","[4] name:[rnn.weight_ih_l1] shape:[(1024, 256)].\n"," val:[-0.03 0.042 0.037 -0.037 -0.043]\n","[5] name:[rnn.weight_hh_l1] shape:[(1024, 256)].\n"," val:[-0.042 0.052 -0.012 0.052 -0.012]\n","[6] name:[rnn.bias_ih_l1] shape:[(1024,)].\n"," val:[-0.022 0.015 -0.013 -0.044 0.05 ]\n","[7] name:[rnn.bias_hh_l1] shape:[(1024,)].\n"," val:[-0.017 -0.015 0.047 0.035 -0.031]\n","[8] name:[lin.weight] shape:[(10, 256)].\n"," val:[-0.035 -0.037 0.031 -0.045 0.01 ]\n","[9] name:[lin.bias] shape:[(10,)].\n"," val:[-0.034 0.018 0.044 -0.038 -0.013]\n","Total number of parameters:[821,770].\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"J6rRmikB8dxU"},"source":["### Simple Forward Path "]},{"cell_type":"code","metadata":{"id":"DBdN6qoO8dah","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935210008,"user_tz":-540,"elapsed":5270,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"24892479-aad1-47f7-9aac-5645c1a0398c"},"source":["np.set_printoptions(precision=3)\n","torch.set_printoptions(precision=3)\n","x_numpy = np.random.rand(3,10,28) # [N x L x Q]\n","x_torch = torch.from_numpy(x_numpy).float().to(device)\n","y_torch = R.forward(x_torch) # [N x 1 x R] where R is the output dim.\n","y_numpy = y_torch.detach().cpu().numpy() # torch tensor to numpy array\n","# print (\"x_torch:\\n\",x_torch)\n","# print (\"y_torch:\\n\",y_torch)\n","print (\"x_numpy %s\"%(x_numpy.shape,))\n","print (\"y_numpy %s\"%(y_numpy.shape,))"],"execution_count":6,"outputs":[{"output_type":"stream","text":["x_numpy (3, 10, 28)\n","y_numpy (3, 10)\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"Zi5cIbKG6X3w"},"source":["### Evaluation Function"]},{"cell_type":"code","metadata":{"id":"-STglZMq5xKk","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935210009,"user_tz":-540,"elapsed":5256,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"e6760929-3d37-4926-e91a-98fd98dd3b16"},"source":["def func_eval(model,data_iter,device):\n"," with torch.no_grad():\n"," n_total,n_correct = 0,0\n"," model.eval() # evaluate (affects DropOut and BN)\n"," for batch_in,batch_out in data_iter:\n"," y_trgt = batch_out.to(device)\n"," model_pred = model.forward(batch_in.view(-1,28,28).to(device))\n"," _,y_pred = torch.max(model_pred,1)\n"," n_correct += (y_pred==y_trgt).sum().item()\n"," n_total += batch_in.size(0)\n"," val_accr = (n_correct/n_total)\n"," model.train() # back to train mode \n"," return val_accr\n","print (\"Done\")"],"execution_count":7,"outputs":[{"output_type":"stream","text":["Done\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"pA-3-qPZ6h5u"},"source":["### Initial Evaluation"]},{"cell_type":"code","metadata":{"id":"qGbdjuhB6Z7U","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935216009,"user_tz":-540,"elapsed":11241,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"c46c3921-755e-42eb-9f93-37585f0566a1"},"source":["train_accr = func_eval(R,train_iter,device)\n","test_accr = func_eval(R,test_iter,device)\n","print (\"train_accr:[%.3f] test_accr:[%.3f].\"%(train_accr,test_accr))"],"execution_count":8,"outputs":[{"output_type":"stream","text":["train_accr:[0.099] test_accr:[0.096].\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"PWywAU1-Lm0G"},"source":["### Train"]},{"cell_type":"code","metadata":{"id":"sp11_Glg6k7e","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609935272491,"user_tz":-540,"elapsed":67708,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"283902af-add5-47ef-abff-3fcf1cde4791"},"source":["print (\"Start training.\")\n","R.train() # to train mode \n","EPOCHS,print_every = 5,1\n","for epoch in range(EPOCHS):\n"," loss_val_sum = 0\n"," for batch_in,batch_out in train_iter:\n"," # Forward path\n"," y_pred = R.forward(batch_in.view(-1,28,28).to(device))\n"," loss_out = loss(y_pred,batch_out.to(device))\n"," # Update\n"," optm.zero_grad() # reset gradient \n"," loss_out.backward() # backpropagate\n"," optm.step() # optimizer update\n"," loss_val_sum += loss_out\n"," loss_val_avg = loss_val_sum/len(train_iter)\n"," # Print\n"," if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):\n"," train_accr = func_eval(R,train_iter,device)\n"," test_accr = func_eval(R,test_iter,device)\n"," print (\"epoch:[%d] loss:[%.3f] train_accr:[%.3f] test_accr:[%.3f].\"%\n"," (epoch,loss_val_avg,train_accr,test_accr))\n","print (\"Done\")"],"execution_count":9,"outputs":[{"output_type":"stream","text":["Start training.\n","epoch:[0] loss:[0.642] train_accr:[0.943] test_accr:[0.945].\n","epoch:[1] loss:[0.147] train_accr:[0.965] test_accr:[0.964].\n","epoch:[2] loss:[0.094] train_accr:[0.978] test_accr:[0.975].\n","epoch:[3] loss:[0.065] train_accr:[0.983] test_accr:[0.980].\n","epoch:[4] loss:[0.050] train_accr:[0.988] test_accr:[0.983].\n","Done\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"4JDDHhJtR1aR"},"source":["### Test"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":606},"id":"HrcOUIBrmNf-","executionInfo":{"status":"ok","timestamp":1609935273692,"user_tz":-540,"elapsed":68895,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"e8c3af87-4660-46cb-cdff-51afc0d65522"},"source":["n_sample = 25\n","sample_indices = np.random.choice(len(mnist_test.targets),n_sample,replace=False)\n","test_x = mnist_test.data[sample_indices]\n","test_y = mnist_test.targets[sample_indices]\n","with torch.no_grad():\n"," R.eval() # to evaluation mode \n"," y_pred = R.forward(test_x.view(-1,28,28).type(torch.float).to(device)/255.)\n","y_pred = y_pred.argmax(axis=1)\n","plt.figure(figsize=(10,10))\n","for idx in range(n_sample):\n"," plt.subplot(5, 5, idx+1)\n"," plt.imshow(test_x[idx], cmap='gray')\n"," plt.axis('off')\n"," plt.title(\"Pred:%d, Label:%d\"%(y_pred[idx],test_y[idx]))\n","plt.show()\n","print (\"Done\")"],"execution_count":10,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"image/png":{"width":569,"height":573},"needs_background":"light"}},{"output_type":"stream","text":["Done\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"zkVkpsQymNuC","executionInfo":{"status":"ok","timestamp":1609935273694,"user_tz":-540,"elapsed":68894,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}}},"source":[""],"execution_count":10,"outputs":[]}]} -------------------------------------------------------------------------------- /notebook/mdn-cls.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"mdn-cls.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyNn21QiDcBe8s1K9JAV/khs"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"2rZ1UcDezAyc"},"source":["# MDN for Classification"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"28QgYFqVzAP7","executionInfo":{"status":"ok","timestamp":1610108868970,"user_tz":-540,"elapsed":998,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"83c444ad-f74b-4ec0-dd84-fb003061d063"},"source":["import math\n","import matplotlib.pyplot as plt\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","import torch.distributions as TD\n","from torch.autograd import Variable\n","from collections import OrderedDict\n","%matplotlib inline\n","%config InlineBackend.figure_format='retina'\n","np.set_printoptions(precision=3)\n","torch.set_printoptions(precision=3)\n","print (\"PyTorch version:[%s].\"%(torch.__version__))\n","device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n","print (\"device:[%s].\"%(device))"],"execution_count":1,"outputs":[{"output_type":"stream","text":["PyTorch version:[1.7.0+cu101].\n","device:[cuda:0].\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"eQ5ZZvu4qoOv"},"source":["### Helper functions"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6o1j8QgKqqIG","executionInfo":{"status":"ok","timestamp":1610108869305,"user_tz":-540,"elapsed":1313,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"7d7589c9-8f22-4b9f-bcfb-a908c0529e94"},"source":["# Codes copied from 'https://github.com/sksq96/pytorch-summary/tree/master/torchsummary' \n","def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):\n"," if dtypes == None:\n"," dtypes = [torch.FloatTensor]*len(input_size)\n"," summary_str = ''\n"," def register_hook(module):\n"," def hook(module, input, output):\n"," class_name = str(module.__class__).split(\".\")[-1].split(\"'\")[0]\n"," module_idx = len(summary)\n","\n"," m_key = \"%s-%i\" % (class_name, module_idx + 1)\n"," summary[m_key] = OrderedDict()\n"," summary[m_key][\"input_shape\"] = list(input[0].size())\n"," summary[m_key][\"input_shape\"][0] = batch_size\n"," if isinstance(output, (list, tuple)):\n"," summary[m_key][\"output_shape\"] = [\n"," [-1] + list(o.size())[1:] for o in output\n"," ]\n"," else:\n"," summary[m_key][\"output_shape\"] = list(output.size())\n"," summary[m_key][\"output_shape\"][0] = batch_size\n","\n"," params = 0\n"," if hasattr(module, \"weight\") and hasattr(module.weight, \"size\"):\n"," params += torch.prod(torch.LongTensor(list(module.weight.size())))\n"," summary[m_key][\"trainable\"] = module.weight.requires_grad\n"," if hasattr(module, \"bias\") and hasattr(module.bias, \"size\"):\n"," params += torch.prod(torch.LongTensor(list(module.bias.size())))\n"," summary[m_key][\"nb_params\"] = params\n","\n"," if (\n"," not isinstance(module, nn.Sequential)\n"," and not isinstance(module, nn.ModuleList)\n"," ):\n"," hooks.append(module.register_forward_hook(hook))\n","\n"," # multiple inputs to the network\n"," if isinstance(input_size, tuple):\n"," input_size = [input_size]\n","\n"," # batch_size of 2 for batchnorm\n"," x = [torch.rand(2, *in_size).type(dtype).to(device=device)\n"," for in_size, dtype in zip(input_size, dtypes)]\n","\n"," # create properties\n"," summary = OrderedDict()\n"," hooks = []\n","\n"," # register hook\n"," model.apply(register_hook)\n","\n"," # make a forward pass\n"," # print(x.shape)\n"," model(*x)\n","\n"," # remove these hooks\n"," for h in hooks:\n"," h.remove()\n","\n"," summary_str += \"----------------------------------------------------------------\" + \"\\n\"\n"," line_new = \"{:>20} {:>25} {:>15}\".format(\n"," \"Layer (type)\", \"Output Shape\", \"Param #\")\n"," summary_str += line_new + \"\\n\"\n"," summary_str += \"================================================================\" + \"\\n\"\n"," total_params = 0\n"," total_output = 0\n"," trainable_params = 0\n"," for layer in summary:\n"," # input_shape, output_shape, trainable, nb_params\n"," line_new = \"{:>20} {:>25} {:>15}\".format(\n"," layer,\n"," str(summary[layer][\"output_shape\"]),\n"," \"{0:,}\".format(summary[layer][\"nb_params\"]),\n"," )\n"," total_params += summary[layer][\"nb_params\"]\n"," summary_str += line_new + \"\\n\"\n"," # return summary\n"," return summary_str,summary\n","print (\"Done.\")"],"execution_count":2,"outputs":[{"output_type":"stream","text":["Done.\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"pm-jmaWi9Ovh"},"source":["## $\\color{yellow}{\\text{Mixture Logits Network (MLN) }}$ \n","- \n","`Cross Entropy Loss`\n","$ \\mathcal{L}_{\\text{CE}} = \n"," -\\sum_{d=1}^{D} y_d \\log(\\hat{\\mu}_d)\n","$\n","where $y \\in [0,1]^d$ is the target and $\\hat{\\mu} \\in \\mathbb{S}^d$ is the prediction result.\n","- `Weighted CE Loss`\n","$ \\mathcal{L}_{\\text{WCE}} = \n"," -\n"," \\sum_{k=1}^{K}\n"," \\hat{\\pi}_k\n"," \\sum_{d=1}^{D} y_d \\log(\\hat{\\mu}_d)\n","$\n","where $\\hat{\\pi}$, $\\hat{\\mu}$, and $y$ are mixture weights,\n","output predicitons, and labels, respectively. \n","- \n","`Gal Loss`\n","$\n"," \\mathcal{L}_{\\text{Gal}} \n"," = \\log \\frac{1}{T} \\sum_{t}\n"," \\exp \\left(\n"," \\hat{x}_{t,c} - \\log \\sum_{c'} \\exp \\hat{x}_{t,c'}\n"," \\right)\n","$\n","where $\\hat{x_t} = f^{W} + \\sigma^{W}\\epsilon_t, ~ \\epsilon_t \\sim \\mathcal{N}(0,I)$.\n","- \n","`Mixture of Attenuated CE Loss`\n","$ \\mathcal{L}_{\\text{MACE}} \n"," =\n"," -\n"," \\sum_{k=1}^{K}\n"," \\hat{\\pi}_k\n"," \\sum_{d=1}^{D}\n"," \\frac\n"," {y_d \\log(\\hat{\\mu}_{d,k})}\n"," {\\hat{\\sigma}_{d,k} + \\sigma_{\\text{min}}}\n","$\n","where $\\sigma_{\\text{min}}=1.0$ is the minimum standard deviation.\n"]},{"cell_type":"code","metadata":{"id":"1dhslad3y48u","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1610108873362,"user_tz":-540,"elapsed":5352,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"6423fb0c-7d51-47a0-af99-d1523f3baa3b"},"source":["class MixturesOfLogits(nn.Module):\n"," \"\"\"\n"," Mixture of Logits \n"," \"\"\"\n"," def __init__(self,\n"," in_dim = 64, # input feature dimension\n"," y_dim = 10, # output dimension\n"," k = 5, # number of mixtures\n"," sig_min = 1, # minimum sigma\n"," sig_max = None # maximum signa \n"," ):\n"," super(MixturesOfLogits,self).__init__()\n"," self.in_dim = in_dim\n"," self.y_dim = y_dim\n"," self.k = k\n"," self.sig_min = sig_min\n"," self.sig_max = sig_max\n"," self.fc_pi = nn.Linear(self.in_dim,self.k)\n"," self.fc_mu = nn.Linear(self.in_dim,self.k*self.y_dim)\n"," self.fc_sigma = nn.Linear(self.in_dim,self.k*self.y_dim)\n","\n"," def forward(self,x):\n"," pi_logit = self.fc_pi(x) # [N x K]\n"," pi = torch.softmax(pi_logit,dim=1) # [N x K]\n"," mu = self.fc_mu(x) # [N x KD]\n"," mu = torch.reshape(mu,(-1,self.k,self.y_dim)) # [N x K x D]\n"," sigma = self.fc_sigma(x) # [N x KD]\n"," sigma = torch.reshape(sigma,(-1,self.k,self.y_dim)) # [N x K x D]\n"," if self.sig_max is None:\n"," sigma = self.sig_min + torch.exp(sigma) # [N x K x D]\n"," else:\n"," sigma = self.sig_min + (self.sig_max-self.sig_min)*torch.sigmoid(sigma) # [N x K x D]\n"," return pi,mu,sigma\n","\n","class MixtureLogitNetwork(nn.Module):\n"," def __init__(self,\n"," name='mln',\n"," x_dim = [1,28,28], # iput dimension \n"," k_size = 3, # kernel size\n"," c_dims = [32,64], # channel dimensions for conv layer(s)\n"," p_sizes = [2,2], # pooling sizes\n"," h_dims = [128], # hidden dimensions for dense layer(s)\n"," y_dim = 10, # output dimension\n"," USE_BN = True, # whether to use batch norm \n"," k = 5, # number of mixtures\n"," sig_min = 1, # $\\sigma_{min}$\n"," sig_max = None, # $\\sigma_{max}$\n"," mu_min = -3, # minimum $\\mu$ while initializing bias \n"," mu_max = +3, # maximum $\\mu$ while initializing bias \n"," ):\n"," super(MixtureLogitNetwork,self).__init__()\n"," self.name = name\n"," self.x_dim = x_dim\n"," self.k_size = k_size\n"," self.c_dims = c_dims\n"," self.p_sizes = p_sizes\n"," self.h_dims = h_dims\n"," self.y_dim = y_dim\n"," self.USE_BN = USE_BN\n"," self.k = k\n"," self.sig_min = sig_min\n"," self.sig_max = sig_max\n"," self.mu_min = mu_min\n"," self.mu_max = mu_max\n","\n"," # Build graph\n"," self.build_graph()\n","\n"," # Initialize parameters \n"," self.init_param() \n","\n"," def build_graph(self):\n"," self.layers = []\n"," # Conv layers\n"," prev_c_dim = self.x_dim[0] # input channel \n"," for (c_dim,p_size) in zip(self.c_dims,self.p_sizes):\n"," self.layers.append(\n"," nn.Conv2d(\n"," in_channels = prev_c_dim,\n"," out_channels = c_dim,\n"," kernel_size = self.k_size,\n"," stride = (1,1),\n"," padding = self.k_size//2\n"," ) # conv\n"," )\n"," if self.USE_BN:\n"," self.layers.append(\n"," nn.BatchNorm2d(num_features=c_dim)\n"," )\n"," self.layers.append(nn.ReLU())\n"," self.layers.append(\n"," nn.MaxPool2d(kernel_size=(p_size,p_size),stride=(p_size,p_size))\n"," )\n"," # self.layers.append(nn.Dropout2d(p=0.1)) # p: to be zero-ed\n"," prev_c_dim = c_dim \n"," # Dense layers\n"," self.layers.append(nn.Flatten())\n"," p_prod = np.prod(self.p_sizes)\n"," prev_h_dim = prev_c_dim*(self.x_dim[1]//p_prod)*(self.x_dim[2]//p_prod)\n"," for h_dim in self.h_dims:\n"," self.layers.append(\n"," nn.Linear(\n"," in_features = prev_h_dim,\n"," out_features = h_dim,\n"," bias = True\n"," )\n"," )\n"," self.layers.append(nn.ReLU(True)) # activation\n"," self.layers.append(nn.Dropout2d(p=0.1)) # p: to be zero-ed\n"," prev_h_dim = h_dim\n"," # Final mixture of logits layer\n"," mol = MixturesOfLogits(\n"," in_dim = prev_h_dim, \n"," y_dim = self.y_dim, \n"," k = self.k,\n"," sig_min = self.sig_min,\n"," sig_max = self.sig_max\n"," )\n"," self.layers.append(mol)\n","\n"," # Concatanate all layers\n"," self.net = nn.Sequential()\n"," for l_idx,layer in enumerate(self.layers):\n"," layer_name = \"%s_%02d\"%(type(layer).__name__.lower(),l_idx)\n"," self.net.add_module(layer_name,layer)\n","\n"," def init_param(self): \n"," for m in self.modules():\n"," if isinstance(m,nn.Conv2d): # init conv\n"," nn.init.kaiming_normal_(m.weight)\n"," nn.init.zeros_(m.bias)\n"," if isinstance(m,nn.Linear): # lnit dense\n"," nn.init.kaiming_normal_(m.weight)\n"," nn.init.zeros_(m.bias)\n"," \"\"\"\n"," Heuristic: fc_mu.bias ~ Uniform(mu_min,mu_max)\n"," \"\"\"\n"," self.layers[-1].fc_mu.bias.data.uniform_(self.mu_min,self.mu_max)\n","\n"," def forward(self,x):\n"," return self.net(x)\n","\n","# Instantiate mixture of logits layer \n","M = MixtureLogitNetwork(\n"," name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],\n"," h_dims=[128],y_dim=10,USE_BN=True,\n"," k=3,sig_min=1,sig_max=None,\n"," mu_min=-3,mu_max =+3).to(device)\n","print (\"Done.\")"],"execution_count":3,"outputs":[{"output_type":"stream","text":["Done.\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"6GE5bkPJjNG3"},"source":["## $\\color{yellow}{\\text{Loss function}}$ \n","`Mixture of Attenuated CE Loss`\n","$ \\mathcal{L}_{\\text{MACE}} \n"," =\n"," \\sum_{k=1}^{K}\n"," \\hat{\\pi}_k\n"," \\sum_{d=1}^{D}\n"," \\frac\n"," {-y_d \\log(\\hat{\\mu}_{d,k})}\n"," {\\hat{\\sigma}_{d,k} \n"," }\n"," + \n"," \\frac{1}{D}\n"," \\sum_{d=1}^{D}\n"," \\sum_{k=1}^{K}\n"," \\hat{\\pi}_k \\hat{\\sigma}_{d,k}\n","$"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"GZwN1fE3jMuk","executionInfo":{"status":"ok","timestamp":1610108873362,"user_tz":-540,"elapsed":5334,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"3c6d8da0-fa72-499c-8876-9ea96fab6820"},"source":["def np2tc(x_np): return torch.from_numpy(x_np).float().to(device)\n","def tc2np(x_tc): return x_tc.detach().cpu().numpy()\n","\n","def mdn_gather(pi,mu,sigma):\n"," \"\"\"\n"," pi: [N x K]\n"," mu: [N x K x D]\n"," sigma: [N x K x D]\n"," \"\"\"\n"," max_idx = torch.argmax(pi,dim=1) # [N]\n"," idx_gather = max_idx.unsqueeze(dim=-1).repeat(1,mu.shape[2]).unsqueeze(1) # [N x 1 x D]\n"," mu_sel = torch.gather(mu,dim=1,index=idx_gather).squeeze(dim=1) # [N x D]\n"," sigma_sel = torch.gather(sigma,dim=1,index=idx_gather).squeeze(dim=1) # [N x D]\n"," out = {'max_idx':max_idx,'idx_gather':idx_gather,\n"," 'mu_sel':mu_sel,'sigma_sel':sigma_sel}\n"," return out\n","\n","def mace_loss(pi,mu,sigma,target,alea_weight=1.0):\n"," \"\"\"\n"," Mixture of attenuated CE loss\n"," pi: [N x K]\n"," mu: [N x K x D]\n"," sigma: [N x K x D]\n"," target: [N x D]\n"," \"\"\"\n"," # softmax \\mu\n"," mu_hat = torch.softmax(mu,dim=2) # logit to prob [N x K x D]\n"," log_mu_hat = torch.log(mu_hat+1e-5) # [N x K x D]\n"," \n"," # Expanded \\pi \n"," pi_usq = torch.unsqueeze(pi,2) # [N x K x 1]\n"," pi_exp = pi_usq.expand_as(sigma) # [N x K x D]\n","\n"," # Expanded target\n"," target_usq = torch.unsqueeze(target,1) # [N x 1 x D]\n"," target_exp = target_usq.expand_as(sigma) # [N x K x D]\n","\n"," # Loss\n"," # ce_loss_exp = -target_exp*log_mu_hat # [N x K x D]\n"," ce_loss_exp = -target_exp*log_mu_hat # [N x K x D]\n"," atte_ce = ce_loss_exp / sigma # attenuated CE loss [N x K x D]\n"," waces = torch.sum(torch.mul(pi_exp,atte_ce),dim=1) # weighted attenuated CE loss [N x D]\n"," wace = torch.mean(waces,dim=1) # N\n"," aleas = alea_weight*torch.sum(pi_exp*sigma,dim=1)# aleatoric uncertainty [N x D]\n"," alea = torch.mean(aleas,dim=1) # [N]\n","\n"," # Accumulate loss \n"," loss = wace + alea # [N]\n","\n"," # Average loss\n"," wace_avg = torch.mean(wace) # [1]\n"," alea_avg = torch.mean(alea) # [1]\n"," loss_avg = torch.mean(loss) # [1]\n","\n","\n"," out = {'mu_hat':mu_hat,'log_mu_hat':log_mu_hat,\n"," 'pi_usq':pi_usq,'pi_exp':pi_exp,\n"," 'sigma':sigma,\n"," 'target_usq':target_usq,'target_exp':target_exp,\n"," 'ce_loss_exp':ce_loss_exp,'atte_ce':atte_ce,\n"," 'waces':waces,'wace':wace,\n"," 'aleas':aleas,'alea':alea,\n"," 'loss':loss,\n"," 'wace_avg':wace_avg,'alea_avg':alea_avg,'loss_avg':loss_avg}\n"," return out\n","\n","def gmm_uncertainties(pi, mu, sigma):\n"," # Compute Epistemic Uncertainty\n"," M = 0.1\n"," # pi = torch.softmax(M*pi,1) # (optional) heuristics \n"," pi_usq = torch.unsqueeze(pi,2) # [N x K x 1]\n"," pi_exp = pi_usq.expand_as(sigma) # [N x K x D]\n","\n"," # For classification problems, we use softmax(mu) instead of me\n"," mu = torch.softmax(mu,dim=2) # logit to prob [N x K x D]\n","\n"," mu_avg = torch.sum(torch.mul(pi_exp,mu),dim=1).unsqueeze(1) # [N x 1 x D]\n"," mu_exp = mu_avg.expand_as(mu) # [N x K x D]\n"," mu_diff_sq = torch.square(mu-mu_exp) # [N x K x D]\n"," epis_unct = torch.sum(torch.mul(pi_exp,mu_diff_sq), dim=1) # [N x D]\n","\n"," # Compute Aleatoric Uncertainty\n"," alea_unct = torch.sum(torch.mul(pi_exp,sigma), dim=1) # [N x D]\n","\n"," # Sqaure root \n"," epis_unct = torch.sqrt(epis_unct) # [N x D]\n"," alea_unct = torch.sqrt(alea_unct) # [N x D]\n"," epis_unct_avg = torch.mean(epis_unct,dim=1) # [N]\n"," alea_unct_avg = torch.mean(alea_unct,dim=1) # [N]\n","\n"," # Out\n"," unct_out = {'epis_unct':epis_unct,'alea_unct':alea_unct,\n"," 'epis_unct_avg':epis_unct_avg,'alea_unct_avg':alea_unct_avg}\n","\n"," return unct_out\n"," \n","# Demo run to check the loss \n","M = MixtureLogitNetwork(\n"," name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],\n"," h_dims=[128],y_dim=10,USE_BN=True).to(device)\n","\n","x_np = np.random.rand(2,1,28,28)\n","x_tc = np2tc(x_np)\n","pi_tc,mu_tc,sigma_tc = M.forward(x_tc) # forward path of MLN\n","target_tc = F.one_hot(torch.randint(low=0,high=10,size=(2,)),num_classes=10).to(device) # random one-hot\n","out = mace_loss(pi_tc,mu_tc,sigma_tc,target_tc) # mixture of CE \n","unct_out = gmm_uncertainties(pi_tc,mu_tc,sigma_tc)\n","\n","print ('pi_tc: %s'%(tc2np(target_tc).shape,))\n","print ('mu_tc: %s'%(tc2np(mu_tc).shape,))\n","print ('sigma_tc: %s'%(tc2np(sigma_tc).shape,))\n","print ('target_tc: %s'%(tc2np(target_tc).shape,))\n","print ('=>')\n","print ('mu_hat: %s'%(tc2np(out['mu_hat']).shape,))\n","print ('log_mu_hat: %s'%(tc2np(out['log_mu_hat']).shape,))\n","print ('pi_usq: %s'%(tc2np(out['pi_usq']).shape,))\n","print ('pi_exp: %s'%(tc2np(out['pi_exp']).shape,))\n","print ('target_usq: %s'%(tc2np(out['target_usq']).shape,))\n","print ('target_exp: %s'%(tc2np(out['target_exp']).shape,))\n","print ('ce_loss_exp: %s'%(tc2np(out['ce_loss_exp']).shape,))\n","print ('atte_ce: %s'%(tc2np(out['atte_ce']).shape,))\n","print ('waces: %s'%(tc2np(out['waces']).shape,))\n","print ('wace: %s'%(tc2np(out['wace']).shape,))\n","print ('aleas: %s'%(tc2np(out['aleas']).shape,))\n","print ('alea: %s'%(tc2np(out['alea']).shape,))\n","print ('loss: %s'%(tc2np(out['loss']).shape,))"],"execution_count":4,"outputs":[{"output_type":"stream","text":["pi_tc: (2, 10)\n","mu_tc: (2, 5, 10)\n","sigma_tc: (2, 5, 10)\n","target_tc: (2, 10)\n","=>\n","mu_hat: (2, 5, 10)\n","log_mu_hat: (2, 5, 10)\n","pi_usq: (2, 5, 1)\n","pi_exp: (2, 5, 10)\n","target_usq: (2, 1, 10)\n","target_exp: (2, 5, 10)\n","ce_loss_exp: (2, 5, 10)\n","atte_ce: (2, 5, 10)\n","waces: (2, 10)\n","wace: (2,)\n","aleas: (2, 10)\n","alea: (2,)\n","loss: (2,)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"3xTbbFSlxwI0","executionInfo":{"status":"ok","timestamp":1610108873363,"user_tz":-540,"elapsed":5316,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}}},"source":[""],"execution_count":4,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JnycQ-fdqyhR"},"source":["### Summarize the model"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_auYKcQ6R-ui","executionInfo":{"status":"ok","timestamp":1610108873364,"user_tz":-540,"elapsed":5302,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"1620a69d-89cd-4e09-a450-9af5e18e0a36"},"source":["summary_str,summary = summary_string(M,input_size=(1,28,28),device=device)\n","print (summary_str)"],"execution_count":5,"outputs":[{"output_type":"stream","text":["----------------------------------------------------------------\n"," Layer (type) Output Shape Param #\n","================================================================\n"," Conv2d-1 [-1, 32, 28, 28] 320\n"," BatchNorm2d-2 [-1, 32, 28, 28] 64\n"," ReLU-3 [-1, 32, 28, 28] 0\n"," MaxPool2d-4 [-1, 32, 14, 14] 0\n"," Conv2d-5 [-1, 64, 14, 14] 18,496\n"," BatchNorm2d-6 [-1, 64, 14, 14] 128\n"," ReLU-7 [-1, 64, 14, 14] 0\n"," MaxPool2d-8 [-1, 64, 7, 7] 0\n"," Flatten-9 [-1, 3136] 0\n"," Linear-10 [-1, 128] 401,536\n"," ReLU-11 [-1, 128] 0\n"," Dropout2d-12 [-1, 128] 0\n"," Linear-13 [-1, 5] 645\n"," Linear-14 [-1, 50] 6,450\n"," Linear-15 [-1, 50] 6,450\n"," MixturesOfLogits-16 [[-1, 5], [-1, 5, 10], [-1, 5, 10]] 0\n","MixtureLogitNetwork-17 [[-1, 5], [-1, 5, 10], [-1, 5, 10]] 0\n","\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"R5XcNuKwO6G7"},"source":["### Check parameters"]},{"cell_type":"code","metadata":{"id":"kyDpsBwk8qr-","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1610108873365,"user_tz":-540,"elapsed":5285,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"2b0a15a3-1829-4367-b8f1-66831a91a9d5"},"source":["n_param = 0\n","for p_idx,(param_name,param) in enumerate(M.named_parameters()):\n"," if param.requires_grad:\n"," param_numpy = param.detach().cpu().numpy() # to numpy array \n"," n_param += len(param_numpy.reshape(-1))\n"," print (\"[%02d] name:[%s] shape:[%s].\"%(p_idx,param_name,param_numpy.shape))\n"," print (\" first 3 values:%s\"%(param_numpy.reshape(-1)[:3]))\n","print (\"Total number of parameters:[%s].\"%(format(n_param,',d')))"],"execution_count":6,"outputs":[{"output_type":"stream","text":["[00] name:[net.conv2d_00.weight] shape:[(32, 1, 3, 3)].\n"," first 3 values:[-0.481 -0.531 -1.441]\n","[01] name:[net.conv2d_00.bias] shape:[(32,)].\n"," first 3 values:[0. 0. 0.]\n","[02] name:[net.batchnorm2d_01.weight] shape:[(32,)].\n"," first 3 values:[1. 1. 1.]\n","[03] name:[net.batchnorm2d_01.bias] shape:[(32,)].\n"," first 3 values:[0. 0. 0.]\n","[04] name:[net.conv2d_04.weight] shape:[(64, 32, 3, 3)].\n"," first 3 values:[-0.028 0.012 -0.068]\n","[05] name:[net.conv2d_04.bias] shape:[(64,)].\n"," first 3 values:[0. 0. 0.]\n","[06] name:[net.batchnorm2d_05.weight] shape:[(64,)].\n"," first 3 values:[1. 1. 1.]\n","[07] name:[net.batchnorm2d_05.bias] shape:[(64,)].\n"," first 3 values:[0. 0. 0.]\n","[08] name:[net.linear_09.weight] shape:[(128, 3136)].\n"," first 3 values:[ 0.023 0.041 -0.054]\n","[09] name:[net.linear_09.bias] shape:[(128,)].\n"," first 3 values:[0. 0. 0.]\n","[10] name:[net.mixturesoflogits_12.fc_pi.weight] shape:[(5, 128)].\n"," first 3 values:[-0.03 0.151 0.103]\n","[11] name:[net.mixturesoflogits_12.fc_pi.bias] shape:[(5,)].\n"," first 3 values:[0. 0. 0.]\n","[12] name:[net.mixturesoflogits_12.fc_mu.weight] shape:[(50, 128)].\n"," first 3 values:[ 0.075 0.026 -0.165]\n","[13] name:[net.mixturesoflogits_12.fc_mu.bias] shape:[(50,)].\n"," first 3 values:[-2.104 1.716 -2.968]\n","[14] name:[net.mixturesoflogits_12.fc_sigma.weight] shape:[(50, 128)].\n"," first 3 values:[-0.089 0.029 -0.171]\n","[15] name:[net.mixturesoflogits_12.fc_sigma.bias] shape:[(50,)].\n"," first 3 values:[0. 0. 0.]\n","Total number of parameters:[434,089].\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"toW9smKjU0tH"},"source":["### Demo forward path"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1fh0_I-fPpid","executionInfo":{"status":"ok","timestamp":1610108873366,"user_tz":-540,"elapsed":5267,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"dd285168-1b4e-43f4-9a16-8101bfccb895"},"source":["# Demo instantiate\n","M = MixtureLogitNetwork(\n"," name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],\n"," h_dims=[128],y_dim=10,USE_BN=True,\n"," k=3,sig_min=1,sig_max=None,\n"," mu_min=-3,mu_max =+3).to(device)\n","# Demo forward path \n","x_np = np.random.rand(2,1,28,28)\n","x_tc = np2tc(x_np)\n","pi_tc,mu_tc,sigma_tc = M.forward(x_tc) # forward path of MLN\n","pi_np,mu_np,sigma_np = tc2np(pi_tc),tc2np(mu_tc),tc2np(sigma_tc)\n","out = mdn_gather(pi_tc,mu_tc,sigma_tc)\n","mu_sel_np = tc2np(out['mu_sel'])\n","print ('x_np: %s'%(x_np.shape,))\n","print ('=>')\n","print ('pi_np: %s'%(pi_np.shape,)) # [N x K]\n","print ('mu_np: %s'%(mu_np.shape,)) # [N x K x D]\n","print ('sigma_np: %s'%(sigma_np.shape,)) # [N x K x D]\n","print ('=>')\n","print ('mu_sel_np: %s'%(mu_sel_np.shape,)) # [N x D]"],"execution_count":7,"outputs":[{"output_type":"stream","text":["x_np: (2, 1, 28, 28)\n","=>\n","pi_np: (2, 3)\n","mu_np: (2, 3, 10)\n","sigma_np: (2, 3, 10)\n","=>\n","mu_sel_np: (2, 10)\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"89CAV6pAIghW"},"source":["### Dataset"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"j1zE9HLvmi2U","executionInfo":{"status":"ok","timestamp":1610108873368,"user_tz":-540,"elapsed":5251,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"3899b5de-4a7e-438e-cf30-684938ccd489"},"source":["from torchvision import datasets,transforms\n","mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)\n","mnist_test = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor(),download=True)\n","mnist_train.targets = mnist_train.targets # manipulate train labels\n","BATCH_SIZE = 64\n","train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)\n","test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)\n","print (\"Done.\")"],"execution_count":8,"outputs":[{"output_type":"stream","text":["Done.\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"wYyrGiMlKIUp"},"source":["### Evaluation function"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"04YxC4wEKIEy","executionInfo":{"status":"ok","timestamp":1610108873368,"user_tz":-540,"elapsed":5233,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"23bf1468-324f-4fc4-8533-51809fe87b0b"},"source":["def func_eval(model,data_iter,device):\n"," with torch.no_grad():\n"," n_total,n_correct,epis_unct_sum,alea_unct_sum = 0,0,0,0\n"," model.eval() # evaluate (affects DropOut and BN)\n"," for batch_in,batch_out in data_iter:\n"," # Foraward path\n"," y_trgt = batch_out.to(device)\n"," pi,mu,sigma = model.forward(batch_in.view(-1,1,28,28).to(device))\n"," out = mdn_gather(pi,mu,sigma)\n"," model_pred = out['mu_sel']\n","\n"," # Uncertainty \n"," unct_out = gmm_uncertainties(pi,mu,sigma)\n"," epis_unct = unct_out['epis_unct'] # [N]\n"," alea_unct = unct_out['alea_unct'] # [N]\n"," epis_unct_sum += torch.sum(epis_unct)\n"," alea_unct_sum += torch.sum(alea_unct)\n","\n"," # Check\n"," _,y_pred = torch.max(model_pred,1)\n"," n_correct += (y_pred==y_trgt).sum().item()\n"," n_total += batch_in.size(0)\n"," val_accr = (n_correct/n_total)\n"," epis_unct_avg = (epis_unct_sum/n_total).detach().cpu().item()\n"," alea_unct_avg = (alea_unct_sum/n_total).detach().cpu().item()\n"," model.train() # back to train mode \n"," out_eval = {'val_accr':val_accr,\n"," 'epis_unct_avg':epis_unct_avg,'alea_unct_avg':alea_unct_avg}\n"," return out_eval\n","print (\"Done\")"],"execution_count":9,"outputs":[{"output_type":"stream","text":["Done\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"K5luyCj0anq3","executionInfo":{"status":"ok","timestamp":1610108881374,"user_tz":-540,"elapsed":13221,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"369c0f44-ae64-4c32-93d8-cc2d73821f46"},"source":["M.init_param()\n","train_accr = func_eval(M,train_iter,device)['val_accr']\n","test_accr = func_eval(M,test_iter,device)['val_accr']\n","print (\"train_accr:[%.3f] test_accr:[%.3f].\"%(train_accr,test_accr))"],"execution_count":10,"outputs":[{"output_type":"stream","text":["train_accr:[0.169] test_accr:[0.170].\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"vJzFw_SSJWvg"},"source":["### Train with clean data"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZS77J0Hfo8vx","executionInfo":{"status":"ok","timestamp":1610109052179,"user_tz":-540,"elapsed":184010,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"80b6a1d5-c6b6-4042-9021-dc095c861171"},"source":["np.random.seed(seed=0)\n","torch.manual_seed(seed=0)\n","M = MixtureLogitNetwork(\n"," name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],\n"," h_dims=[128],y_dim=10,USE_BN=False,k=5,\n"," sig_min=0.01,sig_max=3,\n"," mu_min=-3,mu_max=+3).to(device)\n","M.init_param()\n","optm = optim.Adam(M.parameters(),lr=1e-3,weight_decay=1e-6)\n","M.train() # train mode\n","\n","# Re-define the train iterator\n","mnist_train.targets = mnist_train.targets # manipulate train labels\n","BATCH_SIZE = 64\n","train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)\n","\n","EPOCHS,print_every = 10,1\n","for epoch in range(EPOCHS):\n"," loss_sum,wace_sum,alea_sum = 0,0,0\n"," for batch_in,batch_out in train_iter:\n"," # Forward path\n"," pi,mu,sigma = M.forward(batch_in.view(-1,1,28,28).to(device)) \n"," target = torch.eye(M.y_dim)[batch_out].to(device)\n"," mace_loss_out = mace_loss(pi,mu,sigma,target,alea_weight=0.5) # mixture of CE \n"," loss_out = mace_loss_out['loss_avg']\n"," wace_out = mace_loss_out['wace_avg']\n"," alea_out = mace_loss_out['alea_avg']\n"," # Update \n"," optm.zero_grad() # reset gradient \n"," loss_out.backward() # backpropagate\n"," optm.step() # optimizer update\n"," # Track losses \n"," loss_sum += loss_out\n"," wace_sum += wace_out\n"," alea_sum += alea_out\n"," loss_avg = loss_sum/len(train_iter)\n"," wace_avg = wace_sum/len(train_iter)\n"," alea_avg = alea_sum/len(train_iter)\n"," # Print\n"," if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):\n"," train_res = func_eval(M,train_iter,device)\n"," test_res = func_eval(M,test_iter,device)\n"," print (\"epoch:[%d] loss:[%.3f]=(wace:%.3f+alea:%.3f) train_accr:[%.3f] test_accr:[%.3f].\"%\n"," (epoch,loss_avg,wace_avg,alea_avg,\n"," train_res['val_accr'],test_res['val_accr'])) \n"," print (\" [Train] alea:[%.3f] epis:[%.3f]\"%\n"," (train_res['alea_unct_avg'],train_res['epis_unct_avg']))\n"," print (\" [Test] alea:[%.3f] epis:[%.3f]\"%\n"," (test_res['alea_unct_avg'],test_res['epis_unct_avg']))\n","\n","print (\"Done\")"],"execution_count":11,"outputs":[{"output_type":"stream","text":["epoch:[0] loss:[0.152]=(wace:0.069+alea:0.084) train_accr:[0.974] test_accr:[0.974].\n"," [Train] alea:[2.257] epis:[0.004]\n"," [Test] alea:[2.244] epis:[0.004]\n","epoch:[1] loss:[0.068]=(wace:0.033+alea:0.035) train_accr:[0.982] test_accr:[0.981].\n"," [Train] alea:[2.061] epis:[0.002]\n"," [Test] alea:[2.014] epis:[0.002]\n","epoch:[2] loss:[0.048]=(wace:0.022+alea:0.026) train_accr:[0.984] test_accr:[0.983].\n"," [Train] alea:[1.724] epis:[0.001]\n"," [Test] alea:[1.700] epis:[0.001]\n","epoch:[3] loss:[0.040]=(wace:0.018+alea:0.021) train_accr:[0.989] test_accr:[0.987].\n"," [Train] alea:[1.572] epis:[0.000]\n"," [Test] alea:[1.555] epis:[0.000]\n","epoch:[4] loss:[0.033]=(wace:0.015+alea:0.019) train_accr:[0.992] test_accr:[0.989].\n"," [Train] alea:[1.429] epis:[0.000]\n"," [Test] alea:[1.404] epis:[0.000]\n","epoch:[5] loss:[0.028]=(wace:0.013+alea:0.016) train_accr:[0.993] test_accr:[0.989].\n"," [Train] alea:[1.341] epis:[0.000]\n"," [Test] alea:[1.319] epis:[0.000]\n","epoch:[6] loss:[0.030]=(wace:0.014+alea:0.017) train_accr:[0.993] test_accr:[0.988].\n"," [Train] alea:[1.296] epis:[0.000]\n"," [Test] alea:[1.278] epis:[0.000]\n","epoch:[7] loss:[0.025]=(wace:0.011+alea:0.014) train_accr:[0.992] test_accr:[0.987].\n"," [Train] alea:[1.354] epis:[0.000]\n"," [Test] alea:[1.333] epis:[0.000]\n","epoch:[8] loss:[0.022]=(wace:0.009+alea:0.013) train_accr:[0.995] test_accr:[0.989].\n"," [Train] alea:[1.200] epis:[0.000]\n"," [Test] alea:[1.185] epis:[0.000]\n","epoch:[9] loss:[0.023]=(wace:0.010+alea:0.013) train_accr:[0.994] test_accr:[0.989].\n"," [Train] alea:[1.280] epis:[0.000]\n"," [Test] alea:[1.269] epis:[0.000]\n","Done\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"HxFB796B3DAn"},"source":["### Train with random shuffle noise"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"-nAEUlM0Toe2","executionInfo":{"status":"ok","timestamp":1610109224031,"user_tz":-540,"elapsed":355846,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"457591d2-a6a4-4b7f-96bd-e1b3a6c81857"},"source":["np.random.seed(seed=0)\n","torch.manual_seed(seed=0)\n","M = MixtureLogitNetwork(\n"," name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],\n"," h_dims=[128],y_dim=10,USE_BN=False,k=5,\n"," sig_min=0.01,sig_max=3,\n"," mu_min=-3,mu_max=+3).to(device)\n","M.init_param()\n","optm = optim.Adam(M.parameters(),lr=1e-3,weight_decay=1e-6)\n","M.train() # train mode\n","\n","# Re-define the train iterator\n","mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)\n","n_train = len(mnist_train)\n","corrupt_rate = 0.5 # random shuffle rate \n","n_corrupt = int(n_train*corrupt_rate)\n","r_idx = np.random.permutation(n_train)[:n_corrupt]\n","mnist_train.targets[r_idx] = torch.randint(low=0,high=10,size=(n_corrupt,)) # random label \n","BATCH_SIZE = 64\n","train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)\n","\n","EPOCHS,print_every = 10,1\n","for epoch in range(EPOCHS):\n"," loss_sum,wace_sum,alea_sum = 0,0,0\n"," for batch_in,batch_out in train_iter:\n"," # Forward path\n"," pi,mu,sigma = M.forward(batch_in.view(-1,1,28,28).to(device)) \n"," target = torch.eye(M.y_dim)[batch_out].to(device)\n"," mace_loss_out = mace_loss(pi,mu,sigma,target,alea_weight=0.5) # mixture of CE \n"," loss_out = mace_loss_out['loss_avg']\n"," wace_out = mace_loss_out['wace_avg']\n"," alea_out = mace_loss_out['alea_avg']\n"," # Update \n"," optm.zero_grad() # reset gradient \n"," loss_out.backward() # backpropagate\n"," optm.step() # optimizer update\n"," # Track losses \n"," loss_sum += loss_out\n"," wace_sum += wace_out\n"," alea_sum += alea_out\n"," loss_avg = loss_sum/len(train_iter)\n"," wace_avg = wace_sum/len(train_iter)\n"," alea_avg = alea_sum/len(train_iter)\n"," # Print\n"," if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):\n"," train_res = func_eval(M,train_iter,device)\n"," test_res = func_eval(M,test_iter,device)\n"," print (\"epoch:[%d] loss:[%.3f]=(wace:%.3f+alea:%.3f) train_accr:[%.3f] test_accr:[%.3f].\"%\n"," (epoch,loss_avg,wace_avg,alea_avg,\n"," train_res['val_accr'],test_res['val_accr'])) \n"," print (\" [Train] alea:[%.3f] epis:[%.3f]\"%\n"," (train_res['alea_unct_avg'],train_res['epis_unct_avg']))\n"," print (\" [Test] alea:[%.3f] epis:[%.3f]\"%\n"," (test_res['alea_unct_avg'],test_res['epis_unct_avg']))\n","\n","print (\"Done\")"],"execution_count":12,"outputs":[{"output_type":"stream","text":["epoch:[0] loss:[0.619]=(wace:0.297+alea:0.322) train_accr:[0.529] test_accr:[0.957].\n"," [Train] alea:[7.769] epis:[0.027]\n"," [Test] alea:[7.765] epis:[0.027]\n","epoch:[1] loss:[0.589]=(wace:0.289+alea:0.300) train_accr:[0.534] test_accr:[0.964].\n"," [Train] alea:[7.251] epis:[0.009]\n"," [Test] alea:[7.240] epis:[0.009]\n","epoch:[2] loss:[0.583]=(wace:0.287+alea:0.296) train_accr:[0.539] test_accr:[0.972].\n"," [Train] alea:[7.687] epis:[0.011]\n"," [Test] alea:[7.693] epis:[0.011]\n","epoch:[3] loss:[0.578]=(wace:0.285+alea:0.293) train_accr:[0.540] test_accr:[0.972].\n"," [Train] alea:[7.451] epis:[0.007]\n"," [Test] alea:[7.443] epis:[0.007]\n","epoch:[4] loss:[0.573]=(wace:0.283+alea:0.290) train_accr:[0.541] test_accr:[0.969].\n"," [Train] alea:[7.526] epis:[0.007]\n"," [Test] alea:[7.531] epis:[0.007]\n","epoch:[5] loss:[0.566]=(wace:0.279+alea:0.287) train_accr:[0.543] test_accr:[0.967].\n"," [Train] alea:[7.439] epis:[0.004]\n"," [Test] alea:[7.440] epis:[0.004]\n","epoch:[6] loss:[0.558]=(wace:0.276+alea:0.282) train_accr:[0.545] test_accr:[0.960].\n"," [Train] alea:[7.397] epis:[0.003]\n"," [Test] alea:[7.399] epis:[0.003]\n","epoch:[7] loss:[0.549]=(wace:0.271+alea:0.278) train_accr:[0.547] test_accr:[0.953].\n"," [Train] alea:[7.361] epis:[0.003]\n"," [Test] alea:[7.376] epis:[0.003]\n","epoch:[8] loss:[0.536]=(wace:0.265+alea:0.272) train_accr:[0.551] test_accr:[0.941].\n"," [Train] alea:[7.262] epis:[0.003]\n"," [Test] alea:[7.290] epis:[0.003]\n","epoch:[9] loss:[0.524]=(wace:0.258+alea:0.265) train_accr:[0.558] test_accr:[0.934].\n"," [Train] alea:[7.079] epis:[0.002]\n"," [Test] alea:[7.110] epis:[0.002]\n","Done\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"rbfrSSuN9msB"},"source":["### Train with random permutation noise"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NqA6YD4W1Q-f","executionInfo":{"status":"ok","timestamp":1610109396420,"user_tz":-540,"elapsed":528217,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"c2564487-88fd-4b00-a16f-f45f35f95dd6"},"source":["np.random.seed(seed=0)\n","torch.manual_seed(seed=0)\n","M = MixtureLogitNetwork(\n"," name='mln',x_dim=[1,28,28],k_size=3,c_dims=[32,64],p_sizes=[2,2],\n"," h_dims=[128],y_dim=10,USE_BN=False,k=5,\n"," sig_min=0.01,sig_max=3,\n"," mu_min=-3,mu_max=+3).to(device)\n","M.init_param()\n","optm = optim.Adam(M.parameters(),lr=1e-3,weight_decay=1e-6)\n","M.train() # train mode\n","\n","# Re-define the train iterator\n","mnist_train = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)\n","n_train = len(mnist_train)\n","corrupt_rate = 0.3 # random permutation rate \n","targets_bu = mnist_train.targets\n","permute_targets = [1,2,3,4,5,6,7,8,9,0] # shift label \n","for idx in range(10):\n"," sel_idx = torch.where(targets_bu==idx)[0]\n"," n_sel = sel_idx.shape[0]\n"," corrupt_idx = np.random.permutation(n_sel)[:int(n_sel*corrupt_rate)]\n"," mnist_train.targets[sel_idx[corrupt_idx]] = permute_targets[idx]\n","BATCH_SIZE = 64\n","train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=BATCH_SIZE,shuffle=True,num_workers=1)\n","\n","EPOCHS,print_every = 10,1\n","for epoch in range(EPOCHS):\n"," loss_sum,wace_sum,alea_sum = 0,0,0\n"," for batch_in,batch_out in train_iter:\n"," # Forward path\n"," pi,mu,sigma = M.forward(batch_in.view(-1,1,28,28).to(device)) \n"," target = torch.eye(M.y_dim)[batch_out].to(device)\n"," mace_loss_out = mace_loss(pi,mu,sigma,target,alea_weight=0.5) # mixture of CE \n"," loss_out = mace_loss_out['loss_avg']\n"," wace_out = mace_loss_out['wace_avg']\n"," alea_out = mace_loss_out['alea_avg']\n"," # Update \n"," optm.zero_grad() # reset gradient \n"," loss_out.backward() # backpropagate\n"," optm.step() # optimizer update\n"," # Track losses \n"," loss_sum += loss_out\n"," wace_sum += wace_out\n"," alea_sum += alea_out\n"," loss_avg = loss_sum/len(train_iter)\n"," wace_avg = wace_sum/len(train_iter)\n"," alea_avg = alea_sum/len(train_iter)\n"," # Print\n"," if ((epoch%print_every)==0) or (epoch==(EPOCHS-1)):\n"," train_res = func_eval(M,train_iter,device)\n"," test_res = func_eval(M,test_iter,device)\n"," print (\"epoch:[%d] loss:[%.3f]=(wace:%.3f+alea:%.3f) train_accr:[%.3f] test_accr:[%.3f].\"%\n"," (epoch,loss_avg,wace_avg,alea_avg,\n"," train_res['val_accr'],test_res['val_accr'])) \n"," print (\" [Train] alea:[%.3f] epis:[%.3f]\"%\n"," (train_res['alea_unct_avg'],train_res['epis_unct_avg']))\n"," print (\" [Test] alea:[%.3f] epis:[%.3f]\"%\n"," (test_res['alea_unct_avg'],test_res['epis_unct_avg']))\n","\n","print (\"Done\")"],"execution_count":13,"outputs":[{"output_type":"stream","text":["epoch:[0] loss:[0.411]=(wace:0.192+alea:0.219) train_accr:[0.659] test_accr:[0.930].\n"," [Train] alea:[5.737] epis:[0.022]\n"," [Test] alea:[5.717] epis:[0.022]\n","epoch:[1] loss:[0.335]=(wace:0.162+alea:0.174) train_accr:[0.657] test_accr:[0.918].\n"," [Train] alea:[5.203] epis:[0.012]\n"," [Test] alea:[5.164] epis:[0.012]\n","epoch:[2] loss:[0.317]=(wace:0.154+alea:0.163) train_accr:[0.681] test_accr:[0.962].\n"," [Train] alea:[4.874] epis:[0.005]\n"," [Test] alea:[4.881] epis:[0.005]\n","epoch:[3] loss:[0.303]=(wace:0.147+alea:0.156) train_accr:[0.690] test_accr:[0.976].\n"," [Train] alea:[4.483] epis:[0.002]\n"," [Test] alea:[4.490] epis:[0.002]\n","epoch:[4] loss:[0.293]=(wace:0.143+alea:0.150) train_accr:[0.691] test_accr:[0.973].\n"," [Train] alea:[4.628] epis:[0.003]\n"," [Test] alea:[4.638] epis:[0.003]\n","epoch:[5] loss:[0.283]=(wace:0.138+alea:0.145) train_accr:[0.695] test_accr:[0.982].\n"," [Train] alea:[4.059] epis:[0.001]\n"," [Test] alea:[4.067] epis:[0.001]\n","epoch:[6] loss:[0.275]=(wace:0.134+alea:0.141) train_accr:[0.696] test_accr:[0.976].\n"," [Train] alea:[4.136] epis:[0.001]\n"," [Test] alea:[4.145] epis:[0.001]\n","epoch:[7] loss:[0.271]=(wace:0.132+alea:0.139) train_accr:[0.697] test_accr:[0.976].\n"," [Train] alea:[4.090] epis:[0.001]\n"," [Test] alea:[4.103] epis:[0.001]\n","epoch:[8] loss:[0.261]=(wace:0.127+alea:0.134) train_accr:[0.697] test_accr:[0.970].\n"," [Train] alea:[4.101] epis:[0.001]\n"," [Test] alea:[4.129] epis:[0.001]\n","epoch:[9] loss:[0.257]=(wace:0.125+alea:0.132) train_accr:[0.700] test_accr:[0.974].\n"," [Train] alea:[3.959] epis:[0.001]\n"," [Test] alea:[3.981] epis:[0.001]\n","Done\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"NnGQhAdr2vKV","executionInfo":{"status":"ok","timestamp":1610109396425,"user_tz":-540,"elapsed":528163,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}}},"source":[""],"execution_count":13,"outputs":[]}]} -------------------------------------------------------------------------------- /notebook/mha.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"mha.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyNtKYtrAk1tHkKhTt/4wK5a"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"rsnoyeYbS7i6"},"source":["\n"," \n"," \n","
\n"," Colab\n"," \n"," View Source\n","
"]},{"cell_type":"markdown","metadata":{"id":"xRwj6fIzVc-s"},"source":["# Multi-Headed Attention"]},{"cell_type":"code","metadata":{"id":"G9N9n94EVWy9","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609913294780,"user_tz":-540,"elapsed":5521,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"24a09030-6f5a-4e62-cd26-a642cc2bac57"},"source":["import numpy as np\n","import matplotlib.pyplot as plt\n","import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","import torch.nn.functional as F\n","%matplotlib inline\n","%config InlineBackend.figure_format='retina'\n","print (\"PyTorch version:[%s].\"%(torch.__version__))\n","device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n","print (\"device:[%s].\"%(device))"],"execution_count":1,"outputs":[{"output_type":"stream","text":["PyTorch version:[1.7.0+cu101].\n","device:[cuda:0].\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"TRRvnKZ0XlJT"},"source":["### Scaled Dot-Product Attention (SDPA)\n","- Data $X \\in \\mathbb{R}^{n \\times d}$ where $n$ is the number data and $d$ is the data dimension\n","- Query and Key $Q, K \\in \\mathbb{R}^{n \\times d_K}$ \n","- Value $V \\in \\mathbb{R}^{n \\times d_V} $\n","\n","$\\text{Attention}(Q,K,V) = \\text{softmax} \\left( \\frac{QK^T}{\\sqrt{d_K}} \\right)V \\in \\mathbb{R}^{n \\times d_V} $"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"K-Z3Vd_VV5Pm","executionInfo":{"status":"ok","timestamp":1609913693623,"user_tz":-540,"elapsed":1321,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"15debb01-d0cd-4df9-b636-e19de5c46d40"},"source":["class ScaledDotProductAttention(nn.Module):\n"," def forward(self,Q,K,V,mask=None):\n"," d_K = K.size()[-1] # key dimension\n"," scores = Q.matmul(K.transpose(-2,-1)) / np.sqrt(d_K)\n"," if mask is not None:\n"," scores = scores.masked_fill(mask==0, -1e9)\n"," attention = F.softmax(scores,dim=-1)\n"," out = attention.matmul(V)\n"," return out,attention\n","\n","# Demo run of scaled dot product attention \n","SPDA = ScaledDotProductAttention()\n","n_batch,d_K,d_V = 3,128,256 # d_K(=d_Q) does not necessarily be equal to d_V\n","n_Q,n_K,n_V = 30,50,50\n","Q = torch.rand(n_batch,n_Q,d_K)\n","K = torch.rand(n_batch,n_K,d_K)\n","V = torch.rand(n_batch,n_V,d_V)\n","out,attention = SPDA.forward(Q,K,V,mask=None)\n","def sh(x): return str(x.shape)[11:-1] \n","print (\"SDPA: Q%s K%s V%s => out%s attention%s\"%\n"," (sh(Q),sh(K),sh(V),sh(out),sh(attention)))\n","\n","# It supports 'multi-headed' attention\n","n_batch,n_head,d_K,d_V = 3,5,128,256\n","n_Q,n_K,n_V = 30,50,50 # n_K and n_V should be the same\n","Q = torch.rand(n_batch,n_head,n_Q,d_K)\n","K = torch.rand(n_batch,n_head,n_K,d_K)\n","V = torch.rand(n_batch,n_head,n_V,d_V)\n","out,attention = SPDA.forward(Q,K,V,mask=None)\n","# out: [n_batch x n_head x n_Q x d_V]\n","# attention: [n_batch x n_head x n_Q x n_K] \n","def sh(x): return str(x.shape)[11:-1] \n","print (\"(Multi-Headed) SDPA: Q%s K%s V%s => out%s attention%s\"%\n"," (sh(Q),sh(K),sh(V),sh(out),sh(attention)))"],"execution_count":4,"outputs":[{"output_type":"stream","text":["SDPA: Q[3, 30, 128] K[3, 50, 128] V[3, 50, 256] => out[3, 30, 256] attention[3, 30, 50]\n","(Multi-Headed) SDPA: Q[3, 5, 30, 128] K[3, 5, 50, 128] V[3, 5, 50, 256] => out[3, 5, 30, 256] attention[3, 5, 30, 50]\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"uLbi13pDi3No"},"source":["### Multi-Headed Attention (MHA)\n","\n","$\\text{head}_{\\color{red}i} = \\text{Attention}(Q {\\color{green}W}^Q_{\\color{red}i},K {\\color{green}W}^K_{\\color{red}i}, V {\\color{green}W}^V_{\\color{red}i}) $"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Hf7j24l1dnSF","executionInfo":{"status":"ok","timestamp":1609913914502,"user_tz":-540,"elapsed":1159,"user":{"displayName":"Sungjoon Choi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiFkMNCaA4zshD2C87LC6X0Y7ohjLlu0sIiLepLnQ=s64","userId":"10728677910935649939"}},"outputId":"3866ebf2-b462-46f3-dffe-09414727afca"},"source":["class MultiHeadedAttention(nn.Module):\n"," def __init__(self,d_feat=128,n_head=5,actv=F.relu,USE_BIAS=True,dropout_p=0.1,device=None):\n"," \"\"\"\n"," :param d_feat: feature dimension\n"," :param n_head: number of heads\n"," :param actv: activation after each linear layer\n"," :param USE_BIAS: whether to use bias\n"," :param dropout_p: dropout rate\n"," :device: which device to use (e.g., cuda:0)\n"," \"\"\"\n"," super(MultiHeadedAttention,self).__init__()\n"," if (d_feat%n_head) != 0:\n"," raise ValueError(\"d_feat(%d) should be divisible by b_head(%d)\"%(d_feat,n_head)) \n"," self.d_feat = d_feat\n"," self.n_head = n_head\n"," self.d_head = self.d_feat // self.n_head\n"," self.actv = actv\n"," self.USE_BIAS = USE_BIAS\n"," self.dropout_p = dropout_p # prob. of zeroed\n","\n"," self.lin_Q = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)\n"," self.lin_K = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)\n"," self.lin_V = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)\n"," self.lin_O = nn.Linear(self.d_feat,self.d_feat,self.USE_BIAS)\n","\n"," self.dropout = nn.Dropout(p=self.dropout_p)\n"," \n"," def forward(self,Q,K,V,mask=None):\n"," \"\"\"\n"," :param Q: [n_batch, n_Q, d_feat]\n"," :param K: [n_batch, n_K, d_feat]\n"," :param V: [n_batch, n_V, d_feat] <= n_K and n_V must be the same \n"," :param mask: \n"," \"\"\"\n"," n_batch = Q.shape[0]\n"," Q_feat = self.lin_Q(Q) \n"," K_feat = self.lin_K(K) \n"," V_feat = self.lin_V(V)\n"," # Q_feat: [n_batch, n_Q, d_feat]\n"," # K_feat: [n_batch, n_K, d_feat]\n"," # V_feat: [n_batch, n_V, d_feat]\n","\n"," # Multi-head split of Q, K, and V (d_feat = n_head*d_head)\n"," Q_split = Q_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0, 2, 1, 3)\n"," K_split = K_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0, 2, 1, 3)\n"," V_split = V_feat.view(n_batch, -1, self.n_head, self.d_head).permute(0, 2, 1, 3)\n"," # Q_split: [n_batch, n_head, n_Q, d_head]\n"," # K_split: [n_batch, n_head, n_K, d_head]\n"," # V_split: [n_batch, n_head, n_V, d_head]\n","\n"," # Multi-Headed Attention\n"," d_K = K.size()[-1] # key dimension\n"," scores = torch.matmul(Q_split,K_split.permute(0,1,3,2)) / np.sqrt(d_K)\n"," if mask is not None:\n"," scores = scores.masked_fill(mask==0,-1e9)\n"," attention = torch.softmax(scores,dim=-1)\n"," x_raw = torch.matmul(self.dropout(attention),V_split) # dropout is NOT mentioned in the paper\n"," # attention: [n_batch, n_head, n_Q, n_K]\n"," # x_raw: [n_batch, n_head, n_Q, d_head]\n","\n"," # Reshape x\n"," x_rsh1 = x_raw.permute(0,2,1,3).contiguous()\n"," # x_rsh1: [n_batch, n_Q, n_head, d_head]\n"," x_rsh2 = x_rsh1.view(n_batch,-1,self.d_feat)\n"," # x_rsh2: [n_batch, n_Q, d_feat]\n","\n"," # Linear\n"," x = self.lin_O(x_rsh2)\n"," # x: [n_batch, n_Q, d_feat]\n"," out = {'Q_feat':Q_feat,'K_feat':K_feat,'V_feat':V_feat,\n"," 'Q_split':Q_split,'K_split':K_split,'V_split':V_split,\n"," 'scores':scores,'attention':attention,\n"," 'x_raw':x_raw,'x_rsh1':x_rsh1,'x_rsh2':x_rsh2,'x':x}\n"," return out\n","\n","# Self-Attention Layer\n","n_batch = 128\n","n_src = 32\n","d_feat = 200\n","n_head = 5\n","src = torch.rand(n_batch,n_src,d_feat)\n","self_attention = MultiHeadedAttention(\n"," d_feat=d_feat,n_head=n_head,actv=F.relu,USE_BIAS=True,dropout_p=0.1,device=device)\n","out = self_attention.forward(src,src,src,mask=None)\n","\n","Q_feat,K_feat,V_feat = out['Q_feat'],out['K_feat'],out['V_feat']\n","Q_split,K_split,V_split = out['Q_split'],out['K_split'],out['V_split']\n","scores,attention = out['scores'],out['attention']\n","x_raw,x_rsh1,x_rsh2,x = out['x_raw'],out['x_rsh1'],out['x_rsh2'],out['x']\n","\n","# Print out shapes\n","def sh(_x): return str(_x.shape)[11:-1] \n","print (\"Input src:\\t%s \\t= [n_batch, n_src, d_feat]\"%(sh(src)))\n","print ()\n","print (\"Q_feat: \\t%s \\t= [n_batch, n_src, d_feat]\"%(sh(Q_feat)))\n","print (\"K_feat: \\t%s \\t= [n_batch, n_src, d_feat]\"%(sh(K_feat)))\n","print (\"V_feat: \\t%s \\t= [n_batch, n_src, d_feat]\"%(sh(V_feat)))\n","print ()\n","print (\"Q_split: \\t%s \\t= [n_batch, n_head, n_src, d_head]\"%(sh(Q_split)))\n","print (\"K_split: \\t%s \\t= [n_batch, n_head, n_src, d_head]\"%(sh(K_split)))\n","print (\"V_split: \\t%s \\t= [n_batch, n_head, n_src, d_head]\"%(sh(V_split)))\n","print ()\n","print (\"scores: \\t%s \\t= [n_batch, n_head, n_src, n_src]\"%(sh(scores)))\n","print (\"attention:\\t%s \\t= [n_batch, n_head, n_src, n_src]\"%(sh(attention)))\n","print ()\n","print (\"x_raw: \\t%s \\t= [n_batch, n_head, n_src, d_head]\"%(sh(x_raw)))\n","print (\"x_rsh1: \\t%s \\t= [n_batch, n_src, n_head, d_head]\"%(sh(x_rsh1)))\n","print (\"x_rsh2: \\t%s \\t= [n_batch, n_src, d_feat]\"%(sh(x_rsh2)))\n","print ()\n","print (\"Output x: \\t%s \\t= [n_batch, n_src, d_feat]\"%(sh(x)))\n"],"execution_count":5,"outputs":[{"output_type":"stream","text":["Input src:\t[128, 32, 200] \t= [n_batch, n_src, d_feat]\n","\n","Q_feat: \t[128, 32, 200] \t= [n_batch, n_src, d_feat]\n","K_feat: \t[128, 32, 200] \t= [n_batch, n_src, d_feat]\n","V_feat: \t[128, 32, 200] \t= [n_batch, n_src, d_feat]\n","\n","Q_split: \t[128, 5, 32, 40] \t= [n_batch, n_head, n_src, d_head]\n","K_split: \t[128, 5, 32, 40] \t= [n_batch, n_head, n_src, d_head]\n","V_split: \t[128, 5, 32, 40] \t= [n_batch, n_head, n_src, d_head]\n","\n","scores: \t[128, 5, 32, 32] \t= [n_batch, n_head, n_src, n_src]\n","attention:\t[128, 5, 32, 32] \t= [n_batch, n_head, n_src, n_src]\n","\n","x_raw: \t[128, 5, 32, 40] \t= [n_batch, n_head, n_src, d_head]\n","x_rsh1: \t[128, 32, 5, 40] \t= [n_batch, n_src, n_head, d_head]\n","x_rsh2: \t[128, 32, 200] \t= [n_batch, n_src, d_feat]\n","\n","Output x: \t[128, 32, 200] \t= [n_batch, n_src, d_feat]\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"QgqsHUT-OuJA"},"source":[""],"execution_count":null,"outputs":[]}]} --------------------------------------------------------------------------------