├── .gitignore ├── LICENSE ├── __init__.py ├── absorb_bn.py ├── cross_entropy.py ├── dataset.py ├── functions.py ├── log.py ├── meters.py ├── misc.py ├── mixup.py ├── optim.py ├── param_filter.py ├── quantize.py ├── recorder.py ├── regime.py └── regularization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | # sphinx-apidoc automatically generated documentation 72 | docs/source/packages 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladhoffer/utils.pytorch/7a2bbdb80835af0cc29c79653ee544f8301caa2b/__init__.py -------------------------------------------------------------------------------- /absorb_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | 5 | 6 | def remove_bn_params(bn_module): 7 | bn_module.register_buffer('running_mean', None) 8 | bn_module.register_buffer('running_var', None) 9 | bn_module.register_parameter('weight', None) 10 | bn_module.register_parameter('bias', None) 11 | 12 | 13 | def init_bn_params(bn_module): 14 | bn_module.running_mean.fill_(0) 15 | bn_module.running_var.fill_(1) 16 | 17 | 18 | def absorb_bn(module, bn_module, remove_bn=True, verbose=False): 19 | with torch.no_grad(): 20 | w = module.weight 21 | if module.bias is None: 22 | zeros = torch.zeros(module.out_channels, 23 | dtype=w.dtype, device=w.device) 24 | bias = nn.Parameter(zeros) 25 | module.register_parameter('bias', bias) 26 | b = module.bias 27 | 28 | if hasattr(bn_module, 'running_mean'): 29 | b.add_(-bn_module.running_mean) 30 | if hasattr(bn_module, 'running_var'): 31 | invstd = bn_module.running_var.clone().add_(bn_module.eps).pow_(-0.5) 32 | w.mul_(invstd.view(w.size(0), 1, 1, 1)) 33 | b.mul_(invstd) 34 | 35 | if remove_bn: 36 | if hasattr(bn_module, 'weight'): 37 | w.mul_(bn_module.weight.view(w.size(0), 1, 1, 1)) 38 | b.mul_(bn_module.weight) 39 | if hasattr(bn_module, 'bias'): 40 | b.add_(bn_module.bias) 41 | remove_bn_params(bn_module) 42 | else: 43 | init_bn_params(bn_module) 44 | 45 | if verbose: 46 | logging.info('BN module %s was asborbed into layer %s' % 47 | (bn_module, module)) 48 | 49 | 50 | def is_bn(m): 51 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 52 | 53 | 54 | def is_absorbing(m): 55 | return isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) 56 | 57 | 58 | def search_absorb_bn(model, prev=None, remove_bn=True, verbose=False): 59 | with torch.no_grad(): 60 | for n, m in model.named_children(): 61 | if is_bn(m) and is_absorbing(prev): 62 | absorb_bn(prev, m, remove_bn=remove_bn, verbose=verbose) 63 | if remove_bn: 64 | setattr(model, n, nn.Identity()) 65 | search_absorb_bn(m, remove_bn=remove_bn, verbose=verbose) 66 | prev = m 67 | -------------------------------------------------------------------------------- /cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .misc import onehot 6 | 7 | 8 | def _is_long(x): 9 | if hasattr(x, 'data'): 10 | x = x.data 11 | return isinstance(x, torch.LongTensor) or isinstance(x, torch.cuda.LongTensor) 12 | 13 | 14 | def cross_entropy(inputs, target, weight=None, ignore_index=-100, reduction='mean', 15 | smooth_eps=None, smooth_dist=None, from_logits=True): 16 | """cross entropy loss, with support for target distributions and label smoothing https://arxiv.org/abs/1512.00567""" 17 | smooth_eps = smooth_eps or 0 18 | 19 | # ordinary log-liklihood - use cross_entropy from nn 20 | if _is_long(target) and smooth_eps == 0: 21 | if from_logits: 22 | return F.cross_entropy(inputs, target, weight, ignore_index=ignore_index, reduction=reduction) 23 | else: 24 | return F.nll_loss(inputs, target, weight, ignore_index=ignore_index, reduction=reduction) 25 | 26 | if from_logits: 27 | # log-softmax of inputs 28 | lsm = F.log_softmax(inputs, dim=-1) 29 | else: 30 | lsm = inputs 31 | 32 | masked_indices = None 33 | num_classes = inputs.size(-1) 34 | 35 | if _is_long(target) and ignore_index >= 0: 36 | masked_indices = target.eq(ignore_index) 37 | 38 | if smooth_eps > 0 and smooth_dist is not None: 39 | if _is_long(target): 40 | target = onehot(target, num_classes).type_as(inputs) 41 | if smooth_dist.dim() < target.dim(): 42 | smooth_dist = smooth_dist.unsqueeze(0) 43 | target.lerp_(smooth_dist, smooth_eps) 44 | 45 | if weight is not None: 46 | lsm = lsm * weight.unsqueeze(0) 47 | 48 | if _is_long(target): 49 | eps_sum = smooth_eps / num_classes 50 | eps_nll = 1. - eps_sum - smooth_eps 51 | likelihood = lsm.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1) 52 | loss = -(eps_nll * likelihood + eps_sum * lsm.sum(-1)) 53 | else: 54 | loss = -(target * lsm).sum(-1) 55 | 56 | if masked_indices is not None: 57 | loss.masked_fill_(masked_indices, 0) 58 | 59 | if reduction == 'sum': 60 | loss = loss.sum() 61 | elif reduction == 'mean': 62 | if masked_indices is None: 63 | loss = loss.mean() 64 | else: 65 | loss = loss.sum() / float(loss.size(0) - masked_indices.sum()) 66 | 67 | return loss 68 | 69 | 70 | class CrossEntropyLoss(nn.CrossEntropyLoss): 71 | """CrossEntropyLoss - with ability to recieve distrbution as targets, and optional label smoothing""" 72 | 73 | def __init__(self, weight=None, ignore_index=-100, reduction='mean', smooth_eps=None, smooth_dist=None, from_logits=True): 74 | super(CrossEntropyLoss, self).__init__(weight=weight, 75 | ignore_index=ignore_index, reduction=reduction) 76 | self.smooth_eps = smooth_eps 77 | self.smooth_dist = smooth_dist 78 | self.from_logits = from_logits 79 | 80 | def forward(self, input, target, smooth_dist=None): 81 | if smooth_dist is None: 82 | smooth_dist = self.smooth_dist 83 | return cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, 84 | reduction=self.reduction, smooth_eps=self.smooth_eps, 85 | smooth_dist=smooth_dist, from_logits=self.from_logits) 86 | 87 | 88 | def binary_cross_entropy(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=False): 89 | """cross entropy loss, with support for label smoothing https://arxiv.org/abs/1512.00567""" 90 | smooth_eps = smooth_eps or 0 91 | if smooth_eps > 0: 92 | target = target.float() 93 | target.add_(smooth_eps).div_(2.) 94 | if from_logits: 95 | return F.binary_cross_entropy_with_logits(inputs, target, weight=weight, reduction=reduction) 96 | else: 97 | return F.binary_cross_entropy(inputs, target, weight=weight, reduction=reduction) 98 | 99 | 100 | def binary_cross_entropy_with_logits(inputs, target, weight=None, reduction='mean', smooth_eps=None, from_logits=True): 101 | return binary_cross_entropy(inputs, target, weight, reduction, smooth_eps, from_logits) 102 | 103 | 104 | class BCELoss(nn.BCELoss): 105 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=False): 106 | super(BCELoss, self).__init__(weight, size_average, reduce, reduction) 107 | self.smooth_eps = smooth_eps 108 | self.from_logits = from_logits 109 | 110 | def forward(self, input, target): 111 | return binary_cross_entropy(input, target, 112 | weight=self.weight, reduction=self.reduction, 113 | smooth_eps=self.smooth_eps, from_logits=self.from_logits) 114 | 115 | 116 | class BCEWithLogitsLoss(BCELoss): 117 | def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', smooth_eps=None, from_logits=True): 118 | super(BCEWithLogitsLoss, self).__init__(weight, size_average, 119 | reduce, reduction, smooth_eps=smooth_eps, from_logits=from_logits) 120 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import pickle 3 | import PIL 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.utils.data.sampler import Sampler, RandomSampler, BatchSampler, _int_classes 7 | from numpy.random import choice 8 | import csv 9 | from copy import copy 10 | import codecs 11 | from torch._utils import _accumulate 12 | from collections import Counter 13 | 14 | 15 | class RandomSamplerReplacment(torch.utils.data.sampler.Sampler): 16 | """Samples elements randomly, with replacement. 17 | Arguments: 18 | data_source (Dataset): dataset to sample from 19 | """ 20 | 21 | def __init__(self, data_source): 22 | self.num_samples = len(data_source) 23 | 24 | def __iter__(self): 25 | return iter(torch.from_numpy(choice(self.num_samples, self.num_samples, replace=True))) 26 | 27 | def __len__(self): 28 | return self.num_samples 29 | 30 | class LongestLoader(torch.utils.data.dataloader.DataLoader): 31 | def __init__(self, *loaders): 32 | self.loaders = loaders 33 | self.num_workers = self.loaders[0].num_workers 34 | 35 | def __iter__(self): 36 | self.iterators = [iter(it) for it in self.loaders] 37 | max_length = len(self) 38 | for _ in range(max_length): 39 | values = [] 40 | for i, it in enumerate(self.iterators): 41 | try: 42 | value = next(it) 43 | except StopIteration: 44 | self.iterators[i] = iter(self.loaders[i]) 45 | value = next(self.iterators[i]) 46 | values.append(value) 47 | yield tuple(values) 48 | 49 | def __len__(self): 50 | return max([len(arg) for arg in self.loaders]) 51 | 52 | class LimitDataset(Dataset): 53 | 54 | def __init__(self, dset, max_len): 55 | self.dset = dset 56 | self.max_len = max_len 57 | 58 | def __len__(self): 59 | return min(len(self.dset), self.max_len) 60 | 61 | def __getitem__(self, index): 62 | return self.dset[index] 63 | 64 | 65 | class ByClassDataset(Dataset): 66 | 67 | def __init__(self, ds): 68 | self.dataset = ds 69 | self.idx_by_class = {} 70 | for idx, (_, c) in enumerate(ds): 71 | self.idx_by_class.setdefault(c, []) 72 | self.idx_by_class[c].append(idx) 73 | 74 | def __len__(self): 75 | return min([len(d) for d in self.idx_by_class.values()]) 76 | 77 | def __getitem__(self, idx): 78 | idx_per_class = [self.idx_by_class[c][idx] 79 | for c in range(len(self.idx_by_class))] 80 | labels = torch.LongTensor([self.dataset[i][1] 81 | for i in idx_per_class]) 82 | items = [self.dataset[i][0] for i in idx_per_class] 83 | if torch.is_tensor(items[0]): 84 | items = torch.stack(items) 85 | 86 | return (items, labels) 87 | 88 | 89 | class IdxDataset(Dataset): 90 | """docstring for IdxDataset.""" 91 | 92 | def __init__(self, dset): 93 | super(IdxDataset, self).__init__() 94 | self.dset = dset 95 | self.idxs = range(len(self.dset)) 96 | 97 | def __getitem__(self, idx): 98 | data, labels = self.dset[self.idxs[idx]] 99 | return (idx, data, labels) 100 | 101 | def __len__(self): 102 | return len(self.idxs) 103 | 104 | 105 | def image_loader(imagebytes): 106 | img = PIL.Image.open(BytesIO(imagebytes)) 107 | return img.convert('RGB') 108 | 109 | 110 | class IndexedFileDataset(Dataset): 111 | """ A dataset that consists of an indexed file (with sample offsets in 112 | another file). For example, a .tar that contains image files. 113 | The dataset does not extract the samples, but works with the indexed 114 | file directly. 115 | NOTE: The index file is assumed to be a pickled list of 3-tuples: 116 | (name, offset, size). 117 | """ 118 | 119 | def __init__(self, filename, index_filename=None, extract_target_fn=None, 120 | transform=None, target_transform=None, loader=image_loader): 121 | super(IndexedFileDataset, self).__init__() 122 | 123 | # Defaults 124 | if index_filename is None: 125 | index_filename = filename + '.index' 126 | if extract_target_fn is None: 127 | extract_target_fn = lambda *args: args 128 | 129 | # Read index 130 | with open(index_filename, 'rb') as index_fp: 131 | sample_list = pickle.load(index_fp) 132 | 133 | # Collect unique targets (sorted by name) 134 | targetset = set(extract_target_fn(target) 135 | for target, _, _ in sample_list) 136 | targetmap = {target: i for i, target in enumerate(sorted(targetset))} 137 | 138 | self.samples = [(targetmap[extract_target_fn(target)], offset, size) 139 | for target, offset, size in sample_list] 140 | self.filename = filename 141 | 142 | self.loader = loader 143 | self.transform = transform 144 | self.target_transform = target_transform 145 | 146 | def _get_sample(self, fp, idx): 147 | target, offset, size = self.samples[idx] 148 | fp.seek(offset) 149 | sample = self.loader(fp.read(size)) 150 | 151 | if self.transform is not None: 152 | sample = self.transform(sample) 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | 156 | return sample, target 157 | 158 | def __getitem__(self, index): 159 | with open(self.filename, 'rb') as fp: 160 | # Handle slices 161 | if isinstance(index, slice): 162 | return [self._get_sample(fp, subidx) for subidx in 163 | range(index.start or 0, index.stop or len(self), 164 | index.step or 1)] 165 | 166 | return self._get_sample(fp, index) 167 | 168 | def __len__(self): 169 | return len(self.samples) 170 | 171 | 172 | class DuplicateBatchSampler(Sampler): 173 | def __init__(self, sampler, batch_size, duplicates, drop_last): 174 | if not isinstance(sampler, Sampler): 175 | raise ValueError("sampler should be an instance of " 176 | "torch.utils.data.Sampler, but got sampler={}" 177 | .format(sampler)) 178 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ 179 | batch_size <= 0: 180 | raise ValueError("batch_size should be a positive integeral value, " 181 | "but got batch_size={}".format(batch_size)) 182 | if not isinstance(drop_last, bool): 183 | raise ValueError("drop_last should be a boolean value, but got " 184 | "drop_last={}".format(drop_last)) 185 | self.sampler = sampler 186 | self.batch_size = batch_size 187 | self.drop_last = drop_last 188 | self.duplicates = duplicates 189 | 190 | def __iter__(self): 191 | batch = [] 192 | for idx in self.sampler: 193 | batch.append(idx) 194 | if len(batch) == self.batch_size: 195 | yield batch * self.duplicates 196 | batch = [] 197 | if len(batch) > 0 and not self.drop_last: 198 | yield batch * self.duplicates 199 | 200 | def __len__(self): 201 | if self.drop_last: 202 | return len(self.sampler) // self.batch_size 203 | else: 204 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 205 | 206 | 207 | def list_line_locations(filename, limit=None): 208 | line_offset = [] 209 | offset = 0 210 | with open(filename, "rb") as f: 211 | for line in f: 212 | line_offset.append(offset) 213 | offset += len(line) 214 | if limit is not None and len(line_offset) > limit: 215 | break 216 | return line_offset 217 | 218 | 219 | def _load_or_create(filename, create_fn, cache=True, force_create=False): 220 | loaded = False 221 | if not force_create: 222 | try: 223 | with open(filename, 'rb') as fp: 224 | value = pickle.load(fp) 225 | loaded = True 226 | except: 227 | pass 228 | if not loaded: 229 | value = create_fn() 230 | if cache and not loaded: 231 | with open(filename, 'wb') as fp: 232 | pickle.dump(value, fp) 233 | return value 234 | 235 | 236 | class LinedTextDataset(Dataset): 237 | """ Dataset in which every line is a seperate item (e.g translation) 238 | """ 239 | 240 | def __init__(self, filename, transform=None, cache=True): 241 | self.filename = filename 242 | self.transform = transform 243 | self.items = _load_or_create(filename + '_cached_lines', 244 | create_fn=lambda: list_line_locations( 245 | filename), 246 | cache=cache) 247 | 248 | def __getitem__(self, index): 249 | if isinstance(index, slice): 250 | return [self[idx] for idx in range(index.start or 0, index.stop or len(self), index.step or 1)] 251 | with codecs.open(self.filename, encoding='UTF-8') as f: 252 | f.seek(self.items[index]) 253 | item = f.readline() 254 | if self.transform is not None: 255 | item = self.transform(item) 256 | return item 257 | 258 | def __len__(self): 259 | return len(self.items) 260 | 261 | def select_range(self, start, end): 262 | new_dataset = copy(self) 263 | new_dataset.items = new_dataset.items[start:end] 264 | return new_dataset 265 | 266 | def filter(self, filter_func): 267 | new_dataset = copy(self) 268 | new_dataset.items = [item for item in self if filter_func(item)] 269 | return new_dataset 270 | 271 | def subset(self, indices): 272 | new_dataset = copy(self) 273 | new_dataset.items = [new_dataset.items[idx] for idx in indices] 274 | return new_dataset 275 | 276 | def split(self, lengths): 277 | """ 278 | split a dataset into non-overlapping new datasets of given lengths. 279 | Arguments: 280 | dataset (Dataset): Dataset to be split 281 | lengths (sequence): lengths of splits to be produced 282 | """ 283 | if sum(lengths) != len(self): 284 | raise ValueError( 285 | "Sum of input lengths does not equal the length of the input dataset!") 286 | 287 | return [self.select_range(offset-length, offset) for offset, length in zip(_accumulate(lengths), lengths)] 288 | 289 | def random_split(self, lengths): 290 | """ 291 | Randomly split a dataset into non-overlapping new datasets of given lengths. 292 | Arguments: 293 | dataset (Dataset): Dataset to be split 294 | lengths (sequence): lengths of splits to be produced 295 | """ 296 | if sum(lengths) != len(self): 297 | raise ValueError( 298 | "Sum of input lengths does not equal the length of the input dataset!") 299 | 300 | indices = torch.randperm(sum(lengths)).tolist() 301 | return [self.subset(indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 302 | 303 | 304 | class CSVDataset(LinedTextDataset): 305 | """ Dataset with delimited items and pre-knwon fieldnames (no header) 306 | """ 307 | 308 | def __init__(self, filename, fieldnames=None, delimiter='\t', transform=None, cache=True): 309 | self.filename = filename 310 | self.fieldnames = fieldnames 311 | self.delimiter = delimiter 312 | self.transform = transform 313 | self.items = _load_or_create(filename + '_cached_lines', 314 | create_fn=lambda: list_line_locations( 315 | filename), 316 | cache=cache) 317 | 318 | def __getitem__(self, index): 319 | if isinstance(index, slice): 320 | return [self[idx] for idx in range(index.start or 0, index.stop or len(self), index.step or 1)] 321 | with codecs.open(self.filename, encoding='UTF-8') as f: 322 | f.seek(self.items[index]) 323 | item = f.readline() 324 | item = next(csv.DictReader([item], 325 | fieldnames=self.fieldnames, 326 | delimiter=self.delimiter)) 327 | if self.transform is not None: 328 | item = self.transform(item) 329 | return item 330 | 331 | def count_fields(self, fieldnames=None): 332 | fieldnames = fieldnames or self.fieldnames 333 | counters = {name: Counter() for name in fieldnames} 334 | for i in range(len(self)): 335 | value = self[i] 336 | for field in fieldnames: 337 | counters[field][value[field]] += 1 338 | return counters 339 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function 3 | 4 | class ScaleGrad(Function): 5 | 6 | @staticmethod 7 | def forward(ctx, input, scale): 8 | ctx.scale = scale 9 | return input 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): 13 | grad_input = ctx.scale * grad_output 14 | return grad_input, None 15 | 16 | 17 | def scale_grad(x, scale): 18 | return ScaleGrad().apply(x, scale) 19 | 20 | def negate_grad(x): 21 | return scale_grad(x, -1) 22 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from itertools import cycle 4 | import torch 5 | import logging.config 6 | from datetime import datetime 7 | import json 8 | 9 | import pandas as pd 10 | from bokeh.io import output_file, save, show 11 | from bokeh.plotting import figure 12 | from bokeh.layouts import column 13 | from bokeh.models import Div 14 | 15 | try: 16 | import hyperdash 17 | HYPERDASH_AVAILABLE = True 18 | except ImportError: 19 | HYPERDASH_AVAILABLE = False 20 | 21 | 22 | def export_args_namespace(args, filename): 23 | """ 24 | args: argparse.Namespace 25 | arguments to save 26 | filename: string 27 | filename to save at 28 | """ 29 | with open(filename, 'w') as fp: 30 | json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) 31 | 32 | 33 | def setup_logging(log_file='log.txt', resume=False, dummy=False): 34 | """ 35 | Setup logging configuration 36 | """ 37 | if dummy: 38 | logging.getLogger('dummy') 39 | return 40 | 41 | file_mode = 'a' if os.path.isfile(log_file) and resume else 'w' 42 | 43 | root_logger = logging.getLogger() 44 | logging.basicConfig(level=logging.DEBUG, 45 | format="%(asctime)s - %(levelname)s - %(message)s", 46 | datefmt="%Y-%m-%d %H:%M:%S") 47 | # Remove all existing handlers (can't use the `force` option with 48 | # python < 3.8) 49 | for hdlr in root_logger.handlers[:]: 50 | root_logger.removeHandler(hdlr) 51 | # Add the handlers we want to use 52 | fileout = logging.FileHandler(log_file, mode=file_mode) 53 | fileout.setLevel(logging.DEBUG) 54 | fileout.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) 55 | logging.getLogger().addHandler(fileout) 56 | console = logging.StreamHandler() 57 | console.setLevel(logging.INFO) 58 | console.setFormatter(logging.Formatter('%(message)s')) 59 | logging.getLogger().addHandler(console) 60 | 61 | 62 | def plot_figure(data, x, y, title=None, xlabel=None, ylabel=None, legend=None, 63 | x_axis_type='linear', y_axis_type='linear', 64 | width=800, height=400, line_width=2, 65 | colors=['red', 'green', 'blue', 'orange', 66 | 'black', 'purple', 'brown'], 67 | tools='pan,box_zoom,wheel_zoom,box_select,hover,reset,save', 68 | append_figure=None): 69 | """ 70 | creates a new plot figures 71 | example: 72 | plot_figure(x='epoch', y=['train_loss', 'val_loss'], 73 | 'title='Loss', 'ylabel'='loss') 74 | """ 75 | if not isinstance(y, list): 76 | y = [y] 77 | xlabel = xlabel or x 78 | legend = legend or y 79 | assert len(legend) == len(y) 80 | if append_figure is not None: 81 | f = append_figure 82 | else: 83 | f = figure(title=title, tools=tools, 84 | width=width, height=height, 85 | x_axis_label=xlabel or x, 86 | y_axis_label=ylabel or '', 87 | x_axis_type=x_axis_type, 88 | y_axis_type=y_axis_type) 89 | colors = cycle(colors) 90 | for i, yi in enumerate(y): 91 | f.line(data[x], data[yi], 92 | line_width=line_width, 93 | line_color=next(colors), legend_label=legend[i]) 94 | f.legend.click_policy = "hide" 95 | return f 96 | 97 | 98 | class ResultsLog(object): 99 | 100 | supported_data_formats = ['csv', 'json'] 101 | 102 | def __init__(self, path='', title='', params=None, resume=False, data_format='csv'): 103 | """ 104 | Parameters 105 | ---------- 106 | path: string 107 | path to directory to save data files 108 | plot_path: string 109 | path to directory to save plot files 110 | title: string 111 | title of HTML file 112 | params: Namespace 113 | optionally save parameters for results 114 | resume: bool 115 | resume previous logging 116 | data_format: str('csv'|'json') 117 | which file format to use to save the data 118 | """ 119 | if data_format not in ResultsLog.supported_data_formats: 120 | raise ValueError('data_format must of the following: ' + 121 | '|'.join(['{}'.format(k) for k in ResultsLog.supported_data_formats])) 122 | 123 | if data_format == 'json': 124 | self.data_path = '{}.json'.format(path) 125 | else: 126 | self.data_path = '{}.csv'.format(path) 127 | if params is not None: 128 | export_args_namespace(params, '{}.json'.format(path)) 129 | self.plot_path = '{}.html'.format(path) 130 | self.results = None 131 | self.clear() 132 | self.first_save = True 133 | if os.path.isfile(self.data_path): 134 | if resume: 135 | self.load(self.data_path) 136 | self.first_save = False 137 | else: 138 | os.remove(self.data_path) 139 | self.results = pd.DataFrame() 140 | else: 141 | self.results = pd.DataFrame() 142 | 143 | self.title = title 144 | self.data_format = data_format 145 | 146 | if HYPERDASH_AVAILABLE: 147 | name = self.title if title != '' else path 148 | self.hd_experiment = hyperdash.Experiment(name) 149 | if params is not None: 150 | for k, v in params._get_kwargs(): 151 | self.hd_experiment.param(k, v, log=False) 152 | 153 | def clear(self): 154 | self.figures = [] 155 | 156 | def add(self, **kwargs): 157 | """Add a new row to the dataframe 158 | example: 159 | resultsLog.add(epoch=epoch_num, train_loss=loss, 160 | test_loss=test_loss) 161 | """ 162 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 163 | self.results = self.results.append(df, ignore_index=True) 164 | if hasattr(self, 'hd_experiment'): 165 | for k, v in kwargs.items(): 166 | self.hd_experiment.metric(k, v, log=False) 167 | 168 | def smooth(self, column_name, window): 169 | """Select an entry to smooth over time""" 170 | # TODO: smooth only new data 171 | smoothed_column = self.results[column_name].rolling( 172 | window=window, center=False).mean() 173 | self.results[column_name + '_smoothed'] = smoothed_column 174 | 175 | def save(self, title=None): 176 | """save the json file. 177 | Parameters 178 | ---------- 179 | title: string 180 | title of the HTML file 181 | """ 182 | title = title or self.title 183 | if len(self.figures) > 0: 184 | if os.path.isfile(self.plot_path): 185 | os.remove(self.plot_path) 186 | if self.first_save: 187 | self.first_save = False 188 | logging.info('Plot file saved at: {}'.format( 189 | os.path.abspath(self.plot_path))) 190 | 191 | output_file(self.plot_path, title=title) 192 | plot = column( 193 | Div(text='

