├── .gitignore
├── LICENSE
├── __init__.py
├── absorb_bn.py
├── cross_entropy.py
├── dataset.py
├── functions.py
├── log.py
├── meters.py
├── misc.py
├── mixup.py
├── optim.py
├── param_filter.py
├── quantize.py
├── recorder.py
├── regime.py
└── regularization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 | # sphinx-apidoc automatically generated documentation
72 | docs/source/packages
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Elad Hoffer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladhoffer/utils.pytorch/7a2bbdb80835af0cc29c79653ee544f8301caa2b/__init__.py
--------------------------------------------------------------------------------
/absorb_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 |
5 |
6 | def remove_bn_params(bn_module):
7 | bn_module.register_buffer('running_mean', None)
8 | bn_module.register_buffer('running_var', None)
9 | bn_module.register_parameter('weight', None)
10 | bn_module.register_parameter('bias', None)
11 |
12 |
13 | def init_bn_params(bn_module):
14 | bn_module.running_mean.fill_(0)
15 | bn_module.running_var.fill_(1)
16 |
17 |
18 | def absorb_bn(module, bn_module, remove_bn=True, verbose=False):
19 | with torch.no_grad():
20 | w = module.weight
21 | if module.bias is None:
22 | zeros = torch.zeros(module.out_channels,
23 | dtype=w.dtype, device=w.device)
24 | bias = nn.Parameter(zeros)
25 | module.register_parameter('bias', bias)
26 | b = module.bias
27 |
28 | if hasattr(bn_module, 'running_mean'):
29 | b.add_(-bn_module.running_mean)
30 | if hasattr(bn_module, 'running_var'):
31 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5)
32 | w.mul_(invstd.view(w.size(0), 1, 1, 1))
33 | b.mul_(invstd)
34 |
35 | if remove_bn:
36 | if hasattr(bn_module, 'weight'):
37 | w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1))
38 | b.mul_(bn_module.weight)
39 | if hasattr(bn_module, 'bias'):
40 | b.add_(bn_module.bias)
41 | remove_bn_params(bn_module)
42 | else:
43 | init_bn_params(bn_module)
44 |
45 | if verbose:
46 | logging.info('BN module %s was asborbed into layer %s' %
47 | (bn_module, module))
48 |
49 |
50 | def is_bn(m):
51 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)
52 |
53 |
54 | def is_absorbing(m):
55 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear)
56 |
57 |
58 | def search_absorb_bn(model, prev=None, remove_bn=True, verbose=False):
59 | with torch.no_grad():
60 | for n, m in model.named_children():
61 | if is_bn(m) and is_absorbing(prev):
62 | absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose)
63 | if remove_bn:
64 | setattr(model, n, nn.Identity())
65 | search_absorb_bn(m, remove_bn=remove_bn, verbose=verbose)
66 | prev = m
67 |
--------------------------------------------------------------------------------
/cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .misc import onehot
6 |
7 |
8 | def _is_long(x):
9 | if hasattr(x, 'data'):
10 | x = x.data
11 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor)
12 |
13 |
14 | def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean',
15 | smooth_eps=None, smooth_dist=None, from_logits=True):
16 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567"""
17 | smooth_eps = smooth_eps or 0
18 |
19 | # ordinary log-liklihood - use cross_entropy from nn
20 | if _is_long(target) and smooth_eps == 0:
21 | if from_logits:
22 | return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
23 | else:
24 | return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction)
25 |
26 | if from_logits:
27 | # log-softmax of inputs
28 | lsm = F.log_softmax(inputs, dim=-1)
29 | else:
30 | lsm = inputs
31 |
32 | masked_indices = None
33 | num_classes = inputs.size(-1)
34 |
35 | if _is_long(target) and ignore_index >= 0:
36 | masked_indices = target.eq(ignore_index)
37 |
38 | if smooth_eps > 0 and smooth_dist is not None:
39 | if _is_long(target):
40 | target = onehot(target, num_classes).type_as(inputs)
41 | if smooth_dist.dim() < target.dim():
42 | smooth_dist = smooth_dist.unsqueeze(0)
43 | target.lerp_(smooth_dist, smooth_eps)
44 |
45 | if weight is not None:
46 | lsm = lsm * weight.unsqueeze(0)
47 |
48 | if _is_long(target):
49 | eps_sum = smooth_eps / num_classes
50 | eps_nll = 1. - eps_sum - smooth_eps
51 | likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
52 | loss = -(eps_nll * likelihood + eps_sum * lsm.sum(-1))
53 | else:
54 | loss = -(target * lsm).sum(-1)
55 |
56 | if masked_indices is not None:
57 | loss.masked_fill_(masked_indices, 0)
58 |
59 | if reduction == 'sum':
60 | loss = loss.sum()
61 | elif reduction == 'mean':
62 | if masked_indices is None:
63 | loss = loss.mean()
64 | else:
65 | loss = loss.sum() / float(loss.size(0) - masked_indices.sum())
66 |
67 | return loss
68 |
69 |
70 | class CrossEntropyLoss(nn.CrossEntropyLoss):
71 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing"""
72 |
73 | def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None, from_logits=True):
74 | super(CrossEntropyLoss, self).__init__(weight=weight,
75 | ignore_index=ignore_index, reduction=reduction)
76 | self.smooth_eps = smooth_eps
77 | self.smooth_dist = smooth_dist
78 | self.from_logits = from_logits
79 |
80 | def forward(self, input, target, smooth_dist=None):
81 | if smooth_dist is None:
82 | smooth_dist = self.smooth_dist
83 | return cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index,
84 | reduction=self.reduction, smooth_eps=self.smooth_eps,
85 | smooth_dist=smooth_dist, from_logits=self.from_logits)
86 |
87 |
88 | def binary_cross_entropy(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=False):
89 | """cross entropy loss, with support for label smoothing https://arxiv.org/abs/1512.00567"""
90 | smooth_eps = smooth_eps or 0
91 | if smooth_eps > 0:
92 | target = target.float()
93 | target.add_(smooth_eps).div_(2.)
94 | if from_logits:
95 | return F.binary_cross_entropy_with_logits(inputs, target, weight=weight, reduction=reduction)
96 | else:
97 | return F.binary_cross_entropy(inputs, target, weight=weight, reduction=reduction)
98 |
99 |
100 | def binary_cross_entropy_with_logits(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=True):
101 | return binary_cross_entropy(inputs, target, weight, reduction, smooth_eps, from_logits)
102 |
103 |
104 | class BCELoss(nn.BCELoss):
105 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=False):
106 | super(BCELoss, self).__init__(weight, size_average, reduce, reduction)
107 | self.smooth_eps = smooth_eps
108 | self.from_logits = from_logits
109 |
110 | def forward(self, input, target):
111 | return binary_cross_entropy(input, target,
112 | weight=self.weight, reduction=self.reduction,
113 | smooth_eps=self.smooth_eps, from_logits=self.from_logits)
114 |
115 |
116 | class BCEWithLogitsLoss(BCELoss):
117 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=True):
118 | super(BCEWithLogitsLoss, self).__init__(weight, size_average,
119 | reduce, reduction, smooth_eps=smooth_eps, from_logits=from_logits)
120 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import pickle
3 | import PIL
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torch.utils.data.sampler import Sampler, RandomSampler, BatchSampler, _int_classes
7 | from numpy.random import choice
8 | import csv
9 | from copy import copy
10 | import codecs
11 | from torch._utils import _accumulate
12 | from collections import Counter
13 |
14 |
15 | class RandomSamplerReplacment(torch.utils.data.sampler.Sampler):
16 | """Samples elements randomly, with replacement.
17 | Arguments:
18 | data_source (Dataset): dataset to sample from
19 | """
20 |
21 | def __init__(self, data_source):
22 | self.num_samples = len(data_source)
23 |
24 | def __iter__(self):
25 | return iter(torch.from_numpy(choice(self.num_samples, self.num_samples, replace=True)))
26 |
27 | def __len__(self):
28 | return self.num_samples
29 |
30 | class LongestLoader(torch.utils.data.dataloader.DataLoader):
31 | def __init__(self, *loaders):
32 | self.loaders = loaders
33 | self.num_workers = self.loaders[0].num_workers
34 |
35 | def __iter__(self):
36 | self.iterators = [iter(it) for it in self.loaders]
37 | max_length = len(self)
38 | for _ in range(max_length):
39 | values = []
40 | for i, it in enumerate(self.iterators):
41 | try:
42 | value = next(it)
43 | except StopIteration:
44 | self.iterators[i] = iter(self.loaders[i])
45 | value = next(self.iterators[i])
46 | values.append(value)
47 | yield tuple(values)
48 |
49 | def __len__(self):
50 | return max([len(arg) for arg in self.loaders])
51 |
52 | class LimitDataset(Dataset):
53 |
54 | def __init__(self, dset, max_len):
55 | self.dset = dset
56 | self.max_len = max_len
57 |
58 | def __len__(self):
59 | return min(len(self.dset), self.max_len)
60 |
61 | def __getitem__(self, index):
62 | return self.dset[index]
63 |
64 |
65 | class ByClassDataset(Dataset):
66 |
67 | def __init__(self, ds):
68 | self.dataset = ds
69 | self.idx_by_class = {}
70 | for idx, (_, c) in enumerate(ds):
71 | self.idx_by_class.setdefault(c, [])
72 | self.idx_by_class[c].append(idx)
73 |
74 | def __len__(self):
75 | return min([len(d) for d in self.idx_by_class.values()])
76 |
77 | def __getitem__(self, idx):
78 | idx_per_class = [self.idx_by_class[c][idx]
79 | for c in range(len(self.idx_by_class))]
80 | labels = torch.LongTensor([self.dataset[i][1]
81 | for i in idx_per_class])
82 | items = [self.dataset[i][0] for i in idx_per_class]
83 | if torch.is_tensor(items[0]):
84 | items = torch.stack(items)
85 |
86 | return (items, labels)
87 |
88 |
89 | class IdxDataset(Dataset):
90 | """docstring for IdxDataset."""
91 |
92 | def __init__(self, dset):
93 | super(IdxDataset, self).__init__()
94 | self.dset = dset
95 | self.idxs = range(len(self.dset))
96 |
97 | def __getitem__(self, idx):
98 | data, labels = self.dset[self.idxs[idx]]
99 | return (idx, data, labels)
100 |
101 | def __len__(self):
102 | return len(self.idxs)
103 |
104 |
105 | def image_loader(imagebytes):
106 | img = PIL.Image.open(BytesIO(imagebytes))
107 | return img.convert('RGB')
108 |
109 |
110 | class IndexedFileDataset(Dataset):
111 | """ A dataset that consists of an indexed file (with sample offsets in
112 | another file). For example, a .tar that contains image files.
113 | The dataset does not extract the samples, but works with the indexed
114 | file directly.
115 | NOTE: The index file is assumed to be a pickled list of 3-tuples:
116 | (name, offset, size).
117 | """
118 |
119 | def __init__(self, filename, index_filename=None, extract_target_fn=None,
120 | transform=None, target_transform=None, loader=image_loader):
121 | super(IndexedFileDataset, self).__init__()
122 |
123 | # Defaults
124 | if index_filename is None:
125 | index_filename = filename + '.index'
126 | if extract_target_fn is None:
127 | extract_target_fn = lambda *args: args
128 |
129 | # Read index
130 | with open(index_filename, 'rb') as index_fp:
131 | sample_list = pickle.load(index_fp)
132 |
133 | # Collect unique targets (sorted by name)
134 | targetset = set(extract_target_fn(target)
135 | for target, _, _ in sample_list)
136 | targetmap = {target: i for i, target in enumerate(sorted(targetset))}
137 |
138 | self.samples = [(targetmap[extract_target_fn(target)], offset, size)
139 | for target, offset, size in sample_list]
140 | self.filename = filename
141 |
142 | self.loader = loader
143 | self.transform = transform
144 | self.target_transform = target_transform
145 |
146 | def _get_sample(self, fp, idx):
147 | target, offset, size = self.samples[idx]
148 | fp.seek(offset)
149 | sample = self.loader(fp.read(size))
150 |
151 | if self.transform is not None:
152 | sample = self.transform(sample)
153 | if self.target_transform is not None:
154 | target = self.target_transform(target)
155 |
156 | return sample, target
157 |
158 | def __getitem__(self, index):
159 | with open(self.filename, 'rb') as fp:
160 | # Handle slices
161 | if isinstance(index, slice):
162 | return [self._get_sample(fp, subidx) for subidx in
163 | range(index.start or 0, index.stop or len(self),
164 | index.step or 1)]
165 |
166 | return self._get_sample(fp, index)
167 |
168 | def __len__(self):
169 | return len(self.samples)
170 |
171 |
172 | class DuplicateBatchSampler(Sampler):
173 | def __init__(self, sampler, batch_size, duplicates, drop_last):
174 | if not isinstance(sampler, Sampler):
175 | raise ValueError("sampler should be an instance of "
176 | "torch.utils.data.Sampler, but got sampler={}"
177 | .format(sampler))
178 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
179 | batch_size <= 0:
180 | raise ValueError("batch_size should be a positive integeral value, "
181 | "but got batch_size={}".format(batch_size))
182 | if not isinstance(drop_last, bool):
183 | raise ValueError("drop_last should be a boolean value, but got "
184 | "drop_last={}".format(drop_last))
185 | self.sampler = sampler
186 | self.batch_size = batch_size
187 | self.drop_last = drop_last
188 | self.duplicates = duplicates
189 |
190 | def __iter__(self):
191 | batch = []
192 | for idx in self.sampler:
193 | batch.append(idx)
194 | if len(batch) == self.batch_size:
195 | yield batch * self.duplicates
196 | batch = []
197 | if len(batch) > 0 and not self.drop_last:
198 | yield batch * self.duplicates
199 |
200 | def __len__(self):
201 | if self.drop_last:
202 | return len(self.sampler) // self.batch_size
203 | else:
204 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
205 |
206 |
207 | def list_line_locations(filename, limit=None):
208 | line_offset = []
209 | offset = 0
210 | with open(filename, "rb") as f:
211 | for line in f:
212 | line_offset.append(offset)
213 | offset += len(line)
214 | if limit is not None and len(line_offset) > limit:
215 | break
216 | return line_offset
217 |
218 |
219 | def _load_or_create(filename, create_fn, cache=True, force_create=False):
220 | loaded = False
221 | if not force_create:
222 | try:
223 | with open(filename, 'rb') as fp:
224 | value = pickle.load(fp)
225 | loaded = True
226 | except:
227 | pass
228 | if not loaded:
229 | value = create_fn()
230 | if cache and not loaded:
231 | with open(filename, 'wb') as fp:
232 | pickle.dump(value, fp)
233 | return value
234 |
235 |
236 | class LinedTextDataset(Dataset):
237 | """ Dataset in which every line is a seperate item (e.g translation)
238 | """
239 |
240 | def __init__(self, filename, transform=None, cache=True):
241 | self.filename = filename
242 | self.transform = transform
243 | self.items = _load_or_create(filename + '_cached_lines',
244 | create_fn=lambda: list_line_locations(
245 | filename),
246 | cache=cache)
247 |
248 | def __getitem__(self, index):
249 | if isinstance(index, slice):
250 | return [self[idx] for idx in range(index.start or 0, index.stop or len(self), index.step or 1)]
251 | with codecs.open(self.filename, encoding='UTF-8') as f:
252 | f.seek(self.items[index])
253 | item = f.readline()
254 | if self.transform is not None:
255 | item = self.transform(item)
256 | return item
257 |
258 | def __len__(self):
259 | return len(self.items)
260 |
261 | def select_range(self, start, end):
262 | new_dataset = copy(self)
263 | new_dataset.items = new_dataset.items[start:end]
264 | return new_dataset
265 |
266 | def filter(self, filter_func):
267 | new_dataset = copy(self)
268 | new_dataset.items = [item for item in self if filter_func(item)]
269 | return new_dataset
270 |
271 | def subset(self, indices):
272 | new_dataset = copy(self)
273 | new_dataset.items = [new_dataset.items[idx] for idx in indices]
274 | return new_dataset
275 |
276 | def split(self, lengths):
277 | """
278 | split a dataset into non-overlapping new datasets of given lengths.
279 | Arguments:
280 | dataset (Dataset): Dataset to be split
281 | lengths (sequence): lengths of splits to be produced
282 | """
283 | if sum(lengths) != len(self):
284 | raise ValueError(
285 | "Sum of input lengths does not equal the length of the input dataset!")
286 |
287 | return [self.select_range(offset-length, offset) for offset, length in zip(_accumulate(lengths), lengths)]
288 |
289 | def random_split(self, lengths):
290 | """
291 | Randomly split a dataset into non-overlapping new datasets of given lengths.
292 | Arguments:
293 | dataset (Dataset): Dataset to be split
294 | lengths (sequence): lengths of splits to be produced
295 | """
296 | if sum(lengths) != len(self):
297 | raise ValueError(
298 | "Sum of input lengths does not equal the length of the input dataset!")
299 |
300 | indices = torch.randperm(sum(lengths)).tolist()
301 | return [self.subset(indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
302 |
303 |
304 | class CSVDataset(LinedTextDataset):
305 | """ Dataset with delimited items and pre-knwon fieldnames (no header)
306 | """
307 |
308 | def __init__(self, filename, fieldnames=None, delimiter='\t', transform=None, cache=True):
309 | self.filename = filename
310 | self.fieldnames = fieldnames
311 | self.delimiter = delimiter
312 | self.transform = transform
313 | self.items = _load_or_create(filename + '_cached_lines',
314 | create_fn=lambda: list_line_locations(
315 | filename),
316 | cache=cache)
317 |
318 | def __getitem__(self, index):
319 | if isinstance(index, slice):
320 | return [self[idx] for idx in range(index.start or 0, index.stop or len(self), index.step or 1)]
321 | with codecs.open(self.filename, encoding='UTF-8') as f:
322 | f.seek(self.items[index])
323 | item = f.readline()
324 | item = next(csv.DictReader([item],
325 | fieldnames=self.fieldnames,
326 | delimiter=self.delimiter))
327 | if self.transform is not None:
328 | item = self.transform(item)
329 | return item
330 |
331 | def count_fields(self, fieldnames=None):
332 | fieldnames = fieldnames or self.fieldnames
333 | counters = {name: Counter() for name in fieldnames}
334 | for i in range(len(self)):
335 | value = self[i]
336 | for field in fieldnames:
337 | counters[field][value[field]] += 1
338 | return counters
339 |
--------------------------------------------------------------------------------
/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd.function import Function
3 |
4 | class ScaleGrad(Function):
5 |
6 | @staticmethod
7 | def forward(ctx, input, scale):
8 | ctx.scale = scale
9 | return input
10 |
11 | @staticmethod
12 | def backward(ctx, grad_output):
13 | grad_input = ctx.scale * grad_output
14 | return grad_input, None
15 |
16 |
17 | def scale_grad(x, scale):
18 | return ScaleGrad().apply(x, scale)
19 |
20 | def negate_grad(x):
21 | return scale_grad(x, -1)
22 |
--------------------------------------------------------------------------------
/log.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | from itertools import cycle
4 | import torch
5 | import logging.config
6 | from datetime import datetime
7 | import json
8 |
9 | import pandas as pd
10 | from bokeh.io import output_file, save, show
11 | from bokeh.plotting import figure
12 | from bokeh.layouts import column
13 | from bokeh.models import Div
14 |
15 | try:
16 | import hyperdash
17 | HYPERDASH_AVAILABLE = True
18 | except ImportError:
19 | HYPERDASH_AVAILABLE = False
20 |
21 |
22 | def export_args_namespace(args, filename):
23 | """
24 | args: argparse.Namespace
25 | arguments to save
26 | filename: string
27 | filename to save at
28 | """
29 | with open(filename, 'w') as fp:
30 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4)
31 |
32 |
33 | def setup_logging(log_file='log.txt', resume=False, dummy=False):
34 | """
35 | Setup logging configuration
36 | """
37 | if dummy:
38 | logging.getLogger('dummy')
39 | return
40 |
41 | file_mode = 'a' if os.path.isfile(log_file) and resume else 'w'
42 |
43 | root_logger = logging.getLogger()
44 | logging.basicConfig(level=logging.DEBUG,
45 | format="%(asctime)s - %(levelname)s - %(message)s",
46 | datefmt="%Y-%m-%d %H:%M:%S")
47 | # Remove all existing handlers (can't use the `force` option with
48 | # python < 3.8)
49 | for hdlr in root_logger.handlers[:]:
50 | root_logger.removeHandler(hdlr)
51 | # Add the handlers we want to use
52 | fileout = logging.FileHandler(log_file, mode=file_mode)
53 | fileout.setLevel(logging.DEBUG)
54 | fileout.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
55 | logging.getLogger().addHandler(fileout)
56 | console = logging.StreamHandler()
57 | console.setLevel(logging.INFO)
58 | console.setFormatter(logging.Formatter('%(message)s'))
59 | logging.getLogger().addHandler(console)
60 |
61 |
62 | def plot_figure(data, x, y, title=None, xlabel=None, ylabel=None, legend=None,
63 | x_axis_type='linear', y_axis_type='linear',
64 | width=800, height=400, line_width=2,
65 | colors=['red', 'green', 'blue', 'orange',
66 | 'black', 'purple', 'brown'],
67 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save',
68 | append_figure=None):
69 | """
70 | creates a new plot figures
71 | example:
72 | plot_figure(x='epoch', y=['train_loss', 'val_loss'],
73 | 'title='Loss', 'ylabel'='loss')
74 | """
75 | if not isinstance(y, list):
76 | y = [y]
77 | xlabel = xlabel or x
78 | legend = legend or y
79 | assert len(legend) == len(y)
80 | if append_figure is not None:
81 | f = append_figure
82 | else:
83 | f = figure(title=title, tools=tools,
84 | width=width, height=height,
85 | x_axis_label=xlabel or x,
86 | y_axis_label=ylabel or '',
87 | x_axis_type=x_axis_type,
88 | y_axis_type=y_axis_type)
89 | colors = cycle(colors)
90 | for i, yi in enumerate(y):
91 | f.line(data[x], data[yi],
92 | line_width=line_width,
93 | line_color=next(colors), legend_label=legend[i])
94 | f.legend.click_policy = "hide"
95 | return f
96 |
97 |
98 | class ResultsLog(object):
99 |
100 | supported_data_formats = ['csv', 'json']
101 |
102 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'):
103 | """
104 | Parameters
105 | ----------
106 | path: string
107 | path to directory to save data files
108 | plot_path: string
109 | path to directory to save plot files
110 | title: string
111 | title of HTML file
112 | params: Namespace
113 | optionally save parameters for results
114 | resume: bool
115 | resume previous logging
116 | data_format: str('csv'|'json')
117 | which file format to use to save the data
118 | """
119 | if data_format not in ResultsLog.supported_data_formats:
120 | raise ValueError('data_format must of the following: ' +
121 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats]))
122 |
123 | if data_format == 'json':
124 | self.data_path = '{}.json'.format(path)
125 | else:
126 | self.data_path = '{}.csv'.format(path)
127 | if params is not None:
128 | export_args_namespace(params, '{}.json'.format(path))
129 | self.plot_path = '{}.html'.format(path)
130 | self.results = None
131 | self.clear()
132 | self.first_save = True
133 | if os.path.isfile(self.data_path):
134 | if resume:
135 | self.load(self.data_path)
136 | self.first_save = False
137 | else:
138 | os.remove(self.data_path)
139 | self.results = pd.DataFrame()
140 | else:
141 | self.results = pd.DataFrame()
142 |
143 | self.title = title
144 | self.data_format = data_format
145 |
146 | if HYPERDASH_AVAILABLE:
147 | name = self.title if title != '' else path
148 | self.hd_experiment = hyperdash.Experiment(name)
149 | if params is not None:
150 | for k, v in params._get_kwargs():
151 | self.hd_experiment.param(k, v, log=False)
152 |
153 | def clear(self):
154 | self.figures = []
155 |
156 | def add(self, **kwargs):
157 | """Add a new row to the dataframe
158 | example:
159 | resultsLog.add(epoch=epoch_num, train_loss=loss,
160 | test_loss=test_loss)
161 | """
162 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
163 | self.results = self.results.append(df, ignore_index=True)
164 | if hasattr(self, 'hd_experiment'):
165 | for k, v in kwargs.items():
166 | self.hd_experiment.metric(k, v, log=False)
167 |
168 | def smooth(self, column_name, window):
169 | """Select an entry to smooth over time"""
170 | # TODO: smooth only new data
171 | smoothed_column = self.results[column_name].rolling(
172 | window=window, center=False).mean()
173 | self.results[column_name + '_smoothed'] = smoothed_column
174 |
175 | def save(self, title=None):
176 | """save the json file.
177 | Parameters
178 | ----------
179 | title: string
180 | title of the HTML file
181 | """
182 | title = title or self.title
183 | if len(self.figures) > 0:
184 | if os.path.isfile(self.plot_path):
185 | os.remove(self.plot_path)
186 | if self.first_save:
187 | self.first_save = False
188 | logging.info('Plot file saved at: {}'.format(
189 | os.path.abspath(self.plot_path)))
190 |
191 | output_file(self.plot_path, title=title)
192 | plot = column(
193 | Div(text='
{}
'.format(title)), *self.figures)
194 | save(plot)
195 | self.clear()
196 |
197 | if self.data_format == 'json':
198 | self.results.to_json(self.data_path, orient='records', lines=True)
199 | else:
200 | self.results.to_csv(self.data_path, index=False, index_label=False)
201 |
202 | def load(self, path=None):
203 | """load the data file
204 | Parameters
205 | ----------
206 | path:
207 | path to load the json|csv file from
208 | """
209 | path = path or self.data_path
210 | if os.path.isfile(path):
211 | if self.data_format == 'json':
212 | self.results.read_json(path)
213 | else:
214 | self.results.read_csv(path)
215 | else:
216 | raise ValueError('{} isn''t a file'.format(path))
217 |
218 | def show(self, title=None):
219 | title = title or self.title
220 | if len(self.figures) > 0:
221 | plot = column(
222 | Div(text='{}
'.format(title)), *self.figures)
223 | show(plot)
224 |
225 | def plot(self, *kargs, **kwargs):
226 | """
227 | add a new plot to the HTML file
228 | example:
229 | results.plot(x='epoch', y=['train_loss', 'val_loss'],
230 | 'title='Loss', 'ylabel'='loss')
231 | """
232 | f = plot_figure(self.results, *kargs, **kwargs)
233 | self.figures.append(f)
234 |
235 | def image(self, *kargs, **kwargs):
236 | fig = figure()
237 | fig.image(*kargs, **kwargs)
238 | self.figures.append(fig)
239 |
240 | def end(self):
241 | if hasattr(self, 'hd_experiment'):
242 | self.hd_experiment.end()
243 |
244 |
245 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False):
246 | filename = os.path.join(path, filename)
247 | torch.save(state, filename)
248 | if is_best:
249 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar'))
250 | if save_all:
251 | shutil.copyfile(filename, os.path.join(
252 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch']))
253 |
--------------------------------------------------------------------------------
/meters.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.reset()
9 |
10 | def reset(self):
11 | self.val = 0
12 | self.avg = 0
13 | self.sum = 0
14 | self.count = 0
15 |
16 | def update(self, val, n=1):
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
22 |
23 | class OnlineMeter(object):
24 | """Computes and stores the average and variance/std values of tensor"""
25 |
26 | def __init__(self):
27 | self.mean = torch.FloatTensor(1).fill_(-1)
28 | self.M2 = torch.FloatTensor(1).zero_()
29 | self.count = 0.
30 | self.needs_init = True
31 |
32 | def reset(self, x):
33 | self.mean = x.new(x.size()).zero_()
34 | self.M2 = x.new(x.size()).zero_()
35 | self.count = 0.
36 | self.needs_init = False
37 |
38 | def update(self, x):
39 | self.val = x
40 | if self.needs_init:
41 | self.reset(x)
42 | self.count += 1
43 | delta = x - self.mean
44 | self.mean.add_(delta / self.count)
45 | delta2 = x - self.mean
46 | self.M2.add_(delta * delta2)
47 |
48 | @property
49 | def var(self):
50 | if self.count < 2:
51 | return self.M2.clone().zero_()
52 | return self.M2 / (self.count - 1)
53 |
54 | @property
55 | def std(self):
56 | return self.var().sqrt()
57 |
58 |
59 | def accuracy(output, target, topk=(1,)):
60 | """Computes the precision@k for the specified values of k"""
61 | maxk = max(topk)
62 | batch_size = target.size(0)
63 |
64 | _, pred = output.topk(maxk, 1, True, True)
65 | pred = pred.t().type_as(target)
66 | correct = pred.eq(target.view(1, -1).expand_as(pred))
67 |
68 | res = []
69 | for k in topk:
70 | correct_k = correct[:k].reshape(-1).float().sum(0)
71 | res.append(correct_k.mul_(100.0 / batch_size))
72 | return res
73 |
74 |
75 | class AccuracyMeter(object):
76 | """Computes and stores the average and current topk accuracy"""
77 |
78 | def __init__(self, topk=(1,)):
79 | self.topk = topk
80 | self.reset()
81 |
82 | def reset(self):
83 | self._meters = {}
84 | for k in self.topk:
85 | self._meters[k] = AverageMeter()
86 |
87 | def update(self, output, target):
88 | n = target.nelement()
89 | acc_vals = accuracy(output, target, self.topk)
90 | for i, k in enumerate(self.topk):
91 | self._meters[k].update(acc_vals[i])
92 |
93 | @property
94 | def val(self):
95 | return {n: meter.val for (n, meter) in self._meters.items()}
96 |
97 | @property
98 | def avg(self):
99 | return {n: meter.avg for (n, meter) in self._meters.items()}
100 |
101 | @property
102 | def avg_error(self):
103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()}
104 |
--------------------------------------------------------------------------------
/misc.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
6 | from contextlib import contextmanager
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 | torch_dtypes = {
10 | 'float': torch.float,
11 | 'float32': torch.float32,
12 | 'float64': torch.float64,
13 | 'double': torch.double,
14 | 'float16': torch.float16,
15 | 'half': torch.half,
16 | 'uint8': torch.uint8,
17 | 'int8': torch.int8,
18 | 'int16': torch.int16,
19 | 'short': torch.short,
20 | 'int32': torch.int32,
21 | 'int': torch.int,
22 | 'int64': torch.int64,
23 | 'long': torch.long
24 | }
25 |
26 |
27 | def onehot(indexes, N=None, ignore_index=None):
28 | """
29 | Creates a one-representation of indexes with N possible entries
30 | if N is not specified, it will suit the maximum index appearing.
31 | indexes is a long-tensor of indexes
32 | ignore_index will be zero in onehot representation
33 | """
34 | if N is None:
35 | N = indexes.max() + 1
36 | sz = list(indexes.size())
37 | output = indexes.new().byte().resize_(*sz, N).zero_()
38 | output.scatter_(-1, indexes.unsqueeze(-1), 1)
39 | if ignore_index is not None and ignore_index >= 0:
40 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0)
41 | return output
42 |
43 |
44 | @contextmanager
45 | def no_bn_update(model):
46 | prev_momentum = {}
47 | for name, module in model.named_modules():
48 | if isinstance(module, nn.BatchNorm2d):
49 | prev_momentum[name] = module.momentum
50 | module.momentum = 0
51 | try:
52 | yield model
53 | finally:
54 | for name, module in model.named_modules():
55 | if isinstance(module, nn.BatchNorm2d):
56 | module.momentum = prev_momentum[name]
57 |
58 |
59 | @contextmanager
60 | def calibrate_bn(model):
61 | prev_attributes = {}
62 | for name, module in model.named_modules():
63 | if isinstance(module, _BatchNorm):
64 | prev_attributes[name] = prev_attributes.get(name, {})
65 | prev_attributes[name]['momentum'] = module.momentum
66 | prev_attributes[name]['track_running_stats'] = module.track_running_stats
67 | prev_attributes[name]['training'] = module.training
68 | module.momentum = None
69 | module.track_running_stats = True
70 | module.reset_running_stats()
71 | module.train()
72 | try:
73 | yield model
74 | finally:
75 | for name, module in model.named_modules():
76 | if isinstance(module, _BatchNorm):
77 | module.momentum = prev_attributes[name]['momentum']
78 | module.track_running_stats = prev_attributes[name]['track_running_stats']
79 | module.train(prev_attributes[name]['training'])
80 |
81 |
82 |
83 | def set_global_seeds(i):
84 | try:
85 | import torch
86 | except ImportError:
87 | pass
88 | else:
89 | torch.manual_seed(i)
90 | if torch.cuda.is_available():
91 | torch.cuda.manual_seed_all(i)
92 | np.random.seed(i)
93 | random.seed(i)
94 |
95 |
96 | class CheckpointModule(nn.Module):
97 | def __init__(self, module, num_segments=1):
98 | super(CheckpointModule, self).__init__()
99 | assert num_segments == 1 or isinstance(module, nn.Sequential)
100 | self.module = module
101 | self.num_segments = num_segments
102 |
103 | def forward(self, x):
104 | if self.num_segments > 1:
105 | return checkpoint_sequential(self.module, self.num_segments, x)
106 | else:
107 | return checkpoint(self.module, x)
108 |
--------------------------------------------------------------------------------
/mixup.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from numpy.random import beta
5 | from torch.nn.functional import one_hot
6 |
7 |
8 | class MixUp(nn.Module):
9 | def __init__(self, alpha=1., batch_dim=0):
10 | super(MixUp, self).__init__()
11 | self.batch_dim = batch_dim
12 | self.alpha = alpha
13 | self.reset()
14 |
15 | def reset(self):
16 | self.mix_value = None
17 | self.mix_index = None
18 |
19 | def mix(self, x1, x2):
20 | return x2.lerp(x1, self.mix_value)
21 |
22 | def sample(self, batch_size, alpha=None, even=False):
23 | alpha = self.alpha if alpha is None else alpha
24 | self.mix_index = torch.randperm(batch_size)
25 | self.mix_value = beta(alpha, alpha)
26 | if not even:
27 | self.mix_value = max(self.mix_value, 1 - self.mix_value)
28 |
29 | def mix_target(self, y, n_class):
30 | if not self.training or \
31 | self.mix_value is None:
32 | return y
33 | y = one_hot(y, n_class).to(dtype=torch.float)
34 | idx = self.mix_index.to(device=y.device)
35 | y_mix = y.index_select(self.batch_dim, idx)
36 | return self.mix(y, y_mix)
37 |
38 | def forward(self, x):
39 | if not self.training or \
40 | self.mix_value is None:
41 | return x
42 | idx = self.mix_index.to(device=x.device)
43 | x_mix = x.index_select(self.batch_dim, idx)
44 | return self.mix(x, x_mix)
45 |
46 |
47 | def rand_bbox(size, lam):
48 | W = size[2]
49 | H = size[3]
50 | cut_rat = np.sqrt(1. - lam)
51 | cut_w = np.int(W * cut_rat)
52 | cut_h = np.int(H * cut_rat)
53 |
54 | # uniform
55 | cx = np.random.randint(W)
56 | cy = np.random.randint(H)
57 |
58 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
59 | bby1 = np.clip(cy - cut_h // 2, 0, H)
60 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
61 | bby2 = np.clip(cy + cut_h // 2, 0, H)
62 |
63 | return bbx1, bby1, bbx2, bby2
64 |
65 |
66 | class CutMix(MixUp):
67 | def __init__(self, alpha=1., batch_dim=0):
68 | super(CutMix, self).__init__(alpha, batch_dim)
69 |
70 | def mix_image(self, x1, x2):
71 | assert not torch.is_tensor(self.mix_value) or \
72 | self.mix_value.nelement() == 1
73 | lam = float(self.mix_value)
74 | bbx1, bby1, bbx2, bby2 = rand_bbox(x1.size(), lam)
75 | x1[:, :, bbx1:bbx2, bby1:bby2] = x2[:, :, bbx1:bbx2, bby1:bby2]
76 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
77 | (x1.size()[-1] * x1.size()[-2]))
78 | self.mix_value = lam
79 | return x1
80 |
81 | def sample(self, batch_size, alpha=None):
82 | super(CutMix, self).sample(batch_size, alpha=alpha)
83 |
84 | def forward(self, x):
85 | if not self.training or \
86 | self.mix_value is None:
87 | return x
88 | idx = self.mix_index.to(device=x.device)
89 | x_mix = x.index_select(self.batch_dim, idx)
90 | return self.mix_image(x, x_mix)
91 |
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging.config
3 | import math
4 | from math import floor
5 | from copy import deepcopy
6 | from six import string_types
7 | from .regime import Regime
8 | from .param_filter import FilterParameters
9 | from . import regularization
10 | import torch.nn as nn
11 | from torch.optim.lr_scheduler import _LRScheduler
12 |
13 | _OPTIMIZERS = {name: func for name, func in torch.optim.__dict__.items()}
14 | _LRSCHEDULERS = {name: func for name,
15 | func in torch.optim.lr_scheduler.__dict__.items()}
16 |
17 | try:
18 | from adabound import AdaBound
19 | _OPTIMIZERS['AdaBound'] = AdaBound
20 | except ImportError:
21 | pass
22 |
23 |
24 | def cosine_anneal_lr(lr0, lrT, T, t0=0):
25 | return f"lambda t: {{'lr': {lrT} + {(lr0 - lrT)} * (1 + math.cos(math.pi * (t - {t0}) / {T-t0})) / 2}}"
26 |
27 |
28 | def linear_scale_lr(lr0, lrT, T, t0=0):
29 | rate = (lrT - lr0) / T
30 | return f"lambda t: {{'lr': max({lr0} + (t - {t0}) * {rate}, 0)}}"
31 |
32 |
33 | class _EmptySchedule(torch.optim.lr_scheduler._LRScheduler):
34 | def __init__(self, optimizer, last_epoch=-1):
35 | super(_EmptySchedule, self).__init__(optimizer, last_epoch=-1)
36 | self.last_epoch = 0
37 |
38 | def step(self, epoch=None):
39 | if epoch is None:
40 | epoch = self.last_epoch + 1
41 |
42 |
43 | def copy_params(param_target, param_src):
44 | with torch.no_grad():
45 | for p_src, p_target in zip(param_src, param_target):
46 | p_target.copy_(p_src)
47 |
48 |
49 | def copy_params_grad(param_target, param_src):
50 | for p_src, p_target in zip(param_src, param_target):
51 | if p_target.grad is None:
52 | p_target.backward(p_src.grad.to(dtype=p_target.dtype))
53 | else:
54 | p_target.grad.detach().copy_(p_src.grad)
55 |
56 |
57 | class ModuleFloatShadow(nn.Module):
58 | def __init__(self, module):
59 | super(ModuleFloatShadow, self).__init__()
60 | self.original_module = module
61 | self.float_module = deepcopy(module)
62 | self.float_module.to(dtype=torch.float)
63 |
64 | def parameters(self, *kargs, **kwargs):
65 | return self.float_module.parameters(*kargs, **kwargs)
66 |
67 | def named_parameters(self, *kargs, **kwargs):
68 | return self.float_module.named_parameters(*kargs, **kwargs)
69 |
70 | def modules(self, *kargs, **kwargs):
71 | return self.float_module.modules(*kargs, **kwargs)
72 |
73 | def named_modules(self, *kargs, **kwargs):
74 | return self.float_module.named_modules(*kargs, **kwargs)
75 |
76 | def original_parameters(self, *kargs, **kwargs):
77 | return self.original_module.parameters(*kargs, **kwargs)
78 |
79 | def original_named_parameters(self, *kargs, **kwargs):
80 | return self.original_module.named_parameters(*kargs, **kwargs)
81 |
82 | def original_modules(self, *kargs, **kwargs):
83 | return self.original_module.modules(*kargs, **kwargs)
84 |
85 | def original_named_modules(self, *kargs, **kwargs):
86 | return self.original_module.named_modules(*kargs, **kwargs)
87 |
88 |
89 | class OptimRegime(Regime, torch.optim.Optimizer):
90 | """
91 | Reconfigures the optimizer according to setting list.
92 | Exposes optimizer methods - state, step, zero_grad, add_param_group
93 |
94 | Examples for regime:
95 |
96 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3},
97 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4},
98 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4},
99 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5}
100 | ]"
101 | 2)
102 | "[{'step_lambda':
103 | "lambda t: {
104 | 'optimizer': 'Adam',
105 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5),
106 | 'betas': (0.9, 0.98), 'eps':1e-9}
107 | }]"
108 | """
109 |
110 | def __init__(self, model, regime, defaults={}, filter=None, use_float_copy=False, log=True):
111 | super(OptimRegime, self).__init__(regime, defaults)
112 | if filter is not None:
113 | model = FilterParameters(model, **filter)
114 | if use_float_copy:
115 | model = ModuleFloatShadow(model)
116 | self._original_parameters = list(model.original_parameters())
117 |
118 | self.parameters = list(model.parameters())
119 | self.optimizer = torch.optim.SGD(self.parameters, lr=0)
120 | self.regularizer = regularization.Regularizer(model)
121 | self.use_float_copy = use_float_copy
122 | self.lr_scheduler = _EmptySchedule(self.optimizer, last_epoch=-1)
123 | self.schedule_time_frame = 'epoch'
124 | self.log = log
125 |
126 | def update(self, epoch=None, train_steps=None, metrics=None):
127 | """adjusts optimizer according to current epoch or steps and training regime.
128 | """
129 | updated = False
130 | if super(OptimRegime, self).update(epoch, train_steps):
131 | self.adjust(self.setting)
132 | updated = True
133 | if self.schedule_time_frame == 'epoch':
134 | time = int(floor(epoch)) + 1
135 | elif self.schedule_time_frame == 'step':
136 | time = train_steps
137 | else:
138 | raise ValueError
139 | if (time != self.lr_scheduler.last_epoch) and \
140 | getattr(self.optimizer, '_step_count', 0) > 0:
141 | prev_lr = self.get_lr()[0]
142 | if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
143 | self.lr_scheduler.step(metrics)
144 | self.lr_scheduler.step()
145 | updated = True
146 | if prev_lr != self.get_lr()[0] and self.log:
147 | logging.debug('OPTIMIZER - lr scheduled = %s'
148 | % self.get_lr()[0])
149 | return updated
150 |
151 | def adjust(self, setting):
152 | """adjusts optimizer according to a setting dict.
153 | e.g: setting={optimizer': 'Adam', 'lr': 5e-4}
154 | """
155 | reset = setting.get('reset', False)
156 | if 'optimizer' in setting or reset:
157 | optim_method = _OPTIMIZERS[setting.get('optimizer', 'SGD')]
158 | if reset: # reset the optimizer cache:
159 | self.optimizer = torch.optim.SGD(self.parameters, lr=0)
160 | if self.log:
161 | logging.debug('OPTIMIZER - reset setting')
162 | if not isinstance(self.optimizer, optim_method):
163 | self.optimizer = optim_method(self.optimizer.param_groups)
164 | if self.log:
165 | logging.debug('OPTIMIZER - setting method = %s' %
166 | setting['optimizer'])
167 | for param_group in self.optimizer.param_groups:
168 | for key in param_group.keys():
169 | if key in setting:
170 | new_val = setting[key]
171 | if new_val != param_group[key]:
172 | if self.log:
173 | logging.debug('OPTIMIZER - setting %s = %s' %
174 | (key, setting[key]))
175 | param_group[key] = setting[key]
176 | if key == 'lr':
177 | param_group['initial_lr'] = param_group['lr']
178 | base_lrs = list(map(lambda group: group['lr'],
179 | self.optimizer.param_groups))
180 | self.lr_scheduler.base_lrs = base_lrs
181 |
182 | # fix for AdaBound
183 | if hasattr(self.optimizer, 'base_lrs'):
184 | self.optimizer.base_lrs = base_lrs
185 |
186 | if 'regularizer' in setting:
187 | reg_list = deepcopy(setting['regularizer'])
188 | if not (isinstance(reg_list, list) or isinstance(reg_list, tuple)):
189 | reg_list = (reg_list,)
190 | regularizers = []
191 | for reg in reg_list:
192 | if isinstance(reg, dict):
193 | name = reg.pop('name')
194 | regularizers.append((regularization.__dict__[name], reg))
195 | elif isinstance(reg, regularization.Regularizer):
196 | regularizers.append(reg)
197 | else: # callable on model
198 | regularizers.append(reg(self.regularizer._model))
199 | self.regularizer = regularization.RegularizerList(self.regularizer._model,
200 | regularizers)
201 |
202 | if 'lr_scheduler' in setting:
203 | schedule_config = setting['lr_scheduler']
204 | if isinstance(schedule_config, _LRScheduler):
205 | self.lr_scheduler = schedule_config
206 | elif isinstance(schedule_config, dict):
207 | name = schedule_config.pop('name')
208 | self.schedule_time_frame = schedule_config.pop('time_frame',
209 | 'epoch')
210 | schedule_config['last_epoch'] = self.lr_scheduler.last_epoch
211 | self.lr_scheduler = _LRSCHEDULERS[name](self.optimizer,
212 | **schedule_config)
213 | elif schedule_config is None:
214 | self.lr_scheduler = _EmptySchedule(self.optimizer,
215 | last_epoch=self.lr_scheduler.last_epoch)
216 | else: # invalid config
217 | raise NotImplementedError
218 |
219 | def __getstate__(self):
220 | return {
221 | 'optimizer_state': self.optimizer.__getstate__(),
222 | 'regime': self.regime,
223 | }
224 |
225 | def __setstate__(self, state):
226 | self.regime = state.get('regime')
227 | self.optimizer.__setstate__(state.get('optimizer_state'))
228 |
229 | def state_dict(self):
230 | """Returns the state of the optimizer as a :class:`dict`.
231 | """
232 | return self.optimizer.state_dict()
233 |
234 | def load_state_dict(self, state_dict):
235 | """Loads the optimizer state.
236 |
237 | Arguments:
238 | state_dict (dict): optimizer state. Should be an object returned
239 | from a call to :meth:`state_dict`.
240 | """
241 | # deepcopy, to be consistent with module API
242 | self.optimizer.load_state_dict(state_dict)
243 |
244 | def zero_grad(self):
245 | """Clears the gradients of all optimized :class:`Variable` s."""
246 | self.optimizer.zero_grad()
247 | if self.use_float_copy:
248 | for p in self._original_parameters:
249 | if p.grad is not None:
250 | p.grad.detach().zero_()
251 |
252 | def pre_step(self):
253 | if self.use_float_copy:
254 | copy_params_grad(self.parameters, self._original_parameters)
255 | self.regularizer.pre_step()
256 |
257 | def post_step(self):
258 | self.regularizer.post_step()
259 |
260 | if self.use_float_copy:
261 | copy_params(self._original_parameters, self.parameters)
262 |
263 | def step(self, *args, **kwargs):
264 | """Performs a single optimization step (parameter update).
265 | """
266 | self.pre_step()
267 | self.optimizer.step(*args, **kwargs)
268 | self.post_step()
269 |
270 | def pre_forward(self):
271 | """ allows modification pre-forward pass - e.g for regularization
272 | """
273 | self.regularizer.pre_forward()
274 |
275 | def pre_backward(self):
276 | """ allows modification post-forward pass and pre-backward - e.g for regularization
277 | """
278 | self.regularizer.pre_backward()
279 |
280 | def get_value(self, key):
281 | return [group[key] for group in self.optimizer.param_groups]
282 |
283 | def get_lr(self):
284 | return self.get_value('lr')
285 |
286 | @property
287 | def state(self):
288 | return self.optimizer.state
289 |
290 |
291 | class MultiOptimRegime(OptimRegime):
292 |
293 | def __init__(self, *optim_regime_list, log=True):
294 | self.optim_regime_list = []
295 | for optim_regime in optim_regime_list:
296 | assert isinstance(optim_regime, OptimRegime)
297 | self.optim_regime_list.append(optim_regime)
298 | self.log = log
299 |
300 | def update(self, epoch=None, train_steps=None):
301 | """adjusts optimizer according to current epoch or steps and training regime.
302 | """
303 | updated = False
304 | for i, optim in enumerate(self.optim_regime_list):
305 | current_updated = optim.update(epoch, train_steps)
306 | if current_updated and self.log:
307 | logging.debug('OPTIMIZER #%s was updated' % i)
308 | updated = updated or current_updated
309 | return updated
310 |
311 | def zero_grad(self):
312 | """Clears the gradients of all optimized :class:`Variable` s."""
313 | for optim in self.optim_regime_list:
314 | optim.zero_grad()
315 |
316 | def step(self):
317 | """Performs a single optimization step (parameter update).
318 | """
319 | for optim in self.optim_regime_list:
320 | optim.step()
321 |
322 | def pre_forward(self):
323 | for optim in self.optim_regime_list:
324 | optim.pre_forward()
325 |
326 | def pre_backward(self):
327 | for optim in self.optim_regime_list:
328 | optim.pre_backward()
329 |
330 | def __repr__(self):
331 | return str([str(optim) for optim in self.optim_regime_list])
332 |
333 | def get_value(self, key):
334 | return [[group[key] for group in optim.optimizer.param_groups]
335 | for optim in self.optim_regime_list]
336 |
337 | def get_lr(self):
338 | return self.get_value('lr')
339 |
--------------------------------------------------------------------------------
/param_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | def _search_pattern_fn(pattern):
7 | def _search(name):
8 | return re.search(pattern, name) is not None
9 | return _search
10 |
11 |
12 | def _search_type_pattern_fn(pattern):
13 | def _search(var):
14 | return re.search(pattern, type(var).__name__) is not None
15 | return _search
16 |
17 |
18 | def is_not_bias(name):
19 | return not name.endswith('bias')
20 |
21 |
22 | def is_bn(module):
23 | return isinstance(module, nn.modules.batchnorm._BatchNorm)
24 |
25 |
26 | def is_not_bn(module):
27 | return not is_bn(module)
28 |
29 |
30 | def _negate_fn(fn):
31 | if fn is None:
32 | return None
33 | else:
34 | def _negate(*kargs, **kwargs):
35 | return not fn(*kargs, **kwargs)
36 | return _negate
37 |
38 |
39 | def filtered_parameter_info(model, module_fn=None, module_name_fn=None, parameter_name_fn=None, memo=None):
40 | if memo is None:
41 | memo = set()
42 |
43 | for module_name, module in model.named_modules():
44 | if module_fn is not None and not module_fn(module):
45 | continue
46 | if module_name_fn is not None and not module_name_fn(module_name):
47 | continue
48 | for parameter_name, param in module.named_parameters(prefix=module_name, recurse=False):
49 | if parameter_name_fn is not None and not parameter_name_fn(parameter_name):
50 | continue
51 | if param not in memo:
52 | memo.add(param)
53 | yield {'named_module': (module_name, module), 'named_parameter': (parameter_name, param)}
54 |
55 |
56 | class FilterParameters(object):
57 | def __init__(self, source, module=None, module_name=None, parameter_name=None, exclude=False):
58 | if isinstance(module_name, str):
59 | module_name = _search_pattern_fn(module_name)
60 | if isinstance(parameter_name, str):
61 | parameter_name = _search_pattern_fn(parameter_name)
62 | if isinstance(module, str):
63 | module = _search_type_pattern_fn(module)
64 | if exclude:
65 | module_name = _negate_fn(module_name)
66 | parameter_name = _negate_fn(parameter_name)
67 | module = _negate_fn(module)
68 | if isinstance(source, FilterParameters):
69 | self._filtered_parameter_info = list(source.filter(
70 | module=module,
71 | module_name=module_name,
72 | parameter_name=parameter_name))
73 | elif isinstance(source, torch.nn.Module): # source is a model
74 | self._filtered_parameter_info = list(filtered_parameter_info(source,
75 | module_fn=module,
76 | module_name_fn=module_name,
77 | parameter_name_fn=parameter_name))
78 |
79 | def named_parameters(self):
80 | for p in self._filtered_parameter_info:
81 | yield p['named_parameter']
82 |
83 | def parameters(self):
84 | for _, p in self.named_parameters():
85 | yield p
86 |
87 | def filter(self, module=None, module_name=None, parameter_name=None):
88 | for p_info in self._filtered_parameter_info:
89 | if (module is None or module(p_info['named_module'][1])
90 | and (module_name is None or module_name(p_info['named_module'][0]))
91 | and (parameter_name is None or parameter_name(p_info['named_parameter'][0]))):
92 | yield p_info
93 |
94 | def named_modules(self):
95 | for m in self._filtered_parameter_info:
96 | yield m['named_module']
97 |
98 | def modules(self):
99 | for _, m in self.named_modules():
100 | yield m
101 |
102 | def to(self, *kargs, **kwargs):
103 | for m in self.modules():
104 | m.to(*kargs, **kwargs)
105 |
106 |
107 | class FilterModules(FilterParameters):
108 | pass
109 |
110 |
111 | if __name__ == '__main__':
112 | from torchvision.models import resnet50
113 | model = resnet50()
114 | filterd_params = FilterParameters(model,
115 | module=lambda m: isinstance(
116 | m, torch.nn.Linear),
117 | parameter_name=lambda n: 'bias' in n)
118 |
--------------------------------------------------------------------------------
/quantize.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | import torch.nn as nn
4 |
5 | QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])
6 |
7 |
8 | def quantize_tensor(x, num_bits=8):
9 | qmin = 0.
10 | qmax = 2.**num_bits - 1.
11 | min_val, max_val = x.min(), x.max()
12 |
13 | scale = (max_val - min_val) / (qmax - qmin)
14 |
15 | initial_zero_point = qmin - min_val / scale
16 |
17 | zero_point = 0
18 | if initial_zero_point < qmin:
19 | zero_point = qmin
20 | elif initial_zero_point > qmax:
21 | zero_point = qmax
22 | else:
23 | zero_point = initial_zero_point
24 |
25 | zero_point = int(zero_point)
26 | q_x = zero_point + x / scale
27 | q_x.clamp_(qmin, qmax).round_()
28 | q_x = q_x.round().byte()
29 | return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)
30 |
31 |
32 | def dequantize_tensor(q_x):
33 | return q_x.scale * (q_x.tensor.float() - q_x.zero_point)
34 |
35 |
36 | def quantize_model(model):
37 | qparams = {}
38 |
39 | for n, p in model.state_dict().items():
40 | qp = quantize_tensor(p)
41 | qparams[n + '.quantization.scale'] = torch.FloatTensor([qp.scale])
42 | qparams[
43 | n + '.quantization.zero_point'] = torch.ByteTensor([qp.zero_point])
44 | p.copy_(qp.tensor)
45 | model.type('torch.ByteTensor')
46 | for n, p in qparams.items():
47 | model.register_buffer(n, p)
48 | model.quantized = True
49 |
50 |
51 | def dequantize_model(model):
52 | model.float()
53 | params = model.state_dict()
54 | for n, p in params.items():
55 | if 'quantization' not in n:
56 | qp = QTensor(tensor=p,
57 | scale=params[n + '.quantization.scale'][0],
58 | zero_point=params[n + '.quantization.zero_point'][0])
59 | p.copy_(dequantize_tensor(qp))
60 | model.register_buffer(n + '.quantization.scale', None)
61 | model.register_buffer(n + '.quantization.zero_point', None)
62 | model.quantized = None
63 |
--------------------------------------------------------------------------------
/recorder.py:
--------------------------------------------------------------------------------
1 | from .regime import Regime
2 | import logging.config
3 | from os import path
4 | import torch
5 |
6 |
7 | class Recorder(Regime):
8 | def __init__(self, regime, defaults={}):
9 | self.regime = regime
10 | self.current_regime_phase = None
11 | self.setting = defaults
12 | self.measurments = None
13 |
14 | def get_steps(self):
15 | return [item['step'] for item in self.regime]
16 |
17 | @staticmethod
18 | def load(filename, drop_items=[]):
19 | try:
20 | measurments = torch.load(
21 | filename, map_location='cpu')
22 | for item in drop_items:
23 | measurments.pop(item)
24 | return measurments
25 | except FileNotFoundError:
26 | return None
27 |
28 | def update(self, train_steps=None, measurments={}):
29 | """adjusts optimizer according to current epoch or steps and training regime.
30 | """
31 | updated = False
32 | if super(Recorder, self).update(train_steps=train_steps):
33 | save_file = self.setting.get('save', None)
34 | if save_file is not None:
35 | # filename = path.join(self.file_prefix, f'{train_steps}.record')
36 | torch.save(measurments, save_file)
37 | logging.debug(f'Saved measurments to {save_file}')
38 | load_file = self.setting.get('load', None)
39 | if load_file is not None:
40 | logging.debug(f'Loaded measurments from {load_file}')
41 | self.measurments = self.load(load_file,
42 | drop_items=self.setting.get('drop_items', []))
43 |
44 | updated = True
45 | return updated
46 |
--------------------------------------------------------------------------------
/regime.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from copy import deepcopy
4 | from six import string_types
5 |
6 |
7 | def eval_func(f, x):
8 | if isinstance(f, string_types):
9 | f = eval(f)
10 | return f(x)
11 |
12 |
13 | class Regime(object):
14 | """
15 | Examples for regime:
16 |
17 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3},
18 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4},
19 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4},
20 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5}
21 | ]"
22 | 2)
23 | "[{'step_lambda':
24 | "lambda t: {
25 | 'optimizer': 'Adam',
26 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5),
27 | 'betas': (0.9, 0.98), 'eps':1e-9}
28 | }]"
29 | """
30 |
31 | def __init__(self, regime, defaults={}):
32 | self.regime = regime
33 | self.defaults = defaults
34 | self.reset(regime, defaults)
35 |
36 | def reset(self, regime=None, defaults=None):
37 | if regime is not None:
38 | self.regime = regime
39 | if defaults is not None:
40 | self.defaults = defaults
41 | self.current_regime_phase = None
42 | self.setting = self.defaults
43 |
44 | def update(self, epoch=None, train_steps=None):
45 | """adjusts according to current epoch or steps and regime.
46 | """
47 | if self.regime is None:
48 | return False
49 | epoch = -1 if epoch is None else epoch
50 | train_steps = -1 if train_steps is None else train_steps
51 | setting = deepcopy(self.setting)
52 | if self.current_regime_phase is None:
53 | # Find the first entry where the epoch is smallest than current
54 | for regime_phase, regime_setting in enumerate(self.regime):
55 | start_epoch = regime_setting.get('epoch', 0)
56 | start_step = regime_setting.get('step', 0)
57 | if epoch >= start_epoch or train_steps >= start_step:
58 | self.current_regime_phase = regime_phase
59 | break
60 | # each entry is updated from previous
61 | setting.update(regime_setting)
62 | if len(self.regime) > self.current_regime_phase + 1:
63 | next_phase = self.current_regime_phase + 1
64 | # Any more regime steps?
65 | start_epoch = self.regime[next_phase].get('epoch', float('inf'))
66 | start_step = self.regime[next_phase].get('step', float('inf'))
67 | if epoch >= start_epoch or train_steps >= start_step:
68 | self.current_regime_phase = next_phase
69 | setting.update(self.regime[self.current_regime_phase])
70 |
71 | if 'lr_decay_rate' in setting and 'lr' in setting:
72 | decay_steps = setting.pop('lr_decay_steps', 100)
73 | if train_steps % decay_steps == 0:
74 | decay_rate = setting.pop('lr_decay_rate')
75 | setting['lr'] *= decay_rate ** (train_steps / decay_steps)
76 | elif 'step_lambda' in setting:
77 | setting.update(eval_func(setting.pop('step_lambda'), train_steps))
78 | elif 'epoch_lambda' in setting:
79 | setting.update(eval_func(setting.pop('epoch_lambda'), epoch))
80 |
81 | if 'execute' in setting:
82 | setting.pop('execute')()
83 |
84 | if 'execute_once' in setting:
85 | setting.pop('execute_once')()
86 | # remove from regime, so won't happen again
87 | self.regime[self.current_regime_phase].pop('execute_once', None)
88 |
89 | if setting == self.setting:
90 | return False
91 | else:
92 | self.setting = setting
93 | return True
94 |
95 | def __repr__(self):
96 | return 'Current: %s\n Regime:%s' % (self.setting, self.regime)
97 |
--------------------------------------------------------------------------------
/regularization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .param_filter import FilterParameters, is_not_bn, is_not_bias
3 | from .absorb_bn import search_absorb_bn
4 | from torch.nn.utils import clip_grad_norm_
5 | import logging
6 |
7 |
8 | def sparsity(p):
9 | return float(p.eq(0).sum()) / p.nelement()
10 |
11 |
12 | def _norm(x, dim, p=2):
13 | """Computes the norm over all dimensions except dim"""
14 | if p == -1:
15 | def func(x, dim): return x.max(dim=dim)[0] - x.min(dim=dim)[0]
16 | elif p == float('inf'):
17 | def func(x, dim): return x.max(dim=dim)[0]
18 | else:
19 | def func(x, dim): return torch.norm(x, dim=dim, p=p)
20 | if dim is None:
21 | return x.norm(p=p)
22 | elif dim == 0:
23 | output_size = (x.size(0),) + (1,) * (x.dim() - 1)
24 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size)
25 | elif dim == x.dim() - 1:
26 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),)
27 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size)
28 | else:
29 | return _norm(x.transpose(0, dim), 0).transpose(0, dim)
30 |
31 |
32 | def _param_grad_norm(parameters):
33 | total_norm = 0
34 | for p in parameters:
35 | param_norm = p.grad.data.norm(2)
36 | total_norm += param_norm.item() ** 2
37 | return total_norm ** 0.5
38 |
39 |
40 | class Regularizer(object):
41 | def __init__(self, model, value=0, filter={}, log=False):
42 | self._model = model
43 | self._named_parameters = list(
44 | FilterParameters(model, **filter).named_parameters())
45 | self.value = value
46 | self.log = log
47 | if self.log:
48 | logging.debug('Applying regularization to parameters: %s',
49 | [n for n, _ in self._named_parameters])
50 |
51 | def named_parameters(self):
52 | for n, p in self._named_parameters:
53 | yield n, p
54 |
55 | def parameters(self):
56 | for _, p in self.named_parameters():
57 | yield p
58 |
59 | def _pre_parameter_step(self, parameter):
60 | pass
61 |
62 | def _post_parameter_step(self, parameter):
63 | pass
64 |
65 | def pre_step(self):
66 | pass
67 |
68 | def post_step(self):
69 | pass
70 |
71 | def pre_forward(self):
72 | pass
73 |
74 | def pre_backward(self):
75 | pass
76 |
77 | def params_with_grads(self, avoid_none_grads=False):
78 | return [(p, p.grad) for p in self.parameters()
79 | if (not avoid_none_grads or p.grad is not None)]
80 |
81 |
82 | class RegularizerList(Regularizer):
83 | def __init__(self, model, regularization_list):
84 | """each item must of of format (RegClass, **kwargs) or instance of Regularizer"""
85 | super(RegularizerList, self).__init__(model)
86 | self.regularization_list = []
87 | for regularizer in regularization_list:
88 | if not isinstance(regularizer, Regularizer):
89 | reg, reg_params = regularizer
90 | regularizer = reg(model=model, **reg_params)
91 | self.regularization_list.append(regularizer)
92 |
93 | def pre_step(self):
94 | for reg in self.regularization_list:
95 | reg.pre_step()
96 |
97 | def post_step(self):
98 | for reg in self.regularization_list:
99 | reg.post_step()
100 |
101 | def pre_forward(self):
102 | for reg in self.regularization_list:
103 | reg.pre_forward()
104 |
105 | def pre_backward(self):
106 | for reg in self.regularization_list:
107 | reg.pre_backward()
108 |
109 |
110 | class L2Regularization(Regularizer):
111 | def __init__(self, model, value=0,
112 | filter={'parameter_name': is_not_bias,
113 | 'module': is_not_bn},
114 | pre_op=True, post_op=False, **kwargs):
115 | super(L2Regularization, self).__init__(
116 | model, value, filter=filter, **kwargs)
117 | self.pre_op = pre_op
118 | self.post_op = post_op
119 |
120 | def pre_step(self):
121 | if self.pre_op:
122 | with torch.no_grad():
123 | params, grads = zip(*self.params_with_grads())
124 | torch._foreach_add_(grads, params, alpha=self.value)
125 | if self.log:
126 | logging.debug('L2 penalty of %s was applied pre optimization step',
127 | self.value)
128 |
129 | def post_step(self):
130 | if self.post_op:
131 | with torch.no_grad():
132 | torch._foreach_mul_(list(self.parameters()), 1. - self.value)
133 | if self.log:
134 | logging.debug('L2 penalty of %s was applied post optimization step',
135 | self.value)
136 |
137 |
138 | class WeightDecay(L2Regularization):
139 | def __init__(self, *kargs, **kwargs):
140 | super(WeightDecay, self).__init__(*kargs, **kwargs)
141 |
142 |
143 | class GradClip(Regularizer):
144 | def __init__(self, *kargs, **kwargs):
145 | super(GradClip, self).__init__(*kargs, **kwargs)
146 |
147 | def pre_step(self):
148 | if self.value > 0:
149 | with torch.no_grad():
150 | grad = clip_grad_norm_(self.parameters(), self.value)
151 | if self.log:
152 | logging.debug('Gradient norm value was clipped from %s to %s',
153 | grad, self.value)
154 |
155 |
156 | class GradSmooth(Regularizer):
157 | def __init__(self, model, value=True, momentum=0.9, filter={}, log=False):
158 | super(GradSmooth, self).__init__(model,
159 | value=value, filter=filter, log=log)
160 | self.momentum = momentum
161 | self.running_norm = None
162 | self.enabled = value
163 | self.counter = 0
164 |
165 | def pre_step(self):
166 | parameters = list(
167 | filter(lambda p: p.grad is not None, self.parameters()))
168 | total_norm = _param_grad_norm(parameters)
169 | if self.running_norm is None:
170 | self.running_norm = total_norm
171 | else:
172 | self.running_norm = self.momentum * self.running_norm
173 | + (1 - self.momentum) * total_norm
174 | if self.enabled:
175 | clip_coef = self.running_norm / (total_norm + 1e-6)
176 | for p in parameters:
177 | p.grad.data.mul_(clip_coef)
178 | if self.log:
179 | logging.debug('Gradient norm value was clipped from %s to %s',
180 | total_norm, self.running_norm)
181 |
182 | self.counter += 1
183 |
184 |
185 | class L1Regularization(Regularizer):
186 | def __init__(self, model, value=1e-3,
187 | filter={'parameter_name': is_not_bias,
188 | 'module': is_not_bn},
189 | pre_op=False, post_op=True, report_sparsity=False, **kwargs):
190 | super(L1Regularization, self).__init__(
191 | model, value, filter=filter, **kwargs)
192 | self.pre_op = pre_op
193 | self.post_op = post_op
194 | self.report_sparsity = report_sparsity
195 |
196 | def pre_step(self):
197 | if self.pre_op:
198 | with torch.no_grad():
199 | for n, p in self._named_parameters:
200 | p.grad.add_(p.sign(), alpha=self.value)
201 | if self.report_sparsity:
202 | logging.debug('Sparsity for %s is %s', n, sparsity(p))
203 | if self.log:
204 | logging.debug('L1 penalty of %s was applied pre optimization step',
205 | self.value)
206 |
207 | def post_step(self):
208 | if self.post_op:
209 | with torch.no_grad():
210 | for n, p in self._named_parameters:
211 | p.copy_(torch.nn.functional.softshrink(p, self.value))
212 | if self.report_sparsity:
213 | logging.debug('Sparsity for %s is %s', n, sparsity(p))
214 | if self.log:
215 | logging.debug('L1 penalty of %s was applied post optimization step',
216 | self.value)
217 |
218 |
219 | class BoundedWeightNorm(Regularizer):
220 | def __init__(self, model,
221 | filter={'parameter_name': is_not_bias,
222 | 'module': is_not_bn},
223 | dim=0, p=2, **kwargs):
224 | super(BoundedWeightNorm, self).__init__(
225 | model, 0, filter=filter, **kwargs)
226 | self.dim = dim
227 | self.init_norms = None
228 | self.p = p
229 |
230 | def _gather_init_norm(self):
231 | self.init_norms = {}
232 | with torch.no_grad():
233 | for n, p in self._named_parameters:
234 | self.init_norms[n] = _norm(
235 | p, self.dim, p=self.p).detach().mean()
236 |
237 | def pre_forward(self):
238 | if self.init_norms is None:
239 | self._gather_init_norm()
240 | with torch.no_grad():
241 | for n, p in self._named_parameters:
242 | init_norm = self.init_norms[n]
243 | new_norm = _norm(p, self.dim, p=self.p)
244 | p.mul_(init_norm / new_norm)
245 |
246 | def pre_step(self):
247 | for n, p in self._named_parameters:
248 | init_norm = self.init_norms[n]
249 | norm = _norm(p, self.dim, p=self.p)
250 | curr_grad = p.grad.data.clone()
251 | p.grad.data.zero_()
252 | p_normed = p * (init_norm / norm)
253 | p_normed.backward(curr_grad)
254 |
255 |
256 | class LARS(Regularizer):
257 | """Large Batch Training of Convolutional Networks - https://arxiv.org/abs/1708.03888
258 | """
259 |
260 | def __init__(self, model, value=0.01, weight_decay=0, dim=None, p=2, min_scale=None, max_scale=None,
261 | filter={'parameter_name': is_not_bias,
262 | 'module': is_not_bn},
263 | **kwargs):
264 | super(LARS, self).__init__(model, value, filter=filter, **kwargs)
265 | self.weight_decay = weight_decay
266 | self.dim = dim
267 | self.p = p
268 | self.min_scale = min_scale
269 | self.max_scale = max_scale
270 |
271 | def pre_step(self):
272 | with torch.no_grad():
273 | for _, param in self._named_parameters:
274 | param.grad.add_(param, alpha=self.weight_decay)
275 | if self.dim is not None:
276 | norm = _norm(param, dim=self.dim, p=self.p)
277 | grad_norm = _norm(param.grad, dim=self.dim, p=self.p)
278 | else:
279 | norm = param.norm(p=self.p)
280 | grad_norm = param.grad.norm(p=self.p)
281 | scale = self.value * norm / grad_norm
282 | if self.min_scale is not None or self.max_scale is not None:
283 | scale.clamp_(min=self.min_scale, max=self.max_scale)
284 | param.grad.mul_(scale)
285 |
286 |
287 | class DropConnect(Regularizer):
288 | def __init__(self, model, value=0,
289 | filter={'parameter_name': is_not_bias,
290 | 'module': is_not_bn},
291 | shakeshake=False, **kwargs):
292 | super(DropConnect, self).__init__(
293 | model, value=value, filter=filter, **kwargs)
294 | self.shakeshake = shakeshake
295 |
296 | def _drop_parameters(self):
297 | self.parameter_copy = {}
298 | with torch.no_grad():
299 | for n, p in self._named_parameters:
300 | self.parameter_copy[n] = p.clone()
301 | torch.nn.functional.dropout(p, self.value,
302 | training=True, inplace=True)
303 |
304 | def _reassign_parameters(self):
305 | with torch.no_grad():
306 | for n, p in self._named_parameters:
307 | p.copy_(self.parameter_copy.pop(n))
308 |
309 | def pre_forward(self):
310 | self._drop_parameters()
311 |
312 | def pre_backward(self):
313 | if self.shakeshake:
314 | self._reassign_parameters()
315 |
316 | def pre_step(self):
317 | if not self.shakeshake:
318 | self._reassign_parameters()
319 |
320 |
321 | class AbsorbBN(Regularizer):
322 | def __init__(self, model, remove_bn=False):
323 | self._model = model
324 | if not remove_bn:
325 | for m in model.modules():
326 | if isinstance(m, torch.nn.BatchNorm2d):
327 | m.momentum = 1
328 | self.remove_bn = remove_bn
329 | self._removed = False
330 |
331 | def pre_forward(self):
332 | if self._removed:
333 | return
334 | search_absorb_bn(self._model, remove_bn=self.remove_bn, verbose=False)
335 | self._removed = self.remove_bn
336 |
337 |
338 | class Consolidate(Regularizer):
339 | """ groups is a list of tuples, each tuple is of consolidated parameters (of same size)
340 | """
341 |
342 | def __init__(self, model, value=0.1, force=False, **kwargs):
343 | super(Consolidate, self).__init__(
344 | model, value, **kwargs)
345 | self.force = force
346 | shared_params = {}
347 | for name, param in self._named_parameters:
348 | numel = param.numel()
349 | shared_params.setdefault(numel, [])
350 | shared_params[numel].append((name, param))
351 |
352 | self.groups = []
353 | shared_names = []
354 | for shared_list in shared_params.values():
355 | if len(shared_list) < 2:
356 | continue
357 | names, params = zip(*shared_list)
358 | self.groups.append(params)
359 | shared_names.append(names)
360 |
361 | logging.debug('Shared parameter groups %s' % shared_names)
362 |
363 | def pre_step(self):
364 | with torch.no_grad():
365 | for i, group in enumerate(self.groups):
366 | mean_group = sum(group) / len(group)
367 | for p in group:
368 | p.grad.data.add_(p - mean_group, alpha=self.value)
369 | if self.log:
370 | logging.debug('group %s, diff norm = %s' %
371 | (i, float(sum([(p - mean_group).norm() ** 2 for p in group]).sqrt())))
372 |
373 | def post_step(self):
374 | if self.force:
375 | with torch.no_grad():
376 | for group in enumerate(self.groups):
377 | mean_group = sum(group) / len(group)
378 | for p in group:
379 | p.copy_(mean_group)
380 |
--------------------------------------------------------------------------------