├── .gitignore ├── LICENSE ├── README.rst ├── dni.py ├── examples ├── mnist-cnn │ ├── LICENSE │ ├── README.rst │ ├── main.py │ └── requirements.txt ├── mnist-full-unlock │ ├── LICENSE │ ├── README.rst │ ├── main.py │ └── requirements.txt ├── mnist-mlp │ ├── LICENSE │ ├── README.rst │ ├── main.py │ └── requirements.txt └── rnn │ ├── LICENSE │ ├── README.rst │ ├── data.py │ ├── data │ └── penn │ │ ├── test.txt │ │ ├── train.txt │ │ └── valid.txt │ ├── generate.py │ ├── main.py │ ├── model.py │ └── requirements.txt ├── images ├── feedforward-complete-unlock.png ├── feedforward-update-unlock.png └── rnn-update-unlock.png └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Piotr Kozakowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Decoupled Neural Interfaces for PyTorch 2 | ======================================= 3 | 4 | This tiny library is an implementation of 5 | `Decoupled Neural Interfaces using Synthetic Gradients `_ 6 | for `PyTorch `_. 7 | It's very simple to use as it was designed to enable researchers to integrate 8 | DNI into existing models with minimal amounts of code. 9 | 10 | To install, run:: 11 | 12 | $ python setup.py install 13 | 14 | Description of the library and how to use it in some typical cases is provided 15 | below. For more information, please read the code. 16 | 17 | Terminology 18 | ----------- 19 | 20 | This library uses a message passing abstraction introduced in the paper. Some 21 | terms used in the API (matching those used in the paper wherever possible): 22 | 23 | - ``Interface`` - A Decoupled Neural Interface that decouples two parts (let's 24 | call them part A and part B) of the network and lets them communicate via 25 | ``message`` passing. It may be ``Forward``, ``Backward`` or 26 | ``Bidirectional``. 27 | - ``BackwardInterface`` - A type of ``Interface`` that the paper focuses on. 28 | It can be used to prevent update locking by predicting gradient for part A 29 | of the decoupled network based on the activation of its last layer. 30 | - ``ForwardInterface`` - A type of ``Interface`` that can be used to prevent 31 | forward locking by predicting input for part B of the network based on some 32 | information known to both parts - in the paper it's the input of the whole 33 | network. 34 | - ``BidirectionalInterface`` - A combination of ``ForwardInterface`` and 35 | ``BackwardInterface``, that can be used to achieve a complete unlock. 36 | - ``message`` - Information that is passed through an ``Interface`` - 37 | activation of the last layer for ``ForwardInterface`` or gradient w.r.t. 38 | that activation for ``BackwardInterface``. Note that no original information 39 | passes through. A ``message`` is consumed by one end of the ``Interface`` 40 | and used to update a ``Synthesizer``. Then the ``Synthesizer`` can be used 41 | produce a synthetic ``message`` at the other end of the ``Interface``. 42 | - ``trigger`` - Information based on which ``message`` is synthesized. It needs 43 | to be accessible by both parts of the network. For ``BackwardInterface``, it's 44 | activation of the layer w.r.t. which gradient is to be synthesized. For 45 | ``ForwardInterface`` it can be anything - in the paper it's the input of 46 | the whole network. 47 | - ``context`` - Additional information normally not shown to the network at 48 | the forward pass, that can condition an ``Interface`` to provide a better 49 | estimate of the ``message``. The paper uses labels for this purpose and calls 50 | DNI with context cDNI. 51 | - ``send`` - A method of an ``Interface``, that takes as input ``message`` 52 | and ``trigger``, based on which that ``message`` should be generated, 53 | and updates ``Synthesizer`` to improve the estimate. 54 | - ``receive`` - A method of an ``Interface``, that takes as input ``trigger`` 55 | and returns a ``message`` generated by a ``Synthesizer``. 56 | - ``Synthesizer`` - A regression model that estimates ``message`` based on 57 | ``trigger`` and ``context``. 58 | 59 | Typical use cases 60 | ----------------- 61 | 62 | Synthetic Gradient for Feed-Forward Networks 63 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 64 | 65 | In this case we want to decouple two parts A and B of a neural network to 66 | achieve an update unlock, so that there is a normal forward pass from part A to 67 | B, but part A learns using synthetic gradient generated by the DNI. 68 | 69 | .. image:: images/feedforward-update-unlock.png 70 | 71 | Following the paper's convention, solid black arrows are update-locked forward 72 | connections, dashed black arrows are update-unlocked forward connections, green 73 | arrows are real error gradients and blue arrows are synthetic error gradients. 74 | Full circles denote synthetic gradient loss computation and ``Synthesizer`` 75 | update. 76 | 77 | We can use a ``BackwardInterface`` to do that: 78 | 79 | .. code-block:: python 80 | 81 | class Network(torch.nn.Module): 82 | 83 | def __init__(self): 84 | # ... 85 | 86 | # 1. create a BackwardInterface, assuming that dimensionality of 87 | # the activation for which we want to synthesize gradients is 88 | # activation_dim 89 | self.backward_interface = dni.BackwardInterface( 90 | dni.BasicSynthesizer(output_dim=activation_dim, n_hidden=1) 91 | ) 92 | 93 | # ... 94 | 95 | def forward(self, x): 96 | # ... 97 | 98 | # 2. call the BackwardInterface at the point where we want to 99 | # decouple the network 100 | x = self.backward_interface(x) 101 | 102 | # ... 103 | 104 | return x 105 | 106 | That's it! During the forward pass, ``BackwardInterface`` will use a 107 | ``Synthesizer`` to generate synthetic gradient w.r.t. activation, backpropagate 108 | it and add to the computation graph a node that will intercept 109 | the real gradient during the backward pass and use it to update the 110 | ``Synthesizer``'s estimate. 111 | 112 | The ``Synthesizer`` used here is ``BasicSynthesizer`` - a multi-layer 113 | perceptron with ReLU activation function. Writing a custom ``Synthesizer`` is 114 | described at `Writing custom Synthesizers`_. 115 | 116 | You can specify a ``context`` by passing ``context_dim`` (dimensionality of the 117 | context vector) to the ``BasicSynthesizer`` constructor and wrapping all DNI 118 | calls in the ``dni.synthesizer_context`` context manager: 119 | 120 | .. code-block:: python 121 | 122 | class Network(torch.nn.Module): 123 | 124 | def __init__(self): 125 | # ... 126 | 127 | self.backward_interface = dni.BackwardInterface( 128 | dni.BasicSynthesizer( 129 | output_dim=activation_dim, n_hidden=1, 130 | context_dim=context_dim 131 | ) 132 | ) 133 | 134 | # ... 135 | 136 | def forward(self, x, y): 137 | # ... 138 | 139 | # assuming that context is labels given in variable y 140 | with dni.synthesizer_context(y): 141 | x = self.backward_interface(x) 142 | 143 | # ... 144 | 145 | return x 146 | 147 | Example code for digit classification on MNIST is at 148 | `examples/mnist-mlp `_. 149 | 150 | Complete Unlock for Feed-Forward Networks 151 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 152 | 153 | In this case we want to decouple two parts A and B of a neural network to 154 | achieve forward and update unlock, so that part B receives synthetic input and 155 | part A learns using synthetic gradient generated by the DNI. 156 | 157 | .. image:: images/feedforward-complete-unlock.png 158 | 159 | Red arrows are synthetic inputs. 160 | 161 | We can use a ``BidirectionalInterface`` to do that: 162 | 163 | .. code-block:: python 164 | 165 | class Network(torch.nn.Module): 166 | 167 | def __init__(self): 168 | # ... 169 | 170 | # 1. create a BidirectionalInterface, assuming that dimensionality of 171 | # the activation for which we want to synthesize gradients is 172 | # activation_dim and dimensionality of the input of the whole 173 | # network is input_dim 174 | self.bidirectional_interface = dni.BidirectionalInterface( 175 | # Synthesizer generating synthetic inputs for part B, trigger 176 | # here is the input of the network 177 | dni.BasicSynthesizer( 178 | output_dim=activation_dim, n_hidden=1, 179 | trigger_dim=input_dim 180 | ), 181 | # Synthesizer generating synthetic gradients for part A, 182 | # trigger here is the last activation of part A (no need to 183 | # specify dimensionality) 184 | dni.BasicSynthesizer( 185 | output_dim=activation_dim, n_hidden=1 186 | ) 187 | ) 188 | 189 | # ... 190 | 191 | def forward(self, input): 192 | x = input 193 | 194 | # ... 195 | 196 | # 2. call the BidirectionalInterface at the point where we want to 197 | # decouple the network, need to pass both the last activation 198 | # and the trigger, which in this case is the input of the whole 199 | # network 200 | x = self.backward_interface(x, input) 201 | 202 | # ... 203 | 204 | return x 205 | 206 | During the forward pass, ``BidirectionalInterface`` will receive real 207 | activation, use it to update the input ``Synthesizer``, generate synthetic 208 | gradient w.r.t. that activation using the gradient ``Synthesizer``, 209 | backpropagate it, generate synthetic input using the input ``Synthesizer`` 210 | and attach to it a computation graph node that will intercept the real gradient 211 | w.r.t. the synthetic input and use it to update the gradient ``Synthesizer``. 212 | 213 | Example code for digit classification on MNIST is at 214 | `examples/mnist-full-unlock `_. 215 | 216 | Writing custom Synthesizers 217 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 218 | 219 | This library includes only ``BasicSynthesizer`` - a very simple ``Synthesizer`` 220 | based on a multi-layer perceptron with ReLU activation function. It may not be 221 | sufficient for all cases, for example for classifying MNIST digits using a CNN 222 | the paper uses a ``Synthesizer`` that is also a CNN. 223 | 224 | You can easily write a custom ``Synthesizer`` by subclassing 225 | ``torch.nn.Module`` with method ``forward`` taking ``trigger`` and ``context`` 226 | as arguments and returning a synthetic ``message``: 227 | 228 | .. code-block:: python 229 | 230 | class CustomSynthesizer(torch.nn.Module): 231 | 232 | def forward(self, trigger, context): 233 | # synthesize the message 234 | return message 235 | 236 | ``trigger`` will be a ``torch.autograd.Variable`` and ``context`` will be 237 | whatever is passed to the ``dni.synthesizer_context`` context manager, or 238 | ``None`` if ``dni.synthesizer_context`` is not used. 239 | 240 | Example code for digit classification on MNIST using a CNN is at 241 | `examples/mnist-cnn `_. 242 | 243 | Synthetic Gradient for Recurrent Networks 244 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 245 | 246 | In this case we want to use DNI to approximate gradient from an 247 | infinitely-unrolled recurrent neural network and feed it to the last step of 248 | the RNN unrolled by truncated BPTT. 249 | 250 | .. image:: images/rnn-update-unlock.png 251 | 252 | We can use methods ``make_trigger`` and ``backward`` of ``BackwardInterface`` 253 | to do that: 254 | 255 | .. code-block:: python 256 | 257 | class Network(torch.nn.module): 258 | 259 | def __init__(self): 260 | # ... 261 | 262 | # 1. create a BackwardInterface, assuming that dimensionality of 263 | # the RNN hidden state is hidden_dim 264 | self.backward_interface = dni.BackwardInterface( 265 | dni.BasicSynthesizer(output_dim=hidden_dim, n_hidden=1) 266 | ) 267 | 268 | # ... 269 | 270 | def forward(self, input, hidden): 271 | # ... 272 | 273 | # 2. call make_trigger on the first state of the unrolled RNN 274 | hidden = self.backward_interface.make_trigger(hidden) 275 | # run the RNN 276 | (output, hidden) = self.rnn(input, hidden) 277 | # 3. call backward on the last state of the unrolled RNN 278 | self.backward_interface.backward(hidden) 279 | 280 | # ... 281 | 282 | # in the training loop: 283 | with dni.defer_backward(): 284 | (output, hidden) = model(input, hidden) 285 | loss = criterion(output, target) 286 | dni.backward(loss) 287 | 288 | ``BackwardInterface.make_trigger`` marks the first hidden state as a 289 | ``trigger`` used to update the gradient estimate. During the backward pass, 290 | gradient passing through the ``trigger`` will be compared to synthetic gradient 291 | generated based on the same ``trigger`` and the ``Synthesizer`` will be 292 | updated. ``BackwardInterface.backward`` computes synthetic gradient based on 293 | the last hidden state and backpropagates it. 294 | 295 | Because we are passing both real and synthetic gradients through the same nodes 296 | in the computation graph, we need to use ``dni.defer_backward`` and 297 | ``dni.backward``. ``dni.defer_backward`` is a context manager that accumulates 298 | all gradients passed to ``dni.backward`` (including those generated by 299 | ``Interfaces``) and backpropagates them all at once in the end. If we don't do 300 | that, PyTorch will complain about backpropagating twice through the same 301 | computation graph. 302 | 303 | Example code for word-level language modeling on Penn Treebank is at 304 | `examples/rnn `_. 305 | 306 | Distributed training with a Complete Unlock 307 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 308 | 309 | The paper describes distributed training of complex neural architectures as one 310 | of the potential uses of DNI. In this case we have a network split into parts 311 | A and B trained independently, perhaps on different machines, communicating via 312 | DNI. We can use methods ``send`` and ``receive`` of ``BidirectionalInterface`` 313 | to do that: 314 | 315 | .. code-block:: python 316 | 317 | class PartA(torch.nn.Module): 318 | 319 | def forward(self, input): 320 | x = input 321 | 322 | # ... 323 | 324 | # send the intermediate results computed by part A via DNI 325 | self.bidirectional_interface.send(x, input) 326 | 327 | class PartB(torch.nn.Module): 328 | 329 | def forward(self, input): 330 | # receive the intermediate results computed by part A via DNI 331 | x = self.bidirectional_interface.receive(input) 332 | 333 | # ... 334 | 335 | return x 336 | 337 | ``PartA`` and ``PartB`` have their own copies of the 338 | ``BidirectionalInterface``. ``BidirectionalInterface.send`` will compute 339 | synthetic gradient w.r.t. ``x`` (intermediate results computed by ``PartA``) 340 | based on ``x`` and ``input`` (input of the whole network), backpropagate it and 341 | update the estimate of ``x``. ``BidirectionalInterface.receive`` will compute 342 | synthetic ``x`` based on ``input`` and in the backward pass, update the 343 | estimate of the gradient w.r.t. ``x``. This should work as long as 344 | ``BidirectionalInterface`` parameters are synchronized between ``PartA`` and 345 | ``PartB`` once in a while. 346 | 347 | There is no example code for this use case yet. Contributions welcome! 348 | -------------------------------------------------------------------------------- /dni.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn import functional as F 4 | from torch.nn import init 5 | 6 | from contextlib import contextmanager 7 | from functools import partial 8 | 9 | 10 | class UnidirectionalInterface(torch.nn.Module): 11 | """Basic `Interface` for unidirectional communication. 12 | 13 | Can be used to manually pass `messages` with methods `send` and `receive`. 14 | 15 | Args: 16 | synthesizer: `Synthesizer` to use to generate `messages`. 17 | """ 18 | 19 | def __init__(self, synthesizer): 20 | super().__init__() 21 | 22 | self.synthesizer = synthesizer 23 | 24 | def receive(self, trigger): 25 | """Synthesizes a `message` based on `trigger`. 26 | 27 | Detaches `message` so no gradient will go through it during the 28 | backward pass. 29 | 30 | Args: 31 | trigger: `trigger` to use to synthesize a `message`. 32 | 33 | Returns: 34 | The synthesized `message`. 35 | """ 36 | return self.synthesizer( 37 | trigger, _Manager.get_current_context() 38 | ).detach() 39 | 40 | def send(self, message, trigger): 41 | """Updates the estimate of synthetic `message` based on `trigger`. 42 | 43 | Synthesizes a `message` based on `trigger`, computes the MSE between it 44 | and the input `message` and backpropagates it to compute its gradient 45 | w.r.t. `Synthesizer` parameters. Does not backpropagate through 46 | `trigger`. 47 | 48 | Args: 49 | message: Ground truth `message` that should be synthesized based on 50 | `trigger`. 51 | trigger: `trigger` that the `message` should be synthesized based 52 | on. 53 | """ 54 | synthetic_message = self.synthesizer( 55 | trigger.detach(), _Manager.get_current_context() 56 | ) 57 | loss = F.mse_loss(synthetic_message, message.detach()) 58 | _Manager.backward(loss) 59 | 60 | 61 | class ForwardInterface(UnidirectionalInterface): 62 | """`Interface` for synthesizing activations in the forward pass. 63 | 64 | Can be used to achieve a forward unlock. It does not make too much sense to 65 | use it on its own, as it breaks backpropagation (no gradients pass through 66 | `ForwardInterface`). To achieve both forward and update unlock, use 67 | `BidirectionalInterface`. 68 | 69 | Args: 70 | synthesizer: `Synthesizer` to use to generate `messages`. 71 | """ 72 | 73 | def forward(self, message, trigger): 74 | """Synthetic forward pass, no backward pass. 75 | 76 | Convenience method combining `send` and `receive`. Updates the 77 | `message` estimate based on `trigger` and returns a synthetic 78 | `message`. 79 | 80 | Works only in `training` mode, otherwise just returns the input 81 | `message`. 82 | 83 | Args: 84 | message: Ground truth `message` that should be synthesized based on 85 | `trigger`. 86 | trigger: `trigger` that the `message` should be synthesized based 87 | on. 88 | 89 | Returns: 90 | The synthesized `message`. 91 | """ 92 | if self.training: 93 | self.send(message, trigger) 94 | return self.receive(trigger) 95 | else: 96 | return message 97 | 98 | 99 | class BackwardInterface(UnidirectionalInterface): 100 | """`Interface` for synthesizing gradients in the backward pass. 101 | 102 | Can be used to achieve an update unlock. 103 | 104 | Args: 105 | synthesizer: `Synthesizer` to use to generate gradients. 106 | """ 107 | 108 | def forward(self, trigger): 109 | """Normal forward pass, synthetic backward pass. 110 | 111 | Convenience method combining `backward` and `make_trigger`. Can be 112 | used when we want to backpropagate synthetic gradients from and 113 | intercept real gradients at the same `Variable`, for example for 114 | update decoupling feed-forward networks. 115 | 116 | Backpropagates synthetic gradient from `trigger` and returns a copy of 117 | `trigger` with a synthetic gradient update operation attached. 118 | 119 | Works only in `training` mode, otherwise just returns the input 120 | `trigger`. 121 | 122 | Args: 123 | trigger: `trigger` to backpropagate synthetic gradient from and 124 | intercept real gradient at. 125 | 126 | Returns: 127 | A copy of `trigger` with a synthetic gradient update operation 128 | attached. 129 | """ 130 | if self.training: 131 | self.backward(trigger) 132 | return self.make_trigger(trigger.detach()) 133 | else: 134 | return trigger 135 | 136 | def backward(self, trigger, factor=1): 137 | """Backpropagates synthetic gradient from `trigger`. 138 | 139 | Computes synthetic gradient based on `trigger`, scales it by `factor` 140 | and backpropagates it from `trigger`. 141 | 142 | Works only in `training` mode, otherwise is a no-op. 143 | 144 | Args: 145 | trigger: `trigger` to compute synthetic gradient based on and to 146 | backpropagate it from. 147 | factor (optional): Factor by which to scale the synthetic gradient. 148 | Defaults to 1. 149 | """ 150 | if self.training: 151 | synthetic_gradient = self.receive(trigger) 152 | _Manager.backward(trigger, synthetic_gradient.data * factor) 153 | 154 | def make_trigger(self, trigger): 155 | """Attaches a synthetic gradient update operation to `trigger`. 156 | 157 | Returns a `Variable` with the same `data` as `trigger`, that during 158 | the backward pass will intercept gradient passing through it and use 159 | this gradient to update the `Synthesizer`'s estimate. 160 | 161 | Works only in `training` mode, otherwise just returns the input 162 | `trigger`. 163 | 164 | Returns: 165 | A copy of `trigger` with a synthetic gradient update operation 166 | attached. 167 | """ 168 | if self.training: 169 | return _SyntheticGradientUpdater.apply( 170 | trigger, 171 | self.synthesizer(trigger, _Manager.get_current_context()) 172 | ) 173 | else: 174 | return trigger 175 | 176 | 177 | class _SyntheticGradientUpdater(torch.autograd.Function): 178 | 179 | @staticmethod 180 | def forward(ctx, trigger, synthetic_gradient): 181 | (_, needs_synthetic_gradient_grad) = ctx.needs_input_grad 182 | if not needs_synthetic_gradient_grad: 183 | raise ValueError( 184 | 'synthetic_gradient should need gradient but it does not' 185 | ) 186 | 187 | ctx.save_for_backward(synthetic_gradient) 188 | # clone trigger to force creating a new Variable with 189 | # requires_grad=True 190 | return trigger.clone() 191 | 192 | @staticmethod 193 | def backward(ctx, true_gradient): 194 | (synthetic_gradient,) = ctx.saved_variables 195 | # compute MSE gradient manually to avoid dependency on PyTorch 196 | # internals 197 | (batch_size, *_) = synthetic_gradient.size() 198 | grad_synthetic_gradient = ( 199 | 2 / batch_size * (synthetic_gradient - true_gradient) 200 | ) 201 | return (true_gradient, grad_synthetic_gradient) 202 | 203 | 204 | class BidirectionalInterface(torch.nn.Module): 205 | """`Interface` for synthesizing both activations and gradients w.r.t. them. 206 | 207 | Can be used to achieve a full unlock. 208 | 209 | Args: 210 | forward_synthesizer: `Synthesizer` to use to generate `messages`. 211 | backward_synthesizer: `Synthesizer` to use to generate gradients w.r.t. 212 | `messages`. 213 | """ 214 | 215 | def __init__(self, forward_synthesizer, backward_synthesizer): 216 | super().__init__() 217 | 218 | self.forward_interface = ForwardInterface(forward_synthesizer) 219 | self.backward_interface = BackwardInterface(backward_synthesizer) 220 | 221 | def forward(self, message, trigger): 222 | """Synthetic forward pass, synthetic backward pass. 223 | 224 | Convenience method combining `send` and `receive`. Can be used when we 225 | want to `send` and immediately `receive` using the same `trigger`. For 226 | more complex scenarios, `send` and `receive` need to be used 227 | separately. 228 | 229 | Updates the `message` estimate based on `trigger`, backpropagates 230 | synthetic gradient from `message` and returns a synthetic `message` 231 | with a synthetic gradient update operation attached. 232 | 233 | Works only in `training` mode, otherwise just returns the input 234 | `message`. 235 | """ 236 | if self.training: 237 | self.send(message, trigger) 238 | return self.receive(trigger) 239 | else: 240 | return message 241 | 242 | def receive(self, trigger): 243 | """Combination of `ForwardInterface.receive` and 244 | `BackwardInterface.make_trigger`. 245 | 246 | Generates a synthetic `message` based on `trigger` and attaches to it 247 | a synthetic gradient update operation. 248 | 249 | Args: 250 | trigger: `trigger` to use to synthesize a `message`. 251 | 252 | Returns: 253 | The synthesized `message` with a synthetic gradient update 254 | operation attached. 255 | """ 256 | message = self.forward_interface.receive(trigger) 257 | return self.backward_interface.make_trigger(message) 258 | 259 | def send(self, message, trigger): 260 | """Combination of `ForwardInterface.send` and 261 | `BackwardInterface.backward`. 262 | 263 | Updates the estimate of synthetic `message` based on `trigger` and 264 | backpropagates synthetic gradient from `message`. 265 | 266 | Args: 267 | message: Ground truth `message` that should be synthesized based on 268 | `trigger` and that synthetic gradient should be backpropagated 269 | from. 270 | trigger: `trigger` that the `message` should be synthesized based 271 | on. 272 | """ 273 | self.forward_interface.send(message, trigger) 274 | self.backward_interface.backward(message) 275 | 276 | 277 | class BasicSynthesizer(torch.nn.Module): 278 | """Basic `Synthesizer` based on an MLP with ReLU activation. 279 | 280 | Args: 281 | output_dim: Dimensionality of the synthesized `messages`. 282 | n_hidden (optional): Number of hidden layers. Defaults to 0. 283 | hidden_dim (optional): Dimensionality of the hidden layers. Defaults to 284 | `output_dim`. 285 | trigger_dim (optional): Dimensionality of the trigger. Defaults to 286 | `output_dim`. 287 | context_dim (optional): Dimensionality of the context. If `None`, do 288 | not use context. Defaults to `None`. 289 | """ 290 | 291 | def __init__(self, output_dim, n_hidden=0, hidden_dim=None, 292 | trigger_dim=None, context_dim=None): 293 | super().__init__() 294 | 295 | if hidden_dim is None: 296 | hidden_dim = output_dim 297 | if trigger_dim is None: 298 | trigger_dim = output_dim 299 | 300 | top_layer_dim = output_dim if n_hidden == 0 else hidden_dim 301 | 302 | self.input_trigger = torch.nn.Linear( 303 | in_features=trigger_dim, out_features=top_layer_dim 304 | ) 305 | 306 | if context_dim is not None: 307 | self.input_context = torch.nn.Linear( 308 | in_features=context_dim, out_features=top_layer_dim 309 | ) 310 | else: 311 | self.input_context = None 312 | 313 | self.layers = torch.nn.ModuleList([ 314 | torch.nn.Linear( 315 | in_features=hidden_dim, 316 | out_features=( 317 | hidden_dim if layer_index < n_hidden - 1 else output_dim 318 | ) 319 | ) 320 | for layer_index in range(n_hidden) 321 | ]) 322 | 323 | # zero-initialize the last layer, as in the paper 324 | if n_hidden > 0: 325 | init.constant(self.layers[-1].weight, 0) 326 | else: 327 | init.constant(self.input_trigger.weight, 0) 328 | if context_dim is not None: 329 | init.constant(self.input_context.weight, 0) 330 | 331 | def forward(self, trigger, context): 332 | """Synthesizes a `message` based on `trigger` and `context`. 333 | 334 | Args: 335 | trigger: `trigger` to synthesize the `message` based on. Size: 336 | (`batch_size`, `trigger_dim`). 337 | context: `context` to condition the synthesizer. Ignored if 338 | `context_dim` has not been specified in the constructor. Size: 339 | (`batch_size`, `context_dim`). 340 | 341 | Returns: 342 | The synthesized `message`. 343 | """ 344 | last = self.input_trigger(trigger) 345 | 346 | if self.input_context is not None: 347 | last += self.input_context(context) 348 | 349 | for layer in self.layers: 350 | last = layer(F.relu(last)) 351 | 352 | return last 353 | 354 | 355 | @contextmanager 356 | def defer_backward(): 357 | """Defers backpropagation until the end of scope. 358 | 359 | Accumulates all gradients passed to `dni.backward` inside the scope and 360 | backpropagates them all in a single `torch.autograd.backward` call. 361 | 362 | Use it and `dni.backward` whenever you want to backpropagate multiple times 363 | through the same nodes in the computation graph, for example when mixing 364 | real and synthetic gradients. Otherwise, PyTorch will complain about 365 | backpropagating more than once through the same graph. 366 | 367 | Scopes of this context manager cannot be nested. 368 | """ 369 | if _Manager.defer_backward: 370 | raise RuntimeError('cannot nest defer_backward') 371 | _Manager.defer_backward = True 372 | 373 | try: 374 | yield 375 | 376 | if _Manager.deferred_gradients: 377 | (variables, gradients) = zip(*_Manager.deferred_gradients) 378 | torch.autograd.backward(variables, gradients) 379 | finally: 380 | _Manager.reset_defer_backward() 381 | 382 | 383 | @contextmanager 384 | def synthesizer_context(context): 385 | """Conditions `Synthesizer` calls within the scope on the given `context`. 386 | 387 | All `Synthesizer.forward` calls within the scope will receive `context` 388 | as an argument. 389 | 390 | Scopes of this context manager can be nested. 391 | """ 392 | _Manager.context_stack.append(context) 393 | yield 394 | _Manager.context_stack.pop() 395 | 396 | 397 | class _Manager: 398 | 399 | defer_backward = False 400 | deferred_gradients = [] 401 | context_stack = [] 402 | 403 | @classmethod 404 | def reset_defer_backward(cls): 405 | cls.defer_backward = False 406 | cls.deferred_gradients = [] 407 | 408 | @classmethod 409 | def backward(cls, variable, gradient=None): 410 | if gradient is None: 411 | gradient = _ones_like(variable.data) 412 | 413 | if cls.defer_backward: 414 | cls.deferred_gradients.append((variable, gradient)) 415 | else: 416 | variable.backward(gradient) 417 | 418 | @classmethod 419 | def get_current_context(cls): 420 | if cls.context_stack: 421 | return cls.context_stack[-1] 422 | else: 423 | return None 424 | 425 | 426 | """A simplified variant of `torch.autograd.backward` influenced by 427 | `defer_backward`. 428 | 429 | Inside of `defer_backward` scope, accumulates passed gradient to backpropagate 430 | it at the end of scope. Outside of `defer_backward`, backpropagates the 431 | gradient immediately. 432 | 433 | Use it and `defer_backward` whenever you want to backpropagate multiple times 434 | through the same nodes in the computation graph. 435 | 436 | Args: 437 | variable: `Variable` to backpropagate the gradient from. 438 | gradient (optional): Gradient to backpropagate from `variable`. Defaults 439 | to a `Tensor` of the same size as `variable`, filled with 1. 440 | """ 441 | backward = _Manager.backward 442 | 443 | 444 | def _ones_like(tensor): 445 | return tensor.new().resize_(tensor.size()).fill_(1) 446 | -------------------------------------------------------------------------------- /examples/mnist-cnn/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /examples/mnist-cnn/README.rst: -------------------------------------------------------------------------------- 1 | MNIST with CNN 2 | -------------- 3 | 4 | This example illustrates how to implement a custom ``Synthesizer``. 5 | Code is mostly copied from the official PyTorch MNIST example: 6 | https://github.com/pytorch/examples/blob/master/mnist/main.py 7 | 8 | Classification model is the same as in the original example (a CNN) with 9 | batch normalization added on every layer and DNI inserted between the last 10 | convolutional layer and the first fully-connected layer (before activation). 11 | 12 | Synthesizer used is a CNN with three convolutional layers with padding, so 13 | that sizes of the feature maps are kept constant, and ReLU activation function. 14 | 15 | To install requirements:: 16 | 17 | $ pip install -r requirements.txt 18 | 19 | To train with regular backpropagation:: 20 | 21 | $ python main.py 22 | 23 | To train with DNI (no label conditioning):: 24 | 25 | $ python main.py --dni 26 | 27 | To train with cDNI (label conditioning):: 28 | 29 | $ python main.py --dni --context 30 | -------------------------------------------------------------------------------- /examples/mnist-cnn/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import dni 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 13 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 14 | help='input batch size for training (default: 64)') 15 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 16 | help='input batch size for testing (default: 1000)') 17 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 18 | help='number of epochs to train (default: 10)') 19 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 20 | help='learning rate (default: 0.001)') 21 | parser.add_argument('--no-cuda', action='store_true', default=False, 22 | help='disables CUDA training') 23 | parser.add_argument('--seed', type=int, default=1, metavar='S', 24 | help='random seed (default: 1)') 25 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 26 | help='how many batches to wait before logging training status') 27 | parser.add_argument('--dni', action='store_true', default=False, 28 | help='enable DNI') 29 | parser.add_argument('--context', action='store_true', default=False, 30 | help='enable context (label conditioning) in DNI') 31 | args = parser.parse_args() 32 | args.cuda = not args.no_cuda and torch.cuda.is_available() 33 | 34 | torch.manual_seed(args.seed) 35 | if args.cuda: 36 | torch.cuda.manual_seed(args.seed) 37 | 38 | 39 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 40 | train_loader = torch.utils.data.DataLoader( 41 | datasets.MNIST('../data', train=True, download=True, 42 | transform=transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.1307,), (0.3081,)) 45 | ])), 46 | batch_size=args.batch_size, shuffle=True, **kwargs) 47 | test_loader = torch.utils.data.DataLoader( 48 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,)) 51 | ])), 52 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 53 | 54 | 55 | def one_hot(indexes, n_classes): 56 | result = torch.FloatTensor(indexes.size() + (n_classes,)) 57 | if args.cuda: 58 | result = result.cuda() 59 | result.zero_() 60 | indexes_rank = len(indexes.size()) 61 | result.scatter_( 62 | dim=indexes_rank, 63 | index=indexes.data.unsqueeze(dim=indexes_rank), 64 | value=1 65 | ) 66 | return Variable(result) 67 | 68 | 69 | class Net(nn.Module): 70 | def __init__(self): 71 | super(Net, self).__init__() 72 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 73 | self.conv1_bn = nn.BatchNorm2d(10) 74 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 75 | self.conv2_bn = nn.BatchNorm2d(20) 76 | self.conv2_drop = nn.Dropout2d() 77 | if args.dni: 78 | self.backward_interface = dni.BackwardInterface(ConvSynthesizer()) 79 | self.fc1 = nn.Linear(320, 50) 80 | self.fc1_bn = nn.BatchNorm1d(50) 81 | self.fc2 = nn.Linear(50, 10) 82 | self.fc2_bn = nn.BatchNorm1d(10) 83 | 84 | def forward(self, x, y=None): 85 | x = F.relu(F.max_pool2d(self.conv1_bn(self.conv1(x)), 2)) 86 | x = F.max_pool2d(self.conv2_drop(self.conv2_bn(self.conv2(x))), 2) 87 | if args.dni and self.training: 88 | if args.context: 89 | context = one_hot(y, 10) 90 | else: 91 | context = None 92 | with dni.synthesizer_context(context): 93 | x = self.backward_interface(x) 94 | x = F.relu(x) 95 | x = x.view(-1, 320) 96 | x = F.relu(self.fc1_bn(self.fc1(x))) 97 | x = F.dropout(x, training=self.training) 98 | x = self.fc2_bn(self.fc2(x)) 99 | return F.log_softmax(x) 100 | 101 | 102 | class ConvSynthesizer(nn.Module): 103 | def __init__(self): 104 | super(ConvSynthesizer, self).__init__() 105 | self.input_trigger = nn.Conv2d(20, 20, kernel_size=5, padding=2) 106 | self.input_context = nn.Linear(10, 20) 107 | self.hidden = nn.Conv2d(20, 20, kernel_size=5, padding=2) 108 | self.output = nn.Conv2d(20, 20, kernel_size=5, padding=2) 109 | # zero-initialize the last layer, as in the paper 110 | nn.init.constant(self.output.weight, 0) 111 | 112 | def forward(self, trigger, context): 113 | x = self.input_trigger(trigger) 114 | if context is not None: 115 | x += ( 116 | self.input_context(context).unsqueeze(2) 117 | .unsqueeze(3) 118 | .expand_as(x) 119 | ) 120 | x = self.hidden(F.relu(x)) 121 | return self.output(F.relu(x)) 122 | 123 | 124 | model = Net() 125 | if args.cuda: 126 | model.cuda() 127 | 128 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 129 | 130 | def train(epoch): 131 | model.train() 132 | for batch_idx, (data, target) in enumerate(train_loader): 133 | if args.cuda: 134 | data, target = data.cuda(), target.cuda() 135 | data, target = Variable(data), Variable(target) 136 | optimizer.zero_grad() 137 | output = model(data, target) 138 | loss = F.nll_loss(output, target) 139 | loss.backward() 140 | optimizer.step() 141 | if batch_idx % args.log_interval == 0: 142 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 143 | epoch, batch_idx * len(data), len(train_loader.dataset), 144 | 100. * batch_idx / len(train_loader), loss.data[0])) 145 | 146 | def test(): 147 | model.eval() 148 | test_loss = 0 149 | correct = 0 150 | for data, target in test_loader: 151 | if args.cuda: 152 | data, target = data.cuda(), target.cuda() 153 | data, target = Variable(data, volatile=True), Variable(target) 154 | output = model(data) 155 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 156 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 157 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 158 | 159 | test_loss /= len(test_loader.dataset) 160 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 161 | test_loss, correct, len(test_loader.dataset), 162 | 100. * correct / len(test_loader.dataset))) 163 | 164 | 165 | for epoch in range(1, args.epochs + 1): 166 | train(epoch) 167 | test() 168 | -------------------------------------------------------------------------------- /examples/mnist-cnn/requirements.txt: -------------------------------------------------------------------------------- 1 | git+git://github.com/koz4k/dni-pytorch.git#egg=dni-pytorch 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /examples/mnist-full-unlock/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /examples/mnist-full-unlock/README.rst: -------------------------------------------------------------------------------- 1 | MNIST with MLP, full unlock 2 | --------------------------- 3 | 4 | This example illustrates how to use ``BidirectionalInterface`` to achieve 5 | a full unlock. Code is mostly copied from the official PyTorch 6 | MNIST example: 7 | https://github.com/pytorch/examples/blob/master/mnist/main.py 8 | 9 | Classification model is replaced by a multi-layer perceptron with two hidden 10 | layers, 256 neurons in each, batch normalization after every layer and ReLU 11 | activation function. DNI is inserted between the last hidden layer and the 12 | output layer (before activation). DNI predicts input for the output layer based 13 | on the input image and gradient of the last hidden layer activation based on 14 | that activation. 15 | 16 | Synthesizers used for both forward and backward interface are MLPs with two 17 | hidden layers with 256 neurons and ReLU activation function. 18 | 19 | To install requirements:: 20 | 21 | $ pip install -r requirements.txt 22 | 23 | To train with regular backpropagation:: 24 | 25 | $ python main.py 26 | 27 | To train with DNI (no label conditioning):: 28 | 29 | $ python main.py --dni 30 | 31 | To train with cDNI (label conditioning):: 32 | 33 | $ python main.py --dni --context 34 | -------------------------------------------------------------------------------- /examples/mnist-full-unlock/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import dni 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 13 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 14 | help='input batch size for training (default: 64)') 15 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 16 | help='input batch size for testing (default: 1000)') 17 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 18 | help='number of epochs to train (default: 10)') 19 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 20 | help='learning rate (default: 0.001)') 21 | parser.add_argument('--no-cuda', action='store_true', default=False, 22 | help='disables CUDA training') 23 | parser.add_argument('--seed', type=int, default=1, metavar='S', 24 | help='random seed (default: 1)') 25 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 26 | help='how many batches to wait before logging training status') 27 | parser.add_argument('--dni', action='store_true', default=False, 28 | help='enable DNI') 29 | parser.add_argument('--context', action='store_true', default=False, 30 | help='enable context (label conditioning) in DNI') 31 | args = parser.parse_args() 32 | args.cuda = not args.no_cuda and torch.cuda.is_available() 33 | 34 | torch.manual_seed(args.seed) 35 | if args.cuda: 36 | torch.cuda.manual_seed(args.seed) 37 | 38 | 39 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 40 | train_loader = torch.utils.data.DataLoader( 41 | datasets.MNIST('../data', train=True, download=True, 42 | transform=transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.1307,), (0.3081,)) 45 | ])), 46 | batch_size=args.batch_size, shuffle=True, **kwargs) 47 | test_loader = torch.utils.data.DataLoader( 48 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,)) 51 | ])), 52 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 53 | 54 | 55 | def one_hot(indexes, n_classes): 56 | result = torch.FloatTensor(indexes.size() + (n_classes,)) 57 | if args.cuda: 58 | result = result.cuda() 59 | result.zero_() 60 | indexes_rank = len(indexes.size()) 61 | result.scatter_( 62 | dim=indexes_rank, 63 | index=indexes.data.unsqueeze(dim=indexes_rank), 64 | value=1 65 | ) 66 | return Variable(result) 67 | 68 | 69 | class Net(nn.Module): 70 | def __init__(self): 71 | super(Net, self).__init__() 72 | self.hidden1 = nn.Linear(784, 256, bias=False) 73 | self.hidden1_bn = nn.BatchNorm1d(256) 74 | self.hidden2 = nn.Linear(256, 256, bias=False) 75 | self.hidden2_bn = nn.BatchNorm1d(256) 76 | if args.dni: 77 | if args.context: 78 | context_dim = 10 79 | else: 80 | context_dim = None 81 | self.bidirectional_interface = dni.BidirectionalInterface( 82 | dni.BasicSynthesizer( 83 | output_dim=256, n_hidden=2, trigger_dim=784, 84 | context_dim=context_dim 85 | ), 86 | dni.BasicSynthesizer( 87 | output_dim=256, n_hidden=2, context_dim=context_dim 88 | ) 89 | ) 90 | self.output = nn.Linear(256, 10, bias=False) 91 | self.output_bn = nn.BatchNorm1d(10) 92 | 93 | def forward(self, x, y=None): 94 | input_flat = x.view(x.size()[0], -1) 95 | x = self.hidden1_bn(self.hidden1(input_flat)) 96 | x = self.hidden2_bn(self.hidden2(F.relu(x))) 97 | if args.dni and self.training: 98 | if args.context: 99 | context = one_hot(y, 10) 100 | else: 101 | context = None 102 | with dni.synthesizer_context(context): 103 | x = self.bidirectional_interface(x, input_flat) 104 | x = self.output_bn(self.output(F.relu(x))) 105 | return F.log_softmax(x) 106 | 107 | model = Net() 108 | if args.cuda: 109 | model.cuda() 110 | 111 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 112 | 113 | def train(epoch): 114 | model.train() 115 | for batch_idx, (data, target) in enumerate(train_loader): 116 | if args.cuda: 117 | data, target = data.cuda(), target.cuda() 118 | data, target = Variable(data), Variable(target) 119 | optimizer.zero_grad() 120 | output = model(data, target) 121 | loss = F.nll_loss(output, target) 122 | loss.backward() 123 | optimizer.step() 124 | if batch_idx % args.log_interval == 0: 125 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 126 | epoch, batch_idx * len(data), len(train_loader.dataset), 127 | 100. * batch_idx / len(train_loader), loss.data[0])) 128 | 129 | def test(): 130 | model.eval() 131 | test_loss = 0 132 | correct = 0 133 | for data, target in test_loader: 134 | if args.cuda: 135 | data, target = data.cuda(), target.cuda() 136 | data, target = Variable(data, volatile=True), Variable(target) 137 | output = model(data) 138 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 139 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 140 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 141 | 142 | test_loss /= len(test_loader.dataset) 143 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 144 | test_loss, correct, len(test_loader.dataset), 145 | 100. * correct / len(test_loader.dataset))) 146 | 147 | 148 | for epoch in range(1, args.epochs + 1): 149 | train(epoch) 150 | test() 151 | -------------------------------------------------------------------------------- /examples/mnist-full-unlock/requirements.txt: -------------------------------------------------------------------------------- 1 | git+git://github.com/koz4k/dni-pytorch.git#egg=dni-pytorch 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /examples/mnist-mlp/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /examples/mnist-mlp/README.rst: -------------------------------------------------------------------------------- 1 | MNIST with MLP 2 | -------------- 3 | 4 | This example illustrates how to use ``BackwardInterface`` with a simple 5 | multi-layer perceptron. Code is mostly copied from the official PyTorch 6 | MNIST example: 7 | https://github.com/pytorch/examples/blob/master/mnist/main.py 8 | 9 | Classification model is replaced by a multi-layer perceptron with two hidden 10 | layers, 256 neurons in each, batch normalization after every layer and ReLU 11 | activation function. DNI is inserted between the last hidden layer and the 12 | output layer (before activation). 13 | 14 | Synthesizer used is an MLP with one hidden layer with 256 neurons and ReLU 15 | activation function. 16 | 17 | To install requirements:: 18 | 19 | $ pip install -r requirements.txt 20 | 21 | To train with regular backpropagation:: 22 | 23 | $ python main.py 24 | 25 | To train with DNI (no label conditioning):: 26 | 27 | $ python main.py --dni 28 | 29 | To train with cDNI (label conditioning):: 30 | 31 | $ python main.py --dni --context 32 | -------------------------------------------------------------------------------- /examples/mnist-mlp/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import dni 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 13 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 14 | help='input batch size for training (default: 64)') 15 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 16 | help='input batch size for testing (default: 1000)') 17 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 18 | help='number of epochs to train (default: 10)') 19 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 20 | help='learning rate (default: 0.001)') 21 | parser.add_argument('--no-cuda', action='store_true', default=False, 22 | help='disables CUDA training') 23 | parser.add_argument('--seed', type=int, default=1, metavar='S', 24 | help='random seed (default: 1)') 25 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 26 | help='how many batches to wait before logging training status') 27 | parser.add_argument('--dni', action='store_true', default=False, 28 | help='enable DNI') 29 | parser.add_argument('--context', action='store_true', default=False, 30 | help='enable context (label conditioning) in DNI') 31 | args = parser.parse_args() 32 | args.cuda = not args.no_cuda and torch.cuda.is_available() 33 | 34 | torch.manual_seed(args.seed) 35 | if args.cuda: 36 | torch.cuda.manual_seed(args.seed) 37 | 38 | 39 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 40 | train_loader = torch.utils.data.DataLoader( 41 | datasets.MNIST('../data', train=True, download=True, 42 | transform=transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.1307,), (0.3081,)) 45 | ])), 46 | batch_size=args.batch_size, shuffle=True, **kwargs) 47 | test_loader = torch.utils.data.DataLoader( 48 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,)) 51 | ])), 52 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 53 | 54 | 55 | def one_hot(indexes, n_classes): 56 | result = torch.FloatTensor(indexes.size() + (n_classes,)) 57 | if args.cuda: 58 | result = result.cuda() 59 | result.zero_() 60 | indexes_rank = len(indexes.size()) 61 | result.scatter_( 62 | dim=indexes_rank, 63 | index=indexes.data.unsqueeze(dim=indexes_rank), 64 | value=1 65 | ) 66 | return Variable(result) 67 | 68 | 69 | class Net(nn.Module): 70 | def __init__(self): 71 | super(Net, self).__init__() 72 | self.hidden1 = nn.Linear(784, 256, bias=False) 73 | self.hidden1_bn = nn.BatchNorm1d(256) 74 | self.hidden2 = nn.Linear(256, 256, bias=False) 75 | self.hidden2_bn = nn.BatchNorm1d(256) 76 | if args.dni: 77 | if args.context: 78 | context_dim = 10 79 | else: 80 | context_dim = None 81 | self.backward_interface = dni.BackwardInterface( 82 | dni.BasicSynthesizer( 83 | output_dim=256, n_hidden=1, context_dim=context_dim 84 | ) 85 | ) 86 | self.output = nn.Linear(256, 10, bias=False) 87 | self.output_bn = nn.BatchNorm1d(10) 88 | 89 | def forward(self, x, y=None): 90 | x = x.view(x.size()[0], -1) 91 | x = self.hidden1_bn(self.hidden1(x)) 92 | x = self.hidden2_bn(self.hidden2(F.relu(x))) 93 | if args.dni and self.training: 94 | if args.context: 95 | context = one_hot(y, 10) 96 | else: 97 | context = None 98 | with dni.synthesizer_context(context): 99 | x = self.backward_interface(x) 100 | x = self.output_bn(self.output(F.relu(x))) 101 | return F.log_softmax(x) 102 | 103 | model = Net() 104 | if args.cuda: 105 | model.cuda() 106 | 107 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 108 | 109 | def train(epoch): 110 | model.train() 111 | for batch_idx, (data, target) in enumerate(train_loader): 112 | if args.cuda: 113 | data, target = data.cuda(), target.cuda() 114 | data, target = Variable(data), Variable(target) 115 | optimizer.zero_grad() 116 | output = model(data, target) 117 | loss = F.nll_loss(output, target) 118 | loss.backward() 119 | optimizer.step() 120 | if batch_idx % args.log_interval == 0: 121 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 122 | epoch, batch_idx * len(data), len(train_loader.dataset), 123 | 100. * batch_idx / len(train_loader), loss.data[0])) 124 | 125 | def test(): 126 | model.eval() 127 | test_loss = 0 128 | correct = 0 129 | for data, target in test_loader: 130 | if args.cuda: 131 | data, target = data.cuda(), target.cuda() 132 | data, target = Variable(data, volatile=True), Variable(target) 133 | output = model(data) 134 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss 135 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability 136 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() 137 | 138 | test_loss /= len(test_loader.dataset) 139 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 140 | test_loss, correct, len(test_loader.dataset), 141 | 100. * correct / len(test_loader.dataset))) 142 | 143 | 144 | for epoch in range(1, args.epochs + 1): 145 | train(epoch) 146 | test() 147 | -------------------------------------------------------------------------------- /examples/mnist-mlp/requirements.txt: -------------------------------------------------------------------------------- 1 | git+git://github.com/koz4k/dni-pytorch.git#egg=dni-pytorch 2 | torch 3 | torchvision 4 | -------------------------------------------------------------------------------- /examples/rnn/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /examples/rnn/README.rst: -------------------------------------------------------------------------------- 1 | Word-level language modeling 2 | ---------------------------- 3 | 4 | This example illustrates how to use ``BackwardInterface`` with an RNN to 5 | approximate gradient from an infinitely-unrolled sequence. Code is mostly 6 | copied from the official PyTorch word-level language modeling example: 7 | https://github.com/pytorch/examples/blob/master/word_language_model 8 | 9 | Synthesizer used is an MLP with two hidden layers and ReLU activation function. 10 | 11 | In the example training commands below, BPTT length was reduced to 5 to 12 | highlight the ability to train on shorter sequences using DNI. 13 | 14 | To install requirements:: 15 | 16 | $ pip install -r requirements.txt 17 | 18 | To train with regular backpropagation through time:: 19 | 20 | $ python main.py --cuda --bptt 5 --epochs 6 21 | 22 | To train with DNI:: 23 | 24 | $ python main.py --cuda --bptt 5 --epochs 6 --dni 25 | -------------------------------------------------------------------------------- /examples/rnn/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | class Dictionary(object): 5 | def __init__(self): 6 | self.word2idx = {} 7 | self.idx2word = [] 8 | 9 | def add_word(self, word): 10 | if word not in self.word2idx: 11 | self.idx2word.append(word) 12 | self.word2idx[word] = len(self.idx2word) - 1 13 | return self.word2idx[word] 14 | 15 | def __len__(self): 16 | return len(self.idx2word) 17 | 18 | 19 | class Corpus(object): 20 | def __init__(self, path): 21 | self.dictionary = Dictionary() 22 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 23 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 24 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 25 | 26 | def tokenize(self, path): 27 | """Tokenizes a text file.""" 28 | assert os.path.exists(path) 29 | # Add words to the dictionary 30 | with open(path, 'r') as f: 31 | tokens = 0 32 | for line in f: 33 | words = line.split() + [''] 34 | tokens += len(words) 35 | for word in words: 36 | self.dictionary.add_word(word) 37 | 38 | # Tokenize file content 39 | with open(path, 'r') as f: 40 | ids = torch.LongTensor(tokens) 41 | token = 0 42 | for line in f: 43 | words = line.split() + [''] 44 | for word in words: 45 | ids[token] = self.dictionary.word2idx[word] 46 | token += 1 47 | 48 | return ids 49 | -------------------------------------------------------------------------------- /examples/rnn/generate.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Language Modeling on Penn Tree Bank 3 | # 4 | # This file generates new sentences sampled from the language model 5 | # 6 | ############################################################################### 7 | 8 | import argparse 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | import data 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') 16 | 17 | # Model parameters. 18 | parser.add_argument('--data', type=str, default='./data/penn', 19 | help='location of the data corpus') 20 | parser.add_argument('--checkpoint', type=str, default='./model.pt', 21 | help='model checkpoint to use') 22 | parser.add_argument('--outf', type=str, default='generated.txt', 23 | help='output file for generated text') 24 | parser.add_argument('--words', type=int, default='1000', 25 | help='number of words to generate') 26 | parser.add_argument('--seed', type=int, default=1111, 27 | help='random seed') 28 | parser.add_argument('--cuda', action='store_true', 29 | help='use CUDA') 30 | parser.add_argument('--temperature', type=float, default=1.0, 31 | help='temperature - higher will increase diversity') 32 | parser.add_argument('--log-interval', type=int, default=100, 33 | help='reporting interval') 34 | args = parser.parse_args() 35 | 36 | # Set the random seed manually for reproducibility. 37 | torch.manual_seed(args.seed) 38 | if torch.cuda.is_available(): 39 | if not args.cuda: 40 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 41 | else: 42 | torch.cuda.manual_seed(args.seed) 43 | 44 | if args.temperature < 1e-3: 45 | parser.error("--temperature has to be greater or equal 1e-3") 46 | 47 | with open(args.checkpoint, 'rb') as f: 48 | model = torch.load(f) 49 | model.eval() 50 | 51 | if args.cuda: 52 | model.cuda() 53 | else: 54 | model.cpu() 55 | 56 | corpus = data.Corpus(args.data) 57 | ntokens = len(corpus.dictionary) 58 | hidden = model.init_hidden(1) 59 | input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) 60 | if args.cuda: 61 | input.data = input.data.cuda() 62 | 63 | with open(args.outf, 'w') as outf: 64 | for i in range(args.words): 65 | output, hidden = model(input, hidden) 66 | word_weights = output.squeeze().data.div(args.temperature).exp().cpu() 67 | word_idx = torch.multinomial(word_weights, 1)[0] 68 | input.data.fill_(word_idx) 69 | word = corpus.dictionary.idx2word[word_idx] 70 | 71 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 72 | 73 | if i % args.log_interval == 0: 74 | print('| Generated {}/{} words'.format(i, args.words)) 75 | -------------------------------------------------------------------------------- /examples/rnn/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import dni 8 | 9 | import data 10 | import model 11 | 12 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 13 | parser.add_argument('--data', type=str, default='./data/penn', 14 | help='location of the data corpus') 15 | parser.add_argument('--model', type=str, default='LSTM', 16 | help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') 17 | parser.add_argument('--emsize', type=int, default=200, 18 | help='size of word embeddings') 19 | parser.add_argument('--nhid', type=int, default=200, 20 | help='number of hidden units per layer') 21 | parser.add_argument('--nlayers', type=int, default=2, 22 | help='number of layers') 23 | parser.add_argument('--lr', type=float, default=0.001, 24 | help='initial learning rate') 25 | parser.add_argument('--clip', type=float, default=0.25, 26 | help='gradient clipping') 27 | parser.add_argument('--epochs', type=int, default=40, 28 | help='upper epoch limit') 29 | parser.add_argument('--batch_size', type=int, default=20, metavar='N', 30 | help='batch size') 31 | parser.add_argument('--bptt', type=int, default=35, 32 | help='sequence length') 33 | parser.add_argument('--dropout', type=float, default=0.2, 34 | help='dropout applied to layers (0 = no dropout)') 35 | parser.add_argument('--tied', action='store_true', 36 | help='tie the word embedding and softmax weights') 37 | parser.add_argument('--seed', type=int, default=1111, 38 | help='random seed') 39 | parser.add_argument('--cuda', action='store_true', 40 | help='use CUDA') 41 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 42 | help='report interval') 43 | parser.add_argument('--save', type=str, default='model.pt', 44 | help='path to save the final model') 45 | parser.add_argument('--dni', action='store_true', default=False, 46 | help='enable DNI') 47 | args = parser.parse_args() 48 | 49 | # Set the random seed manually for reproducibility. 50 | torch.manual_seed(args.seed) 51 | if torch.cuda.is_available(): 52 | if not args.cuda: 53 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 54 | else: 55 | torch.cuda.manual_seed(args.seed) 56 | 57 | ############################################################################### 58 | # Load data 59 | ############################################################################### 60 | 61 | corpus = data.Corpus(args.data) 62 | 63 | def batchify(data, bsz): 64 | # Work out how cleanly we can divide the dataset into bsz parts. 65 | nbatch = data.size(0) // bsz 66 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 67 | data = data.narrow(0, 0, nbatch * bsz) 68 | # Evenly divide the data across the bsz batches. 69 | data = data.view(bsz, -1).t().contiguous() 70 | if args.cuda: 71 | data = data.cuda() 72 | return data 73 | 74 | eval_batch_size = 10 75 | train_data = batchify(corpus.train, args.batch_size) 76 | val_data = batchify(corpus.valid, eval_batch_size) 77 | test_data = batchify(corpus.test, eval_batch_size) 78 | 79 | ############################################################################### 80 | # Build the model 81 | ############################################################################### 82 | 83 | ntokens = len(corpus.dictionary) 84 | model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied, args.dni) 85 | if args.cuda: 86 | model.cuda() 87 | 88 | criterion = nn.CrossEntropyLoss() 89 | 90 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 91 | 92 | ############################################################################### 93 | # Training code 94 | ############################################################################### 95 | 96 | def repackage_hidden(h): 97 | """Wraps hidden states in new Variables, to detach them from their history.""" 98 | if type(h) == Variable: 99 | return Variable(h.data) 100 | else: 101 | return tuple(repackage_hidden(v) for v in h) 102 | 103 | 104 | def get_batch(source, i, evaluation=False): 105 | seq_len = min(args.bptt, len(source) - 1 - i) 106 | data = Variable(source[i:i+seq_len], volatile=evaluation) 107 | target = Variable(source[i+1:i+1+seq_len].view(-1)) 108 | return data, target 109 | 110 | 111 | def evaluate(data_source): 112 | # Turn on evaluation mode which disables dropout. 113 | model.eval() 114 | total_loss = 0 115 | ntokens = len(corpus.dictionary) 116 | hidden = model.init_hidden(eval_batch_size) 117 | for i in range(0, data_source.size(0) - 1, args.bptt): 118 | data, targets = get_batch(data_source, i, evaluation=True) 119 | output, hidden = model(data, hidden) 120 | output_flat = output.view(-1, ntokens) 121 | total_loss += len(data) * criterion(output_flat, targets).data 122 | hidden = repackage_hidden(hidden) 123 | return total_loss[0] / len(data_source) 124 | 125 | 126 | def train(): 127 | # Turn on training mode which enables dropout. 128 | model.train() 129 | total_loss = 0 130 | start_time = time.time() 131 | ntokens = len(corpus.dictionary) 132 | hidden = model.init_hidden(args.batch_size) 133 | for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): 134 | data, targets = get_batch(train_data, i) 135 | # Starting each batch, we detach the hidden state from how it was previously produced. 136 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 137 | hidden = repackage_hidden(hidden) 138 | optimizer.zero_grad() 139 | with dni.defer_backward(): 140 | output, hidden = model(data, hidden) 141 | loss = criterion(output.view(-1, ntokens), targets) 142 | dni.backward(loss) 143 | 144 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 145 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 146 | optimizer.step() 147 | 148 | total_loss += loss.data 149 | 150 | if batch % args.log_interval == 0 and batch > 0: 151 | cur_loss = total_loss[0] / args.log_interval 152 | elapsed = time.time() - start_time 153 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 154 | 'loss {:5.2f} | ppl {:8.2f}'.format( 155 | epoch, batch, len(train_data) // args.bptt, lr, 156 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) 157 | total_loss = 0 158 | start_time = time.time() 159 | 160 | # Loop over epochs. 161 | lr = args.lr 162 | best_val_loss = None 163 | 164 | # At any point you can hit Ctrl + C to break out of training early. 165 | try: 166 | for epoch in range(1, args.epochs+1): 167 | epoch_start_time = time.time() 168 | train() 169 | val_loss = evaluate(val_data) 170 | print('-' * 89) 171 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 172 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 173 | val_loss, math.exp(val_loss))) 174 | print('-' * 89) 175 | # Save the model if the validation loss is the best we've seen so far. 176 | if not best_val_loss or val_loss < best_val_loss: 177 | with open(args.save, 'wb') as f: 178 | torch.save(model, f) 179 | best_val_loss = val_loss 180 | else: 181 | # Anneal the learning rate if no improvement has been seen in the validation dataset. 182 | lr /= 4.0 183 | except KeyboardInterrupt: 184 | print('-' * 89) 185 | print('Exiting from training early') 186 | 187 | # Load the best saved model. 188 | with open(args.save, 'rb') as f: 189 | model = torch.load(f) 190 | 191 | # Run on test data. 192 | test_loss = evaluate(test_data) 193 | print('=' * 89) 194 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 195 | test_loss, math.exp(test_loss))) 196 | print('=' * 89) 197 | -------------------------------------------------------------------------------- /examples/rnn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import dni 5 | 6 | class RNNModel(nn.Module): 7 | """Container module with an encoder, a recurrent module, and a decoder.""" 8 | 9 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False, use_dni=False): 10 | super(RNNModel, self).__init__() 11 | self.drop = nn.Dropout(dropout) 12 | self.encoder = nn.Embedding(ntoken, ninp) 13 | if rnn_type in ['LSTM', 'GRU']: 14 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 15 | else: 16 | try: 17 | nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] 18 | except KeyError: 19 | raise ValueError( """An invalid option for `--model` was supplied, 20 | options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") 21 | self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) 22 | self.decoder = nn.Linear(nhid, ntoken) 23 | 24 | # Optionally tie weights as in: 25 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 26 | # https://arxiv.org/abs/1608.05859 27 | # and 28 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 29 | # https://arxiv.org/abs/1611.01462 30 | if tie_weights: 31 | if nhid != ninp: 32 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 33 | self.decoder.weight = self.encoder.weight 34 | 35 | self.init_weights() 36 | 37 | self.rnn_type = rnn_type 38 | self.nhid = nhid 39 | self.nlayers = nlayers 40 | 41 | if use_dni: 42 | if rnn_type == 'LSTM': 43 | output_dim = 2 * nhid 44 | else: 45 | output_dim = nhid 46 | self.backward_interface = dni.BackwardInterface( 47 | dni.BasicSynthesizer(output_dim, n_hidden=2) 48 | ) 49 | else: 50 | self.backward_interface = None 51 | 52 | def init_weights(self): 53 | initrange = 0.1 54 | self.encoder.weight.data.uniform_(-initrange, initrange) 55 | self.decoder.bias.data.fill_(0) 56 | self.decoder.weight.data.uniform_(-initrange, initrange) 57 | 58 | def join_hidden(self, hidden): 59 | if self.rnn_type == 'LSTM': 60 | hidden = torch.cat(hidden, dim=2) 61 | return hidden 62 | 63 | def split_hidden(self, hidden): 64 | if self.rnn_type == 'LSTM': 65 | (h, c) = hidden.chunk(2, dim=2) 66 | hidden = (h.contiguous(), c.contiguous()) 67 | return hidden 68 | 69 | def forward(self, input, hidden): 70 | emb = self.drop(self.encoder(input)) 71 | if self.backward_interface is not None: 72 | # for LSTM, predict gradient for both cell state and output 73 | # to do that, concatenate them before feeding to DNI 74 | hidden = self.join_hidden(hidden) 75 | hidden = self.backward_interface.make_trigger(hidden) 76 | hidden = self.split_hidden(hidden) 77 | output, hidden = self.rnn(emb, hidden) 78 | if self.backward_interface is not None: 79 | hidden = self.join_hidden(hidden) 80 | # scale synthetic gradient by a factor of 0.1, as in the paper 81 | self.backward_interface.backward(hidden, factor=0.1) 82 | hidden = self.split_hidden(hidden) 83 | output = self.drop(output) 84 | decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2))) 85 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 86 | 87 | def init_hidden(self, bsz): 88 | weight = next(self.parameters()).data 89 | if self.rnn_type == 'LSTM': 90 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()), 91 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())) 92 | else: 93 | return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()) 94 | -------------------------------------------------------------------------------- /examples/rnn/requirements.txt: -------------------------------------------------------------------------------- 1 | git+git://github.com/koz4k/dni-pytorch.git#egg=dni-pytorch 2 | torch 3 | -------------------------------------------------------------------------------- /images/feedforward-complete-unlock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koz4k/dni-pytorch/b14a07cab049a72c104b57dcdb56730ab8dbafb1/images/feedforward-complete-unlock.png -------------------------------------------------------------------------------- /images/feedforward-update-unlock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koz4k/dni-pytorch/b14a07cab049a72c104b57dcdb56730ab8dbafb1/images/feedforward-update-unlock.png -------------------------------------------------------------------------------- /images/rnn-update-unlock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koz4k/dni-pytorch/b14a07cab049a72c104b57dcdb56730ab8dbafb1/images/rnn-update-unlock.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup( 5 | name='dni-pytorch', 6 | version='0.1.0', 7 | author='Piotr Kozakowski', 8 | author_email='kozak000@gmail.com', 9 | url='https://github.com/koz4k/dni-pytorch', 10 | description=( 11 | 'Decoupled Neural Interfaces using Synthetic Gradients for PyTorch' 12 | ), 13 | py_modules=['dni'], 14 | install_requires=[ 15 | 'torch>=0.2.0' 16 | ] 17 | ) 18 | --------------------------------------------------------------------------------