{}

'.format(title)), *self.figures) 194 | save(plot) 195 | self.clear() 196 | 197 | if self.data_format == 'json': 198 | self.results.to_json(self.data_path, orient='records', lines=True) 199 | else: 200 | self.results.to_csv(self.data_path, index=False, index_label=False) 201 | 202 | def load(self, path=None): 203 | """load the data file 204 | Parameters 205 | ---------- 206 | path: 207 | path to load the json|csv file from 208 | """ 209 | path = path or self.data_path 210 | if os.path.isfile(path): 211 | if self.data_format == 'json': 212 | self.results.read_json(path) 213 | else: 214 | self.results.read_csv(path) 215 | else: 216 | raise ValueError('{} isn''t a file'.format(path)) 217 | 218 | def show(self, title=None): 219 | title = title or self.title 220 | if len(self.figures) > 0: 221 | plot = column( 222 | Div(text='

{}

'.format(title)), *self.figures) 223 | show(plot) 224 | 225 | def plot(self, *kargs, **kwargs): 226 | """ 227 | add a new plot to the HTML file 228 | example: 229 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 230 | 'title='Loss', 'ylabel'='loss') 231 | """ 232 | f = plot_figure(self.results, *kargs, **kwargs) 233 | self.figures.append(f) 234 | 235 | def image(self, *kargs, **kwargs): 236 | fig = figure() 237 | fig.image(*kargs, **kwargs) 238 | self.figures.append(fig) 239 | 240 | def end(self): 241 | if hasattr(self, 'hd_experiment'): 242 | self.hd_experiment.end() 243 | 244 | 245 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 246 | filename = os.path.join(path, filename) 247 | torch.save(state, filename) 248 | if is_best: 249 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 250 | if save_all: 251 | shutil.copyfile(filename, os.path.join( 252 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 253 | -------------------------------------------------------------------------------- /meters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | class OnlineMeter(object): 24 | """Computes and stores the average and variance/std values of tensor""" 25 | 26 | def __init__(self): 27 | self.mean = torch.FloatTensor(1).fill_(-1) 28 | self.M2 = torch.FloatTensor(1).zero_() 29 | self.count = 0. 30 | self.needs_init = True 31 | 32 | def reset(self, x): 33 | self.mean = x.new(x.size()).zero_() 34 | self.M2 = x.new(x.size()).zero_() 35 | self.count = 0. 36 | self.needs_init = False 37 | 38 | def update(self, x): 39 | self.val = x 40 | if self.needs_init: 41 | self.reset(x) 42 | self.count += 1 43 | delta = x - self.mean 44 | self.mean.add_(delta / self.count) 45 | delta2 = x - self.mean 46 | self.M2.add_(delta * delta2) 47 | 48 | @property 49 | def var(self): 50 | if self.count < 2: 51 | return self.M2.clone().zero_() 52 | return self.M2 / (self.count - 1) 53 | 54 | @property 55 | def std(self): 56 | return self.var().sqrt() 57 | 58 | 59 | def accuracy(output, target, topk=(1,)): 60 | """Computes the precision@k for the specified values of k""" 61 | maxk = max(topk) 62 | batch_size = target.size(0) 63 | 64 | _, pred = output.topk(maxk, 1, True, True) 65 | pred = pred.t().type_as(target) 66 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 67 | 68 | res = [] 69 | for k in topk: 70 | correct_k = correct[:k].reshape(-1).float().sum(0) 71 | res.append(correct_k.mul_(100.0 / batch_size)) 72 | return res 73 | 74 | 75 | class AccuracyMeter(object): 76 | """Computes and stores the average and current topk accuracy""" 77 | 78 | def __init__(self, topk=(1,)): 79 | self.topk = topk 80 | self.reset() 81 | 82 | def reset(self): 83 | self._meters = {} 84 | for k in self.topk: 85 | self._meters[k] = AverageMeter() 86 | 87 | def update(self, output, target): 88 | n = target.nelement() 89 | acc_vals = accuracy(output, target, self.topk) 90 | for i, k in enumerate(self.topk): 91 | self._meters[k].update(acc_vals[i]) 92 | 93 | @property 94 | def val(self): 95 | return {n: meter.val for (n, meter) in self._meters.items()} 96 | 97 | @property 98 | def avg(self): 99 | return {n: meter.avg for (n, meter) in self._meters.items()} 100 | 101 | @property 102 | def avg_error(self): 103 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()} 104 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 6 | from contextlib import contextmanager 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | torch_dtypes = { 10 | 'float': torch.float, 11 | 'float32': torch.float32, 12 | 'float64': torch.float64, 13 | 'double': torch.double, 14 | 'float16': torch.float16, 15 | 'half': torch.half, 16 | 'uint8': torch.uint8, 17 | 'int8': torch.int8, 18 | 'int16': torch.int16, 19 | 'short': torch.short, 20 | 'int32': torch.int32, 21 | 'int': torch.int, 22 | 'int64': torch.int64, 23 | 'long': torch.long 24 | } 25 | 26 | 27 | def onehot(indexes, N=None, ignore_index=None): 28 | """ 29 | Creates a one-representation of indexes with N possible entries 30 | if N is not specified, it will suit the maximum index appearing. 31 | indexes is a long-tensor of indexes 32 | ignore_index will be zero in onehot representation 33 | """ 34 | if N is None: 35 | N = indexes.max() + 1 36 | sz = list(indexes.size()) 37 | output = indexes.new().byte().resize_(*sz, N).zero_() 38 | output.scatter_(-1, indexes.unsqueeze(-1), 1) 39 | if ignore_index is not None and ignore_index >= 0: 40 | output.masked_fill_(indexes.eq(ignore_index).unsqueeze(-1), 0) 41 | return output 42 | 43 | 44 | @contextmanager 45 | def no_bn_update(model): 46 | prev_momentum = {} 47 | for name, module in model.named_modules(): 48 | if isinstance(module, nn.BatchNorm2d): 49 | prev_momentum[name] = module.momentum 50 | module.momentum = 0 51 | try: 52 | yield model 53 | finally: 54 | for name, module in model.named_modules(): 55 | if isinstance(module, nn.BatchNorm2d): 56 | module.momentum = prev_momentum[name] 57 | 58 | 59 | @contextmanager 60 | def calibrate_bn(model): 61 | prev_attributes = {} 62 | for name, module in model.named_modules(): 63 | if isinstance(module, _BatchNorm): 64 | prev_attributes[name] = prev_attributes.get(name, {}) 65 | prev_attributes[name]['momentum'] = module.momentum 66 | prev_attributes[name]['track_running_stats'] = module.track_running_stats 67 | prev_attributes[name]['training'] = module.training 68 | module.momentum = None 69 | module.track_running_stats = True 70 | module.reset_running_stats() 71 | module.train() 72 | try: 73 | yield model 74 | finally: 75 | for name, module in model.named_modules(): 76 | if isinstance(module, _BatchNorm): 77 | module.momentum = prev_attributes[name]['momentum'] 78 | module.track_running_stats = prev_attributes[name]['track_running_stats'] 79 | module.train(prev_attributes[name]['training']) 80 | 81 | 82 | 83 | def set_global_seeds(i): 84 | try: 85 | import torch 86 | except ImportError: 87 | pass 88 | else: 89 | torch.manual_seed(i) 90 | if torch.cuda.is_available(): 91 | torch.cuda.manual_seed_all(i) 92 | np.random.seed(i) 93 | random.seed(i) 94 | 95 | 96 | class CheckpointModule(nn.Module): 97 | def __init__(self, module, num_segments=1): 98 | super(CheckpointModule, self).__init__() 99 | assert num_segments == 1 or isinstance(module, nn.Sequential) 100 | self.module = module 101 | self.num_segments = num_segments 102 | 103 | def forward(self, x): 104 | if self.num_segments > 1: 105 | return checkpoint_sequential(self.module, self.num_segments, x) 106 | else: 107 | return checkpoint(self.module, x) 108 | -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from numpy.random import beta 5 | from torch.nn.functional import one_hot 6 | 7 | 8 | class MixUp(nn.Module): 9 | def __init__(self, alpha=1., batch_dim=0): 10 | super(MixUp, self).__init__() 11 | self.batch_dim = batch_dim 12 | self.alpha = alpha 13 | self.reset() 14 | 15 | def reset(self): 16 | self.mix_value = None 17 | self.mix_index = None 18 | 19 | def mix(self, x1, x2): 20 | return x2.lerp(x1, self.mix_value) 21 | 22 | def sample(self, batch_size, alpha=None, even=False): 23 | alpha = self.alpha if alpha is None else alpha 24 | self.mix_index = torch.randperm(batch_size) 25 | self.mix_value = beta(alpha, alpha) 26 | if not even: 27 | self.mix_value = max(self.mix_value, 1 - self.mix_value) 28 | 29 | def mix_target(self, y, n_class): 30 | if not self.training or \ 31 | self.mix_value is None: 32 | return y 33 | y = one_hot(y, n_class).to(dtype=torch.float) 34 | idx = self.mix_index.to(device=y.device) 35 | y_mix = y.index_select(self.batch_dim, idx) 36 | return self.mix(y, y_mix) 37 | 38 | def forward(self, x): 39 | if not self.training or \ 40 | self.mix_value is None: 41 | return x 42 | idx = self.mix_index.to(device=x.device) 43 | x_mix = x.index_select(self.batch_dim, idx) 44 | return self.mix(x, x_mix) 45 | 46 | 47 | def rand_bbox(size, lam): 48 | W = size[2] 49 | H = size[3] 50 | cut_rat = np.sqrt(1. - lam) 51 | cut_w = np.int(W * cut_rat) 52 | cut_h = np.int(H * cut_rat) 53 | 54 | # uniform 55 | cx = np.random.randint(W) 56 | cy = np.random.randint(H) 57 | 58 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 59 | bby1 = np.clip(cy - cut_h // 2, 0, H) 60 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 61 | bby2 = np.clip(cy + cut_h // 2, 0, H) 62 | 63 | return bbx1, bby1, bbx2, bby2 64 | 65 | 66 | class CutMix(MixUp): 67 | def __init__(self, alpha=1., batch_dim=0): 68 | super(CutMix, self).__init__(alpha, batch_dim) 69 | 70 | def mix_image(self, x1, x2): 71 | assert not torch.is_tensor(self.mix_value) or \ 72 | self.mix_value.nelement() == 1 73 | lam = float(self.mix_value) 74 | bbx1, bby1, bbx2, bby2 = rand_bbox(x1.size(), lam) 75 | x1[:, :, bbx1:bbx2, bby1:bby2] = x2[:, :, bbx1:bbx2, bby1:bby2] 76 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / 77 | (x1.size()[-1] * x1.size()[-2])) 78 | self.mix_value = lam 79 | return x1 80 | 81 | def sample(self, batch_size, alpha=None): 82 | super(CutMix, self).sample(batch_size, alpha=alpha) 83 | 84 | def forward(self, x): 85 | if not self.training or \ 86 | self.mix_value is None: 87 | return x 88 | idx = self.mix_index.to(device=x.device) 89 | x_mix = x.index_select(self.batch_dim, idx) 90 | return self.mix_image(x, x_mix) 91 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging.config 3 | import math 4 | from math import floor 5 | from copy import deepcopy 6 | from six import string_types 7 | from .regime import Regime 8 | from .param_filter import FilterParameters 9 | from . import regularization 10 | import torch.nn as nn 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | 13 | _OPTIMIZERS = {name: func for name, func in torch.optim.__dict__.items()} 14 | _LRSCHEDULERS = {name: func for name, 15 | func in torch.optim.lr_scheduler.__dict__.items()} 16 | 17 | try: 18 | from adabound import AdaBound 19 | _OPTIMIZERS['AdaBound'] = AdaBound 20 | except ImportError: 21 | pass 22 | 23 | 24 | def cosine_anneal_lr(lr0, lrT, T, t0=0): 25 | return f"lambda t: {{'lr': {lrT} + {(lr0 - lrT)} * (1 + math.cos(math.pi * (t - {t0}) / {T-t0})) / 2}}" 26 | 27 | 28 | def linear_scale_lr(lr0, lrT, T, t0=0): 29 | rate = (lrT - lr0) / T 30 | return f"lambda t: {{'lr': max({lr0} + (t - {t0}) * {rate}, 0)}}" 31 | 32 | 33 | class _EmptySchedule(torch.optim.lr_scheduler._LRScheduler): 34 | def __init__(self, optimizer, last_epoch=-1): 35 | super(_EmptySchedule, self).__init__(optimizer, last_epoch=-1) 36 | self.last_epoch = 0 37 | 38 | def step(self, epoch=None): 39 | if epoch is None: 40 | epoch = self.last_epoch + 1 41 | 42 | 43 | def copy_params(param_target, param_src): 44 | with torch.no_grad(): 45 | for p_src, p_target in zip(param_src, param_target): 46 | p_target.copy_(p_src) 47 | 48 | 49 | def copy_params_grad(param_target, param_src): 50 | for p_src, p_target in zip(param_src, param_target): 51 | if p_target.grad is None: 52 | p_target.backward(p_src.grad.to(dtype=p_target.dtype)) 53 | else: 54 | p_target.grad.detach().copy_(p_src.grad) 55 | 56 | 57 | class ModuleFloatShadow(nn.Module): 58 | def __init__(self, module): 59 | super(ModuleFloatShadow, self).__init__() 60 | self.original_module = module 61 | self.float_module = deepcopy(module) 62 | self.float_module.to(dtype=torch.float) 63 | 64 | def parameters(self, *kargs, **kwargs): 65 | return self.float_module.parameters(*kargs, **kwargs) 66 | 67 | def named_parameters(self, *kargs, **kwargs): 68 | return self.float_module.named_parameters(*kargs, **kwargs) 69 | 70 | def modules(self, *kargs, **kwargs): 71 | return self.float_module.modules(*kargs, **kwargs) 72 | 73 | def named_modules(self, *kargs, **kwargs): 74 | return self.float_module.named_modules(*kargs, **kwargs) 75 | 76 | def original_parameters(self, *kargs, **kwargs): 77 | return self.original_module.parameters(*kargs, **kwargs) 78 | 79 | def original_named_parameters(self, *kargs, **kwargs): 80 | return self.original_module.named_parameters(*kargs, **kwargs) 81 | 82 | def original_modules(self, *kargs, **kwargs): 83 | return self.original_module.modules(*kargs, **kwargs) 84 | 85 | def original_named_modules(self, *kargs, **kwargs): 86 | return self.original_module.named_modules(*kargs, **kwargs) 87 | 88 | 89 | class OptimRegime(Regime, torch.optim.Optimizer): 90 | """ 91 | Reconfigures the optimizer according to setting list. 92 | Exposes optimizer methods - state, step, zero_grad, add_param_group 93 | 94 | Examples for regime: 95 | 96 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 97 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 98 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 99 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 100 | ]" 101 | 2) 102 | "[{'step_lambda': 103 | "lambda t: { 104 | 'optimizer': 'Adam', 105 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 106 | 'betas': (0.9, 0.98), 'eps':1e-9} 107 | }]" 108 | """ 109 | 110 | def __init__(self, model, regime, defaults={}, filter=None, use_float_copy=False, log=True): 111 | super(OptimRegime, self).__init__(regime, defaults) 112 | if filter is not None: 113 | model = FilterParameters(model, **filter) 114 | if use_float_copy: 115 | model = ModuleFloatShadow(model) 116 | self._original_parameters = list(model.original_parameters()) 117 | 118 | self.parameters = list(model.parameters()) 119 | self.optimizer = torch.optim.SGD(self.parameters, lr=0) 120 | self.regularizer = regularization.Regularizer(model) 121 | self.use_float_copy = use_float_copy 122 | self.lr_scheduler = _EmptySchedule(self.optimizer, last_epoch=-1) 123 | self.schedule_time_frame = 'epoch' 124 | self.log = log 125 | 126 | def update(self, epoch=None, train_steps=None, metrics=None): 127 | """adjusts optimizer according to current epoch or steps and training regime. 128 | """ 129 | updated = False 130 | if super(OptimRegime, self).update(epoch, train_steps): 131 | self.adjust(self.setting) 132 | updated = True 133 | if self.schedule_time_frame == 'epoch': 134 | time = int(floor(epoch)) + 1 135 | elif self.schedule_time_frame == 'step': 136 | time = train_steps 137 | else: 138 | raise ValueError 139 | if (time != self.lr_scheduler.last_epoch) and \ 140 | getattr(self.optimizer, '_step_count', 0) > 0: 141 | prev_lr = self.get_lr()[0] 142 | if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 143 | self.lr_scheduler.step(metrics) 144 | self.lr_scheduler.step() 145 | updated = True 146 | if prev_lr != self.get_lr()[0] and self.log: 147 | logging.debug('OPTIMIZER - lr scheduled = %s' 148 | % self.get_lr()[0]) 149 | return updated 150 | 151 | def adjust(self, setting): 152 | """adjusts optimizer according to a setting dict. 153 | e.g: setting={optimizer': 'Adam', 'lr': 5e-4} 154 | """ 155 | reset = setting.get('reset', False) 156 | if 'optimizer' in setting or reset: 157 | optim_method = _OPTIMIZERS[setting.get('optimizer', 'SGD')] 158 | if reset: # reset the optimizer cache: 159 | self.optimizer = torch.optim.SGD(self.parameters, lr=0) 160 | if self.log: 161 | logging.debug('OPTIMIZER - reset setting') 162 | if not isinstance(self.optimizer, optim_method): 163 | self.optimizer = optim_method(self.optimizer.param_groups) 164 | if self.log: 165 | logging.debug('OPTIMIZER - setting method = %s' % 166 | setting['optimizer']) 167 | for param_group in self.optimizer.param_groups: 168 | for key in param_group.keys(): 169 | if key in setting: 170 | new_val = setting[key] 171 | if new_val != param_group[key]: 172 | if self.log: 173 | logging.debug('OPTIMIZER - setting %s = %s' % 174 | (key, setting[key])) 175 | param_group[key] = setting[key] 176 | if key == 'lr': 177 | param_group['initial_lr'] = param_group['lr'] 178 | base_lrs = list(map(lambda group: group['lr'], 179 | self.optimizer.param_groups)) 180 | self.lr_scheduler.base_lrs = base_lrs 181 | 182 | # fix for AdaBound 183 | if hasattr(self.optimizer, 'base_lrs'): 184 | self.optimizer.base_lrs = base_lrs 185 | 186 | if 'regularizer' in setting: 187 | reg_list = deepcopy(setting['regularizer']) 188 | if not (isinstance(reg_list, list) or isinstance(reg_list, tuple)): 189 | reg_list = (reg_list,) 190 | regularizers = [] 191 | for reg in reg_list: 192 | if isinstance(reg, dict): 193 | name = reg.pop('name') 194 | regularizers.append((regularization.__dict__[name], reg)) 195 | elif isinstance(reg, regularization.Regularizer): 196 | regularizers.append(reg) 197 | else: # callable on model 198 | regularizers.append(reg(self.regularizer._model)) 199 | self.regularizer = regularization.RegularizerList(self.regularizer._model, 200 | regularizers) 201 | 202 | if 'lr_scheduler' in setting: 203 | schedule_config = setting['lr_scheduler'] 204 | if isinstance(schedule_config, _LRScheduler): 205 | self.lr_scheduler = schedule_config 206 | elif isinstance(schedule_config, dict): 207 | name = schedule_config.pop('name') 208 | self.schedule_time_frame = schedule_config.pop('time_frame', 209 | 'epoch') 210 | schedule_config['last_epoch'] = self.lr_scheduler.last_epoch 211 | self.lr_scheduler = _LRSCHEDULERS[name](self.optimizer, 212 | **schedule_config) 213 | elif schedule_config is None: 214 | self.lr_scheduler = _EmptySchedule(self.optimizer, 215 | last_epoch=self.lr_scheduler.last_epoch) 216 | else: # invalid config 217 | raise NotImplementedError 218 | 219 | def __getstate__(self): 220 | return { 221 | 'optimizer_state': self.optimizer.__getstate__(), 222 | 'regime': self.regime, 223 | } 224 | 225 | def __setstate__(self, state): 226 | self.regime = state.get('regime') 227 | self.optimizer.__setstate__(state.get('optimizer_state')) 228 | 229 | def state_dict(self): 230 | """Returns the state of the optimizer as a :class:`dict`. 231 | """ 232 | return self.optimizer.state_dict() 233 | 234 | def load_state_dict(self, state_dict): 235 | """Loads the optimizer state. 236 | 237 | Arguments: 238 | state_dict (dict): optimizer state. Should be an object returned 239 | from a call to :meth:`state_dict`. 240 | """ 241 | # deepcopy, to be consistent with module API 242 | self.optimizer.load_state_dict(state_dict) 243 | 244 | def zero_grad(self): 245 | """Clears the gradients of all optimized :class:`Variable` s.""" 246 | self.optimizer.zero_grad() 247 | if self.use_float_copy: 248 | for p in self._original_parameters: 249 | if p.grad is not None: 250 | p.grad.detach().zero_() 251 | 252 | def pre_step(self): 253 | if self.use_float_copy: 254 | copy_params_grad(self.parameters, self._original_parameters) 255 | self.regularizer.pre_step() 256 | 257 | def post_step(self): 258 | self.regularizer.post_step() 259 | 260 | if self.use_float_copy: 261 | copy_params(self._original_parameters, self.parameters) 262 | 263 | def step(self, *args, **kwargs): 264 | """Performs a single optimization step (parameter update). 265 | """ 266 | self.pre_step() 267 | self.optimizer.step(*args, **kwargs) 268 | self.post_step() 269 | 270 | def pre_forward(self): 271 | """ allows modification pre-forward pass - e.g for regularization 272 | """ 273 | self.regularizer.pre_forward() 274 | 275 | def pre_backward(self): 276 | """ allows modification post-forward pass and pre-backward - e.g for regularization 277 | """ 278 | self.regularizer.pre_backward() 279 | 280 | def get_value(self, key): 281 | return [group[key] for group in self.optimizer.param_groups] 282 | 283 | def get_lr(self): 284 | return self.get_value('lr') 285 | 286 | @property 287 | def state(self): 288 | return self.optimizer.state 289 | 290 | 291 | class MultiOptimRegime(OptimRegime): 292 | 293 | def __init__(self, *optim_regime_list, log=True): 294 | self.optim_regime_list = [] 295 | for optim_regime in optim_regime_list: 296 | assert isinstance(optim_regime, OptimRegime) 297 | self.optim_regime_list.append(optim_regime) 298 | self.log = log 299 | 300 | def update(self, epoch=None, train_steps=None): 301 | """adjusts optimizer according to current epoch or steps and training regime. 302 | """ 303 | updated = False 304 | for i, optim in enumerate(self.optim_regime_list): 305 | current_updated = optim.update(epoch, train_steps) 306 | if current_updated and self.log: 307 | logging.debug('OPTIMIZER #%s was updated' % i) 308 | updated = updated or current_updated 309 | return updated 310 | 311 | def zero_grad(self): 312 | """Clears the gradients of all optimized :class:`Variable` s.""" 313 | for optim in self.optim_regime_list: 314 | optim.zero_grad() 315 | 316 | def step(self): 317 | """Performs a single optimization step (parameter update). 318 | """ 319 | for optim in self.optim_regime_list: 320 | optim.step() 321 | 322 | def pre_forward(self): 323 | for optim in self.optim_regime_list: 324 | optim.pre_forward() 325 | 326 | def pre_backward(self): 327 | for optim in self.optim_regime_list: 328 | optim.pre_backward() 329 | 330 | def __repr__(self): 331 | return str([str(optim) for optim in self.optim_regime_list]) 332 | 333 | def get_value(self, key): 334 | return [[group[key] for group in optim.optimizer.param_groups] 335 | for optim in self.optim_regime_list] 336 | 337 | def get_lr(self): 338 | return self.get_value('lr') 339 | -------------------------------------------------------------------------------- /param_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | def _search_pattern_fn(pattern): 7 | def _search(name): 8 | return re.search(pattern, name) is not None 9 | return _search 10 | 11 | 12 | def _search_type_pattern_fn(pattern): 13 | def _search(var): 14 | return re.search(pattern, type(var).__name__) is not None 15 | return _search 16 | 17 | 18 | def is_not_bias(name): 19 | return not name.endswith('bias') 20 | 21 | 22 | def is_bn(module): 23 | return isinstance(module, nn.modules.batchnorm._BatchNorm) 24 | 25 | 26 | def is_not_bn(module): 27 | return not is_bn(module) 28 | 29 | 30 | def _negate_fn(fn): 31 | if fn is None: 32 | return None 33 | else: 34 | def _negate(*kargs, **kwargs): 35 | return not fn(*kargs, **kwargs) 36 | return _negate 37 | 38 | 39 | def filtered_parameter_info(model, module_fn=None, module_name_fn=None, parameter_name_fn=None, memo=None): 40 | if memo is None: 41 | memo = set() 42 | 43 | for module_name, module in model.named_modules(): 44 | if module_fn is not None and not module_fn(module): 45 | continue 46 | if module_name_fn is not None and not module_name_fn(module_name): 47 | continue 48 | for parameter_name, param in module.named_parameters(prefix=module_name, recurse=False): 49 | if parameter_name_fn is not None and not parameter_name_fn(parameter_name): 50 | continue 51 | if param not in memo: 52 | memo.add(param) 53 | yield {'named_module': (module_name, module), 'named_parameter': (parameter_name, param)} 54 | 55 | 56 | class FilterParameters(object): 57 | def __init__(self, source, module=None, module_name=None, parameter_name=None, exclude=False): 58 | if isinstance(module_name, str): 59 | module_name = _search_pattern_fn(module_name) 60 | if isinstance(parameter_name, str): 61 | parameter_name = _search_pattern_fn(parameter_name) 62 | if isinstance(module, str): 63 | module = _search_type_pattern_fn(module) 64 | if exclude: 65 | module_name = _negate_fn(module_name) 66 | parameter_name = _negate_fn(parameter_name) 67 | module = _negate_fn(module) 68 | if isinstance(source, FilterParameters): 69 | self._filtered_parameter_info = list(source.filter( 70 | module=module, 71 | module_name=module_name, 72 | parameter_name=parameter_name)) 73 | elif isinstance(source, torch.nn.Module): # source is a model 74 | self._filtered_parameter_info = list(filtered_parameter_info(source, 75 | module_fn=module, 76 | module_name_fn=module_name, 77 | parameter_name_fn=parameter_name)) 78 | 79 | def named_parameters(self): 80 | for p in self._filtered_parameter_info: 81 | yield p['named_parameter'] 82 | 83 | def parameters(self): 84 | for _, p in self.named_parameters(): 85 | yield p 86 | 87 | def filter(self, module=None, module_name=None, parameter_name=None): 88 | for p_info in self._filtered_parameter_info: 89 | if (module is None or module(p_info['named_module'][1]) 90 | and (module_name is None or module_name(p_info['named_module'][0])) 91 | and (parameter_name is None or parameter_name(p_info['named_parameter'][0]))): 92 | yield p_info 93 | 94 | def named_modules(self): 95 | for m in self._filtered_parameter_info: 96 | yield m['named_module'] 97 | 98 | def modules(self): 99 | for _, m in self.named_modules(): 100 | yield m 101 | 102 | def to(self, *kargs, **kwargs): 103 | for m in self.modules(): 104 | m.to(*kargs, **kwargs) 105 | 106 | 107 | class FilterModules(FilterParameters): 108 | pass 109 | 110 | 111 | if __name__ == '__main__': 112 | from torchvision.models import resnet50 113 | model = resnet50() 114 | filterd_params = FilterParameters(model, 115 | module=lambda m: isinstance( 116 | m, torch.nn.Linear), 117 | parameter_name=lambda n: 'bias' in n) 118 | -------------------------------------------------------------------------------- /quantize.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point']) 6 | 7 | 8 | def quantize_tensor(x, num_bits=8): 9 | qmin = 0. 10 | qmax = 2.**num_bits - 1. 11 | min_val, max_val = x.min(), x.max() 12 | 13 | scale = (max_val - min_val) / (qmax - qmin) 14 | 15 | initial_zero_point = qmin - min_val / scale 16 | 17 | zero_point = 0 18 | if initial_zero_point < qmin: 19 | zero_point = qmin 20 | elif initial_zero_point > qmax: 21 | zero_point = qmax 22 | else: 23 | zero_point = initial_zero_point 24 | 25 | zero_point = int(zero_point) 26 | q_x = zero_point + x / scale 27 | q_x.clamp_(qmin, qmax).round_() 28 | q_x = q_x.round().byte() 29 | return QTensor(tensor=q_x, scale=scale, zero_point=zero_point) 30 | 31 | 32 | def dequantize_tensor(q_x): 33 | return q_x.scale * (q_x.tensor.float() - q_x.zero_point) 34 | 35 | 36 | def quantize_model(model): 37 | qparams = {} 38 | 39 | for n, p in model.state_dict().items(): 40 | qp = quantize_tensor(p) 41 | qparams[n + '.quantization.scale'] = torch.FloatTensor([qp.scale]) 42 | qparams[ 43 | n + '.quantization.zero_point'] = torch.ByteTensor([qp.zero_point]) 44 | p.copy_(qp.tensor) 45 | model.type('torch.ByteTensor') 46 | for n, p in qparams.items(): 47 | model.register_buffer(n, p) 48 | model.quantized = True 49 | 50 | 51 | def dequantize_model(model): 52 | model.float() 53 | params = model.state_dict() 54 | for n, p in params.items(): 55 | if 'quantization' not in n: 56 | qp = QTensor(tensor=p, 57 | scale=params[n + '.quantization.scale'][0], 58 | zero_point=params[n + '.quantization.zero_point'][0]) 59 | p.copy_(dequantize_tensor(qp)) 60 | model.register_buffer(n + '.quantization.scale', None) 61 | model.register_buffer(n + '.quantization.zero_point', None) 62 | model.quantized = None 63 | -------------------------------------------------------------------------------- /recorder.py: -------------------------------------------------------------------------------- 1 | from .regime import Regime 2 | import logging.config 3 | from os import path 4 | import torch 5 | 6 | 7 | class Recorder(Regime): 8 | def __init__(self, regime, defaults={}): 9 | self.regime = regime 10 | self.current_regime_phase = None 11 | self.setting = defaults 12 | self.measurments = None 13 | 14 | def get_steps(self): 15 | return [item['step'] for item in self.regime] 16 | 17 | @staticmethod 18 | def load(filename, drop_items=[]): 19 | try: 20 | measurments = torch.load( 21 | filename, map_location='cpu') 22 | for item in drop_items: 23 | measurments.pop(item) 24 | return measurments 25 | except FileNotFoundError: 26 | return None 27 | 28 | def update(self, train_steps=None, measurments={}): 29 | """adjusts optimizer according to current epoch or steps and training regime. 30 | """ 31 | updated = False 32 | if super(Recorder, self).update(train_steps=train_steps): 33 | save_file = self.setting.get('save', None) 34 | if save_file is not None: 35 | # filename = path.join(self.file_prefix, f'{train_steps}.record') 36 | torch.save(measurments, save_file) 37 | logging.debug(f'Saved measurments to {save_file}') 38 | load_file = self.setting.get('load', None) 39 | if load_file is not None: 40 | logging.debug(f'Loaded measurments from {load_file}') 41 | self.measurments = self.load(load_file, 42 | drop_items=self.setting.get('drop_items', [])) 43 | 44 | updated = True 45 | return updated 46 | -------------------------------------------------------------------------------- /regime.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from copy import deepcopy 4 | from six import string_types 5 | 6 | 7 | def eval_func(f, x): 8 | if isinstance(f, string_types): 9 | f = eval(f) 10 | return f(x) 11 | 12 | 13 | class Regime(object): 14 | """ 15 | Examples for regime: 16 | 17 | 1) "[{'epoch': 0, 'optimizer': 'Adam', 'lr': 1e-3}, 18 | {'epoch': 2, 'optimizer': 'Adam', 'lr': 5e-4}, 19 | {'epoch': 4, 'optimizer': 'Adam', 'lr': 1e-4}, 20 | {'epoch': 8, 'optimizer': 'Adam', 'lr': 5e-5} 21 | ]" 22 | 2) 23 | "[{'step_lambda': 24 | "lambda t: { 25 | 'optimizer': 'Adam', 26 | 'lr': 0.1 * min(t ** -0.5, t * 4000 ** -1.5), 27 | 'betas': (0.9, 0.98), 'eps':1e-9} 28 | }]" 29 | """ 30 | 31 | def __init__(self, regime, defaults={}): 32 | self.regime = regime 33 | self.defaults = defaults 34 | self.reset(regime, defaults) 35 | 36 | def reset(self, regime=None, defaults=None): 37 | if regime is not None: 38 | self.regime = regime 39 | if defaults is not None: 40 | self.defaults = defaults 41 | self.current_regime_phase = None 42 | self.setting = self.defaults 43 | 44 | def update(self, epoch=None, train_steps=None): 45 | """adjusts according to current epoch or steps and regime. 46 | """ 47 | if self.regime is None: 48 | return False 49 | epoch = -1 if epoch is None else epoch 50 | train_steps = -1 if train_steps is None else train_steps 51 | setting = deepcopy(self.setting) 52 | if self.current_regime_phase is None: 53 | # Find the first entry where the epoch is smallest than current 54 | for regime_phase, regime_setting in enumerate(self.regime): 55 | start_epoch = regime_setting.get('epoch', 0) 56 | start_step = regime_setting.get('step', 0) 57 | if epoch >= start_epoch or train_steps >= start_step: 58 | self.current_regime_phase = regime_phase 59 | break 60 | # each entry is updated from previous 61 | setting.update(regime_setting) 62 | if len(self.regime) > self.current_regime_phase + 1: 63 | next_phase = self.current_regime_phase + 1 64 | # Any more regime steps? 65 | start_epoch = self.regime[next_phase].get('epoch', float('inf')) 66 | start_step = self.regime[next_phase].get('step', float('inf')) 67 | if epoch >= start_epoch or train_steps >= start_step: 68 | self.current_regime_phase = next_phase 69 | setting.update(self.regime[self.current_regime_phase]) 70 | 71 | if 'lr_decay_rate' in setting and 'lr' in setting: 72 | decay_steps = setting.pop('lr_decay_steps', 100) 73 | if train_steps % decay_steps == 0: 74 | decay_rate = setting.pop('lr_decay_rate') 75 | setting['lr'] *= decay_rate ** (train_steps / decay_steps) 76 | elif 'step_lambda' in setting: 77 | setting.update(eval_func(setting.pop('step_lambda'), train_steps)) 78 | elif 'epoch_lambda' in setting: 79 | setting.update(eval_func(setting.pop('epoch_lambda'), epoch)) 80 | 81 | if 'execute' in setting: 82 | setting.pop('execute')() 83 | 84 | if 'execute_once' in setting: 85 | setting.pop('execute_once')() 86 | # remove from regime, so won't happen again 87 | self.regime[self.current_regime_phase].pop('execute_once', None) 88 | 89 | if setting == self.setting: 90 | return False 91 | else: 92 | self.setting = setting 93 | return True 94 | 95 | def __repr__(self): 96 | return 'Current: %s\n Regime:%s' % (self.setting, self.regime) 97 | -------------------------------------------------------------------------------- /regularization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .param_filter import FilterParameters, is_not_bn, is_not_bias 3 | from .absorb_bn import search_absorb_bn 4 | from torch.nn.utils import clip_grad_norm_ 5 | import logging 6 | 7 | 8 | def sparsity(p): 9 | return float(p.eq(0).sum()) / p.nelement() 10 | 11 | 12 | def _norm(x, dim, p=2): 13 | """Computes the norm over all dimensions except dim""" 14 | if p == -1: 15 | def func(x, dim): return x.max(dim=dim)[0] - x.min(dim=dim)[0] 16 | elif p == float('inf'): 17 | def func(x, dim): return x.max(dim=dim)[0] 18 | else: 19 | def func(x, dim): return torch.norm(x, dim=dim, p=p) 20 | if dim is None: 21 | return x.norm(p=p) 22 | elif dim == 0: 23 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 24 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 25 | elif dim == x.dim() - 1: 26 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 27 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 28 | else: 29 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 30 | 31 | 32 | def _param_grad_norm(parameters): 33 | total_norm = 0 34 | for p in parameters: 35 | param_norm = p.grad.data.norm(2) 36 | total_norm += param_norm.item() ** 2 37 | return total_norm ** 0.5 38 | 39 | 40 | class Regularizer(object): 41 | def __init__(self, model, value=0, filter={}, log=False): 42 | self._model = model 43 | self._named_parameters = list( 44 | FilterParameters(model, **filter).named_parameters()) 45 | self.value = value 46 | self.log = log 47 | if self.log: 48 | logging.debug('Applying regularization to parameters: %s', 49 | [n for n, _ in self._named_parameters]) 50 | 51 | def named_parameters(self): 52 | for n, p in self._named_parameters: 53 | yield n, p 54 | 55 | def parameters(self): 56 | for _, p in self.named_parameters(): 57 | yield p 58 | 59 | def _pre_parameter_step(self, parameter): 60 | pass 61 | 62 | def _post_parameter_step(self, parameter): 63 | pass 64 | 65 | def pre_step(self): 66 | pass 67 | 68 | def post_step(self): 69 | pass 70 | 71 | def pre_forward(self): 72 | pass 73 | 74 | def pre_backward(self): 75 | pass 76 | 77 | def params_with_grads(self, avoid_none_grads=False): 78 | return [(p, p.grad) for p in self.parameters() 79 | if (not avoid_none_grads or p.grad is not None)] 80 | 81 | 82 | class RegularizerList(Regularizer): 83 | def __init__(self, model, regularization_list): 84 | """each item must of of format (RegClass, **kwargs) or instance of Regularizer""" 85 | super(RegularizerList, self).__init__(model) 86 | self.regularization_list = [] 87 | for regularizer in regularization_list: 88 | if not isinstance(regularizer, Regularizer): 89 | reg, reg_params = regularizer 90 | regularizer = reg(model=model, **reg_params) 91 | self.regularization_list.append(regularizer) 92 | 93 | def pre_step(self): 94 | for reg in self.regularization_list: 95 | reg.pre_step() 96 | 97 | def post_step(self): 98 | for reg in self.regularization_list: 99 | reg.post_step() 100 | 101 | def pre_forward(self): 102 | for reg in self.regularization_list: 103 | reg.pre_forward() 104 | 105 | def pre_backward(self): 106 | for reg in self.regularization_list: 107 | reg.pre_backward() 108 | 109 | 110 | class L2Regularization(Regularizer): 111 | def __init__(self, model, value=0, 112 | filter={'parameter_name': is_not_bias, 113 | 'module': is_not_bn}, 114 | pre_op=True, post_op=False, **kwargs): 115 | super(L2Regularization, self).__init__( 116 | model, value, filter=filter, **kwargs) 117 | self.pre_op = pre_op 118 | self.post_op = post_op 119 | 120 | def pre_step(self): 121 | if self.pre_op: 122 | with torch.no_grad(): 123 | params, grads = zip(*self.params_with_grads()) 124 | torch._foreach_add_(grads, params, alpha=self.value) 125 | if self.log: 126 | logging.debug('L2 penalty of %s was applied pre optimization step', 127 | self.value) 128 | 129 | def post_step(self): 130 | if self.post_op: 131 | with torch.no_grad(): 132 | torch._foreach_mul_(list(self.parameters()), 1. - self.value) 133 | if self.log: 134 | logging.debug('L2 penalty of %s was applied post optimization step', 135 | self.value) 136 | 137 | 138 | class WeightDecay(L2Regularization): 139 | def __init__(self, *kargs, **kwargs): 140 | super(WeightDecay, self).__init__(*kargs, **kwargs) 141 | 142 | 143 | class GradClip(Regularizer): 144 | def __init__(self, *kargs, **kwargs): 145 | super(GradClip, self).__init__(*kargs, **kwargs) 146 | 147 | def pre_step(self): 148 | if self.value > 0: 149 | with torch.no_grad(): 150 | grad = clip_grad_norm_(self.parameters(), self.value) 151 | if self.log: 152 | logging.debug('Gradient norm value was clipped from %s to %s', 153 | grad, self.value) 154 | 155 | 156 | class GradSmooth(Regularizer): 157 | def __init__(self, model, value=True, momentum=0.9, filter={}, log=False): 158 | super(GradSmooth, self).__init__(model, 159 | value=value, filter=filter, log=log) 160 | self.momentum = momentum 161 | self.running_norm = None 162 | self.enabled = value 163 | self.counter = 0 164 | 165 | def pre_step(self): 166 | parameters = list( 167 | filter(lambda p: p.grad is not None, self.parameters())) 168 | total_norm = _param_grad_norm(parameters) 169 | if self.running_norm is None: 170 | self.running_norm = total_norm 171 | else: 172 | self.running_norm = self.momentum * self.running_norm 173 | + (1 - self.momentum) * total_norm 174 | if self.enabled: 175 | clip_coef = self.running_norm / (total_norm + 1e-6) 176 | for p in parameters: 177 | p.grad.data.mul_(clip_coef) 178 | if self.log: 179 | logging.debug('Gradient norm value was clipped from %s to %s', 180 | total_norm, self.running_norm) 181 | 182 | self.counter += 1 183 | 184 | 185 | class L1Regularization(Regularizer): 186 | def __init__(self, model, value=1e-3, 187 | filter={'parameter_name': is_not_bias, 188 | 'module': is_not_bn}, 189 | pre_op=False, post_op=True, report_sparsity=False, **kwargs): 190 | super(L1Regularization, self).__init__( 191 | model, value, filter=filter, **kwargs) 192 | self.pre_op = pre_op 193 | self.post_op = post_op 194 | self.report_sparsity = report_sparsity 195 | 196 | def pre_step(self): 197 | if self.pre_op: 198 | with torch.no_grad(): 199 | for n, p in self._named_parameters: 200 | p.grad.add_(p.sign(), alpha=self.value) 201 | if self.report_sparsity: 202 | logging.debug('Sparsity for %s is %s', n, sparsity(p)) 203 | if self.log: 204 | logging.debug('L1 penalty of %s was applied pre optimization step', 205 | self.value) 206 | 207 | def post_step(self): 208 | if self.post_op: 209 | with torch.no_grad(): 210 | for n, p in self._named_parameters: 211 | p.copy_(torch.nn.functional.softshrink(p, self.value)) 212 | if self.report_sparsity: 213 | logging.debug('Sparsity for %s is %s', n, sparsity(p)) 214 | if self.log: 215 | logging.debug('L1 penalty of %s was applied post optimization step', 216 | self.value) 217 | 218 | 219 | class BoundedWeightNorm(Regularizer): 220 | def __init__(self, model, 221 | filter={'parameter_name': is_not_bias, 222 | 'module': is_not_bn}, 223 | dim=0, p=2, **kwargs): 224 | super(BoundedWeightNorm, self).__init__( 225 | model, 0, filter=filter, **kwargs) 226 | self.dim = dim 227 | self.init_norms = None 228 | self.p = p 229 | 230 | def _gather_init_norm(self): 231 | self.init_norms = {} 232 | with torch.no_grad(): 233 | for n, p in self._named_parameters: 234 | self.init_norms[n] = _norm( 235 | p, self.dim, p=self.p).detach().mean() 236 | 237 | def pre_forward(self): 238 | if self.init_norms is None: 239 | self._gather_init_norm() 240 | with torch.no_grad(): 241 | for n, p in self._named_parameters: 242 | init_norm = self.init_norms[n] 243 | new_norm = _norm(p, self.dim, p=self.p) 244 | p.mul_(init_norm / new_norm) 245 | 246 | def pre_step(self): 247 | for n, p in self._named_parameters: 248 | init_norm = self.init_norms[n] 249 | norm = _norm(p, self.dim, p=self.p) 250 | curr_grad = p.grad.data.clone() 251 | p.grad.data.zero_() 252 | p_normed = p * (init_norm / norm) 253 | p_normed.backward(curr_grad) 254 | 255 | 256 | class LARS(Regularizer): 257 | """Large Batch Training of Convolutional Networks - https://arxiv.org/abs/1708.03888 258 | """ 259 | 260 | def __init__(self, model, value=0.01, weight_decay=0, dim=None, p=2, min_scale=None, max_scale=None, 261 | filter={'parameter_name': is_not_bias, 262 | 'module': is_not_bn}, 263 | **kwargs): 264 | super(LARS, self).__init__(model, value, filter=filter, **kwargs) 265 | self.weight_decay = weight_decay 266 | self.dim = dim 267 | self.p = p 268 | self.min_scale = min_scale 269 | self.max_scale = max_scale 270 | 271 | def pre_step(self): 272 | with torch.no_grad(): 273 | for _, param in self._named_parameters: 274 | param.grad.add_(param, alpha=self.weight_decay) 275 | if self.dim is not None: 276 | norm = _norm(param, dim=self.dim, p=self.p) 277 | grad_norm = _norm(param.grad, dim=self.dim, p=self.p) 278 | else: 279 | norm = param.norm(p=self.p) 280 | grad_norm = param.grad.norm(p=self.p) 281 | scale = self.value * norm / grad_norm 282 | if self.min_scale is not None or self.max_scale is not None: 283 | scale.clamp_(min=self.min_scale, max=self.max_scale) 284 | param.grad.mul_(scale) 285 | 286 | 287 | class DropConnect(Regularizer): 288 | def __init__(self, model, value=0, 289 | filter={'parameter_name': is_not_bias, 290 | 'module': is_not_bn}, 291 | shakeshake=False, **kwargs): 292 | super(DropConnect, self).__init__( 293 | model, value=value, filter=filter, **kwargs) 294 | self.shakeshake = shakeshake 295 | 296 | def _drop_parameters(self): 297 | self.parameter_copy = {} 298 | with torch.no_grad(): 299 | for n, p in self._named_parameters: 300 | self.parameter_copy[n] = p.clone() 301 | torch.nn.functional.dropout(p, self.value, 302 | training=True, inplace=True) 303 | 304 | def _reassign_parameters(self): 305 | with torch.no_grad(): 306 | for n, p in self._named_parameters: 307 | p.copy_(self.parameter_copy.pop(n)) 308 | 309 | def pre_forward(self): 310 | self._drop_parameters() 311 | 312 | def pre_backward(self): 313 | if self.shakeshake: 314 | self._reassign_parameters() 315 | 316 | def pre_step(self): 317 | if not self.shakeshake: 318 | self._reassign_parameters() 319 | 320 | 321 | class AbsorbBN(Regularizer): 322 | def __init__(self, model, remove_bn=False): 323 | self._model = model 324 | if not remove_bn: 325 | for m in model.modules(): 326 | if isinstance(m, torch.nn.BatchNorm2d): 327 | m.momentum = 1 328 | self.remove_bn = remove_bn 329 | self._removed = False 330 | 331 | def pre_forward(self): 332 | if self._removed: 333 | return 334 | search_absorb_bn(self._model, remove_bn=self.remove_bn, verbose=False) 335 | self._removed = self.remove_bn 336 | 337 | 338 | class Consolidate(Regularizer): 339 | """ groups is a list of tuples, each tuple is of consolidated parameters (of same size) 340 | """ 341 | 342 | def __init__(self, model, value=0.1, force=False, **kwargs): 343 | super(Consolidate, self).__init__( 344 | model, value, **kwargs) 345 | self.force = force 346 | shared_params = {} 347 | for name, param in self._named_parameters: 348 | numel = param.numel() 349 | shared_params.setdefault(numel, []) 350 | shared_params[numel].append((name, param)) 351 | 352 | self.groups = [] 353 | shared_names = [] 354 | for shared_list in shared_params.values(): 355 | if len(shared_list) < 2: 356 | continue 357 | names, params = zip(*shared_list) 358 | self.groups.append(params) 359 | shared_names.append(names) 360 | 361 | logging.debug('Shared parameter groups %s' % shared_names) 362 | 363 | def pre_step(self): 364 | with torch.no_grad(): 365 | for i, group in enumerate(self.groups): 366 | mean_group = sum(group) / len(group) 367 | for p in group: 368 | p.grad.data.add_(p - mean_group, alpha=self.value) 369 | if self.log: 370 | logging.debug('group %s, diff norm = %s' % 371 | (i, float(sum([(p - mean_group).norm() ** 2 for p in group]).sqrt()))) 372 | 373 | def post_step(self): 374 | if self.force: 375 | with torch.no_grad(): 376 | for group in enumerate(self.groups): 377 | mean_group = sum(group) / len(group) 378 | for p in group: 379 | p.copy_(mean_group) 380 | --------------------------------------------------------------------------------