├── CIFAR-10_experiments.py ├── README.md ├── mixmatch_utils.py └── notebook.ipynb /CIFAR-10_experiments.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This code tries to implement the MixMatch technique from the [paper](https://arxiv.org/pdf/1905.02249.pdf) MixMatch: A Holistic Approach to Semi-Supervised Learning and recreate their results on CIFAR10 with WideResnet28. 4 | 5 | It depends on Pytorch, Numpy and imgaug. The WideResnet28 model code is taken from [meliketoy](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py)'s github repository. Hopefully I can train this on Colab with a Tesla T4. :) 6 | """ 7 | 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import torchvision.models as models 12 | 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | from sklearn.model_selection import train_test_split 16 | 17 | from mixmatch_utils import get_augmenter, mixmatch 18 | 19 | 20 | 21 | 22 | # 23 | # 24 | # 25 | # 26 | # def to_torch(*args, device='cuda'): 27 | # convert_fn = lambda x: torch.from_numpy(x).to(device) 28 | # return list(map(convert_fn, args)) 29 | # 30 | # """That about covers all the code we need for train and test loaders. Now we can start the training and evaluation. Let's see if all of this works or is just a mess. Going to add basically this same training code from meliketoy's repo but with the MixMatchLoss.""" 31 | # 32 | # def test(model, test_gen, test_iters): 33 | # acc = [] 34 | # for i, (x, y) in enumerate(test_gen): 35 | # x = to_torch(x) 36 | # pred = model(x).to('cpu').argmax(axis=1) 37 | # acc.append(np.mean(pred == y.argmax(axis=1))) 38 | # if i == test_iters: 39 | # break 40 | # print('Accuracy was : {}'.format(np.mean(acc))) 41 | # 42 | # def report(loss_history): 43 | # print('Average loss in last epoch was : {}'.format(np.mean(loss_history))) 44 | # return [] 45 | # 46 | # def save(model, iter, train_iters): 47 | # torch.save(model.state_dict(), 'model_{}.pth'.format(train_iters // iters)) 48 | # 49 | # def run(model, train_gen, test_gen, epochs, train_iters, test_iters, device): 50 | # optim = torch.optim.Adam(model.parameters(), lr=lr) 51 | # loss_fn = MixMatchLoss() 52 | # loss_history = [] 53 | # for i, (x, u, p, q) in enumerate(train_gen): 54 | # if i % train_iters == 0: 55 | # loss_history = report(loss_history) 56 | # test(model, test_gen, test_iters) 57 | # save(model, i, train_iters) 58 | # if i // train_iters == epochs: 59 | # return 60 | # else: 61 | # optim.zero_grad() 62 | # x, u, p, q = to_torch(x, u, p, q, device=device) 63 | # loss = loss_fn(x, u, p, q, model) 64 | # loss.backward() 65 | # optim.step() 66 | # loss_history.append(loss.to('cpu')) 67 | # 68 | # 69 | # import torch 70 | # import torchvision 71 | # import torchvision.transforms as transforms 72 | 73 | ######################################################################## 74 | # The output of torchvision datasets are PILImage images of range [0, 1]. 75 | # We transform them to Tensors of normalized range [-1, 1]. 76 | 77 | # transform = transforms.Compose( 78 | # [transforms.ToTensor(), 79 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 80 | 81 | 82 | 83 | # training_amount + training_u_amount + validation_amount <= 50 000 84 | 85 | 86 | def basic_generator(x, y=None, batch_size=32, shuffle=True): 87 | i = 0 88 | all_indices = np.arange(len(x)) 89 | if shuffle: 90 | np.random.shuffle(all_indices) 91 | while(True): 92 | indices = all_indices[i:i+batch_size] 93 | if y is not None: 94 | yield x[indices], y[indices] 95 | else: 96 | yield x[indices] 97 | i = (i + batch_size) % len(x) 98 | 99 | def mixmatch_wrapper(x, y, u, model, batch_size=32): 100 | augment_fn = get_augmenter() 101 | train_generator = basic_generator(x, y, batch_size) 102 | unlabeled_generator = basic_generator(u, batch_size=batch_size) 103 | while(True): 104 | xi, yi = next(train_generator) 105 | ui = next(unlabeled_generator) 106 | yield mixmatch(xi, yi, ui, model, augment_fn) 107 | 108 | # functions to show an image 109 | def imshow(img): 110 | img = img / 2 + 0.5 # unnormalize 111 | npimg=img 112 | #npimg = img.numpy() 113 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 114 | plt.show() 115 | 116 | 117 | training_amount = 300 118 | 119 | training_u_amount = 30000 120 | 121 | validation_amount = 10000 122 | 123 | transform = transforms.Compose( 124 | [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 125 | 126 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 127 | download=True, transform=transform) 128 | 129 | 130 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 131 | download=True, transform=transform) 132 | 133 | X_train = np.array(trainset.data) 134 | y_train = np.array(trainset.targets) 135 | 136 | X_test = np.array(testset.data) 137 | y_test = np.array(testset.targets) 138 | 139 | # Train set / Validation set split 140 | X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=validation_amount, random_state=1, 141 | shuffle=True, stratify=y_train) 142 | 143 | # Train unsupervised / Train supervised split 144 | # Train set / Validation set split 145 | X_train, X_u_train, y_train, y_u_train = train_test_split(X_train, y_train, test_size=training_u_amount, random_state=1, 146 | shuffle=True, stratify=y_train) 147 | 148 | X_remain, X_train, y_remain, y_train = train_test_split(X_train, y_train, test_size=training_amount, random_state=1, 149 | shuffle=True, stratify=y_train) 150 | 151 | 152 | classes = ('plane', 'car', 'bird', 'cat', 153 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 154 | 155 | 156 | model = models.mobilenet_v2() #TODO: Define "Wide ResNet-28" 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | # # get some random training images 165 | # dataiter = iter(trainloader) 166 | # images, labels = dataiter.next() 167 | # 168 | # # show images 169 | # imshow(torchvision.utils.make_grid(images)) 170 | # # print labels 171 | # print(' '.join('%5s' % classes[labels[j]] for j in range(4))) 172 | 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [WIP] MixMatch: A Holistic Approach to Semi-Supervised Learning 2 | A Pytorch Implementation of the paper MixMatch: A Holistic Approach to Semi-Supervised Learning [[paper](https://arxiv.org/pdf/1905.02249.pdf)]. Till it is no longer a WIP check notebook for latest code. 3 | ## TO-DO 4 | * Train on CIFAR10 data. 5 | * ~Add training code for CIFAR10 with WideResnet28 from [here](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py).~ 6 | 7 | ## Dependencies 8 | ``` 9 | pip install 10 | ``` 11 | 12 | To use this layer: 13 | ``` 14 | from layer import mixmatch, MixMatchLoss 15 | ``` 16 | -------------------------------------------------------------------------------- /mixmatch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import imgaug as ia 4 | import imgaug.augmenters as iaa 5 | 6 | 7 | def get_augmenter(): 8 | seq = iaa.Sequential([iaa.Fliplr(0.5), # horrizontal flips 9 | iaa.Crop(percent=(0, 0.1)), # random crops 10 | # Small gaussian blur with random sigma between 0 and 0.5. 11 | # But we only blur about 50% of all images. 12 | iaa.Sometimes(0.5, 13 | iaa.GaussianBlur(sigma=(0, 0.5)) 14 | ), 15 | # Strengthen or weaken the contrast in each image. 16 | iaa.ContrastNormalization((0.75, 1.5)), 17 | # Add gaussian noise. 18 | # For 50% of all images, we sample the noise once per pixel. 19 | # For the other 50% of all images, we sample the noise per pixel AND 20 | # channel. This can change the color (not only brightness) of the 21 | # pixels. 22 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5), 23 | # Make some images brighter and some darker. 24 | # In 20% of all cases, we sample the multiplier once per channel, 25 | # which can end up changing the color of the images. 26 | iaa.Multiply((0.8, 1.2), per_channel=0.2), 27 | # Apply affine transformations to each image. 28 | # Scale/zoom them, translate/move them, rotate them and shear them. 29 | iaa.Affine( 30 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 31 | translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, 32 | rotate=(-25, 25), 33 | shear=(-8, 8) 34 | ) 35 | ]) 36 | def augment(images): 37 | # Only works with list. Convert np to list 38 | imgs = [] 39 | for i in range(images.shape[0]): 40 | imgs.append(images[i,:,:,:]) 41 | 42 | images = images 43 | 44 | return seq.augment(images=images) 45 | return augment 46 | 47 | 48 | def sharpen(x, T): 49 | temp = x**(1/T) 50 | return temp / temp.sum(axis=1, keepdims=True) 51 | 52 | 53 | def mixup_mod(x1, x2, y1, y2, alpha): 54 | # lambda is a reserved word in python, substituting by beta 55 | beta = np.random.beta(alpha, alpha) 56 | beta = np.amax([beta, 1-beta]) 57 | x = beta * x1 + (1 - beta) * x2 58 | y = beta * y1 + (1 - beta) * y2 59 | return x, y 60 | 61 | def label_guessing(model, ub, K): 62 | # Device configuration 63 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 64 | model.eval() 65 | probs = [] 66 | for batch in ub: 67 | batch = torch.from_numpy(np.asarray(batch).transpose((0, 3, 1, 2))).to(device, dtype=torch.float) 68 | pr = model(batch) 69 | probs.append(pr) 70 | 71 | sum = probs[0] 72 | for i in range(1,len(probs)): 73 | sum.add_(probs[i]) 74 | 75 | return (sum/K).cpu().detach().numpy() 76 | 77 | def to_categorical(y, num_classes=None, dtype='float32'): 78 | """Converts a class vector (integers) to binary class matrix. 79 | E.g. for use with categorical_crossentropy. 80 | Taken from Keras code 81 | https://github.com/keras-team/keras/blob/master/keras/utils/np_utils.py#L9 82 | """ 83 | 84 | y = np.array(y, dtype='int') 85 | input_shape = y.shape 86 | if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: 87 | input_shape = tuple(input_shape[:-1]) 88 | y = y.ravel() 89 | if not num_classes: 90 | num_classes = np.max(y) + 1 91 | n = y.shape[0] 92 | categorical = np.zeros((n, num_classes), dtype=dtype) 93 | categorical[np.arange(n), y] = 1 94 | output_shape = input_shape + (num_classes,) 95 | categorical = np.reshape(categorical, output_shape) 96 | return categorical 97 | 98 | 99 | def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75): 100 | xb = augment_fn(x) 101 | y = to_categorical(y, num_classes=10) # Converting to one hot encode, num_clases=10 for future CIFAR test 102 | ub = [augment_fn(u) for _ in range(K)] 103 | avg_probs = label_guessing(model, ub, K) 104 | qb = sharpen(avg_probs, T) 105 | Ux = np.concatenate(ub, axis=0) 106 | Uy = np.concatenate([qb for _ in range(K)], axis=0) 107 | # Randon shuffle according to the paper 108 | indices = np.arange(len(xb) + len(Ux)) 109 | np.random.shuffle(indices) 110 | # MixUp 111 | Wx = np.concatenate([Ux, xb], axis=0)[indices] 112 | Wy = np.concatenate([Uy, y], axis=0)[indices] 113 | X, p = mixup_mod(xb, Wx[:len(xb)], y, Wy[:len(xb)], alpha) 114 | U, q = mixup_mod(Ux, Wx[len(xb):], Uy, Wy[len(xb):], alpha) 115 | 116 | # One hot decode for PyTorch labels compability 117 | p = p.argmax(axis=1) 118 | q = q.argmax(axis=1) 119 | return X, p, U, q 120 | 121 | 122 | class MixMatchLoss(torch.nn.Module): 123 | def __init__(self, lambda_u=100): 124 | self.lambda_u = lambda_u 125 | self.xent = torch.nn.CrossEntropyLoss() 126 | self.mse = torch.nn.MSELoss() 127 | super(MixMatchLoss, self).__init__() 128 | 129 | def forward(self, X, U, p, q): 130 | X_ = np.concatenate([X, U], axis=1) 131 | y_ = np.concatenate([p, q], axis=1) 132 | return self.xent(preds[:len(p)], p) + self.mse(preds[len(p):], q) 133 | -------------------------------------------------------------------------------- /notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "mixmatch.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "hmES8DZG7pFc", 33 | "colab_type": "text" 34 | }, 35 | "source": [ 36 | "This notebook tries to implement the MixMatch technique from the [paper](https://arxiv.org/pdf/1905.02249.pdf) MixMatch: A Holistic Approach to Semi-Supervised Learning and recreate their results on CIFAR10 with WideResnet28. \n", 37 | "\n", 38 | "It depends on Pytorch, Numpy and imgaug. The WideResnet28 model code is taken from [meliketoy](https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py)'s github repository. Hopefully I can train this on Colab with a Tesla T4. :)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "id": "J3ROGJfXigq3", 45 | "colab_type": "code", 46 | "outputId": "58e098b9-1e6b-434e-f695-2dad12be9ba1", 47 | "colab": { 48 | "base_uri": "https://localhost:8080/", 49 | "height": 306 50 | } 51 | }, 52 | "source": [ 53 | "!nvidia-smi" 54 | ], 55 | "execution_count": 15, 56 | "outputs": [ 57 | { 58 | "output_type": "stream", 59 | "text": [ 60 | "Thu May 23 14:30:40 2019 \n", 61 | "+-----------------------------------------------------------------------------+\n", 62 | "| NVIDIA-SMI 418.67 Driver Version: 410.79 CUDA Version: 10.0 |\n", 63 | "|-------------------------------+----------------------+----------------------+\n", 64 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 65 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 66 | "|===============================+======================+======================|\n", 67 | "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", 68 | "| N/A 62C P8 17W / 70W | 0MiB / 15079MiB | 0% Default |\n", 69 | "+-------------------------------+----------------------+----------------------+\n", 70 | " \n", 71 | "+-----------------------------------------------------------------------------+\n", 72 | "| Processes: GPU Memory |\n", 73 | "| GPU PID Type Process name Usage |\n", 74 | "|=============================================================================|\n", 75 | "| No running processes found |\n", 76 | "+-----------------------------------------------------------------------------+\n" 77 | ], 78 | "name": "stdout" 79 | } 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "metadata": { 85 | "id": "zkPKM_FeXCUG", 86 | "colab_type": "code", 87 | "colab": {} 88 | }, 89 | "source": [ 90 | "import torch\n", 91 | "import numpy as np\n", 92 | "import imgaug.augmenters as iaa" 93 | ], 94 | "execution_count": 0, 95 | "outputs": [] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "id": "Z_V6d_r-8QUi", 101 | "colab_type": "text" 102 | }, 103 | "source": [ 104 | "Now that we have the basic imports out of the way lets get to it. \n", 105 | "First we shall define the function to get augmented version of a given batch of images. The below function returns the function to do that. " 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "metadata": { 111 | "id": "yKrQ2XsBXLlN", 112 | "colab_type": "code", 113 | "colab": {} 114 | }, 115 | "source": [ 116 | "def get_augmenter():\n", 117 | " seq = iaa.Sequential([\n", 118 | " iaa.Crop(px=(0, 16)),\n", 119 | " iaa.Fliplr(0.5),\n", 120 | " iaa.GaussianBlur(sigma=(0, 3.0))\n", 121 | " ])\n", 122 | " def augment(images):\n", 123 | " return seq.augment(images.transpose(0, 2, 3, 1)).transpose(0, 2, 3, 1)\n", 124 | " return augment" 125 | ], 126 | "execution_count": 0, 127 | "outputs": [] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "id": "se8HRC8z8byR", 133 | "colab_type": "text" 134 | }, 135 | "source": [ 136 | "Next we define the sharpening function to sharpen the prediction from the averaged prediction of all the unlabeled augmented images. It does the same thing as applying a temperature within the softmax function but to the probabilities. " 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "metadata": { 142 | "id": "G_DDDq0qYP5E", 143 | "colab_type": "code", 144 | "colab": {} 145 | }, 146 | "source": [ 147 | "def sharpen(x, T):\n", 148 | " temp = x**(1/T)\n", 149 | " return temp / temp.sum(axis=1, keepdims=True)" 150 | ], 151 | "execution_count": 0, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": { 157 | "id": "IhvvJUKN80lU", 158 | "colab_type": "text" 159 | }, 160 | "source": [ 161 | "A simple implementation of the [paper](https://arxiv.org/pdf/1710.09412.pdf) mixup: Beyond Empirical Risk Minimization used in this paper as well." 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "metadata": { 167 | "id": "Q21aM3biiVgi", 168 | "colab_type": "code", 169 | "colab": {} 170 | }, 171 | "source": [ 172 | "def mixup(x1, x2, y1, y2, alpha):\n", 173 | " beta = np.random.beta(alpha, -alpha)\n", 174 | " x = beta * x1 + (1 - beta) * x2\n", 175 | " y = beta * y1 + (1 - beta) * y2\n", 176 | " return x, y" 177 | ], 178 | "execution_count": 0, 179 | "outputs": [] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": { 184 | "id": "HU0JHbCh90o5", 185 | "colab_type": "text" 186 | }, 187 | "source": [ 188 | "This covers Algorithm 1 from the paper. " 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "metadata": { 194 | "id": "cE2Yi1WWiZNi", 195 | "colab_type": "code", 196 | "colab": {} 197 | }, 198 | "source": [ 199 | "def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):\n", 200 | " xb = augment_fn(x)\n", 201 | " ub = [augment_fn(u) for _ in range(K)]\n", 202 | " qb = sharpen(sum(map(lambda i: model(i), ub)) / K, T)\n", 203 | " Ux = np.concatenate(ub, axis=0)\n", 204 | " Uy = np.concatenate([qb for _ in range(K)], axis=0)\n", 205 | " indices = np.random.shuffle(np.arange(len(xb) + len(Ux)))\n", 206 | " Wx = np.concatenate([Ux, xb], axis=0)[indices]\n", 207 | " Wy = np.concatenate([qb, y], axis=0)[indices]\n", 208 | " X, p = mixup(xb, Wx[:len(xb)], y, Wy[:len(xb)], alpha)\n", 209 | " U, q = mixup(Ux, Wx[len(xb):], Uy, Wy[len(xb):], alpha)\n", 210 | " return X, U, p, q" 211 | ], 212 | "execution_count": 0, 213 | "outputs": [] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": { 218 | "id": "dmSvUmiP94zT", 219 | "colab_type": "text" 220 | }, 221 | "source": [ 222 | "The combined loss for training from the paper." 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "metadata": { 228 | "id": "K5ylws-0kziT", 229 | "colab_type": "code", 230 | "colab": {} 231 | }, 232 | "source": [ 233 | "class MixMatchLoss(torch.nn.Module):\n", 234 | " def __init__(self, lambda_u=100):\n", 235 | " self.lambda_u = lambda_u\n", 236 | " self.xent = torch.nn.CrossEntropyLoss()\n", 237 | " self.mse = torch.nn.MSELoss()\n", 238 | " super(MixMatchLoss, self).__init__()\n", 239 | " \n", 240 | " def forward(self, X, U, p, q, model):\n", 241 | " X_ = np.concatenate([X, U], axis=1)\n", 242 | " preds = model(X_)\n", 243 | " return self.xent(preds[:len(p)], p) + \\\n", 244 | " self.lambda_u * self.mse(preds[len(p):], q)" 245 | ], 246 | "execution_count": 0, 247 | "outputs": [] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": { 252 | "id": "CCqJtpJ--Cik", 253 | "colab_type": "text" 254 | }, 255 | "source": [ 256 | "Now that we have the MixMatch stuff done, we have a few things to do. Namely, define the WideResnet28 model, write the data and training code and write testing code. \n", 257 | "Let's start with the model. The below is just a copy paste mostly from the wide-resnet.pytorch repo by meliketoy. " 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "metadata": { 263 | "id": "GIkBy3T15P7l", 264 | "colab_type": "code", 265 | "colab": {} 266 | }, 267 | "source": [ 268 | "def conv3x3(in_planes, out_planes, stride=1):\n", 269 | " return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", 270 | " bias=True)" 271 | ], 272 | "execution_count": 0, 273 | "outputs": [] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": { 278 | "id": "Fud8CmEtCaSN", 279 | "colab_type": "text" 280 | }, 281 | "source": [ 282 | "Will need the below init function later before training." 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "FZBBH5EYCZhi", 289 | "colab_type": "code", 290 | "colab": {} 291 | }, 292 | "source": [ 293 | "def conv_init(m):\n", 294 | " classname = m.__class__.__name__\n", 295 | " if classname.find('Conv') != -1:\n", 296 | " torch.nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))\n", 297 | " torch.nn.init.constant(m.bias, 0)\n", 298 | " elif classname.find('BatchNorm') != -1:\n", 299 | " torch.nn.init.constant(m.weight, 1)\n", 300 | " torch.nn.init.constant(m.bias, 0)" 301 | ], 302 | "execution_count": 0, 303 | "outputs": [] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": { 308 | "id": "V_gOfar1CeUx", 309 | "colab_type": "text" 310 | }, 311 | "source": [ 312 | "The basic block for the WideResnet" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "metadata": { 318 | "id": "QZ068XQR6LZP", 319 | "colab_type": "code", 320 | "colab": {} 321 | }, 322 | "source": [ 323 | "class WideBasic(torch.nn.Module):\n", 324 | " def __init__(self, in_planes, planes, dropout_rate, stride=1):\n", 325 | " super(WideBasic, self).__init__()\n", 326 | " self.bn1 = torch.nn.BatchNorm2d(in_planes)\n", 327 | " self.bn2 = torch.nn.BatchNorm2d(planes)\n", 328 | " self.conv1 = torch.nn.Conv2d(in_planes, planes, kernel_size=3,\n", 329 | " padding=1, bias=True)\n", 330 | " self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,\n", 331 | " padding=1, bias=True)\n", 332 | " self.dropout = torch.nn.Dropout(p=dropout_rate)\n", 333 | " self.shortcut = torch.nn.Sequential()\n", 334 | " if stride != 1 or in_planes != planes:\n", 335 | " self.shortcut = torch.nn.Sequential(\n", 336 | " torch.nn.Conv2d(in_planes, planes, kernel_size=1,\n", 337 | " stride=stride, bias=True)\n", 338 | " )\n", 339 | "\n", 340 | " def forward(self, x):\n", 341 | " out = self.dropout(self.conv1(torch.nn.functional.relu(self.bn1(x))))\n", 342 | " out = self.conv2(torch.nn.functional.relu(self.bn2(out)))\n", 343 | " return out + self.shortcut(x)" 344 | ], 345 | "execution_count": 0, 346 | "outputs": [] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": { 351 | "id": "wdew7GNoChmh", 352 | "colab_type": "text" 353 | }, 354 | "source": [ 355 | "Aaand the full model with default params set for CIFAR10." 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "metadata": { 361 | "id": "YvE9l4W27jTx", 362 | "colab_type": "code", 363 | "colab": {} 364 | }, 365 | "source": [ 366 | "class WideResNet(torch.nn.Module):\n", 367 | " def __init__(self, depth=28, widen_factor=10,\n", 368 | " dropout_rate=0.3, num_classes=10):\n", 369 | " super(WideResNet, self).__init__()\n", 370 | " self.in_planes = 16\n", 371 | " n = (depth - 4) // 6\n", 372 | " k = widen_factor\n", 373 | " nStages = [16, 16*k, 32*k, 64*k]\n", 374 | " self.conv1 = conv3x3(3, nStages[0])\n", 375 | " self.layer1 = self.wide_layer(WideBasic, nStages[1], n, dropout_rate,\n", 376 | " stride=1)\n", 377 | " self.layer2 = self.wide_layer(WideBasic, nStages[2], n, dropout_rate,\n", 378 | " stride=2)\n", 379 | " self.layer3 = self.wide_layer(WideBasic, nStages[3], n, dropout_rate,\n", 380 | " stride=2)\n", 381 | " self.b1 = torch.nn.BatchNorm2d(nStages[3], momentum=0.9)\n", 382 | " self.linear = torch.nn.Linear(nStages[3], num_classes)\n", 383 | " \n", 384 | " def wide_layer(self, block, planes, num_blocks, dropout_rate, stride):\n", 385 | " strides = [stride] + [1] * (num_blocks - 1)\n", 386 | " layers = []\n", 387 | " for stride in strides:\n", 388 | " layers.append(block(self.in_planes, planes, dropout_rate, stride))\n", 389 | " self.in_planes = planes\n", 390 | " return torch.nn.Sequential(*layers)\n", 391 | " \n", 392 | " def forward(self, x):\n", 393 | " out = self.conv1(x)\n", 394 | " out = self.layer3(self.layer2(self.layer1(out)))\n", 395 | " out = torch.nn.functional.relu(self.bn1(out))\n", 396 | " out = torch.nn.functional.avg_pool2d(out, 8)\n", 397 | " out = out.view(out.size(0), -1)\n", 398 | " return self.linear(out)" 399 | ], 400 | "execution_count": 0, 401 | "outputs": [] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": { 406 | "id": "JnLX7FkIEz1L", 407 | "colab_type": "text" 408 | }, 409 | "source": [ 410 | "Now that we have the model let's write train and test loaders so that we can pass the model and the data to the MixMatchLoss." 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "metadata": { 416 | "id": "EjCTPM8wB-dR", 417 | "colab_type": "code", 418 | "colab": {} 419 | }, 420 | "source": [ 421 | "def basic_generator(x, y=None, batch_size=32, shuffle=True):\n", 422 | " i = 0\n", 423 | " all_indices = np.random.shuffle(np.arange(len(x))) if shuffle else \\\n", 424 | " np.arange(len(x))\n", 425 | " while(True):\n", 426 | " indices = all_indices[i:i+batch_size]\n", 427 | " if y is not None:\n", 428 | " yield x[indices], y[indices]\n", 429 | " yield x[indices]\n", 430 | " i = (i + batch_size) % len(x)" 431 | ], 432 | "execution_count": 0, 433 | "outputs": [] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "metadata": { 438 | "id": "QQb89EOHfUH8", 439 | "colab_type": "code", 440 | "colab": {} 441 | }, 442 | "source": [ 443 | "def mixmatch_wrapper(x, y, u, model, batch_size=32):\n", 444 | " augment_fn = get_augmenter()\n", 445 | " train_generator = basic_generator(x, y, batch_size)\n", 446 | " unlabeled_generator = basic_generator(u, batch_size=batch_size)\n", 447 | " while(True):\n", 448 | " xi, yi = next(train_generator)\n", 449 | " ui = next(unlabeled_generator)\n", 450 | " yield mixmatch(xi, yi, ui, model, augment_fn)" 451 | ], 452 | "execution_count": 0, 453 | "outputs": [] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "metadata": { 458 | "id": "eLafEafCJthx", 459 | "colab_type": "code", 460 | "colab": {} 461 | }, 462 | "source": [ 463 | "def to_torch(*args, device='cuda'):\n", 464 | " convert_fn = lambda x: torch.from_numpy(x).to(device)\n", 465 | " return list(map(convert_fn, args))" 466 | ], 467 | "execution_count": 0, 468 | "outputs": [] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": { 473 | "id": "NSTVdcWriKTq", 474 | "colab_type": "text" 475 | }, 476 | "source": [ 477 | "That about covers all the code we need for train and test loaders. Now we can start the training and evaluation. Let's see if all of this works or is just a mess. Going to add basically this same training code from meliketoy's repo but with the MixMatchLoss. " 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "metadata": { 483 | "id": "dRJOZ9FLL40g", 484 | "colab_type": "code", 485 | "colab": {} 486 | }, 487 | "source": [ 488 | "def test(model, test_gen, test_iters):\n", 489 | " acc = []\n", 490 | " for i, (x, y) in enumerate(test_gen):\n", 491 | " x = to_torch(x)\n", 492 | " pred = model(x).to('cpu').argmax(axis=1)\n", 493 | " acc.append(np.mean(pred == y.argmax(axis=1)))\n", 494 | " if i == test_iters:\n", 495 | " break\n", 496 | " print('Accuracy was : {}'.format(np.mean(acc)))" 497 | ], 498 | "execution_count": 0, 499 | "outputs": [] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "metadata": { 504 | "id": "YNac4RKMMvln", 505 | "colab_type": "code", 506 | "colab": {} 507 | }, 508 | "source": [ 509 | "def report(loss_history):\n", 510 | " print('Average loss in last epoch was : {}'.format(np.mean(loss_history)))\n", 511 | " return []" 512 | ], 513 | "execution_count": 0, 514 | "outputs": [] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "metadata": { 519 | "id": "4_IeC5TXNHIg", 520 | "colab_type": "code", 521 | "colab": {} 522 | }, 523 | "source": [ 524 | "def save(model, iter, train_iters):\n", 525 | " torch.save(model.state_dict(), 'model_{}.pth'.format(train_iters // iters))" 526 | ], 527 | "execution_count": 0, 528 | "outputs": [] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "metadata": { 533 | "id": "TAMKAUNtiZwV", 534 | "colab_type": "code", 535 | "colab": {} 536 | }, 537 | "source": [ 538 | "def run(model, train_gen, test_gen, epochs, train_iters, test_iters, device):\n", 539 | " optim = torch.optim.Adam(model.parameters(), lr=lr)\n", 540 | " loss_fn = MixMatchLoss()\n", 541 | " loss_history = []\n", 542 | " for i, (x, u, p, q) in enumerate(train_gen):\n", 543 | " if i % train_iters == 0:\n", 544 | " loss_history = report(loss_history)\n", 545 | " test(model, test_gen, test_iters)\n", 546 | " save(model, i, train_iters)\n", 547 | " if i // train_iters == epochs:\n", 548 | " return\n", 549 | " else:\n", 550 | " optim.zero_grad()\n", 551 | " x, u, p, q = to_torch(x, u, p, q, device=device)\n", 552 | " loss = loss_fn(x, u, p, q, model)\n", 553 | " loss.backward()\n", 554 | " optim.step()\n", 555 | " loss_history.append(loss.to('cpu'))" 556 | ], 557 | "execution_count": 0, 558 | "outputs": [] 559 | } 560 | ] 561 | } --------------------------------------------------------------------------------