├── images
└── editable.png
├── .gitmodules
├── lib
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── loader.py
│ ├── copy_and_replace.py
│ ├── init.py
│ ├── learning_rate.py
│ ├── gumbel.py
│ ├── regularize.py
│ ├── data.py
│ ├── basic.py
│ ├── ingraph_update.py
│ └── trainer.py
├── loss.py
├── evaluate.py
├── trainer.py
└── editable.py
├── mt
├── train.py
├── requirements.txt
├── README.md
├── prepare_iwslt14_de_en.sh
├── train.sh
├── evaluate.ipynb
├── generate_edit_datasets_samples.ipynb
├── edited_generate.py
└── fairseq_criterion.py
├── requirements.txt
├── LICENSE
├── .gitignore
├── README.md
└── notebooks
├── resnet.py
├── imagenet_preprocess_logits.ipynb
├── cifar10_editable_layer3.ipynb
├── imagenet_editable_training.ipynb
├── imagenet_evaluate_nae.ipynb
└── imagenet_editable_training_with_natural_distribution.ipynb
/images/editable.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xtinkt/editable/HEAD/images/editable.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "mt/fairseq"]
2 | path = mt/fairseq
3 | url = https://github.com/pytorch/fairseq.git
4 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .editable import *
3 | from .trainer import *
4 | from .evaluate import *
5 | from .loss import *
6 |
--------------------------------------------------------------------------------
/mt/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("fairseq")
3 | import fairseq_criterion
4 | from train import cli_main()
5 |
6 |
7 | cli_main()
8 |
--------------------------------------------------------------------------------
/mt/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.1.0
2 | numpy>=0.13
3 | scipy>=1.2.0
4 | torchvision
5 | tqdm
6 | tensorboardX
7 | prefetch_generator
8 | mosestokenizer
9 | fairseq==0.8.0
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.1.0
2 | numpy>=0.13
3 | scipy>=1.2.0
4 | scikit-learn>=0.17
5 | torchvision
6 | matplotlib
7 | tqdm
8 | tensorboardX
9 | pandas
10 | prefetch_generator
11 |
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import *
2 | from .data import *
3 | from .gumbel import *
4 | from .init import *
5 | from .regularize import *
6 | from .learning_rate import *
7 | from .ingraph_update import *
8 | from .copy_and_replace import *
9 | from .trainer import *
10 | from .loader import *
11 |
--------------------------------------------------------------------------------
/mt/README.md:
--------------------------------------------------------------------------------
1 | ## Get started
2 | 1. Get fairseq submodule: `(git submodule init) && (git submodule update)`
3 | 2. Install fairseq: `cd fairseq && pip install --editable .`
4 | 3. Prepare iwslt14.de-en: `./prepare_iwslt14_de_en.sh`
5 | 4. Prepare edit samples: run notebook `generate_edit_datasets_samples.ipynb`
6 |
7 | ## Training
8 | Train Editable Transformer: `./train.sh`
9 |
10 | ## Evaluating
11 | Evaluate trained model run notebook: `evaluate.ipynb`
12 |
--------------------------------------------------------------------------------
/mt/prepare_iwslt14_de_en.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Download and prepare the data
4 | cd fairseq/examples/translation/
5 | bash prepare-iwslt14.sh
6 | cd ../../..
7 |
8 | # Preprocess/binarize the data
9 | TEXT=fairseq/examples/translation/iwslt14.tokenized.de-en
10 | fairseq-preprocess --source-lang de --target-lang en \
11 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
12 | --destdir data-bin/iwslt14.tokenized.de-en \
13 | --workers 20
14 |
--------------------------------------------------------------------------------
/mt/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python3 train.py data-bin/iwslt14.tokenized.de-en \
4 | --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
5 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
6 | --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
7 | --dropout 0.3 --weight-decay 0.0001 --criterion editable_training_criterion \
8 | --label-smoothing 0.1 --max-tokens 4096 --tensorboard-logdir train_editable_logs \
9 | --edit-samples-path edit_iwslt14.tokenized.de-en/bpe_train.txt \
10 | --save-dir checkpoints_editable \
11 | --stability-coeff 100 \
12 | --editability-coeff 100
13 |
--------------------------------------------------------------------------------
/lib/utils/loader.py:
--------------------------------------------------------------------------------
1 | import torchvision.datasets as datasets
2 | import numpy as np
3 | import torch
4 |
5 | '''
6 | Dataset loader that returns batch of images and precomputed logits
7 | '''
8 |
9 | class ImageAndLogitsFolder(datasets.ImageFolder):
10 |
11 | def __init__(self, *args, logits_prefix, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | self.logits_prefix = logits_prefix
14 |
15 | @staticmethod
16 | def logits_path_create(prefix, path):
17 | return prefix + path[path.rfind("/") + 1:path.find(".j")] + ".npy"
18 |
19 | def get_image_path(self, index):
20 | return self.imgs[index]
21 |
22 | def __getitem__(self, index):
23 | img, target = super().__getitem__(index)
24 |
25 | logits_path = ImageAndLogitsFolder.logits_path_create(
26 | self.logits_prefix,
27 | self.get_image_path(index)[0]
28 | )
29 | logits = np.load(logits_path)
30 | logits = torch.reshape(torch.Tensor(logits), (1, -1))
31 |
32 | return img, target, logits
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Anton Sinitsin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # node and NPM
2 | npm-debug.log
3 | node_modules
4 |
5 | # swap files
6 | *~
7 | *.swp
8 |
9 | notebooks/data/*
10 | notebooks/runs/*
11 | notebooks/.ipynb_checkpoints/*
12 |
13 | env.sh
14 | # Byte-compiled / optimized / DLL files
15 | __pycache__/
16 | *.py[cod]
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | env/
24 | bin/
25 | build/
26 | develop-eggs/
27 | dist/
28 | eggs/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg/
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 |
49 | # Translations
50 | *.mo
51 |
52 | # Mr Developer
53 | .mr.developer.cfg
54 | .project
55 | .pydevproject
56 | .idea
57 | .ipynb_checkpoints
58 |
59 | # Rope
60 | .ropeproject
61 |
62 | # Django stuff:
63 | *.log
64 | *.pot
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 | docs/tmp*
69 |
70 | # OS X garbage
71 | .DS_Store
72 |
73 | # Debian things
74 | debian/reproducible-experiment-platform
75 | debian/files
76 | *.substvars
77 | *.debhelper.log
78 |
79 | .vscode
80 |
81 | #pytorch model weights
82 | *.pth
83 |
--------------------------------------------------------------------------------
/lib/utils/copy_and_replace.py:
--------------------------------------------------------------------------------
1 | """ A convenience function for creating modified copies out-of-place with deepcopy """
2 | from contextlib import contextmanager
3 | from copy import deepcopy
4 |
5 | DEFAULT_MEMO = dict()
6 |
7 |
8 | def copy_and_replace(original, replace=None, do_not_copy=None):
9 | """
10 | :param original: object to be copied
11 | :param replace: a dictionary {old object -> new object}, replace all occurences of old object with new object
12 | :param do_not_copy: a sequence of objects that will not be copied (but may be replaced)
13 | :return: a copy of obj with replacements
14 | """
15 | replace, do_not_copy = replace or {}, do_not_copy or {}
16 | memo = dict(DEFAULT_MEMO)
17 | for item in do_not_copy:
18 | memo[id(item)] = item
19 |
20 | for item, replacement in replace.items():
21 | memo[id(item)] = replacement
22 |
23 | return deepcopy(original, memo)
24 |
25 |
26 | @contextmanager
27 | def do_not_copy(*items):
28 | """ all calls to copy_and_replace within this context won't copy items (but can replace them) """
29 | global DEFAULT_MEMO
30 | keys_to_remove = []
31 | for item in items:
32 | key = id(item)
33 | if key in DEFAULT_MEMO:
34 | DEFAULT_MEMO[key] = item
35 | keys_to_remove.append(key)
36 |
37 | yield
38 |
39 | for key in keys_to_remove:
40 | DEFAULT_MEMO.pop(key)
41 |
--------------------------------------------------------------------------------
/lib/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Loss functions
3 | """
4 |
5 | import torch
6 | from torch import nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | def kl_distill_loss(logits, ref_probs):
11 | """
12 | kullback leibler divergence
13 | """
14 | return F.kl_div(F.log_softmax(logits, dim=-1), ref_probs)
15 |
16 |
17 | def contrastive_cross_entropy(logits, target, margin=0.0):
18 | """
19 | A special loss that is similar to crossentropy but becomes exactly zero if
20 | logp(target) >= max(logp(all_excluding_target)) + margin
21 | Used for classification edits
22 | """
23 | logp = F.log_softmax(logits, dim=-1)
24 | target_one_hot = F.one_hot(target, num_classes=logp.shape[-1])
25 | logp_target = (logp * target_one_hot.to(logits.dtype)).sum(-1)
26 | logp_others = torch.where(target_one_hot.to(torch.uint8), torch.full_like(logp, -float('inf')), logp)
27 | return F.relu(margin + logp_others.max(dim=-1)[0] - logp_target).mean()
28 |
29 |
30 | def threshold_mse(predictions, targets, threshold=0.0, reduction_axes=None):
31 | """
32 | Like mean squared error but becomes exactly zero if
33 | sum of squared errors along reduction axes is below threshold
34 | used for regression edits
35 | """
36 | squared_error = (predictions - targets) ** 2
37 | if reduction_axes is not None:
38 | squared_error = squared_error.sum(reduction_axes)
39 | return F.relu(squared_error - threshold).mean()
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Editable neural networks
2 |
3 | A supplementary code for [Editable Neural Networks](https://openreview.net/forum?id=HJedXaEtvS), an ICLR 2020 submission by Anton Sinitsin, Vsevolod Plokhotnyuk, Dmitry Pyrkin, Sergei Popov, Artem Babenko.
4 |
5 |
6 |
7 | # What does it do?
8 |
9 | It trains a model so that it can later be edited: forced to predict a specific class on a specific input without losing accuracy.
10 |
11 | # What do i need to run it?
12 | * A machine with some CPU (preferably 2+ free cores) and GPU(s)
13 | * Running without GPU is possible but does not scale well, especially for ImageNet
14 | * Some popular Linux x64 distribution
15 | * Tested on Ubuntu16.04, should work fine on any popular linux64 and even MacOS;
16 | * Windows and x32 systems may require heavy wizardry to run;
17 | * When in doubt, use Docker, preferably GPU-enabled (i.e. nvidia-docker)
18 |
19 | # How do I run it?
20 | 1. Clone or download this repo. `cd` yourself to it's root directory.
21 | 2. Grab or build a working python enviromnent. [Anaconda](https://www.anaconda.com/) works fine.
22 | 3. Install packages from `requirements.txt`
23 | 4. Run jupyter notebook and open a notebook in `./notebooks/`
24 | * Before you run the first cell, change `%env CUDA_VISIBLE_DEVICES=#` to an index that you plan to use.
25 | * [CIFAR10 notebook](./notebooks/cifar10_editable_layer3.ipynb) can be ran with no extra preparation
26 | * The ImageNet notebooks require a step-by-step procedure to get running:
27 | 1. Download the dataset first. See [this page](https://pytorch.org/docs/stable/_modules/torchvision/datasets/imagenet.html) or just google it. No, really, go google it!
28 | 2. Run [`imagenet_preprocess_logits.ipynb`](./notebooks/imagenet_preprocess_logits.ipynb)
29 | 3. Train with [`imagenet_editable_training.ipynb`](./notebooks/imagenet_editable_training.ipynb)
30 | 4. Evaluate by using one of the two remaining notebooks.
31 | * To reproduce machine translation experiments, follow the instructions in [`./mt/README.md`](./mt/)
32 |
--------------------------------------------------------------------------------
/lib/utils/init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import nn as nn
4 | from torch.jit import ScriptModule
5 | from torch.nn import functional as F
6 |
7 |
8 | class ModuleWithInit(nn.Module):
9 | """ Base class for pytorch module with data-aware initializer on first batch """
10 | def __init__(self):
11 | super().__init__()
12 | assert not hasattr(self, '_is_initialized_bool')
13 | assert not hasattr(self, '_is_initialized_tensor')
14 | self._is_initialized_tensor = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)
15 | self._is_initialized_bool = None
16 | # Note: this module uses a separate flag self._is_initialized_* so as to achieve both
17 | # * persistence: is_initialized is saved alongside model in state_dict
18 | # * speed: model doesn't need to cache
19 | # please DO NOT use these flags in child modules
20 |
21 | def initialize(self, *args, **kwargs):
22 | """ initialize module tensors using first batch of data """
23 | raise NotImplementedError("Please implement ")
24 |
25 | def is_initialized(self):
26 | """ whether data aware initialization was already performed """
27 | if self._is_initialized_bool is None:
28 | self._is_initialized_bool = bool(self._is_initialized_tensor.item())
29 | return self._is_initialized_bool
30 |
31 | def __call__(self, *args, **kwargs):
32 | if self._is_initialized_bool is None:
33 | self._is_initialized_bool = bool(self._is_initialized_tensor.item())
34 | if not self._is_initialized_bool:
35 | self.initialize(*args, **kwargs)
36 | self._is_initialized_tensor.data[...] = 1
37 | self._is_initialized_bool = True
38 | return super().__call__(*args, **kwargs)
39 |
40 |
41 | class ScriptModuleWithInit(ModuleWithInit, ScriptModule):
42 | """ Base class for pytorch module with data-aware initializer on first batch """
43 | def __init__(self, optimize=True, **kwargs):
44 | ScriptModule.__init__(self, optimize=optimize, **kwargs)
45 | ModuleWithInit.__init__(self)
46 |
47 |
48 | def init_normalized_(x, init_=nn.init.normal_, dim=-1, **kwargs):
49 | """ initialize x inp-place by sampling random normal values and normalizing them over dim """
50 | init_(x)
51 | x.data = F.normalize(x, dim=dim, **kwargs)
52 | return x
53 |
--------------------------------------------------------------------------------
/lib/utils/learning_rate.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from torch.optim.optimizer import Optimizer
3 | import numpy as np
4 |
5 |
6 | def get_learning_rate(optimizer):
7 | for param_group in optimizer.param_groups:
8 | if 'lr' in param_group:
9 | return param_group['lr']
10 | raise ValueError("Could not infer learning rate from optimizer {}".format(optimizer))
11 |
12 |
13 | class OneCycleSchedule(Optimizer):
14 | """ A simplified torch lr schedule that updates learning rate before every opt.step """
15 |
16 | def __init__(self, optimizer, **kwargs):
17 | """
18 | :type optimizer: torch.optim.Optimizer
19 | :param kwargs: see self.update_learning_rate
20 | """
21 | self.learning_rate_opts = kwargs
22 | self.opt = optimizer
23 | self.step_count = 0
24 |
25 | def step(self, **kwargs):
26 | self.current_lr = self.update_learning_rate(t=self.step_count, **self.learning_rate_opts)
27 | res = self.opt.step(**kwargs)
28 | self.step_count += 1
29 | return res
30 |
31 | def state_dict(self, **kwargs):
32 | return OrderedDict([
33 | ('optimizer_state_dict', self.opt.state_dict(**kwargs)),
34 | ('learning_rate_opts', self.learning_rate_opts),
35 | ('step_count', self.step_count)
36 | ])
37 |
38 | def load_state_dict(self, state_dict, load_step=True, load_opts=True, **kwargs):
39 | self.learning_rate_opts = state_dict['learning_rate_opts'] if load_opts else self.learning_rate_opts
40 | self.step_count = state_dict['step_count'] if load_step else self.step_count
41 | return self.opt.load_state_dict(state_dict['optimizer_state_dict'], **kwargs)
42 |
43 | def __getattr__(self, attr):
44 | if attr in self.__dict__:
45 | return getattr(self, attr)
46 | return getattr(self.opt, attr)
47 |
48 | def update_learning_rate(self, t, learning_rate_base=1e-3, warmup_steps=10000,
49 | decay_rate=0.2, learning_rate_min=1e-5):
50 | """ Learning rate with linear warmup and exponential decay """
51 | lr = learning_rate_base * np.minimum(
52 | (t + 1.0) / warmup_steps,
53 | np.exp(decay_rate * ((warmup_steps - t - 1.0) / warmup_steps)),
54 | )
55 | lr = np.maximum(lr, learning_rate_min)
56 | for param_group in self.opt.param_groups:
57 | param_group['lr'] = lr
58 | return lr
59 |
--------------------------------------------------------------------------------
/lib/utils/gumbel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .basic import to_one_hot
3 |
4 |
5 | def gumbel_noise(*sizes, epsilon=1e-9, **kwargs):
6 | """ Sample noise from gumbel distribution """
7 | return -torch.log(-torch.log(torch.rand(*sizes, **kwargs) + epsilon) + epsilon)
8 |
9 |
10 | def gumbel_softmax(logits, dim=-1, tau=1.0, noise=1.0, hard=False, **kwargs):
11 | """
12 | Softmax with gumbel noise
13 | :param logits: inputs for softmax
14 | :param dim: normalize softmax along this dimension
15 | :param tau: gumbel softmax temperature
16 | :param hard: if True, works like onehot(sample) during forward pass,
17 | gumbel-softmax for backward pass
18 | :return: gumbel-softmax "probabilities", tensor of same shape as logits
19 | """
20 | if noise != 0:
21 | z = gumbel_noise(*logits.shape, device=logits.device, dtype=logits.dtype)
22 | logits = logits + noise * z
23 | if tau != 1.0:
24 | logits = logits / tau
25 |
26 | probs_gumbel = torch.softmax(logits, dim=dim)
27 |
28 | if hard:
29 | _, argmax_indices = torch.max(probs_gumbel, dim=dim)
30 | hard_argmax_onehot = to_one_hot(argmax_indices, depth=logits.shape[dim])
31 | if dim != -1 and dim != len(logits.shape) - 1:
32 | new_dim_order = list(range(len(logits.shape) - 1))
33 | new_dim_order.insert(dim, -1)
34 | hard_argmax_onehot = hard_argmax_onehot.permute(*new_dim_order)
35 |
36 | # forward pass: onehot sample, backward pass: gumbel softmax
37 | probs_gumbel = (hard_argmax_onehot - probs_gumbel).detach() + probs_gumbel
38 |
39 | return probs_gumbel
40 |
41 |
42 | def gumbel_sigmoid(logits, tau=1.0, noise=1.0, hard=False, **kwargs):
43 | """
44 | A special case of gumbel softmax with 2 classes: [logit] and 0
45 | :param logits: sigmoid inputs
46 | :param tau: same as gumbel softmax temperature
47 | :param hard: if True, works like bernoulli sample for forward pass,
48 | gumbel sigmoid for backward pass
49 | :return: tensor with same shape as logits
50 | """
51 | if noise != 0.0:
52 | z1 = gumbel_noise(*logits.shape, device=logits.device, dtype=logits.dtype)
53 | z2 = gumbel_noise(*logits.shape, device=logits.device, dtype=logits.dtype)
54 | logits = logits + noise *(z1 - z2)
55 | if tau != 1.0:
56 | logits /= tau
57 | sigm = torch.sigmoid(logits)
58 | if hard:
59 | hard_sample = torch.ge(sigm, 0.5).to(dtype=logits.dtype)
60 | sigm = (hard_sample - sigm).detach() + sigm
61 | return sigm
62 |
--------------------------------------------------------------------------------
/lib/utils/regularize.py:
--------------------------------------------------------------------------------
1 | """
2 | Context-based regularizers, thread-safe
3 |
4 | Usage:
5 | >>> with collect_regularizers() as regularizers:
6 | >>> y1 = model1(x)
7 | >>> y2 = model2(x)
8 | >>>
9 | >>> do_something([regularizers[module]['my_activation'] for module in regularizers])
10 |
11 | Inside model1 and model2 code:
12 | >>> <...>
13 | >>> activation = self.foo(x)
14 | >>> if is_regularized('my_activation')
15 | >>> add_regularizer(self, 'my_activation', activation)
16 | >>> output = self.bar(activation)
17 | >>> return output
18 |
19 | """
20 | from contextlib import contextmanager
21 | from collections import defaultdict
22 | import threading
23 | from warnings import warn
24 |
25 |
26 | REGULARIZERS_KEYS = None
27 | REGULARIZERS = None
28 | tls = threading.local()
29 |
30 |
31 | @contextmanager
32 | def collect_regularizers(collection=None, keys=None, within_thread=None):
33 | if within_thread is None:
34 | if threading.current_thread() is not threading.main_thread():
35 | warn("Calling collect_regularizers while not in main thread, please set within_thread explicitly")
36 | within_thread = threading.current_thread() == threading.main_thread()
37 |
38 | if collection is None:
39 | collection = defaultdict(lambda: defaultdict(list))
40 |
41 | global REGULARIZERS, REGULARIZERS_KEYS
42 | setattr(tls, 'REGULARIZERS', getattr(tls, 'REGULARIZERS', None))
43 | setattr(tls, 'REGULARIZERS_KEYS', getattr(tls, 'REGULARIZERS_KEYS', None))
44 |
45 | _old_regs, _old_keys = REGULARIZERS, REGULARIZERS_KEYS
46 | _old_local_regs, _old_local_keys = tls.REGULARIZERS, tls.REGULARIZERS_KEYS
47 | try:
48 | if within_thread:
49 | tls.REGULARIZERS, tls.REGULARIZERS_KEYS = collection, keys
50 | REGULARIZERS, REGULARIZERS_KEYS = None, None
51 | else:
52 | REGULARIZERS, REGULARIZERS_KEYS = collection, keys
53 | tls.REGULARIZERS, tls.REGULARIZERS_KEYS = None, None
54 |
55 | yield collection
56 | finally:
57 | REGULARIZERS = _old_regs
58 | REGULARIZERS_KEYS = _old_keys
59 | tls.REGULARIZERS = _old_local_regs
60 | tls.REGULARIZERS_KEYS = _old_local_keys
61 |
62 |
63 | def get_regularized_keys():
64 | is_local = hasattr(tls, 'REGULARIZERS')
65 | return getattr(tls, 'REGULARIZERS_KEYS', REGULARIZERS_KEYS) if is_local else REGULARIZERS_KEYS
66 |
67 |
68 | def get_regularizer_collection():
69 | is_local = hasattr(tls, 'REGULARIZERS')
70 | return getattr(tls, 'REGULARIZERS', None) if is_local else REGULARIZERS
71 |
72 |
73 | def is_regularized(key):
74 | if get_regularizer_collection() is None:
75 | return False
76 | keys = get_regularized_keys()
77 | return keys is None or key in keys
78 |
79 |
80 | def add_regularizer(module, key, value):
81 | assert is_regularized(key)
82 | get_regularizer_collection()[module][key].append(value)
83 |
--------------------------------------------------------------------------------
/lib/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn as nn
4 | from tqdm import tqdm
5 | import torch.nn.functional as F
6 |
7 | from .utils import infer_model_device, training_mode, process_in_chunks, check_numpy, nop
8 | from .editable import Editable
9 |
10 |
11 | def classification_error(model: nn.Module, X_test, y_test, batch_size=1024, device=None):
12 | device = device or infer_model_device(model)
13 | with torch.no_grad(), training_mode(model, is_train=False):
14 | val_logits = process_in_chunks(
15 | model, torch.as_tensor(X_test, device=device), batch_size=batch_size)
16 | val_logits = check_numpy(val_logits)
17 | error_rate = (check_numpy(y_test) != np.argmax(val_logits, axis=1)).mean()
18 | return error_rate
19 |
20 | def eval_func(model: nn.Module, X_test, batch_size=1024, device='cuda'):
21 | with torch.no_grad(), training_mode(model, is_train=False):
22 | val_logits = process_in_chunks(
23 | model, torch.as_tensor(X_test, device=device), batch_size=batch_size)
24 | val_logits = check_numpy(val_logits)
25 | return val_logits
26 |
27 | def calculate_edit_statistics(editable_model: Editable, X_test, y_test, X_edit, y_edit,
28 | error_function=classification_error, progressbar=None, **kwargs):
29 | """
30 | For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality
31 | :param editable_model: model to be edited
32 | :param X_test: data for quality evaluaton
33 | :param y_test: targets for quality evaluaton
34 | :param X_edit: sequence of data for training model on
35 | :param y_edit: sequence of targets for training model on
36 | :param error_function: function that measures quality
37 | :param progressbar: iterator(range) that may or may not print progress, use progressbar=True for tqdm.tqdm
38 | :param kwargs: extra parameters for model.edit
39 | :return: list of results of experiments
40 | """
41 | progressbar = tqdm if progressbar is True else progressbar or nop
42 | results_temporary = []
43 | with training_mode(editable_model, is_train=False):
44 | for i in progressbar(range(len(X_edit))):
45 | edited_model, success, loss, complexity = editable_model.edit(
46 | X_edit[i:i + 1], y_edit[i:i + 1], detach=True, **kwargs)
47 | results_temporary.append((error_function(edited_model, X_test, y_test), success, complexity))
48 | return results_temporary
49 |
50 |
51 | def evaluate_quality(editable_model: Editable, X_test, y_test, X_edit, y_edit,
52 | error_function=classification_error, progressbar=None, **kwargs):
53 | """
54 | For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality
55 | :param editable_model: model to be edited
56 | :param X_test: data for quality evaluaton
57 | :param y_test: targets for quality evaluaton
58 | :param X_edit: sequence of data for training model on
59 | :param y_edit: sequence of targets for training model on
60 | :param error_function: function that measures quality
61 | :param kwargs: extra parameters for model.edit
62 | :return: dictionary of metrics
63 | """
64 | base_error = error_function(editable_model, X_test, y_test)
65 | results_temporary = calculate_edit_statistics(editable_model, X_test, y_test, X_edit, y_edit,
66 | progressbar=progressbar, error_function=error_function, **kwargs)
67 | errors, succeses, complexities = zip(*results_temporary)
68 | drawdown = np.mean(errors) - base_error
69 | return dict(base_error=base_error, drawdown=drawdown, success_rate=np.mean(succeses),
70 | mean_complexity=np.mean(complexities))
71 |
--------------------------------------------------------------------------------
/lib/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.nn import functional as F
4 |
5 | from .evaluate import classification_error, evaluate_quality
6 | from .utils import training_mode, BaseTrainer
7 | from .editable import Editable
8 | from .loss import kl_distill_loss
9 |
10 |
11 | class EditableTrainer(BaseTrainer):
12 | def __init__(self, model: Editable, loss_function, error_function=classification_error, opt=None,
13 | stability_coeff=0.01, editability_coeff=0.01, max_norm=None, **kwargs):
14 | """ A simple optimizer that trains to minimize classification or regression loss """
15 | opt = opt if opt is not None else torch.optim.Adam(model.parameters())
16 | super().__init__(model, loss_function=loss_function, opt=opt, error_function=error_function, **kwargs)
17 | self.stability_coeff, self.editability_coeff, self.max_norm = stability_coeff, editability_coeff, max_norm
18 |
19 | def train_on_batch(self, x_batch, y_batch, x_edit, y_edit, prefix='train/', is_train=True, **kwargs):
20 | """ Performs a single gradient update and reports metrics """
21 | x_batch, y_batch = map(torch.as_tensor, (x_batch, y_batch))
22 | self.opt.zero_grad()
23 |
24 | with training_mode(self.model, is_train=is_train):
25 | logits = self.model(x_batch)
26 |
27 | main_loss = self.loss_function(logits, y_batch).mean()
28 |
29 | with training_mode(self.model, is_train=False):
30 | model_edited, success, editability_loss, complexity = self.model.edit(x_edit, y_edit, **kwargs)
31 | logits_updated = model_edited(x_batch)
32 |
33 | stability_loss = - (F.softmax(logits.detach(), dim=1) * F.log_softmax(logits_updated, dim=1)).sum(dim=1).mean()
34 |
35 | final_loss = main_loss + self.stability_coeff * stability_loss + self.editability_coeff * editability_loss
36 |
37 | metrics = dict(
38 | final_loss=final_loss.item(), stability_loss=stability_loss.item(),
39 | editability_loss=editability_loss.item(), main_loss=main_loss.item(),
40 | )
41 |
42 | final_loss.backward()
43 |
44 | if self.max_norm is not None:
45 | metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.max_norm)
46 | self.opt.step()
47 |
48 | return self.record(**metrics, prefix=prefix)
49 |
50 | def evaluate_metrics(self, X, y, X_edit=None, y_edit=None, prefix='val/', **kwargs):
51 | """
52 | For each sample in X_edit, y_edit attempts to train model and evaluates trained model quality
53 | :param X: data for quality evaluaton
54 | :param y: targets for quality evaluaton
55 | :param X_edit: sequence of data for training model on
56 | :param y_edit: sequence of targets for training model on
57 | :param prefix: tensorboard metrics will be written under this prefix
58 | :param kwargs: extra parameters for error function
59 | :return: dictionary of metrics
60 | """
61 | assert (X_edit is None) == (y_edit is None), "provide either both X_edit and y_edit or none of them"
62 | if X_edit is None:
63 | num_classes = y.max() + 1
64 | ind = np.random.permutation(len(X))[:10]
65 | X_edit = X[ind]
66 | y_edit = (y[ind] + torch.randint_like(y[ind], 1, num_classes)) % num_classes
67 |
68 | return self.record(**evaluate_quality(
69 | self.model, X, y, X_edit, y_edit, error_function=self.error_function, **kwargs), prefix=prefix)
70 |
71 | def extra_repr(self):
72 | line = "stability_coeff = {}, editability_coeff = {}, max_norm = {}".format(
73 | self.stability_coeff, self.editability_coeff, self.max_norm)
74 | line += '\nloss = {} '.format(self.loss_function)
75 | line += '\nopt = {} '.format(self.opt)
76 | return line
77 |
78 |
79 | class DistillationEditableTrainer(EditableTrainer):
80 | def __init__(self, model, **kwargs):
81 | return super().__init__(model, loss_function=kl_distill_loss, **kwargs)
82 |
83 | def train_on_batch(self, x_batch, logits_batch, *args, **kwargs):
84 | return super().train_on_batch(x_batch, logits_batch, *args, is_train=True, **kwargs)
85 |
--------------------------------------------------------------------------------
/mt/evaluate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import re\n",
10 | "import os\n",
11 | "import numpy as np\n",
12 | "import tqdm\n",
13 | "\n",
14 | "fout_path = 'bleus_editable.txt'\n",
15 | "if os.path.isfile(fout_path):\n",
16 | " raise ValueError(\"File already exists\")\n",
17 | "\n",
18 | "def get_edit_bleu(checkpoint_path, edit_sample_index=None):\n",
19 | " edit_sample_str = ''\n",
20 | " if edit_sample_index is not None:\n",
21 | " edit_sample_str = '--edit-sample-index {}'.format(edit_sample_index)\n",
22 | "\n",
23 | " res = !python edited_generate.py data-bin/iwslt14.tokenized.de-en \\\n",
24 | " --path {checkpoint_path} \\\n",
25 | " --edit-samples-path edit_iwslt14.tokenized.de-en/bpe_test.txt \\\n",
26 | " {edit_sample_str} \\\n",
27 | " --no-progress-bar \\\n",
28 | " --beam 5 --remove-bpe --sacrebleu --moses-detokenizer de \n",
29 | " \n",
30 | " if edit_sample_index is not None:\n",
31 | " bleu_id = -2\n",
32 | " else:\n",
33 | " bleu_id = -1\n",
34 | "\n",
35 | " try:\n",
36 | " bleu = float(re.findall('BLEU\\(score=(\\d+(\\.\\d+|)),', res[bleu_id])[0][0])\n",
37 | " except:\n",
38 | " bleu = None\n",
39 | " \n",
40 | " if edit_sample_index is not None:\n",
41 | " try:\n",
42 | " success, complexity = re.findall('EditResult\\(success=(True|False), complexity=(\\d+)\\)', res[-1])[0]\n",
43 | " success = eval(success)\n",
44 | " complexity = int(complexity)\n",
45 | " except:\n",
46 | " success = complexity = None\n",
47 | " \n",
48 | " with open(fout_path, 'a') as f:\n",
49 | " print(bleu, success, complexity, file=f)\n",
50 | " else:\n",
51 | " success = complexity = None\n",
52 | " \n",
53 | " return bleu, success, complexity"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "get_edit_bleu('checkpoints_editable/checkpoint_best.pt')"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "np.random.seed(42)\n",
72 | "ids = np.random.permutation(6749)[:1000]\n",
73 | "[get_edit_bleu('checkpoints_editable/checkpoint_best.pt', id) for id in tqdm.tqdm_notebook(ids)]"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "bleus = []\n",
83 | "successes = []\n",
84 | "complexities = []\n",
85 | "with open('bleus_samples_stability1_kl_almost_last.txt') as f:\n",
86 | " for line in f:\n",
87 | " bleu, success, complexity = line[:-1].split()\n",
88 | " \n",
89 | " try:\n",
90 | " bleu = float(bleu)\n",
91 | " success = bool(success)\n",
92 | " complexity = float(complexity)\n",
93 | " except:\n",
94 | " bleu = 0\n",
95 | " success = False\n",
96 | " complexity = -1\n",
97 | " \n",
98 | " bleus.append(bleu)\n",
99 | " successes.append(success)\n",
100 | " complexities.append(complexity)\n",
101 | " \n",
102 | "bleus = np.array(bleus)\n",
103 | "successes = np.array(successes)\n",
104 | "complexities = np.array(complexities)\n",
105 | "\n",
106 | "print('Mean BLEU:', np.mean(bleus))\n",
107 | "print('Success rate:', np.mean(successes))\n",
108 | "print('Mean complexity:', np.mean(complexities))"
109 | ]
110 | }
111 | ],
112 | "metadata": {
113 | "kernelspec": {
114 | "display_name": "Python 3",
115 | "language": "python",
116 | "name": "python3"
117 | },
118 | "language_info": {
119 | "codemirror_mode": {
120 | "name": "ipython",
121 | "version": 3
122 | },
123 | "file_extension": ".py",
124 | "mimetype": "text/x-python",
125 | "name": "python",
126 | "nbconvert_exporter": "python",
127 | "pygments_lexer": "ipython3",
128 | "version": "3.7.4"
129 | }
130 | },
131 | "nbformat": 4,
132 | "nbformat_minor": 2
133 | }
134 |
--------------------------------------------------------------------------------
/notebooks/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, in_planes, planes, stride=1):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion*planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion*planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | out = F.relu(out)
36 | return out
37 |
38 |
39 | class Bottleneck(nn.Module):
40 | expansion = 4
41 |
42 | def __init__(self, in_planes, planes, stride=1):
43 | super(Bottleneck, self).__init__()
44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(planes)
46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
50 |
51 | self.shortcut = nn.Sequential()
52 | if stride != 1 or in_planes != self.expansion*planes:
53 | self.shortcut = nn.Sequential(
54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
55 | nn.BatchNorm2d(self.expansion*planes)
56 | )
57 |
58 | def forward(self, x):
59 | out = F.relu(self.bn1(self.conv1(x)))
60 | out = F.relu(self.bn2(self.conv2(out)))
61 | out = self.bn3(self.conv3(out))
62 | out += self.shortcut(x)
63 | out = F.relu(out)
64 | return out
65 |
66 |
67 | class ResNet(nn.Module):
68 | def __init__(self, block, num_blocks, num_classes=10):
69 | super(ResNet, self).__init__()
70 | self.in_planes = 64
71 |
72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
73 | self.bn1 = nn.BatchNorm2d(64)
74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
78 | self.linear = nn.Linear(512*block.expansion, num_classes)
79 |
80 | def _make_layer(self, block, planes, num_blocks, stride):
81 | strides = [stride] + [1]*(num_blocks-1)
82 | layers = []
83 | for stride in strides:
84 | layers.append(block(self.in_planes, planes, stride))
85 | self.in_planes = planes * block.expansion
86 | return nn.Sequential(*layers)
87 |
88 | def forward(self, x):
89 | out = F.relu(self.bn1(self.conv1(x)))
90 | out = self.layer1(out)
91 | out = self.layer2(out)
92 | out = self.layer3(out)
93 | out = self.layer4(out)
94 | out = F.avg_pool2d(out, 4)
95 | out = out.view(out.size(0), -1)
96 | out = self.linear(out)
97 | return out
98 |
99 |
100 | def ResNet18(**kwargs):
101 | return ResNet(BasicBlock, [2,2,2,2], **kwargs)
102 |
103 | def ResNet34():
104 | return ResNet(BasicBlock, [3,4,6,3])
105 |
106 | def ResNet50():
107 | return ResNet(Bottleneck, [3,4,6,3])
108 |
109 | def ResNet101():
110 | return ResNet(Bottleneck, [3,4,23,3])
111 |
112 | def ResNet152():
113 | return ResNet(Bottleneck, [3,8,36,3])
114 |
115 |
116 | def test():
117 | net = ResNet18()
118 | y = net(torch.randn(1,3,32,32))
119 | print(y.size())
120 |
121 | # test()
122 |
--------------------------------------------------------------------------------
/notebooks/imagenet_preprocess_logits.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "__This notebook__ computes teacher logits for editable fine-tuning on Imagenet"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [
15 | {
16 | "name": "stdout",
17 | "output_type": "stream",
18 | "text": [
19 | "env: CUDA_VISIBLE_DEVICES=5\n"
20 | ]
21 | }
22 | ],
23 | "source": [
24 | "%load_ext autoreload\n",
25 | "%autoreload 2\n",
26 | "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n",
27 | "\n",
28 | "imagenet_train_path = '../../imagenet/train' # path to train ImageFolder\n",
29 | "logits_path = './imagenet_logits/' # saves logits to this path\n",
30 | "\n",
31 | "import os, sys, time\n",
32 | "sys.path.insert(0, '..')\n",
33 | "import lib\n",
34 | "\n",
35 | "import numpy as np\n",
36 | "import torch, torch.nn as nn\n",
37 | "import torch.nn.functional as F\n",
38 | "import matplotlib.pyplot as plt\n",
39 | "%matplotlib inline\n",
40 | "\n",
41 | "import random\n",
42 | "random.seed(42)\n",
43 | "np.random.seed(42)\n",
44 | "torch.random.manual_seed(42)\n",
45 | "\n",
46 | "import time\n",
47 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 2,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "import torchvision.transforms as transforms\n",
57 | "import torchvision.datasets as datasets\n",
58 | "import torchvision.models as models\n",
59 | "\n",
60 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
61 | " std=[0.229, 0.224, 0.225])\n",
62 | "\n",
63 | "train_dataset = datasets.ImageFolder(\n",
64 | " imagenet_train_path,\n",
65 | " transforms.Compose([\n",
66 | " transforms.RandomResizedCrop(224),\n",
67 | " transforms.RandomHorizontalFlip(),\n",
68 | " transforms.ToTensor(),\n",
69 | " normalize,\n",
70 | " ]))\n",
71 | "\n",
72 | "batch_size = 64\n",
73 | "\n",
74 | "train_loader = torch.utils.data.DataLoader(\n",
75 | " train_dataset, batch_size=batch_size, shuffle=False,\n",
76 | " num_workers=1, pin_memory=True)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "from tqdm import tqdm_notebook, tnrange\n",
86 | "from IPython.display import clear_output\n",
87 | "\n",
88 | "import torchvision\n",
89 | "\n",
90 | "image_index = 0\n",
91 | "\n",
92 | "def get_logits(model, X_test):\n",
93 | " with lib.training_mode(model, is_train=False):\n",
94 | " return lib.eval_func(lib.Lambda(lambda x: model(x.to(device))),\n",
95 | " X_test, device='cuda', batch_size=64)\n",
96 | "\n",
97 | "model = torchvision.models.resnet18(pretrained=True).to(device)\n",
98 | "\n",
99 | "for batch in tqdm_notebook(train_loader):\n",
100 | " logits = get_logits(model, batch[0])\n",
101 | " for logit in logits:\n",
102 | " image_name = train_loader.dataset.imgs[image_index][0]\n",
103 | " image_name = \"{}/{}\".format(logits_path, image_name[image_name.rfind(\"/\") + 1:image_name.find(\".j\")])\n",
104 | " np.save(image_name, torch.softmax(logits, dim=-1))\n",
105 | " image_index += 1"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 5,
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "name": "stdout",
115 | "output_type": "stream",
116 | "text": [
117 | "9.9G\timagenet_logits\r\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "!du -h imagenet_logits"
123 | ]
124 | }
125 | ],
126 | "metadata": {
127 | "kernelspec": {
128 | "display_name": "Python 3",
129 | "language": "python",
130 | "name": "python3"
131 | },
132 | "language_info": {
133 | "codemirror_mode": {
134 | "name": "ipython",
135 | "version": 3
136 | },
137 | "file_extension": ".py",
138 | "mimetype": "text/x-python",
139 | "name": "python",
140 | "nbconvert_exporter": "python",
141 | "pygments_lexer": "ipython3",
142 | "version": "3.6.8"
143 | }
144 | },
145 | "nbformat": 4,
146 | "nbformat_minor": 2
147 | }
148 |
--------------------------------------------------------------------------------
/lib/utils/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import hashlib
4 | import numpy as np
5 | import requests
6 | from tqdm import tqdm
7 | import torch
8 |
9 |
10 | def download(url, filename, delete_if_interrupted=True, chunk_size=4096):
11 | """ saves file from url to filename with a fancy progressbar """
12 | try:
13 | with open(filename, "wb") as f:
14 | print("Downloading {} > {}".format(url, filename))
15 | response = requests.get(url, stream=True)
16 | total_length = response.headers.get('content-length')
17 |
18 | if total_length is None: # no content length header
19 | f.write(response.content)
20 | else:
21 | total_length = int(total_length)
22 | with tqdm(total=total_length) as progressbar:
23 | for data in response.iter_content(chunk_size=chunk_size):
24 | if data: # filter-out keep-alive chunks
25 | f.write(data)
26 | progressbar.update(len(data))
27 | except Exception as e:
28 | if delete_if_interrupted:
29 | print("Removing incomplete download {}.".format(filename))
30 | os.remove(filename)
31 | raise e
32 | return filename
33 |
34 |
35 | def iterate_minibatches(*tensors, batch_size, shuffle=True, epochs=1,
36 | allow_incomplete=True, callback=lambda x: x):
37 | """
38 | Minibatch iterator that's faster than torch dataloader when dealing with
39 | very simple objects and hundreds of batches per second.
40 | :param tensors: tensors to iterate over, all tensors must have shape [dataset_size, ...]
41 | compatible with np arrays, torch tensors and anything that can be indexed via tensor[[i0, i1, i2]]
42 | :param batch_size: maximum number of objects in one yield
43 | :param shuffle: if True, each epoch runs over data tensors in shuffled order
44 | :param epochs: number of full passes over train data
45 | :param allow_incomplete: if dataset_size is not divisible by batch_size,
46 | if True, the last batch in epoch will have less than batch_size objects
47 | if False, the last batch in epoch will be omitted
48 | :param callback: a wrapper for batch iterator, insert your tqdm/progressbar here :)
49 | """
50 | indices = np.arange(len(tensors[0]))
51 | upper_bound = int((np.ceil if allow_incomplete else np.floor)(len(indices) / batch_size)) * batch_size
52 | epoch = 0
53 | while True:
54 | if shuffle:
55 | np.random.shuffle(indices)
56 | for batch_start in callback(range(0, upper_bound, batch_size)):
57 | if batch_start == len(indices): break
58 | batch_ix = indices[batch_start: batch_start + batch_size]
59 | batch = [tensor[batch_ix] for tensor in tensors]
60 | yield batch if len(tensors) > 1 else batch[0]
61 | epoch += 1
62 | if epoch >= epochs:
63 | break
64 |
65 |
66 | def process_in_chunks(function, *args, batch_size, out=None, **kwargs):
67 | """
68 | Computes output by applying batch-parallel function to large data tensor in chunks
69 | :param function: a function(*[x[indices, ...] for x in args]) -> out[indices, ...]
70 | :param args: one or many tensors, each [num_instances, ...]
71 | :param batch_size: maximum chunk size processed in one go
72 | :param out: memory buffer for out, defaults to torch.zeros of appropriate size and type
73 | :returns: function(data), computed in a memory-efficient way
74 | """
75 | total_size = args[0].shape[0]
76 | first_output = function(*[x[0: batch_size] for x in args])
77 | output_shape = (total_size,) + tuple(first_output.shape[1:])
78 | if out is None:
79 | out = torch.zeros(*output_shape, dtype=first_output.dtype, device=first_output.device,
80 | layout=first_output.layout, **kwargs)
81 |
82 | out[0: batch_size] = first_output
83 | for i in range(batch_size, total_size, batch_size):
84 | batch_ix = slice(i, min(i + batch_size, total_size))
85 | out[batch_ix] = function(*[x[batch_ix] for x in args])
86 | return out
87 |
88 |
89 | def check_numpy(x):
90 | """ Makes sure x is a numpy array """
91 | if isinstance(x, torch.Tensor):
92 | x = x.detach().cpu().numpy()
93 | x = np.asarray(x)
94 | assert isinstance(x, np.ndarray)
95 | return x
96 |
97 |
98 | def get_latest_file(pattern):
99 | list_of_files = glob.glob(pattern) # * means all if need specific format then *.csv
100 | assert len(list_of_files) > 0, "No files found: " + pattern
101 | return max(list_of_files, key=os.path.getctime)
102 |
103 |
104 | def md5sum(fname):
105 | """ Computes mdp checksum of a file """
106 | hash_md5 = hashlib.md5()
107 | with open(fname, "rb") as f:
108 | for chunk in iter(lambda: f.read(4096), b""):
109 | hash_md5.update(chunk)
110 | return hash_md5.hexdigest()
111 |
--------------------------------------------------------------------------------
/mt/generate_edit_datasets_samples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "import os\n",
11 | "from collections import defaultdict\n",
12 | "import random\n",
13 | "import numpy as np\n",
14 | "\n",
15 | "np.random.seed(1234)\n",
16 | "random.seed(1234)\n",
17 | "\n",
18 | "%env CUDA_VISIBLE_DEVICES=0\n",
19 | "\n",
20 | "\n",
21 | "src_lang, dst_lang = 'de', 'en'\n",
22 | "checkpoint_path = 'baseline_checkpoint.pt'\n",
23 | "data_path = 'data-bin/iwslt14.tokenized.{}-{}/'.format(src_lang, dst_lang)\n",
24 | "output_path = 'edit_iwslt14.tokenized.{}-{}'.format(src_lang, dst_lang)\n",
25 | "beam = nbest = 32\n",
26 | "max_tokens = 1024\n",
27 | "keys = ['valid', 'test', 'train']\n",
28 | "\n",
29 | "\n",
30 | "sys.path.append('fairseq')"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "if os.path.isdir(output_path):\n",
40 | " if len(os.listdir(output_path)) > 0:\n",
41 | " raise ValueError('Output directory {} is not empty'.format(output_path))\n",
42 | "else:\n",
43 | " os.makedirs(output_path, exist_ok=True)"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "if not os.path.isfile(checkpoint_path):\n",
53 | " ! wget https://www.dropbox.com/s/iksezig9qi4g92e/baseline_checkpoint.pt?dl=1 -O {checkpoint_path}"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "from fairseq.data.dictionary import Dictionary\n",
63 | "\n",
64 | "src_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(src_lang)))\n",
65 | "dst_voc = Dictionary.load(os.path.join(data_path, 'dict.{}.txt'.format(dst_lang)))"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "def generate_alternatives(key, file):\n",
75 | " !fairseq-generate {data_path} --path {checkpoint_path} \\\n",
76 | " --beam {beam} --nbest {nbest} \\\n",
77 | " --max-tokens {max_tokens} \\\n",
78 | " --gen-subset {key} > {file}"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "def generate_edits(key, file):\n",
88 | " !fairseq-generate {data_path} --path {checkpoint_path} \\\n",
89 | " --beam {beam} --nbest {nbest} \\\n",
90 | " --max-tokens {max_tokens} \\\n",
91 | " --sampling --temperature 1.2 \\\n",
92 | " --gen-subset {key} > {file}"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "for key in keys:\n",
102 | " tmp_alternatives_output_file_path = os.path.join(output_path, 'tmp_alternaatives_{}.txt'.format(key))\n",
103 | " tmp_edits_output_file_path = os.path.join(output_path, 'tmp_edits_{}.txt'.format(key))\n",
104 | " bpe_output_file_path = os.path.join(output_path, 'bpe_{}.txt'.format(key))\n",
105 | " \n",
106 | " print('Start generating {} alternatives'.format(key))\n",
107 | " generate_alternatives(key, tmp_alternatives_output_file_path)\n",
108 | " print('Start generating {} edits'.format(key))\n",
109 | " generate_edits(key, tmp_edits_output_file_path)\n",
110 | "\n",
111 | " logs = defaultdict(list)\n",
112 | "\n",
113 | " print('Start parsing the alternatives beam search output')\n",
114 | "\n",
115 | " with open(tmp_alternatives_output_file_path) as f_in:\n",
116 | " for line in f_in:\n",
117 | " try:\n",
118 | " tag, *payload = line.split('\\t')\n",
119 | " logs[tag].append(payload)\n",
120 | " except: pass\n",
121 | " \n",
122 | " print('Start parsing the edits beam search output')\n",
123 | " \n",
124 | " edit_logs = defaultdict(list)\n",
125 | "\n",
126 | " with open(tmp_edits_output_file_path) as f_in:\n",
127 | " for line in f_in:\n",
128 | " try:\n",
129 | " tag, *payload = line.split('\\t')\n",
130 | " edit_logs[tag].append(payload)\n",
131 | " except: pass\n",
132 | " \n",
133 | " print('Start edit samples generating')\n",
134 | "\n",
135 | " with open(bpe_output_file_path, 'w') as f_out:\n",
136 | " i = 0\n",
137 | " while True:\n",
138 | " i += 1\n",
139 | " try:\n",
140 | " source = logs['S-{}'.format(i)][0][0].strip()\n",
141 | " edit_source = edit_logs['S-{}'.format(i)][0][0].strip()\n",
142 | " except:\n",
143 | " break\n",
144 | " hypos = [hypo.strip() for prob, hypo in logs['H-{}'.format(i)]]\n",
145 | " \n",
146 | " edits = [edit.strip() for prob, edit in edit_logs['H-{}'.format(i)]]\n",
147 | " edit = random.choice(edits)\n",
148 | "\n",
149 | " np.random.shuffle(hypos)\n",
150 | " f_out.write('{}\\t{}\\t{}\\n'.format(source, edit, '\\t'.join(hypos)))\n",
151 | " \n",
152 | " print('_'*100)\n",
153 | " print('\\n'*3)"
154 | ]
155 | }
156 | ],
157 | "metadata": {
158 | "kernelspec": {
159 | "display_name": "Python 3",
160 | "language": "python",
161 | "name": "python3"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.7.4"
174 | }
175 | },
176 | "nbformat": 4,
177 | "nbformat_minor": 2
178 | }
179 |
--------------------------------------------------------------------------------
/lib/utils/basic.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import gc
3 | import os
4 | import time
5 | from collections import Counter
6 | from itertools import chain
7 |
8 | import torch
9 | from torch import nn as nn
10 |
11 |
12 | def to_one_hot(y, depth=None):
13 | r"""
14 | Takes integer with n dims and converts it to 1-hot representation with n + 1 dims.
15 | The n+1'st dimension will have zeros everywhere but at y'th index, where it will be equal to 1.
16 | Args:
17 | y: input integer (IntTensor, LongTensor or Variable) of any shape
18 | depth (int): the size of the one hot dimension
19 | """
20 | y_flat = y.to(torch.int64).view(-1, 1)
21 | depth = depth if depth is not None else int(torch.max(y_flat)) + 1
22 | y_one_hot = torch.zeros(y_flat.size()[0], depth, device=y.device).scatter_(1, y_flat, 1)
23 | y_one_hot = y_one_hot.view(*(tuple(y.shape) + (-1,)))
24 | return y_one_hot
25 |
26 |
27 | def dot(x, y):
28 | """ numpy-like dot product """
29 | out_flat = x.view(-1, x.shape[-1]) @ y.view(y.shape[0], -1)
30 | return out_flat.view(*x.shape[:-1], *y.shape[1:])
31 |
32 |
33 | def batch_outer_sum(*tensors):
34 | """
35 | :param tensors: each matrix should have shape [..., d_i]
36 | :returns: [..., d_0, d_1, ..., d_N] where N = len(tensors)
37 | output[..., i, j, k] = tensors[0][..., i] + tensors[1][..., j] + tensors[2][..., k]
38 | """
39 | outer_sum = None
40 | for i, tensor in enumerate(tensors):
41 | broadcaster = [None] * len(tensors)
42 | broadcaster[i] = slice(tensor.shape[-1])
43 | broadcaster = tuple([...] + broadcaster)
44 | outer_sum = tensor[broadcaster] if i == 0 else outer_sum + tensor[broadcaster]
45 | return outer_sum
46 |
47 |
48 | def batch_outer_product(*tensors):
49 | """
50 | :param tensors: each matrix should have shape [..., d_i]
51 | :returns: [..., d_0, d_1, ..., d_N] where N = len(tensors)
52 | output[..., i, j, k] = tensors[0][..., i] * tensors[1][..., j] * tensors[2][..., k]
53 | """
54 | prefix_shape = tensors[0].shape[:-1]
55 | assert len(tensors) + len(prefix_shape) <= ord('z') - ord('a')
56 |
57 | prefix_chars = ''.join(map(chr, range(ord('a'), ord('a') + len(prefix_shape))))
58 | dim_chars = ''.join(map(chr, range(ord('a') + len(prefix_shape), ord('a') + len(prefix_shape) + len(tensors))))
59 | einsum_lhs = ','.join(prefix_chars + d_i for d_i in dim_chars)
60 | einsum_rhs = prefix_chars + dim_chars
61 | return torch.einsum("{}->{}".format(einsum_lhs, einsum_rhs), *tensors)
62 |
63 |
64 | def straight_through_grad(function, **kwargs):
65 | """
66 | modify function so that it is applied normally but excluded from backward pass
67 | :param function: callable(*inputs) -> *outputs, number and shape of outputs must match that of inputs,
68 | :param kwargs: keyword arguments that will be sent to each function call
69 | """
70 | def f_straight_through(*inputs):
71 | outputs = function(*inputs, **kwargs)
72 | single_output = isinstance(outputs, torch.Tensor)
73 | if single_output:
74 | outputs = [outputs]
75 |
76 | assert isinstance(outputs, (list, tuple)) and len(outputs) == len(inputs)
77 | outputs = type(outputs)(
78 | input + (output - input).detach()
79 | for input, output in zip(inputs, outputs)
80 | )
81 | return outputs[0] if single_output else outputs
82 |
83 | return f_straight_through
84 |
85 |
86 | def nop(x):
87 | return x
88 |
89 |
90 | @contextlib.contextmanager
91 | def nop_ctx():
92 | yield None
93 |
94 |
95 | class Nop(nn.Module):
96 | def forward(self, x):
97 | return x
98 |
99 |
100 | class Residual(nn.Sequential):
101 | def forward(self, x):
102 | return super().forward(x) + x
103 |
104 |
105 | class Flatten(nn.Module):
106 | def forward(self, x):
107 | return x.view(len(x), -1)
108 |
109 |
110 | @contextlib.contextmanager
111 | def training_mode(*modules, is_train:bool):
112 | group = nn.ModuleList(modules)
113 | was_training = {module: module.training for module in group.modules()}
114 | try:
115 | yield group.train(is_train)
116 | finally:
117 | for key, module in group.named_modules():
118 | if module in was_training:
119 | module.training = was_training[module]
120 | else:
121 | raise ValueError("Model was modified inside training_mode(...) context, could not find {}".format(key))
122 |
123 |
124 | def free_memory(sleep_time=0.1):
125 | """ Black magic function to free torch memory and some jupyter whims """
126 | gc.collect()
127 | torch.cuda.synchronize()
128 | gc.collect()
129 | torch.cuda.empty_cache()
130 | time.sleep(sleep_time)
131 |
132 |
133 | def infer_model_device(model: nn.Module):
134 | """ infers model device as the device where the majority of parameters and buffers are stored """
135 | device_stats = Counter(
136 | tensor.device for tensor in chain(model.parameters(), model.buffers())
137 | if torch.is_tensor(tensor)
138 | )
139 | return max(device_stats, key=device_stats.get)
140 |
141 |
142 | class Lambda(nn.Module):
143 | def __init__(self, func):
144 | """ :param func: call this function during forward """
145 | super().__init__()
146 | self.func = func
147 |
148 | def forward(self, *args, **kwargs):
149 | return self.func(*args, **kwargs)
150 |
151 |
152 | def to_float_str(element):
153 | try:
154 | return str(float(element))
155 | except ValueError:
156 | return element
157 |
158 |
159 | def run_from_ipython():
160 | try:
161 | __IPYTHON__
162 | return True
163 | except NameError:
164 | return False
165 |
166 |
167 | if run_from_ipython():
168 | from IPython.display import clear_output
169 | else:
170 | def clear_output(*args, **kwargs):
171 | os.system('clear')
172 |
173 |
174 | class OptimizerList(torch.optim.Optimizer):
175 | def __init__(self, *optimizers):
176 | self.optimizers = optimizers
177 |
178 | def step(self):
179 | return [opt.step() for opt in self.optimizers]
180 |
181 | def zero_grad(self):
182 | return [opt.zero_grad() for opt in self.optimizers]
183 |
184 | def add_param_group(self, *args, **kwargs):
185 | raise ValueError("Please call add_param_group in one of self.optimizers")
186 |
187 | def __getstate__(self):
188 | return [opt.__getstate__() for opt in self.optimizers]
189 |
190 | def __setstate__(self, state):
191 | return [opt.__setstate__(opt_state) for opt, opt_state in zip(self.optimizers, state)]
192 |
193 | def __repr__(self):
194 | return repr(self.optimizers)
195 |
196 | def state_dict(self, **kwargs):
197 | return {"opt_{}".format(i): opt.state_dict(**kwargs) for i, opt in enumerate(self.optimizers)}
198 |
199 | def load_state_dict(self, state_dict, **kwargs):
200 | return [
201 | opt.load_state_dict(state_dict["opt_{}".format(i)])
202 | for i, opt in enumerate(self.optimizers)
203 | ]
204 |
--------------------------------------------------------------------------------
/lib/editable.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from copy import copy
3 | from itertools import count
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .utils.copy_and_replace import do_not_copy, copy_and_replace
9 | from .utils.ingraph_update import IngraphGradientDescent
10 |
11 |
12 | class BaseEditable(nn.Module):
13 | EditResult = namedtuple('EditResult', ['model', 'success', 'loss', 'complexity'])
14 | # model: a model that was adjusted out-of-place. Must be the same type as self (*Editable)
15 | # success: True if edit was successful, False otherwise
16 | # loss: objective function at the termination of the edit procedure
17 | # complexity: a measure of effort it took to edit model, e.g. number of SGD steps
18 |
19 | def edit(self, *data):
20 | # This should perform editing without changing current model and return EditResult
21 | return self.EditResult(self, success=False, loss=0.0, complexity=0.0)
22 |
23 |
24 | class Editable(BaseEditable):
25 |
26 | def __init__(self, module: nn.Module, loss_function,
27 | optimizer=IngraphGradientDescent(0.01), max_steps=float('inf'),
28 | get_editable_parameters=lambda module: module.parameters(),
29 | is_edit_finished=lambda loss, **kwargs: loss.item() <= 0,
30 | ):
31 | """
32 | Editable module that attempts to change model by performing SGD (with optional momentum and rms scaling)
33 | :param module: a torch module that will be edited
34 | :param loss_function: objective function(model(inputs), targets) that is minimized by editor.
35 | By default this function should be non-negative and loss == 0 is a trigger to finish editing
36 | :param optimizer: in-graph optimizer that creates updated copies of model
37 | :param get_editable_parameters: a function(Editable.module) that takes the wrapped module and returns
38 | an iterable of parameters that should affected by edits, defaults to all parameters inside Editable.module
39 | :param is_edit_finished: a function(loss, prediction, **local variables) that returns True if edit is finished
40 | """
41 | super().__init__()
42 | self.module, self.loss_function, self.optimizer = module, loss_function, optimizer
43 | self.get_editable_parameters = get_editable_parameters
44 | self.is_edit_finished = is_edit_finished
45 | self.max_steps = max_steps
46 |
47 | def forward(self, *args, **kwargs):
48 | return self.module(*args, **kwargs)
49 |
50 | def edit(self, inputs, targets, max_steps=None, model_kwargs=None, loss_kwargs=None, opt_kwargs=None, **kwargs):
51 | """
52 | Attempts to edit model (out-of-place) and return an edited copy
53 | :param inputs: data that is fed into the model
54 | :param targets: reference answers that are fed into loss function
55 | :param max_steps: after this many gradient steps the process is terminated
56 | :param model_kwargs: optional extra model inputs, used as model(inputs, **model_params)
57 | :param loss_kwargs: optional extra loss parameters, self.loss_function(model(inputs), targets, **loss_params)
58 | :param opt_kwargs: optional overrides for optimizer.get_initial_state
59 | :param kwargs: extra parameters passed to optimizer.step
60 | :returns: edited_model, is_edit_successful, final_loss, gradients_steps
61 | :rtype: Editable.EditResult
62 | """
63 | model_kwargs, loss_kwargs, opt_kwargs = model_kwargs or {}, loss_kwargs or {}, opt_kwargs or {}
64 | optimizer_state = self.optimizer.get_initial_state(self, **opt_kwargs)
65 | editable = self
66 |
67 | for step in count():
68 | prediction = editable(inputs, **model_kwargs)
69 | loss = self.loss_function(prediction, targets, **loss_kwargs)
70 |
71 | if self.is_edit_finished(**locals()):
72 | return self.EditResult(editable, success=True, loss=loss, complexity=step)
73 | elif step >= (max_steps or self.max_steps):
74 | return self.EditResult(editable, success=False, loss=loss, complexity=step)
75 |
76 | optimizer_state, editable = self.optimizer.step(
77 | optimizer_state, editable, loss, parameters=editable.get_editable_parameters(editable.module), **kwargs)
78 |
79 | def extra_repr(self):
80 | return "max_steps={}, loss_function={}".format(self.max_steps, repr(self.loss_function))
81 |
82 |
83 | class SequentialWithEditable(BaseEditable):
84 | def __init__(self, *args):
85 | """ A chain of modules with exactly one Editable, edit procedure will only compute pre-editable modules once """
86 | super().__init__()
87 | pre_editable, editable, post_editable = [], None, []
88 | for module in args:
89 | if isinstance(module, BaseEditable):
90 | assert editable is None, "SequentialEditable only supports one Editable module for now"
91 | editable = module
92 | elif editable is None:
93 | pre_editable.append(module)
94 | else:
95 | post_editable.append(module)
96 |
97 | assert editable is not None, "SequentialEditable must have one Editable at init, got 0"
98 | self.prefix_layers = nn.Sequential(*pre_editable)
99 | self.editable = editable if len(post_editable) == 0 else self._editable_with_suffix(editable, *post_editable)
100 |
101 | def forward(self, *args, **kwargs):
102 | return self.editable(self.prefix_layers(*args, **kwargs))
103 |
104 | def edit(self, inputs, *args, **kwargs):
105 | result = self.editable.edit(self.prefix_layers(inputs), *args, **kwargs)
106 | with do_not_copy(self.prefix_layers, *self.prefix_layers.parameters(), *self.prefix_layers.buffers()):
107 | edited_model = copy_and_replace(self, replace={self.editable: result.model})
108 | return self.EditResult(edited_model, *result[1:])
109 |
110 | @staticmethod
111 | def _editable_with_suffix(base_editable: Editable, *suffix):
112 | new_editable = copy(base_editable)
113 | new_editable.module = nn.Sequential(base_editable.module, *suffix)
114 | new_editable.get_editable_parameters = lambda module: base_editable.get_editable_parameters(module[0])
115 | return new_editable
116 |
117 |
118 | class RehearsalEditable(Editable):
119 | def __init__(self, *args, rehearsal_loss_weight=1.0, get_rehearsals, **kwargs):
120 | super().__init__(*args, **kwargs)
121 | self.rehearsal_loss_weight = rehearsal_loss_weight
122 | self.get_rehearsals = get_rehearsals
123 |
124 | def edit(self, inputs, targets, max_steps=None, model_kwargs=None, loss_kwargs=None, opt_kwargs={}, **kwargs):
125 | model_kwargs, loss_kwargs, opt_kwargs = model_kwargs or {}, loss_kwargs or {}, opt_kwargs or {}
126 | optimizer_state = self.optimizer.get_initial_state(self, **opt_kwargs)
127 | editable = self
128 |
129 | X_batch = self.get_rehearsals(inputs)
130 | vanilla_probs = F.softmax(editable(X_batch, **model_kwargs), dim=-1).detach()
131 | for step in count():
132 | prediction = editable(inputs, **model_kwargs)
133 |
134 | loss = self.loss_function(prediction, targets, **loss_kwargs)
135 |
136 | if self.is_edit_finished(**locals()):
137 | return self.EditResult(editable, success=True, loss=loss, complexity=step)
138 | elif step >= (max_steps or self.max_steps):
139 | return self.EditResult(editable, success=False, loss=loss, complexity=step)
140 |
141 | current_logp = F.log_softmax(editable(X_batch, **model_kwargs), dim=-1)
142 | batch_loss = F.kl_div(current_logp, vanilla_probs)
143 | total_loss = loss + self.rehearsal_loss_weight * batch_loss
144 |
145 | with do_not_copy(self.get_rehearsals):
146 | optimizer_state, editable = self.optimizer.step(
147 | optimizer_state, editable, total_loss, parameters=editable.get_editable_parameters(editable.module), **kwargs)
148 |
--------------------------------------------------------------------------------
/notebooks/cifar10_editable_layer3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "__This notebook__ trains resnet18 from scratch on CIFAR10 dataset."
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [
15 | {
16 | "name": "stdout",
17 | "output_type": "stream",
18 | "text": [
19 | "env: CUDA_VISIBLE_DEVICES=1\n",
20 | "editable_layer3_2019.09.19_23:06:14\n",
21 | "PyTorch version: 1.1.0\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "%load_ext autoreload\n",
27 | "%autoreload 2\n",
28 | "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n",
29 | "import os, sys, time\n",
30 | "sys.path.insert(0, '..')\n",
31 | "import lib\n",
32 | "\n",
33 | "import numpy as np\n",
34 | "import torch, torch.nn as nn\n",
35 | "import torch.nn.functional as F\n",
36 | "import matplotlib.pyplot as plt\n",
37 | "%matplotlib inline\n",
38 | "\n",
39 | "import random\n",
40 | "random.seed(42)\n",
41 | "np.random.seed(42)\n",
42 | "torch.random.manual_seed(42)\n",
43 | "\n",
44 | "import time\n",
45 | "from resnet import ResNet18\n",
46 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
47 | "\n",
48 | "experiment_name = 'editable_layer3'\n",
49 | "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n",
50 | "print(experiment_name)\n",
51 | "print(\"PyTorch version:\", torch.__version__)"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 2,
57 | "metadata": {},
58 | "outputs": [
59 | {
60 | "name": "stdout",
61 | "output_type": "stream",
62 | "text": [
63 | "Files already downloaded and verified\n",
64 | "Files already downloaded and verified\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "from torchvision import transforms, datasets\n",
70 | "\n",
71 | "transform_train = transforms.Compose([\n",
72 | " transforms.RandomCrop(32, padding=4),\n",
73 | " transforms.RandomHorizontalFlip(),\n",
74 | " transforms.ToTensor(),\n",
75 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
76 | "])\n",
77 | "\n",
78 | "transform_test = transforms.Compose([\n",
79 | " transforms.ToTensor(),\n",
80 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
81 | "])\n",
82 | "\n",
83 | "trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
84 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)\n",
85 | "\n",
86 | "testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)\n",
87 | "testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)\n",
88 | "\n",
89 | "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
90 | "X_test, y_test = map(torch.cat, zip(*list(testloader)))"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 3,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "model = lib.Editable(\n",
100 | " module=ResNet18(), loss_function=lib.contrastive_cross_entropy,\n",
101 | " get_editable_parameters=lambda module: module.layer3.parameters(),\n",
102 | " optimizer=lib.IngraphRMSProp(\n",
103 | " learning_rate=1e-3, beta=nn.Parameter(torch.tensor(0.5, dtype=torch.float32)), \n",
104 | " ), max_steps=10,\n",
105 | "\n",
106 | ").to(device)\n",
107 | "\n",
108 | "trainer = lib.EditableTrainer(model, F.cross_entropy, experiment_name=experiment_name, max_norm=10)\n",
109 | "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', '
'))"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "data": {
119 | "application/vnd.jupyter.widget-view+json": {
120 | "model_id": "af4871f821da4d80b68dd4841955e2ac",
121 | "version_major": 2,
122 | "version_minor": 0
123 | },
124 | "text/html": [
125 | "
Failed to display Jupyter Widget of type HBox.
\n", 127 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 128 | " that the widgets JavaScript is still loading. If this message persists, it\n", 129 | " likely means that the widgets JavaScript library is either not installed or\n", 130 | " not enabled. See the Jupyter\n", 131 | " Widgets Documentation for setup instructions.\n", 132 | "
\n", 133 | "\n", 134 | " If you're reading this message in another frontend (for example, a static\n", 135 | " rendering on GitHub or NBViewer),\n", 136 | " it may mean that your frontend doesn't currently support widgets.\n", 137 | "
\n" 138 | ], 139 | "text/plain": [ 140 | "HBox(children=(IntProgress(value=0, max=391), HTML(value='')))" 141 | ] 142 | }, 143 | "metadata": {}, 144 | "output_type": "display_data" 145 | } 146 | ], 147 | "source": [ 148 | "from tqdm import tqdm_notebook, tnrange\n", 149 | "from IPython.display import clear_output\n", 150 | "\n", 151 | "val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 152 | "min_error, min_drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 153 | "early_stopping_epochs = 500\n", 154 | "number_of_epochs_without_improvement = 0\n", 155 | "\n", 156 | "def edit_generator():\n", 157 | " while True:\n", 158 | " for xb, yb in torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=2):\n", 159 | " yield xb.to(device), torch.randint_like(yb, low=0, high=len(classes), device=device)\n", 160 | "\n", 161 | "edit_generator = edit_generator()\n", 162 | "\n", 163 | "\n", 164 | "while True:\n", 165 | " for x_batch, y_batch in tqdm_notebook(trainloader):\n", 166 | " trainer.step(x_batch.to(device), y_batch.to(device), *next(edit_generator))\n", 167 | " \n", 168 | " val_metrics = trainer.evaluate_metrics(X_test.to(device), y_test.to(device))\n", 169 | " clear_output(True)\n", 170 | " \n", 171 | " error_rate, drawdown = val_metrics['base_error'], val_metrics['drawdown']\n", 172 | " \n", 173 | " number_of_epochs_without_improvement += 1\n", 174 | " \n", 175 | " \n", 176 | " if error_rate < min_error:\n", 177 | " trainer.save_checkpoint(tag='best_val_error')\n", 178 | " min_error = error_rate\n", 179 | " number_of_epochs_without_improvement = 0\n", 180 | " \n", 181 | " if drawdown < min_drawdown:\n", 182 | " trainer.save_checkpoint(tag='best_drawdown')\n", 183 | " min_drawdown = drawdown\n", 184 | " number_of_epochs_without_improvement = 0\n", 185 | " \n", 186 | " trainer.save_checkpoint()\n", 187 | " trainer.remove_old_temp_checkpoints()\n", 188 | "\n", 189 | " if number_of_epochs_without_improvement > early_stopping_epochs:\n", 190 | " break" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "from lib import evaluate_quality\n", 200 | "\n", 201 | "np.random.seed(9)\n", 202 | "indices = np.random.permutation(len(X_test))[:1000]\n", 203 | "X_edit = X_test[indices].clone().to(device)\n", 204 | "y_edit = torch.tensor(np.random.randint(0, 10, size=y_test[indices].shape), device=device)\n", 205 | "metrics = evaluate_quality(editable_model, X_test, y_test, X_edit, y_edit, batch_size=512)\n", 206 | "for key in sorted(metrics.keys()):\n", 207 | " print('{}\\t:{:.5}'.format(key, metrics[key]))" 208 | ] 209 | } 210 | ], 211 | "metadata": { 212 | "kernelspec": { 213 | "display_name": "Python 3", 214 | "language": "python", 215 | "name": "python3" 216 | }, 217 | "language_info": { 218 | "codemirror_mode": { 219 | "name": "ipython", 220 | "version": 3 221 | }, 222 | "file_extension": ".py", 223 | "mimetype": "text/x-python", 224 | "name": "python", 225 | "nbconvert_exporter": "python", 226 | "pygments_lexer": "ipython3", 227 | "version": "3.6.4" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 2 232 | } 233 | -------------------------------------------------------------------------------- /mt/edited_generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # This source code is based fairseq/generate.py 3 | """ 4 | Perform edit and translate pre-processed data with a trained model. 5 | """ 6 | 7 | 8 | import sys 9 | sys.path.append("fairseq") 10 | 11 | import torch 12 | 13 | from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils 14 | from fairseq.meters import StopwatchMeter, TimeMeter 15 | import tqdm 16 | 17 | from fairseq_criterion import EditableTrainingCriterion 18 | from mosestokenizer import MosesDetokenizer 19 | 20 | 21 | def main(args): 22 | assert args.path is not None, '--path required for generation!' 23 | assert not args.sampling or args.nbest == args.beam, \ 24 | '--sampling requires --nbest to be equal to --beam' 25 | assert args.replace_unk is None or args.raw_text, \ 26 | '--replace-unk requires a raw text dataset (--raw-text)' 27 | 28 | utils.import_user_module(args) 29 | 30 | if args.max_tokens is None and args.max_sentences is None: 31 | args.max_tokens = 12000 32 | print(args) 33 | 34 | use_cuda = torch.cuda.is_available() and not args.cpu 35 | 36 | # Load dataset splits 37 | task = tasks.setup_task(args) 38 | task.load_dataset(args.gen_subset) 39 | 40 | # Set dictionaries 41 | try: 42 | src_dict = getattr(task, 'source_dictionary', None) 43 | except NotImplementedError: 44 | src_dict = None 45 | tgt_dict = task.target_dictionary 46 | 47 | # Load ensemble 48 | print('| loading model(s) from {}'.format(args.path)) 49 | models, _model_args = checkpoint_utils.load_model_ensemble( 50 | args.path.split(':'), 51 | arg_overrides=eval(args.model_overrides), 52 | task=task, 53 | ) 54 | 55 | # Optimize ensemble for generation 56 | for model in models: 57 | model.make_generation_fast_( 58 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 59 | need_attn=args.print_alignment, 60 | ) 61 | if args.fp16: 62 | model.half() 63 | if use_cuda: 64 | model.cuda() 65 | 66 | # Load alignment dictionary for unknown word replacement 67 | # (None if no unknown word replacement, empty if no path to align dictionary) 68 | align_dict = utils.load_align_dict(args.replace_unk) 69 | 70 | # Load dataset (possibly sharded) 71 | itr = task.get_batch_iterator( 72 | dataset=task.dataset(args.gen_subset), 73 | max_tokens=args.max_tokens, 74 | max_sentences=args.max_sentences, 75 | max_positions=utils.resolve_max_positions( 76 | task.max_positions(), 77 | *[model.max_positions() for model in models] 78 | ), 79 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 80 | required_batch_size_multiple=args.required_batch_size_multiple, 81 | num_shards=args.num_shards, 82 | shard_id=args.shard_id, 83 | num_workers=args.num_workers, 84 | ).next_epoch_itr(shuffle=False) 85 | 86 | # Initialize generator 87 | gen_timer = StopwatchMeter() 88 | generator = task.build_generator(args) 89 | 90 | detokenizer = MosesDetokenizer(args.moses_detokenizer) 91 | 92 | # Generate and compute BLEU score 93 | if args.sacrebleu: 94 | scorer = bleu.SacrebleuScorer() 95 | else: 96 | scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) 97 | num_sentences = 0 98 | has_target = True 99 | with progress_bar.build_progress_bar(args, itr) as t: 100 | wps_meter = TimeMeter() 101 | 102 | if args.edit_sample_index is not None: 103 | # make EditableTransformer 104 | 105 | assert len(models) == 1 106 | 107 | criterion = EditableTrainingCriterion(args, task).train(False) 108 | critetion_state_dict = torch.load(args.path) 109 | if 'criterion' in critetion_state_dict: 110 | criterion.load_state_dict(critetion_state_dict['criterion']) 111 | 112 | model = models[0] 113 | edit_sample = criterion.samples[args.edit_sample_index] 114 | device = 'cuda' if use_cuda else 'cpu' 115 | edited_model, success, _, complexity = criterion.get_edited_transformer(model, edit_sample, device, detach=True) 116 | edited_model.train(False) 117 | 118 | models[0] = edited_model.recover_transformer() 119 | 120 | for sample in t: 121 | sample = utils.move_to_cuda(sample) if use_cuda else sample 122 | if 'net_input' not in sample: 123 | continue 124 | 125 | prefix_tokens = None 126 | if args.prefix_size > 0: 127 | prefix_tokens = sample['target'][:, :args.prefix_size] 128 | 129 | gen_timer.start() 130 | hypos = task.inference_step(generator, models, sample, prefix_tokens) 131 | num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) 132 | gen_timer.stop(num_generated_tokens) 133 | 134 | for i, sample_id in enumerate(sample['id'].tolist()): 135 | has_target = sample['target'] is not None 136 | 137 | # Remove padding 138 | src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) 139 | target_tokens = None 140 | if has_target: 141 | target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() 142 | 143 | # Either retrieve the original sentences or regenerate them from tokens. 144 | if align_dict is not None: 145 | src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) 146 | target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) 147 | else: 148 | if src_dict is not None: 149 | src_str = src_dict.string(src_tokens, args.remove_bpe) 150 | else: 151 | src_str = "" 152 | if has_target: 153 | target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) 154 | 155 | if not args.quiet: 156 | if src_dict is not None: 157 | print('S-{}\t{}'.format(sample_id, src_str)) 158 | if has_target: 159 | print('T-{}\t{}'.format(sample_id, target_str)) 160 | 161 | # Process top predictions 162 | for j, hypo in enumerate(hypos[i][:args.nbest]): 163 | hypo_tokens, hypo_str, alignment = utils.post_process_prediction( 164 | hypo_tokens=hypo['tokens'].int().cpu(), 165 | src_str=src_str, 166 | alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, 167 | align_dict=align_dict, 168 | tgt_dict=tgt_dict, 169 | remove_bpe=args.remove_bpe, 170 | ) 171 | 172 | if not args.quiet: 173 | print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) 174 | print('P-{}\t{}'.format( 175 | sample_id, 176 | ' '.join(map( 177 | lambda x: '{:.4f}'.format(x), 178 | hypo['positional_scores'].tolist(), 179 | )) 180 | )) 181 | 182 | if args.print_alignment: 183 | print('A-{}\t{}'.format( 184 | sample_id, 185 | ' '.join(map(lambda x: str(utils.item(x)), alignment)) 186 | )) 187 | 188 | # Score only the top hypothesis 189 | if has_target and j == 0: 190 | if align_dict is not None or args.remove_bpe is not None: 191 | # Convert back to tokens for evaluation with unk replacement and/or without BPE 192 | target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) 193 | if hasattr(scorer, 'add_string'): 194 | if args.moses_detokenizer: 195 | target_str = detokenizer(target_str.split()) 196 | hypo_str = detokenizer(hypo_str.split()) 197 | 198 | scorer.add_string(target_str, hypo_str) 199 | else: 200 | assert not args.moses_detokenizer, "detokenizer has no effect with current bleu scorer" 201 | scorer.add(target_tokens, hypo_tokens) 202 | 203 | wps_meter.update(num_generated_tokens) 204 | t.log({'wps': round(wps_meter.avg)}) 205 | num_sentences += sample['nsentences'] 206 | 207 | print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( 208 | num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) 209 | if has_target: 210 | print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) 211 | if args.edit_sample_index is not None: 212 | print('EditResult(success={}, complexity={})'.format(success, complexity)) 213 | return scorer 214 | 215 | 216 | def cli_main(): 217 | parser = options.get_generation_parser() 218 | parser.add_argument('--edit-sample-index', type=int, metavar='D', default=None, 219 | help='Index of edit sample to use (pass a dataset via criterion --edit-samples-path arg)') 220 | parser.add_argument('--moses-detokenizer', type=str, default=None) 221 | EditableTrainingCriterion.add_args(parser) 222 | args = options.parse_args_and_arch(parser) 223 | main(args) 224 | 225 | 226 | if __name__ == '__main__': 227 | cli_main() 228 | -------------------------------------------------------------------------------- /notebooks/imagenet_editable_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "__This notebook__ fine-tunes a pre-trained resnet18 model with editable training.\n", 8 | "\n", 9 | "__Prepare data:__\n", 10 | "* Download imagenet training and dataset\n", 11 | "* Make sure folder names are called \"000\", \"001\", ... \"010\", \"011\", ... and not \"0\", \"1\", ..., \"10\", \"11\", ...\n", 12 | " * rename if necessary\n", 13 | "* Run `imagenet_preprocess_logits.ipynb` to prepare fine-tuning metadata.\n", 14 | "\n", 15 | "__Training:__\n", 16 | "* Set environment variables and paths in the next cell\n", 17 | "* Run all cells :)" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "name": "stdout", 27 | "output_type": "stream", 28 | "text": [ 29 | "env: CUDA_VISIBLE_DEVICES=1\n", 30 | "imagenet_editable_extra_layer_2019.09.19_23:11:24\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "%load_ext autoreload\n", 36 | "%autoreload 2\n", 37 | "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n", 38 | "\n", 39 | "traindir = '../../imagenet/train' # path to train ImageFolder\n", 40 | "valdir = '../../imagenet/val' # path to validation ImageFolder\n", 41 | "logits_path = './imagenet_logits/' # see imagenet_preprocess_logits\n", 42 | "\n", 43 | "import os, sys, time\n", 44 | "sys.path.insert(0, '..')\n", 45 | "import lib\n", 46 | "\n", 47 | "import numpy as np\n", 48 | "import torch, torch.nn as nn\n", 49 | "import torch.nn.functional as F\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "%matplotlib inline\n", 52 | "\n", 53 | "import random\n", 54 | "random.seed(42)\n", 55 | "np.random.seed(42)\n", 56 | "torch.random.manual_seed(42)\n", 57 | "\n", 58 | "import time\n", 59 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 60 | "\n", 61 | "experiment_name = 'imagenet_editable_extra_layer'\n", 62 | "experiment_name = '{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(experiment_name, *time.gmtime()[:6])\n", 63 | "print(experiment_name)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 2, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "import torchvision.transforms as transforms\n", 73 | "import torchvision.datasets as datasets\n", 74 | "import torchvision.models as models\n", 75 | "\n", 76 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 77 | " std=[0.229, 0.224, 0.225])\n", 78 | "\n", 79 | "train_dataset = lib.ImageAndLogitsFolder(\n", 80 | " traindir,\n", 81 | " transforms.Compose([\n", 82 | " transforms.RandomResizedCrop(224),\n", 83 | " transforms.RandomHorizontalFlip(),\n", 84 | " transforms.ToTensor(),\n", 85 | " normalize,\n", 86 | " ]),\n", 87 | " logits_prefix = logits_path\n", 88 | ")\n", 89 | "\n", 90 | "batch_size = 128\n", 91 | "\n", 92 | "train_loader = torch.utils.data.DataLoader(\n", 93 | " train_dataset, batch_size=batch_size, shuffle=True,\n", 94 | " num_workers=12, pin_memory=True)\n", 95 | "\n", 96 | "val_loader = torch.utils.data.DataLoader(\n", 97 | " datasets.ImageFolder(valdir, transforms.Compose([\n", 98 | " transforms.Resize(256),\n", 99 | " transforms.CenterCrop(224),\n", 100 | " transforms.ToTensor(),\n", 101 | " normalize,\n", 102 | " ])),\n", 103 | " batch_size=batch_size, shuffle=False,\n", 104 | " num_workers=32, pin_memory=True)\n", 105 | "\n", 106 | "X_test, y_test = map(torch.cat, zip(*val_loader))\n", 107 | "X_test, y_test = X_test[::10], y_test[::10]\n", 108 | "# !!!IMPORTANT!!!\n", 109 | "# We use 10% of validation samples for faster validation, please use full validation set to measure \"final\" error rate" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 3, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "import torchvision\n", 119 | "\n", 120 | "model = torchvision.models.resnet18(pretrained=True)\n", 121 | "\n", 122 | "optimizer = lib.IngraphRMSProp(learning_rate=1e-4, beta=nn.Parameter(torch.as_tensor(0.5)))\n", 123 | "\n", 124 | "model = lib.SequentialWithEditable(\n", 125 | " model.conv1, model.bn1, model.relu, model.maxpool,\n", 126 | " model.layer1, model.layer2, model.layer3, model.layer4,\n", 127 | " model.avgpool, lib.Flatten(),\n", 128 | " lib.Editable(\n", 129 | " lib.Residual(nn.Linear(512, 4096), nn.ELU(), nn.Linear(4096, 512)),\n", 130 | " loss_function=lib.contrastive_cross_entropy, \n", 131 | " optimizer=optimizer, max_steps=10),\n", 132 | "\n", 133 | " model.fc\n", 134 | ").to(device)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 4, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "def classification_error(model, X_test, y_test):\n", 144 | " with lib.training_mode(model, is_train=False):\n", 145 | " return lib.classification_error(lib.Lambda(lambda x: model(x.to(device))),\n", 146 | " X_test, y_test, device='cpu', batch_size=128)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "new_params = set(model.editable.module[0].parameters())\n", 156 | "old_params = [param for param in model.parameters() if param not in new_params]\n", 157 | "\n", 158 | "training_opt = lib.OptimizerList(\n", 159 | " torch.optim.SGD(old_params, lr=1e-5, momentum=0.9, weight_decay=1e-4),\n", 160 | " torch.optim.SGD(new_params, lr=1e-3, momentum=0.9, weight_decay=1e-4),\n", 161 | ")\n", 162 | "\n", 163 | "trainer = lib.DistillationEditableTrainer(model,\n", 164 | " stability_coeff=0.03, editability_coeff=0.03,\n", 165 | " experiment_name=experiment_name,\n", 166 | " error_function=classification_error,\n", 167 | " opt=training_opt, max_norm=10)\n", 168 | "\n", 169 | "trainer.writer.add_text(\"trainer\", repr(trainer).replace('\\n', 'Failed to display Jupyter Widget of type HBox.
\n", 391 | " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", 392 | " that the widgets JavaScript is still loading. If this message persists, it\n", 393 | " likely means that the widgets JavaScript library is either not installed or\n", 394 | " not enabled. See the Jupyter\n", 395 | " Widgets Documentation for setup instructions.\n", 396 | "
\n", 397 | "\n", 398 | " If you're reading this message in another frontend (for example, a static\n", 399 | " rendering on GitHub or NBViewer),\n", 400 | " it may mean that your frontend doesn't currently support widgets.\n", 401 | "
\n" 402 | ], 403 | "text/plain": [ 404 | "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))" 405 | ] 406 | }, 407 | "metadata": {}, 408 | "output_type": "display_data" 409 | }, 410 | { 411 | "name": "stdout", 412 | "output_type": "stream", 413 | "text": [ 414 | "base_error\t:0.30764\n", 415 | "drawdown\t:0.0014556\n", 416 | "mean_complexity\t:2.149\n", 417 | "success_rate\t:1.0\n" 418 | ] 419 | } 420 | ], 421 | "source": [ 422 | "from lib import evaluate_quality\n", 423 | "from tqdm import tqdm_notebook\n", 424 | "\n", 425 | "np.random.seed(9)\n", 426 | "indices = np.random.permutation(len(X_adv))[:1000]\n", 427 | "X_edit = X_adv[indices].cuda()\n", 428 | "y_edit = y_adv[indices].cuda()\n", 429 | "metrics_adv = evaluate_quality(model, X_test, y_test, X_edit, y_edit, \n", 430 | " error_function=classification_error, progressbar=tqdm_notebook)\n", 431 | "\n", 432 | "for key in sorted(metrics_adv.keys()):\n", 433 | " print('{}\\t:{:.5}'.format(key, metrics_adv[key]))" 434 | ] 435 | } 436 | ], 437 | "metadata": { 438 | "kernelspec": { 439 | "display_name": "Python 3", 440 | "language": "python", 441 | "name": "python3" 442 | }, 443 | "language_info": { 444 | "codemirror_mode": { 445 | "name": "ipython", 446 | "version": 3 447 | }, 448 | "file_extension": ".py", 449 | "mimetype": "text/x-python", 450 | "name": "python", 451 | "nbconvert_exporter": "python", 452 | "pygments_lexer": "ipython3", 453 | "version": "3.6.4" 454 | } 455 | }, 456 | "nbformat": 4, 457 | "nbformat_minor": 2 458 | } 459 | --------------------------------------------------------------------------------