├── 101 ├── linear_model_from_np_to_torch.py ├── linear_model_np.py └── linear_model_torch.py ├── 301 ├── model_trainer.py ├── plot_squares.py ├── seq2seq_simple.py ├── seq2seq_transformer.py ├── sequence_classification_model.py └── square_data_generation.py ├── 401 └── SDE_diffusion_education.ipynb ├── 201_mnist ├── MNIST.ipynb ├── README.md ├── main.py └── requirements.txt └── README.md /101/linear_model_from_np_to_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | There are illustration codes on how to use PyTorch utilities to reduce the work 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | """ 9 | Prepare Data 10 | """ 11 | n = 1000 12 | rv = np.random.RandomState(0) 13 | x = rv.uniform(0, 1, [n, 1]) 14 | y = 1 + 2 * x + 0.1 * rv.normal(0, 1, [n, 1]) 15 | 16 | idx = np.arange(n) 17 | rv.shuffle(idx) 18 | 19 | train_idx = idx[:int(0.8 * n)] 20 | test_idx = idx[int(0.8 * n):] 21 | 22 | x_train, y_train = x[train_idx], y[train_idx] 23 | x_test, y_test = x[test_idx], y[test_idx] 24 | 25 | """ 26 | PyTorch Training #1: 27 | Use autograd to reduce the work 28 | """ 29 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | torch.manual_seed(0) 31 | n_epochs = 1000 32 | lr = 1e-1 33 | 34 | # when create a new tensor, requires_grad often set as false, so we need to set to True explicitly. 35 | a = torch.randn(1, requires_grad=True, dtype=torch.float, device=device) 36 | b = torch.randn(1, requires_grad=True, dtype=torch.float, device=device) 37 | 38 | x_train_tensor = torch.from_numpy(x_train).to(device, dtype=torch.float) 39 | y_train_tensor = torch.from_numpy(y_train).to(device, dtype=torch.float) 40 | 41 | for epoch in range(n_epochs): 42 | yhat = a + b * x_train_tensor 43 | error = y_train_tensor - yhat 44 | loss = (error ** 2).mean() 45 | 46 | # No more manual computation of gradients! 47 | # a_grad = -2 * error.mean() 48 | # b_grad = -2 * (x_tensor * error).mean() 49 | # We just tell PyTorch to work its way BACKWARDS from the specified loss! 50 | loss.backward() 51 | 52 | with torch.no_grad(): 53 | a -= lr * a.grad # <1> 54 | b -= lr * b.grad 55 | 56 | a.grad.zero_() 57 | b.grad.zero_() 58 | 59 | print(a, b) 60 | 61 | # <1> this code is different than a = a - lr*a.grad, the -= replace 'a' in place, but the later one create a new a. 62 | # Also when a new 'a' is created, we lose its gradients. it is also important to set an enviorment with 63 | # torch.no_grad() so that it will not add this update into gradient computation. 64 | 65 | """ 66 | PyTorch Training #2: 67 | Use optimizer to reduce the work 68 | """ 69 | import torch.optim as optim 70 | 71 | optimizer = optim.SGD([a,b], lr=lr) 72 | 73 | for epoch in range(n_epochs): 74 | yhat = a + b * x_train_tensor 75 | error = y_train_tensor - yhat 76 | loss = (error ** 2).mean() 77 | 78 | # No more manual computation of gradients! 79 | # a_grad = -2 * error.mean() 80 | # b_grad = -2 * (x_tensor * error).mean() 81 | # We just tell PyTorch to work its way BACKWARDS from the specified loss! 82 | loss.backward() 83 | 84 | # with torch.no_grad(): 85 | # a -= lr * a.grad # <1> 86 | # b -= lr * b.grad 87 | optimizer.step() 88 | # a.grad.zero_() 89 | # b.grad.zero_() 90 | optimizer.zero_grad() 91 | 92 | print(a, b) 93 | 94 | """ 95 | PyTorch Training #3: 96 | Use Loss to reduce the work 97 | """ 98 | import torch.nn as nn 99 | loss_fn = nn.MSELoss(reduction='mean') 100 | 101 | for epoch in range(n_epochs): 102 | yhat = a + b * x_train_tensor 103 | 104 | # error = y_train_tensor - yhat 105 | # loss = (error ** 2).mean() 106 | loss = loss_fn(y_train_tensor, yhat) 107 | # No more manual computation of gradients! 108 | # a_grad = -2 * error.mean() 109 | # b_grad = -2 * (x_tensor * error).mean() 110 | # We just tell PyTorch to work its way BACKWARDS from the specified loss! 111 | loss.backward() 112 | 113 | # with torch.no_grad(): 114 | # a -= lr * a.grad # <1> 115 | # b -= lr * b.grad 116 | optimizer.step() 117 | # a.grad.zero_() 118 | # b.grad.zero_() 119 | optimizer.zero_grad() 120 | 121 | print(a, b) 122 | 123 | """ 124 | PyTorch Training #4: 125 | Use model to reduce the work 126 | """ 127 | 128 | class ManualLinearRegression(nn.Module): 129 | def __init__(self): 130 | super().__init__() 131 | self.a = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float)) 132 | self.b = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float)) 133 | 134 | def forward(self, x): 135 | return self.a + self.b*x 136 | 137 | # create the model and sent to device 138 | model = ManualLinearRegression().to(device) 139 | optimizer = optim.SGD(model.parameters(), lr=lr) 140 | for epoch in range(n_epochs): 141 | # yhat = a + b * x_train_tensor 142 | model.train() 143 | yhat = model(x_train_tensor) 144 | # error = y_train_tensor - yhat 145 | # loss = (error ** 2).mean() 146 | loss = loss_fn(y_train_tensor, yhat) 147 | # No more manual computation of gradients! 148 | # a_grad = -2 * error.mean() 149 | # b_grad = -2 * (x_tensor * error).mean() 150 | # We just tell PyTorch to work its way BACKWARDS from the specified loss! 151 | loss.backward() 152 | 153 | # with torch.no_grad(): 154 | # a -= lr * a.grad # <1> 155 | # b -= lr * b.grad 156 | optimizer.step() 157 | # a.grad.zero_() 158 | # b.grad.zero_() 159 | optimizer.zero_grad() 160 | 161 | print(model.state_dict()) 162 | 163 | """ 164 | PyTorch Training #5: 165 | Improve writing the model using built in sequential models 166 | """ 167 | 168 | class LayerLinearRegression(nn.Module): 169 | def __init__(self): 170 | super().__init__() 171 | self.linear = nn.Linear(1,1) 172 | 173 | def forward(self, x): 174 | return self.linear(x) 175 | 176 | model = LayerLinearRegression().to(device) 177 | optimizer = optim.SGD(model.parameters(), lr=lr) 178 | for epoch in range(n_epochs): 179 | model.train() 180 | yhat = model(x_train_tensor) 181 | loss = loss_fn(y_train_tensor, yhat) 182 | loss.backward() 183 | optimizer.step() 184 | optimizer.zero_grad() 185 | 186 | print(model.state_dict()) 187 | 188 | """ 189 | PyTorch Training #6: 190 | Put model, loss_fn, and optimizer into a train step function 191 | """ 192 | 193 | def make_train_step(model, loss_fn, optimizer): 194 | def train_step(x,y): 195 | model.train() 196 | yhat = model(x) 197 | loss = loss_fn(y, yhat) 198 | loss.backward() 199 | optimizer.step() 200 | optimizer.zero_grad() 201 | return loss.item() 202 | return train_step 203 | 204 | train_step = make_train_step(model,loss_fn, optimizer) 205 | 206 | losses =[] 207 | 208 | for epoch in range(n_epochs): 209 | loss = train_step(x_train_tensor,y_train_tensor) 210 | losses.append(loss) 211 | print(model.state_dict()) 212 | 213 | -------------------------------------------------------------------------------- /101/linear_model_np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | Data Generation 5 | """ 6 | n = 100 7 | rv = np.random.RandomState(0) # use random.RandomState to protect thread safe 8 | 9 | x = rv.uniform(0,1,[n,1]) 10 | y = 1 + 2*x + 0.1*rv.normal(0,1,[n,1]) 11 | 12 | idx = np.arange(n) 13 | rv.shuffle(idx) 14 | train_idx = idx[:int(0.8*n)] 15 | test_idx = idx[int(0.8*n):] 16 | x_train, y_train = x[train_idx], y[train_idx] 17 | x_test, y_test = x[test_idx], y[test_idx] 18 | 19 | """ 20 | Fit Model by Numpy 21 | """ 22 | # initializes parameters "a" and "b" randomly 23 | a = rv.normal(0,1,1) 24 | b = rv.normal(0,1,1) 25 | 26 | # define learning rate, number of epochs 27 | lr = 1e-1 28 | n_epochs = 1000 29 | 30 | for epoch in range(n_epochs): 31 | # pathforward to calculate the loss 32 | yhat = a + b*x_train 33 | error = y_train - yhat 34 | loss = (error**2).mean() 35 | 36 | a_grad = -2*error.mean() 37 | b_grad = -2*(x_train*error).mean() 38 | 39 | a -= lr*a_grad 40 | b -= lr*b_grad 41 | 42 | print(f'a={a}, b={b}') 43 | 44 | # compare results with linear regression package 45 | from sklearn.linear_model import LinearRegression 46 | lm = LinearRegression() 47 | lm.fit(x_train, y_train) # reshape the x into a 2D array 48 | print(lm.intercept_,lm.coef_) 49 | 50 | # change the learning rate will chagne the results slightly 51 | # if the dataset is very large, we need to split it into mini-batches 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /101/linear_model_torch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the final code using PyTorch to fit a linear model. It contains 3 | 1. How to prepare dataset and how to use data loader 4 | 2. How to write a make a train step function 5 | 3. Integrated model, loss function, and optimizer 6 | """ 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split 12 | 13 | """ 14 | Generate Data 15 | """ 16 | n = 1000 17 | rv = np.random.RandomState(0) 18 | x = rv.uniform(0, 1, [n,1]) 19 | y = 1 + 2 * x + 0.1*rv.normal(0, 1, [n,1]) 20 | 21 | """ 22 | Fit Model Using PyTorch 23 | """ 24 | 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | """ 28 | Prepare for the data 29 | """ 30 | class CustomDataSet(Dataset): 31 | def __init__(self, x_tensor, y_tensor): 32 | self.x = x_tensor 33 | self.y = y_tensor 34 | 35 | def __getitem__(self, index): 36 | return (self.x[index], self.y[index]) 37 | 38 | def __len__(self): 39 | return len(self.x) 40 | 41 | # x_train_tensor = torch.from_numpy(x_train).float() 42 | # y_train_tensor = torch.from_numpy(y_train).float() 43 | # train_data = CustomDataSet(x_train_tensor, y_train_tensor) 44 | # the above lines generate the same results as using TensorDataset, 45 | # however when data are complicated or very large, we need to write our own dataset by loading data from disk. 46 | # print(train_data[0]) 47 | # train_data = TensorDataset(x_train_tensor, y_train_tensor) 48 | # print(train_data[0]) 49 | 50 | x_tensor = torch.from_numpy(x).float() 51 | y_tensor = torch.from_numpy(y).float() 52 | 53 | dataset = TensorDataset(x_tensor, y_tensor) 54 | lengths = [int(len(dataset)*0.8), int(len(dataset)*0.2)] 55 | train_dataset, test_dataset = random_split(dataset, lengths) 56 | 57 | train_loader = DataLoader(dataset = train_dataset, batch_size = 16, shuffle = True) 58 | test_loader = DataLoader(dataset = test_dataset, batch_size = 20) 59 | 60 | """ 61 | Define Model, Loss Function, Optimizer 62 | """ 63 | 64 | class LayerLinearRegression(nn.Module): 65 | def __init__(self): 66 | super().__init__() 67 | self.linear = nn.Linear(1,1) 68 | 69 | def forward(self, x): 70 | return self.linear(x) 71 | 72 | loss_fn = nn.MSELoss(reduction='mean') 73 | model = LayerLinearRegression().to(device) 74 | optimizer = optim.SGD(model.parameters(), lr = 1e-1) 75 | 76 | """ 77 | Make A Train Step 78 | """ 79 | def make_train_step(model, loss_fn, optimizer): 80 | def train_step(x, y): 81 | # set the model in train state 82 | model.train() 83 | yhat = model(x) 84 | loss = loss_fn(y, yhat) 85 | loss.backward() 86 | optimizer.step() 87 | optimizer.zero_grad() 88 | return loss.item() 89 | return train_step 90 | 91 | """"Train The Model""" 92 | losses = [] 93 | test_losses = [] 94 | n_epochs = 1000 95 | 96 | train_step = make_train_step(model, loss_fn, optimizer) 97 | 98 | for epoch in range(n_epochs): 99 | for x_batch, y_batch in train_loader: 100 | x_batch = x_batch.to(device) 101 | y_batch = y_batch.to(device) 102 | loss = train_step(x_batch,y_batch) 103 | losses.append(loss) 104 | 105 | with torch.no_grad(): 106 | for x_test, y_test in test_loader: 107 | x_test = x_test.to(device) 108 | y_test = y_test.to(device) 109 | model.eval() 110 | yhat = model(x_test) 111 | test_loss = loss_fn(y_test, yhat) 112 | test_losses.append(test_loss.item()) 113 | 114 | print(model.state_dict()) 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /201_mnist/MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "aed4a4d5", 6 | "metadata": {}, 7 | "source": [ 8 | "# MNIST Explore" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 2, 14 | "id": "93828254", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import argparse\n", 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "import torch.nn.functional as F\n", 22 | "import torch.optim as optim\n", 23 | "from torchvision import datasets, transforms\n", 24 | "from torch.optim.lr_scheduler import StepLR" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "id": "046e50ea", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "train_kwargs = {'batch_size': 64, 'num_workers': 1, 'pin_memory': True, 'shuffle': True}\n", 35 | "test_kwargs = {'batch_size': 1000, 'num_workers': 1, 'pin_memory': True, 'shuffle': True}" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 4, 41 | "id": "2522ca58", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "transform=transforms.Compose([\n", 46 | " transforms.ToTensor(),\n", 47 | " transforms.Normalize((0.1307,), (0.3081,))\n", 48 | " ])\n", 49 | "dataset1 = datasets.MNIST('../data', train=True, download=True,\n", 50 | " transform=transform)\n", 51 | "dataset2 = datasets.MNIST('../data', train=False,\n", 52 | " transform=transform)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "id": "872a09a9", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "60000" 65 | ] 66 | }, 67 | "execution_count": 5, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "len(dataset1)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 9, 79 | "id": "3f249e6f", 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "torchvision.datasets.mnist.MNIST" 86 | ] 87 | }, 88 | "execution_count": 9, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "type(dataset1)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 10, 100 | "id": "e5656c1c", 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "tuple" 107 | ] 108 | }, 109 | "execution_count": 10, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "type(dataset1[0])" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 24, 121 | "id": "efe1297a", 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "(tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 128 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 129 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 130 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 131 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 132 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 133 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 134 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 135 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 136 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 137 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 138 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 139 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 140 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 141 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 142 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 143 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 144 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 145 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 146 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 147 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 148 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.3860, -0.1951,\n", 149 | " -0.1951, -0.1951, 1.1795, 1.3068, 1.8032, -0.0933, 1.6887,\n", 150 | " 2.8215, 2.7197, 1.1923, -0.4242, -0.4242, -0.4242, -0.4242],\n", 151 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 152 | " -0.4242, -0.0424, 0.0340, 0.7722, 1.5359, 1.7396, 2.7960,\n", 153 | " 2.7960, 2.7960, 2.7960, 2.7960, 2.4396, 1.7650, 2.7960,\n", 154 | " 2.6560, 2.0578, 0.3904, -0.4242, -0.4242, -0.4242, -0.4242],\n", 155 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 156 | " 0.1995, 2.6051, 2.7960, 2.7960, 2.7960, 2.7960, 2.7960,\n", 157 | " 2.7960, 2.7960, 2.7960, 2.7706, 0.7595, 0.6195, 0.6195,\n", 158 | " 0.2886, 0.0722, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 159 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 160 | " -0.1951, 2.3633, 2.7960, 2.7960, 2.7960, 2.7960, 2.7960,\n", 161 | " 2.0960, 1.8923, 2.7197, 2.6433, -0.4242, -0.4242, -0.4242,\n", 162 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 163 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 164 | " -0.4242, 0.5940, 1.5614, 0.9377, 2.7960, 2.7960, 2.1851,\n", 165 | " -0.2842, -0.4242, 0.1231, 1.5359, -0.4242, -0.4242, -0.4242,\n", 166 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 167 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 168 | " -0.4242, -0.4242, -0.2460, -0.4115, 1.5359, 2.7960, 0.7213,\n", 169 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 170 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 171 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 172 | " -0.4242, -0.4242, -0.4242, -0.4242, 1.3450, 2.7960, 1.9942,\n", 173 | " -0.3988, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 174 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 175 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 176 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.2842, 1.9942, 2.7960,\n", 177 | " 0.4668, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 178 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 179 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 180 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.0213, 2.6433,\n", 181 | " 2.4396, 1.6123, 0.9504, -0.4115, -0.4242, -0.4242, -0.4242,\n", 182 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 183 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 184 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.6068,\n", 185 | " 2.6306, 2.7960, 2.7960, 1.0904, -0.1060, -0.4242, -0.4242,\n", 186 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 187 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 188 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 189 | " 0.1486, 1.9432, 2.7960, 2.7960, 1.4850, -0.0806, -0.4242,\n", 190 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 191 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 192 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 193 | " -0.4242, -0.2206, 0.7595, 2.7833, 2.7960, 1.9560, -0.4242,\n", 194 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 195 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 196 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 197 | " -0.4242, -0.4242, -0.4242, 2.7451, 2.7960, 2.7451, 0.3904,\n", 198 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 199 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 200 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 201 | " 0.1613, 1.2305, 1.9051, 2.7960, 2.7960, 2.2105, -0.3988,\n", 202 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 203 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 204 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.0722, 1.4596,\n", 205 | " 2.4906, 2.7960, 2.7960, 2.7960, 2.7578, 1.8923, -0.4242,\n", 206 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 207 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 208 | " -0.4242, -0.4242, -0.4242, -0.1187, 1.0268, 2.3887, 2.7960,\n", 209 | " 2.7960, 2.7960, 2.7960, 2.1342, 0.5686, -0.4242, -0.4242,\n", 210 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 211 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 212 | " -0.4242, -0.1315, 0.4159, 2.2869, 2.7960, 2.7960, 2.7960,\n", 213 | " 2.7960, 2.0960, 0.6068, -0.3988, -0.4242, -0.4242, -0.4242,\n", 214 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 215 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.1951,\n", 216 | " 1.7523, 2.3633, 2.7960, 2.7960, 2.7960, 2.7960, 2.0578,\n", 217 | " 0.5940, -0.3097, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 218 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 219 | " [-0.4242, -0.4242, -0.4242, -0.4242, 0.2758, 1.7650, 2.4524,\n", 220 | " 2.7960, 2.7960, 2.7960, 2.7960, 2.6815, 1.2686, -0.2842,\n", 221 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 222 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 223 | " [-0.4242, -0.4242, -0.4242, -0.4242, 1.3068, 2.7960, 2.7960,\n", 224 | " 2.7960, 2.2742, 1.2941, 1.2559, -0.2206, -0.4242, -0.4242,\n", 225 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 226 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 227 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 228 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 229 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 230 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 231 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 232 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 233 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 234 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],\n", 235 | " [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 236 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 237 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,\n", 238 | " -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]),\n", 239 | " 5)" 240 | ] 241 | }, 242 | "execution_count": 24, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "dataset1[0]" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 11, 254 | "id": "079d3795", 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "text/plain": [ 260 | "torch.Size([1, 28, 28])" 261 | ] 262 | }, 263 | "execution_count": 11, 264 | "metadata": {}, 265 | "output_type": "execute_result" 266 | } 267 | ], 268 | "source": [ 269 | "dataset1[0][0].size()" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 22, 275 | "id": "3e312249", 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "data": { 280 | "text/plain": [ 281 | "" 282 | ] 283 | }, 284 | "execution_count": 22, 285 | "metadata": {}, 286 | "output_type": "execute_result" 287 | }, 288 | { 289 | "data": { 290 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAOX0lEQVR4nO3dbYxc5XnG8euKbUwxJvHGseMQFxzjFAg0Jl0ZkBFQoVCCIgGKCLGiiFBapwlOQutKUFoVWtHKrRIiSimSKS6m4iWQgPAHmsSyECRqcFmoAROHN+MS4+0aswIDIfZ6fffDjqsFdp5dZs68eO//T1rNzLnnzLk1cPmcmeeceRwRAjD5faDTDQBoD8IOJEHYgSQIO5AEYQeSmNrOjR3i6XGoZrRzk0Aqv9Fb2ht7PFatqbDbPkfS9ZKmSPrXiFhVev6hmqGTfVYzmwRQsDE21K01fBhve4qkGyV9TtLxkpbZPr7R1wPQWs18Zl8i6fmI2BoReyXdJem8atoCULVmwn6kpF+Nery9tuwdbC+33We7b0h7mtgcgGY0E/axvgR4z7m3EbE6InojoneapjexOQDNaCbs2yXNH/X445J2NNcOgFZpJuyPSlpke4HtQyR9SdK6atoCULWGh94iYp/tFZJ+rJGhtzUR8XRlnQGoVFPj7BHxgKQHKuoFQAtxuiyQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJNDWLK7qfp5b/E0/5yOyWbv+ZPz+6bm34sP3FdY9auLNYP+wbLtb/97pD6tYe7/1+cd1dw28V6yffs7JYP+bPHinWO6GpsNveJukNScOS9kVEbxVNAaheFXv234+IXRW8DoAW4jM7kESzYQ9JP7H9mO3lYz3B9nLbfbb7hrSnyc0BaFSzh/FLI2KH7TmS1tv+ZUQ8PPoJEbFa0mpJOsI90eT2ADSoqT17ROyo3e6UdJ+kJVU0BaB6DYfd9gzbMw/cl3S2pM1VNQagWs0cxs+VdJ/tA69zR0T8qJKuJpkpxy0q1mP6tGJ9xxkfKtbfPqX+mHDPB8vjxT/9dHm8uZP+49czi/V/+OdzivWNJ95Rt/bi0NvFdVcNfLZY/9hPD75PpA2HPSK2Svp0hb0AaCGG3oAkCDuQBGEHkiDsQBKEHUiCS1wrMHzmZ4r16269sVj/5LT6l2JOZkMxXKz/9Q1fLdanvlUe/jr1nhV1azNf3ldcd/qu8tDcYX0bi/VuxJ4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0C05/ZUaw/9pv5xfonpw1U2U6lVvafUqxvfbP8U9S3LvxB3drr+8vj5HP/6T+L9VY6+C5gHR97diAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwhHtG1E8wj1xss9q2/a6xeAlpxbru88p/9zzlCcPL9af+MYN77unA67d9bvF+qNnlMfRh197vViPU+v/APG2bxVX1YJlT5SfgPfYGBu0OwbHnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMvvDxfrwq4PF+ot31B8rf/r0NcV1l/z9N4v1OTd27ppyvH9NjbPbXmN7p+3No5b12F5v+7na7awqGwZQvYkcxt8q6d2z3l8paUNELJK0ofYYQBcbN+wR8bCkdx9Hnidpbe3+WknnV9sWgKo1+gXd3Ijol6Ta7Zx6T7S93Haf7b4h7WlwcwCa1fJv4yNidUT0RkTvNE1v9eYA1NFo2Adsz5Ok2u3O6loC0AqNhn2dpItr9y+WdH817QBolXF/N972nZLOlDTb9nZJV0taJelu25dKeknSha1scrIb3vVqU+sP7W58fvdPffkXxforN00pv8D+8hzr6B7jhj0iltUpcXYMcBDhdFkgCcIOJEHYgSQIO5AEYQeSYMrmSeC4K56tW7vkxPKgyb8dtaFYP+PCy4r1md9/pFhH92DPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM4+CZSmTX7168cV131p3dvF+pXX3las/8UXLyjW478/WLc2/+9+XlxXbfyZ8wzYswNJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEkzZnNzgH55arN9+9XeK9QVTD21425+6bUWxvujm/mJ939ZtDW97smpqymYAkwNhB5Ig7EAShB1IgrADSRB2IAnCDiTBODuKYuniYv2IVduL9Ts/8eOGt33sg39UrP/O39S/jl+Shp/b2vC2D1ZNjbPbXmN7p+3No5ZdY/tl25tqf+dW2TCA6k3kMP5WSeeMsfx7EbG49vdAtW0BqNq4YY+IhyUNtqEXAC3UzBd0K2w/WTvMn1XvSbaX2+6z3TekPU1sDkAzGg37TZIWSlosqV/Sd+s9MSJWR0RvRPRO0/QGNwegWQ2FPSIGImI4IvZLulnSkmrbAlC1hsJue96ohxdI2lzvuQC6w7jj7LbvlHSmpNmSBiRdXXu8WFJI2ibpaxFRvvhYjLNPRlPmzinWd1x0TN3axiuuL677gXH2RV9+8exi/fXTXi3WJ6PSOPu4k0RExLIxFt/SdFcA2orTZYEkCDuQBGEHkiDsQBKEHUiCS1zRMXdvL0/ZfJgPKdZ/HXuL9c9/8/L6r33fxuK6Byt+ShoAYQeyIOxAEoQdSIKwA0kQdiAJwg4kMe5Vb8ht/2mLi/UXLixP2XzC4m11a+ONo4/nhsGTivXD7u9r6vUnG/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+yTnHtPKNaf/VZ5rPvmpWuL9dMPLV9T3ow9MVSsPzK4oPwC+8f9dfNU2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsx8Epi44qlh/4ZKP1a1dc9FdxXW/cPiuhnqqwlUDvcX6Q9efUqzPWlv+3Xm807h7dtvzbT9oe4vtp21/u7a8x/Z628/Vbme1vl0AjZrIYfw+SSsj4jhJp0i6zPbxkq6UtCEiFknaUHsMoEuNG/aI6I+Ix2v335C0RdKRks6TdOBcyrWSzm9RjwAq8L6+oLN9tKSTJG2UNDci+qWRfxAkzamzznLbfbb7hrSnyXYBNGrCYbd9uKQfSro8InZPdL2IWB0RvRHRO03TG+kRQAUmFHbb0zQS9Nsj4t7a4gHb82r1eZJ2tqZFAFUYd+jNtiXdImlLRFw3qrRO0sWSVtVu729Jh5PA1KN/u1h//ffmFesX/e2PivU/+dC9xXorrewvD4/9/F/qD6/13PpfxXVn7WdorUoTGWdfKukrkp6yvam27CqNhPxu25dKeknShS3pEEAlxg17RPxM0piTu0s6q9p2ALQKp8sCSRB2IAnCDiRB2IEkCDuQBJe4TtDUeR+tWxtcM6O47tcXPFSsL5s50FBPVVjx8mnF+uM3LS7WZ/9gc7He8wZj5d2CPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJJFmnH3vH5R/tnjvnw4W61cd80Dd2tm/9VZDPVVlYPjturXT160srnvsX/2yWO95rTxOvr9YRTdhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaQZZ992fvnftWdPvKdl277xtYXF+vUPnV2se7jej/uOOPbaF+vWFg1sLK47XKxiMmHPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJOCLKT7DnS7pN0kc1cvny6oi43vY1kv5Y0iu1p14VEfUv+pZ0hHviZDPxK9AqG2ODdsfgmCdmTOSkmn2SVkbE47ZnSnrM9vpa7XsR8Z2qGgXQOhOZn71fUn/t/hu2t0g6stWNAajW+/rMbvtoSSdJOnAO5grbT9peY3tWnXWW2+6z3TekPc11C6BhEw677cMl/VDS5RGxW9JNkhZKWqyRPf93x1ovIlZHRG9E9E7T9OY7BtCQCYXd9jSNBP32iLhXkiJiICKGI2K/pJslLWldmwCaNW7YbVvSLZK2RMR1o5bPG/W0CySVp/ME0FET+TZ+qaSvSHrK9qbasqskLbO9WFJI2ibpay3oD0BFJvJt/M8kjTVuVxxTB9BdOIMOSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQxLg/JV3pxuxXJP3PqEWzJe1qWwPvT7f21q19SfTWqCp7OyoiPjJWoa1hf8/G7b6I6O1YAwXd2lu39iXRW6Pa1RuH8UAShB1IotNhX93h7Zd0a2/d2pdEb41qS28d/cwOoH06vWcH0CaEHUiiI2G3fY7tZ2w/b/vKTvRQj+1ttp+yvcl2X4d7WWN7p+3No5b12F5v+7na7Zhz7HWot2tsv1x77zbZPrdDvc23/aDtLbaftv3t2vKOvneFvtryvrX9M7vtKZKelfRZSdslPSppWUT8oq2N1GF7m6TeiOj4CRi2T5f0pqTbIuKE2rJ/lDQYEatq/1DOiogruqS3ayS92elpvGuzFc0bPc24pPMlfVUdfO8KfX1RbXjfOrFnXyLp+YjYGhF7Jd0l6bwO9NH1IuJhSYPvWnyepLW1+2s18j9L29XprStERH9EPF67/4akA9OMd/S9K/TVFp0I+5GSfjXq8XZ113zvIeknth+zvbzTzYxhbkT0SyP/80ia0+F+3m3cabzb6V3TjHfNe9fI9OfN6kTYx5pKqpvG/5ZGxGckfU7SZbXDVUzMhKbxbpcxphnvCo1Of96sToR9u6T5ox5/XNKODvQxpojYUbvdKek+dd9U1AMHZtCt3e7scD//r5um8R5rmnF1wXvXyenPOxH2RyUtsr3A9iGSviRpXQf6eA/bM2pfnMj2DElnq/umol4n6eLa/Ysl3d/BXt6hW6bxrjfNuDr83nV8+vOIaPufpHM18o38C5L+shM91OnrE5KeqP093eneJN2pkcO6IY0cEV0q6cOSNkh6rnbb00W9/bukpyQ9qZFgzetQb6dp5KPhk5I21f7O7fR7V+irLe8bp8sCSXAGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4k8X+zhHFo7nUhhwAAAABJRU5ErkJggg==\n", 291 | "text/plain": [ 292 | "
" 293 | ] 294 | }, 295 | "metadata": { 296 | "needs_background": "light" 297 | }, 298 | "output_type": "display_data" 299 | } 300 | ], 301 | "source": [ 302 | "import matplotlib.pyplot as plt\n", 303 | "plt.imshow(dataset1[0][0].squeeze())" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 23, 309 | "id": "b9e115c8", 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "name": "stdout", 314 | "output_type": "stream", 315 | "text": [ 316 | "5\n" 317 | ] 318 | } 319 | ], 320 | "source": [ 321 | "print(dataset1[0][1])" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 5, 327 | "id": "b762f362", 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)\n", 332 | "test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 7, 338 | "id": "78d0debe", 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 24, 348 | "id": "06df6aa8", 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "class Net(nn.Module):\n", 353 | " def __init__(self):\n", 354 | " super(Net, self).__init__()\n", 355 | " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", 356 | " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", 357 | " self.dropout1 = nn.Dropout(0.25)\n", 358 | " self.dropout2 = nn.Dropout(0.5)\n", 359 | " self.fc1 = nn.Linear(9216, 128)\n", 360 | " self.fc2 = nn.Linear(128, 10)\n", 361 | "\n", 362 | " def forward(self, x):\n", 363 | " x = self.conv1(x)\n", 364 | " x = F.relu(x)\n", 365 | " x = self.conv2(x)\n", 366 | " x = F.relu(x)\n", 367 | " x = F.max_pool2d(x, 2)\n", 368 | " x = self.dropout1(x)\n", 369 | " x = torch.flatten(x, 1)\n", 370 | " x = self.fc1(x)\n", 371 | " x = F.relu(x)\n", 372 | " x = self.dropout2(x)\n", 373 | " x = self.fc2(x)\n", 374 | " output = F.log_softmax(x, dim=1)\n", 375 | " return output" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 25, 381 | "id": "1ac519fb", 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "model = Net().to(device)\n", 386 | "optimizer = optim.Adadelta(model.parameters(), lr=1.0)\n", 387 | "scheduler = StepLR(optimizer, step_size=1, gamma=0.7)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 29, 393 | "id": "e3d4bcd7", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "def train(dry_run, model, device, train_loader, optimizer, epoch):\n", 398 | " model.train()\n", 399 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 400 | " data, target = data.to(device), target.to(device)\n", 401 | " optimizer.zero_grad()\n", 402 | " output = model(data)\n", 403 | " loss = F.nll_loss(output, target)\n", 404 | " loss.backward()\n", 405 | " optimizer.step()\n", 406 | " if batch_idx % 50 == 0:\n", 407 | " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", 408 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n", 409 | " 100. * batch_idx / len(train_loader), loss.item()))\n", 410 | " if dry_run:\n", 411 | " break" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 27, 417 | "id": "2a4f8d25", 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [ 421 | "def test(model, device, test_loader):\n", 422 | " model.eval()\n", 423 | " test_loss = 0\n", 424 | " correct = 0\n", 425 | " with torch.no_grad():\n", 426 | " for data, target in test_loader:\n", 427 | " data, target = data.to(device), target.to(device)\n", 428 | " output = model(data)\n", 429 | " test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n", 430 | " pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n", 431 | " correct += pred.eq(target.view_as(pred)).sum().item()\n", 432 | "\n", 433 | " test_loss /= len(test_loader.dataset)\n", 434 | "\n", 435 | " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", 436 | " test_loss, correct, len(test_loader.dataset),\n", 437 | " 100. * correct / len(test_loader.dataset)))" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 31, 443 | "id": "7f25600a", 444 | "metadata": {}, 445 | "outputs": [ 446 | { 447 | "name": "stdout", 448 | "output_type": "stream", 449 | "text": [ 450 | "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.304086\n", 451 | "Train Epoch: 1 [3200/60000 (5%)]\tLoss: 0.483454\n", 452 | "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.553788\n", 453 | "Train Epoch: 1 [9600/60000 (16%)]\tLoss: 0.329113\n", 454 | "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.429457\n", 455 | "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.138165\n", 456 | "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.138324\n", 457 | "Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.106210\n", 458 | "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.077803\n", 459 | "Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.131333\n", 460 | "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.026926\n", 461 | "Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.044464\n", 462 | "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.055920\n", 463 | "Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.134737\n", 464 | "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.083895\n", 465 | "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.211297\n", 466 | "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.044227\n", 467 | "Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.038829\n", 468 | "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.133101\n", 469 | "\n", 470 | "Test set: Average loss: 0.0453, Accuracy: 9845/10000 (98%)\n", 471 | "\n", 472 | "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.068935\n", 473 | "Train Epoch: 2 [3200/60000 (5%)]\tLoss: 0.025827\n", 474 | "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.036748\n", 475 | "Train Epoch: 2 [9600/60000 (16%)]\tLoss: 0.009165\n", 476 | "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.034569\n", 477 | "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.041076\n", 478 | "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.067809\n", 479 | "Train Epoch: 2 [22400/60000 (37%)]\tLoss: 0.006986\n", 480 | "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.176369\n", 481 | "Train Epoch: 2 [28800/60000 (48%)]\tLoss: 0.015830\n", 482 | "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.024627\n", 483 | "Train Epoch: 2 [35200/60000 (59%)]\tLoss: 0.120578\n", 484 | "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.004441\n", 485 | "Train Epoch: 2 [41600/60000 (69%)]\tLoss: 0.068772\n", 486 | "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.041876\n", 487 | "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.036133\n", 488 | "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.180075\n", 489 | "Train Epoch: 2 [54400/60000 (91%)]\tLoss: 0.023938\n", 490 | "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.009721\n", 491 | "\n", 492 | "Test set: Average loss: 0.0363, Accuracy: 9879/10000 (99%)\n", 493 | "\n", 494 | "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.151827\n", 495 | "Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.022769\n", 496 | "Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.009587\n", 497 | "Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.148401\n", 498 | "Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.046925\n", 499 | "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.204460\n", 500 | "Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.026152\n", 501 | "Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.034202\n", 502 | "Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.006462\n", 503 | "Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.035811\n", 504 | "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.091731\n", 505 | "Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.065846\n", 506 | "Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.044159\n", 507 | "Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.007944\n", 508 | "Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.031157\n", 509 | "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.004689\n", 510 | "Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.006645\n", 511 | "Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.012520\n", 512 | "Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.021701\n", 513 | "\n", 514 | "Test set: Average loss: 0.0358, Accuracy: 9875/10000 (99%)\n", 515 | "\n", 516 | "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.017820\n", 517 | "Train Epoch: 4 [3200/60000 (5%)]\tLoss: 0.005245\n", 518 | "Train Epoch: 4 [6400/60000 (11%)]\tLoss: 0.011117\n", 519 | "Train Epoch: 4 [9600/60000 (16%)]\tLoss: 0.012314\n", 520 | "Train Epoch: 4 [12800/60000 (21%)]\tLoss: 0.053708\n", 521 | "Train Epoch: 4 [16000/60000 (27%)]\tLoss: 0.019360\n", 522 | "Train Epoch: 4 [19200/60000 (32%)]\tLoss: 0.037697\n", 523 | "Train Epoch: 4 [22400/60000 (37%)]\tLoss: 0.008664\n", 524 | "Train Epoch: 4 [25600/60000 (43%)]\tLoss: 0.028219\n", 525 | "Train Epoch: 4 [28800/60000 (48%)]\tLoss: 0.047608\n", 526 | "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.011517\n", 527 | "Train Epoch: 4 [35200/60000 (59%)]\tLoss: 0.163093\n", 528 | "Train Epoch: 4 [38400/60000 (64%)]\tLoss: 0.092249\n", 529 | "Train Epoch: 4 [41600/60000 (69%)]\tLoss: 0.006683\n", 530 | "Train Epoch: 4 [44800/60000 (75%)]\tLoss: 0.005895\n", 531 | "Train Epoch: 4 [48000/60000 (80%)]\tLoss: 0.003299\n", 532 | "Train Epoch: 4 [51200/60000 (85%)]\tLoss: 0.028930\n", 533 | "Train Epoch: 4 [54400/60000 (91%)]\tLoss: 0.005765\n", 534 | "Train Epoch: 4 [57600/60000 (96%)]\tLoss: 0.057321\n", 535 | "\n", 536 | "Test set: Average loss: 0.0291, Accuracy: 9902/10000 (99%)\n", 537 | "\n", 538 | "Train Epoch: 5 [0/60000 (0%)]\tLoss: 0.005983\n", 539 | "Train Epoch: 5 [3200/60000 (5%)]\tLoss: 0.027719\n", 540 | "Train Epoch: 5 [6400/60000 (11%)]\tLoss: 0.007718\n", 541 | "Train Epoch: 5 [9600/60000 (16%)]\tLoss: 0.003807\n", 542 | "Train Epoch: 5 [12800/60000 (21%)]\tLoss: 0.062214\n", 543 | "Train Epoch: 5 [16000/60000 (27%)]\tLoss: 0.005368\n", 544 | "Train Epoch: 5 [19200/60000 (32%)]\tLoss: 0.001086\n", 545 | "Train Epoch: 5 [22400/60000 (37%)]\tLoss: 0.040702\n", 546 | "Train Epoch: 5 [25600/60000 (43%)]\tLoss: 0.058768\n", 547 | "Train Epoch: 5 [28800/60000 (48%)]\tLoss: 0.009829\n", 548 | "Train Epoch: 5 [32000/60000 (53%)]\tLoss: 0.035533\n", 549 | "Train Epoch: 5 [35200/60000 (59%)]\tLoss: 0.018205\n", 550 | "Train Epoch: 5 [38400/60000 (64%)]\tLoss: 0.007692\n", 551 | "Train Epoch: 5 [41600/60000 (69%)]\tLoss: 0.018107\n", 552 | "Train Epoch: 5 [44800/60000 (75%)]\tLoss: 0.076740\n", 553 | "Train Epoch: 5 [48000/60000 (80%)]\tLoss: 0.140321\n", 554 | "Train Epoch: 5 [51200/60000 (85%)]\tLoss: 0.016024\n", 555 | "Train Epoch: 5 [54400/60000 (91%)]\tLoss: 0.057671\n", 556 | "Train Epoch: 5 [57600/60000 (96%)]\tLoss: 0.027249\n", 557 | "\n", 558 | "Test set: Average loss: 0.0289, Accuracy: 9911/10000 (99%)\n", 559 | "\n", 560 | "Train Epoch: 6 [0/60000 (0%)]\tLoss: 0.031237\n", 561 | "Train Epoch: 6 [3200/60000 (5%)]\tLoss: 0.025852\n", 562 | "Train Epoch: 6 [6400/60000 (11%)]\tLoss: 0.010477\n", 563 | "Train Epoch: 6 [9600/60000 (16%)]\tLoss: 0.026485\n", 564 | "Train Epoch: 6 [12800/60000 (21%)]\tLoss: 0.020630\n", 565 | "Train Epoch: 6 [16000/60000 (27%)]\tLoss: 0.000460\n", 566 | "Train Epoch: 6 [19200/60000 (32%)]\tLoss: 0.084963\n", 567 | "Train Epoch: 6 [22400/60000 (37%)]\tLoss: 0.011054\n", 568 | "Train Epoch: 6 [25600/60000 (43%)]\tLoss: 0.020268\n", 569 | "Train Epoch: 6 [28800/60000 (48%)]\tLoss: 0.012857\n", 570 | "Train Epoch: 6 [32000/60000 (53%)]\tLoss: 0.007072\n", 571 | "Train Epoch: 6 [35200/60000 (59%)]\tLoss: 0.042861\n", 572 | "Train Epoch: 6 [38400/60000 (64%)]\tLoss: 0.002863\n", 573 | "Train Epoch: 6 [41600/60000 (69%)]\tLoss: 0.004465\n", 574 | "Train Epoch: 6 [44800/60000 (75%)]\tLoss: 0.028030\n", 575 | "Train Epoch: 6 [48000/60000 (80%)]\tLoss: 0.008660\n", 576 | "Train Epoch: 6 [51200/60000 (85%)]\tLoss: 0.031442\n", 577 | "Train Epoch: 6 [54400/60000 (91%)]\tLoss: 0.004995\n", 578 | "Train Epoch: 6 [57600/60000 (96%)]\tLoss: 0.001316\n", 579 | "\n", 580 | "Test set: Average loss: 0.0287, Accuracy: 9910/10000 (99%)\n", 581 | "\n", 582 | "Train Epoch: 7 [0/60000 (0%)]\tLoss: 0.116758\n", 583 | "Train Epoch: 7 [3200/60000 (5%)]\tLoss: 0.044722\n", 584 | "Train Epoch: 7 [6400/60000 (11%)]\tLoss: 0.130645\n", 585 | "Train Epoch: 7 [9600/60000 (16%)]\tLoss: 0.081057\n", 586 | "Train Epoch: 7 [12800/60000 (21%)]\tLoss: 0.043891\n", 587 | "Train Epoch: 7 [16000/60000 (27%)]\tLoss: 0.035503\n", 588 | "Train Epoch: 7 [19200/60000 (32%)]\tLoss: 0.016339\n", 589 | "Train Epoch: 7 [22400/60000 (37%)]\tLoss: 0.008756\n", 590 | "Train Epoch: 7 [25600/60000 (43%)]\tLoss: 0.024321\n", 591 | "Train Epoch: 7 [28800/60000 (48%)]\tLoss: 0.017221\n", 592 | "Train Epoch: 7 [32000/60000 (53%)]\tLoss: 0.019142\n", 593 | "Train Epoch: 7 [35200/60000 (59%)]\tLoss: 0.001405\n", 594 | "Train Epoch: 7 [38400/60000 (64%)]\tLoss: 0.034544\n", 595 | "Train Epoch: 7 [41600/60000 (69%)]\tLoss: 0.042996\n", 596 | "Train Epoch: 7 [44800/60000 (75%)]\tLoss: 0.015506\n", 597 | "Train Epoch: 7 [48000/60000 (80%)]\tLoss: 0.040036\n", 598 | "Train Epoch: 7 [51200/60000 (85%)]\tLoss: 0.039684\n", 599 | "Train Epoch: 7 [54400/60000 (91%)]\tLoss: 0.021614\n", 600 | "Train Epoch: 7 [57600/60000 (96%)]\tLoss: 0.022343\n", 601 | "\n", 602 | "Test set: Average loss: 0.0263, Accuracy: 9910/10000 (99%)\n", 603 | "\n", 604 | "Train Epoch: 8 [0/60000 (0%)]\tLoss: 0.047210\n", 605 | "Train Epoch: 8 [3200/60000 (5%)]\tLoss: 0.013396\n", 606 | "Train Epoch: 8 [6400/60000 (11%)]\tLoss: 0.028145\n", 607 | "Train Epoch: 8 [9600/60000 (16%)]\tLoss: 0.064660\n", 608 | "Train Epoch: 8 [12800/60000 (21%)]\tLoss: 0.032066\n", 609 | "Train Epoch: 8 [16000/60000 (27%)]\tLoss: 0.043324\n", 610 | "Train Epoch: 8 [19200/60000 (32%)]\tLoss: 0.002806\n", 611 | "Train Epoch: 8 [22400/60000 (37%)]\tLoss: 0.080810\n", 612 | "Train Epoch: 8 [25600/60000 (43%)]\tLoss: 0.074458\n", 613 | "Train Epoch: 8 [28800/60000 (48%)]\tLoss: 0.049314\n", 614 | "Train Epoch: 8 [32000/60000 (53%)]\tLoss: 0.000767\n", 615 | "Train Epoch: 8 [35200/60000 (59%)]\tLoss: 0.002614\n", 616 | "Train Epoch: 8 [38400/60000 (64%)]\tLoss: 0.006574\n", 617 | "Train Epoch: 8 [41600/60000 (69%)]\tLoss: 0.034785\n", 618 | "Train Epoch: 8 [44800/60000 (75%)]\tLoss: 0.015252\n", 619 | "Train Epoch: 8 [48000/60000 (80%)]\tLoss: 0.060464\n", 620 | "Train Epoch: 8 [51200/60000 (85%)]\tLoss: 0.002489\n", 621 | "Train Epoch: 8 [54400/60000 (91%)]\tLoss: 0.013407\n", 622 | "Train Epoch: 8 [57600/60000 (96%)]\tLoss: 0.018576\n", 623 | "\n", 624 | "Test set: Average loss: 0.0277, Accuracy: 9913/10000 (99%)\n", 625 | "\n", 626 | "Train Epoch: 9 [0/60000 (0%)]\tLoss: 0.004137\n", 627 | "Train Epoch: 9 [3200/60000 (5%)]\tLoss: 0.219774\n", 628 | "Train Epoch: 9 [6400/60000 (11%)]\tLoss: 0.011975\n", 629 | "Train Epoch: 9 [9600/60000 (16%)]\tLoss: 0.003798\n" 630 | ] 631 | }, 632 | { 633 | "name": "stdout", 634 | "output_type": "stream", 635 | "text": [ 636 | "Train Epoch: 9 [12800/60000 (21%)]\tLoss: 0.009582\n", 637 | "Train Epoch: 9 [16000/60000 (27%)]\tLoss: 0.051869\n", 638 | "Train Epoch: 9 [19200/60000 (32%)]\tLoss: 0.149949\n", 639 | "Train Epoch: 9 [22400/60000 (37%)]\tLoss: 0.007237\n", 640 | "Train Epoch: 9 [25600/60000 (43%)]\tLoss: 0.011666\n", 641 | "Train Epoch: 9 [28800/60000 (48%)]\tLoss: 0.003852\n", 642 | "Train Epoch: 9 [32000/60000 (53%)]\tLoss: 0.042090\n", 643 | "Train Epoch: 9 [35200/60000 (59%)]\tLoss: 0.018689\n", 644 | "Train Epoch: 9 [38400/60000 (64%)]\tLoss: 0.002339\n", 645 | "Train Epoch: 9 [41600/60000 (69%)]\tLoss: 0.026952\n", 646 | "Train Epoch: 9 [44800/60000 (75%)]\tLoss: 0.004560\n", 647 | "Train Epoch: 9 [48000/60000 (80%)]\tLoss: 0.110791\n", 648 | "Train Epoch: 9 [51200/60000 (85%)]\tLoss: 0.005805\n", 649 | "Train Epoch: 9 [54400/60000 (91%)]\tLoss: 0.012774\n", 650 | "Train Epoch: 9 [57600/60000 (96%)]\tLoss: 0.019638\n", 651 | "\n", 652 | "Test set: Average loss: 0.0272, Accuracy: 9914/10000 (99%)\n", 653 | "\n", 654 | "Train Epoch: 10 [0/60000 (0%)]\tLoss: 0.005552\n", 655 | "Train Epoch: 10 [3200/60000 (5%)]\tLoss: 0.011325\n", 656 | "Train Epoch: 10 [6400/60000 (11%)]\tLoss: 0.106327\n", 657 | "Train Epoch: 10 [9600/60000 (16%)]\tLoss: 0.016725\n", 658 | "Train Epoch: 10 [12800/60000 (21%)]\tLoss: 0.014983\n", 659 | "Train Epoch: 10 [16000/60000 (27%)]\tLoss: 0.008195\n", 660 | "Train Epoch: 10 [19200/60000 (32%)]\tLoss: 0.004805\n", 661 | "Train Epoch: 10 [22400/60000 (37%)]\tLoss: 0.007613\n", 662 | "Train Epoch: 10 [25600/60000 (43%)]\tLoss: 0.005262\n", 663 | "Train Epoch: 10 [28800/60000 (48%)]\tLoss: 0.015021\n", 664 | "Train Epoch: 10 [32000/60000 (53%)]\tLoss: 0.007859\n", 665 | "Train Epoch: 10 [35200/60000 (59%)]\tLoss: 0.006763\n", 666 | "Train Epoch: 10 [38400/60000 (64%)]\tLoss: 0.007447\n", 667 | "Train Epoch: 10 [41600/60000 (69%)]\tLoss: 0.006996\n", 668 | "Train Epoch: 10 [44800/60000 (75%)]\tLoss: 0.009767\n", 669 | "Train Epoch: 10 [48000/60000 (80%)]\tLoss: 0.002931\n", 670 | "Train Epoch: 10 [51200/60000 (85%)]\tLoss: 0.006214\n", 671 | "Train Epoch: 10 [54400/60000 (91%)]\tLoss: 0.040586\n", 672 | "Train Epoch: 10 [57600/60000 (96%)]\tLoss: 0.010895\n", 673 | "\n", 674 | "Test set: Average loss: 0.0272, Accuracy: 9912/10000 (99%)\n", 675 | "\n", 676 | "Train Epoch: 11 [0/60000 (0%)]\tLoss: 0.020540\n", 677 | "Train Epoch: 11 [3200/60000 (5%)]\tLoss: 0.010748\n", 678 | "Train Epoch: 11 [6400/60000 (11%)]\tLoss: 0.033409\n", 679 | "Train Epoch: 11 [9600/60000 (16%)]\tLoss: 0.032415\n", 680 | "Train Epoch: 11 [12800/60000 (21%)]\tLoss: 0.024565\n", 681 | "Train Epoch: 11 [16000/60000 (27%)]\tLoss: 0.000984\n", 682 | "Train Epoch: 11 [19200/60000 (32%)]\tLoss: 0.004040\n", 683 | "Train Epoch: 11 [22400/60000 (37%)]\tLoss: 0.061551\n", 684 | "Train Epoch: 11 [25600/60000 (43%)]\tLoss: 0.041262\n", 685 | "Train Epoch: 11 [28800/60000 (48%)]\tLoss: 0.030136\n", 686 | "Train Epoch: 11 [32000/60000 (53%)]\tLoss: 0.002762\n", 687 | "Train Epoch: 11 [35200/60000 (59%)]\tLoss: 0.010743\n", 688 | "Train Epoch: 11 [38400/60000 (64%)]\tLoss: 0.005233\n", 689 | "Train Epoch: 11 [41600/60000 (69%)]\tLoss: 0.117578\n", 690 | "Train Epoch: 11 [44800/60000 (75%)]\tLoss: 0.009364\n", 691 | "Train Epoch: 11 [48000/60000 (80%)]\tLoss: 0.014623\n", 692 | "Train Epoch: 11 [51200/60000 (85%)]\tLoss: 0.003061\n", 693 | "Train Epoch: 11 [54400/60000 (91%)]\tLoss: 0.006869\n", 694 | "Train Epoch: 11 [57600/60000 (96%)]\tLoss: 0.003255\n", 695 | "\n", 696 | "Test set: Average loss: 0.0263, Accuracy: 9912/10000 (99%)\n", 697 | "\n", 698 | "Train Epoch: 12 [0/60000 (0%)]\tLoss: 0.030275\n", 699 | "Train Epoch: 12 [3200/60000 (5%)]\tLoss: 0.004557\n", 700 | "Train Epoch: 12 [6400/60000 (11%)]\tLoss: 0.004224\n", 701 | "Train Epoch: 12 [9600/60000 (16%)]\tLoss: 0.003952\n", 702 | "Train Epoch: 12 [12800/60000 (21%)]\tLoss: 0.012795\n", 703 | "Train Epoch: 12 [16000/60000 (27%)]\tLoss: 0.052782\n", 704 | "Train Epoch: 12 [19200/60000 (32%)]\tLoss: 0.044178\n", 705 | "Train Epoch: 12 [22400/60000 (37%)]\tLoss: 0.006386\n", 706 | "Train Epoch: 12 [25600/60000 (43%)]\tLoss: 0.005551\n", 707 | "Train Epoch: 12 [28800/60000 (48%)]\tLoss: 0.000987\n", 708 | "Train Epoch: 12 [32000/60000 (53%)]\tLoss: 0.017901\n", 709 | "Train Epoch: 12 [35200/60000 (59%)]\tLoss: 0.053025\n", 710 | "Train Epoch: 12 [38400/60000 (64%)]\tLoss: 0.003934\n", 711 | "Train Epoch: 12 [41600/60000 (69%)]\tLoss: 0.176611\n", 712 | "Train Epoch: 12 [44800/60000 (75%)]\tLoss: 0.006120\n", 713 | "Train Epoch: 12 [48000/60000 (80%)]\tLoss: 0.025429\n", 714 | "Train Epoch: 12 [51200/60000 (85%)]\tLoss: 0.000851\n", 715 | "Train Epoch: 12 [54400/60000 (91%)]\tLoss: 0.009207\n", 716 | "Train Epoch: 12 [57600/60000 (96%)]\tLoss: 0.027490\n", 717 | "\n", 718 | "Test set: Average loss: 0.0267, Accuracy: 9913/10000 (99%)\n", 719 | "\n", 720 | "Train Epoch: 13 [0/60000 (0%)]\tLoss: 0.011479\n", 721 | "Train Epoch: 13 [3200/60000 (5%)]\tLoss: 0.020575\n", 722 | "Train Epoch: 13 [6400/60000 (11%)]\tLoss: 0.003176\n", 723 | "Train Epoch: 13 [9600/60000 (16%)]\tLoss: 0.013525\n", 724 | "Train Epoch: 13 [12800/60000 (21%)]\tLoss: 0.024242\n", 725 | "Train Epoch: 13 [16000/60000 (27%)]\tLoss: 0.002399\n", 726 | "Train Epoch: 13 [19200/60000 (32%)]\tLoss: 0.000668\n", 727 | "Train Epoch: 13 [22400/60000 (37%)]\tLoss: 0.025648\n", 728 | "Train Epoch: 13 [25600/60000 (43%)]\tLoss: 0.011894\n", 729 | "Train Epoch: 13 [28800/60000 (48%)]\tLoss: 0.037351\n", 730 | "Train Epoch: 13 [32000/60000 (53%)]\tLoss: 0.000562\n", 731 | "Train Epoch: 13 [35200/60000 (59%)]\tLoss: 0.012582\n", 732 | "Train Epoch: 13 [38400/60000 (64%)]\tLoss: 0.013998\n", 733 | "Train Epoch: 13 [41600/60000 (69%)]\tLoss: 0.001727\n", 734 | "Train Epoch: 13 [44800/60000 (75%)]\tLoss: 0.033062\n", 735 | "Train Epoch: 13 [48000/60000 (80%)]\tLoss: 0.004695\n", 736 | "Train Epoch: 13 [51200/60000 (85%)]\tLoss: 0.014415\n", 737 | "Train Epoch: 13 [54400/60000 (91%)]\tLoss: 0.001962\n", 738 | "Train Epoch: 13 [57600/60000 (96%)]\tLoss: 0.004545\n", 739 | "\n", 740 | "Test set: Average loss: 0.0266, Accuracy: 9912/10000 (99%)\n", 741 | "\n", 742 | "Train Epoch: 14 [0/60000 (0%)]\tLoss: 0.049028\n", 743 | "Train Epoch: 14 [3200/60000 (5%)]\tLoss: 0.022294\n", 744 | "Train Epoch: 14 [6400/60000 (11%)]\tLoss: 0.125557\n", 745 | "Train Epoch: 14 [9600/60000 (16%)]\tLoss: 0.001006\n", 746 | "Train Epoch: 14 [12800/60000 (21%)]\tLoss: 0.017536\n", 747 | "Train Epoch: 14 [16000/60000 (27%)]\tLoss: 0.008574\n", 748 | "Train Epoch: 14 [19200/60000 (32%)]\tLoss: 0.003895\n", 749 | "Train Epoch: 14 [22400/60000 (37%)]\tLoss: 0.008526\n", 750 | "Train Epoch: 14 [25600/60000 (43%)]\tLoss: 0.006649\n", 751 | "Train Epoch: 14 [28800/60000 (48%)]\tLoss: 0.054258\n", 752 | "Train Epoch: 14 [32000/60000 (53%)]\tLoss: 0.010476\n", 753 | "Train Epoch: 14 [35200/60000 (59%)]\tLoss: 0.000951\n", 754 | "Train Epoch: 14 [38400/60000 (64%)]\tLoss: 0.011821\n", 755 | "Train Epoch: 14 [41600/60000 (69%)]\tLoss: 0.002824\n", 756 | "Train Epoch: 14 [44800/60000 (75%)]\tLoss: 0.055428\n", 757 | "Train Epoch: 14 [48000/60000 (80%)]\tLoss: 0.062347\n", 758 | "Train Epoch: 14 [51200/60000 (85%)]\tLoss: 0.003034\n", 759 | "Train Epoch: 14 [54400/60000 (91%)]\tLoss: 0.002490\n", 760 | "Train Epoch: 14 [57600/60000 (96%)]\tLoss: 0.008748\n", 761 | "\n", 762 | "Test set: Average loss: 0.0264, Accuracy: 9917/10000 (99%)\n", 763 | "\n" 764 | ] 765 | } 766 | ], 767 | "source": [ 768 | "dry_run = False\n", 769 | "n_epochs = 14\n", 770 | "for epoch in range(1, n_epochs + 1):\n", 771 | " train(dry_run, model, device, train_loader, optimizer, epoch)\n", 772 | " test(model, device, test_loader)\n", 773 | " scheduler.step()" 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": 32, 779 | "id": "bb7ed1b9", 780 | "metadata": {}, 781 | "outputs": [], 782 | "source": [ 783 | "save_model = False\n", 784 | "if save_model:\n", 785 | " torch.save(model.state_dict(), \"mnist_cnn.pt\")" 786 | ] 787 | } 788 | ], 789 | "metadata": { 790 | "kernelspec": { 791 | "display_name": "Python 3 (ipykernel)", 792 | "language": "python", 793 | "name": "python3" 794 | }, 795 | "language_info": { 796 | "codemirror_mode": { 797 | "name": "ipython", 798 | "version": 3 799 | }, 800 | "file_extension": ".py", 801 | "mimetype": "text/x-python", 802 | "name": "python", 803 | "nbconvert_exporter": "python", 804 | "pygments_lexer": "ipython3", 805 | "version": "3.10.4" 806 | } 807 | }, 808 | "nbformat": 4, 809 | "nbformat_minor": 5 810 | } 811 | -------------------------------------------------------------------------------- /201_mnist/README.md: -------------------------------------------------------------------------------- 1 | # Basic MNIST Example 2 | 3 | ```bash 4 | pip install -r requirements.txt 5 | python tempmain.py 6 | # CUDA_VISIBLE_DEVICES=2 python tempmain.py # to specify GPU id to ex. 2 7 | ``` 8 | -------------------------------------------------------------------------------- /201_mnist/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.optim.lr_scheduler import StepLR 9 | 10 | 11 | class Net(nn.Module): 12 | def __init__(self): 13 | super(Net, self).__init__() 14 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 15 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 16 | self.dropout1 = nn.Dropout(0.25) 17 | self.dropout2 = nn.Dropout(0.5) 18 | self.fc1 = nn.Linear(9216, 128) 19 | self.fc2 = nn.Linear(128, 10) 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | x = F.relu(x) 24 | x = self.conv2(x) 25 | x = F.relu(x) 26 | x = F.max_pool2d(x, 2) 27 | x = self.dropout1(x) 28 | x = torch.flatten(x, 1) 29 | x = self.fc1(x) 30 | x = F.relu(x) 31 | x = self.dropout2(x) 32 | x = self.fc2(x) 33 | output = F.log_softmax(x, dim=1) 34 | return output 35 | 36 | 37 | def train(args, model, device, train_loader, optimizer, epoch): 38 | model.train() 39 | for batch_idx, (data, target) in enumerate(train_loader): 40 | data, target = data.to(device), target.to(device) 41 | optimizer.zero_grad() 42 | output = model(data) 43 | loss = F.nll_loss(output, target) 44 | loss.backward() 45 | optimizer.step() 46 | if batch_idx % args.log_interval == 0: 47 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 48 | epoch, batch_idx * len(data), len(train_loader.dataset), 49 | 100. * batch_idx / len(train_loader), loss.item())) 50 | if args.dry_run: 51 | break 52 | 53 | 54 | def test(model, device, test_loader): 55 | model.eval() 56 | test_loss = 0 57 | correct = 0 58 | with torch.no_grad(): 59 | for data, target in test_loader: 60 | data, target = data.to(device), target.to(device) 61 | output = model(data) 62 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 63 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 64 | correct += pred.eq(target.view_as(pred)).sum().item() 65 | 66 | test_loss /= len(test_loader.dataset) 67 | 68 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 69 | test_loss, correct, len(test_loader.dataset), 70 | 100. * correct / len(test_loader.dataset))) 71 | 72 | 73 | def main(): 74 | # Training settings 75 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 76 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 77 | help='input batch size for training (default: 64)') 78 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 79 | help='input batch size for testing (default: 1000)') 80 | parser.add_argument('--epochs', type=int, default=14, metavar='N', 81 | help='number of epochs to train (default: 14)') 82 | parser.add_argument('--lr', type=float, default=1.0, metavar='LR', 83 | help='learning rate (default: 1.0)') 84 | parser.add_argument('--gamma', type=float, default=0.7, metavar='M', 85 | help='Learning rate step gamma (default: 0.7)') 86 | parser.add_argument('--no-cuda', action='store_true', default=False, 87 | help='disables CUDA training') 88 | parser.add_argument('--dry-run', action='store_true', default=False, 89 | help='quickly check a single pass') 90 | parser.add_argument('--seed', type=int, default=1, metavar='S', 91 | help='random seed (default: 1)') 92 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 93 | help='how many batches to wait before logging training status') 94 | parser.add_argument('--save-model', action='store_true', default=False, 95 | help='For Saving the current Model') 96 | args = parser.parse_args() 97 | use_cuda = not args.no_cuda and torch.cuda.is_available() 98 | 99 | torch.manual_seed(args.seed) 100 | 101 | device = torch.device("cuda" if use_cuda else "cpu") 102 | 103 | train_kwargs = {'batch_size': args.batch_size} 104 | test_kwargs = {'batch_size': args.test_batch_size} 105 | if use_cuda: 106 | cuda_kwargs = {'num_workers': 1, 107 | 'pin_memory': True, 108 | 'shuffle': True} 109 | train_kwargs.update(cuda_kwargs) 110 | test_kwargs.update(cuda_kwargs) 111 | 112 | transform=transforms.Compose([ 113 | transforms.ToTensor(), 114 | transforms.Normalize((0.1307,), (0.3081,)) 115 | ]) 116 | dataset1 = datasets.MNIST('../data', train=True, download=True, 117 | transform=transform) 118 | dataset2 = datasets.MNIST('../data', train=False, 119 | transform=transform) 120 | train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs) 121 | test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) 122 | 123 | model = Net().to(device) 124 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 125 | 126 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 127 | for epoch in range(1, args.epochs + 1): 128 | train(args, model, device, train_loader, optimizer, epoch) 129 | test(model, device, test_loader) 130 | scheduler.step() 131 | 132 | if args.save_model: 133 | torch.save(model.state_dict(), "mnist_cnn.pt") 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /201_mnist/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | -------------------------------------------------------------------------------- /301/model_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import datetime 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import random 7 | import matplotlib.pyplot as plt 8 | from copy import deepcopy 9 | from torch.utils.tensorboard import SummaryWriter 10 | from torchvision.transforms import Normalize 11 | from torch.optim.lr_scheduler import LambdaLR 12 | 13 | plt.style.use('fivethirtyeight') 14 | 15 | def make_lr_fn(start_lr, end_lr, num_iter, step_mode='exp'): 16 | if step_mode == 'linear': 17 | factor = (end_lr / start_lr - 1) / num_iter 18 | def lr_fn(iteration): 19 | return 1 + iteration * factor 20 | else: 21 | factor = (np.log(end_lr) - np.log(start_lr)) / num_iter 22 | def lr_fn(iteration): 23 | return np.exp(factor)**iteration 24 | return lr_fn 25 | 26 | class ModelTrainer(object): 27 | def __init__(self, model, loss_fn, optimizer): 28 | # Here we define the attributes of our class 29 | 30 | # We start by storing the arguments as attributes 31 | # to use them later 32 | self.model = model 33 | self.loss_fn = loss_fn 34 | self.optimizer = optimizer 35 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 36 | # Let's send the model to the specified device right away 37 | self.model.to(self.device) 38 | 39 | # These attributes are defined here, but since they are 40 | # not informed at the moment of creation, we keep them None 41 | self.train_loader = None 42 | self.val_loader = None 43 | self.writer = None 44 | self.scheduler = None 45 | self.is_batch_lr_scheduler = False 46 | self.clipping = None 47 | 48 | # These attributes are going to be computed internally 49 | self.losses = [] 50 | self.val_losses = [] 51 | self.learning_rates = [] 52 | self.total_epochs = 0 53 | 54 | self.visualization = {} 55 | self.handles = {} 56 | 57 | # Creates the train_step function for our model, 58 | # loss function and optimizer 59 | # Note: there are NO ARGS there! It makes use of the class 60 | # attributes directly 61 | self.train_step_fn = self._make_train_step_fn() 62 | # Creates the val_step function for our model and loss 63 | self.val_step_fn = self._make_val_step_fn() 64 | 65 | def to(self, device): 66 | # This method allows the user to specify a different device 67 | # It sets the corresponding attribute (to be used later in 68 | # the mini-batches) and sends the model to the device 69 | try: 70 | self.device = device 71 | self.model.to(self.device) 72 | except RuntimeError: 73 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 74 | print(f"Couldn't send it to {device}, sending it to {self.device} instead.") 75 | self.model.to(self.device) 76 | 77 | def set_loaders(self, train_loader, val_loader=None): 78 | # This method allows the user to define which train_loader (and val_loader, optionally) to use 79 | # Both loaders are then assigned to attributes of the class 80 | # So they can be referred to later 81 | self.train_loader = train_loader 82 | self.val_loader = val_loader 83 | 84 | def set_tensorboard(self, name, folder='runs'): 85 | # This method allows the user to define a SummaryWriter to interface with TensorBoard 86 | suffix = datetime.datetime.now().strftime('%Y%m%d%H%M%S') 87 | self.writer = SummaryWriter(f'{folder}/{name}_{suffix}') 88 | 89 | def _make_train_step_fn(self): 90 | # This method does not need ARGS... it can refer to 91 | # the attributes: self.model, self.loss_fn and self.optimizer 92 | 93 | # Builds function that performs a step in the train loop 94 | def perform_train_step_fn(x, y): 95 | # Sets model to TRAIN mode 96 | self.model.train() 97 | 98 | # Step 1 - Computes our model's predicted output - forward pass 99 | yhat = self.model(x) 100 | # Step 2 - Computes the loss 101 | loss = self.loss_fn(yhat, y) 102 | # Step 3 - Computes gradients 103 | loss.backward() 104 | 105 | if callable(self.clipping): 106 | self.clipping() 107 | 108 | # Step 4 - Updates parameters using gradients and the learning rate 109 | self.optimizer.step() 110 | self.optimizer.zero_grad() 111 | 112 | # Returns the loss 113 | return loss.item() 114 | 115 | # Returns the function that will be called inside the train loop 116 | return perform_train_step_fn 117 | 118 | def _make_val_step_fn(self): 119 | # Builds function that performs a step in the validation loop 120 | def perform_val_step_fn(x, y): 121 | # Sets model to EVAL mode 122 | self.model.eval() 123 | 124 | # Step 1 - Computes our model's predicted output - forward pass 125 | yhat = self.model(x) 126 | # Step 2 - Computes the loss 127 | loss = self.loss_fn(yhat, y) 128 | # There is no need to compute Steps 3 and 4, since we don't update parameters during evaluation 129 | return loss.item() 130 | 131 | return perform_val_step_fn 132 | 133 | def _mini_batch(self, validation=False): 134 | # The mini-batch can be used with both loaders 135 | # The argument `validation`defines which loader and 136 | # corresponding step function is going to be used 137 | if validation: 138 | data_loader = self.val_loader 139 | step_fn = self.val_step_fn 140 | else: 141 | data_loader = self.train_loader 142 | step_fn = self.train_step_fn 143 | 144 | if data_loader is None: 145 | return None 146 | 147 | n_batches = len(data_loader) 148 | # Once the data loader and step function, this is the same 149 | # mini-batch loop we had before 150 | mini_batch_losses = [] 151 | for i, (x_batch, y_batch) in enumerate(data_loader): 152 | x_batch = x_batch.to(self.device) 153 | y_batch = y_batch.to(self.device) 154 | 155 | mini_batch_loss = step_fn(x_batch, y_batch) 156 | mini_batch_losses.append(mini_batch_loss) 157 | 158 | if not validation: 159 | self._mini_batch_schedulers(i / n_batches) 160 | 161 | loss = np.mean(mini_batch_losses) 162 | return loss 163 | 164 | def set_seed(self, seed=123): 165 | torch.backends.cudnn.deterministic = True 166 | torch.backends.cudnn.benchmark = False 167 | torch.manual_seed(seed) 168 | np.random.seed(seed) 169 | random.seed(seed) 170 | try: 171 | self.train_loader.sampler.generator.manual_seed(seed) 172 | except AttributeError: 173 | pass 174 | 175 | def train(self, n_epochs, seed=123): 176 | # To ensure reproducibility of the training process 177 | self.set_seed(seed) 178 | 179 | for epoch in range(n_epochs): 180 | # Keeps track of the numbers of epochs 181 | # by updating the corresponding attribute 182 | self.total_epochs += 1 183 | 184 | # inner loop 185 | # Performs training using mini-batches 186 | loss = self._mini_batch(validation=False) 187 | self.losses.append(loss) 188 | 189 | # VALIDATION 190 | # no gradients in validation! 191 | with torch.no_grad(): 192 | # Performs evaluation using mini-batches 193 | val_loss = self._mini_batch(validation=True) 194 | self.val_losses.append(val_loss) 195 | 196 | self._epoch_schedulers(val_loss) 197 | 198 | # If a SummaryWriter has been set... 199 | if self.writer: 200 | scalars = {'training': loss} 201 | if val_loss is not None: 202 | scalars.update({'validation': val_loss}) 203 | # Records both losses for each epoch under the main tag "loss" 204 | self.writer.add_scalars(main_tag='loss', 205 | tag_scalar_dict=scalars, 206 | global_step=epoch) 207 | 208 | if self.writer: 209 | # Closes the writer 210 | self.writer.close() 211 | 212 | def save_checkpoint(self, filename): 213 | # Builds dictionary with all elements for resuming training 214 | checkpoint = {'epoch': self.total_epochs, 215 | 'model_state_dict': self.model.state_dict(), 216 | 'optimizer_state_dict': self.optimizer.state_dict(), 217 | 'loss': self.losses, 218 | 'val_loss': self.val_losses} 219 | 220 | torch.save(checkpoint, filename) 221 | 222 | def load_checkpoint(self, filename): 223 | # Loads dictionary 224 | checkpoint = torch.load(filename) 225 | 226 | # Restore state for model and optimizer 227 | self.model.load_state_dict(checkpoint['model_state_dict']) 228 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 229 | 230 | self.total_epochs = checkpoint['epoch'] 231 | self.losses = checkpoint['loss'] 232 | self.val_losses = checkpoint['val_loss'] 233 | 234 | self.model.train() # always use TRAIN for resuming training 235 | 236 | def predict(self, x): 237 | # Set is to evaluation mode for predictions 238 | self.model.eval() 239 | # Takes aNumpy input and make it a float tensor 240 | x_tensor = torch.as_tensor(x).float() 241 | # Send input to device and uses model for prediction 242 | y_hat_tensor = self.model(x_tensor.to(self.device)) 243 | # Set it back to train mode 244 | self.model.train() 245 | # Detaches it, brings it to CPU and back to Numpy 246 | return y_hat_tensor.detach().cpu().numpy() 247 | 248 | def plot_losses(self): 249 | fig = plt.figure(figsize=(10, 4)) 250 | plt.plot(self.losses, label='Training Loss', c='b') 251 | plt.plot(self.val_losses, label='Validation Loss', c='r') 252 | plt.yscale('log') 253 | plt.xlabel('Epochs') 254 | plt.ylabel('Loss') 255 | plt.legend() 256 | plt.tight_layout() 257 | return fig 258 | 259 | def add_graph(self): 260 | # Fetches a single mini-batch so we can use add_graph 261 | if self.train_loader and self.writer: 262 | x_sample, y_sample = next(iter(self.train_loader)) 263 | self.writer.add_graph(self.model, x_sample.to(self.device)) 264 | 265 | def count_parameters(self): 266 | return sum(p.numel() for p in self.model.parameters() if p.requires_grad) 267 | 268 | @staticmethod 269 | def _visualize_tensors(axs, x, y=None, yhat=None, layer_name='', title=None): 270 | # The number of images is the number of subplots in a row 271 | n_images = len(axs) 272 | # Gets max and min values for scaling the grayscale 273 | minv, maxv = np.min(x[:n_images]), np.max(x[:n_images]) 274 | # For each image 275 | for j, image in enumerate(x[:n_images]): 276 | ax = axs[j] 277 | # Sets title, labels, and removes ticks 278 | if title is not None: 279 | ax.set_title('{} #{}'.format(title, j), fontsize=12) 280 | ax.set_ylabel( 281 | '{}\n{}x{}'.format(layer_name, *np.atleast_2d(image).shape), 282 | rotation=0, labelpad=40 283 | ) 284 | xlabel1 = '' if y is None else '\nLabel: {}'.format(y[j]) 285 | xlabel2 = '' if yhat is None else '\nPredicted: {}'.format(yhat[j]) 286 | xlabel = '{}{}'.format(xlabel1, xlabel2) 287 | if len(xlabel): 288 | ax.set_xlabel(xlabel, fontsize=12) 289 | ax.set_xticks([]) 290 | ax.set_yticks([]) 291 | 292 | # Plots weight as an image 293 | ax.imshow( 294 | np.atleast_2d(image.squeeze()), 295 | cmap='gray', 296 | vmin=minv, 297 | vmax=maxv 298 | ) 299 | return 300 | 301 | def visualize_filters(self, layer_name, **kwargs): 302 | try: 303 | # Gets the layer object from the model 304 | layer = self.model 305 | for name in layer_name.split('.'): 306 | layer = getattr(layer, name) 307 | # We are only looking at filters for 2D convolutions 308 | if isinstance(layer, nn.Conv2d): 309 | # Takes the weight information 310 | weights = layer.weight.data.cpu().numpy() 311 | # The weights have channels_out (filter), channels_in, H, W shape 312 | n_filters, n_channels, _, _ = weights.shape 313 | 314 | # Builds a figure 315 | size = (2 * n_channels + 2, 2 * n_filters) 316 | fig, axes = plt.subplots(n_filters, n_channels, figsize=size) 317 | axes = np.atleast_2d(axes).reshape(n_filters, n_channels) 318 | # For each channel_out (filter) 319 | for i in range(n_filters): 320 | ModelTrainer._visualize_tensors( 321 | axes[i, :], 322 | weights[i], 323 | layer_name='Filter #{}'.format(i), 324 | title='Channel' if (i == 0) else None 325 | ) 326 | 327 | for ax in axes.flat: 328 | ax.label_outer() 329 | 330 | fig.tight_layout() 331 | return fig 332 | except AttributeError: 333 | return 334 | 335 | def attach_hooks(self, layers_to_hook, hook_fn=None): 336 | # Clear any previous values 337 | self.visualization = {} 338 | # Creates the dictionary to map layer objects to their names 339 | modules = list(self.model.named_modules()) 340 | layer_names = {layer: name for name, layer in modules[1:]} 341 | 342 | if hook_fn is None: 343 | # Hook function to be attached to the forward pass 344 | def hook_fn(layer, inputs, outputs): 345 | # Gets the layer name 346 | name = layer_names[layer] 347 | # Detaches outputs 348 | values = outputs.detach().cpu().numpy() 349 | # Since the hook function may be called multiple times 350 | # for example, if we make predictions for multiple mini-batches 351 | # it concatenates the results 352 | if self.visualization[name] is None: 353 | self.visualization[name] = values 354 | else: 355 | self.visualization[name] = np.concatenate([self.visualization[name], values]) 356 | 357 | for name, layer in modules: 358 | # If the layer is in our list 359 | if name in layers_to_hook: 360 | # Initializes the corresponding key in the dictionary 361 | self.visualization[name] = None 362 | # Register the forward hook and keep the handle in another dict 363 | self.handles[name] = layer.register_forward_hook(hook_fn) 364 | 365 | def remove_hooks(self): 366 | # Loops through all hooks and removes them 367 | for handle in self.handles.values(): 368 | handle.remove() 369 | # Clear the dict, as all hooks have been removed 370 | self.handles = {} 371 | 372 | def visualize_outputs(self, layers, n_images=10, y=None, yhat=None): 373 | layers = list(filter(lambda l: l in self.visualization.keys(), layers)) 374 | shapes = [self.visualization[layer].shape for layer in layers] 375 | n_rows = [shape[1] if len(shape) == 4 else 1 for shape in shapes] 376 | total_rows = np.sum(n_rows) 377 | 378 | fig, axes = plt.subplots(total_rows, n_images, figsize=(1.5*n_images, 1.5*total_rows)) 379 | axes = np.atleast_2d(axes).reshape(total_rows, n_images) 380 | 381 | # Loops through the layers, one layer per row of subplots 382 | row = 0 383 | for i, layer in enumerate(layers): 384 | start_row = row 385 | # Takes the produced feature maps for that layer 386 | output = self.visualization[layer] 387 | 388 | is_vector = len(output.shape) == 2 389 | 390 | for j in range(n_rows[i]): 391 | ModelTrainer._visualize_tensors( 392 | axes[row, :], 393 | output if is_vector else output[:, j].squeeze(), 394 | y, 395 | yhat, 396 | layer_name=layers[i] if is_vector else '{}\nfil#{}'.format(layers[i], row-start_row), 397 | title='Image' if (row == 0) else None 398 | ) 399 | row += 1 400 | 401 | for ax in axes.flat: 402 | ax.label_outer() 403 | 404 | plt.tight_layout() 405 | return fig 406 | 407 | def correct(self, x, y, threshold=.5): 408 | self.model.eval() 409 | yhat = self.model(x.to(self.device)) 410 | y = y.to(self.device) 411 | self.model.train() 412 | 413 | # We get the size of the batch and the number of classes 414 | # (only 1, if it is binary) 415 | n_samples, n_dims = yhat.shape 416 | if n_dims > 1: 417 | # In a multiclass classification, the biggest logit 418 | # always wins, so we don't bother getting probabilities 419 | 420 | # This is PyTorch's version of argmax, 421 | # but it returns a tuple: (max value, index of max value) 422 | _, predicted = torch.max(yhat, 1) 423 | else: 424 | n_dims += 1 425 | # In binary classification, we NEED to check if the 426 | # last layer is a sigmoid (and then it produces probs) 427 | if isinstance(self.model, nn.Sequential) and \ 428 | isinstance(self.model[-1], nn.Sigmoid): 429 | predicted = (yhat > threshold).long() 430 | # or something else (logits), which we need to convert 431 | # using a sigmoid 432 | else: 433 | predicted = (yhat > threshold).long() 434 | 435 | # How many samples got classified correctly for each class 436 | result = [] 437 | for c in range(n_dims): 438 | n_class = (y == c).sum().item() 439 | n_correct = (predicted[y == c] == c).sum().item() 440 | result.append((n_correct, n_class)) 441 | return torch.tensor(result) 442 | 443 | @staticmethod 444 | def loader_apply(loader, func, reduce='sum'): 445 | results = [func(x, y) for i, (x, y) in enumerate(loader)] 446 | results = torch.stack(results, axis=0) 447 | 448 | if reduce == 'sum': 449 | results = results.sum(axis=0) 450 | elif reduce == 'mean': 451 | results = results.float().mean(axis=0) 452 | 453 | return results 454 | 455 | @staticmethod 456 | def statistics_per_channel(images, labels): 457 | # NCHW 458 | n_samples, n_channels, n_height, n_weight = images.size() 459 | # Flatten HW into a single dimension 460 | flatten_per_channel = images.reshape(n_samples, n_channels, -1) 461 | 462 | # Computes statistics of each image per channel 463 | # Average pixel value per channel 464 | # (n_samples, n_channels) 465 | means = flatten_per_channel.mean(axis=2) 466 | # Standard deviation of pixel values per channel 467 | # (n_samples, n_channels) 468 | stds = flatten_per_channel.std(axis=2) 469 | 470 | # Adds up statistics of all images in a mini-batch 471 | # (1, n_channels) 472 | sum_means = means.sum(axis=0) 473 | sum_stds = stds.sum(axis=0) 474 | # Makes a tensor of shape (1, n_channels) 475 | # with the number of samples in the mini-batch 476 | n_samples = torch.tensor([n_samples]*n_channels).float() 477 | 478 | # Stack the three tensors on top of one another 479 | # (3, n_channels) 480 | return torch.stack([n_samples, sum_means, sum_stds], axis=0) 481 | 482 | @staticmethod 483 | def make_normalizer(loader): 484 | total_samples, total_means, total_stds = ModelTrainer.loader_apply(loader, ModelTrainer.statistics_per_channel) 485 | norm_mean = total_means / total_samples 486 | norm_std = total_stds / total_samples 487 | return Normalize(mean=norm_mean, std=norm_std) 488 | 489 | def lr_range_test(self, data_loader, end_lr, num_iter=100, step_mode='exp', alpha=0.05, ax=None): 490 | # Since the test updates both model and optimizer we need to store 491 | # their initial states to restore them in the end 492 | previous_states = {'model': deepcopy(self.model.state_dict()), 493 | 'optimizer': deepcopy(self.optimizer.state_dict())} 494 | # Retrieves the learning rate set in the optimizer 495 | start_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 496 | 497 | # Builds a custom function and corresponding scheduler 498 | lr_fn = make_lr_fn(start_lr, end_lr, num_iter) 499 | scheduler = LambdaLR(self.optimizer, lr_lambda=lr_fn) 500 | 501 | # Variables for tracking results and iterations 502 | tracking = {'loss': [], 'lr': []} 503 | iteration = 0 504 | 505 | # If there are more iterations than mini-batches in the data loader, 506 | # it will have to loop over it more than once 507 | while (iteration < num_iter): 508 | # That's the typical mini-batch inner loop 509 | for x_batch, y_batch in data_loader: 510 | x_batch = x_batch.to(self.device) 511 | y_batch = y_batch.to(self.device) 512 | # Step 1 513 | yhat = self.model(x_batch) 514 | # Step 2 515 | loss = self.loss_fn(yhat, y_batch) 516 | # Step 3 517 | loss.backward() 518 | 519 | # Here we keep track of the losses (smoothed) 520 | # and the learning rates 521 | tracking['lr'].append(scheduler.get_last_lr()[0]) 522 | if iteration == 0: 523 | tracking['loss'].append(loss.item()) 524 | else: 525 | prev_loss = tracking['loss'][-1] 526 | smoothed_loss = alpha * loss.item() + (1-alpha) * prev_loss 527 | tracking['loss'].append(smoothed_loss) 528 | 529 | iteration += 1 530 | # Number of iterations reached 531 | if iteration == num_iter: 532 | break 533 | 534 | # Step 4 535 | self.optimizer.step() 536 | scheduler.step() 537 | self.optimizer.zero_grad() 538 | 539 | # Restores the original states 540 | self.optimizer.load_state_dict(previous_states['optimizer']) 541 | self.model.load_state_dict(previous_states['model']) 542 | 543 | if ax is None: 544 | fig, ax = plt.subplots(1, 1, figsize=(6, 4)) 545 | else: 546 | fig = ax.get_figure() 547 | ax.plot(tracking['lr'], tracking['loss']) 548 | if step_mode == 'exp': 549 | ax.set_xscale('log') 550 | ax.set_xlabel('Learning Rate') 551 | ax.set_ylabel('Loss') 552 | fig.tight_layout() 553 | return tracking, fig 554 | 555 | def set_optimizer(self, optimizer): 556 | self.optimizer = optimizer 557 | 558 | def capture_gradients(self, layers_to_hook): 559 | if not isinstance(layers_to_hook, list): 560 | layers_to_hook = [layers_to_hook] 561 | 562 | modules = list(self.model.named_modules()) 563 | self._gradients = {} 564 | 565 | def make_log_fn(name, parm_id): 566 | def log_fn(grad): 567 | self._gradients[name][parm_id].append(grad.tolist()) 568 | return 569 | return log_fn 570 | 571 | for name, layer in self.model.named_modules(): 572 | if name in layers_to_hook: 573 | self._gradients.update({name: {}}) 574 | for parm_id, p in layer.named_parameters(): 575 | if p.requires_grad: 576 | self._gradients[name].update({parm_id: []}) 577 | log_fn = make_log_fn(name, parm_id) 578 | self.handles[f'{name}.{parm_id}.grad'] = p.register_hook(log_fn) 579 | return 580 | 581 | def capture_parameters(self, layers_to_hook): 582 | if not isinstance(layers_to_hook, list): 583 | layers_to_hook = [layers_to_hook] 584 | 585 | modules = list(self.model.named_modules()) 586 | layer_names = {layer: name for name, layer in modules} 587 | 588 | self._parameters = {} 589 | 590 | for name, layer in modules: 591 | if name in layers_to_hook: 592 | self._parameters.update({name: {}}) 593 | for parm_id, p in layer.named_parameters(): 594 | self._parameters[name].update({parm_id: []}) 595 | 596 | def fw_hook_fn(layer, inputs, outputs): 597 | name = layer_names[layer] 598 | for parm_id, parameter in layer.named_parameters(): 599 | self._parameters[name][parm_id].append(parameter.tolist()) 600 | 601 | self.attach_hooks(layers_to_hook, fw_hook_fn) 602 | return 603 | 604 | def set_lr_scheduler(self, scheduler): 605 | # Makes sure the scheduler in the argument is assigned to the 606 | # optimizer we're using in this class 607 | if scheduler.optimizer == self.optimizer: 608 | self.scheduler = scheduler 609 | if (isinstance(scheduler, optim.lr_scheduler.CyclicLR) or 610 | isinstance(scheduler, optim.lr_scheduler.OneCycleLR) or 611 | isinstance(scheduler, optim.lr_scheduler.CosineAnnealingWarmRestarts)): 612 | self.is_batch_lr_scheduler = True 613 | else: 614 | self.is_batch_lr_scheduler = False 615 | 616 | def _epoch_schedulers(self, val_loss): 617 | if self.scheduler: 618 | if not self.is_batch_lr_scheduler: 619 | if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau): 620 | self.scheduler.step(val_loss) 621 | else: 622 | self.scheduler.step() 623 | 624 | current_lr = list(map(lambda d: d['lr'], self.scheduler.optimizer.state_dict()['param_groups'])) 625 | self.learning_rates.append(current_lr) 626 | 627 | def _mini_batch_schedulers(self, frac_epoch): 628 | if self.scheduler: 629 | if self.is_batch_lr_scheduler: 630 | if isinstance(self.scheduler, optim.lr_scheduler.CosineAnnealingWarmRestarts): 631 | self.scheduler.step(self.total_epochs + frac_epoch) 632 | else: 633 | self.scheduler.step() 634 | 635 | current_lr = list(map(lambda d: d['lr'], self.scheduler.optimizer.state_dict()['param_groups'])) 636 | self.learning_rates.append(current_lr) 637 | 638 | def set_clip_grad_value(self, clip_value): 639 | self.clipping = lambda: nn.utils.clip_grad_value_(self.model.parameters(), clip_value=clip_value) 640 | 641 | def set_clip_grad_norm(self, max_norm, norm_type=2): 642 | self.clipping = lambda: nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=max_norm, norm_type=norm_type) 643 | 644 | def set_clip_backprop(self, clip_value): 645 | if self.clipping is None: 646 | self.clipping = [] 647 | for p in self.model.parameters(): 648 | if p.requires_grad: 649 | func = lambda grad: torch.clamp(grad, -clip_value, clip_value) 650 | handle = p.register_hook(func) 651 | self.clipping.append(handle) 652 | 653 | def remove_clip(self): 654 | if isinstance(self.clipping, list): 655 | for handle in self.clipping: 656 | handle.remove() 657 | self.clipping = None -------------------------------------------------------------------------------- /301/plot_squares.py: -------------------------------------------------------------------------------- 1 | 2 | from matplotlib import pyplot as plt 3 | import numpy as np 4 | 5 | def plot_squares(points, directions, n_rows=2, n_cols=5): 6 | fig, axs = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows)) 7 | axs = axs.flatten() 8 | 9 | for e, ax in enumerate(axs): 10 | pred_corners = points[e] 11 | clockwise = directions[e] 12 | for i in range(4): 13 | color = 'k' 14 | ax.scatter(*pred_corners.T, c=color, s=400) 15 | if i == 3: 16 | start = -1 17 | else: 18 | start = i 19 | ax.plot(*pred_corners[[start, start+1]].T, c='k', lw=2, alpha=.5, linestyle='-') 20 | ax.text(*(pred_corners[i] - np.array([.04, 0.04])), str(i+1), c='w', fontsize=12) 21 | if directions is not None: 22 | ax.set_title(f'{"Counter-" if not clockwise else ""}Clockwise (y={clockwise})', fontsize=14) 23 | 24 | ax.set_xlabel(r"$x_0$") 25 | ax.set_ylabel(r"$x_1$", rotation=0) 26 | ax.set_xlim([-1.5, 1.5]) 27 | ax.set_ylim([-1.5, 1.5]) 28 | 29 | fig.tight_layout() 30 | return fig 31 | 32 | def counter_vs_clock(basic_corners=None, basic_colors=None, basic_letters=None, draw_arrows=True, binary=True): 33 | transparent_alpha = 0.2 34 | if basic_corners is None: 35 | basic_corners = np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]]) 36 | clock_arrows = np.array([[0, -1], [-1, 0], [0, 1], [1, 0]]) 37 | else: 38 | clock_arrows = np.array([[0, basic_corners[0][1]], [basic_corners[1][0], 0], 39 | [0, basic_corners[2][1]], [basic_corners[3][0], 0]]) 40 | 41 | if basic_colors is None: 42 | basic_colors = ['gray', 'g', 'b', 'r'] 43 | if basic_letters is None: 44 | basic_letters = ['A', 'B', 'C', 'D'] 45 | 46 | fig, axs = plt.subplots(1 + draw_arrows, 1, figsize=(3, 3+ 3 * draw_arrows)) 47 | if not draw_arrows: 48 | axs = [axs] 49 | 50 | corners = basic_corners[:] 51 | factor = (corners.max(axis=0) - corners.min(axis=0)).max() / 2 52 | 53 | for is_clock in range(1 + draw_arrows): 54 | if draw_arrows: 55 | if binary: 56 | if is_clock: 57 | axs[is_clock].text(-.5, 0, 'Clockwise') 58 | axs[is_clock].text(-.2, -.25, 'y=1') 59 | else: 60 | axs[is_clock].text(-.5, .0, ' Counter-\nClockwise') 61 | axs[is_clock].text(-.2, -.25, 'y=0') 62 | 63 | for i in range(4): 64 | coords = corners[i] 65 | color = basic_colors[i] 66 | letter = basic_letters[i] 67 | if not binary: 68 | targets = [2, 3] if is_clock else [1, 2] 69 | else: 70 | targets = [] 71 | 72 | alpha = transparent_alpha if i in targets else 1.0 73 | axs[is_clock].scatter(*coords, c=color, s=400, alpha=alpha) 74 | 75 | start = i 76 | if is_clock: 77 | end = i + 1 if i < 3 else 0 78 | arrow_coords = np.stack([corners[start] - clock_arrows[start]*0.15, 79 | corners[end] + clock_arrows[start]*0.15]) 80 | else: 81 | end = i - 1 if i > 0 else -1 82 | arrow_coords = np.stack([corners[start] + clock_arrows[end]*0.15, 83 | corners[end] - clock_arrows[end]*0.15]) 84 | alpha = 1.0 85 | if draw_arrows: 86 | alpha = transparent_alpha if ((start in targets) or (end in targets)) else 1.0 87 | line = axs[is_clock].plot(*arrow_coords.T, c=color, lw=0 if draw_arrows else 2, 88 | alpha=alpha, linestyle='--' if (alpha < 1) and (not draw_arrows) else '-')[0] 89 | if draw_arrows: 90 | add_arrow(line, lw=3, alpha=alpha) 91 | 92 | axs[is_clock].text(*(coords - factor*np.array([.05, 0.05])), letter, c='k' if i in targets else 'w', fontsize=12, alpha=transparent_alpha if i in targets else 1.0) 93 | 94 | axs[is_clock].grid(False) 95 | limits = np.stack([corners.min(axis=0), corners.max(axis=0)]) 96 | limits = limits.mean(axis=0).reshape(2, 1) + 1.2*np.array([[-factor, factor]]) 97 | axs[is_clock].set_xlim(limits[0]) 98 | axs[is_clock].set_ylim(limits[1]) 99 | 100 | axs[is_clock].set_xlabel(r'$x_0$') 101 | axs[is_clock].set_ylabel(r'$x_1$', rotation=0) 102 | 103 | fig.tight_layout() 104 | 105 | return fig 106 | 107 | def sequence_pred(trainer_obj, X, directions=None, n_rows=2, n_cols=5): 108 | fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows)) 109 | axs = axs.flatten() 110 | 111 | for e, ax in enumerate(axs): 112 | first_corners = X[e, :2, :] 113 | trainer_obj.model.eval() 114 | next_corners = trainer_obj.model(X[e:e+1, :2].to(trainer_obj.device)).squeeze().detach().cpu().numpy() 115 | pred_corners = np.concatenate([first_corners, next_corners], axis=0) 116 | 117 | for j, corners in enumerate([X[e], pred_corners]): 118 | for i in range(4): 119 | coords = corners[i] 120 | color = 'k' 121 | ax.scatter(*coords, c=color, s=400) 122 | if i == 3: 123 | start = -1 124 | else: 125 | start = i 126 | if (not j) or (j and i): 127 | ax.plot(*corners[[start, start+1]].T, c='k', lw=2, alpha=.5, linestyle='--' if j else '-') 128 | ax.text(*(coords - np.array([.04, 0.04])), str(i+1), c='w', fontsize=12) 129 | if directions is not None: 130 | ax.set_title(f'{"Counter-" if not directions[e] else ""}Clockwise') 131 | 132 | ax.set_xlabel(r"$x_0$") 133 | ax.set_ylabel(r"$x_1$", rotation=0) 134 | ax.set_xlim([-1.7, 1.7]) 135 | ax.set_ylim([-1.7, 1.7]) 136 | 137 | fig.tight_layout() 138 | return fig 139 | 140 | 141 | def add_arrow(line, position=None, direction='right', size=15, color=None, lw=2, alpha=1.0, text=None, text_offset=(0 , 0)): 142 | """ 143 | add an arrow to a line. 144 | 145 | line: Line2D object 146 | position: x-position of the arrow. If None, mean of xdata is taken 147 | direction: 'left' or 'right' 148 | size: size of the arrow in fontsize points 149 | color: if None, line color is taken. 150 | """ 151 | if color is None: 152 | color = line.get_color() 153 | 154 | xdata = line.get_xdata() 155 | ydata = line.get_ydata() 156 | 157 | if position is None: 158 | position = xdata.mean() 159 | # find closest index 160 | start_ind = np.argmin(np.absolute(xdata - position)) 161 | if direction == 'right': 162 | end_ind = start_ind + 1 163 | else: 164 | end_ind = start_ind - 1 165 | 166 | line.axes.annotate('', 167 | xytext=(xdata[start_ind], ydata[start_ind]), 168 | xy=(xdata[end_ind], ydata[end_ind]), 169 | arrowprops=dict(arrowstyle="->", color=color, lw=lw, linestyle='--' if alpha < 1 else '-', alpha=alpha), 170 | size=size, 171 | ) 172 | if text is not None: 173 | line.axes.annotate(text, color=color, 174 | xytext=(xdata[end_ind] + text_offset[0], ydata[end_ind] + text_offset[1]), 175 | xy=(xdata[end_ind], ydata[end_ind]), 176 | size=size, 177 | ) 178 | 179 | 180 | 181 | # plot the clock and counter-clock sequence with the unknowns gray out 182 | counter_vs_clock(binary=False) 183 | -------------------------------------------------------------------------------- /301/seq2seq_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset 5 | from square_data_generation import generate_sequences 6 | from model_trainer import ModelTrainer 7 | from plot_squares import sequence_pred 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, n_features, hidden_dim): 11 | super().__init__() 12 | self.hidden_dim = hidden_dim 13 | self.n_features = n_features 14 | self.hidden = None 15 | self.basic_rnn = nn.GRU(self.n_features, self.hidden_dim, batch_first=True) 16 | 17 | def forward(self, X): 18 | rnn_out, self.hidden = self.basic_rnn(X) 19 | return rnn_out # N, L, F 20 | 21 | class Decoder(nn.Module): 22 | def __init__(self, n_features, hidden_dim): 23 | super().__init__() 24 | self.hidden_dim = hidden_dim 25 | self.n_features = n_features 26 | self.hidden = None 27 | self.basic_rnn = nn.GRU(self.n_features, self.hidden_dim, batch_first=True) 28 | self.regression = nn.Linear(self.hidden_dim, self.n_features) 29 | 30 | def init_hidden(self, hidden_seq): 31 | self.hidden = hidden_seq[:, -1:].permute(1, 0, 2) # from N, 1, H to 1,N,H 32 | 33 | def forward(self, X): 34 | batch_first_output, self.hidden = self.basic_rnn(X, self.hidden) 35 | last_output = batch_first_output[:, -1] 36 | return self.regression(last_output).view(-1, 1, self.n_features) # N,1,F 37 | 38 | 39 | class EncoderDecoder(nn.Module): 40 | def __init__(self, encoder, decoder, input_len, target_len, teacher_forcing_prob=0.5): 41 | super().__init__() 42 | self.encoder = encoder 43 | self.decoder = decoder 44 | self.input_len = input_len 45 | self.target_len = target_len 46 | self.teacher_forcing_prob = teacher_forcing_prob 47 | self.outputs = None 48 | 49 | def init_outputs(self, batch_size): 50 | device = next(self.parameters()).device 51 | # N, L (target), F 52 | self.outputs = torch.zeros(batch_size, 53 | self.target_len, 54 | self.encoder.n_features).to(device) 55 | 56 | def store_output(self, i, out): 57 | # Stores the output 58 | self.outputs[:, i:i+1, :] = out 59 | 60 | def forward(self, X): 61 | # splits the data in source and target sequences 62 | # the target seq will be empty in testing mode 63 | # N, L, F 64 | source_seq = X[:, :self.input_len, :] 65 | target_seq = X[:, self.input_len:, :] 66 | self.init_outputs(X.shape[0]) 67 | # Encoder expected N, L, F 68 | hidden_seq = self.encoder(source_seq) 69 | # Output is N, L, H 70 | self.decoder.init_hidden(hidden_seq) 71 | # The last input of the encoder is also 72 | # the first input of the decoder 73 | dec_inputs = source_seq[:, -1:, :] 74 | # Generates as many outputs as the target length 75 | for i in range(self.target_len): 76 | # Output of decoder is N, 1, F 77 | out = self.decoder(dec_inputs) 78 | self.store_output(i, out) 79 | prob = self.teacher_forcing_prob 80 | # In evaluation/test the target sequence is 81 | # unknown, so we cannot use teacher forcing 82 | if not self.training: 83 | prob = 0 84 | # If it is teacher forcing 85 | if torch.rand(1) <= prob: 86 | # Takes the actual element 87 | dec_inputs = target_seq[:, i:i+1, :] 88 | else: 89 | # Otherwise uses the last predicted output 90 | dec_inputs = out 91 | return self.outputs 92 | 93 | # Setup the model 94 | torch.manual_seed(123) 95 | encoder = Encoder(n_features=2, hidden_dim=2) 96 | decoder = Decoder(n_features=2, hidden_dim=2) 97 | model = EncoderDecoder(encoder, decoder, input_len=2, target_len=2, teacher_forcing_prob=0.5) 98 | loss = nn.MSELoss() 99 | optimizer = optim.Adam(model.parameters(), lr=0.01) 100 | 101 | # Training Data 102 | points, directions = generate_sequences(n=256) 103 | full_train = torch.as_tensor(points).float() 104 | target_train = full_train[:, 2:] 105 | 106 | # Testing Data 107 | test_points, test_directions = generate_sequences(n=1024) 108 | full_test = torch.as_tensor(test_points).float() 109 | source_test = full_test[:, :2] 110 | target_test = full_test[:, 2:] 111 | 112 | # Datasets and DataLoaders 113 | train_data = TensorDataset(full_train, target_train) 114 | test_data = TensorDataset(source_test, target_test) 115 | 116 | train_loader = DataLoader(train_data, batch_size=16, shuffle=True) 117 | test_loader = DataLoader(test_data, batch_size=16) 118 | 119 | # Train the model 120 | seq2seq_simple = ModelTrainer(model, loss, optimizer) 121 | seq2seq_simple.set_loaders(train_loader, test_loader) 122 | seq2seq_simple.train(100) 123 | 124 | seq2seq_simple.plot_losses() 125 | 126 | sequence_pred(seq2seq_simple, full_test, test_directions) -------------------------------------------------------------------------------- /301/seq2seq_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset 5 | from square_data_generation import generate_sequences 6 | from model_trainer import ModelTrainer 7 | from plot_squares import sequence_pred 8 | 9 | class PositionalEncoding(nn.Module): 10 | def __init__(self, max_len, d_model): 11 | super(PositionalEncoding, self).__init__() 12 | self.d_model = d_model 13 | pe = torch.zeros(max_len, d_model) 14 | position = torch.arange(0, max_len).float().unsqueeze(1) 15 | slope = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * slope) # even dimensions 17 | pe[:, 1::2] = torch.cos(position * slope) # odd dimensions 18 | self.register_buffer('pe', pe.unsqueeze(0)) 19 | 20 | def forward(self, x): 21 | # x is N, L, D 22 | # pe is 1, maxlen, D 23 | scaled_x = x * np.sqrt(self.d_model) 24 | encoded = scaled_x + self.pe[:, :x.size(1), :] 25 | return encoded 26 | 27 | class TransformerModel(nn.Module): 28 | def __init__(self, transformer, input_len, target_len, n_features): 29 | super().__init__() 30 | self.transf = transformer 31 | self.input_len = input_len 32 | self.target_len = target_len 33 | self.trg_masks = self.transf.generate_square_subsequent_mask(self.target_len) 34 | self.n_features = n_features 35 | self.proj = nn.Linear(n_features, self.transf.d_model) 36 | self.linear = nn.Linear(self.transf.d_model, n_features) 37 | 38 | max_len = max(self.input_len, self.target_len) 39 | self.pe = PositionalEncoding(max_len, self.transf.d_model) 40 | self.norm = nn.LayerNorm(self.transf.d_model) 41 | 42 | def preprocess(self, seq): 43 | seq_proj = self.proj(seq) 44 | seq_enc = self.pe(seq_proj) 45 | return self.norm(seq_enc) 46 | 47 | def encode_decode(self, source, target, source_mask=None, target_mask=None): 48 | # Projections 49 | # PyTorch Transformer expects L, N, F 50 | src = self.preprocess(source).permute(1, 0, 2) 51 | tgt = self.preprocess(target).permute(1, 0, 2) 52 | 53 | out = self.transf(src, tgt, 54 | src_key_padding_mask=source_mask, 55 | tgt_mask=target_mask) 56 | 57 | # Linear 58 | # Back to N, L, D 59 | out = out.permute(1, 0, 2) 60 | out = self.linear(out) # N, L, F 61 | return out 62 | 63 | def predict(self, source_seq, source_mask=None): 64 | inputs = source_seq[:, -1:] 65 | for i in range(self.target_len): 66 | out = self.encode_decode(source_seq, inputs, 67 | source_mask=source_mask, 68 | target_mask=self.trg_masks[:i+1, :i+1]) 69 | out = torch.cat([inputs, out[:, -1:, :]], dim=-2) 70 | inputs = out.detach() 71 | outputs = out[:, 1:, :] 72 | return outputs 73 | 74 | def forward(self, X, source_mask=None): 75 | self.trg_masks = self.trg_masks.type_as(X) 76 | source_seq = X[:, :self.input_len, :] 77 | 78 | if self.training: 79 | shifted_target_seq = X[:, self.input_len-1:-1, :] 80 | outputs = self.encode_decode(source_seq, shifted_target_seq, 81 | source_mask=source_mask, 82 | target_mask=self.trg_masks) 83 | else: 84 | outputs = self.predict(source_seq, source_mask) 85 | 86 | return outputs 87 | 88 | torch.manual_seed(123) 89 | transformer = nn.Transformer(d_model=6, nhead=3, 90 | num_encoder_layers=1, num_decoder_layers=1, 91 | dim_feedforward=20, dropout=0.1) 92 | model = TransformerModel(transformer, input_len=2, target_len=2, n_features=2) 93 | loss = nn.MSELoss() 94 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 95 | 96 | for p in model.parameters(): 97 | if p.dim() > 1: 98 | nn.init.xavier_uniform_(p) 99 | 100 | # Training Data 101 | points, directions = generate_sequences(n=512) 102 | full_train = torch.as_tensor(points).float() 103 | target_train = full_train[:, 2:] 104 | 105 | # Testing Data 106 | test_points, test_directions = generate_sequences(n=1024) 107 | full_test = torch.as_tensor(test_points).float() 108 | source_test = full_test[:, :2] 109 | target_test = full_test[:, 2:] 110 | 111 | # Datasets and DataLoaders 112 | train_data = TensorDataset(full_train, target_train) 113 | test_data = TensorDataset(source_test, target_test) 114 | 115 | train_loader = DataLoader(train_data, batch_size=16, shuffle=True) 116 | test_loader = DataLoader(test_data, batch_size=16) 117 | 118 | # Train the model 119 | seq2seq_transformer = ModelTrainer(model, loss, optimizer) 120 | seq2seq_transformer.set_loaders(train_loader, test_loader) 121 | seq2seq_transformer.train(100) 122 | 123 | seq2seq_transformer.plot_losses() 124 | 125 | sequence_pred(seq2seq_transformer, full_test, test_directions) -------------------------------------------------------------------------------- /301/sequence_classification_model.py: -------------------------------------------------------------------------------- 1 | from square_data_generation import generate_sequences 2 | from model_trainer import ModelTrainer 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | 9 | train_points, train_directions = generate_sequences(n=256) 10 | test_points, test_directions = generate_sequences(n=1024) 11 | train_data = TensorDataset(torch.as_tensor(train_points).float(), 12 | torch.as_tensor(train_directions).view(-1, 1).float()) 13 | test_data = TensorDataset(torch.as_tensor(test_points).float(), 14 | torch.as_tensor(test_directions).view(-1, 1).float()) 15 | train_loader = DataLoader(train_data, batch_size=64, shuffle=True) 16 | test_loader = DataLoader(test_data, batch_size=64) 17 | 18 | class SquareModel(nn.Module): 19 | def __init__(self, n_features, hidden_dim, n_outputs): 20 | super(SquareModel, self).__init__() 21 | self.hidden_dim = hidden_dim 22 | self.n_features = n_features 23 | self.n_outputs = n_outputs 24 | self.hidden = None 25 | # Simple RNN 26 | self.basic_rnn = nn.RNN(self.n_features, self.hidden_dim, batch_first=True) 27 | # self.basic_rnn = nn.GRU(self.n_features, self.hidden_dim, batch_first=True) # Simply change the RNN to GRU will make the testing errors go to 0. 28 | # Classifier to produce as many logits as outputs 29 | self.classifier = nn.Linear(self.hidden_dim, self.n_outputs) 30 | 31 | def forward(self, X): 32 | # X is batch first (N, L, F) 33 | # output is (N, L, H) 34 | # final hidden state is (1, N, H) 35 | batch_first_output, self.hidden = self.basic_rnn(X) 36 | 37 | # only last item in sequence (N, 1, H) 38 | last_output = batch_first_output[:, -1] 39 | # classifier will output (N, 1, n_outputs) 40 | out = self.classifier(last_output) 41 | # final output is (N, n_outputs) 42 | out = F.sigmoid(out) 43 | return out 44 | 45 | 46 | model = SquareModel(n_features=2, hidden_dim=2, n_outputs=1) 47 | loss_fn = nn.BCEWithLogitsLoss() 48 | optimizer = optim.Adam(model.parameters(), lr=0.01) 49 | 50 | simple_rnn = ModelTrainer(model, loss_fn, optimizer) 51 | simple_rnn.set_loaders(train_loader, test_loader) 52 | simple_rnn.train(100) 53 | simple_rnn.plot_losses() 54 | 55 | ModelTrainer.loader_apply(test_loader,simple_rnn.correct) 56 | # tensor([[509, 512], 57 | # [488, 512]]) 58 | # (509+488)/(512+512) = 97.4% correction -------------------------------------------------------------------------------- /301/square_data_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | seed = 0 3 | rv = np.random.RandomState(seed) 4 | 5 | def generate_sequences(n=128): 6 | """ 7 | Generates sequences of points forming squares, either clockwise or counter-clockwise. 8 | 9 | Each sequence consists of points that represent the corners of a square. The sequence can 10 | either go around the square in a clockwise or counter-clockwise direction. This is determined 11 | randomly for each sequence. The sequences are also slightly randomized by adding a small 12 | noise to each point. 13 | 14 | Args: 15 | n (int): The number of sequences to generate. Default is 128. 16 | 17 | Returns: 18 | tuple: A tuple containing two elements: 19 | - A list of arrays, where each array represents a sequence of points (corners of a square). 20 | - An array indicating the direction of each sequence (0 for counter-clockwise, 1 for clockwise). 21 | 22 | Example: 23 | >>> sequences, directions = generate_sequences(n=5) 24 | >>> print(sequences[0]) # Prints first sequence of points 25 | >>> print(directions[0]) # Prints direction of the first sequence (0 or 1) 26 | """ 27 | basic_corners = np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]]) 28 | bases = rv.randint(4, size=n) # Starting corner indices for each sequence. 29 | directions = np.random.randint(2, size=n) # Direction (0 for CCW, 1 for CW) for each sequence. 30 | 31 | # Generating the point sequences. 32 | points = [basic_corners[[(b + i) % 4 for i in range(4)]][::d*2-1] + np.random.randn(4, 2) * 0.1 33 | for b, d in zip(bases, directions)] 34 | 35 | return points, directions -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learn PyTorch 2 | 3 | The purpose of this repo is to create education materials on PyTorch for model fitting. 4 | The materials are oragnized as course number from 101, 201, etc.. 5 | 6 | Here are some summaries 7 | - 101 We start from using a numpy to fit a linear model, then we use PyTorch utilities to simplify the work 8 | - 201 This is the MINST example building a deep neural network for hand writting zip code recognition. 9 | - 301 Seq classification and seq2seq models 10 | - 401 SDE diffusion model on Generative AI 11 | 12 | 13 | # Setting Up a Python Virtual Environment for Deep Learning 14 | 15 | This guide will walk you through setting up a Python virtual environment on your local machine. This setup allows you to manage dependencies and ensure that the deep learning code from this course runs smoothly in Jupyter Notebook. 16 | 17 | ## Prerequisites 18 | 19 | Before you begin, make sure you have Python installed on your system. You can download Python from python.org. This course assumes you are using Python 3. 20 | 21 | ## Step 1: Create a Virtual Environment 22 | 23 | First, open your terminal and navigate to your project directory or where you want to set up your virtual environment. 24 | 25 | `cd path/to/your/project-directory` 26 | 27 | Create a virtual environment named env_name by running: 28 | 29 | `python3 -m venv env_name` 30 | 31 | Replace env_name with your preferred name for the virtual environment. 32 | 33 | ## Step 2: Activate the Virtual Environment 34 | 35 | Activate the virtual environment using the command below: 36 | 37 | On Windows: 38 | 39 | `.\env_name\Scripts\activate` 40 | 41 | On MacOS and Linux: 42 | 43 | `source env_name/bin/activate` 44 | 45 | You should see the name of your virtual environment in parentheses on your terminal prompt, indicating that it is active. 46 | 47 | ## Step 3: Install Required Packages 48 | 49 | With the virtual environment activated, install the necessary Python packages for the course: 50 | 51 | `pip install torch jupyterlab jupyter` 52 | 53 | This command installs PyTorch, JupyterLab, and Jupyter. You can modify this command to include any other packages you need for the course. 54 | 55 | ## Step 4: Start JupyterLab 56 | 57 | After installing the packages, you can start JupyterLab by running: 58 | 59 | `jupyter lab` 60 | 61 | This command will start the JupyterLab server, and you should see a link in your terminal that you can open in a web browser to access JupyterLab. 62 | 63 | ## Conclusion 64 | 65 | You now have a fully functional Python virtual environment with JupyterLab, ready to run the deep learning code provided in this course. If you encounter any issues, ensure that your virtual environment is activated and that you've installed all required packages. 66 | 67 | --------------------------------------------------------------------------------