├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dataset_reader.py ├── ensembles.py ├── model.py ├── scores.py ├── train.py └── utils.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vector-based navigation using grid-like representations in artificial agents 2 | 3 | This package provides an implementation of the supervised learning experiments 4 | in Vector-based navigation using grid-like representations in artificial agents, 5 | as [published in Nature](https://www.nature.com/articles/s41586-018-0102-6) 6 | 7 | Any publication that discloses findings arising from using this source code must 8 | cite "Banino et al. "Vector-based navigation using grid-like representations in 9 | artificial agents." Nature 557.7705 (2018): 429." 10 | 11 | ## Introduction 12 | 13 | The grid-cell network is a recurrent deep neural network (LSTM). This network 14 | learns to path integrate within a square arena, using simulated trajectories 15 | modelled on those of foraging rodents. The network is required to update its 16 | estimate of location and head direction using translational and angular velocity 17 | signals which are provided as input. The output of the LSTM projects to place 18 | and head direction units via a linear layer which is subject to regularization. 19 | The vector of activities in the place and head direction units, corresponding to 20 | the current position, was provided as a supervised training signal at each time 21 | step. 22 | 23 | The dataset needed to run this code can be downloaded from 24 | [here](https://console.cloud.google.com/storage/browser/grid-cells-datasets). 25 | 26 | The files contained in the repository are the following: 27 | 28 | * `train.py` is where the training and logging loop happen; The file comes 29 | with the flags defined in Table 1 of the paper. In order to run this file 30 | you will need to specify where the dataset is stored and where you want to 31 | save the results. The results are saved in PDF format and they contains the 32 | ratemaps and the spatial autocorrelagram order by grid score. The units are 33 | ordered from higher to lower grid score. Only the last evaluation is saved. 34 | Please note that given random seeds results can vary between runs. 35 | 36 | * `data_reader.py` read the TFRecord and returns a ready to use batch, which 37 | is already shuffled. 38 | 39 | * `model.py` contains the grid-cells network 40 | 41 | * `scores.py` contains all the function for calculating the grid scores and 42 | doing the plotting. 43 | 44 | * `ensembles.py` contains the classes to generate the targets for training of 45 | the grid-cell networks. 46 | 47 | ## Train 48 | 49 | The implementation requires an installation of 50 | [TensorFlow](https://www.tensorflow.org/) version 1.12, and 51 | [Sonnet](https://github.com/deepmind/sonnet) version 1.27. 52 | 53 | ```shell 54 | $ virtualenv env 55 | $ source env/bin/activate 56 | $ pip install --upgrade numpy==1.13.3 57 | $ pip install --upgrade tensorflow==1.12.0-rc0 58 | $ pip install --upgrade dm-sonnet==1.27 59 | $ pip install --upgrade scipy==1.0.0 60 | $ pip install --upgrade matplotlib==1.5.2 61 | $ pip install --upgrade tensorflow-probability==0.5.0 62 | $ pip install --upgrade wrapt==1.9.0 63 | ``` 64 | 65 | An example training script can be executed from a python interpreter: 66 | 67 | ```shell 68 | $ python train.py --task_root='path/to/datasets/root/folder' --saver_results_directory='path/to/results/folder' 69 | ``` 70 | 71 | Disclaimer: This is not an official Google product. 72 | 73 | -------------------------------------------------------------------------------- /dataset_reader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Minimal queue based TFRecord reader for the Grid Cell paper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import os 24 | import tensorflow as tf 25 | nest = tf.contrib.framework.nest 26 | 27 | DatasetInfo = collections.namedtuple( 28 | 'DatasetInfo', ['basepath', 'size', 'sequence_length', 'coord_range']) 29 | 30 | _DATASETS = dict( 31 | square_room=DatasetInfo( 32 | basepath='square_room_100steps_2.2m_1000000', 33 | size=100, 34 | sequence_length=100, 35 | coord_range=((-1.1, 1.1), (-1.1, 1.1))),) 36 | 37 | 38 | def _get_dataset_files(dateset_info, root): 39 | """Generates lists of files for a given dataset version.""" 40 | basepath = dateset_info.basepath 41 | base = os.path.join(root, basepath) 42 | num_files = dateset_info.size 43 | template = '{:0%d}-of-{:0%d}.tfrecord' % (4, 4) 44 | return [ 45 | os.path.join(base, template.format(i, num_files - 1)) 46 | for i in range(num_files) 47 | ] 48 | 49 | 50 | class DataReader(object): 51 | """Minimal queue based TFRecord reader. 52 | 53 | You can use this reader to load the datasets used to train the grid cell 54 | network in the 'Vector-based Navigation using Grid-like Representations 55 | in Artificial Agents' paper. 56 | See README.md for a description of the datasets and an example of how to use 57 | the reader. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | dataset, 63 | root, 64 | # Queue params 65 | num_threads=4, 66 | capacity=256, 67 | min_after_dequeue=128, 68 | seed=None): 69 | """Instantiates a DataReader object and sets up queues for data reading. 70 | 71 | Args: 72 | dataset: string, one of ['jaco', 'mazes', 'rooms_ring_camera', 73 | 'rooms_free_camera_no_object_rotations', 74 | 'rooms_free_camera_with_object_rotations', 'shepard_metzler_5_parts', 75 | 'shepard_metzler_7_parts']. 76 | root: string, path to the root folder of the data. 77 | num_threads: (optional) integer, number of threads used to feed the reader 78 | queues, defaults to 4. 79 | capacity: (optional) integer, capacity of the underlying 80 | RandomShuffleQueue, defaults to 256. 81 | min_after_dequeue: (optional) integer, min_after_dequeue of the underlying 82 | RandomShuffleQueue, defaults to 128. 83 | seed: (optional) integer, seed for the random number generators used in 84 | the reader. 85 | 86 | Raises: 87 | ValueError: if the required version does not exist; 88 | """ 89 | 90 | if dataset not in _DATASETS: 91 | raise ValueError('Unrecognized dataset {} requested. Available datasets ' 92 | 'are {}'.format(dataset, _DATASETS.keys())) 93 | 94 | self._dataset_info = _DATASETS[dataset] 95 | self._steps = _DATASETS[dataset].sequence_length 96 | 97 | with tf.device('/cpu'): 98 | file_names = _get_dataset_files(self._dataset_info, root) 99 | filename_queue = tf.train.string_input_producer(file_names, seed=seed) 100 | reader = tf.TFRecordReader() 101 | 102 | read_ops = [ 103 | self._make_read_op(reader, filename_queue) for _ in range(num_threads) 104 | ] 105 | dtypes = nest.map_structure(lambda x: x.dtype, read_ops[0]) 106 | shapes = nest.map_structure(lambda x: x.shape[1:], read_ops[0]) 107 | 108 | self._queue = tf.RandomShuffleQueue( 109 | capacity=capacity, 110 | min_after_dequeue=min_after_dequeue, 111 | dtypes=dtypes, 112 | shapes=shapes, 113 | seed=seed) 114 | 115 | enqueue_ops = [self._queue.enqueue_many(op) for op in read_ops] 116 | tf.train.add_queue_runner(tf.train.QueueRunner(self._queue, enqueue_ops)) 117 | 118 | def read(self, batch_size): 119 | """Reads batch_size.""" 120 | in_pos, in_hd, ego_vel, target_pos, target_hd = self._queue.dequeue_many( 121 | batch_size) 122 | return in_pos, in_hd, ego_vel, target_pos, target_hd 123 | 124 | def get_coord_range(self): 125 | return self._dataset_info.coord_range 126 | 127 | def _make_read_op(self, reader, filename_queue): 128 | """Instantiates the ops used to read and parse the data into tensors.""" 129 | _, raw_data = reader.read_up_to(filename_queue, num_records=64) 130 | feature_map = { 131 | 'init_pos': 132 | tf.FixedLenFeature(shape=[2], dtype=tf.float32), 133 | 'init_hd': 134 | tf.FixedLenFeature(shape=[1], dtype=tf.float32), 135 | 'ego_vel': 136 | tf.FixedLenFeature( 137 | shape=[self._dataset_info.sequence_length, 3], 138 | dtype=tf.float32), 139 | 'target_pos': 140 | tf.FixedLenFeature( 141 | shape=[self._dataset_info.sequence_length, 2], 142 | dtype=tf.float32), 143 | 'target_hd': 144 | tf.FixedLenFeature( 145 | shape=[self._dataset_info.sequence_length, 1], 146 | dtype=tf.float32), 147 | } 148 | example = tf.parse_example(raw_data, feature_map) 149 | batch = [ 150 | example['init_pos'], example['init_hd'], 151 | example['ego_vel'][:, :self._steps, :], 152 | example['target_pos'][:, :self._steps, :], 153 | example['target_hd'][:, :self._steps, :] 154 | ] 155 | return batch 156 | -------------------------------------------------------------------------------- /ensembles.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Ensembles of place and head direction cells. 17 | 18 | These classes provide the targets for the training of grid-cell networks. 19 | 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | 30 | def one_hot_max(x, axis=-1): 31 | """Compute one-hot vectors setting to one the index with the maximum value.""" 32 | return tf.one_hot(tf.argmax(x, axis=axis), 33 | depth=x.get_shape()[-1], 34 | dtype=x.dtype) 35 | 36 | 37 | def softmax(x, axis=-1): 38 | """Compute softmax values for each sets of scores in x.""" 39 | return tf.nn.softmax(x, dim=axis) 40 | 41 | 42 | def softmax_sample(x): 43 | """Sample the categorical distribution from logits and sample it.""" 44 | dist = tf.contrib.distributions.OneHotCategorical(logits=x, dtype=tf.float32) 45 | return dist.sample() 46 | 47 | 48 | class CellEnsemble(object): 49 | """Abstract parent class for place and head direction cell ensembles.""" 50 | 51 | def __init__(self, n_cells, soft_targets, soft_init): 52 | self.n_cells = n_cells 53 | if soft_targets not in ["softmax", "voronoi", "sample", "normalized"]: 54 | raise ValueError 55 | else: 56 | self.soft_targets = soft_targets 57 | # Provide initialization of LSTM in the same way as targets if not specified 58 | # i.e one-hot if targets are Voronoi 59 | if soft_init is None: 60 | self.soft_init = soft_targets 61 | else: 62 | if soft_init not in [ 63 | "softmax", "voronoi", "sample", "normalized", "zeros" 64 | ]: 65 | raise ValueError 66 | else: 67 | self.soft_init = soft_init 68 | 69 | def get_targets(self, x): 70 | """Type of target.""" 71 | 72 | if self.soft_targets == "normalized": 73 | targets = tf.exp(self.unnor_logpdf(x)) 74 | elif self.soft_targets == "softmax": 75 | lp = self.log_posterior(x) 76 | targets = softmax(lp) 77 | elif self.soft_targets == "sample": 78 | lp = self.log_posterior(x) 79 | targets = softmax_sample(lp) 80 | elif self.soft_targets == "voronoi": 81 | lp = self.log_posterior(x) 82 | targets = one_hot_max(lp) 83 | return targets 84 | 85 | def get_init(self, x): 86 | """Type of initialisation.""" 87 | 88 | if self.soft_init == "normalized": 89 | init = tf.exp(self.unnor_logpdf(x)) 90 | elif self.soft_init == "softmax": 91 | lp = self.log_posterior(x) 92 | init = softmax(lp) 93 | elif self.soft_init == "sample": 94 | lp = self.log_posterior(x) 95 | init = softmax_sample(lp) 96 | elif self.soft_init == "voronoi": 97 | lp = self.log_posterior(x) 98 | init = one_hot_max(lp) 99 | elif self.soft_init == "zeros": 100 | init = tf.zeros_like(self.unnor_logpdf(x)) 101 | return init 102 | 103 | def loss(self, predictions, targets): 104 | """Loss.""" 105 | 106 | if self.soft_targets == "normalized": 107 | smoothing = 1e-2 108 | loss = tf.nn.sigmoid_cross_entropy_with_logits( 109 | labels=(1. - smoothing) * targets + smoothing * 0.5, 110 | logits=predictions, 111 | name="ensemble_loss") 112 | loss = tf.reduce_mean(loss, axis=-1) 113 | else: 114 | loss = tf.nn.softmax_cross_entropy_with_logits( 115 | labels=targets, 116 | logits=predictions, 117 | name="ensemble_loss") 118 | return loss 119 | 120 | def log_posterior(self, x): 121 | logp = self.unnor_logpdf(x) 122 | log_posteriors = logp - tf.reduce_logsumexp(logp, axis=2, keep_dims=True) 123 | return log_posteriors 124 | 125 | 126 | class PlaceCellEnsemble(CellEnsemble): 127 | """Calculates the dist over place cells given an absolute position.""" 128 | 129 | def __init__(self, n_cells, stdev=0.35, pos_min=-5, pos_max=5, seed=None, 130 | soft_targets=None, soft_init=None): 131 | super(PlaceCellEnsemble, self).__init__(n_cells, soft_targets, soft_init) 132 | # Create a random MoG with fixed cov over the position (Nx2) 133 | rs = np.random.RandomState(seed) 134 | self.means = rs.uniform(pos_min, pos_max, size=(self.n_cells, 2)) 135 | self.variances = np.ones_like(self.means) * stdev**2 136 | 137 | def unnor_logpdf(self, trajs): 138 | # Output the probability of each component at each point (BxTxN) 139 | diff = trajs[:, :, tf.newaxis, :] - self.means[np.newaxis, np.newaxis, ...] 140 | unnor_logp = -0.5 * tf.reduce_sum((diff**2)/ self.variances, axis=-1) 141 | return unnor_logp 142 | 143 | 144 | class HeadDirectionCellEnsemble(CellEnsemble): 145 | """Calculates the dist over HD cells given an absolute angle.""" 146 | 147 | def __init__(self, n_cells, concentration=20, seed=None, 148 | soft_targets=None, soft_init=None): 149 | super(HeadDirectionCellEnsemble, self).__init__(n_cells, 150 | soft_targets, 151 | soft_init) 152 | # Create a random Von Mises with fixed cov over the position 153 | rs = np.random.RandomState(seed) 154 | self.means = rs.uniform(-np.pi, np.pi, (n_cells)) 155 | self.kappa = np.ones_like(self.means) * concentration 156 | 157 | def unnor_logpdf(self, x): 158 | return self.kappa * tf.cos(x - self.means[np.newaxis, np.newaxis, :]) 159 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model for grid cells supervised training. 17 | 18 | """ 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy 24 | import sonnet as snt 25 | import tensorflow as tf 26 | 27 | 28 | def displaced_linear_initializer(input_size, displace, dtype=tf.float32): 29 | stddev = 1. / numpy.sqrt(input_size) 30 | return tf.truncated_normal_initializer( 31 | mean=displace*stddev, stddev=stddev, dtype=dtype) 32 | 33 | 34 | class GridCellsRNNCell(snt.RNNCore): 35 | """LSTM core implementation for the grid cell network.""" 36 | 37 | def __init__(self, 38 | target_ensembles, 39 | nh_lstm, 40 | nh_bottleneck, 41 | nh_embed=None, 42 | dropoutrates_bottleneck=None, 43 | bottleneck_weight_decay=0.0, 44 | bottleneck_has_bias=False, 45 | init_weight_disp=0.0, 46 | name="grid_cells_core"): 47 | """Constructor of the RNN cell. 48 | 49 | Args: 50 | target_ensembles: Targets, place cells and head direction cells. 51 | nh_lstm: Size of LSTM cell. 52 | nh_bottleneck: Size of the linear layer between LSTM output and output. 53 | nh_embed: Number of hiddens between input and LSTM input. 54 | dropoutrates_bottleneck: Iterable of keep rates (0,1]. The linear layer is 55 | partitioned into as many groups as the len of this parameter. 56 | bottleneck_weight_decay: Weight decay used in the bottleneck layer. 57 | bottleneck_has_bias: If the bottleneck has a bias. 58 | init_weight_disp: Displacement in the weights initialisation. 59 | name: the name of the module. 60 | """ 61 | super(GridCellsRNNCell, self).__init__(name=name) 62 | self._target_ensembles = target_ensembles 63 | self._nh_embed = nh_embed 64 | self._nh_lstm = nh_lstm 65 | self._nh_bottleneck = nh_bottleneck 66 | self._dropoutrates_bottleneck = dropoutrates_bottleneck 67 | self._bottleneck_weight_decay = bottleneck_weight_decay 68 | self._bottleneck_has_bias = bottleneck_has_bias 69 | self._init_weight_disp = init_weight_disp 70 | self.training = False 71 | with self._enter_variable_scope(): 72 | self._lstm = snt.LSTM(self._nh_lstm) 73 | 74 | def _build(self, inputs, prev_state): 75 | """Build the module. 76 | 77 | Args: 78 | inputs: Egocentric velocity (BxN) 79 | prev_state: Previous state of the recurrent network 80 | 81 | Returns: 82 | ((predictions, bottleneck, lstm_outputs), next_state) 83 | The predictions 84 | """ 85 | conc_inputs = tf.concat(inputs, axis=1, name="conc_inputs") 86 | # Embedding layer 87 | lstm_inputs = conc_inputs 88 | # LSTM 89 | lstm_output, next_state = self._lstm(lstm_inputs, prev_state) 90 | # Bottleneck 91 | bottleneck = snt.Linear(self._nh_bottleneck, 92 | use_bias=self._bottleneck_has_bias, 93 | regularizers={ 94 | "w": tf.contrib.layers.l2_regularizer( 95 | self._bottleneck_weight_decay)}, 96 | name="bottleneck")(lstm_output) 97 | if self.training and self._dropoutrates_bottleneck is not None: 98 | tf.logging.info("Adding dropout layers") 99 | n_scales = len(self._dropoutrates_bottleneck) 100 | scale_pops = tf.split(bottleneck, n_scales, axis=1) 101 | dropped_pops = [tf.nn.dropout(pop, rate, name="dropout") 102 | for rate, pop in zip(self._dropoutrates_bottleneck, 103 | scale_pops)] 104 | bottleneck = tf.concat(dropped_pops, axis=1) 105 | # Outputs 106 | ens_outputs = [snt.Linear( 107 | ens.n_cells, 108 | regularizers={ 109 | "w": tf.contrib.layers.l2_regularizer( 110 | self._bottleneck_weight_decay)}, 111 | initializers={ 112 | "w": displaced_linear_initializer(self._nh_bottleneck, 113 | self._init_weight_disp, 114 | dtype=tf.float32)}, 115 | name="pc_logits")(bottleneck) 116 | for ens in self._target_ensembles] 117 | return (ens_outputs, bottleneck, lstm_output), tuple(list(next_state)) 118 | 119 | @property 120 | def state_size(self): 121 | """Returns a description of the state size, without batch dimension.""" 122 | return self._lstm.state_size 123 | 124 | @property 125 | def output_size(self): 126 | """Returns a description of the output size, without batch dimension.""" 127 | return tuple([ens.n_cells for ens in self._target_ensembles] + 128 | [self._nh_bottleneck, self._nh_lstm]) 129 | 130 | 131 | class GridCellsRNN(snt.AbstractModule): 132 | """RNN computes place and head-direction cell predictions from velocities.""" 133 | 134 | def __init__(self, rnn_cell, nh_lstm, name="grid_cell_supervised"): 135 | super(GridCellsRNN, self).__init__(name=name) 136 | self._core = rnn_cell 137 | self._nh_lstm = nh_lstm 138 | 139 | def _build(self, init_conds, vels, training=False): 140 | """Outputs place, and head direction cell predictions from velocity inputs. 141 | 142 | Args: 143 | init_conds: Initial conditions given by ensemble activatons, list [BxN_i] 144 | vels: Translational and angular velocities [BxTxV] 145 | training: Activates and deactivates dropout 146 | 147 | Returns: 148 | [logits_i]: 149 | logits_i: Logits predicting i-th ensemble activations (BxTxN_i) 150 | """ 151 | # Calculate initialization for LSTM. Concatenate pc and hdc activations 152 | concat_init = tf.concat(init_conds, axis=1) 153 | 154 | init_lstm_state = snt.Linear(self._nh_lstm, name="state_init")(concat_init) 155 | init_lstm_cell = snt.Linear(self._nh_lstm, name="cell_init")(concat_init) 156 | self._core.training = training 157 | 158 | # Run LSTM 159 | output_seq, final_state = tf.nn.dynamic_rnn(cell=self._core, 160 | inputs=(vels,), 161 | time_major=False, 162 | initial_state=(init_lstm_state, 163 | init_lstm_cell)) 164 | ens_targets = output_seq[:-2] 165 | bottleneck = output_seq[-2] 166 | lstm_output = output_seq[-1] 167 | # Return 168 | return (ens_targets, bottleneck, lstm_output), final_state 169 | 170 | def get_all_variables(self): 171 | return (super(GridCellsRNN, self).get_variables() 172 | + self._core.get_variables()) 173 | -------------------------------------------------------------------------------- /scores.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Grid score calculations. 17 | 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import math 25 | import matplotlib.pyplot as plt 26 | import numpy as np 27 | import scipy.signal 28 | 29 | 30 | def circle_mask(size, radius, in_val=1.0, out_val=0.0): 31 | """Calculating the grid scores with different radius.""" 32 | sz = [math.floor(size[0] / 2), math.floor(size[1] / 2)] 33 | x = np.linspace(-sz[0], sz[1], size[1]) 34 | x = np.expand_dims(x, 0) 35 | x = x.repeat(size[0], 0) 36 | y = np.linspace(-sz[0], sz[1], size[1]) 37 | y = np.expand_dims(y, 1) 38 | y = y.repeat(size[1], 1) 39 | z = np.sqrt(x**2 + y**2) 40 | z = np.less_equal(z, radius) 41 | vfunc = np.vectorize(lambda b: b and in_val or out_val) 42 | return vfunc(z) 43 | 44 | 45 | class GridScorer(object): 46 | """Class for scoring ratemaps given trajectories.""" 47 | 48 | def __init__(self, nbins, coords_range, mask_parameters, min_max=False): 49 | """Scoring ratemaps given trajectories. 50 | 51 | Args: 52 | nbins: Number of bins per dimension in the ratemap. 53 | coords_range: Environment coordinates range. 54 | mask_parameters: parameters for the masks that analyze the angular 55 | autocorrelation of the 2D autocorrelation. 56 | min_max: Correction. 57 | """ 58 | self._nbins = nbins 59 | self._min_max = min_max 60 | self._coords_range = coords_range 61 | self._corr_angles = [30, 45, 60, 90, 120, 135, 150] 62 | # Create all masks 63 | self._masks = [(self._get_ring_mask(mask_min, mask_max), (mask_min, 64 | mask_max)) 65 | for mask_min, mask_max in mask_parameters] 66 | # Mask for hiding the parts of the SAC that are never used 67 | self._plotting_sac_mask = circle_mask( 68 | [self._nbins * 2 - 1, self._nbins * 2 - 1], 69 | self._nbins, 70 | in_val=1.0, 71 | out_val=np.nan) 72 | 73 | def calculate_ratemap(self, xs, ys, activations, statistic='mean'): 74 | return scipy.stats.binned_statistic_2d( 75 | xs, 76 | ys, 77 | activations, 78 | bins=self._nbins, 79 | statistic=statistic, 80 | range=self._coords_range)[0] 81 | 82 | def _get_ring_mask(self, mask_min, mask_max): 83 | n_points = [self._nbins * 2 - 1, self._nbins * 2 - 1] 84 | return (circle_mask(n_points, mask_max * self._nbins) * 85 | (1 - circle_mask(n_points, mask_min * self._nbins))) 86 | 87 | def grid_score_60(self, corr): 88 | if self._min_max: 89 | return np.minimum(corr[60], corr[120]) - np.maximum( 90 | corr[30], np.maximum(corr[90], corr[150])) 91 | else: 92 | return (corr[60] + corr[120]) / 2 - (corr[30] + corr[90] + corr[150]) / 3 93 | 94 | def grid_score_90(self, corr): 95 | return corr[90] - (corr[45] + corr[135]) / 2 96 | 97 | def calculate_sac(self, seq1): 98 | """Calculating spatial autocorrelogram.""" 99 | seq2 = seq1 100 | 101 | def filter2(b, x): 102 | stencil = np.rot90(b, 2) 103 | return scipy.signal.convolve2d(x, stencil, mode='full') 104 | 105 | seq1 = np.nan_to_num(seq1) 106 | seq2 = np.nan_to_num(seq2) 107 | 108 | ones_seq1 = np.ones(seq1.shape) 109 | ones_seq1[np.isnan(seq1)] = 0 110 | ones_seq2 = np.ones(seq2.shape) 111 | ones_seq2[np.isnan(seq2)] = 0 112 | 113 | seq1[np.isnan(seq1)] = 0 114 | seq2[np.isnan(seq2)] = 0 115 | 116 | seq1_sq = np.square(seq1) 117 | seq2_sq = np.square(seq2) 118 | 119 | seq1_x_seq2 = filter2(seq1, seq2) 120 | sum_seq1 = filter2(seq1, ones_seq2) 121 | sum_seq2 = filter2(ones_seq1, seq2) 122 | sum_seq1_sq = filter2(seq1_sq, ones_seq2) 123 | sum_seq2_sq = filter2(ones_seq1, seq2_sq) 124 | n_bins = filter2(ones_seq1, ones_seq2) 125 | n_bins_sq = np.square(n_bins) 126 | 127 | std_seq1 = np.power( 128 | np.subtract( 129 | np.divide(sum_seq1_sq, n_bins), 130 | (np.divide(np.square(sum_seq1), n_bins_sq))), 0.5) 131 | std_seq2 = np.power( 132 | np.subtract( 133 | np.divide(sum_seq2_sq, n_bins), 134 | (np.divide(np.square(sum_seq2), n_bins_sq))), 0.5) 135 | covar = np.subtract( 136 | np.divide(seq1_x_seq2, n_bins), 137 | np.divide(np.multiply(sum_seq1, sum_seq2), n_bins_sq)) 138 | x_coef = np.divide(covar, np.multiply(std_seq1, std_seq2)) 139 | x_coef = np.real(x_coef) 140 | x_coef = np.nan_to_num(x_coef) 141 | return x_coef 142 | 143 | def rotated_sacs(self, sac, angles): 144 | return [ 145 | scipy.ndimage.interpolation.rotate(sac, angle, reshape=False) 146 | for angle in angles 147 | ] 148 | 149 | def get_grid_scores_for_mask(self, sac, rotated_sacs, mask): 150 | """Calculate Pearson correlations of area inside mask at corr_angles.""" 151 | masked_sac = sac * mask 152 | ring_area = np.sum(mask) 153 | # Calculate dc on the ring area 154 | masked_sac_mean = np.sum(masked_sac) / ring_area 155 | # Center the sac values inside the ring 156 | masked_sac_centered = (masked_sac - masked_sac_mean) * mask 157 | variance = np.sum(masked_sac_centered**2) / ring_area + 1e-5 158 | corrs = dict() 159 | for angle, rotated_sac in zip(self._corr_angles, rotated_sacs): 160 | masked_rotated_sac = (rotated_sac - masked_sac_mean) * mask 161 | cross_prod = np.sum(masked_sac_centered * masked_rotated_sac) / ring_area 162 | corrs[angle] = cross_prod / variance 163 | return self.grid_score_60(corrs), self.grid_score_90(corrs), variance 164 | 165 | def get_scores(self, rate_map): 166 | """Get summary of scrores for grid cells.""" 167 | sac = self.calculate_sac(rate_map) 168 | rotated_sacs = self.rotated_sacs(sac, self._corr_angles) 169 | 170 | scores = [ 171 | self.get_grid_scores_for_mask(sac, rotated_sacs, mask) 172 | for mask, mask_params in self._masks # pylint: disable=unused-variable 173 | ] 174 | scores_60, scores_90, variances = map(np.asarray, zip(*scores)) # pylint: disable=unused-variable 175 | max_60_ind = np.argmax(scores_60) 176 | max_90_ind = np.argmax(scores_90) 177 | 178 | return (scores_60[max_60_ind], scores_90[max_90_ind], 179 | self._masks[max_60_ind][1], self._masks[max_90_ind][1], sac) 180 | 181 | def plot_ratemap(self, ratemap, ax=None, title=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg 182 | """Plot ratemaps.""" 183 | if ax is None: 184 | ax = plt.gca() 185 | # Plot the ratemap 186 | ax.imshow(ratemap, interpolation='none', *args, **kwargs) 187 | # ax.pcolormesh(ratemap, *args, **kwargs) 188 | ax.axis('off') 189 | if title is not None: 190 | ax.set_title(title) 191 | 192 | def plot_sac(self, 193 | sac, 194 | mask_params=None, 195 | ax=None, 196 | title=None, 197 | *args, 198 | **kwargs): # pylint: disable=keyword-arg-before-vararg 199 | """Plot spatial autocorrelogram.""" 200 | if ax is None: 201 | ax = plt.gca() 202 | # Plot the sac 203 | useful_sac = sac * self._plotting_sac_mask 204 | ax.imshow(useful_sac, interpolation='none', *args, **kwargs) 205 | # ax.pcolormesh(useful_sac, *args, **kwargs) 206 | # Plot a ring for the adequate mask 207 | if mask_params is not None: 208 | center = self._nbins - 1 209 | ax.add_artist( 210 | plt.Circle( 211 | (center, center), 212 | mask_params[0] * self._nbins, 213 | # lw=bump_size, 214 | fill=False, 215 | edgecolor='k')) 216 | ax.add_artist( 217 | plt.Circle( 218 | (center, center), 219 | mask_params[1] * self._nbins, 220 | # lw=bump_size, 221 | fill=False, 222 | edgecolor='k')) 223 | ax.axis('off') 224 | if title is not None: 225 | ax.set_title(title) 226 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Supervised training for the Grid cell network.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import matplotlib 23 | import numpy as np 24 | import tensorflow as tf 25 | import Tkinter # pylint: disable=unused-import 26 | 27 | matplotlib.use('Agg') 28 | 29 | import dataset_reader # pylint: disable=g-bad-import-order, g-import-not-at-top 30 | import model # pylint: disable=g-bad-import-order 31 | import scores # pylint: disable=g-bad-import-order 32 | import utils # pylint: disable=g-bad-import-order 33 | 34 | 35 | # Task config 36 | tf.flags.DEFINE_string('task_dataset_info', 'square_room', 37 | 'Name of the room in which the experiment is performed.') 38 | tf.flags.DEFINE_string('task_root', 39 | None, 40 | 'Dataset path.') 41 | tf.flags.DEFINE_float('task_env_size', 2.2, 42 | 'Environment size (meters).') 43 | tf.flags.DEFINE_list('task_n_pc', [256], 44 | 'Number of target place cells.') 45 | tf.flags.DEFINE_list('task_pc_scale', [0.01], 46 | 'Place cell standard deviation parameter (meters).') 47 | tf.flags.DEFINE_list('task_n_hdc', [12], 48 | 'Number of target head direction cells.') 49 | tf.flags.DEFINE_list('task_hdc_concentration', [20.], 50 | 'Head direction concentration parameter.') 51 | tf.flags.DEFINE_integer('task_neurons_seed', 8341, 52 | 'Seeds.') 53 | tf.flags.DEFINE_string('task_targets_type', 'softmax', 54 | 'Type of target, soft or hard.') 55 | tf.flags.DEFINE_string('task_lstm_init_type', 'softmax', 56 | 'Type of LSTM initialisation, soft or hard.') 57 | tf.flags.DEFINE_bool('task_velocity_inputs', True, 58 | 'Input velocity.') 59 | tf.flags.DEFINE_list('task_velocity_noise', [0.0, 0.0, 0.0], 60 | 'Add noise to velocity.') 61 | 62 | # Model config 63 | tf.flags.DEFINE_integer('model_nh_lstm', 128, 'Number of hidden units in LSTM.') 64 | tf.flags.DEFINE_integer('model_nh_bottleneck', 256, 65 | 'Number of hidden units in linear bottleneck.') 66 | tf.flags.DEFINE_list('model_dropout_rates', [0.5], 67 | 'List of floats with dropout rates.') 68 | tf.flags.DEFINE_float('model_weight_decay', 1e-5, 69 | 'Weight decay regularisation') 70 | tf.flags.DEFINE_bool('model_bottleneck_has_bias', False, 71 | 'Whether to include a bias in linear bottleneck') 72 | tf.flags.DEFINE_float('model_init_weight_disp', 0.0, 73 | 'Initial weight displacement.') 74 | 75 | # Training config 76 | tf.flags.DEFINE_integer('training_epochs', 1000, 'Number of training epochs.') 77 | tf.flags.DEFINE_integer('training_steps_per_epoch', 1000, 78 | 'Number of optimization steps per epoch.') 79 | tf.flags.DEFINE_integer('training_minibatch_size', 10, 80 | 'Size of the training minibatch.') 81 | tf.flags.DEFINE_integer('training_evaluation_minibatch_size', 4000, 82 | 'Size of the minibatch during evaluation.') 83 | tf.flags.DEFINE_string('training_clipping_function', 'utils.clip_all_gradients', 84 | 'Function for gradient clipping.') 85 | tf.flags.DEFINE_float('training_clipping', 1e-5, 86 | 'The absolute value to clip by.') 87 | 88 | tf.flags.DEFINE_string('training_optimizer_class', 'tf.train.RMSPropOptimizer', 89 | 'The optimizer used for training.') 90 | tf.flags.DEFINE_string('training_optimizer_options', 91 | '{"learning_rate": 1e-5, "momentum": 0.9}', 92 | 'Defines a dict with opts passed to the optimizer.') 93 | 94 | # Store 95 | tf.flags.DEFINE_string('saver_results_directory', 96 | None, 97 | 'Path to directory for saving results.') 98 | tf.flags.DEFINE_integer('saver_eval_time', 2, 99 | 'Frequency at which results are saved.') 100 | 101 | # Require flags 102 | tf.flags.mark_flag_as_required('task_root') 103 | tf.flags.mark_flag_as_required('saver_results_directory') 104 | FLAGS = tf.flags.FLAGS 105 | 106 | 107 | def train(): 108 | """Training loop.""" 109 | 110 | tf.reset_default_graph() 111 | 112 | # Create the motion models for training and evaluation 113 | data_reader = dataset_reader.DataReader( 114 | FLAGS.task_dataset_info, root=FLAGS.task_root, num_threads=4) 115 | train_traj = data_reader.read(batch_size=FLAGS.training_minibatch_size) 116 | 117 | # Create the ensembles that provide targets during training 118 | place_cell_ensembles = utils.get_place_cell_ensembles( 119 | env_size=FLAGS.task_env_size, 120 | neurons_seed=FLAGS.task_neurons_seed, 121 | targets_type=FLAGS.task_targets_type, 122 | lstm_init_type=FLAGS.task_lstm_init_type, 123 | n_pc=FLAGS.task_n_pc, 124 | pc_scale=FLAGS.task_pc_scale) 125 | 126 | head_direction_ensembles = utils.get_head_direction_ensembles( 127 | neurons_seed=FLAGS.task_neurons_seed, 128 | targets_type=FLAGS.task_targets_type, 129 | lstm_init_type=FLAGS.task_lstm_init_type, 130 | n_hdc=FLAGS.task_n_hdc, 131 | hdc_concentration=FLAGS.task_hdc_concentration) 132 | target_ensembles = place_cell_ensembles + head_direction_ensembles 133 | 134 | # Model creation 135 | rnn_core = model.GridCellsRNNCell( 136 | target_ensembles=target_ensembles, 137 | nh_lstm=FLAGS.model_nh_lstm, 138 | nh_bottleneck=FLAGS.model_nh_bottleneck, 139 | dropoutrates_bottleneck=np.array(FLAGS.model_dropout_rates), 140 | bottleneck_weight_decay=FLAGS.model_weight_decay, 141 | bottleneck_has_bias=FLAGS.model_bottleneck_has_bias, 142 | init_weight_disp=FLAGS.model_init_weight_disp) 143 | rnn = model.GridCellsRNN(rnn_core, FLAGS.model_nh_lstm) 144 | 145 | # Get a trajectory batch 146 | input_tensors = [] 147 | init_pos, init_hd, ego_vel, target_pos, target_hd = train_traj 148 | if FLAGS.task_velocity_inputs: 149 | # Add the required amount of noise to the velocities 150 | vel_noise = tf.distributions.Normal(0.0, 1.0).sample( 151 | sample_shape=ego_vel.get_shape()) * FLAGS.task_velocity_noise 152 | input_tensors = [ego_vel + vel_noise] + input_tensors 153 | # Concatenate all inputs 154 | inputs = tf.concat(input_tensors, axis=2) 155 | 156 | # Replace euclidean positions and angles by encoding of place and hd ensembles 157 | # Note that the initial_conds will be zeros if the ensembles were configured 158 | # to provide that type of initialization 159 | initial_conds = utils.encode_initial_conditions( 160 | init_pos, init_hd, place_cell_ensembles, head_direction_ensembles) 161 | 162 | # Encode targets as well 163 | ensembles_targets = utils.encode_targets( 164 | target_pos, target_hd, place_cell_ensembles, head_direction_ensembles) 165 | 166 | # Estimate future encoding of place and hd ensembles inputing egocentric vels 167 | outputs, _ = rnn(initial_conds, inputs, training=True) 168 | ensembles_logits, bottleneck, lstm_output = outputs 169 | 170 | # Training loss 171 | pc_loss = tf.nn.softmax_cross_entropy_with_logits_v2( 172 | labels=ensembles_targets[0], logits=ensembles_logits[0], name='pc_loss') 173 | hd_loss = tf.nn.softmax_cross_entropy_with_logits_v2( 174 | labels=ensembles_targets[1], logits=ensembles_logits[1], name='hd_loss') 175 | total_loss = pc_loss + hd_loss 176 | train_loss = tf.reduce_mean(total_loss, name='train_loss') 177 | 178 | # Optimisation ops 179 | optimizer_class = eval(FLAGS.training_optimizer_class) # pylint: disable=eval-used 180 | optimizer = optimizer_class(**eval(FLAGS.training_optimizer_options)) # pylint: disable=eval-used 181 | grad = optimizer.compute_gradients(train_loss) 182 | clip_gradient = eval(FLAGS.training_clipping_function) # pylint: disable=eval-used 183 | clipped_grad = [ 184 | clip_gradient(g, var, FLAGS.training_clipping) for g, var in grad 185 | ] 186 | train_op = optimizer.apply_gradients(clipped_grad) 187 | 188 | # Store the grid scores 189 | grid_scores = dict() 190 | grid_scores['btln_60'] = np.zeros((FLAGS.model_nh_bottleneck,)) 191 | grid_scores['btln_90'] = np.zeros((FLAGS.model_nh_bottleneck,)) 192 | grid_scores['btln_60_separation'] = np.zeros((FLAGS.model_nh_bottleneck,)) 193 | grid_scores['btln_90_separation'] = np.zeros((FLAGS.model_nh_bottleneck,)) 194 | grid_scores['lstm_60'] = np.zeros((FLAGS.model_nh_lstm,)) 195 | grid_scores['lstm_90'] = np.zeros((FLAGS.model_nh_lstm,)) 196 | 197 | # Create scorer objects 198 | starts = [0.2] * 10 199 | ends = np.linspace(0.4, 1.0, num=10) 200 | masks_parameters = zip(starts, ends.tolist()) 201 | latest_epoch_scorer = scores.GridScorer(20, data_reader.get_coord_range(), 202 | masks_parameters) 203 | 204 | with tf.train.SingularMonitoredSession() as sess: 205 | for epoch in range(FLAGS.training_epochs): 206 | loss_acc = list() 207 | for _ in range(FLAGS.training_steps_per_epoch): 208 | res = sess.run({'train_op': train_op, 'total_loss': train_loss}) 209 | loss_acc.append(res['total_loss']) 210 | 211 | tf.logging.info('Epoch %i, mean loss %.5f, std loss %.5f', epoch, 212 | np.mean(loss_acc), np.std(loss_acc)) 213 | if epoch % FLAGS.saver_eval_time == 0: 214 | res = dict() 215 | for _ in xrange(FLAGS.training_evaluation_minibatch_size // 216 | FLAGS.training_minibatch_size): 217 | mb_res = sess.run({ 218 | 'bottleneck': bottleneck, 219 | 'lstm': lstm_output, 220 | 'pos_xy': target_pos 221 | }) 222 | res = utils.concat_dict(res, mb_res) 223 | 224 | # Store at the end of validation 225 | filename = 'rates_and_sac_latest_hd.pdf' 226 | grid_scores['btln_60'], grid_scores['btln_90'], grid_scores[ 227 | 'btln_60_separation'], grid_scores[ 228 | 'btln_90_separation'] = utils.get_scores_and_plot( 229 | latest_epoch_scorer, res['pos_xy'], res['bottleneck'], 230 | FLAGS.saver_results_directory, filename) 231 | 232 | 233 | def main(unused_argv): 234 | tf.logging.set_verbosity(3) # Print INFO log messages. 235 | train() 236 | 237 | if __name__ == '__main__': 238 | tf.app.run() 239 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions for creating the training graph and plotting. 17 | 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import os 25 | from matplotlib.backends.backend_pdf import PdfPages 26 | import matplotlib.pyplot as plt 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | import ensembles # pylint: disable=g-bad-import-order 31 | 32 | 33 | np.seterr(invalid="ignore") 34 | 35 | 36 | def get_place_cell_ensembles( 37 | env_size, neurons_seed, targets_type, lstm_init_type, n_pc, pc_scale): 38 | """Create the ensembles for the Place cells.""" 39 | place_cell_ensembles = [ 40 | ensembles.PlaceCellEnsemble( 41 | n, 42 | stdev=s, 43 | pos_min=-env_size / 2.0, 44 | pos_max=env_size / 2.0, 45 | seed=neurons_seed, 46 | soft_targets=targets_type, 47 | soft_init=lstm_init_type) 48 | for n, s in zip(n_pc, pc_scale) 49 | ] 50 | return place_cell_ensembles 51 | 52 | 53 | def get_head_direction_ensembles( 54 | neurons_seed, targets_type, lstm_init_type, n_hdc, hdc_concentration): 55 | """Create the ensembles for the Head direction cells.""" 56 | head_direction_ensembles = [ 57 | ensembles.HeadDirectionCellEnsemble( 58 | n, 59 | concentration=con, 60 | seed=neurons_seed, 61 | soft_targets=targets_type, 62 | soft_init=lstm_init_type) 63 | for n, con in zip(n_hdc, hdc_concentration) 64 | ] 65 | return head_direction_ensembles 66 | 67 | 68 | def encode_initial_conditions(init_pos, init_hd, place_cell_ensembles, 69 | head_direction_ensembles): 70 | initial_conds = [] 71 | for ens in place_cell_ensembles: 72 | initial_conds.append( 73 | tf.squeeze(ens.get_init(init_pos[:, tf.newaxis, :]), axis=1)) 74 | for ens in head_direction_ensembles: 75 | initial_conds.append( 76 | tf.squeeze(ens.get_init(init_hd[:, tf.newaxis, :]), axis=1)) 77 | return initial_conds 78 | 79 | 80 | def encode_targets(target_pos, target_hd, place_cell_ensembles, 81 | head_direction_ensembles): 82 | ensembles_targets = [] 83 | for ens in place_cell_ensembles: 84 | ensembles_targets.append(ens.get_targets(target_pos)) 85 | for ens in head_direction_ensembles: 86 | ensembles_targets.append(ens.get_targets(target_hd)) 87 | return ensembles_targets 88 | 89 | 90 | def clip_all_gradients(g, var, limit): 91 | # print(var.name) 92 | return (tf.clip_by_value(g, -limit, limit), var) 93 | 94 | 95 | def clip_bottleneck_gradient(g, var, limit): 96 | if ("bottleneck" in var.name or "pc_logits" in var.name): 97 | return (tf.clip_by_value(g, -limit, limit), var) 98 | else: 99 | return (g, var) 100 | 101 | 102 | def no_clipping(g, var): 103 | return (g, var) 104 | 105 | 106 | def concat_dict(acc, new_data): 107 | """Dictionary concatenation function.""" 108 | 109 | def to_array(kk): 110 | if isinstance(kk, np.ndarray): 111 | return kk 112 | else: 113 | return np.asarray([kk]) 114 | 115 | for k, v in new_data.iteritems(): 116 | if isinstance(v, dict): 117 | if k in acc: 118 | acc[k] = concat_dict(acc[k], v) 119 | else: 120 | acc[k] = concat_dict(dict(), v) 121 | else: 122 | v = to_array(v) 123 | if k in acc: 124 | acc[k] = np.concatenate([acc[k], v]) 125 | else: 126 | acc[k] = np.copy(v) 127 | return acc 128 | 129 | 130 | def get_scores_and_plot(scorer, 131 | data_abs_xy, 132 | activations, 133 | directory, 134 | filename, 135 | plot_graphs=True, # pylint: disable=unused-argument 136 | nbins=20, # pylint: disable=unused-argument 137 | cm="jet", 138 | sort_by_score_60=True): 139 | """Plotting function.""" 140 | 141 | # Concatenate all trajectories 142 | xy = data_abs_xy.reshape(-1, data_abs_xy.shape[-1]) 143 | act = activations.reshape(-1, activations.shape[-1]) 144 | n_units = act.shape[1] 145 | # Get the rate-map for each unit 146 | s = [ 147 | scorer.calculate_ratemap(xy[:, 0], xy[:, 1], act[:, i]) 148 | for i in xrange(n_units) 149 | ] 150 | # Get the scores 151 | score_60, score_90, max_60_mask, max_90_mask, sac = zip( 152 | *[scorer.get_scores(rate_map) for rate_map in s]) 153 | # Separations 154 | # separations = map(np.mean, max_60_mask) 155 | # Sort by score if desired 156 | if sort_by_score_60: 157 | ordering = np.argsort(-np.array(score_60)) 158 | else: 159 | ordering = range(n_units) 160 | # Plot 161 | cols = 16 162 | rows = int(np.ceil(n_units / cols)) 163 | fig = plt.figure(figsize=(24, rows * 4)) 164 | for i in xrange(n_units): 165 | rf = plt.subplot(rows * 2, cols, i + 1) 166 | acr = plt.subplot(rows * 2, cols, n_units + i + 1) 167 | if i < n_units: 168 | index = ordering[i] 169 | title = "%d (%.2f)" % (index, score_60[index]) 170 | # Plot the activation maps 171 | scorer.plot_ratemap(s[index], ax=rf, title=title, cmap=cm) 172 | # Plot the autocorrelation of the activation maps 173 | scorer.plot_sac( 174 | sac[index], 175 | mask_params=max_60_mask[index], 176 | ax=acr, 177 | title=title, 178 | cmap=cm) 179 | # Save 180 | if not os.path.exists(directory): 181 | os.makedirs(directory) 182 | with PdfPages(os.path.join(directory, filename), "w") as f: 183 | plt.savefig(f, format="pdf") 184 | plt.close(fig) 185 | return (np.asarray(score_60), np.asarray(score_90), 186 | np.asarray(map(np.mean, max_60_mask)), 187 | np.asarray(map(np.mean, max_90_mask))) 188 | --------------------------------------------------------------------------------