├── images └── editable.png ├── .gitmodules ├── lib ├── __init__.py ├── utils │ ├── __init__.py │ ├── loader.py │ ├── copy_and_replace.py │ ├── init.py │ ├── learning_rate.py │ ├── gumbel.py │ ├── regularize.py │ ├── data.py │ ├── basic.py │ ├── ingraph_update.py │ └── trainer.py ├── loss.py ├── evaluate.py ├── trainer.py └── editable.py ├── mt ├── train.py ├── requirements.txt ├── README.md ├── prepare_iwslt14_de_en.sh ├── train.sh ├── evaluate.ipynb ├── generate_edit_datasets_samples.ipynb ├── edited_generate.py └── fairseq_criterion.py ├── requirements.txt ├── LICENSE ├── .gitignore ├── README.md └── notebooks ├── resnet.py ├── imagenet_preprocess_logits.ipynb ├── cifar10_editable_layer3.ipynb ├── imagenet_editable_training.ipynb ├── imagenet_evaluate_nae.ipynb └── imagenet_editable_training_with_natural_distribution.ipynb /images/editable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xtinkt/editable/HEAD/images/editable.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mt/fairseq"] 2 | path = mt/fairseq 3 | url = https://github.com/pytorch/fairseq.git 4 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .editable import * 3 | from .trainer import * 4 | from .evaluate import * 5 | from .loss import * 6 | -------------------------------------------------------------------------------- /mt/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("fairseq") 3 | import fairseq_criterion 4 | from train import cli_main() 5 | 6 | 7 | cli_main() 8 | -------------------------------------------------------------------------------- /mt/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1.0 2 | numpy>=0.13 3 | scipy>=1.2.0 4 | torchvision 5 | tqdm 6 | tensorboardX 7 | prefetch_generator 8 | mosestokenizer 9 | fairseq==0.8.0 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1.0 2 | numpy>=0.13 3 | scipy>=1.2.0 4 | scikit-learn>=0.17 5 | torchvision 6 | matplotlib 7 | tqdm 8 | tensorboardX 9 | pandas 10 | prefetch_generator 11 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import * 2 | from .data import * 3 | from .gumbel import * 4 | from .init import * 5 | from .regularize import * 6 | from .learning_rate import * 7 | from .ingraph_update import * 8 | from .copy_and_replace import * 9 | from .trainer import * 10 | from .loader import * 11 | -------------------------------------------------------------------------------- /mt/README.md: -------------------------------------------------------------------------------- 1 | ## Get started 2 | 1. Get fairseq submodule: `(git submodule init) && (git submodule update)` 3 | 2. Install fairseq: `cd fairseq && pip install --editable .` 4 | 3. Prepare iwslt14.de-en: `./prepare_iwslt14_de_en.sh` 5 | 4. Prepare edit samples: run notebook `generate_edit_datasets_samples.ipynb` 6 | 7 | ## Training 8 | Train Editable Transformer: `./train.sh` 9 | 10 | ## Evaluating 11 | Evaluate trained model run notebook: `evaluate.ipynb` 12 | -------------------------------------------------------------------------------- /mt/prepare_iwslt14_de_en.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download and prepare the data 4 | cd fairseq/examples/translation/ 5 | bash prepare-iwslt14.sh 6 | cd ../../.. 7 | 8 | # Preprocess/binarize the data 9 | TEXT=fairseq/examples/translation/iwslt14.tokenized.de-en 10 | fairseq-preprocess --source-lang de --target-lang en \ 11 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 12 | --destdir data-bin/iwslt14.tokenized.de-en \ 13 | --workers 20 14 | -------------------------------------------------------------------------------- /mt/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 train.py data-bin/iwslt14.tokenized.de-en \ 4 | --arch transformer_iwslt_de_en --share-decoder-input-output-embed \ 5 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 6 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ 7 | --dropout 0.3 --weight-decay 0.0001 --criterion editable_training_criterion \ 8 | --label-smoothing 0.1 --max-tokens 4096 --tensorboard-logdir train_editable_logs \ 9 | --edit-samples-path edit_iwslt14.tokenized.de-en/bpe_train.txt \ 10 | --save-dir checkpoints_editable \ 11 | --stability-coeff 100 \ 12 | --editability-coeff 100 13 | -------------------------------------------------------------------------------- /lib/utils/loader.py: -------------------------------------------------------------------------------- 1 | import torchvision.datasets as datasets 2 | import numpy as np 3 | import torch 4 | 5 | ''' 6 | Dataset loader that returns batch of images and precomputed logits 7 | ''' 8 | 9 | class ImageAndLogitsFolder(datasets.ImageFolder): 10 | 11 | def __init__(self, *args, logits_prefix, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.logits_prefix = logits_prefix 14 | 15 | @staticmethod 16 | def logits_path_create(prefix, path): 17 | return prefix + path[path.rfind("/") + 1:path.find(".j")] + ".npy" 18 | 19 | def get_image_path(self, index): 20 | return self.imgs[index] 21 | 22 | def __getitem__(self, index): 23 | img, target = super().__getitem__(index) 24 | 25 | logits_path = ImageAndLogitsFolder.logits_path_create( 26 | self.logits_prefix, 27 | self.get_image_path(index)[0] 28 | ) 29 | logits = np.load(logits_path) 30 | logits = torch.reshape(torch.Tensor(logits), (1, -1)) 31 | 32 | return img, target, logits 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Anton Sinitsin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # node and NPM 2 | npm-debug.log 3 | node_modules 4 | 5 | # swap files 6 | *~ 7 | *.swp 8 | 9 | notebooks/data/* 10 | notebooks/runs/* 11 | notebooks/.ipynb_checkpoints/* 12 | 13 | env.sh 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | bin/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | eggs/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg/ 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | 52 | # Mr Developer 53 | .mr.developer.cfg 54 | .project 55 | .pydevproject 56 | .idea 57 | .ipynb_checkpoints 58 | 59 | # Rope 60 | .ropeproject 61 | 62 | # Django stuff: 63 | *.log 64 | *.pot 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | docs/tmp* 69 | 70 | # OS X garbage 71 | .DS_Store 72 | 73 | # Debian things 74 | debian/reproducible-experiment-platform 75 | debian/files 76 | *.substvars 77 | *.debhelper.log 78 | 79 | .vscode 80 | 81 | #pytorch model weights 82 | *.pth 83 | -------------------------------------------------------------------------------- /lib/utils/copy_and_replace.py: -------------------------------------------------------------------------------- 1 | """ A convenience function for creating modified copies out-of-place with deepcopy """ 2 | from contextlib import contextmanager 3 | from copy import deepcopy 4 | 5 | DEFAULT_MEMO = dict() 6 | 7 | 8 | def copy_and_replace(original, replace=None, do_not_copy=None): 9 | """ 10 | :param original: object to be copied 11 | :param replace: a dictionary {old object -> new object}, replace all occurences of old object with new object 12 | :param do_not_copy: a sequence of objects that will not be copied (but may be replaced) 13 | :return: a copy of obj with replacements 14 | """ 15 | replace, do_not_copy = replace or {}, do_not_copy or {} 16 | memo = dict(DEFAULT_MEMO) 17 | for item in do_not_copy: 18 | memo[id(item)] = item 19 | 20 | for item, replacement in replace.items(): 21 | memo[id(item)] = replacement 22 | 23 | return deepcopy(original, memo) 24 | 25 | 26 | @contextmanager 27 | def do_not_copy(*items): 28 | """ all calls to copy_and_replace within this context won't copy items (but can replace them) """ 29 | global DEFAULT_MEMO 30 | keys_to_remove = [] 31 | for item in items: 32 | key = id(item) 33 | if key in DEFAULT_MEMO: 34 | DEFAULT_MEMO[key] = item 35 | keys_to_remove.append(key) 36 | 37 | yield 38 | 39 | for key in keys_to_remove: 40 | DEFAULT_MEMO.pop(key) 41 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions 3 | """ 4 | 5 | import torch 6 | from torch import nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def kl_distill_loss(logits, ref_probs): 11 | """ 12 | kullback leibler divergence 13 | """ 14 | return F.kl_div(F.log_softmax(logits, dim=-1), ref_probs) 15 | 16 | 17 | def contrastive_cross_entropy(logits, target, margin=0.0): 18 | """ 19 | A special loss that is similar to crossentropy but becomes exactly zero if 20 | logp(target) >= max(logp(all_excluding_target)) + margin 21 | Used for classification edits 22 | """ 23 | logp = F.log_softmax(logits, dim=-1) 24 | target_one_hot = F.one_hot(target, num_classes=logp.shape[-1]) 25 | logp_target = (logp * target_one_hot.to(logits.dtype)).sum(-1) 26 | logp_others = torch.where(target_one_hot.to(torch.uint8), torch.full_like(logp, -float('inf')), logp) 27 | return F.relu(margin + logp_others.max(dim=-1)[0] - logp_target).mean() 28 | 29 | 30 | def threshold_mse(predictions, targets, threshold=0.0, reduction_axes=None): 31 | """ 32 | Like mean squared error but becomes exactly zero if 33 | sum of squared errors along reduction axes is below threshold 34 | used for regression edits 35 | """ 36 | squared_error = (predictions - targets) ** 2 37 | if reduction_axes is not None: 38 | squared_error = squared_error.sum(reduction_axes) 39 | return F.relu(squared_error - threshold).mean() 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Editable neural networks 2 | 3 | A supplementary code for [Editable Neural Networks](https://openreview.net/forum?id=HJedXaEtvS), an ICLR 2020 submission by Anton Sinitsin, Vsevolod Plokhotnyuk, Dmitry Pyrkin, Sergei Popov, Artem Babenko. 4 | 5 | 6 | 7 | # What does it do? 8 | 9 | It trains a model so that it can later be edited: forced to predict a specific class on a specific input without losing accuracy. 10 | 11 | # What do i need to run it? 12 | * A machine with some CPU (preferably 2+ free cores) and GPU(s) 13 | * Running without GPU is possible but does not scale well, especially for ImageNet 14 | * Some popular Linux x64 distribution 15 | * Tested on Ubuntu16.04, should work fine on any popular linux64 and even MacOS; 16 | * Windows and x32 systems may require heavy wizardry to run; 17 | * When in doubt, use Docker, preferably GPU-enabled (i.e. nvidia-docker) 18 | 19 | # How do I run it? 20 | 1. Clone or download this repo. `cd` yourself to it's root directory. 21 | 2. Grab or build a working python enviromnent. [Anaconda](https://www.anaconda.com/) works fine. 22 | 3. Install packages from `requirements.txt` 23 | 4. Run jupyter notebook and open a notebook in `./notebooks/` 24 | * Before you run the first cell, change `%env CUDA_VISIBLE_DEVICES=#` to an index that you plan to use. 25 | * [CIFAR10 notebook](./notebooks/cifar10_editable_layer3.ipynb) can be ran with no extra preparation 26 | * The ImageNet notebooks require a step-by-step procedure to get running: 27 | 1. Download the dataset first. See [this page](https://pytorch.org/docs/stable/_modules/torchvision/datasets/imagenet.html) or just google it. No, really, go google it! 28 | 2. Run [`imagenet_preprocess_logits.ipynb`](./notebooks/imagenet_preprocess_logits.ipynb) 29 | 3. Train with [`imagenet_editable_training.ipynb`](./notebooks/imagenet_editable_training.ipynb) 30 | 4. Evaluate by using one of the two remaining notebooks. 31 | * To reproduce machine translation experiments, follow the instructions in [`./mt/README.md`](./mt/) 32 | -------------------------------------------------------------------------------- /lib/utils/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import nn as nn 4 | from torch.jit import ScriptModule 5 | from torch.nn import functional as F 6 | 7 | 8 | class ModuleWithInit(nn.Module): 9 | """ Base class for pytorch module with data-aware initializer on first batch """ 10 | def __init__(self): 11 | super().__init__() 12 | assert not hasattr(self, '_is_initialized_bool') 13 | assert not hasattr(self, '_is_initialized_tensor') 14 | self._is_initialized_tensor = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False) 15 | self._is_initialized_bool = None 16 | # Note: this module uses a separate flag self._is_initialized_* so as to achieve both 17 | # * persistence: is_initialized is saved alongside model in state_dict 18 | # * speed: model doesn't need to cache 19 | # please DO NOT use these flags in child modules 20 | 21 | def initialize(self, *args, **kwargs): 22 | """ initialize module tensors using first batch of data """ 23 | raise NotImplementedError("Please implement ") 24 | 25 | def is_initialized(self): 26 | """ whether data aware initialization was already performed """ 27 | if self._is_initialized_bool is None: 28 | self._is_initialized_bool = bool(self._is_initialized_tensor.item()) 29 | return self._is_initialized_bool 30 | 31 | def __call__(self, *args, **kwargs): 32 | if self._is_initialized_bool is None: 33 | self._is_initialized_bool = bool(self._is_initialized_tensor.item()) 34 | if not self._is_initialized_bool: 35 | self.initialize(*args, **kwargs) 36 | self._is_initialized_tensor.data[...] = 1 37 | self._is_initialized_bool = True 38 | return super().__call__(*args, **kwargs) 39 | 40 | 41 | class ScriptModuleWithInit(ModuleWithInit, ScriptModule): 42 | """ Base class for pytorch module with data-aware initializer on first batch """ 43 | def __init__(self, optimize=True, **kwargs): 44 | ScriptModule.__init__(self, optimize=optimize, **kwargs) 45 | ModuleWithInit.__init__(self) 46 | 47 | 48 | def init_normalized_(x, init_=nn.init.normal_, dim=-1, **kwargs): 49 | """ initialize x inp-place by sampling random normal values and normalizing them over dim """ 50 | init_(x) 51 | x.data = F.normalize(x, dim=dim, **kwargs) 52 | return x 53 | -------------------------------------------------------------------------------- /lib/utils/learning_rate.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from torch.optim.optimizer import Optimizer 3 | import numpy as np 4 | 5 | 6 | def get_learning_rate(optimizer): 7 | for param_group in optimizer.param_groups: 8 | if 'lr' in param_group: 9 | return param_group['lr'] 10 | raise ValueError("Could not infer learning rate from optimizer {}".format(optimizer)) 11 | 12 | 13 | class OneCycleSchedule(Optimizer): 14 | """ A simplified torch lr schedule that updates learning rate before every opt.step """ 15 | 16 | def __init__(self, optimizer, **kwargs): 17 | """ 18 | :type optimizer: torch.optim.Optimizer 19 | :param kwargs: see self.update_learning_rate 20 | """ 21 | self.learning_rate_opts = kwargs 22 | self.opt = optimizer 23 | self.step_count = 0 24 | 25 | def step(self, **kwargs): 26 | self.current_lr = self.update_learning_rate(t=self.step_count, **self.learning_rate_opts) 27 | res = self.opt.step(**kwargs) 28 | self.step_count += 1 29 | return res 30 | 31 | def state_dict(self, **kwargs): 32 | return OrderedDict([ 33 | ('optimizer_state_dict', self.opt.state_dict(**kwargs)), 34 | ('learning_rate_opts', self.learning_rate_opts), 35 | ('step_count', self.step_count) 36 | ]) 37 | 38 | def load_state_dict(self, state_dict, load_step=True, load_opts=True, **kwargs): 39 | self.learning_rate_opts = state_dict['learning_rate_opts'] if load_opts else self.learning_rate_opts 40 | self.step_count = state_dict['step_count'] if load_step else self.step_count 41 | return self.opt.load_state_dict(state_dict['optimizer_state_dict'], **kwargs) 42 | 43 | def __getattr__(self, attr): 44 | if attr in self.__dict__: 45 | return getattr(self, attr) 46 | return getattr(self.opt, attr) 47 | 48 | def update_learning_rate(self, t, learning_rate_base=1e-3, warmup_steps=10000, 49 | decay_rate=0.2, learning_rate_min=1e-5): 50 | """ Learning rate with linear warmup and exponential decay """ 51 | lr = learning_rate_base * np.minimum( 52 | (t + 1.0) / warmup_steps, 53 | np.exp(decay_rate * ((warmup_steps - t - 1.0) / warmup_steps)), 54 | ) 55 | lr = np.maximum(lr, learning_rate_min) 56 | for param_group in self.opt.param_groups: 57 | param_group['lr'] = lr 58 | return lr 59 | -------------------------------------------------------------------------------- /lib/utils/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .basic import to_one_hot 3 | 4 | 5 | def gumbel_noise(*sizes, epsilon=1e-9, **kwargs): 6 | """ Sample noise from gumbel distribution """ 7 | return -torch.log(-torch.log(torch.rand(*sizes, **kwargs) + epsilon) + epsilon) 8 | 9 | 10 | def gumbel_softmax(logits, dim=-1, tau=1.0, noise=1.0, hard=False, **kwargs): 11 | """ 12 | Softmax with gumbel noise 13 | :param logits: inputs for softmax 14 | :param dim: normalize softmax along this dimension 15 | :param tau: gumbel softmax temperature 16 | :param hard: if True, works like onehot(sample) during forward pass, 17 | gumbel-softmax for backward pass 18 | :return: gumbel-softmax "probabilities", tensor of same shape as logits 19 | """ 20 | if noise != 0: 21 | z = gumbel_noise(*logits.shape, device=logits.device, dtype=logits.dtype) 22 | logits = logits + noise * z 23 | if tau != 1.0: 24 | logits = logits / tau 25 | 26 | probs_gumbel = torch.softmax(logits, dim=dim) 27 | 28 | if hard: 29 | _, argmax_indices = torch.max(probs_gumbel, dim=dim) 30 | hard_argmax_onehot = to_one_hot(argmax_indices, depth=logits.shape[dim]) 31 | if dim != -1 and dim != len(logits.shape) - 1: 32 | new_dim_order = list(range(len(logits.shape) - 1)) 33 | new_dim_order.insert(dim, -1) 34 | hard_argmax_onehot = hard_argmax_onehot.permute(*new_dim_order) 35 | 36 | # forward pass: onehot sample, backward pass: gumbel softmax 37 | probs_gumbel = (hard_argmax_onehot - probs_gumbel).detach() + probs_gumbel 38 | 39 | return probs_gumbel 40 | 41 | 42 | def gumbel_sigmoid(logits, tau=1.0, noise=1.0, hard=False, **kwargs): 43 | """ 44 | A special case of gumbel softmax with 2 classes: [logit] and 0 45 | :param logits: sigmoid inputs 46 | :param tau: same as gumbel softmax temperature 47 | :param hard: if True, works like bernoulli sample for forward pass, 48 | gumbel sigmoid for backward pass 49 | :return: tensor with same shape as logits 50 | """ 51 | if noise != 0.0: 52 | z1 = gumbel_noise(*logits.shape, device=logits.device, dtype=logits.dtype) 53 | z2 = gumbel_noise(*logits.shape, device=logits.device, dtype=logits.dtype) 54 | logits = logits + noise *(z1 - z2) 55 | if tau != 1.0: 56 | logits /= tau 57 | sigm = torch.sigmoid(logits) 58 | if hard: 59 | hard_sample = torch.ge(sigm, 0.5).to(dtype=logits.dtype) 60 | sigm = (hard_sample - sigm).detach() + sigm 61 | return sigm 62 | -------------------------------------------------------------------------------- /lib/utils/regularize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context-based regularizers, thread-safe 3 | 4 | Usage: 5 | >>> with collect_regularizers() as regularizers: 6 | >>> y1 = model1(x) 7 | >>> y2 = model2(x) 8 | >>> 9 | >>> do_something([regularizers[module]['my_activation'] for module in regularizers]) 10 | 11 | Inside model1 and model2 code: 12 | >>> <...> 13 | >>> activation = self.foo(x) 14 | >>> if is_regularized('my_activation') 15 | >>> add_regularizer(self, 'my_activation', activation) 16 | >>> output = self.bar(activation) 17 | >>> return output 18 | 19 | """ 20 | from contextlib import contextmanager 21 | from collections import defaultdict 22 | import threading 23 | from warnings import warn 24 | 25 | 26 | REGULARIZERS_KEYS = None 27 | REGULARIZERS = None 28 | tls = threading.local() 29 | 30 | 31 | @contextmanager 32 | def collect_regularizers(collection=None, keys=None, within_thread=None): 33 | if within_thread is None: 34 | if threading.current_thread() is not threading.main_thread(): 35 | warn("Calling collect_regularizers while not in main thread, please set within_thread explicitly") 36 | within_thread = threading.current_thread() == threading.main_thread() 37 | 38 | if collection is None: 39 | collection = defaultdict(lambda: defaultdict(list)) 40 | 41 | global REGULARIZERS, REGULARIZERS_KEYS 42 | setattr(tls, 'REGULARIZERS', getattr(tls, 'REGULARIZERS', None)) 43 | setattr(tls, 'REGULARIZERS_KEYS', getattr(tls, 'REGULARIZERS_KEYS', None)) 44 | 45 | _old_regs, _old_keys = REGULARIZERS, REGULARIZERS_KEYS 46 | _old_local_regs, _old_local_keys = tls.REGULARIZERS, tls.REGULARIZERS_KEYS 47 | try: 48 | if within_thread: 49 | tls.REGULARIZERS, tls.REGULARIZERS_KEYS = collection, keys 50 | REGULARIZERS, REGULARIZERS_KEYS = None, None 51 | else: 52 | REGULARIZERS, REGULARIZERS_KEYS = collection, keys 53 | tls.REGULARIZERS, tls.REGULARIZERS_KEYS = None, None 54 | 55 | yield collection 56 | finally: 57 | REGULARIZERS = _old_regs 58 | REGULARIZERS_KEYS = _old_keys 59 | tls.REGULARIZERS = _old_local_regs 60 | tls.REGULARIZERS_KEYS = _old_local_keys 61 | 62 | 63 | def get_regularized_keys(): 64 | is_local = hasattr(tls, 'REGULARIZERS') 65 | return getattr(tls, 'REGULARIZERS_KEYS', REGULARIZERS_KEYS) if is_local else REGULARIZERS_KEYS 66 | 67 | 68 | def get_regularizer_collection(): 69 | is_local = hasattr(tls, 'REGULARIZERS') 70 | return getattr(tls, 'REGULARIZERS', None) if is_local else REGULARIZERS 71 | 72 | 73 | def is_regularized(key): 74 | if get_regularizer_collection() is None: 75 | return False 76 | keys = get_regularized_keys() 77 | return keys is None or key in keys 78 | 79 | 80 | def add_regularizer(module, key, value): 81 | assert is_regularized(key) 82 | get_regularizer_collection()[module][key].append(value) 83 | -------------------------------------------------------------------------------- /lib/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from tqdm import tqdm 5 | import torch.nn.functional as F 6 | 7 | from .utils import infer_model_device, training_mode, process_in_chunks, check_numpy, nop 8 | from .editable import Editable 9 | 10 | 11 | def classification_error(model: nn.Module, X_test, y_test, batch_size=1024, device=None): 12 | device = device or infer_model_device(model) 13 | with torch.no_grad(), training_mode(model, is_train=False): 14 | val_logits = process_in_chunks( 15 | model, torch.as_tensor(X_test, device=device), batch_size=batch_size) 16 | val_logits = check_numpy(val_logits) 17 | error_rate = (check_numpy(y_test) != np.argmax(val_logits, axis=1)).mean() 18 | return error_rate 19 | 20 | def eval_func(model: nn.Module, X_test, batch_size=1024, device='cuda'): 21 | with torch.no_grad(), training_mode(model, is_train=False): 22 | val_logits = process_in_chunks( 23 | model, torch.as_tensor(X_test, device=device), batch_size=batch_size) 24 | val_logits = check_numpy(val_logits) 25 | return val_logits 26 | 27 | def calculate_edit_statistics(editable_model: Editable, X_test, y_test, X_edit, y_edit, 28 | error_function=classification_error, progressbar=None, **kwargs): 29 | """ 30 | For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality 31 | :param editable_model: model to be edited 32 | :param X_test: data for quality evaluaton 33 | :param y_test: targets for quality evaluaton 34 | :param X_edit: sequence of data for training model on 35 | :param y_edit: sequence of targets for training model on 36 | :param error_function: function that measures quality 37 | :param progressbar: iterator(range) that may or may not print progress, use progressbar=True for tqdm.tqdm 38 | :param kwargs: extra parameters for model.edit 39 | :return: list of results of experiments 40 | """ 41 | progressbar = tqdm if progressbar is True else progressbar or nop 42 | results_temporary = [] 43 | with training_mode(editable_model, is_train=False): 44 | for i in progressbar(range(len(X_edit))): 45 | edited_model, success, loss, complexity = editable_model.edit( 46 | X_edit[i:i + 1], y_edit[i:i + 1], detach=True, **kwargs) 47 | results_temporary.append((error_function(edited_model, X_test, y_test), success, complexity)) 48 | return results_temporary 49 | 50 | 51 | def evaluate_quality(editable_model: Editable, X_test, y_test, X_edit, y_edit, 52 | error_function=classification_error, progressbar=None, **kwargs): 53 | """ 54 | For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality 55 | :param editable_model: model to be edited 56 | :param X_test: data for quality evaluaton 57 | :param y_test: targets for quality evaluaton 58 | :param X_edit: sequence of data for training model on 59 | :param y_edit: sequence of targets for training model on 60 | :param error_function: function that measures quality 61 | :param kwargs: extra parameters for model.edit 62 | :return: dictionary of metrics 63 | """ 64 | base_error = error_function(editable_model, X_test, y_test) 65 | results_temporary = calculate_edit_statistics(editable_model, X_test, y_test, X_edit, y_edit, 66 | progressbar=progressbar, error_function=error_function, **kwargs) 67 | errors, succeses, complexities = zip(*results_temporary) 68 | drawdown = np.mean(errors) - base_error 69 | return dict(base_error=base_error, drawdown=drawdown, success_rate=np.mean(succeses), 70 | mean_complexity=np.mean(complexities)) 71 | -------------------------------------------------------------------------------- /lib/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | 5 | from .evaluate import classification_error, evaluate_quality 6 | from .utils import training_mode, BaseTrainer 7 | from .editable import Editable 8 | from .loss import kl_distill_loss 9 | 10 | 11 | class EditableTrainer(BaseTrainer): 12 | def __init__(self, model: Editable, loss_function, error_function=classification_error, opt=None, 13 | stability_coeff=0.01, editability_coeff=0.01, max_norm=None, **kwargs): 14 | """ A simple optimizer that trains to minimize classification or regression loss """ 15 | opt = opt if opt is not None else torch.optim.Adam(model.parameters()) 16 | super().__init__(model, loss_function=loss_function, opt=opt, error_function=error_function, **kwargs) 17 | self.stability_coeff, self.editability_coeff, self.max_norm = stability_coeff, editability_coeff, max_norm 18 | 19 | def train_on_batch(self, x_batch, y_batch, x_edit, y_edit, prefix='train/', is_train=True, **kwargs): 20 | """ Performs a single gradient update and reports metrics """ 21 | x_batch, y_batch = map(torch.as_tensor, (x_batch, y_batch)) 22 | self.opt.zero_grad() 23 | 24 | with training_mode(self.model, is_train=is_train): 25 | logits = self.model(x_batch) 26 | 27 | main_loss = self.loss_function(logits, y_batch).mean() 28 | 29 | with training_mode(self.model, is_train=False): 30 | model_edited, success, editability_loss, complexity = self.model.edit(x_edit, y_edit, **kwargs) 31 | logits_updated = model_edited(x_batch) 32 | 33 | stability_loss = - (F.softmax(logits.detach(), dim=1) * F.log_softmax(logits_updated, dim=1)).sum(dim=1).mean() 34 | 35 | final_loss = main_loss + self.stability_coeff * stability_loss + self.editability_coeff * editability_loss 36 | 37 | metrics = dict( 38 | final_loss=final_loss.item(), stability_loss=stability_loss.item(), 39 | editability_loss=editability_loss.item(), main_loss=main_loss.item(), 40 | ) 41 | 42 | final_loss.backward() 43 | 44 | if self.max_norm is not None: 45 | metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.max_norm) 46 | self.opt.step() 47 | 48 | return self.record(**metrics, prefix=prefix) 49 | 50 | def evaluate_metrics(self, X, y, X_edit=None, y_edit=None, prefix='val/', **kwargs): 51 | """ 52 | For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality 53 | :param X: data for quality evaluaton 54 | :param y: targets for quality evaluaton 55 | :param X_edit: sequence of data for training model on 56 | :param y_edit: sequence of targets for training model on 57 | :param prefix: tensorboard metrics will be written under this prefix 58 | :param kwargs: extra parameters for error function 59 | :return: dictionary of metrics 60 | """ 61 | assert (X_edit is None) == (y_edit is None), "provide either both X_edit and y_edit or none of them" 62 | if X_edit is None: 63 | num_classes = y.max() + 1 64 | ind = np.random.permutation(len(X))[:10] 65 | X_edit = X[ind] 66 | y_edit = (y[ind] + torch.randint_like(y[ind], 1, num_classes)) % num_classes 67 | 68 | return self.record(**evaluate_quality( 69 | self.model, X, y, X_edit, y_edit, error_function=self.error_function, **kwargs), prefix=prefix) 70 | 71 | def extra_repr(self): 72 | line = "stability_coeff = {}, editability_coeff = {}, max_norm = {}".format( 73 | self.stability_coeff, self.editability_coeff, self.max_norm) 74 | line += '\nloss = {} '.format(self.loss_function) 75 | line += '\nopt = {} '.format(self.opt) 76 | return line 77 | 78 | 79 | class DistillationEditableTrainer(EditableTrainer): 80 | def __init__(self, model, **kwargs): 81 | return super().__init__(model, loss_function=kl_distill_loss, **kwargs) 82 | 83 | def train_on_batch(self, x_batch, logits_batch, *args, **kwargs): 84 | return super().train_on_batch(x_batch, logits_batch, *args, is_train=True, **kwargs) 85 | -------------------------------------------------------------------------------- /mt/evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import re\n", 10 | "import os\n", 11 | "import numpy as np\n", 12 | "import tqdm\n", 13 | "\n", 14 | "fout_path = 'bleus_editable.txt'\n", 15 | "if os.path.isfile(fout_path):\n", 16 | " raise ValueError(\"File already exists\")\n", 17 | "\n", 18 | "def get_edit_bleu(checkpoint_path, edit_sample_index=None):\n", 19 | " edit_sample_str = ''\n", 20 | " if edit_sample_index is not None:\n", 21 | " edit_sample_str = '--edit-sample-index {}'.format(edit_sample_index)\n", 22 | "\n", 23 | " res = !python edited_generate.py data-bin/iwslt14.tokenized.de-en \\\n", 24 | " --path {checkpoint_path} \\\n", 25 | " --edit-samples-path edit_iwslt14.tokenized.de-en/bpe_test.txt \\\n", 26 | " {edit_sample_str} \\\n", 27 | " --no-progress-bar \\\n", 28 | " --beam 5 --remove-bpe --sacrebleu --moses-detokenizer de \n", 29 | " \n", 30 | " if edit_sample_index is not None:\n", 31 | " bleu_id = -2\n", 32 | " else:\n", 33 | " bleu_id = -1\n", 34 | "\n", 35 | " try:\n", 36 | " bleu = float(re.findall('BLEU\\(score=(\\d+(\\.\\d+|)),', res[bleu_id])[0][0])\n", 37 | " except:\n", 38 | " bleu = None\n", 39 | " \n", 40 | " if edit_sample_index is not None:\n", 41 | " try:\n", 42 | " success, complexity = re.findall('EditResult\\(success=(True|False), complexity=(\\d+)\\)', res[-1])[0]\n", 43 | " success = eval(success)\n", 44 | " complexity = int(complexity)\n", 45 | " except:\n", 46 | " success = complexity = None\n", 47 | " \n", 48 | " with open(fout_path, 'a') as f:\n", 49 | " print(bleu, success, complexity, file=f)\n", 50 | " else:\n", 51 | " success = complexity = None\n", 52 | " \n", 53 | " return bleu, success, complexity" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "get_edit_bleu('checkpoints_editable/checkpoint_best.pt')" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "np.random.seed(42)\n", 72 | "ids = np.random.permutation(6749)[:1000]\n", 73 | "[get_edit_bleu('checkpoints_editable/checkpoint_best.pt', id) for id in tqdm.tqdm_notebook(ids)]" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "bleus = []\n", 83 | "successes = []\n", 84 | "complexities = []\n", 85 | "with open('bleus_samples_stability1_kl_almost_last.txt') as f:\n", 86 | " for line in f:\n", 87 | " bleu, success, complexity = line[:-1].split()\n", 88 | " \n", 89 | " try:\n", 90 | " bleu = float(bleu)\n", 91 | " success = bool(success)\n", 92 | " complexity = float(complexity)\n", 93 | " except:\n", 94 | " bleu = 0\n", 95 | " success = False\n", 96 | " complexity = -1\n", 97 | " \n", 98 | " bleus.append(bleu)\n", 99 | " successes.append(success)\n", 100 | " complexities.append(complexity)\n", 101 | " \n", 102 | "bleus = np.array(bleus)\n", 103 | "successes = np.array(successes)\n", 104 | "complexities = np.array(complexities)\n", 105 | "\n", 106 | "print('Mean BLEU:', np.mean(bleus))\n", 107 | "print('Success rate:', np.mean(successes))\n", 108 | "print('Mean complexity:', np.mean(complexities))" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": "Python 3", 115 | "language": "python", 116 | "name": "python3" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.7.4" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 2 133 | } 134 | -------------------------------------------------------------------------------- /notebooks/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(**kwargs): 101 | return ResNet(BasicBlock, [2,2,2,2], **kwargs) 102 | 103 | def ResNet34(): 104 | return ResNet(BasicBlock, [3,4,6,3]) 105 | 106 | def ResNet50(): 107 | return ResNet(Bottleneck, [3,4,6,3]) 108 | 109 | def ResNet101(): 110 | return ResNet(Bottleneck, [3,4,23,3]) 111 | 112 | def ResNet152(): 113 | return ResNet(Bottleneck, [3,8,36,3]) 114 | 115 | 116 | def test(): 117 | net = ResNet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /notebooks/imagenet_preprocess_logits.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "__This notebook__ computes teacher logits for editable fine-tuning on Imagenet" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "env: CUDA_VISIBLE_DEVICES=5\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "%load_ext autoreload\n", 25 | "%autoreload 2\n", 26 | "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n", 27 | "\n", 28 | "imagenet_train_path = '../../imagenet/train' # path to train ImageFolder\n", 29 | "logits_path = './imagenet_logits/' # saves logits to this path\n", 30 | "\n", 31 | "import os, sys, time\n", 32 | "sys.path.insert(0, '..')\n", 33 | "import lib\n", 34 | "\n", 35 | "import numpy as np\n", 36 | "import torch, torch.nn as nn\n", 37 | "import torch.nn.functional as F\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "%matplotlib inline\n", 40 | "\n", 41 | "import random\n", 42 | "random.seed(42)\n", 43 | "np.random.seed(42)\n", 44 | "torch.random.manual_seed(42)\n", 45 | "\n", 46 | "import time\n", 47 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 2, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import torchvision.transforms as transforms\n", 57 | "import torchvision.datasets as datasets\n", 58 | "import torchvision.models as models\n", 59 | "\n", 60 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 61 | " std=[0.229, 0.224, 0.225])\n", 62 | "\n", 63 | "train_dataset = datasets.ImageFolder(\n", 64 | " imagenet_train_path,\n", 65 | " transforms.Compose([\n", 66 | " transforms.RandomResizedCrop(224),\n", 67 | " transforms.RandomHorizontalFlip(),\n", 68 | " transforms.ToTensor(),\n", 69 | " normalize,\n", 70 | " ]))\n", 71 | "\n", 72 | "batch_size = 64\n", 73 | "\n", 74 | "train_loader = torch.utils.data.DataLoader(\n", 75 | " train_dataset, batch_size=batch_size, shuffle=False,\n", 76 | " num_workers=1, pin_memory=True)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from tqdm import tqdm_notebook, tnrange\n", 86 | "from IPython.display import clear_output\n", 87 | "\n", 88 | "import torchvision\n", 89 | "\n", 90 | "image_index = 0\n", 91 | "\n", 92 | "def get_logits(model, X_test):\n", 93 | " with lib.training_mode(model, is_train=False):\n", 94 | " return lib.eval_func(lib.Lambda(lambda x: model(x.to(device))),\n", 95 | " X_test, device='cuda', batch_size=64)\n", 96 | "\n", 97 | "model = torchvision.models.resnet18(pretrained=True).to(device)\n", 98 | "\n", 99 | "for batch in tqdm_notebook(train_loader):\n", 100 | " logits = get_logits(model, batch[0])\n", 101 | " for logit in logits:\n", 102 | " image_name = train_loader.dataset.imgs[image_index][0]\n", 103 | " image_name = \"{}/{}\".format(logits_path, image_name[image_name.rfind(\"/\") + 1:image_name.find(\".j\")])\n", 104 | " np.save(image_name, torch.softmax(logits, dim=-1))\n", 105 | " image_index += 1" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "9.9G\timagenet_logits\r\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "!du -h imagenet_logits" 123 | ] 124 | } 125 | ], 126 | "metadata": { 127 | "kernelspec": { 128 | "display_name": "Python 3", 129 | "language": "python", 130 | "name": "python3" 131 | }, 132 | "language_info": { 133 | "codemirror_mode": { 134 | "name": "ipython", 135 | "version": 3 136 | }, 137 | "file_extension": ".py", 138 | "mimetype": "text/x-python", 139 | "name": "python", 140 | "nbconvert_exporter": "python", 141 | "pygments_lexer": "ipython3", 142 | "version": "3.6.8" 143 | } 144 | }, 145 | "nbformat": 4, 146 | "nbformat_minor": 2 147 | } 148 | -------------------------------------------------------------------------------- /lib/utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import hashlib 4 | import numpy as np 5 | import requests 6 | from tqdm import tqdm 7 | import torch 8 | 9 | 10 | def download(url, filename, delete_if_interrupted=True, chunk_size=4096): 11 | """ saves file from url to filename with a fancy progressbar """ 12 | try: 13 | with open(filename, "wb") as f: 14 | print("Downloading {} > {}".format(url, filename)) 15 | response = requests.get(url, stream=True) 16 | total_length = response.headers.get('content-length') 17 | 18 | if total_length is None: # no content length header 19 | f.write(response.content) 20 | else: 21 | total_length = int(total_length) 22 | with tqdm(total=total_length) as progressbar: 23 | for data in response.iter_content(chunk_size=chunk_size): 24 | if data: # filter-out keep-alive chunks 25 | f.write(data) 26 | progressbar.update(len(data)) 27 | except Exception as e: 28 | if delete_if_interrupted: 29 | print("Removing incomplete download {}.".format(filename)) 30 | os.remove(filename) 31 | raise e 32 | return filename 33 | 34 | 35 | def iterate_minibatches(*tensors, batch_size, shuffle=True, epochs=1, 36 | allow_incomplete=True, callback=lambda x: x): 37 | """ 38 | Minibatch iterator that's faster than torch dataloader when dealing with 39 | very simple objects and hundreds of batches per second. 40 | :param tensors: tensors to iterate over, all tensors must have shape [dataset_size, ...] 41 | compatible with np arrays, torch tensors and anything that can be indexed via tensor[[i0, i1, i2]] 42 | :param batch_size: maximum number of objects in one yield 43 | :param shuffle: if True, each epoch runs over data tensors in shuffled order 44 | :param epochs: number of full passes over train data 45 | :param allow_incomplete: if dataset_size is not divisible by batch_size, 46 | if True, the last batch in epoch will have less than batch_size objects 47 | if False, the last batch in epoch will be omitted 48 | :param callback: a wrapper for batch iterator, insert your tqdm/progressbar here :) 49 | """ 50 | indices = np.arange(len(tensors[0])) 51 | upper_bound = int((np.ceil if allow_incomplete else np.floor)(len(indices) / batch_size)) * batch_size 52 | epoch = 0 53 | while True: 54 | if shuffle: 55 | np.random.shuffle(indices) 56 | for batch_start in callback(range(0, upper_bound, batch_size)): 57 | if batch_start == len(indices): break 58 | batch_ix = indices[batch_start: batch_start + batch_size] 59 | batch = [tensor[batch_ix] for tensor in tensors] 60 | yield batch if len(tensors) > 1 else batch[0] 61 | epoch += 1 62 | if epoch >= epochs: 63 | break 64 | 65 | 66 | def process_in_chunks(function, *args, batch_size, out=None, **kwargs): 67 | """ 68 | Computes output by applying batch-parallel function to large data tensor in chunks 69 | :param function: a function(*[x[indices, ...] for x in args]) -> out[indices, ...] 70 | :param args: one or many tensors, each [num_instances, ...] 71 | :param batch_size: maximum chunk size processed in one go 72 | :param out: memory buffer for out, defaults to torch.zeros of appropriate size and type 73 | :returns: function(data), computed in a memory-efficient way 74 | """ 75 | total_size = args[0].shape[0] 76 | first_output = function(*[x[0: batch_size] for x in args]) 77 | output_shape = (total_size,) + tuple(first_output.shape[1:]) 78 | if out is None: 79 | out = torch.zeros(*output_shape, dtype=first_output.dtype, device=first_output.device, 80 | layout=first_output.layout, **kwargs) 81 | 82 | out[0: batch_size] = first_output 83 | for i in range(batch_size, total_size, batch_size): 84 | batch_ix = slice(i, min(i + batch_size, total_size)) 85 | out[batch_ix] = function(*[x[batch_ix] for x in args]) 86 | return out 87 | 88 | 89 | def check_numpy(x): 90 | """ Makes sure x is a numpy array """ 91 | if isinstance(x, torch.Tensor): 92 | x = x.detach().cpu().numpy() 93 | x = np.asarray(x) 94 | assert isinstance(x, np.ndarray) 95 | return x 96 | 97 | 98 | def get_latest_file(pattern): 99 | list_of_files = glob.glob(pattern) # * means all if need specific format then *.csv 100 | assert len(list_of_files) > 0, "No files found: " + pattern 101 | return max(list_of_files, key=os.path.getctime) 102 | 103 | 104 | def md5sum(fname): 105 | """ Computes mdp checksum of a file """ 106 | hash_md5 = hashlib.md5() 107 | with open(fname, "rb") as f: 108 | for chunk in iter(lambda: f.read(4096), b""): 109 | hash_md5.update(chunk) 110 | return hash_md5.hexdigest() 111 | -------------------------------------------------------------------------------- /mt/generate_edit_datasets_samples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "import os\n", 11 | "from collections import defaultdict\n", 12 | "import random\n", 13 | "import numpy as np\n", 14 | "\n", 15 | "np.random.seed(1234)\n", 16 | "random.seed(1234)\n", 17 | "\n", 18 | "%env CUDA_VISIBLE_DEVICES=0\n", 19 | "\n", 20 | "\n", 21 | "src_lang, dst_lang = 'de', 'en'\n", 22 | "checkpoint_path = 'baseline_checkpoint.pt'\n", 23 | "data_path = 'data-bin/iwslt14.tokenized.{}-{}/'.format(src_lang, dst_lang)\n", 24 | "output_path = 'edit_iwslt14.tokenized.{}-{}'.format(src_lang, dst_lang)\n", 25 | "beam = nbest = 32\n", 26 | "max_tokens = 1024\n", 27 | "keys = ['valid', 'test', 'train']\n", 28 | "\n", 29 | "\n", 30 | "sys.path.append('fairseq')" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "if os.path.isdir(output_path):\n", 40 | " if len(os.listdir(output_path)) > 0:\n", 41 | " raise ValueError('Output directory {} is not empty'.format(output_path))\n", 42 | "else:\n", 43 | " os.makedirs(output_path, exist_ok=True)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "if not os.path.isfile(checkpoint_path):\n", 53 | " ! wget https://www.dropbox.com/s/iksezig9qi4g92e/baseline_checkpoint.pt?dl=1 -O {checkpoint_path}" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "from fairseq.data.dictionary import Dictionary\n", 63 | "\n", 64 | "src_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(src_lang)))\n", 65 | "dst_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(dst_lang)))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "def generate_alternatives(key, file):\n", 75 | " !fairseq-generate {data_path} --path {checkpoint_path} \\\n", 76 | " --beam {beam} --nbest {nbest} \\\n", 77 | " --max-tokens {max_tokens} \\\n", 78 | " --gen-subset {key} > {file}" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "def generate_edits(key, file):\n", 88 | " !fairseq-generate {data_path} --path {checkpoint_path} \\\n", 89 | " --beam {beam} --nbest {nbest} \\\n", 90 | " --max-tokens {max_tokens} \\\n", 91 | " --sampling --temperature 1.2 \\\n", 92 | " --gen-subset {key} > {file}" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "for key in keys:\n", 102 | " tmp_alternatives_output_file_path = os.path.join(output_path, 'tmp_alternaatives_{}.txt'.format(key))\n", 103 | " tmp_edits_output_file_path = os.path.join(output_path, 'tmp_edits_{}.txt'.format(key))\n", 104 | " bpe_output_file_path = os.path.join(output_path, 'bpe_{}.txt'.format(key))\n", 105 | " \n", 106 | " print('Start generating {} alternatives'.format(key))\n", 107 | " generate_alternatives(key, tmp_alternatives_output_file_path)\n", 108 | " print('Start generating {} edits'.format(key))\n", 109 | " generate_edits(key, tmp_edits_output_file_path)\n", 110 | "\n", 111 | " logs = defaultdict(list)\n", 112 | "\n", 113 | " print('Start parsing the alternatives beam search output')\n", 114 | "\n", 115 | " with open(tmp_alternatives_output_file_path) as f_in:\n", 116 | " for line in f_in:\n", 117 | " try:\n", 118 | " tag, *payload = line.split('\\t')\n", 119 | " logs[tag].append(payload)\n", 120 | " except: pass\n", 121 | " \n", 122 | " print('Start parsing the edits beam search output')\n", 123 | " \n", 124 | " edit_logs = defaultdict(list)\n", 125 | "\n", 126 | " with open(tmp_edits_output_file_path) as f_in:\n", 127 | " for line in f_in:\n", 128 | " try:\n", 129 | " tag, *payload = line.split('\\t')\n", 130 | " edit_logs[tag].append(payload)\n", 131 | " except: pass\n", 132 | " \n", 133 | " print('Start edit samples generating')\n", 134 | "\n", 135 | " with open(bpe_output_file_path, 'w') as f_out:\n", 136 | " i = 0\n", 137 | " while True:\n", 138 | " i += 1\n", 139 | " try:\n", 140 | " source = logs['S-{}'.format(i)][0][0].strip()\n", 141 | " edit_source = edit_logs['S-{}'.format(i)][0][0].strip()\n", 142 | " except:\n", 143 | " break\n", 144 | " hypos = [hypo.strip() for prob, hypo in logs['H-{}'.format(i)]]\n", 145 | " \n", 146 | " edits = [edit.strip() for prob, edit in edit_logs['H-{}'.format(i)]]\n", 147 | " edit = random.choice(edits)\n", 148 | "\n", 149 | " np.random.shuffle(hypos)\n", 150 | " f_out.write('{}\\t{}\\t{}\\n'.format(source, edit, '\\t'.join(hypos)))\n", 151 | " \n", 152 | " print('_'*100)\n", 153 | " print('\\n'*3)" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.7.4" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 2 178 | } 179 | -------------------------------------------------------------------------------- /lib/utils/basic.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import gc 3 | import os 4 | import time 5 | from collections import Counter 6 | from itertools import chain 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | 12 | def to_one_hot(y, depth=None): 13 | r""" 14 | Takes integer with n dims and converts it to 1-hot representation with n + 1 dims. 15 | The n+1'st dimension will have zeros everywhere but at y'th index, where it will be equal to 1. 16 | Args: 17 | y: input integer (IntTensor, LongTensor or Variable) of any shape 18 | depth (int): the size of the one hot dimension 19 | """ 20 | y_flat = y.to(torch.int64).view(-1, 1) 21 | depth = depth if depth is not None else int(torch.max(y_flat)) + 1 22 | y_one_hot = torch.zeros(y_flat.size()[0], depth, device=y.device).scatter_(1, y_flat, 1) 23 | y_one_hot = y_one_hot.view(*(tuple(y.shape) + (-1,))) 24 | return y_one_hot 25 | 26 | 27 | def dot(x, y): 28 | """ numpy-like dot product """ 29 | out_flat = x.view(-1, x.shape[-1]) @ y.view(y.shape[0], -1) 30 | return out_flat.view(*x.shape[:-1], *y.shape[1:]) 31 | 32 | 33 | def batch_outer_sum(*tensors): 34 | """ 35 | :param tensors: each matrix should have shape [..., d_i] 36 | :returns: [..., d_0, d_1, ..., d_N] where N = len(tensors) 37 | output[..., i, j, k] = tensors[0][..., i] + tensors[1][..., j] + tensors[2][..., k] 38 | """ 39 | outer_sum = None 40 | for i, tensor in enumerate(tensors): 41 | broadcaster = [None] * len(tensors) 42 | broadcaster[i] = slice(tensor.shape[-1]) 43 | broadcaster = tuple([...] + broadcaster) 44 | outer_sum = tensor[broadcaster] if i == 0 else outer_sum + tensor[broadcaster] 45 | return outer_sum 46 | 47 | 48 | def batch_outer_product(*tensors): 49 | """ 50 | :param tensors: each matrix should have shape [..., d_i] 51 | :returns: [..., d_0, d_1, ..., d_N] where N = len(tensors) 52 | output[..., i, j, k] = tensors[0][..., i] * tensors[1][..., j] * tensors[2][..., k] 53 | """ 54 | prefix_shape = tensors[0].shape[:-1] 55 | assert len(tensors) + len(prefix_shape) <= ord('z') - ord('a') 56 | 57 | prefix_chars = ''.join(map(chr, range(ord('a'), ord('a') + len(prefix_shape)))) 58 | dim_chars = ''.join(map(chr, range(ord('a') + len(prefix_shape), ord('a') + len(prefix_shape) + len(tensors)))) 59 | einsum_lhs = ','.join(prefix_chars + d_i for d_i in dim_chars) 60 | einsum_rhs = prefix_chars + dim_chars 61 | return torch.einsum("{}->{}".format(einsum_lhs, einsum_rhs), *tensors) 62 | 63 | 64 | def straight_through_grad(function, **kwargs): 65 | """ 66 | modify function so that it is applied normally but excluded from backward pass 67 | :param function: callable(*inputs) -> *outputs, number and shape of outputs must match that of inputs, 68 | :param kwargs: keyword arguments that will be sent to each function call 69 | """ 70 | def f_straight_through(*inputs): 71 | outputs = function(*inputs, **kwargs) 72 | single_output = isinstance(outputs, torch.Tensor) 73 | if single_output: 74 | outputs = [outputs] 75 | 76 | assert isinstance(outputs, (list, tuple)) and len(outputs) == len(inputs) 77 | outputs = type(outputs)( 78 | input + (output - input).detach() 79 | for input, output in zip(inputs, outputs) 80 | ) 81 | return outputs[0] if single_output else outputs 82 | 83 | return f_straight_through 84 | 85 | 86 | def nop(x): 87 | return x 88 | 89 | 90 | @contextlib.contextmanager 91 | def nop_ctx(): 92 | yield None 93 | 94 | 95 | class Nop(nn.Module): 96 | def forward(self, x): 97 | return x 98 | 99 | 100 | class Residual(nn.Sequential): 101 | def forward(self, x): 102 | return super().forward(x) + x 103 | 104 | 105 | class Flatten(nn.Module): 106 | def forward(self, x): 107 | return x.view(len(x), -1) 108 | 109 | 110 | @contextlib.contextmanager 111 | def training_mode(*modules, is_train:bool): 112 | group = nn.ModuleList(modules) 113 | was_training = {module: module.training for module in group.modules()} 114 | try: 115 | yield group.train(is_train) 116 | finally: 117 | for key, module in group.named_modules(): 118 | if module in was_training: 119 | module.training = was_training[module] 120 | else: 121 | raise ValueError("Model was modified inside training_mode(...) context, could not find {}".format(key)) 122 | 123 | 124 | def free_memory(sleep_time=0.1): 125 | """ Black magic function to free torch memory and some jupyter whims """ 126 | gc.collect() 127 | torch.cuda.synchronize() 128 | gc.collect() 129 | torch.cuda.empty_cache() 130 | time.sleep(sleep_time) 131 | 132 | 133 | def infer_model_device(model: nn.Module): 134 | """ infers model device as the device where the majority of parameters and buffers are stored """ 135 | device_stats = Counter( 136 | tensor.device for tensor in chain(model.parameters(), model.buffers()) 137 | if torch.is_tensor(tensor) 138 | ) 139 | return max(device_stats, key=device_stats.get) 140 | 141 | 142 | class Lambda(nn.Module): 143 | def __init__(self, func): 144 | """ :param func: call this function during forward """ 145 | super().__init__() 146 | self.func = func 147 | 148 | def forward(self, *args, **kwargs): 149 | return self.func(*args, **kwargs) 150 | 151 | 152 | def to_float_str(element): 153 | try: 154 | return str(float(element)) 155 | except ValueError: 156 | return element 157 | 158 | 159 | def run_from_ipython(): 160 | try: 161 | __IPYTHON__ 162 | return True 163 | except NameError: 164 | return False 165 | 166 | 167 | if run_from_ipython(): 168 | from IPython.display import clear_output 169 | else: 170 | def clear_output(*args, **kwargs): 171 | os.system('clear') 172 | 173 | 174 | class OptimizerList(torch.optim.Optimizer): 175 | def __init__(self, *optimizers): 176 | self.optimizers = optimizers 177 | 178 | def step(self): 179 | return [opt.step() for opt in self.optimizers] 180 | 181 | def zero_grad(self): 182 | return [opt.zero_grad() for opt in self.optimizers] 183 | 184 | def add_param_group(self, *args, **kwargs): 185 | raise ValueError("Please call add_param_group in one of self.optimizers") 186 | 187 | def __getstate__(self): 188 | return [opt.__getstate__() for opt in self.optimizers] 189 | 190 | def __setstate__(self, state): 191 | return [opt.__setstate__(opt_state) for opt, opt_state in zip(self.optimizers, state)] 192 | 193 | def __repr__(self): 194 | return repr(self.optimizers) 195 | 196 | def state_dict(self, **kwargs): 197 | return {"opt_{}".format(i): opt.state_dict(**kwargs) for i, opt in enumerate(self.optimizers)} 198 | 199 | def load_state_dict(self, state_dict, **kwargs): 200 | return [ 201 | opt.load_state_dict(state_dict["opt_{}".format(i)]) 202 | for i, opt in enumerate(self.optimizers) 203 | ] 204 | -------------------------------------------------------------------------------- /lib/editable.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from copy import copy 3 | from itertools import count 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .utils.copy_and_replace import do_not_copy, copy_and_replace 9 | from .utils.ingraph_update import IngraphGradientDescent 10 | 11 | 12 | class BaseEditable(nn.Module): 13 | EditResult = namedtuple('EditResult', ['model', 'success', 'loss', 'complexity']) 14 | # model: a model that was adjusted out-of-place. Must be the same type as self (*Editable) 15 | # success: True if edit was successful, False otherwise 16 | # loss: objective function at the termination of the edit procedure 17 | # complexity: a measure of effort it took to edit model, e.g. number of SGD steps 18 | 19 | def edit(self, *data): 20 | # This should perform editing without changing current model and return EditResult 21 | return self.EditResult(self, success=False, loss=0.0, complexity=0.0) 22 | 23 | 24 | class Editable(BaseEditable): 25 | 26 | def __init__(self, module: nn.Module, loss_function, 27 | optimizer=IngraphGradientDescent(0.01), max_steps=float('inf'), 28 | get_editable_parameters=lambda module: module.parameters(), 29 | is_edit_finished=lambda loss, **kwargs: loss.item() <= 0, 30 | ): 31 | """ 32 | Editable module that attempts to change model by performing SGD (with optional momentum and rms scaling) 33 | :param module: a torch module that will be edited 34 | :param loss_function: objective function(model(inputs), targets) that is minimized by editor. 35 | By default this function should be non-negative and loss == 0 is a trigger to finish editing 36 | :param optimizer: in-graph optimizer that creates updated copies of model 37 | :param get_editable_parameters: a function(Editable.module) that takes the wrapped module and returns 38 | an iterable of parameters that should affected by edits, defaults to all parameters inside Editable.module 39 | :param is_edit_finished: a function(loss, prediction, **local variables) that returns True if edit is finished 40 | """ 41 | super().__init__() 42 | self.module, self.loss_function, self.optimizer = module, loss_function, optimizer 43 | self.get_editable_parameters = get_editable_parameters 44 | self.is_edit_finished = is_edit_finished 45 | self.max_steps = max_steps 46 | 47 | def forward(self, *args, **kwargs): 48 | return self.module(*args, **kwargs) 49 | 50 | def edit(self, inputs, targets, max_steps=None, model_kwargs=None, loss_kwargs=None, opt_kwargs=None, **kwargs): 51 | """ 52 | Attempts to edit model (out-of-place) and return an edited copy 53 | :param inputs: data that is fed into the model 54 | :param targets: reference answers that are fed into loss function 55 | :param max_steps: after this many gradient steps the process is terminated 56 | :param model_kwargs: optional extra model inputs, used as model(inputs, **model_params) 57 | :param loss_kwargs: optional extra loss parameters, self.loss_function(model(inputs), targets, **loss_params) 58 | :param opt_kwargs: optional overrides for optimizer.get_initial_state 59 | :param kwargs: extra parameters passed to optimizer.step 60 | :returns: edited_model, is_edit_successful, final_loss, gradients_steps 61 | :rtype: Editable.EditResult 62 | """ 63 | model_kwargs, loss_kwargs, opt_kwargs = model_kwargs or {}, loss_kwargs or {}, opt_kwargs or {} 64 | optimizer_state = self.optimizer.get_initial_state(self, **opt_kwargs) 65 | editable = self 66 | 67 | for step in count(): 68 | prediction = editable(inputs, **model_kwargs) 69 | loss = self.loss_function(prediction, targets, **loss_kwargs) 70 | 71 | if self.is_edit_finished(**locals()): 72 | return self.EditResult(editable, success=True, loss=loss, complexity=step) 73 | elif step >= (max_steps or self.max_steps): 74 | return self.EditResult(editable, success=False, loss=loss, complexity=step) 75 | 76 | optimizer_state, editable = self.optimizer.step( 77 | optimizer_state, editable, loss, parameters=editable.get_editable_parameters(editable.module), **kwargs) 78 | 79 | def extra_repr(self): 80 | return "max_steps={}, loss_function={}".format(self.max_steps, repr(self.loss_function)) 81 | 82 | 83 | class SequentialWithEditable(BaseEditable): 84 | def __init__(self, *args): 85 | """ A chain of modules with exactly one Editable, edit procedure will only compute pre-editable modules once """ 86 | super().__init__() 87 | pre_editable, editable, post_editable = [], None, [] 88 | for module in args: 89 | if isinstance(module, BaseEditable): 90 | assert editable is None, "SequentialEditable only supports one Editable module for now" 91 | editable = module 92 | elif editable is None: 93 | pre_editable.append(module) 94 | else: 95 | post_editable.append(module) 96 | 97 | assert editable is not None, "SequentialEditable must have one Editable at init, got 0" 98 | self.prefix_layers = nn.Sequential(*pre_editable) 99 | self.editable = editable if len(post_editable) == 0 else self._editable_with_suffix(editable, *post_editable) 100 | 101 | def forward(self, *args, **kwargs): 102 | return self.editable(self.prefix_layers(*args, **kwargs)) 103 | 104 | def edit(self, inputs, *args, **kwargs): 105 | result = self.editable.edit(self.prefix_layers(inputs), *args, **kwargs) 106 | with do_not_copy(self.prefix_layers, *self.prefix_layers.parameters(), *self.prefix_layers.buffers()): 107 | edited_model = copy_and_replace(self, replace={self.editable: result.model}) 108 | return self.EditResult(edited_model, *result[1:]) 109 | 110 | @staticmethod 111 | def _editable_with_suffix(base_editable: Editable, *suffix): 112 | new_editable = copy(base_editable) 113 | new_editable.module = nn.Sequential(base_editable.module, *suffix) 114 | new_editable.get_editable_parameters = lambda module: base_editable.get_editable_parameters(module[0]) 115 | return new_editable 116 | 117 | 118 | class RehearsalEditable(Editable): 119 | def __init__(self, *args, rehearsal_loss_weight=1.0, get_rehearsals, **kwargs): 120 | super().__init__(*args, **kwargs) 121 | self.rehearsal_loss_weight = rehearsal_loss_weight 122 | self.get_rehearsals = get_rehearsals 123 | 124 | def edit(self, inputs, targets, max_steps=None, model_kwargs=None, loss_kwargs=None, opt_kwargs={}, **kwargs): 125 | model_kwargs, loss_kwargs, opt_kwargs = model_kwargs or {}, loss_kwargs or {}, opt_kwargs or {} 126 | optimizer_state = self.optimizer.get_initial_state(self, **opt_kwargs) 127 | editable = self 128 | 129 | X_batch = self.get_rehearsals(inputs) 130 | vanilla_probs = F.softmax(editable(X_batch, **model_kwargs), dim=-1).detach() 131 | for step in count(): 132 | prediction = editable(inputs, **model_kwargs) 133 | 134 | loss = self.loss_function(prediction, targets, **loss_kwargs) 135 | 136 | if self.is_edit_finished(**locals()): 137 | return self.EditResult(editable, success=True, loss=loss, complexity=step) 138 | elif step >= (max_steps or self.max_steps): 139 | return self.EditResult(editable, success=False, loss=loss, complexity=step) 140 | 141 | current_logp = F.log_softmax(editable(X_batch, **model_kwargs), dim=-1) 142 | batch_loss = F.kl_div(current_logp, vanilla_probs) 143 | total_loss = loss + self.rehearsal_loss_weight * batch_loss 144 | 145 | with do_not_copy(self.get_rehearsals): 146 | optimizer_state, editable = self.optimizer.step( 147 | optimizer_state, editable, total_loss, parameters=editable.get_editable_parameters(editable.module), **kwargs) 148 | -------------------------------------------------------------------------------- /notebooks/cifar10_editable_layer3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "__This notebook__ trains resnet18 from scratch on CIFAR10 dataset." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "env: CUDA_VISIBLE_DEVICES=1\n", 20 | "editable_layer3_2019.09.19_23:06:14\n", 21 | "PyTorch version: 1.1.0\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%load_ext autoreload\n", 27 | "%autoreload 2\n", 28 | "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n", 29 | "import os, sys, time\n", 30 | "sys.path.insert(0, '..')\n", 31 | "import lib\n", 32 | "\n", 33 | "import numpy as np\n", 34 | "import torch, torch.nn as nn\n", 35 | "import torch.nn.functional as F\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "%matplotlib inline\n", 38 | "\n", 39 | "import random\n", 40 | "random.seed(42)\n", 41 | "np.random.seed(42)\n", 42 | "torch.random.manual_seed(42)\n", 43 | "\n", 44 | "import time\n", 45 | "from resnet import ResNet18\n", 46 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 47 | "\n", 48 | "experiment_name = 'editable_layer3'\n", 49 | "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n", 50 | "print(experiment_name)\n", 51 | "print(\"PyTorch version:\", torch.__version__)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "Files already downloaded and verified\n", 64 | "Files already downloaded and verified\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "from torchvision import transforms, datasets\n", 70 | "\n", 71 | "transform_train = transforms.Compose([\n", 72 | " transforms.RandomCrop(32, padding=4),\n", 73 | " transforms.RandomHorizontalFlip(),\n", 74 | " transforms.ToTensor(),\n", 75 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 76 | "])\n", 77 | "\n", 78 | "transform_test = transforms.Compose([\n", 79 | " transforms.ToTensor(),\n", 80 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n", 81 | "])\n", 82 | "\n", 83 | "trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n", 84 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)\n", 85 | "\n", 86 | "testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n", 87 | "testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)\n", 88 | "\n", 89 | "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", 90 | "X_test, y_test = map(torch.cat, zip(*list(testloader)))" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "model = lib.Editable(\n", 100 | " module=ResNet18(), loss_function=lib.contrastive_cross_entropy,\n", 101 | " get_editable_parameters=lambda module: module.layer3.parameters(),\n", 102 | " optimizer=lib.IngraphRMSProp(\n", 103 | " learning_rate=1e-3, beta=nn.Parameter(torch.tensor(0.5, dtype=torch.float32)), \n", 104 | " ), max_steps=10,\n", 105 | "\n", 106 | ").to(device)\n", 107 | "\n", 108 | "trainer = lib.EditableTrainer(model, F.cross_entropy, experiment_name=experiment_name, max_norm=10)\n", 109 | "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', '
'))" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "application/vnd.jupyter.widget-view+json": { 120 | "model_id": "af4871f821da4d80b68dd4841955e2ac", 121 | "version_major": 2, 122 | "version_minor": 0 123 | }, 124 | "text/html": [ 125 | "

Failed to display Jupyter Widget of type HBox.

\n", 126 | "

\n", 127 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 128 | " that the widgets JavaScript is still loading. If this message persists, it\n", 129 | " likely means that the widgets JavaScript library is either not installed or\n", 130 | " not enabled. See the Jupyter\n", 131 | " Widgets Documentation for setup instructions.\n", 132 | "

\n", 133 | "

\n", 134 | " If you're reading this message in another frontend (for example, a static\n", 135 | " rendering on GitHub or NBViewer),\n", 136 | " it may mean that your frontend doesn't currently support widgets.\n", 137 | "

\n" 138 | ], 139 | "text/plain": [ 140 | "HBox(children=(IntProgress(value=0, max=391), HTML(value='')))" 141 | ] 142 | }, 143 | "metadata": {}, 144 | "output_type": "display_data" 145 | } 146 | ], 147 | "source": [ 148 | "from tqdm import tqdm_notebook, tnrange\n", 149 | "from IPython.display import clear_output\n", 150 | "\n", 151 | "val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 152 | "min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 153 | "early_stopping_epochs = 500\n", 154 | "number_of_epochs_without_improvement = 0\n", 155 | "\n", 156 | "def edit_generator():\n", 157 | " while True:\n", 158 | " for xb, yb in torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2):\n", 159 | " yield xb.to(device), torch.randint_like(yb, low=0, high=len(classes), device=device)\n", 160 | "\n", 161 | "edit_generator = edit_generator()\n", 162 | "\n", 163 | "\n", 164 | "while True:\n", 165 | " for x_batch, y_batch in tqdm_notebook(trainloader):\n", 166 | " trainer.step(x_batch.to(device), y_batch.to(device), *next(edit_generator))\n", 167 | " \n", 168 | " val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 169 | " clear_output(True)\n", 170 | " \n", 171 | " error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 172 | " \n", 173 | " number_of_epochs_without_improvement += 1\n", 174 | " \n", 175 | " \n", 176 | " if error_rate < min_error:\n", 177 | " trainer.save_checkpoint(tag='best_val_error')\n", 178 | " min_error = error_rate\n", 179 | " number_of_epochs_without_improvement = 0\n", 180 | " \n", 181 | " if drawdown < min_drawdown:\n", 182 | " trainer.save_checkpoint(tag='best_drawdown')\n", 183 | " min_drawdown = drawdown\n", 184 | " number_of_epochs_without_improvement = 0\n", 185 | " \n", 186 | " trainer.save_checkpoint()\n", 187 | " trainer.remove_old_temp_checkpoints()\n", 188 | "\n", 189 | " if number_of_epochs_without_improvement > early_stopping_epochs:\n", 190 | " break" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "from lib import evaluate_quality\n", 200 | "\n", 201 | "np.random.seed(9)\n", 202 | "indices = np.random.permutation(len(X_test))[:1000]\n", 203 | "X_edit = X_test[indices].clone().to(device)\n", 204 | "y_edit = torch.tensor(np.random.randint(0, 10, size=y_test[indices].shape), device=device)\n", 205 | "metrics = evaluate_quality(editable_model, X_test, y_test, X_edit, y_edit, batch_size=512)\n", 206 | "for key in sorted(metrics.keys()):\n", 207 | " print('{}\\t:{:.5}'.format(key, metrics[key]))" 208 | ] 209 | } 210 | ], 211 | "metadata": { 212 | "kernelspec": { 213 | "display_name": "Python 3", 214 | "language": "python", 215 | "name": "python3" 216 | }, 217 | "language_info": { 218 | "codemirror_mode": { 219 | "name": "ipython", 220 | "version": 3 221 | }, 222 | "file_extension": ".py", 223 | "mimetype": "text/x-python", 224 | "name": "python", 225 | "nbconvert_exporter": "python", 226 | "pygments_lexer": "ipython3", 227 | "version": "3.6.4" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 2 232 | } 233 | -------------------------------------------------------------------------------- /mt/edited_generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # This source code is based fairseq/generate.py 3 | """ 4 | Perform edit and translate pre-processed data with a trained model. 5 | """ 6 | 7 | 8 | import sys 9 | sys.path.append("fairseq") 10 | 11 | import torch 12 | 13 | from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils 14 | from fairseq.meters import StopwatchMeter, TimeMeter 15 | import tqdm 16 | 17 | from fairseq_criterion import EditableTrainingCriterion 18 | from mosestokenizer import MosesDetokenizer 19 | 20 | 21 | def main(args): 22 | assert args.path is not None, '--path required for generation!' 23 | assert not args.sampling or args.nbest == args.beam, \ 24 | '--sampling requires --nbest to be equal to --beam' 25 | assert args.replace_unk is None or args.raw_text, \ 26 | '--replace-unk requires a raw text dataset (--raw-text)' 27 | 28 | utils.import_user_module(args) 29 | 30 | if args.max_tokens is None and args.max_sentences is None: 31 | args.max_tokens = 12000 32 | print(args) 33 | 34 | use_cuda = torch.cuda.is_available() and not args.cpu 35 | 36 | # Load dataset splits 37 | task = tasks.setup_task(args) 38 | task.load_dataset(args.gen_subset) 39 | 40 | # Set dictionaries 41 | try: 42 | src_dict = getattr(task, 'source_dictionary', None) 43 | except NotImplementedError: 44 | src_dict = None 45 | tgt_dict = task.target_dictionary 46 | 47 | # Load ensemble 48 | print('| loading model(s) from {}'.format(args.path)) 49 | models, _model_args = checkpoint_utils.load_model_ensemble( 50 | args.path.split(':'), 51 | arg_overrides=eval(args.model_overrides), 52 | task=task, 53 | ) 54 | 55 | # Optimize ensemble for generation 56 | for model in models: 57 | model.make_generation_fast_( 58 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 59 | need_attn=args.print_alignment, 60 | ) 61 | if args.fp16: 62 | model.half() 63 | if use_cuda: 64 | model.cuda() 65 | 66 | # Load alignment dictionary for unknown word replacement 67 | # (None if no unknown word replacement, empty if no path to align dictionary) 68 | align_dict = utils.load_align_dict(args.replace_unk) 69 | 70 | # Load dataset (possibly sharded) 71 | itr = task.get_batch_iterator( 72 | dataset=task.dataset(args.gen_subset), 73 | max_tokens=args.max_tokens, 74 | max_sentences=args.max_sentences, 75 | max_positions=utils.resolve_max_positions( 76 | task.max_positions(), 77 | *[model.max_positions() for model in models] 78 | ), 79 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 80 | required_batch_size_multiple=args.required_batch_size_multiple, 81 | num_shards=args.num_shards, 82 | shard_id=args.shard_id, 83 | num_workers=args.num_workers, 84 | ).next_epoch_itr(shuffle=False) 85 | 86 | # Initialize generator 87 | gen_timer = StopwatchMeter() 88 | generator = task.build_generator(args) 89 | 90 | detokenizer = MosesDetokenizer(args.moses_detokenizer) 91 | 92 | # Generate and compute BLEU score 93 | if args.sacrebleu: 94 | scorer = bleu.SacrebleuScorer() 95 | else: 96 | scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) 97 | num_sentences = 0 98 | has_target = True 99 | with progress_bar.build_progress_bar(args, itr) as t: 100 | wps_meter = TimeMeter() 101 | 102 | if args.edit_sample_index is not None: 103 | # make EditableTransformer 104 | 105 | assert len(models) == 1 106 | 107 | criterion = EditableTrainingCriterion(args, task).train(False) 108 | critetion_state_dict = torch.load(args.path) 109 | if 'criterion' in critetion_state_dict: 110 | criterion.load_state_dict(critetion_state_dict['criterion']) 111 | 112 | model = models[0] 113 | edit_sample = criterion.samples[args.edit_sample_index] 114 | device = 'cuda' if use_cuda else 'cpu' 115 | edited_model, success, _, complexity = criterion.get_edited_transformer(model, edit_sample, device, detach=True) 116 | edited_model.train(False) 117 | 118 | models[0] = edited_model.recover_transformer() 119 | 120 | for sample in t: 121 | sample = utils.move_to_cuda(sample) if use_cuda else sample 122 | if 'net_input' not in sample: 123 | continue 124 | 125 | prefix_tokens = None 126 | if args.prefix_size > 0: 127 | prefix_tokens = sample['target'][:, :args.prefix_size] 128 | 129 | gen_timer.start() 130 | hypos = task.inference_step(generator, models, sample, prefix_tokens) 131 | num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) 132 | gen_timer.stop(num_generated_tokens) 133 | 134 | for i, sample_id in enumerate(sample['id'].tolist()): 135 | has_target = sample['target'] is not None 136 | 137 | # Remove padding 138 | src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) 139 | target_tokens = None 140 | if has_target: 141 | target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() 142 | 143 | # Either retrieve the original sentences or regenerate them from tokens. 144 | if align_dict is not None: 145 | src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) 146 | target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) 147 | else: 148 | if src_dict is not None: 149 | src_str = src_dict.string(src_tokens, args.remove_bpe) 150 | else: 151 | src_str = "" 152 | if has_target: 153 | target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) 154 | 155 | if not args.quiet: 156 | if src_dict is not None: 157 | print('S-{}\t{}'.format(sample_id, src_str)) 158 | if has_target: 159 | print('T-{}\t{}'.format(sample_id, target_str)) 160 | 161 | # Process top predictions 162 | for j, hypo in enumerate(hypos[i][:args.nbest]): 163 | hypo_tokens, hypo_str, alignment = utils.post_process_prediction( 164 | hypo_tokens=hypo['tokens'].int().cpu(), 165 | src_str=src_str, 166 | alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, 167 | align_dict=align_dict, 168 | tgt_dict=tgt_dict, 169 | remove_bpe=args.remove_bpe, 170 | ) 171 | 172 | if not args.quiet: 173 | print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) 174 | print('P-{}\t{}'.format( 175 | sample_id, 176 | ' '.join(map( 177 | lambda x: '{:.4f}'.format(x), 178 | hypo['positional_scores'].tolist(), 179 | )) 180 | )) 181 | 182 | if args.print_alignment: 183 | print('A-{}\t{}'.format( 184 | sample_id, 185 | ' '.join(map(lambda x: str(utils.item(x)), alignment)) 186 | )) 187 | 188 | # Score only the top hypothesis 189 | if has_target and j == 0: 190 | if align_dict is not None or args.remove_bpe is not None: 191 | # Convert back to tokens for evaluation with unk replacement and/or without BPE 192 | target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) 193 | if hasattr(scorer, 'add_string'): 194 | if args.moses_detokenizer: 195 | target_str = detokenizer(target_str.split()) 196 | hypo_str = detokenizer(hypo_str.split()) 197 | 198 | scorer.add_string(target_str, hypo_str) 199 | else: 200 | assert not args.moses_detokenizer, "detokenizer has no effect with current bleu scorer" 201 | scorer.add(target_tokens, hypo_tokens) 202 | 203 | wps_meter.update(num_generated_tokens) 204 | t.log({'wps': round(wps_meter.avg)}) 205 | num_sentences += sample['nsentences'] 206 | 207 | print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( 208 | num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) 209 | if has_target: 210 | print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) 211 | if args.edit_sample_index is not None: 212 | print('EditResult(success={}, complexity={})'.format(success, complexity)) 213 | return scorer 214 | 215 | 216 | def cli_main(): 217 | parser = options.get_generation_parser() 218 | parser.add_argument('--edit-sample-index', type=int, metavar='D', default=None, 219 | help='Index of edit sample to use (pass a dataset via criterion --edit-samples-path arg)') 220 | parser.add_argument('--moses-detokenizer', type=str, default=None) 221 | EditableTrainingCriterion.add_args(parser) 222 | args = options.parse_args_and_arch(parser) 223 | main(args) 224 | 225 | 226 | if __name__ == '__main__': 227 | cli_main() 228 | -------------------------------------------------------------------------------- /notebooks/imagenet_editable_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "__This notebook__ fine-tunes a pre-trained resnet18 model with editable training.\n", 8 | "\n", 9 | "__Prepare data:__\n", 10 | "* Download imagenet training and dataset\n", 11 | "* Make sure folder names are called \"000\", \"001\", ... \"010\", \"011\", ... and not \"0\", \"1\", ..., \"10\", \"11\", ...\n", 12 | " * rename if necessary\n", 13 | "* Run `imagenet_preprocess_logits.ipynb` to prepare fine-tuning metadata.\n", 14 | "\n", 15 | "__Training:__\n", 16 | "* Set environment variables and paths in the next cell\n", 17 | "* Run all cells :)" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "env: CUDA_VISIBLE_DEVICES=1\n", 30 | "imagenet_editable_extra_layer_2019.09.19_23:11:24\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%load_ext autoreload\n", 36 | "%autoreload 2\n", 37 | "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n", 38 | "\n", 39 | "traindir = '../../imagenet/train' # path to train ImageFolder\n", 40 | "valdir = '../../imagenet/val' # path to validation ImageFolder\n", 41 | "logits_path = './imagenet_logits/' # see imagenet_preprocess_logits\n", 42 | "\n", 43 | "import os, sys, time\n", 44 | "sys.path.insert(0, '..')\n", 45 | "import lib\n", 46 | "\n", 47 | "import numpy as np\n", 48 | "import torch, torch.nn as nn\n", 49 | "import torch.nn.functional as F\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "%matplotlib inline\n", 52 | "\n", 53 | "import random\n", 54 | "random.seed(42)\n", 55 | "np.random.seed(42)\n", 56 | "torch.random.manual_seed(42)\n", 57 | "\n", 58 | "import time\n", 59 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 60 | "\n", 61 | "experiment_name = 'imagenet_editable_extra_layer'\n", 62 | "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n", 63 | "print(experiment_name)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 2, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "import torchvision.transforms as transforms\n", 73 | "import torchvision.datasets as datasets\n", 74 | "import torchvision.models as models\n", 75 | "\n", 76 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 77 | " std=[0.229, 0.224, 0.225])\n", 78 | "\n", 79 | "train_dataset = lib.ImageAndLogitsFolder(\n", 80 | " traindir,\n", 81 | " transforms.Compose([\n", 82 | " transforms.RandomResizedCrop(224),\n", 83 | " transforms.RandomHorizontalFlip(),\n", 84 | " transforms.ToTensor(),\n", 85 | " normalize,\n", 86 | " ]),\n", 87 | " logits_prefix = logits_path\n", 88 | ")\n", 89 | "\n", 90 | "batch_size = 128\n", 91 | "\n", 92 | "train_loader = torch.utils.data.DataLoader(\n", 93 | " train_dataset, batch_size=batch_size, shuffle=True,\n", 94 | " num_workers=12, pin_memory=True)\n", 95 | "\n", 96 | "val_loader = torch.utils.data.DataLoader(\n", 97 | " datasets.ImageFolder(valdir, transforms.Compose([\n", 98 | " transforms.Resize(256),\n", 99 | " transforms.CenterCrop(224),\n", 100 | " transforms.ToTensor(),\n", 101 | " normalize,\n", 102 | " ])),\n", 103 | " batch_size=batch_size, shuffle=False,\n", 104 | " num_workers=32, pin_memory=True)\n", 105 | "\n", 106 | "X_test, y_test = map(torch.cat, zip(*val_loader))\n", 107 | "X_test, y_test = X_test[::10], y_test[::10]\n", 108 | "# !!!IMPORTANT!!!\n", 109 | "# We use 10% of validation samples for faster validation, please use full validation set to measure \"final\" error rate" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 3, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "import torchvision\n", 119 | "\n", 120 | "model = torchvision.models.resnet18(pretrained=True)\n", 121 | "\n", 122 | "optimizer = lib.IngraphRMSProp(learning_rate=1e-4, beta=nn.Parameter(torch.as_tensor(0.5)))\n", 123 | "\n", 124 | "model = lib.SequentialWithEditable(\n", 125 | " model.conv1, model.bn1, model.relu, model.maxpool,\n", 126 | " model.layer1, model.layer2, model.layer3, model.layer4,\n", 127 | " model.avgpool, lib.Flatten(),\n", 128 | " lib.Editable(\n", 129 | " lib.Residual(nn.Linear(512, 4096), nn.ELU(), nn.Linear(4096, 512)),\n", 130 | " loss_function=lib.contrastive_cross_entropy, \n", 131 | " optimizer=optimizer, max_steps=10),\n", 132 | "\n", 133 | " model.fc\n", 134 | ").to(device)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 4, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "def classification_error(model, X_test, y_test):\n", 144 | " with lib.training_mode(model, is_train=False):\n", 145 | " return lib.classification_error(lib.Lambda(lambda x: model(x.to(device))),\n", 146 | " X_test, y_test, device='cpu', batch_size=128)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "new_params = set(model.editable.module[0].parameters())\n", 156 | "old_params = [param for param in model.parameters() if param not in new_params]\n", 157 | "\n", 158 | "training_opt = lib.OptimizerList(\n", 159 | " torch.optim.SGD(old_params, lr=1e-5, momentum=0.9, weight_decay=1e-4),\n", 160 | " torch.optim.SGD(new_params, lr=1e-3, momentum=0.9, weight_decay=1e-4),\n", 161 | ")\n", 162 | "\n", 163 | "trainer = lib.DistillationEditableTrainer(model,\n", 164 | " stability_coeff=0.03, editability_coeff=0.03,\n", 165 | " experiment_name=experiment_name,\n", 166 | " error_function=classification_error,\n", 167 | " opt=training_opt, max_norm=10)\n", 168 | "\n", 169 | "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', '
'))" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "from tqdm import tqdm_notebook, tnrange\n", 179 | "from IPython.display import clear_output\n", 180 | "\n", 181 | "# Learnign params\n", 182 | "eval_batch_cd = 500\n", 183 | "val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 184 | "min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 185 | "early_stopping_epochs = 500\n", 186 | "number_of_epochs_without_improvement = 0\n", 187 | " \n", 188 | "def edit_generator():\n", 189 | " while True:\n", 190 | " for xb, yb, lg in torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2):\n", 191 | " yield xb.to(device), torch.randint_like(yb, low=0, high=max(y_test) + 1, device=device)\n", 192 | "\n", 193 | "edit_generator = edit_generator()" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "### Train" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "while True:\n", 210 | " \n", 211 | " for x_batch, y_batch, logits in tqdm_notebook(train_loader):\n", 212 | " trainer.step(x_batch.to(device), logits.to(device), *next(edit_generator))\n", 213 | " \n", 214 | " if trainer.total_steps % eval_batch_cd == 0:\n", 215 | " val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 216 | " clear_output(True)\n", 217 | "\n", 218 | " error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 219 | "\n", 220 | " number_of_epochs_without_improvement += 1\n", 221 | "\n", 222 | " if error_rate < min_error:\n", 223 | " trainer.save_checkpoint(tag='best_val_error')\n", 224 | " min_error = error_rate\n", 225 | " number_of_epochs_without_improvement = 0\n", 226 | "\n", 227 | " if drawdown < min_drawdown:\n", 228 | " trainer.save_checkpoint(tag='best_drawdown')\n", 229 | " min_drawdown = drawdown\n", 230 | " number_of_epochs_without_improvement = 0\n", 231 | "\n", 232 | " trainer.save_checkpoint()\n", 233 | " trainer.remove_old_temp_checkpoints()\n", 234 | "\n", 235 | " if number_of_epochs_without_improvement > early_stopping_epochs:\n", 236 | " break" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "### Evaluate drawdown\n", 244 | "\n", 245 | "__Note:__ this code evaluates quality on 10% of the validation set. In paper we use this subset when evaluating drawdown but we measure the base error on all 50k validation samples." 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "# edit quality\n", 255 | "\n", 256 | "from lib import evaluate_quality\n", 257 | "\n", 258 | "np.random.seed(9)\n", 259 | "indices = np.random.permutation(len(X_test))[:1000]\n", 260 | "X_edit = X_test[indices].clone().to(device)\n", 261 | "y_edit = torch.tensor(np.random.randint(0, max(y_test) + 1, size=y_test[indices].shape), device=device)\n", 262 | "metrics = evaluate_quality(model, X_test, y_test, X_edit, y_edit, \n", 263 | " error_function=classification_error, progressbar=tqdm_notebook)\n", 264 | "\n", 265 | "for key in sorted(metrics.keys()):\n", 266 | " print('{}\\t:{:.5}'.format(key, metrics[key]))\n" 267 | ] 268 | } 269 | ], 270 | "metadata": { 271 | "kernelspec": { 272 | "display_name": "Python 3", 273 | "language": "python", 274 | "name": "python3" 275 | }, 276 | "language_info": { 277 | "codemirror_mode": { 278 | "name": "ipython", 279 | "version": 3 280 | }, 281 | "file_extension": ".py", 282 | "mimetype": "text/x-python", 283 | "name": "python", 284 | "nbconvert_exporter": "python", 285 | "pygments_lexer": "ipython3", 286 | "version": "3.6.4" 287 | } 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 2 291 | } 292 | -------------------------------------------------------------------------------- /lib/utils/ingraph_update.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities required for backpropagating through gradient descent steps, inspired by: 3 | Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks https://arxiv.org/abs/1703.03400 4 | """ 5 | from collections import namedtuple 6 | from warnings import warn 7 | 8 | import torch 9 | import torch.nn as nn 10 | from copy import deepcopy 11 | from itertools import chain 12 | 13 | from torch import nn as nn 14 | from .basic import straight_through_grad 15 | from .copy_and_replace import copy_and_replace 16 | 17 | 18 | def get_updated_model(model: nn.Module, loss=None, gradients=None, parameters=None, 19 | detach=False, learning_rate=1.0, allow_unused=False, **kwargs): 20 | """ 21 | Creates a copy of model whose parameters are updated with one-step gradient descent w.r.t. loss 22 | The copy will propagate gradients into the original model 23 | :param model: original model 24 | :param loss: scalar objective to backprop from; provide either this or gradients 25 | :param gradients: a list or tuple of gradients (updates) for each parameter; provide either this or loss 26 | :param parameters: list/tuple of parameters to update, defaults to model.parameters() 27 | :param detach: if True, the resulting model will not propagate gradients to the original model 28 | :param learning_rate: scales gradients by this value before updating 29 | :param allow_unused: by default, raise an error if one or more parameters receive None gradients 30 | Otherwise (allow_unused=True) simply do not update these parameters 31 | """ 32 | assert (loss is None) != (gradients is None) 33 | parameters = list(model.parameters() if parameters is None else parameters) 34 | if gradients is None: 35 | assert torch.is_grad_enabled() 36 | gradients = torch.autograd.grad( 37 | loss, parameters, create_graph=not detach, only_inputs=True, allow_unused=allow_unused, **kwargs) 38 | 39 | assert isinstance(gradients, (list, tuple)) and len(gradients) == len(parameters) 40 | 41 | updates = dict() 42 | for weight, grad in zip(parameters, gradients): 43 | if grad is not None: 44 | update = weight - learning_rate * grad 45 | if detach: 46 | update = update.detach().requires_grad_(weight.requires_grad) 47 | updates[weight] = update 48 | 49 | do_not_copy = [tensor for tensor in chain(model.parameters(), model.buffers()) 50 | if tensor not in updates] 51 | 52 | return copy_and_replace(model, updates, do_not_copy) 53 | 54 | 55 | class IngraphGradientDescent(nn.Module): 56 | """ Optimizer that updates model out-of-place and returns a copy with changed parameters """ 57 | OptimizerState = namedtuple("OptimizerState", []) 58 | 59 | def __init__(self, learning_rate=1.0): 60 | super().__init__() 61 | self.learning_rate = learning_rate 62 | 63 | def get_initial_state(self, editable, **kwargs): 64 | """ Return initial optimizer state: momenta, rms, etc. """ 65 | return self.OptimizerState() 66 | 67 | def step(self, state: OptimizerState, module: nn.Module, loss, parameters=None, **kwargs): 68 | """ 69 | Return an updated copy of model after one iteration of gradient descent 70 | :param state: optimizer state (as in self.get_initial_state) 71 | :param module: module to be edited (lib.Editable) 72 | :param loss: torch scalar that is differentiable w.r.t. model parameters 73 | :parameters: parameters of :module: that will be edited by updates (default = module.parameters()) 74 | :param kwargs: extra parameters passed to get_updated_model 75 | :returns: new_state, updated_self 76 | new_state: self.OptimizerState - optimizer state after performing sgd step 77 | updated_self: Editable - updated(out-of-place) version of self 78 | """ 79 | updated_editable = get_updated_model(module, loss=loss, learning_rate=self.learning_rate, 80 | parameters=list(parameters or module.parameters()), **kwargs) 81 | return state, updated_editable 82 | 83 | def forward(self, *args, **kwargs): 84 | return self.step(*args, **kwargs) 85 | 86 | def extra_repr(self): 87 | return "learning_rate={}".format(self.learning_rate) 88 | 89 | 90 | class IngraphRMSProp(IngraphGradientDescent): 91 | OptimizerState = namedtuple( 92 | "OptimizerState", ["grad_momenta", "ewma_grad_norms_sq", "learning_rate", "momentum", "beta", "epsilon"]) 93 | 94 | def __init__(self, learning_rate=None, log_learning_rate=None, momentum=None, beta=None, 95 | epsilon=None, log_epsilon=None, force_trainable_params=False): 96 | """ 97 | Ingraph optimzier that performs RMSProp updates with optional momentum 98 | :param learning_rate: log(alpha) for gradient descent, all updates are scaled by exponent of this value 99 | :param momentum: momentum coefficient, the update direction is (1 - momentum) * prev_update + update, 100 | default = no momentum 101 | :param beta: RMSProp decay coefficient, the update is scaled by 1 / sqrt(ewma + epsilon) 102 | where ewma = prev_ewma * beta + dL/dw ^ 2 * (1 - beta), default = no RMSProp 103 | :param force_trainable_params: if True, treats all optimizer parameters that are not None as learnable 104 | parameters that are trained alongside other non-edited layers 105 | 106 | """ 107 | nn.Module.__init__(self) 108 | self.params = dict( 109 | learning_rate=learning_rate, log_learning_rate=log_learning_rate, 110 | momentum=momentum, beta=beta, epsilon=epsilon, log_epsilon=log_epsilon 111 | ) 112 | 113 | if force_trainable_params: 114 | for key in self.params: 115 | if self.params[key] is None: continue 116 | elif isinstance(self.params[key], nn.Parameter): continue 117 | elif isinstance(self.params[key], torch.Tensor) and self.params[key].requires_grad: continue 118 | self.params[key] = nn.Parameter(torch.as_tensor(self.params[key])) 119 | 120 | for key in self.params: 121 | if isinstance(self.params[key], nn.Parameter): 122 | self.register_parameter(key, self.params[key]) 123 | 124 | def get_initial_state(self, editable, **overrides): 125 | """ 126 | Create initial state and make sure all parameters are in a valid range 127 | :param editable: module to be edited 128 | :param overrides: send key-value optimizer params with same names as at init to override them 129 | :return: Editable.OptimizerState 130 | """ 131 | for key in overrides: 132 | assert key in self.params, "unknown optimizer parameter {}".format(key) 133 | params = dict(self.params, **overrides) 134 | 135 | assert (params['learning_rate'] is None) != (params['log_learning_rate'] is None), "provide either lr or log lr" 136 | learning_rate = params['learning_rate'] or torch.exp(params['log_learning_rate']) 137 | learning_rate = straight_through_grad(torch.clamp_min, min=0.0)(torch.as_tensor(learning_rate)) 138 | 139 | momentum = params.get('momentum') 140 | if momentum is not None: 141 | momentum = straight_through_grad(torch.clamp, min=0.0, max=1.0)(torch.as_tensor(momentum)) 142 | if isinstance(momentum, torch.Tensor) and momentum.requires_grad: 143 | warn("The derivative of updated params w.r.t. momentum is proportional to momentum^{n_steps - 1}, " 144 | "optimizing it with gradient descent may suffer from poor numerical stability.") 145 | 146 | beta = params.get('beta') 147 | if beta is not None: 148 | beta = straight_through_grad(torch.clamp, min=0.0, max=1.0)(torch.as_tensor(beta)) 149 | 150 | assert params['epsilon'] is None or params['log_epsilon'] is None, "provide either epsilon or log epsilon" 151 | if params['epsilon'] is None and params['log_epsilon'] is None: 152 | params['epsilon'] = 1e-6 153 | epsilon = params['epsilon'] or torch.exp(params['log_epsilon']) 154 | epsilon = straight_through_grad(torch.clamp_min, min=1e-9)(torch.as_tensor(epsilon)) 155 | 156 | else: 157 | epsilon = None 158 | 159 | return self.OptimizerState(None, None, learning_rate, momentum, beta, epsilon) 160 | 161 | def step(self, state: OptimizerState, module: nn.Module, loss, parameters=None, **kwargs): 162 | """ 163 | :param state: optimizer state (as in self.get_initial_state) 164 | :param module: module to be edited 165 | :param loss: torch scalar that is differentiable w.r.t. model parameters 166 | :param parameters: if model 167 | :param kwargs: extra parameters passed to get_updated_model 168 | :returns: new_state, updated_self 169 | new_state: self.OptimizerState - optimizer state after performing sgd step 170 | updated_self: updated copy of module 171 | """ 172 | grad_momenta, ewma_grad_norms_sq, learning_rate, momentum, beta, epsilon = state 173 | parameters = list(parameters or module.parameters()) 174 | gradients = torch.autograd.grad(loss, parameters, create_graph=True, only_inputs=True, allow_unused=False) 175 | updates = list(gradients) # updates are the scaled/accumulated/tuned gradients 176 | 177 | if momentum is not None: 178 | # momentum: accumulate gradients with moving average-like procedure 179 | if grad_momenta is None: 180 | grad_momenta = list(gradients) 181 | else: 182 | for i in range(len(grad_momenta)): 183 | grad_momenta[i] = grad_momenta[i] * momentum + gradients[i] 184 | updates = grad_momenta 185 | 186 | if self.beta is not None: 187 | # RMSProp: first, update the moving average squared norms 188 | if ewma_grad_norms_sq is None: 189 | ewma_grad_norms_sq = list(map(lambda g: g ** 2, gradients)) 190 | else: 191 | for i in range(len(ewma_grad_norms_sq)): 192 | ewma_grad_norms_sq[i] = beta * ewma_grad_norms_sq[i] + (1.0 - beta) * gradients[i] ** 2 193 | 194 | # scale updates by 1 / sqrt(moving_average_norm_squared + epsilon) 195 | for i in range(len(updates)): 196 | updates[i] = updates[i] / torch.sqrt(ewma_grad_norms_sq[i] + epsilon) 197 | 198 | # finally, perform sgd update 199 | updated_module = get_updated_model(module, loss=None, gradients=updates, parameters=parameters, 200 | learning_rate=learning_rate, **kwargs) 201 | new_state = self.OptimizerState(grad_momenta, ewma_grad_norms_sq, learning_rate, momentum, beta, epsilon) 202 | return new_state, updated_module 203 | 204 | def extra_repr(self): 205 | return repr(self.params) 206 | -------------------------------------------------------------------------------- /notebooks/imagenet_evaluate_nae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "__This notebook__ evaluates trained imaneget classifier on natural adversarial examples" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "env: CUDA_VISIBLE_DEVICES=2\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "%env CUDA_VISIBLE_DEVICES=2\n", 25 | "import os, sys, time\n", 26 | "sys.path.insert(0, '..')\n", 27 | "import lib\n", 28 | "import numpy as np\n", 29 | "from torchvision.models import resnet18 as ResNet18\n", 30 | "\n", 31 | "import torch, torch.nn as nn\n", 32 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import torchvision.datasets as dset\n", 42 | "import torchvision.transforms as trn\n", 43 | "\n", 44 | "mean = [0.485, 0.456, 0.406]\n", 45 | "std = [0.229, 0.224, 0.225]\n", 46 | "\n", 47 | "test_transform = trn.Compose(\n", 48 | " [trn.Resize(256), trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean, std)])\n", 49 | "\n", 50 | "thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1}\n", 51 | "two_hundred_to_1000 = dict(map(reversed, thousand_k_to_200.items()))" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# download Natural Adversarial Examples\n", 61 | "if not os.path.isdir(\"./imagenet-a/\"):\n", 62 | " if not os.path.isfile(\"./imagenet-a.tar\"):\n", 63 | " !wget https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar\n", 64 | " !tar xf imagenet-a.tar" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "X_adv, y_adv = zip(*dset.ImageFolder(root=\"./imagenet-a/\", transform=test_transform))\n", 74 | "X_adv = torch.stack(X_adv).to(device)\n", 75 | "y_adv = torch.tensor(list(map(two_hundred_to_1000.get, y_adv)), device = device)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "import torchvision.transforms as transforms\n", 85 | "import torchvision.datasets as datasets\n", 86 | "import torchvision.models as models\n", 87 | "\n", 88 | "data_path_val = '../../imagenet_val'\n", 89 | "val_loader = torch.utils.data.DataLoader(\n", 90 | " datasets.ImageFolder(data_path_val, transforms.Compose([\n", 91 | " transforms.Resize(256),\n", 92 | " transforms.CenterCrop(224),\n", 93 | " transforms.ToTensor(),\n", 94 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 95 | " std=[0.229, 0.224, 0.225]),\n", 96 | " ])),\n", 97 | " batch_size=128, shuffle=False,\n", 98 | " num_workers=16, pin_memory=True)\n", 99 | "\n", 100 | "X_test, y_test = map(torch.cat, zip(*val_loader))\n", 101 | "X_test, y_test = X_test[::10], y_test[::10]" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 6, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "model = ResNet18(pretrained=True)\n", 111 | "\n", 112 | "optimizer = lib.IngraphRMSProp(\n", 113 | " log_learning_rate=nn.Parameter(torch.log(torch.as_tensor(1e-4))),\n", 114 | " beta=nn.Parameter(torch.as_tensor(0.5)), momentum=None,\n", 115 | " log_epsilon=nn.Parameter(torch.log(torch.as_tensor(1e-3))),\n", 116 | ")\n", 117 | "\n", 118 | "class Flatten(nn.Module):\n", 119 | " def forward(self, x):\n", 120 | " return x.view(len(x), -1)\n", 121 | "\n", 122 | "model = lib.SequentialWithEditable(\n", 123 | " model.conv1, model.bn1, model.relu, model.maxpool,\n", 124 | " model.layer1,\n", 125 | " model.layer2,\n", 126 | " lib.Editable(model.layer3, loss_function=lib.contrastive_cross_entropy, optimizer=optimizer, max_steps=10),\n", 127 | " model.layer4,\n", 128 | " model.avgpool, Flatten(),\n", 129 | " model.fc\n", 130 | ").to(device).train(False)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 7, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "def classification_error(model, X_test, y_test):\n", 140 | " with lib.training_mode(model, is_train=False):\n", 141 | " return lib.classification_error(lib.Lambda(lambda x: model(x.to(device))),\n", 142 | " X_test, y_test, device='cpu', batch_size=128)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 8, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "0.9970666666666667" 154 | ] 155 | }, 156 | "execution_count": 8, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "classification_error(model, X_adv, y_adv)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "from lib import evaluate_quality\n", 172 | "from tqdm import tqdm_notebook\n", 173 | "\n", 174 | "np.random.seed(9)\n", 175 | "indices = np.random.permutation(len(X_adv))[:1000]\n", 176 | "X_edit = X_adv[indices]\n", 177 | "y_edit = y_adv[indices]\n", 178 | "metrics_adv = evaluate_quality(model, X_test, y_test, X_edit, y_edit, \n", 179 | " error_function=classification_error, progressbar=tqdm_notebook)\n", 180 | "\n", 181 | "for key in sorted(metrics_adv.keys()):\n", 182 | " print('{}\\t:{:.5}'.format(key, metrics[key]))" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [] 191 | } 192 | ], 193 | "metadata": { 194 | "kernelspec": { 195 | "display_name": "Python 3", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.6.4" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 2 214 | } 215 | -------------------------------------------------------------------------------- /mt/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from copy import copy 4 | 5 | import tqdm 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | from fairseq.criterions import register_criterion, FairseqCriterion 14 | from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion 15 | from fairseq.data import Dictionary 16 | from fairseq.data.language_pair_dataset import collate 17 | from fairseq.models.transformer import TransformerModel, TransformerDecoder 18 | from fairseq.tasks.translation import TranslationTask 19 | 20 | from lib import Editable, BaseEditable, IngraphRMSProp, training_mode, Lambda, copy_and_replace 21 | 22 | 23 | def encode_fast(voc, line): 24 | tokens = [voc.indices.get(tok, voc.unk_index) for tok in line.split()] 25 | tokens.append(voc.eos_index) 26 | return tokens 27 | 28 | 29 | def get_sentence_logp(logits, target, padding_ix, mean=True): 30 | logp = F.log_softmax(logits, dim=-1) 31 | logp_target_tokens = torch.gather(logp, -1, target[..., None])[..., 0] # [batch_size, max_len] 32 | mask = (target != padding_ix).to(dtype=logp.dtype) 33 | logp_target = (logp_target_tokens * mask).sum(dim=-1) 34 | if mean: 35 | logp_target = logp_target / mask.sum(dim=-1) 36 | return logp_target 37 | 38 | 39 | def read_edits(data_path : str, src_dict : Dictionary, tgt_dict : Dictionary): 40 | samples = [] 41 | with open(data_path) as f_in: 42 | for line in tqdm.tqdm(f_in): 43 | sentences = line.split('\t') 44 | src_sent = encode_fast(src_dict, sentences[0]) 45 | target = encode_fast(tgt_dict, sentences[1]) 46 | alternatives = [encode_fast(tgt_dict, x) for x in sentences[2:]] 47 | samples.append((src_sent, target, alternatives)) 48 | 49 | return samples 50 | 51 | 52 | @register_criterion('editable_training_criterion') 53 | class EditableTrainingCriterion(LabelSmoothedCrossEntropyCriterion): 54 | 55 | def __init__(self, args, task: TranslationTask): 56 | super().__init__(args, task) 57 | self.task = task 58 | self.eps = args.label_smoothing 59 | self.data_path = args.edit_samples_path 60 | self.editability_coeff = args.editability_coeff 61 | self.stability_coeff = args.stability_coeff 62 | self.max_steps = args.edit_max_steps 63 | self.almost_last = (args.almost_last != 0) 64 | print('!!!'*30) 65 | print('Editability coeff:', self.editability_coeff) 66 | print('Stability coeff:', self.stability_coeff) 67 | print('Max steps:', self.max_steps) 68 | print('Edit learning rate:', args.edit_learning_rate) 69 | print('Almost last:', self.almost_last) 70 | print('!!!'*30) 71 | self.optimizer = IngraphRMSProp(learning_rate=args.edit_learning_rate, beta=nn.Parameter(torch.tensor(0.5, dtype=torch.float32)) 72 | ) 73 | 74 | 75 | self.samples = read_edits(self.data_path, task.src_dict, task.tgt_dict) 76 | 77 | @staticmethod 78 | def add_args(parser): 79 | """Add criterion-specific arguments to the parser.""" 80 | # fmt: off 81 | LabelSmoothedCrossEntropyCriterion.add_args(parser) 82 | parser.add_argument('--edit-samples-path', type=str, metavar='D', 83 | help='path to training edits tsv') 84 | 85 | parser.add_argument('--stability-coeff', default=1e2, type=float, metavar='D', 86 | help='Stability loss multiplier') 87 | parser.add_argument('--editability-coeff', default=1e2, type=float, metavar='D', 88 | help='Failed edit penalty multiplier') 89 | parser.add_argument('--edit-max-steps', default=10, type=int, metavar='D', 90 | help='Max steps to perform during an editing') 91 | parser.add_argument('--edit-learning-rate', default=1e-3, type=float, metavar='D', 92 | help='Learning rate for RMSPror editor') 93 | parser.add_argument('--almost-last', default=0, type=int, metavar='D', 94 | help='if 0 use the last decoder layer to perform an edit else use penultimate') 95 | # fmt: on 96 | 97 | def get_edited_transformer(self, model, edit_sample, device=None, dtype=torch.int64, **kwargs): 98 | with torch.no_grad(): 99 | targets = [edit_sample[1], *edit_sample[2]] 100 | pad_ix = self.task.tgt_dict.pad() 101 | edit_target = torch.full([len(targets), max(map(len, targets))], fill_value=pad_ix, device=device, 102 | dtype=dtype) 103 | prev_output_tokens = edit_target.clone() 104 | for i, seq in enumerate(targets): 105 | edit_target[i, :len(seq)] = torch.as_tensor(seq, dtype=dtype) 106 | prev_output_tokens[i, :len(seq)] = torch.as_tensor([self.task.tgt_dict.eos_index] + seq[:-1], 107 | dtype=dtype) 108 | 109 | edit_source = torch.as_tensor([edit_sample[0]] * edit_target.shape[0], device=device, dtype=dtype) 110 | edit_lengths = torch.full(edit_source.shape[:1], len(edit_sample[0]), device=device, dtype=dtype) 111 | 112 | edit_input = {'src_tokens': edit_source, # [batch, max_src_len] 113 | 'src_lengths': edit_lengths, # [batch] 114 | 'target': edit_target, # [batch, max_tgt_len] 115 | 'prev_output_tokens': prev_output_tokens} # [batch, max_tgt_len] 116 | 117 | while not isinstance(model, TransformerModel): 118 | model = model.module 119 | 120 | if self.almost_last: 121 | editable_model = EditableTransformer( 122 | self, model, xbost_fist_layer_id=len(model.decoder.layers) - 2, 123 | optimizer=self.optimizer, max_steps=self.max_steps, 124 | get_editable_parameters=lambda decoder_xbost: decoder_xbost.xbost_layers[0].parameters() 125 | ) 126 | else: 127 | editable_model = EditableTransformer( 128 | self, model, xbost_fist_layer_id=len(model.decoder.layers) - 1, 129 | optimizer=self.optimizer, max_steps=self.max_steps, 130 | get_editable_parameters=lambda decoder_xbost: decoder_xbost.xbost_layers.parameters() 131 | ) 132 | with training_mode(model, is_train=False): 133 | return editable_model.edit(edit_input, **kwargs) 134 | 135 | def forward(self, model, sample, reduce=True): 136 | """Compute the loss for the given sample. 137 | 138 | Returns a tuple with three elements: 139 | 1) the loss 140 | 2) the sample size, which is used as the denominator for the gradient 141 | 3) logging outputs to display while training 142 | """ 143 | 144 | if not model.training: 145 | return super().forward(model, sample, reduce) 146 | 147 | device = sample['net_input']['src_tokens'].device 148 | dtype = sample['net_input']['src_tokens'].dtype 149 | 150 | edit_sample = random.choice(self.samples) 151 | 152 | edited_model, success, editability_loss, edit_complexity = \ 153 | self.get_edited_transformer(model, edit_sample, device, dtype) 154 | 155 | net_output = model(**sample['net_input']) 156 | 157 | with training_mode(model, is_train=False): 158 | edited_output = edited_model(**sample['net_input']) 159 | 160 | main_loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 161 | 162 | ref_logits = net_output[0].detach() 163 | stability_loss = (F.softmax(ref_logits, dim=-1) 164 | * (F.log_softmax(ref_logits, dim=-1) - F.log_softmax(edited_output[0], dim=-1)) 165 | ).sum(-1).mean() 166 | loss = main_loss + self.stability_coeff * stability_loss + self.editability_coeff * editability_loss 167 | 168 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 169 | logging_output = { 170 | 'loss': utils.item(loss.data) if reduce else loss.data, 171 | 'main_loss': utils.item(main_loss.data) if reduce else main_loss.data, 172 | 'editability_loss': utils.item(editability_loss.data) if reduce else editability_loss.data, 173 | 'stability_loss': utils.item(stability_loss.data) if reduce else stability_loss.data, 174 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 175 | 'ntokens': sample['ntokens'], 176 | 'nsentences': sample['target'].size(0), 177 | 'sample_size': sample_size, 178 | 'edit_complexity': edit_complexity 179 | } 180 | return loss, sample_size, logging_output 181 | 182 | @staticmethod 183 | def aggregate_logging_outputs(logging_outputs): 184 | """Aggregate logging outputs from data parallel training.""" 185 | xent_outputs_dict = LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs(logging_outputs) 186 | 187 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 188 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 189 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 190 | 191 | if 'editability_loss' not in logging_outputs[0]: 192 | return xent_outputs_dict 193 | 194 | xent_outputs_dict['editability_loss'] = sum(log['editability_loss'] for log in logging_outputs) / len( 195 | logging_outputs) 196 | xent_outputs_dict['main_loss'] = sum( 197 | log.get('main_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0. 198 | xent_outputs_dict['stability_loss'] = sum(log['stability_loss'] for log in logging_outputs) / len( 199 | logging_outputs) 200 | xent_outputs_dict['edit_complexity'] = sum(log['edit_complexity'] for log in logging_outputs) / len( 201 | logging_outputs) 202 | 203 | return xent_outputs_dict 204 | 205 | 206 | def loss_function(net_output, target, criterion, transformer, **loss_kwargs): 207 | """ Compute editability loss. Only apply layers that were not pre-computed in def edit """ 208 | _, nll_losses = criterion.compute_loss(transformer, net_output, dict(target=target), reduce=False) 209 | logps = get_sentence_logp(net_output[0], target, padding_ix=1, mean=True) 210 | loss = torch.relu(torch.max(logps[1:]) - logps[0]) 211 | 212 | return loss # scalar 213 | 214 | 215 | class EditableTransformer(BaseEditable): 216 | def __init__(self, criterion, transformer: TransformerModel, xbost_fist_layer_id, mean_logp=True, **kwargs): 217 | super().__init__() 218 | self.criterion = criterion 219 | self.transformer = transformer 220 | self.mean_logp = mean_logp 221 | self.padding_ix = self.criterion.task.tgt_dict.pad() 222 | self.xbost_fist_layer_id = xbost_fist_layer_id 223 | 224 | self.editable_xbost = Editable( 225 | self.TransformerDecoderXBost(transformer.decoder, first_layer_id=xbost_fist_layer_id), 226 | loss_function=loss_function, 227 | **kwargs 228 | ) 229 | 230 | def edit(self, edit_input, **kwargs): 231 | transformer = self.transformer 232 | assert isinstance(transformer, TransformerModel) 233 | encoder_out = transformer.encoder(edit_input['src_tokens'], src_lengths=edit_input['src_lengths']) 234 | decoder_states_pre_xbost = self.decoder_pre_xbost(edit_input['prev_output_tokens'], encoder_out) 235 | edit_result = self.editable_xbost.edit( 236 | dict(edit_input, encoder_out=encoder_out, decoder_states_pre_xbost=decoder_states_pre_xbost), 237 | targets=edit_input['target'], loss_kwargs=dict(criterion=self.criterion, transformer=self.transformer), 238 | **kwargs) 239 | 240 | edited_xbost, success, loss, complexity = edit_result 241 | 242 | edited_self = EditableTransformer(self.criterion, self.transformer, self.xbost_fist_layer_id, 243 | mean_logp=self.mean_logp) 244 | edited_self.training = self.training 245 | edited_self.editable_xbost = edited_xbost 246 | return Editable.EditResult(edited_self, success, loss, complexity) 247 | 248 | def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): 249 | encoder_out = self.transformer.encoder(src_tokens, src_lengths=src_lengths, **kwargs) 250 | decoder_states_pre_xbost = self.decoder_pre_xbost(prev_output_tokens, encoder_out, **kwargs) 251 | model_out = self.editable_xbost.module.decoder_post_xbost(encoder_out, decoder_states_pre_xbost) 252 | return model_out 253 | 254 | def recover_transformer(self): 255 | original_xbost = self.TransformerDecoderXBost(self.transformer.decoder, first_layer_id=self.xbost_fist_layer_id) 256 | edited_xbost = self.editable_xbost.module 257 | assert isinstance(original_xbost, self.TransformerDecoderXBost) and isinstance(edited_xbost, self.TransformerDecoderXBost) 258 | 259 | replacement_dict = {} 260 | edited_params = dict(edited_xbost.named_parameters()) 261 | for key, param in original_xbost.named_parameters(): 262 | replacement_dict[param] = edited_params[key] 263 | 264 | return copy_and_replace(self.transformer, replace=replacement_dict) 265 | 266 | def decoder_pre_xbost(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): 267 | decoder = self.transformer.decoder 268 | # embed positions 269 | positions = decoder.embed_positions( 270 | prev_output_tokens, 271 | incremental_state=incremental_state, 272 | ) if decoder.embed_positions is not None else None 273 | 274 | if incremental_state is not None: 275 | prev_output_tokens = prev_output_tokens[:, -1:] 276 | if positions is not None: 277 | positions = positions[:, -1:] 278 | 279 | # embed tokens and positions 280 | x = decoder.embed_scale * decoder.embed_tokens(prev_output_tokens) 281 | 282 | if decoder.project_in_dim is not None: 283 | x = decoder.project_in_dim(x) 284 | 285 | if positions is not None: 286 | x += positions 287 | x = F.dropout(x, p=decoder.dropout, training=decoder.training) 288 | 289 | # B x T x C -> T x B x C 290 | x = x.transpose(0, 1) 291 | attn = None 292 | 293 | inner_states = [x] 294 | 295 | # decoder layers 296 | for layer in decoder.layers[:self.xbost_fist_layer_id]: 297 | x, attn = layer( 298 | x, 299 | encoder_out['encoder_out'] if encoder_out is not None else None, 300 | encoder_out['encoder_padding_mask'] if encoder_out is not None else None, 301 | incremental_state, 302 | self_attn_mask=decoder.buffered_future_mask(x) if incremental_state is None else None, 303 | ) 304 | inner_states.append(x) 305 | return inner_states 306 | 307 | class TransformerDecoderXBost(TransformerDecoder): 308 | """ Temporary module that applies the second part of transformer decoder """ 309 | 310 | def __init__(self, decoder: TransformerDecoder, *, first_layer_id): 311 | nn.Module.__init__(self) 312 | self.first_layer_id = first_layer_id 313 | self.xbost_layers = decoder.layers[first_layer_id:] 314 | assert isinstance(self.xbost_layers, nn.ModuleList) 315 | self.layer_norm = decoder.layer_norm 316 | self.project_out_dim = decoder.project_out_dim 317 | self.output_layer = decoder.output_layer 318 | 319 | def forward(self, edit_input): 320 | return self.decoder_post_xbost(**edit_input) 321 | 322 | def decoder_post_xbost(self, encoder_out, decoder_states_pre_xbost, **unused): 323 | """ Apply final decoder layers after forward_pre_xbost """ 324 | incremental_state = None 325 | inner_states = list(decoder_states_pre_xbost) 326 | x = decoder_states_pre_xbost[-1] 327 | 328 | attn = None 329 | 330 | # decoder layers: xbost 331 | for layer in self.xbost_layers: 332 | x, attn = layer( 333 | x, 334 | encoder_out['encoder_out'] if encoder_out is not None else None, 335 | encoder_out['encoder_padding_mask'] if encoder_out is not None else None, 336 | incremental_state, 337 | self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, 338 | ) 339 | inner_states.append(x) 340 | 341 | if self.layer_norm: 342 | x = self.layer_norm(x) 343 | 344 | # T x B x C -> B x T x C 345 | x = x.transpose(0, 1) 346 | 347 | if self.project_out_dim is not None: 348 | x = self.project_out_dim(x) 349 | 350 | x = self.output_layer(x) 351 | return x, {'attn': attn, 'inner_states': inner_states} 352 | -------------------------------------------------------------------------------- /lib/utils/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generalized & extendable trainer class that handles training and evaluation 3 | """ 4 | import os 5 | import time 6 | import glob 7 | from itertools import count, chain 8 | from warnings import warn 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from torch.utils.data import DataLoader, Dataset 14 | from tqdm import tqdm 15 | 16 | from ..utils import get_latest_file, check_numpy, process_in_chunks, training_mode, \ 17 | infer_model_device, iterate_minibatches, nop_ctx, nop, clear_output 18 | from contextlib import contextmanager 19 | from collections import OrderedDict 20 | from copy import deepcopy 21 | from tensorboardX import SummaryWriter 22 | 23 | 24 | class BaseTrainer(nn.Module): 25 | def __init__(self, model: nn.Module, experiment_name=None, warm_start=False, verbose=False, 26 | num_averaged_checkpoints=1, keep_checkpoints=None, **extra_attrs): 27 | """ 28 | Training helper that trains the model to minimize loss in a supervised mode, 29 | computes metrics and does a few other tricks if you ask nicely 30 | :param experiment_name: a path where all logs and checkpoints are saved 31 | :param warm_start: when set to True, loads last checpoint 32 | :param verbose: logging verbosity 33 | :param num_averaged_checkpoints: if > 1, averages this many previous model checkpoints for evaluation 34 | :param verbose: when set to True, produces logging information 35 | :param extra_attrs: dict {name: module} to be saved inside trainer via setattr 36 | """ 37 | super().__init__() 38 | self.keep_checkpoints = keep_checkpoints or num_averaged_checkpoints 39 | self.num_averaged_checkpoints = num_averaged_checkpoints 40 | self.verbose = verbose 41 | self.total_steps = 0 42 | self.model = model 43 | self.best_metrics = {} 44 | for module_name, module in extra_attrs.items(): 45 | setattr(self, module_name, module) 46 | 47 | if experiment_name is None: 48 | experiment_name = 'untitled_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(*time.gmtime()[:6]) 49 | if self.verbose: 50 | print('using automatic experiment name: ' + experiment_name) 51 | 52 | self.experiment_path = os.path.join('logs/', experiment_name) 53 | if not warm_start and experiment_name != 'debug': 54 | assert not os.path.exists(self.experiment_path), 'experiment {} already exists'.format(experiment_name) 55 | self.writer = SummaryWriter(self.experiment_path, comment=experiment_name) 56 | if warm_start: 57 | self.load_checkpoint() 58 | 59 | def train_on_batch(self, *args, **kwargs): 60 | """ Perform a single gradient update and reports metrics and increment self.step """ 61 | raise NotImplementedError() 62 | 63 | def evaluate_metrics(self, *args, **kwargs): 64 | """ Predicts and evaluates metrics over the entire dataset """ 65 | raise NotImplementedError() 66 | 67 | def predict(self, *inputs, batch_size=1024, is_train=False, device=None, memory_efficient=False, **kwargs): 68 | """ 69 | Compute model predictions over a (large) number of samples 70 | :param inputs: one or several input arrays, each of shape [batch_size, *whatever] 71 | :param batch_size: predicts for this many samples over one call to the model 72 | :param is_train: if True, runs model in training mode (e.g. with dropout) 73 | :param device: moves all inputs to that device, defaults to infer_model_device 74 | :param memory_efficient: if True, data is transferred to device one batch at a time 75 | otherwise (default), transfers all data on device in advance 76 | :param kwargs: key-value arguments passed to every model call 77 | :return: 78 | """ 79 | inputs = tuple(map(torch.as_tensor, inputs)) 80 | device = device or infer_model_device(self.model) 81 | 82 | if memory_efficient: 83 | def predict_on_batch(*batch): 84 | batch = (tensor.to(device=device) for tensor in batch) 85 | return self.model(*batch, **kwargs).cpu() 86 | else: 87 | inputs = (tensor.to(device=device) for tensor in inputs) 88 | predict_on_batch = self.model 89 | 90 | with training_mode(self.model, is_train=is_train), torch.no_grad(): 91 | predictions = process_in_chunks(predict_on_batch, *inputs, batch_size=batch_size) 92 | predictions = check_numpy(predictions) 93 | return predictions 94 | 95 | def record(self, *, prefix='', **metrics): 96 | """ 97 | Computes and saves metrics into tensorboard 98 | :param prefix: common prefix for tensorboard 99 | :param metrics: key-value parameters forwarded into every metric 100 | :return: metrics (same as input) 101 | """ 102 | if not (prefix == '' or prefix.endswith('/')): 103 | warn("It is recommended that prefix ends with slash(/) for readability") 104 | 105 | for key, value in metrics.items(): 106 | assert np.shape(value) == (), "metric {} must be scalar, but got {}".format(key, np.shape(value)) 107 | self.writer.add_scalar(prefix + str(key), value, self.total_steps) 108 | return metrics 109 | 110 | def save_checkpoint(self, tag=None, path=None, mkdir=True, clear_old=False, number_ckpts_to_keep=None, **kwargs): 111 | assert tag is None or path is None, "please provide either tag or path or nothing, not both" 112 | if tag is None and path is None: 113 | tag = "temp_{}".format(self.total_steps) 114 | if path is None: 115 | path = os.path.join(self.experiment_path, "checkpoint_{}.pth".format(tag)) 116 | if mkdir: 117 | os.makedirs(os.path.dirname(path), exist_ok=True) 118 | torch.save(OrderedDict([ 119 | ('model', self.state_dict(**kwargs)), 120 | ('opt', self.opt.state_dict()), 121 | ('step', self.total_steps), 122 | ('best_metrics', self.best_metrics), 123 | ]), path) 124 | if self.verbose: 125 | print("Saved " + path) 126 | if clear_old: 127 | self.remove_old_temp_checkpoints(number_ckpts_to_keep) 128 | return path 129 | 130 | def load_checkpoint(self, tag=None, path=None, **kwargs): 131 | assert tag is None or path is None, "please provide either tag or path or nothing, not both" 132 | if tag is None and path is None: 133 | path = get_latest_file(os.path.join(self.experiment_path, 'checkpoint_temp_[0-9]*.pth')) 134 | elif tag is not None and path is None: 135 | path = os.path.join(self.experiment_path, "checkpoint_{}.pth".format(tag)) 136 | checkpoint = torch.load(path) 137 | 138 | self.load_state_dict(checkpoint['model'], **kwargs) 139 | self.opt.load_state_dict(checkpoint['opt']) 140 | self.total_steps = int(checkpoint['step']) 141 | self.best_metrics = checkpoint['best_metrics'] 142 | 143 | if self.verbose: 144 | print('Loaded ' + path) 145 | return self 146 | 147 | @contextmanager 148 | def using_checkpoint(self, **kwargs): 149 | """ 150 | Backups current checkpoint, loads new one in context, restores current checkpoint upon exiting context 151 | :param kwargs: loads checkpoint with these params (e.g. tag or path) 152 | """ 153 | current_checkpoint_tag = 'current' 154 | while True: 155 | current_checkpoint_tag += '_backup' 156 | path = os.path.join(self.experiment_path, "checkpoint_{}.pth".format(current_checkpoint_tag)) 157 | if not os.path.exists(path): 158 | break 159 | 160 | self.save_checkpoint(current_checkpoint_tag) 161 | self.load_checkpoint(**kwargs) 162 | yield 163 | self.load_checkpoint(current_checkpoint_tag) 164 | os.remove(path) 165 | 166 | def average_checkpoints(self, tags=None, paths=None, out_tag='avg', out_path=None): 167 | assert tags is None or paths is None, "please provide either tags or paths or nothing, not both" 168 | assert out_tag is not None or out_path is not None, "please provide either out_tag or out_path or both" 169 | if tags is None and paths is None: 170 | paths = self.get_latest_checkpoints( 171 | os.path.join(self.experiment_path, 'checkpoint_temp_[0-9]*.pth'), self.num_averaged_checkpoints) 172 | elif tags is not None and paths is None: 173 | paths = [os.path.join(self.experiment_path, 'checkpoint_{}.pth'.format(tag)) for tag in tags] 174 | 175 | checkpoints = [torch.load(path) for path in paths] 176 | averaged_ckpt = deepcopy(checkpoints[0]) 177 | for key in averaged_ckpt['model']: 178 | values = [ckpt['model'][key] for ckpt in checkpoints] 179 | averaged_ckpt['model'][key] = sum(values) / len(values) 180 | 181 | if out_path is None: 182 | out_path = os.path.join(self.experiment_path, 'checkpoint_{}.pth'.format(out_tag)) 183 | torch.save(averaged_ckpt, out_path) 184 | 185 | def get_latest_checkpoints(self, pattern, n_last=None): 186 | list_of_files = glob.glob(pattern) 187 | assert len(list_of_files) > 0, "No files found: " + pattern 188 | return sorted(list_of_files, key=os.path.getctime, reverse=True)[:n_last] 189 | 190 | def remove_old_temp_checkpoints(self, number_ckpts_to_keep=None): 191 | if number_ckpts_to_keep is None: 192 | number_ckpts_to_keep = self.keep_checkpoints 193 | paths = self.get_latest_checkpoints(os.path.join(self.experiment_path, 'checkpoint_temp_[0-9]*.pth')) 194 | paths_to_delete = paths[number_ckpts_to_keep:] 195 | 196 | for ckpt in paths_to_delete: 197 | if self.verbose: 198 | print("Removing", ckpt) 199 | os.remove(ckpt) 200 | 201 | def step(self, *args, **kwargs): 202 | """ Trains on batch and updates the counter of total_steps """ 203 | was_steps = self.total_steps 204 | metrics = self.train_on_batch(*args, **kwargs) 205 | assert self.total_steps == was_steps, "total_steps changed within train_on_batch" 206 | self.total_steps += 1 207 | return metrics 208 | 209 | def forward(self, *inputs, **kwargs): 210 | """ see train_on_batch """ 211 | return self.step(*inputs, **kwargs) 212 | 213 | def fit(self, training_data, batch_size=None, shuffle=True, epochs=1, start_epoch=1, batches_per_epoch=None, 214 | batcher_kwargs=None, progressbar=None, clear_outputs=False, device='auto', val_data=None, eval_kwargs=None, 215 | early_stopping_minimize=(), early_stopping_maximize=(), early_stopping_epochs=None, **kwargs): 216 | """ 217 | Trains for one or several epochs on minibatches of data, optionally evaluates dev metrics after each epoch 218 | :param training_data: training data source, must be either of 219 | * torch DataLoader or Dataset 220 | * a list or tuple of tensors 221 | * iterator of minibatches 222 | :param batch_size: splits tensors into chunks of this size over 0-th divension 223 | :param shuffle: (default) shuffles tensors symmetrically over 0-th dimension 224 | :param batcher_kwargs: keyword parameters to be fed into data iterator 225 | :param progressbar: if True or callback (e.g. tqdm), prints progress of each training epoch 226 | :param epochs: performs this many passes over training data, float('inf') for inifinite loop 227 | :param device: puts minibatches on this device. None to keep original device, 'auto' to try infer model device, 228 | :param val_data: if not None, calls self.evaluate_metrics on this data after each epoch 229 | :param eval_kwargs: additional kwargs for self.evaluate_metrics, only used if eval_data is not None 230 | :param early_stopping_maximize: keeps checkpoints with highest values of these metrics 231 | :param early_stopping_minimize: keeps checkpoints with lowest values of these metrics 232 | :param early_stopping_epochs: stops training if there were no updates on early_stopping_maximize/minimize 233 | for at least this many epochs 234 | :param start_epoch: initial epoch index, only used for printing (epoch ##) 235 | :param kwargs: additional kwargs for self.step (train_on_batch) 236 | :return: self 237 | """ 238 | device = getattr(self, 'device', infer_model_device(self)) if device == 'auto' else device 239 | progressbar = tqdm if progressbar is True else progressbar or nop 240 | epochs, early_stopping_epochs = epochs or float('inf'), early_stopping_epochs or float('inf') 241 | eval_kwargs, batcher_kwargs = eval_kwargs or dict(), batcher_kwargs or dict() 242 | if isinstance(early_stopping_minimize, str): early_stopping_minimize = [early_stopping_minimize] 243 | if isinstance(early_stopping_maximize, str): early_stopping_maximize = [early_stopping_maximize] 244 | number_of_epochs_without_improvement = 0 245 | 246 | # prepare training data one way or another 247 | if isinstance(training_data, DataLoader): 248 | make_training_epoch = lambda: iter(progressbar(training_data)) 249 | elif isinstance(training_data, Dataset): 250 | make_training_epoch = torch.utils.data.DataLoader( 251 | training_data, batch_size=batch_size, shuffle=shuffle, **batcher_kwargs) 252 | elif isinstance(training_data, (list, tuple)): 253 | make_training_epoch = lambda: iterate_minibatches( 254 | *training_data, batch_size=batch_size, epochs=1, shuffle=shuffle, 255 | callback=progressbar, **batcher_kwargs) 256 | else: 257 | training_data = iter(training_data) 258 | assert batches_per_epoch is not None or epochs == 1, "if data is an iterator, please provide " \ 259 | "batches_per_epoch or use a single epoch" 260 | def make_training_epoch(): 261 | for _ in progressbar(range(batches_per_epoch) if batches_per_epoch else count()): 262 | yield next(training_data) 263 | 264 | # iterate training epochs 265 | for epoch_i in count(start=start_epoch): 266 | if epoch_i >= epochs + start_epoch: 267 | if self.verbose: 268 | print("Stopping because of reaching target number of epochs") 269 | break 270 | if self.verbose: 271 | print("Epoch #{}/{}".format(epoch_i, epochs)) 272 | 273 | for batch in make_training_epoch(): 274 | if device is not None: 275 | batch = tuple(torch.as_tensor(tensor, device=device) for tensor in batch) 276 | self.step(*batch, **kwargs) 277 | 278 | if clear_outputs: 279 | clear_output() 280 | 281 | self.save_checkpoint(clear_old=True) 282 | if self.num_averaged_checkpoints > 1: 283 | self.average_checkpoints(out_tag='avg') 284 | 285 | if val_data is not None: 286 | if self.verbose: 287 | print("Evaluating...") 288 | 289 | with self.using_checkpoint(tag='avg') if self.num_averaged_checkpoints > 1 else nop_ctx(): 290 | val_metrics = self.evaluate_metrics(*val_data, **eval_kwargs) 291 | 292 | if self.verbose: 293 | for key, value in val_metrics.items(): 294 | print(key, value) 295 | print() 296 | 297 | # handle best metrics and early stopping 298 | number_of_epochs_without_improvement += 1 299 | 300 | for key, value in val_metrics.items(): 301 | found_new_best = False 302 | if key in early_stopping_maximize: 303 | if value > self.best_metrics.get(key, -float('inf')): 304 | found_new_best = True 305 | if key in early_stopping_minimize: 306 | if value < self.best_metrics.get(key, float('inf')): 307 | found_new_best = True 308 | if found_new_best: 309 | self.best_metrics[key] = value 310 | number_of_epochs_without_improvement = 0 311 | self.save_checkpoint(tag='best_' + key) 312 | 313 | for key in chain(early_stopping_maximize, early_stopping_minimize): 314 | if key not in val_metrics: 315 | warn("Metric name {} not found but requested for maximizing/minimizing") 316 | 317 | if number_of_epochs_without_improvement >= early_stopping_epochs: 318 | if self.verbose: 319 | print("Early stopping because of no improvement in " 320 | "{} epochs".format(number_of_epochs_without_improvement)) 321 | break 322 | 323 | else: 324 | assert eval_kwargs is None, "Eval kwargs is unused if val_data is None" 325 | assert early_stopping_epochs == float('inf'), "Early stopping requires val_data" 326 | assert len(early_stopping_minimize) == len(early_stopping_maximize) == 0, "Please provide val_data" 327 | 328 | return self 329 | 330 | 331 | class SupervisedTrainer(BaseTrainer): 332 | def __init__(self, model: nn.Module, loss_function, opt=None, **kwargs): 333 | """ A simple optimizer that trains to minimize classification or regression loss """ 334 | opt = opt if opt is not None else torch.optim.Adam(model.parameters()) 335 | super().__init__(model, loss_function=loss_function, opt=opt, **kwargs) 336 | 337 | def train_on_batch(self, x_batch, y_batch, prefix='train/', is_train=True): 338 | """ Performs a single gradient update and reports metrics """ 339 | x_batch, y_batch = map(torch.as_tensor, (x_batch, y_batch)) 340 | self.opt.zero_grad() 341 | 342 | with training_mode(self.model, is_train=is_train): 343 | prediction = self.model(x_batch) 344 | 345 | loss = self.loss_function(prediction, y_batch).mean() 346 | loss.backward() 347 | self.opt.step() 348 | 349 | return self.record(loss=loss.item(), prefix=prefix) 350 | 351 | def evaluate_metrics(self, X, y, prefix='val/', **kwargs): 352 | """ Predicts and evaluates metrics over the entire dataset """ 353 | prediction = self.predict(X, **kwargs) 354 | with torch.no_grad(): 355 | loss = self.loss_function(torch.as_tensor(prediction), torch.as_tensor(y)).mean() 356 | 357 | return self.record(loss=loss.item(), prefix=prefix) 358 | -------------------------------------------------------------------------------- /notebooks/imagenet_editable_training_with_natural_distribution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "imagenet_nashlepka_layer4_editable_SGD_momentum_match_rank_2019.09.24_17:00:13\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "%load_ext autoreload\n", 18 | "%autoreload 2\n", 19 | "%env CUDA_VISIBLE_DEVICES=SETYOURDEVICEHERE\n", 20 | "import os, sys, time\n", 21 | "sys.path.insert(0, '..')\n", 22 | "import lib\n", 23 | "\n", 24 | "import numpy as np\n", 25 | "import torch, torch.nn as nn\n", 26 | "import torch.nn.functional as F\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "%matplotlib inline\n", 29 | "\n", 30 | "import random\n", 31 | "random.seed(42)\n", 32 | "np.random.seed(42)\n", 33 | "torch.random.manual_seed(42)\n", 34 | "\n", 35 | "import time\n", 36 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 37 | "\n", 38 | "experiment_name = 'imagenet_nashlepka_layer4_editable_SGD_momentum_match_rank'\n", 39 | "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n", 40 | "print(experiment_name)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import torchvision.transforms as transforms\n", 50 | "import torchvision.datasets as datasets\n", 51 | "import torchvision.models as models\n", 52 | "\n", 53 | "data_path = '../../imagenet/'\n", 54 | "logits_path = 'imagenet_logits/'\n", 55 | "\n", 56 | "traindir = os.path.join(data_path, 'train')\n", 57 | "valdir = os.path.join(data_path, 'val')\n", 58 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 59 | " std=[0.229, 0.224, 0.225])\n", 60 | "\n", 61 | "train_dataset = lib.ImageAndLogitsFolder(\n", 62 | " traindir,\n", 63 | " transforms.Compose([\n", 64 | " transforms.RandomResizedCrop(224),\n", 65 | " transforms.RandomHorizontalFlip(),\n", 66 | " transforms.ToTensor(),\n", 67 | " normalize,\n", 68 | " ]),\n", 69 | " logits_prefix = logits_path\n", 70 | ")\n", 71 | "\n", 72 | "batch_size = 128\n", 73 | "\n", 74 | "train_loader = torch.utils.data.DataLoader(\n", 75 | " train_dataset, batch_size=batch_size, shuffle=True,\n", 76 | " num_workers=12, pin_memory=True)\n", 77 | "\n", 78 | "val_loader = torch.utils.data.DataLoader(\n", 79 | " datasets.ImageFolder(valdir, transforms.Compose([\n", 80 | " transforms.Resize(256),\n", 81 | " transforms.CenterCrop(224),\n", 82 | " transforms.ToTensor(),\n", 83 | " normalize,\n", 84 | " ])),\n", 85 | " batch_size=batch_size, shuffle=False,\n", 86 | " num_workers=32, pin_memory=True)\n", 87 | "\n", 88 | "X_test, y_test = map(torch.cat, zip(*val_loader))\n", 89 | "X_test, y_test = X_test[::10], y_test[::10] \n", 90 | "# Note: we use 10% of data for early stopping\n", 91 | "# We evaluate on all data later" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "import torchvision\n", 101 | "\n", 102 | "model = torchvision.models.resnet18(pretrained=True)\n", 103 | "\n", 104 | "optimizer = lib.IngraphRMSProp(learning_rate=1e-4, beta=nn.Parameter(torch.as_tensor(0.5)))\n", 105 | "\n", 106 | "model = lib.SequentialWithEditable(\n", 107 | " model.conv1, model.bn1, model.relu, model.maxpool,\n", 108 | " model.layer1, model.layer2, model.layer3, model.layer4,\n", 109 | " model.avgpool, lib.Flatten(),\n", 110 | " lib.Editable(\n", 111 | " lib.Residual(nn.Linear(512, 4096), nn.ELU(), nn.Linear(4096, 512)),\n", 112 | " loss_function=lib.contrastive_cross_entropy, \n", 113 | " optimizer=optimizer, max_steps=10),\n", 114 | "\n", 115 | " model.fc\n", 116 | ").to(device)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "class EditableMatchRankTrainer(lib.DistillationEditableTrainer):\n", 126 | " def __init__(self, distribution_cumsum, **kwargs):\n", 127 | " super().__init__(**kwargs)\n", 128 | " self.cumsum = distribution_cumsum\n", 129 | " \n", 130 | " def train_on_batch(self, x_batch, y_batch, x_edit, y_edit, **kwargs):\n", 131 | " with torch.no_grad(), lib.training_mode(self.model, is_train=False):\n", 132 | " logits = self.model(x_edit)\n", 133 | " sorted_ans = logits.topk(k=logits.shape[1], dim=1).indices\n", 134 | " choices = (torch.rand(sorted_ans.shape[0]).view(-1, 1) > self.cumsum).to(torch.int32).sum(-1)\n", 135 | " choices = choices.to(sorted_ans.device)\n", 136 | " y_edit = sorted_ans.gather(dim=1, index=choices.view(-1, 1)).view(-1)\n", 137 | " super().train_on_batch(x_batch, y_batch, x_edit, y_edit, **kwargs)\n", 138 | " \n", 139 | " def evaluate_metrics(self, X, y, X_edit=None, y_edit=None, size_top=25, **kwargs):\n", 140 | " \"\"\"\n", 141 | " For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality\n", 142 | " :param X: data for quality evaluaton\n", 143 | " :param y: targets for quality evaluaton\n", 144 | " :param X_edit: sequence of data for training model on\n", 145 | " :param y_edit: sequence of targets for training model on\n", 146 | " :param kwargs: extra parameters for error function\n", 147 | " :return: dictionary of metrics\n", 148 | " \"\"\"\n", 149 | " assert (X_edit is None) == (y_edit is None), \"provide either both X_edit and y_edit or none of them\"\n", 150 | " if X_edit is None:\n", 151 | " num_classes = y.max() + 1\n", 152 | " ind = np.random.permutation(len(X))[:10]\n", 153 | " X_edit = X[ind]\n", 154 | " with torch.no_grad(), lib.training_mode(self.model, is_train=False):\n", 155 | " logits = self.model(X_edit)\n", 156 | " sorted_ans = logits.topk(k=logits.shape[1], dim=1).indices\n", 157 | " choices = (torch.rand(sorted_ans.shape[0]).view(-1, 1) > self.cumsum).to(torch.int32).sum(-1)\n", 158 | " choices = choices.to(sorted_ans.device)\n", 159 | " y_edit = sorted_ans.gather(dim=1, index=choices.view(-1, 1)).view(-1)\n", 160 | " return super().evaluate_metrics(X, y, X_edit, y_edit, **kwargs)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "def classification_error(model, X_test, y_test):\n", 170 | " with lib.training_mode(model, is_train=False):\n", 171 | " return lib.classification_error(lib.Lambda(lambda x: model(x.to(device))),\n", 172 | " X_test, y_test, device='cpu', batch_size=128)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "## Read natural adversarial examples" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "mean = [0.485, 0.456, 0.406]\n", 189 | "std = [0.229, 0.224, 0.225]\n", 190 | "\n", 191 | "test_transform = transforms.Compose(\n", 192 | " [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)])\n", 193 | "\n", 194 | "thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1}\n", 195 | "two_hundred_to_1000 = dict(map(reversed, thousand_k_to_200.items()))\n", 196 | "\n", 197 | "X_adv, y_adv = zip(*datasets.ImageFolder(root=\"./imagenet-a/\", transform=test_transform))\n", 198 | "X_adv = torch.stack(X_adv)\n", 199 | "y_adv = torch.tensor(list(map(two_hundred_to_1000.get, y_adv)))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 27, 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "data": { 209 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD8CAYAAAB3u9PLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl4VOXZ+PHvPTPJZE8IhH0JSJBNQUVEXOsKWsW28hbbWhdau2i1br8Xq7WtrVatLdq+amvrvhRcqFJFUURc2cImayDsIRAChOx7nt8f58xkZjJJhpD93J/r4uLMOc8585wZOPc8uxhjUEoppVwdnQGllFKdgwYEpZRSgAYEpZRSNg0ISimlAA0ISimlbBoQlFJKARoQlFJK2TQgKKWUAjQgKKWUsnk6OgPHolevXiY9Pb2js6GUUl3GqlWrDhlj0iJJ26UCQnp6OpmZmR2dDaWU6jJEZHekabXKSCmlFKABQSmllE0DglJKKUADglJKKZsGBKWUUoAGBKWUUjYNCEoppQCHBYSVu46QdaC4o7OhlFKdUpcamHa8pv99KQC7Hr68g3OilFKdj6NKCEoppRqnAUEppRSgAUEppZRNA4JSSikgwoAgIlNEJEtEskVkVpjjXhGZax9fLiLp9v6eIvKJiJSIyP+FnHOaiKy3z/mriEhr3JBSSqmWaTYgiIgbeBKYCowGrhGR0SHJZgIFxpjhwGzgEXt/BfBr4K4wl34auAnIsP9MackNRMoY05aXV0qpLi+SEsJEINsYs8MYUwXMAaaFpJkGvGhvvwlcKCJijCk1xnyBFRj8RKQfkGSMWWqsJ/VLwFXHcyPNqanTgKCUUk2JJCAMAPYGvM6x94VNY4ypAQqBns1cM6eZa7aq6to6/7aWFpRSqqFIAkK4uv3QJ2okaVqUXkRuEpFMEcnMz89v4pJN+/unO/zb1bUaEJRSKlQkASEHGBTweiCQ21gaEfEAycCRZq45sJlrAmCMecYYM8EYMyEtLaJlQRsoLKvmrx9v878ur65t0XWUUqo7iyQgrAQyRGSoiEQDM4D5IWnmA9fZ21cDi00T9TLGmP1AsYhMsnsX/RB455hzHyETUvio1ICglFINNDuXkTGmRkRuARYCbuA5Y8xGEXkAyDTGzAeeBV4WkWysksEM3/kisgtIAqJF5CrgEmPMJuBnwAtALPC+/adNhDYoawlBKaUaimhyO2PMAmBByL77A7YrgOmNnJveyP5MYGykGT0etSEBoaK6rpGUSinlXI4YqawlBKWUap4zAkJtcImgvEoDglJKhXJGQAitMqrRgKCUUqEcERBC2xBqdByCUko14IiA4AsAo/olAVBbp43KSikVyhkBwQ4A08b3B3SkslJKheOQgGAFgBiPdbuhVUhKKaUcEhB8AcAb5QaCJ7pTSillcURA8AWAmCjrdnUqbKWUasgRAaHWX2VklRA0ICilVEOOCAj+NgS7yih0oJpSSimHBITaWl8bgjYqK6VUYxwREHzdTmP8jcoaEJRSKpRDAkJIG4JWGSmlVAOOCAi+KqJoj/YyUkqpxjgiINTZi7e5XYLbJf4qJKWUUvUcERB8i3m6BDwu0RKCUkqF4YiAUOcPCGIFBG1UVkqpBhwSEOoDgMft0kZlpZQKwxEBAV8JwSV4PS5dQlMppcJwREDwlRBcAn2TY9hfWNHBOVJKqc7HIQHB+tslwoCUWPYdLe/YDCmlVCfkkIBgRQQBUuKiKK6o6dgMKaVUJ+SIgOBrUhYRPC5tVFZKqXCcERAC2hA8bh2HoJRS4TgiINTZAUACxiF8tCmP4orqDs6ZUkp1Ho4ICL7ygFVCsLqd/vilTO54fV2H5ksppToTRwQEXw2Rr4Tgs/NQaQflSCmlOh9HBARfG4IIeFyuBvuVUkpFGBBEZIqIZIlItojMCnPcKyJz7ePLRSQ94Ng99v4sEbk0YP/tIrJRRDaIyL9FJKY1biic+oFpgsctzaRWSilnajYgiIgbeBKYCowGrhGR0SHJZgIFxpjhwGzgEfvc0cAMYAwwBXhKRNwiMgC4FZhgjBkLuO10bSJ0tlOllFINRVJCmAhkG2N2GGOqgDnAtJA004AX7e03gQtFROz9c4wxlcaYnUC2fT0ADxArIh4gDsg9vltpXNBsp25H1JIppdQxi+TpOADYG/A6x94XNo0xpgYoBHo2dq4xZh/wGLAH2A8UGmM+bMkNRCJwttPAAoK2ICilVL1IAkK4OpbQZ2ljacLuF5EeWKWHoUB/IF5EfhD2zUVuEpFMEcnMz8+PILuNc4lQraOUlVIqrEgCQg4wKOD1QBpW7/jT2FVAycCRJs69CNhpjMk3xlQD84DJ4d7cGPOMMWaCMWZCWlpaBNltyDcwzSVQVRMQELSIoJRSfpEEhJVAhogMFZForMbf+SFp5gPX2dtXA4uN1adzPjDD7oU0FMgAVmBVFU0SkTi7reFCYPPx3054geMQKgMCQoWui6CUUn6e5hIYY2pE5BZgIVZvoOeMMRtF5AEg0xgzH3gWeFlEsrFKBjPsczeKyOvAJqAGuNkYUwssF5E3gdX2/jXAM61/e/Y9EL6EUFqlAUEppXyaDQgAxpgFwIKQffcHbFcA0xs590HgwTD7fwP85lgy21KNlRAKy6vZeaiUob3i2yMbSinVqTmiD6YxhsaGH8xbndO+mVFKqU7KEQGhzhispgr45UUZnJPRizd+eiZej4ucAl09TSmlIMIqo67OmPrxBylx0bw88wwAxvRPIvtgCfsLy+mXHNuBOVRKqY7nkBIC/hJCoF4JXtbvK2TGM8s6IFdKKdW5OCIgGGPCjpDz2X24rN3yopRSnZUzAgLWKOVQd196IgBunfBOKaWcERDq6sL3Msrok8j1k9OJi3K3f6aUUqqTcUZAMOFLCADeKBeVOr+RUko5JSCY8NPsAV63i6qaOl09TSnleI4ICNB4CSHaY30E1bUaEJRSzuaIgFDXxEhlX0Co0mojpZTDOSYghBuHABBtr6AWNC22Uko5kCMCQuBI5VDRHquHkQYEpZTTOSIgNDZSGcDr0RKCUkqBQwJCU7Od+toQKmt0bQSllLM5IiDUGYM00u+0PiBoCUEp5WyOCAhNtyFoLyOllAKHBIQm2xC0l5FSSgEOCQjGGBqJB/UlBA0ISimHc0ZAoIm5jLTbqVJKAQ4JCDpSWSmlmueQgNB4G4JWGSmllMURASGSNgQdh6CUcjpnBIQmjulcRkopZXFEQIBGl0PQgWlKKWVzRkBooogQF+3G7RKOllW3X36UUqoTckRAMDQ+/XWU20V6zzi25hW3c66UUqpzcURAgMarjAAGp8aRW1jebnlRSqnOyBEBobnlkhNjoiiuqGmfzCilVCcVUUAQkSkikiUi2SIyK8xxr4jMtY8vF5H0gGP32PuzROTSgP0pIvKmiGwRkc0icmZr3FA4xtBot1OAxBiPBgSllOM1GxBExA08CUwFRgPXiMjokGQzgQJjzHBgNvCIfe5oYAYwBpgCPGVfD+AJ4ANjzEhgHLD5+G+niftootLIKiFUY5orSiilVDcWSQlhIpBtjNlhjKkC5gDTQtJMA160t98ELhSrFXcaMMcYU2mM2QlkAxNFJAk4F3gWwBhTZYw5evy3E55pciQCJMV6qK412vVUKeVokQSEAcDegNc59r6waYwxNUAh0LOJc4cB+cDzIrJGRP4lIvHh3lxEbhKRTBHJzM/PjyC74TVVZZQSGw3A4dKqFl9fKaW6ukgCQrhHaehP7sbSNLbfA5wKPG2MOQUoBRq0TQAYY54xxkwwxkxIS0uLILvhrtH08T5JXgAOFlW06PpKKdUdRBIQcoBBAa8HArmNpRERD5AMHGni3Bwgxxiz3N7/JlaAaBPNtQz0SYoBIK+osq2yoJRSnV4kAWElkCEiQ0UkGquReH5ImvnAdfb21cBiY7XQzgdm2L2QhgIZwApjzAFgr4icaJ9zIbDpOO+lSY0NTANIjbeqjArKtMpIKeVcnuYSGGNqROQWYCHgBp4zxmwUkQeATGPMfKzG4ZdFJBurZDDDPnejiLyO9bCvAW42xvimFf0F8KodZHYAN7TyvQXcQ9PHY6Osjk8V1TrjqVLKuZoNCADGmAXAgpB99wdsVwDTGzn3QeDBMPvXAhOOJbMtZ5ocqRxjB4RyDQhKKQdzxEhlaLqXkdee8bSiWrudKqWcyxEBobkqI5dLiIlyaZWRUsrRHBEQoOkSAljVRhoQlFJO5oiAEMmEFLFRbsqrNCAopZzLGQHBmCbnMgKrhKCNykopJ3NEQIBIq4y0UVkp5VyOCAiRVRlpo7JSytmcERBM0yumgTYqK6WUIwIC0GydUay2ISilHM4RASGSKqOYaA0ISilnc0RAgAiqjDxuKu1GZWMMJZW6pKZSylkcERAiWRozNtrlLyH89JVVnPb7j6jSFdSUUg7iiIAAEXQ79dQ3Ki/cmEdlTR3FFdXtkDOllOocnBMQmjkea7ch1NXVlya02kgp5SSOCAgR1BgR7/VgTPAU2BoQlFJO4oyAgGlyxTSwAgJAaUAQKK3UXkdKKedwRECA5quM4qOtRXJKAya4u/ONtW2YI6WU6lwcERAirTKC4BLC3iPlbZUlpZTqdBwREKD5XkYJdkB4eelu/77xg1IAqKszvLJst05toZTq1hwRECIpIfgCwtzMvf59R8uqAPjv17nc9/YGnvwku03yp5RSnYEzAgLNr4fQIy466HW0x8XhEisgHC2zxiMU2AFCKaW6I0cEBKDZVuXUhOCAkNE7geLKGtbnFFJrj01wN1fvpJRSXZgjAkJEjcp2LyOfE/smAnDH62upsy/gcmlAUEp1X84ICDTf7VRE/I3IAJeN7QfAyH5JVNVacxppCUEp1Z05IiBA872MAOb9bLJ/Oyk2imFp8dQZQ1G51RW1NpKihlJKdVHOCAgRPsddLsFtVwvFRrmJj/ZQVlnD4ZJKAMp05LJSqhtzRkCAZnsZ+fSyG5e9US7iot2UVtVyuNTqXVRapXMbKaW6L0cEBBNpEQGYMqYvYE2HHe/1UF5VyyG7hFCqk90ppboxT0dnoD0YE1kbAsCvvzma6RMGMbhnnFVCqKyh0l4oJ3CeI6WU6m4iKiGIyBQRyRKRbBGZFea4V0Tm2seXi0h6wLF77P1ZInJpyHluEVkjIu8e7400fw+RpfO4XYwdkAxAanw0h0urtISglHKEZgOCiLiBJ4GpwGjgGhEZHZJsJlBgjBkOzAYesc8dDcwAxgBTgKfs6/ncBmw+3ptoTkv7BvVO9FJYXu0vIZRpCUEp1Y1FUkKYCGQbY3YYY6qAOcC0kDTTgBft7TeBC8VagGAaMMcYU2mM2Qlk29dDRAYClwP/Ov7baJoxzU9dEU7vpBj/drTHxc5Dpfzt421Bq6oppVR3EUlAGADsDXidY+8Lm8YYUwMUAj2bOfdx4P8BTa5kLyI3iUimiGTm5+dHkN3GrnPs5yTHRvm3h6TGAfDnj7ayt6CsxflQSqnOKpKAEO5RGvoTubE0YfeLyDeBg8aYVc29uTHmGWPMBGPMhLS0tOZz20hGWiI2qr52a7AdEACqapqMYUop1SVFEhBygEEBrwcCuY2lEREPkAwcaeLcs4ArRWQXVhXUBSLySgvy36biAuY36pXg9W9XakBQSnVDkQSElUCGiAwVkWisRuL5IWnmA9fZ21cDi40xxt4/w+6FNBTIAFYYY+4xxgw0xqTb11tsjPlBK9xPWC2dcSImoISQGFPfQ9c3t5FSSnUnzY5DMMbUiMgtwELADTxnjNkoIg8AmcaY+cCzwMsiko1VMphhn7tRRF4HNgE1wM3GmHbvqmOwJq87VrEhM6D6aJWRUqo7imhgmjFmAbAgZN/9AdsVwPRGzn0QeLCJay8BlkSSj+PRknlKA6uMAuPJg+9tZkd+CRsfmHL8GVNKqU7CEVNXtLTOKLBR+eZvDGdM/yQA1u8rpLSqlsoaHZeglOo+HBEQrCqjYz8vLtoqQF1+cj9S4qJ5bPq4oOMFpdWtkDullOocHDGXEbSsyija4+KrWRfQO9Hrfx3ocGklfZNjwp2qlFJdjiMCwvGsa9M/Jda/He0ODghH7GmxlVKqO3BElRG0rJdRKG9oCaFEA4JSqvtwREA4lvUQmuL1BHdDfXzR1la5rlJKdQbOCAimZW0IoZJig2vYdh3WOY2UUt2HIwICtKyXUcNrhL/IK8t2s2zH4eN/A6WU6kDaqNxC3zplAP9Zs4/L//o5G3OLANj18OWt/0ZKKdVOHFFCsOJBa1QawS3fGA5Ar4RoAH8wUEqprs4RAQFap8oI4M5LRpD94FQSvFENjhlj2JRbxEtLd7XOmymlVDtySJVR69UZiQget+CNahhLS6tq+dZTX1JZU0e028WMiYNb7X2VUqqtOaeE0MrX841JOHt4L9LskcxHy6r8ayXMmre+ld9RKaXalmMCQmurtddVHpYWzx+uGgvA0bJqf3C4anz/DsubUkq1hCMCgjGt14bgc7C4EoA+STGk2GsvF5ZXU1JRA1iL6BRVVHPjCyvZe0THKyilOj9HBAQAaeVKo2Q7CJzYJ5GUOKvH0bIdhymvtqbELq+qZf7aXBZvOcjfFm9r1fdWSqm24IxG5VaauiLQj88Zxol9ErlwVG9/aeFvi7P9x8ura6mwg0PoLKlKKdUZOSMgtEGVUbTHxUWj+wD1pYVA5dV1FJXreglKqa7DMT9dWzsgBIoJWFntynH9OW9EGpXVtZRUWiWEovIa//FFm/LYkV/SdplRSqkWckRAaIOZKxp13eQhpMRFUVZVS2mlFQh8fxdVVPOjlzK54YWV7ZgjpZSKjCMCArR+o3Jj0hJiSO8ZT05BGTlHrd5FxXZAWLf3KACllboWs1Kq83FEQGjNkcrN6Z8Sw/hBKdSZ+nmOSipqqKqp49pnVwD4l+RUSqnOxBkBAVp/qHKIPknWQ97jdpESZzUyHy2zGpVLq2pYvrN+euyEGEe05SuluhjHPJnausLow1+eR2WtVRXkG5fgU1JRQ0V1nf91eZVWGSmlOh9nBIR2qDFKjosCrJJBSkg31OLKGvKKKvyvSyprOFpW1SBwKKVUR3JMlVFjq521haSQgFBVU0dOQTkAI/smsvNQKeMf+MgfJCoCBrGVVdWweEteu+VVKaV8HBEQoO2rjAK5XcKLN04M2rc9v4Q+SV7OHt7Lv++LbYcAOPfRTzj7kcUA/OG9zdz4Qqa/R5JSSrUXRwSE9uxl5DO6XxIAE4emAr6AEEO8t76W7s431lFXZzhYXMmhkioOFFawalcBACt3HWn3PCulnC2igCAiU0QkS0SyRWRWmONeEZlrH18uIukBx+6x92eJyKX2vkEi8omIbBaRjSJyW2vdUOP30NbvECwt0cvOP17GjWelA7CvoJyUuGgSvMHNNlsPFvu3563JYdfhUgCdIVUp1e6aDQgi4gaeBKYCo4FrRGR0SLKZQIExZjgwG3jEPnc0MAMYA0wBnrKvVwPcaYwZBUwCbg5zzVbT/uUDi4j4l9qsrKkj0esJKiEA3D53nX+7usb4F9jZa7c5KKVUe4mkhDARyDbG7DDGVAFzgGkhaaYBL9rbbwIXitWKOw2YY4ypNMbsBLKBicaY/caY1QDGmGJgMzDg+G8nPGPatw0hUOCYg3ivm3ivNe9RD3uswub9Rf7j//dJ/TTZBWVV7ZRDpZSyRBIQBgB7A17n0PDh7U9jjKkBCoGekZxrVy+dAiyPPNvHrj17GQUKrCKK93oY0z+ZYWnxzJo6skHa6lqrLONxCYVlOlOqUqp9RRIQwj1JQ2thGkvT5LkikgC8BfzSGFMUJi0icpOIZIpIZn5+fgTZDfeGHVVpBIkBJYREr4fhvRNYfOf5TD9tEC770zl3RBpx0fUzpsZFuynUqbOVUu0skoCQAwwKeD0QyG0sjYh4gGTgSFPnikgUVjB41Rgzr7E3N8Y8Y4yZYIyZkJaWFkF2w12jA6uMAkoIgdVHLpdgL8tMbV0dZQGjl8cOSOZwaRXZB3WabKVU+4kkIKwEMkRkqIhEYzUSzw9JMx+4zt6+GlhsrL6e84EZdi+koUAGsMJuX3gW2GyM+Utr3EizOigiBP7yP2Vwj7BpBqTEMn5QCgDPXHsaPzxzCADT//5Vo2snLFi/X9dVUEq1qmanrjDG1IjILcBCwA08Z4zZKCIPAJnGmPlYD/eXRSQbq2Qwwz53o4i8DmzC6ll0szGmVkTOBq4F1ovIWvutfmWMWdDaN2jloy2uGpnAtosx/ZOCjv3hqrHMWbmHh751EmXVtUS7Xf7FdsYOSGLDviIu+sun7Pjj5UHnFVdU8/NXVxMf7WbjA1Pa/iaUUo4Q0VxG9oN6Qci++wO2K4DpjZz7IPBgyL4vaOff7O21HkJT4qKDP+4fTBrCDyZZpYEkd3BhzWUHkjoD736dyzdP7u8/9j//WAZAqV3NtDWvmI825fGjc4bi9bhRSqmWcMRI5Y72zLWn8cC0Mcd0TnrPeP/2La+tCToW2FXVGMOfP8ziTwuz+O+6/ceXUaWUozkiIBhj2n2kcqBLxvTlh2emH9M5D35rLH+6+mT/64LS+nEJpwxO8W9PfeJz1trzHq3P0fmPlFIt54iAAB3Xy6ilEmOimD5hEHNumgTAwo0H/MfqAtpEthwoJq+oEoD9hRUopVRLOSIgdGCb8nHz9T46HFBCyC+q8O8PdKBIA4JSquWcERBM+09u11piotzER7s5XFJFTkEZF//lU3ILKzgno1eDtMdTQvjuP5by2/kbjyerSqkuzhEBATpHL6OWSk2I5khpJb9+ewPb7MFql5/cjzsvHuFPc2KfRA6VVFJVU8fTS7Y3O332rLe+5tdvb/C/Xr7zCC98tatV870+p5AnFm1rPqFSqlNwREDoyKkrWkO/5FhyCsqD7mJk3yR+cWGG//WE9B4YY7U1PPLBFqb/fSlgNah/75/LeHNVDje+sJLHF20FYM7Kvby8bDd1dW332Vzxf18we9FWamrrmk+slOpwjggI0HWrjMD69Z91oBivx/q6rpk4uEGa6yanA/DO2n3+fUfLqth7pJyvth/mrjfWsXjLQR5ftI3agCCwIbfQv3xnWymtbNvrK6VahyMCQkeOVG4NZw3vSXFlDQs3Wmst//bK+qUjfNNoD06NY0BKLJ9tPeQ/9sqy3WTublh1dLik0r/92dZ8jjYys+ravUe5ZPanFFUc30R7xZU6UZ9SXUFEI5W7OkPXLiGcf2Jv4qLdlFXVcmKfxKDRyPN+fhYrdh4mJspNRp8E9h2tX1jnsQ+3hr1efkBA2HawJGhm1el//4qff2M4KbFRfOuprwBYt/coJ6Ql8OLSXdxx8QheW76HsQOSOT3dWh70YHEF+cWVjOmfHPb9SiprWnzvSqn244iAYOm6ESEmys3p6al8ujWf750RXF00tFc8Q3tZo5ovG9uPJVnNTxF+qKS+C2tJRU1QQFi5q4DHFmaxMbd+NPSR0iqufXYxAGkJXv7w3mYAFt1xLsN7J3L+n5ZQVlXLroeD51wKfA+lVOenVUZdRFKsVTUUuiZzoBN6x4fdf3p68Cyr++zlOfsmxVBcWdNg7YVxIWMcPt580L+9eX/9GtAX/eUzAP/U3VU19Y3Hge0UxXYJobyqltkfbWV7fknEa0bvPVKm60sr1U4cERCgY6euaA1R9mo6tU1Et5MHpnDjWUP5x7Wn+ffdfemJ/O7KsUHpthywfv2n94qjuKKGopCAUF0T3Ctoa15xg3N9DhbXj304FFAVVVZVXyrwlRD+tngbT3y8jQv//CnnPPpJo/cR6PzHlnDOo59gukNUV6qTc0yVURePB/4SgquJyBbldnH/FVaD892XnsjFo/swok8i+cWVQeleWrobr8dFv+RY9h090qCE4BvxPGvqSF5dvpsdh0r9x3YGbAPszK9/nVNQTv+UWICgBX98bQiBgSVSvpLGkdIqeiZ4j/l8pVTkHFFC6A4/Lu+4ZAQ/OW8YV4zrF1H6m78xnBF9EoH6nkiBjIGe8dHsPVLO3JV7g47lBFQp9U2K8VcFJXo9/gf9xKFWg3JuYX0j9q/f3uAvMZRWNiwh+NaMbonAdo+28NGmPNbnFLbpeyjV2TkiIEDX7mUEkBQTxT1TR7VovQNPwFoLT37vVACqaus484SeAGTlFRPtqU9z0C4hJMV66JMU498/pFecf/vyk6zAdPvcdf59WXnF3PLaGt5clcPXAQ/XBxds5pOsg0HVSBDcztCcwOqotvDjlzK54v++aNP3UKqzc0RA6AYFhOMWG+UmKcbDiX0T/fsCl/SMj3Zz3+WjSPR6/AvvxES56ZccEBDsNRpcAhl9EsK+z4qdR7jrjXX8cu7aoP03PL+yQdXVEXvCvto6w5OfZPun8fYJbDc4VFLJweIK0me9x4L1uu6DUm3BGQHBmC49l1FryLzvIpb96kL/A37coBRS46M5eaA1dqCgrJofnTOMC0f19p8TF+2hb3Ks//XAHtZ2vNfD2AHBYw76BpQkwnG7pEFA+DrnKH94dxMn/GoBf1qYxR2vr2XP4TL/yOnygBHU+cWV3Pcfa+6ll5fubnD9d9buI33WexwtC65aWrHzCOmz3mN7E+tPt/VIbaW6CkcEBOj6VUbHK97rIS7aQ7zXwws3nM5z100A4MUbJgIwaZjVJtAroOE2NqCEEOUW0uxjxlhVWM/a14Dw02kAfHj7udx72Shq64y/5OGzNa+Ef32x0/96R34p5/7pE+56w6qGChy/cLi0igL7Ye9xC3sOl3HbnDWUV9Xy0aY8bptjlUh2hDR6/2eNNZXHp02Mz2hspLZSTuOIgKBVRsHOP7G3v8dOj/hoPr37fP5xrfVw75VYHxDiot2MsKuG/nfKSNLsY75f7qcPTWVgj1h+d+UYbr1wONPGW+s+J8Z4gq4RWOrw8XpcrN8XfoW3DzZYiwEVBzRMP71ku//BvS2vhMc/3so7a3P579e5/PilTH+6vJApwH3zP5VW1vD0ku3+awOc8+hiHnxvE0fLwzdYF5ZXB3WrDVVdWxc0MrwlNuwr5MvsQ80nVKodOKLbqTFdv9tpWxoSsH5zYAkhJsrNoNQ4vpyz/H8WAAAVhklEQVR1Af2TY1i64zBQ3xicFBPFF/97gT/9+EEpvLM2lwEpsWw5YHUxTYqNIikmipX3XsT8dblcOa4/bpdwxd++YMH6+odzoBr7+sUhI5x9U38fKKqgxu6x9Oqy4Oqj/YUVPPLBFpZk5fPU90/1T+ldVFHNPz+3SiO7Hr6c6to69h4p55+f7yTBW98LyxiDMTBvzT7+9fkOthwoZsvvpxAT1bAx/6EFm3n+y12s+fXF9IiPDrrGvqPlDOwR1+CcUNf8cxnFFTWsvf9iUuKim02vVFtyRAkBQJxeZxShXgn1D6W4aOshOCAlFpH6KqPG+EodcdFuHvnOSdx72SiSYqyHbVqil5lnDyUt0UtqfHTQL+vYKDduV/D3U1RR7a8yeuhbJ/n3+0osX223flWvC+kqeri0kqeXbGfz/iK+8dgS//7A4DJvdQ5/+7h+nYbZi+rnfMrcXcBjH2Zx1xvr/EEt8Bf8y8t2c8FjS9hfWO7fnx3QPrEtr5jZH23l7Ec+YVdI9VU4vnwt39n0+hWB8ooqOKDLpao24IiAoKNcIxfahhBoQI9Y4u2HfThxdvp4r4fvnj6YH587rNH3+dVlI/3bmx64lOwHpwYdP/m3H1Jiz5I6flAKI+3eUd8YaVU/hY5LeGDaGJJiPGzLC994HNigfcfr6/jr4uyw6ab/fSlPLdketG/mi5m8unw3+wvL+f1/N7HjUCmfbzvk/6y25hVTUV3LtCe/5OLZn/mvHTiJ4N4jZZzz6GK+Cqkeiomy/guuPIaAMOXxz7jgz0si+nddUV3boKE9nI25hZzwqwXsOazThDiZIwKCipyvETkpxoMr5Fd7XLSHjQ9M4bunh29A7mufe/bwhst7hrrp3BMAyOidgIggInw56wJevHGiP83eI1YpIinWw0PfPokrxvXnZ+ed4J/PqU+S9UBOjY/mh2em0y85ls+3WQ9c38A5gHEDk9m8P3jKDZ+J6alh94e69z8bOPOPi6myF/t5/stdfLXdqkLLOlDMnBV7WBfSbfZIaRW3z13LE4u28c7afew9Us4Tdsnk+udX8Id3N1FRbV3viwjbEfKLKykoq6asqpbt+c2XQKb/fSnjH/io2XTPf7mL2jrD59lW4/sX2w7xr893NHteUUU19729PuwU6UUV1dqDq4txRhtCR2egC+mZ4OXZ6ya0aJqIsQOS+fD2c8noHX6MQqgvZ10Q1AA9ICWWXgnRJHg9lFTW8OACa1bVXgleBvaI41R73ITv+KmDe/D+hgP0tOvvU+OjycorxiXwy4sy+N4/l3PJ6D6UV9c2qFry+e2VY7jsr583msefnDcMr8fNG5l7g9asDgwwL4XpBgvwk5dX+bdPGWxNGHiopJKNuYUsycr3z0w7LC2eLQeKueP1tfzmijHsyC8hKTaKE9Ksz7Gooppot4uYKHfQXFJ7jpQyvInPuq7OsH6fdd8FpVVB7Rxgjc7+84dZPHv96eTZgxF91YQ/eHY5AJeO6cug1Pq2kEWb8jhreC9i7XQvfbWLV5btoW9SDLdckBF0/ZN/+yGnDenBWz+b3Ggej9f+wnJ2Hy5j0rCeLTrfGKPTogRwREDAaLfTY3HhqD4tPtc3XUYkBqTENtjn9bj5+M7zOOOhj/37Qht0fe0NF43qQ22d4e5LTwSstacBRvVLYvIJvfjo9nMZ3DOOu974utE8DOkZvuE3PtrNht9d6m97+sUFw/njgi089+VOxvRPCpoe3EcETh6QHDb4rNljlR6255fy55B1Kk4fksqO/FLmrd5HotfDi0t30zM+mhvOSmfvkXL+s3Yfo/omcnZGr6Cg9NLS3TyxaBsv3DCxwcMerJHjPhtzizg7o77kVldn/L2z3l6zj4NFVvVWYUgX3Oe+3MlvrhgDwOo9BfzopUyun5zOb6+09vmmIwntUuxr41i1u6BBvlrTeY8uoaq2ju0PXeb/d1FRXUudMcRFN/94+8tHW/nb4mwy77soqLrUqRxTZeT0gWldSe/Epv9jjrd/bQ9Ni+eZH04gww5CMfa0Hr5qrwx7MaE+9vVOHpjM9ZPT+eO3T+LZ6ybw1s8mE+/18Ox1E3j++tOD3qO0qjaoI4Jv4sCdf7yMwfYvZq/HxQPTxvjT3Hf5aC4/2ZrSY8bpg4h2u/j1N0f7079kV4ct3lI/nTjAOSPqH9TvrMsFrHEXj324lbmZe6mqqWNdTiFPfrKdeav30Sshmii3sCQrn3U5hZzy+4+4Y+5afzvJx5vzuHT2Z0x9or7kszE3OEiNf+DDoGN77CnGC+yA4KuOC6wG22o3su86XF9VtcJu+8gvruTWf68hfdZ7fPupL3kloPdXtV3NVlpZ498OdO9/1nPPvPUN9jenrs74q/B8U7oDXP33rxj/u+aryYwx/M1u78ncFXkbTnfmiBKCVhl1LSLCJ3edz+ItB5kwpEeD43+YNpb/mTCIU0LWbehplxB8Dc8+vrEVg1Pj/L9sA/lKRNNPG8ig1Dj+8tFW/6jscHm7ZuJg3t9wgHk/n8zofkmM6Z/MKYNScLkEYwynDUnl1MEpPPydkwH45sn9qKiu9QcGgJS4KI6WVSMS3OYSySC5sQOSSY6N4p21uUS5hepaw7w1+9h1uJTXfjyJmS/Wj8sYPyiFg0UVbMwt4q1VOXy46QD/uHYCRQG9rgK7/z7x8TbW5Rwlr6iSKLewYV8RlTW1eD1uPrfbOfKLK1m39yj9kmNYvccqAby5Ksd/jdV7jrJ6T30gybj3ff/2t04ZwCPfOZnZi7Yyok8CC9Yf4KNN1tKwd14ygl4JXr7OOcr2/BKuOLk/CzfmsetwKROHpnJ6eip1dYasvGIWbznIeSPS/NfNzi9msF3a27DPKr1t3l/EkJ5xjZYUngjoaZZ1oIQpY8MmcxRnBATT9ddDcJqhveKZefbQsMd6xEcHPQx8br0wgyvH9W8wrcZZJ/TijKGp/qqlxvxp+jjAahwPV53lc+6INLY9OJUoe9LA0wKClogEvQaCJgh89xdnc/ebXzNr6kjyCis4sW8iKXHRfHT7udzy2hqy8oq5eHQfbr0ggzpj+N1/N3K0rJphafEYA4uzDnLTOcOI93p4Z20uN5w1lGc+sxp/V+85yshff+B/r4npqbw0cyI/e2UV89flMt8ufWwLqEq6+rSBvLkqh2i3i5MGJrNqd4G/beNn553AXxdns/NQKSP7JrHZribbmFvEtCe/DLg/L3lFlSTGWKPgv/P00kY/u/+s2cfyHYfJDdNt9o3MHF7P3OufYv1oWTW/++8mwGo3Wv/bS7j5tdW8bw8u/NPCLP+52QdLuGBkHwpK63tU+UpIPz//BO6+9EQOFley81ApE9NTqaypY5k9riYt0cuOQyV8vi2fg0WVJMR4GNM/Kew4kro6w4pdRzg9PbVBV2mw1jH/bGs+fZJi+P1VDSPMqt0FDE6N44WvdnLjWUMbtF18mX2I8YNSiG9iIay2FNG7isgU4AnADfzLGPNwyHEv8BJwGnAY+K4xZpd97B5gJlAL3GqMWRjJNVubxoPuLyHMHEsAJw1MZu5Pzoz4Ov8zYVCzaaLcLattHTsgmfdvO6fB/ow+iSTb05RPPqEnJ9lzTL3x08kYY/wz1tbU1vm35/18MqP6JnH3pSfy+bZ8bnzBKhkMSo3l07u+4e8llhwbPP35xbOtle5mf3ccgvDmqhxOGpjMxaP7BNX5XzKmL39dnM13nvrK30YwMT2VFSHVK7OmjuT2ueu46ZxhnDYklSE949gd0n11wpAeHCmtYseh0gbB4PrJ6by8bDePfLAFgGi3i6raOn8wSPR6KK6sYf66XN7fcIDU+Gj/xIhj+idRVlXLQwu2UF1r/EFiWK94/zQmTy3Zzta8EhZtzmvwuX/n1IEcLK5g56FSrn12hX//xKGpvG7/myksr+aZz7bz8/OH8+ry3Ty0YAu3XzSC2y4KbkQvLK/mvrc3+F+PH5TCd04b6H/93tf7ufm11cREuaioruO15Xv44n8vwO0SYqLcvLJsN/e9vYHvnTE4aOxNe2o2IIiIG3gSuBjIAVaKyHxjzKaAZDOBAmPMcBGZATwCfFdERgMzgDFAf2CRiIywz2numq1Gq4xUV/D9MwazYueRoNKP9Su0/udM4FTmpwbMVjtpWE/OHZHGVeP7M2Vs36Auw3dPGUlVbR2xUR7eWl1ftXPluAHU2aOqrxzXn7REL4N6xFFQVsWQnnH+Xk6BDcZ3XjKC7z6zDIApY/ry7VMHcMmYvpw3ord/3Y05N01i1lvrOXVwD3rER3HSgGROGdyDujrDT15ZxaAecWT0SeCeeev5948nceYJPSmuqOGt1Tlk9E5gwW3n8GlWPj+yG72f+sGpXPvsCm6bs5bBqXG8e+vZzF+by31vb+C+y0ezLucoD7+/JajE8Nz1p7N6TwE1tYanP90eNhgAXHVKfxZtyuPFkJ5iK3Ye4cw/fsxD3zqJFbuO8PSS7Szdftgf6GYv2srsRVu5fnI691w2kp+8vKrByoOvrdjD9vwSnlqy3V9FCPi7GheUVTPmNwsBazr59+xZfN9dl8v93xzNZ1vzOXdEWthR8m1FmhvcIiJnAr81xlxqv74HwBjzx4A0C+00S0XEAxwA0oBZgWl96ezTmrxmOBMmTDCZmZlNJQlr9P0f8P0zBnPv5aOP+Vyl2lNZVU1EvWNaavWeAr791Fc8evXJEZWEvvP0V6zfV8iT3zuV/OJKvnfGYOau3EPvxJgGbTXHKvBeK6pr+XzbIc48oScJXg91ddaDfNKwVE4d3IOMe9+nps6w6r6L6JngxRhrssQEr4fq2jq+989lrNxllW5G9k3kg1+e63+fL7Yd8nejXfebS3C7hOyDJazaXcDMs4eyeX8RDy3YzFfbDzOiTyJj+icxf11u0BrhgWZNHcnD728Je+yiUb15fMYpvLR0F49+kNXgeGyUm/LqWk4ZnELu0XLyiuoHLw7rFc8tFwznjtfXcdX4/ry91qrie/0nZwaNqzlWIrLKGDOh+ZSRVRkNAAKX1MoBzmgsjTGmRkQKgZ72/mUh5w6wt5u7ZqvRgcqqq2jLYABWqWLXw5dHnP7VH51BRXVt0DxLjQ1MPFaB9xoT5ebi0fXdnV0u4eZvDPe//uSu8ymqqPbXuYuIf4BilNvFv354Ok8uyebaSUP8kzD6nJ3Ri8emj6NnQrS/+mz8oBTG250SRvVL4uWZZ1BSWeO/5i8uGM6/V+zl+S93UllTx68uG8lLS3fjEuH6yemM6JPA4ZIq/vzhVv+Ss32SvPzUHjj5o7OH8e8VezhUXMV/bp5M/5RYEr0elmzN54bnV3La4B689dPJHCqt5KH3NvP22ly+feoAu0Tg8gcDsOa7Om1wD17/aeTVni0Vyb++cNXvoY/YxtI0tj9cBWzYx7aI3ATcBDB4cMv+IU4Z25dR/ZJadK5SThYT5W7XKovGBA6OCyc5LopfXTaq0eNXB9TlNyYhoCF3SM94Zk0dyQ1npfP5tkNcMa4f105KR8T6TC4YafdMmzCIz7bmk9EngX4Ba4dEe1y8+4tzwOBvGwI4f0Qaz99wOqenp+JyCb0TY3j06nFcfnJ/zhuRRrTHxVPfP5X31x/g1gszcLmEJxZtbXIt9dbkiCojpZRyqmOpMoqkq8RKIENEhopINFYj8fyQNPOB6+ztq4HFxoo084EZIuIVkaFABrAiwmsqpZRqR81WGdltArcAC7G6iD5njNkoIg8AmcaY+cCzwMsikg0cwXrAY6d7HdgE1AA3G2NqAcJds/VvTymlVKSarTLqTLTKSCmljk1rVxkppZRyAA0ISimlAA0ISimlbBoQlFJKARoQlFJK2bpULyMRyQfCr1fYvF5AZAvXdh96z86g99z9Hc/9DjHGNJwvPowuFRCOh4hkRtr1qrvQe3YGvefur73uV6uMlFJKARoQlFJK2ZwUEJ7p6Ax0AL1nZ9B77v7a5X4d04aglFKqaU4qISillGpCtw8IIjJFRLJEJFtEZnV0flqLiAwSkU9EZLOIbBSR2+z9qSLykYhss//uYe8XEfmr/Tl8LSKnduwdtJyIuEVkjYi8a78eKiLL7Xuea0+pjj3t+lz7npeLSHpH5rulRCRFRN4UkS32931md/+eReR2+9/1BhH5t4jEdLfvWUSeE5GDIrIhYN8xf68icp2dfpuIXBfuvSLVrQOCiLiBJ4GpwGjgGhHpLgsr1wB3GmNGAZOAm+17mwV8bIzJAD62X4P1GWTYf24Cnm7/LLea24DNAa8fAWbb91wAzLT3zwQKjDHDgdl2uq7oCeADY8xIYBzWvXfb71lEBgC3AhOMMWOxpsifQff7nl8ApoTsO6bvVURSgd9gLUE8EfiNL4i0iDGm2/4BzgQWBry+B7ino/PVRvf6DnAxkAX0s/f1A7Ls7X8A1wSk96frSn+AgfZ/lAuAd7GWaT0EeEK/c6z1Ns60tz12OunoezjG+00Cdobmuzt/z9Sv0Z5qf2/vApd2x+8ZSAc2tPR7Ba4B/hGwPyjdsf7p1iUE6v9h+eTY+7oVu4h8CrAc6GOM2Q9g/93bTtZdPovHgf8H1NmvewJHjTE19uvA+/Lfs3280E7flQwD8oHn7Wqyf4lIPN34ezbG7AMeA/YA+7G+t1V07+/Z51i/11b9vrt7QAi3MnW36lYlIgnAW8AvjTFFTSUNs69LfRYi8k3goDFmVeDuMElNBMe6Cg9wKvC0MeYUoJT6aoRwuvw921Ue04ChQH8gHqvKJFR3+p6b09g9tuq9d/eAkAMMCng9EMjtoLy0OhGJwgoGrxpj5tm780Skn328H3DQ3t8dPouzgCtFZBcwB6va6HEgRUR8y8EG3pf/nu3jyVhLvHYlOUCOMWa5/fpNrADRnb/ni4Cdxph8Y0w1MA+YTPf+nn2O9Xtt1e+7uweElUCG3TshGqthan4H56lViIhgrWW92Rjzl4BD8wFfT4PrsNoWfPt/aPdWmAQU+oqmXYUx5h5jzEBjTDrWd7nYGPN94BPgajtZ6D37Pour7fRd6pejMeYAsFdETrR3XYi1Rnm3/Z6xqoomiUic/e/cd8/d9nsOcKzf60LgEhHpYZesLrH3tUxHN6q0Q6PNZcBWYDtwb0fnpxXv62ysouHXwFr7z2VYdacfA9vsv1Pt9ILV42o7sB6rB0eH38dx3P/5wLv29jBgBZANvAF47f0x9uts+/iwjs53C+91PJBpf9dvAz26+/cM/A7YAmwAXga83e17Bv6N1UZSjfVLf2ZLvlfgRvves4EbjidPOlJZKaUU0P2rjJRSSkVIA4JSSilAA4JSSimbBgSllFKABgSllFI2DQhKKaUADQhKKaVsGhCUUkoB8P8BzfXN2i0p9S8AAAAASUVORK5CYII=\n", 210 | "text/plain": [ 211 | "" 212 | ] 213 | }, 214 | "metadata": {}, 215 | "output_type": "display_data" 216 | } 217 | ], 218 | "source": [ 219 | "from scipy.signal import correlate\n", 220 | "def calculate_natural_distribution(model, X_adv, y_adv, batch_size=256, kernel=np.array([0.1, 0.2, 0.4, 0.2, 0.1])):\n", 221 | " with torch.no_grad(), lib.training_mode(model, is_train=False):\n", 222 | " logits = lib.process_in_chunks(lambda X_batch:model(X_batch.to(device)), X_adv, batch_size=batch_size)\n", 223 | " sorted_ans = logits.topk(k=logits.shape[1], dim=1).indices.to('cpu')\n", 224 | " y_rank = (sorted_ans == y_adv.view(-1,1)).argmax(dim=1)\n", 225 | " bin_counts = np.bincount(lib.check_numpy(y_rank), minlength=logits.shape[-1]).astype('float32')\n", 226 | " soft_counts = correlate(bin_counts, kernel)[(len(kernel) - 1) // 2: (len(kernel) - 1) // 2 + len(bin_counts)]\n", 227 | " soft_counts[0] = 0\n", 228 | " assert len(soft_counts) == len(bin_counts)\n", 229 | " return soft_counts / soft_counts.sum()\n", 230 | "\n", 231 | "soft_counts = calculate_natural_distribution(models.resnet18(pretrained=True).to(device), X_adv, y_adv)\n", 232 | "\n", 233 | "plt.plot(soft_counts)\n", 234 | "\n", 235 | "cumsum = soft_counts.cumsum() / soft_counts.sum()\n", 236 | "cumsum = torch.as_tensor(cumsum, dtype=torch.float32)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "### Train" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 8, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "new_params = set(model.editable.module[0].parameters())\n", 253 | "old_params = [param for param in model.parameters() if param not in new_params]\n", 254 | "\n", 255 | "training_opt = lib.OptimizerList(\n", 256 | " torch.optim.SGD(old_params, lr=1e-5, momentum=0.9, weight_decay=1e-4),\n", 257 | " torch.optim.SGD(new_params, lr=1e-3, momentum=0.9, weight_decay=1e-4),\n", 258 | ")\n", 259 | "\n", 260 | "trainer = EditableMatchRankTrainer(cumsum, model=model,\n", 261 | " stability_coeff=0.03, editability_coeff=0.03,\n", 262 | " experiment_name=experiment_name,\n", 263 | " error_function=classification_error,\n", 264 | " opt=training_opt, max_norm=10)\n", 265 | "\n", 266 | "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', '
'))" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 29, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "from tqdm import tqdm_notebook, tnrange\n", 276 | "from IPython.display import clear_output\n", 277 | "\n", 278 | "# Learnign params\n", 279 | "eval_batch_cd = 500\n", 280 | "val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 281 | "min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 282 | "early_stopping_epochs = 500\n", 283 | "number_of_epochs_without_improvement = 0\n", 284 | " \n", 285 | "def edit_generator():\n", 286 | " while True:\n", 287 | " for xb, yb, lg in torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2):\n", 288 | " yield xb.to(device), torch.randint_like(yb, low=0, high=max(y_test) + 1, device=device)\n", 289 | "\n", 290 | "edit_generator = edit_generator()" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "while True:\n", 300 | " \n", 301 | " for x_batch, y_batch, logits in tqdm_notebook(train_loader):\n", 302 | " trainer.step(x_batch.to(device), logits.to(device), *next(edit_generator))\n", 303 | " \n", 304 | " if trainer.total_steps % eval_batch_cd == 0:\n", 305 | " val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 306 | " clear_output(True)\n", 307 | "\n", 308 | " error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 309 | "\n", 310 | " number_of_epochs_without_improvement += 1\n", 311 | "\n", 312 | " if error_rate < min_error:\n", 313 | " trainer.save_checkpoint(tag='best_val_error')\n", 314 | " min_error = error_rate\n", 315 | " number_of_epochs_without_improvement = 0\n", 316 | "\n", 317 | " if drawdown < min_drawdown:\n", 318 | " trainer.save_checkpoint(tag='best_drawdown')\n", 319 | " min_drawdown = drawdown\n", 320 | " number_of_epochs_without_improvement = 0\n", 321 | "\n", 322 | " trainer.save_checkpoint()\n", 323 | " trainer.remove_old_temp_checkpoints()\n", 324 | "\n", 325 | " if number_of_epochs_without_improvement > early_stopping_epochs:\n", 326 | " break" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "### Eval metrics" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 13, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "trainer.load_checkpoint(path='best_val_error');\n", 343 | "# if you're not running this in the same notebook, you can also select the checkpoint via path\n", 344 | "# trainer.load_checkpoint(path='./logs/EXPERIMENTNAMEHERE/checkpoint_best_val_error.pth')" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "## Adv metrics" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 14, 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "mean = [0.485, 0.456, 0.406]\n", 361 | "std = [0.229, 0.224, 0.225]\n", 362 | "\n", 363 | "test_transform = transforms.Compose(\n", 364 | " [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)])\n", 365 | "\n", 366 | "thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1}\n", 367 | "two_hundred_to_1000 = dict(map(reversed, thousand_k_to_200.items()))\n", 368 | "\n", 369 | "X_adv, y_adv = zip(*datasets.ImageFolder(root=\"./imagenet-a/\", transform=test_transform))\n", 370 | "X_adv = torch.stack(X_adv)\n", 371 | "y_adv = torch.tensor(list(map(two_hundred_to_1000.get, y_adv)))\n", 372 | "\n", 373 | "X_test, y_test = map(torch.cat, zip(*val_loader)) # Read the whole test" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 25, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "application/vnd.jupyter.widget-view+json": { 384 | "model_id": "d907ecd079674f3ebb264d87cca2ee00", 385 | "version_major": 2, 386 | "version_minor": 0 387 | }, 388 | "text/html": [ 389 | "

