├── .gitignore ├── LICENSE ├── README.md ├── aws-pytorch ├── commands.txt └── install.sh ├── knn └── knn.py ├── pr-lr ├── Loss Patterns.ipynb ├── Loss vs Gradient Norm.ipynb ├── Monte Carlo and Importance Sampling.ipynb ├── cifar.py ├── cifar_reg.py ├── classic.py ├── imgs │ ├── loss_vs_grad.jpg │ ├── no_shuffling.jpg │ └── with_shuffling.jpg ├── mini_batch.py └── models.py └── spatial-transformer ├── interpolation.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | data/ 3 | dump/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kevin Zakka 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### blog-code 2 | 3 | Code from my academic blog [kevinzakka.github.io](https://kevinzakka.github.io/) 4 | -------------------------------------------------------------------------------- /aws-pytorch/commands.txt: -------------------------------------------------------------------------------- 1 | 2 important commands 2 | 3 | # ssh into an instance 4 | ssh -v -i ubuntu@ 5 | 6 | # copy file from local computer to instance 7 | scp -i -r ubuntu@: 8 | -------------------------------------------------------------------------------- /aws-pytorch/install.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # drivers 4 | wget http://us.download.nvidia.com/tesla/375.66/nvidia-diag-driver-local-repo-ubuntu1604_375.66-1_amd64.deb 5 | dpkg -i nvidia-diag-driver-local-repo-ubuntu1604_375.66-1_amd64.deb 6 | sudo apt-get update 7 | sudo apt-get -y install cuda-drivers 8 | sudo apt-get update && sudo apt-get -y upgrade 9 | 10 | # python 11 | sudo apt-get install unzip 12 | sudo apt-get --assume-yes install python3-tk 13 | sudo apt-get --assume-yes install python3-pip 14 | sudo pip3 install --upgrade pip 15 | sudo pip3 install virtualenv numpy scipy matplotlib 16 | 17 | # virtualenv 18 | mkdir envs 19 | cd envs 20 | virtualenv --system-site-packages deepL 21 | 22 | # pytorch 23 | source ~/envs/deepL/bin/activate 24 | pip install http://download.pytorch.org/whl/cu91/torch-0.4.0-cp35-cp35m-linux_x86_64.whl 25 | pip install torchvision tqdm 26 | 27 | sudo reboot 28 | -------------------------------------------------------------------------------- /knn/knn.py: -------------------------------------------------------------------------------- 1 | # ============================== loading libraries =========================================== 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | from sklearn.cross_validation import train_test_split 6 | from sklearn.neighbors import KNeighborsClassifier 7 | from sklearn.metrics import accuracy_score 8 | from sklearn.cross_validation import cross_val_score 9 | from collections import Counter 10 | # ============================================================================================= 11 | # Part I 12 | # ============================================================================================= 13 | 14 | # ============================== data preprocessing =========================================== 15 | # define column names 16 | names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'class'] 17 | 18 | # loading training data 19 | df = pd.read_csv('/Users/kevin/Desktop/Blog/iris.data.txt', header=None, names=names) 20 | print(df.head()) 21 | 22 | # create design matrix X and target vector y 23 | X = np.array(df.ix[:, 0:4]) # end index is exclusive 24 | y = np.array(df['class']) # showing you two ways of indexing a pandas df 25 | 26 | # split into train and test 27 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) 28 | # ============================== KNN with k = 3 =============================================== 29 | # instantiate learning model (k = 3) 30 | knn = KNeighborsClassifier(n_neighbors=3) 31 | 32 | # fitting the model 33 | knn.fit(X_train, y_train) 34 | 35 | # predict the response 36 | pred = knn.predict(X_test) 37 | 38 | # evaluate accuracy 39 | acc = accuracy_score(y_test, pred) * 100 40 | print('\nThe accuracy of the knn classifier for k = 3 is %d%%' % acc) 41 | # ============================== parameter tuning ============================================= 42 | # creating odd list of K for KNN 43 | myList = list(range(0,50)) 44 | neighbors = list(filter(lambda x: x % 2 != 0, myList)) 45 | 46 | # empty list that will hold cv scores 47 | cv_scores = [] 48 | 49 | # perform 10-fold cross validation 50 | for k in neighbors: 51 | knn = KNeighborsClassifier(n_neighbors=k) 52 | scores = cross_val_score(knn, X_train, y_train, cv=10, scoring='accuracy') 53 | cv_scores.append(scores.mean()) 54 | 55 | # changing to misclassification error 56 | MSE = [1 - x for x in cv_scores] 57 | 58 | # determining best k 59 | optimal_k = neighbors[MSE.index(min(MSE))] 60 | print('\nThe optimal number of neighbors is %d.' % optimal_k) 61 | 62 | # plot misclassification error vs k 63 | plt.plot(neighbors, MSE) 64 | plt.xlabel('Number of Neighbors K') 65 | plt.ylabel('Misclassification Error') 66 | plt.show() 67 | # ============================================================================================= 68 | # Part II 69 | # ============================================================================================= 70 | # ===================================== writing our own KNN =================================== 71 | def train(X_train, y_train): 72 | # do nothing 73 | return 74 | 75 | def predict(X_train, y_train, x_test, k): 76 | # create list for distances and targets 77 | distances = [] 78 | targets = [] 79 | 80 | for i in range(len(X_train)): 81 | # first we compute the euclidean distance 82 | distance = np.sqrt(np.sum(np.square(x_test - X_train[i, :]))) 83 | # add it to list of distances 84 | distances.append([distance, i]) 85 | 86 | # sort the list 87 | distances = sorted(distances) 88 | 89 | # make a list of the k neighbors' targets 90 | for i in range(k): 91 | index = distances[i][1] 92 | #print(y_train[index]) 93 | targets.append(y_train[index]) 94 | 95 | # return most common target 96 | return Counter(targets).most_common(1)[0][0] 97 | 98 | def kNearestNeighbor(X_train, y_train, X_test, predictions, k): 99 | # check if k is not larger than n 100 | if k > len(X_train): 101 | raise ValueError 102 | 103 | # train on the input data 104 | train(X_train, y_train) 105 | 106 | # predict for each testing observation 107 | for i in range(len(X_test)): 108 | predictions.append(predict(X_train, y_train, X_test[i, :], k)) 109 | # ============================== testing our KNN ============================================= 110 | # making our predictions 111 | predictions = [] 112 | try: 113 | kNearestNeighbor(X_train, y_train, X_test, predictions, 7) 114 | predictions = np.asarray(predictions) 115 | 116 | # evaluating accuracy 117 | accuracy = accuracy_score(y_test, predictions) * 100 118 | print('\nThe accuracy of OUR classifier is %d%%' % accuracy) 119 | 120 | except ValueError: 121 | print('Can\'t have more neighbors than training samples!!') 122 | -------------------------------------------------------------------------------- /pr-lr/Loss Patterns.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import numpy as np\n", 11 | "import pandas as pd\n", 12 | "import seaborn as sns\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "\n", 15 | "from functools import reduce\n", 16 | "from matplotlib import colors\n", 17 | "from matplotlib.ticker import MaxNLocator\n", 18 | "\n", 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "import torch.optim as optim\n", 22 | "import torch.nn.functional as F\n", 23 | "\n", 24 | "from torchvision.datasets import MNIST\n", 25 | "from torchvision import transforms\n", 26 | "from torch.utils.data import DataLoader\n", 27 | "from torch.utils.data.sampler import Sampler\n", 28 | "\n", 29 | "# plotting params\n", 30 | "%matplotlib inline\n", 31 | "plt.rcParams['font.size'] = 10\n", 32 | "plt.rcParams['axes.labelsize'] = 10\n", 33 | "plt.rcParams['axes.titlesize'] = 10\n", 34 | "plt.rcParams['xtick.labelsize'] = 8\n", 35 | "plt.rcParams['ytick.labelsize'] = 8\n", 36 | "plt.rcParams['legend.fontsize'] = 10\n", 37 | "plt.rcParams['figure.titlesize'] = 12\n", 38 | "plt.rcParams['figure.figsize'] = (13.0, 6.0)\n", 39 | "sns.set_style(\"white\")\n", 40 | "\n", 41 | "data_dir = './data/'\n", 42 | "plot_dir = './imgs/'\n", 43 | "dump_dir = './dump/'" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## Setup" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# ensuring reproducibility\n", 60 | "SEED = 42\n", 61 | "torch.manual_seed(SEED)\n", 62 | "torch.backends.cudnn.benchmark = False" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "GPU = False\n", 72 | "\n", 73 | "device = torch.device(\"cuda\" if GPU else \"cpu\")\n", 74 | "kwargs = {'num_workers': 1, 'pin_memory': True} if GPU else {}" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Data Loader\n", 82 | "\n", 83 | "We need to create a special dataloader for the experiment with shuffling. This is necessary because we need to keep track of each sample and shuffling loses that information.\n", 84 | "\n", 85 | "To solve this, we can:\n", 86 | "\n", 87 | "- create permutations of a list of numbers from 0 to 59,999 (the number of images in MNIST)\n", 88 | "- create a sampler class that takes a list and interates over it sequentially\n", 89 | "- at each epoch, create a dataloader with a sampler that gets fed the precomputed permutations" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "class LinearSampler(Sampler):\n", 99 | " def __init__(self, idx):\n", 100 | " self.idx = idx\n", 101 | "\n", 102 | " def __iter__(self):\n", 103 | " return iter(self.idx)\n", 104 | "\n", 105 | " def __len__(self):\n", 106 | " return len(self.idx)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 5, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "def get_data_loader(data_dir, batch_size, permutation=None, num_workers=3, pin_memory=False):\n", 116 | " normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,))\n", 117 | " transform = transforms.Compose([transforms.ToTensor(), normalize])\n", 118 | " dataset = MNIST(root=data_dir, train=True, download=True, transform=transform)\n", 119 | " \n", 120 | " sampler = None\n", 121 | " if permutation is not None:\n", 122 | " sampler = LinearSampler(permutation)\n", 123 | "\n", 124 | " loader = DataLoader(\n", 125 | " dataset, batch_size=batch_size,\n", 126 | " shuffle=False, num_workers=num_workers,\n", 127 | " pin_memory=pin_memory, sampler=sampler\n", 128 | " )\n", 129 | "\n", 130 | " return loader" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "## Model" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "class SmallConv(nn.Module):\n", 147 | " def __init__(self):\n", 148 | " super(SmallConv, self).__init__()\n", 149 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", 150 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", 151 | " self.fc1 = nn.Linear(320, 50)\n", 152 | " self.fc2 = nn.Linear(50, 10)\n", 153 | "\n", 154 | " def forward(self, x):\n", 155 | " out = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 156 | " out = F.relu(F.max_pool2d(self.conv2(out), 2))\n", 157 | " out = out.view(-1, 320)\n", 158 | " out = F.relu(self.fc1(out))\n", 159 | " out = self.fc2(out)\n", 160 | " return F.log_softmax(out, dim=1)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## Utility Functions" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 10, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "def accuracy(predicted, ground_truth):\n", 177 | " predicted = torch.max(predicted, 1)[1]\n", 178 | " total = len(ground_truth)\n", 179 | " correct = (predicted == ground_truth).sum().double()\n", 180 | " acc = 100 * (correct / total)\n", 181 | " return acc.item()\n", 182 | "\n", 183 | "def train(model, device, train_loader, optimizer, epoch):\n", 184 | " model.train()\n", 185 | " \n", 186 | " epoch_stats = []\n", 187 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 188 | " data, target = data.to(device), target.to(device)\n", 189 | " optimizer.zero_grad()\n", 190 | "\n", 191 | " # forward pass\n", 192 | " output = model(data)\n", 193 | " acc = accuracy(output, target)\n", 194 | " \n", 195 | " # compute batch loss and gradient norm\n", 196 | " losses = F.nll_loss(output, target, reduction='none')\n", 197 | " indices = [batch_idx*train_loader.batch_size + i for i in range(len(data))]\n", 198 | " \n", 199 | " batch_stats = []\n", 200 | " for i, l in zip(indices, losses):\n", 201 | " batch_stats.append([i, l])\n", 202 | " epoch_stats.append(batch_stats)\n", 203 | " \n", 204 | " # take average loss\n", 205 | " loss = losses.mean()\n", 206 | " \n", 207 | " # backwards pass\n", 208 | " loss.backward()\n", 209 | " optimizer.step()\n", 210 | " \n", 211 | " if batch_idx % 25 == 0:\n", 212 | " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tAcc: {:.2f}%'.format(\n", 213 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n", 214 | "100. * batch_idx / len(train_loader), loss.item(), acc))\n", 215 | "\n", 216 | " return epoch_stats\n", 217 | "\n", 218 | "def percentage_split(seq, percentages):\n", 219 | " cdf = np.cumsum(percentages)\n", 220 | " assert np.allclose(cdf[-1], 1.0)\n", 221 | " stops = list(map(int, cdf * len(seq)))\n", 222 | " return [seq[a:b] for a, b in zip([0]+stops, stops)]\n", 223 | "\n", 224 | "def bin_losses(all_epochs, num_quantiles=10):\n", 225 | " percentile_splits = []\n", 226 | " for epoch in all_epochs:\n", 227 | " # sort by decreasing loss\n", 228 | " sorted_loss_idx = sorted(\n", 229 | " range(len(epoch)), key=lambda k: epoch[k][1], reverse=True\n", 230 | " )\n", 231 | " \n", 232 | " # bin into 10 quantiles\n", 233 | " splits = percentage_split(sorted_loss_idx, [num_quantiles/100]*num_quantiles)\n", 234 | " \n", 235 | " percentile_splits.append(splits)\n", 236 | "\n", 237 | " return percentile_splits" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "## Without Shuffling" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 8, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "num_epochs = 5\n", 254 | "learning_rate = 1e-3\n", 255 | "mom = 0.99\n", 256 | "batch_size = 64" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 9, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 3.633006\tAcc: 10.94%\n", 269 | "Train Epoch: 1 [1600/60000 (3%)]\tLoss: 1.619076\tAcc: 48.44%\n", 270 | "Train Epoch: 1 [3200/60000 (5%)]\tLoss: 0.823937\tAcc: 71.88%\n", 271 | "Train Epoch: 1 [4800/60000 (8%)]\tLoss: 0.598104\tAcc: 79.69%\n", 272 | "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.370156\tAcc: 92.19%\n", 273 | "Train Epoch: 1 [8000/60000 (13%)]\tLoss: 0.533670\tAcc: 81.25%\n", 274 | "Train Epoch: 1 [9600/60000 (16%)]\tLoss: 0.291023\tAcc: 89.06%\n", 275 | "Train Epoch: 1 [11200/60000 (19%)]\tLoss: 0.611064\tAcc: 82.81%\n", 276 | "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.301691\tAcc: 85.94%\n", 277 | "Train Epoch: 1 [14400/60000 (24%)]\tLoss: 0.171518\tAcc: 96.88%\n", 278 | "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.311606\tAcc: 92.19%\n", 279 | "Train Epoch: 1 [17600/60000 (29%)]\tLoss: 0.311102\tAcc: 90.62%\n", 280 | "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.225483\tAcc: 93.75%\n", 281 | "Train Epoch: 1 [20800/60000 (35%)]\tLoss: 0.140306\tAcc: 93.75%\n", 282 | "Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.154456\tAcc: 95.31%\n", 283 | "Train Epoch: 1 [24000/60000 (40%)]\tLoss: 0.203898\tAcc: 93.75%\n", 284 | "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.032903\tAcc: 100.00%\n", 285 | "Train Epoch: 1 [27200/60000 (45%)]\tLoss: 0.218158\tAcc: 93.75%\n", 286 | "Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.071584\tAcc: 98.44%\n", 287 | "Train Epoch: 1 [30400/60000 (51%)]\tLoss: 0.053365\tAcc: 100.00%\n", 288 | "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.092851\tAcc: 96.88%\n", 289 | "Train Epoch: 1 [33600/60000 (56%)]\tLoss: 0.023382\tAcc: 100.00%\n", 290 | "Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.239355\tAcc: 93.75%\n", 291 | "Train Epoch: 1 [36800/60000 (61%)]\tLoss: 0.159463\tAcc: 93.75%\n", 292 | "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.054997\tAcc: 98.44%\n", 293 | "Train Epoch: 1 [40000/60000 (67%)]\tLoss: 0.179467\tAcc: 93.75%\n", 294 | "Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.061505\tAcc: 96.88%\n", 295 | "Train Epoch: 1 [43200/60000 (72%)]\tLoss: 0.169429\tAcc: 95.31%\n", 296 | "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.159596\tAcc: 95.31%\n", 297 | "Train Epoch: 1 [46400/60000 (77%)]\tLoss: 0.468967\tAcc: 87.50%\n", 298 | "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.086294\tAcc: 95.31%\n", 299 | "Train Epoch: 1 [49600/60000 (83%)]\tLoss: 0.220987\tAcc: 93.75%\n", 300 | "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.208579\tAcc: 93.75%\n", 301 | "Train Epoch: 1 [52800/60000 (88%)]\tLoss: 0.264812\tAcc: 93.75%\n", 302 | "Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.046534\tAcc: 96.88%\n", 303 | "Train Epoch: 1 [56000/60000 (93%)]\tLoss: 0.117640\tAcc: 98.44%\n", 304 | "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.196920\tAcc: 93.75%\n", 305 | "Train Epoch: 1 [59200/60000 (99%)]\tLoss: 0.011510\tAcc: 100.00%\n", 306 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.074488\tAcc: 98.44%\n", 307 | "Train Epoch: 2 [1600/60000 (3%)]\tLoss: 0.132464\tAcc: 96.88%\n", 308 | "Train Epoch: 2 [3200/60000 (5%)]\tLoss: 0.066270\tAcc: 98.44%\n", 309 | "Train Epoch: 2 [4800/60000 (8%)]\tLoss: 0.071151\tAcc: 98.44%\n", 310 | "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.210089\tAcc: 96.88%\n", 311 | "Train Epoch: 2 [8000/60000 (13%)]\tLoss: 0.061824\tAcc: 98.44%\n", 312 | "Train Epoch: 2 [9600/60000 (16%)]\tLoss: 0.054427\tAcc: 96.88%\n", 313 | "Train Epoch: 2 [11200/60000 (19%)]\tLoss: 0.140843\tAcc: 93.75%\n", 314 | "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.093792\tAcc: 95.31%\n", 315 | "Train Epoch: 2 [14400/60000 (24%)]\tLoss: 0.045398\tAcc: 98.44%\n", 316 | "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.158572\tAcc: 95.31%\n", 317 | "Train Epoch: 2 [17600/60000 (29%)]\tLoss: 0.026146\tAcc: 100.00%\n", 318 | "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.226149\tAcc: 96.88%\n", 319 | "Train Epoch: 2 [20800/60000 (35%)]\tLoss: 0.018054\tAcc: 100.00%\n", 320 | "Train Epoch: 2 [22400/60000 (37%)]\tLoss: 0.030244\tAcc: 100.00%\n", 321 | "Train Epoch: 2 [24000/60000 (40%)]\tLoss: 0.028830\tAcc: 98.44%\n", 322 | "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.006937\tAcc: 100.00%\n", 323 | "Train Epoch: 2 [27200/60000 (45%)]\tLoss: 0.154943\tAcc: 95.31%\n", 324 | "Train Epoch: 2 [28800/60000 (48%)]\tLoss: 0.044339\tAcc: 96.88%\n", 325 | "Train Epoch: 2 [30400/60000 (51%)]\tLoss: 0.041290\tAcc: 98.44%\n", 326 | "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.060451\tAcc: 98.44%\n", 327 | "Train Epoch: 2 [33600/60000 (56%)]\tLoss: 0.015808\tAcc: 100.00%\n", 328 | "Train Epoch: 2 [35200/60000 (59%)]\tLoss: 0.101732\tAcc: 96.88%\n", 329 | "Train Epoch: 2 [36800/60000 (61%)]\tLoss: 0.061473\tAcc: 98.44%\n", 330 | "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.041257\tAcc: 98.44%\n", 331 | "Train Epoch: 2 [40000/60000 (67%)]\tLoss: 0.042018\tAcc: 98.44%\n", 332 | "Train Epoch: 2 [41600/60000 (69%)]\tLoss: 0.039075\tAcc: 98.44%\n", 333 | "Train Epoch: 2 [43200/60000 (72%)]\tLoss: 0.040230\tAcc: 96.88%\n", 334 | "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.066020\tAcc: 96.88%\n", 335 | "Train Epoch: 2 [46400/60000 (77%)]\tLoss: 0.359166\tAcc: 87.50%\n", 336 | "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.056128\tAcc: 98.44%\n", 337 | "Train Epoch: 2 [49600/60000 (83%)]\tLoss: 0.159148\tAcc: 95.31%\n", 338 | "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.183856\tAcc: 95.31%\n", 339 | "Train Epoch: 2 [52800/60000 (88%)]\tLoss: 0.201649\tAcc: 95.31%\n", 340 | "Train Epoch: 2 [54400/60000 (91%)]\tLoss: 0.019636\tAcc: 100.00%\n", 341 | "Train Epoch: 2 [56000/60000 (93%)]\tLoss: 0.046323\tAcc: 98.44%\n", 342 | "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.109396\tAcc: 96.88%\n", 343 | "Train Epoch: 2 [59200/60000 (99%)]\tLoss: 0.015261\tAcc: 100.00%\n", 344 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.024345\tAcc: 100.00%\n", 345 | "Train Epoch: 3 [1600/60000 (3%)]\tLoss: 0.103422\tAcc: 98.44%\n", 346 | "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.014144\tAcc: 100.00%\n", 347 | "Train Epoch: 3 [4800/60000 (8%)]\tLoss: 0.037579\tAcc: 100.00%\n", 348 | "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.198094\tAcc: 96.88%\n", 349 | "Train Epoch: 3 [8000/60000 (13%)]\tLoss: 0.014647\tAcc: 100.00%\n", 350 | "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.030492\tAcc: 100.00%\n", 351 | "Train Epoch: 3 [11200/60000 (19%)]\tLoss: 0.067932\tAcc: 96.88%\n", 352 | "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.133189\tAcc: 96.88%\n", 353 | "Train Epoch: 3 [14400/60000 (24%)]\tLoss: 0.037543\tAcc: 98.44%\n", 354 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.070643\tAcc: 98.44%\n", 355 | "Train Epoch: 3 [17600/60000 (29%)]\tLoss: 0.008885\tAcc: 100.00%\n", 356 | "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.181374\tAcc: 95.31%\n", 357 | "Train Epoch: 3 [20800/60000 (35%)]\tLoss: 0.008857\tAcc: 100.00%\n", 358 | "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.090927\tAcc: 96.88%\n", 359 | "Train Epoch: 3 [24000/60000 (40%)]\tLoss: 0.029201\tAcc: 98.44%\n", 360 | "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.022435\tAcc: 98.44%\n", 361 | "Train Epoch: 3 [27200/60000 (45%)]\tLoss: 0.116180\tAcc: 96.88%\n", 362 | "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.016861\tAcc: 100.00%\n", 363 | "Train Epoch: 3 [30400/60000 (51%)]\tLoss: 0.034622\tAcc: 96.88%\n", 364 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.081620\tAcc: 95.31%\n", 365 | "Train Epoch: 3 [33600/60000 (56%)]\tLoss: 0.017481\tAcc: 98.44%\n", 366 | "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.142122\tAcc: 96.88%\n", 367 | "Train Epoch: 3 [36800/60000 (61%)]\tLoss: 0.041635\tAcc: 98.44%\n", 368 | "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.041654\tAcc: 98.44%\n", 369 | "Train Epoch: 3 [40000/60000 (67%)]\tLoss: 0.027219\tAcc: 98.44%\n", 370 | "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.035152\tAcc: 98.44%\n", 371 | "Train Epoch: 3 [43200/60000 (72%)]\tLoss: 0.032792\tAcc: 96.88%\n", 372 | "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.037414\tAcc: 98.44%\n", 373 | "Train Epoch: 3 [46400/60000 (77%)]\tLoss: 0.252676\tAcc: 90.62%\n", 374 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.022737\tAcc: 100.00%\n", 375 | "Train Epoch: 3 [49600/60000 (83%)]\tLoss: 0.060854\tAcc: 96.88%\n", 376 | "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.182138\tAcc: 96.88%\n", 377 | "Train Epoch: 3 [52800/60000 (88%)]\tLoss: 0.145763\tAcc: 96.88%\n", 378 | "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.023650\tAcc: 98.44%\n", 379 | "Train Epoch: 3 [56000/60000 (93%)]\tLoss: 0.021866\tAcc: 98.44%\n", 380 | "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.099294\tAcc: 96.88%\n", 381 | "Train Epoch: 3 [59200/60000 (99%)]\tLoss: 0.007801\tAcc: 100.00%\n", 382 | "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.008445\tAcc: 100.00%\n", 383 | "Train Epoch: 4 [1600/60000 (3%)]\tLoss: 0.108969\tAcc: 98.44%\n", 384 | "Train Epoch: 4 [3200/60000 (5%)]\tLoss: 0.012685\tAcc: 100.00%\n", 385 | "Train Epoch: 4 [4800/60000 (8%)]\tLoss: 0.051898\tAcc: 98.44%\n", 386 | "Train Epoch: 4 [6400/60000 (11%)]\tLoss: 0.218408\tAcc: 96.88%\n", 387 | "Train Epoch: 4 [8000/60000 (13%)]\tLoss: 0.004244\tAcc: 100.00%\n", 388 | "Train Epoch: 4 [9600/60000 (16%)]\tLoss: 0.048583\tAcc: 96.88%\n", 389 | "Train Epoch: 4 [11200/60000 (19%)]\tLoss: 0.053676\tAcc: 96.88%\n", 390 | "Train Epoch: 4 [12800/60000 (21%)]\tLoss: 0.119289\tAcc: 96.88%\n", 391 | "Train Epoch: 4 [14400/60000 (24%)]\tLoss: 0.014504\tAcc: 100.00%\n", 392 | "Train Epoch: 4 [16000/60000 (27%)]\tLoss: 0.034672\tAcc: 100.00%\n", 393 | "Train Epoch: 4 [17600/60000 (29%)]\tLoss: 0.015773\tAcc: 100.00%\n", 394 | "Train Epoch: 4 [19200/60000 (32%)]\tLoss: 0.146741\tAcc: 96.88%\n", 395 | "Train Epoch: 4 [20800/60000 (35%)]\tLoss: 0.004331\tAcc: 100.00%\n", 396 | "Train Epoch: 4 [22400/60000 (37%)]\tLoss: 0.058988\tAcc: 96.88%\n", 397 | "Train Epoch: 4 [24000/60000 (40%)]\tLoss: 0.023140\tAcc: 98.44%\n", 398 | "Train Epoch: 4 [25600/60000 (43%)]\tLoss: 0.005246\tAcc: 100.00%\n", 399 | "Train Epoch: 4 [27200/60000 (45%)]\tLoss: 0.053999\tAcc: 96.88%\n", 400 | "Train Epoch: 4 [28800/60000 (48%)]\tLoss: 0.005452\tAcc: 100.00%\n" 401 | ] 402 | }, 403 | { 404 | "name": "stdout", 405 | "output_type": "stream", 406 | "text": [ 407 | "Train Epoch: 4 [30400/60000 (51%)]\tLoss: 0.011213\tAcc: 100.00%\n", 408 | "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.038986\tAcc: 96.88%\n", 409 | "Train Epoch: 4 [33600/60000 (56%)]\tLoss: 0.011563\tAcc: 100.00%\n", 410 | "Train Epoch: 4 [35200/60000 (59%)]\tLoss: 0.129776\tAcc: 96.88%\n", 411 | "Train Epoch: 4 [36800/60000 (61%)]\tLoss: 0.034880\tAcc: 98.44%\n", 412 | "Train Epoch: 4 [38400/60000 (64%)]\tLoss: 0.020573\tAcc: 100.00%\n", 413 | "Train Epoch: 4 [40000/60000 (67%)]\tLoss: 0.011226\tAcc: 100.00%\n", 414 | "Train Epoch: 4 [41600/60000 (69%)]\tLoss: 0.037513\tAcc: 98.44%\n", 415 | "Train Epoch: 4 [43200/60000 (72%)]\tLoss: 0.058901\tAcc: 98.44%\n", 416 | "Train Epoch: 4 [44800/60000 (75%)]\tLoss: 0.043138\tAcc: 98.44%\n", 417 | "Train Epoch: 4 [46400/60000 (77%)]\tLoss: 0.187012\tAcc: 93.75%\n", 418 | "Train Epoch: 4 [48000/60000 (80%)]\tLoss: 0.007739\tAcc: 100.00%\n", 419 | "Train Epoch: 4 [49600/60000 (83%)]\tLoss: 0.029716\tAcc: 100.00%\n", 420 | "Train Epoch: 4 [51200/60000 (85%)]\tLoss: 0.147594\tAcc: 98.44%\n", 421 | "Train Epoch: 4 [52800/60000 (88%)]\tLoss: 0.112262\tAcc: 96.88%\n", 422 | "Train Epoch: 4 [54400/60000 (91%)]\tLoss: 0.010896\tAcc: 100.00%\n", 423 | "Train Epoch: 4 [56000/60000 (93%)]\tLoss: 0.019521\tAcc: 98.44%\n", 424 | "Train Epoch: 4 [57600/60000 (96%)]\tLoss: 0.133354\tAcc: 95.31%\n", 425 | "Train Epoch: 4 [59200/60000 (99%)]\tLoss: 0.002342\tAcc: 100.00%\n", 426 | "Train Epoch: 5 [0/60000 (0%)]\tLoss: 0.021062\tAcc: 98.44%\n", 427 | "Train Epoch: 5 [1600/60000 (3%)]\tLoss: 0.079082\tAcc: 98.44%\n", 428 | "Train Epoch: 5 [3200/60000 (5%)]\tLoss: 0.022650\tAcc: 98.44%\n", 429 | "Train Epoch: 5 [4800/60000 (8%)]\tLoss: 0.056475\tAcc: 98.44%\n", 430 | "Train Epoch: 5 [6400/60000 (11%)]\tLoss: 0.196984\tAcc: 96.88%\n", 431 | "Train Epoch: 5 [8000/60000 (13%)]\tLoss: 0.002361\tAcc: 100.00%\n", 432 | "Train Epoch: 5 [9600/60000 (16%)]\tLoss: 0.031388\tAcc: 100.00%\n", 433 | "Train Epoch: 5 [11200/60000 (19%)]\tLoss: 0.030519\tAcc: 98.44%\n", 434 | "Train Epoch: 5 [12800/60000 (21%)]\tLoss: 0.093129\tAcc: 96.88%\n", 435 | "Train Epoch: 5 [14400/60000 (24%)]\tLoss: 0.021665\tAcc: 100.00%\n", 436 | "Train Epoch: 5 [16000/60000 (27%)]\tLoss: 0.043810\tAcc: 96.88%\n", 437 | "Train Epoch: 5 [17600/60000 (29%)]\tLoss: 0.007854\tAcc: 100.00%\n", 438 | "Train Epoch: 5 [19200/60000 (32%)]\tLoss: 0.133513\tAcc: 96.88%\n", 439 | "Train Epoch: 5 [20800/60000 (35%)]\tLoss: 0.003202\tAcc: 100.00%\n", 440 | "Train Epoch: 5 [22400/60000 (37%)]\tLoss: 0.033618\tAcc: 98.44%\n", 441 | "Train Epoch: 5 [24000/60000 (40%)]\tLoss: 0.013201\tAcc: 100.00%\n", 442 | "Train Epoch: 5 [25600/60000 (43%)]\tLoss: 0.002040\tAcc: 100.00%\n", 443 | "Train Epoch: 5 [27200/60000 (45%)]\tLoss: 0.046378\tAcc: 98.44%\n", 444 | "Train Epoch: 5 [28800/60000 (48%)]\tLoss: 0.004634\tAcc: 100.00%\n", 445 | "Train Epoch: 5 [30400/60000 (51%)]\tLoss: 0.009121\tAcc: 100.00%\n", 446 | "Train Epoch: 5 [32000/60000 (53%)]\tLoss: 0.047612\tAcc: 98.44%\n", 447 | "Train Epoch: 5 [33600/60000 (56%)]\tLoss: 0.006736\tAcc: 100.00%\n", 448 | "Train Epoch: 5 [35200/60000 (59%)]\tLoss: 0.138957\tAcc: 96.88%\n", 449 | "Train Epoch: 5 [36800/60000 (61%)]\tLoss: 0.037154\tAcc: 96.88%\n", 450 | "Train Epoch: 5 [38400/60000 (64%)]\tLoss: 0.025383\tAcc: 98.44%\n", 451 | "Train Epoch: 5 [40000/60000 (67%)]\tLoss: 0.017236\tAcc: 100.00%\n", 452 | "Train Epoch: 5 [41600/60000 (69%)]\tLoss: 0.039191\tAcc: 98.44%\n", 453 | "Train Epoch: 5 [43200/60000 (72%)]\tLoss: 0.048996\tAcc: 98.44%\n", 454 | "Train Epoch: 5 [44800/60000 (75%)]\tLoss: 0.043573\tAcc: 98.44%\n", 455 | "Train Epoch: 5 [46400/60000 (77%)]\tLoss: 0.159468\tAcc: 95.31%\n", 456 | "Train Epoch: 5 [48000/60000 (80%)]\tLoss: 0.007549\tAcc: 100.00%\n", 457 | "Train Epoch: 5 [49600/60000 (83%)]\tLoss: 0.023951\tAcc: 100.00%\n", 458 | "Train Epoch: 5 [51200/60000 (85%)]\tLoss: 0.146604\tAcc: 98.44%\n", 459 | "Train Epoch: 5 [52800/60000 (88%)]\tLoss: 0.139337\tAcc: 95.31%\n", 460 | "Train Epoch: 5 [54400/60000 (91%)]\tLoss: 0.002421\tAcc: 100.00%\n", 461 | "Train Epoch: 5 [56000/60000 (93%)]\tLoss: 0.013481\tAcc: 100.00%\n", 462 | "Train Epoch: 5 [57600/60000 (96%)]\tLoss: 0.161190\tAcc: 95.31%\n", 463 | "Train Epoch: 5 [59200/60000 (99%)]\tLoss: 0.002633\tAcc: 100.00%\n" 464 | ] 465 | } 466 | ], 467 | "source": [ 468 | "torch.manual_seed(SEED)\n", 469 | "\n", 470 | "# instantiate convnet\n", 471 | "model = SmallConv().to(device)\n", 472 | "\n", 473 | "# relu init\n", 474 | "for m in model.modules():\n", 475 | " if isinstance(m, (nn.Conv2d, nn.Linear)):\n", 476 | " nn.init.kaiming_normal_(m.weight, mode='fan_in')\n", 477 | "\n", 478 | "# define optimizer\n", 479 | "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=mom)\n", 480 | "\n", 481 | "# instantiate data loader\n", 482 | "train_loader = get_data_loader(data_dir, batch_size, None, **kwargs)\n", 483 | "\n", 484 | "stats_no_shuffling = []\n", 485 | "for epoch in range(1, num_epochs+1):\n", 486 | " stats_no_shuffling.append(train(model, device, train_loader, optimizer, epoch))\n", 487 | "pickle.dump(stats_no_shuffling, open(dump_dir + \"no_shuffling.pkl\", \"wb\"))" 488 | ] 489 | }, 490 | { 491 | "cell_type": "markdown", 492 | "metadata": {}, 493 | "source": [ 494 | "## With Shuffling" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 10, 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "num_epochs = 5\n", 504 | "learning_rate = 1e-3\n", 505 | "mom = 0.99\n", 506 | "batch_size = 64" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 11, 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "# create permutations\n", 516 | "permutations = []\n", 517 | "permutations.append(list(np.arange(60000)))\n", 518 | "\n", 519 | "x = list(np.arange(60000))\n", 520 | "np.random.seed(SEED)\n", 521 | "\n", 522 | "for _ in range(num_epochs-1):\n", 523 | " np.random.shuffle(x)\n", 524 | " permutations.append(x.copy())\n", 525 | "pickle.dump(permutations, open(dump_dir + \"permutations.pkl\", \"wb\"))" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 12, 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "name": "stdout", 535 | "output_type": "stream", 536 | "text": [ 537 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 3.633006\tAcc: 10.94%\n", 538 | "Train Epoch: 1 [1600/60000 (3%)]\tLoss: 1.619076\tAcc: 48.44%\n", 539 | "Train Epoch: 1 [3200/60000 (5%)]\tLoss: 0.823937\tAcc: 71.88%\n", 540 | "Train Epoch: 1 [4800/60000 (8%)]\tLoss: 0.598104\tAcc: 79.69%\n", 541 | "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.370156\tAcc: 92.19%\n", 542 | "Train Epoch: 1 [8000/60000 (13%)]\tLoss: 0.533670\tAcc: 81.25%\n", 543 | "Train Epoch: 1 [9600/60000 (16%)]\tLoss: 0.291023\tAcc: 89.06%\n", 544 | "Train Epoch: 1 [11200/60000 (19%)]\tLoss: 0.611064\tAcc: 82.81%\n", 545 | "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.301691\tAcc: 85.94%\n", 546 | "Train Epoch: 1 [14400/60000 (24%)]\tLoss: 0.171518\tAcc: 96.88%\n", 547 | "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.311606\tAcc: 92.19%\n", 548 | "Train Epoch: 1 [17600/60000 (29%)]\tLoss: 0.311102\tAcc: 90.62%\n", 549 | "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.225483\tAcc: 93.75%\n", 550 | "Train Epoch: 1 [20800/60000 (35%)]\tLoss: 0.140306\tAcc: 93.75%\n", 551 | "Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.154456\tAcc: 95.31%\n", 552 | "Train Epoch: 1 [24000/60000 (40%)]\tLoss: 0.203898\tAcc: 93.75%\n", 553 | "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.032903\tAcc: 100.00%\n", 554 | "Train Epoch: 1 [27200/60000 (45%)]\tLoss: 0.218158\tAcc: 93.75%\n", 555 | "Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.071584\tAcc: 98.44%\n", 556 | "Train Epoch: 1 [30400/60000 (51%)]\tLoss: 0.053365\tAcc: 100.00%\n", 557 | "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.092851\tAcc: 96.88%\n", 558 | "Train Epoch: 1 [33600/60000 (56%)]\tLoss: 0.023382\tAcc: 100.00%\n", 559 | "Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.239355\tAcc: 93.75%\n", 560 | "Train Epoch: 1 [36800/60000 (61%)]\tLoss: 0.159463\tAcc: 93.75%\n", 561 | "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.054997\tAcc: 98.44%\n", 562 | "Train Epoch: 1 [40000/60000 (67%)]\tLoss: 0.179467\tAcc: 93.75%\n", 563 | "Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.061505\tAcc: 96.88%\n", 564 | "Train Epoch: 1 [43200/60000 (72%)]\tLoss: 0.169429\tAcc: 95.31%\n", 565 | "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.159596\tAcc: 95.31%\n", 566 | "Train Epoch: 1 [46400/60000 (77%)]\tLoss: 0.468967\tAcc: 87.50%\n", 567 | "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.086294\tAcc: 95.31%\n", 568 | "Train Epoch: 1 [49600/60000 (83%)]\tLoss: 0.220987\tAcc: 93.75%\n", 569 | "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.208579\tAcc: 93.75%\n", 570 | "Train Epoch: 1 [52800/60000 (88%)]\tLoss: 0.264812\tAcc: 93.75%\n", 571 | "Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.046534\tAcc: 96.88%\n", 572 | "Train Epoch: 1 [56000/60000 (93%)]\tLoss: 0.117640\tAcc: 98.44%\n", 573 | "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.196920\tAcc: 93.75%\n", 574 | "Train Epoch: 1 [59200/60000 (99%)]\tLoss: 0.011510\tAcc: 100.00%\n", 575 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.102078\tAcc: 98.44%\n", 576 | "Train Epoch: 2 [1600/60000 (3%)]\tLoss: 0.031098\tAcc: 98.44%\n", 577 | "Train Epoch: 2 [3200/60000 (5%)]\tLoss: 0.232979\tAcc: 93.75%\n", 578 | "Train Epoch: 2 [4800/60000 (8%)]\tLoss: 0.016755\tAcc: 100.00%\n", 579 | "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.025166\tAcc: 100.00%\n", 580 | "Train Epoch: 2 [8000/60000 (13%)]\tLoss: 0.023412\tAcc: 100.00%\n", 581 | "Train Epoch: 2 [9600/60000 (16%)]\tLoss: 0.036617\tAcc: 98.44%\n", 582 | "Train Epoch: 2 [11200/60000 (19%)]\tLoss: 0.068466\tAcc: 98.44%\n", 583 | "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.098916\tAcc: 95.31%\n", 584 | "Train Epoch: 2 [14400/60000 (24%)]\tLoss: 0.026504\tAcc: 100.00%\n", 585 | "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.090279\tAcc: 96.88%\n", 586 | "Train Epoch: 2 [17600/60000 (29%)]\tLoss: 0.062686\tAcc: 95.31%\n", 587 | "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.080275\tAcc: 95.31%\n", 588 | "Train Epoch: 2 [20800/60000 (35%)]\tLoss: 0.059304\tAcc: 98.44%\n", 589 | "Train Epoch: 2 [22400/60000 (37%)]\tLoss: 0.013309\tAcc: 100.00%\n", 590 | "Train Epoch: 2 [24000/60000 (40%)]\tLoss: 0.011628\tAcc: 100.00%\n", 591 | "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.080991\tAcc: 96.88%\n", 592 | "Train Epoch: 2 [27200/60000 (45%)]\tLoss: 0.032248\tAcc: 98.44%\n", 593 | "Train Epoch: 2 [28800/60000 (48%)]\tLoss: 0.111027\tAcc: 95.31%\n", 594 | "Train Epoch: 2 [30400/60000 (51%)]\tLoss: 0.114232\tAcc: 95.31%\n", 595 | "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.030984\tAcc: 98.44%\n", 596 | "Train Epoch: 2 [33600/60000 (56%)]\tLoss: 0.078294\tAcc: 95.31%\n", 597 | "Train Epoch: 2 [35200/60000 (59%)]\tLoss: 0.243407\tAcc: 96.88%\n", 598 | "Train Epoch: 2 [36800/60000 (61%)]\tLoss: 0.043078\tAcc: 98.44%\n", 599 | "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.090051\tAcc: 98.44%\n", 600 | "Train Epoch: 2 [40000/60000 (67%)]\tLoss: 0.170737\tAcc: 96.88%\n", 601 | "Train Epoch: 2 [41600/60000 (69%)]\tLoss: 0.125144\tAcc: 95.31%\n", 602 | "Train Epoch: 2 [43200/60000 (72%)]\tLoss: 0.127538\tAcc: 96.88%\n", 603 | "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.085707\tAcc: 96.88%\n", 604 | "Train Epoch: 2 [46400/60000 (77%)]\tLoss: 0.049352\tAcc: 96.88%\n", 605 | "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.120053\tAcc: 95.31%\n", 606 | "Train Epoch: 2 [49600/60000 (83%)]\tLoss: 0.089620\tAcc: 96.88%\n", 607 | "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.051293\tAcc: 98.44%\n", 608 | "Train Epoch: 2 [52800/60000 (88%)]\tLoss: 0.039015\tAcc: 98.44%\n", 609 | "Train Epoch: 2 [54400/60000 (91%)]\tLoss: 0.004403\tAcc: 100.00%\n", 610 | "Train Epoch: 2 [56000/60000 (93%)]\tLoss: 0.171666\tAcc: 95.31%\n", 611 | "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.025525\tAcc: 98.44%\n", 612 | "Train Epoch: 2 [59200/60000 (99%)]\tLoss: 0.010009\tAcc: 100.00%\n", 613 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.019292\tAcc: 100.00%\n", 614 | "Train Epoch: 3 [1600/60000 (3%)]\tLoss: 0.018023\tAcc: 100.00%\n", 615 | "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.083953\tAcc: 95.31%\n", 616 | "Train Epoch: 3 [4800/60000 (8%)]\tLoss: 0.016674\tAcc: 100.00%\n", 617 | "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.052217\tAcc: 98.44%\n", 618 | "Train Epoch: 3 [8000/60000 (13%)]\tLoss: 0.022347\tAcc: 100.00%\n", 619 | "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.084956\tAcc: 98.44%\n", 620 | "Train Epoch: 3 [11200/60000 (19%)]\tLoss: 0.041385\tAcc: 98.44%\n", 621 | "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.201479\tAcc: 92.19%\n", 622 | "Train Epoch: 3 [14400/60000 (24%)]\tLoss: 0.006730\tAcc: 100.00%\n", 623 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.038304\tAcc: 98.44%\n", 624 | "Train Epoch: 3 [17600/60000 (29%)]\tLoss: 0.042132\tAcc: 98.44%\n", 625 | "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.137135\tAcc: 96.88%\n", 626 | "Train Epoch: 3 [20800/60000 (35%)]\tLoss: 0.140588\tAcc: 96.88%\n", 627 | "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.016618\tAcc: 100.00%\n", 628 | "Train Epoch: 3 [24000/60000 (40%)]\tLoss: 0.014411\tAcc: 100.00%\n", 629 | "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.133534\tAcc: 95.31%\n", 630 | "Train Epoch: 3 [27200/60000 (45%)]\tLoss: 0.016728\tAcc: 100.00%\n", 631 | "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.074302\tAcc: 96.88%\n", 632 | "Train Epoch: 3 [30400/60000 (51%)]\tLoss: 0.081907\tAcc: 96.88%\n", 633 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.204502\tAcc: 93.75%\n", 634 | "Train Epoch: 3 [33600/60000 (56%)]\tLoss: 0.044639\tAcc: 98.44%\n", 635 | "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.005864\tAcc: 100.00%\n", 636 | "Train Epoch: 3 [36800/60000 (61%)]\tLoss: 0.065540\tAcc: 96.88%\n", 637 | "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.106513\tAcc: 98.44%\n", 638 | "Train Epoch: 3 [40000/60000 (67%)]\tLoss: 0.028559\tAcc: 98.44%\n", 639 | "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.258976\tAcc: 93.75%\n", 640 | "Train Epoch: 3 [43200/60000 (72%)]\tLoss: 0.004482\tAcc: 100.00%\n", 641 | "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.008298\tAcc: 100.00%\n", 642 | "Train Epoch: 3 [46400/60000 (77%)]\tLoss: 0.086896\tAcc: 96.88%\n", 643 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.078228\tAcc: 96.88%\n", 644 | "Train Epoch: 3 [49600/60000 (83%)]\tLoss: 0.025244\tAcc: 98.44%\n", 645 | "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.040387\tAcc: 98.44%\n", 646 | "Train Epoch: 3 [52800/60000 (88%)]\tLoss: 0.052178\tAcc: 98.44%\n", 647 | "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.035417\tAcc: 98.44%\n", 648 | "Train Epoch: 3 [56000/60000 (93%)]\tLoss: 0.050410\tAcc: 96.88%\n", 649 | "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.131319\tAcc: 96.88%\n", 650 | "Train Epoch: 3 [59200/60000 (99%)]\tLoss: 0.064430\tAcc: 98.44%\n", 651 | "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.089984\tAcc: 95.31%\n", 652 | "Train Epoch: 4 [1600/60000 (3%)]\tLoss: 0.075069\tAcc: 96.88%\n", 653 | "Train Epoch: 4 [3200/60000 (5%)]\tLoss: 0.037407\tAcc: 100.00%\n", 654 | "Train Epoch: 4 [4800/60000 (8%)]\tLoss: 0.284701\tAcc: 96.88%\n", 655 | "Train Epoch: 4 [6400/60000 (11%)]\tLoss: 0.005039\tAcc: 100.00%\n", 656 | "Train Epoch: 4 [8000/60000 (13%)]\tLoss: 0.049454\tAcc: 96.88%\n", 657 | "Train Epoch: 4 [9600/60000 (16%)]\tLoss: 0.070564\tAcc: 98.44%\n", 658 | "Train Epoch: 4 [11200/60000 (19%)]\tLoss: 0.061793\tAcc: 98.44%\n", 659 | "Train Epoch: 4 [12800/60000 (21%)]\tLoss: 0.043478\tAcc: 98.44%\n", 660 | "Train Epoch: 4 [14400/60000 (24%)]\tLoss: 0.015482\tAcc: 98.44%\n", 661 | "Train Epoch: 4 [16000/60000 (27%)]\tLoss: 0.063184\tAcc: 98.44%\n", 662 | "Train Epoch: 4 [17600/60000 (29%)]\tLoss: 0.004628\tAcc: 100.00%\n", 663 | "Train Epoch: 4 [19200/60000 (32%)]\tLoss: 0.119642\tAcc: 95.31%\n", 664 | "Train Epoch: 4 [20800/60000 (35%)]\tLoss: 0.064909\tAcc: 96.88%\n", 665 | "Train Epoch: 4 [22400/60000 (37%)]\tLoss: 0.038996\tAcc: 98.44%\n", 666 | "Train Epoch: 4 [24000/60000 (40%)]\tLoss: 0.034809\tAcc: 100.00%\n", 667 | "Train Epoch: 4 [25600/60000 (43%)]\tLoss: 0.009662\tAcc: 100.00%\n", 668 | "Train Epoch: 4 [27200/60000 (45%)]\tLoss: 0.010710\tAcc: 100.00%\n", 669 | "Train Epoch: 4 [28800/60000 (48%)]\tLoss: 0.014937\tAcc: 98.44%\n" 670 | ] 671 | }, 672 | { 673 | "name": "stdout", 674 | "output_type": "stream", 675 | "text": [ 676 | "Train Epoch: 4 [30400/60000 (51%)]\tLoss: 0.041674\tAcc: 100.00%\n", 677 | "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.203437\tAcc: 95.31%\n", 678 | "Train Epoch: 4 [33600/60000 (56%)]\tLoss: 0.010256\tAcc: 100.00%\n", 679 | "Train Epoch: 4 [35200/60000 (59%)]\tLoss: 0.102535\tAcc: 93.75%\n", 680 | "Train Epoch: 4 [36800/60000 (61%)]\tLoss: 0.030967\tAcc: 98.44%\n", 681 | "Train Epoch: 4 [38400/60000 (64%)]\tLoss: 0.050214\tAcc: 96.88%\n", 682 | "Train Epoch: 4 [40000/60000 (67%)]\tLoss: 0.032882\tAcc: 98.44%\n", 683 | "Train Epoch: 4 [41600/60000 (69%)]\tLoss: 0.174472\tAcc: 98.44%\n", 684 | "Train Epoch: 4 [43200/60000 (72%)]\tLoss: 0.044490\tAcc: 98.44%\n", 685 | "Train Epoch: 4 [44800/60000 (75%)]\tLoss: 0.052338\tAcc: 98.44%\n", 686 | "Train Epoch: 4 [46400/60000 (77%)]\tLoss: 0.005276\tAcc: 100.00%\n", 687 | "Train Epoch: 4 [48000/60000 (80%)]\tLoss: 0.009933\tAcc: 100.00%\n", 688 | "Train Epoch: 4 [49600/60000 (83%)]\tLoss: 0.200144\tAcc: 96.88%\n", 689 | "Train Epoch: 4 [51200/60000 (85%)]\tLoss: 0.048852\tAcc: 98.44%\n", 690 | "Train Epoch: 4 [52800/60000 (88%)]\tLoss: 0.030278\tAcc: 98.44%\n", 691 | "Train Epoch: 4 [54400/60000 (91%)]\tLoss: 0.054861\tAcc: 98.44%\n", 692 | "Train Epoch: 4 [56000/60000 (93%)]\tLoss: 0.003664\tAcc: 100.00%\n", 693 | "Train Epoch: 4 [57600/60000 (96%)]\tLoss: 0.001415\tAcc: 100.00%\n", 694 | "Train Epoch: 4 [59200/60000 (99%)]\tLoss: 0.009499\tAcc: 100.00%\n", 695 | "Train Epoch: 5 [0/60000 (0%)]\tLoss: 0.013112\tAcc: 100.00%\n", 696 | "Train Epoch: 5 [1600/60000 (3%)]\tLoss: 0.029796\tAcc: 100.00%\n", 697 | "Train Epoch: 5 [3200/60000 (5%)]\tLoss: 0.003726\tAcc: 100.00%\n", 698 | "Train Epoch: 5 [4800/60000 (8%)]\tLoss: 0.023652\tAcc: 98.44%\n", 699 | "Train Epoch: 5 [6400/60000 (11%)]\tLoss: 0.007024\tAcc: 100.00%\n", 700 | "Train Epoch: 5 [8000/60000 (13%)]\tLoss: 0.029964\tAcc: 98.44%\n", 701 | "Train Epoch: 5 [9600/60000 (16%)]\tLoss: 0.166122\tAcc: 93.75%\n", 702 | "Train Epoch: 5 [11200/60000 (19%)]\tLoss: 0.005029\tAcc: 100.00%\n", 703 | "Train Epoch: 5 [12800/60000 (21%)]\tLoss: 0.070467\tAcc: 98.44%\n", 704 | "Train Epoch: 5 [14400/60000 (24%)]\tLoss: 0.007154\tAcc: 100.00%\n", 705 | "Train Epoch: 5 [16000/60000 (27%)]\tLoss: 0.118603\tAcc: 96.88%\n", 706 | "Train Epoch: 5 [17600/60000 (29%)]\tLoss: 0.047950\tAcc: 98.44%\n", 707 | "Train Epoch: 5 [19200/60000 (32%)]\tLoss: 0.048197\tAcc: 96.88%\n", 708 | "Train Epoch: 5 [20800/60000 (35%)]\tLoss: 0.024703\tAcc: 98.44%\n", 709 | "Train Epoch: 5 [22400/60000 (37%)]\tLoss: 0.032201\tAcc: 98.44%\n", 710 | "Train Epoch: 5 [24000/60000 (40%)]\tLoss: 0.010922\tAcc: 100.00%\n", 711 | "Train Epoch: 5 [25600/60000 (43%)]\tLoss: 0.033510\tAcc: 98.44%\n", 712 | "Train Epoch: 5 [27200/60000 (45%)]\tLoss: 0.005127\tAcc: 100.00%\n", 713 | "Train Epoch: 5 [28800/60000 (48%)]\tLoss: 0.019979\tAcc: 98.44%\n", 714 | "Train Epoch: 5 [30400/60000 (51%)]\tLoss: 0.041545\tAcc: 100.00%\n", 715 | "Train Epoch: 5 [32000/60000 (53%)]\tLoss: 0.037461\tAcc: 98.44%\n", 716 | "Train Epoch: 5 [33600/60000 (56%)]\tLoss: 0.039346\tAcc: 96.88%\n", 717 | "Train Epoch: 5 [35200/60000 (59%)]\tLoss: 0.045353\tAcc: 98.44%\n", 718 | "Train Epoch: 5 [36800/60000 (61%)]\tLoss: 0.011449\tAcc: 100.00%\n", 719 | "Train Epoch: 5 [38400/60000 (64%)]\tLoss: 0.032527\tAcc: 98.44%\n", 720 | "Train Epoch: 5 [40000/60000 (67%)]\tLoss: 0.002806\tAcc: 100.00%\n", 721 | "Train Epoch: 5 [41600/60000 (69%)]\tLoss: 0.001624\tAcc: 100.00%\n", 722 | "Train Epoch: 5 [43200/60000 (72%)]\tLoss: 0.066037\tAcc: 98.44%\n", 723 | "Train Epoch: 5 [44800/60000 (75%)]\tLoss: 0.014355\tAcc: 100.00%\n", 724 | "Train Epoch: 5 [46400/60000 (77%)]\tLoss: 0.002218\tAcc: 100.00%\n", 725 | "Train Epoch: 5 [48000/60000 (80%)]\tLoss: 0.049414\tAcc: 98.44%\n", 726 | "Train Epoch: 5 [49600/60000 (83%)]\tLoss: 0.022918\tAcc: 98.44%\n", 727 | "Train Epoch: 5 [51200/60000 (85%)]\tLoss: 0.091738\tAcc: 98.44%\n", 728 | "Train Epoch: 5 [52800/60000 (88%)]\tLoss: 0.004077\tAcc: 100.00%\n", 729 | "Train Epoch: 5 [54400/60000 (91%)]\tLoss: 0.001125\tAcc: 100.00%\n", 730 | "Train Epoch: 5 [56000/60000 (93%)]\tLoss: 0.012950\tAcc: 98.44%\n", 731 | "Train Epoch: 5 [57600/60000 (96%)]\tLoss: 0.003812\tAcc: 100.00%\n", 732 | "Train Epoch: 5 [59200/60000 (99%)]\tLoss: 0.002558\tAcc: 100.00%\n" 733 | ] 734 | } 735 | ], 736 | "source": [ 737 | "torch.manual_seed(SEED)\n", 738 | "\n", 739 | "model = SmallConv().to(device)\n", 740 | "\n", 741 | "# relu init\n", 742 | "for m in model.modules():\n", 743 | " if isinstance(m, (nn.Conv2d, nn.Linear)):\n", 744 | " nn.init.kaiming_normal_(m.weight, mode='fan_in')\n", 745 | "\n", 746 | "# define optimizer\n", 747 | "optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=mom)\n", 748 | "\n", 749 | "stats_with_shuffling = []\n", 750 | "for epoch in range(1, num_epochs+1):\n", 751 | " train_loader = get_data_loader(data_dir, batch_size, permutations[epoch-1], **kwargs)\n", 752 | " stats_with_shuffling.append(train(model, device, train_loader, optimizer, epoch))\n", 753 | "pickle.dump(stats_with_shuffling, open(dump_dir + \"with_shuffling.pkl\", \"wb\"))" 754 | ] 755 | }, 756 | { 757 | "cell_type": "markdown", 758 | "metadata": {}, 759 | "source": [ 760 | "## Generate Quantiles - Shuffling ON" 761 | ] 762 | }, 763 | { 764 | "cell_type": "code", 765 | "execution_count": 23, 766 | "metadata": {}, 767 | "outputs": [], 768 | "source": [ 769 | "stats_with_shuffling = pickle.load(open(dump_dir + \"with_shuffling.pkl\", \"rb\"))\n", 770 | "permutations = pickle.load(open(dump_dir + \"permutations.pkl\", \"rb\"))" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": 24, 776 | "metadata": {}, 777 | "outputs": [], 778 | "source": [ 779 | "# flatten the list for each epoch\n", 780 | "stats_with_shuffling_flat = []\n", 781 | "for stat in stats_with_shuffling:\n", 782 | " stats_with_shuffling_flat.append(\n", 783 | " [v for sublist in stat for v in sublist]\n", 784 | " )" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": 25, 790 | "metadata": {}, 791 | "outputs": [], 792 | "source": [ 793 | "# remap the indices based on the permutations list\n", 794 | "for stat, perm in zip(stats_with_shuffling_flat, permutations):\n", 795 | " for i in range(len(stat)):\n", 796 | " stat[i][0] = perm[i]\n", 797 | " stat[i][1] = stat[i][1].item()" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": 26, 803 | "metadata": {}, 804 | "outputs": [], 805 | "source": [ 806 | "# resort in increasing index order\n", 807 | "for i in range(len(stats_with_shuffling_flat)):\n", 808 | " stats_with_shuffling_flat[i] = sorted(stats_with_shuffling_flat[i], key=lambda x: x[0])" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": 27, 814 | "metadata": {}, 815 | "outputs": [], 816 | "source": [ 817 | "# get percentile splits for all 5 epochs\n", 818 | "num_quantiles = 10\n", 819 | "percentile_splits = bin_losses(stats_with_shuffling_flat, num_quantiles)" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": 28, 825 | "metadata": {}, 826 | "outputs": [ 827 | { 828 | "data": { 829 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAD+CAYAAAAAo82zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF6xJREFUeJzt3XuUXXV99/H3NCB4wQChNYYpBg3zpdobpbSi3IoVq5V2+fA8YJEqCBEEasSqXLxQqC0+tfV50hZEbvZiFJG2LCmKrQWCCoiK0gXqtypUicmYKKAgBSJM/9h76jDJnDmZnN85v8x5v9aadS57n/37zTmfM9/927cZmZiYQJIk1eOnBt0BSZL0RBZnSZIqY3GWJKkyFmdJkipjcZYkqTIWZ0mSKrPdoDsQEYcAVwBfmfL0hsz8Pz1Y9t8Al2fmtXN47TLgqsz8+c1MOwi4PzP/vYvl7Ar8B3BH+9Q/ZebKafMsAv4U+Czw/Mw8uX3+/cD+mfmL7ePjgF8EbgfuzcyPRcSpmfnXEXEssHdmntGhLy8F3gw8DiwALs3MVRHxR8B4Zl7Yxe+zAPg48FTg8Pb+A8Ba4PJ2tj0y86LZljVtuefSfFZfmXXmrVRj5iLiPcABNN/JizLz4mnTzdw2nLm2vUOoLHfta58C3AScMf315m5wuRt4cW5dl5mvHHQnJkXE7wMrgN1mmOW1NB/OrIEFfgX4cGb+QYd53gWcD2ygCdSk/YANEbE0M/8TOAT4UGZ+cso8bwf+uot+AFwI/FJm3h8ROwG3R8S/dvnaSc8EdsvMfSPiAGBtZh7R/nFgLn8cWu8FVgG/PcfXb6lqMhcRvwEsy8z9I2IH4M6IuDIz75sym5nb9jMHFeVuivOBmS54Ye4GlLuRQV+EpF2bPGlzgY2IG4CvAXsDI8BRmTkeEX9BM8qA5gNcGRF7AZcATwIeAl4JvAdYOOXn9TQhu6J9/GTgrZl5w7R2Xw5cD3wzMxdPm7YvcA2wHng5cCDwRuAR4OvA6zJz45T5Twd+B/hx+5o3ZOa6KdOfTvOF/dX28e3AbwC7A2cCtwEPZeYFEfFVYB/gDGAcWASc3f7etwLHAo8CPw28b/oaXUTc2vb9Spq19ydl5iPt2uSvt+/dIuAdmXl1RIxP/v4RcTlN4N/avvf/0PZlCfA+4Fk0X+LF7ed1IfBh4B7gOcCtmfn6iNgN+BCwA5DAoZm5rG3jH4BzullL3xq1ZS4idgSenJn3RcSTaN6XX8jMB9vpZm4bz1zb1iFUlLu23TcDPwBeyLSRt7kbbO5q2ed8aETcMOXnLVOm3ZSZhwAfAc5qC+eewPNp3rijI+IXgD8HzsvM/YH307yZAF/MzEOBv6L5QJ9D86YeDhwNPGV6ZzLznzPzR5vraGZ+EbiW5oP7EXAOzZt+AHA/cOK0l3wNODszDwauavsx1fNpPrhJ/0bzRXkp8In256URsSfwn5n58JS+/AnNJp+T26c2Ai8BXkHzJZrud9rf98PAOuDMiBhpp30nM1/Uvu71m/vdWycDX8nM49p5r8vMs2eYdww4Hvg14GURsRh4G83ugoOBj/LErTf/TrPG3A/VZC4zH24L8/bA39Js1n5wynQzNz8yBxXlLiJeBOyV03ahTDJ3g83dtrBZ+7r29ibgd2nWTj6dmRPAxoi4BXguEMDNAJl5BUBEHA18sX39OPCUzLwzIs6n+dC2B/5yK/r9bODOzHygfXwjcNhm+v9Qe/+fgHOnTd8N+O6Ux/9Ksza5L/DKzNwQEaM0H+Rsm1Fuy8yJiBhn0y/iLsCzMvN04PSI2J1mjXDy/XnC+7SZZY9s5rnZfGPyvYmIdcCOwM/RFCCAT0+bfx3NWnQ/VJW59vO5ErghM8/r0G8z11nNmYO6cnc88Kx21L438CvtCPLLm+mbueus57mrZeTcyb7t7QuBO4Gv0m7maUcaL6DZxPJVmv0WRMSrImJyv8cTttu3a547ZeZvA69h07W7bjxO897dDTw3Ip7aPn8wzQERU10CHNHefxE/Ccak9cDOUx6vBvan2QyzoX3uVpov0uYCOzVInfZR7ABcERE/2z5eRxPORzq8dvuIeFq7qfV5HZY9k80t8w6a3w+aNempdqF5Pwatr5mLiCfTjCIuy8w/nqFPZq4722rmoM+5y8yjM/OF7Wj9WprN3tMLs7nrTs9zV8vI+dB27W2ql7a3x0bEm2g2q/x+Zn4/Ig6JiJtp9htckZm3tZuH3h8Rb6dZezuGn4R9qq8DZ0fEq2n2WbxzDv39HPBu4Cia/SDXR8TjwDdo9pFMdQZwWUSc3P4OJ0ybfgvwfycfZOZDEbGRZs100ieAwzIz2dRXIuKDwKc6dbjdf/UHwD9GxI9pjmD858z8l4h4wQwv+/9t/+4CvtVp+Vvg3cDfR8SRNEc9bpwy7deBs3rUzmxqytxJNCOT5RGxvH3uuMy8e8o8Zm7uaskc1JW7bpi7udu63E1MTFT7MzY2dsPY2Njeg+5HH37PC8fGxvYZdD/69Lu+bGxsbL/2/m+OjY1d197fdWxs7OoK+mfm5tlP7Zlr+2Lu5tnP1uaulpHzsHsn8CfA8tlmnAfuplm7nlyjfUP7/Gn0dwQz7MycmRsEc9dl7gZ+KpUkSXqibeGAMEmShkqRzdoR8XyagwPuoTktYBRYSnMy/GlTjsyTesLMaRDMnUoptc/592hORr89mqutLMrMF0dzmcLlNNdWBSCayxXuR3O4+2OF+qNtwwKaS+Z9PjMfmW3macyc5srcqd9mzVyp4vz/gHdGxL3Az/CTQ9PX0FwCbar92PQEbQ23A4HPbOFrzJy2lrlTv82YuVLFeQ/gXZl5V0RcAzyjfX6U5nyvqdYBrFq1isWLF6PhNT4+zqte9SpoM7GFzJzmxNyp37rJXKnivAb484i4n+bC37tFxAU0V0U5adq8jwEsXryY0dHRQt3RNmYum/zMnLaWuVO/zZi5IsU5M+8C/leJZUubY+Y0COZOpXgqlSRJlbE4S5JUGYuzJEmVsThLklQZi7MkSZWxOEuSVBmLsyRJlbE4S5JUGYuzJEmVsThLklQZi7MkSZWxOEuSVBmLsyRJlbE4S5JUGYuzJEmVsThLklQZi7MkSZWxOEuSVJntSiw0IkaBc4H7gBFgHbAUWAiclpkbSrSr4WXmNAjmTqWUGjnvDRwK7AF8DzgoM08BLgWWF2pTw83MaRDMnYooMnIG7qEJ7N3AvwBr2ufXAEsKtanhZuY0COZORZQaOZ8KLMzMCeAHwLPa50eBtYXa1HAzcxoEc6ciSo2cPwCcGxHfBj4HPBoRFwC7ACcValPDzcxpEMydiihSnDPzNuDwEsuWNsfMaRDMnUrxVCpJkipjcZYkqTIWZ0mSKmNxliSpMhZnSZIqY3GWJKkyFmdJkipjcZYkqTIWZ0mSKmNxliSpMhZnSZIqY3GWJKkyFmdJkipjcZYkqTIWZ0mSKmNxliSpMhZnSZIqY3GWJKky25VYaEScAuwHbA8cAHwE2AFYCJyYmY+UaFfDy8xpEMydSikycs7M8zPzWGAN8ApgYWauAFYDR5RoU8PNzGkQzJ1KKbZZOyL2pll73JEmuLS3S0q1qeFm5jQI5k4lFNms3ToZeA/wKLB7+9wosLZgmxpuZk6DYO7UcyWL87LM/CZARNwbEStp1i6XF2xTw83MaRDMnXquWHHOzJdNuX9WqXakSWZOg9CL3I2MdD/vxMRcWtC2xlOpJEmqjMVZkqTKWJwlSaqMxVmSpMpYnCVJqozFWZKkylicJUmqjMVZkqTKWJwlSaqMxVmSpMpYnCVJqozFWZKkylicJUmqjMVZkqTKWJwlSaqMxVmSpMpYnCVJqozFWZKkymxXYqERsRR4B7AeeADYGdgBWAicmJmPlGhXw8vMaRDMnUopNXL+Q2AN8Ezgu8DCzFwBrAaOKNSmhpuZ0yCYOxVRqjgvA64ClgPH0ISX9nZJoTY13MycBsHcqYhSxXkc+GFmbmwf797ejgJrC7Wp4WbmNAjmTkUU2ecM/BlwXkSsBy4BnhcRK2n2wywv1KaGm5nTIJg7FVGkOGfmV4EjSyxb2hwzp0EwdyrFU6kkSaqMxVmSpMrMulk7Is4GTgU2AiPARGZ6FKIkSYV0s8/55cAemflfpTsjSVLtRka6n3diYm5tdLNZez3NqFmSJPXBjCPniPgwMAE8A/hSRNzRTprIzKP70TlJkoZRp83aF/atF5Ik6X/MuFk7M1dn5mrg6cCL2vtnAjv2q3OSJA2jbg4IOwf4rfb+UcAngE8W65EkSUOumwPCNmbmeoDM/AHwWNkuSZI03LoZOd8aER8Cbgb2A75UtkuSJA23borzG4DfBQK4MjM/VrZLkiQNt06nUi0AFgCXA68EPg4siIjrMvPQPvVPkqSh02nk/FrgLGAx8DWaS3c+BnymD/2SJGlozVicM/Ni4OKIeG1mXtbHPkmSNNS62ed8Y0ScCWxPM3pekpknlu2WJEnDq5tTqf6uvT0A2BNYVK47kiSpm+L8UGaeB6zJzGNprrUtSZIK6Waz9khELAaeFhFPBXad7QURsQfwMeDLwDqao753ABYCJ2bmI3PvsrQpM6dBMHcqpZuR8znAK4APAnfTXL5zNgcD323vjwMLM3MFsBo4Yg79lGZj5jQI5k5FzDpyzswbgRvbhz/T5XJvBT5FE9pPAde3z68BfmkL+yh1w8xpEMydiuh0EZK7af6f8yYy89mzLHcf4ObMfDwiRoDJ+UeBtXPpqDQLM6dBMHcqotPI+WpgX5q1wQ8C396C5X4deE9EbAA+AuwREStp9sMsn2NfpU7MnAbB3KmIThcheUNE/BRwGPAOmgPBrgKuADoe5JCZXwSO7GE/pY7MnAbB3KmUjvucM/Nx4Frg2ojYFXgf8FfAk/vQN0mShlLH4tyOnF8M/B7wyzRHau/Xh35JkjS0Oh0Qdj7NaQI3ABdl5k396pQkScOs08j59cD3ac7VOyIiJmiurT2RmUv60Tmp30ZGup93YrPnMkjS1ut0QFg3FyiRJEk9ZgGWJKkyFmdJkioza3GOiF+d9vjgct2RJEmdjtY+EHgucFpEvLd9egFwCvDzfeibJElDqdPR2vcBi2n+/dkz2+ceB95aulOSJA2zTkdr3wHcEREXZ6YXcJckqU9m/ZeRwG9GxJk0I+jJ85xn+69UkiRpjropzqcDhwP3FO6LJEmiu+J8V2Z+o3hPJEkS0F1xfigiPgF8GZgAyMyzivZKkqQh1k1x/njxXkiSNCA1XlO/myuErQK2B54NfAu4pmiPJEkact0U5wuBPYDDgJ2AvyvaI0mShlw3m7Wfk5knRMSBmXl1RJzRzYIjYhVwNfCzwFJgIXBaZm6Yc2+lWZg79ZuZUwndjJy3i4jdgImI2InmKmEdRcSbgAfbhwdl5inApcDyOfdUmoW5U7+ZOZXSzcj57cBnaS7heQvwxk4zR8ThwP3AzTTFf307aQ2wZM49lTowd+o3M6eSZi3Ombk6Il4M/BewNDM/P8tLjqG5Lne0jyfXKkcBLwOqUsyd+s3MqZhZi3NEXAisycx3RcTbI+KYzFwx0/yZeVT7umOBh4FnRMQFwC7ASb3ptvRE5k79ZuZUUjebtffJzJMAMnNFRNzYzYIz82+2pmPSXJg79ZuZUwndHBA2EhGLACJiZ7or6JIkaY66KbTnAF+IiHuBnYGTy3ZJkqTh1k1x3hlYBuwGrM/MPl28TJKk4dRNcX5dZq4Cvlu6M5IkqbvivENEfAlI2guQZObRRXslSdIQ66Y4n168F5Ik6X90c7T2bcCLgVcDi4DvFO2RJElDrpvifBlwFzAGjNNcN1aSJBXSTXFelJmXARsz8yZgC/4ttSRJ2lLdFGciYu/2dhR4rGiPJEkact0cELYC+ADwc8CVeBESSZKK6licI+LpwDczc/8+9UeSpKE342btiDgVuB24PSJe0r8uSZI03DqNnI+m+T+lTwf+HvhkX3okSepoZAsOy53wgsvbpE4HhD2cmY9m5veAJ/WrQ5IkDbuujtbG06ckSeqbTpu1nxcRH6IpzJP3Aa+tLUlSSZ2K85FT7l9YuiOSJKkxY3HOzNX97IgkSWp0cxGSLRYRewHvprkW9+eBnwaWAguB0zJzQ4l2NbzMnAbB3KmUbg8I21ILgTcDb6Q5JeugzDyF5p9mLC/UpoabmdMgmDsVUaQ4Z+YXgEeBa4AbgPXtpDXAkhJtariZOQ2CuVMpRYpzRPwyzXnShwH7Aru1k0aBtSXa1HAzcxoEc6dSiuxzprloyfsj4vs0/wt6TURcAOwCnFSoTQ03M6dBMHcqokhxzsxbgf9dYtnS5pg5DYK5UymlDgiTJElzZHGWJKkyFmdJkipjcZYkqTKljtbuCf9nqSRpGDlyliSpMhZnSZIqU/VmbUmajbu/NB85cpYkqTIWZ0mSKmNxliSpMhZnSZIqY3GWJKkyHq0tSfOcR7RveyzOkqR5ZT6sjLhZW5KkylicJUmqjMVZkqTKFNnnHBEvAFYADwLfAp4G7AAsBE7MzEdKtKvhZeY0COZOpZQaOe8CnJCZxwMHAAszcwWwGjiiUJsasJGR7n8KMHMaBHOnIooU58y8BngwIt4GfBZY005aAywp0aaGm5nTIMzX3A14RVsUKs4RsRNwCXALcBmweztpFFhbok0NNzOnQTB3KqXUec4rgb2A44DHgO9ExEqa/TDLC7Wp4WbmNAjmrg/mw3nLW6pIcc7M15ZYrjQTM6dBMHdzM4zFdkt5hTBJ0lax2Pae5zlLklQZi7MkSZWxOEuSVBmLsyRJlbE4S5JUGYuzJEmVsThLklQZi7MkSZWxOEuSVBmLsyRJlbE4S5JUGYuzJEmVsThLklQZi7MkSZWxOEuSVBmLsyRJlbE4S5JUme1KLTgilgEfzcx9IuItwFJgIXBaZm4o1a6Gm7lTv5k5lVBk5BwRi4ETgB9FxI7AQZl5CnApsLxEm5K5U7+ZOZVSpDhn5nhmngE8COwKrG8nrQGWlGhTMnfqNzOnUvqxz3k9sKi9Pwqs7UObkrlTv5k59Uzx4pyZPwauj4gLgNcB55duUzJ36jczp14qdkAYQGb+Vnu7smQ70lTmTv1m5tRrnkolSVJlLM6SJFXG4ixJUmUszpIkVcbiLElSZSzOkiRVpuipVNJ8NzLS/bwTE+X6IWl+ceQsSVJlLM6SJFXG4ixJUmUszpIkVcbiLElSZTxaW1JVPAJecuQsSVJ1LM6SJFXG4ixJUmXc5yz1kftTJXVjXhVn//BJkuaDvhTniNgd+AvgXuDOzDy/H+1quM2H3G3pCmc/VlC7bWNy+Vs6/7ZsPmROdejXyPlE4C8z86aI+HhEXJSZG9tpCwDGx8c37dwW9G7Nmi2fX721te//lAws6EV/mDl3A8tcbfMD7Llnd/PffXdz220bk8svNf/kayrLXZV/65x/sPNP103mRib6sLoaERcBf5yZ90TEKmBFZn6vnXYA8OnindC25MDM/MzWLmSm3Jk5zWCrc+ffOm2hGTPXr5Hzt4FR4B5gV+D+KdM+DxwIrAMe61N/VKcFwDNpMtELM+XOzGmqXubOv3XqxqyZ69fIeTHwXuAB4AuZeXHxRjX0zJ36zcypV/pSnCVJUvfm1alUU0XEC4AVwIPAtzLz3ELtrAKuzszLe7zcpcA7gPXAA5n5p71cftvGKHAucB8wkplv6uGylwEfzcx9IuItwFJgIXBaZm7oVTu1MXezLt/M9di2nrl22Usxd08wn68QtgtwQmYeDxxQooGIeBPNF6KEPwTW0OyXuLlQG3sDhwJ7tG31RLtp7wTgRxGxI3BQZp4CXAos71U7lTJ3nZm53tvWMwfmbhPztjhn5jXAgxHxNmBVr5cfEYfTHOxRqnAuA66i+YDfWaiNe2gCeyTwkoh4Si8WmpnjmXkGzZd5V5q1YWi+FEt60UatzN2szFyPzYPMgbnbxLwtzhGxE3AJcEtm/m2BJo4Bfg14DXB8RCzq8fLHgR+250g+0ONlTzoVWJiZE20bJXZzrAcm35tRYG2BNqph7mZl5npsHmQOzN0m5u0+Z2AlsBdwXES8OjNf08uFZ+ZRABFxLPBwZn6/l8sH/gw4LyLWAx/p8bInfQA4NyK+DXwuM3/Y6wYy88cRcX1EXECz+e2kXrdRGXPXmZnrvW09c2DuNuHR2pIkVWbebtaWJGlbZXGWJKkyFmdJkipjcZYkqTIWZ0mSKmNxliSpMhZnSZIq89/tyW0sUbvaWgAAAABJRU5ErkJggg==\n", 830 | "text/plain": [ 831 | "
" 832 | ] 833 | }, 834 | "metadata": {}, 835 | "output_type": "display_data" 836 | } 837 | ], 838 | "source": [ 839 | "fr = [0, 1, 3]\n", 840 | "to = [0, 0, 0]\n", 841 | "names = ['1 to 5', '2 to 5', '4 to 5']\n", 842 | "\n", 843 | "all_matches = []\n", 844 | "for f, t in zip(fr, to):\n", 845 | " percent_matches = []\n", 846 | " for i in range(num_quantiles):\n", 847 | " percentile_all = []\n", 848 | " for j in range(f, len(percentile_splits)-t):\n", 849 | " percentile_all.append(percentile_splits[j][i])\n", 850 | " matching = reduce(np.intersect1d, percentile_all)\n", 851 | " percent = 100 * len(matching) / len(percentile_all[0])\n", 852 | " percent_matches.append(percent)\n", 853 | " all_matches.append(percent_matches)\n", 854 | " \n", 855 | " \n", 856 | "fig, axes = plt.subplots(1, 3, figsize=(8, 4))\n", 857 | "for i, (ax, match, n) in enumerate(zip(axes, all_matches, names)):\n", 858 | " ax.bar(range(1, num_quantiles+1), match, width=0.9, color='b')\n", 859 | " if i == 0:\n", 860 | " ax.set_ylabel('Percent Match')\n", 861 | " ax.set_title('Epochs {} (With Shuffling)'.format(n))\n", 862 | " ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 863 | " ax.set_ylim([0, 90])\n", 864 | " \n", 865 | "plt.savefig(plot_dir + \"with_shuffling.jpg\", format=\"jpg\", dpi=1000, bbox_inches='tight')" 866 | ] 867 | }, 868 | { 869 | "cell_type": "markdown", 870 | "metadata": {}, 871 | "source": [ 872 | "## Generate Quantiles - Shuffling OFF" 873 | ] 874 | }, 875 | { 876 | "cell_type": "code", 877 | "execution_count": 29, 878 | "metadata": {}, 879 | "outputs": [], 880 | "source": [ 881 | "stats_no_shuffling = pickle.load(open(dump_dir + \"no_shuffling.pkl\", \"rb\"))\n", 882 | "permutations = permutations = [np.arange(60000)]*len(stats_no_shuffling)" 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": 30, 888 | "metadata": {}, 889 | "outputs": [], 890 | "source": [ 891 | "# flatten the list for each epoch\n", 892 | "stats_no_shuffling_flat = []\n", 893 | "for stat in stats_no_shuffling:\n", 894 | " stats_no_shuffling_flat.append(\n", 895 | " [v for sublist in stat for v in sublist]\n", 896 | " )" 897 | ] 898 | }, 899 | { 900 | "cell_type": "code", 901 | "execution_count": 31, 902 | "metadata": {}, 903 | "outputs": [], 904 | "source": [ 905 | "# remap the indices based on the permutations list\n", 906 | "for stat, perm in zip(stats_no_shuffling_flat, permutations):\n", 907 | " for i in range(len(stat)):\n", 908 | " stat[i][0] = perm[i]\n", 909 | " stat[i][1] = stat[i][1].item()" 910 | ] 911 | }, 912 | { 913 | "cell_type": "code", 914 | "execution_count": 32, 915 | "metadata": {}, 916 | "outputs": [], 917 | "source": [ 918 | "# resort in increasing index order\n", 919 | "for i in range(len(stats_no_shuffling_flat)):\n", 920 | " stats_no_shuffling_flat[i] = sorted(stats_no_shuffling_flat[i], key=lambda x: x[0])" 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": 33, 926 | "metadata": {}, 927 | "outputs": [], 928 | "source": [ 929 | "# get percentile splits for all 5 epochs\n", 930 | "num_quantiles = 10\n", 931 | "percentile_splits = bin_losses(stats_no_shuffling_flat, num_quantiles)" 932 | ] 933 | }, 934 | { 935 | "cell_type": "code", 936 | "execution_count": 34, 937 | "metadata": {}, 938 | "outputs": [ 939 | { 940 | "data": { 941 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAD+CAYAAADvYaaNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF1BJREFUeJzt3XuQpXV95/F3OxIQgw0MG8exdxwM7JeYbCISLJWbi9G4rmwqISs6EkVkHBTWEYyKxEuBlJiYkBrdIQiImMogEitSuiBJGWHQcFHwsiuRb4yyMu0wzKggjMjFofeP52npaXpOn27O75xf93m/qqa6+5znPL9v93y6v8/vuZ2RiYkJJElSHZ406AIkSdJjbMySJFXExixJUkVszJIkVcTGLElSRWzMkiRV5MmDLmC6iHgxcAXwr1Me3paZ/6MH674UuDwzr5nHaw8ArszM35rhuSOBezPz/3Sxnn2BfwO+3T702cxcN22ZpcAHM3NNRPw/4LzM/Ej73EHABZn54i7rPgP4PeBRYAI4MzNvjYjrgJMz8/Yu1rE/8FngW8B64BPA54DXAAcBbwO+lJlf7aamdp0jwKVtDT/v9nUl1Ji5iPgwcDjN7+iFmXnRtOfN3ALOHNSZu/a1ewI3AGdMf72560/uqmvMrS9l5qsHXcSkiPgTYC2w3y4WORG4HJg1rMDzgE9l5v/ssMw5NKGYdHpE/GNmZjf1ToqI5wD/HTgsMyci4rnAJ4Hfmct6gMOAf87Mt0fEe2h+WT4aEa8ByMwPzXF9tPVcBrwTOGuury+gmsxFxH8BDsjMF0bE7sBtEfGZzLxnymJmbo4qzBxUlLsp1tM0tpmYuzmaT+5qbcwzard8bqfZchkBjsvMLRHxVzSzC4DLMnNdRBwIXAz8CvAAMBn+NRHxTmAUeDNNwK5ov34K8M7MvG7a0PcARwHfm6GmQ4CXA8+LiH8FjqDZqnoI+C7wpsx8ZMpLDmmX3QhsBd6amXdNWd/TgEMz881TXnM68MmIOGza2AcDHwV2AA8CqzPzzimLbAVWACdGxDWZ+c2IeP6U598fEU8HnkqzRbiCZqvu1e36twDPB94D7BkR9wFvAB6OiPEpdVxK88u6DHgFsCfw68CfZ+al7Zjrgfvbmh7MzBOALwLnRcQHMvPR6T/bGgwoczcC32w/nwCWAL/MkJlb3JmDwf2ti4g/pZktj8xQk7nrU+5qPcZ8dERcN+XfO6Y8d0O7a+PTwJkR8Upgf+AFNIFdFRH/GfhL4NzMfCHwMeDg9vW3ZubRNP/JJ9D8UJcBxwCraH7QO8nM/52ZP5up0My8FbiGZmvoZzRbREdn5uHAvcCaaS+5HXh/Zh4FXNnWMdULgOlbi1cD/xd417THLwJObdd1PnDetNp+RLsVCdwYEbcDr5yyyFXtz+ILwB/v4vu7E/gQzR+Bs2h2yZyXmZ+daXlgNDNf2Y57RvvYBcAJ7Vi/3LjJzB004X3c4YEBqCZzmflgZt4TEbvRbPVfmJnbpzxv5na2UDMHFeUuIl4CHDj9sMkkc/c4xXJX64y50+6dL7UfbwD+ANgEfDkzJ4BHIuIm4DlA0Mw8yMwrACJiFXBr+/otwJ6ZeVtErAc+BewGfOQJ1P1s4LbMvL/9+nrgZTPU/0D7+WeBs6c9vx9w9wzrPh24hZ1n7cszc3JmdT1NqH4pmuPi92Xmie3XvwtcHRHXtotM/Vksm2HMx201d2Gynk3AHlPqvK39/Ms8tkUPcBewdB7j9FpVmYuIfYDPANdl5rkd6jZzCzdzUFfu3gg8q52tH0Qz290y5f97KnNXMHe1zpg7OaT9eBhwG/Ad2l077QzjRTS7Vb4DHNo+/tqImDzOsdOxk3aLc6/M/G/A63n8Vl03HqX5Wd4BPCcinto+fhTNyQ9TXQwc237+Eh4LzKStwN7TB2h/AdYAU0+e2BwRv91hrN8G/iYiJkPzb8BPaXYHweOPIz0IPAMgIp4F7Du9ji7MdGxqU3sMCJqt5Kn2ofmea9bXzEXEU4B/Bi7JzA/soiYz95jFmDnoc+4yc1VmHtbO0q+h2dU9vSmbu8cUy12tM+aj2622qf5r+/GEiDidZlfKn2TmjyPixRFxI80xlisy8+vtLqGPRXMA/wHgeB4L+lTfpTn+8DrgYeB986j3ZpotuOOA9wPXRsSjwL/z2C6OSWcAl0TEW9rv4aRpz98E/PlMg2TmdRHxKR7bVbUa+F/RnPX3C5ot3qnL/0NE/AZwc0Rsp/mFekdm/jQiZhriFuDeiLiZ5pf9jlm/8+68heZ73k7zM/4hQEQ8CRhj57NSB6WmzJ1MMyNZHRGr28fekJlT/z/MXGcLIXNQV+66Ye4660nuRhbSu0vFHE57X8gi4gLgY5n5jUHX0gsRcQrNH5FtEXEO8HBmnh0RrwCel5nnDLjEXTJzC9NCzhyYu4WqV7lbiLuyh8H7aLa8Fou7gX+KiC8DzwXWt1u+q4C/HmhlmmTmNAjmbgYLasYsSdJi54xZkqSKFDn5KyJeQHPgfxPNqf5jwEqaC9tPy8xtJcbV8DJzGgRzpxJKnZX9GpoLy78VEZcDSzPzpdHcanA18MHJBaO55eChNNd47ZhxbRoWS2guYfhaZj40x9eaOc2XuVO/dcxcqcb818D7IuInwK8BP2gfHweWT1v2UJoLsaVJRwBfmeNrzJyeKHOnfpsxc6Ua8wrgnMz8fkRcBTy9fXwM2Dxt2bsANmzYwLJlM92QRcNiy5YtvPa1r4U2E3Nk5jQv5k79NlvmSjXmceAvI+Je4DJgv4g4n+bOJydPW3YHwLJlyxgbGytUjhaY+ezmM3N6osyd+m3GzBVpzJn5feCPSqxbmomZ0yCYO5Xg5VKSJFXExixJUkVszJIkVcTGLElSRWzMkiRVxMYsSVJFbMySJFXExixJUkVszJIkVcTGLElSRUrdK1uS1K2Rke6XnZgoV4eq4IxZkqSK2JglSaqIjVmSpIrYmCVJqoiNWZKkitiYJUmqiI1ZkqSKFLmOOSLGgLOBe4AR4C5gJTAKnJaZ20qMq+Fl5jQI5k4llJoxHwQcDawAfgQcmZmnAB8HVhcaU8PNzGkQzJ16rtSdvzbRhPUO4J+A8fbxcWB5oTE13MycBsHcqedKzZhPBUYzcwL4KfCs9vExYHOhMTXczJwGwdyp50rNmD8BnB0RdwI3Aw9HxPnAPsDJhcbUcDNzGgRzp54r0pgz8+vAMSXWLc3EzGkQzJ1K8HIpSZIqYmOWJKkiNmZJkipiY5YkqSI2ZkmSKmJjliSpIjZmSZIqUuoGI5IkLU4jI90vOzEx59U7Y5YkqSI2ZkmSKmJjliSpIjZmSZIqYmOWJKkiNmZJkipiY5YkqSI2ZkmSKmJjliSpIjZmSZIqUuSWnBFxCnAosBtwOPBpYHdgFFiTmQ+VGFfDy8xpEMydSigyY87M9Zl5AjAO/CEwmplrgY3AsSXG1HAzcxoEc6cSiu3KjoiDaLYa96AJLe3H5aXG1HAzcxoEc6deK/nuUm8BPgw8DDyzfWwM2FxwTA03M6dBMHfqqZKN+YDM/B5ARPwkItbRbFWuLjimhpuZ0yCYO/VUscacma+Y8vmZpcaRJpk5DYK5U695uZQkSRWxMUuSVBEbsyRJFbExS5JUERuzJEkVsTFLklQRG7MkSRWxMUuSVBEbsyRJFbExS5JUERuzJEkVsTFLklQRG7MkSRWxMUuSVBEbsyRJFbExS5JUERuzJEkVsTFLklSRJ5dYaUSsBN4LbAXuB/YGdgdGgTWZ+VCJcTW8zJwGwdyphFIz5rcD48AzgLuB0cxcC2wEji00poabmdMgmDv1XKnGfABwJbAaOJ4muLQflxcaU8PNzGkQzN1iMDLS/b8+KNWYtwD3ZeYj7dfPbD+OAZsLjanhZuY0COZOPVfkGDPwF8C5EbEVuBj4zYhYR3PcZXWhMTXczJwGYTC5m8vMbWKiWBkqo0hjzszvAK8qsW5pJmZOg2DuVIKXS0mSVBEbsyRJFZl1V3ZEvB84FXgEGAEmMtOzDbU4eexO0oB1c4z5lcCKzPx56WIkSRp23ezK3kozW5YkSYXtcsYcEZ8CJoCnA9+IiG+3T01k5qp+FCdJ0rDptCv7gr5VIUmSgA67sjNzY2ZuBJ4GvKT9/N3AHv0qTpKkYdPNyV9nAS9vPz8O+ALwj8UqkiRpiHVz8tcjmbkVIDN/CuwoW5IkScOrmxnzVyPiMuBG4FDgG2VLkiRpeHXTmN8K/AEQwGcy83NlS5IkaXh1ulxqCbAEuBx4NXA1sCQivpSZR/epPkmShkqnGfOJwJnAMuB2mttx7gC+0oe6JEkaSrtszJl5EXBRRJyYmZf0sSZJkoZWN8eYr4+IdwO70cyal2fmmrJlSZI0nLq5XOpv24+HA/sDS8uVI0nScOumMT+QmecC45l5As29syVJUgHd7MoeiYhlwK9GxFOBfWd7QUSsAD4HfBO4i+bs7t2BUWBNZj40/5KlxzNzGgRzpxK6mTGfBfwh8HfAHTS35JzNUcDd7edbgNHMXAtsBI6dR53SbMycBsHcqedmnTFn5vXA9e2Xv9bler8KfJEmsF8Erm0fHwd+Z441St0wcxoEc6ee63SDkTto3o/5cTLz2bOs92Dgxsx8NCJGgMnlx4DN8ylUmoWZ0yCYO/Vcpxnz54FDaLYC/w64cw7r/S7w4YjYBnwaWBER62iOu6yeZ61SJ2ZOg2Du1HOdbjDy1oh4EvAy4L00J31dCVwBdDyhITNvBV7VwzqljsycBsHcqYSOx5gz81HgGuCaiNgX+Bvgo8BT+lCbJElDp2NjbmfMLwVeAzyX5ozsQ/tQlyRJQ6nTyV/raS4FuA64MDNv6FdRkiQNq04z5jcDP6a5Fu/YiJiguVf2RGYu70dxkiQNm04nf3Vz8xFJUu1GRrpfdmLGq2TVRzZfSZIqYmOWJKkiszbmiPjdaV8fVa4cSZKGW6ezso8AngOcFhHntQ8vAU4BfqsPtUmSNHQ6nZV9D7CM5i3MntE+9ijwztJFSZI0rDqdlf1t4NsRcVFmejN2SZL6YNa3fQR+LyLeTTNznryOebZ3l5IkSfPQTWN+F3AMsKlwLZIkDb1uGvP3M/Pfi1ciSZK6aswPRMQXgG8CEwCZeWbRqiRJmq8Ffqezbhrz1cWrkCRJQHd3/toA7AY8G/gBcFXRiiRpLkZGuv8nLQDdNOYLgBXAy4C9gL8tWpEkSUOsm13Zv56ZJ0XEEZn5+Yg4o5sVR8QG4PPAfwRWAqPAaZm5bd7VSrMwd+o3M6de62bG/OSI2A+YiIi9aO7+1VFEnA5sb788MjNPAT4OrJ53pdIszJ36zcyphG5mzO8B/oXmtpw3AW/rtHBEHAPcC9xI0/i3tk+NA8vnXanUgblTv5m5PlrgZ1nP1ayNOTM3RsRLgZ8DKzPza7O85Hia+2xH+/Xk1uQY4K09VYq5U7+ZORUxa2OOiAuA8cw8JyLeExHHZ+baXS2fmce1rzsBeBB4ekScD+wDnNybsqWdmTv126LO3JDNUGvTza7sgzPzZIDMXBsR13ez4sy89IkUJs2HuVO/mTn1Wjcnf41ExFKAiNib7pq5JEmah26a7FnALRHxE2Bv4C1lS5IkaXh105j3Bg4A9gO2ZqYHFCRJ8+cx7I66acxvyswNwN2li5Ekadh105h3j4hvAEl7c5HMXFW0KkmShlQ3jfldxauQJElAd2dlfx14KfA6YCnww6IVSZIWFt/hq6e6acyXAN8H/hOwheY+sJIkqYBuGvPSzLwEeCQzbwDc5JEkqZBuGjMRcVD7cQzYUbQiSZKGWDcnf60FPgH8BvAZvMGIJEnFdGzMEfE04HuZ+cI+1SNJ0lDb5a7siDgV+BbwrYj4/f6VJEnS8Op0jHkVzfuMvhB4W3/KkSRpuHVqzA9m5sOZ+SPgV/pVkCRJw6yrs7LxEilJkvqi08lfvxkRl9E05cnPAe+VLUlSKZ0a86umfH5B6UIkSVKHxpyZG/tZiCRJ6u4GI3MWEQcCH6K5t/bXgP8ArARGgdMyc1uJcTW8zJwGwdyphG5P/pqrUeBPaS6zWgUcmZmn0LwBxupCY2rQBvsOM2ZOg2Du1HNFGnNm3gI8DFwFXAdsbZ8aB5Z3vSLfSkxd6lnmpDkwdyqhSGOOiOfSXAf9MuAQYL/2qTFgc4kxNdzMnAbB3KmEIseYaW5I8rGI+DHNezmPR8T5wD7AyYXG1HAzcxoEc6eeK9KYM/OrwB+XWLc0EzOnQTB3KqHUyV+SJGkebMySJFXExixJUkVszJIkVcTGLElSRWzMkiRVxMYsSVJFbMySJFXExixJUkVszJIkVcTGLElSRWzMkiRVxMYsSVJFbMySJFXExixJUkVszJIkVcTGLElSRWzMkiRV5MklVhoRLwLWAtuBHwC/CuwOjAJrMvOhEuNqeJk5DYK5UwmlZsz7ACdl5huBw4HRzFwLbASOLTSmhpuZ0yCYO/VckcacmVcB2yPiz4B/Acbbp8aB5SXG1HAzcxoEc6cSijTmiNgLuBi4CbgEeGb71BiwucSYGm5mToNg7lRCkWPMwDrgQOANwA7ghxGxjua4y+pCY2q4mTkNgrlTzxVpzJl5Yon1Srti5jQI5k4leLmUJEkVsTFLklQRG7MkSRWxMUuSVBEbsyRJFbExS5JUERuzJEkVsTFLklQRG7MkSRWxMUuSVBEbsyRJFbExS5JUkVLvLiUNh5GR7pedmChXh6RFwxmzJEkVsTFLklQRG7MkSRWxMUuSVBEbsyRJFSl2VnZEHAD8fWYeHBHvAFYCo8Bpmbmt1LgabuZO/Wbm1GtFZswRsQw4CfhZROwBHJmZpwAfB1aXGFMyd+o3M6cSijTmzNySmWcA24F9ga3tU+PA8hJjSuZO/WbmVEI/jjFvBZa2n48Bm/swpmTu1G9mTj1RvDFn5i+AayPifOBNwPrSY0rmTv1m5tQrRW/JmZkvbz+uKzmONJW5W+AW4G1OzZx6yculJEmqiI1ZkqSK2JglSaqIjVmSpIrYmCVJqoiNWZKkitiYJUmqSNHrmCVNswCv0ZXUX86YJUmqiI1ZkqSKuCtbUlnd7r53170EOGOWJKkqNmZJkiqyuHZle8arFpu5ZrofvwPumpaKWlyNWdLc2WilqrgrW5Kkijhj1q55aECS+s4ZsyRJFenLjDkingn8FfAT4LbMXN+PcWc1bDPCIft+q82dFi0zp17o167sNcBHMvOGiLg6Ii7MzEfa55YAbNmyZYbq5lDe+Hj55fffv/vl77ijvuVr+3lOMyUDS7pfUUe7yt3CyVw//g+6fU1ty0++pq7cLY6/dS7f2+WnmS1zIxN9mBlFxIXABzJzU0RsANZm5o/a5w4Hvly8CC0kR2TmV57oSnaVOzOnXXjCufNvneZoxsz1a8Z8JzAGbAL2Be6d8tzXgCOAu4AdfapHdVoCPIMmE72wq9yZOU3Vy9z5t07d6Ji5fs2YlwHnAfcDt2TmRcUH1dAzd+o3M6de6EtjliRJ3Vm01zFHxIuAtcB24AeZeXahcTYAn8/My3u83pXAe4GtwP2Z+cFerr8dYww4G7gHGMnM03u47gOAv8/MgyPiHcBKYBQ4LTO39Wqc2pi7Wddv5npsoWeuXfdKzN0vLebrmPcBTsrMNwKHlxggIk6n+WUo4e3AOM1xiBsLjXEQcDSwoh2rJ9rdeScBP4uIPYAjM/MU4OPA6l6NUylz15mZ672FnjkwdztZtI05M68CtkfEnwEber3+iDiG5sSOUk3zAOBKmv/c9xUaYxNNWF8F/H5E7NmLlWbmlsw8g+YXeV+arWBofiGW92KMWpm7WZm5HlsEmQNzt5NF25gjYi/gYuCmzPxkgSGOB54PvB54Y0Qs7fH6twD3tddA3t/jdU86FRjNzIl2jBKHNrYCkz+bMWBzgTGqYe5mZeZ6bBFkDszdThbtMWZgHXAg8IaIeF1mvr6XK8/M4wAi4gTgwcz8cS/XD/wFcG5EbAU+3eN1T/oEcHZE3AncnJn39XqAzPxFRFwbEefT7HI7uddjVMbcdWbmem+hZw7M3U48K1uSpIos2l3ZkiQtRDZmSZIqYmOWJKkiNmZJkipiY5YkqSI2ZkmSKmJjliSpIv8f9TvyYtYy2cEAAAAASUVORK5CYII=\n", 942 | "text/plain": [ 943 | "
" 944 | ] 945 | }, 946 | "metadata": {}, 947 | "output_type": "display_data" 948 | } 949 | ], 950 | "source": [ 951 | "fr = [0, 1, 3]\n", 952 | "to = [0, 0, 0]\n", 953 | "names = ['1 to 5', '2 to 5', '4 to 5']\n", 954 | "\n", 955 | "all_matches = []\n", 956 | "for f, t in zip(fr, to):\n", 957 | " percent_matches = []\n", 958 | " for i in range(num_quantiles):\n", 959 | " percentile_all = []\n", 960 | " for j in range(f, len(percentile_splits)-t):\n", 961 | " percentile_all.append(percentile_splits[j][i])\n", 962 | " matching = reduce(np.intersect1d, percentile_all)\n", 963 | " percent = 100 * len(matching) / len(percentile_all[0])\n", 964 | " percent_matches.append(percent)\n", 965 | " all_matches.append(percent_matches)\n", 966 | " \n", 967 | " \n", 968 | "fig, axes = plt.subplots(1, 3, figsize=(8, 4))\n", 969 | "for i, (ax, match, n) in enumerate(zip(axes, all_matches, names)):\n", 970 | " ax.bar(range(1, num_quantiles+1), match, width=0.9, color='r')\n", 971 | " if i == 0:\n", 972 | " ax.set_ylabel('Percent Match')\n", 973 | " ax.set_title('Epochs {} (No Shuffling)'.format(n))\n", 974 | " ax.xaxis.set_major_locator(MaxNLocator(integer=True))\n", 975 | " ax.set_ylim([0, 90])\n", 976 | " \n", 977 | "plt.savefig(plot_dir + \"no_shuffling.jpg\", format=\"jpg\", dpi=1000, bbox_inches='tight')" 978 | ] 979 | } 980 | ], 981 | "metadata": { 982 | "kernelspec": { 983 | "display_name": "Python 3", 984 | "language": "python", 985 | "name": "python3" 986 | }, 987 | "language_info": { 988 | "codemirror_mode": { 989 | "name": "ipython", 990 | "version": 3 991 | }, 992 | "file_extension": ".py", 993 | "mimetype": "text/x-python", 994 | "name": "python", 995 | "nbconvert_exporter": "python", 996 | "pygments_lexer": "ipython3", 997 | "version": "3.6.5" 998 | } 999 | }, 1000 | "nbformat": 4, 1001 | "nbformat_minor": 2 1002 | } 1003 | -------------------------------------------------------------------------------- /pr-lr/Loss vs Gradient Norm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import numpy as np\n", 11 | "import pandas as pd\n", 12 | "import seaborn as sns\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "\n", 15 | "import torch\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "import torch.nn.functional as F\n", 19 | "\n", 20 | "from torchvision.datasets import MNIST\n", 21 | "from torchvision import transforms\n", 22 | "from torch.utils.data import DataLoader\n", 23 | "\n", 24 | "# plotting params\n", 25 | "%matplotlib inline\n", 26 | "plt.rcParams['font.size'] = 10\n", 27 | "plt.rcParams['axes.labelsize'] = 10\n", 28 | "plt.rcParams['axes.titlesize'] = 10\n", 29 | "plt.rcParams['xtick.labelsize'] = 8\n", 30 | "plt.rcParams['ytick.labelsize'] = 8\n", 31 | "plt.rcParams['legend.fontsize'] = 10\n", 32 | "plt.rcParams['figure.titlesize'] = 12\n", 33 | "plt.rcParams['figure.figsize'] = (13.0, 6.0)\n", 34 | "sns.set_style(\"white\")\n", 35 | "\n", 36 | "data_dir = './data/'\n", 37 | "plot_dir = './imgs/'\n", 38 | "dump_dir = './dump/'" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "## Setup" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# ensuring reproducibility\n", 55 | "SEED = 42\n", 56 | "torch.manual_seed(SEED)\n", 57 | "torch.backends.cudnn.benchmark = False" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "GPU = False\n", 67 | "\n", 68 | "device = torch.device(\"cuda\" if GPU else \"cpu\")\n", 69 | "kwargs = {'num_workers': 1, 'pin_memory': True} if GPU else {}" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "## Data Loader" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# make sure shuffling is turned off\n", 86 | "train_loader = DataLoader(\n", 87 | " MNIST(data_dir, train=True, download=True,\n", 88 | " transform=transforms.Compose([\n", 89 | " transforms.ToTensor(),\n", 90 | " transforms.Normalize((0.1307,), (0.3081,))\n", 91 | " ])),\n", 92 | " batch_size=64, shuffle=False, **kwargs)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Model" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "class SmallConv(nn.Module):\n", 109 | " def __init__(self):\n", 110 | " super(SmallConv, self).__init__()\n", 111 | " self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", 112 | " self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", 113 | " self.fc1 = nn.Linear(320, 50)\n", 114 | " self.fc2 = nn.Linear(50, 10)\n", 115 | "\n", 116 | " def forward(self, x):\n", 117 | " out = F.relu(F.max_pool2d(self.conv1(x), 2))\n", 118 | " out = F.relu(F.max_pool2d(self.conv2(out), 2))\n", 119 | " out = out.view(-1, 320)\n", 120 | " out = F.relu(self.fc1(out))\n", 121 | " out = self.fc2(out)\n", 122 | " return F.log_softmax(out, dim=1)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## Utility Functions" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "def accumulate_gradient(losses, model):\n", 139 | " \"\"\"Computes the L2 norm of the gradient of the loss \n", 140 | " with respect to the weights and biases of the network.\n", 141 | " \n", 142 | " Since there's a weight and bias vector associated with \n", 143 | " every convolutional and fully-connected layer, the square\n", 144 | " root of the sum of the squared gradient norms is returned.\n", 145 | " \"\"\"\n", 146 | " norms = []\n", 147 | " for l in losses:\n", 148 | " grad_params = torch.autograd.grad(l, model.parameters(), create_graph=True)\n", 149 | " grad_norm = 0\n", 150 | " for grad in grad_params:\n", 151 | " grad_norm += grad.norm(2).pow(2)\n", 152 | " norms.append(grad_norm.sqrt())\n", 153 | " return norms\n", 154 | "\n", 155 | "def accuracy(predicted, ground_truth):\n", 156 | " predicted = torch.max(predicted, 1)[1]\n", 157 | " total = len(ground_truth)\n", 158 | " correct = (predicted == ground_truth).sum().double()\n", 159 | " acc = 100 * (correct / total)\n", 160 | " return acc.item()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "def train(model, device, train_loader, optimizer, epoch):\n", 170 | " model.train()\n", 171 | " \n", 172 | " epoch_stats = []\n", 173 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 174 | " data, target = data.to(device), target.to(device)\n", 175 | " optimizer.zero_grad()\n", 176 | "\n", 177 | " # forward pass\n", 178 | " output = model(data)\n", 179 | " acc = accuracy(output, target)\n", 180 | " \n", 181 | " # compute batch loss and gradient norm\n", 182 | " losses = F.nll_loss(output, target, reduction='none')\n", 183 | " grad_norms = accumulate_gradient(losses, model)\n", 184 | " indices = [batch_idx*len(data) + i for i in range(len(data))]\n", 185 | " \n", 186 | " batch_stats = []\n", 187 | " for i, g, l in zip(indices, grad_norms, losses):\n", 188 | " batch_stats.append([i, [g, l]])\n", 189 | " epoch_stats.append(batch_stats)\n", 190 | " \n", 191 | " # take average loss and accuracy\n", 192 | " loss = losses.mean()\n", 193 | " \n", 194 | " # backwards pass\n", 195 | " loss.backward()\n", 196 | " optimizer.step()\n", 197 | " \n", 198 | " if batch_idx % 25 == 0:\n", 199 | " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tAcc: {:.2f}%'.format(\n", 200 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n", 201 | "100. * batch_idx / len(train_loader), loss.item(), acc))\n", 202 | "\n", 203 | " return epoch_stats" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "## Main Logic" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "model = SmallConv().to(device)\n", 220 | "\n", 221 | "# relu init\n", 222 | "for m in model.modules():\n", 223 | " if isinstance(m, (nn.Conv2d, nn.Linear)):\n", 224 | " nn.init.kaiming_normal_(m.weight, mode='fan_in')\n", 225 | "\n", 226 | "# define optimizer\n", 227 | "optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)\n", 228 | "\n", 229 | "num_epochs = 1\n", 230 | "\n", 231 | "stats = []\n", 232 | "for epoch in range(1, num_epochs+1):\n", 233 | " stats.append(train(model, device, train_loader, optimizer, epoch))" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "pickle.dump(stats, open(dump_dir + \"statistics.pkl\", \"wb\"))" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "## Pearson Correlation Coefficient" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "stats = pickle.load(open(dump_dir + \"statistics.pkl\", \"rb\"))[0]\n", 259 | "stats = [val for sublist in stats for val in sublist]\n", 260 | "\n", 261 | "grad_norms = [l[1][0].item() for l in stats]\n", 262 | "losses = [l[1][1].item() for l in stats]\n", 263 | "\n", 264 | "grad_norms = np.asarray(grad_norms)\n", 265 | "losses = np.asarray(losses)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "cov = np.cov(grad_norms, losses)\n", 275 | "std_gn = np.std(grad_norms)\n", 276 | "std_l = np.std(losses)\n", 277 | "\n", 278 | "corr = cov / (std_gn * std_l)\n", 279 | "\n", 280 | "print(\"Pearson Correlation Coeff: {}\".format(corr[0, 1]))" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "## Plot" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "stats = pickle.load(open(dump_dir + \"statistics.pkl\", \"rb\"))" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "def loss_vs_gradnorm(list_stats):\n", 306 | " flattened = [val for sublist in list_stats for val in sublist]\n", 307 | " sorted_idx = sorted(range(len(flattened)), key=lambda k: flattened[k][1][0])\n", 308 | " losses = [flattened[idx][1][1].item() for idx in sorted_idx]\n", 309 | " return losses\n", 310 | "\n", 311 | "def rolling_window(a, window):\n", 312 | " shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)\n", 313 | " strides = a.strides + (a.strides[-1],)\n", 314 | " return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "sorted_losses = np.array(loss_vs_gradnorm(stats[0]))\n", 324 | "\n", 325 | "fig, ax = plt.subplots(figsize=(13, 6))\n", 326 | "\n", 327 | "rolling_mean = np.mean(rolling_window(sorted_losses, 50), 1)\n", 328 | "rolling_std = np.std(rolling_window(sorted_losses, 50), 1)\n", 329 | "\n", 330 | "plt.plot(range(len(rolling_mean)), rolling_mean, alpha=0.98, linewidth=0.9)\n", 331 | "plt.fill_between(range(len(rolling_std)), rolling_mean-rolling_std, rolling_mean+rolling_std, alpha=0.5)\n", 332 | "\n", 333 | "plt.grid()\n", 334 | "plt.savefig(plot_dir + \"loss_vs_grad.jpg\", format=\"jpg\", dpi=250, bbox_inches='tight')\n", 335 | "plt.show()" 336 | ] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python 3", 342 | "language": "python", 343 | "name": "python3" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.6.5" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 2 360 | } 361 | -------------------------------------------------------------------------------- /pr-lr/Monte Carlo and Importance Sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import seaborn as sns\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import matplotlib.lines as mlines\n", 13 | "\n", 14 | "from scipy.stats import norm\n", 15 | "from tqdm import trange\n", 16 | "\n", 17 | "# plotting params\n", 18 | "%matplotlib inline\n", 19 | "plt.rcParams['font.size'] = 10\n", 20 | "plt.rcParams['axes.labelsize'] = 10\n", 21 | "plt.rcParams['axes.titlesize'] = 10\n", 22 | "plt.rcParams['xtick.labelsize'] = 8\n", 23 | "plt.rcParams['ytick.labelsize'] = 8\n", 24 | "plt.rcParams['legend.fontsize'] = 10\n", 25 | "plt.rcParams['figure.titlesize'] = 12\n", 26 | "# plt.rcParams['figure.figsize'] = (15.0, 8.0)\n", 27 | "sns.set_style(\"white\")\n", 28 | "\n", 29 | "# path params\n", 30 | "plot_dir = './'" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Question Prompt\n", 38 | "\n", 39 | "Given the following current equation\n", 40 | "\n", 41 | "$$I(\\Delta L, \\Delta V_{TH}) = \\frac{50}{0.1 + \\Delta L} (0.6 - \\Delta V_{TH})^2$$\n", 42 | "\n", 43 | "* $\\Delta L \\sim \\ N(0, 0.01^2)$\n", 44 | "* $\\Delta V_{TH} \\sim \\ N(0, 0.03^2)$\n", 45 | "\n", 46 | "We would like to calculate $P(I > 275)$ using direct Monte-Carlo and Importance Sampling." 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Direct Monte-Carlo Estimation\n", 54 | "\n", 55 | "In MC estimation, we approximate an integral by the sample mean of a function of simulated random variables. In more mathematical terms,\n", 56 | "\n", 57 | "$$\\int p(x)\\ f(x)\\ dx = \\mathbb{E}_{p(x)} \\big[\\ f(x) \\big] \\approx \\frac{1}{N} \\sum_{n=1}^{N}f(x_n)$$\n", 58 | "\n", 59 | "where $x_n \\sim \\ p(x)$.\n", 60 | "\n", 61 | "A useful application of MC is probability estimation. In fact, we can cast a probability as an expectation using the indicator function. In our case, given that $A = \\{I \\ | \\ I > 275\\}$, we define $f(x)$ as\n", 62 | "\n", 63 | "$$f(x) = I_{A}(x)= \\begin{cases} \n", 64 | " 1 & I \\geq 275 \\\\\n", 65 | " 0 & I < 275 \n", 66 | " \\end{cases}$$\n", 67 | " \n", 68 | "Replacing in our equation above, we get\n", 69 | "\n", 70 | "$$\\int p(x) \\ f(x) \\ dx = \\int I(x)\\ p(x) \\ d(x) = \\int_{x \\in A} p(x)\\ d(x) \\approx \\frac{1}{N} \\sum_{n=1}^{N}I_{A}(x_n)$$" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 2, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "def monte_carlo_proba(num_simulations, num_samples, verbose=True, plot=False):\n", 80 | " \n", 81 | " if verbose:\n", 82 | " print(\"===========================================\")\n", 83 | " print(\"{} Monte Carlo Simulations of size {}\".format(num_simulations, num_samples))\n", 84 | " print(\"===========================================\\n\")\n", 85 | " \n", 86 | " num_samples = int(num_samples)\n", 87 | " num_simulations = int(num_simulations)\n", 88 | " \n", 89 | " probas = []\n", 90 | " for i in range(num_simulations):\n", 91 | " mu_1, sigma_1 = 0, 0.01\n", 92 | " mu_2, sigma_2 = 0, 0.03\n", 93 | "\n", 94 | " length = np.random.normal(mu_1, sigma_1, num_samples)\n", 95 | " voltage = np.random.normal(mu_2, sigma_2, num_samples)\n", 96 | "\n", 97 | " num = 50 * np.square((0.6 - voltage))\n", 98 | " denum = 0.1 + length\n", 99 | " I = num / denum\n", 100 | " \n", 101 | " true_condition = np.where(I >= 275)\n", 102 | " false_condition = np.where(I < 275)\n", 103 | " num_true = true_condition[0].shape[0]\n", 104 | " proba = num_true / num_samples\n", 105 | " probas.append(proba)\n", 106 | " \n", 107 | " if plot:\n", 108 | " if i == (num_simulations - 1):\n", 109 | " plt.scatter(length[true_condition], voltage[true_condition], color='r')\n", 110 | " plt.scatter(length[false_condition], voltage[false_condition], color='b')\n", 111 | " plt.xlabel(r'$\\Delta L$ [$\\mu$m]')\n", 112 | " plt.ylabel(r'$\\Delta V_{TH}$ [V]')\n", 113 | " plt.title(\"Monte Carlo Estimation of P(I > 275)\")\n", 114 | " plt.grid(True)\n", 115 | " plt.savefig(plot_dir + 'monte_carlo_{}.pdf'.format(num_samples), format='pdf', dpi=300)\n", 116 | " plt.show()\n", 117 | " \n", 118 | " \n", 119 | " mean_proba = np.mean(probas)\n", 120 | " std_proba = np.std(probas)\n", 121 | " \n", 122 | " if verbose:\n", 123 | " print(\"Probability Mean: {:0.5f}\".format(mean_proba))\n", 124 | " print(\"Probability Std: {:0.5f}\".format(std_proba))\n", 125 | " \n", 126 | " return probas" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 3, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "===========================================\n", 139 | "10 Monte Carlo Simulations of size 10000\n", 140 | "===========================================\n", 141 | "\n", 142 | "Probability Mean: 0.00199\n", 143 | "Probability Std: 0.00050\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "probas = monte_carlo_proba(10, 10000, plot=False)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 4, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "def MC_histogram(num_samples, plot=True):\n", 158 | " \n", 159 | " num_samples = int(num_samples)\n", 160 | " \n", 161 | " mu_1, sigma_1 = 0, 0.01\n", 162 | " mu_2, sigma_2 = 0, 0.03\n", 163 | "\n", 164 | " length = np.random.normal(mu_1, sigma_1, num_samples)\n", 165 | " voltage = np.random.normal(mu_2, sigma_2, num_samples)\n", 166 | "\n", 167 | " num = 50 * np.square((0.6 - voltage))\n", 168 | " denum = 0.1 + length\n", 169 | " I = num / denum\n", 170 | " \n", 171 | " if plot:\n", 172 | " n, bins, patches = plt.hist(I, 50, density=1, facecolor='green', alpha=0.75)\n", 173 | " plt.ylabel('Number of Samples')\n", 174 | " plt.xlabel(r'$I_{DS}$ [$\\mu$A]')\n", 175 | " plt.title(\"Monte Carlo Estimation of P(I > 275)\")\n", 176 | " plt.grid(True)\n", 177 | " plt.savefig(plot_dir + 'mc_histogram_{}.pdf'.format(num_samples), format='pdf', dpi=300)\n", 178 | " plt.show()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 5, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEQCAYAAABWY8jCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XuYHFWd//H3JObChCEhGA0wrqiDXwZRUFduGgiaRYVV3HX3ERIMQUMCQUCCEBRcwSgu8uRCfoZIDOHiAqJc9AFdf6tLjK4hwMJyi+0XRlgSLhkJSXTIpAkw2T/OmaTT6e6pSaZ6prs+r+eZJz1Vp6q+J91T3z6nqs5p2Lp1KyIikk2D+jsAERHpP0oCIiIZpiQgIpJhSgIiIhmmJCAikmFKAiIiGfam/g5ABgYzGw8sA05299sKlj8GPOzuU3q5v/cCe7v7bxOW/wxwHtAA7AFc5e63J9z2MmCtu38/QdnfAI1AZ8Hiq9z952XKTwOuB94DfNrdv5kkpjL7+gfgfqAL+Bd3n7Gr+6pwjJuAdwNT3P2PcdllwETgBWArMAz4mrv/Jq6/GrgKmAAc5O4X9/KYhwH/D3gDeBWYDOwLzC8odiTwGeAB4Engibj8LmABcANwprtv7s2xZfcpCUihPwKnALfBthP5iF3c12eBtUCPScDMjgbOB05091fMbB9gpZn9wd3/sIvHr2Ry9wkyga8BN7n7I8Aju3nc8wgnuj8CfZ4Aoo+7+1tLLJ/bnSTNrBW4GfiAmR0JvO7uz5nZThuZ2d8Co4FfuXu5h4quBs5x90fMbDowy91nAuPjPv4ZeMHdf2lmE4Bb3f2couPcAlwEXN77KsvuUBKQQo8C7zazUe6+ETiVcLL4GwAzmwR8mfBt7ylgGjAJOIHw7fpdwJXAr4ApwBYze5jwzf7bhG+KfwKmu/trBcc9A5jv7q8AuPvLZnY4sNHMmoFFwHBgH+Cb7v5TM3uC8I3yVcC7d2Rmc4CPxF9vcferk1TczMYQkt8gYAhwJnA4MBb4kZnNJ5zATzazNmAFcCBwLzAylnV3/7yZHQLMjfsaBZwL7A0cBtxkZqcSEsuRZvZ3wLeAPPAy8IVYbhawBXgHcJu7f7so3lLbXQHsbWY/c/eTKlR3NPBKfH0uMKdC2ReAzwGXm9ndwPXu/mJRmZMLlr0pxtQd5wjCif2YuOiDhOSzHPgzcG7c9tfAXDOb7e5dFeKRPqZrAlLsTuAfzKyBcGJbARC/nV8OfNTdPwJsBKbHbUa6+98DnwYudvfnCc37ucCDwA+Af3T3Y4HnCQmi0H7A04UL3H1D/OZ5EDDH3f8O+BJwdiyyJzDb3U/p3sbM/p5w0jySkAgmxtZMsZvM7DcFP2NiXf8CfJJwYtzL3a8jtGZOLtr+AOBSwontXOAa4AjgI2Y2itB1dIG7T4j/B6fH7qZHCF0lW2K8DcDigv+b5XG/AG8ntKaOInxD3qbcdrF7aX2ZBDAz1vU/gZmExAtwLNu7Znbi7i+4+4XAOGA18LiZfaqozIsxrqMJ79G8gtVfBH7i7uvi738EvhHj/imhGwl3f4OQFA4pF4ukQy0BKXYL4Zv308DvCpa/E1jl7h3x998CxxP6uLu7SdYQvrEXGkPoH/5x7G7YA/iPojLPAm8jtEQAMLMPA+3Ai8ClZvZFQn/2kILtnB21Ar+LyeM1M1sJHAw8XlRup+4gM/t3wjf7nwGvEb5ll/Oyu6+O223q7rIys7/E+j8PfN3MNgNNwF/L7OfNwF9j0oTwf3oFcA/wuLu/Drwe95Nku0q2dQcVGezur1ba0Mw+SEgaRugeu7dEmc8BlxC69F4qWDUJ+KeC3+9l+/WYu4DCaywvElp7UkVqCcgO3P1pwnWAc4F/K1j1DHBwbN5D+Ab5ZHxdqq+4i/D5Wgc8B5zk7uMJ3ULLispeD1zYvW8ze0tc1gjMJnSdfD5u11B0jEI5YleQmQ0BjiZ0WyUxHnjR3Y8nJIDuk2p3PQr1NODWAsK33dMICag75uJ9rQP2MrN94+89/Z8m2a63NpvZ4HIr47f+6cAP3P04d1/s7puKypxKaAGMj5+f7uUjgWHuvqag+BJCCwfgY8BDBev2JrQGpIrUEpBSbgM+7+5Pmtk7Adx9nZl9A1hmZl1AG3AxO3eVdHuIcMdJjnBB9OdmNojwrXhyYUF3v8/MFgO/MrPXCK2Fr7r7Y2b2E2CBma0ltDTeXC5od7/HzMab2X3AUODH7v5wiaI3mVnh3UG3AT8GbjOzLxOuXXR/Q/0d8At6d8Hy34CfmVk7IQF2x7wCuIlwLQV332pmZwB3xv/TDYSusopdIhW22xW/Bz5A6LYrday7gbvLbRwTyAJCV9GdsbW33N2/QbhL6X+LNrkYWGpmM4BNwNS4n0FAM5DGjQBSQYNGERXJLjM7inBh97x+juME4APuXqkbTlKg7iCRDHP3+4A3xbuw+kW80D2RHS8oS5WoJSAikmFqCYiIZJiSgIhIhtXU3UFmNgz4EOF+4jf6ORwRkVoxmPC8zoPFz4XUVBIgJIDf9VhKRERKGQf8V+GCWksCLwLcfPPNjB07tuoHb2tro6WlperHTVs91kt1qh31WK+BVqe1a9cyadIkiOfQQrWWBN4AGDt2LM3N1b+jraOjo1+Om7Z6rJfqVDvqsV4DuE47daPrwrCISIYpCYiIZFgq3UFmtj9hjPL1hJEnF8blEwjjxjQAi9x9RRxkailhdqeVZvYJ4CTCOPHL3P1nacQoIiLptQSmAwvi+OYnxhEdIYxjPpUwgNasuOwSdpzq72xC8hgF/E9K8YmICOldGB5LGPERwgiHIwnD3za4e/eEGsMB3P2iOAdqt0MIMxmNJQzpu8OIkxCuvHd0dBQvTl0+nyeXy1X9uGmrx3qpTrWjHus10OrU3t5edl1aSWA1YVjYNYSp7DbG5XkzG0p4cCFfZttn2T5lXkktLS39cuU9l8vR2tpa9eOmrR7rpTrVjnqs10CrU1NTU9l1aSWBJYT5QqcQpiucZ2YXAPPjuiGEyUJK+R5wI+FWpu+mFJ+IiJBSEnD3tYShYYstjz/F5S8reH07cHsacUn1HHfjcWXXLTuteGIxEekvtfawmNSBcglCyUGk+vScgIhIhikJiIhkmJKAiEiGKQmIiGSYkoCISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiG6YlhGTD0JLFI9SkJyG6pNEaQiAx86g4SEckwJQERkQxTEhARyTAlARGRDFMSEBHJMCUBEZEMUxIQEckwJQERkQxL5WExM9sfmAOsB1a5+8K4fAIwGWgAFrn7CjMbCSwFrnL3lbHcXsBK4KNx0noREUlBWi2B6cACd58BnGhmQ+LymcBUYBowKy67BOjs3tDMBgHfAf6UUmwiIhKlNWzEWGBNfL0BGAmsAxrcfQuAmQ0HcPeLzOyygm2/ASwGzi+387a2Njo6OlIIu7J8Pk8ul6v6cdO2O/Xq3NTZc6HdtCux1eN7VY91gvqs10CrU3t7e9l1aSWB1UAzIRGMBjbG5XkzGwoMBvLFG5nZGOBI4K3AUcBFhNbDDlpaWmhubk4n8gpyuRytra1VP27adqdejQ809nE0O9uV2OrxvarHOkF91mug1ampqansurSSwBJgrplNAe4E5pnZBcD8uG4IMLt4I3d/Cfg4gJndAHw3pfhERISUkkC8mDuxxKrl8ae4/GUllk3p88BERGQHukVURCTDlARERDJMk8rIgKcZx0TSo5aAiEiGKQmIiGSYkoCISIYpCYiIZJguDEsi5S7OikhtU0tARCTDlARERDJMSUBEJMOUBEREMkxJQEQkw5QEREQyTElARCTDlARERDJMSUBEJMOUBEREMkxJQEQkw5QEREQyLJUB5Mxsf2AOsB5Y5e4L4/IJwGSgAVjk7ivMbCSwFLjK3Vea2VnAh4A9gR+6+91pxCgiIum1BKYDC9x9BnCimQ2Jy2cCU4FpwKy47BKgs2Dbje7+BeAsYGJK8YmICOkNJT0WWBNfbwBGAuuABnffAmBmwwHc/SIzu6x7Q3e/1cz2JLQkvlNq521tbXR0dKQUenn5fJ5cLlf146YtSb06N3VWXN8fKsVcj+9VPdYJ6rNeA61O7e3tZdellQRWA82ERDAa2BiX581sKDAYyJfa0MwOAi4Fvu7uz5Qq09LSQnNzc58H3ZNcLkdra2vVj5u2JPVqfKCxStEkVynmenyv6rFOUJ/1Gmh1ampqKrsure6gJcA5ZnYtcCcwL57858d1S4HZxRuZWQNwD7AHMNvMLk4pPhERIaWWgLuvpXR//vL4U1z+soJfW9KISepPudnOlp22rMqRiNQu3SIqIpJhSgIiIhnWqyRgZm9LKxAREam+Hq8JmNm5wGZgFHC6mf3S3WemHpmIiKQuyYXhU4BjgV8C7wH+M9WIRESkapJ0B20F9gXa3X0r4b5/ERGpA0laAsuA3wKnmNk84I50QxIRkWrpMQm4+yXAJWa2NzCre9gHERGpfUkuDB8DXEMY6uEnZvasu1+XemQiIpK6JNcEvgUcA6wFrgBmpBqRiIhUTZIk0OXu64Gt7p4Hqj98p4iIpCJJEmgzs+8A+8QB3Z5NOSYREamSJHcHnUmYCOa/gFeAM1KNSPpVuUHZRKQ+lU0CZnZ8wa9Pxx+A8cB/pBiTiIhUSaWWwClFv28lzA28FSUBEZG6UDYJuPvp3a/N7P2AESaNf7wagYmISPp6vDBsZrOB7wFHANea2YWpRyUiIlWR5O6gTwLj3P18YBzwz+mGJCIi1ZIkCTwHdM9SPAQoP229iIjUlCS3iO4HPGlmjwIHA1vMbAWAux+dZnAiIpKuJElA3T8iInUqSRJ4K3AyMLx7gbtXHD/IzPYH5gDrCXcULYzLJwCTCbeaLnL3FWY2ElgKXOXuK0uV6X21REQkiSTXBG4EHgP+f8FPT6YDC2KyONHMhsTlMwlPH08DZsVllwCdBduWKiMiIilI0hJ4yt1v6OV+xwJr4usNwEhgHdDQPR+BmQ0HcPeLzOyygm13KlOsra2Njo7qj2OXz+fJ5XJVP27aCuvVuamzh9ID3xHXHEFXVxeDlu34HeeG427on4D6SBY+f/VioNWpvb38/TxJksAdZvYj4A/dC9z9mz1ssxpoJiSC0cDGuDxvZkMJcxPky2zbY5mWlhaam5sThN63crkcra2tVT9u2grr1fhAYz9H0zc6N3XSOGLHutT6e5eFz1+9GGh1ampqKrsuSRKYAdzJ9hN5EkuAuWY2JW47z8wuAObHdUOA2WW2TVJGRET6QJIksN7dr+zNTt19LTCxxKrl8ae4/GUFr0uWERGRvpckCawzs2uBhwmDx+Hui1ONSkREqiJJEmiL/46N/25NKRYREamyHpOAu19uZvsS+ugbCE8Qi4hIHegxCZjZdcBRwAhgD8LkMkemHJeIiFRBkofFWoH3EB4SO5jyt3aKiEiNSZIEOtx9KzDC3dcBQ1OOSUREqiRJEnjIzL4CvBAfGhucckwiIlIlSS4Mf83MmoDNhAlmHkg9KhERqYqySSCO2zMdWEAY+2cp8CrwYHVCExGRtFXqDloAvD2WWQg8CtwBLKpCXCIiUgWVksDb3X0m4fmAccCV7n4XMKYqkYmISOoqJYGu+O+HgQfc/bX4+x7phiQiItVS6cLwJjObBvwTcIuZDQK+QBgmWkRE6kCllsCZwLuAnxJmFxsPfAo4K/2wRESkGsq2BOKDYYXTO94bf0REpE4keVhMRETqVNkkYGYjqxmIiIhUX6ULw3cDx5jZInfXdYA6c9yNx2173bmps27mFhaR3qmUBDab2YPAgWZ2aFzWAGx196PTD01ERNJWKQl8kjCBzLWEO4IaqhKRiIhUTaW7g7qA58zsJGAaYU6BJ0kwbISZ7Q/MAdYDq9x9YVw+AZhMSCiLgJXAYqADGObuM8zs48BJMbbfu/uNu149ke0Ku8AKLTttWZUjERk4ktwddC3QAvwKOABYkmCb6cACd58BnGhmQ+LymcBUQlKZRXj24Gl3Px94ycyOJgxRcWg85pOJayIiIr2WZKL5A939mPj6p2a2IsE2Y4E18fUGwiik64AGd98C20YpLSz3HKH76dfAt4FG4Hrg08U7b2tro6OjI0EYfSufz5PL5ap+3DR0burc9rqrq2uH3+tBb+pUK+9pPX3+CtVjvQZandrb28uuS5IEhptZo7t3mtkeJJtUZjXQTDjBjwY2xuV5Mxsa95GP5cbFdc3AKuAKYELcpmR8LS0tNDc3Jwijb+VyOVpbW6t+3DQU3g3UuamTxhH1dXdQb+pUK+9pPX3+CtVjvQZanZqamsquS9IddDXwqJndBTwCzEuwzRLgHDO7FrgTmBdP/vPjuqXAbOD3wAFmdjUwyt1XEL793wJcl/BYIiKyi5LMLHazmf078E7gGXd/OcE2a4GJJVYtjz+FphdteyNhrCIREUlZku4g3H094U4fERGpIxo7SEQkw3pMAmb2lWoEIiIi1ZekJXCCmSW5I0hERGpMkmsCbwZeMLNngK1o7CARkbqRJAl8KvUoRESkXyRJAq8DVwJjgNuBx4Bn0wxKRESqI8k1gcWEh7uGAr8lPDwmIiJ1IEkSGO7u9xKuBThhuAcREakDSZLAq3F458FmdiRKAiIidSNJEpgGnE64S+grhAlmRESkDiQZO+g5M7sCeDfwhLs/k35YIiJSDUmeGL4UuAb4MHCdmX059ahERKQqEj0xDBwTZ/86Fjg53ZBERKRakiSBPxNm+YJwm+hL6YUjIiLVVPaagJndRxgm4i3AU2b2KHAw0ON8AiIiUhsqXRhWt4+ISJ0rmwTc/VkAMzuckBCGF6yekXJcIlVz3I3HlVy+7LRlVY5EpPqSjB10I2HsoA0pxyIiIlWWJAk85e43pB2IiIhUX5IkcIeZ/Qj4Q/cCd/9mpQ3MbH9gDmFe4lXuvjAunwBMBhqARcBKwgB1HcAwd59hZp8ATgJeBZa5+896XSsREUkkyS2iM4D/AdoLfnoyHVjg7jOAE81sSFw+E5hKGIpiFjAeeDo+g/CSmR0NnE1IHqPicUVEJCVJWgLr3f3KXu53LLAmvt4AjATWAQ3uvgXAzIYXlXsO2A84BPhcXPctQsthB21tbXR0dPQypN2Xz+fJ5XJVP24aOjd1bnvd1dW1w+/1oC/qNNDe63r6/BWqx3oNtDq1t5f/7p4kCawzs2uBhwnPDeDui3vYZjXQTDjBjwY2xuV5MxsKDCaMRroaGBfXNQOrCBPW5KnwPEJLSwvNzc0JQu9buVyO1tbWqh83DY0PNG573bmpk8YRjRVK156+qNNAe6/r6fNXqB7rNdDq1NTUVHZdkiTQFv8d24tjLgHmmtkU4E5gnpldAMyP64YAs4GHgM+b2dUA7r7CzL5HuCPpDeC7vTimlFDu9kcREUiWBK7v7U7dfS0wscSq5fGn0PSibW8nTGMpIiIpS5IEbiN0Aw0C3gE8BXwkzaBERKQ6kswncFT3azMbBVybakQiIlI1SW4RLfQX4F1pBCIiItXXY0ugYDTRBmAM8Ou0gxIRkepIck2gcDTRvLsneVhMRERqQKX5BHZ6SCsux91vSi8kERGplkotgeInHRqA04FOQElARKQOVJpP4Kvdr82sBbgBuAfQRPMiInUiyYXhswkn/vPd/Z70QxIRkWqpdE1gf8LTwuuBw91dk8qIiNSZSi2BJ4AtwL3AQjPbtsLdSw0JIVJXNO2kZEGlJPCZqkUhIiL9otKF4eKB3kREpM70dtgIERGpI0oCIiIZpiQgIpJhSgIiIhmmJCAikmFKAiIiGaYkICKSYUoCIiIZlmRSmV6L4w7NIYw7tMrdF8blE4DJhGGpFwErgcVABzDM3WfEcnvFdR9197VpxCgiIum1BKYDC+JJ/UQzGxKXzwSmAtOAWcB44Gl3Px94ycyONrNBwHeAP6UUm4iIRKm0BICxwJr4egMwElgHNLj7FgAzG15U7jlgP+AbhNbB+eV23tbWRkdHRzqRV5DP58nlclU/7u7o3NTZY5murq5E5WpJmnXqr89ALX7+kqjHeg20OrW3l58VOK0ksBpoJpzgRwMb4/K8mQ0FBgP5WG5cXNcMPAkcCbwVOAq4iNB62EFLSwvNzc0phV5eLpejtbV4wrWBrfGBxh7LdG7qpHFEz+VqSZp16q/PQC1+/pKox3oNtDo1NTWVXZdWElgCzDWzKcCdwDwzuwCYH9cNAWYDDwGfN7OrYdugdcsBzOwG4LspxSciIqSUBOLF3FJzDmw7yReYXmYfU/o4rLpWbux76XuaZ0DqiW4RFRHJMCUBEZEMUxIQEckwJQERkQxTEhARyTAlARGRDFMSEBHJMCUBEZEMUxIQEckwJQERkQxTEhARyTAlARGRDEtrFFGRzNHAclKL1BIQEckwJQERkQxTEhARyTAlARGRDFMSEBHJMCUBEZEMUxIQEcmwVJ4TMLP9gTnAemCVuy+MyycAk4EGYBGwElgMdADD3H2GmZ0FfAjYE/ihu9+dRoy1ShPKi0hfSqslMB1Y4O4zgBPNbEhcPhOYCkwDZgHjgafd/XzgJTM7Gtjo7l8AzgImphSfiIiQ3hPDY4E18fUGYCSwDmhw9y0AZja8qNxzwH7ufquZ7UloSXyn1M7b2tro6OhIKfTy8vk8uVyu6sct1Lmps8/32dXVlcp++9NAqtMR1xxRdt0Nx92QeD8D4fOXhnqs10CrU3t7e9l1aSWB1UAz4QQ/GtgYl+fNbCgwGMjHcuPiumZglZkdBFwKfN3dnym185aWFpqbm1MKvbxcLkdra2vVj1uo8YHGPt9n56ZOGkf0/X77U63UqTefp4Hw+UtDPdZroNWpqamp7Lq0uoOWAOeY2bXAncC8ePKfH9ctBWYDvwcOMLOrgVHAfcA9wB7AbDO7OKX4RESElFoC7r6W0v35y+NPoelFv7ekEZOIiOxMt4iKiGSYkoCISIYpCYiIZJiSgIhIhmlmMZF+pNnIpL+pJSAikmFKAiIiGaYkICKSYbomMEBptFARqQYlAZEBqNSXgM5Nndzfen8/RCP1TN1BIiIZpiQgIpJhSgIiIhmmawIiNUQPl0lfU0tARCTDlARERDJM3UH9TM8DSF9QN5HsKrUEREQyTC0BkTqmFoL0RElAJIOUHKRbKknAzPYH5gDrgVXuvjAunwBMBhqARcBKYDHQAQxz9xnFZdx9RRoxVpv6/qUWKDlkT1otgenAAndfYWa/MLPF7v4aMBP4DDAY+BFwNfC0u19hZpeb2dElypxUsN/BAGvXrk0p7Mra29tpamqqWOaUO06pUjR9543Nb/D6a6/3dxh9SnXqW+OuHter8rd+9tbEZZP8XdWagVangnPm4OJ1aSWBscCa+HoDMBJYBzS4+xYAMxteVO45YL8SZQrtCzBp0qSUws6uDjr6O4Q+pzr1n49d87H+DkFK2xf4U+GCtJLAaqCZcIIfDWyMy/NmNpSQjfKxXPdXjGZgVYkyhR6M5V8E3kgpdhGRejOYkAAeLF7RsHXr1j4/mpmNBeYS+vr/G3gfcAFwFPBFYAjhmsFDwPeJJ3t3P8/Mji0s4+7/3ecBiogIkFISEBGR2qBbRMswsxbgJ+7+fjO7EDiAcG3jfGAoJe5+qgVF9VpKaCZuJdyttRn4KtAJ3OPud/VfpD2LNxKcB7wCPAvsCQwjvE/TAaOG6tOtRL0+BLwcV88G3kGN3UFnZgcC/wqsJXRJjKEO/qZK1OsYauxvSk8MlxC7s6YCm+LF6WPc/WzgOuAMtt/9NAM40cyG9F+0yRXWKy56b3z9CuF6zFeAC9z9i8BZ/RJk7+wNTI3xfgQY6e7nAcuBz1J79elWWK/xwLsIXaZ/Bv6XcAfdVGAaMKt/Quy1kYT348vAROrkb4od6/WP1ODflJJACe6+1t0vJryRowl/fLD9DqZSdz8NeIX1MrMG4ML4R/dL4FxgjLu/EIsP+H5Cd/85oS6XAL8nvD+w/X2qqfp0K6rXjcDp7n4m4UaKzxHvoHP3zUDxHXQDUry2twX4OfAb6udvqrBe91ODf1NKAj37M7BPfN0MvMD2u59gx7ufakkTcHB8/TKhOf68me0blzX0S1S9YGZNwBLCQ4dLgf3jqu73qabq062oXssJ3T+w/X3Km9lQM9uDne+gG5DM7DAg7+7HAx8E3hxX1fTfVFG9/hb4cFxVM39TujBcgZn90t0/YWbnEfqX9wbOBPag4O4nd/9BP4bZawX1WkD4djIKuIiQ7C4lfLO53d3v6ccwexSvaRxI6Dd/A3geGEH4FnlGXFcz9elWVK8G4HXCSWUvwrfLD1Fjd9CZ2eGEz9jLwF8JLYCa/5sqUa9h1NjflJKAiEiGqTtIRCTDlARERDJMSUBEJMOUBEREMkxJQEQkw5QEREQyTElAJDKz6Wb2/f6OQ6SalAREtnsf8HhvNzKzKWa22sxmJiw/y8xe7J40ycyWmNlGMzuot8cW2V1KAiLbvRd4bBe3vcXd5wKY2W/MzOLrfczsiaKykwhTp54M4O5TgUd28bgiu0VJQGS7Q9iFlkAJLcBT8fUOrQszG0+Y3u/7wNl9cCyR3aL5BEQAM3sb8Iq7bzSzBwkjQu4FLHP36+OUp3MIYxSNAi5x9+dL7OftwPPu3hUXvY8dWxdTgSXu7mb2qpkd4e73p1g1kYqUBESC9wGPx2Rwv7t/CcDM7jWzmwjzFbzk7t/sYT+HseNJ/4PAbXFfewMnAG8xs3MIA919iZBwRPqFuoNEgu7rAR8kzH3drRPoAu4DhpnZD83s1Ar7OZQ4xn+cdeoktncHnQpc5+7Hu/sngCOA481sTJ/WRKQXlAREgvcSTtbbkoCZHQqsdvet7r7Z3S8BTqNyX/5hwCAzexT4FyAXt4HQFfTD7oLu3gncQRj2WqRfqDtIBHD3SQBm9gtgHzN7ldD/Pysuv4Ywrv+ewBUVdvU+4P3u3lHiGIeWWDYj7v/43a2DyK5QEhAp4O4nlFk+o4dNJ5rZZqCrVAKoxMyWEFoQIlWnSWVERDJM1wRERDJMSUBEJMOUBEREMkxJQEQkw5QEREQyTElARCTDlARERDJMSUBEJMP+D8AbFf6rPhUDAAAAAElFTkSuQmCC\n", 189 | "text/plain": [ 190 | "
" 191 | ] 192 | }, 193 | "metadata": {}, 194 | "output_type": "display_data" 195 | } 196 | ], 197 | "source": [ 198 | "MC_histogram(1e6)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 6, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "Iter 1/4\n", 211 | "Iter 2/4\n", 212 | "Iter 3/4\n", 213 | "Iter 4/4\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "num_samples = [1e3, 1e4, 1e5, 1e6]\n", 219 | "num_repetitions = 25\n", 220 | "\n", 221 | "total_probas = []\n", 222 | "for i, num_sample in enumerate(num_samples):\n", 223 | " print(\"Iter {}/{}\".format(i+1, len(num_samples)))\n", 224 | " probas = monte_carlo_proba(num_repetitions, num_sample, verbose=False)\n", 225 | " total_probas.append(probas)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 7, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAETCAYAAAAyK6EVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4VdW9//F3DFORiEWwCMEx+jVVayd/KoriLHprubfSgj4gVCSVtldxqK2116m3en0EcaRwsQ5txVZLvbUOHZWiiLWTrXj6tXFiUFKQwSCEKPD7Y63o2afJ4STknH0gn9fz5Dkne1rfk5Vnf8/aa+21K7Zs2YKIiEiLndIOQEREyosSg4iIJCgxiIhIghKDiIgkKDGIiEiCEoOIiCR0SzsA2fGY2XDgx8CLQAXQHZju7j82s48DZ7j7Ndtw/EnAXe7+btay8cBdwBHu/mxc1h14E7jN3a9qZxnHAGvc/a8Fbv9h4EZgf6ASWALUufvaAvdf7u4DC9huOB/8bVuscPdRbWy/J3Couz9sZtOBae6+uJCYWjlWP+BUd7/PzL4O/Nbdf9+RY0l5U2KQYvmtu48GMLM+wDwze8nd/wL8ZRuPfTlwL/BuzvK/A2OAZ+PvpwIFnZhb8UXgfqCgxADMAWa6+08BzGwKMBMY3cHy83n/b1uA44EDgYfd/cJtLPdjwBnAfe5+/TYeS8qYEoMUnbuvM7OZwJlmtivwJXcfbWavE07mGWAqMAvoBTQBk9x9iZldAYwk/K/OAN4DBhJO2iNzinoMOMXMdnL3zYQkMadlpZldTDhRvwf8zt0vM7OrgH2A3YG9gCnASkJS+aSZvQgcDlwEbAKecvevZxdqZnsBA1uSQnQL0Ceu/wrwH4SW09r4/ixC8tkJuDLrWJ8Abo1lNQHnFfoN38wmA+cAm4GngK/Hn95mtiB+hi/Fv0EN0B/oB9wBfA44ADjH3Rea2XXAp4EqIOPuE4BvAofGFttQQh38BvgesB+hpTTN3X9kZk8SvgAcDOwCjHL31wv5HJI+9TFIqTQQTkTZhgBnxW+yNwK3uPtx8f318SQ5gnBiHgp8lHASWk7r38SbgWeAY82sinBCWgpgZocAn4/HGQrsb2b/Fvfb6O4jgAuAKe7+R+Bx4GvAOuBq4AR3PxoYbGYn5ZQ7CHg1e4G7b3L3tWa2E7AbcKK7DyMkh8PiZqvd/Wh3/03Wrv8LfMXdjyWcsKe18jmPN7Mns34ujcsnABe4+5HAK4TLeNcTvuH/LOcYG9z9VGAucJq7fyZuO9rMdomxnRT/VkeY2WDgvwmtlVlZx6kDVrr7UOBE4Ntm1lLPv3f3E4FfEZK0bCfUYpBS2Yt4ks6y0t3fiu8PAS43s8sIJ7RmwAgnl03AesKJGzPLV859hJPQnoSTXo+4/EBgYUu/hJnNBw6K6/4cX5cQWizZaoABwKOx3CpgXzObHdetAC4EqrN3iv0bo+L1+GZgjpmti9t1j5t5K/EPipfbAH5HOFnnautS0gTgEjP7H0KCrGhlmxZ/iq9r+KC/YjXh828AdjezOYTE2Ccr5ly1wK8B3L0xtrD2i+uy/65b7T+R8qEWgxRd/PZ+HvBAzqrNWe//Dlzm7sMJ30IfjMs+aWY7mVl3M/uVmfWM+7X1v/skcAQwKh4j+/iHm1k3M6sAjgFeiutamzCspYxXCSe2k2JstwLPuvtEdx/u7qPcfRmw0sw+m7X/BcBIM/sYMNLdvwB8NR6zIquMXG/EfQCOzYqxEOcRLtMdC3yC8G2/rb9VvknSRgBD3H0MoT/nQzHm1o6VAYbB+/V8CB+0njQR23ZKLQYpluPjdeZNhP+zK93dzWyPNra/BJhhZr0IJ6IL3P0vZvY48DThhDTD3TfGb/uPmtlx7p44+bj7ZjP7FeHE9nZL68Ld/2ZmP8461lPAQ8ChbcTzLOHb+hcIl3PmmVkl8BphVFCuscDtZnYJoZXyMuFE/S7wjpn9AdhIGCU1qM2/Wtjntpi83gPObWWblr9tthHA34DnzGwFsCx+hreBb5rZnyjc74FvmdnCGPMrMeaXgUPMLLsTexbwv2b2FKHernb3f26lVSdlrkKzq4qISDZdShIRkQQlBhERSShJH0Mc6jYVWAUscvfb4/ITgXGEjq0ZwELCNctGoKe7TzazU4HPEq51PuHu/1eKmEVEuqpStRjqCGPUJwOnx6F8EG64mQhMAi4DhgOvuPsUYIWZDQW+TEgou/LB8DcRESmSUo1KGkgY8gdhrHRfwt2lFe7eDBBHo2Rvt5QwEuJgwsiQgcC3CS2M98Xhi4cRRntsKuqnEBHZMVQCewDPufvG3JWlSgyLCTf2LCHcgr8mLm8ysx4xyKa43bC4rhpYBLwe171F6w4D5hcnbBGRHdowwtDthJIMVzWzgYSx4I3AHwiTcV0MHEkYp92d0AfxR+C7hESAu19gZmcS+hg2ATe6+ws5x94PqP/hD3/IwIHtv7myvr6empqaDn4yKQbVSXlSvZSfjtbJ8uXLOfvsswFq3P3l3PUlaTG4+3LCpGG55sWfbHU5+z5I8g7WXJsABg4cSHV1dZ7NWtfY2Nih/aR4VCflSfVSfjqhTlq9/K7hqiIikqDEICIiCUoMIiKSoMQgIiIJSgwiIpKgxCAiIglKDCIikqDEICIiCUoMIiKSoMQgIiIJSgwiIpKgxCAiIglKDCIikqDEICIiCUoMIiKSoMQgIiIJSgwiIpKgxCAiIglKDCIikqDEICIiCUoMIiKSoMQgIiIJ3UpRiJkNBqYCq4BF7n57XH4iMA6oAGYAC4FZQCPQ090nm9k5wBhgOfCEu99TiphFRLqqUrUY6oBb3H0ycLqZdY/LLwImApOAy4DhwCvuPgVYYWZDgWOAZYTksbBE8YqIdFklaTEAA4El8f1qoC+wEqhw92YAM+uVs91SYBDwPeC5uM9s4LOtFVBfX09jY2O7A2tqaiKTybR7Pyke1Ul5Ur2Un47WSUNDQ971pUoMi4Fqwkm/H7AmLm8ysx5AJdAUtxsW11UDi4CjgWeAtwmthlbV1NRQXV3d7sAymQy1tbXt3k+KR3VSnlQv5aejdVJVVZV3fakuJc0GvmpmM4G5wE0xIUyP674HXAs8DextZjcDu7r7AmAFcCeh7+G6EsUrItJllaTF4O7LgbNaWTUv/mSry9n3LuCuIoUmIiI5NFxVREQSlBhERCRBiUFERBKUGEREJEGJQUREEpQYREQkQYlBREQSlBhERCRBiUFERBKUGEREJEGJQUREEpQYREQkQYlBREQSlBhERCRBiUFERBKUGEREJEGJQUREEpQYREQkQYlBREQSlBhERCRBiUFERBK6laIQMxsMTAVWAYvc/fa4/ERgHFABzAAWArOARqCnu0+O2+0S1x3v7stLEbOISFdVqhZDHXBLPNGfbmbd4/KLgInAJOAyYDjwirtPAVaY2VAz2wm4Dni5s4Natmo9kx5azLJV6zv70CIi262StBiAgcCS+H410BdYCVS4ezOAmfXK2W4pMAi4ktCKmJKvgPr6ehobG9sV1KSHFrNk7XuMnjmfWSP3bNe+UjxNTU1kMpm0w5Acqpfy09E6aWhoyLu+VIlhMVBNOOn3A9bE5U1m1gOoBJridsPiumrgJeAI4CPAkcDXCK2Mf1FTU0N1dXW7grr/I3sxeuZ87q8bxuB+vdu1rxRPJpOhtrY27TAkh+ql/HS0TqqqqvKuL1VimA1MM7PxwFzgJjO7GJge13UHrgX+CIw1s5sB3H0eMA/AzO4GbujMoAb3682skXsqKYiIZClJYogdxme1sur9E3+WujaOMb6TwxIRkVZouKqIiCQoMYiISIISg4iIJCgxiIhIghKDiIgkKDGIiEiCEoOIiCQUdB+DmZ0A7As8C7zk7k1FjUpERFKz1cRgZt8hTE9RCzQD3wDGFDkuERFJSSGXko5293HAOne/B9inyDGJiEiKCkkM3eLMp1vMrBLYVOSYREQkRYX0MdxEmNxuAKGPYVpRIxIRkVQVkhieAY4GaoBXgd2KGpGIiKSqzcRgZgcDg4H/ITwHAUJSuB74ePFDExGRNORrMXwYGE14SE7LKKTNwB3FDkpERNLTZmJw9/nAfDP7pLv/qYQxiYhIigrpY6g2s+sIT1mrAPq7+yHFDUtERNJSyHDV/wKuIjyv+R7g+WIGJCIi6SokMbzl7s8AuPvdwJCiRiQiIqkqJDFsNLNjgO5mdgqwR5FjEhGRFBWSGM4n9C98G5gEfKuoEYmISKoK6XxeTrh/oQ9wA7ClvYWY2WBgKrAKWOTut8flJwLjCJ3aM4CFwCygEejp7pPN7N+BzxGm4pju7n9ub/kiIlK4QloMjwI3E2ZV/Xr8aa864BZ3nwycbmbd4/KLgImElshlwHDgFXefAqwws6GERDQBuA349w6U3aZ/rm3i4seW8c+1mkVcRKRFIS2GXu5+7DaWM5AwqglgNdAXWAlUuHszQJyoL3u7pcAgd3/QzI4lJKcL2yqgvr6exsbGdgV18WPLePGfGxl/51NMHTG4XftK8TQ1NZHJZNIOQ3KoXspPR+ukoaEh7/pCEsPvYqfz+6W7++J2xrGY8EyHJUA/YE1c3mRmPYBKoCluNyyuqwYWxYcE/Rb4FPA48GRrBdTU1FBdXd2uoO4etA/j73yKu889mt379mrXvlI8mUyG2tratMOQHKqX8tPROqmqqsq7vpDE8BFgOh+czLcAQ9sZx2xgmpmNB+YCN5nZxfG4swmd29cSZnEda2Y3A7j7AjObSLh/Ygvwg3aWm9fufXsxdcRgJQURkSyFJAZz9236muDuy4GzWlk1L/5kq8vZdzYheYiISAkUkhj+ZmZHAH8mjkhq6RcQEZEdTyGjko4B5gB/Bzy+ihTHxnX0/+tM2Lgu7UhEuqytthjc/WOlCEQEgKen0z9zFwzoD8dfkXY0Il1Svgf13ObuXzGzZ8i5qc3d29v5LFKYoy5k5YqVDDiqzZHJIlJk+VoM18bXcUB2n0K/4oUjXV7PPqz8WB0DevZJOxKRLitfH0OFmR0AfB/oAfQEPgTMLEVgIiKSjnwthiOACwAjJIMKwqM9f1GCuKSreq+ZPsvmw/77QbceaUcj0iXle7TnQ8BDZnaauz9awpikK3v5N1Q//Q0YMgRsRNrRiHRJhQxXXWdmp5rZaWb2spm1dqOaSOfY7wSWHnUd7HdC2pFIlvXN67l/yf2sb16fdihSAoUkhhuAfwD/CRwFfKmoEUnXtqmZXm+9CJt0D2U5uXPRncx9cy53Lroz7VCkBApJDBuABuC9OLVFz+KGJF1ay30MT09POxLJMmb/MdjOxpj9x6QdipRAIYnhbeDXwI/N7MuEGVBFiuOoC1lZOwF0H0NZmfOPOfg7zpx/zEk7FCmBQuZK+jywn7u/aGYHowntpJgqe9C020ehUiOSysm4A8fxtyV/Y9yB49IORUqgzRaDmU0AcPeNhKGquPsLdOwJbiKFaRmV9PJv0o5EsixcvpCFqxeycPnCtEOREsh3KWls1vtbs95v69PcRNo25HDeHnwsDDk87UgkywsrX2ALW3hh5QtphyLRluZmeO658NrJ8t75XMB7kc618A52WfobWHhH2pFIlr122SvxKulbPXcufOe68NrJ8iWGLQW8F+lc6nwuSwveWJB4lfQ13Hpb4rUz5et83s3MTiIkj35mdjKhtaBJ9KR4NjXTo3Gx7mMoM5d+6lJebHiRSz91adqhSItulcnXTpSvxfAnwuM4RxOe3jYm671IcTxyUbiU9MhFaUciWe7z+1i2cRn3+X1phyIt1r6dfO1E+eZKmtDppYlszenTeHvt2/Q9fVrakUiWdze/m3iV9O188km887OH2fnkkzr92IXc4CZSOpU9aK7aU/cxiGxFxYamxGtnKuQGt21mZoOBqcAqYJG73x6Xn0h4EFAFMANYCMwCGoGe7j7ZzM4HDgP6AN9394dLEbOkRI/2LEtb4piTLRp7UjYGXXsNL61/h0HXXtPpx95qi8HMbjWzj29jOXXALe4+GTjdzLrH5RcBE4FJwGXAcOAVd58CrDCzocAad/8icD6hz0N2ZBqVVJZ6VvZMvEr6KnfdFS65JLx2skIuJT0CXG5mC8zsfDPbpQPlDASWxPergb7xfYW7N7v7BqBXznZLgUHuPsfM+hBaHNd1oOw2LVu1nkkPLWbZKk0lXDbWNdD3tUdgXUPakUiWcw86lyM/fCTnHnRu2qFICWz1UpK7Pw48bmYDgJuBG83sAeBKd3+9wHIWA9WEk34/YE1c3mRmPYBKoCluNyyuqwYWmdmBwBXAt9z91bYKqK+vp7GxscBwgkkPLWbJ2vcYPXM+s0bu2a59pTj2+9kZdN/wT5pnn8rLZ/ws7XAkemblMyxcvZCf/OEnHNn/yLTDkaipqYlMJtPu/Roa8n/x2mpiMLNaYDzwGeAJ4Oi430+ATxcYx2xgmpmNB+YCN5nZxcD0uK47cC3wR2Csmd0c93uG8CyI54FrzewFd7++tQJqamqorq4uMJzg/o/sxeiZ87m/bhiD+/Vu175SJLs9zLuzT6bHhIep3f3AtKOR6NHnHmULW1jdazW1tbVphyNRJpPpUH1UVVXlXV9I5/NsQofwVfGSDwBmdlehQcTnOLTWPzAv/mSry/m9ptBy2mtwv97MGrmnkkI5Wf0q3d59B1a/CkoMZeOg/gdRQQUH9T8o7VCkBArpY3jM3e9pSQpmdh1Ay8gikU6lR3uWpRP2PIFL97+UE/ZUvXQFbbYYzOxcwoihWjM7LS6uJFz2+UYJYpOuaMMq+mW+D4efAVUD045Gou6V3fn0hz9N98ruW99Ytnv5Wgw/IEyD8eP4OgY4E1DPkxTPA+fQ+62/wgPnpB2JSJeVLzEc4u6vETqZLf7UoucxSDGNuof1u30MRt2TdiQiXVa+zucTgD8QJs7LtgX4ZdEikq7tQ/1YVTuWnT+kSXxF0pIvMdwU7zHIHSUkUjwtj/YcMgRsRNrRiHRJ+RKD868P5amIy/YtWkTStcVRSUM0KkkkNfmm3d6nlIGIANCtB+sGD4Numl1VJC35hqve5u5fMbNnyGk5uPvQokcmIiKpyHcp6dr4mtv5LCIiO7A2h6u6e8ssS5XANODnwPXA5hLEJSIiKSlkSow7CfMlHQ3cB3yvqBGJiEiqCplEb5O7PxbfP2xmeoKKiMgOLF/n88nx7Ttm9jXgd8D/A/QEFRGRHVi+FsOY+LqKMBVGy6TfG4sakYiIpCrffQwTWltuZnsULxwREUlbIU9wuxqYDPQAegMvAXpah4jIDqqQUUkjCM9f/iHhctKyokYkIiKpKiQxvOXuG4Eqd68ntBpERGQHVUhiWGpmXySMTroO2KXIMYmISIoKuY+hjnAp6QFgPJoiQ0Rkh1ZIYvgwcAFwAPAC6mMQEdmhFZIY7iXMk3QvMAy4BxjZnkLMbDAwlXBPxCJ3vz0uPxEYR3jOwwxgITALaAR6uvvkuN2pwER3P7M95YqISPsV0sfQy91nuPvz7n4b0LcD5dQBt8QT/elm1j0uvwiYCEwCLgOGA6+4+xRghZkNNbPjgBqgTwfKzeu1FesY+8BrvLZiXWcfWkRku5VvSowD4tuVZjYKmE+YEuPVDpQzEFgS368mJJeVQIW7N8fyeuVstxQY5O4PAk+Y2b/lK6C+vp7GxsZ2BTX2gddYuX4zZ86Yz/dH7d2ufaV4mpqayGQyaYchOVQv5aejddLQkH9mo3yXkmZmvZ8MnM8Hj/Zsr8WEDuwlQD9gTVzeFJ8rXQk0xe2GxXXVwKJCC6ipqaG6urpdQT14/hDOnDGfB88fxt4DOr1BIh2UyWSora3d+oZSUqqX8tPROqmqqsq7Pt+UGMe1vDez3YD9CJd5VrY7ijBt9zQzGw/MBW4ys4uB6XFdd8KDgf4IjDWzm2MMCzpQVsH2HtCH74/aW0lBRCRLIVNijAK+DWSAg83sKnf/QXsKcfflwFmtrJoXf7LVtXGMU9tTpoiIdEwhnc8XAZ9y95HAJwhDV0VEZAdVSGLY7O7rANy9kdAXICIiO6hC7mN42cymEh7UcwzwcnFDEhGRNBXSYpgIvAKcFF/PK2pEIiKSqkJaDD9395O3vpmIiOwICkkMa8zsDMIDejYDuPtLRY1KRERSU0hiGABMyfp9C3B8ccIREZG05U0MZrYLcLq7ry9RPCIikrI2O5/N7CvA88DzZnZK6UISEZE05RuVdBZgwJHAhaUJR0RE0pYvMTS5e3OcG6lHqQISEZF0FXIfA4RZVUVEpAvI1/l8kJndR0gKLe8BcPfWJsQTEZEdQL7E8Pms998tdiAiIlIe8j2PIXc6bBER6QIK7WMQEZEuQolBREQSlBhERCRBiUFERBKUGEREJEGJQUREEgqZdnubmdlgYCqwCljk7rfH5ScC4wg30c0AFgKzgEagp7tPzt3G3ReUImYRka6qVC2GOuAWd58MnG5m3ePyiwiPDp0EXAYMB15x9ynACjMb2so2nWbZqvVMemgxy1ZpVnERkRYlaTEAA4El8f1qoC+wEqhw92YAM+uVs91SYFAr27Sqvr6exsbGdgU16aHFLFn7HqNnzmfWyD3bta8UT1NTE5lMJu0wJIfqpfx0tE4aGhryri9VYlgMVBNO+v2ANXF5k5n1ACqBprjdsLiuGljUyjatqqmpobq6ul1B3f+RvRg9cz731w1jcL/e7dpXiieTyVBbW5t2GJJD9VJ+OlonVVVVedeX6lLSbOCrZjYTmAvcFE/20+O67wHXAk8De5vZzcCusT8hd5tOM7hfb2aN3FNJQUQkS0laDO6+nPDgn1zz4k+2upx9W9tGRESKRMNVRUQkQYlBREQSlBhERCRBiUFERBKUGEREJEGJQUREEpQYREQkQYlBREQSlBhERCRBiUFERBKUGEREJEGJQUREEpQYREQkQYlBREQSlBhERCRBiUFERBKUGEREJEGJQUREEpQYREQkQYlBREQSlBhERCShWykKMbOxwDFAb+Bqd38pa90NQE+gL1AHGPANYD3wc3f/qZn1Am4EnnL3+0sRs4hIV1WqFsM57n4ecAVwSctCM9sX6OvuFwDzgM/F9Re7+7nA+XHTi4DNJYpVRKRLK0qLwcwmAWdlLdoQX5cCg7KWD4zLWtYdCgxw9zfisi0A7v4dMxufr8z6+noaGxvbHWtTUxOZTKbd+0nxqE7Kk+ql/HS0ThoaGvKuL0picPdZwKyW383skfi2Gngja9MlwOCcdcvMbA93fxOoKLTMmpoaqqur2x1rJpOhtra23ftJ8ahOypPqpfx0tE6qqqryri9JHwNwr5nNBnYBLjWzwcC57n6Nma0ys5sJfQznAX8BpppZM3BLieITEZGoJInB3X8E/Chn8TVx3eU5y18keRmq5Rh3FyU4ERFJ0HBVERFJUGIQEZEEJQYREUlQYhARkQQlBhERSVBiEBGRBCUGERFJUGIQEZEEJQYREUlQYhARkQQlBhERSVBiEBGRBCUGERFJUGIQEZEEJQYREUlQYhARkQQlBhERSVBiEBGRBCUGERFJUGIQEZEEJQYREUnoVopCzGwscAzQG7ja3V/KWncD0BPoC9QBBnwDWA/8HHgMmAlsAPoDk9x9VSniFhHpikqSGIBz3P1EM9uHcNKfBGBm+wJ93b3OzCYAnwNOBS529zfM7JfAc8Bsd59vZhcDhwG/yDp2JcDy5cs7FFhDQwNVVVUd/VxSBKqT8qR6KT8drZOs82Vla+uLkhjMbBJwVtaiDfF1KTAoa/nAuKxl3aHAAHd/Iy7b4u5LgaVmdlhcf3NOcXsAnH322Z33AUREuoY9gJdzFxYlMbj7LGBWy+9m9kh8Ww28kbXpEmBwzrplZraHu78JVMT96+L6L7r7eznFPQcMA94ENnXyRxER2RFVEpLCc62trNiyZUvRIzCzLwAnAbsAlwLvAee6+zVm9h1gZ0Ifw3nA/sAVQDPwINAI3Ac8EQ93h7svKHrQIiJdVEkSg4iIbD80XFVERBJKNSppu2BmnwGGABXufnva8UgQBzM0u/vdacciQRxFuCewyN0fTDseATM7BTgYWOXud23LsdRiyOLuDwNvEe63kDJgZscA76Qdh/yLA4G1hL5AKQ8nAd0J57BtosSQxcwOcPcfAf3SjkXedzLhpsdhaQciCbe6+3Tg2LQDkff1dvfrgaO39UC6lJR0mJn9Bx/cWyEpc/crzGxvYHjKoUjSGWa2M/BK2oHI+35tZpcAy7b1QF1qVJKZ1QAPuPsnzGwwMBVYRbhOqj6FFKhOypPqpfyUsk66zKUkMxsITOSD69V1wC3uPhk43cy6pxZcF6U6KU+ql/JT6jrpMonB3Ze7+9eBdXHRQMKd1wCrCTfYSQmpTsqT6qX8lLpOukxiaMViwjQbEDqb16QYiwSqk/Kkeik/Ra2Trtz5PBuYZmbjgbmtzMEkpac6KU+ql/JT1DrpUp3PIiKydV35UpKIiLRCiUFERBKUGEREJEGJQUREEpQYREQkQYlBREQSlBhERCShK9/gJtshMxsOPAQc4u5L4rLrgb939EE+cfbW+939iE4KM/vYlcCjhOeaf8bdV8flA4DvAn2ACuB14D/dfUMnl3834bM93pnHlR2bWgyyPWoG7jKzirQDKcAeQH93P7olKUSXAr9y91Pc/WTC5GhfSiVCkRxqMcj26LeELzVfBm5rWZj7zd/MFgKjgfFADdCfMK/MHcDngAOAc4DlwAAz+xmwO/CIu19rZkOAWUAvoAmYBFQCLU/6e9Tdb8gq/2zgQmAj8I+4/SxgfzOb6e51WZ/hdeBMM6sHngYuAbbE41wHfBqoAjLuPsHMrirgMzwAvEmYQ+cxd/9mVmzdCS2U/ePf7gp3f9LM/hs4Pi6bEx++I12cWgyyvTofmGJm+xe4/QZ3PxWYC5zm7p8BrickDgiXdMYCRwEjzOxQ4EbC1MbHxffXx20HAifnJIXdgKuB4939aMKkZnXAZODFnKQAMAO4j9ByeAP4KTDIzHYBVrv7ScBQ4Ig4934hn2FvQhI8DDjezD6ZVd5EYKW7HwN8FmiZv38ccBZwDNCpl7Fk+6XEINsld3+L8O3ChnBFAAABuElEQVT8btr+P86+1PSn+LoGeDG+X01oDQA87+5r3X0T8HvCN/FDgMvN7EngvwitCYBX3T33Wcf7Eh6Y0hh//x1wUJ6PcBxwr7ufQkg0vwemE07Ou5vZHGAmIWG1zLVfyGdYFT/Ds4RHorY4BDgtfpafAN1iMhsNXAf8Atg1T7zShSgxyHbL3R8GnPAtGcLlnt3NrNLMdgX2ydp8a7NF1ppZHzPrBhwOLAL+Dlzm7sMJ3/4fjNtubmX/V4GPxsddQngW8kt5yrsAmBA/x8ZY3kZgBDDE3ccAlwMf4oMEV8hn6B07vA/ng+RB/Cxz4mcZQbjstA4YBYwhXE4ab2Z7baUM6QKUGGR7dyHxEoi7Lwd+BTxHuLZf347jrAJ+BCwAHnT3FwnX/a80s3nAvcBf29rZ3VcCVwJPxL6N/oTLRW35EuHJW382swWEfoJLCC2HfeMxHiQ8U3lQgZ+hmXDCfxb4P3d/PmvdTODA+FkWAK/HhLQK+Auh3+aXhHn+pYvTtNsiO4BiDrmVrkctBhERSVCLQUREEtRiEBGRBCUGERFJUGIQEZEEJQYREUlQYhARkQQlBhERSfj/cZy9Ynz9BsMAAAAASUVORK5CYII=\n", 236 | "text/plain": [ 237 | "
" 238 | ] 239 | }, 240 | "metadata": {}, 241 | "output_type": "display_data" 242 | } 243 | ], 244 | "source": [ 245 | "# plt.figure(figsize=(8, 10))\n", 246 | "y_axis_monte = np.asarray(total_probas)\n", 247 | "x_axis_monte = np.asarray(num_samples)\n", 248 | "\n", 249 | "for x, y in zip(x_axis_monte, y_axis_monte):\n", 250 | " plt.scatter([x] * len(y), y, s=0.5)\n", 251 | "\n", 252 | "plt.xscale('log')\n", 253 | "plt.title(\"Direct Monte-Carlo Estimation\")\n", 254 | "plt.ylabel(\"Probability Estimate\")\n", 255 | "plt.xlabel('Number of Samples')\n", 256 | "plt.grid(True)\n", 257 | "plt.savefig(plot_dir + 'monte_carlo_convergence_speed.jpg', format='jpg', dpi=300)\n", 258 | "plt.show()" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "## Importance Sampling" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": { 271 | "collapsed": true 272 | }, 273 | "source": [ 274 | "With importance sampling, we try to reduce the variance of our Monte-Carlo integral estimation by choosing a better distribution from which to simulate our random variables. It involves multiplying the integrand by 1 (usually dressed up in a “tricky fashion”) to yield an expectation of a quantity that varies less than the original integrand over the region of integration. Concretely,\n", 275 | "\n", 276 | "$$\\mathbb{E}_{p(x)} \\big[\\ f(x) \\big] = \\int f(x)\\ p(x)\\ dx = \\int f(x)\\ p(x)\\ \\frac{q(x)}{q(x)}\\ dx = \\int \\frac{p(x)}{q(x)}\\cdot f(x)\\ q(x)\\ dx = \\mathbb{E}_{q(x)} \\big[\\ f(x)\\cdot \\frac{p(x)}{q(x)} \\big]$$\n", 277 | "\n", 278 | "Thus, the MC estimation of the expectation becomes:\n", 279 | "\n", 280 | "$$\\mathbb{E}_{q(x)} \\big[\\ f(x)\\cdot \\frac{p(x)}{q(x)} \\big] \\approx \\frac{1}{N} \\sum_{n=1}^{N} w_n \\cdot f(x_n)$$\n", 281 | "\n", 282 | "where $w_n = \\dfrac{p(x_n)}{q(x_n)}$" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "In our current example above, we can alter the mean and/or standard deviation of $\\Delta L$ and $\\Delta V_{TH}$ in the hopes that more of our sampling points will fall in the failure region (red area). For example, let us define 2 new distributions with altered $\\sigma^2$.\n", 290 | "\n", 291 | "* $\\Delta \\hat{L} \\sim \\ N(0, 0.02^2)$\n", 292 | "* $\\Delta \\hat{V}_{TH} \\sim \\ N(0, 0.06^2)$" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 8, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "def importance_sampling(num_simulations, num_samples, verbose=True, plot=False):\n", 302 | " \n", 303 | " if verbose:\n", 304 | " print(\"===================================================\")\n", 305 | " print(\"{} Importance Sampling Simulations of size {}\".format(num_simulations, num_samples))\n", 306 | " print(\"===================================================\\n\")\n", 307 | " \n", 308 | " num_simulations = int(num_simulations)\n", 309 | " num_samples = int(num_samples)\n", 310 | " \n", 311 | " probas = []\n", 312 | " for i in range(num_simulations):\n", 313 | " mu_1, sigma_1 = 0, 0.01\n", 314 | " mu_2, sigma_2 = 0, 0.03\n", 315 | " mu_1_n, sigma_1_n = 0, 0.02\n", 316 | " mu_2_n, sigma_2_n = 0, 0.06\n", 317 | " \n", 318 | " # setup pdfs\n", 319 | " old_pdf_1 = norm(mu_1, sigma_1)\n", 320 | " new_pdf_1 = norm(mu_1_n, sigma_1_n)\n", 321 | " old_pdf_2 = norm(mu_2, sigma_2)\n", 322 | " new_pdf_2 = norm(mu_2_n, sigma_2_n)\n", 323 | "\n", 324 | " length = np.random.normal(mu_1_n, sigma_1_n, num_samples)\n", 325 | " voltage = np.random.normal(mu_2_n, sigma_2_n, num_samples)\n", 326 | " \n", 327 | " # calculate current\n", 328 | " num = 50 * np.square((0.6 - voltage))\n", 329 | " denum = 0.1 + length\n", 330 | " I = num / denum\n", 331 | " \n", 332 | " # calculate f\n", 333 | " true_condition = np.where(I >= 275)\n", 334 | "\n", 335 | " # calculate weight\n", 336 | " num = old_pdf_1.pdf(length) * old_pdf_2.pdf(voltage)\n", 337 | " denum = new_pdf_1.pdf(length) * new_pdf_2.pdf(voltage)\n", 338 | " weights = num / denum\n", 339 | "\n", 340 | " # select weights for nonzero f\n", 341 | " weights = weights[true_condition]\n", 342 | "\n", 343 | " # compute unbiased proba\n", 344 | " proba = np.sum(weights) / num_samples\n", 345 | " probas.append(proba)\n", 346 | " \n", 347 | " false_condition = np.where(I < 275)\n", 348 | " if plot:\n", 349 | " if i == num_simulations -1:\n", 350 | " plt.scatter(length[true_condition], voltage[true_condition], color='r')\n", 351 | " plt.scatter(length[false_condition], voltage[false_condition], color='b')\n", 352 | " plt.xlabel(r'$\\Delta L$ [$\\mu$m]')\n", 353 | " plt.ylabel(r'$\\Delta V_{TH}$ [V]')\n", 354 | " plt.title(\"Monte Carlo Estimation of P(I > 275)\")\n", 355 | " plt.grid(True)\n", 356 | " plt.savefig(plot_dir + 'imp_sampling_{}.pdf'.format(num_samples), format='pdf', dpi=300)\n", 357 | " plt.show()\n", 358 | " \n", 359 | " \n", 360 | " mean_proba = np.mean(probas)\n", 361 | " std_proba = np.std(probas)\n", 362 | " \n", 363 | " if verbose:\n", 364 | " print(\"Probability Mean: {}\".format(mean_proba))\n", 365 | " print(\"Probability Std: {}\".format(std_proba))\n", 366 | " \n", 367 | " return probas" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 9, 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "===================================================\n", 380 | "10 Importance Sampling Simulations of size 10000\n", 381 | "===================================================\n", 382 | "\n", 383 | "Probability Mean: 0.002236025696537488\n", 384 | "Probability Std: 0.00015573742910298226\n" 385 | ] 386 | } 387 | ], 388 | "source": [ 389 | "probas = importance_sampling(10, 10000, plot=False)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": 10, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "def IS_histogram(num_samples, plot=True):\n", 399 | " \n", 400 | " num_samples = int(num_samples)\n", 401 | "\n", 402 | " mu_1_n, sigma_1_n = 0, 0.02\n", 403 | " mu_2_n, sigma_2_n = 0, 0.06\n", 404 | "\n", 405 | " length = np.random.normal(mu_1_n, sigma_1_n, num_samples)\n", 406 | " voltage = np.random.normal(mu_2_n, sigma_2_n, num_samples)\n", 407 | "\n", 408 | " # calculate biased current\n", 409 | " num = 50 * np.square((0.6 - voltage))\n", 410 | " denum = 0.1 + length\n", 411 | " I = num / denum\n", 412 | "\n", 413 | " if plot:\n", 414 | " n, bins, patches = plt.hist(I, 50, density=1, facecolor='green', alpha=0.75)\n", 415 | " plt.ylabel('Number of Samples')\n", 416 | " plt.xlabel(r'$I_{DS}$ [$\\mu$A]')\n", 417 | " plt.title(\"Importance Sampling of P(I > 275)\")\n", 418 | " plt.grid(True)\n", 419 | " plt.savefig(plot_dir + 'is_histogram_{}.pdf'.format(num_samples), format='pdf', dpi=300)\n", 420 | " plt.show()" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 11, 426 | "metadata": {}, 427 | "outputs": [ 428 | { 429 | "data": { 430 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEQCAYAAABWY8jCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH9ZJREFUeJzt3XuYXFWZ7/FvE3OxQ5sQRBpSjpmx9aVHBRVHMBogEhRFDzg6g1wMoJGWxsghXoLCjAfiwYPncEnGgIkhEB2UOQLqA3qYM0qMDiEGFVFC+UobNAmh28Qk0HRTBOjMH2s1VCpd3bs7vauqa/8+z9NPKnvty3pr17PfWmvvWqthz549iIhINh1Q7QqIiEj1KAmIiGSYkoCISIYpCYiIZJiSgIhIhikJiIhk2EuqXQGpXWY2A7jV3Y9N+TjHAbvc/TdpHice6wDg/wBvAPqA3cBF7r5xlI9zLnAEcB3wz+7ePpr7j8f4BvBa4Fx3/11c9j+AM4GtwB5gIvAFd/9JLF8M/G9gDnCEu18yzGO+EfgX4HngGWAucBghzn7HAqcB64HfAw/F5d8FlgA3A59w96eHc2xJh5KA1IKPArcCqScB4GTgcHc/CcDMTgOuBU5N42Du3gmMegKI3u3uhw6w/Bp3/xqAmbUCtwBvNrNjgefcfYuZ7bORmb0FmAb8h7uX+wHRYmC+u//azNqAhe6+ADgh7uMfgK3ufreZzQG+7e7zS47zLeBzwOXDD1lGm5KAJGJmPwEeBF4PPAX8DHg3MBV4F+EieirwMuDlwBXufruZnQR8CSgAfyFc8N8IXEX4Fv4jwoX5zWb2MPDfgL8HxgNPxNdnAu8FGoFXA1e5+81mdgzhotQAPAacBbQQvm029B/P3Z8oCmUz8BYzOx34MfB94Icxxg8BF8ZtAT4U4/084VvvK4GvAe8EjgIWu/sNsd4/A14H7ADOKHrfZhBbU2b2G2ANcCThW/qpwJPAUuAtQCfw18D73f2PRfsY6D28EjjIzL7v7oMlsGmE8wXwKeDqQdbdCpwOXG5mdwI3ufvjJet8uGjZS2Kd+us5mXBhPy4uOppwXtcAfwY+Fbf9EXCNmS1y975B6iMVoHsCMhzr3f1EQhdDb/w2/TBwfCw/EDiJkBSuMbPxwHLg7939eMIF8LK47iR3n+XulwN3E74ZbgEOBua4+yxCIvi7uP4Ud38fIUn0d2EsB85z92MIF5ZW4OvAhe5+AuHi/rniANz9t8DHCd0VG4BfAG+Lxa8FTonbOiHJAeSADwIXxPp/BHgP0BbLG4Fb3P0dwO+Klpd6GeGb8fGEpPWeGM/B7v5W4GOERPMCM2sY6D2M3Us7yiSABWb2EzP7MbAgxgvhPD00wPr9781Wd/8sMAvYBPzWzN5fss7jsV4zgU8SWlH9PgZ8x923x///DvhirPf3CN1IuPvzhKTw+nJ1kcpRS0CG41fx312Eiz/ATmBSfL0mfrPrMrOdQDPwpLs/Fst/SvgGexfhIrsXd+8zs93At83sKcLFd3ws/nX8d3PR8Q5193zc9np4ofvj+tjdMZ7QJ/0CMzsyrO5nxAvsScD/NbNmwoVpVTz2EcB9cbOH3P1ZM9sF/MHdd8f4+uvxrLv/NL5eS7i4929b6oGSOGb0r+vu28zsdyXrv5yB38PBvNAdVGKcuz8z2IZmdjQhaRjwBeCeAdY5HbiUkDC3FRWdRWg99bsH6I2vvwtcUVT2OCHhS5WpJSDDMdRAU0cDmNmhhG+9W4GXmdlhsfx4XrwoF3cD9AEHxAv0ae5+OjCf8Pns75oZ6Nhbzew18ZgLzewDhOQyN36b/xzwg5Jt5gBfNrNxsd97A9AT63s58GFgHvD0EMcuNt7Mjoqv3x73WU7pvh4itkTM7CBCa6TYdsq/h8P1tJmNK1cYv/W3AV9399nuvtzde0rWOZvQAjih+Ga6mU0BJrr75qLVVxBaUAAnAr8sKjuIkHSlytQSkNHUHLsgpgDt7v68mX0cuMPM+githnPZtxvg58D/IvSl95jZLwh98I8Dhw9yvDZgZdz344QnVDYB3yi62H2sZJslhKeDHjCzJwkJ6COEvvl7Ca2dnljXw4FHE8a+0Mz+Kh7/MoruCwzhB8B7zGwt4Z5AL/Bsf6G77ynzHo7EvcCbgfsHKnT3O4E7y20c39MlhBjviK2tNe7+RULy+mPJJpcQzk874T2dF/dzAKGV9zBSdQ0aRVRGQ/8jkcN95LAemNkfCbEXhlh1oG2PAN7o7rea2cGEVsSrhuq2GQkzexvhxu5Fo73vYdbjvcCb3f1L1ayHBOoOEqmuzcAZZraOcIN8YRoJAMDd7wNeYma5NPafRLwPcyZ731CWKlJLQEQkw9QSEBHJMCUBEZEMG1NPB5nZRMKPhx4njF0iIiJDG0cY4+n+0ntOYyoJEBLAz6pdCRGRMWoW8J/FC8ZaEngc4JZbbqG5uZmOjg5aWlqqXaf9pjhqT73EojhqS7Xi6Ozs5KyzzoJ4DS021pLA8wDNzc3kcjm6u7vJ5ar2tNuoURy1p15iURy1pQbi2KcbXTeGRUQyTElARCTDUukOMrPphHHLdwAb3H1pXD6HMBNRA3ADsI4wTG43YfCpdjN7N2Gc9ZcA97r7qjTqKCIi6bUE2oAlcczzU+K48hDGNp8HnA8sJMxGtNHdLwa2xTHKZxEm7Ghh5KMliohIAmndGG4mjIkCYdTDKYQhcRvcfTeAmU0qWW8LYdTGHwH/kzBRx02ESTf20tHRQXd3N4VCgXw+n1IIlaM4ak+9xKI4aku14ujq6ipbllYS2EQYKnYzYXq7XXF5wcwmEH64UIjrzYplOcIIilcSxnzfVa5+LS0t5HI58vk8ra2tKYVQOYqj9tRLLIqjtlQrjqamprJlaXUHrQDmm9ky4A7g2njxvy6WrQQWEcY3n2Fmi4Gp7r6W8O3/W8CNaKRBEZFUpdIScPdOwnCxpdbEv2J7zccabwSP2ZvBs1fNHnD56nNWV7gmIiJD0yOiIiIZpiQgIpJhSgIiIhmmJCAikmFKAiIiGaYkICKSYWNtKOkxS4+OikgtUktARCTDlARERDJMSUBEJMOUBEREMkxJQEQkw5QEREQyTElARCTDlARERDJMSUBEJMOUBEREMkxJQEQkw5QEREQyTElARCTDlARERDIslaGkzWw6cDWwA9jg7kvj8jnAXKABuAFYBywHuoGJ7t5uZv8CNAHTgU53/0gadRQRkfRaAm3AEndvB04xs/Fx+QJgHnA+sBA4Adjo7hcD28xsprvPj+V/AS5MqX4iIkJ6k8o0A5vj653AFGA70ODuuwHMbFLJeluAw+Pr84Dvu/uTA+28o6OD7u5uCoUC+Xw+pRBGprend1jr5/P5moxjJOolDqifWBRHbalWHF1dXWXL0koCm4Ac4QI/DdgVlxfMbAIwDijE9WbFshywIb5+H3BauZ23tLSQy+XI5/O0tramUP2Ra1zfOKz1W1tbazKOkaiXOKB+YlEctaVacTQ1NZUtS6s7aAUw38yWAXcA18aL/3WxbCWwCLgXmGFmi4Gp7r7WzCYDT7v78ynVTUREolRaAu7eCZw5QNGa+FesrWTbHuAf06iXiIjsTY+IiohkmJKAiEiGKQmIiGSYkoCISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiGKQmIiGSYkoCISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiGKQmIiGSYkoCISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiGpTLRvJlNB64GdgAb3H1pXD4HmAs0ADcA64DlQDcw0d3bzexk4FTgGWC1u38/jTqKiEh6LYE2YIm7twOnmNn4uHwBMA84H1gInABsdPeLgW1mNhO4kJA8pgIPpFQ/EREhpZYA0Axsjq93AlOA7UCDu+8GMLNJJettAQ4HXg+cHsu+RGg57KWjo4Pu7m4KhQL5fD6lEEamt6d3WOsfc/0x9PX1ccDqvfPxzbNvHsVaVUYtno+RqpdYFEdtqVYcXV1dZcvSSgKbgBzhAj8N2BWXF8xsAjAOKMT1ZsWyHLAB+FMs+0u5nbe0tJDL5cjn87S2tqYTwQg1rm8c9ja9Pb00Tt57u1qLK4laPB8jVS+xKI7aUq04mpqaypal1R20AphvZsuAO4Br48X/uli2ElgE3AvMMLPFwFR3Xwt8FVgFLAa+klL9RESElFoC7t4JnDlA0Zr4V6ytZNvbgNvSqJeIiOwtre6gujZ71exqV0FEZFQMqzvIzF6ZVkVERKTyhmwJmNmngKcJj2yeZ2Z3u/uC1GsmIiKpS9IddAZwPHA38Drgx6nWSEREKiZJd9Ae4DCgy933EB75FBGROpCkJbAa+ClwhpldC9yebpVERKRShkwC7n4pcKmZHQQs7P/Fr4iIjH1JbgwfB1xP+JXvd8zsT+5+Y+o1ExGR1CW5J/Al4DigE7gSaE+1RiIiUjFJkkCfu+8A9rh7gTDss4iI1IEkSaDDzL4MHGxmlxAGeBMRkTqQJAl8gnDh/0/gKeDjqdZIREQqpuyNYTN7V9F/N8Y/CBPB/P8U6yQiIhUy2NNBZ5T8fw9hWsg9KAmIiNSFsknA3c/rf21mbwKMMF/wbytRMRERSd+Q9wTMbBFhopdjgGVm9tnUayUiIhWR5Mbwe4BZcTL4WcA/pFslERGplCRJYAvQP0HleKD8jMUiIjKmJBlA7nDg92b2IPC3wG4zWwvg7jPTrJyIiKQrSRJQ94+ISJ1KkgQOBT4MTOpf4O6Djh9kZtOBq4EdhCeKlsblc4C5hEdNbwDWAcsJQ1FMdPd2MzuH8HhqJ7Da3VcNNygREUkmyT2BVcBvgH8v+htKG7AkJotTzGx8XL4AmAecDywk/PBsY7zpvM3MZhIGq3uMkCjWJQ9FRESGK0lL4BF3v3mY+20GNsfXO4EpwHagoX8+AjObVLLeFsL9h5XA/XGbFcCppTvv6Oigu7ubQqFAPp8fZtX2X29P76jur6+vb599ViOu/VWt85GGeolFcdSWasXR1VX+eZ4kSeB2M7sVeLh/gbtfMcQ2m4Ac4QI/DdgVlxfMbAJhboJCXG9WLMsBG4B3APcBTxJaA/toaWkhl8uRz+dpbW1NEMLoalzfOKr76+3ppXHy3vusRlz7q1rnIw31EoviqC3ViqOpqalsWZLuoHbgAcKjof1/Q1kBzDezZcAdwLXx4n9dLFsJLALuBWaY2WJgqruvBbYBNxLuFXw5wbFERGSEkrQEdrj7VcPZqbt3AmcOULQm/hVrK9n2JuCm4RxPRERGJkkS2B6/0f+KMHgc7r481VqJiEhFJEkCHfHf5vjvnpTqIiIiFTZkEnD3y83sMMKQEQ2EJ3hERKQODJkEzOxG4G3AZOClhMlljk25XiIiUgFJng5qBV5H+JHY3xIe7RQRkTqQJAl0u/seYLK7bwcmpFwnERGpkCRJ4Jdm9hlga/zR2LiU6yQiIhWS5MbwF8ysCXiaMMHM+tRrJSIiFVE2CcSxfdqAJYRxfFYCzxDG9RERkTowWHfQEuBVcZ2lwIPA7YQhoEVEpA4MlgRe5e4LCL8PmAVc5e7fBQ6pSM1ERCR1gyWBvvjv24H17v5s/P9L062SiIhUymA3hnvM7HzgQ8C3zOwA4KOE4Z9FRKQODNYS+ATwauB7hNnFTgDeD1yQfrVERKQSyrYE4g/DFhYtuif+iYhInUjyYzEREalTZZOAmU2pZEVERKTyBmsJ3AlgZvpdgIhInRrs6aCnzex+4DVmdlRc1gDscfeZ6VdNRETSNlgSeA9hApllhCeCGipSIwFg9qrZAy5ffc7qCtdEROrZYE8H9QFbzOxU4HzCnAK/J8GwEWY2Hbga2AFscPelcfkcYC4hodwArAOWA93ARHdvj+u9LJa9M05aLyIiKUjydNAyoAX4D2AGsCLBNm3AknhRP8XMxsflC4B5hKSykPDbg43ufjGwzcxmxh+lfRn4wzDiEBGREUgy0fxr3P24+Pp7ZrY2wTbNwOb4eidhFNLtQIO774YXRiktXm8Lofvpi4TWwcXldt7R0UF3dzeFQoF8Pp+gOqOrt6d3VPfX19eXeJ/ViDepap2PNNRLLIqjtlQrjq6urrJlSZLAJDNrdPdeM3spySaV2QTkCBf4acCuuLxgZhPiPgpxvVmxLEfobjoWOJQwr/HnCK2HvbS0tJDL5cjn87S2tiaozuhqXN84qvvr7emlcXKyfVYj3qSqdT7SUC+xKI7aUq04mpqaypYlSQKLgQfN7CHCHMNfTLDNCuAaMzsXuAO41sw+DVwXy8YDi4BfAh8xs8UA7r4GWANgZjcDX0lwLBERGaEkM4vdYmb/D/gb4FF3/0uCbTqBMwcoeuEiX6StzD7OHeo4IiKyf5K0BHD3HYQnfUREpI5o7CARkQwbMgmY2WcqUREREam8JC2B95pZkieCRERkjElyT+DlwFYzexTYg8YOEhGpG0mSwPtTr4WIiFRFkiTwHHAVcAhwG/Ab4E9pVkpERCojyT2B5cBKYALwU8KPx0REpA4kSQKT3P0ewr0AJwz3ICIidSBJEnjGzN4NjDOzY1ESEBGpG0mSwPnAeYSnhD5DmGBGRETqQJKxg7aY2ZXAa4GH3P3R9KslIiKVkOQXw5cB1wNvB240s/+eeq1ERKQiEv1iGDguzv51PPDhdKskIiKVkiQJ/Bnon/FkArAtveqIiEgllb0nYGb3EYaJeAXwiJk9SJhUZsj5BEREZGwY7Mawun1EROpc2STg7n8CMLO3EhLCpKLi9pTrJSIiFZBk7KBVhLGDdqZcFxERqbAkSeARd7857YqIiEjlJUkCt5vZrcDD/Qvc/YrBNjCz6cDVhHmJN7j70rh8DjAXaABuANYRBqjrBia6e7uZfQD4IPA8cJ27PzDsqEREJJEkj4i2Aw8AXUV/Q2kDlrh7O3CKmY2PyxcA8whDUSwETgA2xt8gbDOzmYQnks4Dvgp8IHkoIiIyXElaAjvc/aph7rcZ2Bxf7wSmANuBBnffDWBmk0rW2wIc7u63mdnxhCGr9etkEZEUJUkC281sGfArwrd03H35ENtsAnKEC/w0YFdcXjCzCcA4wmikm4BZsSwHbDCzE4F7gKOBu4GflO68o6OD7u5uCoUC+Xw+QQijq7end1T319fXl3if1Yg3qWqdjzTUSyyKo7ZUK46urvIdOEmSQEf8t3kYx1wBXGNm5wJ3ANea2aeB62LZeGAR8EvgI2a2GMDd15rZPMITSXuAfx1o5y0tLeRyOfL5PK2trcOo1uhoXN849ErD0NvTS+PkZPusRrxJVet8pKFeYlEctaVacTQ1NZUtS5IEbhruAd29EzhzgKI18a9YW8m2KwiJQkREUpYkCfwb4Vv5AcBfA48A70izUiIiUhlJ5hN4W/9rM5sKLEu1RiIiUjFJHhEt9gTw6jQqIiIilTdkS6BoNNEG4BDgR2lXSkREKiPJPYHi0UQL7p7kx2J1Yfaq2dWugohIqgabT2BumeW4+zfSq5KIiFTKYC2B0odZGwjDOfQCSgIiInVgsPkEPt//2sxagJuBu9BQDiIidSPJjeELCRf+i939rvSrJCIilTLYPYHphF8L7wDe6u6aVKYGlLtZvfqc1RWuiYjUg8FaAg8BuwmDuS01sxcK3H2gISFERGSMGSwJnFaxWoiISFUMdmO4dKA3ERGpM8MdNkJEROqIkoCISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiGKQmIiGSYkoCISIYlmVRm2OK4Q1cTxh3a4O5L4/I5wFzCsNQ3AOuA5UA3MNHd283sAuDvgAOBb7r7nWnUUURE0msJtAFL3L0dOMXMxsflC4B5wPnAQuAEYKO7XwxsM7OZwC53/yhwAaAxikREUpRKSwBoBjbH1zuBKcB2oMHddwOY2aSS9bYAh7v7t83sQEJL4ssD7byjo4Pu7m4KhQL5fD6lEKC3pze1fRfr6+vb72Ol+T4klfb5qKR6iUVx1JZqxdHVVX5W4LSSwCYgR7jATwN2xeUFM5sAjAMKcb1ZsSwHbDCzI4DLgH9y90cH2nlLSwu5XI58Pk9ra+kEaKOncX1javsu1tvTS+Pk/TtWmu9DUmmfj0qql1gUR22pVhxNTU1ly9LqDloBzDezZcAdwLXx4n9dLFsJLALuBWaY2WJgKnAfYfaylwKLzOySlOonIiKk1BJw904G7s9fE/+KtZX8vyWNOomIyL70iKiISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiGKQmIiGSYkoCISIal9YthqbDZq2YPuHz1OasrXBMRGUvUEhARyTAlARGRDFMSEBHJMCUBEZEMUxIQEckwJQERkQxTEhARyTAlARGRDFMSEBHJMCUBEZEMUxIQEckwJQERkQxLZQA5M5sOXA3sADa4+9K4fA4wF2gAbgDWAcuBbmCiu7fH9U4G5rn7h9Kon4iIBGm1BNqAJfGifoqZjY/LFwDzgPOBhcAJwEZ3vxjYZmYzzWw20AIcmFLdREQkSmso6WZgc3y9E5gCbAca3H03gJlNKllvC3C4u98GrDaz95XbeUdHB93d3RQKBfL5fEohQG9Pb2r7LtbX15fasdJ8f0qlfT4qqV5iURy1pVpxdHV1lS1LKwlsAnKEC/w0YFdcXjCzCcA4oBDXmxXLcsCGJDtvaWkhl8uRz+dpbW0d1YoXa1zfmNq+i/X29NI4OZ1jpfn+lEr7fFRSvcSiOGpLteJoamoqW5ZWd9AKYL6ZLQPuAK6NF//rYtlKYBFwLzDDzBYDU919bUr1ERGRAaTSEnD3TuDMAYrWxL9ibWX2cfJo1yuLNOOYiAxGj4iKiGSYkoCISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiGKQmIiGRYWr8YHlPKPUsvIlLvlAQySj8iExFQd5CISKYpCYiIZJiSgIhIhikJiIhkmJKAiEiG6ekg2YueGhLJFrUEREQyTElARCTD1B0kiaibSKQ+qSUgIpJhagnIfiluIfT29NK4vhFQC0FkrEglCZjZdOBqYAewwd2XxuVzgLlAA3ADsA5YDnQDE929vXQdd1+bRh0lXYMNyqcEIVI70moJtAFL3H2tmf3QzJa7+7PAAuA0YBxwK7AY2OjuV5rZ5WY2c4B1Ti3a7ziAzs5OALq6umhqakpcqTNuP2O/A0vD808/z3PPPlftauy3pHHMWjxrWPv99ge/PdIqjdhwP1u1SnHUlmrF0X/NJF5Di6WVBJqBzfH1TmAKsB1ocPfdAGY2qWS9LcDhA6xT7DCAs846K6VqV0833dWuwqhII44Trz9x1PcpklGHAX8oXpBWEtgE5AgX+GnArri8YGYTCNmoENfr/1qYAzYMsE6x++P6jwPPp1R3EZF6M46QAO4vLWjYs2fPqB/NzJqBawh9/b8AjgQ+DbwN+BgwnnDP4JfA14gXe3e/yMyOL17H3X8x6hUUEREgpSQgIiJjw5h8RLTc00e1Lt74vgh4CvgTcCAwkXDPpA0w4PNAL3CXu3+3SlVNxMxuAe4EXgnMIMRxMTCBMXB+zGwG8E/Anwmt1qmMwfNhZjngCsL9twZCd+kMxt75aAG+4+5vMrPPMkQMpeu4+7aqVLxESRxfJVxnDyX0hhxIyWfKzL5C0efO3Z+pZH3H6o/F+p8+agdOMbPx1a5QQgcB89z9Y8A7gCnufhGwBvgg8Bng07H8gupVc2hmtoCQzACOc/cLgRuBjzN2zs+nCQ8kHAZ0MXbPxxHAO4G/IjyAMebOR+xCngf0xAdChoqhaYB1qq4kjgOBu939E8C/AidR8pkys79h389dRY3VJDDQ00c1z91/ADxlZpcC9xIuQPDik1GHuPvWuKxm++nM7P2Em/33ET5Df45F/XGMlfPTAnyPcAE5mzF6Pgjv9TuBfwRmExIBjKHz4e6d7n4J4YvFNIb+TB00wDpVVxyHuz/l7nfFlsHpwLfY9zPVzL6fu4oaq0mg/+kj2Pvpo5oWv72sIPxIbiUwPRblgK3AY2Z2WFzWUPkaJnY28FbgHMIF9BVxeX8cY+X8dAJPxt+wwNg9H58kfJvcAzwBvCouH2vno9+fgYPj63IxbB1gnZpjZqcRuoDPdfdu9v1MbWbfz11Fjckbw6VPH7n716tcpUTMbCXwGsL9gOeBx4DJhG9mH49llwG7gdvc/a4qVTURMzuX8GTXoYT+84OATwAvZQycHzNrBS4nXHTuA17HGDwfZvZmQhybgD8S6jvmzgeAmd3t7ieb2UUMEUPpOu7+RLXqXcrM7iYk5/uAf4+LvwM8QslnysyupOhzV/SlpCLGZBIQEZHRMVa7g0REZBQoCYiIZJiSgIhIhikJiIhkmJKAiEiGKQmIiGSYkoBIZGZtZva1atdDpJKUBERedCTw2+FuZGbnmtmmOJ5SkvUXmtnj/ZMmmdkKM9tlZkcM99gi+0tJQORFbwB+M8Jtv+Xu1wCY2U/MzOLrg83soZJ1zyJMnfphAHefB/x6hMcV2S9KAiIvej0jaAkMoIUwPACUtC7M7ATC9H5fAy4chWOJ7JcxOZ+AyGgzs1cSRn7cZWb3Az8HXgasdveb4pSnVxPGfJoKXOrujw2wn1cBj7l7X1x0JHu3LuYBK9zdzewZMzvG3X+eYmgig1ISEAmOBH4bk8HP3f2TAGZ2j5l9gzD/wzZ3v2KI/byRvS/6RwP/Fvd1EPBe4BVmNp8wYNgnCQlHpCrUHSQS9N8POJow93W/XqCPMBrkRDP7ppmdPch+jgL6b/i+BjiVF7uDzgZudPd3ufvJwDHAu8zskFGNRGQYlAREgjcQLtYvJAEzOwrY5O573P1pd7+UMIfCYH35bwQOMLMHgX8G8nEbCF1B3+xf0d17gdupkVmxJJvUHSQCuPtZAGb2Q+BgM3uG0P+/MC6/HniOMEfslYPs6kjgTXECkdJjHDXAsva4/3ftbwwiI6EkIFLE3d9bZnn7EJueaWZPA30DJYDBmNkKQgtCpOI0qYyISIbpnoCISIYpCYiIZJiSgIhIhikJiIhkmJKAiEiGKQmIiGSYkoCISIYpCYiIZNh/AVB0O5PtsEYzAAAAAElFTkSuQmCC\n", 431 | "text/plain": [ 432 | "
" 433 | ] 434 | }, 435 | "metadata": {}, 436 | "output_type": "display_data" 437 | } 438 | ], 439 | "source": [ 440 | "IS_histogram(1e5)" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 12, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "Iter 1/4\n", 453 | "Iter 2/4\n", 454 | "Iter 3/4\n", 455 | "Iter 4/4\n" 456 | ] 457 | } 458 | ], 459 | "source": [ 460 | "num_samples = [1e3, 1e4, 1e5, 1e6]\n", 461 | "num_repetitions = 25\n", 462 | "\n", 463 | "total_probas = []\n", 464 | "for i, num_sample in enumerate(num_samples):\n", 465 | " print(\"Iter {}/{}\".format(i+1, len(num_samples)))\n", 466 | " probas = importance_sampling(num_repetitions, num_sample, verbose=False)\n", 467 | " total_probas.append(probas)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 13, 473 | "metadata": {}, 474 | "outputs": [ 475 | { 476 | "data": { 477 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAETCAYAAADQ97psAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAG+hJREFUeJzt3Xl4HeV59/GvkGRcsIDYEIwtJ4AFt9UASVgaAjZm35yFrGUpBAhgMEkTKMWBlyRNncbAVdaEgHkpARqWtoSmDZClgWB4MSRsgWDE7cgm2AakegUZkOVF7x/Po/hIlo6OhOackZ/f57p8zTDbuccPnt+ZZ+bMVHV2diIiImnaqtIFiIhI5SgEREQSphAQEUmYQkBEJGEKARGRhCkEREQSVlPpAkQKmdmuwD3ufmDGn3MIsNrdX8jyc+JnbQX8M7A3sBHoAL7m7ouG+HNOByYB1wLfcvcZQ7l92TIpBCRVZwL3AJmHAHAsMM7djwIwsxOAa4BPZ/Fh7t4CKACkJAoByS0zewR4HtgLWAM8BhwD7AAcTTiIfhrYDtgR+Ed3/4mZHQV8F2gHVhAO+B8BriB8C/814cC8r5m9BHwK+CxQC7wZx08Gjge2ASYCV7j7bWb2MeA6oAp4DTgFaACuj9NWAGe6+5sFu7IE2N/M/hp4CPgv4MG4j58Hzo/rAnw+7u8lwFpgAnATcDjwYeA6d78x1v0Y8CFgJXBSwd/brsSzKTN7AZgL7AN0xr+vt4AbgP2BFmA34JPu/qcSmkW2MLomIHn3O3c/AtgaeCd+m34JmBrnjwKOIoTC1WZWC9wMfNbdpxIOgJfFZUe6+xR3/w7wC+BiYCkwBjjS3acQguCAuPz27v4JQkh8I067GTjD3T9GCJNG4P8C57v7oYSD+8WFO+DufwDOBk4A5gNPAx+Ps/cEpsV1nRByAPXA54DzYv2nAscB0+P8bYA73X0y8HLB9J62A+6OfxevxW18Chjj7n8FfJkQNJIohYDk3bNxuJpw8AdYBYyM43PdfaO7t8bpY4G33P21OP9RwrdlCAfZbty9q4/+bjP7F8LBtzbO/n0cLin4vJ3dvSmu+0N3f5YQBD+MZy5nAuMKP8PM9gmL+0mxvkuAfzezKuB/gdvN7EeEb+tdn/2iu6+L+73Q3Tt67Pc6d380js8DrLe/vOi5HvvRCDwR92EZIUQkUQoBybv+Hm61H4CZ7Uz41vs6sJ2Z7RLnTwUWxPGNBettBLaKB+gT3P2vga8S/k10dc309tmvm9ke8TNnmtlnCOFyWvw2fzHwQI91jgRmm1m1u3cSzgbejvV+BzgROAt4t5/PLlRrZh+O4wfHbfal57ZeJJ6JmNn7CGcjkihdE5DhbqyZPQRsD8xw9w1mdjZwn5ltJHx7Pp3Qz17ot8DlhL70t83saUIf/Bv0+Cbfw3Tg1rjtNwh34iwG7jCz6rjMl3uscz3h7qDnzOwtQgCdSuibf5xwtvN2rHUc8EqJ+z7TzD4QP/8yCq4L9OMB4Dgzm0e4JvAOsK7EdWULU6WniMpw1XVLpLt/o79ltzRm9ifCvrcPYt1JwEfc/R4zG0M4i/igu68d2iplOFB3kEh6lgAnmdmThAvkMxUA6dKZgIhIwnQmICKSMIWAiEjChtXdQWa2NeGHPG8AGypcjojIcFEN7AI81fP6z7AKAUIAPFbpIkREhqkpwP8rnDDcQuANgDvvvJOxY8cOeOXm5mYaGhqGvCh5b9Qu+aM2yafBtktLSwunnHIKxGNooeEWAhsAxo4dS319/YBXbmtrG9R6ki21S/6oTfJpCNpls250XRgWEUmYQkBEJGEKARGRhCkEREQSphAQEUmYQkBEJGEKARGRhCkEREQSphAQEUmYQkBEJGEKARGRhCkEREQSphAQEUmYQkBEJGEKARGRhCkEREQSphAQEUmYQkBEJGEKARGRhCkEREQSphAQEUmYQkBEJGEKARGRhCkEREQSVpPFRs1sPHAVsBKY7+43xOlHAqcBVcCNwJPAzUAbsLW7z4jLbRfnHe7uLVnUKCIi2Z0JTAeujwf1aWZWG6dfCJwFnAPMBA4FFrn7BcAyMzvIzLYCZgMLM6pNRESirEJgLLAkjq8Cto/jVe7e4e7vAiN7LLcUGAd8m3B2sCKj2kREJMqkOwhYDNQTDvCjgdVxeruZjQCqgfa43JQ4rx5YABwI7Ax8HLiYcPbQTXNzM21tbQMuqr29naampgGvJ9lSu+SP2iSfBtsura2tfc7LKgRuAa42s9OB+4BrzOzvgGvjvFpgFvAMcKqZXQfg7nOBuQBmdhtwZW8bb2hooL6+fsBFNTU10djYOOD1JFtql/xRm+TTYNulrq6uz3mZhEC8mHtyL7P+fJAvML2PbZw+xGWJiEgPukVURCRhCgERkYQpBEREEqYQEBFJmEJARCRhCgERkYQpBEREEqYQEBFJmEJARCRhCgERkYQpBEREEqYQEBFJmEJARCRhCgERkYQpBEREEqYQEBFJmEJARCRhCgERkYQpBEREEqYQEBFJmEJARCRhCgERkYQpBEREEqYQEBFJmEJARCRhCgERkYQpBEREEqYQEBFJmEJARCRhCgERkYQpBEREEqYQEBFJWE0pC5nZEcDuwG+BBe7enmlVIiJSFv2GgJl9D6gHGoEO4BLgpIzrEhGRMiilO2iyu58GrHH324HdMq5JRETKpJQQqDGzkUCnmVUDGzKuSUREyqSUawLXAM8AOxGuCVydaUUiIlI2pYTAE8BkoAF4BRiTaUUiIlI2fYaAme0FjAeuAC6Ok8cAlwMfKbZRMxsPXAWsBOa7+w1x+pHAaUAVcCPwJHAz0AZs7e4zzOwzwOcI3U7Xuvtzg947EREpqtg1gfcBJwI7E+4GOgn4AvDDErY7Hbje3WcA08ysNk6/EDgLOAeYCRwKLHL3C4BlZnYQ0AmcAfwA+MxAd0hERErX55mAuz8GPGZm+7r7swPc7lhgSRxfBWwPLAeq3L0DIF5sLlxuKTDO3e81s6nAdcDXe9t4c3MzbW1tAywJ2tvbaWpqGvB6ki21S/6oTfJpsO3S2tra57xSrgnUm9lsoJbQjbOju+/dzzqLCb8tWAKMBlbH6e1mNgKoBtrjclO6PgeYH3+Y9jCwH/AL4JGeG29oaKC+vr6E0rtramqisbFxwOtJttQu+aM2yafBtktdXV2f80q5RfRbwD8QDui3A8+XsM4twFfNbA5wH3BNPPhfG+fdCswCHgd2NbPrgB3cfR7hdwi3x2V+XMJniYjIIJVyJrDC3Z8ws3Pd/TYzO6O/Fdy9BTi5l1lz459C03usewshKEREJGOlnAmsNbNDgFozOwbYJeOaRESkTEoJgfMI1wO+S7ir55uZViQiImVTSndQC+H3AaOAKwm3cIqIyBaglBB4EBjBpjt8OoHPZlZRRjrWb+TJJW8zcY+NjKjRaxRERKC0EBjp7lMzryRjjy5Yxnd/08qE+mUc+Zc7V7ocEZFcKCUEHo0XhP/8CwV3X5xdSdk4ZM+duOywnTlkz50qXYqISG6UEgI7E+7vL+wOOiizijIyomYrDpywrbqCREQKlBIC5u766aCIyBaolBD4g5kdCDxHvDOo6/k/IiIyvJUSAocA0wr+u5Pw0nkRERnm+g0Bd9+nHIVkbU37em5/biWX7baeUSNLyT4RkS1fsZfK/MDdv2JmT9DjB2LuPuwuDN80dyH3vLCaHccs5KJjrNLliIjkQrGvxLPi8DSg8BrA6OzKyc65UyeyfMVyzp06sdKliIjkRrH7JavMbE/gXwm/GN4a+AtgTjkKG2qjRtbwpY+OVldQ3qzvYNRrj8F63WsgUgnFjogHAl8DjHDgrwI2Ar8sQ12SioUPUf/4JTBhAthxla5GJDnFXi/5U+CnZna8uz9YxpoyoWcH5dTEI1h68GwmTDyi0pWIJKmUvpE1ZnYsoevo+8A33f2ubMsaeg83tTLr4VZ22aWVY/fWKxFyo2YEa8ZPgZoRla5EJEmlfCW+Evgj8LfAwcC5mVaUlaqCPyIiApQWAu8CrcD6+NrIrbMtKRuHT9qZbx62M4dP0hNEc0UXhkUqqpQQeAv4NfDvZnY+MOyeIAp6gFxudV0YXvhQpSsRSVIpR8QvAme6+x2El8T/TbYlSVLihWF0YVikIvoMATM7A8Dd1xJ70t39ReAb5SltaHXdHdSxfmOlS5FCujAsUlHFzgROLRj/fsH4sHzLWNebxR5dsKzSpYiI5EbRXwyXMD5s6M1iOaULwyIVVSwEOksYF3lvFvyS+sdnwgL9EF2kEor9WGyMmR1FCIrRZnY04SxgWD5ATi+az6kN66GzMwxFpOyKhcCzwMlx/DngpILxYUfdQTlVXQNVVWEoubBuwzqeXvU0DRsaqK2urXQ5krFizw46o5yFSKImHMA7o/di2wkHVLoSiR5Y9ABX/vFK6naq44Q9Tqh0OZKxZH45pbuDcur+C9lmxQtw/4WVrkSiu5ru6jaUfOjs6ICnngrDIZRMCKg7KKc6N3YfSsW1rW3rNpR8WH3/A/C92WE4hPoNATP7vpl9ZEg/VaTLnkd3H0rF7fv+fbsNJR9a5szpNhwqpZwJPABcambzzOw8M9tuSCsok1/Nb+EfH27lV/NbKl2KFFr4SPehVNyzy57tNpScePXV7sMh0m8IuPsv3P2LwKeBKcAbZnabmX1wSCvJ2PzX36IzDiVHjr+Sd8bsA8dfWelKJBpfN77bUPJhqw98oNtwyLbb3wJm1mhmVxAeHrcKmAzcAPxkSCvJ2PmHNXDiPjtw/mENlS5FCi15im1WvghLnqp0JRKt6VjTbSj5MPHOH0PjpDAcQqV0B90CvATs5+7nu/tz7v4U8KMhrSRjetF8Tq1rDxeF17VXuhKJTpx0Yreh5EPNTjvB974XhkOolBD4ubvf7u7vApjZbAB3v2FIK8mYniKaU2/8vvtQKm7a7tO4eI+Lmbb7tEqXImXQ59diM/sycBbQaGbHx8nVQC1wSRlqG1IPv9zKrN+0ssu4Vo7dS+8Yzo3ODd2HUnG11bXs/7799WvhRBTrG/kx8BBwKfBPcdpG4H+zLioL69dvpLMzDCVHqqq7D0WkrIqFwN7u/rSZ/QSwgumNwK+KbdTMxgNXASuB+V1dR2Z2JHAa4UF0NwJPAjcDbcDW7j7DzM4DDgBGAf/q7j8b1J710HXsVwbkTP1+wFZxKCLlVuyaQNf7/k4kPDyu608pV4umA9e7+wxgmpl1nVdeSOhiOgeYCRwKLHL3C4BlZnYQsNrdzwTOY9MD7N6zpjfepDMOJUcmfYKlk6+ASZ+odCUiSSp2JnCNmY0gHNAHaiywJI6vArYHlgNV7t4BYGYjeyy3FBjn7neb2SjCmcTs3jbe3NxMW9vAftI+akPbn4dNTU0DWlcytGEdI9ato+nll0F90LnR3t6ufyc5NNh2aW1t7XNesRBwNn+BTFWctns/n7kYqCcc4EcDq+P09hgs1UB7XG5KnFcPzDezScBlwDfd/ZXeNt7Q0EB9fX0/JXT3s1ebgOW8O2I7GhsbB7SuZMh/TufvvkXV7j8GO67S1UjU1NSkfyc5NNh2qaur63NesUdJ7zbgT9rkFuBqMzsduI9wVvF3wLVxXi0wC3gGONXMrovrPQH8EXgemGVmL7r75e+hjj87e/JE/vCnVs6ePHEoNidDZeIRLD14NhMmHtH/siIy5IrdIvoDd/+KmT1BjzMCdz+o2EbdvYXe+/Pnxj+FenY3ZfKT3mcWr2Leq2/zzOJVerNYntSMYM34KVAzotKViCSpWHfQrDjcIn42qEdJi4hsrs+7g9y960pCNXA1cD9wOeG3AsPOiJqtOHDCtoyoSeYVCiIi/SrliPgvhH78ycBdwK2ZVpQRPTZCRGRzpYTABnf/ubu/GX+4NSy/Suv1kiIimyt2YbjrVU9vm9nFwKPAXwF933CaY7omICKyuWIXhk+Kw5WER0V03Zy6NtOKMqJrAiIimyv2O4EzeptuZnoEp4jIFqLfN6yY2XeAGcAIYBtgAfChjOsacl0XhifusVFnAyIiUSlHw+MIj3S4k9Al9FqmFWVEF4ZFRDZXSgiscPe1QJ27NxPOBoYdXRgWEdlcKSGw1MzOJNwlNBvYLuOaMqELwyIimyvliDgd+DXw98DrDNPHSOjHYiIimyslBN4HfI3wa+Fx6JqAiMgWo5QQuANoJjzj/zXg9kwryoiuCYiIbK7fW0SBke5+Yxx/3sw+l2VBWelYvxFfvpaO9bpFVESkS7HHRuwZR5eb2ReAxwiPjej1bV95d9Pchdzzwmp2HLOQi46xSpcjIpILxc4E5hSMzyC8+L3r9ZLDzrlTJ7J8xXLOnao3i4mIdCn22IjDusbNbAwwEVjk7svLUdhQGzWyhi99dDSjRpbSAyYikoZ+O8djV9A84FLgSTP7m8yrEhGRsijlCumFwH7ufgLwUcLtoiIisgUoJQQ2uvsaAHdvA9qzLUlERMqllA7yhWZ2FeGlMocAC7MtSUREyqWUM4GzgEXAUXF4dqYViYhI2ZRyJnC/ux/d/2IiIjLclBICq83sU4SXyWwEcPcFmVYlIiJlUUoI7ARcUPDfncDh2ZQjIiLlVDQEzGw7YJq7v1OmekREpIz6vDBsZl8Bnic8NO6Y8pUkIiLlUuzuoJMBAz4OfL085YiISDkVC4F2d++IzwoaUa6CRESkfEp9sH5VplWIiEhFFLsw/CEzu4sQAF3jALj7yZlXJiIimSsWAl8sGL8p60JERKT8ir1PYG45CxERkfLTy3ZFRBKmEBARSZhCQEQkYQoBEZGEZfLWdTMbD1wFrATmu/sNcfqRwGmE205vBJ4EbgbagK3dfUZc7ljgLHf/fBb1iYhIkNWZwHTg+nhQn2ZmtXH6hYSX1JwDzAQOBRa5+wXAMjM7yMwOAxqAURnVJiIiUSZnAsBYYEkcXwVsDywHqty9A8DMRvZYbikwzt3vBX5jZp/oa+PNzc20tbUNuKj29naampoGvJ5kS+2SP2qTfBpsu7S2tvY5L6sQWAzUEw7wo4HVcXq7mY0AqgkvrF8MTInz6oH5pWy8oaGB+vr6ARfV1NREY2PjgNeTbKld8kdtkk+DbZe6uro+52XVHXQL8FUzmwPcB1wTD/7Xxnm3ArOAx4Fdzew6YAd3n5dRPSIi0otMzgTcvYXwKOqe5sY/hab3sY1jh7ouERHpTreIiogkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkTCEgIpKwmiw2ambjgauAlcB8d78hTj8SOA2oAm4EngRuBtqArd19Rs9l3H1eFjWKiEh2ZwLTgevdfQYwzcxq4/QLgbOAc4CZwKHAIne/AFhmZgf1soyIiGQkkzMBYCywJI6vArYHlgNV7t4BYGYjeyy3FBjXyzKbaW5upq2tbcBFtbe309TUNOD1JFtql/xRm+TTYNultbW1z3lZhcBioJ5wgB8NrI7T281sBFANtMflpsR59cD8XpbZTENDA/X19QMuqqmpicbGxgGvJ9lSu+SP2iSfBtsudXV1fc7LqjvoFuCrZjYHuA+4Jh7Yr43zbgVmAY8Du5rZdcAOsf+/5zIiIpKRTM4E3L0FOLmXWXPjn0LTe6zb2zIiIpIB3SIqIpIwhYCISMIUAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCcvqF8NZqQZoaWkZ1Mqtra1FfzknlaF2yR+1ST4Ntl0KjpnVPecNtxDYBeCUU06pdB0iIsPRLsDCwgnDLQSeIjxr6A1gQ4VrEREZLqoJAfBUzxlVnZ2d5S9HRERyQReGRUQSphAQEUnYcLsmMGTM7JPABMJLbG6odD2yiZmdA3S4+22VrkXAzM4APkB4Vey9la5HAjM7BtgLWOnuPxrsdpI9E3D3nwErgG0qXYtsYmaHAG9Xug7pZhLwJtBR6UKkm6OAWsJxbNCSDQEz29Pd/43w5jPJj6MBY9Mb56Tyvu/u1wJTK12IdLONu18OTH4vG0m2Owg4wMw+S3i3seSEu19mZrsCh1a4FNnkU2a2LbCo0oVIN782s4uA197LRrbYW0TNrAH4D3f/qJmNB64CVhL6NXUNoELULvmjNsmncrXLFtkdZGZjgbPY1Lc8Hbje3WcA08ystmLFJUztkj9qk3wqZ7tskSHg7i3u/g1gTZw0FlgSx1cB21eksMSpXfJHbZJP5WyXLTIEerEYqI/jo4HVFaxFNlG75I/aJJ8ya5dULgzfAlxtZqcD97n7+grXI4HaJX/UJvmUWbtssReGRUSkf6l0B4mISC8UAiIiCVMIiIgkTCEgIpIwhYCISMIUAiIiCVMIiIgkLJUfi8kwZGaHAj8F9nb3JXHa5cDLg33hTHxC6T3ufuAQlVm47WrgQWBb4JPuvipO3wm4CRgFVAGvAn/r7u8O8effRti3XwzldmXLpjMBybsO4EdmVlXpQkqwC7Cju0/uCoDo74H/cfdj3P1owkPBzq1IhSI96ExA8u5hwpeV84EfdE3s+Y3ezJ4ETgROBxqAHQnPWPkh8DlgT+BLQAuwk5n9N/B+4AF3n2VmE4CbgZFAO3AOUA10vYHuQXe/suDzTwG+DqwF/hiXvxnYw8zmuPv0gn14Ffi8mTUDjwMXAZ1xO7OB/YE6oMndzzCzfyhhH/4DeIPwPJmfu/v/KaitlnDmsUf8u7vM3R8xs38CDo/T7o4vipHE6UxAhoPzgAvMbI8Sl3/X3Y8F7gOOd/dPApcTQgJCt8ypwMHAcWb2YeCfCY/qPSyOXx6XHQsc3SMAxgDfAQ5398mEh3lNB2YAL/UIAIAbgbsIZwSvA/8JjDOz7YBV7n4UcBBwYHxufCn7sCsh8A4ADjezfQs+7yxgubsfAnwa6Hr2/GnAycAhwJB2RcnwpRCQ3HP3FYRv3bfR9/+zhd1Fz8bhauClOL6K8C0f4Hl3f9PdNwC/I3zD3hu41MweAb5FOEsAeMXde75bd3fCiz3a4n8/CnyoyC4cBtzh7scQQuV3wLWEA/H7zexuYA4hnLqeE1/KPqyM+/Bbwis5u+wNHB/35SdATQyuE4HZwC+BHYrUKwlRCMiw4O4/A5zw7RdCl837zazazHYAditYvL+nIjaa2SgzqwE+BswHXgZmuvuhhG/198ZlN/ay/ivAX8ZXLkJ49+6CIp/3NeCMuB9r4+etBY4DJrj7ScClwF+wKcxK2Ydt4sXoj7EpKIj7cnfcl+MIXUdrgC8AJxG6hE43sw/28xmSAIWADCdfJ3ZjuHsL8D/AU4S++OYBbGcl8G/APOBed3+J0E//bTObC9wBvNDXyu6+HPg28Jt4LWJHQpdPX84lvA3qOTObR+jXv4hwRrB73Ma9hHf4jitxHzoIB/ffAv/l7s8XzJsDTIr7Mg94NYbPSuD3hOssvyI8o14Sp0dJiwwzWd7mKunRmYCISMJ0JiAikjCdCYiIJEwhICKSMIWAiEjCFAIiIglTCIiIJEwhICKSsP8Py2LYmRC4TzMAAAAASUVORK5CYII=\n", 478 | "text/plain": [ 479 | "
" 480 | ] 481 | }, 482 | "metadata": {}, 483 | "output_type": "display_data" 484 | } 485 | ], 486 | "source": [ 487 | "# plt.figure(figsize=(8, 10))\n", 488 | "\n", 489 | "y_axis_imp = np.asarray(total_probas)\n", 490 | "x_axis_imp = np.asarray(num_samples)\n", 491 | "\n", 492 | "for x, y in zip(x_axis_imp, y_axis_imp):\n", 493 | " plt.scatter([x] * len(y), y, s=0.5)\n", 494 | " \n", 495 | "plt.xscale('log')\n", 496 | "plt.title(\"Importance Sampling\")\n", 497 | "plt.ylabel(\"Probability Estimate\")\n", 498 | "plt.xlabel('Number of Samples')\n", 499 | "plt.grid(True)\n", 500 | "plt.savefig(plot_dir + 'imp_sampling_convergence_speed.jpg', format='jpg', dpi=300)\n", 501 | "plt.show()" 502 | ] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": {}, 507 | "source": [ 508 | "## Side by Side" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 21, 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "data": { 518 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH6tJREFUeJzt3X94VNW97/F3EhJDIIYAKQEG+ZE0K1HEivWIIKhX9PJArS1WKlh+eKrQQ/W2pBzlVNoq3qKiEMCChYdS8dctVkMrkuu1Wg8oFn/VPhUYVgigJCAxIYABHJIwuX/MQGNMgEwmzMrm83qeeRj2Wnuv785K8snes2dPXH19PSIiIq6Jj3UBIiIiTVFAiYiIkxRQIiLiJAWUiIg4SQElIiJOUkCJiIiTOrR0BWNMb2A+UAVssdYuCS8fCUwC4oAngE3AcqAaOM9aO90YMwq4CTgGvGGt/XNU9kJERDynxQEFTAMWW2vfNsYUGWOWW2trgXzgO0AC8AdgEbDTWjvXGPOAMWYo8GPgn0Bv4MPGGzbGnAdcDnwKHI9oj0REpD1JAHoC71lrjzVsiCSgMoHS8PMDQBpQCcRZa2sAjDHJjfqVAb2AgcD3w23/m9ARV0OXA29GUJOIiLRvw4G3Gi6IJKB2Az5C4dMVOBheHjDGJBFKw0C43/Bwmw/YAnwSbtvfzLY/BXj22WfJzMyMoLSQkpISsrOzI15f3KL59A7NpbdEYz737dvHbbfdBuHf/w1FElArgAXGmClAIVBgjPkZsDDclgg8CHwATDTGLAIInxL8DbCK0Om7eU1s+zhAZmYmPp8vgtJCqqurW7W+uEXz6R2aS2+J8nx+5WWdFgeUtXYfMKGJpvXhR0PTGq37AvBCS8cUEZFzjy4zFxERJymgRETESQooERFxkgJKREScpIASEREnKaBERMRJCigREXGSAkpERJykgBIREScpoERExEkKKBERcZICSkREnKSAEhERJymgRETESQooERFxkgJKREScpIASEREnKaBERMRJCigREXGSAkpERJykgBIREScpoERExEkKKBERcZICSkREnKSAEhERJ3Vo6QrGmN7AfKAK2GKtXRJePhKYBMQBTwCbgOVANXCetXa6MWYyMB7YB7xhrV0Vlb0QERHPieQIahqw2Fo7HRhjjEkML88H7gCmAvcC1wA7rbUzgApjzFBgBLCHUIhtamXtIiLiYS0+ggIygdLw8wNAGlAJxFlrawCMMcmN+pUBvYCVwHvhdVYANzU1QElJCdXV1RGUFhIIBPD7/RGvL27RfHqH5tJbojGf5eXlzbZFElC7AR+h8OkKHAwvDxhjkoAEIBDuNzzc5gO2AFcBfwM+J3QU1aTs7Gx8Pl8EpYX4/X7y8vIiXl/covn0Ds2lt0RjPlNTU5tti+QU3wrgbmPMMqAQKAgH08Jw20rgQWAj0M8YswjoYq19G6gAfkfotamHIhhbRETOES0+grLW7gMmNNG0PvxoaFqjdX8P/L6lY4qIyLlHl5mLiIiTFFAiIuIkBZSIiDhJASUiIk5SQImIiJMUUCIi4iQFlIiIOEkBJSIiTlJAiYiIkxRQIiLiJAWUiIg4SQElIiJOUkCJiIiTFFAiIuIkBZSIiDhJASUiIk5SQImIiJMUUCIi4iQFlIiIOEkBJSIiTlJAiYiIkxRQIiLiJAWUiIg4SQElIiJOUkCJiIiTOrR0BWNMb2A+UAVssdYuCS8fCUwC4oAngE3AcqAaOM9aOz3c7/xw2/+w1u6Lxk6IiIj3RHIENQ1YHA6cMcaYxPDyfOAOYCpwL3ANsNNaOwOoMMYMNcbEAw8BO1pdeXOOHiX1xRfh6NE2G0JERNpei4+ggEygNPz8AJAGVAJx1toaAGNMcqN+ZUAv4FeEjqpmnGqAkpISqqurIygNUl98kYwFCygDqm++OaJtiFsCgQB+vz/WZUgUaC69JRrzWV5e3mxbJAG1G/ARCp+uwMHw8oAxJglIAALhfsPDbT6gGBgC9ACuBO4hdNT1FdnZ2fh8vghKA/LzKQN8+fmQkhLZNsQpfr+fvLy8WJchUaC59JZozGdqamqzbZEE1ApggTFmClAIFBhjfgYsDLclAg8CHwATjTGLAKy164H1AMaYJ4F5EYx9eikpoSMnhZOISLvW4oAKX9gwoYmmkwHUwLRmtjGlpeOKiMi5RZeZi4iIkxRQIiLiJAWUiIg4SQElIiJOUkCJiIiTFFAiIuIkBZSIiDhJASUiIk5SQImIiJMUUCIi4iQFlIiIOEkBJSIiTlJAiYiIkxRQIiLiJAWUiIg4SQElIiJOUkCJiIiTFFAiIuIkBZSIiDhJASUiIk5SQImIiJMUUCIi4iQFlIiIOEkBJSIiTlJAiYiIkzq0dAVjTG9gPlAFbLHWLgkvHwlMAuKAJ4BNwHKgGjjPWjvdGPNd4GbgOLDQWvthVPaioUCATn/5C/TvD8nJUd+8iIicHS0OKGAasNha+7YxpsgYs9xaWwvkA98BEoA/AIuAndbaucaYB4wxQ4F64HbgG8B3gSYDqqSkhOrq6ghKg05/+QvdH3uM3cCR66+PaBvilkAggN/vj3UZEgWaS2+JxnyWl5c32xZJQGUCpeHnB4A0oBKIs9bWABhjkhv1KwN6WWtfMMZcTSi8ftrcANnZ2fh8vghKA/r3ZzdwwdSpOoLyCL/fT15eXqzLkCjQXHpLNOYzNTW12bZIAmo34CMUPl2Bg+HlAWNMEqEjqEC43/Bwmw/YYoy5DvgrcBnwCvDfEYx/asnJoSMnhZOISLsWyUUSK4C7jTHLgEKgIBxMC8NtK4EHgY1AP2PMIqCLtfZtoD+wKtznmSjULyIiHtXiIyhr7T5gQhNN68OPhqY1WncFoRATERE5Jc9dZh4MQmVlAsFgrCsREZHW8FxAVVTAhg2dqKiIdSUSFXV1JBUXQ11drCsRkbPMcwGVkQEjRhwhIyPWlUhUbNtG11WrYNu2WFciImeZ5wIqPh66dz9OvOf27ByVm0vV5MmQmxvrSkTkLNOvcXFbhw7U5ORAh0jeESEi7ZnnAqquDoqLk/SShYhIO+e5gNq2DVat6qqXLERE2jnPBVRuLkyeXKWXLDwiWBfkQPEhgnV634DIucZzARUfD1276iIJr6jYtp9Nqyqo2LY/1qVIawWDJFRWojcpypny3K/x8nIoKkrlFDfIlXYkI7cbQyZnkJHbLdalSGuVl5NaVIR+OOVMeS6gxGOCQRIOHtRf3SLnIM8FVI8eMHp0NT16xLoSiYaKt7fz/uLtVLy9PdalSGt168ax3FzopqNhOTOeCyjxloyhX+eb/+vrZAz9eqxLkdYqLibtz3+G4uJYVyLthOcCSvfi85b4DvGkD+hEfAfPfauee3JyOHTTTZCTE+tKpJ3w3E+97sXnLcHyCo4WfUCwXH9xtHsVFaR88AH661HOlOcCSrylPJjBH498i/Kg/uJo9/bvJ6mkBPbrLQNyZjwXUDrF5y11NUEqyo5TV6Or+Nq7YE4u20dNJpijd9F7wll4X5vnAqpbNxg4MKALhTzi07/t5J//XcOnf9sZ61KklezGCmbMvRC7UX89ekFwxy6Cv1pKcMeuNhvDcwFVUQGbNqXoCMojdgYHsIMB7AwOiHUp0kqLF9SxsTyXxQt0J2cv2LXsZX71xhh2LXu5zcbwXEAFg3DkiD7y3Sv69gmSkXqEvn00oe3djT/MoEvyUW78oV5P9IIlG3N5nu+xZGPbnbL1XECF1Me6AImSw+9sJq78AIff2RzrUqSV3iw6yKFAR94sOhjrUiQK1m3JBRLD/7YNzwVUfDx06hTUzWI9Ivs7A7nw8nqyvzMw1qVIK2XmpJGUcJzMnLRYlyJRsG7BLnLZzLoFeg3qjGVkwJAhR/U+KI/YXxVP2cEu7K/y3LfqOefbHd9gXOKLfLvjG7EuRaIgu/9xPsgbT3b/4202hud+6nWRhLdUfPgJ5Ttrqfjwk1iXIq3U9wbDfd/8E31vMLEuRaJh+HA+u/deGD68zYbwXEDV1MDu3YnU1MS6EomGDv36ktg9lQ79+sa6FGml+L1l9D5UTPzesliXItGQlMQX//ZvkJTUZkN0aOkKxpjewHygCthirV0SXj4SmATEAU8Am4DlQDVwnrV2ujHmP4DLgc7A09batVHZiwaK/3mUv/+lluJbj9K3b0q0Ny9n2dVXBPj1sNVcfcVthL5tpN268kqq7riDzldeGetKpJ1ocUAB04DF1tq3jTFFxpjl1tpaIB/4DpAA/AFYBOy01s41xjxgjBkKHLTW/rsxphvwG6DJgCopKaG6ujqiHTq2bgM9qnI4tq4Yf/aIiLYh7kh98imG/OkPfNbnY6qnTIp1OdIKCZWVJFVVUfzeexzv3j3W5UgUBAIB/H5/q7ZRfooPsIwkoDKB0vDzA0AaUAnEWWtrAIwxyY36lQG9rLX/xxjTmdAR2EPNDZCdnY3P54ugNOic35e8LX/nkvyJ9MnREVR7Vz5tNs9tvpQJ08aQN6BTrMuR1qip4eOdO8m54oo2PS0kZ4/f7ycvL69V20hNTW22LZLXoHYDJ9KjK3DiTQ0BY0ySMaYjEGjUzwfsNcbkAr8FHrDW/jOCsU+rQ2oKXa7KokOqwskLMi7oyBVTLyDjgo6xLkVaa9s20p5/HrZti3Ul0k5EElArgLuNMcuAQqDAGJMELAy3rQQeBDYC/Ywxi4AuwN+Al4GOwIPGmFlRqP8r9Im63hK/v4K+m18nfr8uy2z3unWjJitLn6grZ6zFp/istfuACU00rQ8/GprW6P/ZLR1PznEZGRwZMQK9sc0Devbk0C230KNnz1hXIu2E5y4z18dteExNDckffojeN+AB8fGhiyN0mxc5Q577TtEn6npMURFfW7AAiopiXYmInGWRXMUncvaMHs1nZWVcMHp0rCsRkbPMc0dQOsXnMUlJBC69VJcli5yDPBdQOsXnMRUVdNqwAf3FIXLu8VxAicfoKj6Rc5bnAkqn+DxGV36JnLM891Ofng4+Xy3p6bGuREREWsNzAVVcDH/+cxrFxbGuREREWsNzAZWbC5MnV5GbG+tKRESkNTwXUB06QE5ODR30Di8RkXbNcwElIiLeoIASEREnKaBERMRJCigREXGSAkpERJykgBIREScpoERExEkKKBERcZICSkREnKSAEhERJymgRETESQooERFxkgJKRESc1OJ7fhtjegPzgSpgi7V2SXj5SGASEAc8AWwClgPVwHnW2unhfqOAO6y134vKHoiIiCdFcgQ1DVgcDpwxxpjE8PJ84A5gKnAvcA2w01o7A6gwxgw1xlwLZAOdW115cw4fJm3VKjh8uM2GEBGRthfJpyZlAqXh5weANKASiLPW1gAYY5Ib9SsDellrXwDeMMZ861QDlJSUUF1dHUFpkLZqFRmPP85e4NDkyRFtQ9wSCATw+/2xLkOiQHPpLdGYz/Ly8mbbIgmo3YCPUPh0BQ6GlweMMUlAAhAI9xsebvMBW850gOzsbHw+XwSlAbNnsxfoNXs2vTq33YGanD1+v5+8vLxYlyFRoLn0lmjMZ2pqarNtkZziWwHcbYxZBhQCBeFgWhhuWwk8CGwE+hljFgFdrLVvRzBWy3XuHDpyUjiJiLRrLT6CstbuAyY00bQ+/GhoWjPbGNXScUVE5Nyiy8xFRMRJCigREXGSAkpERJykgBIREScpoERExEkKKBERcZICSkREnKSAEhERJymgRETESQooERFxkgJKREScpIASEREnKaBERMRJCigREXGSAkpERJykgBIREScpoERExEkKKBERcZICSkREnKSAEhERJymgRETESQooERFxkgJKREScpIASEREnKaBERMRJHVq6gjGmNzAfqAK2WGuXhJePBCYBccATwCZgOVANnGetnd64j7X27ajshYiIeE4kR1DTgMXW2unAGGNMYnh5PnAHMBW4F7gG2GmtnQFUGGOGNtEn+o4eJfXFF+Ho0TbZvIiInB0tPoICMoHS8PMDQBpQCcRZa2sAjDHJjfqVAb2a6NOkkpISqqurIygNUl98kYwFCygDqm++OaJtiFsCgQB+vz/WZUgUaC69JRrzWV5e3mxbJAG1G/ARCp+uwMHw8oAxJglIAALhfsPDbT5gSxN9mpSdnY3P54ugNCA/nzLAl58PKSmRbUOc4vf7ycvLi3UZEgWaS2+JxnympqY22xbJKb4VwN3GmGVAIVAQDp2F4baVwIPARqCfMWYR0CX8elPjPtGXkhI6clI4iYi0ay0+grLW7gMmNNG0PvxoaFqjdZvqIyIi8hW6zFxERJykgBIREScpoERExEkKKBERcZICSkREnKSAEhERJymgRETESQooERFxkgJKREScpIASEREnKaBERMRJCigREXGSAkpERJykgBIREScpoERExEkKKBERcZICSkREnKSAEhERJymgRETESQooERFxkgJKREScpIASEREnKaBERMRJCigREXGSAkpERJzUoaUrGGMmAiOAFOABa21xg7Z5wHlAGjANMMB/AUeBl621a4wxycBjwFvW2j+0fhdERMSLIjmCmmytvROYDcw8sdAYMwBIs9b+BFgP3Bxu/5m19ofAf4S75gPBVlUtIhKBDRs2sHr16jYf55lnnmnzMc6Wd955hxkzZgBw1113ndWxT3sEZYyZCkxosOiL8L9lQK8GyzPDy060XQJkWGv3hpfVA1hr5xpjppxqzJKSEqqrq09bfHMCgQB+vz/i9cUtmk/viPVcZmRkkJGR0eY1PP7441x22WVtOsbZ8sknn/D555/j9/v58Y9//KWvXTTms7y8vNm20waUtXY5sPzE/40x68JPfcDeBl1Lgd6N2vYYY3paaz8F4s604OzsbHw+35l2/wq/309eXl7E64tbNJ/eEeu5LCwsZOfOndx6663MmDGDnj17UlZWxpgxY9i+fTtbt27lmmuuIT8/n4kTJ9K/f3927dpFfX09BQUFZGRk8PDDD/PBBx8A8K1vfYvJkycza9YsDh48yMGDB7n66qs5cuQIq1evZubMmdx3331UV1dz4MABbrnlFiZMmMDEiRPJzc1l+/btHD58mEWLFtG7d2+WLl3Ka6+9xvHjxxk/fjy33norTz/9NC+//DJxcXGMHj2aSZMmfWmfCgoK2LRpE8FgkDFjxjBlyhTeffddfvOb3wChEHnkkUdITEyMaJ/79u3L+eefT15eHsOGDWPjxo0n6//HP/5BfX39yfqXLFnCa6+9RteuXfniiy/4yU9+whVXXHHKOUlNTW22rcWvQQFPGWNWAOcD/2mM6Q380Fo7xxhTZYxZROg1qDuBfwDzjTE1wOIIxhKRc8z998MDD5y+369+FeobqdLSUlauXEkgEOC6665jw4YNdOzYkWuvvZb8/HwABg8ezJw5c3j22WdZtmwZw4YNo6ysjOeff566ujomTJjAkCFDABgyZAhTpkwBQqf47r//frZs2cKYMWO44YYbKC8vZ+LEiUyYEDohNWjQIO677z4KCgpYt24dV111FRs2bOCPf/wjNTU1zJ8/n+3bt1NUVMRzzz1HXFwcU6ZM4aqrrmLAgAEn9+NPf/oTzzzzDD169KCwsBCA7du38+ijj9KjRw9++9vf8sorr3DjjTdGtM/XX399k1+/QYMGMXbsWF555RXWrVvHiBEjePPNN3nhhReora3lxhtvjHxywlocUNba1UDjk7hzwm0/b7R8K18+PXhiG0+2dFwROTfcf3/rgudM9enTh9TUVJKSkujevTtdunQBIC7uXyd7ToTP4MGD+etf/0pmZibf/OY3iYuLIzExkUsuuYQdO3YA0L9//6+M0b17d1atWsWrr75K586dqaurO9l24YUXApCZmUllZSW7du1i0KBBJCQk0LFjR2bPnk1RURF79+49GXyHDh1i9+7dXwqoBQsWsGDBAiorKxk+fDgAPXr04Ne//jUpKSmUl5czePDgiPe5ORdeeCE1NTUn69+xYwcXX3wxCQkJJCQkMHDgwDOZhlPSZeYick5q+Eu5OZs3bwbg73//O9nZ2WRlZZ08vVdbW8uHH35I3759v7K9+vp6AFauXMk3vvENHnvsMUaNGnVyeVMGDBjA1q1bCQaD1NbWcvvttzNgwACys7N56qmnePrppxk7diw5OTkn16mpqeGVV15hwYIFrFq1ijVr1rBnzx5mz57N3Llzefjhh/na1752ctxI9vlMZWdn89FHHxEMBqmpqWHr1q1nvG5zIjnFJyJyTlizZg1PPvkkHTt2ZN68eaSnp/Puu+/y/e9/n9raWkaNGsVFF130lfWysrKYOXMm3/ve97j//vtZu3YtXbp0ISEhgZqamibHysvLY/jw4YwfP55gMMj48ePJzc3lyiuvZPz48dTU1DBo0CB69Ohxcp2kpCTS0tK46aabSEtLY9iwYfTq1YubbrqJcePGcf7559O9e3c+++yziPe5uLj49CsBxhiuvvpqxo0bR3p6OomJiXTo0MqIqa+vd+aRk5PTLycnp760tLS+NbZu3dqq9cUtmk/vaE9z+YMf/KC+pKQk1mWcVS3d54bzWVlZWf/MM8/U19fX1x87dqx+5MiR9Xv27DntNkpLS+tzcnLqc3Jy+tU3ygQdQYmISKulp6ezefNmbr75ZuLi4rjlllvo1avX6Vc8BQWUiEgTnn766ViXcNa1Zp/j4+N56KGHoliNLpIQERFHKaBERMRJCigREXGSAkpERJykgBIRiZC1lvfee69Nx5g1axYbNmxo0zFcpYASEYnQq6++SklJSazL8CxdZi4i54zCwkLeeOMNAoEAFRUVTJo0iddff53t27dzzz33MHLkSF566SVWrVpFUlIS/fr1Y86cOaxdu5b169cTCATYvXs3d955J8OGDWPNmjUkJiZy0UUXEQgEKCgoICEhgT59+jBnzhwSExNPjv3xxx8ze/ZsamtrSU5OpqCggMrKSh5++GGCwSCff/45s2fPZvDgwVx77bUMGDDgS/fcq62t5ec//zmlpaUcP36c22+/ndGjR8fiy3jWKKBExC1tfDvzI0eOsHLlStatW8eTTz7J888/zzvvvMNTTz3FZZddxuOPP86aNWvo3Lkzc+fOZfXq1aSkpHD48GF+97vf8fHHH/OjH/2IsWPH8t3vfpfu3btz8cUXM2rUKJ577jm6devGwoULWbNmDePGjTs57iOPPMLUqVMZMWIERUVFbN26lc8//5x7770XYwxr166lsLCQwYMH8+mnn1JYWEh6ejqzZs0CYPXq1aSnp/Poo49y+PBhxo4dy5AhQ+jatWuLvwbthQJKRNzSxrczP/F5VKmpqWRlZREXF0daWhrHjh2jtLSU7OxsOnfuDMDll1/OW2+9xSWXXEJubi4APXv2/Mr99Kqqqvjss8/46U9/CoQ+g2nYsGFf6rNr1y4uvfRSgJNHPu+//z5Lly4lOTmZI0eOnBw3PT2d9PT0L62/Y8cOhg4dCkDnzp3JysqitLRUASUi4hWnuqO3z+djx44dHD16lJSUFN59992TH6PR1HpxcXEEg0HS09PJzMxk6dKlpKam8vrrr5OSkvKlvllZWXz00UcMHTqUl156iUOHDlFYWMhjjz1GVlYWixcvZs+ePUDorgyNZWVl8f7773P99ddz+PBhiouLW/XBru2BAkpEJKxr167cfffdTJo0ifj4eC644AJmzpzJunXrmuw/cOBA5s2bR1ZWFvfddx9Tp06lvr6eTp06MW/evC/1veeee/jlL3/JE088QXJyMo8++ih1dXVMnz6dbt26kZmZyYEDB5qtbdy4cfziF79g/PjxHDt2jLvuuotu3bpFdf9dE1d/is8nOduMMf2AXa+//ro+8l1O0nx6h+bSW6Ixn2VlZVx33XUA/a21Hzds02XmIiLiJAWUiIg4SQElIiJOUkCJiIiTFFAiIuIkBZSIiDhJASUiIk5SQImIiJNafCcJY8xEYASQAjxgrS1u0DYPOA9IA6YBBvgv4CjwMvB/gWXAF0B3YKq1tqrB5hMA9u3bF8m+nFReXk5qamqrtiHu0Hx6h+bSW6Ixnw1+3yc0bovkVkeTrbUjjTH9CYXPVABjzAAgzVo7zRhzO3AzMAr4mbV2rzHmVeA9YIW19k1jzM+Ay4H/12DbPQFuu+22CMoSEZF2rCewo+GC0waUMWYqMKHBoi/C/5YBvRoszwwvO9F2CZBhrd0bXlZvrS0Dyowxl4fbFzUa7j1gOPApcPx0tYmISLuXQCicvvLRxKcNKGvtcmD5if8bY07cNdEH7G3QtRTo3ahtjzGmp7X2UyAuvP60cPu/W2vrGo11DHjrzPZJREQ8YkdTC1t8s1hjzPeB64Hzgf8E6oAfWmvnGGPmAp0IvQZ1J/B1YDZQA7wAVAPPAW+EN7fUWvt2i3dFREQ8z6m7mYuIiJygy8xFRMRJnv/AQmPMjUAfIM5auyTW9UjrhS/cqbHWPhnrWiRy4at9LwC2WGtfiHU90jrGmP8JDASqrLW/j8Y2PX8EZa1dC+wn9L4taeeMMSOAI7GuQ6IiFzhE6DVqaf+uBxIJ/b6NCs8HlDEmx1q7Guga61okKm4g9Abw4bEuRFrtcWvtQuDqWBciUZFirX0YuCpaG/T8KT7gcmPMWP71Hi1px6y1s40x/YBrYlyKtN63jTGdgJ2xLkSi4jVjzExgT7Q22O6v4jPGZAN/tNZeaozpDcwHqgid19ZrTu2M5tM7NJfeEov5bNen+IwxmcAd/Os1iWnAYmvtdGCMMSYxZsVJi2k+vUNz6S2xms92HVDW2n3W2lnA4fCiTEJ3tAA4QOgNw9JOaD69Q3PpLbGaz3YdUE3YTeg2ShC6KOJgDGuR1tN8eofm0lvOynx67SKJFcACY8wUoLDxvf6k3dF8eofm0lvOyny2+4skRETEm7x2ik9ERDxCASUiIk5SQImIiJMUUCIi4iQFlIiIOEkBJSIiTlJAiYiIkxRQIiLiJAWUiIg46f8D61k8WWw5Tl0AAAAASUVORK5CYII=\n", 519 | "text/plain": [ 520 | "
" 521 | ] 522 | }, 523 | "metadata": {}, 524 | "output_type": "display_data" 525 | } 526 | ], 527 | "source": [ 528 | "fig, ax = plt.subplots(1, 1)\n", 529 | "\n", 530 | "# monte carlo\n", 531 | "for x, y in zip(x_axis_imp, y_axis_monte):\n", 532 | " ax.scatter([x] * len(y), y, s=0.5, c='r', alpha=0.3)\n", 533 | "\n", 534 | "# importance sampling\n", 535 | "for x, y in zip(x_axis_imp, y_axis_imp):\n", 536 | " ax.scatter([x] * len(y), y, s=0.5, c='b', alpha=0.3)\n", 537 | " \n", 538 | "blue = mlines.Line2D([], [], color='blue', marker='_', linestyle='None', markersize=10, label='importance sampling')\n", 539 | "red = mlines.Line2D([], [], color='red', marker='_', linestyle='None', markersize=10, label='monte carlo')\n", 540 | " \n", 541 | "plt.xscale('log')\n", 542 | "plt.grid(True)\n", 543 | "plt.legend(handles=[blue, red], loc='lower right')\n", 544 | "plt.savefig('/Users/kevin/Desktop/plot.jpg', format='jpg', dpi=300, bbox_inches='tight')\n", 545 | "plt.tight_layout()\n", 546 | "plt.show()" 547 | ] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "metadata": {}, 552 | "source": [ 553 | "## References\n", 554 | "\n", 555 | "* http://ib.berkeley.edu/labs/slatkin/eriq/classes/guest_lect/mc_lecture_notes.pdf" 556 | ] 557 | } 558 | ], 559 | "metadata": { 560 | "kernelspec": { 561 | "display_name": "Python 3", 562 | "language": "python", 563 | "name": "python3" 564 | }, 565 | "language_info": { 566 | "codemirror_mode": { 567 | "name": "ipython", 568 | "version": 3 569 | }, 570 | "file_extension": ".py", 571 | "mimetype": "text/x-python", 572 | "name": "python", 573 | "nbconvert_exporter": "python", 574 | "pygments_lexer": "ipython3", 575 | "version": "3.6.5" 576 | } 577 | }, 578 | "nbformat": 4, 579 | "nbformat_minor": 2 580 | } 581 | -------------------------------------------------------------------------------- /pr-lr/cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torchvision.datasets as datasets 15 | from torchvision.datasets import CIFAR10 16 | from torchvision import transforms 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.sampler import Sampler, WeightedRandomSampler 19 | 20 | from models import WideResNet 21 | 22 | # used for logging to TensorBoard 23 | from tensorboard_logger import configure, log_value 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch WideResNet Training') 26 | parser.add_argument('--data_dir', default='./data/', type=str, 27 | help='data path') 28 | parser.add_argument('--epochs', default=10, type=int, 29 | help='number of total epochs to run') 30 | parser.add_argument('--batch-size', default=64, type=int, 31 | help='mini-batch size (default: 128)') 32 | parser.add_argument('--learning-rate', default=1e-3, type=float, 33 | help='initial learning rate') 34 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, # 5e-4 35 | help='weight decay (default: 5e-4)') 36 | parser.add_argument('--print-freq', '-p', default=10, type=int, 37 | help='print frequency (default: 10)') 38 | parser.add_argument('--layers', default=28, type=int, 39 | help='total number of layers (default: 28)') 40 | parser.add_argument('--widen-factor', default=10, type=int, 41 | help='widen factor (default: 10)') 42 | parser.add_argument('--droprate', default=0, type=float, 43 | help='dropout probability (default: 0.0)') 44 | parser.add_argument('--no-augment', dest='augment', action='store_false', 45 | help='whether to use standard augmentation (default: True)') 46 | parser.add_argument('--name', default='WideResNet-28-10', type=str, 47 | help='name of experiment') 48 | 49 | best_prec1 = 0 50 | 51 | def get_data_loader(data_dir, batch_size, num_workers=3, pin_memory=False): 52 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 53 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 54 | transform_train = transforms.Compose([ 55 | transforms.ToTensor(), 56 | normalize, 57 | ]) 58 | dataset = CIFAR10(data_dir, train=True, download=True, transform=transform_train) 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, 61 | shuffle=False, num_workers=num_workers, 62 | pin_memory=pin_memory, 63 | ) 64 | return loader 65 | 66 | def get_weighted_loader(data_dir, batch_size, weights, num_workers=3, pin_memory=False): 67 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 68 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 69 | transform_train = transforms.Compose([ 70 | transforms.ToTensor(), 71 | normalize, 72 | ]) 73 | dataset = CIFAR10(data_dir, train=True, download=True, transform=transform_train) 74 | sampler = WeightedRandomSampler(weights, len(weights), False) 75 | loader = DataLoader( 76 | dataset, batch_size=batch_size, 77 | shuffle=False, num_workers=num_workers, 78 | pin_memory=pin_memory, sampler=sampler 79 | ) 80 | return loader 81 | 82 | def get_test_loader(data_dir, batch_size, num_workers=3, pin_memory=False): 83 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 84 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 85 | transform_train = transforms.Compose([ 86 | transforms.ToTensor(), 87 | normalize, 88 | ]) 89 | dataset = CIFAR10(root=data_dir, train=False, download=True, transform=transform_train) 90 | loader = DataLoader( 91 | dataset, batch_size=batch_size, 92 | shuffle=False, num_workers=num_workers, 93 | pin_memory=pin_memory, 94 | ) 95 | return loader 96 | 97 | def main(): 98 | global args, best_prec1 99 | args = parser.parse_args() 100 | 101 | # ensuring reproducibility 102 | SEED = 42 103 | torch.manual_seed(SEED) 104 | torch.backends.cudnn.benchmark = False 105 | 106 | kwargs = {'num_workers': 1, 'pin_memory': True} 107 | device = torch.device("cuda") 108 | 109 | num_epochs_transient = 2 110 | num_epochs_steady = 7 111 | perc_to_remove = 10 112 | 113 | torch.manual_seed(SEED) 114 | 115 | # create model 116 | model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device) 117 | 118 | optimizer = torch.optim.Adam( 119 | model.parameters(), 120 | args.learning_rate, 121 | weight_decay=args.weight_decay 122 | ) 123 | 124 | # instantiate loaders 125 | train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs) 126 | test_loader = get_test_loader(args.data_dir, 128, **kwargs) 127 | 128 | tic = time.time() 129 | seen_losses = None 130 | for epoch in range(1, 3): 131 | if epoch == 1: 132 | seen_losses = train_transient(model, device, train_loader, optimizer, epoch, track=True) 133 | else: 134 | train_transient(model, device, train_loader, optimizer, epoch) 135 | test(model, device, test_loader, epoch) 136 | 137 | for epoch in range(3, 4): 138 | seen_losses = [v for sublist in seen_losses for v in sublist] 139 | sorted_loss_idx = sorted(range(len(seen_losses)), key=lambda k: seen_losses[k][1], reverse=True) 140 | removed = sorted_loss_idx[-int((perc_to_remove / 100) * len(sorted_loss_idx)):] 141 | sorted_loss_idx = sorted_loss_idx[:-int((perc_to_remove / 100) * len(sorted_loss_idx))] 142 | to_add = list(np.random.choice(removed, int(0.33*len(sorted_loss_idx)), replace=False)) 143 | sorted_loss_idx = sorted_loss_idx + to_add 144 | sorted_loss_idx.sort() 145 | weights = [seen_losses[idx][1] for idx in sorted_loss_idx] 146 | train_loader = get_weighted_loader(args.data_dir, 64*2, weights, **kwargs) 147 | seen_losses = train_steady_state(model, device, train_loader, optimizer, epoch) 148 | test(model, device, test_loader, epoch) 149 | 150 | for epoch in range(4, 8): 151 | train_transient(model, device, train_loader, optimizer, epoch) 152 | test(model, device, test_loader, epoch) 153 | toc = time.time() 154 | print("Time Elapsed: {}s".format(toc-tic)) 155 | 156 | 157 | def train_transient(model, device, train_loader, optimizer, epoch, track=False): 158 | """Train for one epoch on the training set""" 159 | losses = AverageMeter() 160 | top1 = AverageMeter() 161 | 162 | # switch to train mode 163 | model.train() 164 | epoch_stats = [] 165 | 166 | for batch_idx, (data, target) in enumerate(train_loader): 167 | data, target = data.to(device), target.to(device) 168 | 169 | # compute output 170 | output = model(data) 171 | losses_ = F.nll_loss(output, target, reduction='none') 172 | 173 | if track: 174 | indices = [batch_idx*train_loader.batch_size + i for i in range(len(data))] 175 | batch_stats = [] 176 | for i, l in zip(indices, losses_): 177 | batch_stats.append([i, l.item()]) 178 | epoch_stats.append(batch_stats) 179 | 180 | loss = losses_.mean() 181 | 182 | # measure accuracy and record loss 183 | prec1 = accuracy(output, target, topk=(1,))[0] 184 | losses.update(loss.item(), data.size(0)) 185 | top1.update(prec1.item(), data.size(0)) 186 | 187 | # compute gradient and do SGD step 188 | optimizer.zero_grad() 189 | loss.backward() 190 | optimizer.step() 191 | 192 | if batch_idx % args.print_freq == 0: 193 | print('Epoch: [{0}][{1}/{2}]\t' 194 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 195 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 196 | epoch, batch_idx, len(train_loader), loss=losses, top1=top1)) 197 | if track: 198 | return epoch_stats 199 | return None 200 | 201 | def train_steady_state(model, device, train_loader, optimizer, epoch): 202 | """Train for one epoch on the training set""" 203 | losses = AverageMeter() 204 | top1 = AverageMeter() 205 | 206 | # switch to train mode 207 | model.train() 208 | epoch_stats = [] 209 | 210 | for batch_idx, (data, target) in enumerate(train_loader): 211 | data, target = data.to(device), target.to(device) 212 | 213 | # compute output 214 | output = model(data) 215 | losses_ = F.nll_loss(output, target, reduction='none') 216 | indices = [batch_idx*train_loader.batch_size + i for i in range(len(data))] 217 | batch_stats = [] 218 | for i, l in zip(indices, losses_): 219 | batch_stats.append([i, l.item()]) 220 | epoch_stats.append(batch_stats) 221 | 222 | loss = losses_.mean() 223 | 224 | # measure accuracy and record loss 225 | prec1 = accuracy(output, target, topk=(1,))[0] 226 | losses.update(loss.item(), data.size(0)) 227 | top1.update(prec1.item(), data.size(0)) 228 | 229 | # compute gradient and do SGD step 230 | optimizer.zero_grad() 231 | loss.backward() 232 | optimizer.step() 233 | 234 | if batch_idx % args.print_freq == 0: 235 | print('Epoch: [{0}][{1}/{2}]\t' 236 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 237 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 238 | epoch, batch_idx, len(train_loader), loss=losses, top1=top1)) 239 | return epoch_stats 240 | 241 | def test(model, device, test_loader, epoch): 242 | losses = AverageMeter() 243 | top1 = AverageMeter() 244 | 245 | # switch to evaluate mode 246 | model.eval() 247 | 248 | for batch_idx, (data, target) in enumerate(test_loader): 249 | data, target = data.to(device), target.to(device) 250 | 251 | # compute output 252 | with torch.no_grad(): 253 | output = model(data) 254 | loss = F.nll_loss(output, target) 255 | 256 | # measure accuracy and record loss 257 | prec1 = accuracy(output, target, topk=(1,))[0] 258 | losses.update(loss.item(), data.size(0)) 259 | top1.update(prec1.item(), data.size(0)) 260 | 261 | if batch_idx % args.print_freq == 0: 262 | print('Test: [{0}/{1}]\t' 263 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 264 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 265 | batch_idx, len(test_loader), loss=losses, 266 | top1=top1)) 267 | 268 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 269 | return top1.avg 270 | 271 | class AverageMeter(object): 272 | """Computes and stores the average and current value""" 273 | def __init__(self): 274 | self.reset() 275 | 276 | def reset(self): 277 | self.val = 0 278 | self.avg = 0 279 | self.sum = 0 280 | self.count = 0 281 | 282 | def update(self, val, n=1): 283 | self.val = val 284 | self.sum += val * n 285 | self.count += n 286 | self.avg = self.sum / self.count 287 | 288 | 289 | def accuracy(output, target, topk=(1,)): 290 | """Computes the precision@k for the specified values of k""" 291 | maxk = max(topk) 292 | batch_size = target.size(0) 293 | 294 | _, pred = output.topk(maxk, 1, True, True) 295 | pred = pred.t() 296 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 297 | 298 | res = [] 299 | for k in topk: 300 | correct_k = correct[:k].view(-1).float().sum(0) 301 | res.append(correct_k.mul_(100.0 / batch_size)) 302 | return res 303 | 304 | if __name__ == '__main__': 305 | main() 306 | -------------------------------------------------------------------------------- /pr-lr/cifar_reg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torchvision.datasets as datasets 15 | from torchvision.datasets import CIFAR10 16 | from torchvision import transforms 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.sampler import Sampler, WeightedRandomSampler 19 | 20 | from models import WideResNet 21 | 22 | # used for logging to TensorBoard 23 | from tensorboard_logger import configure, log_value 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch WideResNet Training') 26 | parser.add_argument('--data_dir', default='./data/', type=str, 27 | help='data path') 28 | parser.add_argument('--epochs', default=10, type=int, 29 | help='number of total epochs to run') 30 | parser.add_argument('--batch-size', default=128, type=int, 31 | help='mini-batch size (default: 128)') 32 | parser.add_argument('--learning-rate', default=1e-3, type=float, 33 | help='initial learning rate') 34 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, #5e-4 35 | help='weight decay (default: 5e-4)') 36 | parser.add_argument('--print-freq', '-p', default=10, type=int, 37 | help='print frequency (default: 10)') 38 | parser.add_argument('--layers', default=28, type=int, 39 | help='total number of layers (default: 28)') 40 | parser.add_argument('--widen-factor', default=10, type=int, 41 | help='widen factor (default: 10)') 42 | parser.add_argument('--droprate', default=0, type=float, 43 | help='dropout probability (default: 0.0)') 44 | parser.add_argument('--no-augment', dest='augment', action='store_false', 45 | help='whether to use standard augmentation (default: True)') 46 | parser.add_argument('--name', default='WideResNet-28-10', type=str, 47 | help='name of experiment') 48 | 49 | best_prec1 = 0 50 | 51 | def get_data_loader(data_dir, batch_size, num_workers=3, pin_memory=False): 52 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 53 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 54 | transform_train = transforms.Compose([ 55 | transforms.ToTensor(), 56 | normalize, 57 | ]) 58 | dataset = CIFAR10(data_dir, train=True, download=True, transform=transform_train) 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, 61 | shuffle=False, num_workers=num_workers, 62 | pin_memory=pin_memory, 63 | ) 64 | return loader 65 | 66 | def get_weighted_loader(data_dir, batch_size, weights, num_workers=3, pin_memory=False): 67 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 68 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 69 | transform_train = transforms.Compose([ 70 | transforms.ToTensor(), 71 | normalize, 72 | ]) 73 | dataset = CIFAR10(data_dir, train=True, download=True, transform=transform_train) 74 | sampler = WeightedRandomSampler(weights, len(weights), True) 75 | loader = DataLoader( 76 | dataset, batch_size=batch_size, 77 | shuffle=False, num_workers=num_workers, 78 | pin_memory=pin_memory, sampler=sampler 79 | ) 80 | return loader 81 | 82 | def get_test_loader(data_dir, batch_size, num_workers=3, pin_memory=False): 83 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 84 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 85 | transform_train = transforms.Compose([ 86 | transforms.ToTensor(), 87 | normalize, 88 | ]) 89 | dataset = CIFAR10(root=data_dir, train=False, download=True, transform=transform_train) 90 | loader = DataLoader( 91 | dataset, batch_size=batch_size, 92 | shuffle=False, num_workers=num_workers, 93 | pin_memory=pin_memory, 94 | ) 95 | return loader 96 | 97 | def main(): 98 | global args, best_prec1 99 | args = parser.parse_args() 100 | 101 | # ensuring reproducibility 102 | SEED = 42 103 | torch.manual_seed(SEED) 104 | torch.backends.cudnn.benchmark = False 105 | 106 | kwargs = {'num_workers': 1, 'pin_memory': True} 107 | device = torch.device("cuda") 108 | 109 | num_epochs = 7 110 | 111 | # create model 112 | model = WideResNet(args.layers, 10, args.widen_factor, dropRate=args.droprate).to(device) 113 | 114 | optimizer = torch.optim.Adam( 115 | model.parameters(), 116 | args.learning_rate, 117 | weight_decay=args.weight_decay 118 | ) 119 | 120 | # instantiate loaders 121 | train_loader = get_data_loader(args.data_dir, args.batch_size, **kwargs) 122 | test_loader = get_test_loader(args.data_dir, 128, **kwargs) 123 | 124 | tic = time.time() 125 | for epoch in range(1, num_epochs+1): 126 | train(model, device, train_loader, optimizer, epoch) 127 | test(model, device, test_loader, epoch) 128 | toc = time.time() 129 | print("Time Elapsed: {}s".format(toc-tic)) 130 | 131 | 132 | def train(model, device, train_loader, optimizer, epoch): 133 | """Train for one epoch on the training set""" 134 | losses = AverageMeter() 135 | top1 = AverageMeter() 136 | 137 | # switch to train mode 138 | model.train() 139 | for batch_idx, (data, target) in enumerate(train_loader): 140 | data, target = data.to(device), target.to(device) 141 | 142 | # compute output 143 | output = model(data) 144 | loss = F.nll_loss(output, target) 145 | 146 | # measure accuracy and record loss 147 | prec1 = accuracy(output, target, topk=(1,))[0] 148 | losses.update(loss.item(), data.size(0)) 149 | top1.update(prec1.item(), data.size(0)) 150 | 151 | # compute gradient and do SGD step 152 | optimizer.zero_grad() 153 | loss.backward() 154 | optimizer.step() 155 | 156 | if batch_idx % args.print_freq == 0: 157 | print('Epoch: [{0}][{1}/{2}]\t' 158 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 159 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 160 | epoch, batch_idx, len(train_loader), loss=losses, top1=top1)) 161 | 162 | def test(model, device, test_loader, epoch): 163 | losses = AverageMeter() 164 | top1 = AverageMeter() 165 | 166 | # switch to evaluate mode 167 | model.eval() 168 | 169 | for batch_idx, (data, target) in enumerate(test_loader): 170 | data, target = data.to(device), target.to(device) 171 | 172 | # compute output 173 | with torch.no_grad(): 174 | output = model(data) 175 | loss = F.nll_loss(output, target) 176 | 177 | # measure accuracy and record loss 178 | prec1 = accuracy(output, target, topk=(1,))[0] 179 | losses.update(loss.item(), data.size(0)) 180 | top1.update(prec1.item(), data.size(0)) 181 | 182 | if batch_idx % args.print_freq == 0: 183 | print('Test: [{0}/{1}]\t' 184 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 185 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 186 | batch_idx, len(test_loader), loss=losses, 187 | top1=top1)) 188 | 189 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 190 | return top1.avg 191 | 192 | class AverageMeter(object): 193 | """Computes and stores the average and current value""" 194 | def __init__(self): 195 | self.reset() 196 | 197 | def reset(self): 198 | self.val = 0 199 | self.avg = 0 200 | self.sum = 0 201 | self.count = 0 202 | 203 | def update(self, val, n=1): 204 | self.val = val 205 | self.sum += val * n 206 | self.count += n 207 | self.avg = self.sum / self.count 208 | 209 | 210 | def accuracy(output, target, topk=(1,)): 211 | """Computes the precision@k for the specified values of k""" 212 | maxk = max(topk) 213 | batch_size = target.size(0) 214 | 215 | _, pred = output.topk(maxk, 1, True, True) 216 | pred = pred.t() 217 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 218 | 219 | res = [] 220 | for k in topk: 221 | correct_k = correct[:k].view(-1).float().sum(0) 222 | res.append(correct_k.mul_(100.0 / batch_size)) 223 | return res 224 | 225 | if __name__ == '__main__': 226 | main() 227 | -------------------------------------------------------------------------------- /pr-lr/classic.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | 8 | from torchvision.datasets import MNIST 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.sampler import Sampler, WeightedRandomSampler 12 | 13 | 14 | def get_data_loader(data_dir, batch_size, num_workers=3, pin_memory=False): 15 | normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) 16 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 17 | dataset = MNIST(root=data_dir, train=True, download=True, transform=transform) 18 | 19 | loader = DataLoader( 20 | dataset, batch_size=batch_size, 21 | shuffle=False, num_workers=num_workers, 22 | pin_memory=pin_memory, 23 | ) 24 | 25 | return loader 26 | 27 | 28 | def get_test_loader(data_dir, batch_size, num_workers=3, pin_memory=False): 29 | normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) 30 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 31 | dataset = MNIST(root=data_dir, train=False, download=True, transform=transform) 32 | loader = DataLoader( 33 | dataset, batch_size=batch_size, 34 | shuffle=False, num_workers=num_workers, 35 | pin_memory=pin_memory, 36 | ) 37 | return loader 38 | 39 | 40 | class SmallConv(nn.Module): 41 | def __init__(self): 42 | super(SmallConv, self).__init__() 43 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 44 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 45 | self.fc1 = nn.Linear(320, 50) 46 | self.fc2 = nn.Linear(50, 10) 47 | 48 | def forward(self, x): 49 | out = F.relu(F.max_pool2d(self.conv1(x), 2)) 50 | out = F.relu(F.max_pool2d(self.conv2(out), 2)) 51 | out = out.view(-1, 320) 52 | out = F.relu(self.fc1(out)) 53 | out = self.fc2(out) 54 | return F.log_softmax(out, dim=1) 55 | 56 | 57 | def accuracy(predicted, ground_truth): 58 | predicted = torch.max(predicted, 1)[1] 59 | total = len(ground_truth) 60 | correct = (predicted == ground_truth).sum().double() 61 | acc = 100 * (correct / total) 62 | return acc.item() 63 | 64 | 65 | def train(model, device, train_loader, optimizer, epoch): 66 | model.train() 67 | for batch_idx, (data, target) in enumerate(train_loader): 68 | data, target = data.to(device), target.to(device) 69 | optimizer.zero_grad() 70 | output = model(data) 71 | acc = accuracy(output, target) 72 | loss = F.nll_loss(output, target) 73 | loss.backward() 74 | optimizer.step() 75 | if batch_idx % 25 == 0: 76 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.2f}%'.format( 77 | epoch, batch_idx * len(data), len(train_loader.dataset), 78 | 100. * batch_idx / len(train_loader), loss.item(), acc)) 79 | 80 | 81 | def test(model, device, test_loader): 82 | model.eval() 83 | test_loss = 0 84 | correct = 0 85 | with torch.no_grad(): 86 | for data, target in test_loader: 87 | data, target = data.to(device), target.to(device) 88 | output = model(data) 89 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 90 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 91 | correct += pred.eq(target.view_as(pred)).sum().item() 92 | 93 | test_loss /= len(test_loader.dataset) 94 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 95 | test_loss, correct, len(test_loader.dataset), 96 | 100. * correct / len(test_loader.dataset))) 97 | 98 | 99 | def main(): 100 | data_dir = './data/' 101 | plot_dir = './imgs/' 102 | dump_dir = './dump/' 103 | 104 | # ensuring reproducibility 105 | SEED = 42 106 | torch.manual_seed(SEED) 107 | torch.backends.cudnn.benchmark = False 108 | 109 | GPU = False 110 | device = torch.device("cuda" if GPU else "cpu") 111 | kwargs = {'num_workers': 1, 'pin_memory': True} if GPU else {} 112 | 113 | num_epochs = 5 114 | num_epochs_steady = 3 115 | learning_rate = 1e-3 116 | mom = 0.99 117 | batch_size = 64 118 | 119 | torch.manual_seed(SEED) 120 | 121 | # instantiate convnet 122 | model = SmallConv().to(device) 123 | 124 | # relu init 125 | for m in model.modules(): 126 | if isinstance(m, (nn.Conv2d, nn.Linear)): 127 | nn.init.kaiming_normal_(m.weight, mode='fan_in') 128 | 129 | # define optimizer 130 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=mom) 131 | 132 | # instantiate loaders 133 | train_loader = get_data_loader(data_dir, batch_size, **kwargs) 134 | test_loader = get_test_loader(data_dir, 128, **kwargs) 135 | 136 | # transient training 137 | tic = time.time() 138 | for epoch in range(1, num_epochs+1): 139 | train(model, device, train_loader, optimizer, epoch) 140 | test(model, device, test_loader) 141 | toc = time.time() 142 | print("Time Elapsed: {}s".format(toc-tic)) 143 | 144 | 145 | if __name__ == '__main__': 146 | main() 147 | -------------------------------------------------------------------------------- /pr-lr/imgs/loss_vs_grad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/blog-code/1fb5ff808549cdb1bb0bd567e4d27d5ccff74aea/pr-lr/imgs/loss_vs_grad.jpg -------------------------------------------------------------------------------- /pr-lr/imgs/no_shuffling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/blog-code/1fb5ff808549cdb1bb0bd567e4d27d5ccff74aea/pr-lr/imgs/no_shuffling.jpg -------------------------------------------------------------------------------- /pr-lr/imgs/with_shuffling.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevinzakka/blog-code/1fb5ff808549cdb1bb0bd567e4d27d5ccff74aea/pr-lr/imgs/with_shuffling.jpg -------------------------------------------------------------------------------- /pr-lr/mini_batch.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | 8 | from torchvision.datasets import MNIST 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.sampler import Sampler, WeightedRandomSampler 12 | 13 | 14 | def get_data_loader(data_dir, batch_size, permutation=None, num_workers=3, pin_memory=False): 15 | normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) 16 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 17 | dataset = MNIST(root=data_dir, train=True, download=True, transform=transform) 18 | 19 | sampler = None 20 | if permutation is not None: 21 | sampler = LinearSampler(permutation) 22 | 23 | loader = DataLoader( 24 | dataset, batch_size=batch_size, 25 | shuffle=False, num_workers=num_workers, 26 | pin_memory=pin_memory, sampler=sampler 27 | ) 28 | 29 | return loader 30 | 31 | 32 | def get_weighted_loader(data_dir, batch_size, weights, num_workers=3, pin_memory=False): 33 | normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) 34 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 35 | dataset = MNIST(root=data_dir, train=True, download=True, transform=transform) 36 | 37 | sampler = WeightedRandomSampler(weights, len(weights), True) 38 | 39 | loader = DataLoader( 40 | dataset, batch_size=batch_size, 41 | shuffle=False, num_workers=num_workers, 42 | pin_memory=pin_memory, sampler=sampler 43 | ) 44 | 45 | return loader 46 | 47 | 48 | def get_test_loader(data_dir, batch_size, num_workers=3, pin_memory=False): 49 | normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,)) 50 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 51 | dataset = MNIST(root=data_dir, train=False, download=True, transform=transform) 52 | loader = DataLoader( 53 | dataset, batch_size=batch_size, 54 | shuffle=False, num_workers=num_workers, 55 | pin_memory=pin_memory, 56 | ) 57 | return loader 58 | 59 | 60 | class LinearSampler(Sampler): 61 | def __init__(self, idx): 62 | self.idx = idx 63 | 64 | def __iter__(self): 65 | return iter(self.idx) 66 | 67 | def __len__(self): 68 | return len(self.idx) 69 | 70 | 71 | class SmallConv(nn.Module): 72 | def __init__(self): 73 | super(SmallConv, self).__init__() 74 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 75 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 76 | self.fc1 = nn.Linear(320, 50) 77 | self.fc2 = nn.Linear(50, 10) 78 | 79 | def forward(self, x): 80 | out = F.relu(F.max_pool2d(self.conv1(x), 2)) 81 | out = F.relu(F.max_pool2d(self.conv2(out), 2)) 82 | out = out.view(-1, 320) 83 | out = F.relu(self.fc1(out)) 84 | out = self.fc2(out) 85 | return F.log_softmax(out, dim=1) 86 | 87 | 88 | def accuracy(predicted, ground_truth): 89 | predicted = torch.max(predicted, 1)[1] 90 | total = len(ground_truth) 91 | correct = (predicted == ground_truth).sum().double() 92 | acc = 100 * (correct / total) 93 | return acc.item() 94 | 95 | 96 | def train_transient(model, device, train_loader, optimizer, epoch, track=False): 97 | model.train() 98 | epoch_stats = [] 99 | for batch_idx, (data, target) in enumerate(train_loader): 100 | data, target = data.to(device), target.to(device) 101 | optimizer.zero_grad() 102 | output = model(data) 103 | acc = accuracy(output, target) 104 | losses = F.nll_loss(output, target, reduction='none') 105 | if track: 106 | indices = [batch_idx*train_loader.batch_size + i for i in range(len(data))] 107 | batch_stats = [] 108 | for i, l in zip(indices, losses): 109 | batch_stats.append([i, l.item()]) 110 | epoch_stats.append(batch_stats) 111 | loss = losses.mean() 112 | loss.backward() 113 | optimizer.step() 114 | if batch_idx % 25 == 0: 115 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.2f}%'.format( 116 | epoch, batch_idx * len(data), len(train_loader.dataset), 117 | 100. * batch_idx / len(train_loader), loss.item(), acc)) 118 | if track: 119 | return epoch_stats 120 | return None 121 | 122 | 123 | def train_steady_state(model, device, train_loader, optimizer, epoch): 124 | model.train() 125 | epoch_stats = [] 126 | for batch_idx, (data, target) in enumerate(train_loader): 127 | data, target = data.to(device), target.to(device) 128 | optimizer.zero_grad() 129 | output = model(data) 130 | acc = accuracy(output, target) 131 | losses = F.nll_loss(output, target, reduction='none') 132 | indices = [batch_idx*train_loader.batch_size + i for i in range(len(data))] 133 | batch_stats = [] 134 | for i, l in zip(indices, losses): 135 | batch_stats.append([i, l.item()]) 136 | epoch_stats.append(batch_stats) 137 | loss = losses.mean() 138 | loss.backward() 139 | optimizer.step() 140 | if batch_idx % 25 == 0: 141 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.2f}%'.format( 142 | epoch, batch_idx * len(data), len(train_loader.dataset), 143 | 100. * batch_idx / len(train_loader), loss.item(), acc)) 144 | return epoch_stats 145 | 146 | 147 | def test(model, device, test_loader): 148 | model.eval() 149 | test_loss = 0 150 | correct = 0 151 | with torch.no_grad(): 152 | for data, target in test_loader: 153 | data, target = data.to(device), target.to(device) 154 | output = model(data) 155 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 156 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 157 | correct += pred.eq(target.view_as(pred)).sum().item() 158 | 159 | test_loss /= len(test_loader.dataset) 160 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 161 | test_loss, correct, len(test_loader.dataset), 162 | 100. * correct / len(test_loader.dataset))) 163 | 164 | 165 | def main(): 166 | data_dir = './data/' 167 | plot_dir = './imgs/' 168 | dump_dir = './dump/' 169 | 170 | # ensuring reproducibility 171 | SEED = 42 172 | torch.manual_seed(SEED) 173 | torch.backends.cudnn.benchmark = False 174 | 175 | GPU = False 176 | device = torch.device("cuda" if GPU else "cpu") 177 | kwargs = {'num_workers': 1, 'pin_memory': True} if GPU else {} 178 | 179 | num_epochs_transient = 2 180 | num_epochs_steady = 3 181 | learning_rate = 1e-3 182 | mom = 0.99 183 | batch_size = 64 184 | normalize = False 185 | perc_to_remove = 10 186 | 187 | torch.manual_seed(SEED) 188 | 189 | # instantiate convnet 190 | model = SmallConv().to(device) 191 | 192 | # relu init 193 | for m in model.modules(): 194 | if isinstance(m, (nn.Conv2d, nn.Linear)): 195 | nn.init.kaiming_normal_(m.weight, mode='fan_in') 196 | 197 | # define optimizer 198 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=mom) 199 | 200 | # instantiate loaders 201 | train_loader = get_data_loader(data_dir, batch_size, None, **kwargs) 202 | test_loader = get_test_loader(data_dir, 128, **kwargs) 203 | 204 | # transient training 205 | tic = time.time() 206 | losses = None 207 | for epoch in range(1, num_epochs_transient+1): 208 | if epoch == 1: 209 | losses = train_transient(model, device, train_loader, optimizer, epoch, track=True) 210 | else: 211 | train_transient(model, device, train_loader, optimizer, epoch) 212 | test(model, device, test_loader) 213 | 214 | for epoch in range(num_epochs_transient, num_epochs_steady+1): 215 | losses = [v for sublist in losses for v in sublist] 216 | sorted_loss_idx = sorted(range(len(losses)), key=lambda k: losses[k][1], reverse=True) 217 | removed = sorted_loss_idx[-int((perc_to_remove / 100) * len(sorted_loss_idx)):] 218 | sorted_loss_idx = sorted_loss_idx[:-int((perc_to_remove / 100) * len(sorted_loss_idx))] 219 | to_add = list(np.random.choice(removed, int(0.01*len(sorted_loss_idx)), replace=False)) 220 | sorted_loss_idx = sorted_loss_idx + to_add 221 | sorted_loss_idx.sort() 222 | weights = [losses[idx][1] for idx in sorted_loss_idx] 223 | if normalize: 224 | max_w = max(weights) 225 | weights = [w / max_w for w in weights] 226 | train_loader = get_weighted_loader(data_dir, 128, weights, **kwargs) 227 | print("\t[*] Effective Size: {:,}".format(len(train_loader.sampler))) 228 | losses = train_steady_state(model, device, train_loader, optimizer, epoch) 229 | test(model, device, test_loader) 230 | toc = time.time() 231 | print("Time Elapsed: {}s".format(toc-tic)) 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /pr-lr/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | def forward(self, x): 23 | if not self.equalInOut: 24 | x = self.relu1(self.bn1(x)) 25 | else: 26 | out = self.relu1(self.bn1(x)) 27 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 28 | if self.droprate > 0: 29 | out = F.dropout(out, p=self.droprate, training=self.training) 30 | out = self.conv2(out) 31 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 32 | 33 | class NetworkBlock(nn.Module): 34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 35 | super(NetworkBlock, self).__init__() 36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 38 | layers = [] 39 | for i in range(int(nb_layers)): 40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 41 | return nn.Sequential(*layers) 42 | def forward(self, x): 43 | return self.layer(x) 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 47 | super(WideResNet, self).__init__() 48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 49 | assert((depth - 4) % 6 == 0) 50 | n = (depth - 4) / 6 51 | block = BasicBlock 52 | # 1st conv before any network block 53 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 54 | padding=1, bias=False) 55 | # 1st block 56 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 57 | # 2nd block 58 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 59 | # 3rd block 60 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 61 | # global average pooling and classifier 62 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.fc = nn.Linear(nChannels[3], num_classes) 65 | self.nChannels = nChannels[3] 66 | 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 70 | m.weight.data.normal_(0, math.sqrt(2. / n)) 71 | elif isinstance(m, nn.BatchNorm2d): 72 | m.weight.data.fill_(1) 73 | m.bias.data.zero_() 74 | elif isinstance(m, nn.Linear): 75 | m.bias.data.zero_() 76 | def forward(self, x): 77 | out = self.conv1(x) 78 | out = self.block1(out) 79 | out = self.block2(out) 80 | out = self.block3(out) 81 | out = self.relu(self.bn1(out)) 82 | out = F.avg_pool2d(out, 8) 83 | out = out.view(-1, self.nChannels) 84 | out = self.fc(out) 85 | out = F.log_softmax(out, dim=1) 86 | return out 87 | -------------------------------------------------------------------------------- /spatial-transformer/interpolation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils import * 4 | 5 | def main(): 6 | 7 | DIMS = (500, 500) 8 | data_path = './data/' 9 | 10 | # load 4 cat images 11 | img1 = img_to_array(data_path + 'cat1.jpg', DIMS) 12 | img2 = img_to_array(data_path + 'cat2.jpg', DIMS, view=True) 13 | 14 | # concat into tensor of shape (2, 400, 400, 3) 15 | input_img = np.concatenate([img1, img2], axis=0) 16 | 17 | # dimension sanity check 18 | print("Input Img Shape: {}".format(input_img.shape)) 19 | 20 | # grab shape 21 | B, H, W, C = input_img.shape 22 | 23 | # initialize theta to identity transform 24 | M = np.array([[1., 0., 0.], [0., 1., 0.]]) 25 | 26 | # repeat num_batch times 27 | M = np.resize(M, (B, 2, 3)) 28 | 29 | # get grids 30 | batch_grids = affine_grid_generator(H, W, M) 31 | 32 | x_s = batch_grids[:, :, :, 0:1].squeeze() 33 | y_s = batch_grids[:, :, :, 1:2].squeeze() 34 | 35 | out = bilinear_sampler(input_img, x_s, y_s) 36 | print("Out Img Shape: {}".format(out.shape)) 37 | 38 | # view the 2nd image 39 | array_to_img(out[-1]).show() 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /spatial-transformer/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image 4 | 5 | 6 | def affine_grid_generator(height, width, M): 7 | """ 8 | This function returns a sampling grid, which when 9 | used with the bilinear sampler on the input img, 10 | will create an output img that is an affine 11 | transformation of the input. 12 | 13 | Input 14 | ----- 15 | - M: affine transform matrices of shape (num_batch, 2, 3). 16 | For each image in the batch, we have 6 parameters of 17 | the form (2x3) that define the affine transformation T. 18 | 19 | Returns 20 | ------- 21 | - normalized gird (-1, 1) of shape (num_batch, H, W, 2). 22 | The 4th dimension has 2 components: (x, y) which are the 23 | sampling points of the original image for each point in the 24 | target image. 25 | """ 26 | # grab batch size 27 | num_batch = M.shape[0] 28 | 29 | # create normalized 2D grid 30 | x = np.linspace(-1, 1, width) 31 | y = np.linspace(-1, 1, height) 32 | x_t, y_t = np.meshgrid(x, y) 33 | 34 | # reshape to (xt, yt, 1) 35 | ones = np.ones(np.prod(x_t.shape)) 36 | sampling_grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 37 | # homogeneous coordinates 38 | 39 | # repeat grid num_batch times 40 | sampling_grid = np.resize(sampling_grid, (num_batch, 3, height*width)) 41 | 42 | # transform the sampling grid - batch multiply 43 | batch_grids = np.matmul(M, sampling_grid) 44 | # batch grid has shape (num_batch, 2, H*W) 45 | 46 | # reshape to (num_batch, H, W, 2) 47 | batch_grids = batch_grids.reshape(num_batch, 2, height, width) 48 | batch_grids = np.moveaxis(batch_grids, 1, -1) 49 | 50 | # sanity check 51 | print("Transformation Matrices: {}".format(M.shape)) 52 | print("Sampling Grid: {}".format(sampling_grid.shape)) 53 | print("Batch Grids: {}".format(batch_grids.shape)) 54 | 55 | return batch_grids 56 | 57 | 58 | def bilinear_sampler(input_img, x, y): 59 | """ 60 | Performs bilinear sampling of the input images according to the 61 | normalized coordinates provided by the sampling grid. Note that 62 | the sampling is done identically for each channel of the input. 63 | 64 | To test if the function works properly, output image should be 65 | identical to input image when theta is initialized to identity 66 | transform. 67 | 68 | Input 69 | ----- 70 | - input_imgs: batch of images in (B, H, W, C) layout. 71 | - grid: x, y which is the output of affine_grid_generator. 72 | 73 | Returns 74 | ------- 75 | - interpolated images according to grids. Same size as grid. 76 | """ 77 | # grab dimensions 78 | B, H, W, C = input_img.shape 79 | 80 | max_y = (H - 1) 81 | max_x = (W - 1) 82 | 83 | x = x.astype(np.float32) 84 | y = y.astype(np.float32) 85 | 86 | # rescale x and y to [0, W/H] 87 | x = ((x + 1.) * max_x) * 0.5 88 | y = ((y + 1.) * max_y) * 0.5 89 | 90 | # grab 4 nearest corner points for each (x_i, y_i) 91 | x0 = np.floor(x).astype(np.int64) 92 | x1 = x0 + 1 93 | y0 = np.floor(y).astype(np.int64) 94 | y1 = y0 + 1 95 | 96 | # calculate deltas 97 | wa = (x1-x) * (y1-y) 98 | wb = (x1-x) * (y-y0) 99 | wc = (x-x0) * (y1-y) 100 | wd = (x-x0) * (y-y0) 101 | 102 | x0 = x0.astype(np.int32) 103 | y0 = y0.astype(np.int32) 104 | x1 = x1.astype(np.int32) 105 | y1 = y1.astype(np.int32) 106 | 107 | # make sure it's inside img range [0, H] or [0, W] 108 | x0 = np.clip(x0, 0, max_x) 109 | x1 = np.clip(x1, 0, max_x) 110 | y0 = np.clip(y0, 0, max_y) 111 | y1 = np.clip(y1, 0, max_y) 112 | 113 | # look up pixel values at corner coords 114 | Ia = input_img[np.arange(B)[:, None, None], y0, x0] 115 | Ib = input_img[np.arange(B)[:, None, None], y1, x0] 116 | Ic = input_img[np.arange(B)[:, None, None], y0, x1] 117 | Id = input_img[np.arange(B)[:, None, None], y1, x1] 118 | 119 | # add dimension for addition 120 | wa = np.expand_dims(wa, axis=3) 121 | wb = np.expand_dims(wb, axis=3) 122 | wc = np.expand_dims(wc, axis=3) 123 | wd = np.expand_dims(wd, axis=3) 124 | 125 | # compute output 126 | out = wa*Ia + wb*Ib + wc*Ic + wd*Id 127 | 128 | return out 129 | 130 | 131 | def img_to_array(data_path, desired_size=None, view=False): 132 | """ 133 | Util function for loading RGB image into 4D numpy array. 134 | 135 | Returns array of shape (1, H, W, C) 136 | 137 | References 138 | ---------- 139 | - adapted from keras preprocessing/image.py 140 | """ 141 | img = Image.open(data_path) 142 | img = img.convert('RGB') 143 | if desired_size: 144 | img = img.resize((desired_size[1], desired_size[0])) 145 | if view: 146 | img.show() 147 | x = np.asarray(img, dtype='float32') 148 | x = np.expand_dims(x, axis=0) 149 | x /= 255.0 150 | return x 151 | 152 | 153 | def array_to_img(x): 154 | """ 155 | Util function for converting 4D numpy array to numpy array. 156 | 157 | Returns PIL RGB image. 158 | 159 | References 160 | ---------- 161 | - adapted from keras preprocessing/image.py 162 | """ 163 | x = np.asarray(x) 164 | x += max(-np.min(x), 0) 165 | x_max = np.max(x) 166 | if x_max != 0: 167 | x /= x_max 168 | x *= 255 169 | return Image.fromarray(x.astype('uint8'), 'RGB') 170 | --------------------------------------------------------------------------------