├── LICENSE ├── README.md ├── __pycache__ ├── checkpoints.cpython-36.pyc ├── config.cpython-36.pyc ├── dataloader.cpython-36.pyc ├── model.cpython-36.pyc ├── train.cpython-36.pyc └── utils.cpython-36.pyc ├── dataloader.py ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── filelist.cpython-35.pyc │ ├── filelist.cpython-36.pyc │ ├── folderlist.cpython-35.pyc │ ├── folderlist.cpython-36.pyc │ ├── loaders.cpython-35.pyc │ ├── loaders.cpython-36.pyc │ └── svhn.cpython-35.pyc ├── filelist.py ├── folderlist.py ├── loaders.py └── transforms.py ├── main.py ├── models.py ├── test.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Felix Juefei Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Perturbative Neural Networks (PNN) 2 | This is an attempt to reproduce results in Perturbative Neural Networks paper. 3 | See [original repo](https://github.com/juefeix/pnn.pytorch) for details. 4 | 5 | ## Motivation 6 | The original implementation used regular convolutions in the first layer, and the remaining layers used fanout of 1, which means each input channel was perturbed with a single noise mask. 7 | 8 | However, the biggest issue with the original implementation is that test accuracy was calculated incorrectly. Instead of the usual method of calculating ratio of correct samples to total samples in the test dataset, the authors calculated accuracy on per batch basis, and applied smoothing weight (test_accuracy = 0.7 * prev_batch_accuracy + 0.3 * current_batch_accuracy). 9 | 10 | Here's how [this method](https://github.com/juefeix/pnn.pytorch/blob/master/plugins/monitor.py#L31) (reported) compares to the [proper accuracy calculation](https://github.com/michaelklachko/pnn.pytorch/blob/master/main.py#L226-L230) (actual): 11 | 12 | ![img](https://s15.postimg.cc/vta2ku9nv/image.png) 13 | 14 | For this model run (noiseresnet18 on CIFAR10), the code in original repo would report best test accuracy 90.53%, while the actual best test accuracy is 85.91% 15 | 16 | After correcting this issue, I ran large number of experiments trying to see if perturbing input with noise masks would provide any benefit, and my conclusion is that it does not. 17 | 18 | Here's for example, the difference between ResNet18-like models: a baseline model with reduced number of filters to keep the same parameter count, a model where all layers except first one use only 1x1 convolutions (no noise), and a model where all layers except first one use perturbations followed by 1x1 convolutions. All three models have ~5.5M parameters: 19 | 20 | ![img](https://s15.postimg.cc/5jrce4zyz/image.png) 21 | 22 | The accuracy difference between regular resnet baseline and PNN remains ~5% throughout the training, and the addition of noise masks results in less than 1% improvement over equivalently "crippled" resnet without any noise applied. 23 | 24 | ## Implementation details 25 | Most of the modifications are contained in the [PerturbLayer class.](https://github.com/michaelklachko/pnn.pytorch/blob/master/models.py#L15) Here are the main changes from the original code: 26 | 27 | `--first_filter_size` and `--filter_size` arguments control the type of the first layer, and the remaining layers, correspondingly. A value of 0 turns the layer into a perturbation layer, as described in the paper. Any value n > 0 will turn the layer into a regular convolutional layer with filter size n. The original implementation only supports first_filter_size=7, and filter_size=0. 28 | 29 | `--nmasks` specifies number of noise masks to apply to each input channel. This is "fanout" parameter mentioned in the paper. The original implementation only supports nmasks=1. 30 | 31 | `--unique_masks` specifies whether to use different sets of `nmasks` noise masks for each input channel. `--no-unique_masks` forces the same set of nmasks to be used for all input channels. 32 | 33 | `--train_masks` enables treating noise masks as regular parameters, and optimizes their values during training at the same time as model weights. 34 | 35 | `--mix_maps` adds second 1x1 convolutional layer after perturbed input channels are combined with the first 1x1 convolution. Without this second 1x1 "mixing" layer, there is no information exchange between input channels all the way until the softmax layer in the end. Note that it's not needed when `--nmasks` is 1, because then the first 1x1 convolutional layer already plays this role. 36 | 37 | Other arguments allow changing noise type (uniform or normal), pooling type (max or avg), activation function (relu, rrelu, prelu, elu, selu, tanh, sigmoid), whether to apply activation function in the first layer (--use_act, immediately after perturbing the input RGB channels, this results in some information loss), whether to scale noise level in the first layer, and --debug argument prints out values of input, noise, and output for every update step to verify that noise is being applied correctly. 38 | 39 | Three different models are supported: `perturb_resnet18`, `cifarnet` (6 conv layers, followed by a fully connected layer), and `lenet` (3 conv. layers followed by a fully connected layer). In addition, I included the baseline ResNet-18 model `resnet18` taken from [here](https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py), and `noiseresnet18` model from the original repo. Note that `perturb_resnet18` model is flexible enough to replace both baseline and `noiseresnet18` models, using appropriate arguments. 40 | 41 | ## Results 42 | CIFAR-10: 43 | 44 | 1. Baseline (regular ResNet18 with 3x3 convolutions, number of filters reduced to match PNN parameter count) Test Accuracy: 91.8% 45 | ``` 46 | python main.py --net-type 'resnet18' --dataset-test 'CIFAR10' --dataset-train 'CIFAR10' --nfilters 44 --batch-size 10 --learning-rate 1e-3 47 | ``` 48 | 49 | 2. Original implementation (equivalent to running the code from the original repo). Test Accuracy: 85.7% 50 | ``` 51 | python main.py --net-type 'noiseresnet18' --dataset-test 'CIFAR10' --dataset-train 'CIFAR10' --nfilters 128 --batch-size 10 --learning-rate 1e-4 --first_filter_size 7 52 | ``` 53 | 54 | 3. Same as above, but changing `first_filter_size` argument to 3 improves the accuracy to 86.2% 55 | 56 | 4. Same as above, but without any noise (resnet18 with 3x3 convolutions in the first layer, and 1x1 in remaining layers). Test Accuracy: 85.5% 57 | ``` 58 | python main.py --net-type 'perturb_resnet18' --dataset-test 'CIFAR10' --dataset-train 'CIFAR10' --nfilters 128 --batch-size 16 --learning-rate 1e-3 --first_filter_size 3 --filter_size 1 59 | ``` 60 | 61 | 5. PNN with all uniform noise in all layers (including the first layer). Test Accuracy: 72.6% 62 | ``` 63 | python main.py --net-type 'perturb_resnet18' --dataset-test 'CIFAR10' --dataset-train 'CIFAR10' --nfilters 128 --batch-size 16 --learning-rate 1e-3 --first_filter_size 0 --filter_size 0 --nmasks 1 64 | ``` 65 | 66 | 6. PNN with noise masks in all layers except the first layer, which uses regular 3x3 convolutions with fanout=64. Internally fanout is implemented with grouped 1x1 convolutions. Note: --unique_masks arg creates unique set of masks for each input channel, in every layer, and --mix_maps argument which uses extra 1x1 convolutional layer in all perturbation layers. Test Accuracy: 82.7% 67 | ``` 68 | python main.py --net-type 'perturb_resnet18' --dataset-test 'CIFAR10' --dataset-train 'CIFAR10' --nfilters 128 --batch-size 16 --learning-rate 1e-3 --first_filter_size 3 --filter_size 0 --nmasks 64 --unique_masks --mix_maps 69 | ``` 70 | 71 | 7. Same as above, but with --no-unique_masks argument, which means that the same set of masks is used for each input channel. Test Accuracy: 82.4% 72 | ``` 73 | python main.py --net-type 'perturb_resnet18' --dataset-test 'CIFAR10' --dataset-train 'CIFAR10' --nfilters 128 --batch-size 16 --learning-rate 1e-3 --first_filter_size 3 --filter_size 0 --nmasks 64 --no-unique_masks 74 | ``` 75 | 76 | Experiments 6 and 7 are the closest to what was described in the paper. 77 | 78 | 8. training the noise masks (updated each batch, at the same time as regular model parameters). Test Accuracy: 85.9% 79 | 80 | `python main.py --net-type 'perturb_resnet18' --dataset-test 'CIFAR10' --dataset-train 'CIFAR10' --nfilters 128 --batch-size 16 --learning-rate 1e-3 --first_filter_size 3 --filter_size 0 --nmasks 64 --no-unique_masks --train_masks` 81 | 82 | ## Weakness of reasoning: 83 | Section 3.3: "given the known input x and convolution transformation matrix A, we can always solve for the matching noise perturbation matrix N". 84 | 85 | While for any given single input sample PNN might be able to find the weights required to match the output of a CNN, it does not follow that it can find weights to do that for all input samples in the dataset. 86 | 87 | Section 3.4: The result of a single convolution operation is represented as a value of the center pixel Xc in a patch X, plus some quantity Nc (a function of filter weights W and the neighboring pixels of Xc): Y = XW = Xc + Nc. The claim is: "Establishing that Nc behaves like additive perturbation noise, will allows us to relate the CNN formulation to the PNN formulation". 88 | 89 | Even if Nc statistically behaves like random noise does not mean it can be replaced with random noise. The random noise in PNN does not depend on values of neighboring pixels in the patch, unlike Nc in a regular convolution. PNN layer lacks the main feature extraction property of a regular convolution: it cannot directly match any spatial patterns with a filter. 90 | 91 | ## Conclusion 92 | It appears that perturbing layer inputs with noise does not provide any significant benefit. Simple 1x1 convolutions without noise masks provide similar performance. No matter how we apply noise masks, the accuracy drop resulting from using 1x1 filters is severe (~5% on CIFAR-10 even when not modifying the first layer). The results published by the authors are invalid due to incorrect accuracy calculation method. 93 | 94 | -------------------------------------------------------------------------------- /__pycache__/checkpoints.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/__pycache__/checkpoints.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | # dataloader.py 2 | 3 | import os 4 | import torch 5 | import datasets 6 | #import torch.utils.data 7 | import torchvision.transforms as transforms 8 | 9 | class Dataloader: 10 | 11 | def __init__(self, args, input_size): 12 | self.args = args 13 | 14 | self.dataset_test_name = args.dataset_test 15 | self.dataset_train_name = args.dataset_train 16 | self.input_size = input_size 17 | 18 | if self.dataset_train_name == 'LSUN': 19 | self.dataset_train = getattr(datasets, self.dataset_train_name)(db_path=args.dataroot, classes=['bedroom_train'], 20 | transform=transforms.Compose([ 21 | transforms.Scale(self.input_size), 22 | transforms.CenterCrop(self.input_size), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 25 | ]) 26 | ) 27 | 28 | elif self.dataset_train_name == 'CIFAR10' or self.dataset_train_name == 'CIFAR100': 29 | self.dataset_train = getattr(datasets, self.dataset_train_name)(root=self.args.dataroot, train=True, download=True, 30 | transform=transforms.Compose([ 31 | transforms.RandomCrop(self.input_size, padding=4), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 35 | ]) 36 | ) 37 | 38 | elif self.dataset_train_name == 'CocoCaption' or self.dataset_train_name == 'CocoDetection': 39 | self.dataset_train = getattr(datasets, self.dataset_train_name)(root=self.args.dataroot, train=True, download=True, 40 | transform=transforms.Compose([ 41 | transforms.Scale(self.input_size), 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 44 | ]) 45 | ) 46 | 47 | elif self.dataset_train_name == 'STL10' or self.dataset_train_name == 'SVHN': 48 | self.dataset_train = getattr(datasets, self.dataset_train_name)(root=self.args.dataroot, split='train', download=True, 49 | transform=transforms.Compose([ 50 | transforms.Scale(self.input_size), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 53 | ]) 54 | ) 55 | 56 | elif self.dataset_train_name == 'MNIST': 57 | self.dataset_train = getattr(datasets, self.dataset_train_name)(root=self.args.dataroot, train=True, download=True, 58 | transform=transforms.Compose([ 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.1307,), (0.3081,)) 61 | ]) 62 | ) 63 | 64 | elif self.dataset_train_name == 'ImageNet': 65 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 66 | std=[0.229, 0.224, 0.225]) 67 | self.dataset_train = datasets.ImageFolder(root=os.path.join(self.args.dataroot,self.args.input_filename_train), 68 | transform=transforms.Compose([ 69 | transforms.RandomSizedCrop(224), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | normalize, 73 | ]) 74 | ) 75 | 76 | elif self.dataset_train_name == 'FRGC': 77 | self.dataset_train = datasets.ImageFolder(root=self.args.dataroot+self.args.input_filename_train, 78 | transform=transforms.Compose([ 79 | transforms.Scale(self.input_size), 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 82 | ]) 83 | ) 84 | 85 | elif self.dataset_train_name == 'Folder': 86 | self.dataset_train = datasets.ImageFolder(root=self.args.dataroot+self.args.input_filename_train, 87 | transform=transforms.Compose([ 88 | transforms.Scale(self.input_size), 89 | transforms.ToTensor(), 90 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 91 | ]) 92 | ) 93 | 94 | elif self.dataset_train_name == 'FileList': 95 | self.dataset_train = datasets.FileList(self.input_filename_train, self.label_filename_train, self.split_train, 96 | self.split_test, train=True, 97 | transform_train=transforms.Compose([ 98 | transforms.Scale(self.input_size), 99 | transforms.CenterCrop(self.input_size), 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 102 | ]), 103 | transform_test=transforms.Compose([ 104 | transforms.Scale(self.input_size), 105 | transforms.CenterCrop(self.input_size), 106 | transforms.ToTensor(), 107 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 108 | ]), 109 | loader_input=self.loader_input, 110 | loader_label=self.loader_label, 111 | ) 112 | 113 | elif self.dataset_train_name == 'FolderList': 114 | self.dataset_train = datasets.FileList(self.input_filename_train, self.label_filename_train, self.split_train, 115 | self.split_test, train=True, 116 | transform_train=transforms.Compose([ 117 | transforms.Scale(self.input_size), 118 | transforms.CenterCrop(self.input_size), 119 | transforms.ToTensor(), 120 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 121 | ]), 122 | transform_test=transforms.Compose([ 123 | transforms.Scale(self.input_size), 124 | transforms.CenterCrop(self.input_size), 125 | transforms.ToTensor(), 126 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 127 | ]), 128 | loader_input=self.loader_input, 129 | loader_label=self.loader_label, 130 | ) 131 | 132 | else: 133 | raise(Exception("Unknown Dataset")) 134 | 135 | if self.dataset_test_name == 'LSUN': 136 | self.dataset_test = getattr(datasets, self.dataset_test_name)(db_path=args.dataroot, classes=['bedroom_val'], 137 | transform=transforms.Compose([ 138 | transforms.Scale(self.input_size), 139 | transforms.CenterCrop(self.input_size), 140 | transforms.ToTensor(), 141 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 142 | ]) 143 | ) 144 | 145 | elif self.dataset_test_name == 'CIFAR10' or self.dataset_test_name == 'CIFAR100': 146 | self.dataset_test = getattr(datasets, self.dataset_test_name)(root=self.args.dataroot, train=False, download=True, 147 | transform=transforms.Compose([ 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 150 | ]) 151 | ) 152 | 153 | elif self.dataset_test_name == 'CocoCaption' or self.dataset_test_name == 'CocoDetection': 154 | self.dataset_test = getattr(datasets, self.dataset_test_name)(root=self.args.dataroot, train=False, download=True, 155 | transform=transforms.Compose([ 156 | transforms.Scale(self.input_size), 157 | transforms.ToTensor(), 158 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 159 | ]) 160 | ) 161 | 162 | elif self.dataset_test_name == 'STL10' or self.dataset_test_name == 'SVHN': 163 | self.dataset_test = getattr(datasets, self.dataset_test_name)(root=self.args.dataroot, split='test', download=True, 164 | transform=transforms.Compose([ 165 | transforms.Scale(self.input_size), 166 | transforms.ToTensor(), 167 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 168 | ]) 169 | ) 170 | 171 | elif self.dataset_test_name == 'MNIST': 172 | self.dataset_test = getattr(datasets, self.dataset_test_name)(root=self.args.dataroot, train=False, download=True, 173 | transform=transforms.Compose([ 174 | transforms.ToTensor(), 175 | transforms.Normalize((0.1307,), (0.3081,)) 176 | ]) 177 | ) 178 | 179 | elif self.dataset_test_name == 'ImageNet': 180 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 181 | std=[0.229, 0.224, 0.225]) 182 | self.dataset_test = datasets.ImageFolder(root=os.path.join(self.args.dataroot,self.args.input_filename_test), 183 | transform=transforms.Compose([ 184 | transforms.Scale(256), 185 | transforms.CenterCrop(224), 186 | transforms.ToTensor(), 187 | normalize, 188 | ]) 189 | ) 190 | 191 | elif self.dataset_test_name == 'FRGC': 192 | self.dataset_test = datasets.ImageFolder(root=self.args.dataroot+self.args.input_filename_test, 193 | transform=transforms.Compose([ 194 | transforms.Scale(self.input_size), 195 | transforms.ToTensor(), 196 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 197 | ]) 198 | ) 199 | 200 | elif self.dataset_test_name == 'Folder': 201 | self.dataset_test = datasets.ImageFolder(root=self.args.dataroot+self.args.input_filename_test, 202 | transform=transforms.Compose([ 203 | transforms.Scale(self.input_size), 204 | transforms.ToTensor(), 205 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 206 | ]) 207 | ) 208 | 209 | elif self.dataset_test_name == 'FileList': 210 | self.dataset_test = datasets.FileList(self.input_filename_test, self.label_filename_test, self.split_train, 211 | self.split_test, train=True, 212 | transform_train=transforms.Compose([ 213 | transforms.Scale(self.input_size), 214 | transforms.CenterCrop(self.input_size), 215 | transforms.ToTensor(), 216 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 217 | ]), 218 | loader_input=self.loader_input, 219 | loader_label=self.loader_label, 220 | ) 221 | 222 | elif self.dataset_test_name == 'FolderList': 223 | self.dataset_test = datasets.FileList(self.input_filename_test, self.label_filename_test, self.split_train, 224 | self.split_test, train=True, 225 | transform_train=transforms.Compose([ 226 | transforms.Scale(self.input_size), 227 | transforms.CenterCrop(self.input_size), 228 | transforms.ToTensor(), 229 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 230 | ]), 231 | loader_input=self.loader_input, 232 | loader_label=self.loader_label, 233 | ) 234 | 235 | else: 236 | raise(Exception("Unknown Dataset")) 237 | 238 | def create(self, flag=None): 239 | if flag == "Train": 240 | dataloader_train = torch.utils.data.DataLoader(self.dataset_train, batch_size=self.args.batch_size, 241 | shuffle=True, num_workers=int(self.args.nthreads), pin_memory=True) 242 | return dataloader_train 243 | 244 | if flag == "Test": 245 | dataloader_test = torch.utils.data.DataLoader(self.dataset_test, batch_size=self.args.batch_size, 246 | shuffle=False, num_workers=int(self.args.nthreads), pin_memory=True) 247 | return dataloader_test 248 | 249 | if flag == None: 250 | dataloader_train = torch.utils.data.DataLoader(self.dataset_train, batch_size=self.args.batch_size, 251 | shuffle=True, num_workers=int(self.args.nthreads), pin_memory=True) 252 | 253 | dataloader_test = torch.utils.data.DataLoader(self.dataset_test, batch_size=self.args.batch_size, 254 | shuffle=False, num_workers=int(self.args.nthreads), pin_memory=True) 255 | return dataloader_train, dataloader_test -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | 3 | from torchvision.datasets import * 4 | from .filelist import FileList 5 | from .folderlist import FolderList 6 | #from .svhn import SVHN 7 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/filelist.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/filelist.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/filelist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/filelist.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/folderlist.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/folderlist.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/folderlist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/folderlist.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/loaders.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/loaders.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/loaders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/loaders.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/svhn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michaelklachko/pnn.pytorch/08865dbac326f2f1537c6989932bce477e448b67/datasets/__pycache__/svhn.cpython-35.pyc -------------------------------------------------------------------------------- /datasets/filelist.py: -------------------------------------------------------------------------------- 1 | # filelist.py 2 | 3 | import torch 4 | import numpy as np 5 | import utils as utils 6 | import torch.utils.data as data 7 | from sklearn.utils import shuffle 8 | import datasets.loaders as loaders 9 | import torchvision.transforms as transforms 10 | 11 | class FileList(data.Dataset): 12 | def __init__(self, ifile, lfile=None, split_train=1.0, split_test=0.0, train=True, 13 | transform_train=None, transform_test=None, loader_input=loaders.loader_image, loader_label=loaders.loader_torch): 14 | 15 | self.ifile = ifile 16 | self.lfile = lfile 17 | self.train = train 18 | self.split_test = split_test 19 | self.split_train = split_train 20 | self.transform_test = transform_test 21 | self.transform_train = transform_train 22 | 23 | self.loader_input = loader_input 24 | self.loader_label = loader_label 25 | 26 | if loader_input == 'image': 27 | self.loader_input = loaders.loader_image 28 | if loader_input == 'torch': 29 | self.loader_input = loaders.loader_torch 30 | if loader_input == 'numpy': 31 | self.loader_input = loaders.loader_numpy 32 | 33 | if loader_label == 'image': 34 | self.loader_label = loaders.loader_image 35 | if loader_label == 'torch': 36 | self.loader_label = loaders.loader_torch 37 | if loader_label == 'numpy': 38 | self.loader_label = loaders.loader_numpy 39 | 40 | if ifile != None: 41 | imagelist = utils.readtextfile(ifile) 42 | imagelist = [x.rstrip('\n') for x in imagelist] 43 | else: 44 | imagelist = [] 45 | 46 | if lfile != None: 47 | labellist = utils.readtextfile(lfile) 48 | labellist = [x.rstrip('\n') for x in labellist] 49 | else: 50 | labellist = [] 51 | 52 | if len(imagelist) == len(labellist): 53 | shuffle(imagelist, labellist) 54 | 55 | if len(imagelist) > 0 and len(labellist) == 0: 56 | shuffle(imagelist) 57 | 58 | if len(labellist) > 0 and len(imagelist) == 0: 59 | shuffle(labellist) 60 | 61 | if (self.split_train < 1.0) & (self.split_train > 0.0): 62 | if len(imagelist) > 0: 63 | num = math.floor(self.split*len(imagelist)) 64 | self.images_train = imagelist[0:num] 65 | self.images_test = images[num+1:len(imagelist)] 66 | if len(labellist) > 0: 67 | num = math.floor(self.split*len(labellist)) 68 | self.labels_train = labellist[0:num] 69 | self.labels_test = labellist[num+1:len(labellist)] 70 | 71 | elif self.split_train == 1.0: 72 | if len(imagelist) > 0: 73 | self.images_train = imagelist 74 | if len(labellist) > 0: 75 | self.labels_train = labellist 76 | 77 | elif self.split_test == 1.0: 78 | if len(imagelist) > 0: 79 | self.images_test = imagelist 80 | if len(labellist) > 0: 81 | self.labels_test = labellist 82 | 83 | def __len__(self): 84 | if self.train == True: 85 | return len(self.images_train) 86 | if self.train == False: 87 | return len(self.images_test) 88 | 89 | def __getitem__(self, index): 90 | input = {} 91 | if self.train == True: 92 | if len(self.images_train) > 0: 93 | path = self.images_train[index] 94 | input['inp'] = self.loader_input(path) 95 | 96 | if len(self.labels_train) > 0: 97 | path = self.labels_train[index] 98 | input['tgt'] = self.loader_label(path) 99 | 100 | if self.transform_train is not None: 101 | input = self.transform_train(input) 102 | 103 | image = input['inp'] 104 | label = input['tgt'] 105 | 106 | if self.train == False: 107 | if len(self.images_test) > 0: 108 | path = self.images_test[index] 109 | input['inp'] = self.loader_input(path) 110 | 111 | if len(self.labels_test) > 0: 112 | path = self.labels_test[index] 113 | input['tgt'] = self.loader_label(path) 114 | 115 | if self.transform_test is not None: 116 | input = self.transform_test(input) 117 | 118 | image = input['inp'] 119 | label = input['tgt'] 120 | 121 | return image, label -------------------------------------------------------------------------------- /datasets/folderlist.py: -------------------------------------------------------------------------------- 1 | # folderlist.py 2 | 3 | import os 4 | import os.path 5 | import utils as utils 6 | import torch.utils.data as data 7 | import datasets.loaders as loaders 8 | 9 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG','.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | def make_dataset(classlist, labellist=None): 15 | images = [] 16 | labels = [] 17 | classes = utils.readtextfile(ifile) 18 | classes = [x.rstrip('\n') for x in classes] 19 | classes.sort() 20 | 21 | for i in len(classes): 22 | for fname in os.listdir(classes[i]): 23 | if is_image_file(fname): 24 | label = {} 25 | label['class'] = os.path.split(classes[i]) 26 | images.append(fname) 27 | labels.append(label) 28 | 29 | if labellist != None: 30 | labels = utils.readtextfile(ifile) 31 | labels = [x.rstrip('\n') for x in labels] 32 | labels.sort() 33 | for i in len(labels): 34 | for fname in os.listdir(labels[i]): 35 | if is_image_file(fname): 36 | labels.append(os.path.split(classes[i])) 37 | 38 | return images, labels 39 | 40 | class FolderList(data.Dataset): 41 | def __init__(self, ifile, lfile=None, split_train=1.0, split_test=0.0, train=True, 42 | transform_train=None, transform_test=None, loader_input=loaders.loader_image, loader_label=loaders.loader_torch): 43 | 44 | imagelist, labellist = make_dataset(ifile, lfile) 45 | if len(imagelist) == 0: 46 | raise(RuntimeError("No images found")) 47 | if len(labellist) == 0: 48 | raise(RuntimeError("No labels found")) 49 | 50 | self.loader_input = loader_input 51 | self.loader_label = loader_label 52 | 53 | if loader_input == 'image': 54 | self.loader_input = loaders.loader_image 55 | if loader_input == 'torch': 56 | self.loader_input = loaders.loader_torch 57 | if loader_input == 'numpy': 58 | self.loader_input = loaders.loader_numpy 59 | 60 | if loader_label == 'image': 61 | self.loader_label = loaders.loader_image 62 | if loader_label == 'torch': 63 | self.loader_label = loaders.loader_torch 64 | if loader_label == 'numpy': 65 | self.loader_label = loaders.loader_numpy 66 | 67 | self.imagelist = imagelist 68 | self.labellist = labellist 69 | self.transform_test = transform_test 70 | self.transform_train = transform_train 71 | 72 | if len(imagelist) == len(labellist): 73 | shuffle(imagelist, labellist) 74 | 75 | if len(imagelist) > 0 and len(labellist) == 0: 76 | shuffle(imagelist) 77 | 78 | if len(labellist) > 0 and len(imagelist) == 0: 79 | shuffle(labellist) 80 | 81 | if (args.split_train < 1.0) & (args.split_train > 0.0): 82 | if len(imagelist) > 0: 83 | num = math.floor(args.split*len(imagelist)) 84 | self.images_train = imagelist[0:num] 85 | self.images_test = images[num+1:len(imagelist)] 86 | if len(labellist) > 0: 87 | num = math.floor(args.split*len(labellist)) 88 | self.labels_train = labellist[0:num] 89 | self.labels_test = labellist[num+1:len(labellist)] 90 | 91 | elif args.split_train == 1.0: 92 | if len(imagelist) > 0: 93 | self.images_train = imagelist 94 | if len(labellist) > 0: 95 | self.labels_train = labellist 96 | 97 | elif args.split_test == 1.0: 98 | if len(imagelist) > 0: 99 | self.images_test = imagelist 100 | if len(labellist) > 0: 101 | self.labels_test = labellist 102 | 103 | def __len__(self): 104 | if self.train == True: 105 | return len(self.images_train) 106 | if self.train == False: 107 | return len(self.images_test) 108 | 109 | def __getitem__(self, index): 110 | if self.train == True: 111 | if len(self.images_train) > 0: 112 | path = self.images_train[index] 113 | input['inp'] = self.loader_input(path) 114 | 115 | if len(self.labels_train) > 0: 116 | path = self.labels_train[index] 117 | input['tgt'] = self.loader_label(path) 118 | 119 | if self.transform_train is not None: 120 | input = self.transform_train(input) 121 | 122 | image = input['inp'] 123 | label = input['tgt'] 124 | 125 | if self.train == False: 126 | if len(self.images_test) > 0: 127 | path = self.images_test[index] 128 | input['inp'] = self.loader_input(path) 129 | 130 | if len(self.labels_test) > 0: 131 | path = self.labels_test[index] 132 | input['tgt'] = self.loader_label(path) 133 | 134 | if self.transform_test is not None: 135 | input = self.transform_test(input) 136 | 137 | image = input['inp'] 138 | label = input['tgt'] 139 | 140 | return image, label -------------------------------------------------------------------------------- /datasets/loaders.py: -------------------------------------------------------------------------------- 1 | # loaders.py 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | 7 | def loader_image(path): 8 | return Image.open(path).convert('RGB') 9 | 10 | def loader_torch(path): 11 | return torch.load(path) 12 | 13 | def loader_numpy(path): 14 | return np.load(path) -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # transforms.py 2 | 3 | from __future__ import division 4 | 5 | import math 6 | import types 7 | import torch 8 | import utils 9 | import random 10 | import numbers 11 | import numpy as np 12 | import scipy as sp 13 | from scipy import misc 14 | from PIL import Image, ImageOps, ImageDraw 15 | 16 | class Compose(object): 17 | """Composes several transforms together. 18 | 19 | Args: 20 | transforms (List[Transform]): list of transforms to compose. 21 | 22 | Example: 23 | >>> transforms.Compose([ 24 | >>> transforms.CenterCrop(10), 25 | >>> transforms.ToTensor(), 26 | >>> ]) 27 | """ 28 | def __init__(self, transforms): 29 | self.transforms = transforms 30 | 31 | def __call__(self, input): 32 | for t in self.transforms: 33 | input = t(input) 34 | 35 | return input 36 | 37 | 38 | 39 | class ToTensor(object): 40 | """Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range 41 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].""" 42 | 43 | def __call__(self, input): 44 | for key in input.keys(): 45 | value = input[key] 46 | if isinstance(value, np.ndarray): 47 | # handle numpy array 48 | input[key] = torch.from_numpy(value) 49 | else: 50 | # handle PIL Image 51 | tmp = torch.ByteTensor(torch.ByteStorage.from_buffer(value.tobytes())) 52 | value = tmp.view(value.size[1], value.size[0], len(value.mode)) 53 | # put it from HWC to CHW format 54 | # yikes, this transpose takes 80% of the loading time/CPU 55 | value = value.transpose(0, 1).transpose(0, 2).contiguous() 56 | input[key] = value.float().div(255) 57 | return input 58 | 59 | 60 | 61 | class ToPILImage(object): 62 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 63 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 64 | to a PIL.Image of range [0, 255] 65 | """ 66 | def __call__(self, input): 67 | if isinstance(input['img'], np.ndarray): 68 | # handle numpy array 69 | input['img'] = Image.fromarray(input['img']) 70 | else: 71 | npimg = input['img'].mul(255).byte().numpy() 72 | npimg = np.transpose(npimg, (1,2,0)) 73 | input['img'] = Image.fromarray(npimg) 74 | return input 75 | 76 | class Normalize(object): 77 | """Given mean: (R, G, B) and std: (R, G, B), 78 | will normalize each channel of the torch.*Tensor, i.e. 79 | channel = (channel - mean) / std 80 | """ 81 | def __init__(self, mean, std): 82 | self.mean = mean 83 | self.std = std 84 | 85 | def __call__(self, input): 86 | # TODO: make efficient 87 | for t, m, s in zip(input['img'], self.mean, self.std): 88 | t.sub_(m).div_(s) 89 | return input 90 | 91 | class Scale(object): 92 | """Rescales the input PIL.Image to the given 'size'. 93 | 'size' will be the size of the smaller edge. 94 | For example, if height > width, then image will be 95 | rescaled to (size * height / width, size) 96 | size: size of the smaller edge 97 | interpolation: Default: PIL.Image.BILINEAR 98 | """ 99 | def __init__(self, size, interpolation=Image.BILINEAR): 100 | self.size = size 101 | self.interpolation = interpolation 102 | 103 | def __call__(self, input): 104 | w, h = input['img'].size 105 | if (w <= h and w == self.size) or (h <= w and h == self.size): 106 | return input 107 | if w < h: 108 | ow = self.size 109 | oh = int(self.size * h / w) 110 | input['img'] = input['img'].resize((ow, oh), self.interpolation) 111 | return input 112 | else: 113 | oh = self.size 114 | ow = int(self.size * w / h) 115 | input['img'] = input['img'].resize((ow, oh), self.interpolation) 116 | return input 117 | 118 | 119 | class CenterCrop(object): 120 | """Crops the given PIL.Image at the center to have a region of 121 | the given size. size can be a tuple (target_height, target_width) 122 | or an integer, in which case the target will be of a square shape (size, size) 123 | """ 124 | def __init__(self, size): 125 | if isinstance(size, numbers.Number): 126 | self.size = (int(size), int(size)) 127 | else: 128 | self.size = size 129 | 130 | def __call__(self, input): 131 | w, h = input['img'].size 132 | th, tw = self.size 133 | x1 = int(round((w - tw) / 2.)) 134 | y1 = int(round((h - th) / 2.)) 135 | input['img'] = input['img'].crop((x1, y1, x1 + tw, y1 + th)) 136 | return input 137 | 138 | class Pad(object): 139 | """Pads the given PIL.Image on all sides with the given "pad" value""" 140 | def __init__(self, padding, fill=0): 141 | assert isinstance(padding, numbers.Number) 142 | assert isinstance(fill, numbers.Number) 143 | self.padding = padding 144 | self.fill = fill 145 | 146 | def __call__(self, input): 147 | input['img'] = ImageOps.expand(input['img'], border=self.padding, fill=self.fill) 148 | return input 149 | 150 | class Lambda(object): 151 | """Applies a lambda as a transform.""" 152 | def __init__(self, lambd): 153 | assert type(lambd) is types.LambdaType 154 | self.lambd = lambd 155 | 156 | def __call__(self, input): 157 | input['img'] = self.lambd(img) 158 | return input 159 | 160 | 161 | class RandomCrop(object): 162 | """Crops the given PIL.Image at a random location to have a region of 163 | the given size. size can be a tuple (target_height, target_width) 164 | or an integer, in which case the target will be of a square shape (size, size) 165 | """ 166 | def __init__(self, size, padding=0): 167 | if isinstance(size, numbers.Number): 168 | self.size = (int(size), int(size)) 169 | else: 170 | self.size = size 171 | self.padding = padding 172 | 173 | def __call__(self, input): 174 | if self.padding > 0: 175 | input['img'] = ImageOps.expand(img, border=self.padding, fill=0) 176 | 177 | w, h = input['img'].size 178 | th, tw = self.size 179 | if w == tw and h == th: 180 | return input 181 | 182 | x1 = random.randint(0, w - tw) 183 | y1 = random.randint(0, h - th) 184 | input['img'] = input['img'].crop((x1, y1, x1 + tw, y1 + th)) 185 | return input 186 | 187 | 188 | class RandomHorizontalFlip(object): 189 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 190 | """ 191 | def __call__(self, input): 192 | if random.random() < 0.5: 193 | input['img'] = input['img'].transpose(Image.FLIP_LEFT_RIGHT) 194 | input['tgt'] = input['tgt'].transpose(Image.FLIP_LEFT_RIGHT) 195 | input['loc'][0] = input['loc'][0] - math.ceil(input['img'].size[0]/2) 196 | return input 197 | 198 | class RandomSizedCrop(object): 199 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 200 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 201 | This is popularly used to train the Inception networks 202 | size: size of the smaller edge 203 | interpolation: Default: PIL.Image.BILINEAR 204 | """ 205 | def __init__(self, size, interpolation=Image.BILINEAR): 206 | self.size = size 207 | self.interpolation = interpolation 208 | 209 | def __call__(self, input): 210 | for attempt in range(10): 211 | area = input['img'].size[0] * input['img'].size[1] 212 | target_area = random.uniform(0.08, 1.0) * area 213 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 214 | 215 | w = int(round(math.sqrt(target_area * aspect_ratio))) 216 | h = int(round(math.sqrt(target_area / aspect_ratio))) 217 | 218 | if random.random() < 0.5: 219 | w, h = h, w 220 | 221 | if w <= input['img'].size[0] and h <= input['img'].size[1]: 222 | x1 = random.randint(0, input['img'].size[0] - w) 223 | y1 = random.randint(0, input['img'].size[1] - h) 224 | 225 | input['img'] = input['img'].crop((x1, y1, x1 + w, y1 + h)) 226 | assert(input['img'].size == (w, h)) 227 | input['img'] = input['img'].resize((self.size, self.size), self.interpolation) 228 | return input 229 | 230 | # Fallback 231 | scale = Scale(self.size, interpolation=self.interpolation) 232 | crop = CenterCrop(self.size) 233 | return crop(scale(input)) 234 | 235 | class NormalizeLandmarks(object): 236 | """ max-min normalization of landmarks to range [-1,1]""" 237 | def __init__(self,xsize,ysize): 238 | self.xsize = xsize 239 | self.ysize = ysize 240 | 241 | def __call__(self, input): 242 | valid_points = [v for v in input['loc'] if v[0] != 0 and v[1] != 0] 243 | mean = np.mean(valid_points,axis = 0) 244 | for i in range(input['loc'].shape[0]): 245 | input['loc'][i][0] = -1 + (input['loc'][i][0] * 2. )/(inputx_res) 246 | input['loc'][i][1] = -1 + (input['loc'][i][1] * 2. )/(inputy_res) 247 | 248 | return input 249 | 250 | class AffineCrop(object): 251 | def __init__(self,nlandmark,ix,iy,ox,oy,rangle=0,rscale=0,rtrans=0,gauss=1): 252 | self.rangle=rangle 253 | self.rscale=rscale 254 | self.rtrans=rtrans 255 | self.nlandmark=nlandmark 256 | self.ix = ix 257 | self.iy = iy 258 | self.ox = ox 259 | self.oy = oy 260 | self.utils = utils 261 | self.gauss = gauss 262 | 263 | def __call__(self, input): 264 | 265 | angle = self.rangle*(2*torch.rand(1)[0] - 1) 266 | grad_angle = angle * math.pi / 180 267 | scale = 1+self.rscale*(2*torch.rand(1)[0] - 1) 268 | transx = self.rtrans*(2*torch.rand(1)[0] - 1) 269 | transy = self.rtrans*(2*torch.rand(1)[0] - 1) 270 | 271 | img = input['img'] 272 | size = img.size 273 | h, w = size[0], size[1] 274 | centerX, centerY = int(w/2), int(h/2) 275 | 276 | # perform rotation 277 | img = img.rotate(angle, Image.BICUBIC) 278 | # perform translation 279 | img = img.transform(img.size, Image.AFFINE, (1, 0, transx, 0, 1, transy)) 280 | # perform scaling 281 | img = img.resize((int(math.ceil(scale*h)) , int(math.ceil(scale*w))) , Image.ANTIALIAS) 282 | 283 | w, h = img.size 284 | x1 = int(round((w - self.ix) / 2.)) 285 | y1 = int(round((h - self.ix) / 2.)) 286 | input['img'] = img.crop((x1, y1, x1 + self.ix, y1 + self.iy)) 287 | 288 | if (np.sum(input['loc']) != 0): 289 | 290 | occ = input['occ'] 291 | loc = input['loc'] 292 | newloc = np.ones((3,loc.shape[1]+1)) 293 | newloc[0:2,0:loc.shape[1]] = loc 294 | newloc[0,loc.shape[1]] = centerY 295 | newloc[1,loc.shape[1]] = centerX 296 | 297 | trans_matrix = np.array([[1,0,-1*transx],[0,1,-1*transy],[0,0,1]]) 298 | scale_matrix = np.array([[scale,0,0],[0,scale,0],[0,0,1]]) 299 | angle_matrix = np.array([[math.cos(grad_angle),math.sin(grad_angle),0],[-math.sin(grad_angle),math.cos(grad_angle),0],[0,0,1]]) 300 | 301 | # perform rotation 302 | newloc[0,:] = newloc[0,:] - centerY 303 | newloc[1,:] = newloc[1,:] - centerX 304 | newloc = np.dot(angle_matrix, newloc) 305 | newloc[0,:] = newloc[0,:] + centerY 306 | newloc[1,:] = newloc[1,:] + centerX 307 | # perform translation 308 | newloc = np.dot(trans_matrix, newloc) 309 | # perform scaling 310 | newloc = np.dot(scale_matrix, newloc) 311 | 312 | newloc[0,:] = newloc[0,:] - y1 313 | newloc[1,:] = newloc[1,:] - x1 314 | input['loc'] = newloc[0:2,:] 315 | 316 | for i in range(input['loc'].shape[1]): 317 | if ~((input['loc'][0,i] == np.nan) & (input['loc'][1,i] == np.nan)): 318 | if ((input['loc'][0,i] < 0) | (input['loc'][0,i] > self.iy) | (input['loc'][1,i] < 0) | (input['loc'][1,i] > self.ix)): 319 | input['loc'][:,i] = np.nan 320 | input['occ'][i] = 0 321 | 322 | # generate heatmaps 323 | input['tgt'] = np.zeros((self.nlandmark+1, self.ox, self.oy)) 324 | for i in range(self.nlandmark): 325 | if (not np.isnan(input['loc'][:,i][0]) and not np.isnan(input['loc'][:,i][1])): 326 | tmp = self.utils.gaussian(np.array([self.ix,self.iy]),input['loc'][:,i],self.gauss) 327 | scaled_tmp = sp.misc.imresize(tmp, [self.ox, self.oy]) 328 | scaled_tmp = (scaled_tmp - min(scaled_tmp.flatten()) ) / ( max(scaled_tmp.flatten()) - min(scaled_tmp.flatten())) 329 | else: 330 | scaled_tmp = np.zeros([self.ox,self.oy]) 331 | input['tgt'][i] = scaled_tmp 332 | 333 | tmp = self.utils.gaussian(np.array([self.iy,self.ix]),input['loc'][:,-1],4*self.gauss) 334 | scaled_tmp = sp.misc.imresize(tmp, [self.ox, self.oy]) 335 | scaled_tmp = (scaled_tmp - min(scaled_tmp.flatten()) ) / ( max(scaled_tmp.flatten()) - min(scaled_tmp.flatten())) 336 | input['tgt'][self.nlandmark] = scaled_tmp 337 | 338 | return input 339 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from dataloader import Dataloader 4 | import utils 5 | import os 6 | from datetime import datetime 7 | import argparse 8 | import math 9 | import numpy as np 10 | from torch import nn 11 | import models 12 | import torch.optim as optim 13 | 14 | result_path = "results/" 15 | result_path = os.path.join(result_path, datetime.now().strftime('%Y-%m-%d_%H-%M-%S/')) 16 | 17 | parser = argparse.ArgumentParser(description='Your project title goes here') 18 | 19 | # ======================== Data Setings ============================================ 20 | parser.add_argument('--dataset-test', type=str, default='CIFAR10', metavar='', help='name of training dataset') 21 | parser.add_argument('--dataset-train', type=str, default='CIFAR10', metavar='', help='name of training dataset') 22 | parser.add_argument('--dataroot', type=str, default='./data', metavar='', help='path to the data') 23 | parser.add_argument('--save', type=str, default=result_path +'Save', metavar='', help='save the trained models here') 24 | parser.add_argument('--logs', type=str, default=result_path +'Logs', metavar='', help='save the training log files here') 25 | parser.add_argument('--resume', type=str, default=None, metavar='', help='full path of models to resume training') 26 | 27 | # ======================== Network Model Setings =================================== 28 | 29 | feature_parser = parser.add_mutually_exclusive_group(required=False) 30 | feature_parser.add_argument('--use_act', dest='use_act', action='store_true') 31 | feature_parser.add_argument('--no-use_act', dest='use_act', action='store_false') 32 | parser.set_defaults(use_act=False) 33 | 34 | feature_parser = parser.add_mutually_exclusive_group(required=False) 35 | feature_parser.add_argument('--unique_masks', dest='unique_masks', action='store_true') 36 | feature_parser.add_argument('--no-unique_masks', dest='unique_masks', action='store_false') 37 | parser.set_defaults(unique_masks=True) 38 | 39 | feature_parser = parser.add_mutually_exclusive_group(required=False) 40 | feature_parser.add_argument('--debug', dest='debug', action='store_true') 41 | feature_parser.add_argument('--no-debug', dest='debug', action='store_false') 42 | parser.set_defaults(debug=False) 43 | 44 | feature_parser = parser.add_mutually_exclusive_group(required=False) 45 | feature_parser.add_argument('--train_masks', dest='train_masks', action='store_true') 46 | feature_parser.add_argument('--no-train_masks', dest='train_masks', action='store_false') 47 | parser.set_defaults(train_masks=False) 48 | 49 | feature_parser = parser.add_mutually_exclusive_group(required=False) 50 | feature_parser.add_argument('--mix_maps', dest='mix_maps', action='store_true') 51 | feature_parser.add_argument('--no-mix_maps', dest='mix_maps', action='store_false') 52 | parser.set_defaults(mix_maps=False) 53 | 54 | parser.add_argument('--filter_size', type=int, default=0, metavar='', help='use conv layer with this kernel size in FirstLayer') 55 | parser.add_argument('--first_filter_size', type=int, default=0, metavar='', help='use conv layer with this kernel size in FirstLayer') 56 | parser.add_argument('--nfilters', type=int, default=64, metavar='', help='number of filters in each layer') 57 | parser.add_argument('--nmasks', type=int, default=1, metavar='', help='number of noise masks per input channel (fan out)') 58 | parser.add_argument('--level', type=float, default=0.5, metavar='', help='noise level for uniform noise') 59 | parser.add_argument('--scale_noise', type=float, default=1.0, metavar='', help='noise level for uniform noise') 60 | parser.add_argument('--noise_type', type=str, default='uniform', metavar='', help='type of noise') 61 | parser.add_argument('--dropout', type=float, default=0.5, metavar='', help='dropout parameter') 62 | parser.add_argument('--net-type', type=str, default='resnet18', metavar='', help='type of network') 63 | parser.add_argument('--act', type=str, default='relu', metavar='', help='activation function (for both perturb and conv layers)') 64 | parser.add_argument('--pool_type', type=str, default='max', metavar='', help='pooling function (max or avg)') 65 | 66 | # ======================== Training Settings ======================================= 67 | parser.add_argument('--batch-size', type=int, default=64, metavar='', help='batch size for training') 68 | parser.add_argument('--nepochs', type=int, default=150, metavar='', help='number of epochs to train') 69 | parser.add_argument('--nthreads', type=int, default=4, metavar='', help='number of threads for data loading') 70 | parser.add_argument('--manual-seed', type=int, default=1, metavar='', help='manual seed for randomness') 71 | 72 | # ======================== Hyperparameter Setings ================================== 73 | parser.add_argument('--optim-method', type=str, default='SGD', metavar='', help='the optimization routine ') 74 | parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='', help='learning rate') 75 | parser.add_argument('--learning-rate-decay', type=float, default=None, metavar='', help='learning rate decay') 76 | parser.add_argument('--momentum', type=float, default=0.9, metavar='', help='momentum') 77 | parser.add_argument('--weight-decay', type=float, default=1e-4, metavar='', help='weight decay') 78 | parser.add_argument('--adam-beta1', type=float, default=0.9, metavar='', help='Beta 1 parameter for Adam') 79 | parser.add_argument('--adam-beta2', type=float, default=0.999, metavar='', help='Beta 2 parameter for Adam') 80 | 81 | args = parser.parse_args() 82 | random.seed(args.manual_seed) 83 | torch.manual_seed(args.manual_seed) 84 | utils.saveargs(args) 85 | 86 | class Model: 87 | def __init__(self, args): 88 | self.cuda = torch.cuda.is_available() 89 | self.lr = args.learning_rate 90 | self.dataset_train_name = args.dataset_train 91 | self.nfilters = args.nfilters 92 | self.batch_size = args.batch_size 93 | self.level = args.level 94 | self.net_type = args.net_type 95 | self.nmasks = args.nmasks 96 | self.unique_masks = args.unique_masks 97 | self.filter_size = args.filter_size 98 | self.first_filter_size = args.first_filter_size 99 | self.scale_noise = args.scale_noise 100 | self.noise_type = args.noise_type 101 | self.act = args.act 102 | self.use_act = args.use_act 103 | self.dropout = args.dropout 104 | self.train_masks = args.train_masks 105 | self.debug = args.debug 106 | self.pool_type = args.pool_type 107 | self.mix_maps = args.mix_maps 108 | 109 | if self.dataset_train_name.startswith("CIFAR"): 110 | self.input_size = 32 111 | self.nclasses = 10 112 | if self.filter_size < 7: 113 | self.avgpool = 4 114 | elif self.filter_size == 7: 115 | self.avgpool = 1 116 | 117 | elif self.dataset_train_name.startswith("MNIST"): 118 | self.nclasses = 10 119 | self.input_size = 28 120 | if self.filter_size < 7: 121 | self.avgpool = 14 #TODO 122 | elif self.filter_size == 7: 123 | self.avgpool = 7 124 | 125 | self.model = getattr(models, self.net_type)( 126 | nfilters=self.nfilters, 127 | avgpool=self.avgpool, 128 | nclasses=self.nclasses, 129 | nmasks=self.nmasks, 130 | unique_masks=self.unique_masks, 131 | level=self.level, 132 | filter_size=self.filter_size, 133 | first_filter_size=self.first_filter_size, 134 | act=self.act, 135 | scale_noise=self.scale_noise, 136 | noise_type=self.noise_type, 137 | use_act=self.use_act, 138 | dropout=self.dropout, 139 | train_masks=self.train_masks, 140 | pool_type=self.pool_type, 141 | debug=self.debug, 142 | input_size=self.input_size, 143 | mix_maps=self.mix_maps 144 | ) 145 | 146 | self.loss_fn = nn.CrossEntropyLoss() 147 | 148 | if self.cuda: 149 | self.model = self.model.cuda() 150 | self.loss_fn = self.loss_fn.cuda() 151 | 152 | parameters = filter(lambda p: p.requires_grad, self.model.parameters()) 153 | 154 | if args.optim_method == 'Adam': 155 | self.optimizer = optim.Adam(parameters, lr=self.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.weight_decay) #increase weight decay for no-noise large models 156 | elif args.optim_method == 'RMSprop': 157 | self.optimizer = optim.RMSprop(parameters, lr=self.lr, momentum=args.momentum, weight_decay=args.weight_decay) 158 | elif args.optim_method == 'SGD': 159 | self.optimizer = optim.SGD(parameters, lr=self.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 160 | """ 161 | # use this to set different learning rates for training noise masks and regular parameters: 162 | self.optimizer = optim.SGD([{'params': [param for name, param in self.model.named_parameters() if 'noise' not in name]}, 163 | {'params': [param for name, param in self.model.named_parameters() if 'noise' in name], 'lr': self.lr * 10}, 164 | ], lr=self.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) #""" 165 | else: 166 | raise(Exception("Unknown Optimization Method")) 167 | 168 | 169 | def learning_rate(self, epoch): 170 | if self.dataset_train_name == 'CIFAR10': 171 | new_lr = self.lr * ((0.2 ** int(epoch >= 60)) * (0.2 ** int(epoch >= 90)) * (0.2 ** int(epoch >= 120)) * (0.2 ** int(epoch >= 160))) 172 | elif self.dataset_train_name == 'CIFAR100': 173 | new_lr = self.lr * ((0.1 ** int(epoch >= 80)) * (0.1 ** int(epoch >= 120))* (0.1 ** int(epoch >= 160))) 174 | elif self.dataset_train_name == 'MNIST': 175 | new_lr = self.lr * ((0.2 ** int(epoch >= 30)) * (0.2 ** int(epoch >= 60))* (0.2 ** int(epoch >= 90))) 176 | elif self.dataset_train_name == 'FRGC': 177 | new_lr = self.lr * ((0.1 ** int(epoch >= 80)) * (0.1 ** int(epoch >= 120))* (0.1 ** int(epoch >= 160))) 178 | elif self.dataset_train_name == 'ImageNet': 179 | decay = math.floor((epoch - 1) / 30) 180 | new_lr = self.lr * math.pow(0.1, decay) 181 | #print('\nReducing learning rate to {}\n'.format(new_lr)) 182 | return new_lr 183 | 184 | 185 | def train(self, epoch, dataloader): 186 | self.model.train() 187 | 188 | lr = self.learning_rate(epoch+1) 189 | 190 | for param_group in self.optimizer.param_groups: 191 | #print(param_group) #TODO figure out how to set diff learning rate to noise params if train_masks 192 | param_group['lr'] = lr 193 | 194 | losses = [] 195 | accuracies = [] 196 | for i, (input, label) in enumerate(dataloader): 197 | if self.cuda: 198 | label = label.cuda() 199 | input = input.cuda() 200 | 201 | output = self.model(input) 202 | loss = self.loss_fn(output, label) 203 | if self.debug: 204 | print('\nBatch:', i) 205 | self.optimizer.zero_grad() 206 | loss.backward() 207 | self.optimizer.step() 208 | 209 | pred = output.data.max(1)[1] 210 | 211 | acc = pred.eq(label.data).cpu().sum()*100.0 / self.batch_size 212 | 213 | losses.append(loss.item()) 214 | accuracies.append(acc) 215 | 216 | return np.mean(losses), np.mean(accuracies) 217 | 218 | def test(self, dataloader): 219 | self.model.eval() 220 | losses = [] 221 | accuracies = [] 222 | with torch.no_grad(): 223 | for i, (input, label) in enumerate(dataloader): 224 | if self.cuda: 225 | label = label.cuda() 226 | input = input.cuda() 227 | 228 | output = self.model(input) 229 | loss = self.loss_fn(output, label) 230 | 231 | pred = output.data.max(1)[1] 232 | acc = pred.eq(label.data).cpu().sum()*100.0 / self.batch_size 233 | losses.append(loss.item()) 234 | accuracies.append(acc) 235 | 236 | return np.mean(losses), np.mean(accuracies) 237 | 238 | print('\n\n****** Creating {} model ******\n\n'.format(args.net_type)) 239 | setup = Model(args) 240 | print('\n\n****** Preparing {} dataset *******\n\n'.format(args.dataset_train)) 241 | dataloader = Dataloader(args, setup.input_size) 242 | loader_train, loader_test = dataloader.create() 243 | 244 | # initialize model: 245 | if args.resume is None: 246 | model = setup.model 247 | model.apply(utils.weights_init) 248 | train = setup.train 249 | test = setup.test 250 | init_epoch = 0 251 | acc_best = 0 252 | best_epoch = 0 253 | if os.path.isdir(args.save) == False: 254 | os.makedirs(args.save) 255 | else: 256 | print('\n\nLoading model from saved checkpoint at {}\n\n'.format(args.resume)) 257 | #self.model.load_state_dict(checkpoints.load(checkpoints.latest('resume'))) 258 | setup.model = torch.load(args.resume) 259 | model = setup.model 260 | train = setup.train 261 | test = setup.test 262 | te_loss, te_acc = test(loader_test) 263 | init_epoch = int(args.resume.split('_')[3]) # extract N from 'results/xxx_xxx/Save/model_epoch_N_acc_nn.nn.pth' 264 | print('\n\nRestored Model Accuracy (epoch {:d}): {:.2f}\n\n'.format(init_epoch, te_acc)) 265 | acc_best = te_acc 266 | best_epoch = init_epoch 267 | args.save = '/'.join(args.resume.split('/')[:-1]) 268 | init_epoch += 1 269 | 270 | 271 | print('\n\n****** Model Graph ******\n\n') 272 | for arg in vars(model): 273 | print(arg, getattr(model, arg)) 274 | 275 | print('\n\nModel parameters:\n') 276 | model_total = 0 277 | for name, param in model.named_parameters(): 278 | size = param.numel() / 1000000. 279 | print('{} {} requires_grad: {} size: {:.2f}M'.format(name, list(param.size()), param.requires_grad, param.numel()/1000000.)) 280 | model_total += size 281 | 282 | print('\n\nNoise masks:\n') 283 | masks_total = 0 284 | for name, param in [(name, param) for name, param in model.named_parameters() if 'noise' in name]: 285 | size = param.numel() / 1000000. 286 | print('{:>22} size: {:.2f}M'.format(str(list(param.size())), param.numel()/1000000.)) 287 | masks_total += size 288 | 289 | print('\n\nModel size: {:.2f}M regular parameters, {:.2f}M noise mask values\n\n'.format(model_total - masks_total, masks_total)) 290 | """ 291 | print('\n\n******************** Model parameters:\n') 292 | for param in model.parameters(): 293 | #if param.requires_grad: 294 | print('{} {}'.format(list(param.size()), param.requires_grad)) 295 | 296 | print('\n\n****** Model state_dict() ******\n\n') 297 | for name, param in model.state_dict().items(): 298 | print('{} {} {}'.format(name, list(param.size()), param.requires_grad)) 299 | """ 300 | 301 | print('\n\n****** Model Configuration ******\n\n') 302 | for arg in vars(args): 303 | print(arg, getattr(args, arg)) 304 | 305 | if args.net_type != 'resnet18' and args.net_type != 'noiseresnet18' and (args.first_filter_size == 0 or args.filter_size == 0): 306 | if args.train_masks: 307 | msg = '(also training noise masks values)' 308 | else: 309 | msg = '(noise masks are fixed)' 310 | else: 311 | msg = '' 312 | 313 | print('\n\nTraining {} model {}\n\n'.format(args.net_type, msg)) 314 | 315 | accuracies = [] 316 | 317 | for epoch in range(init_epoch, args.nepochs, 1): 318 | 319 | tr_loss, tr_acc = train(epoch, loader_train) 320 | te_loss, te_acc = test(loader_test) 321 | 322 | accuracies.append(te_acc) 323 | 324 | if te_acc > acc_best and epoch > 10: 325 | print('{} Epoch {:d}/{:d} Train: Loss {:.2f} Accuracy {:.2f} Test: Loss {:.2f} Accuracy {:.2f} (best result, saving to {})'.format( 326 | str(datetime.now())[:-7], epoch, args.nepochs, tr_loss, tr_acc, te_loss, te_acc, args.save)) 327 | model_best = True 328 | acc_best = te_acc 329 | best_epoch = epoch 330 | torch.save(model, args.save + '/model_epoch_{:d}_acc_{:.2f}.pth'.format(epoch, te_acc)) 331 | else: 332 | if epoch == 0: 333 | print('\n') 334 | print('{} Epoch {:d}/{:d} Train: Loss {:.2f} Accuracy {:.2f} Test: Loss {:.2f} Accuracy {:.2f}'.format( 335 | str(datetime.now())[:-7], epoch, args.nepochs, tr_loss, tr_acc, te_loss, te_acc)) 336 | 337 | print('\n\nBest Accuracy: {:.2f} (epoch {:d})\n\n'.format(acc_best, best_epoch)) 338 | 339 | print('\n\nTest Accuracies:\n\n') 340 | 341 | for v in accuracies: 342 | print('{:.2f}'.format(v)+', ', end='') 343 | print('\n\n') 344 | 345 | plot = False 346 | if plot: 347 | import matplotlib.pyplot as plt 348 | plt.plot(range(args.nepochs), accuracies, 'black', label='model_1') 349 | plt.plot(range(args.nepochs), accuracies, 'red', label='model_2') 350 | plt.plot(range(args.nepochs), accuracies, 'blue', label='model_3') 351 | plt.title('Test Accuracy (CIFAR-10)', fontsize=18) 352 | plt.xlabel('Epochs', fontsize=16) 353 | plt.ylabel('%', fontsize=16) 354 | plt.xticks(fontsize=14) 355 | plt.yticks(fontsize=14) 356 | plt.legend(loc='center right', prop={'size': 14}) 357 | plt.show() -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import act_fn, print_values 5 | 6 | 7 | if torch.cuda.is_available(): 8 | device = torch.device("cuda") 9 | else: 10 | device = torch.device("cpu") 11 | 12 | 13 | """ ****************** Modified (Michael Klachko) PNN Implementation ******************* """ 14 | 15 | class PerturbLayer(nn.Module): 16 | def __init__(self, in_channels=None, out_channels=None, nmasks=None, level=None, filter_size=None, 17 | debug=False, use_act=False, stride=1, act=None, unique_masks=False, mix_maps=None, 18 | train_masks=False, noise_type='uniform', input_size=None): 19 | super(PerturbLayer, self).__init__() 20 | self.nmasks = nmasks #per input channel 21 | self.unique_masks = unique_masks # same set or different sets of nmasks per input channel 22 | self.train_masks = train_masks #whether to treat noise masks as regular trainable parameters of the model 23 | self.level = level # noise magnitude 24 | self.filter_size = filter_size #if filter_size=0, layers=(perturb, conv_1x1) else layers=(conv_NxN), N=filter_size 25 | self.use_act = use_act #whether to use activation immediately after perturbing input (set it to False for the first layer) 26 | self.act = act_fn(act) #relu, prelu, rrelu, elu, selu, tanh, sigmoid (see utils) 27 | self.debug = debug #print input, mask, output values for each batch 28 | self.noise_type = noise_type #normal or uniform 29 | self.in_channels = in_channels 30 | self.input_size = input_size #input image resolution (28 for MNIST, 32 for CIFAR), needed to construct masks 31 | self.mix_maps = mix_maps #whether to apply second 1x1 convolution after perturbation, to mix output feature maps 32 | 33 | if filter_size == 1: 34 | padding = 0 35 | bias = True 36 | elif filter_size == 3 or filter_size == 5: 37 | padding = 1 38 | bias = False 39 | elif filter_size == 7: 40 | stride = 2 41 | padding = 3 42 | bias = False 43 | 44 | if self.filter_size > 0: 45 | self.noise = None 46 | self.layers = nn.Sequential( 47 | nn.Conv2d(in_channels, out_channels, kernel_size=filter_size, padding=padding, stride=stride, bias=bias), 48 | nn.BatchNorm2d(out_channels), 49 | self.act 50 | ) 51 | else: 52 | noise_channels = in_channels if self.unique_masks else 1 53 | shape = (1, noise_channels, self.nmasks, input_size, input_size) # can't dynamically reshape masks in forward if we want to train them 54 | self.noise = nn.Parameter(torch.Tensor(*shape), requires_grad=self.train_masks) 55 | if noise_type == "uniform": 56 | self.noise.data.uniform_(-1, 1) 57 | elif self.noise_type == 'normal': 58 | self.noise.data.normal_() 59 | else: 60 | print('\n\nNoise type {} is not supported / understood\n\n'.format(self.noise_type)) 61 | 62 | if nmasks != 1: 63 | if out_channels % in_channels != 0: 64 | print('\n\n\nnfilters must be divisible by 3 if using multiple noise masks per input channel\n\n\n') 65 | groups = in_channels 66 | else: 67 | groups = 1 68 | 69 | self.layers = nn.Sequential( 70 | #self.act, #TODO orig code uses ReLU here 71 | #nn.BatchNorm2d(out_channels), #TODO: orig code uses BN here 72 | nn.Conv2d(in_channels*self.nmasks, out_channels, kernel_size=1, stride=1, groups=groups), 73 | nn.BatchNorm2d(out_channels), 74 | self.act, 75 | ) 76 | if self.mix_maps: 77 | self.mix_layers = nn.Sequential( 78 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, groups=1), 79 | nn.BatchNorm2d(out_channels), 80 | self.act, 81 | ) 82 | 83 | def forward(self, x): 84 | if self.filter_size > 0: 85 | return self.layers(x) #image, conv, batchnorm, relu 86 | else: 87 | y = torch.add(x.unsqueeze(2), self.noise * self.level) # (10, 3, 1, 32, 32) + (1, 3, 128, 32, 32) --> (10, 3, 128, 32, 32) 88 | 89 | if self.debug: 90 | print_values(x, self.noise, y, self.unique_masks) 91 | 92 | if self.use_act: 93 | y = self.act(y) 94 | 95 | y = y.view(-1, self.in_channels * self.nmasks, self.input_size, self.input_size) 96 | y = self.layers(y) 97 | 98 | if self.mix_maps: 99 | y = self.mix_layers(y) 100 | 101 | return y #image, perturb, (relu?), conv1x1, batchnorm, relu + mix_maps (conv1x1, batchnorm relu) 102 | 103 | 104 | class PerturbBasicBlock(nn.Module): 105 | expansion = 1 106 | def __init__(self, in_channels=None, out_channels=None, stride=1, shortcut=None, nmasks=None, train_masks=False, 107 | level=None, use_act=False, filter_size=None, act=None, unique_masks=False, noise_type=None, 108 | input_size=None, pool_type=None, mix_maps=None): 109 | super(PerturbBasicBlock, self).__init__() 110 | self.shortcut = shortcut 111 | if pool_type == 'max': 112 | pool = nn.MaxPool2d 113 | elif pool_type == 'avg': 114 | pool = nn.AvgPool2d 115 | else: 116 | print('\n\nPool Type {} is not supported/understood\n\n'.format(pool_type)) 117 | return 118 | self.layers = nn.Sequential( 119 | PerturbLayer(in_channels=in_channels, out_channels=out_channels, nmasks=nmasks, input_size=input_size, 120 | level=level, filter_size=filter_size, use_act=use_act, train_masks=train_masks, 121 | act=act, unique_masks=unique_masks, noise_type=noise_type, mix_maps=mix_maps), 122 | pool(stride, stride), 123 | PerturbLayer(in_channels=out_channels, out_channels=out_channels, nmasks=nmasks, input_size=input_size//stride, 124 | level=level, filter_size=filter_size, use_act=use_act, train_masks=train_masks, 125 | act=act, unique_masks=unique_masks, noise_type=noise_type, mix_maps=mix_maps), 126 | ) 127 | 128 | def forward(self, x): 129 | residual = x 130 | y = self.layers(x) 131 | if self.shortcut: 132 | residual = self.shortcut(x) 133 | y += residual 134 | y = F.relu(y) 135 | return y 136 | 137 | 138 | class PerturbResNet(nn.Module): 139 | def __init__(self, block, nblocks=None, avgpool=None, nfilters=None, nclasses=None, nmasks=None, input_size=32, 140 | level=None, filter_size=None, first_filter_size=None, use_act=False, train_masks=False, mix_maps=None, 141 | act=None, scale_noise=1, unique_masks=False, debug=False, noise_type=None, pool_type=None): 142 | super(PerturbResNet, self).__init__() 143 | self.nfilters = nfilters 144 | self.unique_masks = unique_masks 145 | self.noise_type = noise_type 146 | self.train_masks = train_masks 147 | self.pool_type = pool_type 148 | self.mix_maps = mix_maps 149 | 150 | layers = [PerturbLayer(in_channels=3, out_channels=nfilters, nmasks=nmasks, level=level*scale_noise, 151 | debug=debug, filter_size=first_filter_size, use_act=use_act, train_masks=train_masks, input_size=input_size, 152 | act=act, unique_masks=self.unique_masks, noise_type=self.noise_type, mix_maps=mix_maps)] 153 | 154 | if first_filter_size == 7: 155 | layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 156 | 157 | self.pre_layers = nn.Sequential(*layers) 158 | self.layer1 = self._make_layer(block, 1*nfilters, nblocks[0], stride=1, level=level, nmasks=nmasks, use_act=True, 159 | filter_size=filter_size, act=act, input_size=input_size) 160 | self.layer2 = self._make_layer(block, 2*nfilters, nblocks[1], stride=2, level=level, nmasks=nmasks, use_act=True, 161 | filter_size=filter_size, act=act, input_size=input_size) 162 | self.layer3 = self._make_layer(block, 4*nfilters, nblocks[2], stride=2, level=level, nmasks=nmasks, use_act=True, 163 | filter_size=filter_size, act=act, input_size=input_size//2) 164 | self.layer4 = self._make_layer(block, 8*nfilters, nblocks[3], stride=2, level=level, nmasks=nmasks, use_act=True, 165 | filter_size=filter_size, act=act, input_size=input_size//4) 166 | self.avgpool = nn.AvgPool2d(avgpool, stride=1) 167 | self.linear = nn.Linear(8*nfilters*block.expansion, nclasses) 168 | 169 | def _make_layer(self, block, out_channels, nblocks, stride=1, level=0.2, nmasks=None, use_act=False, 170 | filter_size=None, act=None, input_size=None): 171 | shortcut = None 172 | if stride != 1 or self.nfilters != out_channels * block.expansion: 173 | shortcut = nn.Sequential( 174 | nn.Conv2d(self.nfilters, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False), 175 | nn.BatchNorm2d(out_channels * block.expansion), 176 | ) 177 | layers = [] 178 | layers.append(block(self.nfilters, out_channels, stride, shortcut, level=level, nmasks=nmasks, use_act=use_act, 179 | filter_size=filter_size, act=act, unique_masks=self.unique_masks, noise_type=self.noise_type, 180 | train_masks=self.train_masks, input_size=input_size, pool_type=self.pool_type, mix_maps=self.mix_maps)) 181 | self.nfilters = out_channels * block.expansion 182 | for i in range(1, nblocks): 183 | layers.append(block(self.nfilters, out_channels, level=level, nmasks=nmasks, use_act=use_act, 184 | train_masks=self.train_masks, filter_size=filter_size, act=act, unique_masks=self.unique_masks, 185 | noise_type=self.noise_type, input_size=input_size//stride, pool_type=self.pool_type, mix_maps=self.mix_maps)) 186 | return nn.Sequential(*layers) 187 | 188 | def forward(self, x): 189 | x = self.pre_layers(x) 190 | x = self.layer1(x) 191 | x = self.layer2(x) 192 | x = self.layer3(x) 193 | x = self.layer4(x) 194 | x = self.avgpool(x) 195 | x = x.view(x.size(0), -1) 196 | x = self.linear(x) 197 | return x 198 | 199 | 200 | class LeNet(nn.Module): 201 | def __init__(self, nfilters=None, nclasses=None, nmasks=None, level=None, filter_size=None, linear=128, input_size=28, 202 | debug=False, scale_noise=1, act='relu', use_act=False, first_filter_size=None, pool_type=None, 203 | dropout=None, unique_masks=False, train_masks=False, noise_type='uniform', mix_maps=None): 204 | super(LeNet, self).__init__() 205 | if filter_size == 5: 206 | n = 5 207 | else: 208 | n = 4 209 | 210 | if input_size == 32: 211 | first_channels = 3 212 | elif input_size == 28: 213 | first_channels = 1 214 | 215 | if pool_type == 'max': 216 | pool = nn.MaxPool2d 217 | elif pool_type == 'avg': 218 | pool = nn.AvgPool2d 219 | else: 220 | print('\n\nPool Type {} is not supported/understood\n\n'.format(pool_type)) 221 | return 222 | 223 | self.linear1 = nn.Linear(nfilters*n*n, linear) 224 | self.linear2 = nn.Linear(linear, nclasses) 225 | self.dropout = nn.Dropout(p=dropout) 226 | self.act = act_fn(act) 227 | self.batch_norm = nn.BatchNorm1d(linear) 228 | 229 | self.first_layers = nn.Sequential( 230 | PerturbLayer(in_channels=first_channels, out_channels=nfilters, nmasks=nmasks, level=level*scale_noise, 231 | filter_size=first_filter_size, use_act=use_act, act=act, unique_masks=unique_masks, 232 | train_masks=train_masks, noise_type=noise_type, input_size=input_size, mix_maps=mix_maps), 233 | pool(kernel_size=3, stride=2, padding=1), 234 | 235 | PerturbLayer(in_channels=nfilters, out_channels=nfilters, nmasks=nmasks, level=level, filter_size=filter_size, 236 | use_act=True, act=act, unique_masks=unique_masks, debug=debug, train_masks=train_masks, 237 | noise_type=noise_type, input_size=input_size//2, mix_maps=mix_maps), 238 | pool(kernel_size=3, stride=2, padding=1), 239 | 240 | PerturbLayer(in_channels=nfilters, out_channels=nfilters, nmasks=nmasks, level=level, filter_size=filter_size, 241 | use_act=True, act=act, unique_masks=unique_masks, train_masks=train_masks, noise_type=noise_type, 242 | input_size=input_size//4, mix_maps=mix_maps), 243 | pool(kernel_size=3, stride=2, padding=1), 244 | ) 245 | 246 | self.last_layers = nn.Sequential( 247 | self.dropout, 248 | self.linear1, 249 | self.batch_norm, 250 | self.act, 251 | self.dropout, 252 | self.linear2, 253 | ) 254 | 255 | def forward(self, x): 256 | x = self.first_layers(x) 257 | x = x.view(x.size(0), -1) 258 | x = self.last_layers(x) 259 | return x 260 | 261 | 262 | 263 | class CifarNet(nn.Module): 264 | def __init__(self, nfilters=None, nclasses=None, nmasks=None, level=None, filter_size=None, input_size=32, 265 | linear=256, scale_noise=1, act='relu', use_act=False, first_filter_size=None, pool_type=None, 266 | dropout=None, unique_masks=False, debug=False, train_masks=False, noise_type='uniform', mix_maps=None): 267 | super(CifarNet, self).__init__() 268 | if filter_size == 5: 269 | n = 5 270 | else: 271 | n = 4 272 | 273 | if input_size == 32: 274 | first_channels = 3 275 | elif input_size == 28: 276 | first_channels = 1 277 | 278 | if pool_type == 'max': 279 | pool = nn.MaxPool2d 280 | elif pool_type == 'avg': 281 | pool = nn.AvgPool2d 282 | else: 283 | print('\n\nPool Type {} is not supported/understood\n\n'.format(pool_type)) 284 | return 285 | 286 | self.linear1 = nn.Linear(nfilters*n*n, linear) 287 | self.linear2 = nn.Linear(linear, nclasses) 288 | self.dropout = nn.Dropout(p=dropout) 289 | self.act = act_fn(act) 290 | self.batch_norm = nn.BatchNorm1d(linear) 291 | 292 | self.first_layers = nn.Sequential( 293 | PerturbLayer(in_channels=first_channels, out_channels=nfilters, nmasks=nmasks, level=level*scale_noise, 294 | unique_masks=unique_masks, filter_size=first_filter_size, use_act=use_act, input_size=input_size, 295 | act=act, train_masks=train_masks, noise_type=noise_type, mix_maps=mix_maps), 296 | PerturbLayer(in_channels=nfilters, out_channels=nfilters, nmasks=nmasks, level=level, filter_size=filter_size, 297 | debug=debug, use_act=True, act=act, mix_maps=mix_maps, 298 | unique_masks=unique_masks, train_masks=train_masks, noise_type=noise_type, input_size=input_size), 299 | pool(kernel_size=3, stride=2, padding=1), 300 | 301 | PerturbLayer(in_channels=nfilters, out_channels=nfilters, nmasks=nmasks, level=level, filter_size=filter_size, 302 | use_act=True, act=act, unique_masks=unique_masks, mix_maps=mix_maps, 303 | train_masks=train_masks, noise_type=noise_type, input_size=input_size//2), 304 | PerturbLayer(in_channels=nfilters, out_channels=nfilters, nmasks=nmasks, level=level, filter_size=filter_size, 305 | use_act=True, act=act, unique_masks=unique_masks, mix_maps=mix_maps, 306 | train_masks=train_masks, noise_type=noise_type, input_size=input_size//2), 307 | pool(kernel_size=3, stride=2, padding=1), 308 | 309 | PerturbLayer(in_channels=nfilters, out_channels=nfilters, nmasks=nmasks, level=level, filter_size=filter_size, 310 | use_act=True, act=act, unique_masks=unique_masks, mix_maps=mix_maps, 311 | train_masks=train_masks, noise_type=noise_type, input_size=input_size//4), 312 | PerturbLayer(in_channels=nfilters, out_channels=nfilters, nmasks=nmasks, level=level, filter_size=filter_size, 313 | use_act=True, act=act, unique_masks=unique_masks, mix_maps=mix_maps, 314 | train_masks=train_masks, noise_type=noise_type, input_size=input_size//4), 315 | pool(kernel_size=3, stride=2, padding=1), 316 | ) 317 | 318 | self.last_layers = nn.Sequential( 319 | self.dropout, 320 | self.linear1, 321 | self.batch_norm, 322 | self.act, 323 | self.dropout, 324 | self.linear2, 325 | ) 326 | 327 | def forward(self, x): 328 | x = self.first_layers(x) 329 | x = x.view(x.size(0), -1) 330 | x = self.last_layers(x) 331 | return x 332 | 333 | 334 | 335 | """************* Original PNN Implementation ****************""" 336 | 337 | class NoiseLayer(nn.Module): 338 | def __init__(self, in_planes, out_planes, level): 339 | super(NoiseLayer, self).__init__() 340 | self.noise = nn.Parameter(torch.Tensor(0), requires_grad=False).to(device) 341 | self.level = level 342 | self.layers = nn.Sequential( 343 | nn.ReLU(True), 344 | nn.BatchNorm2d(in_planes), #TODO paper does not use it! 345 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1), 346 | ) 347 | 348 | def forward(self, x): 349 | if self.noise.numel() == 0: 350 | self.noise.resize_(x.data[0].shape).uniform_() #fill with uniform noise 351 | self.noise = (2 * self.noise - 1) * self.level 352 | y = torch.add(x, self.noise) 353 | return self.layers(y) #input, perturb, relu, batchnorm, conv1x1 354 | 355 | 356 | class NoiseBasicBlock(nn.Module): 357 | expansion = 1 358 | 359 | def __init__(self, in_planes, planes, stride=1, shortcut=None, level=0.2): 360 | super(NoiseBasicBlock, self).__init__() 361 | self.layers = nn.Sequential( 362 | NoiseLayer(in_planes, planes, level), #perturb, relu, conv1x1 363 | nn.MaxPool2d(stride, stride), 364 | nn.BatchNorm2d(planes), 365 | nn.ReLU(True), #TODO paper does not use it! 366 | NoiseLayer(planes, planes, level), #perturb, relu, conv1x1 367 | nn.BatchNorm2d(planes), 368 | ) 369 | self.shortcut = shortcut 370 | 371 | def forward(self, x): 372 | residual = x 373 | y = self.layers(x) 374 | if self.shortcut: 375 | residual = self.shortcut(x) 376 | y += residual 377 | y = F.relu(y) 378 | return y 379 | 380 | 381 | class NoiseResNet(nn.Module): 382 | def __init__(self, block, nblocks, nfilters, nclasses, pool, level, first_filter_size=3): 383 | super(NoiseResNet, self).__init__() 384 | self.in_planes = nfilters 385 | if first_filter_size == 7: 386 | pool = 1 387 | self.pre_layers = nn.Sequential( 388 | nn.Conv2d(3, nfilters, kernel_size=first_filter_size, stride=2, padding=3, bias=False), 389 | nn.BatchNorm2d(nfilters), 390 | nn.ReLU(True), 391 | nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 392 | ) 393 | elif first_filter_size == 3: 394 | pool = 4 395 | self.pre_layers = nn.Sequential( 396 | nn.Conv2d(3, nfilters, kernel_size=first_filter_size, stride=1, padding=1, bias=False), 397 | nn.BatchNorm2d(nfilters), 398 | nn.ReLU(True), 399 | ) 400 | elif first_filter_size == 0: 401 | print('\n\nThe original noiseresnet18 model does not support noise masks in the first layer, ' 402 | 'use perturb_resnet18 model, or set first_filter_size to 3 or 7\n\n') 403 | return 404 | 405 | self.layer1 = self._make_layer(block, 1*nfilters, nblocks[0], stride=1, level=level) 406 | self.layer2 = self._make_layer(block, 2*nfilters, nblocks[1], stride=2, level=level) 407 | self.layer3 = self._make_layer(block, 4*nfilters, nblocks[2], stride=2, level=level) 408 | self.layer4 = self._make_layer(block, 8*nfilters, nblocks[3], stride=2, level=level) 409 | self.avgpool = nn.AvgPool2d(pool, stride=1) 410 | self.linear = nn.Linear(8*nfilters*block.expansion, nclasses) 411 | 412 | def _make_layer(self, block, planes, nblocks, stride=1, level=0.2, filter_size=1): 413 | shortcut = None 414 | if stride != 1 or self.in_planes != planes * block.expansion: 415 | shortcut = nn.Sequential( 416 | nn.Conv2d(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 417 | nn.BatchNorm2d(planes * block.expansion), 418 | ) 419 | layers = [] 420 | layers.append(block(self.in_planes, planes, stride, shortcut, level=level)) 421 | self.in_planes = planes * block.expansion 422 | for i in range(1, nblocks): 423 | layers.append(block(self.in_planes, planes, level=level)) 424 | return nn.Sequential(*layers) 425 | 426 | def forward(self, x): 427 | x1 = self.pre_layers(x) 428 | x2 = self.layer1(x1) 429 | x3 = self.layer2(x2) 430 | x4 = self.layer3(x3) 431 | x5 = self.layer4(x4) 432 | x6 = self.avgpool(x5) 433 | x7 = x6.view(x6.size(0), -1) 434 | x8 = self.linear(x7) 435 | return x8 436 | 437 | 438 | 439 | """ *************** Reference ResNet Implementation (https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py) ****************** """ 440 | 441 | class BasicBlock(nn.Module): 442 | expansion = 1 443 | 444 | def __init__(self, in_planes, planes, stride=1): 445 | super(BasicBlock, self).__init__() 446 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 447 | self.bn1 = nn.BatchNorm2d(planes) 448 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 449 | self.bn2 = nn.BatchNorm2d(planes) 450 | 451 | self.shortcut = nn.Sequential() 452 | if stride != 1 or in_planes != self.expansion*planes: 453 | self.shortcut = nn.Sequential( 454 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 455 | nn.BatchNorm2d(self.expansion*planes) 456 | ) 457 | 458 | def forward(self, x): 459 | out = F.relu(self.bn1(self.conv1(x))) 460 | out = self.bn2(self.conv2(out)) 461 | out += self.shortcut(x) 462 | out = F.relu(out) 463 | return out 464 | 465 | 466 | class ResNet(nn.Module): 467 | def __init__(self, block, num_blocks, nfilters=64, avgpool=4, nclasses=10): 468 | super(ResNet, self).__init__() 469 | self.in_planes = nfilters 470 | self.avgpool = avgpool 471 | 472 | self.conv1 = nn.Conv2d(3, nfilters, kernel_size=3, stride=1, padding=1, bias=False) 473 | self.bn1 = nn.BatchNorm2d(nfilters) 474 | self.layer1 = self._make_layer(block, nfilters, num_blocks[0], stride=1) 475 | self.layer2 = self._make_layer(block, nfilters*2, num_blocks[1], stride=2) 476 | self.layer3 = self._make_layer(block, nfilters*4, num_blocks[2], stride=2) 477 | self.layer4 = self._make_layer(block, nfilters*8, num_blocks[3], stride=2) 478 | self.linear = nn.Linear(nfilters*8*block.expansion, nclasses) 479 | 480 | def _make_layer(self, block, planes, num_blocks, stride): 481 | strides = [stride] + [1]*(num_blocks-1) 482 | layers = [] 483 | for stride in strides: 484 | layers.append(block(self.in_planes, planes, stride)) 485 | self.in_planes = planes * block.expansion 486 | return nn.Sequential(*layers) 487 | 488 | def forward(self, x): 489 | out = F.relu(self.bn1(self.conv1(x))) 490 | out = self.layer1(out) 491 | out = self.layer2(out) 492 | out = self.layer3(out) 493 | out = self.layer4(out) 494 | out = F.avg_pool2d(out, self.avgpool) 495 | out = out.view(out.size(0), -1) 496 | out = self.linear(out) 497 | return out 498 | 499 | 500 | 501 | def resnet18(nfilters, avgpool=4, nclasses=10, nmasks=32, level=0.1, filter_size=0, first_filter_size=0, pool_type=None, 502 | input_size=None, scale_noise=1, act='relu', use_act=True, dropout=0.5, unique_masks=False, 503 | noise_type='uniform', train_masks=False, debug=False, mix_maps=None): 504 | return ResNet(BasicBlock, [2, 2, 2, 2], nfilters=nfilters, avgpool=avgpool, nclasses=nclasses) 505 | 506 | 507 | def noiseresnet18(nfilters, avgpool=4, nclasses=10, nmasks=32, level=0.1, filter_size=0, first_filter_size=7, 508 | pool_type=None, input_size=None, scale_noise=1, act='relu', use_act=True, dropout=0.5, unique_masks=False, 509 | debug=False, noise_type='uniform', train_masks=False, mix_maps=None): 510 | return NoiseResNet(NoiseBasicBlock, [2, 2, 2, 2], nfilters=nfilters, pool=avgpool, nclasses=nclasses, 511 | level=level, first_filter_size=first_filter_size) 512 | 513 | 514 | def perturb_resnet18(nfilters, avgpool=4, nclasses=10, nmasks=32, level=0.1, filter_size=0, first_filter_size=0, 515 | pool_type=None, input_size=None, scale_noise=1, act='relu', use_act=True, dropout=0.5, 516 | unique_masks=False, debug=False, noise_type='uniform', train_masks=False, mix_maps=None): 517 | return PerturbResNet(PerturbBasicBlock, [2, 2, 2, 2], nfilters=nfilters, avgpool=avgpool, nclasses=nclasses, pool_type=pool_type, 518 | scale_noise=scale_noise, nmasks=nmasks, level=level, filter_size=filter_size, train_masks=train_masks, 519 | first_filter_size=first_filter_size, act=act, use_act=use_act, unique_masks=unique_masks, 520 | debug=debug, noise_type=noise_type, input_size=input_size, mix_maps=mix_maps) 521 | 522 | 523 | def lenet(nfilters, avgpool=None, nclasses=10, nmasks=32, level=0.1, filter_size=3, first_filter_size=0, 524 | pool_type=None, input_size=None, scale_noise=1, act='relu', use_act=True, dropout=0.5, 525 | unique_masks=False, debug=False, noise_type='uniform', train_masks=False, mix_maps=None): 526 | return LeNet(nfilters=nfilters, nclasses=nclasses, nmasks=nmasks, level=level, filter_size=filter_size, pool_type=pool_type, 527 | scale_noise=scale_noise, act=act, first_filter_size=first_filter_size, input_size=input_size, mix_maps=mix_maps, 528 | use_act=use_act, dropout=dropout, unique_masks=unique_masks, debug=debug, noise_type=noise_type, train_masks=train_masks) 529 | 530 | 531 | def cifarnet(nfilters, avgpool=None, nclasses=10, nmasks=32, level=0.1, filter_size=3, first_filter_size=0, 532 | pool_type=None, input_size=None, scale_noise=1, act='relu', use_act=True, dropout=0.5, 533 | unique_masks=False, debug=False, noise_type='uniform', train_masks=False, mix_maps=None): 534 | return CifarNet(nfilters=nfilters, nclasses=nclasses, nmasks=nmasks, level=level, filter_size=filter_size, pool_type=pool_type, 535 | scale_noise=scale_noise, act=act, use_act=use_act, first_filter_size=first_filter_size, input_size=input_size, 536 | dropout=dropout, unique_masks=unique_masks, debug=debug, noise_type=noise_type, train_masks=train_masks, mix_maps=mix_maps) 537 | 538 | 539 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | import numpy as np 4 | from torch import nn 5 | import torch.optim as optim 6 | import datasets 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class Net(nn.Module): 11 | def __init__(self, ): 12 | super(Net, self).__init__() 13 | self.linear = nn.Linear(9 * 6 * 6, 10) 14 | self.noise = nn.Parameter(torch.Tensor(1, 1, 28, 28), requires_grad=True) 15 | self.noise.data.uniform_(-1, 1) 16 | 17 | self.layers = nn.Sequential( 18 | nn.Conv2d(1, 9, kernel_size=5, stride=2, bias=False), 19 | nn.MaxPool2d(2, 2), 20 | nn.ReLU(), 21 | ) 22 | 23 | def forward(self, x): 24 | x = torch.add(x, self.noise) 25 | x = self.layers(x) 26 | x = x.view(x.size(0), -1) 27 | x = self.linear(x) 28 | print('{:.5f}'.format(self.noise.data[0, 0, 0, 0].cpu().numpy())) 29 | return x 30 | 31 | 32 | model = Net() 33 | model.apply(utils.weights_init) 34 | model = model.cuda() 35 | 36 | dataset_train = getattr(datasets, 'MNIST')(root='./data', train=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) 37 | dataset_test = getattr(datasets, 'MNIST')(root='./data', train=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) 38 | 39 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True, num_workers=4, pin_memory=True) 40 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=16, shuffle=False, num_workers=4, pin_memory=True) 41 | 42 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001, nesterov=True) 43 | 44 | print('\n\n****** Model Graph ******\n\n') 45 | for arg in vars(model): 46 | print(arg, getattr(model, arg)) 47 | 48 | print('\n\nModel named_parameters():\n') 49 | for name, param in model.named_parameters(): 50 | #if param.requires_grad: 51 | print('{} {} requires_grad: {} {:.2f}k'.format(name, list(param.size()), param.requires_grad, param.numel()/1000.)) 52 | 53 | print('\n\nModel parameters():\n') 54 | for param in model.parameters(): 55 | #if param.requires_grad: 56 | print('{} requires_grad: {}'.format(list(param.size()), param.requires_grad)) 57 | 58 | print('\n\n****** Model state_dict() ******\n\n') 59 | for name, param in model.state_dict().items(): 60 | print('{} {} requires_grad: {}'.format(name, list(param.size()), param.requires_grad)) 61 | 62 | print('\n\n') 63 | for epoch in range(1): 64 | model.train() 65 | tr_accuracies = [] 66 | for i, (input, label) in enumerate(loader_train): 67 | label = label.cuda() 68 | input = input.cuda() 69 | 70 | output = model(input) 71 | loss = nn.CrossEntropyLoss()(output, label) 72 | #print('\nBatch:', i) 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | 77 | pred = output.data.max(1)[1] 78 | acc = pred.eq(label.data).cpu().sum() * 100.0 / 16 79 | tr_accuracies.append(acc) 80 | 81 | model.eval() 82 | te_accuracies = [] 83 | with torch.no_grad(): 84 | for i, (input, label) in enumerate(loader_test): 85 | label = label.cuda() 86 | input = input.cuda() 87 | 88 | output = model(input) 89 | pred = output.data.max(1)[1] 90 | acc = pred.eq(label.data).cpu().sum() * 100.0 / 16 91 | te_accuracies.append(acc) 92 | 93 | print('Epoch {:d} Train Accuracy {:.2f} Test Accuracy {:.2f}'.format(epoch, np.mean(tr_accuracies), np.mean(te_accuracies))) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | def readtextfile(filename): 9 | with open(filename) as f: 10 | content = f.readlines() 11 | f.close() 12 | return content 13 | 14 | def writetextfile(data, filename): 15 | with open(filename, 'w') as f: 16 | f.writelines(data) 17 | f.close() 18 | 19 | def delete_file(filename): 20 | if os.path.isfile(filename) == True: 21 | os.remove(filename) 22 | 23 | def eformat(f, prec, exp_digits): 24 | s = "%.*e"%(prec, f) 25 | mantissa, exp = s.split('e') 26 | # add 1 to digits as 1 is taken by sign +/- 27 | return "%se%+0*d"%(mantissa, exp_digits+1, int(exp)) 28 | 29 | def saveargs(args): 30 | path = args.logs 31 | if os.path.isdir(path) == False: 32 | os.makedirs(path) 33 | with open(os.path.join(path,'args.txt'), 'w') as f: 34 | for arg in vars(args): 35 | f.write(arg+' '+str(getattr(args,arg))+'\n') 36 | 37 | def init_params(net): 38 | for m in net.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | nn.init.kaiming_normal(m.weight, mode='fan_out') 41 | if m.bias: 42 | nn.init.constant(m.bias, 0) 43 | elif isinstance(m, nn.BatchNorm2d): 44 | nn.init.constant(m.weight, 1) 45 | nn.init.constant(m.bias, 0) 46 | elif isinstance(m, nn.Linear): 47 | nn.init.normal(m.weight, std=1e-3) 48 | if m.bias: 49 | nn.init.constant(m.bias, 0) 50 | 51 | def weights_init(m): 52 | if isinstance(m, nn.Conv2d): 53 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 54 | m.weight.data.normal_(0, math.sqrt(2. / n)) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | m.weight.data.fill_(1) 57 | m.bias.data.zero_() 58 | 59 | 60 | class Counter: #not used currently 61 | def __init__(self): 62 | self.mask_size = 0 63 | 64 | def update(self, size): 65 | self.mask_size += size 66 | 67 | def get_total(self): 68 | return self.mask_size 69 | 70 | 71 | def act_fn(act): 72 | if act == 'relu': 73 | act_ = nn.ReLU(inplace=False) 74 | elif act == 'lrelu': 75 | act_ = nn.LeakyReLU(inplace=True) 76 | elif act == 'prelu': 77 | act_ = nn.PReLU() 78 | elif act == 'rrelu': 79 | act_ = nn.RReLU(inplace=True) 80 | elif act == 'elu': 81 | act_ = nn.ELU(inplace=True) 82 | elif act == 'selu': 83 | act_ = nn.SELU(inplace=True) 84 | elif act == 'tanh': 85 | act_ = nn.Tanh() 86 | elif act == 'sigmoid': 87 | act_ = nn.Sigmoid() 88 | else: 89 | print('\n\nActivation function {} is not supported/understood\n\n'.format(act)) 90 | act_ = None 91 | return act_ 92 | 93 | 94 | def print_values(x, noise, y, unique_masks, n=2): 95 | np.set_printoptions(precision=5, linewidth=200, threshold=1000000, suppress=True) 96 | print('\nimage: {} image0, channel0 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[0, 0, 0, 0, :n].cpu().numpy())) 97 | print('image: {} image0, channel1 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[0, 1, 0, 0, :n].cpu().numpy())) 98 | print('\nimage: {} image1, channel0 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[1, 0, 0, 0, :n].cpu().numpy())) 99 | print('image: {} image1, channel1 {}'.format(list(x.unsqueeze(2).size()), x.unsqueeze(2).data[1, 1, 0, 0, :n].cpu().numpy())) 100 | if noise is not None: 101 | print('\nnoise {} channel0, mask0: {}'.format(list(noise.size()), noise.data[0, 0, 0, 0, :n].cpu().numpy())) 102 | print('noise {} channel0, mask1: {}'.format(list(noise.size()), noise.data[0, 0, 1, 0, :n].cpu().numpy())) 103 | if unique_masks: 104 | print('\nnoise {} channel1, mask0: {}'.format(list(noise.size()), noise.data[0, 1, 0, 0, :n].cpu().numpy())) 105 | print('noise {} channel1, mask1: {}'.format(list(noise.size()), noise.data[0, 1, 1, 0, :n].cpu().numpy())) 106 | 107 | print('\nmasks: {} image0, channel0, mask0: {}'.format(list(y.size()), y.data[0, 0, 0, 0, :n].cpu().numpy())) 108 | print('masks: {} image0, channel0, mask1: {}'.format(list(y.size()), y.data[0, 0, 1, 0, :n].cpu().numpy())) 109 | print('masks: {} image0, channel1, mask0: {}'.format(list(y.size()), y.data[0, 1, 0, 0, :n].cpu().numpy())) 110 | print('masks: {} image0, channel1, mask1: {}'.format(list(y.size()), y.data[0, 1, 1, 0, :n].cpu().numpy())) 111 | print('\nmasks: {} image1, channel0, mask0: {}'.format(list(y.size()), y.data[1, 0, 0, 0, :n].cpu().numpy())) 112 | print('masks: {} image1, channel0, mask1: {}'.format(list(y.size()), y.data[1, 0, 1, 0, :n].cpu().numpy())) 113 | print('masks: {} image1, channel1, mask0: {}'.format(list(y.size()), y.data[1, 1, 0, 0, :n].cpu().numpy())) 114 | print('masks: {} image1, channel1, mask1: {}'.format(list(y.size()), y.data[1, 1, 1, 0, :n].cpu().numpy())) --------------------------------------------------------------------------------