├── .gitignore ├── README.md ├── deep_nets ├── Dockerfile ├── data.py ├── models.py ├── plot_feature_sparsity.ipynb ├── train.py ├── utils.py ├── utils_eval.py └── utils_train.py ├── diag_nets.ipynb ├── diag_nets.py ├── diag_nets_2d_loss_surface.ipynb ├── fc_nets.py ├── fc_nets_1d_regression.ipynb ├── fc_nets_multi_layer.ipynb ├── fc_nets_two_layer.ipynb ├── images ├── fig1.png ├── twitter.gif └── twitter.mp4 └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | plots 2 | .DS_Store 3 | __pycache__ 4 | .mat 5 | 6 | deep_nets/exps/ 7 | deep_nets/logs/ 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SGD with large step sizes learns sparse features 2 | 3 | **Maksym Andriushchenko, Aditya Varre, Loucas Pillaud-Vivien, Nicolas Flammarion (EPFL)** 4 | 5 | **ICML 2023** 6 | 7 | **Paper:** [https://arxiv.org/abs/2210.05337](https://arxiv.org/abs/2210.05337) 8 | 9 | 10 |

11 | 12 | ## Abstract 13 | We showcase important features of the dynamics of the Stochastic Gradient Descent (SGD) in the training of neural networks. We present empirical observations that commonly used large step sizes (i) lead the iterates to jump from one side of a valley to the other causing *loss stabilization*, and (ii) this stabilization induces a hidden stochastic dynamics orthogonal to the bouncing directions that *biases it implicitly* toward simple predictors. Furthermore, we show empirically that the longer large step sizes keep SGD high in the loss landscape valleys, the better the implicit regularization can operate and find sparse representations. Notably, no explicit regularization is used so that the regularization effect comes solely from the SGD training dynamics influenced by the step size schedule. Therefore, these observations unveil how, through the step size schedules, both gradient and noise drive together the SGD dynamics through the loss landscape of neural networks. We justify these findings theoretically through the study of simple neural network models as well as qualitative arguments inspired from stochastic processes. Finally, this analysis allows to shed a new light on some common practice and observed phenomena when training neural networks. 14 | 15 |

16 | 17 | 18 | 19 | 20 | 21 | ## Code 22 | The exact code to reproduce all the reported experiments on simple networks is available in jupyter notebooks: 23 | - `diag_nets.ipynb`: diagonal linear networks (also see `diag_nets_2d_loss_surface.ipynb` for loss surface visualizations). 24 | - `fc_nets_1d_regression.ipynb`: two-layer ReLU networks on 1D regression problem. 25 | - `fc_nets_two_layer.ipynb`: two-layer ReLU networks in a teacher-student setup (+ neuron movement visualization). 26 | - `fc_nets_multi_layer.ipynb`: three-layer ReLU networks in a teacher-student setup. 27 | 28 | For deep networks, see folder `deep_nets` where the dependencies are collected in `Dockerfile`. Typical training commands for a ResNet-18 on CIFAR-10 would look like this: 29 | - Plain SGD without explicit regularization (loss stabilization is achieved via exponential warmup): 30 | - with large step sizes: `python train.py --dataset=cifar10 --lr_init=0.75 --lr_schedule=piecewise_05epochs --warmup_exp=1.05 --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.0 --l2_reg=0.0 --no_data_augm --eval_iter_freq=200 --exp_name=no_explicit_reg` 31 | - with small step sizes: `python train.py --dataset=cifar10 --lr_init=0.01 --lr_schedule=constant --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.0 --l2_reg=0.0 --no_data_augm --eval_iter_freq=200 --exp_name=no_explicit_reg` 32 | - SGD + momentum in the state-of-the-art setting with data augmentation and weight decay: 33 | - with large step sizes: `python train.py --dataset=cifar10 --lr_init=0.05 --lr_schedule=piecewise_05epochs --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.9 --l2_reg=0.0005 --eval_iter_freq=200 --exp_name=sota_setting` 34 | - with small step sizes: `python train.py --dataset=cifar10 --lr_init=0.002 --lr_schedule=constant --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.9 --l2_reg=0.0005 --eval_iter_freq=200 --exp_name=sota_setting` 35 | 36 | The runs with CIFAR-100 are analogous, just put `dataset=cifar100`. The step size schedule can be selected from [`constant`, `piecewise_01epochs`, `piecewise_03epochs`, `piecewise_05epochs`], see `utils_train.py` for more details. 37 | 38 | 39 | ## Contact 40 | Feel free to reach out if you have any questions regarding the code! 41 | 42 | -------------------------------------------------------------------------------- /deep_nets/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04 2 | LABEL maintainer "Maksym Andriushchenko " 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive # needed to prevent some questions during the Docker building phase 5 | 6 | # install some necessary tools 7 | RUN apt-get update 8 | RUN apt-get install -y \ 9 | cmake \ 10 | curl \ 11 | htop \ 12 | locales \ 13 | python3 \ 14 | python3-pip \ 15 | sudo \ 16 | unzip \ 17 | vim \ 18 | git \ 19 | wget \ 20 | zsh \ 21 | libssl-dev \ 22 | libffi-dev \ 23 | libmagickwand-dev \ 24 | ffmpeg \ 25 | libsm6 \ 26 | libxext6 \ 27 | openssh-server 28 | RUN rm -rf /var/lib/apt/lists/* 29 | RUN mkdir /var/run/sshd 30 | 31 | 32 | # configure environments 33 | RUN locale-gen en_US.UTF-8 34 | ENV LANG en_US.UTF-8 35 | ENV LANGUAGE en_US:en 36 | ENV LC_ALL en_US.UTF-8 37 | 38 | 39 | RUN pip3 install --upgrade pip # needed for opencv 40 | RUN pip3 install -U setuptools # may be needed for opencv 41 | RUN pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html # needed for A100 pods 42 | 43 | 44 | # python packages 45 | RUN pip3 install --upgrade \ 46 | scipy \ 47 | numpy \ 48 | jupyter notebook \ 49 | ipdb \ 50 | pyyaml \ 51 | easydict \ 52 | requests \ 53 | matplotlib \ 54 | seaborn 55 | RUN export LC_ALL=en_US.UTF-8 56 | 57 | # Configure user and group 58 | ENV SHELL=/bin/bash 59 | 60 | ENV HOME=/home/$NB_USER 61 | 62 | RUN ln -s /usr/bin/python3 /usr/bin/python 63 | 64 | # expose the port to ssh 65 | EXPOSE 22 66 | CMD ["/usr/sbin/sshd", "-D"] 67 | -------------------------------------------------------------------------------- /deep_nets/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as td 4 | import numpy as np 5 | from torchvision import datasets, transforms 6 | 7 | 8 | class DatasetWithLabelNoise(torch.utils.data.Dataset): 9 | def __init__(self, data, split, transform): 10 | self.data = data 11 | self.split = split 12 | self.transform = transform 13 | 14 | def __getitem__(self, index): 15 | x = self.data.data[index] 16 | x1 = self.transform(x) if self.transform is not None else x 17 | if self.split == 'train': 18 | x2 = self.transform(x) if self.transform is not None else x 19 | else: # to save a bit of computations 20 | x2 = x1 21 | y = self.data.targets[index] 22 | y_correct = self.data.targets_correct[index] 23 | label_noise = self.data.label_noise[index] 24 | return x1, x2, y, y_correct, label_noise 25 | 26 | def __len__(self): 27 | return len(self.data.targets) 28 | 29 | 30 | def uniform_noise(*args, **kwargs): 31 | shape = [1000, 1, 28, 28] 32 | x = torch.from_numpy(np.random.rand(*shape)).float() 33 | # y_train = np.random.randint(0, 10, size=shape_train[0]) 34 | y = np.floor(10 * x[:, 0, 0, 0].numpy()) # take the first feature 35 | y = torch.from_numpy(y).long() 36 | data = td.TensorDataset(x, y) 37 | return data 38 | 39 | 40 | def dataset_gaussians_binary(*args, **kwargs): 41 | shape = shapes_dict['gaussians_binary'] 42 | n, d = shape[0], shape[3] 43 | std = 0.1 44 | 45 | v = v_global.copy() 46 | v /= (v**2).sum()**0.5 # make it unit norm 47 | mu_zero, mu_one = v, -v 48 | 49 | x = np.concatenate([mu_zero + std*np.random.randn(n // 2, d), mu_one + std*np.random.randn(n // 2, d)]) 50 | y = np.concatenate([np.zeros(n // 2), np.ones(n // 2)]) 51 | indices = np.random.permutation(np.arange(n)) 52 | x, y = x[indices], y[indices] 53 | x = x[:, None, None, :] # make it image-like 54 | 55 | data = td.TensorDataset() 56 | data.data, data.targets = torch.from_numpy(x).float(), torch.from_numpy(y).long() 57 | return data 58 | 59 | 60 | def asym_label_noise(dataset, label): 61 | if dataset == 'cifar10': 62 | if label == 9: 63 | return 1 64 | # bird -> airplane 65 | elif label == 2: 66 | return 0 67 | # cat -> dog 68 | elif label == 3: 69 | return 5 70 | # dog -> cat 71 | elif label == 5: 72 | return 3 73 | # deer -> horse 74 | elif label == 4: 75 | return 7 76 | else: 77 | return label 78 | elif dataset == 'cifar100': 79 | return (label + 1) % 100 80 | elif dataset == 'svhn': 81 | return (label + 1) % 10 82 | else: 83 | raise ValueError('This dataset does not yet support asymmetric label noise.') 84 | 85 | 86 | def get_loaders(dataset, n_ex, batch_size, split, shuffle, data_augm, val_indices=None, p_label_noise=0.0, 87 | noise_type='sym', drop_last=False): 88 | dir_ = '/tmlscratch/andriush/data' 89 | # dir_ = '/tmldata1/andriush/data' 90 | dataset_f = datasets_dict[dataset] 91 | batch_size = n_ex if n_ex < batch_size and n_ex != -1 else batch_size 92 | num_workers_train, num_workers_val, num_workers_test = 4, 4, 4 93 | 94 | data_augm_transforms = [transforms.RandomCrop(32, padding=4)] 95 | if dataset not in ['mnist', 'svhn']: 96 | data_augm_transforms.append(transforms.RandomHorizontalFlip()) 97 | base_transforms = [transforms.ToPILImage()] if dataset != 'gaussians_binary' else [] 98 | transform_list = base_transforms + data_augm_transforms if data_augm else base_transforms 99 | transform = transforms.Compose(transform_list + [transforms.ToTensor()]) 100 | 101 | if dataset == 'cifar10_horse_car': 102 | cl1, cl2 = 7, 1 # 7=horse, 1=car 103 | elif dataset == 'cifar10_dog_cat': 104 | cl1, cl2 = 5, 3 # 5=dog, 3=cat 105 | if split in ['train', 'val']: 106 | if dataset != 'svhn': 107 | data = dataset_f(dir_, train=True, transform=transform, download=True) 108 | else: 109 | data = dataset_f(dir_, split='train', transform=transform, download=True) 110 | data.data = data.data.transpose([0, 2, 3, 1]) 111 | data.targets = data.labels 112 | data.targets = np.array(data.targets) 113 | n_cls = max(data.targets) + 1 114 | 115 | if dataset in ['cifar10_horse_car', 'cifar10_dog_cat']: 116 | data.targets = np.array(data.targets) 117 | idx = (data.targets == cl1) + (data.targets == cl2) 118 | data.data, data.targets = data.data[idx], data.targets[idx] 119 | data.targets[data.targets == cl1], data.targets[data.targets == cl2] = 0, 1 120 | n_cls = 2 121 | n_ex = len(data.targets) if n_ex == -1 else n_ex 122 | if '_gs' in dataset: 123 | data.data = data.data.mean(3).astype(np.uint8) 124 | 125 | if val_indices is not None: 126 | assert len(val_indices) < len(data.targets), '#val has to be < total #train pts' 127 | val_indices_mask = np.zeros(len(data.targets), dtype=bool) 128 | val_indices_mask[val_indices] = True 129 | if split == 'train': 130 | data.data, data.targets = data.data[~val_indices_mask], data.targets[~val_indices_mask] 131 | else: 132 | data.data, data.targets = data.data[val_indices_mask], data.targets[val_indices_mask] 133 | data.data, data.targets = data.data[:n_ex], data.targets[:n_ex] # so the #pts can be in [n_ex-n_eval, n_ex] 134 | # e.g., when frac_train=1.0, for training set, n_ex=50k while data.data.shape[0]=45k bc of val set 135 | if n_ex > data.data.shape[0]: 136 | n_ex = data.data.shape[0] 137 | 138 | data.label_noise = np.zeros(n_ex, dtype=bool) 139 | data.targets_correct = data.targets.copy() 140 | if p_label_noise > 0.0: 141 | print('Split: {}, number of examples: {}, noisy examples: {}'.format(split, n_ex, int(n_ex*p_label_noise))) 142 | print('Dataset shape: x is {}, y is {}'.format(data.data.shape, data.targets.shape)) 143 | assert n_ex == data.data.shape[0] # there was a mistake previously here leading to a larger noise level 144 | 145 | # gen random indices 146 | indices = np.random.permutation(np.arange(len(data.targets)))[:int(n_ex*p_label_noise)] 147 | for index in indices: 148 | if noise_type == 'sym': 149 | lst_classes = list(range(n_cls)) 150 | cls_int = data.targets[index] if type(data.targets[index]) is int else data.targets[index].item() 151 | lst_classes.remove(cls_int) 152 | data.targets[index] = np.random.choice(lst_classes) 153 | else: 154 | data.targets[index] = asym_label_noise(dataset, data.targets[index]) 155 | data.label_noise[indices] = True 156 | print(data.data.shape) 157 | data = DatasetWithLabelNoise(data, split, transform if dataset != 'gaussians_binary' else None) 158 | loader = torch.utils.data.DataLoader( 159 | dataset=data, batch_size=batch_size, shuffle=shuffle, pin_memory=True, 160 | num_workers=num_workers_train if split == 'train' else num_workers_val, drop_last=drop_last) 161 | 162 | elif split == 'test': 163 | if dataset != 'svhn': 164 | data = dataset_f(dir_, train=False, transform=transform, download=True) 165 | else: 166 | data = dataset_f(dir_, split='test', transform=transform, download=True) 167 | data.data = data.data.transpose([0, 2, 3, 1]) 168 | data.targets = data.labels 169 | n_ex = len(data) if n_ex == -1 else n_ex 170 | 171 | if dataset in ['cifar10_horse_car', 'cifar10_dog_cat']: 172 | data.targets = np.array(data.targets) 173 | idx = (data.targets == cl1) + (data.targets == cl2) 174 | data.data, data.targets = data.data[idx], data.targets[idx] 175 | data.targets[data.targets == cl1], data.targets[data.targets == cl2] = 0, 1 176 | data.targets = list(data.targets) # to reduce memory consumption 177 | if '_gs' in dataset: 178 | data.data = data.data.mean(3).astype(np.uint8) 179 | data.data, data.targets = data.data[:n_ex], data.targets[:n_ex] 180 | data.targets_correct = data.targets.copy() 181 | 182 | data.label_noise = np.zeros(n_ex) 183 | data = DatasetWithLabelNoise(data, split, transform if dataset != 'gaussians_binary' else None) 184 | loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, pin_memory=True, 185 | num_workers=num_workers_test, drop_last=drop_last) 186 | 187 | else: 188 | raise ValueError('wrong split') 189 | 190 | return loader 191 | 192 | 193 | def create_loader(x, y, ln, n_ex, batch_size, shuffle, drop_last): 194 | if n_ex > 0: 195 | x, y, ln = x[:n_ex], y[:n_ex], ln[:n_ex] 196 | data = td.TensorDataset(x, y, ln) 197 | loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, pin_memory=False, 198 | num_workers=2, drop_last=drop_last) 199 | return loader 200 | 201 | 202 | def get_xy_from_loader(loader, cuda=True, n_batches=-1): 203 | tuples = [(x, y, y_correct, ln) for i, (x, x_augm2, y, y_correct, ln) in enumerate(loader) if n_batches == -1 or i < n_batches] 204 | x_vals = torch.cat([x for (x, y, y_correct, ln) in tuples]) 205 | y_vals = torch.cat([y for (x, y, y_correct, ln) in tuples]) 206 | y_correct_vals = torch.cat([y_correct for (x, y, y_correct, ln) in tuples]) 207 | ln_vals = torch.cat([ln for (x, y, y_correct, ln) in tuples]) 208 | if cuda: 209 | x_vals, y_vals, y_correct_vals, ln_vals = x_vals.cuda(), y_vals.cuda(), y_correct_vals.cuda(), ln_vals.cuda() 210 | return x_vals, y_vals, y_correct_vals, ln_vals 211 | 212 | 213 | shapes_dict = {'mnist': (60000, 1, 28, 28), 214 | 'mnist_binary': (13007, 1, 28, 28), 215 | 'svhn': (73257, 3, 32, 32), 216 | 'cifar10': (50000, 3, 32, 32), 217 | 'cifar10_horse_car': (10000, 3, 32, 32), 218 | 'cifar10_dog_cat': (10000, 3, 32, 32), 219 | 'cifar100': (50000, 3, 32, 32), 220 | 'uniform_noise': (1000, 1, 28, 28), 221 | 'gaussians_binary': (1000, 1, 1, 100), 222 | } 223 | np.random.seed(0) 224 | v_global = np.random.randn(shapes_dict['gaussians_binary'][3]) # needed for consistency between train and test 225 | datasets_dict = {'mnist': datasets.MNIST, 226 | 'mnist_binary': datasets.MNIST, 227 | 'svhn': datasets.SVHN, 228 | 'cifar10': datasets.CIFAR10, 229 | 'cifar10_horse_car': datasets.CIFAR10, 230 | 'cifar10_dog_cat': datasets.CIFAR10, 231 | 'cifar100': datasets.CIFAR100, 232 | 'uniform_noise': uniform_noise, 233 | 'gaussians_binary': dataset_gaussians_binary, 234 | } 235 | classes_dict = {'cifar10': {0: 'airplane', 236 | 1: 'automobile', 237 | 2: 'bird', 238 | 3: 'cat', 239 | 4: 'deer', 240 | 5: 'dog', 241 | 6: 'frog', 242 | 7: 'horse', 243 | 8: 'ship', 244 | 9: 'truck', 245 | } 246 | } 247 | 248 | -------------------------------------------------------------------------------- /deep_nets/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | 7 | 8 | class Flatten(nn.Module): 9 | def forward(self, x): 10 | return x.view(x.size(0), -1) 11 | 12 | 13 | class Normalize(nn.Module): 14 | def __init__(self, mu, std): 15 | super(Normalize, self).__init__() 16 | self.mu, self.std = mu, std 17 | 18 | def forward(self, x): 19 | return (x - self.mu) / self.std 20 | 21 | 22 | class CustomReLU(nn.Module): 23 | def __init__(self): 24 | super(CustomReLU, self).__init__() 25 | self.collect_preact = True 26 | self.avg_preacts = [] 27 | 28 | def forward(self, preact): 29 | if self.collect_preact: 30 | self.avg_preacts.append(preact.abs().mean().item()) 31 | act = F.relu(preact) 32 | return act 33 | 34 | 35 | class ModuleWithStats(nn.Module): 36 | def __init__(self): 37 | super(ModuleWithStats, self).__init__() 38 | 39 | def forward(self, x): 40 | for layer in self._model: 41 | if type(layer) == CustomReLU: 42 | layer.avg_preacts = [] 43 | 44 | out = self._model(x) 45 | 46 | avg_preacts_all = [layer.avg_preacts for layer in self._model if type(layer) == CustomReLU] 47 | self.avg_preact = np.mean(avg_preacts_all) 48 | return out 49 | 50 | 51 | class Linear(ModuleWithStats): 52 | def __init__(self, n_cls, shape_in): 53 | n_cls = 1 if n_cls == 2 else n_cls 54 | super(Linear, self).__init__() 55 | d = int(np.prod(shape_in[1:])) 56 | self._model = nn.Sequential( 57 | Flatten(), 58 | nn.Linear(d, n_cls, bias=False) 59 | ) 60 | 61 | def forward(self, x): 62 | logits = self._model(x) 63 | return torch.cat([torch.zeros(logits.shape).cuda(), logits], dim=1) 64 | 65 | 66 | class LinearTwoOutputs(ModuleWithStats): 67 | def __init__(self, n_cls, shape_in): 68 | super(LinearTwoOutputs, self).__init__() 69 | d = int(np.prod(shape_in[1:])) 70 | self._model = nn.Sequential( 71 | Flatten(), 72 | nn.Linear(d, n_cls, bias=False) 73 | ) 74 | 75 | 76 | class IdentityLayer(nn.Module): 77 | def forward(self, inputs): 78 | return inputs 79 | 80 | 81 | class PreActBlock(nn.Module): 82 | """ Pre-activation version of the BasicBlock. """ 83 | expansion = 1 84 | 85 | def __init__(self, in_planes, planes, bn, learnable_bn, stride=1, activation='relu', droprate=0.0, gn_groups=32): 86 | super(PreActBlock, self).__init__() 87 | self.collect_preact = True 88 | self.activation = activation 89 | self.droprate = droprate 90 | self.avg_preacts = [] 91 | self.bn1 = nn.BatchNorm2d(in_planes, affine=learnable_bn) if bn else nn.GroupNorm(gn_groups, in_planes) 92 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=not learnable_bn) 93 | self.bn2 = nn.BatchNorm2d(planes, affine=learnable_bn) if bn else nn.GroupNorm(gn_groups, planes) 94 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=not learnable_bn) 95 | 96 | if stride != 1 or in_planes != self.expansion*planes: 97 | self.shortcut = nn.Sequential( 98 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=not learnable_bn) 99 | ) 100 | 101 | def act_function(self, preact): 102 | if self.activation == 'relu': 103 | act = F.relu(preact) 104 | # print((act == 0).float().mean().item(), (act.norm() / act.shape[0]).item(), (act.norm() / np.prod(act.shape)).item()) 105 | else: 106 | assert self.activation[:8] == 'softplus' 107 | beta = int(self.activation.split('softplus')[1]) 108 | act = F.softplus(preact, beta=beta) 109 | return act 110 | 111 | def forward(self, x): 112 | out = self.act_function(self.bn1(x)) 113 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x # Important: using out instead of x 114 | out = self.conv1(out) 115 | out = self.act_function(self.bn2(out)) 116 | if self.droprate > 0: 117 | out = F.dropout(out, p=self.droprate, training=self.training) 118 | out = self.conv2(out) 119 | out += shortcut 120 | return out 121 | 122 | 123 | class BasicBlock(nn.Module): 124 | def __init__(self, in_planes, out_planes, stride, droprate=0.0): 125 | super(BasicBlock, self).__init__() 126 | self.bn1 = nn.BatchNorm2d(in_planes) 127 | self.relu1 = nn.ReLU(inplace=True) 128 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 129 | self.bn2 = nn.BatchNorm2d(out_planes) 130 | self.relu2 = nn.ReLU(inplace=True) 131 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 132 | self.droprate = droprate 133 | self.equalInOut = (in_planes == out_planes) 134 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 135 | padding=0, bias=False) or None 136 | 137 | def forward(self, x): 138 | if not self.equalInOut: 139 | x = self.relu1(self.bn1(x)) 140 | else: 141 | out = self.relu1(self.bn1(x)) 142 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 143 | if self.droprate > 0: 144 | out = F.dropout(out, p=self.droprate, training=self.training) 145 | out = self.conv2(out) 146 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 147 | 148 | 149 | class BasicBlockResNet34(nn.Module): 150 | expansion = 1 151 | 152 | def __init__(self, in_planes, planes, stride=1): 153 | super(BasicBlockResNet34, self).__init__() 154 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 155 | self.bn1 = nn.BatchNorm2d(planes) 156 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 157 | self.bn2 = nn.BatchNorm2d(planes) 158 | 159 | self.shortcut = nn.Sequential() 160 | if stride != 1 or in_planes != self.expansion*planes: 161 | self.shortcut = nn.Sequential( 162 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 163 | nn.BatchNorm2d(self.expansion*planes) 164 | ) 165 | 166 | def forward(self, x): 167 | out = F.relu(self.bn1(self.conv1(x))) 168 | out = self.bn2(self.conv2(out)) 169 | out += self.shortcut(x) 170 | out = F.relu(out) 171 | return out 172 | 173 | 174 | class NetworkBlock(nn.Module): 175 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, droprate=0.0): 176 | super(NetworkBlock, self).__init__() 177 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, droprate) 178 | 179 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, droprate): 180 | layers = [] 181 | for i in range(int(nb_layers)): 182 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, droprate)) 183 | return nn.Sequential(*layers) 184 | 185 | def forward(self, x): 186 | return self.layer(x) 187 | 188 | 189 | class ResNet(nn.Module): 190 | def __init__(self, block, num_blocks, num_classes=10, model_width=64, droprate=0.0): 191 | super(ResNet, self).__init__() 192 | self.in_planes = model_width 193 | self.half_prec = False 194 | # self.mu = torch.tensor((0.4914, 0.4822, 0.4465)).view(1, 3, 1, 1).cuda() 195 | # self.std = torch.tensor((0.2471, 0.2435, 0.2616)).view(1, 3, 1, 1).cuda() 196 | self.mu = torch.tensor((0.0, 0.0, 0.0)).view(1, 3, 1, 1).cuda() 197 | self.std = torch.tensor((1.0, 1.0, 1.0)).view(1, 3, 1, 1).cuda() 198 | # if self.half_prec: 199 | # self.mu, self.std = self.mu.half(), self.std.half() 200 | 201 | self.normalize = Normalize(self.mu, self.std) 202 | self.conv1 = nn.Conv2d(3, model_width, kernel_size=3, stride=1, padding=1, bias=False) 203 | self.bn1 = nn.BatchNorm2d(model_width) 204 | self.layer1 = self._make_layer(block, model_width, num_blocks[0], stride=1) 205 | self.layer2 = self._make_layer(block, 2*model_width, num_blocks[1], stride=2) 206 | self.layer3 = self._make_layer(block, 4*model_width, num_blocks[2], stride=2) 207 | self.layer4 = self._make_layer(block, 8*model_width, num_blocks[3], stride=2) 208 | self.linear = nn.Linear(8*model_width*block.expansion, num_classes) 209 | 210 | def _make_layer(self, block, planes, num_blocks, stride): 211 | strides = [stride] + [1]*(num_blocks-1) 212 | layers = [] 213 | for stride in strides: 214 | layers.append(block(self.in_planes, planes, stride)) 215 | self.in_planes = planes * block.expansion 216 | return nn.Sequential(*layers) 217 | 218 | def forward(self, x, return_features=False, return_block=5): 219 | assert return_block in [1, 2, 3, 4, 5], 'wrong return_block' 220 | # out = self.normalize(x) 221 | out = F.relu(self.bn1(self.conv1(x))) 222 | out = self.layer1(out) 223 | if return_features and return_block == 1: 224 | return out 225 | out = self.layer2(out) 226 | if return_features and return_block == 2: 227 | return out 228 | out = self.layer3(out) 229 | if return_features and return_block == 3: 230 | return out 231 | out = self.layer4(out) 232 | if return_features and return_block == 4: 233 | return out 234 | out = F.avg_pool2d(out, 4) 235 | out = out.view(out.size(0), -1) 236 | if return_features and return_block == 5: 237 | return out 238 | out = self.linear(out) 239 | return out 240 | 241 | 242 | class PreActResNet(nn.Module): 243 | def __init__(self, block, num_blocks, n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', 244 | droprate=0.0, bn_flag=True): 245 | super(PreActResNet, self).__init__() 246 | self.half_prec = half_prec 247 | self.bn_flag = bn_flag 248 | self.gn_groups = model_width // 2 # in particular, 32 for model_width=64 as in the original GroupNorm paper 249 | self.learnable_bn = True # doesn't matter if self.bn=False 250 | self.in_planes = model_width 251 | self.avg_preact = None 252 | self.activation = activation 253 | self.n_cls = n_cls 254 | # self.mu = torch.tensor((0.4914, 0.4822, 0.4465)).view(1, 3, 1, 1) 255 | # self.std = torch.tensor((0.2471, 0.2435, 0.2616)).view(1, 3, 1, 1) 256 | self.mu = torch.tensor((0.0, 0.0, 0.0)).view(1, 3, 1, 1) 257 | self.std = torch.tensor((1.0, 1.0, 1.0)).view(1, 3, 1, 1) 258 | 259 | if cuda: 260 | self.mu, self.std = self.mu.cuda(), self.std.cuda() 261 | # if half_prec: 262 | # self.mu, self.std = self.mu.half(), self.std.half() 263 | 264 | self.normalize = Normalize(self.mu, self.std) 265 | self.conv1 = nn.Conv2d(3, model_width, kernel_size=3, stride=1, padding=1, bias=not self.learnable_bn) 266 | self.layer1 = self._make_layer(block, model_width, num_blocks[0], 1, droprate) 267 | self.layer2 = self._make_layer(block, 2*model_width, num_blocks[1], 2, droprate) 268 | self.layer3 = self._make_layer(block, 4*model_width, num_blocks[2], 2, droprate) 269 | final_layer_factor = 8 270 | self.layer4 = self._make_layer(block, final_layer_factor*model_width, num_blocks[3], 2, droprate) 271 | self.bn = nn.BatchNorm2d(final_layer_factor*model_width*block.expansion) if self.bn_flag \ 272 | else nn.GroupNorm(self.gn_groups, final_layer_factor*model_width*block.expansion) 273 | self.linear = nn.Linear(final_layer_factor*model_width*block.expansion, 1 if n_cls == 2 else n_cls) 274 | 275 | def _make_layer(self, block, planes, num_blocks, stride, droprate): 276 | strides = [stride] + [1]*(num_blocks-1) 277 | layers = [] 278 | for stride in strides: 279 | layers.append(block(self.in_planes, planes, self.bn_flag, self.learnable_bn, stride, self.activation, 280 | droprate, self.gn_groups)) 281 | # layers.append(block(self.in_planes, planes, stride)) 282 | self.in_planes = planes * block.expansion 283 | return nn.Sequential(*layers) 284 | 285 | def forward(self, x, return_features=False, return_block=5): 286 | assert return_block in [1, 2, 3, 4, 5], 'wrong return_block' 287 | for layer in [*self.layer1, *self.layer2, *self.layer3, *self.layer4]: 288 | layer.avg_preacts = [] 289 | 290 | # x = x / ((x**2).sum([1, 2, 3], keepdims=True)**0.5 + 1e-6) # numerical stability is needed for RLAT 291 | out = self.normalize(x) 292 | out = self.conv1(out) 293 | out = self.layer1(out) 294 | if return_features and return_block == 1: 295 | return out 296 | out = self.layer2(out) 297 | if return_features and return_block == 2: 298 | return out 299 | out = self.layer3(out) 300 | if return_features and return_block == 3: 301 | return out 302 | out = self.layer4(out) 303 | out = F.relu(self.bn(out)) 304 | if return_features and return_block == 4: 305 | return out 306 | out = F.avg_pool2d(out, 4) 307 | out = out.view(out.size(0), -1) 308 | if return_features and return_block == 5: 309 | return out 310 | 311 | out = self.linear(out) 312 | if out.shape[1] == 1: 313 | out = torch.cat([torch.zeros_like(out), out], dim=1) 314 | 315 | return out 316 | 317 | 318 | class WideResNet(nn.Module): 319 | """ Based on code from https://github.com/yaodongyu/TRADES """ 320 | def __init__(self, depth=28, num_classes=10, widen_factor=10, droprate=0.0, bias_last=True): 321 | super(WideResNet, self).__init__() 322 | self.half_prec = False 323 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 324 | assert ((depth - 4) % 6 == 0) 325 | n = (depth - 4) / 6 326 | block = BasicBlock 327 | # 1st conv before any network block 328 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) 329 | # 1st block 330 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, droprate) 331 | # 2nd block 332 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, droprate) 333 | # 3rd block 334 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, droprate) 335 | # global average pooling and classifier 336 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 337 | self.relu = nn.ReLU(inplace=True) 338 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 339 | self.nChannels = nChannels[3] 340 | 341 | for m in self.modules(): 342 | if isinstance(m, nn.Conv2d): 343 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 344 | m.weight.data.normal_(0, math.sqrt(2. / n)) 345 | elif isinstance(m, nn.BatchNorm2d): 346 | m.weight.data.fill_(1) 347 | m.bias.data.zero_() 348 | elif isinstance(m, nn.Linear) and not m.bias is None: 349 | m.bias.data.zero_() 350 | 351 | def forward(self, x): 352 | out = self.conv1(x) 353 | out = self.block1(out) 354 | out = self.block2(out) 355 | out = self.block3(out) 356 | out = self.relu(self.bn1(out)) 357 | out = F.avg_pool2d(out, 8) 358 | out = out.view(-1, self.nChannels) 359 | return self.fc(out) 360 | 361 | 362 | class VGG(nn.Module): 363 | ''' 364 | VGG model. Source: https://github.com/chengyangfu/pytorch-vgg-cifar10/blob/master/vgg.py 365 | (in turn modified from https://github.com/pytorch/vision.git) 366 | ''' 367 | def __init__(self, n_cls, half_prec, cfg): 368 | super(VGG, self).__init__() 369 | self.half_prec = half_prec 370 | self.mu = torch.tensor((0.485, 0.456, 0.406)).view(1, 3, 1, 1).cuda() 371 | self.std = torch.tensor((0.229, 0.224, 0.225)).view(1, 3, 1, 1).cuda() 372 | self.normalize = Normalize(self.mu, self.std) 373 | self.features = self.make_layers(cfg) 374 | n_out = cfg[-2] # equal to 8*model_width 375 | self.classifier = nn.Sequential( 376 | # nn.Dropout(), 377 | nn.Linear(n_out, n_out), 378 | nn.ReLU(True), 379 | # nn.Dropout(), 380 | nn.Linear(n_out, n_out), 381 | nn.ReLU(True), 382 | nn.Linear(n_out, n_cls), 383 | ) 384 | # Initialize weights 385 | for m in self.modules(): 386 | if isinstance(m, nn.Conv2d): 387 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 388 | m.weight.data.normal_(0, math.sqrt(2. / n)) 389 | m.bias.data.zero_() 390 | 391 | def make_layers(self, cfg, batch_norm=False): 392 | layers = [] 393 | in_channels = 3 394 | for v in cfg: 395 | if v == 'M': 396 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 397 | else: 398 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 399 | if batch_norm: 400 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 401 | else: 402 | layers += [conv2d, nn.ReLU(inplace=True)] 403 | in_channels = v 404 | return nn.Sequential(*layers) 405 | 406 | def forward(self, x): 407 | x = self.normalize(x) 408 | x = self.features(x) 409 | x = x.view(x.size(0), -1) 410 | x = self.classifier(x) 411 | return x 412 | 413 | 414 | def VGG16(n_cls, model_width, half_prec): 415 | """VGG 16-layer model (configuration "D")""" 416 | w1, w2, w3, w4, w5 = model_width, 2*model_width, 4*model_width, 8*model_width, 8*model_width 417 | # cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 418 | cfg = [w1, w1, 'M', w2, w2, 'M', w3, w3, w3, 'M', w4, w4, w4, 'M', w5, w5, w5, 'M'] 419 | return VGG(n_cls, half_prec, cfg) 420 | 421 | 422 | def TinyResNet(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0): 423 | bn_flag = True 424 | return PreActResNet(PreActBlock, [1, 1, 1, 1], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec, 425 | activation=activation, droprate=droprate, bn_flag=bn_flag) 426 | 427 | def TinyResNetGroupNorm(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0): 428 | bn_flag = False 429 | return PreActResNet(PreActBlock, [1, 1, 1, 1], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec, 430 | activation=activation, droprate=droprate, bn_flag=bn_flag) 431 | 432 | 433 | def ResNet18(n_cls, model_width=64): 434 | return ResNet(BasicBlockResNet34, [2, 2, 2, 2], num_classes=n_cls, model_width=model_width) 435 | 436 | 437 | def PreActResNet18(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0): 438 | bn_flag = True 439 | return PreActResNet(PreActBlock, [2, 2, 2, 2], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec, 440 | activation=activation, droprate=droprate, bn_flag=bn_flag) 441 | 442 | 443 | def PreActResNet34(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0): 444 | bn_flag = True 445 | return PreActResNet(PreActBlock, [3, 4, 6, 3], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec, 446 | activation=activation, droprate=droprate, bn_flag=bn_flag) 447 | 448 | 449 | def PreActResNet18GroupNorm(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0): 450 | bn_flag = False # bn_flag==False means that we use GroupNorm with 32 groups 451 | return PreActResNet(PreActBlock, [2, 2, 2, 2], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec, 452 | activation=activation, droprate=droprate, bn_flag=bn_flag) 453 | 454 | 455 | def PreActResNet34GroupNorm(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0): 456 | bn_flag = False # bn_flag==False means that we use GroupNorm with 32 groups 457 | return PreActResNet(PreActBlock, [3, 4, 6, 3], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec, 458 | activation=activation, droprate=droprate, bn_flag=bn_flag) 459 | 460 | 461 | def ResNet34(n_cls, model_width=64): 462 | return ResNet(BasicBlockResNet34, [3, 4, 6, 3], num_classes=n_cls, model_width=model_width) 463 | 464 | 465 | def WideResNet28(n_cls, model_width=10): 466 | return WideResNet(num_classes=n_cls, widen_factor=model_width) 467 | 468 | 469 | def get_model(model_name, n_cls, half_prec, shapes_dict, model_width, activation='relu', droprate=0.0): 470 | if model_name == 'resnet18': 471 | model = PreActResNet18(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate) 472 | elif model_name == 'resnet18_plain': 473 | model = ResNet18(n_cls, model_width) 474 | elif model_name == 'resnet18_gn': 475 | model = PreActResNet18GroupNorm(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate) 476 | elif model_name == 'vgg16': 477 | assert droprate == 0.0, 'dropout is not implemented for vgg16' 478 | model = VGG16(n_cls, model_width, half_prec) 479 | elif model_name in ['resnet34', 'resnet34_plain']: 480 | model = ResNet34(n_cls, model_width) 481 | elif model_name == 'resnet34_gn': 482 | model = PreActResNet34GroupNorm(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate) 483 | elif model_name == 'resnet34preact': 484 | model = PreActResNet34(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate) 485 | elif model_name == 'wrn28': 486 | model = WideResNet28(n_cls, model_width) 487 | elif model_name == 'resnet_tiny': 488 | model = TinyResNet(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate) 489 | elif model_name == 'resnet_tiny_gn': 490 | model = TinyResNetGroupNorm(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate) 491 | elif model_name == 'linear': 492 | model = Linear(n_cls, shapes_dict) 493 | else: 494 | raise ValueError('wrong model') 495 | return model 496 | 497 | 498 | def init_weights(model, scale_init=0.0): 499 | def init_weights_he(m): 500 | if isinstance(m, nn.Conv2d): 501 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 502 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 503 | m.weight.data.fill_(1) 504 | m.bias.data.zero_() 505 | elif isinstance(m, nn.Linear): 506 | m.bias.data.zero_() 507 | 508 | return init_weights_he 509 | 510 | 511 | def forward_pass_rlat(model, x, deltas, layers): 512 | i = 0 513 | 514 | def out_hook(m, inp, out_layer): 515 | nonlocal i 516 | if layers[i] == model.normalize: 517 | new_out = (torch.clamp(inp[0] + deltas[i], 0, 1) - model.mu) / model.std 518 | else: 519 | new_out = out_layer + deltas[i] 520 | i += 1 521 | return new_out 522 | 523 | handles = [layer.register_forward_hook(out_hook) for layer in layers] 524 | out = model(x) 525 | 526 | for handle in handles: 527 | handle.remove() 528 | return out 529 | 530 | -------------------------------------------------------------------------------- /deep_nets/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import copy 8 | import utils 9 | import utils_eval 10 | import utils_train 11 | import data 12 | import models 13 | from collections import defaultdict 14 | from datetime import datetime 15 | from models import forward_pass_rlat 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--batch_size', default=128, type=int) 21 | parser.add_argument('--dataset', default='cifar10', choices=data.datasets_dict.keys(), type=str) 22 | parser.add_argument('--model', default='resnet18', choices=['vgg16', 'resnet18', 'resnet18_plain', 'resnet18_gn', 'resnet_tiny', 'resnet_tiny_gn', 'resnet34', 'resnet34_plain', 'resnet34preact', 'resnet34_gn', 'wrn28'], type=str) 23 | parser.add_argument('--epochs', default=100, type=int, help='100 epochs is standard with batch_size=128') 24 | parser.add_argument('--lr_schedule', default='piecewise', choices=['cyclic', 'piecewise', 'cosine', 'constant', 'piecewise_02epochs', 'piecewise_03epochs', 'piecewise_05epochs']) 25 | parser.add_argument('--ln_schedule', default='constant', choices=['constant', 'inverted_cosine', 'piecewise_10_100', 'piecewise_3_9', 'piecewise_3_inf', 'piecewise_2_3_3', 'piecewise_5_3_3', 'piecewise_8_3_3']) 26 | parser.add_argument('--lr_init', default=0.1, type=float, help='') 27 | parser.add_argument('--warmup_factor', default=1.0, type=float, help='linear warmup factor of the peak lr') 28 | parser.add_argument('--warmup_exp', default=1.0, type=float, help='the exponent of the exponential warmup') 29 | parser.add_argument('--momentum', default=0.9, type=float, help='') 30 | parser.add_argument('--p_label_noise', default=0.0, type=float, help='Fraction of flipped labels in the training data.') 31 | parser.add_argument('--noise_type', default='sym', type=str, choices=['sym', 'asym'], help='Noise type: symmetric or asymmetric') 32 | parser.add_argument('--attack', default='none', type=str, choices=['fgsm', 'fgsmpp', 'pgd', 'rlat', 'random_noise', 'none']) 33 | parser.add_argument('--at_pred_label', action='store_true', help='Use predicted labels for AT.') 34 | parser.add_argument('--swa_tau', default=0.999, type=float, help='SWA moving averaging coefficient (averaging executed every iteration).') 35 | parser.add_argument('--sgd_p_label_noise', default=0.0, type=float, help='ratio of label noise in SGD per batch') 36 | parser.add_argument('--frac_train', default=1, type=float, help='Fraction of training points.') 37 | parser.add_argument('--l2_reg', default=0.0, type=float, help='l2 regularization in the objective') 38 | parser.add_argument('--seed', default=0, type=int) 39 | parser.add_argument('--gpu', default=0, type=int) 40 | parser.add_argument('--debug', action='store_true') 41 | parser.add_argument('--save_model_each_k_epochs', default=0, type=int, help='save each k epochs; 0 means saving only at the end') 42 | parser.add_argument('--half_prec', action='store_true', help='if enabled, runs everything as half precision [not recommended]') 43 | parser.add_argument('--no_data_augm', action='store_true') 44 | parser.add_argument('--eval_iter_freq', default=-1, type=int, help='how often to evaluate test stats. -1 means to evaluate each #iter that corresponds to 2nd epoch with frac_train=1.') 45 | parser.add_argument('--n_eval_every_k_iter', default=512, type=int, help='on how many examples to eval every k iters') 46 | parser.add_argument('--model_width', default=-1, type=int, help='model width (# conv filters on the first layer for ResNets)') 47 | parser.add_argument('--batch_size_eval', default=512, type=int, help='batch size for the final eval with pgd rr; 6 GB memory is consumed for 1024 examples with fp32 network') 48 | parser.add_argument('--n_final_eval', default=10000, type=int, help='on how many examples to do the final evaluation; -1 means on all test examples.') 49 | parser.add_argument('--exp_name', default='other', type=str) 50 | parser.add_argument('--model_path', type=str, default='', help='Path to a checkpoint to continue training from.') 51 | return parser.parse_args() 52 | 53 | 54 | def main(): 55 | args = get_args() 56 | assert args.model_width != -1, 'args.model_width has to be always specified (e.g., 64 for resnet18, 10 for wrn28)' 57 | assert 0 <= args.frac_train <= 1 58 | assert 0 <= args.sgd_p_label_noise <= 1 59 | assert 0 <= args.swa_tau <= 1 60 | 61 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 62 | cur_timestamp = str(datetime.now())[:-3] # include also ms to prevent the probability of name collision 63 | model_name = '{} dataset={} model={} epochs={} lr_init={} model_width={} l2_reg={} batch_size={} frac_train={} p_label_noise={} seed={}'.format( 64 | cur_timestamp, args.dataset, args.model, args.epochs, args.lr_init, args.model_width, args.l2_reg, 65 | args.batch_size, args.frac_train, args.p_label_noise, args.seed) 66 | logger = utils.configure_logger(model_name, args.debug) 67 | logger.info(args) 68 | 69 | n_cls = 2 if args.dataset in ['cifar10_horse_car', 'cifar10_dog_cat'] else 10 if args.dataset != 'cifar100' else 100 70 | n_train = int(args.frac_train * data.shapes_dict[args.dataset][0]) 71 | 72 | args.exp_name = 'exps/{}'.format(args.exp_name) 73 | if not os.path.exists(args.exp_name): os.makedirs(args.exp_name) 74 | 75 | # fixing the seed helps, but not completely. there is still some non-determinism due to GPU computations. 76 | np.random.seed(args.seed) 77 | torch.manual_seed(args.seed) 78 | torch.cuda.manual_seed(args.seed) 79 | 80 | train_data_augm = False if args.no_data_augm or args.model == 'linear' or args.dataset in ['mnist', 'mnist_binary', 'gaussians_binary'] else True 81 | train_batches = data.get_loaders(args.dataset, n_train, args.batch_size, split='train', val_indices=[], shuffle=True, data_augm=train_data_augm, p_label_noise=args.p_label_noise, noise_type=args.noise_type, drop_last=True) 82 | train_batches_large_bs = data.get_loaders(args.dataset, n_train, args.batch_size_eval, split='train', val_indices=[], shuffle=False, data_augm=False, p_label_noise=args.p_label_noise, noise_type=args.noise_type, drop_last=False) 83 | test_batches = data.get_loaders(args.dataset, args.n_final_eval, args.batch_size_eval, split='test', shuffle=True, data_augm=False, noise_type=args.noise_type, drop_last=False) 84 | 85 | model = models.get_model(args.model, n_cls, args.half_prec, data.shapes_dict[args.dataset], args.model_width).cuda() 86 | if args.model_path != '': 87 | model_dict = torch.load(args.model_path)['last'] 88 | model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k}) 89 | else: 90 | model.apply(models.init_weights(args.model)) 91 | model.train() 92 | model_swa = copy.deepcopy(model).eval() # stochastic weight averaging model (keep it in the eval mode by default) 93 | 94 | opt = torch.optim.SGD(model.parameters(), lr=args.lr_init, momentum=args.momentum) 95 | scaler = torch.cuda.amp.GradScaler(enabled=model.half_prec) 96 | lr_schedule = utils_train.get_lr_schedule(args.lr_schedule, args.epochs, args.lr_init, args.warmup_factor, args.warmup_exp) 97 | ln_schedule = utils_train.get_lr_schedule(args.ln_schedule, args.epochs, args.sgd_p_label_noise) 98 | 99 | loss_f = lambda logits, y: F.cross_entropy(logits, y, reduction='mean') 100 | 101 | metr_dict = defaultdict(list, vars(args)) 102 | start_time = time.time() 103 | time_train, iteration = 0, 0 104 | for epoch in range(args.epochs + 1): 105 | model = model.eval() if epoch == 0 else model.train() # epoch=0 is eval only 106 | 107 | train_obj, train_reg = 0, 0 108 | for i, (x, _, y, _, ln) in enumerate(train_batches): 109 | if epoch == 0 and i > 0: # epoch=0 runs only for one iteration (to check the training stats at init) 110 | break 111 | time_start_iter = time.time() 112 | x, y = x.cuda(), y.cuda() 113 | lr = lr_schedule(epoch - 1 + (i + 1) / len(train_batches)) # epoch - 1 since the 0th epoch is skipped 114 | opt.param_groups[0].update(lr=lr) 115 | 116 | # label noise SGD 117 | if args.sgd_p_label_noise > 0.0: 118 | sgd_p_label_noise_eff = ln_schedule(epoch - 1 + (i + 1) / len(train_batches)) 119 | n_noisy_pts = (torch.rand(args.batch_size) < sgd_p_label_noise_eff).int().sum() # randomized fraction of noisy points 120 | rand_indices = torch.randperm(args.batch_size)[:n_noisy_pts].cuda() 121 | rand_labels = torch.randint(low=0, high=n_cls, size=(n_noisy_pts, )).cuda() 122 | y[rand_indices] = rand_labels 123 | 124 | with torch.cuda.amp.autocast(enabled=model.half_prec): 125 | logits = model(x) 126 | obj = loss_f(logits, y) 127 | 128 | reg = torch.zeros(1).cuda()[0] 129 | for param in model.parameters(): 130 | reg += args.l2_reg * 0.5 * torch.sum(param ** 2).float() 131 | obj += reg 132 | 133 | 134 | opt.zero_grad() 135 | scaler.scale(obj).backward() 136 | 137 | train_obj += obj.item() / n_train # only for statistics 138 | train_reg += reg.item() / n_train # only for statistics 139 | 140 | if epoch > 0: # on 0-th epoch only evaluation occurs 141 | scaler.step(opt) 142 | scaler.update() # update the scale of the loss for fp16 143 | 144 | opt.zero_grad() # zero grad (also at epoch==0) 145 | 146 | time_train += time.time() - time_start_iter 147 | utils_train.moving_average(model_swa, model, 1-args.swa_tau) # executed every iteration 148 | 149 | # by default, evaluate every 2 epochs (update: 5 temporary to save time) 150 | if (args.eval_iter_freq == -1 and iteration % (5 * (n_train // args.batch_size)) == 0) or \ 151 | (args.eval_iter_freq != -1 and iteration % args.eval_iter_freq == 0): 152 | utils_train.bn_update(train_batches, model_swa) # a bit heavy but ok to do once per 2 epochs 153 | 154 | model.eval() # it'd be incorrect to recalculate the BN stats based on some evaluations 155 | 156 | train_err, train_loss = utils_eval.compute_loss(train_batches, model, loss_f=loss_f, n_batches=4) # i.e. it's evaluated using 4*batch_size examples 157 | train_err_swa, train_loss_swa = utils_eval.compute_loss(train_batches, model_swa, loss_f=loss_f, n_batches=4) # i.e. it's evaluated using 4*batch_size examples 158 | 159 | sparsity_train_block1, sparsity_train_block1_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=1, n_batches=20) 160 | sparsity_train_block2, sparsity_train_block2_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=2, n_batches=20) 161 | sparsity_train_block3, sparsity_train_block3_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=3, n_batches=20) 162 | sparsity_train_block4, sparsity_train_block4_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=4, n_batches=20) 163 | sparsity_train_block5, sparsity_train_block5_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=5, n_batches=20) 164 | # sparsity_train_block1, sparsity_train_block1_rmdup, sparsity_train_block2, sparsity_train_block2_rmdup, sparsity_train_block3, sparsity_train_block3_rmdup, sparsity_train_block4, sparsity_train_block4_rmdup, sparsity_train_block5, sparsity_train_block5_rmdup = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 165 | 166 | test_err, test_loss = utils_eval.compute_loss(test_batches, model, loss_f=loss_f) 167 | test_err_swa, _ = utils_eval.compute_loss(test_batches, model_swa, loss_f=loss_f) 168 | 169 | time_elapsed = time.time() - start_time 170 | train_str = '[train] loss {:.4f}/{:.4f} err {:.2%}'.format(train_loss, train_loss_swa, train_err) 171 | test_str = '[test] err {:.2%}/{:.2%} '.format(test_err, test_err_swa) 172 | sparsity_str = 'sparsity {:.1%}/{:.1%}/{:.1%}/{:.1%}/{:.1%}'.format(sparsity_train_block1_rmdup, sparsity_train_block2_rmdup, sparsity_train_block3_rmdup, sparsity_train_block4_rmdup, sparsity_train_block5_rmdup) 173 | logger.info('{}-{}: {} {} {} ({:.2f}m, {:.2f}m)'.format( 174 | epoch, iteration, train_str, test_str, sparsity_str, time_train/60, time_elapsed/60)) 175 | metr_vals = [epoch, iteration, train_obj, train_loss, train_reg, train_err, 176 | test_err, test_loss, train_loss_swa, train_err_swa, 177 | test_err_swa, time_train, time_elapsed, 178 | sparsity_train_block1, sparsity_train_block2, sparsity_train_block3, sparsity_train_block4, sparsity_train_block5] 179 | metr_names = ['epoch', 'iter', 'train_obj', 'train_loss', 'train_reg', 'train_err', 180 | 'test_err', 'test_loss', 'train_loss_swa', 'train_err_swa', 'test_err_swa', 'time_train', 'time_elapsed', 181 | 'sparsity_train_block1', 'sparsity_train_block2', 'sparsity_train_block3', 'sparsity_train_block4', 'sparsity_train_block5'] 182 | utils.update_metrics(metr_dict, metr_vals, metr_names) 183 | 184 | if not args.debug: 185 | np.save('{}/{}.npy'.format(args.exp_name, model_name), metr_dict) 186 | 187 | model.train() 188 | 189 | iteration += 1 190 | 191 | if args.save_model_each_k_epochs > 0: 192 | if epoch % args.save_model_each_k_epochs == 0 or epoch <= 5: 193 | torch.save({'last': model.state_dict()}, 'models/{} epoch={}.pth'.format(model_name, epoch)) 194 | 195 | if not args.debug: 196 | np.save('{}/{}.npy'.format(args.exp_name, model_name), metr_dict) 197 | if epoch == args.epochs: # only save at the end 198 | torch.save({'last': model.state_dict(), 'swa_last': model_swa.state_dict()}, 199 | 'models/{} epoch={}.pth'.format(model_name, epoch)) 200 | 201 | logger.info('Saved the model at: models/{} epoch={}.pth'.format(model_name, epoch)) 202 | logger.info('Done in {:.2f}m'.format((time.time() - start_time) / 60)) 203 | 204 | 205 | if __name__ == "__main__": 206 | main() 207 | -------------------------------------------------------------------------------- /deep_nets/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from contextlib import contextmanager 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | logging.basicConfig( 9 | format='[%(asctime)s %(filename)s %(name)s %(levelname)s] - %(message)s', 10 | datefmt='%Y/%m/%d %H:%M:%S', 11 | level=logging.DEBUG) 12 | 13 | 14 | def clamp(X, l, u, cuda=True): 15 | if type(l) is not torch.Tensor: 16 | if cuda: 17 | l = torch.cuda.FloatTensor(1).fill_(l) 18 | else: 19 | l = torch.FloatTensor(1).fill_(l) 20 | if type(u) is not torch.Tensor: 21 | if cuda: 22 | u = torch.cuda.FloatTensor(1).fill_(u) 23 | else: 24 | u = torch.FloatTensor(1).fill_(u) 25 | return torch.max(torch.min(X, u), l) 26 | 27 | 28 | def configure_logger(model_name, debug): 29 | if not os.path.exists('logs'): 30 | os.makedirs('logs') 31 | logging.basicConfig(format='%(message)s') # , level=logging.DEBUG) 32 | logger = logging.getLogger() 33 | logger.handlers = [] # remove the default logger 34 | 35 | # add a new logger for stdout 36 | formatter = logging.Formatter('%(message)s') 37 | ch = logging.StreamHandler() 38 | ch.setFormatter(formatter) 39 | ch.setLevel(logging.DEBUG) 40 | logger.addHandler(ch) 41 | 42 | if not debug: 43 | # add a new logger to a log file 44 | logger.addHandler(logging.FileHandler('logs/{}.log'.format(model_name))) 45 | 46 | return logger 47 | 48 | 49 | def get_random_delta(shape, eps, at_norm, requires_grad=True): 50 | delta = torch.zeros(shape).cuda() 51 | if at_norm == 'l2': # uniform from the hypercube 52 | delta.normal_() 53 | delta /= (delta**2).sum([1, 2, 3], keepdim=True)**0.5 54 | elif at_norm == 'linf': # uniform on the sphere 55 | delta.uniform_(-eps, eps) 56 | else: 57 | raise ValueError('wrong at_norm') 58 | delta.requires_grad = requires_grad 59 | return delta 60 | 61 | 62 | def project_lp(img, at_norm, eps): 63 | if at_norm == 'l2': # uniform on the sphere 64 | l2_norms = (img ** 2).sum([1, 2, 3], keepdim=True) ** 0.5 65 | img_proj = img * torch.min(eps/l2_norms, torch.ones_like(l2_norms)) # if eps>l2_norms => multiply by 1 66 | elif at_norm == 'linf': # uniform from the hypercube 67 | img_proj = clamp(img, -eps, eps) 68 | else: 69 | raise ValueError('wrong at_norm') 70 | return img_proj 71 | 72 | 73 | def update_metrics(metrics_dict, metrics_values, metrics_names): 74 | assert len(metrics_values) == len(metrics_names) 75 | for metric_value, metric_name in zip(metrics_values, metrics_names): 76 | metrics_dict[metric_name].append(metric_value) 77 | return metrics_dict 78 | 79 | 80 | @contextmanager 81 | def nullcontext(enter_result=None): 82 | yield enter_result 83 | 84 | 85 | def get_flat_grad(model): 86 | return torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]) 87 | 88 | 89 | def zero_grad(model): 90 | for p in model.parameters(): 91 | if p.grad is not None: 92 | p.grad.zero_() 93 | 94 | -------------------------------------------------------------------------------- /deep_nets/utils_eval.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def compute_loss(batches, model, cuda=True, noisy_examples='default', loss_f=F.cross_entropy, n_batches=-1): 8 | n_corr_classified, train_loss_sum, n_ex = 0, 0.0, 0 9 | for i, (X, X_augm2, y, _, ln) in enumerate(batches): 10 | if n_batches != -1 and i > n_batches: # limit to only n_batches 11 | break 12 | if cuda: 13 | X, y = X.cuda(), y.cuda() 14 | 15 | if noisy_examples == 'none': 16 | X, y = X[~ln], y[~ln] 17 | elif noisy_examples == 'all': 18 | X, y = X[ln], y[ln] 19 | else: 20 | assert noisy_examples == 'default' 21 | 22 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=model.half_prec): 23 | output = model(X) 24 | loss = loss_f(output, y) 25 | 26 | n_corr_classified += (output.max(1)[1] == y).sum().item() 27 | train_loss_sum += loss.item() * y.size(0) 28 | n_ex += y.size(0) 29 | 30 | robust_acc = n_corr_classified / n_ex 31 | avg_loss = train_loss_sum / n_ex 32 | 33 | return 1 - robust_acc, avg_loss 34 | 35 | 36 | def compute_feature_sparsity(batches, model, return_block, corr_threshold=0.95, n_batches=-1, n_relu_max=1000): 37 | with torch.no_grad(): 38 | features_list = [] 39 | for i, (X, X_augm2, y, _, ln) in enumerate(batches): 40 | if n_batches != -1 and i > n_batches: 41 | break 42 | X, y = X.cuda(), y.cuda() 43 | features = model(X, return_features=True, return_block=return_block).cpu().numpy() 44 | features_list.append(features) 45 | 46 | phi = np.vstack(features_list) 47 | phi = phi.reshape(phi.shape[0], np.prod(phi.shape[1:])) 48 | 49 | sparsity = (phi > 0).sum() / (phi.shape[0] * phi.shape[1]) 50 | 51 | if phi.shape[1] > n_relu_max: # if there are too many neurons, we speed it up by random subsampling 52 | random_idx = np.random.choice(phi.shape[1], n_relu_max, replace=False) 53 | phi = phi[:, random_idx] 54 | 55 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0] 56 | phi_filtered = phi[:, idx_keep] # filter out always-zeros 57 | corr_matrix = np.corrcoef(phi_filtered.T) 58 | corr_matrix -= np.eye(corr_matrix.shape[0]) 59 | 60 | idx_to_delete, i, j = [], 0, 0 61 | while i != corr_matrix.shape[0]: 62 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum()) 63 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0: 64 | corr_matrix = np.delete(corr_matrix, (i), axis=0) 65 | corr_matrix = np.delete(corr_matrix, (i), axis=1) 66 | # print('delete', j) 67 | idx_to_delete.append(j) 68 | else: 69 | i += 1 70 | j += 1 71 | assert corr_matrix.shape[0] == corr_matrix.shape[1] 72 | # print(idx_to_delete, idx_keep) 73 | idx_keep = np.delete(idx_keep, [idx_to_delete]) 74 | sparsity_rmdup = (phi[:, idx_keep] > 0).sum() / (phi.shape[0] * phi.shape[1]) 75 | 76 | n_highly_corr = phi.shape[1] - len(idx_keep) 77 | return sparsity, sparsity_rmdup, n_highly_corr 78 | -------------------------------------------------------------------------------- /deep_nets/utils_train.py: -------------------------------------------------------------------------------- 1 | from logging import lastResort 2 | import torch 3 | import numpy as np 4 | import math 5 | from utils import clamp, get_random_delta, project_lp 6 | 7 | 8 | def get_lr_schedule(lr_schedule_type, n_epochs, lr, warmup_factor=1.0, warmup_exp=1.0): 9 | if lr_schedule_type == 'cyclic': 10 | lr_schedule = lambda epoch: np.interp([epoch], [0, n_epochs * 2 // 5, n_epochs], [0, lr, 0])[0] 11 | elif lr_schedule_type in ['piecewise', 'piecewise_10_100']: 12 | def lr_schedule(t): 13 | """ 14 | Following the original ResNet paper (+ warmup for resnet34). 15 | t is the fractional number of epochs that is passed which starts from 0. 16 | """ 17 | # if 100 epochs in total, then warmup lasts for exactly 2 first epochs 18 | # if t / n_epochs < 0.02 and model in ['resnet34']: 19 | # return lr_max / 10. 20 | if t / n_epochs < 0.5: 21 | return lr 22 | elif t / n_epochs < 0.75: 23 | return lr / 10. 24 | else: 25 | return lr / 100. 26 | elif lr_schedule_type in 'piecewise_01epochs': 27 | def lr_schedule(t): 28 | if warmup_exp > 1.0: 29 | if t / n_epochs < 0.1: 30 | return warmup_exp**(t/n_epochs*100) * lr 31 | elif t / n_epochs < 0.9: 32 | return warmup_exp**(0.1*100) * lr / 10. 33 | else: 34 | return warmup_exp**(0.1*100) * lr / 100. 35 | elif warmup_exp < 1.0: 36 | if t / n_epochs < 0.1: 37 | return (1 + (t/n_epochs*100)**warmup_exp) * lr 38 | elif t / n_epochs < 0.9: 39 | return (1 + (0.1*100)**warmup_exp) * lr / 10. 40 | else: 41 | return (1 + (0.1*100)**warmup_exp) * lr / 100. 42 | else: 43 | if t / n_epochs < 0.1: 44 | return np.interp([t], [0, 0.5 * n_epochs], [lr, warmup_factor*lr])[0] # note: we interpolate up to 0.5*t to be compatible with toy ReLU net experiments 45 | elif t / n_epochs < 0.9: 46 | return 0.1 / 0.5 * warmup_factor*lr / 10. 47 | else: 48 | return 0.1 / 0.5 * warmup_factor*lr / 100. 49 | elif lr_schedule_type in 'piecewise_03epochs': 50 | def lr_schedule(t): 51 | if warmup_exp > 1.0: 52 | if t / n_epochs < 0.3: 53 | return warmup_exp**(t/n_epochs*100) * lr 54 | elif t / n_epochs < 0.9: 55 | return warmup_exp**(0.3*100) * lr / 10. 56 | else: 57 | return warmup_exp**(0.3*100) * lr / 100. 58 | elif warmup_exp < 1.0: 59 | if t / n_epochs < 0.3: 60 | return (1 + (t/n_epochs*100)**warmup_exp) * lr 61 | elif t / n_epochs < 0.9: 62 | return (1 + (0.3*100)**warmup_exp) * lr / 10. 63 | else: 64 | return (1 + (0.3*100)**warmup_exp) * lr / 100. 65 | else: 66 | if t / n_epochs < 0.3: 67 | return np.interp([t], [0, 0.5 * n_epochs], [lr, warmup_factor*lr])[0] # note: we interpolate up to 0.5*t to be compatible with toy ReLU net experiments 68 | elif t / n_epochs < 0.9: 69 | return 0.3 / 0.5 * warmup_factor*lr / 10. 70 | else: 71 | return 0.3 / 0.5 * warmup_factor*lr / 100. 72 | elif lr_schedule_type in 'piecewise_05epochs': 73 | def lr_schedule(t): 74 | if warmup_exp > 1.0: 75 | if t / n_epochs < 0.5: 76 | return warmup_exp**(t/n_epochs*100) * lr 77 | elif t / n_epochs < 0.9: 78 | return warmup_exp**(0.5*100) * lr / 10. 79 | else: 80 | return warmup_exp**(0.5*100) * lr / 100. 81 | elif warmup_exp < 1.0: 82 | if t / n_epochs < 0.5: 83 | return (1 + (t/n_epochs*100)**warmup_exp) * lr 84 | elif t / n_epochs < 0.9: 85 | return (1 + (0.5*100)**warmup_exp) * lr / 10. 86 | else: 87 | return (1 + (0.5*100)**warmup_exp) * lr / 100. 88 | else: 89 | if t / n_epochs < 0.5: 90 | return np.interp([t], [0, 0.5 * n_epochs], [lr, warmup_factor*lr])[0] # note: we interpolate up to 0.5*t to be compatible with toy ReLU net experiments 91 | elif t / n_epochs < 0.9: 92 | return 0.5 / 0.5 * warmup_factor*lr / 10. 93 | else: 94 | return 0.5 / 0.5 * warmup_factor*lr / 100. 95 | elif lr_schedule_type == 'cosine': 96 | # cosine LR schedule without restarts like in the SAM paper 97 | # (as in the JAX implementation used in SAM https://flax.readthedocs.io/en/latest/_modules/flax/training/lr_schedule.html#create_cosine_learning_rate_schedule) 98 | return lambda epoch: lr * (0.5 + 0.5*math.cos(math.pi * epoch / n_epochs)) 99 | elif lr_schedule_type == 'inverted_cosine': 100 | return lambda epoch: lr - lr * (0.5 + 0.5*math.cos(math.pi * epoch / n_epochs)) 101 | elif lr_schedule_type == 'constant': 102 | return lambda epoch: lr 103 | else: 104 | raise ValueError('wrong lr_schedule_type') 105 | return lr_schedule 106 | 107 | 108 | def change_bn_mode(model, bn_train): 109 | for module in model.modules(): 110 | if isinstance(module, torch.nn.BatchNorm2d): 111 | if bn_train: 112 | module.train() 113 | else: 114 | module.eval() 115 | 116 | 117 | def moving_average(net1, net2, alpha=0.999): 118 | for param1, param2 in zip(net1.parameters(), net2.parameters()): 119 | param1.data *= (1.0 - alpha) 120 | param1.data += param2.data * alpha 121 | 122 | 123 | def _check_bn(module, flag): 124 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 125 | flag[0] = True 126 | 127 | 128 | def check_bn(model): 129 | flag = [False] 130 | model.apply(lambda module: _check_bn(module, flag)) 131 | return flag[0] 132 | 133 | 134 | def reset_bn(module): 135 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 136 | module.running_mean = torch.zeros_like(module.running_mean) 137 | module.running_var = torch.ones_like(module.running_var) 138 | 139 | 140 | def _get_momenta(module, momenta): 141 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 142 | momenta[module] = module.momentum 143 | 144 | 145 | def _set_momenta(module, momenta): 146 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 147 | module.momentum = momenta[module] 148 | 149 | 150 | def bn_update(loader, model): 151 | """ 152 | BatchNorm buffers update (if any). 153 | Performs 1 epochs to estimate buffers average using train dataset. 154 | :param loader: train dataset loader for buffers average estimation. 155 | :param model: model being update 156 | :return: None 157 | """ 158 | if not check_bn(model): 159 | return 160 | with torch.no_grad(): 161 | model.train() 162 | momenta = {} 163 | model.apply(reset_bn) 164 | model.apply(lambda module: _get_momenta(module, momenta)) 165 | n = 0 166 | for x, _, _, _, _ in loader: 167 | x = x.cuda(non_blocking=True) 168 | b = x.data.size(0) 169 | 170 | momentum = b / (n + b) 171 | for module in momenta.keys(): 172 | module.momentum = momentum 173 | 174 | model(x) 175 | n += b 176 | 177 | model.apply(lambda module: _set_momenta(module, momenta)) 178 | model.eval() 179 | 180 | 181 | -------------------------------------------------------------------------------- /diag_nets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def loss(X, y_hat, y): 5 | return np.linalg.norm(y_hat - y)**2 / (2*X.shape[0]) 6 | 7 | 8 | 9 | def valley_projection(X, y, gamma, u0, v0, num_iter): 10 | u, v = u0, v0 11 | for i in range(num_iter): 12 | y_hat = X @ (u * v) 13 | Xerr = X.T@(y_hat - y) 14 | grad_u, grad_v = (Xerr * v) / X.shape[0], (Xerr * u) / X.shape[0] 15 | 16 | u = u - gamma * grad_u 17 | v = v - gamma * grad_v 18 | 19 | return u, v 20 | 21 | 22 | def GD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, thresholds=[-1], decays=[2], wd=0, balancedness=0, normalized_gd=False, weight_avg=0.0, return_all_params=False, valley_project=False): 23 | train_losses, test_losses = [], [] 24 | us, vs = [], [] 25 | u, v, u_avg, v_avg = u0, v0, u0, v0 26 | for i in range(num_iter): 27 | if i in iters_loss: 28 | if valley_project: 29 | u_avg, v_avg = valley_projection(X, y, 0.5, u, v, 2000) 30 | train_losses += [loss(X, X @ (u_avg * v_avg), y)] 31 | test_losses += [loss(X_test, X_test @ (u_avg * v_avg), y_test)] 32 | us, vs = us + [u_avg], vs + [v_avg] 33 | 34 | y_hat = X @ (u * v) 35 | Xerr = X.T@(y_hat - y) 36 | grad_u, grad_v = (Xerr * v) / X.shape[0], (Xerr * u) / X.shape[0] 37 | if normalized_gd and (np.linalg.norm(grad_u) > 0 and np.linalg.norm(grad_v) > 0): 38 | grad_u, grad_v = grad_u / np.linalg.norm(grad_u), grad_v / np.linalg.norm(grad_v) 39 | 40 | u = u - gamma * grad_u - wd*u - balancedness*u*(2*(np.abs(u)>np.abs(v))-1) 41 | v = v - gamma * grad_v - wd*v - balancedness*v*(2*(np.abs(v)>np.abs(u))-1) 42 | u_avg, v_avg = weight_avg*u_avg + (1-weight_avg)*u, weight_avg*v_avg + (1-weight_avg)*v 43 | 44 | if i in thresholds: 45 | for threshold, decay in zip(thresholds, decays): 46 | if i == threshold: 47 | gamma = gamma / decay 48 | 49 | if return_all_params: 50 | return train_losses, test_losses, u, v, us, vs 51 | else: 52 | return train_losses, test_losses, u, v 53 | 54 | 55 | def SGD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, thresholds=[-1], decays=[2], weight_avg=0.0, return_all_params=False, valley_project=False): 56 | train_losses, test_losses = [], [] 57 | us, vs = [], [] 58 | u, v, u_avg, v_avg = u0, v0, u0, v0 59 | for i in range(num_iter): 60 | if i in iters_loss: 61 | if valley_project: 62 | u_avg, v_avg = valley_projection(X, y, 0.5, u, v, 2000) 63 | train_losses += [loss(X, X @ (u_avg * v_avg), y)] 64 | test_losses += [loss(X_test, X_test @ (u_avg * v_avg), y_test)] 65 | us, vs = us + [u_avg], vs + [v_avg] 66 | 67 | i_t = np.random.randint(X.shape[0]) 68 | error = X[i_t] @ (u * v) - y[i_t] 69 | Xerr = error * X[i_t] 70 | grad_u, grad_v = Xerr * v, Xerr * u 71 | 72 | u = u - gamma * grad_u # gradient step 73 | v = v - gamma * grad_v # gradient step 74 | u_avg, v_avg = weight_avg*u_avg + (1-weight_avg)*u, weight_avg*v_avg + (1-weight_avg)*v 75 | 76 | if i in thresholds: 77 | for threshold, decay in zip(thresholds, decays): 78 | if i == threshold: 79 | gamma = gamma / decay 80 | 81 | if return_all_params: 82 | return train_losses, test_losses, u, v, us, vs 83 | else: 84 | return train_losses, test_losses, u, v 85 | 86 | 87 | def n_SAM_GD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, rho): 88 | train_losses, test_losses = [], [] 89 | u, v = u0, v0 90 | for i in range(num_iter): 91 | y_hat = X @ (u * v) 92 | Xerr = X.T @ (y_hat - y) 93 | u_sam, v_sam = u + rho * (Xerr * v) / X.shape[0], v + rho * (Xerr * u) / X.shape[0] 94 | 95 | Xerr_sam = X.T @ (X @ (u_sam * v_sam) - y) 96 | grad_u_sam, grad_v_sam = (Xerr_sam * v_sam) / X.shape[0], (Xerr_sam * u_sam) / X.shape[0] 97 | 98 | u = u - gamma * grad_u_sam # gradient step 99 | v = v - gamma * grad_v_sam # gradient step 100 | 101 | if i in iters_loss: 102 | train_losses += [loss(X, X @ (u * v), y)] 103 | test_losses += [loss(X_test, X_test @ (u * v), y_test)] 104 | 105 | return train_losses, test_losses, u, v 106 | 107 | 108 | def one_SAM_GD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, rho, loss_derivative_only=False): 109 | train_losses, test_losses = [], [] 110 | u, v = u0, v0 111 | for i in range(num_iter): 112 | y_hat = X @ (u * v) 113 | r = y_hat - y 114 | grad_u_sam, grad_v_sam = np.zeros_like(u), np.zeros_like(v) 115 | for k in range(X.shape[0]): 116 | u_sam_k = u + 2 * rho * r[k] * X[k] * v 117 | v_sam_k = v + 2 * rho * r[k] * X[k] * u 118 | if not loss_derivative_only: 119 | grad_u_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * v_sam_k * X[k] / X.shape[0] 120 | grad_v_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * u_sam_k * X[k] / X.shape[0] 121 | else: 122 | grad_u_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * v * X[k] / X.shape[0] 123 | grad_v_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * u * X[k] / X.shape[0] 124 | 125 | u = u - gamma * grad_u_sam # gradient step 126 | v = v - gamma * grad_v_sam # gradient step 127 | 128 | if i in iters_loss: 129 | train_losses += [loss(X, X @ (u * v), y)] 130 | test_losses += [loss(X_test, X_test @ (u * v), y_test)] 131 | 132 | return train_losses, test_losses, u, v 133 | 134 | 135 | def dln_hessian(u, v, X, y, normalized=False): 136 | beta = u * v 137 | if normalized: 138 | u = np.abs(beta)**0.5 * np.sign(u) 139 | v = np.abs(beta)**0.5 * np.sign(v) 140 | # print(beta[:10], u[:10], v[:10]) 141 | n, d = X.shape 142 | H = np.zeros([2*d, 2*d]) 143 | H[:d, :d] = np.diag(v) @ (X.T@X) @ np.diag(v) 144 | H[d:, :d] = np.diag(v) @ (X.T@X) @ np.diag(u) + np.diag(X.T@(X@(u*v) - y)) 145 | H[:d, d:] = np.diag(u) @ (X.T@X) @ np.diag(v) + np.diag(X.T@(X@(u*v) - y)) 146 | H[d:, d:] = np.diag(u) @ (X.T@X) @ np.diag(u) 147 | return H / n 148 | 149 | 150 | def dln_hessian_eigs(u, v, X, y, normalized=False): 151 | hess = dln_hessian(u, v, X, y, normalized) 152 | eigs, _ = np.linalg.eig(hess) 153 | return eigs 154 | 155 | 156 | def dln_grad_loss(u, v, X, y): 157 | Xerr = X.T @ (X @ (u * v) - y) 158 | grad_u, grad_v = (Xerr * v) / X.shape[0], (Xerr * u) / X.shape[0] 159 | return np.concatenate([grad_u, grad_v]) 160 | 161 | 162 | def dln_avg_individual_grad_loss(u, v, X, y, residuals=True): 163 | n = X.shape[0] 164 | r = X @ (u * v) - y 165 | 166 | if not residuals: 167 | r = np.ones_like(r) 168 | 169 | sum = 0 170 | for i in range(n): 171 | sum += np.sum((r[i] * X[i] * v)**2) / n 172 | sum += np.sum((r[i] * X[i] * u)**2) / n 173 | 174 | return sum 175 | 176 | 177 | def compute_grad_matrix_ranks(us, vs, X, l0_threshold_grad_matrix=0.0001): 178 | grad_matrix_ranks = [] 179 | for u, v in zip(us, vs): 180 | grad_matrix = np.hstack([X * u, X * v]) 181 | svals = np.linalg.svd(grad_matrix)[1] 182 | rank = (svals / svals[0] > l0_threshold_grad_matrix).sum() 183 | grad_matrix_ranks.append(rank) 184 | return grad_matrix_ranks 185 | 186 | -------------------------------------------------------------------------------- /fc_nets.py: -------------------------------------------------------------------------------- 1 | from unittest import TestResult 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import copy 6 | 7 | 8 | class FCNet2Layers(torch.nn.Module): 9 | def __init__(self, n_feature, n_hidden, n_output=1, biases=[False, False]): 10 | super(FCNet2Layers, self).__init__() 11 | self.layer1 = torch.nn.Linear(n_feature, n_hidden, bias=biases[0]) 12 | self.layer2 = torch.nn.Linear(n_hidden, n_output, bias=biases[1]) 13 | 14 | def init_gaussian(self, init_scales): 15 | self.layer1.weight.data = init_scales[0] * torch.randn_like(self.layer1.weight) 16 | self.layer2.weight.data = init_scales[1] * torch.randn_like(self.layer2.weight) 17 | 18 | def init_gaussian_clsf(self, init_scales): 19 | self.layer1.weight.data = init_scales[0] * torch.randn_like(self.layer1.weight) 20 | self.layer2.weight.data = ((torch.randn_like(self.layer2.weight) > 0).float() - 0.5) * 2 * (self.layer1.weight.data**2).sum(1)**0.5 21 | 22 | def init_blanc_et_al(self, init_scales): 23 | self.layer1.weight.data = 2.5 * (-1 + 2 * torch.round(torch.rand_like(self.layer1.weight))) * init_scales[0] 24 | self.layer1.bias.data = 2.5 * init_scales[0] * torch.randn_like(self.layer1.bias) 25 | self.layer2.weight.data = 4.0 * init_scales[1] * torch.randn_like(self.layer2.weight) 26 | 27 | def features(self, x, normalize=True, scaled=False): 28 | x = F.relu(self.layer1(x)) 29 | if scaled: 30 | x = x * self.layer2.weight 31 | if normalize: 32 | x /= (x**2).sum(1, keepdim=True)**0.5 33 | x[torch.isnan(x)] = 0.0 34 | return x.data.numpy() 35 | 36 | def feature_sparsity(self, X, corr_threshold=0.99): 37 | phi = self.features(X) 38 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0] 39 | phi_filtered = phi[:, idx_keep] # filter out zeros 40 | corr_matrix = np.corrcoef(phi_filtered.T) 41 | corr_matrix -= np.eye(corr_matrix.shape[0]) 42 | 43 | idx_to_delete, i, j = [], 0, 0 44 | while i != corr_matrix.shape[0]: 45 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum()) 46 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0: 47 | corr_matrix = np.delete(corr_matrix, (i), axis=0) 48 | corr_matrix = np.delete(corr_matrix, (i), axis=1) 49 | # print('delete', j) 50 | idx_to_delete.append(j) 51 | else: 52 | i += 1 53 | j += 1 54 | assert corr_matrix.shape[0] == corr_matrix.shape[1] 55 | # print(idx_to_delete, idx_keep) 56 | idx_keep = np.delete(idx_keep, [idx_to_delete]) 57 | sparsity = (phi[:, idx_keep] != 0).sum() / (phi.shape[0] * phi.shape[1]) 58 | 59 | return sparsity 60 | 61 | def forward(self, x): 62 | z = F.relu(self.layer1(x)) 63 | z = self.layer2(z) 64 | return z 65 | 66 | 67 | class FCNet(torch.nn.Module): 68 | def __init__(self, n_feature, n_hidden, biases=True): 69 | super(FCNet, self).__init__() 70 | self.n_hidden = [n_feature] + n_hidden + [1] # add the number of input and output units 71 | self.biases = biases 72 | self.layers = torch.nn.ModuleList() 73 | for i in range(len(self.n_hidden) - 1): 74 | self.layers.append(torch.nn.Linear(self.n_hidden[i], self.n_hidden[i+1], bias=self.biases)) 75 | 76 | def init_gaussian(self, init_scales): 77 | for i in range(len(self.n_hidden) - 1): 78 | self.layers[i].weight.data = init_scales[i] * torch.randn_like(self.layers[i].weight) 79 | 80 | def forward(self, x): 81 | for i in range(len(self.n_hidden) - 2): 82 | x = F.relu(self.layers[i](x)) 83 | x = self.layers[-1](x) 84 | return x 85 | 86 | def features(self, x, normalize=True, scaled=False, n_hidden_to_take=-1): 87 | for i in range(n_hidden_to_take if n_hidden_to_take > 0 else len(self.n_hidden) - 2): 88 | x = F.relu(self.layers[i](x)) 89 | if scaled and n_hidden_to_take in [-1, len(self.n_hidden)]: 90 | x = x * self.layers[-1].weight 91 | if normalize: 92 | x /= (x**2).sum(1, keepdim=True)**0.5 93 | x[torch.isnan(x)] = 0.0 94 | return x.data.numpy() 95 | 96 | def feature_sparsity(self, X, n_hidden_to_take=-1, corr_threshold=0.99): 97 | phi = self.features(X, n_hidden_to_take=n_hidden_to_take) 98 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0] 99 | phi_filtered = phi[:, idx_keep] # filter out zeros 100 | corr_matrix = np.corrcoef(phi_filtered.T) 101 | corr_matrix -= np.eye(corr_matrix.shape[0]) 102 | 103 | idx_to_delete, i, j = [], 0, 0 104 | while i != corr_matrix.shape[0]: 105 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum()) 106 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0: 107 | corr_matrix = np.delete(corr_matrix, (i), axis=0) 108 | corr_matrix = np.delete(corr_matrix, (i), axis=1) 109 | # print('delete', j) 110 | idx_to_delete.append(j) 111 | else: 112 | i += 1 113 | j += 1 114 | assert corr_matrix.shape[0] == corr_matrix.shape[1] 115 | # print(idx_to_delete, idx_keep) 116 | idx_keep = np.delete(idx_keep, [idx_to_delete]) 117 | sparsity = (phi[:, idx_keep] != 0).sum() / (phi.shape[0] * phi.shape[1]) 118 | 119 | return sparsity 120 | 121 | def n_highly_corr(self, X, n_hidden_to_take=-1, corr_threshold=0.99): 122 | phi = self.features(X, n_hidden_to_take=n_hidden_to_take) 123 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0] 124 | phi_filtered = phi[:, idx_keep] # filter out zeros 125 | corr_matrix = np.corrcoef(phi_filtered.T) 126 | corr_matrix -= np.eye(corr_matrix.shape[0]) 127 | 128 | idx_to_delete, i, j = [], 0, 0 129 | while i != corr_matrix.shape[0]: 130 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum()) 131 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0: 132 | corr_matrix = np.delete(corr_matrix, (i), axis=0) 133 | corr_matrix = np.delete(corr_matrix, (i), axis=1) 134 | # print('delete', j) 135 | idx_to_delete.append(j) 136 | else: 137 | i += 1 138 | j += 1 139 | assert corr_matrix.shape[0] == corr_matrix.shape[1] 140 | # print(idx_to_delete, idx_keep) 141 | idx_keep = np.delete(idx_keep, [idx_to_delete]) 142 | sparsity = (phi[:, idx_keep] != 0).sum() / (phi.shape[0] * phi.shape[1]) 143 | 144 | return phi.shape[1] - len(idx_keep) 145 | 146 | 147 | def moving_average(net, net_avg, weight_avg): 148 | for param, param_avg in zip(net.parameters(), net_avg.parameters()): 149 | param_avg.data = weight_avg*param_avg.data + (1-weight_avg)*param.data 150 | 151 | 152 | def train_fc_net(X, y, X_test, y_test, gamma, batch_size, net, iters_loss, num_iter, thresholds=[-1], decays=[-1], iters_percentage_linear_warmup=0.0, gamma_warmup_factor_max=1.0, warmup_exponent=1.0, weight_avg=0.0, clsf=False, gauss_ln_scale=0.0): 153 | assert iters_percentage_linear_warmup <= decays[0], 'we should decay the step size only after warmup' 154 | train_losses, test_losses, nets_avg = [], [], [] 155 | net, net_avg = copy.deepcopy(net), copy.deepcopy(net) 156 | 157 | loss_f = (lambda y_pred, y: torch.mean(torch.log(1 + torch.exp(-y_pred * y)))) if clsf else (lambda y_pred, y: torch.mean((y_pred - y)**2)) 158 | # loss_f = lambda y_pred, y: torch.mean(torch.log(1 + torch.exp(-y_pred * y))) 159 | 160 | optimizer = torch.optim.SGD(net.parameters(), lr=gamma) #, momentum=0.9) 161 | for i in range(num_iter): 162 | if i in iters_loss: 163 | train_losses += [loss_f(net_avg(X), y)] 164 | test_losses += [loss_f(net_avg(X_test), y_test)] 165 | nets_avg.append(copy.deepcopy(net_avg)) 166 | if torch.isnan(train_losses[-1]): 167 | return train_losses, test_losses, nets_avg 168 | 169 | if i <= int(iters_percentage_linear_warmup * num_iter) and int(iters_percentage_linear_warmup * num_iter) > 0: 170 | optimizer.param_groups[0]['lr'] = gamma + (gamma_warmup_factor_max - 1) * gamma * (i / int(num_iter))**warmup_exponent 171 | # optimizer.param_groups[0]['lr'] = gamma + gamma * gamma_warmup_factor_max * (i / int(iters_percentage_linear_warmup * num_iter))**warmup_exponent 172 | 173 | if i in thresholds: 174 | for threshold, decay in zip(thresholds, decays): 175 | if i == threshold: 176 | optimizer.param_groups[0]['lr'] /= decay 177 | 178 | indices = np.random.choice(X.shape[0], size=batch_size, replace=False) 179 | batch_x, batch_y = X[indices], y[indices] 180 | if gauss_ln_scale > 0.0: # label noise with schedule (note: supports only one threshold at the moment) 181 | batch_y += torch.randn_like(batch_y) * gauss_ln_scale / (decay if i > thresholds[0] else 1.0) 182 | loss = loss_f(net(batch_x), batch_y) 183 | 184 | optimizer.zero_grad() 185 | loss.backward() 186 | optimizer.step() 187 | 188 | moving_average(net, net_avg, weight_avg) 189 | 190 | return train_losses, test_losses, nets_avg 191 | 192 | 193 | def compute_grad_matrix(net, X): 194 | optimizer = torch.optim.SGD(net.parameters(), lr=0.0) 195 | grad_matrix_list = [] 196 | for i in range(X.shape[0]): 197 | h = net(X[[i]]) 198 | optimizer.zero_grad() 199 | h.backward() 200 | 201 | grad_total_list = [] 202 | for param in net.parameters(): 203 | grad_total_list.append(param.grad.flatten().data.numpy()) 204 | grad_total = np.concatenate(grad_total_list) 205 | grad_matrix_list.append(grad_total) 206 | 207 | grad_matrix = np.vstack(grad_matrix_list) 208 | return grad_matrix 209 | 210 | 211 | def compute_grad_matrix_ranks(nets, X, l0_threshold_grad_matrix=0.0001): 212 | n_params = sum([np.prod(param.shape) for param in nets[-1].parameters()]) 213 | X_eval = X[:n_params] 214 | grad_matrix_ranks = [] 215 | for net in nets: 216 | svals = np.linalg.svd(compute_grad_matrix(net, X_eval))[1] 217 | rank = (svals / svals[0] > l0_threshold_grad_matrix).sum() 218 | grad_matrix_ranks.append(rank) 219 | return grad_matrix_ranks 220 | 221 | -------------------------------------------------------------------------------- /images/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-epfl/sgd-sparse-features/656cb0d9e4e1cdd688073841f145fe9f94703d7c/images/fig1.png -------------------------------------------------------------------------------- /images/twitter.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-epfl/sgd-sparse-features/656cb0d9e4e1cdd688073841f145fe9f94703d7c/images/twitter.gif -------------------------------------------------------------------------------- /images/twitter.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-epfl/sgd-sparse-features/656cb0d9e4e1cdd688073841f145fe9f94703d7c/images/twitter.mp4 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from fc_nets import FCNet2Layers, FCNet, compute_grad_matrix 4 | 5 | 6 | def get_iters_eval(n_iter_power, x_log_scale, n_iters_first=101, n_iters_next=151): 7 | num_iter = int(10**n_iter_power) + 1 8 | 9 | iters_loss_first = np.array(range(100)) 10 | if x_log_scale: 11 | iters_loss_next = np.unique(np.round(np.logspace(0, n_iter_power, n_iters_first))) 12 | else: 13 | iters_loss_next = np.unique(np.round(np.linspace(0, num_iter, n_iters_next)))[:-1] 14 | iters_loss = np.unique(np.concatenate((iters_loss_first, iters_loss_next))) 15 | 16 | return num_iter, iters_loss 17 | 18 | 19 | def get_data_two_layer_relu_net(n, d, m_teacher, init_scales_teacher, seed): 20 | np.random.seed(seed + 1) 21 | torch.manual_seed(seed + 1) 22 | 23 | n_test = 1000 24 | H = np.eye(d) 25 | X = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n)).float() 26 | X = X / torch.sum(X**2, 1, keepdim=True)**0.5 27 | X_test = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n_test)).float() 28 | X_test = X_test / torch.sum(X_test**2, 1, keepdim=True)**0.5 29 | 30 | # generate ground truth labels 31 | with torch.no_grad(): 32 | net_teacher = FCNet2Layers(n_feature=d, n_hidden=m_teacher) 33 | net_teacher.init_gaussian(init_scales_teacher) 34 | net_teacher.layer1.weight.data = net_teacher.layer1.weight.data / torch.sum((net_teacher.layer1.weight.data)**2, 1, keepdim=True)**0.5 35 | net_teacher.layer2.weight.data = torch.sign(net_teacher.layer2.weight.data) 36 | 37 | y, y_test = net_teacher(X), net_teacher(X_test) 38 | 39 | print('y', y[:20, 0]) 40 | 41 | return X, y, X_test, y_test, net_teacher 42 | 43 | 44 | def get_data_multi_layer_relu_net(n, d, m_teacher, init_scales_teacher, seed): 45 | np.random.seed(seed + 1) 46 | torch.manual_seed(seed + 1) 47 | 48 | n_test = 1000 49 | H = np.eye(d) 50 | X = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n)).float() 51 | X = X / torch.sum(X**2, 1, keepdim=True)**0.5 52 | X_test = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n_test)).float() 53 | X_test = X_test / torch.sum(X_test**2, 1, keepdim=True)**0.5 54 | 55 | # generate ground truth labels 56 | with torch.no_grad(): 57 | net_teacher = FCNet(n_feature=d, n_hidden=m_teacher) 58 | net_teacher.init_gaussian(init_scales_teacher) 59 | y, y_test = net_teacher(X), net_teacher(X_test) 60 | print('y:', y[:, 0]) 61 | 62 | return X, y, X_test, y_test, net_teacher 63 | 64 | 65 | def effective_rank(v): 66 | v = v[v != 0] 67 | v /= v.sum() 68 | return -(v * np.log(v)).sum() 69 | 70 | 71 | def rm_too_correlated(net, X, V, corr_threshold=0.99): 72 | V = V.T 73 | idx_keep = np.where((V > 0.0).sum(0) > 0)[0] 74 | V_filtered = V[:, idx_keep] # filter out zeros 75 | corr_matrix = np.corrcoef(V_filtered.T) 76 | corr_matrix -= np.eye(corr_matrix.shape[0]) 77 | 78 | idx_to_delete, i, j = [], 0, 0 79 | while i != corr_matrix.shape[0]: 80 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0: 81 | corr_matrix = np.delete(corr_matrix, (i), axis=0) 82 | corr_matrix = np.delete(corr_matrix, (i), axis=1) 83 | # print('delete', j) 84 | idx_to_delete.append(j) 85 | else: 86 | i += 1 87 | j += 1 88 | assert corr_matrix.shape[0] == corr_matrix.shape[1] 89 | idx_keep = np.delete(idx_keep, [idx_to_delete]) 90 | 91 | return V[:, idx_keep].T 92 | 93 | def compute_grad_matrix_dim(net, X, corr_threshold=0.99): 94 | grad_matrix = compute_grad_matrix(net, X) 95 | grad_matrix_sq_norms = np.sum(grad_matrix**2, 0) 96 | m = 100 97 | v_j = [] 98 | for j in range(m): 99 | v_j.append(grad_matrix_sq_norms[[j, m+j, 2*m+j]]) # matrix: w1, w2, w3, w4 100 | V = np.vstack(v_j) 101 | 102 | V_reduced = rm_too_correlated(net, X, V, corr_threshold=corr_threshold) 103 | grad_matrix_dim = V_reduced.shape[0] 104 | return grad_matrix_dim 105 | 106 | 107 | def project(u, v, u0, v0, u1, v1, u2, v2): 108 | u, v = u - u0, v - v0 109 | alpha = (u @ u1 + v @ v1) / (np.sum(u1**2) + np.sum(v1**2)) 110 | beta = (u @ u2 + v @ v2) / (np.sum(u2**2) + np.sum(v2**2)) 111 | return alpha, beta 112 | 113 | --------------------------------------------------------------------------------