├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dnc ├── __init__.py ├── access.py ├── access_test.py ├── addressing.py ├── addressing_test.py ├── dnc.py ├── repeat_copy.py ├── util.py └── util_test.py ├── images └── dnc_model.png ├── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | 101 | # vscode and its extensions 102 | .vscode/* 103 | .history/* -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult [GitHub Help] for more 22 | information on using pull requests. 23 | 24 | [GitHub Help]: https://help.github.com/articles/about-pull-requests/ 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differentiable Neural Computer (DNC) 2 | 3 | This package provides an implementation of the Differentiable Neural Computer, 4 | as [published in Nature]( 5 | https://www.nature.com/articles/nature20101.epdf?author_access_token=ImTXBI8aWbYxYQ51Plys8NRgN0jAjWel9jnR3ZoTv0MggmpDmwljGswxVdeocYSurJ3hxupzWuRNeGvvXnoO8o4jTJcnAyhGuZzXJ1GEaD-Z7E6X_a9R-xqJ9TfJWBqz). 6 | 7 | Any publication that discloses findings arising from using this source code must 8 | cite “Hybrid computing using a neural network with dynamic external memory", 9 | Nature 538, 471–476 (October 2016) doi:10.1038/nature20101. 10 | 11 | ## Introduction 12 | 13 | The Differentiable Neural Computer is a recurrent neural network. At each 14 | timestep, it has state consisting of the current memory contents (and auxiliary 15 | information such as memory usage), and maps input at time `t` to output at time 16 | `t`. It is implemented as a collection of `RNNCore` modules, which allow 17 | plugging together the different modules to experiment with variations on the 18 | architecture. 19 | 20 | * The *access* module is where the main DNC logic happens; as this is where 21 | memory is written to and read from. At every timestep, the input to an 22 | access module is a vector passed from the `controller`, and its output is 23 | the contents read from memory. It uses two futher `RNNCore`s: 24 | `TemporalLinkage` which tracks the order of memory writes, and `Freeness` 25 | which tracks which memory locations have been written to and not yet 26 | subsequently "freed". These are both defined in `addressing.py`. 27 | 28 | * The *controller* module "controls" memory access. Typically, it is just a 29 | feedforward or (possibly deep) LSTM network, whose inputs are the inputs to 30 | the overall recurrent network at that time, concatenated with the read 31 | memory output from the access module from the previous timestep. 32 | 33 | * The *dnc* simply wraps the access module and the control module, and forms 34 | the basic `RNNCore` unit of the overall architecture. This is defined in 35 | `dnc.py`. 36 | 37 | ![DNC architecture](images/dnc_model.png) 38 | 39 | ## Train 40 | The `DNC` requires an installation of [TensorFlow](https://www.tensorflow.org/) 41 | and [Sonnet](https://github.com/deepmind/sonnet). An example training script is 42 | provided for the algorithmic task of repeatedly copying a given input string. 43 | This can be executed from a python interpreter: 44 | 45 | ```shell 46 | $ ipython train.py 47 | ``` 48 | 49 | You can specify training options, including parameters to the model 50 | and optimizer, via flags: 51 | 52 | ```shell 53 | $ python train.py --memory_size=64 --num_bits=8 --max_length=3 54 | 55 | # Or with ipython: 56 | $ ipython train.py -- --memory_size=64 --num_bits=8 --max_length=3 57 | ``` 58 | 59 | Periodically saving, or 'checkpointing', the model is disabled by default. To 60 | enable, use the `checkpoint_interval` flag. E.g. `--checkpoint_interval=10000` 61 | will ensure a checkpoint is created every `10,000` steps. The model will be 62 | checkpointed to `/tmp/tf/dnc/` by default. From there training can be resumed. 63 | To specify an alternate checkpoint directory, use the `checkpoint_dir` flag. 64 | Note: ensure that `/tmp/tf/dnc/` is deleted before training is resumed with 65 | different model parameters, to avoid shape inconsistency errors. 66 | 67 | More generally, the `DNC` class found within `dnc.py` can be used as a standard 68 | TensorFlow rnn core and unrolled with TensorFlow rnn ops, such as 69 | `tf.nn.dynamic_rnn` on any sequential task. 70 | 71 | Disclaimer: This is not an official Google product 72 | -------------------------------------------------------------------------------- /dnc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dnc/f5981c66ac6fe27978014c314c4aba39f1962ee3/dnc/__init__.py -------------------------------------------------------------------------------- /dnc/access.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DNC access modules.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import sonnet as snt 23 | import tensorflow as tf 24 | 25 | from dnc import addressing 26 | from dnc import util 27 | 28 | AccessState = collections.namedtuple('AccessState', ( 29 | 'memory', 'read_weights', 'write_weights', 'linkage', 'usage')) 30 | 31 | 32 | def _erase_and_write(memory, address, reset_weights, values): 33 | """Module to erase and write in the external memory. 34 | 35 | Erase operation: 36 | M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t) 37 | 38 | Add operation: 39 | M_t(i) = M_t'(i) + w_t(i) * a_t 40 | 41 | where e are the reset_weights, w the write weights and a the values. 42 | 43 | Args: 44 | memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`. 45 | address: 3-D tensor `[batch_size, num_writes, memory_size]`. 46 | reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`. 47 | values: 3-D tensor `[batch_size, num_writes, word_size]`. 48 | 49 | Returns: 50 | 3-D tensor of shape `[batch_size, num_writes, word_size]`. 51 | """ 52 | with tf.name_scope('erase_memory', values=[memory, address, reset_weights]): 53 | expand_address = tf.expand_dims(address, 3) 54 | reset_weights = tf.expand_dims(reset_weights, 2) 55 | weighted_resets = expand_address * reset_weights 56 | reset_gate = util.reduce_prod(1 - weighted_resets, 1) 57 | memory *= reset_gate 58 | 59 | with tf.name_scope('additive_write', values=[memory, address, values]): 60 | add_matrix = tf.matmul(address, values, adjoint_a=True) 61 | memory += add_matrix 62 | 63 | return memory 64 | 65 | 66 | class MemoryAccess(snt.RNNCore): 67 | """Access module of the Differentiable Neural Computer. 68 | 69 | This memory module supports multiple read and write heads. It makes use of: 70 | 71 | * `addressing.TemporalLinkage` to track the temporal ordering of writes in 72 | memory for each write head. 73 | * `addressing.FreenessAllocator` for keeping track of memory usage, where 74 | usage increase when a memory location is written to, and decreases when 75 | memory is read from that the controller says can be freed. 76 | 77 | Write-address selection is done by an interpolation between content-based 78 | lookup and using unused memory. 79 | 80 | Read-address selection is done by an interpolation of content-based lookup 81 | and following the link graph in the forward or backwards read direction. 82 | """ 83 | 84 | def __init__(self, 85 | memory_size=128, 86 | word_size=20, 87 | num_reads=1, 88 | num_writes=1, 89 | name='memory_access'): 90 | """Creates a MemoryAccess module. 91 | 92 | Args: 93 | memory_size: The number of memory slots (N in the DNC paper). 94 | word_size: The width of each memory slot (W in the DNC paper) 95 | num_reads: The number of read heads (R in the DNC paper). 96 | num_writes: The number of write heads (fixed at 1 in the paper). 97 | name: The name of the module. 98 | """ 99 | super(MemoryAccess, self).__init__(name=name) 100 | self._memory_size = memory_size 101 | self._word_size = word_size 102 | self._num_reads = num_reads 103 | self._num_writes = num_writes 104 | 105 | self._write_content_weights_mod = addressing.CosineWeights( 106 | num_writes, word_size, name='write_content_weights') 107 | self._read_content_weights_mod = addressing.CosineWeights( 108 | num_reads, word_size, name='read_content_weights') 109 | 110 | self._linkage = addressing.TemporalLinkage(memory_size, num_writes) 111 | self._freeness = addressing.Freeness(memory_size) 112 | 113 | def _build(self, inputs, prev_state): 114 | """Connects the MemoryAccess module into the graph. 115 | 116 | Args: 117 | inputs: tensor of shape `[batch_size, input_size]`. This is used to 118 | control this access module. 119 | prev_state: Instance of `AccessState` containing the previous state. 120 | 121 | Returns: 122 | A tuple `(output, next_state)`, where `output` is a tensor of shape 123 | `[batch_size, num_reads, word_size]`, and `next_state` is the new 124 | `AccessState` named tuple at the current time t. 125 | """ 126 | inputs = self._read_inputs(inputs) 127 | 128 | # Update usage using inputs['free_gate'] and previous read & write weights. 129 | usage = self._freeness( 130 | write_weights=prev_state.write_weights, 131 | free_gate=inputs['free_gate'], 132 | read_weights=prev_state.read_weights, 133 | prev_usage=prev_state.usage) 134 | 135 | # Write to memory. 136 | write_weights = self._write_weights(inputs, prev_state.memory, usage) 137 | memory = _erase_and_write( 138 | prev_state.memory, 139 | address=write_weights, 140 | reset_weights=inputs['erase_vectors'], 141 | values=inputs['write_vectors']) 142 | 143 | linkage_state = self._linkage(write_weights, prev_state.linkage) 144 | 145 | # Read from memory. 146 | read_weights = self._read_weights( 147 | inputs, 148 | memory=memory, 149 | prev_read_weights=prev_state.read_weights, 150 | link=linkage_state.link) 151 | read_words = tf.matmul(read_weights, memory) 152 | 153 | return (read_words, AccessState( 154 | memory=memory, 155 | read_weights=read_weights, 156 | write_weights=write_weights, 157 | linkage=linkage_state, 158 | usage=usage)) 159 | 160 | def _read_inputs(self, inputs): 161 | """Applies transformations to `inputs` to get control for this module.""" 162 | 163 | def _linear(first_dim, second_dim, name, activation=None): 164 | """Returns a linear transformation of `inputs`, followed by a reshape.""" 165 | linear = snt.Linear(first_dim * second_dim, name=name)(inputs) 166 | if activation is not None: 167 | linear = activation(linear, name=name + '_activation') 168 | return tf.reshape(linear, [-1, first_dim, second_dim]) 169 | 170 | # v_t^i - The vectors to write to memory, for each write head `i`. 171 | write_vectors = _linear(self._num_writes, self._word_size, 'write_vectors') 172 | 173 | # e_t^i - Amount to erase the memory by before writing, for each write head. 174 | erase_vectors = _linear(self._num_writes, self._word_size, 'erase_vectors', 175 | tf.sigmoid) 176 | 177 | # f_t^j - Amount that the memory at the locations read from at the previous 178 | # time step can be declared unused, for each read head `j`. 179 | free_gate = tf.sigmoid( 180 | snt.Linear(self._num_reads, name='free_gate')(inputs)) 181 | 182 | # g_t^{a, i} - Interpolation between writing to unallocated memory and 183 | # content-based lookup, for each write head `i`. Note: `a` is simply used to 184 | # identify this gate with allocation vs writing (as defined below). 185 | allocation_gate = tf.sigmoid( 186 | snt.Linear(self._num_writes, name='allocation_gate')(inputs)) 187 | 188 | # g_t^{w, i} - Overall gating of write amount for each write head. 189 | write_gate = tf.sigmoid( 190 | snt.Linear(self._num_writes, name='write_gate')(inputs)) 191 | 192 | # \pi_t^j - Mixing between "backwards" and "forwards" positions (for 193 | # each write head), and content-based lookup, for each read head. 194 | num_read_modes = 1 + 2 * self._num_writes 195 | read_mode = snt.BatchApply(tf.nn.softmax)( 196 | _linear(self._num_reads, num_read_modes, name='read_mode')) 197 | 198 | # Parameters for the (read / write) "weights by content matching" modules. 199 | write_keys = _linear(self._num_writes, self._word_size, 'write_keys') 200 | write_strengths = snt.Linear(self._num_writes, name='write_strengths')( 201 | inputs) 202 | 203 | read_keys = _linear(self._num_reads, self._word_size, 'read_keys') 204 | read_strengths = snt.Linear(self._num_reads, name='read_strengths')(inputs) 205 | 206 | result = { 207 | 'read_content_keys': read_keys, 208 | 'read_content_strengths': read_strengths, 209 | 'write_content_keys': write_keys, 210 | 'write_content_strengths': write_strengths, 211 | 'write_vectors': write_vectors, 212 | 'erase_vectors': erase_vectors, 213 | 'free_gate': free_gate, 214 | 'allocation_gate': allocation_gate, 215 | 'write_gate': write_gate, 216 | 'read_mode': read_mode, 217 | } 218 | return result 219 | 220 | def _write_weights(self, inputs, memory, usage): 221 | """Calculates the memory locations to write to. 222 | 223 | This uses a combination of content-based lookup and finding an unused 224 | location in memory, for each write head. 225 | 226 | Args: 227 | inputs: Collection of inputs to the access module, including controls for 228 | how to chose memory writing, such as the content to look-up and the 229 | weighting between content-based and allocation-based addressing. 230 | memory: A tensor of shape `[batch_size, memory_size, word_size]` 231 | containing the current memory contents. 232 | usage: Current memory usage, which is a tensor of shape `[batch_size, 233 | memory_size]`, used for allocation-based addressing. 234 | 235 | Returns: 236 | tensor of shape `[batch_size, num_writes, memory_size]` indicating where 237 | to write to (if anywhere) for each write head. 238 | """ 239 | with tf.name_scope('write_weights', values=[inputs, memory, usage]): 240 | # c_t^{w, i} - The content-based weights for each write head. 241 | write_content_weights = self._write_content_weights_mod( 242 | memory, inputs['write_content_keys'], 243 | inputs['write_content_strengths']) 244 | 245 | # a_t^i - The allocation weights for each write head. 246 | write_allocation_weights = self._freeness.write_allocation_weights( 247 | usage=usage, 248 | write_gates=(inputs['allocation_gate'] * inputs['write_gate']), 249 | num_writes=self._num_writes) 250 | 251 | # Expands gates over memory locations. 252 | allocation_gate = tf.expand_dims(inputs['allocation_gate'], -1) 253 | write_gate = tf.expand_dims(inputs['write_gate'], -1) 254 | 255 | # w_t^{w, i} - The write weightings for each write head. 256 | return write_gate * (allocation_gate * write_allocation_weights + 257 | (1 - allocation_gate) * write_content_weights) 258 | 259 | def _read_weights(self, inputs, memory, prev_read_weights, link): 260 | """Calculates read weights for each read head. 261 | 262 | The read weights are a combination of following the link graphs in the 263 | forward or backward directions from the previous read position, and doing 264 | content-based lookup. The interpolation between these different modes is 265 | done by `inputs['read_mode']`. 266 | 267 | Args: 268 | inputs: Controls for this access module. This contains the content-based 269 | keys to lookup, and the weightings for the different read modes. 270 | memory: A tensor of shape `[batch_size, memory_size, word_size]` 271 | containing the current memory contents to do content-based lookup. 272 | prev_read_weights: A tensor of shape `[batch_size, num_reads, 273 | memory_size]` containing the previous read locations. 274 | link: A tensor of shape `[batch_size, num_writes, memory_size, 275 | memory_size]` containing the temporal write transition graphs. 276 | 277 | Returns: 278 | A tensor of shape `[batch_size, num_reads, memory_size]` containing the 279 | read weights for each read head. 280 | """ 281 | with tf.name_scope( 282 | 'read_weights', values=[inputs, memory, prev_read_weights, link]): 283 | # c_t^{r, i} - The content weightings for each read head. 284 | content_weights = self._read_content_weights_mod( 285 | memory, inputs['read_content_keys'], inputs['read_content_strengths']) 286 | 287 | # Calculates f_t^i and b_t^i. 288 | forward_weights = self._linkage.directional_read_weights( 289 | link, prev_read_weights, forward=True) 290 | backward_weights = self._linkage.directional_read_weights( 291 | link, prev_read_weights, forward=False) 292 | 293 | backward_mode = inputs['read_mode'][:, :, :self._num_writes] 294 | forward_mode = ( 295 | inputs['read_mode'][:, :, self._num_writes:2 * self._num_writes]) 296 | content_mode = inputs['read_mode'][:, :, 2 * self._num_writes] 297 | 298 | read_weights = ( 299 | tf.expand_dims(content_mode, 2) * content_weights + tf.reduce_sum( 300 | tf.expand_dims(forward_mode, 3) * forward_weights, 2) + 301 | tf.reduce_sum(tf.expand_dims(backward_mode, 3) * backward_weights, 2)) 302 | 303 | return read_weights 304 | 305 | @property 306 | def state_size(self): 307 | """Returns a tuple of the shape of the state tensors.""" 308 | return AccessState( 309 | memory=tf.TensorShape([self._memory_size, self._word_size]), 310 | read_weights=tf.TensorShape([self._num_reads, self._memory_size]), 311 | write_weights=tf.TensorShape([self._num_writes, self._memory_size]), 312 | linkage=self._linkage.state_size, 313 | usage=self._freeness.state_size) 314 | 315 | @property 316 | def output_size(self): 317 | """Returns the output shape.""" 318 | return tf.TensorShape([self._num_reads, self._word_size]) 319 | -------------------------------------------------------------------------------- /dnc/access_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for memory access.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | from tensorflow.python.ops import rnn 24 | 25 | from dnc import access 26 | from dnc import util 27 | 28 | BATCH_SIZE = 2 29 | MEMORY_SIZE = 20 30 | WORD_SIZE = 6 31 | NUM_READS = 2 32 | NUM_WRITES = 3 33 | TIME_STEPS = 4 34 | INPUT_SIZE = 10 35 | 36 | 37 | class MemoryAccessTest(tf.test.TestCase): 38 | 39 | def setUp(self): 40 | self.module = access.MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, 41 | NUM_WRITES) 42 | self.initial_state = self.module.initial_state(BATCH_SIZE) 43 | 44 | def testBuildAndTrain(self): 45 | inputs = tf.random_normal([TIME_STEPS, BATCH_SIZE, INPUT_SIZE]) 46 | 47 | output, _ = rnn.dynamic_rnn( 48 | cell=self.module, 49 | inputs=inputs, 50 | initial_state=self.initial_state, 51 | time_major=True) 52 | 53 | targets = np.random.rand(TIME_STEPS, BATCH_SIZE, NUM_READS, WORD_SIZE) 54 | loss = tf.reduce_mean(tf.square(output - targets)) 55 | train_op = tf.train.GradientDescentOptimizer(1).minimize(loss) 56 | init = tf.global_variables_initializer() 57 | 58 | with self.test_session(): 59 | init.run() 60 | train_op.run() 61 | 62 | def testValidReadMode(self): 63 | inputs = self.module._read_inputs( 64 | tf.random_normal([BATCH_SIZE, INPUT_SIZE])) 65 | init = tf.global_variables_initializer() 66 | 67 | with self.test_session() as sess: 68 | init.run() 69 | inputs = sess.run(inputs) 70 | 71 | # Check that the read modes for each read head constitute a probability 72 | # distribution. 73 | self.assertAllClose(inputs['read_mode'].sum(2), 74 | np.ones([BATCH_SIZE, NUM_READS])) 75 | self.assertGreaterEqual(inputs['read_mode'].min(), 0) 76 | 77 | def testWriteWeights(self): 78 | memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) 79 | usage = np.random.rand(BATCH_SIZE, MEMORY_SIZE) 80 | 81 | allocation_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) 82 | write_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) 83 | write_content_keys = np.random.rand(BATCH_SIZE, NUM_WRITES, WORD_SIZE) 84 | write_content_strengths = np.random.rand(BATCH_SIZE, NUM_WRITES) 85 | 86 | # Check that turning on allocation gate fully brings the write gate to 87 | # the allocation weighting (which we will control by controlling the usage). 88 | usage[:, 3] = 0 89 | allocation_gate[:, 0] = 1 90 | write_gate[:, 0] = 1 91 | 92 | inputs = { 93 | 'allocation_gate': tf.constant(allocation_gate), 94 | 'write_gate': tf.constant(write_gate), 95 | 'write_content_keys': tf.constant(write_content_keys), 96 | 'write_content_strengths': tf.constant(write_content_strengths) 97 | } 98 | 99 | weights = self.module._write_weights(inputs, 100 | tf.constant(memory), 101 | tf.constant(usage)) 102 | 103 | with self.test_session(): 104 | weights = weights.eval() 105 | 106 | # Check the weights sum to their target gating. 107 | self.assertAllClose(np.sum(weights, axis=2), write_gate, atol=5e-2) 108 | 109 | # Check that we fully allocated to the third row. 110 | weights_0_0_target = util.one_hot(MEMORY_SIZE, 3) 111 | self.assertAllClose(weights[0, 0], weights_0_0_target, atol=1e-3) 112 | 113 | def testReadWeights(self): 114 | memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) 115 | prev_read_weights = np.random.rand(BATCH_SIZE, NUM_READS, MEMORY_SIZE) 116 | prev_read_weights /= prev_read_weights.sum(2, keepdims=True) + 1 117 | 118 | link = np.random.rand(BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE) 119 | # Row and column sums should be at most 1: 120 | link /= np.maximum(link.sum(2, keepdims=True), 1) 121 | link /= np.maximum(link.sum(3, keepdims=True), 1) 122 | 123 | # We query the memory on the third location in memory, and select a large 124 | # strength on the query. Then we select a content-based read-mode. 125 | read_content_keys = np.random.rand(BATCH_SIZE, NUM_READS, WORD_SIZE) 126 | read_content_keys[0, 0] = memory[0, 3] 127 | read_content_strengths = tf.constant( 128 | 100., shape=[BATCH_SIZE, NUM_READS], dtype=tf.float64) 129 | read_mode = np.random.rand(BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES) 130 | read_mode[0, 0, :] = util.one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES) 131 | inputs = { 132 | 'read_content_keys': tf.constant(read_content_keys), 133 | 'read_content_strengths': read_content_strengths, 134 | 'read_mode': tf.constant(read_mode), 135 | } 136 | read_weights = self.module._read_weights(inputs, memory, prev_read_weights, 137 | link) 138 | with self.test_session(): 139 | read_weights = read_weights.eval() 140 | 141 | # read_weights for batch 0, read head 0 should be memory location 3 142 | self.assertAllClose( 143 | read_weights[0, 0, :], util.one_hot(MEMORY_SIZE, 3), atol=1e-3) 144 | 145 | def testGradients(self): 146 | inputs = tf.constant(np.random.randn(BATCH_SIZE, INPUT_SIZE), tf.float32) 147 | output, _ = self.module(inputs, self.initial_state) 148 | loss = tf.reduce_sum(output) 149 | 150 | tensors_to_check = [ 151 | inputs, self.initial_state.memory, self.initial_state.read_weights, 152 | self.initial_state.linkage.precedence_weights, 153 | self.initial_state.linkage.link 154 | ] 155 | shapes = [x.get_shape().as_list() for x in tensors_to_check] 156 | with self.test_session() as sess: 157 | sess.run(tf.global_variables_initializer()) 158 | err = tf.test.compute_gradient_error(tensors_to_check, shapes, loss, [1]) 159 | self.assertLess(err, 0.1) 160 | 161 | 162 | if __name__ == '__main__': 163 | tf.test.main() 164 | -------------------------------------------------------------------------------- /dnc/addressing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DNC addressing modules.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import sonnet as snt 23 | import tensorflow as tf 24 | 25 | from dnc import util 26 | 27 | # Ensure values are greater than epsilon to avoid numerical instability. 28 | _EPSILON = 1e-6 29 | 30 | TemporalLinkageState = collections.namedtuple('TemporalLinkageState', 31 | ('link', 'precedence_weights')) 32 | 33 | 34 | def _vector_norms(m): 35 | squared_norms = tf.reduce_sum(m * m, axis=2, keepdims=True) 36 | return tf.sqrt(squared_norms + _EPSILON) 37 | 38 | 39 | def weighted_softmax(activations, strengths, strengths_op): 40 | """Returns softmax over activations multiplied by positive strengths. 41 | 42 | Args: 43 | activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of 44 | activations to be transformed. Softmax is taken over the last dimension. 45 | strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to 46 | multiply by the activations prior to the softmax. 47 | strengths_op: An operation to transform strengths before softmax. 48 | 49 | Returns: 50 | A tensor of same shape as `activations` with weighted softmax applied. 51 | """ 52 | transformed_strengths = tf.expand_dims(strengths_op(strengths), -1) 53 | sharp_activations = activations * transformed_strengths 54 | softmax = snt.BatchApply(module_or_op=tf.nn.softmax) 55 | return softmax(sharp_activations) 56 | 57 | 58 | class CosineWeights(snt.AbstractModule): 59 | """Cosine-weighted attention. 60 | 61 | Calculates the cosine similarity between a query and each word in memory, then 62 | applies a weighted softmax to return a sharp distribution. 63 | """ 64 | 65 | def __init__(self, 66 | num_heads, 67 | word_size, 68 | strength_op=tf.nn.softplus, 69 | name='cosine_weights'): 70 | """Initializes the CosineWeights module. 71 | 72 | Args: 73 | num_heads: number of memory heads. 74 | word_size: memory word size. 75 | strength_op: operation to apply to strengths (default is tf.nn.softplus). 76 | name: module name (default 'cosine_weights') 77 | """ 78 | super(CosineWeights, self).__init__(name=name) 79 | self._num_heads = num_heads 80 | self._word_size = word_size 81 | self._strength_op = strength_op 82 | 83 | def _build(self, memory, keys, strengths): 84 | """Connects the CosineWeights module into the graph. 85 | 86 | Args: 87 | memory: A 3-D tensor of shape `[batch_size, memory_size, word_size]`. 88 | keys: A 3-D tensor of shape `[batch_size, num_heads, word_size]`. 89 | strengths: A 2-D tensor of shape `[batch_size, num_heads]`. 90 | 91 | Returns: 92 | Weights tensor of shape `[batch_size, num_heads, memory_size]`. 93 | """ 94 | # Calculates the inner product between the query vector and words in memory. 95 | dot = tf.matmul(keys, memory, adjoint_b=True) 96 | 97 | # Outer product to compute denominator (euclidean norm of query and memory). 98 | memory_norms = _vector_norms(memory) 99 | key_norms = _vector_norms(keys) 100 | norm = tf.matmul(key_norms, memory_norms, adjoint_b=True) 101 | 102 | # Calculates cosine similarity between the query vector and words in memory. 103 | similarity = dot / (norm + _EPSILON) 104 | 105 | return weighted_softmax(similarity, strengths, self._strength_op) 106 | 107 | 108 | class TemporalLinkage(snt.RNNCore): 109 | """Keeps track of write order for forward and backward addressing. 110 | 111 | This is a pseudo-RNNCore module, whose state is a pair `(link, 112 | precedence_weights)`, where `link` is a (collection of) graphs for (possibly 113 | multiple) write heads (represented by a tensor with values in the range 114 | [0, 1]), and `precedence_weights` records the "previous write locations" used 115 | to build the link graphs. 116 | 117 | The function `directional_read_weights` computes addresses following the 118 | forward and backward directions in the link graphs. 119 | """ 120 | 121 | def __init__(self, memory_size, num_writes, name='temporal_linkage'): 122 | """Construct a TemporalLinkage module. 123 | 124 | Args: 125 | memory_size: The number of memory slots. 126 | num_writes: The number of write heads. 127 | name: Name of the module. 128 | """ 129 | super(TemporalLinkage, self).__init__(name=name) 130 | self._memory_size = memory_size 131 | self._num_writes = num_writes 132 | 133 | def _build(self, write_weights, prev_state): 134 | """Calculate the updated linkage state given the write weights. 135 | 136 | Args: 137 | write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` 138 | containing the memory addresses of the different write heads. 139 | prev_state: `TemporalLinkageState` tuple containg a tensor `link` of 140 | shape `[batch_size, num_writes, memory_size, memory_size]`, and a 141 | tensor `precedence_weights` of shape `[batch_size, num_writes, 142 | memory_size]` containing the aggregated history of recent writes. 143 | 144 | Returns: 145 | A `TemporalLinkageState` tuple `next_state`, which contains the updated 146 | link and precedence weights. 147 | """ 148 | link = self._link(prev_state.link, prev_state.precedence_weights, 149 | write_weights) 150 | precedence_weights = self._precedence_weights(prev_state.precedence_weights, 151 | write_weights) 152 | return TemporalLinkageState( 153 | link=link, precedence_weights=precedence_weights) 154 | 155 | def directional_read_weights(self, link, prev_read_weights, forward): 156 | """Calculates the forward or the backward read weights. 157 | 158 | For each read head (at a given address), there are `num_writes` link graphs 159 | to follow. Thus this function computes a read address for each of the 160 | `num_reads * num_writes` pairs of read and write heads. 161 | 162 | Args: 163 | link: tensor of shape `[batch_size, num_writes, memory_size, 164 | memory_size]` representing the link graphs L_t. 165 | prev_read_weights: tensor of shape `[batch_size, num_reads, 166 | memory_size]` containing the previous read weights w_{t-1}^r. 167 | forward: Boolean indicating whether to follow the "future" direction in 168 | the link graph (True) or the "past" direction (False). 169 | 170 | Returns: 171 | tensor of shape `[batch_size, num_reads, num_writes, memory_size]` 172 | """ 173 | with tf.name_scope('directional_read_weights'): 174 | # We calculate the forward and backward directions for each pair of 175 | # read and write heads; hence we need to tile the read weights and do a 176 | # sort of "outer product" to get this. 177 | expanded_read_weights = tf.stack([prev_read_weights] * self._num_writes, 178 | 1) 179 | result = tf.matmul(expanded_read_weights, link, adjoint_b=forward) 180 | # Swap dimensions 1, 2 so order is [batch, reads, writes, memory]: 181 | return tf.transpose(result, perm=[0, 2, 1, 3]) 182 | 183 | def _link(self, prev_link, prev_precedence_weights, write_weights): 184 | """Calculates the new link graphs. 185 | 186 | For each write head, the link is a directed graph (represented by a matrix 187 | with entries in range [0, 1]) whose vertices are the memory locations, and 188 | an edge indicates temporal ordering of writes. 189 | 190 | Args: 191 | prev_link: A tensor of shape `[batch_size, num_writes, memory_size, 192 | memory_size]` representing the previous link graphs for each write 193 | head. 194 | prev_precedence_weights: A tensor of shape `[batch_size, num_writes, 195 | memory_size]` which is the previous "aggregated" write weights for 196 | each write head. 197 | write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` 198 | containing the new locations in memory written to. 199 | 200 | Returns: 201 | A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` 202 | containing the new link graphs for each write head. 203 | """ 204 | with tf.name_scope('link'): 205 | batch_size = tf.shape(prev_link)[0] 206 | write_weights_i = tf.expand_dims(write_weights, 3) 207 | write_weights_j = tf.expand_dims(write_weights, 2) 208 | prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2) 209 | prev_link_scale = 1 - write_weights_i - write_weights_j 210 | new_link = write_weights_i * prev_precedence_weights_j 211 | link = prev_link_scale * prev_link + new_link 212 | # Return the link with the diagonal set to zero, to remove self-looping 213 | # edges. 214 | return tf.matrix_set_diag( 215 | link, 216 | tf.zeros( 217 | [batch_size, self._num_writes, self._memory_size], 218 | dtype=link.dtype)) 219 | 220 | def _precedence_weights(self, prev_precedence_weights, write_weights): 221 | """Calculates the new precedence weights given the current write weights. 222 | 223 | The precedence weights are the "aggregated write weights" for each write 224 | head, where write weights with sum close to zero will leave the precedence 225 | weights unchanged, but with sum close to one will replace the precedence 226 | weights. 227 | 228 | Args: 229 | prev_precedence_weights: A tensor of shape `[batch_size, num_writes, 230 | memory_size]` containing the previous precedence weights. 231 | write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` 232 | containing the new write weights. 233 | 234 | Returns: 235 | A tensor of shape `[batch_size, num_writes, memory_size]` containing the 236 | new precedence weights. 237 | """ 238 | with tf.name_scope('precedence_weights'): 239 | write_sum = tf.reduce_sum(write_weights, 2, keepdims=True) 240 | return (1 - write_sum) * prev_precedence_weights + write_weights 241 | 242 | @property 243 | def state_size(self): 244 | """Returns a `TemporalLinkageState` tuple of the state tensors' shapes.""" 245 | return TemporalLinkageState( 246 | link=tf.TensorShape( 247 | [self._num_writes, self._memory_size, self._memory_size]), 248 | precedence_weights=tf.TensorShape([self._num_writes, 249 | self._memory_size]),) 250 | 251 | 252 | class Freeness(snt.RNNCore): 253 | """Memory usage that is increased by writing and decreased by reading. 254 | 255 | This module is a pseudo-RNNCore whose state is a tensor with values in 256 | the range [0, 1] indicating the usage of each of `memory_size` memory slots. 257 | 258 | The usage is: 259 | 260 | * Increased by writing, where usage is increased towards 1 at the write 261 | addresses. 262 | * Decreased by reading, where usage is decreased after reading from a 263 | location when free_gate is close to 1. 264 | 265 | The function `write_allocation_weights` can be invoked to get free locations 266 | to write to for a number of write heads. 267 | """ 268 | 269 | def __init__(self, memory_size, name='freeness'): 270 | """Creates a Freeness module. 271 | 272 | Args: 273 | memory_size: Number of memory slots. 274 | name: Name of the module. 275 | """ 276 | super(Freeness, self).__init__(name=name) 277 | self._memory_size = memory_size 278 | 279 | def _build(self, write_weights, free_gate, read_weights, prev_usage): 280 | """Calculates the new memory usage u_t. 281 | 282 | Memory that was written to in the previous time step will have its usage 283 | increased; memory that was read from and the controller says can be "freed" 284 | will have its usage decreased. 285 | 286 | Args: 287 | write_weights: tensor of shape `[batch_size, num_writes, 288 | memory_size]` giving write weights at previous time step. 289 | free_gate: tensor of shape `[batch_size, num_reads]` which indicates 290 | which read heads read memory that can now be freed. 291 | read_weights: tensor of shape `[batch_size, num_reads, 292 | memory_size]` giving read weights at previous time step. 293 | prev_usage: tensor of shape `[batch_size, memory_size]` giving 294 | usage u_{t - 1} at the previous time step, with entries in range 295 | [0, 1]. 296 | 297 | Returns: 298 | tensor of shape `[batch_size, memory_size]` representing updated memory 299 | usage. 300 | """ 301 | # Calculation of usage is not differentiable with respect to write weights. 302 | write_weights = tf.stop_gradient(write_weights) 303 | usage = self._usage_after_write(prev_usage, write_weights) 304 | usage = self._usage_after_read(usage, free_gate, read_weights) 305 | return usage 306 | 307 | def write_allocation_weights(self, usage, write_gates, num_writes): 308 | """Calculates freeness-based locations for writing to. 309 | 310 | This finds unused memory by ranking the memory locations by usage, for each 311 | write head. (For more than one write head, we use a "simulated new usage" 312 | which takes into account the fact that the previous write head will increase 313 | the usage in that area of the memory.) 314 | 315 | Args: 316 | usage: A tensor of shape `[batch_size, memory_size]` representing 317 | current memory usage. 318 | write_gates: A tensor of shape `[batch_size, num_writes]` with values in 319 | the range [0, 1] indicating how much each write head does writing 320 | based on the address returned here (and hence how much usage 321 | increases). 322 | num_writes: The number of write heads to calculate write weights for. 323 | 324 | Returns: 325 | tensor of shape `[batch_size, num_writes, memory_size]` containing the 326 | freeness-based write locations. Note that this isn't scaled by 327 | `write_gate`; this scaling must be applied externally. 328 | """ 329 | with tf.name_scope('write_allocation_weights'): 330 | # expand gatings over memory locations 331 | write_gates = tf.expand_dims(write_gates, -1) 332 | 333 | allocation_weights = [] 334 | for i in range(num_writes): 335 | allocation_weights.append(self._allocation(usage)) 336 | # update usage to take into account writing to this new allocation 337 | usage += ((1 - usage) * write_gates[:, i, :] * allocation_weights[i]) 338 | 339 | # Pack the allocation weights for the write heads into one tensor. 340 | return tf.stack(allocation_weights, axis=1) 341 | 342 | def _usage_after_write(self, prev_usage, write_weights): 343 | """Calcualtes the new usage after writing to memory. 344 | 345 | Args: 346 | prev_usage: tensor of shape `[batch_size, memory_size]`. 347 | write_weights: tensor of shape `[batch_size, num_writes, memory_size]`. 348 | 349 | Returns: 350 | New usage, a tensor of shape `[batch_size, memory_size]`. 351 | """ 352 | with tf.name_scope('usage_after_write'): 353 | # Calculate the aggregated effect of all write heads 354 | write_weights = 1 - util.reduce_prod(1 - write_weights, 1) 355 | return prev_usage + (1 - prev_usage) * write_weights 356 | 357 | def _usage_after_read(self, prev_usage, free_gate, read_weights): 358 | """Calcualtes the new usage after reading and freeing from memory. 359 | 360 | Args: 361 | prev_usage: tensor of shape `[batch_size, memory_size]`. 362 | free_gate: tensor of shape `[batch_size, num_reads]` with entries in the 363 | range [0, 1] indicating the amount that locations read from can be 364 | freed. 365 | read_weights: tensor of shape `[batch_size, num_reads, memory_size]`. 366 | 367 | Returns: 368 | New usage, a tensor of shape `[batch_size, memory_size]`. 369 | """ 370 | with tf.name_scope('usage_after_read'): 371 | free_gate = tf.expand_dims(free_gate, -1) 372 | free_read_weights = free_gate * read_weights 373 | phi = util.reduce_prod(1 - free_read_weights, 1, name='phi') 374 | return prev_usage * phi 375 | 376 | def _allocation(self, usage): 377 | r"""Computes allocation by sorting `usage`. 378 | 379 | This corresponds to the value a = a_t[\phi_t[j]] in the paper. 380 | 381 | Args: 382 | usage: tensor of shape `[batch_size, memory_size]` indicating current 383 | memory usage. This is equal to u_t in the paper when we only have one 384 | write head, but for multiple write heads, one should update the usage 385 | while iterating through the write heads to take into account the 386 | allocation returned by this function. 387 | 388 | Returns: 389 | Tensor of shape `[batch_size, memory_size]` corresponding to allocation. 390 | """ 391 | with tf.name_scope('allocation'): 392 | # Ensure values are not too small prior to cumprod. 393 | usage = _EPSILON + (1 - _EPSILON) * usage 394 | 395 | nonusage = 1 - usage 396 | sorted_nonusage, indices = tf.nn.top_k( 397 | nonusage, k=self._memory_size, name='sort') 398 | sorted_usage = 1 - sorted_nonusage 399 | prod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True) 400 | sorted_allocation = sorted_nonusage * prod_sorted_usage 401 | inverse_indices = util.batch_invert_permutation(indices) 402 | 403 | # This final line "unsorts" sorted_allocation, so that the indexing 404 | # corresponds to the original indexing of `usage`. 405 | return util.batch_gather(sorted_allocation, inverse_indices) 406 | 407 | @property 408 | def state_size(self): 409 | """Returns the shape of the state tensor.""" 410 | return tf.TensorShape([self._memory_size]) 411 | -------------------------------------------------------------------------------- /dnc/addressing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for memory addressing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import sonnet as snt 23 | import tensorflow as tf 24 | 25 | from dnc import addressing 26 | from dnc import util 27 | 28 | try: 29 | xrange 30 | except NameError: 31 | xrange = range 32 | 33 | 34 | class WeightedSoftmaxTest(tf.test.TestCase): 35 | 36 | def testValues(self): 37 | batch_size = 5 38 | num_heads = 3 39 | memory_size = 7 40 | 41 | activations_data = np.random.randn(batch_size, num_heads, memory_size) 42 | weights_data = np.ones((batch_size, num_heads)) 43 | 44 | activations = tf.placeholder(tf.float32, 45 | [batch_size, num_heads, memory_size]) 46 | weights = tf.placeholder(tf.float32, [batch_size, num_heads]) 47 | # Run weighted softmax with identity placed on weights. Output should be 48 | # equal to a standalone softmax. 49 | observed = addressing.weighted_softmax(activations, weights, tf.identity) 50 | expected = snt.BatchApply( 51 | module_or_op=tf.nn.softmax, name='BatchSoftmax')(activations) 52 | with self.test_session() as sess: 53 | observed = sess.run( 54 | observed, 55 | feed_dict={activations: activations_data, 56 | weights: weights_data}) 57 | expected = sess.run(expected, feed_dict={activations: activations_data}) 58 | self.assertAllClose(observed, expected) 59 | 60 | 61 | class CosineWeightsTest(tf.test.TestCase): 62 | 63 | def testShape(self): 64 | batch_size = 5 65 | num_heads = 3 66 | memory_size = 7 67 | word_size = 2 68 | 69 | module = addressing.CosineWeights(num_heads, word_size) 70 | mem = tf.placeholder(tf.float32, [batch_size, memory_size, word_size]) 71 | keys = tf.placeholder(tf.float32, [batch_size, num_heads, word_size]) 72 | strengths = tf.placeholder(tf.float32, [batch_size, num_heads]) 73 | weights = module(mem, keys, strengths) 74 | self.assertTrue(weights.get_shape().is_compatible_with( 75 | [batch_size, num_heads, memory_size])) 76 | 77 | def testValues(self): 78 | batch_size = 5 79 | num_heads = 4 80 | memory_size = 10 81 | word_size = 2 82 | 83 | mem_data = np.random.randn(batch_size, memory_size, word_size) 84 | np.copyto(mem_data[0, 0], [1, 2]) 85 | np.copyto(mem_data[0, 1], [3, 4]) 86 | np.copyto(mem_data[0, 2], [5, 6]) 87 | 88 | keys_data = np.random.randn(batch_size, num_heads, word_size) 89 | np.copyto(keys_data[0, 0], [5, 6]) 90 | np.copyto(keys_data[0, 1], [1, 2]) 91 | np.copyto(keys_data[0, 2], [5, 6]) 92 | np.copyto(keys_data[0, 3], [3, 4]) 93 | strengths_data = np.random.randn(batch_size, num_heads) 94 | 95 | module = addressing.CosineWeights(num_heads, word_size) 96 | mem = tf.placeholder(tf.float32, [batch_size, memory_size, word_size]) 97 | keys = tf.placeholder(tf.float32, [batch_size, num_heads, word_size]) 98 | strengths = tf.placeholder(tf.float32, [batch_size, num_heads]) 99 | weights = module(mem, keys, strengths) 100 | 101 | with self.test_session() as sess: 102 | result = sess.run( 103 | weights, 104 | feed_dict={mem: mem_data, 105 | keys: keys_data, 106 | strengths: strengths_data}) 107 | 108 | # Manually checks results. 109 | strengths_softplus = np.log(1 + np.exp(strengths_data)) 110 | similarity = np.zeros((memory_size)) 111 | 112 | for b in xrange(batch_size): 113 | for h in xrange(num_heads): 114 | key = keys_data[b, h] 115 | key_norm = np.linalg.norm(key) 116 | 117 | for m in xrange(memory_size): 118 | row = mem_data[b, m] 119 | similarity[m] = np.dot(key, row) / (key_norm * np.linalg.norm(row)) 120 | 121 | similarity = np.exp(similarity * strengths_softplus[b, h]) 122 | similarity /= similarity.sum() 123 | self.assertAllClose(result[b, h], similarity, atol=1e-4, rtol=1e-4) 124 | 125 | def testDivideByZero(self): 126 | batch_size = 5 127 | num_heads = 4 128 | memory_size = 10 129 | word_size = 2 130 | 131 | module = addressing.CosineWeights(num_heads, word_size) 132 | keys = tf.random_normal([batch_size, num_heads, word_size]) 133 | strengths = tf.random_normal([batch_size, num_heads]) 134 | 135 | # First row of memory is non-zero to concentrate attention on this location. 136 | # Remaining rows are all zero. 137 | first_row_ones = tf.ones([batch_size, 1, word_size], dtype=tf.float32) 138 | remaining_zeros = tf.zeros( 139 | [batch_size, memory_size - 1, word_size], dtype=tf.float32) 140 | mem = tf.concat((first_row_ones, remaining_zeros), 1) 141 | 142 | output = module(mem, keys, strengths) 143 | gradients = tf.gradients(output, [mem, keys, strengths]) 144 | 145 | with self.test_session() as sess: 146 | output, gradients = sess.run([output, gradients]) 147 | self.assertFalse(np.any(np.isnan(output))) 148 | self.assertFalse(np.any(np.isnan(gradients[0]))) 149 | self.assertFalse(np.any(np.isnan(gradients[1]))) 150 | self.assertFalse(np.any(np.isnan(gradients[2]))) 151 | 152 | 153 | class TemporalLinkageTest(tf.test.TestCase): 154 | 155 | def testModule(self): 156 | batch_size = 7 157 | memory_size = 4 158 | num_reads = 11 159 | num_writes = 5 160 | module = addressing.TemporalLinkage( 161 | memory_size=memory_size, num_writes=num_writes) 162 | 163 | prev_link_in = tf.placeholder( 164 | tf.float32, (batch_size, num_writes, memory_size, memory_size)) 165 | prev_precedence_weights_in = tf.placeholder( 166 | tf.float32, (batch_size, num_writes, memory_size)) 167 | write_weights_in = tf.placeholder(tf.float32, 168 | (batch_size, num_writes, memory_size)) 169 | 170 | state = addressing.TemporalLinkageState( 171 | link=np.zeros([batch_size, num_writes, memory_size, memory_size]), 172 | precedence_weights=np.zeros([batch_size, num_writes, memory_size])) 173 | 174 | calc_state = module(write_weights_in, 175 | addressing.TemporalLinkageState( 176 | link=prev_link_in, 177 | precedence_weights=prev_precedence_weights_in)) 178 | 179 | with self.test_session() as sess: 180 | num_steps = 5 181 | for i in xrange(num_steps): 182 | write_weights = np.random.rand(batch_size, num_writes, memory_size) 183 | write_weights /= write_weights.sum(2, keepdims=True) + 1 184 | 185 | # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1 186 | if i == num_steps - 2: 187 | write_weights[0, 0, :] = util.one_hot(memory_size, 0) 188 | write_weights[0, 1, :] = util.one_hot(memory_size, 3) 189 | elif i == num_steps - 1: 190 | write_weights[0, 0, :] = util.one_hot(memory_size, 1) 191 | write_weights[0, 1, :] = util.one_hot(memory_size, 2) 192 | 193 | state = sess.run( 194 | calc_state, 195 | feed_dict={ 196 | prev_link_in: state.link, 197 | prev_precedence_weights_in: state.precedence_weights, 198 | write_weights_in: write_weights 199 | }) 200 | 201 | # link should be bounded in range [0, 1] 202 | self.assertGreaterEqual(state.link.min(), 0) 203 | self.assertLessEqual(state.link.max(), 1) 204 | 205 | # link diagonal should be zero 206 | self.assertAllEqual( 207 | state.link[:, :, range(memory_size), range(memory_size)], 208 | np.zeros([batch_size, num_writes, memory_size])) 209 | 210 | # link rows and columns should sum to at most 1 211 | self.assertLessEqual(state.link.sum(2).max(), 1) 212 | self.assertLessEqual(state.link.sum(3).max(), 1) 213 | 214 | # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2 215 | self.assertAllEqual(state.link[0, 0, :, 0], util.one_hot(memory_size, 1)) 216 | self.assertAllEqual(state.link[0, 1, :, 3], util.one_hot(memory_size, 2)) 217 | 218 | # Now test calculation of forward and backward read weights 219 | prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) 220 | prev_read_weights[0, 5, :] = util.one_hot(memory_size, 0) # read 5, posn 0 221 | prev_read_weights[0, 6, :] = util.one_hot(memory_size, 2) # read 6, posn 2 222 | forward_read_weights = module.directional_read_weights( 223 | tf.constant(state.link), 224 | tf.constant(prev_read_weights, dtype=tf.float32), 225 | forward=True) 226 | backward_read_weights = module.directional_read_weights( 227 | tf.constant(state.link), 228 | tf.constant(prev_read_weights, dtype=tf.float32), 229 | forward=False) 230 | 231 | with self.test_session(): 232 | forward_read_weights = forward_read_weights.eval() 233 | backward_read_weights = backward_read_weights.eval() 234 | 235 | # Check directional weights calculated correctly. 236 | self.assertAllEqual( 237 | forward_read_weights[0, 5, 0, :], # read=5, write=0 238 | util.one_hot(memory_size, 1)) 239 | self.assertAllEqual( 240 | backward_read_weights[0, 6, 1, :], # read=6, write=1 241 | util.one_hot(memory_size, 3)) 242 | 243 | def testPrecedenceWeights(self): 244 | batch_size = 7 245 | memory_size = 3 246 | num_writes = 5 247 | module = addressing.TemporalLinkage( 248 | memory_size=memory_size, num_writes=num_writes) 249 | 250 | prev_precedence_weights = np.random.rand(batch_size, num_writes, 251 | memory_size) 252 | write_weights = np.random.rand(batch_size, num_writes, memory_size) 253 | 254 | # These should sum to at most 1 for each write head in each batch. 255 | write_weights /= write_weights.sum(2, keepdims=True) + 1 256 | prev_precedence_weights /= prev_precedence_weights.sum(2, keepdims=True) + 1 257 | 258 | write_weights[0, 1, :] = 0 # batch 0 head 1: no writing 259 | write_weights[1, 2, :] /= write_weights[1, 2, :].sum() # b1 h2: all writing 260 | 261 | precedence_weights = module._precedence_weights( 262 | prev_precedence_weights=tf.constant(prev_precedence_weights), 263 | write_weights=tf.constant(write_weights)) 264 | 265 | with self.test_session(): 266 | precedence_weights = precedence_weights.eval() 267 | 268 | # precedence weights should be bounded in range [0, 1] 269 | self.assertGreaterEqual(precedence_weights.min(), 0) 270 | self.assertLessEqual(precedence_weights.max(), 1) 271 | 272 | # no writing in batch 0, head 1 273 | self.assertAllClose(precedence_weights[0, 1, :], 274 | prev_precedence_weights[0, 1, :]) 275 | 276 | # all writing in batch 1, head 2 277 | self.assertAllClose(precedence_weights[1, 2, :], write_weights[1, 2, :]) 278 | 279 | 280 | class FreenessTest(tf.test.TestCase): 281 | 282 | def testModule(self): 283 | batch_size = 5 284 | memory_size = 11 285 | num_reads = 3 286 | num_writes = 7 287 | module = addressing.Freeness(memory_size) 288 | 289 | free_gate = np.random.rand(batch_size, num_reads) 290 | 291 | # Produce read weights that sum to 1 for each batch and head. 292 | prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) 293 | prev_read_weights[1, :, 3] = 0 # no read at batch 1, position 3; see below 294 | prev_read_weights /= prev_read_weights.sum(2, keepdims=True) 295 | prev_write_weights = np.random.rand(batch_size, num_writes, memory_size) 296 | prev_write_weights /= prev_write_weights.sum(2, keepdims=True) 297 | prev_usage = np.random.rand(batch_size, memory_size) 298 | 299 | # Add some special values that allows us to test the behaviour: 300 | prev_write_weights[1, 2, 3] = 1 # full write in batch 1, head 2, position 3 301 | prev_read_weights[2, 0, 4] = 1 # full read at batch 2, head 0, position 4 302 | free_gate[2, 0] = 1 # can free up all locations for batch 2, read head 0 303 | 304 | usage = module( 305 | tf.constant(prev_write_weights), 306 | tf.constant(free_gate), 307 | tf.constant(prev_read_weights), tf.constant(prev_usage)) 308 | with self.test_session(): 309 | usage = usage.eval() 310 | 311 | # Check all usages are between 0 and 1. 312 | self.assertGreaterEqual(usage.min(), 0) 313 | self.assertLessEqual(usage.max(), 1) 314 | 315 | # Check that the full write at batch 1, position 3 makes it fully used. 316 | self.assertEqual(usage[1][3], 1) 317 | 318 | # Check that the full free at batch 2, position 4 makes it fully free. 319 | self.assertEqual(usage[2][4], 0) 320 | 321 | def testWriteAllocationWeights(self): 322 | batch_size = 7 323 | memory_size = 23 324 | num_writes = 5 325 | module = addressing.Freeness(memory_size) 326 | 327 | usage = np.random.rand(batch_size, memory_size) 328 | write_gates = np.random.rand(batch_size, num_writes) 329 | 330 | # Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the 331 | # weighting, but it means that the usage doesn't change, so we should get 332 | # the same allocation weightings for: (1, 2) and (3, 4) (but all others 333 | # being different). 334 | write_gates[0, 1] = 0 335 | write_gates[0, 3] = 0 336 | # and turn heads 0 and 2 on for full effect. 337 | write_gates[0, 0] = 1 338 | write_gates[0, 2] = 1 339 | 340 | # In batch 1, make one of the usages 0 and another almost 0, so that these 341 | # entries get most of the allocation weights for the first and second heads. 342 | usage[1] = usage[1] * 0.9 + 0.1 # make sure all entries are in [0.1, 1] 343 | usage[1][4] = 0 # write head 0 should get allocated to position 4 344 | usage[1][3] = 1e-4 # write head 1 should get allocated to position 3 345 | write_gates[1, 0] = 1 # write head 0 fully on 346 | write_gates[1, 1] = 1 # write head 1 fully on 347 | 348 | weights = module.write_allocation_weights( 349 | usage=tf.constant(usage), 350 | write_gates=tf.constant(write_gates), 351 | num_writes=num_writes) 352 | 353 | with self.test_session(): 354 | weights = weights.eval() 355 | 356 | # Check that all weights are between 0 and 1 357 | self.assertGreaterEqual(weights.min(), 0) 358 | self.assertLessEqual(weights.max(), 1) 359 | 360 | # Check that weights sum to close to 1 361 | self.assertAllClose( 362 | np.sum(weights, axis=2), np.ones([batch_size, num_writes]), atol=1e-3) 363 | 364 | # Check the same / different allocation weight pairs as described above. 365 | self.assertGreater(np.abs(weights[0, 0, :] - weights[0, 1, :]).max(), 0.1) 366 | self.assertAllEqual(weights[0, 1, :], weights[0, 2, :]) 367 | self.assertGreater(np.abs(weights[0, 2, :] - weights[0, 3, :]).max(), 0.1) 368 | self.assertAllEqual(weights[0, 3, :], weights[0, 4, :]) 369 | 370 | self.assertAllClose(weights[1][0], util.one_hot(memory_size, 4), atol=1e-3) 371 | self.assertAllClose(weights[1][1], util.one_hot(memory_size, 3), atol=1e-3) 372 | 373 | def testWriteAllocationWeightsGradient(self): 374 | batch_size = 7 375 | memory_size = 5 376 | num_writes = 3 377 | module = addressing.Freeness(memory_size) 378 | 379 | usage = tf.constant(np.random.rand(batch_size, memory_size)) 380 | write_gates = tf.constant(np.random.rand(batch_size, num_writes)) 381 | weights = module.write_allocation_weights(usage, write_gates, num_writes) 382 | 383 | with self.test_session(): 384 | err = tf.test.compute_gradient_error( 385 | [usage, write_gates], 386 | [usage.get_shape().as_list(), write_gates.get_shape().as_list()], 387 | weights, 388 | weights.get_shape().as_list(), 389 | delta=1e-5) 390 | self.assertLess(err, 0.01) 391 | 392 | def testAllocation(self): 393 | batch_size = 7 394 | memory_size = 13 395 | usage = np.random.rand(batch_size, memory_size) 396 | module = addressing.Freeness(memory_size) 397 | allocation = module._allocation(tf.constant(usage)) 398 | with self.test_session(): 399 | allocation = allocation.eval() 400 | 401 | # 1. Test that max allocation goes to min usage, and vice versa. 402 | self.assertAllEqual(np.argmin(usage, axis=1), np.argmax(allocation, axis=1)) 403 | self.assertAllEqual(np.argmax(usage, axis=1), np.argmin(allocation, axis=1)) 404 | 405 | # 2. Test that allocations sum to almost 1. 406 | self.assertAllClose(np.sum(allocation, axis=1), np.ones(batch_size), 0.01) 407 | 408 | def testAllocationGradient(self): 409 | batch_size = 1 410 | memory_size = 5 411 | usage = tf.constant(np.random.rand(batch_size, memory_size)) 412 | module = addressing.Freeness(memory_size) 413 | allocation = module._allocation(usage) 414 | with self.test_session(): 415 | err = tf.test.compute_gradient_error( 416 | usage, 417 | usage.get_shape().as_list(), 418 | allocation, 419 | allocation.get_shape().as_list(), 420 | delta=1e-5) 421 | self.assertLess(err, 0.01) 422 | 423 | 424 | if __name__ == '__main__': 425 | tf.test.main() 426 | -------------------------------------------------------------------------------- /dnc/dnc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DNC Cores. 16 | 17 | These modules create a DNC core. They take input, pass parameters to the memory 18 | access module, and integrate the output of memory to form an output. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import collections 26 | import numpy as np 27 | import sonnet as snt 28 | import tensorflow as tf 29 | 30 | from dnc import access 31 | 32 | DNCState = collections.namedtuple('DNCState', ('access_output', 'access_state', 33 | 'controller_state')) 34 | 35 | 36 | class DNC(snt.RNNCore): 37 | """DNC core module. 38 | 39 | Contains controller and memory access module. 40 | """ 41 | 42 | def __init__(self, 43 | access_config, 44 | controller_config, 45 | output_size, 46 | clip_value=None, 47 | name='dnc'): 48 | """Initializes the DNC core. 49 | 50 | Args: 51 | access_config: dictionary of access module configurations. 52 | controller_config: dictionary of controller (LSTM) module configurations. 53 | output_size: output dimension size of core. 54 | clip_value: clips controller and core output values to between 55 | `[-clip_value, clip_value]` if specified. 56 | name: module name (default 'dnc'). 57 | 58 | Raises: 59 | TypeError: if direct_input_size is not None for any access module other 60 | than KeyValueMemory. 61 | """ 62 | super(DNC, self).__init__(name=name) 63 | 64 | with self._enter_variable_scope(): 65 | self._controller = snt.LSTM(**controller_config) 66 | self._access = access.MemoryAccess(**access_config) 67 | 68 | self._access_output_size = np.prod(self._access.output_size.as_list()) 69 | self._output_size = output_size 70 | self._clip_value = clip_value or 0 71 | 72 | self._output_size = tf.TensorShape([output_size]) 73 | self._state_size = DNCState( 74 | access_output=self._access_output_size, 75 | access_state=self._access.state_size, 76 | controller_state=self._controller.state_size) 77 | 78 | def _clip_if_enabled(self, x): 79 | if self._clip_value > 0: 80 | return tf.clip_by_value(x, -self._clip_value, self._clip_value) 81 | else: 82 | return x 83 | 84 | def _build(self, inputs, prev_state): 85 | """Connects the DNC core into the graph. 86 | 87 | Args: 88 | inputs: Tensor input. 89 | prev_state: A `DNCState` tuple containing the fields `access_output`, 90 | `access_state` and `controller_state`. `access_state` is a 3-D Tensor 91 | of shape `[batch_size, num_reads, word_size]` containing read words. 92 | `access_state` is a tuple of the access module's state, and 93 | `controller_state` is a tuple of controller module's state. 94 | 95 | Returns: 96 | A tuple `(output, next_state)` where `output` is a tensor and `next_state` 97 | is a `DNCState` tuple containing the fields `access_output`, 98 | `access_state`, and `controller_state`. 99 | """ 100 | 101 | prev_access_output = prev_state.access_output 102 | prev_access_state = prev_state.access_state 103 | prev_controller_state = prev_state.controller_state 104 | 105 | batch_flatten = snt.BatchFlatten() 106 | controller_input = tf.concat( 107 | [batch_flatten(inputs), batch_flatten(prev_access_output)], 1) 108 | 109 | controller_output, controller_state = self._controller( 110 | controller_input, prev_controller_state) 111 | 112 | controller_output = self._clip_if_enabled(controller_output) 113 | controller_state = tf.contrib.framework.nest.map_structure(self._clip_if_enabled, controller_state) 114 | 115 | access_output, access_state = self._access(controller_output, 116 | prev_access_state) 117 | 118 | output = tf.concat([controller_output, batch_flatten(access_output)], 1) 119 | output = snt.Linear( 120 | output_size=self._output_size.as_list()[0], 121 | name='output_linear')(output) 122 | output = self._clip_if_enabled(output) 123 | 124 | return output, DNCState( 125 | access_output=access_output, 126 | access_state=access_state, 127 | controller_state=controller_state) 128 | 129 | def initial_state(self, batch_size, dtype=tf.float32): 130 | return DNCState( 131 | controller_state=self._controller.initial_state(batch_size, dtype), 132 | access_state=self._access.initial_state(batch_size, dtype), 133 | access_output=tf.zeros( 134 | [batch_size] + self._access.output_size.as_list(), dtype)) 135 | 136 | @property 137 | def state_size(self): 138 | return self._state_size 139 | 140 | @property 141 | def output_size(self): 142 | return self._output_size 143 | -------------------------------------------------------------------------------- /dnc/repeat_copy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A repeat copy task.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import collections 21 | import numpy as np 22 | import sonnet as snt 23 | import tensorflow as tf 24 | 25 | DatasetTensors = collections.namedtuple('DatasetTensors', ('observations', 26 | 'target', 'mask')) 27 | 28 | 29 | def masked_sigmoid_cross_entropy(logits, 30 | target, 31 | mask, 32 | time_average=False, 33 | log_prob_in_bits=False): 34 | """Adds ops to graph which compute the (scalar) NLL of the target sequence. 35 | 36 | The logits parametrize independent bernoulli distributions per time-step and 37 | per batch element, and irrelevant time/batch elements are masked out by the 38 | mask tensor. 39 | 40 | Args: 41 | logits: `Tensor` of activations for which sigmoid(`logits`) gives the 42 | bernoulli parameter. 43 | target: time-major `Tensor` of target. 44 | mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost 45 | masking out irrelevant time-steps. 46 | time_average: optionally average over the time dimension (sum by default). 47 | log_prob_in_bits: iff True express log-probabilities in bits (default nats). 48 | 49 | Returns: 50 | A `Tensor` representing the log-probability of the target. 51 | """ 52 | xent = tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=logits) 53 | loss_time_batch = tf.reduce_sum(xent, axis=2) 54 | loss_batch = tf.reduce_sum(loss_time_batch * mask, axis=0) 55 | 56 | batch_size = tf.cast(tf.shape(logits)[1], dtype=loss_time_batch.dtype) 57 | 58 | if time_average: 59 | mask_count = tf.reduce_sum(mask, axis=0) 60 | loss_batch /= (mask_count + np.finfo(np.float32).eps) 61 | 62 | loss = tf.reduce_sum(loss_batch) / batch_size 63 | if log_prob_in_bits: 64 | loss /= tf.log(2.) 65 | 66 | return loss 67 | 68 | 69 | def bitstring_readable(data, batch_size, model_output=None, whole_batch=False): 70 | """Produce a human readable representation of the sequences in data. 71 | 72 | Args: 73 | data: data to be visualised 74 | batch_size: size of batch 75 | model_output: optional model output tensor to visualize alongside data. 76 | whole_batch: whether to visualise the whole batch. Only the first sample 77 | will be visualized if False 78 | 79 | Returns: 80 | A string used to visualise the data batch 81 | """ 82 | 83 | def _readable(datum): 84 | return '+' + ' '.join(['-' if x == 0 else '%d' % x for x in datum]) + '+' 85 | 86 | obs_batch = data.observations 87 | targ_batch = data.target 88 | 89 | iterate_over = range(batch_size) if whole_batch else range(1) 90 | 91 | batch_strings = [] 92 | for batch_index in iterate_over: 93 | obs = obs_batch[:, batch_index, :] 94 | targ = targ_batch[:, batch_index, :] 95 | 96 | obs_channels = range(obs.shape[1]) 97 | targ_channels = range(targ.shape[1]) 98 | obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels] 99 | targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels] 100 | 101 | readable_obs = 'Observations:\n' + '\n'.join(obs_channel_strings) 102 | readable_targ = 'Targets:\n' + '\n'.join(targ_channel_strings) 103 | strings = [readable_obs, readable_targ] 104 | 105 | if model_output is not None: 106 | output = model_output[:, batch_index, :] 107 | output_strings = [_readable(output[:, i]) for i in targ_channels] 108 | strings.append('Model Output:\n' + '\n'.join(output_strings)) 109 | 110 | batch_strings.append('\n\n'.join(strings)) 111 | 112 | return '\n' + '\n\n\n\n'.join(batch_strings) 113 | 114 | 115 | class RepeatCopy(snt.AbstractModule): 116 | """Sequence data generator for the task of repeating a random binary pattern. 117 | 118 | When called, an instance of this class will return a tuple of tensorflow ops 119 | (obs, targ, mask), representing an input sequence, target sequence, and 120 | binary mask. Each of these ops produces tensors whose first two dimensions 121 | represent sequence position and batch index respectively. The value in 122 | mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be 123 | penalized and 0 otherwise. 124 | 125 | For each realisation from this generator, the observation sequence is 126 | comprised of I.I.D. uniform-random binary vectors (and some flags). 127 | 128 | The target sequence is comprised of this binary pattern repeated 129 | some number of times (and some flags). Before explaining in more detail, 130 | let's examine the setup pictorially for a single batch element: 131 | 132 | ```none 133 | Note: blank space represents 0. 134 | 135 | time ------------------------------------------> 136 | 137 | +-------------------------------+ 138 | mask: |0000000001111111111111111111111| 139 | +-------------------------------+ 140 | 141 | +-------------------------------+ 142 | target: | 1| 'end-marker' channel. 143 | | 101100110110011011001 | 144 | | 010101001010100101010 | 145 | +-------------------------------+ 146 | 147 | +-------------------------------+ 148 | observation: | 1011001 | 149 | | 0101010 | 150 | |1 | 'start-marker' channel 151 | | 3 | 'num-repeats' channel. 152 | +-------------------------------+ 153 | ``` 154 | 155 | The length of the random pattern and the number of times it is repeated 156 | in the target are both discrete random variables distributed according to 157 | uniform distributions whose parameters are configured at construction time. 158 | 159 | The obs sequence has two extra channels (components in the trailing dimension) 160 | which are used for flags. One channel is marked with a 1 at the first time 161 | step and is otherwise equal to 0. The other extra channel is zero until the 162 | binary pattern to be repeated ends. At this point, it contains an encoding of 163 | the number of times the observation pattern should be repeated. Rather than 164 | simply providing this integer number directly, it is normalised so that 165 | a neural network may have an easier time representing the number of 166 | repetitions internally. To allow a network to be readily evaluated on 167 | instances of this task with greater numbers of repetitions, the range with 168 | respect to which this encoding is normalised is also configurable by the user. 169 | 170 | As in the diagram, the target sequence is offset to begin directly after the 171 | observation sequence; both sequences are padded with zeros to accomplish this, 172 | resulting in their lengths being equal. Additional padding is done at the end 173 | so that all sequences in a minibatch represent tensors with the same shape. 174 | """ 175 | 176 | def __init__( 177 | self, 178 | num_bits=6, 179 | batch_size=1, 180 | min_length=1, 181 | max_length=1, 182 | min_repeats=1, 183 | max_repeats=2, 184 | norm_max=10, 185 | log_prob_in_bits=False, 186 | time_average_cost=False, 187 | name='repeat_copy',): 188 | """Creates an instance of RepeatCopy task. 189 | 190 | Args: 191 | name: A name for the generator instance (for name scope purposes). 192 | num_bits: The dimensionality of each random binary vector. 193 | batch_size: Minibatch size per realization. 194 | min_length: Lower limit on number of random binary vectors in the 195 | observation pattern. 196 | max_length: Upper limit on number of random binary vectors in the 197 | observation pattern. 198 | min_repeats: Lower limit on number of times the obervation pattern 199 | is repeated in targ. 200 | max_repeats: Upper limit on number of times the observation pattern 201 | is repeated in targ. 202 | norm_max: Upper limit on uniform distribution w.r.t which the encoding 203 | of the number of repetitions presented in the observation sequence 204 | is normalised. 205 | log_prob_in_bits: By default, log probabilities are expressed in units of 206 | nats. If true, express log probabilities in bits. 207 | time_average_cost: If true, the cost at each time step will be 208 | divided by the `true`, sequence length, the number of non-masked time 209 | steps, in each sequence before any subsequent reduction over the time 210 | and batch dimensions. 211 | """ 212 | super(RepeatCopy, self).__init__(name=name) 213 | 214 | self._batch_size = batch_size 215 | self._num_bits = num_bits 216 | self._min_length = min_length 217 | self._max_length = max_length 218 | self._min_repeats = min_repeats 219 | self._max_repeats = max_repeats 220 | self._norm_max = norm_max 221 | self._log_prob_in_bits = log_prob_in_bits 222 | self._time_average_cost = time_average_cost 223 | 224 | def _normalise(self, val): 225 | return val / self._norm_max 226 | 227 | def _unnormalise(self, val): 228 | return val * self._norm_max 229 | 230 | @property 231 | def time_average_cost(self): 232 | return self._time_average_cost 233 | 234 | @property 235 | def log_prob_in_bits(self): 236 | return self._log_prob_in_bits 237 | 238 | @property 239 | def num_bits(self): 240 | """The dimensionality of each random binary vector in a pattern.""" 241 | return self._num_bits 242 | 243 | @property 244 | def target_size(self): 245 | """The dimensionality of the target tensor.""" 246 | return self._num_bits + 1 247 | 248 | @property 249 | def batch_size(self): 250 | return self._batch_size 251 | 252 | def _build(self): 253 | """Implements build method which adds ops to graph.""" 254 | 255 | # short-hand for private fields. 256 | min_length, max_length = self._min_length, self._max_length 257 | min_reps, max_reps = self._min_repeats, self._max_repeats 258 | num_bits = self.num_bits 259 | batch_size = self.batch_size 260 | 261 | # We reserve one dimension for the num-repeats and one for the start-marker. 262 | full_obs_size = num_bits + 2 263 | # We reserve one target dimension for the end-marker. 264 | full_targ_size = num_bits + 1 265 | start_end_flag_idx = full_obs_size - 2 266 | num_repeats_channel_idx = full_obs_size - 1 267 | 268 | # Samples each batch index's sequence length and the number of repeats. 269 | sub_seq_length_batch = tf.random_uniform( 270 | [batch_size], minval=min_length, maxval=max_length + 1, dtype=tf.int32) 271 | num_repeats_batch = tf.random_uniform( 272 | [batch_size], minval=min_reps, maxval=max_reps + 1, dtype=tf.int32) 273 | 274 | # Pads all the batches to have the same total sequence length. 275 | total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3 276 | max_length_batch = tf.reduce_max(total_length_batch) 277 | residual_length_batch = max_length_batch - total_length_batch 278 | 279 | obs_batch_shape = [max_length_batch, batch_size, full_obs_size] 280 | targ_batch_shape = [max_length_batch, batch_size, full_targ_size] 281 | mask_batch_trans_shape = [batch_size, max_length_batch] 282 | 283 | obs_tensors = [] 284 | targ_tensors = [] 285 | mask_tensors = [] 286 | 287 | # Generates patterns for each batch element independently. 288 | for batch_index in range(batch_size): 289 | sub_seq_len = sub_seq_length_batch[batch_index] 290 | num_reps = num_repeats_batch[batch_index] 291 | 292 | # The observation pattern is a sequence of random binary vectors. 293 | obs_pattern_shape = [sub_seq_len, num_bits] 294 | obs_pattern = tf.cast( 295 | tf.random_uniform( 296 | obs_pattern_shape, minval=0, maxval=2, dtype=tf.int32), 297 | tf.float32) 298 | 299 | # The target pattern is the observation pattern repeated n times. 300 | # Some reshaping is required to accomplish the tiling. 301 | targ_pattern_shape = [sub_seq_len * num_reps, num_bits] 302 | flat_obs_pattern = tf.reshape(obs_pattern, [-1]) 303 | flat_targ_pattern = tf.tile(flat_obs_pattern, tf.stack([num_reps])) 304 | targ_pattern = tf.reshape(flat_targ_pattern, targ_pattern_shape) 305 | 306 | # Expand the obs_pattern to have two extra channels for flags. 307 | # Concatenate start flag and num_reps flag to the sequence. 308 | obs_flag_channel_pad = tf.zeros([sub_seq_len, 2]) 309 | obs_start_flag = tf.one_hot( 310 | [start_end_flag_idx], full_obs_size, on_value=1., off_value=0.) 311 | num_reps_flag = tf.one_hot( 312 | [num_repeats_channel_idx], 313 | full_obs_size, 314 | on_value=self._normalise(tf.cast(num_reps, tf.float32)), 315 | off_value=0.) 316 | 317 | # note the concatenation dimensions. 318 | obs = tf.concat([obs_pattern, obs_flag_channel_pad], 1) 319 | obs = tf.concat([obs_start_flag, obs], 0) 320 | obs = tf.concat([obs, num_reps_flag], 0) 321 | 322 | # Now do the same for the targ_pattern (it only has one extra channel). 323 | targ_flag_channel_pad = tf.zeros([sub_seq_len * num_reps, 1]) 324 | targ_end_flag = tf.one_hot( 325 | [start_end_flag_idx], full_targ_size, on_value=1., off_value=0.) 326 | targ = tf.concat([targ_pattern, targ_flag_channel_pad], 1) 327 | targ = tf.concat([targ, targ_end_flag], 0) 328 | 329 | # Concatenate zeros at end of obs and begining of targ. 330 | # This aligns them s.t. the target begins as soon as the obs ends. 331 | obs_end_pad = tf.zeros([sub_seq_len * num_reps + 1, full_obs_size]) 332 | targ_start_pad = tf.zeros([sub_seq_len + 2, full_targ_size]) 333 | 334 | # The mask is zero during the obs and one during the targ. 335 | mask_off = tf.zeros([sub_seq_len + 2]) 336 | mask_on = tf.ones([sub_seq_len * num_reps + 1]) 337 | 338 | obs = tf.concat([obs, obs_end_pad], 0) 339 | targ = tf.concat([targ_start_pad, targ], 0) 340 | mask = tf.concat([mask_off, mask_on], 0) 341 | 342 | obs_tensors.append(obs) 343 | targ_tensors.append(targ) 344 | mask_tensors.append(mask) 345 | 346 | # End the loop over batch index. 347 | # Compute how much zero padding is needed to make tensors sequences 348 | # the same length for all batch elements. 349 | residual_obs_pad = [ 350 | tf.zeros([residual_length_batch[i], full_obs_size]) 351 | for i in range(batch_size) 352 | ] 353 | residual_targ_pad = [ 354 | tf.zeros([residual_length_batch[i], full_targ_size]) 355 | for i in range(batch_size) 356 | ] 357 | residual_mask_pad = [ 358 | tf.zeros([residual_length_batch[i]]) for i in range(batch_size) 359 | ] 360 | 361 | # Concatenate the pad to each batch element. 362 | obs_tensors = [ 363 | tf.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad) 364 | ] 365 | targ_tensors = [ 366 | tf.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad) 367 | ] 368 | mask_tensors = [ 369 | tf.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad) 370 | ] 371 | 372 | # Concatenate each batch element into a single tensor. 373 | obs = tf.reshape(tf.concat(obs_tensors, 1), obs_batch_shape) 374 | targ = tf.reshape(tf.concat(targ_tensors, 1), targ_batch_shape) 375 | mask = tf.transpose( 376 | tf.reshape(tf.concat(mask_tensors, 0), mask_batch_trans_shape)) 377 | return DatasetTensors(obs, targ, mask) 378 | 379 | def cost(self, logits, targ, mask): 380 | return masked_sigmoid_cross_entropy( 381 | logits, 382 | targ, 383 | mask, 384 | time_average=self.time_average_cost, 385 | log_prob_in_bits=self.log_prob_in_bits) 386 | 387 | def to_human_readable(self, data, model_output=None, whole_batch=False): 388 | obs = data.observations 389 | unnormalised_num_reps_flag = self._unnormalise(obs[:,:,-1:]).round() 390 | obs = np.concatenate([obs[:,:,:-1], unnormalised_num_reps_flag], axis=2) 391 | data = data._replace(observations=obs) 392 | return bitstring_readable(data, self.batch_size, model_output, whole_batch) 393 | -------------------------------------------------------------------------------- /dnc/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """DNC util ops and modules.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | 25 | def batch_invert_permutation(permutations): 26 | """Returns batched `tf.invert_permutation` for every row in `permutations`.""" 27 | with tf.name_scope('batch_invert_permutation', values=[permutations]): 28 | perm = tf.cast(permutations, tf.float32) 29 | dim = int(perm.get_shape()[-1]) 30 | size = tf.cast(tf.shape(perm)[0], tf.float32) 31 | delta = tf.cast(tf.shape(perm)[-1], tf.float32) 32 | rg = tf.range(0, size * delta, delta, dtype=tf.float32) 33 | rg = tf.expand_dims(rg, 1) 34 | rg = tf.tile(rg, [1, dim]) 35 | perm = tf.add(perm, rg) 36 | flat = tf.reshape(perm, [-1]) 37 | perm = tf.invert_permutation(tf.cast(flat, tf.int32)) 38 | perm = tf.reshape(perm, [-1, dim]) 39 | return tf.subtract(perm, tf.cast(rg, tf.int32)) 40 | 41 | 42 | def batch_gather(values, indices): 43 | """Returns batched `tf.gather` for every row in the input.""" 44 | with tf.name_scope('batch_gather', values=[values, indices]): 45 | idx = tf.expand_dims(indices, -1) 46 | size = tf.shape(indices)[0] 47 | rg = tf.range(size, dtype=tf.int32) 48 | rg = tf.expand_dims(rg, -1) 49 | rg = tf.tile(rg, [1, int(indices.get_shape()[-1])]) 50 | rg = tf.expand_dims(rg, -1) 51 | gidx = tf.concat([rg, idx], -1) 52 | return tf.gather_nd(values, gidx) 53 | 54 | 55 | def one_hot(length, index): 56 | """Return an nd array of given `length` filled with 0s and a 1 at `index`.""" 57 | result = np.zeros(length) 58 | result[index] = 1 59 | return result 60 | 61 | def reduce_prod(x, axis, name=None): 62 | """Efficient reduce product over axis. 63 | 64 | Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU. 65 | """ 66 | with tf.name_scope(name, 'util_reduce_prod', values=[x]): 67 | cp = tf.cumprod(x, axis, reverse=True) 68 | size = tf.shape(cp)[0] 69 | idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32) 70 | idx2 = tf.zeros([size], tf.float32) 71 | indices = tf.stack([idx1, idx2], 1) 72 | return tf.gather_nd(cp, tf.cast(indices, tf.int32)) 73 | -------------------------------------------------------------------------------- /dnc/util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for utility functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from dnc import util 25 | 26 | try: 27 | xrange 28 | except NameError: 29 | xrange = range 30 | 31 | 32 | class BatchInvertPermutation(tf.test.TestCase): 33 | 34 | def test(self): 35 | # Tests that the _batch_invert_permutation function correctly inverts a 36 | # batch of permutations. 37 | batch_size = 5 38 | length = 7 39 | 40 | permutations = np.empty([batch_size, length], dtype=int) 41 | for i in xrange(batch_size): 42 | permutations[i] = np.random.permutation(length) 43 | 44 | inverse = util.batch_invert_permutation(tf.constant(permutations, tf.int32)) 45 | with self.test_session(): 46 | inverse = inverse.eval() 47 | 48 | for i in xrange(batch_size): 49 | for j in xrange(length): 50 | self.assertEqual(permutations[i][inverse[i][j]], j) 51 | 52 | 53 | class BatchGather(tf.test.TestCase): 54 | 55 | def test(self): 56 | values = np.array([[3, 1, 4, 1], [5, 9, 2, 6], [5, 3, 5, 7]]) 57 | indexs = np.array([[1, 2, 0, 3], [3, 0, 1, 2], [0, 2, 1, 3]]) 58 | target = np.array([[1, 4, 3, 1], [6, 5, 9, 2], [5, 5, 3, 7]]) 59 | result = util.batch_gather(tf.constant(values), tf.constant(indexs)) 60 | with self.test_session(): 61 | result = result.eval() 62 | self.assertAllEqual(target, result) 63 | 64 | 65 | if __name__ == '__main__': 66 | tf.test.main() 67 | -------------------------------------------------------------------------------- /images/dnc_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dnc/f5981c66ac6fe27978014c314c4aba39f1962ee3/images/dnc_model.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='dnc', 5 | version='0.0.2', 6 | description='This package provides an implementation of the Differentiable Neural Computer, as published in Nature.', 7 | license='Apache Software License 2.0', 8 | packages=['dnc'], 9 | author='DeepMind', 10 | keywords=['tensorflow', 'differentiable neural computer', 'dnc', 'deepmind', 'deep mind', 'sonnet', 'dm-sonnet', 'machine learning'], 11 | url='https://github.com/deepmind/dnc' 12 | ) 13 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example script to train the DNC on a repeated copy task.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | import sonnet as snt 23 | 24 | from dnc import dnc 25 | from dnc import repeat_copy 26 | 27 | FLAGS = tf.flags.FLAGS 28 | 29 | # Model parameters 30 | tf.flags.DEFINE_integer("hidden_size", 64, "Size of LSTM hidden layer.") 31 | tf.flags.DEFINE_integer("memory_size", 16, "The number of memory slots.") 32 | tf.flags.DEFINE_integer("word_size", 16, "The width of each memory slot.") 33 | tf.flags.DEFINE_integer("num_write_heads", 1, "Number of memory write heads.") 34 | tf.flags.DEFINE_integer("num_read_heads", 4, "Number of memory read heads.") 35 | tf.flags.DEFINE_integer("clip_value", 20, 36 | "Maximum absolute value of controller and dnc outputs.") 37 | 38 | # Optimizer parameters. 39 | tf.flags.DEFINE_float("max_grad_norm", 50, "Gradient clipping norm limit.") 40 | tf.flags.DEFINE_float("learning_rate", 1e-4, "Optimizer learning rate.") 41 | tf.flags.DEFINE_float("optimizer_epsilon", 1e-10, 42 | "Epsilon used for RMSProp optimizer.") 43 | 44 | # Task parameters 45 | tf.flags.DEFINE_integer("batch_size", 16, "Batch size for training.") 46 | tf.flags.DEFINE_integer("num_bits", 4, "Dimensionality of each vector to copy") 47 | tf.flags.DEFINE_integer( 48 | "min_length", 1, 49 | "Lower limit on number of vectors in the observation pattern to copy") 50 | tf.flags.DEFINE_integer( 51 | "max_length", 2, 52 | "Upper limit on number of vectors in the observation pattern to copy") 53 | tf.flags.DEFINE_integer("min_repeats", 1, 54 | "Lower limit on number of copy repeats.") 55 | tf.flags.DEFINE_integer("max_repeats", 2, 56 | "Upper limit on number of copy repeats.") 57 | 58 | # Training options. 59 | tf.flags.DEFINE_integer("num_training_iterations", 100000, 60 | "Number of iterations to train for.") 61 | tf.flags.DEFINE_integer("report_interval", 100, 62 | "Iterations between reports (samples, valid loss).") 63 | tf.flags.DEFINE_string("checkpoint_dir", "/tmp/tf/dnc", 64 | "Checkpointing directory.") 65 | tf.flags.DEFINE_integer("checkpoint_interval", -1, 66 | "Checkpointing step interval.") 67 | 68 | 69 | def run_model(input_sequence, output_size): 70 | """Runs model on input sequence.""" 71 | 72 | access_config = { 73 | "memory_size": FLAGS.memory_size, 74 | "word_size": FLAGS.word_size, 75 | "num_reads": FLAGS.num_read_heads, 76 | "num_writes": FLAGS.num_write_heads, 77 | } 78 | controller_config = { 79 | "hidden_size": FLAGS.hidden_size, 80 | } 81 | clip_value = FLAGS.clip_value 82 | 83 | dnc_core = dnc.DNC(access_config, controller_config, output_size, clip_value) 84 | initial_state = dnc_core.initial_state(FLAGS.batch_size) 85 | output_sequence, _ = tf.nn.dynamic_rnn( 86 | cell=dnc_core, 87 | inputs=input_sequence, 88 | time_major=True, 89 | initial_state=initial_state) 90 | 91 | return output_sequence 92 | 93 | 94 | def train(num_training_iterations, report_interval): 95 | """Trains the DNC and periodically reports the loss.""" 96 | 97 | dataset = repeat_copy.RepeatCopy(FLAGS.num_bits, FLAGS.batch_size, 98 | FLAGS.min_length, FLAGS.max_length, 99 | FLAGS.min_repeats, FLAGS.max_repeats) 100 | dataset_tensors = dataset() 101 | 102 | output_logits = run_model(dataset_tensors.observations, dataset.target_size) 103 | # Used for visualization. 104 | output = tf.round( 105 | tf.expand_dims(dataset_tensors.mask, -1) * tf.sigmoid(output_logits)) 106 | 107 | train_loss = dataset.cost(output_logits, dataset_tensors.target, 108 | dataset_tensors.mask) 109 | 110 | # Set up optimizer with global norm clipping. 111 | trainable_variables = tf.trainable_variables() 112 | grads, _ = tf.clip_by_global_norm( 113 | tf.gradients(train_loss, trainable_variables), FLAGS.max_grad_norm) 114 | 115 | global_step = tf.get_variable( 116 | name="global_step", 117 | shape=[], 118 | dtype=tf.int64, 119 | initializer=tf.zeros_initializer(), 120 | trainable=False, 121 | collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP]) 122 | 123 | optimizer = tf.train.RMSPropOptimizer( 124 | FLAGS.learning_rate, epsilon=FLAGS.optimizer_epsilon) 125 | train_step = optimizer.apply_gradients( 126 | zip(grads, trainable_variables), global_step=global_step) 127 | 128 | saver = tf.train.Saver() 129 | 130 | if FLAGS.checkpoint_interval > 0: 131 | hooks = [ 132 | tf.train.CheckpointSaverHook( 133 | checkpoint_dir=FLAGS.checkpoint_dir, 134 | save_steps=FLAGS.checkpoint_interval, 135 | saver=saver) 136 | ] 137 | else: 138 | hooks = [] 139 | 140 | # Train. 141 | with tf.train.SingularMonitoredSession( 142 | hooks=hooks, checkpoint_dir=FLAGS.checkpoint_dir) as sess: 143 | 144 | start_iteration = sess.run(global_step) 145 | total_loss = 0 146 | 147 | for train_iteration in range(start_iteration, num_training_iterations): 148 | _, loss = sess.run([train_step, train_loss]) 149 | total_loss += loss 150 | 151 | if (train_iteration + 1) % report_interval == 0: 152 | dataset_tensors_np, output_np = sess.run([dataset_tensors, output]) 153 | dataset_string = dataset.to_human_readable(dataset_tensors_np, 154 | output_np) 155 | tf.logging.info("%d: Avg training loss %f.\n%s", 156 | train_iteration, total_loss / report_interval, 157 | dataset_string) 158 | total_loss = 0 159 | 160 | 161 | def main(unused_argv): 162 | tf.logging.set_verbosity(3) # Print INFO log messages. 163 | train(FLAGS.num_training_iterations, FLAGS.report_interval) 164 | 165 | 166 | if __name__ == "__main__": 167 | tf.app.run() 168 | --------------------------------------------------------------------------------