├── requirements.txt ├── README.md ├── .gitignore ├── augmentation.py ├── LICENSE ├── utils.py ├── mnist_experiments.py ├── models.py ├── cifar10_experiments.py └── plot.py /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | matplotlib 3 | numpy 4 | pytorch>=0.4.0 5 | seaborn 6 | torchvision 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Augmentation 2 | 3 | ## Dependencies 4 | - Python 3.6+ 5 | - Pillow, matplotlib, numpy, pytorch==0.4.0, seaborn, torchvision 6 | 7 | ## Usage 8 | 9 | * `mnist_experiments.py` runs a full set of experiments on MNIST and save the 10 | results to the directory `saved`. Note: the default run take a long time (43 hours) to 11 | finish, since we're running for all 10 seeds. 12 | ``` 13 | python mnist_experiments.py 14 | ``` 15 | 16 | Currently, it executes the following experiments: 17 | 1. Measure the difference between exact augmented objective and approximate 18 | objectives (on original images, 1st order approximation, 2nd order approximation). 19 | 2. Measure the agreement and KL divergence between the predictions made by 20 | model trained on exact augmented objective and models trained on 21 | approximate objectives. 22 | 3. Compute kernel target alignment for features from different transformations. 23 | 24 | * `plot.py` plots all the figures in the paper using the saved results from 25 | `mnist_experiments.py`. The figures are saved in the directory `figs`. 26 | ``` 27 | python plot.py 28 | ``` 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | .pytest_cache/ 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule.* 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | env.bak/ 88 | venv.bak/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | 104 | # End of https://www.gitignore.io/api/python -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | import torch 6 | from torchvision import transforms 7 | from PIL import Image, ImageFilter, ImageEnhance 8 | 9 | # An augmentation object consists of its name, the transform functions of type 10 | # torchvision.transforms, and the resulting augmented dataset of type 11 | # torch.utils.data.Dataset. 12 | Augmentation = namedtuple('Augmentation', ['name', 'transforms', 'dataset']) 13 | 14 | 15 | def copy_with_new_transform(dataset, transform): 16 | """A copy of @dataset with its transform set to @transform. 17 | Will work for datasets from torchvision, e.g., MNIST, CIFAR10, etc. Probably 18 | won't work for a generic dataset. 19 | """ 20 | new_dataset = copy.copy(dataset) 21 | new_dataset.transform = transform 22 | return new_dataset 23 | 24 | 25 | def augment_transforms(augmentations, base_transform, add_id_transform=True): 26 | """Construct a new transform that stack all the augmentations. 27 | Parameters: 28 | augmentations: list of transforms (e.g. image rotations) 29 | base_transform: transform to be applied after augmentation (e.g. ToTensor) 30 | add_id_transform: whether to include the original image (i.e. identity transform) in the new transform. 31 | Return: 32 | a new transform that takes in a data point and applies all the 33 | augmentations, then stack the result. 34 | """ 35 | if add_id_transform: 36 | fn = lambda x: torch.stack([base_transform(x)] + [base_transform(aug(x)) 37 | for aug in augmentations]) 38 | else: 39 | fn = lambda x: torch.stack([base_transform(aug(x)) for aug in augmentations]) 40 | return transforms.Lambda(fn) 41 | 42 | 43 | def rotation(base_dataset, base_transform, angles=range(-15, 16, 2)): 44 | """Rotations, e.g. between -15 and 15 degrees 45 | """ 46 | rotations = [transforms.RandomRotation((angle, angle)) for angle in angles] 47 | aug_dataset = copy_with_new_transform(base_dataset, 48 | augment_transforms(rotations, base_transform)) 49 | return Augmentation('rotation', rotations, aug_dataset) 50 | 51 | 52 | def resized_crop(base_dataset, base_transform, size=28, scale=(0.64, 1.0), n_random_samples=31): 53 | """Random crop (with resize) 54 | """ 55 | random_resized_crops = [transforms.RandomResizedCrop(size, scale=scale) for _ in range(n_random_samples)] 56 | aug_dataset = copy_with_new_transform(base_dataset, 57 | augment_transforms(random_resized_crops, base_transform)) 58 | return Augmentation('crop', random_resized_crops, aug_dataset) 59 | 60 | 61 | def blur(base_dataset, base_transform, radii=np.linspace(0.05, 1.0, 20)): 62 | """Random Gaussian blur 63 | """ 64 | def gaussian_blur_fn(radius): 65 | return transforms.Lambda(lambda img: img.filter(ImageFilter.GaussianBlur(radius))) 66 | 67 | blurs = [gaussian_blur_fn(radius) for radius in radii] 68 | aug_dataset = copy_with_new_transform(base_dataset, 69 | augment_transforms(blurs, base_transform)) 70 | return Augmentation('blur', blurs, aug_dataset) 71 | 72 | 73 | def rotation_crop_blur(base_dataset, base_transform, angles=range(-15, 16, 2), 74 | size=28, scale=(0.64, 1.0), n_random_samples=31, 75 | radii=np.linspace(0.05, 1.0, 20)): 76 | """All 3: rotations, random crops, and blurs 77 | """ 78 | rotations = rotation(base_dataset, base_transform, angles).transforms 79 | random_resized_crops = resized_crop(base_dataset, base_transform, size, scale, n_random_samples).transforms 80 | blurs = blur(base_dataset, base_transform, radii).transforms 81 | aug_dataset = copy_with_new_transform(base_dataset, 82 | augment_transforms(rotations + random_resized_crops + blurs, base_transform)) 83 | return Augmentation('rotation_crop_blur', blurs, aug_dataset) 84 | 85 | 86 | def hflip(base_dataset, base_transform): 87 | """Horizontal flip 88 | """ 89 | flip = [transforms.Lambda(lambda img: img.transpose(Image.FLIP_LEFT_RIGHT))] 90 | aug_dataset = copy_with_new_transform(base_dataset, 91 | augment_transforms(flip, base_transform)) 92 | return Augmentation('hflip', flip, aug_dataset) 93 | 94 | 95 | def hflip_vflip(base_dataset, base_transform): 96 | """Both horizontal and vertical flips 97 | """ 98 | allflips = [transforms.Lambda(lambda img: img.transpose(Image.FLIP_LEFT_RIGHT)), 99 | transforms.Lambda(lambda img: img.transpose(Image.FLIP_TOP_BOTTOM)), 100 | transforms.Lambda(lambda img: img.transpose(Image.FLIP_LEFT_RIGHT).transpose(Image.FLIP_TOP_BOTTOM))] 101 | aug_dataset = copy_with_new_transform(base_dataset, 102 | augment_transforms(allflips, base_transform)) 103 | return Augmentation('hflip_vflip', allflips, aug_dataset) 104 | 105 | 106 | def brightness(base_dataset, base_transform, brightness_factors=np.linspace(1 - 0.25, 1 + 0.25, 11)): 107 | """Random brightness adjustment 108 | """ 109 | def brightness_fn(brightness_factor): 110 | return transforms.Lambda(lambda img: ImageEnhance.Brightness(img).enhance(brightness_factor)) 111 | 112 | brightness_transforms = [brightness_fn(factor) for factor in brightness_factors] 113 | aug_dataset = copy_with_new_transform(base_dataset, 114 | augment_transforms(brightness_transforms, base_transform)) 115 | return Augmentation('brightness', brightness_transforms, aug_dataset) 116 | 117 | 118 | def contrast(base_dataset, base_transform, contrast_factors=np.linspace(1 - 0.35, 1 + 0.35, 11)): 119 | """Random contrast adjustment 120 | """ 121 | def contrast_fn(contrast_factor): 122 | return transforms.Lambda(lambda img: ImageEnhance.Contrast(img).enhance(contrast_factor)) 123 | 124 | contrast_transforms = [contrast_fn(factor) for factor in contrast_factors] 125 | aug_dataset = copy_with_new_transform(base_dataset, 126 | augment_transforms(contrast_transforms, base_transform)) 127 | return Augmentation('contrast', contrast_transforms, aug_dataset) 128 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from sklearn.metrics import roc_auc_score 8 | 9 | from models import combine_transformed_dimension, split_transformed_dimension 10 | 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | def get_train_valid_datasets(dataset, 14 | valid_size=0.1, 15 | random_seed=None, 16 | shuffle=True): 17 | """ 18 | Utility function for loading and returning train and validation 19 | datasets. 20 | Parameters: 21 | ------ 22 | - dataset: the dataset, need to have train_data and train_labels attributes. 23 | - valid_size: percentage split of the training set used for 24 | the validation set. Should be a float in the range [0, 1]. 25 | - random_seed: fix seed for reproducibility. 26 | - shuffle: whether to shuffle the train/validation indices. 27 | Returns: 28 | ------- 29 | - train_dataset: training set. 30 | - valid_dataset: validation set. 31 | """ 32 | error_msg = "[!] valid_size should be in the range [0, 1]." 33 | assert ((valid_size >= 0) and (valid_size <= 1)), error_msg 34 | num_train = len(dataset) 35 | indices = list(range(num_train)) 36 | if shuffle: 37 | np.random.seed(random_seed) 38 | np.random.shuffle(indices) 39 | split = int(np.floor(valid_size * num_train)) 40 | train_idx, valid_idx = indices[split:], indices[:split] 41 | train_dataset, valid_dataset = copy.copy(dataset), copy.copy(dataset) 42 | train_dataset.train_data = train_dataset.train_data[train_idx] 43 | train_dataset.train_labels = train_dataset.train_labels[train_idx] 44 | valid_dataset.train_data = valid_dataset.train_data[valid_idx] 45 | valid_dataset.train_labels = valid_dataset.train_labels[valid_idx] 46 | return train_dataset, valid_dataset 47 | 48 | 49 | def train(data_loader, model, optimizer): 50 | model.train() 51 | train_loss, train_acc = [], [] 52 | for data, target in data_loader: 53 | data, target = data.to(device), target.to(device) 54 | optimizer.zero_grad() 55 | output = model(data) 56 | pred = model.predict(output) 57 | loss = model.loss(output, target) 58 | loss.backward() 59 | optimizer.step() 60 | acc = (pred == target).sum().item() / target.size(0) 61 | train_loss.append(loss.item()) 62 | train_acc.append(acc) 63 | return train_loss, train_acc 64 | 65 | 66 | def train_models_compute_agreement(data_loader, models, optimizers): 67 | train_agreement = [] 68 | for model in models: 69 | model.train() 70 | for data, target in data_loader: 71 | data, target = data.to(device), target.to(device) 72 | pred, loss = [], [] 73 | for model, optimizer in zip(models, optimizers): 74 | optimizer.zero_grad() 75 | output = model(data) 76 | pred.append(model.predict(output)) 77 | loss_minibatch = model.loss(output, target) 78 | loss_minibatch.backward() 79 | optimizer.step() 80 | loss.append(loss_minibatch.item()) 81 | # To avoid out-of-memory error, as these attributes prevent the memory from being freed 82 | if hasattr(model, '_avg_features'): 83 | del model._avg_features 84 | if hasattr(model, '_centered_features'): 85 | del model._centered_features 86 | loss = np.array(loss) 87 | pred = np.array([p.cpu().numpy() for p in pred]) 88 | train_agreement.append((pred == pred[0]).mean(axis=1)) 89 | return train_agreement 90 | 91 | 92 | def train_all_epochs(train_loader, 93 | valid_loader, 94 | model, 95 | optimizer, 96 | n_epochs, 97 | verbose=True): 98 | model.train() 99 | train_loss, train_acc, valid_acc = [], [], [] 100 | for epoch in range(n_epochs): 101 | if verbose: 102 | print(f'Train Epoch: {epoch}') 103 | loss, acc = train(train_loader, model, optimizer) 104 | train_loss += loss 105 | train_acc += acc 106 | correct, total = accuracy(valid_loader, model) 107 | valid_acc.append(correct / total) 108 | if verbose: 109 | print( 110 | f'Validation set: Accuracy: {correct}/{total} ({correct/total*100:.4f}%)' 111 | ) 112 | return train_loss, train_acc, valid_acc 113 | 114 | 115 | def accuracy(data_loader, model): 116 | """Accuracy over all mini-batches. 117 | """ 118 | training = model.training 119 | model.eval() 120 | correct, total = 0, 0 121 | with torch.no_grad(): 122 | for data, target in data_loader: 123 | data, target = data.to(device), target.to(device) 124 | output = model(data) 125 | pred = model.predict(output) 126 | correct += (pred == target).sum().item() 127 | total += target.size(0) 128 | model.train(training) 129 | return correct, total 130 | 131 | def roc_auc(data_loader, model): 132 | """Accuracy over all mini-batches. 133 | """ 134 | training = model.training 135 | model.eval() 136 | y_true, y_score = [], [] 137 | with torch.no_grad(): 138 | for data, target in data_loader: 139 | data, target = data.to(device), target.to(device) 140 | output = model(data) 141 | y_true.append(target) 142 | y_score.append(torch.nn.Softmax(dim=-1)(output)[:, 1]) 143 | model.train(training) 144 | y_true = torch.cat(y_true).cpu().numpy() 145 | y_score = torch.cat(y_score).cpu().numpy() 146 | return roc_auc_score(y_true, y_score) 147 | 148 | 149 | def all_losses(data_loader, model): 150 | """All losses over all mini-batches. 151 | """ 152 | training = model.training 153 | model.eval() 154 | losses = [] 155 | with torch.no_grad(): 156 | for data, target in data_loader: 157 | data, target = data.to(device), target.to(device) 158 | losses.append([l.item() for l in model.all_losses(data, target)]) 159 | model.train(training) 160 | return np.array(losses) 161 | 162 | 163 | def agreement_kl_accuracy(data_loader, models): 164 | training = [model.training for model in models] 165 | for model in models: 166 | model.eval() 167 | valid_agreement, valid_acc, valid_kl = [], [], [] 168 | with torch.no_grad(): 169 | for data, target in data_loader: 170 | data, target = data.to(device), target.to(device) 171 | pred, out = [], [] 172 | for model in models: 173 | output = model(data).detach() 174 | out.append(output) 175 | pred.append(model.predict(output)) 176 | pred = torch.stack(pred) 177 | out = torch.stack(out) 178 | log_prob = F.log_softmax(out, dim=-1) 179 | prob = F.softmax(out[0], dim=-1) 180 | valid_kl.append([F.kl_div(lp, prob, size_average=False).item() / prob.size(0) for lp in log_prob]) 181 | valid_acc.append((pred == target).float().mean(dim=1).cpu().numpy()) 182 | valid_agreement.append((pred == pred[0]).float().mean(dim=1).cpu().numpy()) 183 | for model, training_ in zip(models, training): 184 | model.train(training_) 185 | return valid_agreement, valid_kl, valid_acc 186 | 187 | 188 | def kernel_target_alignment(data_loader, model, n_passes_through_data=10): 189 | """Compute kernel target alignment approximately by summing over 190 | mini-batches. The number of mini-batches is controlled by @n_passes_through_data. 191 | Larger number of passes yields more accurate result, but takes longer. 192 | """ 193 | inclass_kernel, kernel_fro_norm, inclass_num = [], [], [] 194 | with torch.no_grad(): 195 | for _ in range(n_passes_through_data): 196 | for data, target in data_loader: 197 | data, target = data.to(device), target.to(device) 198 | features = model.features(data) 199 | target = target[:, None] 200 | same_labels = target == target.t() 201 | K = features @ features.t() 202 | inclass_kernel.append(K[same_labels].sum().item()) 203 | kernel_fro_norm.append((K * K).sum().item()) 204 | inclass_num.append(same_labels.long().sum().item()) 205 | inclass_kernel = np.array(inclass_kernel) 206 | kernel_fro_norm = np.array(kernel_fro_norm) 207 | inclass_num = np.array(inclass_num) 208 | return inclass_kernel.mean(axis=0) / np.sqrt(kernel_fro_norm.mean(axis=0) * inclass_num.mean()) 209 | 210 | 211 | def kernel_target_alignment_augmented(data_loader, model, n_passes_through_data=10): 212 | """Compute kernel target alignment on augmented dataset, of the original 213 | features and averaged features. Alignment is approximately by summing over 214 | minibatches. The number of minibatches is controlled by 215 | @n_passes_through_data. Larger number of passes yields more accurate 216 | result. 217 | """ 218 | inclass_kernel, kernel_fro_norm, inclass_num = [], [], [] 219 | with torch.no_grad(): 220 | for _ in range(n_passes_through_data): 221 | for data, target in data_loader: 222 | data, target = data.to(device), target.to(device) 223 | n_transforms = data.size(1) 224 | data = combine_transformed_dimension(data) 225 | features = model.features(data) 226 | features = split_transformed_dimension(features, n_transforms) 227 | features_avg = features.mean(dim=1) 228 | features_og = features[:, 0] 229 | target = target[:, None] 230 | same_labels = target == target.t() 231 | K_avg = features_avg @ features_avg.t() 232 | K_og = features_og @ features_og.t() 233 | inclass_kernel.append([K_avg[same_labels].sum().item(), K_og[same_labels].sum().item()]) 234 | kernel_fro_norm.append([(K_avg * K_avg).sum().item(), (K_og * K_og).sum().item()]) 235 | inclass_num.append(same_labels.long().sum().item()) 236 | inclass_kernel = np.array(inclass_kernel) 237 | kernel_fro_norm = np.array(kernel_fro_norm) 238 | inclass_num = np.array(inclass_num) 239 | return tuple(inclass_kernel.mean(axis=0) / np.sqrt(kernel_fro_norm.mean(axis=0) * inclass_num.mean())) 240 | 241 | 242 | def kernel_target_alignment_augmented_no_avg(data_loader, model, n_passes_through_data=10): 243 | """Compute kernel target alignment approximately by summing over 244 | mini-batches. This is for augmented dataset, and no feature averaging will be done. 245 | Thus this is kernel target alignment on the augmented dataset. 246 | The number of mini-batches is controlled by @n_passes_through_data. 247 | Larger number of passes yields more accurate result, but takes longer. 248 | """ 249 | inclass_kernel, kernel_fro_norm, inclass_num = [], [], [] 250 | with torch.no_grad(): 251 | for _ in range(n_passes_through_data): 252 | for data, target in data_loader: 253 | data, target = data.to(device), target.to(device) 254 | # Repeat target for augmented data points 255 | target = target[:, None].repeat(1, data.shape[1]).view(-1) 256 | data = combine_transformed_dimension(data) 257 | features = model.features(data) 258 | target = target[:, None] 259 | same_labels = target == target.t() 260 | K = features @ features.t() 261 | inclass_kernel.append(K[same_labels].sum().item()) 262 | kernel_fro_norm.append((K * K).sum().item()) 263 | inclass_num.append(same_labels.long().sum().item()) 264 | inclass_kernel = np.array(inclass_kernel) 265 | kernel_fro_norm = np.array(kernel_fro_norm) 266 | inclass_num = np.array(inclass_num) 267 | return inclass_kernel.mean(axis=0) / np.sqrt(kernel_fro_norm.mean(axis=0) * inclass_num.mean()) 268 | -------------------------------------------------------------------------------- /mnist_experiments.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import pathlib 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torchvision import datasets, transforms 9 | 10 | from models import LinearLogisticRegression, RBFLogisticRegression, LinearLogisticRegressionAug, RBFLogisticRegressionAug, LeNet, LeNetAug, combine_transformed_dimension, split_transformed_dimension 11 | from augmentation import copy_with_new_transform, augment_transforms, rotation, resized_crop, blur, rotation_crop_blur, hflip, hflip_vflip, brightness, contrast 12 | from utils import get_train_valid_datasets, train, train_all_epochs, accuracy, all_losses, train_models_compute_agreement, agreement_kl_accuracy, kernel_target_alignment, kernel_target_alignment_augmented 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | batch_size = 256 18 | if device.type == 'cuda': 19 | loader_args = {'num_workers': 16, 'pin_memory': True} 20 | else: 21 | loader_args = {'num_workers': 4, 'pin_memory': False} 22 | 23 | 24 | def loader_from_dataset(dataset): 25 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, 26 | shuffle=True, **loader_args) 27 | 28 | # Construct loader from MNIST dataset, then construct loaders corresponding to 29 | # augmented dataset (wrt to different transformations). 30 | mnist_normalize = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.1307, ), (0.3081, )) 33 | ]) 34 | mnist_train = datasets.MNIST( 35 | '../data', train=True, download=True, transform=mnist_normalize) 36 | mnist_test = datasets.MNIST( 37 | '../data', train=False, download=True, transform=mnist_normalize) 38 | mnist_train, mnist_valid = get_train_valid_datasets(mnist_train) 39 | train_loader = loader_from_dataset(mnist_train) 40 | valid_loader = loader_from_dataset(mnist_valid) 41 | test_loader = loader_from_dataset(mnist_test) 42 | 43 | augmentations = [rotation(mnist_train, mnist_normalize), 44 | resized_crop(mnist_train, mnist_normalize), 45 | blur(mnist_train, mnist_normalize), 46 | rotation_crop_blur(mnist_train, mnist_normalize), 47 | hflip(mnist_train, mnist_normalize), 48 | hflip_vflip(mnist_train, mnist_normalize), 49 | brightness(mnist_train, mnist_normalize), 50 | contrast(mnist_train, mnist_normalize)] 51 | 52 | n_features = 28 * 28 53 | n_classes = 10 54 | gamma = 0.003 # gamma hyperparam for RBF kernel exp(-gamma ||x - y||^2). Best gamma is around 0.001--0.003 55 | n_components = 10000 56 | sgd_n_epochs = 15 57 | n_trials = 10 58 | 59 | model_factories = {'linear': lambda: LinearLogisticRegressionAug(n_features, n_classes, approx=False), 60 | 'kernel': lambda: RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, approx=False), 61 | 'lenet': lambda: LeNetAug(approx=False)} 62 | 63 | 64 | def sgd_opt_from_model(model, learning_rate=0.01, momentum=0.9, weight_decay=0.001): 65 | return optim.SGD((p for p in model.parameters() if p.requires_grad), 66 | lr=learning_rate, momentum=momentum, 67 | weight_decay=weight_decay) 68 | 69 | 70 | def train_basic_models(train_loader, augmented_loader): 71 | """Train a few simple models with data augmentation / approximation, as a 72 | sanity check. 73 | """ 74 | models = [ 75 | LinearLogisticRegressionAug(n_features, n_classes), # No augmentation, accuracy around 92.5% 76 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components), # Accuracy around 97.5% 77 | LeNetAug(), # Accuracy around 98.7% 78 | LinearLogisticRegressionAug(n_features, n_classes, approx=False), # Augmented data, exact objective 79 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, approx=False), 80 | LeNetAug(approx=False), 81 | LinearLogisticRegressionAug(n_features, n_classes, regularization=False), # Augmented data, 1st order approx 82 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, regularization=False), 83 | LeNetAug(), 84 | LinearLogisticRegressionAug(n_features, n_classes, regularization=True), # Augmented data, 2nd order approx 85 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, regularization=True), 86 | LeNetAug(regularization=True) 87 | ] 88 | loaders = [train_loader, train_loader, train_loader, 89 | augmented_loader, augmented_loader, augmented_loader, 90 | augmented_loader, augmented_loader, augmented_loader, 91 | augmented_loader, augmented_loader, augmented_loader] 92 | for model, loader in zip(models, loaders): 93 | model.to(device) 94 | optimizer = sgd_opt_from_model(model) 95 | train_loss, train_acc, valid_acc = train_all_epochs(loader, valid_loader, model, 96 | optimizer, sgd_n_epochs) 97 | correct, total = accuracy(test_loader, model) 98 | print(f'Test set: Accuracy: {correct}/{total} ({correct/total*100:.4f}%)\n') 99 | 100 | 101 | def objective_difference(augmentations): 102 | """Measure the difference in approximate and true objectives as we train on 103 | the true objective. 104 | """ 105 | for model_name in ['kernel', 'lenet']: 106 | for augmentation in augmentations: 107 | for seed in range(n_trials): 108 | print(f'Seed: {seed}') 109 | torch.manual_seed(seed) 110 | model = model_factories[model_name]().to(device) 111 | optimizer = sgd_opt_from_model(model) 112 | loader = loader_from_dataset(augmentation.dataset) 113 | model.train() 114 | losses = [] 115 | losses.append(all_losses(loader, model).mean(axis=0)) 116 | train_loss, train_acc, valid_acc = [], [], [] 117 | for epoch in range(sgd_n_epochs): 118 | train_loss_epoch, train_acc_epoch = train(loader, model, optimizer) 119 | train_loss += train_loss_epoch 120 | train_acc += train_acc_epoch 121 | print(f'Train Epoch: {epoch}') 122 | correct, total = accuracy(valid_loader, model) 123 | valid_acc.append(correct / total) 124 | print( 125 | f'Validation set: Accuracy: {correct}/{total} ({correct/total*100:.4f}%)' 126 | ) 127 | losses.append(np.array(all_losses(loader, model)).mean(axis=0)) 128 | train_loss, train_acc, valid_acc = np.array(train_loss), np.array(train_acc), np.array(valid_acc) 129 | np.savez(f'saved/train_valid_acc_{model_name}_{augmentation.name}_{seed}.npz', 130 | train_loss=train_loss, train_acc=train_acc, valid_acc=valid_acc) 131 | losses = np.array(losses).T 132 | np.save(f'saved/all_losses_{model_name}_{augmentation.name}_{seed}.npy', losses) 133 | 134 | 135 | def accuracy_on_true_objective(augmentations): 136 | """Measure the accuracy when trained on true augmented objective. 137 | """ 138 | for model_name in ['kernel', 'lenet']: 139 | for augmentation in augmentations: 140 | for seed in range(n_trials): 141 | print(f'Seed: {seed}') 142 | torch.manual_seed(seed) 143 | model = model_factories[model_name]().to(device) 144 | optimizer = sgd_opt_from_model(model) 145 | loader = loader_from_dataset(augmentation.dataset) 146 | train_loss, train_acc, valid_acc = train_all_epochs(loader, valid_loader, model, optimizer, sgd_n_epochs) 147 | train_loss, train_acc, valid_acc = np.array(train_loss), np.array(train_acc), np.array(valid_acc) 148 | np.savez(f'saved/train_valid_acc_{model_name}_{augmentation.name}_{seed}.npz', 149 | train_loss=train_loss, train_acc=train_acc, valid_acc=valid_acc) 150 | 151 | 152 | def exact_to_og_model(model): 153 | """Convert model training on exact augmented objective to model training on 154 | original data. 155 | """ 156 | model_og = copy.deepcopy(model) 157 | model_og.approx = True 158 | model_og.feature_avg = False 159 | model_og.regularization = False 160 | return model_og 161 | 162 | 163 | def exact_to_1st_order_model(model): 164 | """Convert model training on exact augmented objective to model training on 165 | 1st order approximation. 166 | """ 167 | model_1st = copy.deepcopy(model) 168 | model_1st.approx = True 169 | model_1st.feature_avg = True 170 | model_1st.regularization = False 171 | return model_1st 172 | 173 | 174 | def exact_to_2nd_order_no_1st_model(model): 175 | """Convert model training on exact augmented objective to model training on 176 | 2nd order approximation without feature averaging (1st order approx). 177 | """ 178 | model_2nd_no_1st = copy.deepcopy(model) 179 | model_2nd_no_1st.approx = True 180 | model_2nd_no_1st.feature_avg = False 181 | model_2nd_no_1st.regularization = True 182 | return model_2nd_no_1st 183 | 184 | 185 | def exact_to_2nd_order_model(model): 186 | """Convert model training on exact augmented objective to model training on 187 | 2nd order approximation. 188 | """ 189 | model_2nd = copy.deepcopy(model) 190 | model_2nd.approx = True 191 | model_2nd.feature_avg = True 192 | model_2nd.regularization = True 193 | return model_2nd 194 | 195 | 196 | def exact_to_2nd_order_model_layer_avg(model, layer_to_avg=3): 197 | """Convert LeNet model training on exact augmented objective to model 198 | training on 2nd order approximation, but approximation is done at different 199 | layers. 200 | """ 201 | model_2nd = copy.deepcopy(model) 202 | model_2nd.approx = True 203 | model_2nd.feature_avg = True 204 | model_2nd.regularization = True 205 | model_2nd.layer_to_avg = layer_to_avg 206 | # Can't use the regularization function specialized to linear model unless 207 | # averaging at layer 4. 208 | if layer_to_avg != 4: 209 | model.regularization_2nd_order = model.regularization_2nd_order_general 210 | return model_2nd 211 | 212 | 213 | def agreement_kl_difference(augmentations): 214 | """Measure the agreement and KL divergence between the predictions produced 215 | by model trained on exact augmentation objectives vs models trained on 216 | approximate objectives. 217 | """ 218 | model_variants = {'kernel': lambda model: [model, exact_to_og_model(model), exact_to_1st_order_model(model), 219 | exact_to_2nd_order_no_1st_model(model), exact_to_2nd_order_model(model)], 220 | 'lenet': lambda model: [model, exact_to_og_model(model), exact_to_1st_order_model(model), 221 | exact_to_2nd_order_no_1st_model(model)] + 222 | [exact_to_2nd_order_model_layer_avg(model, layer_to_avg) for layer_to_avg in [4, 3, 2, 1, 0]]} 223 | 224 | for model_name in ['kernel', 'lenet']: 225 | for augmentation in augmentations: 226 | for seed in range(n_trials): 227 | print(f'Seed: {seed}') 228 | torch.manual_seed(n_trials + seed) 229 | loader = loader_from_dataset(augmentation.dataset) 230 | model = model_factories[model_name]() 231 | models = model_variants[model_name](model) 232 | for model in models: 233 | model.to(device) 234 | optimizers = [sgd_opt_from_model(model) for model in models] 235 | for model in models: 236 | model.train() 237 | train_agreement, valid_agreement, valid_acc, valid_kl = [], [], [], [] 238 | for epoch in range(sgd_n_epochs): 239 | print(f'Train Epoch: {epoch}') 240 | train_agreement_epoch = train_models_compute_agreement(loader, models, optimizers) 241 | train_agreement.append(np.array(train_agreement_epoch).mean(axis=0)) 242 | # Agreement and KL on validation set 243 | valid_agreement_epoch, valid_kl_epoch, valid_acc_epoch = agreement_kl_accuracy(valid_loader, models) 244 | valid_agreement.append(np.array(valid_agreement_epoch).mean(axis=0)) 245 | valid_acc.append(np.array(valid_acc_epoch).mean(axis=0)) 246 | valid_kl.append(np.array(valid_kl_epoch).mean(axis=0)) 247 | train_agreement = np.array(train_agreement).T 248 | valid_agreement = np.array(valid_agreement).T 249 | valid_acc = np.array(valid_acc).T 250 | valid_kl = np.array(valid_kl).T 251 | np.savez(f'saved/train_valid_agreement_kl_{model_name}_{augmentation.name}_{seed}.npz', 252 | train_agreement=train_agreement, valid_agreement=valid_agreement, valid_acc=valid_acc, valid_kl=valid_kl) 253 | 254 | 255 | def find_gamma_by_alignment(train_loader, gamma_vals=(0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001)): 256 | """Example use of kernel target alignment: to pick the hyperparameter gamma 257 | of the RBF kernel exp(-gamma ||x-y||^2) by computing the kernel target 258 | alignment of the random features wrt different values of gamma. 259 | The value of gamma giving the highest alignment is likely the best gamma. 260 | """ 261 | for gamma in gamma_vals: 262 | model = RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, approx=False).to(device) 263 | print(kernel_target_alignment(train_loader, model)) 264 | # Best gamma is 0.003 265 | 266 | 267 | def alignment_comparison(augmentations): 268 | """Compute the kernel target alignment of different augmentations. 269 | """ 270 | alignment = [] 271 | model_name = 'kernel' 272 | for augmentation in augmentations: 273 | print(augmentation.name) 274 | loader = loader_from_dataset(augmentation.dataset) 275 | model = model_factories[model_name]().to(device) 276 | alignment.append(kernel_target_alignment_augmented(loader, model, n_passes_through_data=10)) 277 | alignment = np.array(alignment) 278 | alignment_no_transform = alignment[:, 1].mean() 279 | np.save('saved/kernel_alignment.npy', np.array([alignment_no_transform] + list(alignment[:, 0]))) 280 | 281 | 282 | def alignment_lenet(augmentations): 283 | """Compute the kernel target alignment on LeNet. Since the feature map is 284 | initialized to be random and then trained, unlike kernels where feature map 285 | is fixed, kernel target alignment doesn't predict the accuracy at all. 286 | """ 287 | for augmentation in augmentations: 288 | print(augmentation.name) 289 | model_base = LeNet().to(device) 290 | optimizer = sgd_opt_from_model(model_base) 291 | # Train LeNet for 1 epoch first 292 | _ = train_all_epochs(train_loader, valid_loader, model_base, optimizer, 1) 293 | model = LeNetAug().to(device) 294 | model.load_state_dict(model_base.state_dict()) 295 | loader = loader_from_dataset(augmentation.dataset) 296 | print(kernel_target_alignment_augmented(loader, model)) 297 | 298 | 299 | def measure_computation_fraction_lenet(train_loader): 300 | """Measure percentage of computation time spent in each layer of LeNet. 301 | """ 302 | model = LeNet().to(device) 303 | loader = train_loader 304 | it = iter(loader) 305 | data, target = next(it) 306 | data, target = data.to(device), target.to(device) 307 | # We use iPython's %timeit. Uncomment and copy these to iPython. 308 | # %timeit feat1 = model.layer_1(data) 309 | # feat1 = model.layer_1(data) 310 | # %timeit feat2 = model.layer_2(feat1) 311 | # feat2 = model.layer_2(feat1) 312 | # %timeit feat3 = model.layer_3(feat2) 313 | # feat3 = model.layer_3(feat2) 314 | # %timeit feat4 = model.layer_4(feat3) 315 | # feat4 = model.layer_4(feat3) 316 | # %timeit output = model.output_from_features(feat4) 317 | # %timeit output = model(data) 318 | 319 | 320 | def memory_profile(): 321 | # Print out the resident Tensors 322 | import gc 323 | for obj in gc.get_objects(): 324 | try: 325 | if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): 326 | print(type(obj), obj.size(), obj.type()) 327 | except: 328 | pass 329 | 330 | 331 | def main(): 332 | pathlib.Path('saved').mkdir(parents=True, exist_ok=True) 333 | # train_basic_models(train_loader, loader_from_dataset(augmentations[0].dataset)) 334 | objective_difference(augmentations[:4]) 335 | accuracy_on_true_objective(augmentations[4:]) 336 | agreement_kl_difference(augmentations[:4]) 337 | # find_gamma_by_alignment(train_loader) 338 | alignment_comparison(augmentations) 339 | # alignment_lenet(augmentations) 340 | 341 | if __name__ == '__main__': 342 | main() 343 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import autograd 5 | 6 | 7 | class MultinomialLogisticRegression(nn.Module): 8 | """Abstract class for multinomial logistic regression. 9 | Subclasses need to implement @features and @output_from_features. 10 | """ 11 | 12 | def features(self, x): 13 | raise NotImplementedError() 14 | 15 | def output_from_features(self, feat): 16 | raise NotImplementedError() 17 | 18 | def forward(self, x): 19 | return self.output_from_features(self.features(x)) 20 | 21 | @staticmethod 22 | def loss(output, target, reduce=True): 23 | return F.cross_entropy(output, target, reduce=reduce) 24 | 25 | @staticmethod 26 | def predict(output): 27 | return output.detach().max(1)[1] 28 | 29 | 30 | class MultinomialLogisticRegressionAug(MultinomialLogisticRegression): 31 | """Abstract class for multinomial logistic regression on augmented data. 32 | Input has size B x T x ..., where B is batch size and T is the number of 33 | transformations. 34 | Original i-th data point placed first among the transformed versions, which 35 | is input[i, 0]. 36 | Output has size B x T. 37 | Works exactly like the non-augmented version with default options and normal 38 | loader (where input is of size B x ...). 39 | """ 40 | 41 | def __init__(self, approx=True, feature_avg=True, regularization=False): 42 | """Parameters: 43 | approx: whether to use approximation or train on augmented data points. 44 | If False, ignore @feature_avg and @regularization. 45 | feature_avg: whether to average features or just use the features of the original data point. 46 | regularization: whether to add 2nd order term (variance regularization). 47 | """ 48 | self.approx = approx 49 | self.feature_avg = feature_avg 50 | self.regularization = regularization 51 | self.regularization_2nd_order_general = self.regularization_2nd_order 52 | 53 | def forward(self, x): 54 | augmented = x.dim() > 4 55 | if augmented: 56 | n_transforms = x.size(1) 57 | x = combine_transformed_dimension(x) 58 | feat = self.features(x) 59 | if self.approx and augmented: 60 | if not feat.requires_grad: 61 | feat.requires_grad = True 62 | feat = split_transformed_dimension(feat, n_transforms) 63 | if self.feature_avg: 64 | self._avg_features = feat.mean(dim=1) 65 | else: 66 | self._avg_features = feat[:, 0] 67 | if self.regularization: # Storing this every time consumes lots of memory 68 | self._centered_features = feat - self._avg_features[:, None] 69 | feat = self._avg_features 70 | output = self.output_from_features(feat) 71 | if not self.approx and augmented: 72 | output = split_transformed_dimension(output, n_transforms) 73 | return output 74 | 75 | @staticmethod 76 | def predict(output): 77 | if output.dim() > 2: 78 | # Average over transformed versions of the same data point 79 | output = output.mean(dim=1) 80 | return output.detach().max(1)[1] 81 | 82 | @classmethod 83 | def loss_original(cls, output, target, reduce=True): 84 | """Original cross entropy loss. 85 | """ 86 | return super().loss(output, target, reduce=reduce) 87 | 88 | @classmethod 89 | def loss_on_augmented_data(cls, output, target, reduce=True): 90 | """Loss averaged over augmented data points (no approximation). 91 | """ 92 | # For each data point, replicate the target then compute the cross 93 | # entropy loss. Finally stack the result. 94 | loss = torch.stack([ 95 | cls.loss_original(out, tar.repeat(out.size(0))) 96 | for out, tar in zip(output, target) 97 | ]) 98 | return loss.mean() if reduce else loss 99 | 100 | def regularization_2nd_order(self, output, reduce=True): 101 | """Compute regularization term from output instead of from loss. 102 | Fast implementation by evaluating the Jacobian directly instead of relying on 2nd order differentiation. 103 | """ 104 | p = F.softmax(output, dim=-1) 105 | # Using autograd.grad(output[:, i]) is slower since it creates new node in graph. 106 | # ones = torch.ones_like(output[:, 0]) 107 | # W = torch.stack([autograd.grad(output[:, i], self._avg_features, grad_outputs=ones, create_graph=True)[0] 108 | # for i in range(10)], dim=1) 109 | eye = torch.eye(output.size(1), device=output.device) 110 | eye = eye[None, :].expand(output.size(0), -1, -1) 111 | W = torch.stack([autograd.grad(output, self._avg_features, grad_outputs=eye[:, i], create_graph=True)[0] 112 | for i in range(10)], dim=1) 113 | # t = (W[:, None] * self._centered_features[:, :, None]).view(W.size(0), self._centered_features.size(1), W.size(1), -1).sum(dim=-1) 114 | t = (W.view(W.size(0), 1, W.size(1), -1) @ self._centered_features.view(*self._centered_features.shape[:2], -1, 1)).squeeze(-1) 115 | term_1 = (t**2 * p[:, None]).sum(dim=-1).mean(dim=-1) 116 | # term_1 = (t**2 @ p[:, :, None]).squeeze(2).mean(dim=-1) 117 | term_2 = ((t * p[:, None]).sum(dim=-1)**2).mean(dim=-1) 118 | # term_2 = ((t @ p[:, :, None]).squeeze(2)**2).mean(dim=-1) 119 | reg = (term_1 - term_2) / 2 120 | return reg.mean() if reduce else reg 121 | 122 | def regularization_2nd_order_linear(self, output, reduce=True): 123 | """Variance regularization (2nd order) term when the model is linear. 124 | Fastest implementations since it doesn't rely on pytorch's autograd. 125 | Equal to E[(W phi - W psi)^T (diag(p) - p p^T) (W phi - W psi)] / 2, 126 | where W is the weight matrix, phi is the feature, psi is the average 127 | feature, and p is the softmax probability. 128 | In this case @output is W phi + bias, but the bias will be subtracted away. 129 | """ 130 | p = F.softmax(output, dim=-1) 131 | unreduced_output = self.output_from_features(self._centered_features + self._avg_features[:, None]) 132 | reduced_output = self.output_from_features(self._avg_features) 133 | centered_output = unreduced_output - reduced_output[:, None] 134 | term_1 = (centered_output**2 * p[:, None]).sum(dim=-1).mean(dim=-1) 135 | term_2 = ((centered_output * p[:, None]).sum(dim=-1)**2).mean(dim=-1) 136 | reg = (term_1 - term_2) / 2 137 | return reg.mean() if reduce else reg 138 | 139 | def regularization_2nd_order_slow(self, output, reduce=True): 140 | """Compute regularization term from output, but uses pytorch's 2nd order differentiation. 141 | Slow implementation, only faster than @regularization_2nd_order_from_loss. 142 | """ 143 | p = F.softmax(output, dim=-1) 144 | g, = autograd.grad(output, self._avg_features, grad_outputs=p, create_graph=True) 145 | term_1 = [] 146 | for i in range(self._centered_features.size(1)): 147 | gg, = autograd.grad(g, p, grad_outputs=self._centered_features[:, i], create_graph=True) 148 | term_1.append((gg**2 * p).sum(dim=-1)) 149 | term_1 = torch.stack(term_1, dim=-1).mean(dim=-1) 150 | term_2 = ((g[:, None] * self._centered_features).view(*self._centered_features.shape[:2], -1).sum(dim=-1)**2).mean(dim=-1) 151 | reg = (term_1 - term_2) / 2 152 | return reg.mean() if reduce else reg 153 | 154 | def regularization_2nd_order_from_loss(self, loss, reduce=True): 155 | """Variance regularization (2nd order) term. 156 | Computed from loss, using Pytorch's 2nd order differentiation. 157 | This is much slower but more likely to be correct. Used to check other implementations. 158 | """ 159 | g, = autograd.grad(loss * self._avg_features.size(0), self._avg_features, create_graph=True) 160 | reg = [] 161 | for i in range(self._centered_features.size(1)): 162 | gg, = autograd.grad(g, self._avg_features, grad_outputs=self._centered_features[:, i], create_graph=True) 163 | reg.append((gg * self._centered_features[:, i]).view(gg.size(0), -1).sum(dim=-1)) 164 | reg = torch.stack(reg, dim=-1).mean(dim=-1) / 2 165 | return reg.mean() if reduce else reg 166 | 167 | def loss(self, output, target, reduce=True): 168 | """Cross entropy loss, with optional variance regularization. 169 | """ 170 | if not self.approx: # No approximation, loss on all augmented data points 171 | return self.loss_on_augmented_data(output, target, reduce=reduce) 172 | loss = self.loss_original(output, target, reduce=reduce) 173 | if self.regularization: 174 | return loss + self.regularization_2nd_order(output, reduce=reduce) 175 | else: 176 | return loss 177 | 178 | def all_losses(self, x, target, reduce=True): 179 | """All losses: true loss on augmented data, loss on original image, approximate 180 | loss with feature averaging (1st order), approximate loss with 181 | variance regularization and no feature averaging, and approximate 182 | loss with feature averaging and variance regularization (2nd order). 183 | Used to compare the effects of different approximations. 184 | 185 | Parameters: 186 | x: the input of size B (batch size) x T (no. of transforms) x ... 187 | target: target of size B (batch size) 188 | 189 | """ 190 | approx, feature_avg = self.approx, self.feature_avg 191 | self.approx, self.feature_avg, self.regularization = True, True, True 192 | output = self(x) 193 | features = self._centered_features + self._avg_features[:, None] 194 | n_transforms = features.size(1) 195 | unreduced_output = self.output_from_features(combine_transformed_dimension(features)) 196 | unreduced_output = split_transformed_dimension(unreduced_output, n_transforms) 197 | true_loss = self.loss_on_augmented_data(unreduced_output, target, reduce=reduce) 198 | reduced_output = output 199 | loss_original = self.loss_original(unreduced_output[:, 0], target, reduce=reduce) 200 | loss_1st_order = self.loss_original(reduced_output, target, reduce=reduce) 201 | reg_2nd_order = self.regularization_2nd_order(output, reduce=reduce) 202 | loss_2nd_order = loss_1st_order + reg_2nd_order 203 | loss_2nd_no_1st = loss_original + reg_2nd_order 204 | self.approx, self.feature_avg = approx, feature_avg 205 | return true_loss, loss_original, loss_1st_order, loss_2nd_no_1st, loss_2nd_order 206 | 207 | 208 | class LinearLogisticRegression(MultinomialLogisticRegression): 209 | """Simple linear logistic regression model. 210 | """ 211 | 212 | def __init__(self, n_features, n_classes): 213 | """Parameters: 214 | n_features: number of input features. 215 | n_classes: number of classes. 216 | """ 217 | super().__init__() 218 | self.fc = nn.Linear(n_features, n_classes) 219 | 220 | def features(self, x): 221 | return x.view(x.size(0), x.size(1), -1) if x.dim() > 4 else x.view(x.size(0), -1) 222 | 223 | def output_from_features(self, feat): 224 | return self.fc(feat) 225 | 226 | 227 | class RBFLogisticRegression(MultinomialLogisticRegression): 228 | """Logistic regression with RBF kernel approximation (random Fourier features). 229 | Equivalent to neural network with 2 layers: first layer is random 230 | projection with sine-cosine nonlinearity, and second trainable linear 231 | layer. 232 | """ 233 | 234 | def __init__(self, n_features, n_classes, gamma=1.0, n_components=100): 235 | """Parameters: 236 | n_features: number of input features. 237 | n_classes: number of classes. 238 | gamma: hyperparameter of the RBF kernel k(x, y) = exp(-gamma*||x-y||^2) 239 | n_components: number of components used to approximate kernel, i.e. 240 | number of hidden units. 241 | """ 242 | super().__init__() 243 | n_components //= 2 # Need 2 slots each for sine and cosine 244 | self.fc = nn.Linear(n_components * 2, n_classes) 245 | self.gamma = nn.Parameter(torch.Tensor([gamma]), requires_grad=False) 246 | self.random_directions = nn.Parameter( 247 | torch.randn(n_features, n_components), requires_grad=False) 248 | 249 | def features(self, x): 250 | x = x.view(x.size(0), x.size(1), -1) if x.dim() > 4 else x.view(x.size(0), -1) 251 | projected_x = torch.sqrt(2 * self.gamma) * (x @ self.random_directions) 252 | # Don't normalize by sqrt(self.n_components), it makes the weights too small. 253 | return torch.cat((torch.sin(projected_x), torch.cos(projected_x)), -1) 254 | 255 | def output_from_features(self, feat): 256 | return self.fc(feat) 257 | 258 | 259 | class LeNet(MultinomialLogisticRegression): 260 | """LeNet for MNIST, with 2 convolution-max pooling layers and 2 fully connected 261 | layers. 262 | """ 263 | 264 | def __init__(self, n_channels=1, size=28): 265 | super().__init__() 266 | self.conv1 = nn.Conv2d(n_channels, 6, 5) 267 | self.conv2 = nn.Conv2d(6, 16, 5) 268 | self.fc1 = nn.Linear(16 * (size // 4 - 3) * (size // 4 - 3), 120) 269 | self.fc2 = nn.Linear(120, 84) 270 | self.fc3 = nn.Linear(84, 10) 271 | self.layers = [self.layer_1, self.layer_2, self.layer_3, self.layer_4] 272 | 273 | def layer_1(self, x): 274 | feat = F.relu(self.conv1(x)) 275 | return F.max_pool2d(feat, 2) 276 | 277 | def layer_2(self, x): 278 | feat = F.relu(self.conv2(x)) 279 | return F.max_pool2d(feat, 2).view(feat.size(0), -1) 280 | 281 | def layer_3(self, x): 282 | return F.relu(self.fc1(x)) 283 | 284 | def layer_4(self, x): 285 | return F.relu(self.fc2(x)) 286 | 287 | def features(self, x): 288 | feat = x 289 | for layer in self.layers: 290 | feat = layer(feat) 291 | return feat 292 | 293 | def output_from_features(self, feat): 294 | return self.fc3(feat) 295 | 296 | 297 | class LinearLogisticRegressionAug(MultinomialLogisticRegressionAug, 298 | LinearLogisticRegression): 299 | """Linear logistic regression model with augmented data. 300 | Input has size B x T x ..., where B is batch size and T is the number of 301 | transformations. 302 | Original i-th data point placed first among the transformed versions, which 303 | is input[i, 0]. 304 | Output has size B x T. 305 | """ 306 | 307 | def __init__(self, 308 | n_features, 309 | n_classes, 310 | approx=True, 311 | feature_avg=True, 312 | regularization=False): 313 | """Parameters: 314 | n_features: number of input features. 315 | n_classes: number of classes. 316 | approx: whether to use approximation or train on augmented data points. 317 | If False, ignore @feature_avg and @regularization. 318 | feature_avg: whether to average features or just use the features of the original data point. 319 | regularization: whether to add 2nd order term (variance regularization). 320 | """ 321 | LinearLogisticRegression.__init__(self, n_features, n_classes) 322 | MultinomialLogisticRegressionAug.__init__(self, approx, feature_avg, 323 | regularization) 324 | self.regularization_2nd_order = self.regularization_2nd_order_linear 325 | 326 | 327 | class RBFLogisticRegressionAug(MultinomialLogisticRegressionAug, 328 | RBFLogisticRegression): 329 | """Logistic regression model with RBF kernel and augmented data. 330 | Input has size B x T x ..., where B is batch size and T is the number of 331 | transformations. 332 | Original i-th data point placed first among the transformed versions, which 333 | is input[i, 0]. 334 | Output has size B x T. 335 | """ 336 | 337 | def __init__(self, 338 | n_features, 339 | n_classes, 340 | gamma=1.0, 341 | n_components=100, 342 | approx=True, 343 | feature_avg=True, 344 | regularization=False): 345 | RBFLogisticRegression.__init__(self, n_features, n_classes, gamma, 346 | n_components) 347 | MultinomialLogisticRegressionAug.__init__(self, approx, feature_avg, 348 | regularization) 349 | self.regularization_2nd_order = self.regularization_2nd_order_linear 350 | 351 | 352 | class LeNetAug(MultinomialLogisticRegressionAug, LeNet): 353 | """LeNet for MNIST, with 2 convolution-max pooling layers and 2 fully connected 354 | layers. 355 | """ 356 | 357 | def __init__(self, n_channels=1, size=28, approx=True, feature_avg=True, regularization=False, layer_to_avg=4): 358 | LeNet.__init__(self, n_channels, size) 359 | MultinomialLogisticRegressionAug.__init__(self, approx, feature_avg, 360 | regularization) 361 | error_msg = "[!] layer_to_avg should be in the range [0, ..., 4]." 362 | assert (layer_to_avg in range(5)), error_msg 363 | self.layer_to_avg = layer_to_avg 364 | if layer_to_avg == 4: # Not a linear model unless averaging at 4th layer 365 | self.regularization_2nd_order = self.regularization_2nd_order_linear 366 | 367 | def features(self, x): 368 | feat = x 369 | for layer in self.layers[:self.layer_to_avg]: 370 | feat = layer(feat) 371 | return feat 372 | 373 | def output_from_features(self, feat): 374 | for layer in self.layers[self.layer_to_avg:]: 375 | feat = layer(feat) 376 | return self.fc3(feat) 377 | 378 | 379 | def combine_transformed_dimension(input): 380 | """Combine the minibatch and the transformation dimensions. 381 | Parameter: 382 | input: Tensor of shape B x T x ..., where B is the batch size and T is 383 | the number of transformations. 384 | Return: 385 | output: Same tensor, now of shape (B * T) x .... 386 | """ 387 | return input.view(-1, *input.shape[2:]) 388 | 389 | 390 | def split_transformed_dimension(input, n_transforms): 391 | """Split the minibatch and the transformation dimensions. 392 | Parameter: 393 | input: Tensor of shape (B * T) x ..., where B is the batch size and T is 394 | the number of transformations. 395 | Return: 396 | output: Same tensor, now of shape B x T x .... 397 | """ 398 | return input.view(-1, n_transforms, *input.shape[1:]) 399 | -------------------------------------------------------------------------------- /cifar10_experiments.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import pathlib 4 | 5 | import torch 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torchvision import datasets, transforms 9 | 10 | from models import LinearLogisticRegression, RBFLogisticRegression, RBFLogisticRegressionRotated, LinearLogisticRegressionAug, RBFLogisticRegressionAug, LeNet, LeNetAug, combine_transformed_dimension, split_transformed_dimension 11 | from augmentation import copy_with_new_transform, augment_transforms, rotation, resized_crop, blur, rotation_crop_blur, hflip, hflip_vflip, brightness, contrast 12 | from utils import get_train_valid_datasets, train, train_all_epochs, accuracy, all_losses, train_models_compute_agreement, agreement_kl_accuracy, kernel_target_alignment, kernel_target_alignment_augmented, kernel_target_alignment_augmented_no_avg 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | save_dir = 'saved_cifar10' 17 | # save_dir = 'saved_cifar10_rerun_50_epochs' 18 | # save_dir = 'saved_cifar10_rerun_100_epochs' 19 | # save_dir = 'saved_cifar10_basic_models_3_channels' 20 | 21 | batch_size = 256 22 | if device.type == 'cuda': 23 | loader_args = {'num_workers': 32, 'pin_memory': True} 24 | # loader_args = {'num_workers': 16, 'pin_memory': True} 25 | else: 26 | loader_args = {'num_workers': 4, 'pin_memory': False} 27 | 28 | 29 | def loader_from_dataset(dataset): 30 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, 31 | shuffle=True, **loader_args) 32 | 33 | # Construct loader from CIFAR10 dataset, then construct loaders corresponding to 34 | # augmented dataset (wrt to different transformations). 35 | cifar10_normalize = transforms.Compose([ 36 | # transforms.Grayscale(num_output_channels=1), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.5, ), (0.5, )) 39 | ]) 40 | cifar10_train = datasets.CIFAR10( 41 | '../data', train=True, download=True, transform=cifar10_normalize) 42 | cifar10_test = datasets.CIFAR10( 43 | '../data', train=False, download=True, transform=cifar10_normalize) 44 | # For some reason the train labels are lists instead of torch.LongTensor 45 | cifar10_train.train_labels = torch.LongTensor(cifar10_train.train_labels) 46 | cifar10_test.test_labels = torch.LongTensor(cifar10_test.test_labels) 47 | cifar10_train, cifar10_valid = get_train_valid_datasets(cifar10_train) 48 | train_loader = loader_from_dataset(cifar10_train) 49 | valid_loader = loader_from_dataset(cifar10_valid) 50 | test_loader = loader_from_dataset(cifar10_test) 51 | 52 | cifar10_normalize_rotate = transforms.Compose([ 53 | # transforms.Grayscale(num_output_channels=1), 54 | transforms.RandomRotation((-5, 5)), 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.5, ), (0.5, )) 57 | ]) 58 | cifar10_test_rotated = datasets.CIFAR10( 59 | '../data', train=False, download=True, transform=cifar10_normalize_rotate) 60 | test_loader_rotated = loader_from_dataset(cifar10_test_rotated) 61 | 62 | augmentations = [rotation(cifar10_train, cifar10_normalize, angles=range(-5, 6, 1)), 63 | resized_crop(cifar10_train, cifar10_normalize, size=32), 64 | blur(cifar10_train, cifar10_normalize), 65 | rotation_crop_blur(cifar10_train, cifar10_normalize, size=32), 66 | hflip(cifar10_train, cifar10_normalize), 67 | hflip_vflip(cifar10_train, cifar10_normalize), 68 | brightness(cifar10_train, cifar10_normalize), 69 | contrast(cifar10_train, cifar10_normalize)] 70 | 71 | n_channels = 3 72 | size = 32 73 | n_features = n_channels * size * size 74 | n_classes = 10 75 | gamma = 0.003 # gamma hyperparam for RBF kernel exp(-gamma ||x - y||^2). Best gamma is around 0.001--0.003 76 | n_components = 10000 77 | sgd_n_epochs = 15 78 | n_trials = 10 79 | 80 | model_factories = {'linear': lambda: LinearLogisticRegressionAug(n_features, n_classes), 81 | 'kernel': lambda: RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, approx=False), 82 | 'lenet': lambda: LeNetAug(n_channels=n_channels, size=size, approx=False)} 83 | 84 | 85 | def sgd_opt_from_model(model, learning_rate=0.1, momentum=0.9, weight_decay=0.000): 86 | # return optim.SGD((p for p in model.parameters() if p.requires_grad), 87 | # lr=learning_rate, momentum=momentum, 88 | # weight_decay=weight_decay) 89 | return optim.Adam((p for p in model.parameters() if p.requires_grad), 90 | weight_decay=weight_decay) 91 | 92 | 93 | def train_basic_models(train_loader, augmented_loader): 94 | """Train a few simple models with data augmentation / approximation, as a 95 | sanity check. 96 | """ 97 | test_acc = [] 98 | test_acc_rotated = [] 99 | for seed in range(n_trials): 100 | print(f'Seed: {seed}') 101 | torch.manual_seed(seed) 102 | models = [ 103 | LinearLogisticRegressionAug(n_features, n_classes), # No augmentation, accuracy around 92.5% 104 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components), # Accuracy around 97.5% 105 | LeNetAug(n_channels=n_channels, size=size), # Accuracy around 98.7% 106 | LinearLogisticRegressionAug(n_features, n_classes, approx=False), # Augmented data, exact objective 107 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, approx=False), 108 | LeNetAug(n_channels=n_channels, size=size, approx=False), 109 | LinearLogisticRegressionAug(n_features, n_classes, regularization=False), # Augmented data, 1st order approx 110 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, regularization=False), 111 | LeNetAug(n_channels=n_channels, size=size), 112 | LinearLogisticRegressionAug(n_features, n_classes, feature_avg=False, regularization=True), # Augmented data, 2nd order no 1st approx 113 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, feature_avg=False, regularization=True), 114 | LeNetAug(n_channels=n_channels, size=size, feature_avg=False, regularization=True), 115 | LinearLogisticRegressionAug(n_features, n_classes, regularization=True), # Augmented data, 2nd order approx 116 | RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, regularization=True), 117 | LeNetAug(n_channels=n_channels, size=size, regularization=True), 118 | RBFLogisticRegressionRotated(n_features, n_classes, gamma=gamma, n_components=n_components, size=size, n_channels=n_channels) 119 | ] 120 | loaders = [train_loader, train_loader, train_loader, 121 | augmented_loader, augmented_loader, augmented_loader, 122 | augmented_loader, augmented_loader, augmented_loader, 123 | augmented_loader, augmented_loader, augmented_loader, 124 | augmented_loader, augmented_loader, augmented_loader, 125 | train_loader] 126 | test_acc_per_seed = [] 127 | test_acc_rotated_per_seed = [] 128 | import time 129 | for model, loader in zip(models, loaders): 130 | start = time.time() 131 | # model = RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components) 132 | model = RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, approx=False) 133 | # model = RBFLogisticRegressionRotated(n_features, n_classes, gamma=gamma, n_components=n_components, n_channels=n_channels, size=size) 134 | # end = time.time() 135 | # print(end - start) 136 | model.to(device) 137 | optimizer = sgd_opt_from_model(model) 138 | # start = time.time() 139 | train_loss, train_acc, valid_acc = train_all_epochs(loader, valid_loader, model, 140 | optimizer, sgd_n_epochs, verbose=True) 141 | end = time.time() 142 | print(end - start) 143 | correct, total = accuracy(test_loader, model) 144 | print(f'Test set: Accuracy: {correct}/{total} ({correct/total*100:.4f}%)') 145 | correct_rotated, total_rotated = accuracy(test_loader_rotated, model) 146 | print(f'Test set rotated: Accuracy: {correct_rotated}/{total_rotated} ({correct_rotated/total_rotated*100:.4f}%)\n') 147 | test_acc_per_seed.append(correct / total) 148 | test_acc_rotated_per_seed.append(correct_rotated / total_rotated) 149 | np.save(f'{save_dir}/basic_models_cifar10_test_accuracy_last_{seed}.np', np.array(test_acc_per_seed)) 150 | np.save(f'{save_dir}/basic_models_cifar10_test_accuracy_rotated_last_{seed}.np', np.array(test_acc_rotated_per_seed)) 151 | test_acc.append(np.array(test_acc_per_seed)) 152 | test_acc_rotated.append(np.array(test_acc_rotated_per_seed)) 153 | test_acc = np.array(test_acc) 154 | test_acc_rotated = np.array(test_acc_rotated) 155 | np.save(f'{save_dir}/basic_models_cifar10_test_accuracy_last', test_acc) 156 | np.save(f'{save_dir}/basic_models_cifar10_test_accuracy_rotated_last', test_acc_rotated) 157 | 158 | 159 | def objective_difference(augmentations): 160 | """Measure the difference in approximate and true objectives as we train on 161 | the true objective. 162 | """ 163 | # for model_name in ['kernel', 'lenet']: 164 | for model_name in ['lenet']: 165 | for augmentation in augmentations: 166 | for seed in range(n_trials): 167 | print(f'Seed: {seed}') 168 | torch.manual_seed(seed) 169 | model = model_factories[model_name]().to(device) 170 | optimizer = sgd_opt_from_model(model) 171 | loader = loader_from_dataset(augmentation.dataset) 172 | model.train() 173 | losses = [] 174 | losses.append(all_losses(loader, model).mean(axis=0)) 175 | train_loss, train_acc, valid_acc = [], [], [] 176 | for epoch in range(sgd_n_epochs): 177 | train_loss_epoch, train_acc_epoch = train(loader, model, optimizer) 178 | train_loss += train_loss_epoch 179 | train_acc += train_acc_epoch 180 | print(f'Train Epoch: {epoch}') 181 | correct, total = accuracy(valid_loader, model) 182 | valid_acc.append(correct / total) 183 | print( 184 | f'Validation set: Accuracy: {correct}/{total} ({correct/total*100:.4f}%)' 185 | ) 186 | losses.append(np.array(all_losses(loader, model)).mean(axis=0)) 187 | train_loss, train_acc, valid_acc = np.array(train_loss), np.array(train_acc), np.array(valid_acc) 188 | np.savez(f'{save_dir}/train_valid_acc_{model_name}_{augmentation.name}_{seed}.npz', 189 | train_loss=train_loss, train_acc=train_acc, valid_acc=valid_acc) 190 | losses = np.array(losses).T 191 | np.save(f'{save_dir}/all_losses_{model_name}_{augmentation.name}_{seed}.npy', losses) 192 | 193 | 194 | def accuracy_on_true_objective(augmentations): 195 | """Measure the accuracy when trained on true augmented objective. 196 | """ 197 | for model_name in ['kernel', 'lenet']: 198 | for augmentation in augmentations: 199 | for seed in range(n_trials): 200 | print(f'Seed: {seed}') 201 | torch.manual_seed(seed) 202 | model = model_factories[model_name]().to(device) 203 | optimizer = sgd_opt_from_model(model) 204 | loader = loader_from_dataset(augmentation.dataset) 205 | train_loss, train_acc, valid_acc = train_all_epochs(loader, valid_loader, model, optimizer, sgd_n_epochs) 206 | train_loss, train_acc, valid_acc = np.array(train_loss), np.array(train_acc), np.array(valid_acc) 207 | np.savez(f'{save_dir}/train_valid_acc_{model_name}_{augmentation.name}_{seed}.npz', 208 | train_loss=train_loss, train_acc=train_acc, valid_acc=valid_acc) 209 | 210 | 211 | def exact_to_og_model(model): 212 | """Convert model training on exact augmented objective to model training on 213 | original data. 214 | """ 215 | model_og = copy.deepcopy(model) 216 | model_og.approx = True 217 | model_og.feature_avg = False 218 | model_og.regularization = False 219 | return model_og 220 | 221 | 222 | def exact_to_1st_order_model(model): 223 | """Convert model training on exact augmented objective to model training on 224 | 1st order approximation. 225 | """ 226 | model_1st = copy.deepcopy(model) 227 | model_1st.approx = True 228 | model_1st.feature_avg = True 229 | model_1st.regularization = False 230 | return model_1st 231 | 232 | 233 | def exact_to_2nd_order_no_1st_model(model): 234 | """Convert model training on exact augmented objective to model training on 235 | 2nd order approximation without feature averaging (1st order approx). 236 | """ 237 | model_2nd_no_1st = copy.deepcopy(model) 238 | model_2nd_no_1st.approx = True 239 | model_2nd_no_1st.feature_avg = False 240 | model_2nd_no_1st.regularization = True 241 | return model_2nd_no_1st 242 | 243 | 244 | def exact_to_2nd_order_model(model): 245 | """Convert model training on exact augmented objective to model training on 246 | 2nd order approximation. 247 | """ 248 | model_2nd = copy.deepcopy(model) 249 | model_2nd.approx = True 250 | model_2nd.feature_avg = True 251 | model_2nd.regularization = True 252 | return model_2nd 253 | 254 | 255 | def exact_to_2nd_order_model_layer_avg(model, layer_to_avg=3): 256 | """Convert LeNet model training on exact augmented objective to model 257 | training on 2nd order approximation, but approximation is done at different 258 | layers. 259 | """ 260 | model_2nd = copy.deepcopy(model) 261 | model_2nd.approx = True 262 | model_2nd.feature_avg = True 263 | model_2nd.regularization = True 264 | model_2nd.layer_to_avg = layer_to_avg 265 | # Can't use the regularization function specialized to linear model unless 266 | # averaging at layer 4. 267 | if layer_to_avg != 4: 268 | model.regularization_2nd_order = model.regularization_2nd_order_general 269 | return model_2nd 270 | 271 | 272 | def agreement_kl_difference(augmentations): 273 | """Measure the agreement and KL divergence between the predictions produced 274 | by model trained on exact augmentation objectives vs models trained on 275 | approximate objectives. 276 | """ 277 | model_variants = {'kernel': lambda model: [model, exact_to_og_model(model), exact_to_1st_order_model(model), 278 | exact_to_2nd_order_no_1st_model(model), exact_to_2nd_order_model(model)], 279 | 'lenet': lambda model: [model, exact_to_og_model(model), exact_to_1st_order_model(model), 280 | exact_to_2nd_order_no_1st_model(model)] + 281 | [exact_to_2nd_order_model_layer_avg(model, layer_to_avg) for layer_to_avg in [4, 3, 2, 1, 0]]} 282 | 283 | for model_name in ['kernel', 'lenet']: 284 | # for model_name in ['lenet']: 285 | for augmentation in augmentations: 286 | for seed in range(5, 5 + n_trials): 287 | print(f'Seed: {seed}') 288 | torch.manual_seed(n_trials + seed) 289 | loader = loader_from_dataset(augmentation.dataset) 290 | model = model_factories[model_name]() 291 | models = model_variants[model_name](model) 292 | for model in models: 293 | model.to(device) 294 | optimizers = [sgd_opt_from_model(model) for model in models] 295 | for model in models: 296 | model.train() 297 | train_agreement, valid_agreement, valid_acc, valid_kl = [], [], [], [] 298 | for epoch in range(sgd_n_epochs): 299 | print(f'Train Epoch: {epoch}') 300 | train_agreement_epoch = train_models_compute_agreement(loader, models, optimizers) 301 | train_agreement.append(np.array(train_agreement_epoch).mean(axis=0)) 302 | # Agreement and KL on validation set 303 | valid_agreement_epoch, valid_kl_epoch, valid_acc_epoch = agreement_kl_accuracy(valid_loader, models) 304 | valid_agreement.append(np.array(valid_agreement_epoch).mean(axis=0)) 305 | valid_acc.append(np.array(valid_acc_epoch).mean(axis=0)) 306 | valid_kl.append(np.array(valid_kl_epoch).mean(axis=0)) 307 | train_agreement = np.array(train_agreement).T 308 | valid_agreement = np.array(valid_agreement).T 309 | valid_acc = np.array(valid_acc).T 310 | valid_kl = np.array(valid_kl).T 311 | np.savez(f'{save_dir}/train_valid_agreement_kl_{model_name}_{augmentation.name}_{seed}.npz', 312 | train_agreement=train_agreement, valid_agreement=valid_agreement, valid_acc=valid_acc, valid_kl=valid_kl) 313 | 314 | 315 | def find_gamma_by_alignment(train_loader, gamma_vals=(0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001)): 316 | """Example use of kernel target alignment: to pick the hyperparameter gamma 317 | of the RBF kernel exp(-gamma ||x-y||^2) by computing the kernel target 318 | alignment of the random features wrt different values of gamma. 319 | The value of gamma giving the highest alignment is likely the best gamma. 320 | """ 321 | for gamma in gamma_vals: 322 | model = RBFLogisticRegressionAug(n_features, n_classes, gamma=gamma, n_components=n_components, approx=False).to(device) 323 | print(kernel_target_alignment(train_loader, model)) 324 | # Best gamma is 0.003 325 | 326 | 327 | def alignment_comparison(augmentations): 328 | """Compute the kernel target alignment of different augmentations. 329 | """ 330 | alignment = [] 331 | model_name = 'kernel' 332 | for augmentation in augmentations[:1]: 333 | print(augmentation.name) 334 | loader = loader_from_dataset(augmentation.dataset) 335 | model = model_factories[model_name]().to(device) 336 | alignment.append(kernel_target_alignment_augmented(loader, model, n_passes_through_data=50)) 337 | print(alignment) 338 | alignment = np.array(alignment) 339 | alignment_no_transform = alignment[:, 1].mean() 340 | np.save(f'{save_dir}/kernel_alignment.npy', np.array([alignment_no_transform] + list(alignment[:, 0]))) 341 | 342 | alignment_no_avg = [] 343 | model_name = 'kernel' 344 | for augmentation in augmentations: 345 | print(augmentation.name) 346 | loader = loader_from_dataset(augmentation.dataset) 347 | model = model_factories[model_name]().to(device) 348 | alignment_no_avg.append(kernel_target_alignment_augmented_no_avg(loader, model, n_passes_through_data=10)) 349 | alignment_no_avg = np.array(alignment_no_avg) 350 | np.save(f'{save_dir}/kernel_alignment_no_avg.npy', alignment_no_avg) 351 | 352 | 353 | def measure_computation_fraction_lenet(train_loader): 354 | """Measure percentage of computation time spent in each layer of LeNet. 355 | """ 356 | model = LeNet(n_channels=n_channels, size=32).to(device) 357 | loader = train_loader 358 | it = iter(loader) 359 | data, target = next(it) 360 | data, target = data.to(device), target.to(device) 361 | # We use iPython's %timeit. Uncomment and copy these to iPython. 362 | # %timeit feat1 = model.layer_1(data) 363 | # feat1 = model.layer_1(data) 364 | # %timeit feat2 = model.layer_2(feat1) 365 | # feat2 = model.layer_2(feat1) 366 | # %timeit feat3 = model.layer_3(feat2) 367 | # feat3 = model.layer_3(feat2) 368 | # %timeit feat4 = model.layer_4(feat3) 369 | # feat4 = model.layer_4(feat3) 370 | # %timeit output = model.output_from_features(feat4) 371 | # %timeit output = model(data) 372 | 373 | 374 | 375 | def main(): 376 | pathlib.Path(f'{save_dir}').mkdir(parents=True, exist_ok=True) 377 | # train_basic_models(train_loader, loader_from_dataset(augmentations[0].dataset)) 378 | # objective_difference(augmentations[:4]) 379 | accuracy_on_true_objective(augmentations[:1]) 380 | # agreement_kl_difference(augmentations[:1]) 381 | # find_gamma_by_alignment(train_loader) 382 | # alignment_comparison(augmentations) 383 | # alignment_lenet(augmentations) 384 | 385 | if __name__ == '__main__': 386 | main() 387 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | plt.switch_backend('agg') 5 | import seaborn as sns 6 | 7 | sgd_n_epochs = 15 8 | n_trials = 10 9 | 10 | 11 | def plot_objective_difference(): 12 | """Plot objective difference during training 13 | """ 14 | for model_name in ['kernel', 'lenet']: 15 | for transform_name in ['rotation', 'crop', 'blur', 'rotation_crop_blur']: 16 | losses = np.array([np.load(f'saved/all_losses_{model_name}_{transform_name}_{seed}.npy') for seed in range(n_trials)]) 17 | diff_og = losses[:, 1] - losses[:, 0] 18 | diff_1st = losses[:, 2] - losses[:, 0] 19 | diff_2nd_no_1st = losses[:, 3] - losses[:, 0] 20 | diff_2nd = losses[:, 4] - losses[:, 0] 21 | plt.clf() 22 | plt.errorbar(range(sgd_n_epochs + 1), diff_og.mean(axis=0), diff_og.std(axis=0), fmt='o-', capsize=5, label='Original image') 23 | plt.errorbar(range(sgd_n_epochs + 1), diff_1st.mean(axis=0), diff_1st.std(axis=0), fmt='o-', capsize=5, label='1st-order') 24 | plt.errorbar(range(sgd_n_epochs + 1), diff_2nd_no_1st.mean(axis=0), diff_2nd_no_1st.std(axis=0), fmt='o-', capsize=5, label='2nd-order w/o 1st-order') 25 | plt.errorbar(range(sgd_n_epochs + 1), diff_2nd.mean(axis=0), diff_2nd.std(axis=0), fmt='o-', capsize=5, label='2nd-order') 26 | plt.xlabel('Epoch') 27 | plt.ylabel('Difference in objective') 28 | plt.legend() 29 | plt.axhline(color='k') 30 | plt.savefig(f'figs/objective_difference_{model_name}_{transform_name}.pdf', bbox_inches='tight') 31 | 32 | def plot_agreement_kl(): 33 | """Plot training/valid agreements and KL divergence 34 | """ 35 | 36 | for model_name in ['kernel', 'lenet']: 37 | for transform_name in ['rotation', 'crop', 'blur', 'rotation_crop_blur']: 38 | saved_arrays = [np.load(f'saved/train_valid_agreement_kl_{model_name}_{transform_name}_{seed}.npz') 39 | for seed in range(n_trials)] 40 | train_agreement = np.array([saved['train_agreement'] for saved in saved_arrays]) 41 | valid_agreement = np.array([saved['valid_agreement'] for saved in saved_arrays]) 42 | valid_kl = np.array([saved['valid_kl'] for saved in saved_arrays]) 43 | valid_acc = np.array([saved['valid_acc'] for saved in saved_arrays]) 44 | 45 | plt.clf() 46 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 1].mean(axis=0), train_agreement[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 47 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 2].mean(axis=0), train_agreement[:, 2].std(axis=0), fmt='o-', capsize=5, label='1st-order') 48 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 3].mean(axis=0), train_agreement[:, 3].std(axis=0), fmt='o-', capsize=5, label='2nd-order w/o 1st-order') 49 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 4].mean(axis=0), train_agreement[:, 4].std(axis=0), fmt='o-', capsize=5, label='2nd-order') 50 | plt.xlabel('Epoch') 51 | plt.ylabel('Prediction agreement') 52 | plt.legend() 53 | # plt.axhline(color='k') 54 | plt.savefig(f'figs/prediction_agreement_training_{model_name}_{transform_name}.pdf', bbox_inches='tight') 55 | 56 | plt.clf() 57 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 1].mean(axis=0), valid_agreement[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 58 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 2].mean(axis=0), valid_agreement[:, 2].std(axis=0), fmt='o-', capsize=5, label='1st-order') 59 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 3].mean(axis=0), valid_agreement[:, 3].std(axis=0), fmt='o-', capsize=5, label='2nd-order w/o 1st-order') 60 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 4].mean(axis=0), valid_agreement[:, 4].std(axis=0), fmt='o-', capsize=5, label='2nd-order') 61 | plt.xlabel('Epoch') 62 | plt.ylabel('Prediction agreement') 63 | plt.legend() 64 | # plt.axhline(color='k') 65 | plt.savefig(f'figs/prediction_agreement_valid_{model_name}_{transform_name}.pdf', bbox_inches='tight') 66 | 67 | plt.clf() 68 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 1].mean(axis=0), valid_kl[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 69 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 2].mean(axis=0), valid_kl[:, 2].std(axis=0), fmt='o-', capsize=5, label='1st-order') 70 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 3].mean(axis=0), valid_kl[:, 3].std(axis=0), fmt='o-', capsize=5, label='2nd-order w/o 1st-order') 71 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 4].mean(axis=0), valid_kl[:, 4].std(axis=0), fmt='o-', capsize=5, label='2nd-order') 72 | plt.xlabel('Epoch') 73 | plt.ylabel('Prediction KL') 74 | plt.legend() 75 | plt.axhline(color='k') 76 | plt.savefig(f'figs/kl_valid_{model_name}_{transform_name}.pdf', bbox_inches='tight') 77 | 78 | plt.clf() 79 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 1].mean(axis=0), valid_acc[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 80 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 2].mean(axis=0), valid_acc[:, 2].std(axis=0), fmt='o-', capsize=5, label='1st-order') 81 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 3].mean(axis=0), valid_acc[:, 3].std(axis=0), fmt='o-', capsize=5, label='2nd-order w/o 1st-order') 82 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 4].mean(axis=0), valid_acc[:, 4].std(axis=0), fmt='o-', capsize=5, label='2nd-order') 83 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 0].mean(axis=0), valid_acc[:, 0].std(axis=0), fmt='o-', capsize=5, label='Exact (augmented images)') 84 | plt.xlabel('Epoch') 85 | plt.ylabel('Accuracy') 86 | plt.legend() 87 | # plt.axhline(color='k') 88 | plt.savefig(f'figs/accuracy_valid_{model_name}_{transform_name}.pdf', bbox_inches='tight') 89 | 90 | plt.clf() 91 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 1] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 1] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='Original image') 92 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 2] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 2] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='1st-order') 93 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 3] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 3] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='2nd-order w/o 1st-order') 94 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 4] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 4] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='2nd-order') 95 | plt.xlabel('Epoch') 96 | plt.ylabel('Accuracy difference') 97 | plt.legend() 98 | plt.axhline(color='k') 99 | plt.savefig(f'figs/accuracy_difference_valid_{model_name}_{transform_name}.pdf', bbox_inches='tight') 100 | 101 | def plot_agreement_kl_avg_at_layers(): 102 | """Plot generalization difference when doing feature averaging at different layers 103 | """ 104 | model_name = 'lenet' 105 | transform_name = 'rotation' 106 | saved_arrays = [np.load(f'saved/train_valid_agreement_kl_{model_name}_{transform_name}_{seed}.npz') 107 | for seed in range(n_trials)] 108 | train_agreement = np.array([saved['train_agreement'] for saved in saved_arrays]) 109 | valid_agreement = np.array([saved['valid_agreement'] for saved in saved_arrays]) 110 | valid_kl = np.array([saved['valid_kl'] for saved in saved_arrays]) 111 | valid_acc = np.array([saved['valid_acc'] for saved in saved_arrays]) 112 | plt.clf() 113 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 1].mean(axis=0), train_agreement[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 114 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 4].mean(axis=0), train_agreement[:, 4].std(axis=0), fmt='o-', capsize=5, label='Averaged at 4th layer') 115 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 5].mean(axis=0), train_agreement[:, 5].std(axis=0), fmt='o-', capsize=5, label='Averaged at 3rd layer') 116 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 6].mean(axis=0), train_agreement[:, 6].std(axis=0), fmt='o-', capsize=5, label='Averaged at 2nd layer') 117 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 7].mean(axis=0), train_agreement[:, 7].std(axis=0), fmt='o-', capsize=5, label='Averaged at 1st layer') 118 | plt.errorbar(range(1, sgd_n_epochs + 1), train_agreement[:, 8].mean(axis=0), train_agreement[:, 8].std(axis=0), fmt='o-', capsize=5, label='Averaged at 0th layer') 119 | plt.xlabel('Epoch') 120 | plt.ylabel('Prediction agreement') 121 | plt.legend() 122 | # plt.axhline(color='k') 123 | plt.savefig(f'figs/prediction_agreement_training_{model_name}_{transform_name}_layers.pdf', bbox_inches='tight') 124 | 125 | plt.clf() 126 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 1].mean(axis=0), valid_agreement[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 127 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 4].mean(axis=0), valid_agreement[:, 4].std(axis=0), fmt='o-', capsize=5, label='Averaged at 4th layer') 128 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 5].mean(axis=0), valid_agreement[:, 5].std(axis=0), fmt='o-', capsize=5, label='Averaged at 3rd layer') 129 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 6].mean(axis=0), valid_agreement[:, 6].std(axis=0), fmt='o-', capsize=5, label='Averaged at 2nd layer') 130 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 7].mean(axis=0), valid_agreement[:, 7].std(axis=0), fmt='o-', capsize=5, label='Averaged at 1st layer') 131 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_agreement[:, 8].mean(axis=0), valid_agreement[:, 8].std(axis=0), fmt='o-', capsize=5, label='Averaged at 0th layer') 132 | plt.xlabel('Epoch') 133 | plt.ylabel('Prediction agreement') 134 | plt.legend() 135 | # plt.axhline(color='k') 136 | plt.savefig(f'figs/prediction_agreement_valid_{model_name}_{transform_name}_layers.pdf', bbox_inches='tight') 137 | plt.clf() 138 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 1].mean(axis=0), valid_kl[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 139 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 4].mean(axis=0), valid_kl[:, 4].std(axis=0), fmt='o-', capsize=5, label='Averaged at 4th layer') 140 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 5].mean(axis=0), valid_kl[:, 5].std(axis=0), fmt='o-', capsize=5, label='Averaged at 3rd layer') 141 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 6].mean(axis=0), valid_kl[:, 6].std(axis=0), fmt='o-', capsize=5, label='Averaged at 2nd layer') 142 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 7].mean(axis=0), valid_kl[:, 7].std(axis=0), fmt='o-', capsize=5, label='Averaged at 1st layer') 143 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_kl[:, 8].mean(axis=0), valid_kl[:, 8].std(axis=0), fmt='o-', capsize=5, label='Averaged at 0th layer') 144 | plt.xlabel('Epoch') 145 | plt.ylabel('Prediction KL') 146 | plt.legend() 147 | plt.axhline(color='k') 148 | plt.savefig(f'figs/kl_valid_{model_name}_{transform_name}_layers.pdf', bbox_inches='tight') 149 | plt.clf() 150 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 1].mean(axis=0), valid_acc[:, 1].std(axis=0), fmt='o-', capsize=5, label='Original image') 151 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 4].mean(axis=0), valid_acc[:, 4].std(axis=0), fmt='o-', capsize=5, label='Averaged at 4th layer') 152 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 5].mean(axis=0), valid_acc[:, 5].std(axis=0), fmt='o-', capsize=5, label='Averaged at 3rd layer') 153 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 6].mean(axis=0), valid_acc[:, 6].std(axis=0), fmt='o-', capsize=5, label='Averaged at 2nd layer') 154 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 7].mean(axis=0), valid_acc[:, 7].std(axis=0), fmt='o-', capsize=5, label='Averaged at 1st layer') 155 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 8].mean(axis=0), valid_acc[:, 8].std(axis=0), fmt='o-', capsize=5, label='Averaged at 0th layer') 156 | plt.errorbar(range(1, sgd_n_epochs + 1), valid_acc[:, 0].mean(axis=0), valid_acc[:, 0].std(axis=0), fmt='o-', capsize=5, label='Exact (augmented images)') 157 | plt.xlabel('Epoch') 158 | plt.ylabel('Accuracy') 159 | plt.legend() 160 | # plt.axhline(color='k') 161 | plt.savefig(f'figs/accuracy_valid_{model_name}_{transform_name}_layers.pdf', bbox_inches='tight') 162 | plt.clf() 163 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 1] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 1] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='Original image') 164 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 4] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 4] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='Averaged at 4th layer') 165 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 5] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 5] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='Averaged at 3rd layer') 166 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 6] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 6] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='Averaged at 2nd layer') 167 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 7] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 7] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='Averaged at 1st layer') 168 | plt.errorbar(range(1, sgd_n_epochs + 1), (valid_acc[:, 8] - valid_acc[:, 0]).mean(axis=0), (valid_acc[:, 8] - valid_acc[:, 0]).std(axis=0), fmt='o-', capsize=5, label='Averaged at 0th layer') 169 | plt.xlabel('Epoch') 170 | plt.ylabel('Accuracy difference') 171 | plt.legend() 172 | plt.axhline(color='k') 173 | plt.savefig(f'figs/accuracy_difference_valid_{model_name}_{transform_name}_layers.pdf', bbox_inches='tight') 174 | 175 | 176 | def plot_accuracy_vs_computation(): 177 | """Plot computational savings when doing averaging at earlier layers of LeNet 178 | """ 179 | layers = ['conv1_maxpool1', 'conv2_maxpool2', 'fc1', 'fc2', 'fc3'] 180 | # flops = np.array([50 * 24 * 24 * 6 + 4 * 12 * 12 * 6, 24 * 8 * 8 * 16 + 4 * 4 * 4 * 16, 256 * 120 + 120, 120 * 84 + 84, 84 * 10]) 181 | # computation_time = np.array([193, 120, 42, 42, 31]) 182 | # offset = 3 183 | # computation_time -= offset 184 | computation_time = np.array([123, 94, 41, 40, 30]) # Measured with iPython's %timeit 185 | ratio = computation_time / computation_time.sum() 186 | n_transforms = 16 187 | exact = n_transforms 188 | avg = np.empty(6) 189 | avg[5] = 1.0 190 | avg[4] = (ratio[:4].sum() * n_transforms + ratio[4:].sum()) / exact 191 | avg[3] = (ratio[:3].sum() * n_transforms + ratio[3:].sum()) / exact 192 | avg[2] = (ratio[:2].sum() * n_transforms + ratio[2:].sum()) / exact 193 | avg[1] = (ratio[:1].sum() * n_transforms + ratio[1:].sum()) / exact 194 | avg[0] = (ratio[:0].sum() * n_transforms + ratio[0:].sum()) / exact 195 | model_name = 'lenet' 196 | transform_name = 'rotation' 197 | saved_arrays = [np.load(f'saved/train_valid_agreement_kl_{model_name}_{transform_name}_{seed}.npz') 198 | for seed in range(n_trials)] 199 | valid_acc = np.array([saved['valid_acc'] for saved in saved_arrays]) 200 | plt.clf() 201 | plt.errorbar(avg, valid_acc[:, [8, 7, 6, 5, 4, 0], -1].mean(axis=0), valid_acc[:, [8, 7, 6, 5, 4, 0], -1].std(axis=0), fmt='o-', capsize=5) 202 | plt.ylabel('Accuracy') 203 | plt.xlabel('Computation fraction') 204 | plt.savefig(f'figs/accuracy_vs_computation_{model_name}_{transform_name}.pdf', bbox_inches='tight') 205 | # Plot relative accuracy gain 206 | l, u = valid_acc[:, 8, -1].mean(axis=0), valid_acc[:, 0, -1].mean(axis=0) 207 | plt.figure() 208 | plt.errorbar(avg, (valid_acc[:, [8, 7, 6, 5, 4, 0], -1].mean(axis=0) - l) / (u - l), valid_acc[:, [8, 7, 6, 5, 4, 0], -1].std(axis=0) / (u - l), fmt='o-', capsize=10, markersize=10, linewidth=2) 209 | plt.ylabel('Relative accuracy gain', fontsize=16) 210 | plt.xlabel('Computation fraction', fontsize=16) 211 | plt.tick_params(axis='both', which='major', labelsize=12) 212 | plt.savefig(f'figs/accuracy_vs_computation_relative_{model_name}_{transform_name}.pdf', bbox_inches='tight') 213 | plt.close() 214 | 215 | def plot_accuracy_vs_kernel_alignment(): 216 | """Scatter plot of accuracy vs kernel target alignment 217 | """ 218 | valid_acc = [] 219 | for model_name in ['kernel', 'lenet']: 220 | valid_acc_per_model = [] 221 | # Accuracy on no transform 222 | saved_arrays = [np.load(f'saved/train_valid_agreement_kl_{model_name}_blur_{seed}.npz') 223 | for seed in range(n_trials)] 224 | valid_acc_per_model.append(np.array([saved['valid_acc'] for saved in saved_arrays])[:, 1, -1]) 225 | for transform_name in ['rotation', 'crop', 'blur', 'rotation_crop_blur', 'hflip', 'hflip_vflip', 'brightness', 'contrast']: 226 | saved_arrays = [np.load(f'saved/train_valid_acc_{model_name}_{transform_name}_{seed}.npz') 227 | for seed in range(n_trials)] 228 | valid_acc_per_model.append(np.array([saved['valid_acc'] for saved in saved_arrays])[:, -1]) 229 | # print(valid_acc.mean(axis=0)[-1], valid_acc.std(axis=0)[-1]) 230 | valid_acc.append(valid_acc_per_model) 231 | valid_acc = np.array(valid_acc) 232 | kernel_alignment = np.load('saved/kernel_alignment.npy') 233 | 234 | plt.clf() 235 | plt.errorbar(kernel_alignment, valid_acc[0].mean(axis=-1), valid_acc[0].std(axis=-1), fmt='o', capsize=5) 236 | plt.axhline(valid_acc[0, 0].mean(axis=-1), color='k') 237 | plt.axvline(kernel_alignment[0], color='k') 238 | plt.errorbar(kernel_alignment[0], valid_acc[0, 0].mean(axis=-1), valid_acc[0, 0].std(axis=-1), fmt='o', capsize=5) 239 | plt.ylabel('Accuracy') 240 | plt.xlabel('Kernel target alignment') 241 | plt.savefig(f'figs/accuracy_vs_alignment_kernel.pdf', bbox_inches='tight') 242 | plt.clf() 243 | plt.errorbar(kernel_alignment, valid_acc[1].mean(axis=-1), valid_acc[1].std(axis=-1), fmt='o', capsize=5) 244 | plt.axhline(valid_acc[1, 0].mean(axis=-1), color='k') 245 | plt.axvline(kernel_alignment[0], color='k') 246 | plt.errorbar(kernel_alignment[0], valid_acc[1, 0].mean(axis=-1), valid_acc[1, 0].std(axis=-1), fmt='o', capsize=5) 247 | plt.ylabel('Accuracy') 248 | plt.xlabel('Kernel target alignment') 249 | plt.savefig(f'figs/accuracy_vs_alignment_lenet.pdf', bbox_inches='tight') 250 | 251 | plt.clf() 252 | sns.set_style('white') 253 | plt.figure(figsize=(10, 5)) 254 | ax = plt.subplot(1, 2, 1) 255 | ax.errorbar(kernel_alignment[0], valid_acc[0, 0].mean(axis=-1), valid_acc[0, 0].std(axis=-1), fmt='x', color='r', capsize=5) 256 | ax.errorbar(kernel_alignment[1], valid_acc[0, 1].mean(axis=-1), valid_acc[0, 1].std(axis=-1), fmt='s', color='b', capsize=5) 257 | ax.errorbar(kernel_alignment[2], valid_acc[0, 2].mean(axis=-1), valid_acc[0, 2].std(axis=-1), fmt='s', color='g', capsize=5) 258 | ax.errorbar(kernel_alignment[3], valid_acc[0, 3].mean(axis=-1), valid_acc[0, 3].std(axis=-1), fmt='o', color='b', capsize=5) 259 | ax.errorbar(kernel_alignment[4], valid_acc[0, 4].mean(axis=-1), valid_acc[0, 4].std(axis=-1), fmt='s', color='tab:orange', capsize=5) 260 | ax.errorbar(kernel_alignment[5], valid_acc[0, 5].mean(axis=-1), valid_acc[0, 5].std(axis=-1), fmt='D', color='g', capsize=5) 261 | ax.errorbar(kernel_alignment[6], valid_acc[0, 6].mean(axis=-1), valid_acc[0, 6].std(axis=-1), fmt='o', color='g', capsize=5) 262 | ax.errorbar(kernel_alignment[7], valid_acc[0, 7].mean(axis=-1), valid_acc[0, 7].std(axis=-1), fmt='D', color='b', capsize=5) 263 | ax.errorbar(kernel_alignment[8], valid_acc[0, 8].mean(axis=-1), valid_acc[0, 8].std(axis=-1), fmt='D', color='m', capsize=5) 264 | ax.axhline(valid_acc[0, 0].mean(axis=-1), color='k') 265 | ax.axvline(kernel_alignment[0], color='k') 266 | ax.set_yticks([0.94, 0.96, 0.98]) 267 | ax.tick_params(axis='both', which='major', labelsize=12) 268 | ax.set_ylabel('Accuracy', fontsize=16) 269 | ax.set_title('RBF Kernel', fontsize=16) 270 | ax = plt.subplot(1, 2, 2) 271 | ax.errorbar(kernel_alignment[0], valid_acc[1, 0].mean(axis=-1), valid_acc[1, 0].std(axis=-1), fmt='x', color='r', capsize=5, label='original') 272 | ax.errorbar(kernel_alignment[1], valid_acc[1, 1].mean(axis=-1), valid_acc[1, 1].std(axis=-1), fmt='s', color='b', capsize=5, label='rotation') 273 | ax.errorbar(kernel_alignment[2], valid_acc[1, 2].mean(axis=-1), valid_acc[1, 2].std(axis=-1), fmt='s', color='g', capsize=5, label='crop') 274 | ax.errorbar(kernel_alignment[3], valid_acc[1, 3].mean(axis=-1), valid_acc[1, 3].std(axis=-1), fmt='o', color='b', capsize=5, label='blur') 275 | ax.errorbar(kernel_alignment[4], valid_acc[1, 4].mean(axis=-1), valid_acc[1, 4].std(axis=-1), fmt='s', color='tab:orange', capsize=5, label='rotation, crop, blur') 276 | ax.errorbar(kernel_alignment[5], valid_acc[1, 5].mean(axis=-1), valid_acc[1, 5].std(axis=-1), fmt='D', color='g', capsize=5, label='h. flip') 277 | ax.errorbar(kernel_alignment[6], valid_acc[1, 6].mean(axis=-1), valid_acc[1, 6].std(axis=-1), fmt='o', color='g', capsize=5, label='h. flip, v. flip') 278 | ax.errorbar(kernel_alignment[7], valid_acc[1, 7].mean(axis=-1), valid_acc[1, 7].std(axis=-1), fmt='D', color='b', capsize=5, label='brightness') 279 | ax.errorbar(kernel_alignment[8], valid_acc[1, 8].mean(axis=-1), valid_acc[1, 8].std(axis=-1), fmt='D', color='m', capsize=5, label='contrast') 280 | ax.axhline(valid_acc[1, 0].mean(axis=-1), color='k') 281 | ax.axvline(kernel_alignment[0], color='k') 282 | ax.set_yticks([0.97, 0.98, 0.99]) 283 | ax.tick_params(axis='both', which='major', labelsize=12) 284 | ax.set_ylabel('Accuracy', fontsize=16) 285 | ax.set_title('LeNet', fontsize=16) 286 | # sns.despine() 287 | # labels = ['original', 'rotation', 'crop', 'blur', 'rot., crop, blur', 'h. flip', 'h. flip, v. flip', 'brightness', 'contrast'] 288 | # plt.legend(labels, loc='upper center', bbox_transform=plt.gcf().transFigure, bbox_to_anchor=(0,0,1,1), ncol=3, fontsize=14) 289 | plt.legend(loc='upper center', bbox_transform=plt.gcf().transFigure, bbox_to_anchor=(0,0.07,1,1), ncol=3, fontsize=16, frameon=True, edgecolor='k') 290 | plt.tight_layout() 291 | plt.subplots_adjust(wspace=0.4, top=0.75, bottom=0.1) 292 | plt.suptitle('Kernel target alignment', x=0.5, y=0.05, fontsize=16) 293 | # ax.set_ylabel('Accuracy') 294 | # plt.xlabel('Kernel target alignment') 295 | plt.savefig(f'figs/accuracy_vs_alignment.pdf', bbox_inches='tight') 296 | 297 | def main(): 298 | pathlib.Path('figs').mkdir(parents=True, exist_ok=True) 299 | plot_objective_difference() 300 | plot_agreement_kl() 301 | plot_agreement_kl_avg_at_layers() 302 | plot_accuracy_vs_computation() 303 | plot_accuracy_vs_kernel_alignment() 304 | 305 | if __name__ == '__main__': 306 | main() 307 | --------------------------------------------------------------------------------