Failed to display Jupyter Widget of type HBox.

\n", 390 | "

\n", 391 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 392 | " that the widgets JavaScript is still loading. If this message persists, it\n", 393 | " likely means that the widgets JavaScript library is either not installed or\n", 394 | " not enabled. See the Jupyter\n", 395 | " Widgets Documentation for setup instructions.\n", 396 | "

\n", 397 | "

\n", 398 | " If you're reading this message in another frontend (for example, a static\n", 399 | " rendering on GitHub or NBViewer),\n", 400 | " it may mean that your frontend doesn't currently support widgets.\n", 401 | "

\n" 402 | ], 403 | "text/plain": [ 404 | "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))" 405 | ] 406 | }, 407 | "metadata": {}, 408 | "output_type": "display_data" 409 | }, 410 | { 411 | "name": "stdout", 412 | "output_type": "stream", 413 | "text": [ 414 | "base_error\t:0.30764\n", 415 | "drawdown\t:0.0014556\n", 416 | "mean_complexity\t:2.149\n", 417 | "success_rate\t:1.0\n" 418 | ] 419 | } 420 | ], 421 | "source": [ 422 | "from lib import evaluate_quality\n", 423 | "from tqdm import tqdm_notebook\n", 424 | "\n", 425 | "np.random.seed(9)\n", 426 | "indices = np.random.permutation(len(X_adv))[:1000]\n", 427 | "X_edit = X_adv[indices].cuda()\n", 428 | "y_edit = y_adv[indices].cuda()\n", 429 | "metrics_adv = evaluate_quality(model, X_test, y_test, X_edit, y_edit, \n", 430 | " error_function=classification_error, progressbar=tqdm_notebook)\n", 431 | "\n", 432 | "for key in sorted(metrics_adv.keys()):\n", 433 | " print('{}\\t:{:.5}'.format(key, metrics_adv[key]))" 434 | ] 435 | } 436 | ], 437 | "metadata": { 438 | "kernelspec": { 439 | "display_name": "Python 3", 440 | "language": "python", 441 | "name": "python3" 442 | }, 443 | "language_info": { 444 | "codemirror_mode": { 445 | "name": "ipython", 446 | "version": 3 447 | }, 448 | "file_extension": ".py", 449 | "mimetype": "text/x-python", 450 | "name": "python", 451 | "nbconvert_exporter": "python", 452 | "pygments_lexer": "ipython3", 453 | "version": "3.6.4" 454 | } 455 | }, 456 | "nbformat": 4, 457 | "nbformat_minor": 2 458 | } 459 | --------------------------------------------------------------------------------