├── .gitignore ├── 01_extremeSB_slab_data.ipynb ├── 02_extremeSB_mnistcifar_data.ipynb ├── 03_suboptimal_generalization.ipynb ├── 04_effect_of_ensembles.ipynb ├── 05_effect_of_adversarial_training.ipynb ├── 06_uaps.ipynb ├── README.md ├── cifar10_models ├── .gitignore ├── LICENSE ├── README.md ├── cifar10_download.py ├── cifar10_models │ ├── __init__.py │ ├── densenet.py │ ├── googlenet.py │ ├── inception.py │ ├── mobilenetv2.py │ ├── resnet.py │ ├── resnet_orig.py │ └── vgg.py ├── cifar10_module.py ├── cifar10_test.py └── cifar10_train.py ├── imports.py ├── requirements.txt └── scripts ├── data_utils.py ├── ensemble.py ├── gendata.py ├── gpu_utils.py ├── lms_utils.py ├── mnistcifar_utils.py ├── ptb_utils.py ├── synth_models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | data/*.pkl 3 | data/ 4 | .DS_Store 5 | data/fixed_data 6 | */*/nolin_mnist.pkl 7 | data/*mnist* 8 | .idea/ 9 | data/pytorch_datasets 10 | data/*/*mnist* 11 | datasets/cifar* 12 | datasets/ 13 | models/ 14 | data/batch_science 15 | data/demogen 16 | data/df_* 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## The Pitfalls of Simplicity Bias in Neural Networks 2 | 3 | ### Summary 4 | 5 | This repository consists of code primitives and Jupyter notebooks that can be used to replicate and extend the findings presented in the paper "The Pitfalls of Simplicity Bias in Neural Networks" ([link](https://arxiv.org/abs/2006.07710)). In addition to the code (in scripts/) to generate the proposed datasets, we provide six Jupyter notebooks: 6 | 7 | 1. ```01_extremeSB_slab_data.ipynb``` shows the simplicity bias of fully-connected networks trained on synthetic slab-structured datasets. 8 | 2. ```02_extremeSB_mnistcifar_data.ipynb``` highlights simplicity bias of commonly-used convolutional neural networks (CNNs) on the concatenated MNIST-CIFAR dataset, 9 | 3. ```03_suboptimal_generalization.ipynb``` analyzes the effect of extreme simplicity bias on standard generalization. 10 | 4. ```04_effect_of_ensembles.ipynb``` studies the effectiveness of ensembles of independently trained methods in mitigating simplicity bias and its pitfalls. 11 | 5. ```05_effect_of_adversarial_training.ipynb``` evaluates the effectiveness of adversarial training in mitigating simplicity bias. 12 | 6. ```06_uaps.ipynb``` demonstrates how extreme simplicity bias can lead to small-norm and data-agnostic "universal" adversarial perturbations that nullify performance of SGD-trained neural networks. 13 | 14 | 15 | Please check out our [paper](https://arxiv.org/abs/2006.07710) or [poster](http://harshay.me/pdf/poster_neurips20_simplicitybias.pdf) for more details. 16 | 17 | ### Setup 18 | 19 | Our code uses Python 3.7.3, Torch 1.1.0, Torchvision 0.3.0, Ubuntu 18.04.2 LTS and the packages listed in `requirements.txt`. 20 | 21 | --- 22 | 23 | If you find this project useful in your research, please consider citing the following publication: 24 | 25 | ``` 26 | @article{shah2020pitfalls, 27 | title={The Pitfalls of Simplicity Bias in Neural Networks}, 28 | author={Shah, Harshay and Tamuly, Kaustav and Raghunathan, Aditi and Jain, Prateek and Netrapalli, Praneeth}, 29 | journal={Advances in Neural Information Processing Systems}, 30 | volume={33}, 31 | year={2020} 32 | } 33 | ``` 34 | 35 | -------------------------------------------------------------------------------- /cifar10_models/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.ckpt 3 | __pycache__/ 4 | */__pycache__/ 5 | .ipynb_checkpoints/ 6 | */.ipynb_checkpoints/ 7 | -------------------------------------------------------------------------------- /cifar10_models/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Huy Phan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cifar10_models/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch models trained on CIFAR-10 dataset 2 | - I modified [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html) official implementation of popular CNN models, and trained those on CIFAR-10 dataset. 3 | - I changed *number of class, filter size, stride, and padding* in the the original code so that it works with CIFAR-10. 4 | - I also share the **weights** of these models, so you can just load the weights and use them. 5 | - The code is highly re-producible and readable by using PyTorch-Lightning. 6 | 7 | ## Statistics of supported models 8 | | No. | Model | Val. Acc. | No. Params | Size | 9 | |:---:|:-------------|----------:|-----------:|-------:| 10 | | 1 | vgg11_bn | 92.09% | 128.813 M | 491 MB | 11 | | 2 | vgg13_bn | 94.29% | 128.998 M | 492 MB | 12 | | 3 | vgg16_bn | 93.91% | 134.310 M | 512 MB | 13 | | 4 | vgg19_bn | 93.80% | 139.622 M | 533 MB | 14 | | 5 | resnet18 | 93.33% | 11.174 M | 43 MB | 15 | | 6 | resnet34 | 92.92% | 21.282 M | 81 MB | 16 | | 7 | resnet50 | 93.86% | 23.521 M | 90 MB | 17 | | 8 | densenet121 | 94.14% | 6.956 M | 27 MB | 18 | | 9 | densenet161 | 94.24% | 26.483 M | 102 MB | 19 | | 10 | densenet169 | 94.00% | 12.493 M | 48 MB | 20 | | 11 | mobilenet_v2 | 94.17% | 2.237 M | 9 MB | 21 | | 12 | googlenet | 92.73% | 5.491 M | 21 MB | 22 | | 13 | inception_v3 | 93.76% | 21.640 M | 83 MB | 23 | 24 | ## How to use pretrained models 25 | 26 | **Automatically download and extract the weights from Box (2.39 GB)** 27 | ```python 28 | python cifar10_download.py 29 | ``` 30 | Or use [Google Drive](https://drive.google.com/file/d/11DDSbPqFXLzooIv6YPmXuKRIZJ24808g/view?usp=sharing) backup link (you have to download and extract manually) 31 | 32 | **Load model and run** 33 | ```python 34 | from cifar10_models import * 35 | 36 | # Untrained model 37 | my_model = vgg11_bn() 38 | 39 | # Pretrained model 40 | my_model = vgg11_bn(pretrained=True) 41 | ``` 42 | 43 | If you use your own images, all models expect data to be in range [0, 1] then normalize by 44 | ```python 45 | mean = [0.4914, 0.4822, 0.4465] 46 | std = [0.2023, 0.1994, 0.2010] 47 | ``` 48 | 49 | ## How to train models from scratch 50 | Check the `cifar10_train.py` to see all available hyper-parameter choices. 51 | To reproduce the same accuracy use the default hyper-parameters 52 | 53 | `python cifar10_train.py --classifier resnet18 --gpu '0,'` 54 | 55 | ## How to test trained models 56 | `python cifar10_test.py --classifier resnet18 --gpu '0,'` 57 | 58 | Output 59 | 60 | `TEST RESULTS 61 | {'Accuracy': 93.33}` 62 | 63 | ## Check the TensorBoard logs 64 | To see the training progress, cd to the `tensorboard_logs` and run TensorBoard there 65 | 66 | `tensorboard --logdir=. --port=YOUR_PORT_NUMBER` 67 | 68 | Then go to 69 | `http://localhost:YOUR_PORT_NUMBER` 70 | 71 | ## Requirements 72 | **Just to use pretrained models** 73 | - pytorch = 1.5.0 74 | 75 | **To train & test** 76 | - torchvision = 0.6.0 77 | - tensorboard = 2.2.1 78 | - pytorch-lightning = 0.7.6 -------------------------------------------------------------------------------- /cifar10_models/cifar10_download.py: -------------------------------------------------------------------------------- 1 | import requests, zipfile, os 2 | from tqdm import tqdm 3 | 4 | def main(): 5 | url = "https://rutgers.box.com/shared/static/y9wi8ic7bshe2nn63prj9vsea7wibd4x.zip" 6 | 7 | # Streaming, so we can iterate over the response. 8 | r = requests.get(url, stream=True) 9 | 10 | # Total size in Mebibyte 11 | total_size = int(r.headers.get('content-length', 0)) 12 | block_size = 2**20 # Mebibyte 13 | t=tqdm(total=total_size, unit='MiB', unit_scale=True) 14 | 15 | with open('state_dicts.zip', 'wb') as f: 16 | for data in r.iter_content(block_size): 17 | t.update(len(data)) 18 | f.write(data) 19 | t.close() 20 | 21 | if total_size != 0 and t.n != total_size: 22 | raise Exception('Error, something went wrong') 23 | 24 | print('Download successful. Unzipping file.') 25 | path_to_zip_file = os.path.join(os.getcwd(), 'state_dicts.zip') 26 | directory_to_extract_to = os.path.join(os.getcwd(), 'cifar10_models') 27 | with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: 28 | zip_ref.extractall(directory_to_extract_to) 29 | print('Unzip file successful!') 30 | 31 | if __name__ == '__main__': 32 | main() -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenetv2 import * 2 | from .resnet import * 3 | from .vgg import * 4 | from .densenet import * 5 | from .resnet_orig import * 6 | from .googlenet import * 7 | from .inception import * -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | import os 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | class _DenseLayer(nn.Sequential): 11 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 12 | super(_DenseLayer, self).__init__() 13 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 14 | self.add_module('relu1', nn.ReLU(inplace=True)), 15 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 16 | growth_rate, kernel_size=1, stride=1, 17 | bias=False)), 18 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 19 | self.add_module('relu2', nn.ReLU(inplace=True)), 20 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 21 | kernel_size=3, stride=1, padding=1, 22 | bias=False)), 23 | self.drop_rate = drop_rate 24 | 25 | def forward(self, x): 26 | new_features = super(_DenseLayer, self).forward(x) 27 | if self.drop_rate > 0: 28 | new_features = F.dropout(new_features, p=self.drop_rate, 29 | training=self.training) 30 | return torch.cat([x, new_features], 1) 31 | 32 | 33 | class _DenseBlock(nn.Sequential): 34 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 35 | super(_DenseBlock, self).__init__() 36 | for i in range(num_layers): 37 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, 38 | bn_size, drop_rate) 39 | self.add_module('denselayer%d' % (i + 1), layer) 40 | 41 | 42 | class _Transition(nn.Sequential): 43 | def __init__(self, num_input_features, num_output_features): 44 | super(_Transition, self).__init__() 45 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 46 | self.add_module('relu', nn.ReLU(inplace=True)) 47 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 48 | kernel_size=1, stride=1, bias=False)) 49 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 50 | 51 | 52 | class DenseNet(nn.Module): 53 | r"""Densenet-BC model class, based on 54 | `"Densely Connected Convolutional Networks" `_ 55 | 56 | Args: 57 | growth_rate (int) - how many filters to add each layer (`k` in paper) 58 | block_config (list of 4 ints) - how many layers in each pooling block 59 | num_init_features (int) - the number of filters to learn in the first convolution layer 60 | bn_size (int) - multiplicative factor for number of bottle neck layers 61 | (i.e. bn_size * k features in the bottleneck layer) 62 | drop_rate (float) - dropout rate after each dense layer 63 | num_classes (int) - number of classification classes 64 | """ 65 | 66 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 67 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=10): 68 | 69 | super(DenseNet, self).__init__() 70 | 71 | # First convolution 72 | 73 | # CIFAR-10: kernel_size 7 ->3, stride 2->1, padding 3->1 74 | self.features = nn.Sequential(OrderedDict([ 75 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, 76 | padding=1, bias=False)), 77 | ('norm0', nn.BatchNorm2d(num_init_features)), 78 | ('relu0', nn.ReLU(inplace=True)), 79 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 80 | ])) 81 | ## END 82 | 83 | # Each denseblock 84 | num_features = num_init_features 85 | for i, num_layers in enumerate(block_config): 86 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 87 | bn_size=bn_size, growth_rate=growth_rate, 88 | drop_rate=drop_rate) 89 | self.features.add_module('denseblock%d' % (i + 1), block) 90 | num_features = num_features + num_layers * growth_rate 91 | if i != len(block_config) - 1: 92 | trans = _Transition(num_input_features=num_features, 93 | num_output_features=num_features // 2) 94 | self.features.add_module('transition%d' % (i + 1), trans) 95 | num_features = num_features // 2 96 | 97 | # Final batch norm 98 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 99 | 100 | # Linear layer 101 | self.classifier = nn.Linear(num_features, num_classes) 102 | 103 | # Official init from torch repo. 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.Linear): 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def forward(self, x): 114 | features = self.features(x) 115 | out = F.relu(features, inplace=True) 116 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 117 | out = self.classifier(out) 118 | return out 119 | 120 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, device, **kwargs): 121 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 122 | if pretrained: 123 | script_dir = os.path.dirname(__file__) 124 | state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device) 125 | model.load_state_dict(state_dict) 126 | return model 127 | 128 | 129 | def densenet121(pretrained=False, progress=True, device='cpu', **kwargs): 130 | r"""Densenet-121 model from 131 | `"Densely Connected Convolutional Networks" `_ 132 | 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | progress (bool): If True, displays a progress bar of the download to stderr 136 | """ 137 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, device, 138 | **kwargs) 139 | 140 | 141 | def densenet161(pretrained=False, progress=True, device='cpu', **kwargs): 142 | r"""Densenet-161 model from 143 | `"Densely Connected Convolutional Networks" `_ 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | progress (bool): If True, displays a progress bar of the download to stderr 148 | """ 149 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, device, 150 | **kwargs) 151 | 152 | 153 | def densenet169(pretrained=False, progress=True, device='cpu', **kwargs): 154 | r"""Densenet-169 model from 155 | `"Densely Connected Convolutional Networks" `_ 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | progress (bool): If True, displays a progress bar of the download to stderr 160 | """ 161 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, device, 162 | **kwargs) 163 | 164 | 165 | def densenet201(pretrained=False, progress=True, device='cpu', **kwargs): 166 | r"""Densenet-201 model from 167 | `"Densely Connected Convolutional Networks" `_ 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | progress (bool): If True, displays a progress bar of the download to stderr 172 | """ 173 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, device, 174 | **kwargs) 175 | -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/googlenet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import os 7 | 8 | __all__ = ['GoogLeNet', 'googlenet'] 9 | 10 | 11 | _GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1']) 12 | 13 | 14 | def googlenet(pretrained=False, progress=True, device='cpu', **kwargs): 15 | r"""GoogLeNet (Inception v1) model architecture from 16 | `"Going Deeper with Convolutions" `_. 17 | 18 | Args: 19 | pretrained (bool): If True, returns a model pre-trained on ImageNet 20 | progress (bool): If True, displays a progress bar of the download to stderr 21 | aux_logits (bool): If True, adds two auxiliary branches that can improve training. 22 | Default: *False* when pretrained is True otherwise *True* 23 | transform_input (bool): If True, preprocesses the input according to the method with which it 24 | was trained on ImageNet. Default: *False* 25 | """ 26 | model = GoogLeNet() 27 | if pretrained: 28 | script_dir = os.path.dirname(__file__) 29 | state_dict = torch.load(script_dir + '/state_dicts/googlenet.pt', map_location=device) 30 | model.load_state_dict(state_dict) 31 | return model 32 | 33 | 34 | class GoogLeNet(nn.Module): 35 | 36 | ## CIFAR10: aux_logits True->False 37 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 38 | super(GoogLeNet, self).__init__() 39 | self.aux_logits = aux_logits 40 | self.transform_input = transform_input 41 | 42 | ## CIFAR10: out_channels 64->192, kernel_size 7->3, stride 2->1, padding 3->1 43 | self.conv1 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 44 | # self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 45 | # self.conv2 = BasicConv2d(64, 64, kernel_size=1) 46 | # self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 47 | # self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 48 | ## END 49 | 50 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 51 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 52 | 53 | ## CIFAR10: padding 0->1, ciel_model True->False 54 | self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 55 | ## END 56 | 57 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 58 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 59 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 60 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 61 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 62 | 63 | ## CIFAR10: kernel_size 2->3, padding 0->1, ciel_model True->False 64 | self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 65 | ## END 66 | 67 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 68 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 69 | 70 | if aux_logits: 71 | self.aux1 = InceptionAux(512, num_classes) 72 | self.aux2 = InceptionAux(528, num_classes) 73 | 74 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 75 | self.dropout = nn.Dropout(0.2) 76 | self.fc = nn.Linear(1024, num_classes) 77 | 78 | # if init_weights: 79 | # self._initialize_weights() 80 | 81 | # def _initialize_weights(self): 82 | # for m in self.modules(): 83 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 84 | # import scipy.stats as stats 85 | # X = stats.truncnorm(-2, 2, scale=0.01) 86 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 87 | # values = values.view(m.weight.size()) 88 | # with torch.no_grad(): 89 | # m.weight.copy_(values) 90 | # elif isinstance(m, nn.BatchNorm2d): 91 | # nn.init.constant_(m.weight, 1) 92 | # nn.init.constant_(m.bias, 0) 93 | 94 | def forward(self, x): 95 | if self.transform_input: 96 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 97 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 98 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 99 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 100 | 101 | # N x 3 x 224 x 224 102 | x = self.conv1(x) 103 | 104 | ## CIFAR10 105 | # N x 64 x 112 x 112 106 | # x = self.maxpool1(x) 107 | # N x 64 x 56 x 56 108 | # x = self.conv2(x) 109 | # N x 64 x 56 x 56 110 | # x = self.conv3(x) 111 | # N x 192 x 56 x 56 112 | # x = self.maxpool2(x) 113 | ## END 114 | 115 | # N x 192 x 28 x 28 116 | x = self.inception3a(x) 117 | # N x 256 x 28 x 28 118 | x = self.inception3b(x) 119 | # N x 480 x 28 x 28 120 | x = self.maxpool3(x) 121 | # N x 480 x 14 x 14 122 | x = self.inception4a(x) 123 | # N x 512 x 14 x 14 124 | if self.training and self.aux_logits: 125 | aux1 = self.aux1(x) 126 | 127 | x = self.inception4b(x) 128 | # N x 512 x 14 x 14 129 | x = self.inception4c(x) 130 | # N x 512 x 14 x 14 131 | x = self.inception4d(x) 132 | # N x 528 x 14 x 14 133 | if self.training and self.aux_logits: 134 | aux2 = self.aux2(x) 135 | 136 | x = self.inception4e(x) 137 | # N x 832 x 14 x 14 138 | x = self.maxpool4(x) 139 | # N x 832 x 7 x 7 140 | x = self.inception5a(x) 141 | # N x 832 x 7 x 7 142 | x = self.inception5b(x) 143 | # N x 1024 x 7 x 7 144 | 145 | x = self.avgpool(x) 146 | # N x 1024 x 1 x 1 147 | x = x.view(x.size(0), -1) 148 | # N x 1024 149 | x = self.dropout(x) 150 | x = self.fc(x) 151 | # N x 1000 (num_classes) 152 | if self.training and self.aux_logits: 153 | return _GoogLeNetOuputs(x, aux2, aux1) 154 | return x 155 | 156 | 157 | class Inception(nn.Module): 158 | 159 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 160 | super(Inception, self).__init__() 161 | 162 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 163 | 164 | self.branch2 = nn.Sequential( 165 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 166 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) 167 | ) 168 | 169 | self.branch3 = nn.Sequential( 170 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 171 | BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) 172 | ) 173 | 174 | self.branch4 = nn.Sequential( 175 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 176 | BasicConv2d(in_channels, pool_proj, kernel_size=1) 177 | ) 178 | 179 | def forward(self, x): 180 | branch1 = self.branch1(x) 181 | branch2 = self.branch2(x) 182 | branch3 = self.branch3(x) 183 | branch4 = self.branch4(x) 184 | 185 | outputs = [branch1, branch2, branch3, branch4] 186 | return torch.cat(outputs, 1) 187 | 188 | 189 | class InceptionAux(nn.Module): 190 | 191 | def __init__(self, in_channels, num_classes): 192 | super(InceptionAux, self).__init__() 193 | self.conv = BasicConv2d(in_channels, 128, kernel_size=1) 194 | 195 | self.fc1 = nn.Linear(2048, 1024) 196 | self.fc2 = nn.Linear(1024, num_classes) 197 | 198 | def forward(self, x): 199 | # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 200 | x = F.adaptive_avg_pool2d(x, (4, 4)) 201 | # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 202 | x = self.conv(x) 203 | # N x 128 x 4 x 4 204 | x = x.view(x.size(0), -1) 205 | # N x 2048 206 | x = F.relu(self.fc1(x), inplace=True) 207 | # N x 2048 208 | x = F.dropout(x, 0.7, training=self.training) 209 | # N x 2048 210 | x = self.fc2(x) 211 | # N x 1024 212 | 213 | return x 214 | 215 | 216 | class BasicConv2d(nn.Module): 217 | 218 | def __init__(self, in_channels, out_channels, **kwargs): 219 | super(BasicConv2d, self).__init__() 220 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 221 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 222 | 223 | def forward(self, x): 224 | x = self.conv(x) 225 | x = self.bn(x) 226 | return F.relu(x, inplace=True) 227 | -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/inception.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | 7 | __all__ = ['Inception3', 'inception_v3'] 8 | 9 | 10 | _InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits']) 11 | 12 | 13 | def inception_v3(pretrained=False, progress=True, device='cpu', **kwargs): 14 | r"""Inception v3 model architecture from 15 | `"Rethinking the Inception Architecture for Computer Vision" `_. 16 | 17 | .. note:: 18 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 19 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 20 | 21 | Args: 22 | pretrained (bool): If True, returns a model pre-trained on ImageNet 23 | progress (bool): If True, displays a progress bar of the download to stderr 24 | aux_logits (bool): If True, add an auxiliary branch that can improve training. 25 | Default: *True* 26 | transform_input (bool): If True, preprocesses the input according to the method with which it 27 | was trained on ImageNet. Default: *False* 28 | """ 29 | model = Inception3() 30 | if pretrained: 31 | script_dir = os.path.dirname(__file__) 32 | state_dict = torch.load(script_dir + '/state_dicts/inception_v3.pt', map_location=device) 33 | model.load_state_dict(state_dict) 34 | return model 35 | 36 | class Inception3(nn.Module): 37 | ## CIFAR10: aux_logits True->False 38 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 39 | super(Inception3, self).__init__() 40 | self.aux_logits = aux_logits 41 | self.transform_input = transform_input 42 | 43 | ## CIFAR10: stride 2->1, padding 0 -> 1 44 | self.Conv2d_1a_3x3 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 45 | # self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 46 | # self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 47 | # self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 48 | # self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 49 | self.Mixed_5b = InceptionA(192, pool_features=32) 50 | self.Mixed_5c = InceptionA(256, pool_features=64) 51 | self.Mixed_5d = InceptionA(288, pool_features=64) 52 | self.Mixed_6a = InceptionB(288) 53 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 54 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 55 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 56 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 57 | if aux_logits: 58 | self.AuxLogits = InceptionAux(768, num_classes) 59 | self.Mixed_7a = InceptionD(768) 60 | self.Mixed_7b = InceptionE(1280) 61 | self.Mixed_7c = InceptionE(2048) 62 | self.fc = nn.Linear(2048, num_classes) 63 | 64 | # for m in self.modules(): 65 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 66 | # import scipy.stats as stats 67 | # stddev = m.stddev if hasattr(m, 'stddev') else 0.1 68 | # X = stats.truncnorm(-2, 2, scale=stddev) 69 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 70 | # values = values.view(m.weight.size()) 71 | # with torch.no_grad(): 72 | # m.weight.copy_(values) 73 | # elif isinstance(m, nn.BatchNorm2d): 74 | # nn.init.constant_(m.weight, 1) 75 | # nn.init.constant_(m.bias, 0) 76 | 77 | def forward(self, x): 78 | if self.transform_input: 79 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 80 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 81 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 82 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 83 | # N x 3 x 299 x 299 84 | x = self.Conv2d_1a_3x3(x) 85 | 86 | ## CIFAR10 87 | # N x 32 x 149 x 149 88 | # x = self.Conv2d_2a_3x3(x) 89 | # N x 32 x 147 x 147 90 | # x = self.Conv2d_2b_3x3(x) 91 | # N x 64 x 147 x 147 92 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 93 | # N x 64 x 73 x 73 94 | # x = self.Conv2d_3b_1x1(x) 95 | # N x 80 x 73 x 73 96 | # x = self.Conv2d_4a_3x3(x) 97 | # N x 192 x 71 x 71 98 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 99 | # N x 192 x 35 x 35 100 | x = self.Mixed_5b(x) 101 | # N x 256 x 35 x 35 102 | x = self.Mixed_5c(x) 103 | # N x 288 x 35 x 35 104 | x = self.Mixed_5d(x) 105 | # N x 288 x 35 x 35 106 | x = self.Mixed_6a(x) 107 | # N x 768 x 17 x 17 108 | x = self.Mixed_6b(x) 109 | # N x 768 x 17 x 17 110 | x = self.Mixed_6c(x) 111 | # N x 768 x 17 x 17 112 | x = self.Mixed_6d(x) 113 | # N x 768 x 17 x 17 114 | x = self.Mixed_6e(x) 115 | # N x 768 x 17 x 17 116 | if self.training and self.aux_logits: 117 | aux = self.AuxLogits(x) 118 | # N x 768 x 17 x 17 119 | x = self.Mixed_7a(x) 120 | # N x 1280 x 8 x 8 121 | x = self.Mixed_7b(x) 122 | # N x 2048 x 8 x 8 123 | x = self.Mixed_7c(x) 124 | # N x 2048 x 8 x 8 125 | # Adaptive average pooling 126 | x = F.adaptive_avg_pool2d(x, (1, 1)) 127 | # N x 2048 x 1 x 1 128 | x = F.dropout(x, training=self.training) 129 | # N x 2048 x 1 x 1 130 | x = x.view(x.size(0), -1) 131 | # N x 2048 132 | x = self.fc(x) 133 | # N x 1000 (num_classes) 134 | if self.training and self.aux_logits: 135 | return _InceptionOuputs(x, aux) 136 | return x 137 | 138 | 139 | class InceptionA(nn.Module): 140 | 141 | def __init__(self, in_channels, pool_features): 142 | super(InceptionA, self).__init__() 143 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 144 | 145 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 146 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 147 | 148 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 149 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 150 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 151 | 152 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 153 | 154 | def forward(self, x): 155 | branch1x1 = self.branch1x1(x) 156 | 157 | branch5x5 = self.branch5x5_1(x) 158 | branch5x5 = self.branch5x5_2(branch5x5) 159 | 160 | branch3x3dbl = self.branch3x3dbl_1(x) 161 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 162 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 163 | 164 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 165 | branch_pool = self.branch_pool(branch_pool) 166 | 167 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 168 | return torch.cat(outputs, 1) 169 | 170 | 171 | class InceptionB(nn.Module): 172 | 173 | def __init__(self, in_channels): 174 | super(InceptionB, self).__init__() 175 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 176 | 177 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 178 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 179 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 180 | 181 | def forward(self, x): 182 | branch3x3 = self.branch3x3(x) 183 | 184 | branch3x3dbl = self.branch3x3dbl_1(x) 185 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 186 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 187 | 188 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 189 | 190 | outputs = [branch3x3, branch3x3dbl, branch_pool] 191 | return torch.cat(outputs, 1) 192 | 193 | 194 | class InceptionC(nn.Module): 195 | 196 | def __init__(self, in_channels, channels_7x7): 197 | super(InceptionC, self).__init__() 198 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 199 | 200 | c7 = channels_7x7 201 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 202 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 203 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 204 | 205 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 206 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 207 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 208 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 209 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 210 | 211 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 212 | 213 | def forward(self, x): 214 | branch1x1 = self.branch1x1(x) 215 | 216 | branch7x7 = self.branch7x7_1(x) 217 | branch7x7 = self.branch7x7_2(branch7x7) 218 | branch7x7 = self.branch7x7_3(branch7x7) 219 | 220 | branch7x7dbl = self.branch7x7dbl_1(x) 221 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 222 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 223 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 224 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 225 | 226 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 227 | branch_pool = self.branch_pool(branch_pool) 228 | 229 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 230 | return torch.cat(outputs, 1) 231 | 232 | 233 | class InceptionD(nn.Module): 234 | 235 | def __init__(self, in_channels): 236 | super(InceptionD, self).__init__() 237 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 238 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 239 | 240 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 241 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 242 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 243 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 244 | 245 | def forward(self, x): 246 | branch3x3 = self.branch3x3_1(x) 247 | branch3x3 = self.branch3x3_2(branch3x3) 248 | 249 | branch7x7x3 = self.branch7x7x3_1(x) 250 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 251 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 252 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 253 | 254 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 255 | outputs = [branch3x3, branch7x7x3, branch_pool] 256 | return torch.cat(outputs, 1) 257 | 258 | 259 | class InceptionE(nn.Module): 260 | 261 | def __init__(self, in_channels): 262 | super(InceptionE, self).__init__() 263 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 264 | 265 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 266 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 267 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 268 | 269 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 270 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 271 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 272 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 273 | 274 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 275 | 276 | def forward(self, x): 277 | branch1x1 = self.branch1x1(x) 278 | 279 | branch3x3 = self.branch3x3_1(x) 280 | branch3x3 = [ 281 | self.branch3x3_2a(branch3x3), 282 | self.branch3x3_2b(branch3x3), 283 | ] 284 | branch3x3 = torch.cat(branch3x3, 1) 285 | 286 | branch3x3dbl = self.branch3x3dbl_1(x) 287 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 288 | branch3x3dbl = [ 289 | self.branch3x3dbl_3a(branch3x3dbl), 290 | self.branch3x3dbl_3b(branch3x3dbl), 291 | ] 292 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 293 | 294 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 295 | branch_pool = self.branch_pool(branch_pool) 296 | 297 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 298 | return torch.cat(outputs, 1) 299 | 300 | 301 | class InceptionAux(nn.Module): 302 | 303 | def __init__(self, in_channels, num_classes): 304 | super(InceptionAux, self).__init__() 305 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 306 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 307 | self.conv1.stddev = 0.01 308 | self.fc = nn.Linear(768, num_classes) 309 | self.fc.stddev = 0.001 310 | 311 | def forward(self, x): 312 | # N x 768 x 17 x 17 313 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 314 | # N x 768 x 5 x 5 315 | x = self.conv0(x) 316 | # N x 128 x 5 x 5 317 | x = self.conv1(x) 318 | # N x 768 x 1 x 1 319 | # Adaptive average pooling 320 | x = F.adaptive_avg_pool2d(x, (1, 1)) 321 | # N x 768 x 1 x 1 322 | x = x.view(x.size(0), -1) 323 | # N x 768 324 | x = self.fc(x) 325 | # N x 1000 326 | return x 327 | 328 | 329 | class BasicConv2d(nn.Module): 330 | 331 | def __init__(self, in_channels, out_channels, **kwargs): 332 | super(BasicConv2d, self).__init__() 333 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 334 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 335 | 336 | def forward(self, x): 337 | x = self.conv(x) 338 | x = self.bn(x) 339 | return F.relu(x, inplace=True) 340 | -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | class ConvBNReLU(nn.Sequential): 9 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 10 | padding = (kernel_size - 1) // 2 11 | super(ConvBNReLU, self).__init__( 12 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 13 | nn.BatchNorm2d(out_planes), 14 | nn.ReLU6(inplace=True) 15 | ) 16 | 17 | 18 | class InvertedResidual(nn.Module): 19 | def __init__(self, inp, oup, stride, expand_ratio): 20 | super(InvertedResidual, self).__init__() 21 | self.stride = stride 22 | assert stride in [1, 2] 23 | 24 | hidden_dim = int(round(inp * expand_ratio)) 25 | self.use_res_connect = self.stride == 1 and inp == oup 26 | 27 | layers = [] 28 | if expand_ratio != 1: 29 | # pw 30 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 31 | layers.extend([ 32 | # dw 33 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 34 | # pw-linear 35 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | ]) 38 | self.conv = nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | if self.use_res_connect: 42 | return x + self.conv(x) 43 | else: 44 | return self.conv(x) 45 | 46 | 47 | class MobileNetV2(nn.Module): 48 | def __init__(self, num_classes=10, width_mult=1.0): 49 | super(MobileNetV2, self).__init__() 50 | block = InvertedResidual 51 | input_channel = 32 52 | last_channel = 1280 53 | 54 | ## CIFAR10 55 | inverted_residual_setting = [ 56 | # t, c, n, s 57 | [1, 16, 1, 1], 58 | [6, 24, 2, 1], # Stride 2 -> 1 for CIFAR-10 59 | [6, 32, 3, 2], 60 | [6, 64, 4, 2], 61 | [6, 96, 3, 1], 62 | [6, 160, 3, 2], 63 | [6, 320, 1, 1], 64 | ] 65 | ## END 66 | 67 | # building first layer 68 | input_channel = int(input_channel * width_mult) 69 | self.last_channel = int(last_channel * max(1.0, width_mult)) 70 | 71 | # CIFAR10: stride 2 -> 1 72 | features = [ConvBNReLU(3, input_channel, stride=1)] 73 | # END 74 | 75 | # building inverted residual blocks 76 | for t, c, n, s in inverted_residual_setting: 77 | output_channel = int(c * width_mult) 78 | for i in range(n): 79 | stride = s if i == 0 else 1 80 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 81 | input_channel = output_channel 82 | # building last several layers 83 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 84 | # make it nn.Sequential 85 | self.features = nn.Sequential(*features) 86 | 87 | # building classifier 88 | self.classifier = nn.Sequential( 89 | nn.Dropout(0.2), 90 | nn.Linear(self.last_channel, num_classes), 91 | ) 92 | 93 | # weight initialization 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 97 | if m.bias is not None: 98 | nn.init.zeros_(m.bias) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | nn.init.ones_(m.weight) 101 | nn.init.zeros_(m.bias) 102 | elif isinstance(m, nn.Linear): 103 | nn.init.normal_(m.weight, 0, 0.01) 104 | nn.init.zeros_(m.bias) 105 | 106 | def forward(self, x): 107 | x = self.features(x) 108 | x = x.mean([2, 3]) 109 | x = self.classifier(x) 110 | return x 111 | 112 | 113 | def mobilenet_v2(pretrained=False, progress=True, device='cpu', **kwargs): 114 | """ 115 | Constructs a MobileNetV2 architecture from 116 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 117 | 118 | Args: 119 | pretrained (bool): If True, returns a model pre-trained on ImageNet 120 | progress (bool): If True, displays a progress bar of the download to stderr 121 | """ 122 | model = MobileNetV2(**kwargs) 123 | if pretrained: 124 | script_dir = os.path.dirname(__file__) 125 | state_dict = torch.load(script_dir+'/state_dicts/mobilenet_v2.pt', map_location=device) 126 | model.load_state_dict(state_dict) 127 | return model 128 | -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=dilation, groups=groups, bias=False, dilation=dilation) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 23 | base_width=64, dilation=1, norm_layer=None): 24 | super(BasicBlock, self).__init__() 25 | if norm_layer is None: 26 | norm_layer = nn.BatchNorm2d 27 | if groups != 1 or base_width != 64: 28 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 29 | if dilation > 1: 30 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 31 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = norm_layer(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = norm_layer(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 63 | base_width=64, dilation=1, norm_layer=None): 64 | super(Bottleneck, self).__init__() 65 | if norm_layer is None: 66 | norm_layer = nn.BatchNorm2d 67 | width = int(planes * (base_width / 64.)) * groups 68 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 69 | self.conv1 = conv1x1(inplanes, width) 70 | self.bn1 = norm_layer(width) 71 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 72 | self.bn2 = norm_layer(width) 73 | self.conv3 = conv1x1(width, planes * self.expansion) 74 | self.bn3 = norm_layer(planes * self.expansion) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | identity = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | identity = self.downsample(x) 95 | 96 | out += identity 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class ResNet(nn.Module): 103 | 104 | def __init__(self, block, layers, num_classes=10, zero_init_residual=False, 105 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 106 | norm_layer=None): 107 | super(ResNet, self).__init__() 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | self._norm_layer = norm_layer 111 | 112 | self.inplanes = 64 113 | self.dilation = 1 114 | if replace_stride_with_dilation is None: 115 | # each element in the tuple indicates if we should replace 116 | # the 2x2 stride with a dilated convolution instead 117 | replace_stride_with_dilation = [False, False, False] 118 | if len(replace_stride_with_dilation) != 3: 119 | raise ValueError("replace_stride_with_dilation should be None " 120 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 121 | self.groups = groups 122 | self.base_width = width_per_group 123 | 124 | ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 125 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 126 | ## END 127 | 128 | self.bn1 = norm_layer(self.inplanes) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 133 | dilate=replace_stride_with_dilation[0]) 134 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 135 | dilate=replace_stride_with_dilation[1]) 136 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 137 | dilate=replace_stride_with_dilation[2]) 138 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 139 | self.fc = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 144 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | 148 | # Zero-initialize the last BN in each residual branch, 149 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 150 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 151 | if zero_init_residual: 152 | for m in self.modules(): 153 | if isinstance(m, Bottleneck): 154 | nn.init.constant_(m.bn3.weight, 0) 155 | elif isinstance(m, BasicBlock): 156 | nn.init.constant_(m.bn2.weight, 0) 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 159 | norm_layer = self._norm_layer 160 | downsample = None 161 | previous_dilation = self.dilation 162 | if dilate: 163 | self.dilation *= stride 164 | stride = 1 165 | if stride != 1 or self.inplanes != planes * block.expansion: 166 | downsample = nn.Sequential( 167 | conv1x1(self.inplanes, planes * block.expansion, stride), 168 | norm_layer(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 173 | self.base_width, previous_dilation, norm_layer)) 174 | self.inplanes = planes * block.expansion 175 | for _ in range(1, blocks): 176 | layers.append(block(self.inplanes, planes, groups=self.groups, 177 | base_width=self.base_width, dilation=self.dilation, 178 | norm_layer=norm_layer)) 179 | 180 | return nn.Sequential(*layers) 181 | 182 | def forward(self, x): 183 | x = self.conv1(x) 184 | x = self.bn1(x) 185 | x = self.relu(x) 186 | x = self.maxpool(x) 187 | 188 | x = self.layer1(x) 189 | x = self.layer2(x) 190 | x = self.layer3(x) 191 | x = self.layer4(x) 192 | 193 | x = self.avgpool(x) 194 | x = x.reshape(x.size(0), -1) 195 | x = self.fc(x) 196 | 197 | return x 198 | 199 | 200 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 201 | model = ResNet(block, layers, **kwargs) 202 | if pretrained: 203 | script_dir = os.path.dirname(__file__) 204 | state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device) 205 | model.load_state_dict(state_dict) 206 | return model 207 | 208 | 209 | def resnet18(pretrained=False, progress=True, device='cpu', **kwargs): 210 | """Constructs a ResNet-18 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | progress (bool): If True, displays a progress bar of the download to stderr 215 | """ 216 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device, 217 | **kwargs) 218 | 219 | 220 | def resnet34(pretrained=False, progress=True, device='cpu', **kwargs): 221 | """Constructs a ResNet-34 model. 222 | 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | progress (bool): If True, displays a progress bar of the download to stderr 226 | """ 227 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device, 228 | **kwargs) 229 | 230 | 231 | def resnet50(pretrained=False, progress=True, device='cpu', **kwargs): 232 | """Constructs a ResNet-50 model. 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | progress (bool): If True, displays a progress bar of the download to stderr 237 | """ 238 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device, 239 | **kwargs) 240 | 241 | 242 | def resnet101(pretrained=False, progress=True, **kwargs): 243 | """Constructs a ResNet-101 model. 244 | 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | progress (bool): If True, displays a progress bar of the download to stderr 248 | """ 249 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device, 250 | **kwargs) 251 | 252 | 253 | def resnet152(pretrained=False, progress=True, **kwargs): 254 | """Constructs a ResNet-152 model. 255 | 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | progress (bool): If True, displays a progress bar of the download to stderr 259 | """ 260 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device, 261 | **kwargs) 262 | 263 | 264 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 265 | """Constructs a ResNeXt-50 32x4d model. 266 | 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | progress (bool): If True, displays a progress bar of the download to stderr 270 | """ 271 | kwargs['groups'] = 32 272 | kwargs['width_per_group'] = 4 273 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 274 | pretrained, progress, device, **kwargs) 275 | 276 | 277 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 278 | """Constructs a ResNeXt-101 32x8d model. 279 | 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | kwargs['groups'] = 32 285 | kwargs['width_per_group'] = 8 286 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 287 | pretrained, progress, device, **kwargs) 288 | -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/resnet_orig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | 6 | #Credit to https://github.com/akamaster/pytorch_resnet_cifar10 7 | 8 | __all__ = ['resnet_orig'] 9 | 10 | class LambdaLayer(nn.Module): 11 | def __init__(self, lambd): 12 | super(LambdaLayer, self).__init__() 13 | self.lambd = lambd 14 | 15 | def forward(self, x): 16 | return self.lambd(x) 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, in_planes, planes, stride=1, option='A'): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != planes: 30 | if option == 'A': 31 | """ 32 | For CIFAR10 ResNet paper uses option A. 33 | """ 34 | self.shortcut = LambdaLayer(lambda x: 35 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 36 | elif option == 'B': 37 | self.shortcut = nn.Sequential( 38 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 39 | nn.BatchNorm2d(self.expansion * planes) 40 | ) 41 | 42 | def forward(self, x): 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = self.bn2(self.conv2(out)) 45 | out += self.shortcut(x) 46 | out = F.relu(out) 47 | return out 48 | 49 | class ResNet(nn.Module): 50 | def __init__(self, block, num_blocks, num_classes=10): 51 | super(ResNet, self).__init__() 52 | self.in_planes = 16 53 | 54 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(16) 56 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 57 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 58 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 59 | self.linear = nn.Linear(64, num_classes) 60 | 61 | def _make_layer(self, block, planes, num_blocks, stride): 62 | strides = [stride] + [1]*(num_blocks-1) 63 | layers = [] 64 | for stride in strides: 65 | layers.append(block(self.in_planes, planes, stride)) 66 | self.in_planes = planes * block.expansion 67 | 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | out = self.layer1(out) 73 | out = self.layer2(out) 74 | out = self.layer3(out) 75 | out = F.avg_pool2d(out, out.size()[3]) 76 | out = out.view(out.size(0), -1) 77 | out = self.linear(out) 78 | return out 79 | 80 | def resnet_orig(pretrained=True, device='cpu'): 81 | net = ResNet(BasicBlock, [3, 3, 3]) 82 | if pretrained: 83 | script_dir = os.path.dirname(__file__) 84 | state_dict = torch.load(script_dir + '/state_dicts/resnet_orig.pt', map_location=device) 85 | net.load_state_dict(state_dict) 86 | return net -------------------------------------------------------------------------------- /cifar10_models/cifar10_models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = [ 6 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 7 | 'vgg19_bn', 'vgg19', 8 | ] 9 | 10 | class VGG(nn.Module): 11 | 12 | def __init__(self, features, num_classes=10, init_weights=True): 13 | super(VGG, self).__init__() 14 | self.features = features 15 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 16 | self.classifier = nn.Sequential( 17 | nn.Linear(512 * 7 * 7, 4096), 18 | nn.ReLU(True), 19 | nn.Dropout(), 20 | nn.Linear(4096, 4096), 21 | nn.ReLU(True), 22 | nn.Dropout(), 23 | nn.Linear(4096, num_classes), 24 | ) 25 | if init_weights: 26 | self._initialize_weights() 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = self.avgpool(x) 31 | x = x.view(x.size(0), -1) 32 | x = self.classifier(x) 33 | return x 34 | 35 | def _initialize_weights(self): 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 39 | if m.bias is not None: 40 | nn.init.constant_(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | nn.init.constant_(m.weight, 1) 43 | nn.init.constant_(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | nn.init.normal_(m.weight, 0, 0.01) 46 | nn.init.constant_(m.bias, 0) 47 | 48 | 49 | def make_layers(cfg, batch_norm=False): 50 | layers = [] 51 | in_channels = 3 52 | for v in cfg: 53 | if v == 'M': 54 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 55 | else: 56 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 57 | if batch_norm: 58 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 59 | else: 60 | layers += [conv2d, nn.ReLU(inplace=True)] 61 | in_channels = v 62 | return nn.Sequential(*layers) 63 | 64 | 65 | cfgs = { 66 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 67 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 68 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 69 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 70 | } 71 | 72 | 73 | def _vgg(arch, cfg, batch_norm, pretrained, progress, device, **kwargs): 74 | if pretrained: 75 | kwargs['init_weights'] = False 76 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 77 | if pretrained: 78 | script_dir = os.path.dirname(__file__) 79 | state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device) 80 | model.load_state_dict(state_dict) 81 | return model 82 | 83 | 84 | def vgg11(pretrained=False, progress=True, **kwargs): 85 | """VGG 11-layer model (configuration "A") 86 | 87 | Args: 88 | pretrained (bool): If True, returns a model pre-trained on ImageNet 89 | progress (bool): If True, displays a progress bar of the download to stderr 90 | """ 91 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 92 | 93 | 94 | def vgg11_bn(pretrained=False, progress=True, device='cpu', **kwargs): 95 | """VGG 11-layer model (configuration "A") with batch normalization 96 | 97 | Args: 98 | pretrained (bool): If True, returns a model pre-trained on ImageNet 99 | progress (bool): If True, displays a progress bar of the download to stderr 100 | """ 101 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, device, **kwargs) 102 | 103 | 104 | def vgg13(pretrained=False, progress=True, **kwargs): 105 | """VGG 13-layer model (configuration "B") 106 | 107 | Args: 108 | pretrained (bool): If True, returns a model pre-trained on ImageNet 109 | progress (bool): If True, displays a progress bar of the download to stderr 110 | """ 111 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 112 | 113 | 114 | def vgg13_bn(pretrained=False, progress=True, device='cpu', **kwargs): 115 | """VGG 13-layer model (configuration "B") with batch normalization 116 | 117 | Args: 118 | pretrained (bool): If True, returns a model pre-trained on ImageNet 119 | progress (bool): If True, displays a progress bar of the download to stderr 120 | """ 121 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, device, **kwargs) 122 | 123 | 124 | def vgg16(pretrained=False, progress=True, **kwargs): 125 | """VGG 16-layer model (configuration "D") 126 | 127 | Args: 128 | pretrained (bool): If True, returns a model pre-trained on ImageNet 129 | progress (bool): If True, displays a progress bar of the download to stderr 130 | """ 131 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 132 | 133 | 134 | def vgg16_bn(pretrained=False, progress=True, device='cpu', **kwargs): 135 | """VGG 16-layer model (configuration "D") with batch normalization 136 | 137 | Args: 138 | pretrained (bool): If True, returns a model pre-trained on ImageNet 139 | progress (bool): If True, displays a progress bar of the download to stderr 140 | """ 141 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, device, **kwargs) 142 | 143 | 144 | def vgg19(pretrained=False, progress=True, **kwargs): 145 | """VGG 19-layer model (configuration "E") 146 | 147 | Args: 148 | pretrained (bool): If True, returns a model pre-trained on ImageNet 149 | progress (bool): If True, displays a progress bar of the download to stderr 150 | """ 151 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 152 | 153 | 154 | def vgg19_bn(pretrained=False, progress=True, device='cpu', **kwargs): 155 | """VGG 19-layer model (configuration 'E') with batch normalization 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | progress (bool): If True, displays a progress bar of the download to stderr 160 | """ 161 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, device, **kwargs) 162 | -------------------------------------------------------------------------------- /cifar10_models/cifar10_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torchvision.transforms as transforms 4 | from torchvision.datasets import CIFAR10 5 | from torch.utils.data import DataLoader 6 | from cifar10_models import * 7 | 8 | def get_classifier(classifier, pretrained): 9 | if classifier == 'vgg11_bn': 10 | return vgg11_bn(pretrained=pretrained) 11 | elif classifier == 'vgg13_bn': 12 | return vgg13_bn(pretrained=pretrained) 13 | elif classifier == 'vgg16_bn': 14 | return vgg16_bn(pretrained=pretrained) 15 | elif classifier == 'vgg19_bn': 16 | return vgg19_bn(pretrained=pretrained) 17 | elif classifier == 'resnet18': 18 | return resnet18(pretrained=pretrained) 19 | elif classifier == 'resnet34': 20 | return resnet34(pretrained=pretrained) 21 | elif classifier == 'resnet50': 22 | return resnet50(pretrained=pretrained) 23 | elif classifier == 'densenet121': 24 | return densenet121(pretrained=pretrained) 25 | elif classifier == 'densenet161': 26 | return densenet161(pretrained=pretrained) 27 | elif classifier == 'densenet169': 28 | return densenet169(pretrained=pretrained) 29 | elif classifier == 'mobilenet_v2': 30 | return mobilenet_v2(pretrained=pretrained) 31 | elif classifier == 'googlenet': 32 | return googlenet(pretrained=pretrained) 33 | elif classifier == 'inception_v3': 34 | return inception_v3(pretrained=pretrained) 35 | else: 36 | raise NameError('Please enter a valid classifier') 37 | 38 | class CIFAR10_Module(pl.LightningModule): 39 | def __init__(self, hparams, pretrained=False): 40 | super().__init__() 41 | self.hparams = hparams 42 | self.criterion = torch.nn.CrossEntropyLoss() 43 | self.mean = [0.4914, 0.4822, 0.4465] 44 | self.std = [0.2023, 0.1994, 0.2010] 45 | self.model = get_classifier(hparams.classifier, pretrained) 46 | self.train_size = len(self.train_dataloader().dataset) 47 | self.val_size = len(self.val_dataloader().dataset) 48 | 49 | def forward(self, batch): 50 | images, labels = batch 51 | predictions = self.model(images) 52 | loss = self.criterion(predictions, labels) 53 | accuracy = torch.sum(torch.max(predictions, 1)[1] == labels.data).float() / batch[0].size(0) 54 | return loss, accuracy 55 | 56 | def training_step(self, batch, batch_nb): 57 | loss, accuracy = self.forward(batch) 58 | logs = {'loss/train': loss, 'accuracy/train': accuracy} 59 | return {'loss': loss, 'log': logs} 60 | 61 | def validation_step(self, batch, batch_nb): 62 | avg_loss, accuracy = self.forward(batch) 63 | loss = avg_loss * batch[0].size(0) 64 | corrects = accuracy * batch[0].size(0) 65 | logs = {'loss/val': loss, 'corrects': corrects} 66 | return logs 67 | 68 | def validation_epoch_end(self, outputs): 69 | loss = torch.stack([x['loss/val'] for x in outputs]).sum() / self.val_size 70 | accuracy = torch.stack([x['corrects'] for x in outputs]).sum() / self.val_size 71 | logs = {'loss/val': loss, 'accuracy/val': accuracy} 72 | return {'val_loss': loss, 'log': logs} 73 | 74 | def test_step(self, batch, batch_nb): 75 | return self.validation_step(batch, batch_nb) 76 | 77 | def test_epoch_end(self, outputs): 78 | accuracy = self.validation_epoch_end(outputs)['log']['accuracy/val'] 79 | accuracy = round((100 * accuracy).item(), 2) 80 | return {'progress_bar': {'Accuracy': accuracy}} 81 | 82 | def configure_optimizers(self): 83 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.learning_rate, 84 | weight_decay=self.hparams.weight_decay, momentum=0.9, nesterov=True) 85 | 86 | scheduler = {'scheduler': torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.learning_rate, 87 | steps_per_epoch=self.train_size//self.hparams.batch_size, 88 | epochs=self.hparams.max_epochs), 89 | 'interval': 'step', 'name': 'learning_rate'} 90 | return [optimizer], [scheduler] 91 | 92 | def train_dataloader(self): 93 | transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize(self.mean, self.std)]) 97 | dataset = CIFAR10(root=self.hparams.data_dir, train=True, transform=transform_train) 98 | dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True, drop_last=True, pin_memory=True) 99 | return dataloader 100 | 101 | def val_dataloader(self): 102 | transform_val = transforms.Compose([transforms.ToTensor(), 103 | transforms.Normalize(self.mean, self.std)]) 104 | dataset = CIFAR10(root=self.hparams.data_dir, train=False, transform=transform_val) 105 | dataloader = DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, pin_memory=True) 106 | return dataloader 107 | 108 | def test_dataloader(self): 109 | return self.val_dataloader() -------------------------------------------------------------------------------- /cifar10_models/cifar10_test.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | import torch 3 | from argparse import ArgumentParser 4 | from pytorch_lightning import Trainer 5 | from cifar10_module import CIFAR10_Module 6 | 7 | def main(hparams): 8 | # If only train on 1 GPU. Must set_device otherwise PyTorch always store model on GPU 0 first 9 | if type(hparams.gpus) == str: 10 | if len(hparams.gpus) == 2: # GPU number and comma e.g. '0,' or '1,' 11 | torch.cuda.set_device(int(hparams.gpus[0])) 12 | 13 | model = CIFAR10_Module(hparams, pretrained=True) 14 | trainer = Trainer(gpus=hparams.gpus, default_save_path=os.path.join(os.getcwd(), 'test_temp')) 15 | trainer.test(model) 16 | shutil.rmtree(os.path.join(os.getcwd(), 'test_temp')) 17 | 18 | if __name__ == '__main__': 19 | parser = ArgumentParser() 20 | parser.add_argument('--classifier', type=str, default='resnet18') 21 | parser.add_argument('--data_dir', type=str, default='/data/huy/cifar10/') 22 | parser.add_argument('--gpus', default='0,') 23 | parser.add_argument('--max_epochs', type=int, default=100) 24 | parser.add_argument('--batch_size', type=int, default=256) 25 | parser.add_argument('--learning_rate', type=float, default=1e-2) 26 | parser.add_argument('--weight_decay', type=float, default=1e-2) 27 | args = parser.parse_args() 28 | main(args) -------------------------------------------------------------------------------- /cifar10_models/cifar10_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from argparse import ArgumentParser 4 | from pytorch_lightning import Trainer, seed_everything 5 | from pytorch_lightning.callbacks import LearningRateLogger 6 | from pytorch_lightning.loggers import TensorBoardLogger 7 | from cifar10_module import CIFAR10_Module 8 | 9 | def main(hparams): 10 | 11 | seed_everything(0) 12 | 13 | # If only train on 1 GPU. Must set_device otherwise PyTorch always store model on GPU 0 first 14 | if type(hparams.gpus) == str: 15 | if len(hparams.gpus) == 2: # GPU number and comma e.g. '0,' or '1,' 16 | torch.cuda.set_device(int(hparams.gpus[0])) 17 | 18 | # Model 19 | classifier = CIFAR10_Module(hparams) 20 | 21 | # Trainer 22 | lr_logger = LearningRateLogger() 23 | logger = TensorBoardLogger("logs", name=hparams.classifier) 24 | trainer = Trainer(callbacks=[lr_logger], gpus=hparams.gpus, max_epochs=hparams.max_epochs, 25 | deterministic=True, early_stop_callback=False, logger=logger) 26 | trainer.fit(classifier) 27 | 28 | # Load best checkpoint 29 | checkpoint_path = os.path.join(os.getcwd(), 'logs', hparams.classifier, 'version_' + str(classifier.logger.version),'checkpoints') 30 | classifier = CIFAR10_Module.load_from_checkpoint(os.path.join(checkpoint_path, os.listdir(checkpoint_path)[0])) 31 | 32 | # Save weights from checkpoint 33 | statedict_path = os.path.join(os.getcwd(), 'cifar10_models', 'state_dicts', hparams.classifier + '.pt') 34 | torch.save(classifier.model.state_dict(), statedict_path) 35 | 36 | # Test model 37 | trainer.test(classifier) 38 | 39 | if __name__ == '__main__': 40 | parser = ArgumentParser() 41 | parser.add_argument('--classifier', type=str, default='resnet18') 42 | parser.add_argument('--data_dir', type=str, default='/data/huy/cifar10/') 43 | parser.add_argument('--gpus', default='0,') # use None to train on CPU 44 | parser.add_argument('--batch_size', type=int, default=256) 45 | parser.add_argument('--max_epochs', type=int, default=100) 46 | parser.add_argument('--learning_rate', type=float, default=1e-2) 47 | parser.add_argument('--weight_decay', type=float, default=1e-2) 48 | args = parser.parse_args() 49 | main(args) -------------------------------------------------------------------------------- /imports.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import os, copy, pickle, time 4 | import argparse 5 | import itertools 6 | from collections import defaultdict, Counter, OrderedDict 7 | import matplotlib.pyplot as plt 8 | 9 | import numpy as np 10 | import seaborn as sns 11 | import torch 12 | import torchvision 13 | from torch import optim, nn 14 | import torch.nn.functional as F 15 | from torch.utils.data import TensorDataset, DataLoader 16 | from torch.autograd import Variable 17 | import pandas as pd 18 | 19 | sys.path.append('scripts') 20 | import lms_utils 21 | import data_utils as du 22 | import gpu_utils as gu 23 | import ptb_utils as pu 24 | import synth_models 25 | import gendata 26 | import utils 27 | import synth_models as sm 28 | import mnistcifar_utils as mc_utils 29 | import ensemble 30 | 31 | torch.backends.cudnn.benchmark = True 32 | torch.backends.cudnn.enabled = True 33 | 34 | def get_data(**c): 35 | smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin'] 36 | data_func = gendata.generate_ub_linslab_data_v2 37 | spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7'] 38 | data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'], 39 | corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'], 40 | corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab'], corrupt_slab7=c['corrupt_slab7']) 41 | return data 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | seaborn==0.9.0 2 | torchvision==0.3.0 3 | matplotlib==3.0.3 4 | torch==1.1.0.post2 5 | requests==2.21.0 6 | tqdm==4.31.1 7 | scipy==1.2.1 8 | numpy==1.19.2 9 | pandas==0.24.2 10 | pytorch_lightning==0.3.6.9 11 | dill==0.3.2 12 | pycuda==2020.1 13 | scikit_learn==0.23.2 14 | -------------------------------------------------------------------------------- /scripts/data_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import random, os, copy, pickle, time, random, argparse, itertools 4 | from collections import defaultdict, Counter, OrderedDict 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from torch import optim, nn 9 | import torch.nn.functional as F 10 | from sklearn import metrics 11 | from torch.utils.data import TensorDataset, DataLoader 12 | 13 | import gpu_utils as gu 14 | import lms_utils as au 15 | import synth_models 16 | import utils 17 | import matplotlib.pyplot as plt 18 | import pathlib 19 | 20 | try: 21 | sys.path.append('../../cifar10_models/') 22 | import cifar10_models as c10 23 | c10_not_found = False 24 | except: 25 | c10_not_found = True 26 | 27 | torch.backends.cudnn.benchmark = True 28 | torch.backends.cudnn.enabled = True 29 | 30 | REPO_DIR = pathlib.Path(__file__).parent.parent.absolute() 31 | DOWNLOAD_DIR = os.path.join(REPO_DIR, 'datasets') 32 | 33 | def msd(x, r=3): 34 | return np.round(np.mean(x), r), np.round(np.std(x), r) 35 | 36 | def _get_dataloaders(trd, ted, bs, pm=True, shuffle=True): 37 | train_dl = DataLoader(trd, batch_size=bs, shuffle=shuffle, pin_memory=pm) 38 | test_dl = DataLoader(ted, batch_size=bs, pin_memory=pm) 39 | return train_dl, test_dl 40 | 41 | def get_cifar10_models(device=None, pretrained=True): 42 | if c10_not_found: return {} 43 | device = gu.get_device(None) if device is None else device 44 | get_lmbda = lambda cls: (lambda: cls(pretrained=pretrained).eval().to(device)) 45 | return { 46 | 'vgg11_bn': get_lmbda(c10.vgg11_bn), 47 | 'vgg13_bn': get_lmbda(c10.vgg13_bn), 48 | 'vgg16_bn': get_lmbda(c10.vgg16_bn), 49 | 'vgg19_bn': get_lmbda(c10.vgg19_bn), 50 | 'resnet18': get_lmbda(c10.resnet18), 51 | 'resnet34': get_lmbda(c10.resnet34), 52 | 'resnet50': get_lmbda(c10.resnet50), 53 | 'densenet121': get_lmbda(c10.densenet121), 54 | 'densenet161': get_lmbda(c10.densenet161), 55 | 'densenet169': get_lmbda(c10.densenet169), 56 | 'mobilenet_v2': get_lmbda(c10.mobilenet_v2), 57 | 'googlenet': get_lmbda(c10.googlenet), 58 | 'inception_v3': get_lmbda(c10.inception_v3) 59 | } 60 | 61 | def plot_decision_boundary(dl, model, c1, c2, ax=None, print_info=True): 62 | if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4)) 63 | model = model.cpu() 64 | deps = sorted(au.get_feature_deps(dl, model).items(), key=lambda t: t[-1]) 65 | 66 | if print_info: 67 | for k, v in deps: print ('{}:{:.3f}'.format(k,v), end=', ') 68 | print ("") 69 | 70 | X, Y = utils.extract_numpy_from_loader(dl) 71 | K = 100_000 72 | U = np.random.uniform(low=X.min(), high=X.max(), size=(K, X.shape[1])) # copy.deepcopy(X) 73 | U[:, c1] = np.random.uniform(low=X[:, c1].min(), high=X[:, c1].max(), size=K) 74 | U[:, c2] = np.random.uniform(low=X[:, c2].min(), high=X[:, c2].max(), size=K) 75 | U = torch.Tensor(U) 76 | 77 | with torch.no_grad(): 78 | out = model(U) 79 | Yu = torch.argmax(out, 1) 80 | 81 | ax.scatter(U[:,c1], U[:,c2], c=Yu, alpha=0.3, s=24) 82 | ax.scatter(X[:,c1], X[:,c2], c=Y, cmap='coolwarm', s=12) 83 | 84 | def get_binary_datasets(X, Y, y1, y2, image_width=28, use_cnn=False): 85 | assert type(X) is np.ndarray and type(Y) is np.ndarray 86 | idx0 = (Y==y1).nonzero()[0] 87 | idx1 = (Y==y2).nonzero()[0] 88 | idx = np.concatenate((idx0, idx1)) 89 | X_, Y_ = X[idx,:], (Y[idx]==y2).astype(int) 90 | P = np.random.permutation(len(X_)) 91 | X_, Y_ = X_[P,:], Y_[P] 92 | if use_cnn: X_ = X_.reshape(X.shape[0], -1, image_width)[:, None, :, :] 93 | return X_[P,:], Y_[P] 94 | 95 | def get_binary_loader(dl, y1, y2): 96 | X, Y = utils.extract_numpy_from_loader(dl) 97 | X, Y = get_binary_datasets(X, Y, y1, y2) 98 | return utils._to_dl(X, Y, bs=dl.batch_size) 99 | 100 | def get_mnist(fpath=DOWNLOAD_DIR, flatten=False, binarize=False, normalize=True, y0={0,1,2,3,4}): 101 | """get preprocessed mnist torch.TensorDataset class""" 102 | def _to_torch(d): 103 | X, Y = [], [] 104 | for xb, yb in d: 105 | X.append(xb) 106 | Y.append(yb) 107 | return torch.Tensor(np.stack(X)), torch.LongTensor(np.stack(Y)) 108 | 109 | to_tensor = torchvision.transforms.ToTensor() 110 | to_flat = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze()) 111 | to_norm = torchvision.transforms.Normalize((0.5, ), (0.5, )) 112 | to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1) 113 | 114 | transforms = [to_tensor] 115 | if normalize: transforms.append(to_norm) 116 | if flatten: transforms.append(to_flat) 117 | tf = torchvision.transforms.Compose(transforms) 118 | ttf = to_binary if binarize else None 119 | 120 | X_tr = torchvision.datasets.MNIST(fpath, download=True, transform=tf, target_transform=ttf) 121 | X_te = torchvision.datasets.MNIST(fpath, download=True, train=False, transform=tf, target_transform=ttf) 122 | 123 | return _to_torch(X_tr), _to_torch(X_te) 124 | 125 | def get_mnist_dl(fpath=DOWNLOAD_DIR, to_np=False, bs=128, pm=False, shuffle=False, 126 | normalize=True, flatten=False, binarize=False, y0={0,1,2,3,4}): 127 | (X_tr, Y_tr), (X_te, Y_te) = get_mnist(fpath, normalize=normalize, flatten=flatten, binarize=binarize, y0=y0) 128 | tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, shuffle=shuffle, pin_memory=pm) 129 | te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, pin_memory=pm) 130 | return tr_dl, te_dl 131 | 132 | def get_cifar(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none', 133 | means=None, std=None, use_grayscale=False, binarize=False, normalize=True, y0={0,1,2,3,4}): 134 | """get preprocessed cifar torch.Dataset class""" 135 | 136 | if transform_type == 'none': 137 | normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 138 | tensorize = torchvision.transforms.ToTensor() 139 | to_grayscale = torchvision.transforms.Grayscale() 140 | flatten = torchvision.transforms.Lambda(lambda X: X.reshape(-1).squeeze()) 141 | 142 | transforms = [tensorize] 143 | if use_grayscale: transforms = [to_grayscale] + transforms 144 | if normalize: transforms.append(normalize_cifar()) 145 | if flatten_data: transforms.append(flatten) 146 | tr_transforms = te_transforms = torchvision.transforms.Compose(transforms) 147 | 148 | if transform_type == 'basic': 149 | normalize_cifar = lambda: torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 150 | 151 | tr_transforms= [ 152 | torchvision.transforms.RandomCrop(32, padding=4), 153 | torchvision.transforms.RandomHorizontalFlip(), 154 | torchvision.transforms.ToTensor() 155 | ] 156 | 157 | te_transforms = [ 158 | torchvision.transforms.Resize(32), 159 | torchvision.transforms.CenterCrop(32), 160 | torchvision.transforms.ToTensor(), 161 | ] 162 | 163 | if normalize: 164 | tr_transforms.append(normalize_cifar()) 165 | te_transforms.append(normalize_cifar()) 166 | 167 | tr_transforms = torchvision.transforms.Compose(tr_transforms) 168 | te_transforms = torchvision.transforms.Compose(te_transforms) 169 | 170 | to_binary = torchvision.transforms.Lambda(lambda y: 0 if y in y0 else 1) 171 | target_transforms = to_binary if binarize else None 172 | dset = 'cifar10' if use_cifar10 else 'cifar100' 173 | func = torchvision.datasets.CIFAR10 if use_cifar10 else torchvision.datasets.CIFAR100 174 | 175 | X_tr = func(fpath, download=True, transform=tr_transforms, target_transform=target_transforms) 176 | X_te = func(fpath, download=True, train=False, transform=te_transforms, target_transform=target_transforms) 177 | 178 | return X_tr, X_te 179 | 180 | def get_cifar_dl(fpath=DOWNLOAD_DIR, use_cifar10=False, bs=128, shuffle=True, transform_type='none', 181 | means=None, std=None, normalize=True, flatten_data=False, use_grayscale=False, nw=4, pm=False, binarize=False, y0={0,1,2,3,4}): 182 | """data in dataloaders have has shape (B, C, W, H)""" 183 | d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, means=means, std=std, flatten_data=flatten_data, binarize=binarize, y0=y0) 184 | tr_dl = DataLoader(d_tr, batch_size=bs, shuffle=shuffle, num_workers=nw, pin_memory=pm) 185 | te_dl = DataLoader(d_te, batch_size=bs, num_workers=nw, pin_memory=pm) 186 | return tr_dl, te_dl 187 | 188 | def get_cifar_np(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False, transform_type='none', normalize=True, binarize=False, y0={0,1,2,3,4}, use_grayscale=False): 189 | """get numpy matrices of preprocessed cifar data""" 190 | 191 | def _to_np(d): 192 | X, Y = [], [] 193 | for xb, yb in d: 194 | X.append(xb) 195 | Y.append(yb) 196 | return map(np.stack, [X,Y]) 197 | 198 | d_tr, d_te = get_cifar(fpath, use_cifar10=use_cifar10, use_grayscale=use_grayscale, transform_type=transform_type, normalize=normalize, flatten_data=flatten_data, binarize=binarize, y0=y0) 199 | return _to_np(d_tr), _to_np(d_te) 200 | 201 | if __name__ == '__main__': 202 | pass -------------------------------------------------------------------------------- /scripts/ensemble.py: -------------------------------------------------------------------------------- 1 | import os, copy, pickle, time 2 | import random, itertools 3 | from collections import defaultdict, Counter, OrderedDict 4 | import numpy as np 5 | import torch 6 | import pandas as pd 7 | import torchvision 8 | from torch.utils.data import TensorDataset, DataLoader 9 | from torch import optim, nn 10 | import torch.nn.functional as F 11 | import dill 12 | import gpu_utils as gu 13 | import data_utils as du 14 | import synth_models as sm 15 | import utils 16 | 17 | class Ensemble(nn.Module): 18 | 19 | def _get_dummy_classifier(self): 20 | def dummy(x): 21 | return x 22 | return dummy 23 | 24 | def __init__(self, models, num_classes, use_softmax=False): 25 | super(Ensemble, self).__init__() 26 | self.num_classes = num_classes 27 | self.use_softmax = use_softmax 28 | 29 | # register models as pytorch modules 30 | self.models = [] 31 | for idx, m in enumerate(models,1): 32 | setattr(self, 'm{}'.format(idx), m.eval()) 33 | self.models.append(getattr(self, 'm{}'.format(idx))) 34 | 35 | self.classifier = self._get_dummy_classifier() 36 | 37 | def _forward(self, x): 38 | return x 39 | 40 | def forward(self, x): 41 | outs = self._forward(x) 42 | return self.classifier(outs) 43 | 44 | def get_output_loader(self, dl, device=gu.get_device(None), bs=None): 45 | """return dataloader of model output (logit or softmax prob)""" 46 | X, Y = [], [] 47 | with torch.no_grad(): 48 | for xb, yb in dl: 49 | xb = xb.to(device) 50 | out = self._forward(xb).cpu() 51 | X.append(out) 52 | Y.append(yb) 53 | X, Y = torch.cat(X), torch.cat(Y) 54 | return DataLoader(TensorDataset(X, Y), batch_size=bs or dl.batch_size) 55 | 56 | def fit_classifier(self, tr_dl, te_dl, lr=0.05, adam=False, wd=5e-5, device=None, **fit_kw): 57 | device = gu.get_device(None) if device is None else device 58 | self = self.to(device) 59 | 60 | c = dict(gap=1000, epsilon=1e-2, wd=5e-5, is_loss_epsilon=True) 61 | c.update(**fit_kw) 62 | 63 | tro_dl = self.get_output_loader(tr_dl, device) 64 | teo_dl = self.get_output_loader(te_dl, device) 65 | 66 | if adam: opt = optim.Adam(self.classifier.parameters()) 67 | else: opt = optim.SGD(self.classifier.parameters(), lr=lr, weight_decay=wd) 68 | stats = utils.fit_model(self.classifier, F.cross_entropy, opt, tro_dl, teo_dl, device=device, **c) 69 | 70 | self.classifier = stats['best_model'][-1].to(device) 71 | self = self.cpu() 72 | return stats 73 | 74 | class EnsembleLinear(Ensemble): 75 | 76 | def _get_classifier(self): 77 | # linear with equal weights and zero bias 78 | nl = self.num_classes*len(self.models) 79 | linear = nn.Linear(nl, self.num_classes, bias=self.use_bias) 80 | nn.init.ones_(linear.weight.data) 81 | linear.weight.data /= float(nl) 82 | if self.use_bias: linear.bias.data.zero_() 83 | return linear 84 | 85 | def __init__(self, models, num_classes=2, use_softmax=False, use_bias=True): 86 | super(EnsembleLinear, self).__init__(models, num_classes, use_softmax) 87 | self.use_bias = use_bias 88 | self.classifier = self._get_classifier() 89 | 90 | def _forward(self, x): 91 | outs = [m(x) for m in self.models] 92 | if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs] 93 | outs = torch.stack(outs, dim=2) 94 | outs = outs.reshape(outs.shape[0], -1) 95 | return outs 96 | 97 | class EnsembleMLP(Ensemble): 98 | 99 | def _get_classifier(self): 100 | nl = self.num_classes*len(self.models) 101 | fcn = sm.get_fcn(nl, self.hdim or nl, self.num_classes, hl=self.hl) 102 | return fcn 103 | 104 | def __init__(self, models, num_classes=2, use_softmax=False, hdim=None, hl=1): 105 | super(EnsembleMLP, self).__init__(models, num_classes, use_softmax) 106 | self.hdim = hdim 107 | self.hl = hl 108 | self.classifier = self._get_classifier() 109 | 110 | def _forward(self, x): 111 | outs = [m(x) for m in self.models] 112 | if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs] 113 | outs = torch.stack(outs, dim=2) 114 | outs = outs.reshape(outs.shape[0], -1) 115 | return outs 116 | 117 | class EnsembleAverage(Ensemble): 118 | 119 | def __init__(self, models, num_classes=2, use_softmax=False): 120 | super(EnsembleAverage, self).__init__(models, num_classes, use_softmax) 121 | self.classifier = self._get_dummy_classifier() 122 | 123 | def _forward(self, x): 124 | outs = [m(x) for m in self.models] 125 | if self.use_softmax: outs = [F.softmax(o, dim=1) for o in outs] 126 | outs = torch.stack(outs) 127 | return outs.mean(dim=0) 128 | 129 | def fit_classifier(self, *args, **kwargs): 130 | return None 131 | -------------------------------------------------------------------------------- /scripts/gendata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as scs 3 | import random 4 | from collections import Counter 5 | import torch 6 | from torch.utils.data import TensorDataset, DataLoader 7 | import utils 8 | import gpu_utils as gu 9 | 10 | def _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=None): 11 | X_te, Y_te = torch.Tensor(X[:N_te,:]), torch.Tensor(Y[:N_te]) 12 | X_tr, Y_tr = torch.Tensor(X[N_te:,:]), torch.Tensor(Y[N_te:]) 13 | Y_te, Y_tr = map(lambda Z: Z.long(), [Y_te, Y_tr]) 14 | 15 | tr_dl = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=True) 16 | te_dl = DataLoader(TensorDataset(X_te, Y_te), batch_size=bs, num_workers=nw, pin_memory=pm, shuffle=False) 17 | 18 | return { 19 | 'X': torch.tensor(X).float(), 20 | 'Y': torch.tensor(Y).long(), 21 | 'w': w, 22 | 'tr_dl': tr_dl, 23 | 'te_dl': te_dl, 24 | 'N': (N_tr, N_te), 25 | 'W': orth_matrix 26 | } 27 | 28 | def _get_random_data(N, dim, scale): 29 | X = np.random.uniform(size=(N, dim)) 30 | X *= scale 31 | Y = np.random.choice([0,1], size=N) 32 | return X, Y 33 | 34 | def generate_linsep_data_v2(N_tr, dim, eff_margin, width=10., bs=256, scale_noise=True, pm=True, nw=0, no_width=False, N_te=5000): # no unif_max. 35 | assert eff_margin < 1, "equal range constraint" 36 | margin = eff_margin if no_width else eff_margin*width 37 | 38 | N = N_tr + N_te 39 | w = np.zeros(shape=dim) 40 | w[0] = 1 41 | 42 | X, Y = _get_random_data(N, dim, width if scale_noise else 1.) 43 | 44 | U = np.random.uniform(size=N) 45 | if no_width: X[:,0] = (2*Y-1)*margin 46 | else: X[:, 0] = (2*Y-1)*(margin + (width-margin)*U) 47 | 48 | P = np.random.permutation(X.shape[0]) 49 | X, Y = X[P,:], Y[P] 50 | 51 | return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w) 52 | 53 | def sample_from_unif_union_of_unifs(unifs, size): 54 | x = [] 55 | choices = Counter(np.random.choice(list(range(len(unifs))), size=size)) 56 | for choice, sz in choices.items(): 57 | s = np.random.uniform(low=unifs[choice][0], high=unifs[choice][1], size=sz) 58 | x.append(s) 59 | x = np.concatenate(x) 60 | return x 61 | 62 | def generate_ub_linslab_data_diffmargin_v2(N_tr, dim, eff_lin_margins, eff_slab_margins, 63 | slabs_per_coord, slab_p_vals, corrupt_lin=0., corrupt_slab=0., 64 | corrupt_slab7=0., scale_noise=True, width=10., lin_coord=0, lin_shift=0., 65 | slab_shift=0., indep_slabs=True, bs=256, pm=True, nw=0, N_te=10000, 66 | random_transform=False, corrupt_lin_margin=False, corrupt_5slab_margin=False): 67 | get_unif = lambda a: np.random.uniform(size=a) 68 | get_bool = lambda a: np.random.choice([0,1], size=a) 69 | get_sign = lambda a: 2*get_bool(a)-1. 70 | 71 | def get_slab_width(NS, B, SM): 72 | if NS==3: return (2.*B-4.*SM)/3. 73 | if NS==5: return (2.*B-8.*SM)/5. 74 | if NS==7: return (2.*B-12.*SM)/7. 75 | return None 76 | 77 | num_lin, num_slabs = map(len, [eff_lin_margins, eff_slab_margins]) 78 | assert 0 <= corrupt_lin <= 1, "input is probability" 79 | assert num_lin + num_slabs <= dim, "dim constraint, num_lin: {}, num_slabs: {}, dim: {}".format(num_lin, num_slabs, dim) 80 | for elm in eff_lin_margins: assert 0 < elm < 1, "equal range constraint (0 < eff_lin_margin={} < 1)".format(elm) 81 | for esm in eff_slab_margins: assert 0 < esm < 1, "equal range constraint (0 < eff_slab_margin={} < 0.25)".format(esm) 82 | 83 | lin_margins = list(map(lambda x: x*width, eff_lin_margins)) 84 | slab_margins = list(map(lambda x: x*width, eff_slab_margins)) 85 | 86 | # hyperplane 87 | N = N_tr + N_te 88 | half_N = N//2 89 | w = np.zeros(shape=dim); w[0] = 1 90 | 91 | X, Y = _get_random_data(N, dim, width if scale_noise else 1.) 92 | nrange = list(range(N)) 93 | # linear 94 | total_corrupt = int(round(N*corrupt_lin)) 95 | no_linear = num_lin == 0 96 | if not no_linear: 97 | for coord, lin_margin in enumerate(lin_margins): 98 | if indep_slabs: 99 | P = np.random.permutation(N) 100 | X, Y = X[P, :], Y[P] 101 | X[:, coord] = (2*Y-1)*(lin_margin+(width-lin_margin)*get_unif(N)) + lin_shift*width 102 | 103 | # corrupt linear coordinate 104 | if total_corrupt > 0: 105 | corrupt_sample = np.random.choice(nrange, size=total_corrupt, replace=False) 106 | if corrupt_lin_margin: 107 | X[corrupt_sample, 0] = np.random.uniform(low=-lin_margin, high=lin_margin, size=total_corrupt) 108 | else: 109 | X[corrupt_sample, 0] *= -1 110 | 111 | # slabs 112 | i = (num_lin)*int(not no_linear) 113 | for idx, coord in enumerate(range(i, i+num_slabs)): 114 | slab_per = slabs_per_coord[idx] 115 | assert slab_per in [3, 5, 7], "Invalid slabs_per_coord" 116 | 117 | slab_pval = slab_p_vals[idx] 118 | slab_margin = slab_margins[idx] 119 | slab_width = get_slab_width(slab_per, width, slab_margin) 120 | 121 | if indep_slabs: 122 | P = np.random.permutation(N) 123 | X, Y = X[P, :], Y[P] 124 | 125 | if slab_per == 3: 126 | # positive slabs 127 | idx_p = (Y==1).nonzero()[0] 128 | offset = 0.5*slab_width + 2*slab_margin 129 | X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p))) 130 | 131 | # negative center 132 | idx_n = (Y==0).nonzero()[0] 133 | X[idx_n, coord] = 0.5*get_sign(len(idx_n))*slab_width*get_unif(len(idx_n)) 134 | 135 | if slab_per == 5: 136 | # positive slabs 137 | idx_p = (Y==1).nonzero()[0] 138 | offset = (width+6*slab_margin)/5. 139 | X[idx_p, coord] = get_sign(len(idx_p))*(offset+slab_width*get_unif(len(idx_p))) 140 | 141 | # negative slabs partitioned using p val 142 | idx_n = (Y==0).nonzero()[0] 143 | in_ctr = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_n)) 144 | idx_nc, idx_ns = idx_n[(in_ctr==1)], idx_n[(in_ctr==0)] 145 | 146 | # negative center 147 | X[idx_nc, coord] = 0.5*get_sign(len(idx_nc))*slab_width*get_unif(len(idx_nc)) 148 | 149 | # negative sides 150 | offset = (8*slab_margin+3*width)/5. 151 | X[idx_ns, coord] = get_sign(len(idx_ns))*(offset+slab_width*get_unif(len(idx_ns))) 152 | 153 | # corrupt slab 5 154 | total_corrupt = int(round(N*corrupt_slab)) 155 | if total_corrupt > 0: 156 | if corrupt_5slab_margin: 157 | offset1 = (width+6*slab_margin)/5. 158 | offset2 = (8*slab_margin+3*width)/5. 159 | unifs = [ 160 | (0.5*slab_width, offset1), 161 | (offset1+slab_width, offset2), 162 | (-offset1, -0.5*slab_width), 163 | (-offset2, -(offset1+slab_width)) 164 | ] 165 | 166 | idx = np.random.choice(range(N), size=total_corrupt, replace=False) 167 | X[idx, coord] = sample_from_unif_union_of_unifs(unifs, total_corrupt) 168 | else: 169 | # get corrupt sample 170 | idx = np.random.choice(range(N), size=total_corrupt, replace=False) 171 | idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1) 172 | idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1) 173 | 174 | # move negative points to random positive slabs 175 | offset = (0.5*slab_width+2*slab_margin) 176 | X[idx_n, coord] = torch.Tensor(get_sign(len(idx_n))*(offset+slab_width*get_unif(len(idx_n)))) 177 | 178 | # pick negative slab for each positve point 179 | mv_to_ctr = np.random.choice([0, 1], size=len(idx_p)) 180 | idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1) 181 | idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1) 182 | 183 | # move positive points to negative slabs 184 | X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*slab_width*get_unif(len(idx_p_ctr))) 185 | 186 | # move negative points to positve slabs 187 | offset = 1.5*slab_width + 4*slab_margin 188 | X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid)))) 189 | 190 | if slab_per == 7: 191 | # positive slabs 192 | idx_p = (Y==1).nonzero()[0] 193 | in_s0 = np.random.choice([0,1], p=[1-slab_pval, slab_pval], size=len(idx_p)) 194 | idx_p0, idx_p1 = idx_p[(in_s0==1)], idx_p[(in_s0==0)] 195 | 196 | # positive slab 0 (inner) 197 | offset = 0.5*slab_width+2*slab_margin 198 | X[idx_p0, coord] = get_sign(len(idx_p0))*(offset+slab_width*get_unif(len(idx_p0))) 199 | 200 | # positive slab 1 (outer) 201 | offset = 2.5*slab_width+6*slab_margin 202 | X[idx_p1, coord] = get_sign(len(idx_p1))*(offset+slab_width*get_unif(len(idx_p1))) 203 | 204 | # negative slabs 205 | idx_n = (Y==0).nonzero()[0] 206 | in_s0 = get_bool(len(idx_n)) 207 | idx_n0, idx_n1 = idx_n[(in_s0==1)], idx_n[(in_s0==0)] 208 | 209 | # negative slab 0 (center) 210 | X[idx_n0, coord] = 0.5*get_sign(len(idx_n0))*slab_width*get_unif(len(idx_n0)) 211 | 212 | # negative slab 1 (outer) 213 | offset = 1.5*slab_width+4*slab_margin 214 | X[idx_n1, coord] = get_sign(len(idx_n1))*(offset+slab_width*get_unif(len(idx_n1))) 215 | 216 | # corrupt slab7 217 | total_corrupt = int(round(N*corrupt_slab7)) 218 | if total_corrupt > 0: 219 | # corrupt data 220 | idx = np.random.choice(range(len(X)), size=total_corrupt, replace=False) 221 | idx_p = idx[np.argwhere((Y[idx]==1))].reshape(-1) 222 | idx_n = idx[np.argwhere((Y[idx]==0))].reshape(-1) 223 | 224 | # pick positive slab for each negative slab 225 | mv_to_inner = get_bool(len(idx_n)) 226 | idx_n_inner = idx_n[np.argwhere(mv_to_inner==1)].reshape(-1) 227 | idx_n_outer = idx_n[np.argwhere(mv_to_inner==0)].reshape(-1) 228 | 229 | # move to idx_n_inner and outer 230 | offset = 0.5*slab_width+2*slab_margin 231 | X[idx_n_inner, coord] = torch.Tensor(get_sign(len(idx_n_inner))*(offset+slab_width*get_unif(len(idx_n_inner)))) 232 | offset = 2.5*slab_width+6*slab_margin 233 | X[idx_n_outer, coord] = torch.Tensor(get_sign(len(idx_n_outer))*(offset+slab_width*get_unif(len(idx_n_outer)))) 234 | 235 | # pick negative slab for each positive point 236 | mv_to_ctr = get_bool(len(idx_p)) 237 | idx_p_ctr = idx_p[np.argwhere(mv_to_ctr==1)].reshape(-1) 238 | idx_p_sid = idx_p[np.argwhere(mv_to_ctr==0)].reshape(-1) 239 | 240 | # move to idx_n_inner and outer 241 | X[idx_p_ctr, coord] = torch.Tensor(0.5*get_sign(len(idx_p_ctr))*(slab_width*get_unif(len(idx_p_ctr)))) 242 | offset = 1.5*slab_width+4*slab_margin 243 | X[idx_p_sid, coord] = torch.Tensor(get_sign(len(idx_p_sid))*(offset+slab_width*get_unif(len(idx_p_sid)))) 244 | 245 | # shift 246 | X[:, coord] += slab_shift*width 247 | 248 | # reshuffle 249 | P = np.random.permutation(N) 250 | X, Y = X[P,:], Y[P] 251 | 252 | # lin coord position 253 | if not random_transform and lin_coord != 0: 254 | X[:, [0, lin_coord]] = X[:, [lin_coord, 0]] 255 | 256 | # transform 257 | W = np.eye(dim) 258 | if random_transform: W = utils.get_orthonormal_matrix(dim) 259 | X = X.dot(W) 260 | 261 | return _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=W) 262 | 263 | 264 | def generate_ub_linslab_data_v2(N_tr, dim, eff_lin_margin, eff_slab_margin=None, lin_coord=0, 265 | corrupt_lin=0., corrupt_slab=0., corrupt_slab3=0., corrupt_slab7=0., 266 | scale_noise=True, num_lin=1, lin_shift=0., slab_shift=0., random_transform=False, 267 | num_slabs=1, slabs_per_coord=5, width=10., indep_slabs=True, no_linear=False, 268 | bs=256, pm=True, nw=0, N_te=10000, corrupt_lin_margin=False, slab5_pval=3/4., 269 | slab3_pval=1/2., slab7_pval=7/8., corrupt_5slab_margin=False): 270 | slab_p_map = {5: slab5_pval, 7: slab7_pval, 3: slab3_pval} 271 | slabs_per_coord = [slabs_per_coord]*num_slabs if type(slabs_per_coord) is int else slabs_per_coord[:] 272 | for x in slabs_per_coord: assert x in slab_p_map 273 | slab_p_vals = [slab_p_map[x] for x in slabs_per_coord] 274 | lms = [eff_lin_margin]*num_lin 275 | sms = eff_slab_margin if type(eff_slab_margin) is list else [eff_slab_margin]*num_slabs 276 | return generate_ub_linslab_data_diffmargin_v2(N_tr, dim, lms, sms, slabs_per_coord, slab_p_vals, lin_coord=lin_coord, corrupt_slab=corrupt_slab, 277 | corrupt_slab7=corrupt_slab7, corrupt_lin=corrupt_lin, scale_noise=scale_noise, width=width, 278 | lin_shift=lin_shift, slab_shift=slab_shift, random_transform=random_transform, indep_slabs=indep_slabs, 279 | pm=pm, bs=bs, corrupt_lin_margin=corrupt_lin_margin, nw=nw, N_te=N_te, corrupt_5slab_margin=corrupt_5slab_margin) 280 | 281 | 282 | def get_lms_data(**kw): 283 | 284 | c = config = { 285 | 'num_train': 100_000, 286 | 'dim': 20, 287 | 'lin_margin': 0.1, 288 | 'slab_margin': 0.1, 289 | 'same_margin': False, 290 | 'random_transform': False, 291 | 'width': 1, # data width 292 | 'bs': 256, 293 | 'corrupt_lin': 0.0, 294 | 'corrupt_lin_margin': False, 295 | 'corrupt_slab': 0.0, 296 | 'num_test': 2_000, 297 | 'hdim': 200, # model width 298 | 'hl': 2, # model depth 299 | 'device': gu.get_device(0), 300 | 'input_dropout': 0, 301 | 'num_lin': 1, 302 | 'num_slabs': 19, 303 | 'num_slabs7': 0, 304 | 'num_slabs3': 0, 305 | } 306 | 307 | c.update(kw) 308 | 309 | smargin = c['lin_margin'] if c['same_margin'] else c['slab_margin'] 310 | data_func = generate_ub_linslab_data_v2 311 | spc = [3]*c['num_slabs3']+[5]*c['num_slabs'] + [7]*c['num_slabs7'] 312 | data = data_func(c['num_train'], c['dim'], c['lin_margin'], slabs_per_coord=spc, eff_slab_margin=smargin, random_transform=c['random_transform'], N_te=c['num_test'], 313 | corrupt_lin_margin=c['corrupt_lin_margin'], num_lin=c['num_lin'], num_slabs=c['num_slabs3']+c['num_slabs']+c['num_slabs7'], width=c['width'], bs=c['bs'], 314 | corrupt_lin=c['corrupt_lin'], corrupt_slab=c['corrupt_slab']) 315 | return data, c 316 | 317 | -------------------------------------------------------------------------------- /scripts/gpu_utils.py: -------------------------------------------------------------------------------- 1 | try: import pycuda.driver as cuda 2 | except: print ("pycuda not available") 3 | 4 | import torch 5 | import sys, os, glob, subprocess 6 | 7 | def get_gpu_info(print_info=True, get_specs=False): 8 | cuda.init() 9 | if get_specs: gpu_specs = cuda.Device(0).get_attributes() # assume same for all (dnnx) 10 | else: gpu_specs = None 11 | 12 | gpu_info = { 13 | 'available': torch.cuda.is_available(), 14 | 'num_devices': torch.cuda.device_count(), 15 | 'devices': set([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]), 16 | 'current device id': torch.cuda.current_device(), 17 | 'allocated memory': torch.cuda.memory_allocated(), 18 | 'cached memory': torch.cuda.memory_cached() 19 | } 20 | 21 | if print_info: 22 | for k,v in gpu_info.items(): print ("{}: {}".format(k, v)) 23 | 24 | return gpu_info, gpu_specs 25 | 26 | 27 | def get_device(device_id=None): # None -> cpu 28 | device = 'cuda:{}'.format(device_id) if device_id is not None else 'cpu' 29 | device = torch.device(device if torch.cuda.is_available() and device_id is not None else 'cpu') 30 | return device 31 | 32 | def get_gpu_name(): 33 | try: 34 | out_str = subprocess.run(["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], stdout=subprocess.PIPE).stdout 35 | out_list = out_str.decode("utf-8").split('\n') 36 | out_list = out_list[1:-1] 37 | return out_list 38 | except Exception as e: 39 | print(e) 40 | 41 | def get_cuda_version(): 42 | """Get CUDA version""" 43 | if sys.platform == 'win32': 44 | raise NotImplementedError("Implement this!") 45 | # This breaks on linux: 46 | #cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA" 47 | #path = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\version.txt" 48 | elif sys.platform == 'linux' or sys.platform == 'darwin': 49 | path = '/usr/local/cuda/version.txt' 50 | else: 51 | raise ValueError("Not in Windows, Linux or Mac") 52 | if os.path.isfile(path): 53 | with open(path, 'r') as f: 54 | data = f.read().replace('\n','') 55 | return data 56 | else: 57 | return "No CUDA in this machine" 58 | 59 | def get_cudnn_version(): 60 | """Get CUDNN version""" 61 | if sys.platform == 'win32': 62 | raise NotImplementedError("Implement this!") 63 | # This breaks on linux: 64 | #cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA" 65 | #candidates = ["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\include\\cudnn.h"] 66 | elif sys.platform == 'linux': 67 | candidates = ['/usr/include/x86_64-linux-gnu/cudnn_v[0-99].h', 68 | '/usr/local/cuda/include/cudnn.h', 69 | '/usr/include/cudnn.h'] 70 | elif sys.platform == 'darwin': 71 | candidates = ['/usr/local/cuda/include/cudnn.h', 72 | '/usr/include/cudnn.h'] 73 | else: 74 | raise ValueError("Not in Windows, Linux or Mac") 75 | for c in candidates: 76 | file = glob.glob(c) 77 | if file: break 78 | if file: 79 | with open(file[0], 'r') as f: 80 | version = '' 81 | for line in f: 82 | if "#define CUDNN_MAJOR" in line: 83 | version = line.split()[-1] 84 | if "#define CUDNN_MINOR" in line: 85 | version += '.' + line.split()[-1] 86 | if "#define CUDNN_PATCHLEVEL" in line: 87 | version += '.' + line.split()[-1] 88 | if version: 89 | return version 90 | else: 91 | return "Cannot find CUDNN version" 92 | else: 93 | return "No CUDNN in this machine" 94 | 95 | if __name__=='__main__': 96 | print ('gpu name', get_gpu_name()) 97 | print ('cuda', get_cuda_version()) 98 | print ('cudnn', get_cudnn_version()) 99 | print ('device0', get_device(0)) 100 | print ('available', torch.cuda.is_available()) -------------------------------------------------------------------------------- /scripts/lms_utils.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import gpu_utils as gu 3 | import data_utils as du 4 | import utils 5 | import random 6 | import os, copy, pickle, time 7 | import itertools 8 | from collections import defaultdict, Counter, OrderedDict 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import pandas as pd 13 | #import foolbox 14 | from torch.utils.data import TensorDataset, DataLoader 15 | from torch import optim, nn 16 | import torch.nn.functional as F 17 | from sklearn.metrics import roc_auc_score 18 | 19 | def parse_data(exps=None, root='/', **funcs_kw): 20 | """ 21 | main function (parse data files and run added functions on it) 22 | """ 23 | exps = exps if exps is not None else os.listdir(root) 24 | total = len(exps) 25 | print ("total: {}".format(total)) 26 | parsed = defaultdict(dict) 27 | if total == 0: return parsed 28 | for idx, exp in enumerate(exps): 29 | if (idx+1) % 1 == 0: print (idx+1, end=' ', flush=True) 30 | # load data 31 | fpath = os.path.join(root, exp) 32 | try: 33 | data = torch.load(fpath, map_location=lambda storage, loc: storage) 34 | except: 35 | print ("File {} corrupted, skip.".format(fpath)) 36 | continue 37 | config = data['config'] 38 | 39 | # make exp config key 40 | config['run'] = int(exp.rsplit('.', 1)[0][-1]) 41 | config['fname'] = exp 42 | ckeys = ['exp_name', 'dim', 'num_train', 'lin_margin', 'slab_margin', 'num_slabs', 43 | 'num_slabs', 'width', 'hdim', 'hl', 'linear', 'use_bn', 'run', 'fname', 44 | 'weight_decay', 'dropout'] 45 | ckeys = [c for c in ckeys if c in config] 46 | cvals = [config[k] for k in ckeys] 47 | ckv = tuple(zip(ckeys, cvals)) 48 | 49 | # save config 50 | parsed[ckv]['config'] = config 51 | 52 | # save functions 53 | for func_name, func in funcs_kw.items(): 54 | parsed[ckv][func_name] = func(data) 55 | 56 | return parsed 57 | 58 | def parse_exp_stats(data): 59 | """training summary statistics""" 60 | stats = data['stats'] 61 | s = {} 62 | 63 | # loss + accuracy 64 | for t1, t2 in itertools.product(['acc', 'loss'], ['tr', 'te']): 65 | s['{}_{}'.format(t1,t2)] = stats['{}_{}'.format(t1, t2)][-1] 66 | s['orig_stats'] = stats 67 | s['acc_gap'] = s['acc_tr']-s['acc_te'] 68 | s['loss_gap'] = s['loss_te']-s['loss_tr'] 69 | s['fin_acc_te'] = stats['acc_te'][-1] 70 | s['fin_acc_tr'] = stats['acc_tr'][-1] 71 | 72 | # updates 73 | s['update_gap'] = stats['update_gap'] 74 | s['num_updates'] = stats['num_updates'] 75 | 76 | # effective number of updates 77 | for acc_threshold in [0.96, 0.97, 0.98, 0.99, 1]: 78 | eff = np.argmin(np.abs(np.array(stats['acc_tr'])-acc_threshold))*s['update_gap'] 79 | s['effective_num_updates{}'.format(int(acc_threshold*100))] = eff 80 | 81 | return s 82 | 83 | def parse_exp_model(data): 84 | """model parameter stats""" 85 | depth = data['config']['hl'] 86 | linear = data['config']['linear'] 87 | mtype = data['config'].get('mtype', 'fcn') 88 | if mtype == 'fcn' and depth == 1 and not linear: d = parse_exp_depth1_model(data) 89 | if mtype == 'fcn' and depth == 1 and linear: d = parse_exp_linear_model(data) 90 | return {} 91 | 92 | def parse_exp_depth1_model(data): 93 | """cosine + w2""" 94 | device = gu.get_device() 95 | model = data['model'].to(device) 96 | p = W1, b1, w2, b2 = list(map(lambda x: x.detach().numpy(), model.parameters())) 97 | s = {} 98 | s['params'] = p 99 | s['cosine'] = W1[:, 0]/np.linalg.norm(W1, axis=1) 100 | s['l2'] = np.linalg.norm(W1, axis=1) 101 | s['w2'] = w2 102 | s['corr0'] = np.corrcoef(s['cosine'], w2[0, :])[0,1] 103 | s['corr1'] = np.corrcoef(s['cosine'], w2[1, :])[0,1] 104 | s['max_weight_cosine'] = s['cosine'][np.argmax(s['w2'][1,:])] 105 | return s 106 | 107 | def parse_exp_linear_model(data): 108 | """cosine""" 109 | device = gu.get_device() 110 | model = data['model'].to(device) 111 | p = W,b = list(map(lambda x: x.detach().numpy(), model.parameters())) 112 | s = {} 113 | s['cosine0'], s['cosine1'] = W[:, 0]/np.linalg.norm(W, axis=1) 114 | return s 115 | 116 | def parse_exp_data(data, load_X=False): 117 | s = {} 118 | model = data['model'].to(gu.get_device()) 119 | data = data['data'] 120 | X, Y = data['X'], data['Y'] 121 | 122 | if type(X) != np.ndarray: 123 | X = data['X'].detach().cpu() 124 | 125 | if type(X) != np.ndarray: 126 | Y = data['Y'].detach().cpu() 127 | 128 | s['Y'] = Y 129 | if load_X: s['X'] = X 130 | s['Y_'] = get_yhat(model, X) 131 | s['model'] = model 132 | return s 133 | 134 | def get_yhat(model, data): 135 | if type(data)==np.ndarray: data = torch.Tensor(data) 136 | return torch.argmax(model(data), 1) 137 | 138 | def get_acc(y,yhat): 139 | n = float(len(y)) 140 | return (y==yhat).sum().item()/n 141 | 142 | def parse_and_get_df(root, prefix, files=None, device_id=None, only_load=False, only_linear=False, sample_pct=0.5, load_X=False, use_model_pred=False): 143 | exps = files if files is not None else [f for f in os.listdir(root) if f.startswith(prefix)] 144 | 145 | funcs = { 146 | 'config': lambda d: d['config'], 147 | 'stats': parse_exp_stats, 148 | 'model': parse_exp_model, 149 | 'data': lambda x: parse_exp_data(x, load_X=load_X), 150 | 'random_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='random', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id), 151 | 'swap_dep': lambda d: get_feature_deps(d['data']['te_dl'], d['model'], only_linear=only_linear, W=d['data'].get('W', None), dep_type='swap', use_model_pred=use_model_pred, print_info=False, sample_pct=sample_pct, device_id=device_id), 152 | } 153 | 154 | P = parse_data(root=root, exps=exps, **funcs) 155 | if only_load: return P 156 | 157 | D = [] 158 | for idx, (k,v) in enumerate(P.items(),1): 159 | d = OrderedDict() 160 | for a,b in k: d[a] = b 161 | for vk in ['model', 'data', 'stats', 'config']: 162 | for a,b in v[vk].items(): d[a] = b 163 | for vk in ['random_dep', 'swap_dep']: 164 | for coord, dep in v[vk].items(): 165 | d[f'{vk[0]}dep_{coord}'] = dep 166 | D.append(d) 167 | 168 | df = pd.DataFrame(D) 169 | if len(df): df['nd'] = df['num_train']/df['dim'] 170 | return df 171 | 172 | def viz(d, c1, c2, k=80_000, info=True, plot_dm=True, plot_data=True, use_yhat=False, unif_k=False, width=10, title=None, is_binary=False, dep_type='swap', ax=None): 173 | if 'W' not in d['data']: W = np.eye(d['config']['dim']) 174 | else: W = d['data']['W'] 175 | if W is None: W = np.eye(d['config']['dim']) 176 | 177 | z = parse_exp_data(d) 178 | X = d['data']['X'] 179 | 180 | # visualize un-transformed data... 181 | X_ = np.array(X).dot(W.T) 182 | Y, Y_ = z['Y'], z['Y_'] 183 | model = d['model'].cpu() 184 | D = X.shape[1] 185 | kn = k if unif_k else len(X) 186 | K = torch.Tensor(np.random.uniform(size=(k, D)))*width if unif_k else np.array(X_) 187 | K[:, c1] = torch.Tensor(np.random.uniform(low=min(X_[:,c1]), high=max(X_[:,c1]), size=kn)) 188 | K[:, c2] = torch.Tensor(np.random.uniform(low=min(X_[:,c2]), high=max(X_[:,c2]), size=kn)) 189 | KO = model(torch.Tensor(np.array(K).dot(W))) 190 | if is_binary: KY = (KO > 0).squeeze().numpy() 191 | else: KY = torch.argmax(KO, 1).numpy() 192 | 193 | if info: 194 | deps = get_feature_deps(d['data']['te_dl'], d['model'], W=d['data'].get('W', None), dep_type=dep_type) 195 | for k,v in sorted(deps.items(), reverse=False, key=lambda t: t[-1]): print ('{}:{:.3f}'.format(k,v), end=' ') 196 | print ("\n") 197 | 198 | if ax is None: fig, ax = plt.subplots(1,1,figsize=(6,4)) 199 | 200 | if plot_dm: ax.scatter(K[:, c1], K[:, c2], c=KY, cmap='binary', s=8, alpha=.2) 201 | if plot_data: ax.scatter(X_[:, c1], X_[:, c2], c=Y_ if use_yhat else Y, cmap='coolwarm', s=8, alpha=.4) 202 | 203 | ax.set_xlabel('e_{}'.format(c1)) 204 | ax.set_ylabel('e_{}'.format(c2)) 205 | ax.set_title(title if title else '') 206 | plt.tight_layout() 207 | return ax 208 | 209 | def visualize_boundary(model, data, c1, c2, dim, ax=None, is_binary=False, use_yhat=False, width=1, unif_k=True, k=100_000, print_info=True, dep_type='random'): 210 | agg = {'model': model, 'data': data, 'config': dict(dim=dim)} 211 | return viz(agg, c1, c2, unif_k=unif_k, width=width, dep_type=dep_type, is_binary=is_binary, use_yhat=use_yhat, ax=ax, info=print_info) 212 | 213 | def get_randomized_loader(dl, W, coordinates): 214 | """ 215 | dl: dataloader 216 | W: rotation matrix 217 | coordinates: list of coordinates to randomize 218 | output: randomized dataloader 219 | """ 220 | 221 | def _randomize(X, coords): 222 | p = torch.randperm(len(X)) 223 | for c in coords: X[:, c] = X[p, c] 224 | return X 225 | 226 | # rotate data 227 | X, Y = map(copy.deepcopy, dl.dataset.tensors) 228 | dim = X.shape[1] 229 | if W is None: W = np.eye(dim) 230 | 231 | rt_X = torch.Tensor(X.numpy().dot(W.T)) 232 | rand_rt_X = _randomize(rt_X, coordinates) 233 | rand_X = torch.Tensor(rand_rt_X.numpy().dot(W)) 234 | 235 | return utils._to_dl(rand_X, Y, dl.batch_size) 236 | 237 | 238 | def get_feature_deps(dl, model, W=None, dep_type='random', only_linear=False, coords=None, metric='accuracy', 239 | use_model_pred=False, print_info=False, sample_pct=1.0, device_id=None): 240 | """Compute feature dependencies using randomization or swapping""" 241 | def _randomize(X, Y, coords): 242 | p = torch.randperm(len(X)) 243 | for c in coords: X[:, c] = X[p, c] 244 | return X 245 | 246 | def _swap(X, Y, coords): 247 | idx0, idx1 = map(lambda c: (Y.numpy()==c).nonzero()[0], [0, 1]) 248 | idx0_new = np.random.choice(idx1, size=len(idx0), replace=True) 249 | idx1_new = np.random.choice(idx0, size=len(idx1), replace=True) 250 | for c in coords: X[idx0, c], X[idx1, c] = X[idx0_new, c], X[idx1_new, c] 251 | return X 252 | 253 | def _get_dep_data(X, Y, coords): 254 | return dict(random=_randomize, swap=_swap)[dep_type](X, Y, coords) 255 | 256 | 257 | assert metric in {'accuracy', 'loss', 'auc'} 258 | 259 | # setup data 260 | device = gu.get_device(device_id) 261 | model = model.to(device) 262 | X, Y = map(lambda Z: Z.to(device), dl.dataset.tensors) 263 | Yh = get_yhat(model, X) 264 | dim = X.shape[1] 265 | if W is None: W = np.eye(dim) 266 | W = torch.Tensor(W).to(device) 267 | rt_X = torch.mm(X, torch.transpose(W,0,1)) 268 | 269 | # subsample data 270 | n_samp = int(round(sample_pct*len(rt_X))) 271 | perm = torch.randperm(len(rt_X))[:n_samp] 272 | rt_X, Y, Yh = rt_X[perm, :], Y[perm], Yh[perm] 273 | 274 | # compute deps 275 | deps = {} 276 | 277 | dims = list(range(dim)) 278 | if coords is None and not only_linear: coords = dims 279 | if coords is None and only_linear: coords = [0,1] 280 | 281 | for idx, coord in enumerate(coords): 282 | if print_info: print ('{}/{}'.format(idx, len(coords)), end=' ') 283 | rt_X_ = copy.deepcopy(rt_X).to(device) 284 | rt_X_ = _get_dep_data(rt_X_, Y, coord if type(coord) in (list, tuple) else [coord]) 285 | X_ = torch.mm(rt_X_, W) 286 | Ys = get_yhat(model, X_) 287 | 288 | key = tuple(coord) if type(coord) in (list, tuple) else coord 289 | 290 | if metric == 'auc': 291 | L = utils.get_logits_given_tensor(X_, model, device=device, bs=250) 292 | S = L[:,1]-L[:,0] 293 | auc = roc_auc_score(Y.cpu().numpy(), S.cpu().numpy()) 294 | deps[key] = auc 295 | elif metric == 'accuracy': 296 | deps[key] = get_acc(Yh if use_model_pred else Y, Ys) 297 | elif metric == 'loss': 298 | L = utils.get_logits_given_tensor(X_, model, device=device, bs=250) 299 | with torch.no_grad(): 300 | loss_val = F.cross_entropy(L, Y).item() 301 | deps[key] = loss_val 302 | 303 | return deps 304 | 305 | def get_subset_feature_deps(dl, model, coords_set, comb_size, W=None, dep_type='random', sample_pct=0.5, device_id=None, print_info=False): 306 | coords = list(itertools.combinations(coords_set, comb_size)) 307 | return get_feature_deps(dl, model, W=W, dep_type=dep_type, coords=coords, print_info=print_info, sample_pct=sample_pct, device_id=device_id) 308 | -------------------------------------------------------------------------------- /scripts/mnistcifar_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os, copy, pickle, time 3 | import itertools 4 | from collections import defaultdict, Counter, OrderedDict 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import TensorDataset, DataLoader 8 | import utils 9 | import gpu_utils as gu 10 | import data_utils as du 11 | 12 | def get_binary_mnist(y1=0, y2=1, apply_padding=True, repeat_channels=True): 13 | 14 | def _make_cifar_compatible(X): 15 | if apply_padding: X = np.stack([np.pad(X[i][0], 2)[None,:] for i in range(len(X))]) # pad 16 | if repeat_channels: X = np.repeat(X, 3, axis=1) # add channels 17 | return X 18 | 19 | binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2) 20 | 21 | tr_dl, te_dl = du.get_mnist_dl(normalize=False) 22 | Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl)) 23 | Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl)) 24 | Xtr, Xte = map(_make_cifar_compatible, [Xtr, Xte]) 25 | return (Xtr, Ytr), (Xte, Yte) 26 | 27 | def get_binary_cifar(y1=3, y2=5, c={0,1,2,3,4}, use_cifar10=True): 28 | binarize = lambda X,Y: du.get_binary_datasets(X, Y, y1=y1, y2=y2) 29 | binary = False if y1 is not None and y2 is not None else True 30 | if binary: print ("grouping cifar classes") 31 | tr_dl, te_dl = du.get_cifar_dl(use_cifar10=use_cifar10, shuffle=False, normalize=False, binarize=binary, y0=c) 32 | 33 | Xtr, Ytr = binarize(*utils.extract_numpy_from_loader(tr_dl)) 34 | Xte, Yte = binarize(*utils.extract_numpy_from_loader(te_dl)) 35 | return (Xtr, Ytr), (Xte, Yte) 36 | 37 | def combine_datasets(Xm, Ym, Xc, Yc, randomize_order=False, randomize_first_block=False, randomize_second_block=False): 38 | """combine two datasets""" 39 | 40 | def partition(X, Y, randomize=False): 41 | """partition randomly or using labels""" 42 | if randomize: 43 | n = len(Y) 44 | p = np.random.permutation(n) 45 | ni, pi = p[:n//2], p[n//2:] 46 | else: 47 | ni, pi = (Y==0).nonzero()[0], (Y==1).nonzero()[0] 48 | return X[pi], X[ni] 49 | 50 | def _combine(X1, X2): 51 | """concatenate images from two sources""" 52 | X = [] 53 | for i in range(min(len(X1), len(X2))): 54 | x1, x2 = X1[i], X2[i] 55 | # randomize order 56 | if randomize_order and random.random() < 0.5: 57 | x1, x2 = x2, x1 58 | x = np.concatenate((x1,x2), axis=1) 59 | X.append(x) 60 | return np.stack(X) 61 | 62 | Xmp, Xmn = partition(Xm, Ym, randomize=randomize_first_block) 63 | Xcp, Xcn = partition(Xc, Yc, randomize=randomize_second_block) 64 | n = min(map(len, [Xmp, Xmn, Xcp, Xcn])) 65 | Xmp, Xmn, Xcp, Xcn = map(lambda Z: Z[:n], [Xmp, Xmn, Xcp, Xcn]) 66 | 67 | Xp = _combine(Xmp, Xcp) 68 | Yp = np.ones(len(Xp)) 69 | 70 | Xn = _combine(Xmn, Xcn) 71 | Yn = np.zeros(len(Xn)) 72 | 73 | X = np.concatenate([Xp, Xn], axis=0) 74 | Y = np.concatenate([Yp, Yn], axis=0) 75 | P = np.random.permutation(len(X)) 76 | X, Y = X[P], Y[P] 77 | return X, Y 78 | 79 | def get_mnist_cifar(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4}, 80 | randomize_mnist=False, randomize_cifar=False): 81 | 82 | y1, y2 = mnist_classes 83 | (Xtrm, Ytrm), (Xtem, Ytem) = get_binary_mnist(y1=y1, y2=y2) 84 | 85 | y1, y2 = (None, None) if cifar_classes is None else cifar_classes 86 | (Xtrc, Ytrc), (Xtec, Ytec) = get_binary_cifar(c=c, y1=y1, y2=y2) 87 | 88 | Xtr, Ytr = combine_datasets(Xtrm, Ytrm, Xtrc, Ytrc, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar) 89 | Xte, Yte = combine_datasets(Xtem, Ytem, Xtec, Ytec, randomize_first_block=randomize_mnist, randomize_second_block=randomize_cifar) 90 | return (Xtr, Ytr), (Xte, Yte) 91 | 92 | def get_mnist_cifar_dl(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4}, bs=256, 93 | randomize_mnist=False, randomize_cifar=False): 94 | (Xtr, Ytr), (Xte, Yte) = get_mnist_cifar(mnist_classes=mnist_classes, cifar_classes=cifar_classes, 95 | c=c, randomize_mnist=randomize_mnist, randomize_cifar=randomize_cifar) 96 | tr_dl = utils._to_dl(Xtr, Ytr, bs=bs, shuffle=True) 97 | te_dl = utils._to_dl(Xte, Yte, bs=100, shuffle=False) 98 | return tr_dl, te_dl -------------------------------------------------------------------------------- /scripts/ptb_utils.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | import utils 3 | import random 4 | import os, copy, pickle, time 5 | import itertools 6 | from collections import defaultdict, Counter, OrderedDict 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import pandas as pd 11 | from torch.utils.data import TensorDataset, DataLoader 12 | from torch import optim, nn 13 | import torch.nn.functional as F 14 | 15 | import gpu_utils as gu 16 | import data_utils as du 17 | import synth_models 18 | 19 | #import foolbox as fb 20 | #from autoattack import AutoAttack 21 | 22 | # Misc 23 | def get_yhat(model, data): return torch.argmax(model(data), 1) 24 | def get_acc(y,yhat): return (y==yhat).sum().item()/float(len(y)) 25 | 26 | class PGD_Attack(object): 27 | 28 | def __init__(self, eps, lr, num_iter, loss_type, rand_eps=1e-3, 29 | num_classes=2, bounds=(0.,1.), minimal=False, restarts=1, device=None): 30 | self.eps = eps 31 | self.lr = lr 32 | self.num_iter = num_iter 33 | self.B = bounds 34 | self.restarts = restarts 35 | self.rand_eps = rand_eps 36 | self.device = device or gu.get_device(None) 37 | self.loss_type = loss_type 38 | self.num_classes = num_classes 39 | self.classes = list(range(self.num_classes)) 40 | self.delta = None 41 | self.minimal = minimal # early stop + no eps 42 | self.project = not self.minimal 43 | self.loss = -np.inf 44 | 45 | def evaluate_attack(self, dl, model): 46 | model = model.to(self.device) 47 | Xa, Ya, Yh, P = [], [], [], [] 48 | 49 | for xb, yb in dl: 50 | xb, yb = xb.to(self.device), yb.to(self.device) 51 | delta = self.perturb(xb, yb, model) 52 | xba = xb+delta 53 | 54 | with torch.no_grad(): 55 | out = model(xba).detach() 56 | yh = torch.argmax(out, dim=1) 57 | xb, yb, yh, xba, delta = xb.cpu(), yb.cpu(), yh.cpu(), xba.cpu(), delta.cpu() 58 | 59 | Ya.append(yb) 60 | Yh.append(yh) 61 | Xa.append(xba) 62 | P.append(delta) 63 | 64 | Xa, Ya, Yh, P = map(torch.cat, [Xa, Ya, Yh, P]) 65 | ta_dl = utils._to_dl(Xa, Ya, dl.batch_size) 66 | acc, loss = utils.compute_loss_and_accuracy_from_dl(ta_dl, model, 67 | F.cross_entropy, 68 | device=self.device) 69 | return { 70 | 'acc': acc.item(), 71 | 'loss': loss.item(), 72 | 'ta_dl': ta_dl, 73 | 'Xa': Xa.numpy(), 74 | 'Ya': Ya.numpy(), 75 | 'Yh': Yh.numpy(), 76 | 'P': P.numpy() 77 | } 78 | 79 | def perturb(self, xb, yb, model, cpu=False): 80 | model, xb, yb = model.to(self.device), xb.to(self.device), yb.to(self.device) 81 | if self.eps == 0: return torch.zeros_like(xb) 82 | 83 | # compute perturbations and track best perturbations 84 | self.loss = -np.inf 85 | max_delta = self._perturb_once(xb, yb, model) 86 | 87 | with torch.no_grad(): 88 | out = model(xb+max_delta) 89 | max_loss = nn.CrossEntropyLoss(reduction='none')(out, yb) 90 | 91 | for _ in range(self.restarts-1): 92 | delta = self._perturb_once(xb, yb, model) 93 | 94 | with torch.no_grad(): 95 | out = model(xb+delta) 96 | all_loss = nn.CrossEntropyLoss(reduction='none')(out, yb) 97 | 98 | loss_flag = all_loss >= max_loss 99 | max_delta[loss_flag] = delta[loss_flag] 100 | max_loss = torch.max(max_loss, all_loss) 101 | 102 | if cpu: max_delta = max_delta.cpu() 103 | return max_delta 104 | 105 | def _perturb_once(self, xb, yb, model, track_scores=False, stop_const=1e-5): 106 | self.delta = self._init_delta(xb, yb) 107 | scores = [] 108 | 109 | # (minimal) mask perturbations if model already misclassifies 110 | for t in range(self.num_iter): 111 | loss, out = self._get_loss(xb, yb, model, get_scores=True) 112 | 113 | if self.minimal: 114 | yh = torch.argmax(out, dim=1).detach() 115 | not_flipped = yh == yb 116 | not_flipped_ratio = not_flipped.sum().item()/float(len(yb)) 117 | else: 118 | not_flipped = None 119 | not_flipped_ratio = 1.0 120 | 121 | # stop if almost all examples in the batch misclassified 122 | if not_flipped_ratio < stop_const: 123 | break 124 | 125 | if track_scores: 126 | scores.append(out.detach().cpu().numpy()) 127 | 128 | # compute loss, update + clamp delta 129 | loss.backward() 130 | self.loss = max(self.loss, loss.item()) 131 | 132 | self.delta = self._update_delta(xb, yb, update_mask=not_flipped) 133 | self.delta = self._clamp_input(xb, yb) 134 | 135 | d = self.delta.detach() 136 | 137 | if track_scores: 138 | scores = np.stack(scores).swapaxes(0, 1) 139 | return d, scores 140 | 141 | return d 142 | 143 | def _init_delta(self, xb, yb): 144 | delta = torch.empty_like(xb) 145 | delta = delta.uniform_(-self.rand_eps, self.rand_eps) 146 | delta = delta.to(self.device) 147 | delta.requires_grad = True 148 | return delta 149 | 150 | def _clamp_input(self, xb, yb): 151 | # clamp delta s.t. X+delta in valid input range 152 | self.delta.data = torch.max(self.B[0]-xb, 153 | torch.min(self.B[1]-xb, 154 | self.delta.data)) 155 | return self.delta 156 | 157 | def _get_loss(self, xb, yb, model, get_scores=False): 158 | out = model(xb+self.delta) 159 | 160 | if self.loss_type == 'untargeted': 161 | L = -1*F.cross_entropy(out, yb) 162 | 163 | elif self.loss_type == 'targeted': 164 | L = nn.CrossEntropyLoss()(out, yb) 165 | 166 | elif self.loss_type == 'random_targeted': 167 | rand_yb = torch.randint(low=0, high=self.num_classes, size=(len(yb),), device=self.device) 168 | #rand_yb[rand_yb==yb] = (yb[rand_yb==yb]+1) % self.num_classes 169 | L = nn.CrossEntropyLoss()(out, rand_yb) 170 | 171 | elif self.loss_type == 'plusone_targeted': 172 | next_yb = (yb+1) % self.num_classes 173 | L = nn.CrossEntropyLoss()(out, next_yb) 174 | 175 | elif self.loss_type == 'binary_targeted': 176 | yb_opp = 1-yb 177 | L = nn.CrossEntropyLoss()(out, yb_opp) 178 | 179 | elif self.loss_type == 'binary_hybrid': 180 | yb_opp = 1-yb 181 | L = nn.CrossEntropyLoss()(out, yb_opp) - nn.CrossEntropyLoss()(out, yb) 182 | 183 | else: 184 | assert False, "unknown loss type" 185 | 186 | if get_scores: return L, out 187 | return L 188 | 189 | class L2_PGD_Attack(PGD_Attack): 190 | 191 | OVERFLOW_CONST = 1e-10 192 | 193 | def get_norms(self, X): 194 | nch = len(X.shape) 195 | return X.view(X.shape[0], -1).norm(dim=1)[(...,) + (None,)*(nch-1)] 196 | 197 | def _update_delta(self, xb, yb, update_mask=None): 198 | # normalize gradients 199 | grad = self.delta.grad.detach() 200 | norms = self.get_norms(grad) 201 | grad = grad/(norms+self.OVERFLOW_CONST) # add const to avoid overflow 202 | 203 | # steepest descent 204 | if self.minimal and update_mask is not None: 205 | um = update_mask 206 | self.delta.data[um] = self.delta.data[um] - self.lr*grad[um] 207 | else: 208 | self.delta.data = self.delta.data - self.lr*grad 209 | 210 | # l2 ball projection 211 | if self.project: 212 | delta_norms = self.get_norms(self.delta.data) 213 | self.delta.data = self.eps*self.delta.data / (delta_norms.clamp(min=self.eps)) 214 | 215 | self.delta.grad.zero_() 216 | return self.delta 217 | 218 | def _init_delta(self, xb, yb): 219 | # random vector with L2 norm rand_eps 220 | delta = torch.zeros_like(xb) 221 | delta = delta.uniform_(-self.rand_eps, self.rand_eps) 222 | delta_norm = self.get_norms(delta) 223 | delta = self.rand_eps*delta/(delta_norm+self.OVERFLOW_CONST) 224 | delta = delta.to(self.device) 225 | delta.requires_grad = True 226 | return delta 227 | 228 | class Linf_PGD_Attack(PGD_Attack): 229 | 230 | def _update_delta(self, xb, yb, **kw): 231 | # steepest descent + linf projection (GD) 232 | self.delta.data = self.delta.data - self.lr*(self.delta.grad.detach().sign()) 233 | self.delta.data = self.delta.data.clamp(-self.eps, self.eps) 234 | self.delta.grad.zero_() 235 | return self.delta 236 | 237 | 238 | # UAP methods 239 | 240 | class AS_UAP(object): 241 | """ 242 | UAP method (Algorithm 2) in Universal Adversarial Training paper (https://arxiv.org/abs/1811.11304) 243 | not using clipped version to avoid hyper-parameter tuning 244 | (even with tuning, improvement is marginal) 245 | """ 246 | 247 | def __init__(self, eps, lr, num_iter, shape, num_classes=2, bounds=(0.,1.), 248 | loss_type='untargeted', rand_eps=0., device=None): 249 | self.device = device if device else gu.get_device(None) 250 | self.loss_type = loss_type 251 | self.B = bounds 252 | self.rand_eps = rand_eps 253 | self.eps = eps 254 | self.lr = lr 255 | self.num_iter = num_iter 256 | self.num_classes = num_classes 257 | self.classes = list(range(self.num_classes)) 258 | self.shape = shape 259 | self._init_delta() 260 | 261 | @property 262 | def uap(self): 263 | return copy.deepcopy(self.delta.detach().cpu()).numpy() 264 | 265 | def fit(self, dl, model, num_epochs): 266 | # compute uap 267 | model = model.to(self.device) 268 | for t in range(num_epochs): 269 | for xb, yb in dl: 270 | xb, yb = xb.to(self.device), yb.to(self.device) 271 | 272 | # update + project + clamp delta t times 273 | for t in range(self.num_iter): 274 | loss = self._get_loss(xb, yb, model, True) 275 | loss.backward() 276 | self.delta = self._update_delta(xb, yb) 277 | 278 | def evaluate_attack(self, dl, model, **kw): 279 | model = model.to(self.device) 280 | X, Xa, Ya, P = [], [], [], [] 281 | 282 | for xb, yb in dl: 283 | xb, yb = xb.to(self.device), yb.to(self.device) 284 | xba, ptb = self._perturb(xb, yb, False) 285 | X.append(xb.cpu()) 286 | Xa.append(xba.cpu()) 287 | Ya.append(yb.cpu()) 288 | P.append(ptb.cpu()) 289 | 290 | X, Xa, Ya, P = map(torch.cat, [X, Xa, Ya, P]) 291 | ta_dl = utils._to_dl(Xa, Ya, dl.batch_size) 292 | acc_func = utils.compute_loss_and_accuracy_from_dl 293 | acc, loss = acc_func(ta_dl, model, F.cross_entropy, device=self.device) 294 | 295 | return { 296 | 'P': P, 297 | 'X': X, 298 | 'Xa': Xa, 299 | 'Ya': Ya, 300 | 'acc': acc.item(), 301 | 'loss': loss.item(), 302 | 'dl': ta_dl 303 | } 304 | 305 | def _get_loss(self, xb, yb, model, train_mode): 306 | xba, delta = self._perturb(xb, yb, train_mode=train_mode) 307 | out = model(xba) 308 | 309 | if self.loss_type == 'untargeted': 310 | return -1*nn.CrossEntropyLoss()(out, yb) 311 | 312 | elif self.loss_type == 'targeted': 313 | return nn.CrossEntropyLoss()(out, yb) 314 | 315 | elif self.loss_type == 'binary_targeted': 316 | yb_opp = 1-yb 317 | return nn.CrossEntropyLoss()(out, yb_opp) 318 | 319 | elif self.loss_type == 'binary_hybrid': 320 | yb_opp = 1-yb 321 | return nn.CrossEntropyLoss()(out, yb_opp) - nn.CrossEntropyLoss()(out, yb) 322 | 323 | else: 324 | assert False, "unknown loss type" 325 | 326 | def _perturb(self, xb, yb, train_mode): 327 | # broadcast clamped + scaled + signed-if-binary UAPs 328 | d = self.delta if train_mode else self.delta.data 329 | sign = ((2*yb-1)*1.0) if self.num_classes == 2 else torch.ones(len(yb)).to(self.device) 330 | sign = sign[(...,)+(None,)*(len(xb.shape)-1)].float() 331 | delta = torch.zeros_like(xb, device=self.device) 332 | delta = sign*(delta+d) 333 | 334 | # perturb and re-clamp data 335 | xba = (xb + delta).clamp(self.B[0], self.B[1]) 336 | delta = xba-xb 337 | return xba, delta 338 | 339 | def _init_delta(self): 340 | delta = torch.zeros(*self.shape) 341 | delta = delta.uniform_(-self.rand_eps, self.rand_eps).to(self.device) 342 | delta.requires_grad = True 343 | self.delta = delta 344 | 345 | class Linf_AS_UAP(AS_UAP): 346 | 347 | def _update_delta(self, xb, yb): 348 | # steepest descent + linf projection (GD) 349 | self.delta.data = self.delta.data - self.lr*(self.delta.grad.detach().sign()) 350 | self.delta.data = self.delta.data.clamp(-self.eps, self.eps) 351 | self.delta.grad.zero_() 352 | return self.delta 353 | 354 | class L2_AS_UAP(AS_UAP): 355 | 356 | OVERFLOW_CONST = 1e-10 357 | 358 | def _update_delta(self, xb, yb): 359 | # normalize gradients 360 | grad = self.delta.grad.detach() 361 | norms = grad.norm() 362 | grad = grad/(norms+self.OVERFLOW_CONST) # add const to avoid overflow 363 | 364 | # steepest descent 365 | self.delta.data = self.delta.data - self.lr*grad 366 | 367 | # l2 ball projection 368 | delta_norms = self.delta.data.norm() 369 | self.delta.data = self.eps*self.delta.data / (delta_norms.clamp(min=self.eps)) 370 | 371 | self.delta.grad.zero_() 372 | return self.delta 373 | 374 | 375 | class SVD_UAP(object): 376 | # based on https://arxiv.org/abs/2005.08632 377 | 378 | def __init__(self, attack, bounds=(0.,1.), device=None, num_classes=2): 379 | self.device = device if device else gu.get_device(None) 380 | self.attack = attack 381 | self.attack.device = self.device 382 | self.B = bounds 383 | self.num_classes = num_classes 384 | 385 | def fit(self, dl, model): 386 | # get perturbations 387 | model = model.to(self.device) 388 | pdata = self.attack.evaluate_attack(dl, model) 389 | 390 | self.p_acc = pdata['acc'] 391 | shape = list(pdata['Xa'].shape) 392 | self.num_imgs, self.img_shape = shape[0], shape[1:] 393 | 394 | # run SVD 395 | P = pdata['P'].reshape(pdata['P'].shape[0], -1) 396 | self.P = P / np.linalg.norm(P, axis=1)[:, None] 397 | U, self.S, self.VH = np.linalg.svd(P) 398 | del U 399 | 400 | # setup UAP 401 | self.uaps = self.VH.reshape(self.VH.shape[0], *self.img_shape) 402 | self.uap = self.uaps[0] 403 | 404 | def evaluate_attack(self, dl, model, eps, kth=0, **kw): 405 | self.delta = torch.FloatTensor(self.uaps[kth]).to(self.device) 406 | eval1 = self._eval(dl, model, eps, 1.0) 407 | eval2 = self._eval(dl, model, eps, -1.0) 408 | if eval1['acc'] < eval2['acc']: return eval1 409 | return eval2 410 | 411 | def _eval(self, dl, model, eps, pos_dir): 412 | model = model.to(self.device) 413 | X, Xa, Ya, P = [], [], [], [] 414 | 415 | for xb, yb in dl: 416 | xb, yb = xb.to(self.device), yb.to(self.device) 417 | xba, ptb = self._perturb(xb, yb, eps, pos_dir) 418 | X.append(xb.cpu()) 419 | Xa.append(xba.cpu()) 420 | Ya.append(yb.cpu()) 421 | P.append(ptb.cpu()) 422 | 423 | X, Xa, Ya, P = map(torch.cat, [X, Xa, Ya, P]) 424 | ta_dl = utils._to_dl(Xa, Ya, dl.batch_size) 425 | acc_func = utils.compute_loss_and_accuracy_from_dl 426 | acc, loss = acc_func(ta_dl, model, F.cross_entropy, device=self.device) 427 | 428 | return { 429 | 'P': P, 430 | 'X': X, 431 | 'Xa': Xa, 432 | 'Ya': Ya, 433 | 'acc': acc.item(), 434 | 'loss': loss.item(), 435 | 'dl': ta_dl, 436 | 'pos_dir': pos_dir 437 | } 438 | 439 | def _perturb(self, xb, yb, eps, pos_dir): 440 | nch = len(xb.shape) 441 | 442 | # broadcast clamped + scaled + signed UAPs 443 | sign = ((2*yb-1)*(1.0)) if self.num_classes == 2 else torch.ones(len(yb)).to(self.device) 444 | sign = sign[(...,)+(None,)*(len(xb.shape)-1)].float() 445 | 446 | delta = torch.zeros_like(xb, device=self.device) 447 | delta = eps*pos_dir*sign*(delta + self.delta) 448 | 449 | # perturb and re-clamp data 450 | xba = xb + delta 451 | xba = xba.clamp(self.B[0], self.B[1]) 452 | delta = xba-xb 453 | return xba, delta 454 | 455 | -------------------------------------------------------------------------------- /scripts/synth_models.py: -------------------------------------------------------------------------------- 1 | import sys, copy 2 | import torch, torchvision 3 | from torch import optim, nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import TensorDataset, DataLoader 6 | import gendata 7 | import utils 8 | import numpy as np 9 | import gpu_utils as gu 10 | import ptb_utils as pu 11 | 12 | def kaiming_init(m): 13 | if isinstance(m, nn.Linear): 14 | nn.init.kaiming_uniform_(m.weight.data) 15 | nn.init.kaiming_uniform_(m.bias.data) 16 | 17 | class SequenceClassifier(nn.Module): 18 | 19 | def __init__(self, seq_model, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True): 20 | super(SequenceClassifier, self).__init__() 21 | self.seq_model = seq_model 22 | self.hdim = hdim 23 | self.hl = hl 24 | self.input_size = input_size 25 | self.idim = idim 26 | self.num_classes = num_classes 27 | self.unsqueeze_input = unsqueeze_input 28 | self.many_to_many = many_to_many 29 | 30 | self.seq_length = self.idim//self.input_size 31 | self.seq = self.seq_model(input_size=input_size, hidden_size=hdim, num_layers=hl, batch_first=True) 32 | self.lin_idim = hdim*self.seq_length if many_to_many else hdim 33 | self.lin = nn.Linear(self.lin_idim, num_classes) 34 | 35 | def forward(self, x): 36 | if self.unsqueeze_input: x = x.unsqueeze(2) 37 | bsize, idim, _ = x.shape 38 | seq_length = idim//self.input_size 39 | x = x.view((bsize, seq_length, self.input_size)) 40 | out, hidden = self.seq(x) 41 | lin_in = out[:,-1,:] 42 | if self.many_to_many: lin_in = out.contiguous().view((bsize, -1)) 43 | lin_out = self.lin(lin_in) 44 | return lin_out 45 | 46 | class GRUClassifier(SequenceClassifier): 47 | 48 | def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True): 49 | super(GRUClassifier, self).__init__(nn.GRU, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input) 50 | 51 | class LSTMClassifier(SequenceClassifier): 52 | 53 | def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_many=False, unsqueeze_input=True): 54 | super(LSTMClassifier, self).__init__(nn.LSTM, idim, hdim, hl, input_size, many_to_many=many_to_many, num_classes=num_classes, unsqueeze_input=unsqueeze_input) 55 | 56 | class CNNClassifier(nn.Module): 57 | 58 | def __init__(self, out_channels, hl, kernel_size, idim, num_classes=2, padding=None, stride=1, maxpool_kernel_size=None, use_maxpool=False): 59 | """ 60 | Fixed architecture: 61 | - default max pool kernel size half of convolution kernel size 62 | - default padding = kernel size - 1 // 2 to maintain same dimension 63 | - stride = 1 64 | - 1 FC layer 65 | """ 66 | if padding == None: assert kernel_size % 2 == 1, "use odd kernel size, equal padding constraint" 67 | super(CNNClassifier, self).__init__() 68 | self.out_channels = out_channels 69 | self.num_conv = hl 70 | self.kernel_size = kernel_size 71 | self.padding = padding or (self.kernel_size-1)//2 72 | self.stride = 1 73 | self.num_classes = 2 74 | self.idim = idim 75 | self.use_maxpool = use_maxpool 76 | self.maxpool_kernel_size = maxpool_kernel_size or self.kernel_size//2 77 | 78 | self.maxpool = nn.MaxPool1d(self.maxpool_kernel_size) 79 | self.ih_conv = nn.Conv1d(1, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride) 80 | 81 | self.hh_convs = [] 82 | for _ in range(self.num_conv-1): 83 | self.hh_convs.append(nn.Conv1d(self.out_channels, self.out_channels, self.kernel_size, padding=self.padding, stride=self.stride)) 84 | self.hh_convs.append(nn.ReLU()) 85 | self.hh_convs = nn.Sequential(*self.hh_convs) 86 | 87 | fc_idim = int(self.idim/self.maxpool_kernel_size) if self.use_maxpool else self.idim 88 | self.fc_layer = nn.Linear(self.out_channels*fc_idim, self.idim) 89 | self.out_layer = nn.Linear(self.idim, self.num_classes) 90 | self.relu = nn.ReLU() 91 | 92 | def forward(self, x): 93 | bs = x.shape[0] 94 | x_ = x.unsqueeze(1) 95 | 96 | x_ = self.relu(self.ih_conv(x_)) 97 | x_ = self.hh_convs(x_) 98 | 99 | if self.use_maxpool: x_ = self.maxpool(x_) 100 | x_ = self.relu(self.fc_layer(x_.view(bs, -1))) 101 | 102 | return self.out_layer(x_) 103 | 104 | class CNN2DClassifier(nn.Module): 105 | 106 | def __init__(self, num_filters, filter_size, num_layers, input_shape, input_channels=1, stride=2, padding=None, num_stride2_layers=2, fc_idim=None, fc_odim=None, num_classes=2, use_avgpool=True, avgpool_ksize=5): 107 | super(CNN2DClassifier, self).__init__() 108 | self.outch = num_filters 109 | self.fsize = filter_size 110 | self.input_channels = input_channels 111 | self.hl = num_layers 112 | self.padding = (self.fsize-1)//2 if padding is None else padding 113 | self.num_classes = num_classes 114 | num_stride2_layers = num_stride2_layers 115 | self.strides = iter([stride]*num_stride2_layers+[1]*(num_layers-num_stride2_layers)) 116 | self.use_avgpool = use_avgpool 117 | self.avgpool_ksize = avgpool_ksize 118 | 119 | self.convs = [nn.Conv2d(self.input_channels, self.outch, self.fsize, padding=self.padding, stride=next(self.strides)), nn.ReLU()] 120 | if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize)) 121 | 122 | for _ in range(self.hl-1): 123 | self.convs.append(nn.Conv2d(self.outch, self.outch, self.fsize, stride=next(self.strides), padding=self.padding)) 124 | self.convs.append(nn.ReLU()) 125 | if self.use_avgpool: self.convs.append(nn.AvgPool2d(self.avgpool_ksize)) 126 | 127 | self.convs = nn.Sequential(*self.convs) # need to wrap for gpu 128 | sl = min(self.hl, num_stride2_layers) 129 | self.fc_idim = int(num_filters*input_shape[0]*input_shape[1]/float(4**sl)) if fc_idim is None else fc_idim 130 | self.fc_odim = fc_odim if fc_odim is not None else self.fc_idim 131 | self.fc = nn.Linear(self.fc_idim, self.fc_odim) 132 | self.out = nn.Linear(self.fc_odim, self.num_classes) 133 | 134 | def forward(self, x): 135 | x = self.convs(x) 136 | x = x.reshape(x.shape[0], -1) 137 | return self.out(F.relu(self.fc(x))) 138 | 139 | def get_linear(input_dim, num_classes): 140 | return nn.Sequential(nn.Linear(input_dim, num_classes)) 141 | 142 | def get_fcn(idim, hdim, odim, hl=1, init=False, activation=nn.ReLU, use_activation=True, use_bn=False, input_dropout=0, dropout=0): 143 | use_dropout = dropout > 0 144 | layers = [] 145 | if input_dropout > 0: layers.append(nn.Dropout(input_dropout)) 146 | layers.append(nn.Linear(idim, hdim)) 147 | if use_activation: layers.append(activation()) 148 | if use_dropout: layers.append(nn.Dropout(dropout)) 149 | if use_bn: layers.append(nn.BatchNorm1d(hdim)) 150 | for _ in range(hl-1): 151 | l = [nn.Linear(hdim, hdim)] 152 | if use_activation: l.append(activation()) 153 | if use_dropout: l.append(nn.Dropout(dropout)) 154 | if use_bn: l.append(nn.BatchNorm1d(hdim)) 155 | layers.extend(l) 156 | layers.append(nn.Linear(hdim, odim)) 157 | model = nn.Sequential(*layers) 158 | 159 | if init: model.apply(kaiming_init) 160 | return model -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import pickle 6 | import copy 7 | from collections import defaultdict, Counter, OrderedDict 8 | import time 9 | from torch.utils.data import TensorDataset, DataLoader 10 | import torchvision 11 | from torch import optim, nn 12 | import torch.nn.functional as F 13 | from scipy.linalg import qr 14 | import lms_utils as au 15 | import ptb_utils as pu 16 | import gpu_utils as gu 17 | from sklearn import metrics 18 | import collections 19 | from sklearn.metrics import roc_auc_score 20 | 21 | plt.style.use('seaborn-ticks') 22 | import matplotlib.ticker as ticker 23 | 24 | def get_orthonormal_matrix(n): 25 | H = np.random.randn(n, n) 26 | s = np.linalg.svd(H)[1] 27 | s = s[s>1e-7] 28 | if len(s) != n: return get_orthonormal_matrix(n) 29 | Q, R = qr(H) 30 | return Q 31 | 32 | def get_dataloader(X, Y, bs, **kw): 33 | return DataLoader(TensorDataset(X, Y), batch_size=bs, **kw) 34 | 35 | def split_dataloader(dl, frac=0.5): 36 | bs = dl.batch_size 37 | X, Y = dl.dataset.tensors 38 | p = torch.randperm(len(X)) 39 | X, Y = X[p, :], Y[p] 40 | n = int(round(len(X)*frac)) 41 | X0, Y0 = X[:n, :], Y[:n] 42 | X1, Y1 = X[n:, :], Y[n:] 43 | dl0 = DataLoader(TensorDataset(torch.Tensor(X0), torch.LongTensor(Y0)), batch_size=bs, shuffle=True) 44 | dl1 = DataLoader(TensorDataset(torch.Tensor(X1), torch.LongTensor(Y1)), batch_size=bs, shuffle=True) 45 | return dl0, dl1 46 | 47 | def _to_dl(X, Y, bs, shuffle=True): 48 | return DataLoader(TensorDataset(torch.Tensor(X), torch.LongTensor(Y)), batch_size=bs, shuffle=shuffle) 49 | 50 | def extract_tensors_from_loader(dl, repeat=1, transform_fn=None): 51 | X, Y = [], [] 52 | for _ in range(repeat): 53 | for xb, yb in dl: 54 | if transform_fn: 55 | xb, yb = transform_fn(xb, yb) 56 | X.append(xb) 57 | Y.append(yb) 58 | X = torch.FloatTensor(torch.cat(X)) 59 | Y = torch.LongTensor(torch.cat(Y)) 60 | return X, Y 61 | 62 | def extract_numpy_from_loader(dl, repeat=1, transform_fn=None): 63 | X, Y = extract_tensors_from_loader(dl, repeat=repeat, transform_fn=transform_fn) 64 | return X.numpy(), Y.numpy() 65 | 66 | def _to_tensor_dl(dl, repeat=1, bs=None): 67 | X, Y = extract_numpy_from_loader(dl, repeat=repeat) 68 | dl = _to_dl(X, Y, bs if bs else dl.batch_size) 69 | return dl 70 | 71 | def flatten_loader(dl, bs=None): 72 | X, Y = extract_numpy_from_loader(dl) 73 | X = X.reshape(X.shape[0], -1) 74 | return _to_dl(X, Y, bs=bs if bs else dl.batch_size) 75 | 76 | def merge_loaders(dla, dlb): 77 | bs = dla.batch_size 78 | Xa, Ya = extract_numpy_from_loader(dla) 79 | Xb, Yb = extract_numpy_from_loader(dlb) 80 | return _to_dl(np.concatenate([Xa, Xb]), np.concatenate([Ya, Yb]), bs) 81 | 82 | def transform_loader(dl, func, shuffle=True): 83 | #assert type(dl.sampler) is torch.utils.data.sampler.SequentialSampler 84 | X, Y = extract_numpy_from_loader(dl, transform_fn=func) 85 | return _to_dl(X, Y, dl.batch_size, shuffle=shuffle) 86 | 87 | def visualize_tensors(P, size=8, normalize=True, scale_each=False, permute=True, ax=None, pad_value=0.): 88 | if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4)) 89 | if permute: 90 | s = np.random.choice(len(P), size=size, replace=False) 91 | p = P[s] 92 | else: 93 | p = P[:size] 94 | g = torchvision.utils.make_grid(torch.FloatTensor(p), nrow=size, normalize=normalize, scale_each=scale_each, pad_value=pad_value) 95 | g = g.permute(1,2,0).numpy() 96 | ax.imshow(g) 97 | ax.set_xticks([]) 98 | ax.set_yticks([]) 99 | return ax 100 | 101 | def visualize_loader(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None): 102 | if ax is None: _, ax = plt.subplots(1,1,figsize=(20,4)) 103 | for xb, yb in dl: break 104 | if reshape: xb = xb.reshape(len(xb), *reshape) 105 | return visualize_tensors(xb, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax) 106 | 107 | def visualize_loader_by_class(dl, ax=None, size=8, normalize=True, scale_each=False, reshape=None): 108 | for xb, yb in dl: break 109 | if reshape: xb = xb.reshape(len(xb), *reshape) 110 | 111 | classes = list(set(list(yb.numpy()))) 112 | fig, axs = plt.subplots(len(classes), 1, figsize=(15, 3*len(classes))) 113 | 114 | for y, ax in zip(classes, axs): 115 | xb_ = xb[yb==y] 116 | ax = visualize_tensors(xb_, size=size, normalize=normalize, scale_each=scale_each, permute=True, ax=ax) 117 | ax.set_title('Class: {}'.format(y)) 118 | 119 | return fig 120 | 121 | def visualize_perturbations(P, transform_fn=None): 122 | if transform_fn is not None: 123 | P = transform_fn(P) 124 | plt.figure(figsize=(20,4)) 125 | s = np.random.choice(len(P), size=8, replace=False) 126 | p = P[s] 127 | g = torchvision.utils.make_grid(torch.FloatTensor(p)) 128 | g = g.permute(1,2,0).numpy() 129 | g = (g-g.min())/g.max() 130 | plt.imshow(g) 131 | 132 | def get_logits_given_tensor(X, model, device=None, bs=250, softmax=False): 133 | if device is None: device = gu.get_device(None) 134 | sampler = torch.utils.data.SequentialSampler(X) 135 | sampler = torch.utils.data.BatchSampler(sampler, bs, False) 136 | 137 | logits = [] 138 | 139 | with torch.no_grad(): 140 | model = model.to(device) 141 | for idx in sampler: 142 | xb = X[idx].to(device) 143 | out = model(xb) 144 | logits.append(out) 145 | 146 | L = torch.cat(logits) 147 | if softmax: return F.softmax(L, 1) 148 | return L 149 | 150 | def get_predictions_given_tensor(X, model, device=None, bs=250): 151 | out = get_logits_given_tensor(X, model, device=device, bs=bs) 152 | return torch.argmax(out, 1) 153 | 154 | def get_accuracy_given_tensor(X, Y, model, device=None, bs=250): 155 | if device is None: device = gu.get_device(None) 156 | Y = torch.LongTensor(Y).to(device) 157 | yhat = get_predictions_given_tensor(X, model, device=device, bs=bs) 158 | return (Y==yhat).float().mean().item() 159 | 160 | def compute_accuracy(X, Y, model): 161 | with torch.no_grad(): 162 | pred = torch.argmax(model(X),1) 163 | correct = (pred == Y).sum().item() 164 | accuracy = correct/float(len(Y)) 165 | return accuracy 166 | 167 | def compute_loss_and_accuracy_from_dl(dl, model, loss_fn, sample_pct=1.0, device=None, transform_fn=None): 168 | in_tr_mode = model.training 169 | model = model.eval() 170 | data_size = float(len(dl.dataset)) 171 | samp_size = int(np.ceil(sample_pct*data_size)) 172 | num_eval = 0. 173 | bs = dl.batch_size 174 | accs, losses, bss = [], [], [] 175 | 176 | with torch.no_grad(): 177 | for xb, yb in dl: 178 | xb, yb = xb.to(device, non_blocking=False), yb.to(device, non_blocking=False) 179 | 180 | if transform_fn: 181 | xb, yb = transform_fn(xb, yb) 182 | 183 | sc = model(xb) 184 | 185 | if loss_fn is F.cross_entropy: 186 | loss = loss_fn(sc, yb, reduction='mean') 187 | pred = torch.argmax(sc, 1) 188 | elif loss_fn is F.binary_cross_entropy_with_logits: 189 | loss = loss_fn(sc, yb.float().unsqueeze(1)) 190 | pred = (sc > 0.).long().squeeze() 191 | elif loss_fn is hinge_loss: 192 | loss = loss_fn(sc, yb) 193 | pred = (sc > 0).long().squeeze() 194 | else: 195 | try: 196 | loss = loss_fn(sc, yb) 197 | pred = torch.argmax(sc, 1) 198 | except: 199 | assert False, "unknown loss function" 200 | 201 | correct = (pred==yb).sum().float() 202 | n = float(len(xb)) 203 | losses.append(loss.item()) 204 | accs.append((correct/n).item()) 205 | bss.append(n) 206 | 207 | num_eval += n 208 | if num_eval >= samp_size: break 209 | 210 | accs, losses, bss = map(np.array, [accs, losses, bss]) 211 | if in_tr_mode: model = model.train() 212 | return np.sum(bss*accs)/num_eval, np.sum(bs*losses)/num_eval 213 | 214 | def count_parameters(model): 215 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 216 | 217 | def get_logits(model, loader, device): 218 | S, Y = [], [] 219 | with torch.no_grad(): 220 | for xb, yb in loader: 221 | xb = xb.to(device) 222 | out = model(xb).cpu().numpy() 223 | S.append(out) 224 | Y.append(list(yb)) 225 | S, Y = map(np.concatenate, [S, Y]) 226 | return S, Y 227 | 228 | def get_scores(model, loader, device): 229 | """binary tasks only""" 230 | S, Y = get_logits(model, loader, device) 231 | return S[:,1]-S[:,0], Y 232 | 233 | def get_multiclass_logit_score(L, Y): 234 | scores = [] 235 | for idx, (l, y) in enumerate(zip(L, Y)): 236 | sc_y = l[y] 237 | 238 | indices = np.argsort(l) 239 | best2_idx, best1_idx = indices[-2:] 240 | sc_max = l[best2_idx] if y == best1_idx else l[best1_idx] 241 | 242 | score = sc_y - sc_max 243 | scores.append(score) 244 | 245 | return np.array(scores) 246 | 247 | def get_binary_auc(model, loader, device): 248 | S, Y = get_scores(model, loader, device) 249 | return roc_auc_score(Y, S) 250 | 251 | def get_multiclass_auc(model, loader, device, one_vs_rest=True): 252 | X, Y = extract_tensors_from_loader(loader) 253 | S = get_logits_given_tensor(X, model, device=device, softmax=True).cpu() 254 | mc = 'ovr' if one_vs_rest is True else 'ovo' 255 | S, Y = S.numpy(), Y.numpy() 256 | return roc_auc_score(Y, S, multi_class=mc) 257 | 258 | def clip_gradient(model, clip_value): 259 | params = list(filter(lambda p: p.grad is not None, model.parameters())) 260 | for p in params: p.grad.data.clamp_(-clip_value, clip_value) 261 | 262 | def print_model_gradients(model, print_bias=True): 263 | for name, params in model.named_parameters(): 264 | if not print_bias and 'bias' in name: continue 265 | if not params.requires_grad: continue 266 | avg_grad = np.mean(params.grad.cpu().numpy()) 267 | print (name, params.shape, avg_grad) 268 | 269 | def hinge_loss(out, y): 270 | y_ = (2*y.float()-1).unsqueeze(1) 271 | return torch.mean(F.relu(1-out*y_)) 272 | 273 | def pgd_adv_fit_model(model, opt, tr_dl, te_dl, attack, eval_attack=None, device=None, sch=None, max_epochs=100, epoch_gap=2, 274 | min_loss=0.001, print_info=True, save_init_model=True): 275 | 276 | # setup tracking 277 | PR = lambda x: print (x) if print_info else None 278 | stop_training = False 279 | stats = defaultdict(list) 280 | best_val, best_model = np.inf, None 281 | adv_epoch_timer = [] 282 | epoch_gap_timer = [time.time()] 283 | init_model = copy.deepcopy(model).cpu() if save_init_model else None 284 | 285 | # eval attack 286 | eval_attack = eval_attack or attack 287 | 288 | print ("Min loss: {}".format(min_loss)) 289 | 290 | def standard_epoch(loader, model, optimizer=None, sch=None): 291 | """compute accuracy and loss. Backprop if optimizer provided""" 292 | total_loss, total_err = 0.,0. 293 | model = model.eval() if optimizer is None else model.train() 294 | model = model.to(device) 295 | update_params = optimizer is not None 296 | 297 | with torch.set_grad_enabled(update_params): 298 | for xb, yb in loader: 299 | xb, yb = xb.to(device), yb.to(device) 300 | yp = model(xb) 301 | loss = F.cross_entropy(yp, yb) 302 | if update_params: 303 | optimizer.zero_grad() 304 | loss.backward() 305 | optimizer.step() 306 | total_err += (yp.max(dim=1)[1] != yb).sum().item() 307 | total_loss += loss.item() * xb.shape[0] 308 | return total_err / len(loader.dataset), total_loss / len(loader.dataset) 309 | 310 | def adv_epoch(loader, model, attack, optimizer=None, sch=None): 311 | """compute adv accuracy and loss. Backprop if optimizer provided""" 312 | start_time = time.time() 313 | total_loss, total_err = 0.,0. 314 | model = model.eval() if optimizer is None else model.train() 315 | model = model.to(device) 316 | update_params = optimizer is not None 317 | 318 | for xb, yb in loader: 319 | torch.set_grad_enabled(True) 320 | xb, yb = xb.to(device), yb.to(device) 321 | delta = attack.perturb(xb, yb, model).to(device) 322 | xb = xb + delta 323 | with torch.set_grad_enabled(update_params): 324 | yp = model(xb) 325 | loss = F.cross_entropy(yp, yb) 326 | if update_params: 327 | optimizer.zero_grad() 328 | loss.backward() 329 | optimizer.step() 330 | total_err += (yp.max(dim=1)[1] != yb).sum().item() 331 | total_loss += loss.item() * xb.shape[0] 332 | 333 | if optimizer is not None and sch is not None: 334 | cur_lr = next(iter(opt.param_groups))['lr'] 335 | sch.step() 336 | new_lr = next(iter(opt.param_groups))['lr'] 337 | if new_lr != cur_lr: 338 | PR('Epoch {}, LR : {} -> {}'.format(epoch, cur_lr, new_lr)) 339 | 340 | total_time = time.time()-start_time 341 | adv_epoch_timer.append(total_time) 342 | return total_err / len(loader.dataset), total_loss / len(loader.dataset) 343 | 344 | epoch = 0 345 | while epoch < max_epochs: 346 | if stop_training: 347 | break 348 | try: 349 | stat = {} 350 | model = model.train() 351 | train_err, train_loss = adv_epoch(tr_dl, model, attack, optimizer=opt, sch=sch) 352 | 353 | if epoch % epoch_gap == 0: 354 | model = model.eval() 355 | test_err, test_loss = standard_epoch(te_dl, model, optimizer=None, sch=None) 356 | adv_err, adv_loss = adv_epoch(te_dl, model, eval_attack, optimizer=None, sch=None) 357 | stat['acc_te'], stat['acc_te_std'] = adv_err, test_err 358 | stat['loss_te'], stat['loss_te_std'] = adv_loss, test_loss 359 | 360 | if adv_err < best_val: 361 | best_val = adv_err 362 | best_model = copy.deepcopy(model).eval() 363 | 364 | if print_info: 365 | if epoch==0: print ("Epoch", "l-tr", "a-tr", "a-te", "s-te", "time", sep='\t') 366 | #print (epoch, *("{:.4f}".format(i) for i in (train_loss, train_err)), sep=' ') 367 | diff_time = time.time()-epoch_gap_timer[-1] 368 | epoch_gap_timer.append(time.time()) 369 | print (epoch, *("{:.4f}".format(i) for i in (train_loss, 1.-train_err, 1.-adv_err, 1.-test_err, diff_time)), sep=' ') 370 | 371 | if train_loss < min_loss: 372 | stop_training = True 373 | 374 | print ("Epoch {}: accuracy {:.3f} and loss {:.3f}".format(epoch, 1-train_err, train_loss)) 375 | 376 | stat['epoch'] = epoch 377 | stat['acc_tr'] = train_err 378 | stat['loss_tr'] = train_loss 379 | 380 | for k, v in stat.items(): 381 | stats[k].append(v) 382 | 383 | epoch += 1 384 | 385 | except KeyboardInterrupt: 386 | inp = input("LR num or Q or SAVE or GAP or MAXEPOCHS: ") 387 | if inp.startswith('LR'): 388 | lr = float(inp.split(' ')[-1]) 389 | cur_lr = next(iter(opt.param_groups))['lr'] 390 | PR("New LR: {}".format(lr)) 391 | for g in opt.param_groups: g['lr'] = lr 392 | if inp.startswith('Q'): 393 | stop_training = True 394 | if inp.startswith('SAVE'): 395 | fpath = inp.split(' ')[-1] 396 | stats['best_model'] = (best_val, best_model.cpu()) 397 | torch.save({ 398 | 'model': copy.deepcopy(model).cpu(), 399 | 'stats': stats, 400 | 'opt': copy.deepcopy(opt).cpu() 401 | }, fpath) 402 | PR(f'Saved to {fpath}') 403 | if inp.startswith('GAP'): 404 | _, gap = inp.split(' ') 405 | gap = int(gap) 406 | print ("epoch gap: {} -> {}".format(epoch_gap, gap)) 407 | epoch_gap = gap 408 | if inp.startswith('MAXEPOCHS'): 409 | _, me = inp.split(' ') 410 | me = int(me) 411 | print ("max_epochs: {} -> {}".format(max_epochs, me)) 412 | max_epochs = me 413 | 414 | stats['best_model'] = (best_val, best_model.cpu()) 415 | stats['init_model'] = init_model 416 | return stats 417 | 418 | 419 | def fit_model(model, loss, opt, train_dl, valid_dl, sch=None, epsilon=1e-2, is_loss_epsilon=False, update_gap=50, update_print_gap=50, gap=None, 420 | print_info=True, save_grads=False, test_dl=None, skip_epoch_eval=True, sample_pct=0.5, sample_loss_threshold=0.75, save_models=False, 421 | print_grads=False, print_model_layers=False, tr_batch_fn=None, te_batch_fn=None, device=None, max_updates=800_000, patience_updates=1, 422 | enable_redo=False, save_best_model=True, save_init_model=True, max_epochs=100000, **misc): 423 | 424 | # setup update metadata 425 | MAX_LOSS_VAL = 1000000. 426 | PR = lambda x: print (x) if print_info else None 427 | use_epoch = False 428 | if gap is not None: update_gap = update_print_gap = gap 429 | bs_ratio = int(len(train_dl.dataset)/float(train_dl.batch_size)) 430 | act_update_gap = update_gap if not use_epoch else update_gap*bs_ratio 431 | act_pr_update_gap = update_print_gap if not use_epoch else update_print_gap*bs_ratio 432 | PR("accuracy/loss measured every {} updates".format(act_update_gap)) 433 | 434 | if save_models: 435 | PR("saving models every {} updates".format(act_update_gap)) 436 | 437 | PR("update_print_gap: {}, epss: {}, bs: {}, device: {}".format(act_pr_update_gap, epsilon, train_dl.batch_size, device or 'cpu')) 438 | 439 | # init_save setup 440 | init_model = copy.deepcopy(model).cpu() if save_init_model else None 441 | 442 | # redo setup 443 | if enable_redo: 444 | init_model_sd = copy.deepcopy(model.state_dict()) 445 | init_opt_sd = copy.deepcopy(opt.state_dict()) 446 | else: 447 | init_model_sd = None 448 | init_opt_sd = None 449 | 450 | # best model setup 451 | best_val, best_model = 0, None 452 | 453 | # tracking setup 454 | start_time = time.time() 455 | num_evals, num_epochs, num_updates, num_patience = 0, 0, 0, 0 456 | stats = dict(loss_tr=[], loss_te=[], acc_tr=[], acc_te=[], acc_test=[], loss_test=[], models=[], gradients=[]) 457 | if save_models: stats['models'].append(copy.deepcopy(model).cpu()) 458 | first_run, converged = True, False 459 | print_stats_flag = update_print_gap is not None 460 | exceeded_max = False 461 | diverged = False 462 | 463 | def _evaluate(device=device): 464 | model.eval() 465 | with torch.no_grad(): 466 | prev_loss = stats['loss_tr'][-1] if stats['loss_tr'] else 1. 467 | tr_sample_pct = sample_pct if prev_loss > sample_loss_threshold else 1. 468 | acc_tr, loss_tr = compute_loss_and_accuracy_from_dl(train_dl,model,loss,sample_pct=tr_sample_pct,device=device,transform_fn=tr_batch_fn) 469 | acc_te, loss_te = compute_loss_and_accuracy_from_dl(valid_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn) 470 | acc_tr, loss_tr, acc_te, loss_te = map(lambda x: x.item(), [acc_tr, loss_tr, acc_te, loss_te]) 471 | stats['loss_tr'].append(loss_tr) 472 | stats['loss_te'].append(loss_te) 473 | stats['acc_tr'].append(acc_tr) 474 | stats['acc_te'].append(acc_te) 475 | 476 | if test_dl is not None: 477 | acc_test, loss_test = compute_loss_and_accuracy_from_dl(test_dl,model,loss,sample_pct=1.,device=device,transform_fn=te_batch_fn) 478 | acc_test, loss_test = acc_test.item(), loss_test.item() 479 | stats['acc_test'].append(acc_test) 480 | stats['loss_test'].append(loss_test) 481 | 482 | if save_models: 483 | stats['models'].append(copy.deepcopy(model).cpu()) 484 | 485 | def _update(x,y,diff_device, device=device, save_grads=False, print_grads=False): 486 | model.train() 487 | 488 | # if diff_device: 489 | # x = x.to(device, non_blocking=False) 490 | # y = y.to(device, non_blocking=False) 491 | 492 | opt.zero_grad() 493 | out = model(x) 494 | if loss is F.cross_entropy or loss is hinge_loss: 495 | bloss = loss(out, y) 496 | elif loss is F.binary_cross_entropy_with_logits: 497 | bloss = loss(out, y.float().unsqueeze(1)) 498 | else: 499 | try: 500 | bloss = loss(out, y) 501 | except: 502 | assert False, "unknown loss function" 503 | 504 | bloss.backward() 505 | if print_grads and print_info: print_model_gradients(model) 506 | #clip_gradient(model, clip_value) 507 | opt.step() 508 | 509 | if save_grads: 510 | g = {k: v.grad.data.cpu().numpy() for k, v in model.named_parameters() if v.requires_grad} 511 | stats['gradients'].append(g) 512 | 513 | opt.zero_grad() 514 | model.eval() 515 | 516 | def print_time(): 517 | end_time = time.time() 518 | minutes, seconds = divmod(end_time-start_time, 60) 519 | gap_valid = len(stats['acc_tr']) > 0 520 | gap = round(stats['acc_tr'][-1]-stats['acc_te'][-1],4) if gap_valid else 'na' 521 | PR("converged after {} epochs in {}m {:1f}s, gap: {}".format(num_epochs, minutes, seconds, gap)) 522 | 523 | def print_stats(force_print=False): 524 | 525 | if test_dl is None: 526 | atr, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'loss_tr']] 527 | PR("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr)) 528 | if not print_info and force_print: 529 | print ("{} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, ate, ltr)) 530 | else: 531 | atr, aval, ate, ltr = [stats[k][-1] for k in ['acc_tr', 'acc_te', 'acc_test', 'loss_tr']] 532 | PR("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr)) 533 | if not print_info and force_print: 534 | print ("{} {:.4f} {:.4f} {:.4f} {:.4f}".format(num_updates, atr, aval, ate, ltr)) 535 | 536 | #xb_, yb_ = next(iter(train_dl)) 537 | diff_device = True #xb_.device != device 538 | 539 | if test_dl is None: PR("#updates, train acc, test acc, train loss") 540 | else: PR("#updates, train acc, val acc, test acc, train loss") 541 | 542 | while not converged or num_patience < patience_updates: 543 | try: 544 | model.train() 545 | for xb, yb in train_dl: 546 | 547 | if tr_batch_fn: 548 | xb, yb = tr_batch_fn(xb, yb) 549 | 550 | if diff_device: 551 | xb = xb.to(device, non_blocking=False) 552 | yb = yb.to(device, non_blocking=False) 553 | 554 | if converged: 555 | num_patience += 1 556 | 557 | if converged and num_patience == patience_updates: 558 | _evaluate() 559 | print_stats() 560 | break 561 | 562 | # update flag for printing gradients 563 | update_flag = print_model_layers and (num_updates == 0 or (num_updates % act_update_gap == 0 and print_grads)) 564 | _update(xb, yb, diff_device, device=device, save_grads=save_grads, print_grads=update_flag) 565 | 566 | if (num_evals == 0 or num_updates % act_update_gap == 0): 567 | num_evals += 1 568 | _evaluate() 569 | print_stats() 570 | 571 | val_acc = stats['acc_te'][-1] 572 | if num_updates > 0 and val_acc >= best_val: 573 | best_val = val_acc 574 | best_model = copy.deepcopy(model).eval() 575 | 576 | # check if loss has diverged 577 | loss_val = max(stats['loss_tr'][-1], stats['loss_te'][-1]) 578 | if loss_val > MAX_LOSS_VAL: diverged = True 579 | if not np.isfinite(loss_val): diverged = True 580 | 581 | 582 | if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon 583 | else: stop = stats['acc_tr'][-1] >= 1-epsilon 584 | 585 | if not converged and diverged: 586 | converged = True 587 | print_time() 588 | PR("loss diverging...exiting".format(patience_updates)) 589 | 590 | if not converged and stop: 591 | converged = True 592 | print_time() 593 | PR("init-ing patience ({} updates)".format(patience_updates)) 594 | 595 | num_updates += 1 596 | first_run = False 597 | 598 | if num_updates > max_updates: 599 | converged = True 600 | exceeded_max = True 601 | num_patience = patience_updates 602 | PR("Exceeded max updates") 603 | print_stats() 604 | print_time() 605 | break 606 | 607 | # re-eval at the end of epoch 608 | if not converged: 609 | num_epochs += 1 610 | 611 | if not converged and num_epochs >= max_epochs: 612 | converged = True 613 | exceeded_max = True 614 | num_patience = patience_updates 615 | PR("Exceeded max epochs") 616 | print_stats() 617 | print_time() 618 | break 619 | 620 | if not skip_epoch_eval: 621 | _evaluate() 622 | print_stats() 623 | 624 | if is_loss_epsilon: stop = stats['loss_tr'][-1] < epsilon 625 | else: stop = stats['acc_tr'][-1] >= 1-epsilon 626 | 627 | if not converged and stop: 628 | converged = True 629 | print_time() 630 | PR("init-ing patience ({} updates)".format(patience_updates)) 631 | 632 | if num_patience >= patience_updates: 633 | _evaluate() 634 | print_stats() 635 | break 636 | 637 | # update LR via scheduler 638 | if sch is not None: 639 | cur_lr = next(iter(opt.param_groups))['lr'] 640 | sch.step() 641 | new_lr = next(iter(opt.param_groups))['lr'] 642 | if new_lr != cur_lr: 643 | PR('Epoch {}, LR : {} -> {}'.format(num_epochs, cur_lr, new_lr)) 644 | 645 | except KeyboardInterrupt: 646 | inp = input("LR num or Q or GAP num or SAVE fpath or EVAL or REDO: ") 647 | if inp.startswith('LR'): 648 | lr = float(inp.split(' ')[-1]) 649 | cur_lr = next(iter(opt.param_groups))['lr'] 650 | PR("LR: {} - > {}".format(cur_lr, lr)) 651 | for g in opt.param_groups: g['lr'] = lr 652 | elif inp.startswith('GAP'): 653 | gap = int(inp.split(' ')[-1]) 654 | act_update_gap = act_pr_update_gap = gap 655 | elif inp == "Q": 656 | converged = True 657 | num_patience = patience_updates 658 | print_time() 659 | elif inp.startswith('SAVE'): 660 | fpath = inp.split(' ')[-1] 661 | torch.save({ 662 | 'model': model, 663 | 'opt': opt, 664 | 'update_gap': update_gap 665 | }, fpath) 666 | elif inp == 'EVAL': 667 | _evaluate() 668 | print_stats(True) 669 | elif inp == 'REDO': 670 | if enable_redo: 671 | model.load_state_dict(init_model_sd) 672 | opt.load_state_dict(init_opt_sd) 673 | else: 674 | print ("REDO disabled") 675 | 676 | best_test = None 677 | if test_dl is not None: 678 | best_test = compute_loss_and_accuracy_from_dl(test_dl, best_model, loss, sample_pct=1.0, device=device)[0].item() 679 | 680 | stats['num_updates'] = num_updates 681 | stats['num_epochs'] = num_epochs 682 | stats['update_gap'] = update_gap 683 | 684 | stats['best_model'] = (best_val, best_test, best_model.cpu() if best_model else model.cpu()) 685 | stats['init_model'] = init_model 686 | if save_models: stats['models'].append(copy.deepcopy(model).cpu()) 687 | 688 | stats['x_updates']= list(range(0, num_evals*(update_gap+1), update_gap)) 689 | stats['x'] = stats['x_updates'][:] 690 | stats['x_epochs'] = list(range(num_epochs)) 691 | stats['gap'] = stats['acc_tr'][-1]-stats['acc_te'][-1] 692 | return stats 693 | 694 | def save_pickle(fname, d, mode='w'): 695 | with open(fname, mode) as f: 696 | pickle.dump(d, f) 697 | 698 | def load_pickle(fname, mode='r'): 699 | with open(fname, mode) as f: 700 | return pickle.load(f) 701 | 702 | def update_ax(ax, title=None, xlabel=None, ylabel=None, legend_loc='best', ticks=True, ticks_fs=10, label_fs=12, legend_fs=12, title_fs=14, hide_xlabels=False, hide_ylabels=False, despine=True): 703 | if title: ax.set_title(title, fontsize=title_fs) 704 | if xlabel: ax.set_xlabel(xlabel, fontsize=label_fs) 705 | if ylabel: ax.set_ylabel(ylabel, fontsize=label_fs) 706 | if legend_loc: ax.legend(loc=legend_loc, fontsize=legend_fs) 707 | if despine: sns.despine(ax=ax) 708 | 709 | if ticks: 710 | # ax.minorticks_on() 711 | ax.tick_params(direction='in', length=6, width=2, colors='k', which='major', top=False, right=False) 712 | ax.tick_params(direction='in', length=4, width=1, colors='k', which='minor', top=False, right=False) 713 | ax.tick_params(labelsize=ticks_fs) 714 | 715 | if hide_xlabels: ax.set_xticks([]) 716 | if hide_ylabels: ax.set_yticks([]) 717 | return ax --------------------------------------------------------------------------------