├── .gitignore
├── README.md
├── deep_nets
├── Dockerfile
├── data.py
├── models.py
├── plot_feature_sparsity.ipynb
├── train.py
├── utils.py
├── utils_eval.py
└── utils_train.py
├── diag_nets.ipynb
├── diag_nets.py
├── diag_nets_2d_loss_surface.ipynb
├── fc_nets.py
├── fc_nets_1d_regression.ipynb
├── fc_nets_multi_layer.ipynb
├── fc_nets_two_layer.ipynb
├── images
├── fig1.png
├── twitter.gif
└── twitter.mp4
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | plots
2 | .DS_Store
3 | __pycache__
4 | .mat
5 |
6 | deep_nets/exps/
7 | deep_nets/logs/
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SGD with large step sizes learns sparse features
2 |
3 | **Maksym Andriushchenko, Aditya Varre, Loucas Pillaud-Vivien, Nicolas Flammarion (EPFL)**
4 |
5 | **ICML 2023**
6 |
7 | **Paper:** [https://arxiv.org/abs/2210.05337](https://arxiv.org/abs/2210.05337)
8 |
9 |
10 |

11 |
12 | ## Abstract
13 | We showcase important features of the dynamics of the Stochastic Gradient Descent (SGD) in the training of neural networks. We present empirical observations that commonly used large step sizes (i) lead the iterates to jump from one side of a valley to the other causing *loss stabilization*, and (ii) this stabilization induces a hidden stochastic dynamics orthogonal to the bouncing directions that *biases it implicitly* toward simple predictors. Furthermore, we show empirically that the longer large step sizes keep SGD high in the loss landscape valleys, the better the implicit regularization can operate and find sparse representations. Notably, no explicit regularization is used so that the regularization effect comes solely from the SGD training dynamics influenced by the step size schedule. Therefore, these observations unveil how, through the step size schedules, both gradient and noise drive together the SGD dynamics through the loss landscape of neural networks. We justify these findings theoretically through the study of simple neural network models as well as qualitative arguments inspired from stochastic processes. Finally, this analysis allows to shed a new light on some common practice and observed phenomena when training neural networks.
14 |
15 | 
16 |
17 |
18 |
19 |
20 |
21 | ## Code
22 | The exact code to reproduce all the reported experiments on simple networks is available in jupyter notebooks:
23 | - `diag_nets.ipynb`: diagonal linear networks (also see `diag_nets_2d_loss_surface.ipynb` for loss surface visualizations).
24 | - `fc_nets_1d_regression.ipynb`: two-layer ReLU networks on 1D regression problem.
25 | - `fc_nets_two_layer.ipynb`: two-layer ReLU networks in a teacher-student setup (+ neuron movement visualization).
26 | - `fc_nets_multi_layer.ipynb`: three-layer ReLU networks in a teacher-student setup.
27 |
28 | For deep networks, see folder `deep_nets` where the dependencies are collected in `Dockerfile`. Typical training commands for a ResNet-18 on CIFAR-10 would look like this:
29 | - Plain SGD without explicit regularization (loss stabilization is achieved via exponential warmup):
30 | - with large step sizes: `python train.py --dataset=cifar10 --lr_init=0.75 --lr_schedule=piecewise_05epochs --warmup_exp=1.05 --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.0 --l2_reg=0.0 --no_data_augm --eval_iter_freq=200 --exp_name=no_explicit_reg`
31 | - with small step sizes: `python train.py --dataset=cifar10 --lr_init=0.01 --lr_schedule=constant --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.0 --l2_reg=0.0 --no_data_augm --eval_iter_freq=200 --exp_name=no_explicit_reg`
32 | - SGD + momentum in the state-of-the-art setting with data augmentation and weight decay:
33 | - with large step sizes: `python train.py --dataset=cifar10 --lr_init=0.05 --lr_schedule=piecewise_05epochs --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.9 --l2_reg=0.0005 --eval_iter_freq=200 --exp_name=sota_setting`
34 | - with small step sizes: `python train.py --dataset=cifar10 --lr_init=0.002 --lr_schedule=constant --model=resnet18_plain --model_width=64 --epochs=100 --batch_size=256 --momentum=0.9 --l2_reg=0.0005 --eval_iter_freq=200 --exp_name=sota_setting`
35 |
36 | The runs with CIFAR-100 are analogous, just put `dataset=cifar100`. The step size schedule can be selected from [`constant`, `piecewise_01epochs`, `piecewise_03epochs`, `piecewise_05epochs`], see `utils_train.py` for more details.
37 |
38 |
39 | ## Contact
40 | Feel free to reach out if you have any questions regarding the code!
41 |
42 |
--------------------------------------------------------------------------------
/deep_nets/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04
2 | LABEL maintainer "Maksym Andriushchenko "
3 |
4 | ARG DEBIAN_FRONTEND=noninteractive # needed to prevent some questions during the Docker building phase
5 |
6 | # install some necessary tools
7 | RUN apt-get update
8 | RUN apt-get install -y \
9 | cmake \
10 | curl \
11 | htop \
12 | locales \
13 | python3 \
14 | python3-pip \
15 | sudo \
16 | unzip \
17 | vim \
18 | git \
19 | wget \
20 | zsh \
21 | libssl-dev \
22 | libffi-dev \
23 | libmagickwand-dev \
24 | ffmpeg \
25 | libsm6 \
26 | libxext6 \
27 | openssh-server
28 | RUN rm -rf /var/lib/apt/lists/*
29 | RUN mkdir /var/run/sshd
30 |
31 |
32 | # configure environments
33 | RUN locale-gen en_US.UTF-8
34 | ENV LANG en_US.UTF-8
35 | ENV LANGUAGE en_US:en
36 | ENV LC_ALL en_US.UTF-8
37 |
38 |
39 | RUN pip3 install --upgrade pip # needed for opencv
40 | RUN pip3 install -U setuptools # may be needed for opencv
41 | RUN pip3 install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html # needed for A100 pods
42 |
43 |
44 | # python packages
45 | RUN pip3 install --upgrade \
46 | scipy \
47 | numpy \
48 | jupyter notebook \
49 | ipdb \
50 | pyyaml \
51 | easydict \
52 | requests \
53 | matplotlib \
54 | seaborn
55 | RUN export LC_ALL=en_US.UTF-8
56 |
57 | # Configure user and group
58 | ENV SHELL=/bin/bash
59 |
60 | ENV HOME=/home/$NB_USER
61 |
62 | RUN ln -s /usr/bin/python3 /usr/bin/python
63 |
64 | # expose the port to ssh
65 | EXPOSE 22
66 | CMD ["/usr/sbin/sshd", "-D"]
67 |
--------------------------------------------------------------------------------
/deep_nets/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.utils.data as td
4 | import numpy as np
5 | from torchvision import datasets, transforms
6 |
7 |
8 | class DatasetWithLabelNoise(torch.utils.data.Dataset):
9 | def __init__(self, data, split, transform):
10 | self.data = data
11 | self.split = split
12 | self.transform = transform
13 |
14 | def __getitem__(self, index):
15 | x = self.data.data[index]
16 | x1 = self.transform(x) if self.transform is not None else x
17 | if self.split == 'train':
18 | x2 = self.transform(x) if self.transform is not None else x
19 | else: # to save a bit of computations
20 | x2 = x1
21 | y = self.data.targets[index]
22 | y_correct = self.data.targets_correct[index]
23 | label_noise = self.data.label_noise[index]
24 | return x1, x2, y, y_correct, label_noise
25 |
26 | def __len__(self):
27 | return len(self.data.targets)
28 |
29 |
30 | def uniform_noise(*args, **kwargs):
31 | shape = [1000, 1, 28, 28]
32 | x = torch.from_numpy(np.random.rand(*shape)).float()
33 | # y_train = np.random.randint(0, 10, size=shape_train[0])
34 | y = np.floor(10 * x[:, 0, 0, 0].numpy()) # take the first feature
35 | y = torch.from_numpy(y).long()
36 | data = td.TensorDataset(x, y)
37 | return data
38 |
39 |
40 | def dataset_gaussians_binary(*args, **kwargs):
41 | shape = shapes_dict['gaussians_binary']
42 | n, d = shape[0], shape[3]
43 | std = 0.1
44 |
45 | v = v_global.copy()
46 | v /= (v**2).sum()**0.5 # make it unit norm
47 | mu_zero, mu_one = v, -v
48 |
49 | x = np.concatenate([mu_zero + std*np.random.randn(n // 2, d), mu_one + std*np.random.randn(n // 2, d)])
50 | y = np.concatenate([np.zeros(n // 2), np.ones(n // 2)])
51 | indices = np.random.permutation(np.arange(n))
52 | x, y = x[indices], y[indices]
53 | x = x[:, None, None, :] # make it image-like
54 |
55 | data = td.TensorDataset()
56 | data.data, data.targets = torch.from_numpy(x).float(), torch.from_numpy(y).long()
57 | return data
58 |
59 |
60 | def asym_label_noise(dataset, label):
61 | if dataset == 'cifar10':
62 | if label == 9:
63 | return 1
64 | # bird -> airplane
65 | elif label == 2:
66 | return 0
67 | # cat -> dog
68 | elif label == 3:
69 | return 5
70 | # dog -> cat
71 | elif label == 5:
72 | return 3
73 | # deer -> horse
74 | elif label == 4:
75 | return 7
76 | else:
77 | return label
78 | elif dataset == 'cifar100':
79 | return (label + 1) % 100
80 | elif dataset == 'svhn':
81 | return (label + 1) % 10
82 | else:
83 | raise ValueError('This dataset does not yet support asymmetric label noise.')
84 |
85 |
86 | def get_loaders(dataset, n_ex, batch_size, split, shuffle, data_augm, val_indices=None, p_label_noise=0.0,
87 | noise_type='sym', drop_last=False):
88 | dir_ = '/tmlscratch/andriush/data'
89 | # dir_ = '/tmldata1/andriush/data'
90 | dataset_f = datasets_dict[dataset]
91 | batch_size = n_ex if n_ex < batch_size and n_ex != -1 else batch_size
92 | num_workers_train, num_workers_val, num_workers_test = 4, 4, 4
93 |
94 | data_augm_transforms = [transforms.RandomCrop(32, padding=4)]
95 | if dataset not in ['mnist', 'svhn']:
96 | data_augm_transforms.append(transforms.RandomHorizontalFlip())
97 | base_transforms = [transforms.ToPILImage()] if dataset != 'gaussians_binary' else []
98 | transform_list = base_transforms + data_augm_transforms if data_augm else base_transforms
99 | transform = transforms.Compose(transform_list + [transforms.ToTensor()])
100 |
101 | if dataset == 'cifar10_horse_car':
102 | cl1, cl2 = 7, 1 # 7=horse, 1=car
103 | elif dataset == 'cifar10_dog_cat':
104 | cl1, cl2 = 5, 3 # 5=dog, 3=cat
105 | if split in ['train', 'val']:
106 | if dataset != 'svhn':
107 | data = dataset_f(dir_, train=True, transform=transform, download=True)
108 | else:
109 | data = dataset_f(dir_, split='train', transform=transform, download=True)
110 | data.data = data.data.transpose([0, 2, 3, 1])
111 | data.targets = data.labels
112 | data.targets = np.array(data.targets)
113 | n_cls = max(data.targets) + 1
114 |
115 | if dataset in ['cifar10_horse_car', 'cifar10_dog_cat']:
116 | data.targets = np.array(data.targets)
117 | idx = (data.targets == cl1) + (data.targets == cl2)
118 | data.data, data.targets = data.data[idx], data.targets[idx]
119 | data.targets[data.targets == cl1], data.targets[data.targets == cl2] = 0, 1
120 | n_cls = 2
121 | n_ex = len(data.targets) if n_ex == -1 else n_ex
122 | if '_gs' in dataset:
123 | data.data = data.data.mean(3).astype(np.uint8)
124 |
125 | if val_indices is not None:
126 | assert len(val_indices) < len(data.targets), '#val has to be < total #train pts'
127 | val_indices_mask = np.zeros(len(data.targets), dtype=bool)
128 | val_indices_mask[val_indices] = True
129 | if split == 'train':
130 | data.data, data.targets = data.data[~val_indices_mask], data.targets[~val_indices_mask]
131 | else:
132 | data.data, data.targets = data.data[val_indices_mask], data.targets[val_indices_mask]
133 | data.data, data.targets = data.data[:n_ex], data.targets[:n_ex] # so the #pts can be in [n_ex-n_eval, n_ex]
134 | # e.g., when frac_train=1.0, for training set, n_ex=50k while data.data.shape[0]=45k bc of val set
135 | if n_ex > data.data.shape[0]:
136 | n_ex = data.data.shape[0]
137 |
138 | data.label_noise = np.zeros(n_ex, dtype=bool)
139 | data.targets_correct = data.targets.copy()
140 | if p_label_noise > 0.0:
141 | print('Split: {}, number of examples: {}, noisy examples: {}'.format(split, n_ex, int(n_ex*p_label_noise)))
142 | print('Dataset shape: x is {}, y is {}'.format(data.data.shape, data.targets.shape))
143 | assert n_ex == data.data.shape[0] # there was a mistake previously here leading to a larger noise level
144 |
145 | # gen random indices
146 | indices = np.random.permutation(np.arange(len(data.targets)))[:int(n_ex*p_label_noise)]
147 | for index in indices:
148 | if noise_type == 'sym':
149 | lst_classes = list(range(n_cls))
150 | cls_int = data.targets[index] if type(data.targets[index]) is int else data.targets[index].item()
151 | lst_classes.remove(cls_int)
152 | data.targets[index] = np.random.choice(lst_classes)
153 | else:
154 | data.targets[index] = asym_label_noise(dataset, data.targets[index])
155 | data.label_noise[indices] = True
156 | print(data.data.shape)
157 | data = DatasetWithLabelNoise(data, split, transform if dataset != 'gaussians_binary' else None)
158 | loader = torch.utils.data.DataLoader(
159 | dataset=data, batch_size=batch_size, shuffle=shuffle, pin_memory=True,
160 | num_workers=num_workers_train if split == 'train' else num_workers_val, drop_last=drop_last)
161 |
162 | elif split == 'test':
163 | if dataset != 'svhn':
164 | data = dataset_f(dir_, train=False, transform=transform, download=True)
165 | else:
166 | data = dataset_f(dir_, split='test', transform=transform, download=True)
167 | data.data = data.data.transpose([0, 2, 3, 1])
168 | data.targets = data.labels
169 | n_ex = len(data) if n_ex == -1 else n_ex
170 |
171 | if dataset in ['cifar10_horse_car', 'cifar10_dog_cat']:
172 | data.targets = np.array(data.targets)
173 | idx = (data.targets == cl1) + (data.targets == cl2)
174 | data.data, data.targets = data.data[idx], data.targets[idx]
175 | data.targets[data.targets == cl1], data.targets[data.targets == cl2] = 0, 1
176 | data.targets = list(data.targets) # to reduce memory consumption
177 | if '_gs' in dataset:
178 | data.data = data.data.mean(3).astype(np.uint8)
179 | data.data, data.targets = data.data[:n_ex], data.targets[:n_ex]
180 | data.targets_correct = data.targets.copy()
181 |
182 | data.label_noise = np.zeros(n_ex)
183 | data = DatasetWithLabelNoise(data, split, transform if dataset != 'gaussians_binary' else None)
184 | loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, pin_memory=True,
185 | num_workers=num_workers_test, drop_last=drop_last)
186 |
187 | else:
188 | raise ValueError('wrong split')
189 |
190 | return loader
191 |
192 |
193 | def create_loader(x, y, ln, n_ex, batch_size, shuffle, drop_last):
194 | if n_ex > 0:
195 | x, y, ln = x[:n_ex], y[:n_ex], ln[:n_ex]
196 | data = td.TensorDataset(x, y, ln)
197 | loader = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, pin_memory=False,
198 | num_workers=2, drop_last=drop_last)
199 | return loader
200 |
201 |
202 | def get_xy_from_loader(loader, cuda=True, n_batches=-1):
203 | tuples = [(x, y, y_correct, ln) for i, (x, x_augm2, y, y_correct, ln) in enumerate(loader) if n_batches == -1 or i < n_batches]
204 | x_vals = torch.cat([x for (x, y, y_correct, ln) in tuples])
205 | y_vals = torch.cat([y for (x, y, y_correct, ln) in tuples])
206 | y_correct_vals = torch.cat([y_correct for (x, y, y_correct, ln) in tuples])
207 | ln_vals = torch.cat([ln for (x, y, y_correct, ln) in tuples])
208 | if cuda:
209 | x_vals, y_vals, y_correct_vals, ln_vals = x_vals.cuda(), y_vals.cuda(), y_correct_vals.cuda(), ln_vals.cuda()
210 | return x_vals, y_vals, y_correct_vals, ln_vals
211 |
212 |
213 | shapes_dict = {'mnist': (60000, 1, 28, 28),
214 | 'mnist_binary': (13007, 1, 28, 28),
215 | 'svhn': (73257, 3, 32, 32),
216 | 'cifar10': (50000, 3, 32, 32),
217 | 'cifar10_horse_car': (10000, 3, 32, 32),
218 | 'cifar10_dog_cat': (10000, 3, 32, 32),
219 | 'cifar100': (50000, 3, 32, 32),
220 | 'uniform_noise': (1000, 1, 28, 28),
221 | 'gaussians_binary': (1000, 1, 1, 100),
222 | }
223 | np.random.seed(0)
224 | v_global = np.random.randn(shapes_dict['gaussians_binary'][3]) # needed for consistency between train and test
225 | datasets_dict = {'mnist': datasets.MNIST,
226 | 'mnist_binary': datasets.MNIST,
227 | 'svhn': datasets.SVHN,
228 | 'cifar10': datasets.CIFAR10,
229 | 'cifar10_horse_car': datasets.CIFAR10,
230 | 'cifar10_dog_cat': datasets.CIFAR10,
231 | 'cifar100': datasets.CIFAR100,
232 | 'uniform_noise': uniform_noise,
233 | 'gaussians_binary': dataset_gaussians_binary,
234 | }
235 | classes_dict = {'cifar10': {0: 'airplane',
236 | 1: 'automobile',
237 | 2: 'bird',
238 | 3: 'cat',
239 | 4: 'deer',
240 | 5: 'dog',
241 | 6: 'frog',
242 | 7: 'horse',
243 | 8: 'ship',
244 | 9: 'truck',
245 | }
246 | }
247 |
248 |
--------------------------------------------------------------------------------
/deep_nets/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import math
6 |
7 |
8 | class Flatten(nn.Module):
9 | def forward(self, x):
10 | return x.view(x.size(0), -1)
11 |
12 |
13 | class Normalize(nn.Module):
14 | def __init__(self, mu, std):
15 | super(Normalize, self).__init__()
16 | self.mu, self.std = mu, std
17 |
18 | def forward(self, x):
19 | return (x - self.mu) / self.std
20 |
21 |
22 | class CustomReLU(nn.Module):
23 | def __init__(self):
24 | super(CustomReLU, self).__init__()
25 | self.collect_preact = True
26 | self.avg_preacts = []
27 |
28 | def forward(self, preact):
29 | if self.collect_preact:
30 | self.avg_preacts.append(preact.abs().mean().item())
31 | act = F.relu(preact)
32 | return act
33 |
34 |
35 | class ModuleWithStats(nn.Module):
36 | def __init__(self):
37 | super(ModuleWithStats, self).__init__()
38 |
39 | def forward(self, x):
40 | for layer in self._model:
41 | if type(layer) == CustomReLU:
42 | layer.avg_preacts = []
43 |
44 | out = self._model(x)
45 |
46 | avg_preacts_all = [layer.avg_preacts for layer in self._model if type(layer) == CustomReLU]
47 | self.avg_preact = np.mean(avg_preacts_all)
48 | return out
49 |
50 |
51 | class Linear(ModuleWithStats):
52 | def __init__(self, n_cls, shape_in):
53 | n_cls = 1 if n_cls == 2 else n_cls
54 | super(Linear, self).__init__()
55 | d = int(np.prod(shape_in[1:]))
56 | self._model = nn.Sequential(
57 | Flatten(),
58 | nn.Linear(d, n_cls, bias=False)
59 | )
60 |
61 | def forward(self, x):
62 | logits = self._model(x)
63 | return torch.cat([torch.zeros(logits.shape).cuda(), logits], dim=1)
64 |
65 |
66 | class LinearTwoOutputs(ModuleWithStats):
67 | def __init__(self, n_cls, shape_in):
68 | super(LinearTwoOutputs, self).__init__()
69 | d = int(np.prod(shape_in[1:]))
70 | self._model = nn.Sequential(
71 | Flatten(),
72 | nn.Linear(d, n_cls, bias=False)
73 | )
74 |
75 |
76 | class IdentityLayer(nn.Module):
77 | def forward(self, inputs):
78 | return inputs
79 |
80 |
81 | class PreActBlock(nn.Module):
82 | """ Pre-activation version of the BasicBlock. """
83 | expansion = 1
84 |
85 | def __init__(self, in_planes, planes, bn, learnable_bn, stride=1, activation='relu', droprate=0.0, gn_groups=32):
86 | super(PreActBlock, self).__init__()
87 | self.collect_preact = True
88 | self.activation = activation
89 | self.droprate = droprate
90 | self.avg_preacts = []
91 | self.bn1 = nn.BatchNorm2d(in_planes, affine=learnable_bn) if bn else nn.GroupNorm(gn_groups, in_planes)
92 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=not learnable_bn)
93 | self.bn2 = nn.BatchNorm2d(planes, affine=learnable_bn) if bn else nn.GroupNorm(gn_groups, planes)
94 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=not learnable_bn)
95 |
96 | if stride != 1 or in_planes != self.expansion*planes:
97 | self.shortcut = nn.Sequential(
98 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=not learnable_bn)
99 | )
100 |
101 | def act_function(self, preact):
102 | if self.activation == 'relu':
103 | act = F.relu(preact)
104 | # print((act == 0).float().mean().item(), (act.norm() / act.shape[0]).item(), (act.norm() / np.prod(act.shape)).item())
105 | else:
106 | assert self.activation[:8] == 'softplus'
107 | beta = int(self.activation.split('softplus')[1])
108 | act = F.softplus(preact, beta=beta)
109 | return act
110 |
111 | def forward(self, x):
112 | out = self.act_function(self.bn1(x))
113 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x # Important: using out instead of x
114 | out = self.conv1(out)
115 | out = self.act_function(self.bn2(out))
116 | if self.droprate > 0:
117 | out = F.dropout(out, p=self.droprate, training=self.training)
118 | out = self.conv2(out)
119 | out += shortcut
120 | return out
121 |
122 |
123 | class BasicBlock(nn.Module):
124 | def __init__(self, in_planes, out_planes, stride, droprate=0.0):
125 | super(BasicBlock, self).__init__()
126 | self.bn1 = nn.BatchNorm2d(in_planes)
127 | self.relu1 = nn.ReLU(inplace=True)
128 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
129 | self.bn2 = nn.BatchNorm2d(out_planes)
130 | self.relu2 = nn.ReLU(inplace=True)
131 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
132 | self.droprate = droprate
133 | self.equalInOut = (in_planes == out_planes)
134 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
135 | padding=0, bias=False) or None
136 |
137 | def forward(self, x):
138 | if not self.equalInOut:
139 | x = self.relu1(self.bn1(x))
140 | else:
141 | out = self.relu1(self.bn1(x))
142 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
143 | if self.droprate > 0:
144 | out = F.dropout(out, p=self.droprate, training=self.training)
145 | out = self.conv2(out)
146 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
147 |
148 |
149 | class BasicBlockResNet34(nn.Module):
150 | expansion = 1
151 |
152 | def __init__(self, in_planes, planes, stride=1):
153 | super(BasicBlockResNet34, self).__init__()
154 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
155 | self.bn1 = nn.BatchNorm2d(planes)
156 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
157 | self.bn2 = nn.BatchNorm2d(planes)
158 |
159 | self.shortcut = nn.Sequential()
160 | if stride != 1 or in_planes != self.expansion*planes:
161 | self.shortcut = nn.Sequential(
162 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
163 | nn.BatchNorm2d(self.expansion*planes)
164 | )
165 |
166 | def forward(self, x):
167 | out = F.relu(self.bn1(self.conv1(x)))
168 | out = self.bn2(self.conv2(out))
169 | out += self.shortcut(x)
170 | out = F.relu(out)
171 | return out
172 |
173 |
174 | class NetworkBlock(nn.Module):
175 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, droprate=0.0):
176 | super(NetworkBlock, self).__init__()
177 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, droprate)
178 |
179 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, droprate):
180 | layers = []
181 | for i in range(int(nb_layers)):
182 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, droprate))
183 | return nn.Sequential(*layers)
184 |
185 | def forward(self, x):
186 | return self.layer(x)
187 |
188 |
189 | class ResNet(nn.Module):
190 | def __init__(self, block, num_blocks, num_classes=10, model_width=64, droprate=0.0):
191 | super(ResNet, self).__init__()
192 | self.in_planes = model_width
193 | self.half_prec = False
194 | # self.mu = torch.tensor((0.4914, 0.4822, 0.4465)).view(1, 3, 1, 1).cuda()
195 | # self.std = torch.tensor((0.2471, 0.2435, 0.2616)).view(1, 3, 1, 1).cuda()
196 | self.mu = torch.tensor((0.0, 0.0, 0.0)).view(1, 3, 1, 1).cuda()
197 | self.std = torch.tensor((1.0, 1.0, 1.0)).view(1, 3, 1, 1).cuda()
198 | # if self.half_prec:
199 | # self.mu, self.std = self.mu.half(), self.std.half()
200 |
201 | self.normalize = Normalize(self.mu, self.std)
202 | self.conv1 = nn.Conv2d(3, model_width, kernel_size=3, stride=1, padding=1, bias=False)
203 | self.bn1 = nn.BatchNorm2d(model_width)
204 | self.layer1 = self._make_layer(block, model_width, num_blocks[0], stride=1)
205 | self.layer2 = self._make_layer(block, 2*model_width, num_blocks[1], stride=2)
206 | self.layer3 = self._make_layer(block, 4*model_width, num_blocks[2], stride=2)
207 | self.layer4 = self._make_layer(block, 8*model_width, num_blocks[3], stride=2)
208 | self.linear = nn.Linear(8*model_width*block.expansion, num_classes)
209 |
210 | def _make_layer(self, block, planes, num_blocks, stride):
211 | strides = [stride] + [1]*(num_blocks-1)
212 | layers = []
213 | for stride in strides:
214 | layers.append(block(self.in_planes, planes, stride))
215 | self.in_planes = planes * block.expansion
216 | return nn.Sequential(*layers)
217 |
218 | def forward(self, x, return_features=False, return_block=5):
219 | assert return_block in [1, 2, 3, 4, 5], 'wrong return_block'
220 | # out = self.normalize(x)
221 | out = F.relu(self.bn1(self.conv1(x)))
222 | out = self.layer1(out)
223 | if return_features and return_block == 1:
224 | return out
225 | out = self.layer2(out)
226 | if return_features and return_block == 2:
227 | return out
228 | out = self.layer3(out)
229 | if return_features and return_block == 3:
230 | return out
231 | out = self.layer4(out)
232 | if return_features and return_block == 4:
233 | return out
234 | out = F.avg_pool2d(out, 4)
235 | out = out.view(out.size(0), -1)
236 | if return_features and return_block == 5:
237 | return out
238 | out = self.linear(out)
239 | return out
240 |
241 |
242 | class PreActResNet(nn.Module):
243 | def __init__(self, block, num_blocks, n_cls, model_width=64, cuda=True, half_prec=False, activation='relu',
244 | droprate=0.0, bn_flag=True):
245 | super(PreActResNet, self).__init__()
246 | self.half_prec = half_prec
247 | self.bn_flag = bn_flag
248 | self.gn_groups = model_width // 2 # in particular, 32 for model_width=64 as in the original GroupNorm paper
249 | self.learnable_bn = True # doesn't matter if self.bn=False
250 | self.in_planes = model_width
251 | self.avg_preact = None
252 | self.activation = activation
253 | self.n_cls = n_cls
254 | # self.mu = torch.tensor((0.4914, 0.4822, 0.4465)).view(1, 3, 1, 1)
255 | # self.std = torch.tensor((0.2471, 0.2435, 0.2616)).view(1, 3, 1, 1)
256 | self.mu = torch.tensor((0.0, 0.0, 0.0)).view(1, 3, 1, 1)
257 | self.std = torch.tensor((1.0, 1.0, 1.0)).view(1, 3, 1, 1)
258 |
259 | if cuda:
260 | self.mu, self.std = self.mu.cuda(), self.std.cuda()
261 | # if half_prec:
262 | # self.mu, self.std = self.mu.half(), self.std.half()
263 |
264 | self.normalize = Normalize(self.mu, self.std)
265 | self.conv1 = nn.Conv2d(3, model_width, kernel_size=3, stride=1, padding=1, bias=not self.learnable_bn)
266 | self.layer1 = self._make_layer(block, model_width, num_blocks[0], 1, droprate)
267 | self.layer2 = self._make_layer(block, 2*model_width, num_blocks[1], 2, droprate)
268 | self.layer3 = self._make_layer(block, 4*model_width, num_blocks[2], 2, droprate)
269 | final_layer_factor = 8
270 | self.layer4 = self._make_layer(block, final_layer_factor*model_width, num_blocks[3], 2, droprate)
271 | self.bn = nn.BatchNorm2d(final_layer_factor*model_width*block.expansion) if self.bn_flag \
272 | else nn.GroupNorm(self.gn_groups, final_layer_factor*model_width*block.expansion)
273 | self.linear = nn.Linear(final_layer_factor*model_width*block.expansion, 1 if n_cls == 2 else n_cls)
274 |
275 | def _make_layer(self, block, planes, num_blocks, stride, droprate):
276 | strides = [stride] + [1]*(num_blocks-1)
277 | layers = []
278 | for stride in strides:
279 | layers.append(block(self.in_planes, planes, self.bn_flag, self.learnable_bn, stride, self.activation,
280 | droprate, self.gn_groups))
281 | # layers.append(block(self.in_planes, planes, stride))
282 | self.in_planes = planes * block.expansion
283 | return nn.Sequential(*layers)
284 |
285 | def forward(self, x, return_features=False, return_block=5):
286 | assert return_block in [1, 2, 3, 4, 5], 'wrong return_block'
287 | for layer in [*self.layer1, *self.layer2, *self.layer3, *self.layer4]:
288 | layer.avg_preacts = []
289 |
290 | # x = x / ((x**2).sum([1, 2, 3], keepdims=True)**0.5 + 1e-6) # numerical stability is needed for RLAT
291 | out = self.normalize(x)
292 | out = self.conv1(out)
293 | out = self.layer1(out)
294 | if return_features and return_block == 1:
295 | return out
296 | out = self.layer2(out)
297 | if return_features and return_block == 2:
298 | return out
299 | out = self.layer3(out)
300 | if return_features and return_block == 3:
301 | return out
302 | out = self.layer4(out)
303 | out = F.relu(self.bn(out))
304 | if return_features and return_block == 4:
305 | return out
306 | out = F.avg_pool2d(out, 4)
307 | out = out.view(out.size(0), -1)
308 | if return_features and return_block == 5:
309 | return out
310 |
311 | out = self.linear(out)
312 | if out.shape[1] == 1:
313 | out = torch.cat([torch.zeros_like(out), out], dim=1)
314 |
315 | return out
316 |
317 |
318 | class WideResNet(nn.Module):
319 | """ Based on code from https://github.com/yaodongyu/TRADES """
320 | def __init__(self, depth=28, num_classes=10, widen_factor=10, droprate=0.0, bias_last=True):
321 | super(WideResNet, self).__init__()
322 | self.half_prec = False
323 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
324 | assert ((depth - 4) % 6 == 0)
325 | n = (depth - 4) / 6
326 | block = BasicBlock
327 | # 1st conv before any network block
328 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)
329 | # 1st block
330 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, droprate)
331 | # 2nd block
332 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, droprate)
333 | # 3rd block
334 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, droprate)
335 | # global average pooling and classifier
336 | self.bn1 = nn.BatchNorm2d(nChannels[3])
337 | self.relu = nn.ReLU(inplace=True)
338 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
339 | self.nChannels = nChannels[3]
340 |
341 | for m in self.modules():
342 | if isinstance(m, nn.Conv2d):
343 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
344 | m.weight.data.normal_(0, math.sqrt(2. / n))
345 | elif isinstance(m, nn.BatchNorm2d):
346 | m.weight.data.fill_(1)
347 | m.bias.data.zero_()
348 | elif isinstance(m, nn.Linear) and not m.bias is None:
349 | m.bias.data.zero_()
350 |
351 | def forward(self, x):
352 | out = self.conv1(x)
353 | out = self.block1(out)
354 | out = self.block2(out)
355 | out = self.block3(out)
356 | out = self.relu(self.bn1(out))
357 | out = F.avg_pool2d(out, 8)
358 | out = out.view(-1, self.nChannels)
359 | return self.fc(out)
360 |
361 |
362 | class VGG(nn.Module):
363 | '''
364 | VGG model. Source: https://github.com/chengyangfu/pytorch-vgg-cifar10/blob/master/vgg.py
365 | (in turn modified from https://github.com/pytorch/vision.git)
366 | '''
367 | def __init__(self, n_cls, half_prec, cfg):
368 | super(VGG, self).__init__()
369 | self.half_prec = half_prec
370 | self.mu = torch.tensor((0.485, 0.456, 0.406)).view(1, 3, 1, 1).cuda()
371 | self.std = torch.tensor((0.229, 0.224, 0.225)).view(1, 3, 1, 1).cuda()
372 | self.normalize = Normalize(self.mu, self.std)
373 | self.features = self.make_layers(cfg)
374 | n_out = cfg[-2] # equal to 8*model_width
375 | self.classifier = nn.Sequential(
376 | # nn.Dropout(),
377 | nn.Linear(n_out, n_out),
378 | nn.ReLU(True),
379 | # nn.Dropout(),
380 | nn.Linear(n_out, n_out),
381 | nn.ReLU(True),
382 | nn.Linear(n_out, n_cls),
383 | )
384 | # Initialize weights
385 | for m in self.modules():
386 | if isinstance(m, nn.Conv2d):
387 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
388 | m.weight.data.normal_(0, math.sqrt(2. / n))
389 | m.bias.data.zero_()
390 |
391 | def make_layers(self, cfg, batch_norm=False):
392 | layers = []
393 | in_channels = 3
394 | for v in cfg:
395 | if v == 'M':
396 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
397 | else:
398 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
399 | if batch_norm:
400 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
401 | else:
402 | layers += [conv2d, nn.ReLU(inplace=True)]
403 | in_channels = v
404 | return nn.Sequential(*layers)
405 |
406 | def forward(self, x):
407 | x = self.normalize(x)
408 | x = self.features(x)
409 | x = x.view(x.size(0), -1)
410 | x = self.classifier(x)
411 | return x
412 |
413 |
414 | def VGG16(n_cls, model_width, half_prec):
415 | """VGG 16-layer model (configuration "D")"""
416 | w1, w2, w3, w4, w5 = model_width, 2*model_width, 4*model_width, 8*model_width, 8*model_width
417 | # cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
418 | cfg = [w1, w1, 'M', w2, w2, 'M', w3, w3, w3, 'M', w4, w4, w4, 'M', w5, w5, w5, 'M']
419 | return VGG(n_cls, half_prec, cfg)
420 |
421 |
422 | def TinyResNet(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0):
423 | bn_flag = True
424 | return PreActResNet(PreActBlock, [1, 1, 1, 1], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec,
425 | activation=activation, droprate=droprate, bn_flag=bn_flag)
426 |
427 | def TinyResNetGroupNorm(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0):
428 | bn_flag = False
429 | return PreActResNet(PreActBlock, [1, 1, 1, 1], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec,
430 | activation=activation, droprate=droprate, bn_flag=bn_flag)
431 |
432 |
433 | def ResNet18(n_cls, model_width=64):
434 | return ResNet(BasicBlockResNet34, [2, 2, 2, 2], num_classes=n_cls, model_width=model_width)
435 |
436 |
437 | def PreActResNet18(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0):
438 | bn_flag = True
439 | return PreActResNet(PreActBlock, [2, 2, 2, 2], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec,
440 | activation=activation, droprate=droprate, bn_flag=bn_flag)
441 |
442 |
443 | def PreActResNet34(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0):
444 | bn_flag = True
445 | return PreActResNet(PreActBlock, [3, 4, 6, 3], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec,
446 | activation=activation, droprate=droprate, bn_flag=bn_flag)
447 |
448 |
449 | def PreActResNet18GroupNorm(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0):
450 | bn_flag = False # bn_flag==False means that we use GroupNorm with 32 groups
451 | return PreActResNet(PreActBlock, [2, 2, 2, 2], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec,
452 | activation=activation, droprate=droprate, bn_flag=bn_flag)
453 |
454 |
455 | def PreActResNet34GroupNorm(n_cls, model_width=64, cuda=True, half_prec=False, activation='relu', droprate=0.0):
456 | bn_flag = False # bn_flag==False means that we use GroupNorm with 32 groups
457 | return PreActResNet(PreActBlock, [3, 4, 6, 3], n_cls=n_cls, model_width=model_width, cuda=cuda, half_prec=half_prec,
458 | activation=activation, droprate=droprate, bn_flag=bn_flag)
459 |
460 |
461 | def ResNet34(n_cls, model_width=64):
462 | return ResNet(BasicBlockResNet34, [3, 4, 6, 3], num_classes=n_cls, model_width=model_width)
463 |
464 |
465 | def WideResNet28(n_cls, model_width=10):
466 | return WideResNet(num_classes=n_cls, widen_factor=model_width)
467 |
468 |
469 | def get_model(model_name, n_cls, half_prec, shapes_dict, model_width, activation='relu', droprate=0.0):
470 | if model_name == 'resnet18':
471 | model = PreActResNet18(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate)
472 | elif model_name == 'resnet18_plain':
473 | model = ResNet18(n_cls, model_width)
474 | elif model_name == 'resnet18_gn':
475 | model = PreActResNet18GroupNorm(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate)
476 | elif model_name == 'vgg16':
477 | assert droprate == 0.0, 'dropout is not implemented for vgg16'
478 | model = VGG16(n_cls, model_width, half_prec)
479 | elif model_name in ['resnet34', 'resnet34_plain']:
480 | model = ResNet34(n_cls, model_width)
481 | elif model_name == 'resnet34_gn':
482 | model = PreActResNet34GroupNorm(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate)
483 | elif model_name == 'resnet34preact':
484 | model = PreActResNet34(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate)
485 | elif model_name == 'wrn28':
486 | model = WideResNet28(n_cls, model_width)
487 | elif model_name == 'resnet_tiny':
488 | model = TinyResNet(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate)
489 | elif model_name == 'resnet_tiny_gn':
490 | model = TinyResNetGroupNorm(n_cls, model_width=model_width, half_prec=half_prec, activation=activation, droprate=droprate)
491 | elif model_name == 'linear':
492 | model = Linear(n_cls, shapes_dict)
493 | else:
494 | raise ValueError('wrong model')
495 | return model
496 |
497 |
498 | def init_weights(model, scale_init=0.0):
499 | def init_weights_he(m):
500 | if isinstance(m, nn.Conv2d):
501 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
502 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
503 | m.weight.data.fill_(1)
504 | m.bias.data.zero_()
505 | elif isinstance(m, nn.Linear):
506 | m.bias.data.zero_()
507 |
508 | return init_weights_he
509 |
510 |
511 | def forward_pass_rlat(model, x, deltas, layers):
512 | i = 0
513 |
514 | def out_hook(m, inp, out_layer):
515 | nonlocal i
516 | if layers[i] == model.normalize:
517 | new_out = (torch.clamp(inp[0] + deltas[i], 0, 1) - model.mu) / model.std
518 | else:
519 | new_out = out_layer + deltas[i]
520 | i += 1
521 | return new_out
522 |
523 | handles = [layer.register_forward_hook(out_hook) for layer in layers]
524 | out = model(x)
525 |
526 | for handle in handles:
527 | handle.remove()
528 | return out
529 |
530 |
--------------------------------------------------------------------------------
/deep_nets/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import copy
8 | import utils
9 | import utils_eval
10 | import utils_train
11 | import data
12 | import models
13 | from collections import defaultdict
14 | from datetime import datetime
15 | from models import forward_pass_rlat
16 |
17 |
18 | def get_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--batch_size', default=128, type=int)
21 | parser.add_argument('--dataset', default='cifar10', choices=data.datasets_dict.keys(), type=str)
22 | parser.add_argument('--model', default='resnet18', choices=['vgg16', 'resnet18', 'resnet18_plain', 'resnet18_gn', 'resnet_tiny', 'resnet_tiny_gn', 'resnet34', 'resnet34_plain', 'resnet34preact', 'resnet34_gn', 'wrn28'], type=str)
23 | parser.add_argument('--epochs', default=100, type=int, help='100 epochs is standard with batch_size=128')
24 | parser.add_argument('--lr_schedule', default='piecewise', choices=['cyclic', 'piecewise', 'cosine', 'constant', 'piecewise_02epochs', 'piecewise_03epochs', 'piecewise_05epochs'])
25 | parser.add_argument('--ln_schedule', default='constant', choices=['constant', 'inverted_cosine', 'piecewise_10_100', 'piecewise_3_9', 'piecewise_3_inf', 'piecewise_2_3_3', 'piecewise_5_3_3', 'piecewise_8_3_3'])
26 | parser.add_argument('--lr_init', default=0.1, type=float, help='')
27 | parser.add_argument('--warmup_factor', default=1.0, type=float, help='linear warmup factor of the peak lr')
28 | parser.add_argument('--warmup_exp', default=1.0, type=float, help='the exponent of the exponential warmup')
29 | parser.add_argument('--momentum', default=0.9, type=float, help='')
30 | parser.add_argument('--p_label_noise', default=0.0, type=float, help='Fraction of flipped labels in the training data.')
31 | parser.add_argument('--noise_type', default='sym', type=str, choices=['sym', 'asym'], help='Noise type: symmetric or asymmetric')
32 | parser.add_argument('--attack', default='none', type=str, choices=['fgsm', 'fgsmpp', 'pgd', 'rlat', 'random_noise', 'none'])
33 | parser.add_argument('--at_pred_label', action='store_true', help='Use predicted labels for AT.')
34 | parser.add_argument('--swa_tau', default=0.999, type=float, help='SWA moving averaging coefficient (averaging executed every iteration).')
35 | parser.add_argument('--sgd_p_label_noise', default=0.0, type=float, help='ratio of label noise in SGD per batch')
36 | parser.add_argument('--frac_train', default=1, type=float, help='Fraction of training points.')
37 | parser.add_argument('--l2_reg', default=0.0, type=float, help='l2 regularization in the objective')
38 | parser.add_argument('--seed', default=0, type=int)
39 | parser.add_argument('--gpu', default=0, type=int)
40 | parser.add_argument('--debug', action='store_true')
41 | parser.add_argument('--save_model_each_k_epochs', default=0, type=int, help='save each k epochs; 0 means saving only at the end')
42 | parser.add_argument('--half_prec', action='store_true', help='if enabled, runs everything as half precision [not recommended]')
43 | parser.add_argument('--no_data_augm', action='store_true')
44 | parser.add_argument('--eval_iter_freq', default=-1, type=int, help='how often to evaluate test stats. -1 means to evaluate each #iter that corresponds to 2nd epoch with frac_train=1.')
45 | parser.add_argument('--n_eval_every_k_iter', default=512, type=int, help='on how many examples to eval every k iters')
46 | parser.add_argument('--model_width', default=-1, type=int, help='model width (# conv filters on the first layer for ResNets)')
47 | parser.add_argument('--batch_size_eval', default=512, type=int, help='batch size for the final eval with pgd rr; 6 GB memory is consumed for 1024 examples with fp32 network')
48 | parser.add_argument('--n_final_eval', default=10000, type=int, help='on how many examples to do the final evaluation; -1 means on all test examples.')
49 | parser.add_argument('--exp_name', default='other', type=str)
50 | parser.add_argument('--model_path', type=str, default='', help='Path to a checkpoint to continue training from.')
51 | return parser.parse_args()
52 |
53 |
54 | def main():
55 | args = get_args()
56 | assert args.model_width != -1, 'args.model_width has to be always specified (e.g., 64 for resnet18, 10 for wrn28)'
57 | assert 0 <= args.frac_train <= 1
58 | assert 0 <= args.sgd_p_label_noise <= 1
59 | assert 0 <= args.swa_tau <= 1
60 |
61 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
62 | cur_timestamp = str(datetime.now())[:-3] # include also ms to prevent the probability of name collision
63 | model_name = '{} dataset={} model={} epochs={} lr_init={} model_width={} l2_reg={} batch_size={} frac_train={} p_label_noise={} seed={}'.format(
64 | cur_timestamp, args.dataset, args.model, args.epochs, args.lr_init, args.model_width, args.l2_reg,
65 | args.batch_size, args.frac_train, args.p_label_noise, args.seed)
66 | logger = utils.configure_logger(model_name, args.debug)
67 | logger.info(args)
68 |
69 | n_cls = 2 if args.dataset in ['cifar10_horse_car', 'cifar10_dog_cat'] else 10 if args.dataset != 'cifar100' else 100
70 | n_train = int(args.frac_train * data.shapes_dict[args.dataset][0])
71 |
72 | args.exp_name = 'exps/{}'.format(args.exp_name)
73 | if not os.path.exists(args.exp_name): os.makedirs(args.exp_name)
74 |
75 | # fixing the seed helps, but not completely. there is still some non-determinism due to GPU computations.
76 | np.random.seed(args.seed)
77 | torch.manual_seed(args.seed)
78 | torch.cuda.manual_seed(args.seed)
79 |
80 | train_data_augm = False if args.no_data_augm or args.model == 'linear' or args.dataset in ['mnist', 'mnist_binary', 'gaussians_binary'] else True
81 | train_batches = data.get_loaders(args.dataset, n_train, args.batch_size, split='train', val_indices=[], shuffle=True, data_augm=train_data_augm, p_label_noise=args.p_label_noise, noise_type=args.noise_type, drop_last=True)
82 | train_batches_large_bs = data.get_loaders(args.dataset, n_train, args.batch_size_eval, split='train', val_indices=[], shuffle=False, data_augm=False, p_label_noise=args.p_label_noise, noise_type=args.noise_type, drop_last=False)
83 | test_batches = data.get_loaders(args.dataset, args.n_final_eval, args.batch_size_eval, split='test', shuffle=True, data_augm=False, noise_type=args.noise_type, drop_last=False)
84 |
85 | model = models.get_model(args.model, n_cls, args.half_prec, data.shapes_dict[args.dataset], args.model_width).cuda()
86 | if args.model_path != '':
87 | model_dict = torch.load(args.model_path)['last']
88 | model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})
89 | else:
90 | model.apply(models.init_weights(args.model))
91 | model.train()
92 | model_swa = copy.deepcopy(model).eval() # stochastic weight averaging model (keep it in the eval mode by default)
93 |
94 | opt = torch.optim.SGD(model.parameters(), lr=args.lr_init, momentum=args.momentum)
95 | scaler = torch.cuda.amp.GradScaler(enabled=model.half_prec)
96 | lr_schedule = utils_train.get_lr_schedule(args.lr_schedule, args.epochs, args.lr_init, args.warmup_factor, args.warmup_exp)
97 | ln_schedule = utils_train.get_lr_schedule(args.ln_schedule, args.epochs, args.sgd_p_label_noise)
98 |
99 | loss_f = lambda logits, y: F.cross_entropy(logits, y, reduction='mean')
100 |
101 | metr_dict = defaultdict(list, vars(args))
102 | start_time = time.time()
103 | time_train, iteration = 0, 0
104 | for epoch in range(args.epochs + 1):
105 | model = model.eval() if epoch == 0 else model.train() # epoch=0 is eval only
106 |
107 | train_obj, train_reg = 0, 0
108 | for i, (x, _, y, _, ln) in enumerate(train_batches):
109 | if epoch == 0 and i > 0: # epoch=0 runs only for one iteration (to check the training stats at init)
110 | break
111 | time_start_iter = time.time()
112 | x, y = x.cuda(), y.cuda()
113 | lr = lr_schedule(epoch - 1 + (i + 1) / len(train_batches)) # epoch - 1 since the 0th epoch is skipped
114 | opt.param_groups[0].update(lr=lr)
115 |
116 | # label noise SGD
117 | if args.sgd_p_label_noise > 0.0:
118 | sgd_p_label_noise_eff = ln_schedule(epoch - 1 + (i + 1) / len(train_batches))
119 | n_noisy_pts = (torch.rand(args.batch_size) < sgd_p_label_noise_eff).int().sum() # randomized fraction of noisy points
120 | rand_indices = torch.randperm(args.batch_size)[:n_noisy_pts].cuda()
121 | rand_labels = torch.randint(low=0, high=n_cls, size=(n_noisy_pts, )).cuda()
122 | y[rand_indices] = rand_labels
123 |
124 | with torch.cuda.amp.autocast(enabled=model.half_prec):
125 | logits = model(x)
126 | obj = loss_f(logits, y)
127 |
128 | reg = torch.zeros(1).cuda()[0]
129 | for param in model.parameters():
130 | reg += args.l2_reg * 0.5 * torch.sum(param ** 2).float()
131 | obj += reg
132 |
133 |
134 | opt.zero_grad()
135 | scaler.scale(obj).backward()
136 |
137 | train_obj += obj.item() / n_train # only for statistics
138 | train_reg += reg.item() / n_train # only for statistics
139 |
140 | if epoch > 0: # on 0-th epoch only evaluation occurs
141 | scaler.step(opt)
142 | scaler.update() # update the scale of the loss for fp16
143 |
144 | opt.zero_grad() # zero grad (also at epoch==0)
145 |
146 | time_train += time.time() - time_start_iter
147 | utils_train.moving_average(model_swa, model, 1-args.swa_tau) # executed every iteration
148 |
149 | # by default, evaluate every 2 epochs (update: 5 temporary to save time)
150 | if (args.eval_iter_freq == -1 and iteration % (5 * (n_train // args.batch_size)) == 0) or \
151 | (args.eval_iter_freq != -1 and iteration % args.eval_iter_freq == 0):
152 | utils_train.bn_update(train_batches, model_swa) # a bit heavy but ok to do once per 2 epochs
153 |
154 | model.eval() # it'd be incorrect to recalculate the BN stats based on some evaluations
155 |
156 | train_err, train_loss = utils_eval.compute_loss(train_batches, model, loss_f=loss_f, n_batches=4) # i.e. it's evaluated using 4*batch_size examples
157 | train_err_swa, train_loss_swa = utils_eval.compute_loss(train_batches, model_swa, loss_f=loss_f, n_batches=4) # i.e. it's evaluated using 4*batch_size examples
158 |
159 | sparsity_train_block1, sparsity_train_block1_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=1, n_batches=20)
160 | sparsity_train_block2, sparsity_train_block2_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=2, n_batches=20)
161 | sparsity_train_block3, sparsity_train_block3_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=3, n_batches=20)
162 | sparsity_train_block4, sparsity_train_block4_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=4, n_batches=20)
163 | sparsity_train_block5, sparsity_train_block5_rmdup, _ = utils_eval.compute_feature_sparsity(train_batches_large_bs, model, return_block=5, n_batches=20)
164 | # sparsity_train_block1, sparsity_train_block1_rmdup, sparsity_train_block2, sparsity_train_block2_rmdup, sparsity_train_block3, sparsity_train_block3_rmdup, sparsity_train_block4, sparsity_train_block4_rmdup, sparsity_train_block5, sparsity_train_block5_rmdup = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
165 |
166 | test_err, test_loss = utils_eval.compute_loss(test_batches, model, loss_f=loss_f)
167 | test_err_swa, _ = utils_eval.compute_loss(test_batches, model_swa, loss_f=loss_f)
168 |
169 | time_elapsed = time.time() - start_time
170 | train_str = '[train] loss {:.4f}/{:.4f} err {:.2%}'.format(train_loss, train_loss_swa, train_err)
171 | test_str = '[test] err {:.2%}/{:.2%} '.format(test_err, test_err_swa)
172 | sparsity_str = 'sparsity {:.1%}/{:.1%}/{:.1%}/{:.1%}/{:.1%}'.format(sparsity_train_block1_rmdup, sparsity_train_block2_rmdup, sparsity_train_block3_rmdup, sparsity_train_block4_rmdup, sparsity_train_block5_rmdup)
173 | logger.info('{}-{}: {} {} {} ({:.2f}m, {:.2f}m)'.format(
174 | epoch, iteration, train_str, test_str, sparsity_str, time_train/60, time_elapsed/60))
175 | metr_vals = [epoch, iteration, train_obj, train_loss, train_reg, train_err,
176 | test_err, test_loss, train_loss_swa, train_err_swa,
177 | test_err_swa, time_train, time_elapsed,
178 | sparsity_train_block1, sparsity_train_block2, sparsity_train_block3, sparsity_train_block4, sparsity_train_block5]
179 | metr_names = ['epoch', 'iter', 'train_obj', 'train_loss', 'train_reg', 'train_err',
180 | 'test_err', 'test_loss', 'train_loss_swa', 'train_err_swa', 'test_err_swa', 'time_train', 'time_elapsed',
181 | 'sparsity_train_block1', 'sparsity_train_block2', 'sparsity_train_block3', 'sparsity_train_block4', 'sparsity_train_block5']
182 | utils.update_metrics(metr_dict, metr_vals, metr_names)
183 |
184 | if not args.debug:
185 | np.save('{}/{}.npy'.format(args.exp_name, model_name), metr_dict)
186 |
187 | model.train()
188 |
189 | iteration += 1
190 |
191 | if args.save_model_each_k_epochs > 0:
192 | if epoch % args.save_model_each_k_epochs == 0 or epoch <= 5:
193 | torch.save({'last': model.state_dict()}, 'models/{} epoch={}.pth'.format(model_name, epoch))
194 |
195 | if not args.debug:
196 | np.save('{}/{}.npy'.format(args.exp_name, model_name), metr_dict)
197 | if epoch == args.epochs: # only save at the end
198 | torch.save({'last': model.state_dict(), 'swa_last': model_swa.state_dict()},
199 | 'models/{} epoch={}.pth'.format(model_name, epoch))
200 |
201 | logger.info('Saved the model at: models/{} epoch={}.pth'.format(model_name, epoch))
202 | logger.info('Done in {:.2f}m'.format((time.time() - start_time) / 60))
203 |
204 |
205 | if __name__ == "__main__":
206 | main()
207 |
--------------------------------------------------------------------------------
/deep_nets/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch
4 | from contextlib import contextmanager
5 |
6 |
7 | logger = logging.getLogger(__name__)
8 | logging.basicConfig(
9 | format='[%(asctime)s %(filename)s %(name)s %(levelname)s] - %(message)s',
10 | datefmt='%Y/%m/%d %H:%M:%S',
11 | level=logging.DEBUG)
12 |
13 |
14 | def clamp(X, l, u, cuda=True):
15 | if type(l) is not torch.Tensor:
16 | if cuda:
17 | l = torch.cuda.FloatTensor(1).fill_(l)
18 | else:
19 | l = torch.FloatTensor(1).fill_(l)
20 | if type(u) is not torch.Tensor:
21 | if cuda:
22 | u = torch.cuda.FloatTensor(1).fill_(u)
23 | else:
24 | u = torch.FloatTensor(1).fill_(u)
25 | return torch.max(torch.min(X, u), l)
26 |
27 |
28 | def configure_logger(model_name, debug):
29 | if not os.path.exists('logs'):
30 | os.makedirs('logs')
31 | logging.basicConfig(format='%(message)s') # , level=logging.DEBUG)
32 | logger = logging.getLogger()
33 | logger.handlers = [] # remove the default logger
34 |
35 | # add a new logger for stdout
36 | formatter = logging.Formatter('%(message)s')
37 | ch = logging.StreamHandler()
38 | ch.setFormatter(formatter)
39 | ch.setLevel(logging.DEBUG)
40 | logger.addHandler(ch)
41 |
42 | if not debug:
43 | # add a new logger to a log file
44 | logger.addHandler(logging.FileHandler('logs/{}.log'.format(model_name)))
45 |
46 | return logger
47 |
48 |
49 | def get_random_delta(shape, eps, at_norm, requires_grad=True):
50 | delta = torch.zeros(shape).cuda()
51 | if at_norm == 'l2': # uniform from the hypercube
52 | delta.normal_()
53 | delta /= (delta**2).sum([1, 2, 3], keepdim=True)**0.5
54 | elif at_norm == 'linf': # uniform on the sphere
55 | delta.uniform_(-eps, eps)
56 | else:
57 | raise ValueError('wrong at_norm')
58 | delta.requires_grad = requires_grad
59 | return delta
60 |
61 |
62 | def project_lp(img, at_norm, eps):
63 | if at_norm == 'l2': # uniform on the sphere
64 | l2_norms = (img ** 2).sum([1, 2, 3], keepdim=True) ** 0.5
65 | img_proj = img * torch.min(eps/l2_norms, torch.ones_like(l2_norms)) # if eps>l2_norms => multiply by 1
66 | elif at_norm == 'linf': # uniform from the hypercube
67 | img_proj = clamp(img, -eps, eps)
68 | else:
69 | raise ValueError('wrong at_norm')
70 | return img_proj
71 |
72 |
73 | def update_metrics(metrics_dict, metrics_values, metrics_names):
74 | assert len(metrics_values) == len(metrics_names)
75 | for metric_value, metric_name in zip(metrics_values, metrics_names):
76 | metrics_dict[metric_name].append(metric_value)
77 | return metrics_dict
78 |
79 |
80 | @contextmanager
81 | def nullcontext(enter_result=None):
82 | yield enter_result
83 |
84 |
85 | def get_flat_grad(model):
86 | return torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None])
87 |
88 |
89 | def zero_grad(model):
90 | for p in model.parameters():
91 | if p.grad is not None:
92 | p.grad.zero_()
93 |
94 |
--------------------------------------------------------------------------------
/deep_nets/utils_eval.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def compute_loss(batches, model, cuda=True, noisy_examples='default', loss_f=F.cross_entropy, n_batches=-1):
8 | n_corr_classified, train_loss_sum, n_ex = 0, 0.0, 0
9 | for i, (X, X_augm2, y, _, ln) in enumerate(batches):
10 | if n_batches != -1 and i > n_batches: # limit to only n_batches
11 | break
12 | if cuda:
13 | X, y = X.cuda(), y.cuda()
14 |
15 | if noisy_examples == 'none':
16 | X, y = X[~ln], y[~ln]
17 | elif noisy_examples == 'all':
18 | X, y = X[ln], y[ln]
19 | else:
20 | assert noisy_examples == 'default'
21 |
22 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=model.half_prec):
23 | output = model(X)
24 | loss = loss_f(output, y)
25 |
26 | n_corr_classified += (output.max(1)[1] == y).sum().item()
27 | train_loss_sum += loss.item() * y.size(0)
28 | n_ex += y.size(0)
29 |
30 | robust_acc = n_corr_classified / n_ex
31 | avg_loss = train_loss_sum / n_ex
32 |
33 | return 1 - robust_acc, avg_loss
34 |
35 |
36 | def compute_feature_sparsity(batches, model, return_block, corr_threshold=0.95, n_batches=-1, n_relu_max=1000):
37 | with torch.no_grad():
38 | features_list = []
39 | for i, (X, X_augm2, y, _, ln) in enumerate(batches):
40 | if n_batches != -1 and i > n_batches:
41 | break
42 | X, y = X.cuda(), y.cuda()
43 | features = model(X, return_features=True, return_block=return_block).cpu().numpy()
44 | features_list.append(features)
45 |
46 | phi = np.vstack(features_list)
47 | phi = phi.reshape(phi.shape[0], np.prod(phi.shape[1:]))
48 |
49 | sparsity = (phi > 0).sum() / (phi.shape[0] * phi.shape[1])
50 |
51 | if phi.shape[1] > n_relu_max: # if there are too many neurons, we speed it up by random subsampling
52 | random_idx = np.random.choice(phi.shape[1], n_relu_max, replace=False)
53 | phi = phi[:, random_idx]
54 |
55 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0]
56 | phi_filtered = phi[:, idx_keep] # filter out always-zeros
57 | corr_matrix = np.corrcoef(phi_filtered.T)
58 | corr_matrix -= np.eye(corr_matrix.shape[0])
59 |
60 | idx_to_delete, i, j = [], 0, 0
61 | while i != corr_matrix.shape[0]:
62 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum())
63 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0:
64 | corr_matrix = np.delete(corr_matrix, (i), axis=0)
65 | corr_matrix = np.delete(corr_matrix, (i), axis=1)
66 | # print('delete', j)
67 | idx_to_delete.append(j)
68 | else:
69 | i += 1
70 | j += 1
71 | assert corr_matrix.shape[0] == corr_matrix.shape[1]
72 | # print(idx_to_delete, idx_keep)
73 | idx_keep = np.delete(idx_keep, [idx_to_delete])
74 | sparsity_rmdup = (phi[:, idx_keep] > 0).sum() / (phi.shape[0] * phi.shape[1])
75 |
76 | n_highly_corr = phi.shape[1] - len(idx_keep)
77 | return sparsity, sparsity_rmdup, n_highly_corr
78 |
--------------------------------------------------------------------------------
/deep_nets/utils_train.py:
--------------------------------------------------------------------------------
1 | from logging import lastResort
2 | import torch
3 | import numpy as np
4 | import math
5 | from utils import clamp, get_random_delta, project_lp
6 |
7 |
8 | def get_lr_schedule(lr_schedule_type, n_epochs, lr, warmup_factor=1.0, warmup_exp=1.0):
9 | if lr_schedule_type == 'cyclic':
10 | lr_schedule = lambda epoch: np.interp([epoch], [0, n_epochs * 2 // 5, n_epochs], [0, lr, 0])[0]
11 | elif lr_schedule_type in ['piecewise', 'piecewise_10_100']:
12 | def lr_schedule(t):
13 | """
14 | Following the original ResNet paper (+ warmup for resnet34).
15 | t is the fractional number of epochs that is passed which starts from 0.
16 | """
17 | # if 100 epochs in total, then warmup lasts for exactly 2 first epochs
18 | # if t / n_epochs < 0.02 and model in ['resnet34']:
19 | # return lr_max / 10.
20 | if t / n_epochs < 0.5:
21 | return lr
22 | elif t / n_epochs < 0.75:
23 | return lr / 10.
24 | else:
25 | return lr / 100.
26 | elif lr_schedule_type in 'piecewise_01epochs':
27 | def lr_schedule(t):
28 | if warmup_exp > 1.0:
29 | if t / n_epochs < 0.1:
30 | return warmup_exp**(t/n_epochs*100) * lr
31 | elif t / n_epochs < 0.9:
32 | return warmup_exp**(0.1*100) * lr / 10.
33 | else:
34 | return warmup_exp**(0.1*100) * lr / 100.
35 | elif warmup_exp < 1.0:
36 | if t / n_epochs < 0.1:
37 | return (1 + (t/n_epochs*100)**warmup_exp) * lr
38 | elif t / n_epochs < 0.9:
39 | return (1 + (0.1*100)**warmup_exp) * lr / 10.
40 | else:
41 | return (1 + (0.1*100)**warmup_exp) * lr / 100.
42 | else:
43 | if t / n_epochs < 0.1:
44 | return np.interp([t], [0, 0.5 * n_epochs], [lr, warmup_factor*lr])[0] # note: we interpolate up to 0.5*t to be compatible with toy ReLU net experiments
45 | elif t / n_epochs < 0.9:
46 | return 0.1 / 0.5 * warmup_factor*lr / 10.
47 | else:
48 | return 0.1 / 0.5 * warmup_factor*lr / 100.
49 | elif lr_schedule_type in 'piecewise_03epochs':
50 | def lr_schedule(t):
51 | if warmup_exp > 1.0:
52 | if t / n_epochs < 0.3:
53 | return warmup_exp**(t/n_epochs*100) * lr
54 | elif t / n_epochs < 0.9:
55 | return warmup_exp**(0.3*100) * lr / 10.
56 | else:
57 | return warmup_exp**(0.3*100) * lr / 100.
58 | elif warmup_exp < 1.0:
59 | if t / n_epochs < 0.3:
60 | return (1 + (t/n_epochs*100)**warmup_exp) * lr
61 | elif t / n_epochs < 0.9:
62 | return (1 + (0.3*100)**warmup_exp) * lr / 10.
63 | else:
64 | return (1 + (0.3*100)**warmup_exp) * lr / 100.
65 | else:
66 | if t / n_epochs < 0.3:
67 | return np.interp([t], [0, 0.5 * n_epochs], [lr, warmup_factor*lr])[0] # note: we interpolate up to 0.5*t to be compatible with toy ReLU net experiments
68 | elif t / n_epochs < 0.9:
69 | return 0.3 / 0.5 * warmup_factor*lr / 10.
70 | else:
71 | return 0.3 / 0.5 * warmup_factor*lr / 100.
72 | elif lr_schedule_type in 'piecewise_05epochs':
73 | def lr_schedule(t):
74 | if warmup_exp > 1.0:
75 | if t / n_epochs < 0.5:
76 | return warmup_exp**(t/n_epochs*100) * lr
77 | elif t / n_epochs < 0.9:
78 | return warmup_exp**(0.5*100) * lr / 10.
79 | else:
80 | return warmup_exp**(0.5*100) * lr / 100.
81 | elif warmup_exp < 1.0:
82 | if t / n_epochs < 0.5:
83 | return (1 + (t/n_epochs*100)**warmup_exp) * lr
84 | elif t / n_epochs < 0.9:
85 | return (1 + (0.5*100)**warmup_exp) * lr / 10.
86 | else:
87 | return (1 + (0.5*100)**warmup_exp) * lr / 100.
88 | else:
89 | if t / n_epochs < 0.5:
90 | return np.interp([t], [0, 0.5 * n_epochs], [lr, warmup_factor*lr])[0] # note: we interpolate up to 0.5*t to be compatible with toy ReLU net experiments
91 | elif t / n_epochs < 0.9:
92 | return 0.5 / 0.5 * warmup_factor*lr / 10.
93 | else:
94 | return 0.5 / 0.5 * warmup_factor*lr / 100.
95 | elif lr_schedule_type == 'cosine':
96 | # cosine LR schedule without restarts like in the SAM paper
97 | # (as in the JAX implementation used in SAM https://flax.readthedocs.io/en/latest/_modules/flax/training/lr_schedule.html#create_cosine_learning_rate_schedule)
98 | return lambda epoch: lr * (0.5 + 0.5*math.cos(math.pi * epoch / n_epochs))
99 | elif lr_schedule_type == 'inverted_cosine':
100 | return lambda epoch: lr - lr * (0.5 + 0.5*math.cos(math.pi * epoch / n_epochs))
101 | elif lr_schedule_type == 'constant':
102 | return lambda epoch: lr
103 | else:
104 | raise ValueError('wrong lr_schedule_type')
105 | return lr_schedule
106 |
107 |
108 | def change_bn_mode(model, bn_train):
109 | for module in model.modules():
110 | if isinstance(module, torch.nn.BatchNorm2d):
111 | if bn_train:
112 | module.train()
113 | else:
114 | module.eval()
115 |
116 |
117 | def moving_average(net1, net2, alpha=0.999):
118 | for param1, param2 in zip(net1.parameters(), net2.parameters()):
119 | param1.data *= (1.0 - alpha)
120 | param1.data += param2.data * alpha
121 |
122 |
123 | def _check_bn(module, flag):
124 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
125 | flag[0] = True
126 |
127 |
128 | def check_bn(model):
129 | flag = [False]
130 | model.apply(lambda module: _check_bn(module, flag))
131 | return flag[0]
132 |
133 |
134 | def reset_bn(module):
135 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
136 | module.running_mean = torch.zeros_like(module.running_mean)
137 | module.running_var = torch.ones_like(module.running_var)
138 |
139 |
140 | def _get_momenta(module, momenta):
141 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
142 | momenta[module] = module.momentum
143 |
144 |
145 | def _set_momenta(module, momenta):
146 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
147 | module.momentum = momenta[module]
148 |
149 |
150 | def bn_update(loader, model):
151 | """
152 | BatchNorm buffers update (if any).
153 | Performs 1 epochs to estimate buffers average using train dataset.
154 | :param loader: train dataset loader for buffers average estimation.
155 | :param model: model being update
156 | :return: None
157 | """
158 | if not check_bn(model):
159 | return
160 | with torch.no_grad():
161 | model.train()
162 | momenta = {}
163 | model.apply(reset_bn)
164 | model.apply(lambda module: _get_momenta(module, momenta))
165 | n = 0
166 | for x, _, _, _, _ in loader:
167 | x = x.cuda(non_blocking=True)
168 | b = x.data.size(0)
169 |
170 | momentum = b / (n + b)
171 | for module in momenta.keys():
172 | module.momentum = momentum
173 |
174 | model(x)
175 | n += b
176 |
177 | model.apply(lambda module: _set_momenta(module, momenta))
178 | model.eval()
179 |
180 |
181 |
--------------------------------------------------------------------------------
/diag_nets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def loss(X, y_hat, y):
5 | return np.linalg.norm(y_hat - y)**2 / (2*X.shape[0])
6 |
7 |
8 |
9 | def valley_projection(X, y, gamma, u0, v0, num_iter):
10 | u, v = u0, v0
11 | for i in range(num_iter):
12 | y_hat = X @ (u * v)
13 | Xerr = X.T@(y_hat - y)
14 | grad_u, grad_v = (Xerr * v) / X.shape[0], (Xerr * u) / X.shape[0]
15 |
16 | u = u - gamma * grad_u
17 | v = v - gamma * grad_v
18 |
19 | return u, v
20 |
21 |
22 | def GD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, thresholds=[-1], decays=[2], wd=0, balancedness=0, normalized_gd=False, weight_avg=0.0, return_all_params=False, valley_project=False):
23 | train_losses, test_losses = [], []
24 | us, vs = [], []
25 | u, v, u_avg, v_avg = u0, v0, u0, v0
26 | for i in range(num_iter):
27 | if i in iters_loss:
28 | if valley_project:
29 | u_avg, v_avg = valley_projection(X, y, 0.5, u, v, 2000)
30 | train_losses += [loss(X, X @ (u_avg * v_avg), y)]
31 | test_losses += [loss(X_test, X_test @ (u_avg * v_avg), y_test)]
32 | us, vs = us + [u_avg], vs + [v_avg]
33 |
34 | y_hat = X @ (u * v)
35 | Xerr = X.T@(y_hat - y)
36 | grad_u, grad_v = (Xerr * v) / X.shape[0], (Xerr * u) / X.shape[0]
37 | if normalized_gd and (np.linalg.norm(grad_u) > 0 and np.linalg.norm(grad_v) > 0):
38 | grad_u, grad_v = grad_u / np.linalg.norm(grad_u), grad_v / np.linalg.norm(grad_v)
39 |
40 | u = u - gamma * grad_u - wd*u - balancedness*u*(2*(np.abs(u)>np.abs(v))-1)
41 | v = v - gamma * grad_v - wd*v - balancedness*v*(2*(np.abs(v)>np.abs(u))-1)
42 | u_avg, v_avg = weight_avg*u_avg + (1-weight_avg)*u, weight_avg*v_avg + (1-weight_avg)*v
43 |
44 | if i in thresholds:
45 | for threshold, decay in zip(thresholds, decays):
46 | if i == threshold:
47 | gamma = gamma / decay
48 |
49 | if return_all_params:
50 | return train_losses, test_losses, u, v, us, vs
51 | else:
52 | return train_losses, test_losses, u, v
53 |
54 |
55 | def SGD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, thresholds=[-1], decays=[2], weight_avg=0.0, return_all_params=False, valley_project=False):
56 | train_losses, test_losses = [], []
57 | us, vs = [], []
58 | u, v, u_avg, v_avg = u0, v0, u0, v0
59 | for i in range(num_iter):
60 | if i in iters_loss:
61 | if valley_project:
62 | u_avg, v_avg = valley_projection(X, y, 0.5, u, v, 2000)
63 | train_losses += [loss(X, X @ (u_avg * v_avg), y)]
64 | test_losses += [loss(X_test, X_test @ (u_avg * v_avg), y_test)]
65 | us, vs = us + [u_avg], vs + [v_avg]
66 |
67 | i_t = np.random.randint(X.shape[0])
68 | error = X[i_t] @ (u * v) - y[i_t]
69 | Xerr = error * X[i_t]
70 | grad_u, grad_v = Xerr * v, Xerr * u
71 |
72 | u = u - gamma * grad_u # gradient step
73 | v = v - gamma * grad_v # gradient step
74 | u_avg, v_avg = weight_avg*u_avg + (1-weight_avg)*u, weight_avg*v_avg + (1-weight_avg)*v
75 |
76 | if i in thresholds:
77 | for threshold, decay in zip(thresholds, decays):
78 | if i == threshold:
79 | gamma = gamma / decay
80 |
81 | if return_all_params:
82 | return train_losses, test_losses, u, v, us, vs
83 | else:
84 | return train_losses, test_losses, u, v
85 |
86 |
87 | def n_SAM_GD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, rho):
88 | train_losses, test_losses = [], []
89 | u, v = u0, v0
90 | for i in range(num_iter):
91 | y_hat = X @ (u * v)
92 | Xerr = X.T @ (y_hat - y)
93 | u_sam, v_sam = u + rho * (Xerr * v) / X.shape[0], v + rho * (Xerr * u) / X.shape[0]
94 |
95 | Xerr_sam = X.T @ (X @ (u_sam * v_sam) - y)
96 | grad_u_sam, grad_v_sam = (Xerr_sam * v_sam) / X.shape[0], (Xerr_sam * u_sam) / X.shape[0]
97 |
98 | u = u - gamma * grad_u_sam # gradient step
99 | v = v - gamma * grad_v_sam # gradient step
100 |
101 | if i in iters_loss:
102 | train_losses += [loss(X, X @ (u * v), y)]
103 | test_losses += [loss(X_test, X_test @ (u * v), y_test)]
104 |
105 | return train_losses, test_losses, u, v
106 |
107 |
108 | def one_SAM_GD(X, y, X_test, y_test, gamma, u0, v0, iters_loss, num_iter, rho, loss_derivative_only=False):
109 | train_losses, test_losses = [], []
110 | u, v = u0, v0
111 | for i in range(num_iter):
112 | y_hat = X @ (u * v)
113 | r = y_hat - y
114 | grad_u_sam, grad_v_sam = np.zeros_like(u), np.zeros_like(v)
115 | for k in range(X.shape[0]):
116 | u_sam_k = u + 2 * rho * r[k] * X[k] * v
117 | v_sam_k = v + 2 * rho * r[k] * X[k] * u
118 | if not loss_derivative_only:
119 | grad_u_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * v_sam_k * X[k] / X.shape[0]
120 | grad_v_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * u_sam_k * X[k] / X.shape[0]
121 | else:
122 | grad_u_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * v * X[k] / X.shape[0]
123 | grad_v_sam += ((X[k] * u_sam_k * v_sam_k).sum() - y[k]) * u * X[k] / X.shape[0]
124 |
125 | u = u - gamma * grad_u_sam # gradient step
126 | v = v - gamma * grad_v_sam # gradient step
127 |
128 | if i in iters_loss:
129 | train_losses += [loss(X, X @ (u * v), y)]
130 | test_losses += [loss(X_test, X_test @ (u * v), y_test)]
131 |
132 | return train_losses, test_losses, u, v
133 |
134 |
135 | def dln_hessian(u, v, X, y, normalized=False):
136 | beta = u * v
137 | if normalized:
138 | u = np.abs(beta)**0.5 * np.sign(u)
139 | v = np.abs(beta)**0.5 * np.sign(v)
140 | # print(beta[:10], u[:10], v[:10])
141 | n, d = X.shape
142 | H = np.zeros([2*d, 2*d])
143 | H[:d, :d] = np.diag(v) @ (X.T@X) @ np.diag(v)
144 | H[d:, :d] = np.diag(v) @ (X.T@X) @ np.diag(u) + np.diag(X.T@(X@(u*v) - y))
145 | H[:d, d:] = np.diag(u) @ (X.T@X) @ np.diag(v) + np.diag(X.T@(X@(u*v) - y))
146 | H[d:, d:] = np.diag(u) @ (X.T@X) @ np.diag(u)
147 | return H / n
148 |
149 |
150 | def dln_hessian_eigs(u, v, X, y, normalized=False):
151 | hess = dln_hessian(u, v, X, y, normalized)
152 | eigs, _ = np.linalg.eig(hess)
153 | return eigs
154 |
155 |
156 | def dln_grad_loss(u, v, X, y):
157 | Xerr = X.T @ (X @ (u * v) - y)
158 | grad_u, grad_v = (Xerr * v) / X.shape[0], (Xerr * u) / X.shape[0]
159 | return np.concatenate([grad_u, grad_v])
160 |
161 |
162 | def dln_avg_individual_grad_loss(u, v, X, y, residuals=True):
163 | n = X.shape[0]
164 | r = X @ (u * v) - y
165 |
166 | if not residuals:
167 | r = np.ones_like(r)
168 |
169 | sum = 0
170 | for i in range(n):
171 | sum += np.sum((r[i] * X[i] * v)**2) / n
172 | sum += np.sum((r[i] * X[i] * u)**2) / n
173 |
174 | return sum
175 |
176 |
177 | def compute_grad_matrix_ranks(us, vs, X, l0_threshold_grad_matrix=0.0001):
178 | grad_matrix_ranks = []
179 | for u, v in zip(us, vs):
180 | grad_matrix = np.hstack([X * u, X * v])
181 | svals = np.linalg.svd(grad_matrix)[1]
182 | rank = (svals / svals[0] > l0_threshold_grad_matrix).sum()
183 | grad_matrix_ranks.append(rank)
184 | return grad_matrix_ranks
185 |
186 |
--------------------------------------------------------------------------------
/fc_nets.py:
--------------------------------------------------------------------------------
1 | from unittest import TestResult
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | import copy
6 |
7 |
8 | class FCNet2Layers(torch.nn.Module):
9 | def __init__(self, n_feature, n_hidden, n_output=1, biases=[False, False]):
10 | super(FCNet2Layers, self).__init__()
11 | self.layer1 = torch.nn.Linear(n_feature, n_hidden, bias=biases[0])
12 | self.layer2 = torch.nn.Linear(n_hidden, n_output, bias=biases[1])
13 |
14 | def init_gaussian(self, init_scales):
15 | self.layer1.weight.data = init_scales[0] * torch.randn_like(self.layer1.weight)
16 | self.layer2.weight.data = init_scales[1] * torch.randn_like(self.layer2.weight)
17 |
18 | def init_gaussian_clsf(self, init_scales):
19 | self.layer1.weight.data = init_scales[0] * torch.randn_like(self.layer1.weight)
20 | self.layer2.weight.data = ((torch.randn_like(self.layer2.weight) > 0).float() - 0.5) * 2 * (self.layer1.weight.data**2).sum(1)**0.5
21 |
22 | def init_blanc_et_al(self, init_scales):
23 | self.layer1.weight.data = 2.5 * (-1 + 2 * torch.round(torch.rand_like(self.layer1.weight))) * init_scales[0]
24 | self.layer1.bias.data = 2.5 * init_scales[0] * torch.randn_like(self.layer1.bias)
25 | self.layer2.weight.data = 4.0 * init_scales[1] * torch.randn_like(self.layer2.weight)
26 |
27 | def features(self, x, normalize=True, scaled=False):
28 | x = F.relu(self.layer1(x))
29 | if scaled:
30 | x = x * self.layer2.weight
31 | if normalize:
32 | x /= (x**2).sum(1, keepdim=True)**0.5
33 | x[torch.isnan(x)] = 0.0
34 | return x.data.numpy()
35 |
36 | def feature_sparsity(self, X, corr_threshold=0.99):
37 | phi = self.features(X)
38 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0]
39 | phi_filtered = phi[:, idx_keep] # filter out zeros
40 | corr_matrix = np.corrcoef(phi_filtered.T)
41 | corr_matrix -= np.eye(corr_matrix.shape[0])
42 |
43 | idx_to_delete, i, j = [], 0, 0
44 | while i != corr_matrix.shape[0]:
45 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum())
46 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0:
47 | corr_matrix = np.delete(corr_matrix, (i), axis=0)
48 | corr_matrix = np.delete(corr_matrix, (i), axis=1)
49 | # print('delete', j)
50 | idx_to_delete.append(j)
51 | else:
52 | i += 1
53 | j += 1
54 | assert corr_matrix.shape[0] == corr_matrix.shape[1]
55 | # print(idx_to_delete, idx_keep)
56 | idx_keep = np.delete(idx_keep, [idx_to_delete])
57 | sparsity = (phi[:, idx_keep] != 0).sum() / (phi.shape[0] * phi.shape[1])
58 |
59 | return sparsity
60 |
61 | def forward(self, x):
62 | z = F.relu(self.layer1(x))
63 | z = self.layer2(z)
64 | return z
65 |
66 |
67 | class FCNet(torch.nn.Module):
68 | def __init__(self, n_feature, n_hidden, biases=True):
69 | super(FCNet, self).__init__()
70 | self.n_hidden = [n_feature] + n_hidden + [1] # add the number of input and output units
71 | self.biases = biases
72 | self.layers = torch.nn.ModuleList()
73 | for i in range(len(self.n_hidden) - 1):
74 | self.layers.append(torch.nn.Linear(self.n_hidden[i], self.n_hidden[i+1], bias=self.biases))
75 |
76 | def init_gaussian(self, init_scales):
77 | for i in range(len(self.n_hidden) - 1):
78 | self.layers[i].weight.data = init_scales[i] * torch.randn_like(self.layers[i].weight)
79 |
80 | def forward(self, x):
81 | for i in range(len(self.n_hidden) - 2):
82 | x = F.relu(self.layers[i](x))
83 | x = self.layers[-1](x)
84 | return x
85 |
86 | def features(self, x, normalize=True, scaled=False, n_hidden_to_take=-1):
87 | for i in range(n_hidden_to_take if n_hidden_to_take > 0 else len(self.n_hidden) - 2):
88 | x = F.relu(self.layers[i](x))
89 | if scaled and n_hidden_to_take in [-1, len(self.n_hidden)]:
90 | x = x * self.layers[-1].weight
91 | if normalize:
92 | x /= (x**2).sum(1, keepdim=True)**0.5
93 | x[torch.isnan(x)] = 0.0
94 | return x.data.numpy()
95 |
96 | def feature_sparsity(self, X, n_hidden_to_take=-1, corr_threshold=0.99):
97 | phi = self.features(X, n_hidden_to_take=n_hidden_to_take)
98 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0]
99 | phi_filtered = phi[:, idx_keep] # filter out zeros
100 | corr_matrix = np.corrcoef(phi_filtered.T)
101 | corr_matrix -= np.eye(corr_matrix.shape[0])
102 |
103 | idx_to_delete, i, j = [], 0, 0
104 | while i != corr_matrix.shape[0]:
105 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum())
106 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0:
107 | corr_matrix = np.delete(corr_matrix, (i), axis=0)
108 | corr_matrix = np.delete(corr_matrix, (i), axis=1)
109 | # print('delete', j)
110 | idx_to_delete.append(j)
111 | else:
112 | i += 1
113 | j += 1
114 | assert corr_matrix.shape[0] == corr_matrix.shape[1]
115 | # print(idx_to_delete, idx_keep)
116 | idx_keep = np.delete(idx_keep, [idx_to_delete])
117 | sparsity = (phi[:, idx_keep] != 0).sum() / (phi.shape[0] * phi.shape[1])
118 |
119 | return sparsity
120 |
121 | def n_highly_corr(self, X, n_hidden_to_take=-1, corr_threshold=0.99):
122 | phi = self.features(X, n_hidden_to_take=n_hidden_to_take)
123 | idx_keep = np.where((phi > 0.0).sum(0) > 0)[0]
124 | phi_filtered = phi[:, idx_keep] # filter out zeros
125 | corr_matrix = np.corrcoef(phi_filtered.T)
126 | corr_matrix -= np.eye(corr_matrix.shape[0])
127 |
128 | idx_to_delete, i, j = [], 0, 0
129 | while i != corr_matrix.shape[0]:
130 | # print(i, corr_matrix.shape, (np.abs(corr_matrix[i]) > corr_threshold).sum())
131 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0:
132 | corr_matrix = np.delete(corr_matrix, (i), axis=0)
133 | corr_matrix = np.delete(corr_matrix, (i), axis=1)
134 | # print('delete', j)
135 | idx_to_delete.append(j)
136 | else:
137 | i += 1
138 | j += 1
139 | assert corr_matrix.shape[0] == corr_matrix.shape[1]
140 | # print(idx_to_delete, idx_keep)
141 | idx_keep = np.delete(idx_keep, [idx_to_delete])
142 | sparsity = (phi[:, idx_keep] != 0).sum() / (phi.shape[0] * phi.shape[1])
143 |
144 | return phi.shape[1] - len(idx_keep)
145 |
146 |
147 | def moving_average(net, net_avg, weight_avg):
148 | for param, param_avg in zip(net.parameters(), net_avg.parameters()):
149 | param_avg.data = weight_avg*param_avg.data + (1-weight_avg)*param.data
150 |
151 |
152 | def train_fc_net(X, y, X_test, y_test, gamma, batch_size, net, iters_loss, num_iter, thresholds=[-1], decays=[-1], iters_percentage_linear_warmup=0.0, gamma_warmup_factor_max=1.0, warmup_exponent=1.0, weight_avg=0.0, clsf=False, gauss_ln_scale=0.0):
153 | assert iters_percentage_linear_warmup <= decays[0], 'we should decay the step size only after warmup'
154 | train_losses, test_losses, nets_avg = [], [], []
155 | net, net_avg = copy.deepcopy(net), copy.deepcopy(net)
156 |
157 | loss_f = (lambda y_pred, y: torch.mean(torch.log(1 + torch.exp(-y_pred * y)))) if clsf else (lambda y_pred, y: torch.mean((y_pred - y)**2))
158 | # loss_f = lambda y_pred, y: torch.mean(torch.log(1 + torch.exp(-y_pred * y)))
159 |
160 | optimizer = torch.optim.SGD(net.parameters(), lr=gamma) #, momentum=0.9)
161 | for i in range(num_iter):
162 | if i in iters_loss:
163 | train_losses += [loss_f(net_avg(X), y)]
164 | test_losses += [loss_f(net_avg(X_test), y_test)]
165 | nets_avg.append(copy.deepcopy(net_avg))
166 | if torch.isnan(train_losses[-1]):
167 | return train_losses, test_losses, nets_avg
168 |
169 | if i <= int(iters_percentage_linear_warmup * num_iter) and int(iters_percentage_linear_warmup * num_iter) > 0:
170 | optimizer.param_groups[0]['lr'] = gamma + (gamma_warmup_factor_max - 1) * gamma * (i / int(num_iter))**warmup_exponent
171 | # optimizer.param_groups[0]['lr'] = gamma + gamma * gamma_warmup_factor_max * (i / int(iters_percentage_linear_warmup * num_iter))**warmup_exponent
172 |
173 | if i in thresholds:
174 | for threshold, decay in zip(thresholds, decays):
175 | if i == threshold:
176 | optimizer.param_groups[0]['lr'] /= decay
177 |
178 | indices = np.random.choice(X.shape[0], size=batch_size, replace=False)
179 | batch_x, batch_y = X[indices], y[indices]
180 | if gauss_ln_scale > 0.0: # label noise with schedule (note: supports only one threshold at the moment)
181 | batch_y += torch.randn_like(batch_y) * gauss_ln_scale / (decay if i > thresholds[0] else 1.0)
182 | loss = loss_f(net(batch_x), batch_y)
183 |
184 | optimizer.zero_grad()
185 | loss.backward()
186 | optimizer.step()
187 |
188 | moving_average(net, net_avg, weight_avg)
189 |
190 | return train_losses, test_losses, nets_avg
191 |
192 |
193 | def compute_grad_matrix(net, X):
194 | optimizer = torch.optim.SGD(net.parameters(), lr=0.0)
195 | grad_matrix_list = []
196 | for i in range(X.shape[0]):
197 | h = net(X[[i]])
198 | optimizer.zero_grad()
199 | h.backward()
200 |
201 | grad_total_list = []
202 | for param in net.parameters():
203 | grad_total_list.append(param.grad.flatten().data.numpy())
204 | grad_total = np.concatenate(grad_total_list)
205 | grad_matrix_list.append(grad_total)
206 |
207 | grad_matrix = np.vstack(grad_matrix_list)
208 | return grad_matrix
209 |
210 |
211 | def compute_grad_matrix_ranks(nets, X, l0_threshold_grad_matrix=0.0001):
212 | n_params = sum([np.prod(param.shape) for param in nets[-1].parameters()])
213 | X_eval = X[:n_params]
214 | grad_matrix_ranks = []
215 | for net in nets:
216 | svals = np.linalg.svd(compute_grad_matrix(net, X_eval))[1]
217 | rank = (svals / svals[0] > l0_threshold_grad_matrix).sum()
218 | grad_matrix_ranks.append(rank)
219 | return grad_matrix_ranks
220 |
221 |
--------------------------------------------------------------------------------
/images/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-epfl/sgd-sparse-features/656cb0d9e4e1cdd688073841f145fe9f94703d7c/images/fig1.png
--------------------------------------------------------------------------------
/images/twitter.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-epfl/sgd-sparse-features/656cb0d9e4e1cdd688073841f145fe9f94703d7c/images/twitter.gif
--------------------------------------------------------------------------------
/images/twitter.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-epfl/sgd-sparse-features/656cb0d9e4e1cdd688073841f145fe9f94703d7c/images/twitter.mp4
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from fc_nets import FCNet2Layers, FCNet, compute_grad_matrix
4 |
5 |
6 | def get_iters_eval(n_iter_power, x_log_scale, n_iters_first=101, n_iters_next=151):
7 | num_iter = int(10**n_iter_power) + 1
8 |
9 | iters_loss_first = np.array(range(100))
10 | if x_log_scale:
11 | iters_loss_next = np.unique(np.round(np.logspace(0, n_iter_power, n_iters_first)))
12 | else:
13 | iters_loss_next = np.unique(np.round(np.linspace(0, num_iter, n_iters_next)))[:-1]
14 | iters_loss = np.unique(np.concatenate((iters_loss_first, iters_loss_next)))
15 |
16 | return num_iter, iters_loss
17 |
18 |
19 | def get_data_two_layer_relu_net(n, d, m_teacher, init_scales_teacher, seed):
20 | np.random.seed(seed + 1)
21 | torch.manual_seed(seed + 1)
22 |
23 | n_test = 1000
24 | H = np.eye(d)
25 | X = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n)).float()
26 | X = X / torch.sum(X**2, 1, keepdim=True)**0.5
27 | X_test = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n_test)).float()
28 | X_test = X_test / torch.sum(X_test**2, 1, keepdim=True)**0.5
29 |
30 | # generate ground truth labels
31 | with torch.no_grad():
32 | net_teacher = FCNet2Layers(n_feature=d, n_hidden=m_teacher)
33 | net_teacher.init_gaussian(init_scales_teacher)
34 | net_teacher.layer1.weight.data = net_teacher.layer1.weight.data / torch.sum((net_teacher.layer1.weight.data)**2, 1, keepdim=True)**0.5
35 | net_teacher.layer2.weight.data = torch.sign(net_teacher.layer2.weight.data)
36 |
37 | y, y_test = net_teacher(X), net_teacher(X_test)
38 |
39 | print('y', y[:20, 0])
40 |
41 | return X, y, X_test, y_test, net_teacher
42 |
43 |
44 | def get_data_multi_layer_relu_net(n, d, m_teacher, init_scales_teacher, seed):
45 | np.random.seed(seed + 1)
46 | torch.manual_seed(seed + 1)
47 |
48 | n_test = 1000
49 | H = np.eye(d)
50 | X = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n)).float()
51 | X = X / torch.sum(X**2, 1, keepdim=True)**0.5
52 | X_test = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n_test)).float()
53 | X_test = X_test / torch.sum(X_test**2, 1, keepdim=True)**0.5
54 |
55 | # generate ground truth labels
56 | with torch.no_grad():
57 | net_teacher = FCNet(n_feature=d, n_hidden=m_teacher)
58 | net_teacher.init_gaussian(init_scales_teacher)
59 | y, y_test = net_teacher(X), net_teacher(X_test)
60 | print('y:', y[:, 0])
61 |
62 | return X, y, X_test, y_test, net_teacher
63 |
64 |
65 | def effective_rank(v):
66 | v = v[v != 0]
67 | v /= v.sum()
68 | return -(v * np.log(v)).sum()
69 |
70 |
71 | def rm_too_correlated(net, X, V, corr_threshold=0.99):
72 | V = V.T
73 | idx_keep = np.where((V > 0.0).sum(0) > 0)[0]
74 | V_filtered = V[:, idx_keep] # filter out zeros
75 | corr_matrix = np.corrcoef(V_filtered.T)
76 | corr_matrix -= np.eye(corr_matrix.shape[0])
77 |
78 | idx_to_delete, i, j = [], 0, 0
79 | while i != corr_matrix.shape[0]:
80 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0:
81 | corr_matrix = np.delete(corr_matrix, (i), axis=0)
82 | corr_matrix = np.delete(corr_matrix, (i), axis=1)
83 | # print('delete', j)
84 | idx_to_delete.append(j)
85 | else:
86 | i += 1
87 | j += 1
88 | assert corr_matrix.shape[0] == corr_matrix.shape[1]
89 | idx_keep = np.delete(idx_keep, [idx_to_delete])
90 |
91 | return V[:, idx_keep].T
92 |
93 | def compute_grad_matrix_dim(net, X, corr_threshold=0.99):
94 | grad_matrix = compute_grad_matrix(net, X)
95 | grad_matrix_sq_norms = np.sum(grad_matrix**2, 0)
96 | m = 100
97 | v_j = []
98 | for j in range(m):
99 | v_j.append(grad_matrix_sq_norms[[j, m+j, 2*m+j]]) # matrix: w1, w2, w3, w4
100 | V = np.vstack(v_j)
101 |
102 | V_reduced = rm_too_correlated(net, X, V, corr_threshold=corr_threshold)
103 | grad_matrix_dim = V_reduced.shape[0]
104 | return grad_matrix_dim
105 |
106 |
107 | def project(u, v, u0, v0, u1, v1, u2, v2):
108 | u, v = u - u0, v - v0
109 | alpha = (u @ u1 + v @ v1) / (np.sum(u1**2) + np.sum(v1**2))
110 | beta = (u @ u2 + v @ v2) / (np.sum(u2**2) + np.sum(v2**2))
111 | return alpha, beta
112 |
113 |
--------------------------------------------------------------------------------