├── .gitignore ├── LICENSE ├── README.md ├── consumers ├── consumer.py ├── control.py ├── eval_consumer.py ├── generator_consumer.py ├── imitation_loss.py ├── margin_loss.py └── task_embedding.py ├── data ├── data_sequencer.py ├── dataset.py ├── generator.py ├── mil_sim_push.py ├── mil_sim_reach.py └── utils.py ├── datasets └── README.md ├── evaluation ├── __init__.py ├── eval.py ├── eval_mil_push.py └── eval_mil_reach.py ├── main_il.py ├── networks ├── cnn.py ├── input_output.py ├── save_load.py └── utils.py ├── readme_images └── tecnets.png ├── scripts └── mil_to_tecnet.py ├── tecnets_corl_results.sh ├── test ├── consumers │ ├── test_control.py │ ├── test_imitation_loss.py │ ├── test_margin_loss.py │ └── test_task_embedding.py ├── data │ ├── test_data_sequencer.py │ ├── test_datasets.py │ └── test_generator.py ├── evaluation │ └── test_evalulation.py ├── integration │ └── test_integration.py ├── networks │ └── test_networks.py └── test_data │ └── test_task │ ├── 0.gif │ ├── 1.gif │ ├── 2.gif │ └── 3.gif └── trainers ├── il_trainer.py ├── pipeline.py └── summary_writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .DS_STORE 7 | 8 | datasets/* 9 | !datasets/README.md 10 | 11 | logs/ 12 | 13 | .idea/ 14 | 15 | *.json 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | env/ 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *,cover 57 | .hypothesis/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # IPython Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | TASK-EMBEDDED CONTROL NETWORKS SOFTWARE 2 | 3 | LICENCE AGREEMENT 4 | 5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”)) 6 | ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE 7 | CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE 8 | FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE 9 | DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD 10 | THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT. 11 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) 12 | 13 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty 14 | free, non-transferable, non-sub- licensable licence (the “Licence”) to use the task-embedded 15 | control networks source code, including any modification, part or derivative (the “Software”). 16 | Ownership and Licence. Your rights to use and download the Software onto your computer, 17 | and all other copies that You are authorised to make, are specified in this Agreement. 18 | However, we (or our licensors) retain all rights, including but not limited to all copyright and 19 | other intellectual property rights anywhere in the world, in the Software not expressly 20 | granted to You in this Agreement. 21 | 22 | 2. Permitted use of the Licence: 23 | 24 | (a) You may download and install the Software onto one computer or server for use in 25 | accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is 26 | not accessible by other users unless they have themselves accepted the terms of this licence 27 | agreement. 28 | 29 | (b) You may use the Software solely for non-commercial, internal or academic research 30 | purposes and only in accordance with the terms of this Agreement. You may not use the 31 | Software for commercial purposes, including but not limited to (1) integration of all or part of 32 | the source code or the Software into a product for sale or licence by or on behalf of You to 33 | third parties or (2) use of the Software or any derivative of it for research to develop software 34 | products for sale or licence to a third party or (3) use of the Software or any derivative of it 35 | for research to develop non-software products for sale or licence to a third party, or (4) use of 36 | the Software to provide any service to an external organisation for which payment is 37 | received. 38 | 39 | Should You wish to use the Software for commercial purposes, You shall 40 | email researchcontracts.engineering@imperial.ac.uk . 41 | 42 | (c) Right to Copy. You may copy the Software for back-up and archival purposes, provided 43 | that each copy is kept in your possession and provided You reproduce our copyright notice 44 | (set out in Schedule 1) on each copy. 45 | 46 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may 47 | not transmit, transfer or sub-license this licence to use the Software or any of your rights or 48 | obligations under this Agreement to another party. 49 | 50 | (e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit 51 | any third party to access, modify or otherwise use the Software nor shall You access modify 52 | or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for 53 | mutiple users or a site licence for the Software please contact us 54 | at researchcontracts.engineering@imperial.ac.uk . 55 | 56 | (f) Publications and presentations. You may make public, results or data obtained from, 57 | dependent on or arising from research carried out using the Software, provided that any such 58 | presentation or publication identifies the Software as the source of the results or the data, 59 | including the Copyright Notice given in each element of the Software, and stating that the 60 | Software has been made available for use by You under licence from Imperial College London 61 | and You provide a copy of any such publication to Imperial College London. 62 | 63 | 3. Prohibited Uses. You may not, without written permission from us 64 | at researchcontracts.engineering@imperial.ac.uk : 65 | 66 | (a) Use, copy, modify, merge, or transfer copies of the Software or any documentation 67 | provided by us which relates to the Software except as provided in this Agreement; 68 | 69 | (b) Use any back-up or archival copies of the Software (or allow anyone else to use such 70 | copies) for any purpose other than to replace the original copy in the event it is destroyed or 71 | becomes defective; or 72 | 73 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the 74 | Software for any reason. 75 | 76 | 4. Warranty Disclaimer 77 | 78 | (a) Disclaimer. The Software has been developed for research purposes only. You 79 | acknowledge that we are providing the Software to You under this licence agreement free of 80 | charge and on condition that the disclaimer set out below shall apply. We do not represent or 81 | warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the 82 | suitability of the Software for any particular use or for use under any specific conditions; and 83 | (iii) whether use of the Software will infringe third-party rights. 84 | You acknowledge that You have reviewed and evaluated the Software to determine that it 85 | meets your needs and that You assume all responsibility and liability for determining the 86 | suitability of the Software as fit for your particular purposes and requirements. Subject to 87 | Clause 4(b), we exclude and expressly disclaim all express and implied representations, 88 | warranties, conditions and terms not stated herein (including the implied conditions or 89 | warranties of satisfactory quality, merchantable quality, merchantability and fitness for 90 | purpose). 91 | 92 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose 93 | obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or 94 | otherwise do not allow the exclusion of implied warranties, conditions or terms, in which 95 | case the above warranty disclaimer and exclusion will only apply to You to the extent 96 | permitted in the relevant jurisdiction and does not in any event exclude any implied 97 | warranties, conditions or terms which may not under applicable law be excluded. 98 | 99 | (c) Imperial College London disclaims all responsibility for the use which is made of the 100 | Software and any liability for the outcomes arising from using the Software. 101 | 102 | 5. Limitation of Liability 103 | 104 | (a) You acknowledge that we are providing the Software to You under this licence agreement 105 | free of charge and on condition that the limitation of liability set out below shall apply. 106 | Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort, 107 | negligence or otherwise, in respect of the Software and/or any related documentation 108 | provided to You by us including, but not limited to, liability for loss or corruption of data, 109 | loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect 110 | loss or damage of any kind arising out of or in connection with this licence agreement, 111 | however caused. This exclusion shall apply even if we have been advised of the possibility of 112 | such loss or damage. 113 | 114 | (b) You agree to indemnify Imperial College London and hold it harmless from and against 115 | any and all claims, damages and liabilities asserted by third parties (including claims for 116 | negligence) which arise directly or indirectly from the use of the Software or any derivative 117 | of it or the sale of any products based on the Software. You undertake to make no liability 118 | claim against any employee, student, agent or appointee of Imperial College London, in 119 | connection with this Licence or the Software. 120 | 121 | (c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory 122 | liability. 123 | 124 | (d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part, 125 | and, to that extent, they may not apply to you. Nothing in this licence agreement will affect 126 | your statutory rights or other relevant statutory provisions which cannot be excluded, 127 | restricted or modified, and its terms and conditions must be read and construed subject to any 128 | such statutory rights and/or provisions. 129 | 130 | 6. Confidentiality. You agree not to disclose any confidential information provided to You by 131 | us pursuant to this Agreement to any third party without our prior written consent. The 132 | obligations in this Clause 6 shall survive the termination of this Agreement for any reason. 133 | 134 | 7. Termination. 135 | 136 | (a) We may terminate this licence agreement and your right to use the Software at any time 137 | with immediate effect upon written notice to You. 138 | 139 | (b) This licence agreement and your right to use the Software automatically terminate if You: 140 | (i) fail to comply with any provisions of this Agreement; or 141 | (ii) destroy the copies of the Software in your possession, or voluntarily return the Software 142 | to us. 143 | 144 | (c) Upon termination You will destroy all copies of the Software. 145 | 146 | (d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years 147 | after first use of the Software under this licence agreement. 148 | 149 | 8. Miscellaneous Provisions. 150 | 151 | (a) This Agreement will be governed by and construed in accordance with the substantive 152 | laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes 153 | which may arise between us. 154 | 155 | (b) This is the entire agreement between us relating to the Software, and supersedes any prior 156 | purchase order, communications, advertising or representations concerning the Software. 157 | 158 | (c) No change or modification of this Agreement will be valid unless it is in writing, and is 159 | signed by us. 160 | 161 | (d) The unenforceability or invalidity of any part of this Agreement will not affect the 162 | enforceability or validity of the remaining parts. 163 | 164 | BSD Elements of the Software 165 | 166 | For BSD elements of the Software, the following terms shall apply: 167 | 168 | Copyright as indicated in the header of the individual element of the Software. 169 | 170 | All rights reserved. 171 | 172 | Redistribution and use in source and binary forms, with or without modification, are 173 | permitted provided that the following conditions are met: 174 | 175 | 1. Redistributions of source code must retain the above copyright notice, this list of 176 | conditions and the following disclaimer. 177 | 178 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of 179 | conditions and the following disclaimer in the documentation and/or other materials 180 | provided with the distribution. 181 | 182 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to 183 | endorse or promote products derived from this software without specific prior written 184 | permission. 185 | 186 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 187 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 188 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 189 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 190 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 191 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 192 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 193 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 194 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 195 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 196 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 197 | 198 | SCHEDULE 1 199 | 200 | The Software 201 | 202 | Task-Embedded Control Networks, which employ ideas from metric learning in 203 | order to create a task embedding that can be used by a robot to learn new tasks 204 | from one or more demonstrations. It is based on the techniques described in the 205 | following publication: 206 | 207 | Stephen James, Michael Bloesch, Andrew J Davison. Task-Embedded Control Networks 208 | for Few-Shot Imitation Learning, Conference on Robot Learning (CoRL), 2018. 209 | 210 | 211 | If you use the software, you should reference the following paper in any publication: 212 | 213 | Stephen James, Michael Bloesch, Andrew J Davison. Task-Embedded Control Networks 214 | for Few-Shot Imitation Learning, Conference on Robot Learning (CoRL), 2018. 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Task Embedded Control Networks 2 | 3 | ![TecNets Example](readme_images/tecnets.png) 4 | 5 | The code presented here was used in: [Task-Embedded Control Networks for Few-Shot Imitation Learning](https://arxiv.org/abs/1810.03237). 6 | 7 | ### Running Paper Experiments 8 | 9 | If you want to be able to re-run the experiments presented in the paper, you will need some of the dependencies from a paper that we compare against: [One-Shot Visual Imitation Learning via Meta-Learning 10 | ](https://arxiv.org/abs/1709.04905). 11 | 12 | Follow these steps: 13 | 14 | 1. First clone the fork of the gym repo found [here](https://github.com/tianheyu927/gym), and switch to branch mil. 15 | 2. You can now either install this, or just add the gym fork to your PYTHONPATH. 16 | 3. Download the _mil_sim_reach_ and _mil_sim_push_ datsets from [here](https://www.doc.ic.ac.uk/~slj12/data/mil_data.zip). Unzip them to the _datasets_ folder. _Note: The data format here has been changed slightly in comparison to the original data from the MIL paper._ 17 | 4. (Optional) Run the integration test to make sure everything is set-up correctly. 18 | 19 | To run the reaching task, run: 20 | ```bash 21 | ./tecnets_corl_results.sh sim_reach 22 | ``` 23 | To run the pushing task, run: 24 | ```bash 25 | ./tecnets_corl_results.sh sim_push 26 | ``` 27 | 28 | ### Code Design 29 | 30 | This section is for people who wish to extend the framework. 31 | 32 | The code in designed in a pipelined fashion, where there are a list of 33 | consumers that takes in a dictionary of inputs (from a previous consumer) 34 | and then outputs a combined dictionary of the inputs and outputs of that 35 | consumer. 36 | For example: 37 | 38 | ```python 39 | a = GeneratorConsumer(...) 40 | b = TaskEmbedding(...) 41 | c = MarginLoss(...) 42 | d = Control(...) 43 | e = ImitationLoss(...) 44 | consumers = [a, b, c, d, e] 45 | p = Pipeline(consumers) 46 | ``` 47 | 48 | This allows the TecNet to be built in a modular way. For example, if one 49 | wanted to do use a prototypical loss rather than a margin loss, then one would 50 | only need to swap out one of these consumers. 51 | 52 | ## Citation 53 | 54 | ``` 55 | @article{james2018task, 56 | title={Task-Embedded Control Networks for Few-Shot Imitation Learning}, 57 | author={James, Stephen and Bloesch, Michael and Davison, Andrew J}, 58 | journal={Conference on Robot Learning (CoRL)}, 59 | year={2018} 60 | } 61 | ``` -------------------------------------------------------------------------------- /consumers/consumer.py: -------------------------------------------------------------------------------- 1 | class Consumer(object): 2 | 3 | def consume(self, inputs): 4 | raise NotImplementedError('Must be overridden.') 5 | 6 | def get_summaries(self, prefix): 7 | del prefix 8 | return [] 9 | 10 | def get_loss(self): 11 | return 0 12 | 13 | def verify(self, item): 14 | if item is None: 15 | raise RuntimeError( 16 | 'Attempted to get summaries or loss before calling consume.') 17 | return item 18 | 19 | def get(self, dict, key): 20 | if key not in dict: 21 | raise KeyError("Tried to access consumer %s." % key) 22 | return dict[key] 23 | -------------------------------------------------------------------------------- /consumers/control.py: -------------------------------------------------------------------------------- 1 | from consumers.consumer import Consumer 2 | import tensorflow as tf 3 | from networks.input_output import * 4 | 5 | 6 | class Control(Consumer): 7 | 8 | def __init__(self, network, action_size, include_state=False): 9 | self.network = network 10 | self.action_size = action_size 11 | self.include_state = include_state 12 | super().__init__() 13 | 14 | def consume(self, inputs): 15 | 16 | # Shape (batch, embsize) 17 | s = self.get(inputs, 'sentences') 18 | 19 | # (batch, examples, h, w, 3) 20 | ctrnet_images = self.get(inputs, 'ctrnet_images') 21 | 22 | examples = ctrnet_images.shape[1] 23 | width = ctrnet_images.shape[-2] 24 | height = ctrnet_images.shape[-3] 25 | 26 | # (batch, 1, 1, 1, emb) 27 | s = tf.expand_dims( 28 | tf.expand_dims(tf.expand_dims(s, axis=1), axis=1), axis=1) 29 | # (batch, examples, height, width, emb) 30 | tiled = tf.tile(s, [1, examples, height, width, 1]) 31 | ctrnet_input = tf.concat([ctrnet_images, tiled], axis=-1) 32 | emb_plus_channels = ctrnet_input.shape[-1] 33 | 34 | # Squash (batch * examples, h, w, emb) 35 | ctrnet_input = tf.reshape( 36 | ctrnet_input, (-1, height, width, emb_plus_channels)) 37 | 38 | net_ins = [NetworkInput(name='ctr_images', layer_type='conv', 39 | layer_num=0, tensor=ctrnet_input)] 40 | 41 | if self.include_state: 42 | states = self.get(inputs, 'ctrnet_states') 43 | states = tf.reshape(states, (-1, states.shape[-1])) 44 | net_ins.append(NetworkInput( 45 | name='ctrnet_states', layer_type='fc', 46 | layer_num=0, tensor=states, merge_mode='concat')) 47 | 48 | net_out = NetworkHead(name='output_action', 49 | nodes=self.action_size) 50 | 51 | with tf.variable_scope('control_net', reuse=tf.AUTO_REUSE): 52 | outputs = self.network.forward(net_ins, [net_out], 53 | self.get(inputs, 'training')) 54 | 55 | inputs['output_actions'] = tf.reshape(self.get(outputs, 'output_action'), 56 | (-1, examples, self.action_size)) 57 | return inputs 58 | -------------------------------------------------------------------------------- /consumers/eval_consumer.py: -------------------------------------------------------------------------------- 1 | from consumers.consumer import Consumer 2 | import tensorflow as tf 3 | from data import utils 4 | 5 | 6 | class EvalConsumer(Consumer): 7 | 8 | def __init__(self, dataset, data_sequencer, support, disk_images=True): 9 | self.dataset = dataset 10 | self.data_sequencer = data_sequencer 11 | self.support = support 12 | self.disk_images = disk_images 13 | super().__init__() 14 | 15 | def consume(self, inputs): 16 | 17 | if self.disk_images: 18 | # (Examples,) 19 | input_image = tf.placeholder(tf.string, (self.support,)) 20 | else: 21 | # (Examples, timesteps) 22 | input_image = tf.placeholder(tf.float32, 23 | (None, None) + self.dataset.img_shape) 24 | input_states = tf.placeholder( 25 | tf.float32, 26 | (self.support, self.dataset.time_horizon, self.dataset.state_size)) 27 | input_outputs = tf.placeholder( 28 | tf.float32, 29 | (self.support, self.dataset.time_horizon, self.dataset.action_size)) 30 | 31 | # (B. W, H, C) 32 | input_ctr_image = tf.placeholder(tf.float32, 33 | (None, 1) + self.dataset.img_shape) 34 | input_ctr_state = tf.placeholder(tf.float32, 35 | (None, 1, self.dataset.state_size)) 36 | 37 | training = tf.placeholder_with_default(False, None) 38 | 39 | stacked_embnet_images, bs, cs = [], [], [] 40 | for i in range(self.support): 41 | embnet_images, embnet_states, embnet_outputs = ( 42 | self.data_sequencer.load( 43 | input_image[i], input_states[i], input_outputs[i])) 44 | embnet_images = utils.preprocess(embnet_images) 45 | stacked_embnet_images.append(embnet_images) 46 | bs.append(embnet_states) 47 | cs.append(embnet_outputs) 48 | 49 | embnet_images = tf.stack(stacked_embnet_images) 50 | embnet_images = tf.expand_dims(embnet_images, axis=0) # set batchsize 1 51 | 52 | embnet_states = tf.stack(bs) 53 | embnet_states = tf.expand_dims(embnet_states, axis=0) 54 | 55 | embnet_outputs = tf.stack(cs) 56 | embnet_outputs = tf.expand_dims(embnet_outputs, axis=0) 57 | 58 | embnet_images.set_shape( 59 | (None, None, self.data_sequencer.frames) + self.dataset.img_shape) 60 | embnet_states.set_shape( 61 | (None, None, self.data_sequencer.frames, self.dataset.state_size)) 62 | embnet_outputs.set_shape( 63 | (None, None, self.data_sequencer.frames, self.dataset.action_size)) 64 | 65 | return { 66 | 'embnet_images': embnet_images, 67 | 'embnet_states': embnet_states, 68 | 'embnet_outputs': embnet_outputs, 69 | 'input_image_files': input_image, 70 | 'input_states': input_states, 71 | 'input_outputs': input_outputs, 72 | 'ctrnet_images': input_ctr_image, 73 | 'ctrnet_states': input_ctr_state, 74 | 'training': training, 75 | 'support': tf.placeholder_with_default(self.support, None), 76 | 'query': tf.placeholder_with_default(0, None), 77 | } 78 | -------------------------------------------------------------------------------- /consumers/generator_consumer.py: -------------------------------------------------------------------------------- 1 | from consumers.consumer import Consumer 2 | import tensorflow as tf 3 | 4 | 5 | class GeneratorConsumer(Consumer): 6 | 7 | def __init__(self, generator, dataset, support, query): 8 | super().__init__() 9 | self.generator = generator 10 | self.dataset = dataset 11 | self.support = support 12 | self.query = query 13 | self.embnet_images = None 14 | self.ctrnet_images = None 15 | 16 | def consume(self, inputs): 17 | (embnet_images, embnet_states, embnet_outputs, ctrnet_images, 18 | ctrnet_states, ctrnet_outputs) = self.generator.next_element 19 | training = tf.placeholder(tf.bool) 20 | 21 | embnet_images.set_shape((None, None, None) + self.dataset.img_shape) 22 | embnet_states.set_shape((None, None, None, self.dataset.state_size)) 23 | embnet_outputs.set_shape((None, None, None, self.dataset.action_size)) 24 | ctrnet_images.set_shape((None, 2,) + self.dataset.img_shape) 25 | ctrnet_states.set_shape((None, 2, self.dataset.state_size)) 26 | ctrnet_outputs.set_shape((None, 2, self.dataset.action_size)) 27 | 28 | self.embnet_images = embnet_images 29 | self.ctrnet_images = ctrnet_images 30 | 31 | return { 32 | 'embnet_images': embnet_images, 33 | 'embnet_states': embnet_states, 34 | 'embnet_outputs': embnet_outputs, 35 | 'ctrnet_images': ctrnet_images, 36 | 'ctrnet_states': ctrnet_states, 37 | 'ctrnet_outputs': ctrnet_outputs, 38 | 'training': training, 39 | 'support': tf.placeholder_with_default(self.support, None), 40 | 'query': tf.placeholder_with_default(self.query, None), 41 | } 42 | 43 | def get_summaries(self, prefix): 44 | # Grab the last frame for each task. 45 | # We know there should be at least 2 examples. 46 | embnet_example_1 = self.verify(self.embnet_images)[:, 0, -1] 47 | embnet_example_2 = self.verify(self.embnet_images)[:, 1, -1] 48 | ctrnet_support = self.verify(self.ctrnet_images)[:, 0] 49 | ctrnet_query = self.verify(self.ctrnet_images)[:, 1] 50 | return [ 51 | tf.summary.image(prefix + '_embnet_example_1', embnet_example_1), 52 | tf.summary.image(prefix + '_embnet_example_2', embnet_example_2), 53 | tf.summary.image(prefix + '_ctrnet_support', ctrnet_support), 54 | tf.summary.image(prefix + '_ctrnet_query', ctrnet_query), 55 | ] 56 | -------------------------------------------------------------------------------- /consumers/imitation_loss.py: -------------------------------------------------------------------------------- 1 | from consumers.consumer import Consumer 2 | import tensorflow as tf 3 | from networks.input_output import * 4 | 5 | 6 | class ImitationLoss(Consumer): 7 | 8 | def __init__(self, support_lambda=1.0, query_lambda=1.0): 9 | self.support_lambda = support_lambda 10 | self.query_lambda = query_lambda 11 | self.loss_support = None 12 | self.loss_query = None 13 | 14 | def consume(self, inputs): 15 | 16 | # (batch, 2, actions) 17 | a = self.get(inputs, 'output_actions') 18 | labels = self.get(inputs, 'ctrnet_outputs') 19 | 20 | support_loss = tf.losses.mean_squared_error(a[:, 0], labels[:, 0]) 21 | query_loss = tf.losses.mean_squared_error(a[:, 1], labels[:, 1]) 22 | 23 | self.loss_support = self.support_lambda * support_loss 24 | self.loss_query = self.query_lambda * query_loss 25 | inputs['loss_support'] = self.loss_support 26 | inputs['loss_query'] = self.loss_query 27 | return inputs 28 | 29 | def get_summaries(self, prefix): 30 | return [ 31 | tf.summary.scalar( 32 | prefix + '_support_loss', self.verify(self.loss_support)), 33 | tf.summary.scalar( 34 | prefix + '_query_loss', self.verify(self.loss_query)) 35 | ] 36 | 37 | def get_loss(self): 38 | return self.verify(self.loss_support) + self.verify(self.loss_query) 39 | -------------------------------------------------------------------------------- /consumers/margin_loss.py: -------------------------------------------------------------------------------- 1 | from consumers.consumer import Consumer 2 | import tensorflow as tf 3 | 4 | 5 | class MarginLoss(Consumer): 6 | 7 | def __init__(self, margin, loss_lambda=1.0): 8 | self.margin = margin 9 | self.loss_lambda = loss_lambda 10 | self.loss_embedding = None 11 | self.embedding_accuracy = None 12 | 13 | def _norm(self, vecs, axis=1): 14 | mag = tf.sqrt(tf.reduce_sum(tf.square(vecs), axis=axis, keep_dims=True)) 15 | return vecs / tf.maximum(mag, 1e-6) 16 | 17 | def consume(self, inputs): 18 | 19 | # (batch, sup, emb_size) 20 | semb = self.get(inputs, 'support_embedding') 21 | qemb = self.get(inputs, 'query_embedding') 22 | qemb_shape = tf.shape(qemb) 23 | batch_size, query_size = qemb_shape[0], qemb_shape[1] 24 | qemb = tf.reshape(qemb, (batch_size * query_size, -1)) 25 | 26 | # Shape (batch, embsize) 27 | support_sentences = self._norm( 28 | tf.reduce_mean(self._norm(semb, axis=2), axis=1), axis=1) 29 | inputs['sentences'] = support_sentences 30 | 31 | # Similarities of every sentence with every query 32 | # Shape (batch, batch * queries) 33 | similarities = tf.matmul(support_sentences, qemb, transpose_b=True) 34 | # Shape (batch, batch, queries) 35 | similarities = tf.reshape(similarities, 36 | (batch_size, batch_size, query_size)) 37 | 38 | # Gets the diagonal to give (batch, query) 39 | positives = tf.boolean_mask(similarities, tf.eye(batch_size)) 40 | positives_ex = tf.expand_dims(positives, axis=1) # (batch, 1, query) 41 | 42 | negatives = tf.boolean_mask(similarities, 43 | tf.equal(tf.eye(batch_size), 0)) 44 | # (batch, batch-1, query) 45 | negatives = tf.reshape(negatives, (batch_size, batch_size - 1, -1)) 46 | 47 | loss = tf.maximum(0.0, self.margin - positives_ex + negatives) 48 | loss = tf.reduce_mean(loss) 49 | 50 | self.loss_embedding = self.loss_lambda * loss 51 | 52 | # Summaries 53 | max_of_negs = tf.reduce_max(negatives, axis=1) # (batch, query) 54 | accuracy = tf.greater(positives, max_of_negs) 55 | self.embedding_accuracy = tf.reduce_mean(tf.cast(accuracy, tf.float32)) 56 | inputs['loss_embedding'] = self.loss_embedding 57 | inputs['embedding_accuracy'] = self.embedding_accuracy 58 | return inputs 59 | 60 | def get_summaries(self, prefix): 61 | return [ 62 | tf.summary.scalar(prefix + 'embedding_accuracy', 63 | self.verify(self.embedding_accuracy)), 64 | tf.summary.scalar(prefix + 'loss_embedding', 65 | self.verify(self.loss_embedding)) 66 | ] 67 | 68 | def get_loss(self): 69 | return self.verify(self.loss_embedding) 70 | -------------------------------------------------------------------------------- /consumers/task_embedding.py: -------------------------------------------------------------------------------- 1 | from consumers.consumer import Consumer 2 | from networks.input_output import * 3 | import tensorflow as tf 4 | 5 | VALID_FRAME_COLLAPSE = ['concat'] 6 | 7 | 8 | class TaskEmbedding(Consumer): 9 | 10 | def __init__(self, network, embedding_size, frame_collapse_method='concat', 11 | include_state=False, include_action=False): 12 | self.network = network 13 | self.embedding_size = embedding_size 14 | self.include_state = include_state 15 | self.include_action = include_action 16 | if frame_collapse_method not in VALID_FRAME_COLLAPSE: 17 | raise ValueError('%s is not a valid frame collapse method.' 18 | % frame_collapse_method) 19 | self.frame_collapse_method = frame_collapse_method 20 | 21 | def _squash_input(self, tensor, shape, batch_size, support_plus_query): 22 | return tf.reshape(tensor, tf.concat( 23 | [[batch_size * support_plus_query], shape[2:]], axis=0)) 24 | 25 | def _expand_output(self, tensor, batch_size,support_plus_query): 26 | return tf.reshape(tensor, (batch_size, support_plus_query, -1)) 27 | 28 | def consume(self, inputs): 29 | 30 | # Condense the inputs from (Batch, support_query, transis, w, h, c) 31 | # to (Batch, support_query, w, h, c * transis) 32 | embed_images = self.get(inputs, 'embnet_images') 33 | support = self.get(inputs, 'support') 34 | query = self.get(inputs, 'query') 35 | 36 | if self.frame_collapse_method == 'concat': 37 | embed_images = tf.concat(tf.unstack(embed_images, axis=2), axis=-1) 38 | 39 | embed_images_shape = tf.shape(embed_images) 40 | batch_size = embed_images_shape[0] 41 | support_plus_query = embed_images_shape[1] 42 | 43 | # Sanity check 44 | assertion_op = tf.assert_equal( 45 | support_plus_query, support + query, 46 | message='Support and Query size is different than expected.') 47 | with tf.control_dependencies([assertion_op]): 48 | # Condense to shape (batch_size*(support_plus_query),w,h,c) 49 | reshaped_images = self._squash_input( 50 | embed_images, embed_images_shape, batch_size, 51 | support_plus_query) 52 | net_ins = [NetworkInput(name='embed_images', layer_type='conv', 53 | layer_num=0, tensor=reshaped_images)] 54 | 55 | if self.include_state: 56 | embnet_states = self.get(inputs, 'embnet_states') 57 | if self.frame_collapse_method == 'concat': 58 | # (Batch, support_query, State*Frames) 59 | embnet_states = tf.concat( 60 | tf.unstack(embnet_states, axis=2), axis=-1) 61 | reshaped_state = self._squash_input( 62 | embnet_states, tf.shape(embnet_states), batch_size, 63 | support_plus_query) 64 | net_ins.append(NetworkInput( 65 | name='embnet_states', layer_type='fc', 66 | layer_num=0, tensor=reshaped_state, merge_mode='concat')) 67 | 68 | if self.include_action: 69 | embnet_actions = self.get(inputs, 'embnet_outputs') 70 | if self.frame_collapse_method == 'concat': 71 | # (Batch, support_query, Actions*Frames) 72 | embnet_actions = tf.concat( 73 | tf.unstack(embnet_actions, axis=2), axis=-1) 74 | reshaped_action = self._squash_input( 75 | embnet_actions, tf.shape(embnet_actions), batch_size, 76 | support_plus_query) 77 | net_ins.append(NetworkInput( 78 | name='embnet_actions', layer_type='fc', 79 | layer_num=0, tensor=reshaped_action, merge_mode='concat')) 80 | 81 | net_out = NetworkHead(name='output_embedding', 82 | nodes=self.embedding_size) 83 | with tf.variable_scope('task_embedding_net', reuse=tf.AUTO_REUSE): 84 | outputs = self.network.forward(net_ins, [net_out], 85 | self.get(inputs, 'training')) 86 | 87 | # Convert to (Batch, support_query, emb_size) 88 | embedding = tf.reshape( 89 | self.get(outputs, 'output_embedding'), 90 | (batch_size, support_plus_query, self.embedding_size)) 91 | 92 | outputs['support_embedding'] = embedding[:, :support] 93 | outputs['query_embedding'] = embedding[:, support:] 94 | 95 | inputs.update(outputs) 96 | return inputs 97 | -------------------------------------------------------------------------------- /data/data_sequencer.py: -------------------------------------------------------------------------------- 1 | from data import utils 2 | import tensorflow as tf 3 | 4 | VALID_SEQUENCE_STRATEGIES = ['first', 'last', 'first_last', 'all'] 5 | 6 | 7 | class DataSequencer(object): 8 | 9 | def __init__(self, sequence_strategy, time_horizon): 10 | self.sequence_strategy = sequence_strategy 11 | self.time_horizon = time_horizon 12 | if sequence_strategy not in VALID_SEQUENCE_STRATEGIES: 13 | raise ValueError('%s is not a valid sequence embedding strategy.' 14 | % sequence_strategy) 15 | self.frames = 1 16 | if sequence_strategy == 'first_last': 17 | self.frames = 2 18 | elif sequence_strategy == 'all': 19 | self.frames = self.time_horizon 20 | 21 | def load(self, images, states, outputs): 22 | is_image_file = images.dtype == tf.string 23 | # Embedding images 24 | if self.sequence_strategy == 'first': 25 | if is_image_file: 26 | loaded_images = [utils.tf_load_image(images, 0)] 27 | else: 28 | loaded_images = [images[0]] 29 | emb_states = [states[0]] 30 | emb_outputs = [outputs[0]] 31 | elif self.sequence_strategy == 'last': 32 | if is_image_file: 33 | loaded_images = [utils.tf_load_image(images, 34 | self.time_horizon - 1)] 35 | else: 36 | loaded_images = [images[self.time_horizon - 1]] 37 | emb_states = [states[-1]] 38 | emb_outputs = [outputs[-1]] 39 | elif self.sequence_strategy == 'first_last': 40 | if is_image_file: 41 | loaded_images = [utils.tf_load_image(images, 0), 42 | utils.tf_load_image(images, 43 | self.time_horizon - 1)] 44 | else: 45 | loaded_images = [images[0], images[self.time_horizon - 1]] 46 | emb_states = [states[0], states[-1]] 47 | emb_outputs = [outputs[0], outputs[-1]] 48 | elif self.sequence_strategy == 'all': 49 | if is_image_file: 50 | loaded_images = [utils.tf_load_image(images, t) 51 | for t in range(self.time_horizon)] 52 | else: 53 | loaded_images = images 54 | emb_states = [states[t] 55 | for t in range(self.time_horizon)] 56 | emb_outputs = [outputs[t] 57 | for t in range(self.time_horizon)] 58 | else: 59 | raise ValueError( 60 | '%s is not a valid sequence embedding strategy.' 61 | % self.sequence_strategy) 62 | return loaded_images, emb_states, emb_outputs 63 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from natsort import natsorted 4 | 5 | 6 | class Dataset(object): 7 | 8 | def __init__(self, name, img_shape, state_size, action_size, time_horizon, 9 | training_size=None, validation_size=None): 10 | self.name = name 11 | self.img_shape = img_shape 12 | self.state_size = state_size 13 | self.action_size = action_size 14 | self.time_horizon = time_horizon 15 | self.training_size = training_size 16 | self.validation_size = validation_size 17 | self.data_root = os.path.join( 18 | os.path.dirname(os.path.realpath(__file__)), '../datasets', name) 19 | 20 | def training_set(self): 21 | tasks = self.load('train', self.training_size + self.validation_size) 22 | return tasks[:self.training_size], tasks[-self.validation_size:] 23 | 24 | def test_set(self): 25 | return self.load("test") 26 | 27 | def load(self, train_or_test, count=None): 28 | """Expected to be the test or train folder""" 29 | train_test_dir = os.path.join(self.data_root, train_or_test) 30 | tasks = [] 31 | for task_f in natsorted(os.listdir(train_test_dir)): 32 | task_path = os.path.join(train_test_dir, task_f) 33 | if not os.path.isdir(task_path): 34 | continue 35 | pkl_file = task_path + '.pkl' 36 | with open(pkl_file, "rb") as f: 37 | data = pickle.load(f) 38 | example_img_folders = natsorted(os.listdir(task_path)) 39 | examples = [] 40 | for e_idx, ex_file in enumerate(example_img_folders): 41 | img_path = os.path.join(task_path, ex_file) 42 | example = { 43 | 'image_files': img_path, 44 | 'actions': data['actions'][e_idx], 45 | 'states': data['states'][e_idx] 46 | } 47 | if 'demo_selection' in data: 48 | example['demo_selection'] = data['demo_selection'] 49 | examples.append(example) 50 | tasks.append(examples) 51 | if count is not None and len(tasks) >= count: 52 | break 53 | return tasks 54 | 55 | def get_outputs(self): 56 | return { 57 | 'actions': (self.action_size,) 58 | } 59 | 60 | def get_inputs(self): 61 | return { 62 | 'states': (self.state_size,) 63 | } 64 | -------------------------------------------------------------------------------- /data/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import multiprocessing 4 | import logging 5 | from data import utils 6 | 7 | 8 | class Generator(object): 9 | 10 | def __init__(self, dataset, batch_size, support_size, query_size, 11 | data_sequencer): 12 | self.dataset = dataset 13 | self.batch_size = batch_size 14 | self.support_size = support_size 15 | self.query_size = query_size 16 | self.data_sequencer = data_sequencer 17 | self.support_query_size = support_size + query_size 18 | self.train, self.validation = self.dataset.training_set() 19 | self._construct() 20 | 21 | def _create_generator(self, data): 22 | 23 | num_tasks = len(data) 24 | samples_to_take = np.minimum(self.batch_size, num_tasks) 25 | if samples_to_take != self.batch_size: 26 | logging.warning('Batch size was greater than number of tasks.') 27 | 28 | def gen(): 29 | 30 | while True: 31 | states, outputs, image_files = [], [], [] 32 | task_indices = np.random.choice( 33 | num_tasks, samples_to_take, replace=False) 34 | for index in task_indices: 35 | task = data[index] 36 | num_examples_of_task = len(task) 37 | if num_examples_of_task < self.support_query_size: 38 | raise RuntimeError( 39 | 'Tried to sample %d support and query samples,' 40 | 'but there are only %d samples of this task.' 41 | % (self.support_query_size, num_examples_of_task)) 42 | 43 | sample_indices = np.random.choice( 44 | num_examples_of_task, self.support_query_size, 45 | replace=False) 46 | sampled_examples = [task[sample_index] for sample_index in 47 | sample_indices] 48 | states.append([ex['states'] for ex in sampled_examples]) 49 | outputs.append([ex['actions'] for ex in sampled_examples]) 50 | image_files.append( 51 | [ex['image_files'] for ex in sampled_examples]) 52 | 53 | states = np.array(states) 54 | outputs = np.array(outputs) 55 | image_files = np.array(image_files) 56 | yield image_files, states, outputs 57 | 58 | return gen 59 | 60 | def _load_from_disk(self, image_files, states, outputs): 61 | 62 | embnet_images, embnet_states, embnet_outputs = [], [], [] 63 | for i in range(self.support_query_size): 64 | 65 | images, emb_states, emb_outputs = self.data_sequencer.load( 66 | image_files[i], states[i], outputs[i]) 67 | # images will be of shape (sequence, w, h, 3) 68 | embnet_images.append(images) 69 | embnet_states.append(emb_states) 70 | embnet_outputs.append(emb_outputs) 71 | 72 | embed_images = tf.stack(embnet_images) 73 | embnet_states = tf.stack(embnet_states) 74 | embnet_outputs = tf.stack(embnet_outputs) 75 | 76 | embnet_states.set_shape( 77 | (self.support_query_size, self.data_sequencer.frames, None)) 78 | embnet_outputs.set_shape( 79 | (self.support_query_size, self.data_sequencer.frames, None)) 80 | 81 | # Grab a random timestep in one of the support and query trajectories 82 | ctrnet_timestep = tf.random_uniform( 83 | (2,), 0, self.dataset.time_horizon, tf.int32) 84 | # The first should be a support and the last should be a query 85 | ctrnet_images = [ 86 | utils.tf_load_image(image_files[0], ctrnet_timestep[0]), 87 | utils.tf_load_image(image_files[-1], ctrnet_timestep[1]) 88 | ] 89 | ctrnet_states = [states[0][ctrnet_timestep[0]], 90 | states[-1][ctrnet_timestep[1]]] 91 | ctrnet_outputs = [outputs[0][ctrnet_timestep[0]], 92 | outputs[-1][ctrnet_timestep[1]]] 93 | 94 | ctrnet_images = tf.stack(ctrnet_images) 95 | ctrnet_states = tf.stack(ctrnet_states) 96 | ctrnet_outputs = tf.stack(ctrnet_outputs) 97 | 98 | embed_images = utils.preprocess(embed_images) 99 | ctrnet_images = utils.preprocess(ctrnet_images) 100 | 101 | return (embed_images, embnet_states, embnet_outputs, 102 | ctrnet_images, ctrnet_states, ctrnet_outputs) 103 | 104 | def _construct_dataset(self, data, prefetch): 105 | dataset = tf.data.Dataset.from_generator( 106 | self._create_generator(data), (tf.string, tf.float32, tf.float32)) 107 | dataset = dataset.apply(tf.contrib.data.unbatch()).map( 108 | map_func=self._load_from_disk, 109 | num_parallel_calls=multiprocessing.cpu_count()) 110 | dataset = dataset.batch(self.batch_size).prefetch(prefetch) 111 | return dataset 112 | 113 | def _construct(self): 114 | train_dataset = self._construct_dataset(self.train, 5) 115 | validation_dataset = self._construct_dataset(self.validation, 1) 116 | handle = tf.placeholder(tf.string, shape=[]) 117 | iterator = tf.data.Iterator.from_string_handle( 118 | handle, train_dataset.output_types, train_dataset.output_shapes) 119 | 120 | self.next_element = iterator.get_next() 121 | self.train_iterator = train_dataset.make_one_shot_iterator() 122 | self.validation_iterator = validation_dataset.make_one_shot_iterator() 123 | self.handle = handle 124 | 125 | def get_handles(self, sess): 126 | training_handle = sess.run(self.train_iterator.string_handle()) 127 | validation_handle = sess.run(self.validation_iterator.string_handle()) 128 | return training_handle, validation_handle 129 | -------------------------------------------------------------------------------- /data/mil_sim_push.py: -------------------------------------------------------------------------------- 1 | from data.dataset import Dataset 2 | 3 | 4 | class MilSimPush(Dataset): 5 | 6 | def __init__(self, training_size=693, validation_size=76): 7 | super().__init__(name='mil_sim_push', img_shape=(125, 125, 3), 8 | state_size=20, action_size=7, time_horizon=100, 9 | training_size=training_size, 10 | validation_size=validation_size) 11 | -------------------------------------------------------------------------------- /data/mil_sim_reach.py: -------------------------------------------------------------------------------- 1 | from data.dataset import Dataset 2 | 3 | 4 | class MilSimReach(Dataset): 5 | 6 | def __init__(self, training_size=1500, validation_size=150): 7 | super().__init__(name='mil_sim_reach', img_shape=(64, 80, 3), 8 | state_size=10, action_size=2, time_horizon=50, 9 | training_size=training_size, 10 | validation_size=validation_size) 11 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | 5 | def create_dir(path): 6 | exist = os.path.exists(path) 7 | if not exist: 8 | os.makedirs(path) 9 | return exist 10 | 11 | 12 | def tf_load_image(foldername, timestep): 13 | file = tf.string_join([foldername, '/', tf.as_string(timestep), '.gif']) 14 | return tf.image.decode_gif(tf.read_file(file))[0] 15 | 16 | 17 | def preprocess(img): 18 | # In range [-1, 1] 19 | return ((tf.cast(img, tf.float32) / 255.) * 2.) - 1. 20 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | Download the _mil_sim_reach_ and _mil_sim_push_ datsets from [here](https://www.doc.ic.ac.uk/~slj12/data/mil_data.zip). Unzip them here. 2 | 3 | _Note: The data format here has been changed slightly in comparison to the original data from the MIL paper._ 4 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepjam/TecNets/bf885956fd45b601ea0a820c124d70702e88ac7a/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import utils 3 | import numpy as np 4 | import imageio 5 | 6 | 7 | class Eval(object): 8 | 9 | def __init__(self, sess, dataset, outputs, supports, num_tasks, 10 | num_trials, log_dir=".", record_gifs=False, render=True): 11 | self.time_horizon = dataset.time_horizon 12 | self.sess = sess 13 | self.demos = dataset.test_set() 14 | self.supports = supports 15 | self.num_tasks = num_tasks 16 | self.num_trials = num_trials 17 | self.log_dir = log_dir 18 | self.record_gifs = record_gifs 19 | self.render = render 20 | self.record_gifs_dir = os.path.join(self.log_dir, 'evaluated_gifs') 21 | self.outputs = outputs 22 | 23 | def evaluate(self, iter): 24 | raise NotImplementedError("Override this function.") 25 | 26 | def get_embedding(self, task_index, demo_indexes): 27 | image_files = [ 28 | self.demos[task_index][j]['image_files'] for j in demo_indexes] 29 | states = [ 30 | self.demos[task_index][j]['states'] for j in demo_indexes] 31 | outs = [ 32 | self.demos[task_index][j]['actions'] for j in demo_indexes] 33 | 34 | feed_dict = { 35 | self.outputs['input_image_files']: image_files, 36 | self.outputs['input_states']: states, 37 | self.outputs['input_outputs']: outs, 38 | } 39 | embedding, = self.sess.run( 40 | self.outputs['sentences'], feed_dict=feed_dict) 41 | return embedding 42 | 43 | def get_action(self, obs, state, embedding): 44 | feed_dict = { 45 | self.outputs['ctrnet_images']: [[obs]], 46 | self.outputs['ctrnet_states']: [[state]], 47 | self.outputs['sentences']: [embedding], 48 | } 49 | action, = self.sess.run( 50 | self.outputs['output_actions'], feed_dict=feed_dict) 51 | return action 52 | 53 | def create_gif_dir(self, iteration_dir, task_id): 54 | gifs_dir = None 55 | if self.record_gifs: 56 | gifs_dir = os.path.join(iteration_dir, 'task_%d' % task_id) 57 | utils.create_dir(gifs_dir) 58 | return gifs_dir 59 | 60 | def save_gifs(self, observations, gifs_dir, trial): 61 | if self.record_gifs: 62 | video = np.array(observations) 63 | record_gif_path = os.path.join( 64 | gifs_dir, 'cond%d.samp0.gif' % trial) 65 | imageio.mimwrite(record_gif_path, video) 66 | -------------------------------------------------------------------------------- /evaluation/eval_mil_push.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | from gym.envs.mujoco.pusher import PusherEnv 5 | from evaluation.eval import Eval 6 | from data import utils 7 | 8 | XML_FOLDER = "/media/stephen/c6c2821e-ed17-493a-b35b-4b66f0b21ee7/MIL/gym/gym/envs/mujoco/assets" 9 | 10 | 11 | class EvalMilPush(Eval): 12 | 13 | def _load_env(self, xml): 14 | xml = xml[xml.rfind('pusher'):] 15 | xml_file = 'sim_push_xmls/test2_ensure_woodtable_distractor_%s' % xml 16 | xml_file = os.path.join(XML_FOLDER, xml_file) 17 | env = PusherEnv(**{'xml_file': xml_file, 'distractors': True}) 18 | env.set_visibility(self.render) 19 | env.render() 20 | viewer = env.viewer 21 | viewer.autoscale() 22 | viewer.cam.trackbodyid = -1 23 | viewer.cam.lookat[0] = 0.4 24 | viewer.cam.lookat[1] = -0.1 25 | viewer.cam.lookat[2] = 0.0 26 | viewer.cam.distance = 0.75 27 | viewer.cam.elevation = -50 28 | viewer.cam.azimuth = -90 29 | return env 30 | 31 | def _eval_success(self, obs): 32 | obs = np.array(obs) 33 | target = obs[:, -3:-1] 34 | obj = obs[:, -6:-4] 35 | dists = np.sum((target - obj) ** 2, 1) # distances at each time step 36 | return np.sum(dists < 0.017) >= 10 37 | 38 | def evaluate(self, iter): 39 | 40 | print("Evaluating at iteration: %i" % iter) 41 | iter_dir = os.path.join(self.record_gifs_dir, 'iter_%i' % iter) 42 | utils.create_dir(iter_dir) 43 | 44 | successes = [] 45 | for i in range(self.num_tasks): 46 | 47 | # demo_selection will be an xml file 48 | env = self._load_env(self.demos[i][0]['demo_selection']) 49 | 50 | selected_demo_indexs = random.sample( 51 | range(len(self.demos[i])), self.supports) 52 | 53 | embedding = self.get_embedding(i, selected_demo_indexs) 54 | gifs_dir = self.create_gif_dir(iter_dir, i) 55 | 56 | for j in range(self.num_trials): 57 | env.reset() 58 | observations = [] 59 | world_state = [] 60 | for t in range(self.time_horizon): 61 | env.render() 62 | # Observation is shape (100,100,3) 63 | obs, state = env.get_current_image_obs() 64 | observations.append(obs) 65 | obs = ((obs / 255.0) * 2.) - 1. 66 | 67 | action = self.get_action(obs, state, embedding) 68 | ob, reward, done, reward_dict = env.step(np.squeeze(action)) 69 | world_state.append(np.squeeze(ob)) 70 | if done: 71 | break 72 | 73 | if self._eval_success(world_state): 74 | successes.append(1.) 75 | else: 76 | successes.append(0.) 77 | self.save_gifs(observations, gifs_dir, j) 78 | 79 | env.render(close=True) 80 | 81 | final_suc = np.mean(successes) 82 | print("Final success rate is %.5f" % (final_suc)) 83 | return final_suc 84 | -------------------------------------------------------------------------------- /evaluation/eval_mil_reach.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import gym 5 | from evaluation.eval import Eval 6 | from data import utils 7 | 8 | REACH_SUCCESS_THRESH = 0.05 9 | REACH_SUCCESS_TIME = 10 10 | 11 | 12 | class EvalMilReach(Eval): 13 | 14 | def __init__(self, sess, dataset, outputs, supports, num_tasks, 15 | num_trials, log_dir=".", record_gifs=False, render=True): 16 | super().__init__(sess, dataset, outputs, supports, num_tasks, 17 | num_trials, log_dir, record_gifs, render) 18 | self.env = gym.make('ReacherMILTest-v1') 19 | self.env.env.set_visibility(render) 20 | 21 | def evaluate(self, iter): 22 | 23 | print("Evaluating at iteration: %i" % iter) 24 | iter_dir = os.path.join(self.record_gifs_dir, 'iter_%i' % iter) 25 | utils.create_dir(iter_dir) 26 | self.env.reset() 27 | 28 | successes = [] 29 | for i in range(self.num_tasks): 30 | 31 | # TODO hacked in for now. Remove 0 32 | dem_conds = self.demos[i][0]['demo_selection'] 33 | 34 | # randomly select a demo from each of the folders 35 | selected_demo_indexs = random.sample( 36 | range(len(dem_conds)), self.supports) 37 | 38 | embedding = self.get_embedding(i, selected_demo_indexs) 39 | gifs_dir = self.create_gif_dir(iter_dir, i) 40 | 41 | for j in range(self.num_trials): 42 | if j in dem_conds: 43 | distances = [] 44 | observations = [] 45 | for t in range(self.time_horizon): 46 | self.env.render() 47 | # Observation is shape (64,80,3) 48 | obs, state = self.env.env.get_current_image_obs() 49 | observations.append(obs) 50 | obs = ((obs / 255.0) * 2.) - 1. 51 | 52 | action = self.get_action(obs, state, embedding) 53 | ob, reward, done, reward_dict = self.env.step( 54 | np.squeeze(action)) 55 | dist = -reward_dict['reward_dist'] 56 | if t >= self.time_horizon - REACH_SUCCESS_TIME: 57 | distances.append(dist) 58 | if np.amin(distances) <= REACH_SUCCESS_THRESH: 59 | successes.append(1.) 60 | else: 61 | successes.append(0.) 62 | self.save_gifs(observations, gifs_dir, j) 63 | 64 | self.env.render(close=True) 65 | self.env.env.next() 66 | self.env.env.set_visibility(self.render) 67 | self.env.render() 68 | 69 | self.env.render(close=True) 70 | self.env.env.reset_iter() 71 | final_suc = np.mean(successes) 72 | print("Final success rate is %.5f" % final_suc) 73 | return final_suc 74 | -------------------------------------------------------------------------------- /main_il.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.platform import flags 2 | from data.mil_sim_reach import MilSimReach 3 | from data.mil_sim_push import MilSimPush 4 | from consumers.control import Control 5 | from consumers.eval_consumer import EvalConsumer 6 | from consumers.generator_consumer import GeneratorConsumer 7 | from consumers.imitation_loss import ImitationLoss 8 | from consumers.margin_loss import MarginLoss 9 | from consumers.task_embedding import TaskEmbedding 10 | from data.data_sequencer import DataSequencer 11 | from data.generator import Generator 12 | from networks.cnn import CNN 13 | from trainers.il_trainer import ILTrainer 14 | from trainers.pipeline import Pipeline 15 | from trainers.summary_writer import SummaryWriter 16 | import os 17 | from networks.save_load import Saver, Loader 18 | 19 | # Dataset/method options 20 | flags.DEFINE_string( 21 | 'dataset', 'sim_reach', 'One of sim_reach, sim_push.') 22 | 23 | # Training Options 24 | flags.DEFINE_integer( 25 | 'iterations', 500000, 'The number of training iterations.') 26 | flags.DEFINE_integer( 27 | 'batch_size', 64, 'The number of tasks sampled per batch (aka batch size).') 28 | flags.DEFINE_float( 29 | 'lr', 0.0001, 'The learning rate.') 30 | flags.DEFINE_integer( 31 | 'support', 5, 'The number of support examples per task (aka k-shot).') 32 | flags.DEFINE_integer( 33 | 'query', 5, 'The number of query examples per task.') 34 | flags.DEFINE_integer( 35 | 'embedding', 20, 'The embedding size.') 36 | 37 | # Model Options 38 | flags.DEFINE_string( 39 | 'activation', 'relu', 'One of relu, elu, or leaky_relu.') 40 | flags.DEFINE_bool( 41 | 'max_pool', False, 'Use max pool rather than strides.') 42 | flags.DEFINE_list( 43 | 'filters', [32, 64], 'List of filters per convolution layer.') 44 | flags.DEFINE_list( 45 | 'kernels', [3, 3], 'List of kernel sizes per convolution layer.') 46 | flags.DEFINE_list( 47 | 'strides', [2, 2], 'List of strides per convolution layer. ' 48 | 'Can be None if using max pooling.') 49 | flags.DEFINE_list( 50 | 'fc_layers', [64, 64], 'List of fully connected nodes per layer.') 51 | flags.DEFINE_float( 52 | 'drop_rate', 0.0, 'Dropout probability. 0 for no dropout.') 53 | flags.DEFINE_string( 54 | 'norm', None, 'One of layer, batch, or None') 55 | 56 | # Loss Options 57 | flags.DEFINE_float( 58 | 'lambda_embedding', 1.0, 'Lambda for the embedding loss.') 59 | flags.DEFINE_float( 60 | 'lambda_support', 1.0, 'Lambda for the support control loss.') 61 | flags.DEFINE_float( 62 | 'lambda_query', 1.0, 'Lambda for the query control loss.') 63 | flags.DEFINE_float( 64 | 'margin', 0.1, 'The margin for the embedding loss.') 65 | 66 | # Logging, Saving, and Eval Options 67 | flags.DEFINE_bool( 68 | 'summaries', True, 'If false do not write summaries (for tensorboard).') 69 | flags.DEFINE_bool( 70 | 'save', True, 'If false do not save network weights.') 71 | flags.DEFINE_bool( 72 | 'load', False, 'If we should load a checkpoint.') 73 | flags.DEFINE_string( 74 | 'logdir', '/tmp/data', 'The directory to store summaries and checkpoints.') 75 | flags.DEFINE_bool( 76 | 'eval', False, 'If evaluation should be done.') 77 | flags.DEFINE_integer( 78 | 'checkpoint_iter', -1, 'The checkpoint iteration to restore ' 79 | '(-1 for latest model).') 80 | flags.DEFINE_string( 81 | 'checkpoint_dir', None, 'The checkpoint directory.') 82 | flags.DEFINE_bool( 83 | 'no_mujoco', True, 'Run without Mujoco. Eval should be False.') 84 | 85 | FLAGS = flags.FLAGS 86 | 87 | if not FLAGS.no_mujoco: 88 | from evaluation.eval_mil_reach import EvalMilReach 89 | from evaluation.eval_mil_push import EvalMilPush 90 | 91 | filters = list(map(int, FLAGS.filters)) 92 | kernels = list(map(int, FLAGS.kernels)) 93 | strides = list(map(int, FLAGS.strides)) 94 | fc_layers = list(map(int, FLAGS.fc_layers)) 95 | 96 | data = None 97 | if FLAGS.dataset == 'sim_reach': 98 | data = MilSimReach() 99 | elif FLAGS.dataset == 'sim_push': 100 | data = MilSimPush() 101 | else: 102 | raise RuntimeError('Unrecognised dataset.') 103 | 104 | loader = saver = None 105 | if FLAGS.save: 106 | saver = Saver(savedir=FLAGS.logdir) 107 | if FLAGS.load: 108 | loader = Loader(savedir=FLAGS.logdir, 109 | checkpoint=FLAGS.checkpoint_iter) 110 | 111 | net = CNN(filters=filters, 112 | fc_layers=fc_layers, 113 | kernel_sizes=kernels, 114 | strides=strides, 115 | max_pool=FLAGS.max_pool, 116 | drop_rate=FLAGS.drop_rate, 117 | norm=FLAGS.norm, 118 | activation=FLAGS.activation) 119 | 120 | sequencer = DataSequencer('first_last', data.time_horizon) 121 | gen = Generator(dataset=data, 122 | batch_size=FLAGS.batch_size, 123 | support_size=FLAGS.support, 124 | query_size=FLAGS.query, 125 | data_sequencer=sequencer) 126 | 127 | generator_consumer = GeneratorConsumer(gen, data, FLAGS.support, FLAGS.query) 128 | task_emb = TaskEmbedding(network=net, 129 | embedding_size=FLAGS.embedding, 130 | include_state=False, 131 | include_action=False) 132 | ml = MarginLoss(margin=FLAGS.margin, loss_lambda=FLAGS.lambda_embedding) 133 | ctr = Control(network=net, 134 | action_size=data.action_size, 135 | include_state=True) 136 | il = ImitationLoss(support_lambda=FLAGS.lambda_support, 137 | query_lambda=FLAGS.lambda_query) 138 | consumers = [generator_consumer, task_emb, ml, ctr, il] 139 | p = Pipeline(consumers, 140 | saver=saver, 141 | loader=loader, 142 | learning_rate=FLAGS.lr) 143 | train_outs = p.get_outputs() 144 | 145 | summary_w = None 146 | log_dir = os.path.join(FLAGS.logdir, 'no_state_action') 147 | if FLAGS.summaries: 148 | summary_w = SummaryWriter(log_dir) 149 | 150 | eval = None 151 | if FLAGS.eval: 152 | disk_images = FLAGS.dataset != 'sim_to_real_place' 153 | econs = EvalConsumer(data, sequencer, FLAGS.support, disk_images) 154 | task_emb = TaskEmbedding(network=net, 155 | embedding_size=FLAGS.embedding, 156 | include_state=False, 157 | include_action=False) 158 | ml = MarginLoss(margin=FLAGS.margin, loss_lambda=FLAGS.lambda_embedding) 159 | ctr = Control(network=net, 160 | action_size=data.action_size, 161 | include_state=True) 162 | peval = Pipeline([econs, task_emb, ml, ctr]) 163 | outs = peval.get_outputs() 164 | if FLAGS.dataset == 'sim_reach': 165 | eval = EvalMilReach(sess=p.get_session(), 166 | dataset=data, 167 | outputs=outs, 168 | supports=FLAGS.support, 169 | num_tasks=10, 170 | num_trials=10, 171 | log_dir=log_dir, 172 | record_gifs=True, 173 | render=False) 174 | elif FLAGS.dataset == 'sim_push': 175 | eval = EvalMilPush(sess=p.get_session(), 176 | dataset=data, 177 | outputs=outs, 178 | supports=FLAGS.support, 179 | num_tasks=10, 180 | num_trials=6, 181 | log_dir=log_dir, 182 | record_gifs=True, 183 | render=False) 184 | trainer = ILTrainer(pipeline=p, 185 | outputs=train_outs, 186 | generator=gen, 187 | iterations=FLAGS.iterations, 188 | summary_writer=summary_w, 189 | eval=eval) 190 | trainer.train() 191 | -------------------------------------------------------------------------------- /networks/cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from networks.utils import * 3 | from networks.input_output import * 4 | 5 | 6 | class CNN(object): 7 | 8 | def __init__(self, filters, fc_layers, kernel_sizes, strides=None, 9 | max_pool=False, drop_rate=0.0, norm=None, activation='relu'): 10 | """Initializes a standard CNN network. 11 | 12 | :param filters: List of number of filters per convolution layer. 13 | :param fc_layers: List of fully connected units per layer. 14 | :param normalization: String defining the type of normalization. 15 | """ 16 | self.filters = filters 17 | self.fc_layers = fc_layers 18 | self.norm = norm 19 | self.drop_rate = drop_rate 20 | self.kernel_sizes = kernel_sizes 21 | self.activation = activation 22 | self.max_pool = max_pool 23 | self.strides = strides if strides is not None else [1] * len(fc_layers) 24 | if not max_pool and strides is None: 25 | raise RuntimeError('No dimensionality reduction.') 26 | 27 | def _pre_layer_util(self, layer, cur_layer_num, ins): 28 | for cin in ins: 29 | if cin.layer_num > cur_layer_num: 30 | break 31 | elif cin.layer_num == cur_layer_num: 32 | if cin.merge_mode == 'concat': 33 | layer = tf.concat([layer, cin.tensor], axis=cin.axis) 34 | elif cin.merge_mode == 'addition': 35 | layer += cin.tensor 36 | elif cin.merge_mode == 'multiply': 37 | layer *= cin.tensor 38 | else: 39 | raise RuntimeError('Unrecognised merging method for %s.' % 40 | cin.name) 41 | return layer 42 | 43 | def _post_layer_util(self, layer, training, norm): 44 | 45 | if self.drop_rate > 0: 46 | layer = tf.layers.dropout(layer, rate=0.5, training=training) 47 | 48 | act_fn = activation(self.activation) 49 | if norm and self.norm is not None: 50 | if self.norm == 'batch': 51 | layer = tf.contrib.layers.batch_norm( 52 | layer, is_training=training, activation_fn=act_fn) 53 | elif self.norm == 'layer': 54 | layer = tf.contrib.layers.layer_norm( 55 | layer, activation_fn=act_fn) 56 | else: 57 | raise RuntimeError('Unsupported normalization method: %s' 58 | % self.norm) 59 | else: 60 | layer = act_fn(layer) 61 | return layer 62 | 63 | def forward(self, inputs, heads, training): 64 | """Inputs want to be fused in at different times. """ 65 | 66 | inputs = sorted(inputs, key=lambda item: item.layer_num) 67 | conv_inputs = list(filter(lambda item: item.layer_type == 'conv', inputs)) 68 | fc_inputs = list(filter(lambda item: item.layer_type == 'fc', inputs)) 69 | 70 | if conv_inputs[0].layer_num > 0: 71 | raise RuntimeError('Need an input tensor.') 72 | elif len(conv_inputs) > 1 and conv_inputs[1].layer_num == 0: 73 | raise RuntimeError('Can only have one main input tensor.') 74 | 75 | layer = conv_inputs[0].tensor 76 | del conv_inputs[0] 77 | 78 | outputs = {} 79 | 80 | for i, (filters, ksize, stride) in enumerate( 81 | zip(self.filters, self.kernel_sizes, self.strides)): 82 | layer = self._pre_layer_util(layer, i, conv_inputs) 83 | layer = tf.layers.conv2d(layer, filters, ksize, stride, 'same') 84 | layer = self._post_layer_util(layer, training, True) 85 | 86 | layer = tf.layers.flatten(layer) 87 | for i, fc_layers in enumerate(self.fc_layers): 88 | layer = self._pre_layer_util(layer, i, fc_inputs) 89 | layer = tf.layers.dense(layer, fc_layers) 90 | layer = self._post_layer_util(layer, training, False) 91 | 92 | for head in heads: 93 | act_fn = activation(head.activation) 94 | output = tf.layers.dense(layer, head.nodes) 95 | outputs[head.name] = output if act_fn is None else act_fn(output) 96 | 97 | return outputs 98 | -------------------------------------------------------------------------------- /networks/input_output.py: -------------------------------------------------------------------------------- 1 | class NetworkInputOutput(object): 2 | """Used for pulling out info at different layers. """ 3 | 4 | def __init__(self, name): 5 | self.name = name 6 | 7 | 8 | class NetworkHead(NetworkInputOutput): 9 | """Used for pulling out info at different layers. """ 10 | 11 | def __init__(self, name, nodes, activation=None): 12 | super().__init__(name) 13 | self.activation = activation 14 | self.nodes = nodes 15 | 16 | 17 | class NetworkInput(NetworkInputOutput): 18 | """Used for inputting inputs at different layers. """ 19 | 20 | def __init__(self, name, layer_type, layer_num, tensor, 21 | merge_mode=None, axis=-1): 22 | super().__init__(name) 23 | self.layer_type = layer_type 24 | self.layer_num = layer_num 25 | self.tensor = tensor 26 | self.merge_mode = merge_mode 27 | self.axis = axis 28 | -------------------------------------------------------------------------------- /networks/save_load.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | PREFIX = "itr" 5 | 6 | 7 | class Saver(object): 8 | 9 | def __init__(self, savedir='.', savetitle=''): 10 | self.savedir = savedir 11 | self.savefile = os.path.join(savedir, savetitle) 12 | self.saver = None 13 | 14 | def save(self, sess, itr): 15 | if self.saver is None: 16 | self.saver = tf.train.Saver(max_to_keep=10) 17 | self.saver.save(sess, self.savefile + "_" + PREFIX + str(itr)) 18 | print('Saved model at iteration', itr) 19 | 20 | 21 | class Loader(object): 22 | 23 | def __init__(self, savedir='.', savetitle='', checkpoint=-1): 24 | self.savedir = savedir 25 | self.checkpoint = checkpoint 26 | self.savefile = os.path.join(savedir, savetitle) 27 | self.saver = None 28 | 29 | def load(self, sess): 30 | if self.saver is None: 31 | self.saver = tf.train.Saver(max_to_keep=10) 32 | model_file = tf.train.latest_checkpoint(self.savedir) 33 | if model_file: 34 | ind1 = model_file.rfind('itr') 35 | if self.checkpoint > 0: 36 | model_file = model_file[:ind1] + PREFIX + str(self.checkpoint) 37 | resume_itr = self.checkpoint 38 | else: 39 | resume_itr = int(model_file[ind1 + len(PREFIX):]) 40 | print("Restoring model weights from " + model_file) 41 | self.saver.restore(sess, model_file) 42 | return resume_itr 43 | raise RuntimeError('Could not find model file in: %s' % self.savedir) 44 | -------------------------------------------------------------------------------- /networks/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def activation(fn_name): 5 | fn = None 6 | if fn_name == 'relu': 7 | fn = tf.nn.relu 8 | elif fn_name == 'elu': 9 | fn = tf.nn.elu 10 | elif fn_name == 'leaky_relu': 11 | fn = tf.nn.leaky_relu 12 | return fn 13 | -------------------------------------------------------------------------------- /readme_images/tecnets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepjam/TecNets/bf885956fd45b601ea0a820c124d70702e88ac7a/readme_images/tecnets.png -------------------------------------------------------------------------------- /scripts/mil_to_tecnet.py: -------------------------------------------------------------------------------- 1 | """Converts the MIL data to a format suited for TecNets""" 2 | 3 | import os 4 | import argparse 5 | from multiprocessing import Pool 6 | import pickle 7 | from PIL import Image 8 | from natsort import natsorted 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('source_dir') 12 | parser.add_argument('target_dir') 13 | parser.add_argument('dataset', choices=['reach', 'push']) 14 | args = parser.parse_args() 15 | 16 | 17 | def process_task_folder(tasks): 18 | src_task_path = os.path.join(args.source_dir, tasks) 19 | if not os.path.isdir(src_task_path): 20 | return 21 | task_id = tasks.split('_')[1] 22 | target_task_path = os.path.join(args.target_dir, 'task_' + task_id) 23 | 24 | task_pkl_opt1 = os.path.join(args.source_dir, task_id + '.pkl') 25 | task_pkl_opt2 = os.path.join(args.source_dir, 'demos_' + task_id + '.pkl') 26 | task_pkl = task_pkl_opt1 if os.path.exists(task_pkl_opt1) else task_pkl_opt2 27 | 28 | gif_dirs = natsorted(os.listdir(src_task_path)) 29 | with open(task_pkl, 'rb') as pkl: 30 | data = pickle.load(pkl, encoding='bytes') 31 | if args.dataset == 'reach': 32 | states = data[b'demoX'] 33 | actions = data[b'demoU'] 34 | demo_selection = data[b'demoConditions'] 35 | elif args.dataset == 'push': 36 | # Mil only used part of the data. 37 | states = data['demoX'][6:-6] 38 | actions = data['demoU'][6:-6] 39 | demo_selection = data['xml'] 40 | gif_dirs = gif_dirs[6:-6] 41 | else: 42 | raise RuntimeError('Unrecognized dataset', args.dataset) 43 | new_data = { 44 | 'states': states, 45 | 'actions': actions, 46 | 'demo_selection': demo_selection 47 | } 48 | 49 | for gif_file in gif_dirs: 50 | new_gif_folder = os.path.join(target_task_path, gif_file[:-4]) 51 | print('Splitting gif', gif_file) 52 | os.makedirs(new_gif_folder) 53 | with Image.open(os.path.join(src_task_path, gif_file)) as frame: 54 | nframes = 0 55 | while frame: 56 | gif_file_path = os.path.join(new_gif_folder, "%i.gif" % nframes) 57 | if not os.path.exists(gif_file_path): 58 | f = frame.convert("RGB") 59 | f.save(gif_file_path, 'gif') 60 | nframes += 1 61 | try: 62 | frame.seek(nframes) 63 | except EOFError: 64 | break 65 | 66 | new_pickle_path = target_task_path + '.pkl' 67 | print('Saving', new_pickle_path) 68 | with open(new_pickle_path, 'wb') as f: 69 | pickle.dump(new_data, f) 70 | 71 | 72 | with Pool(20) as p: 73 | p.map(process_task_folder, os.listdir(args.source_dir)) 74 | -------------------------------------------------------------------------------- /tecnets_corl_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" = "sim_reach" ]; then 4 | python3 main_il.py \ 5 | --dataset=sim_reach --iterations=400000 --batch_size=64 \ 6 | --lr=0.0005 --support=2 --query=2 --embedding=20 \ 7 | --activation=elu --filters=40,40,40 --kernels=3,3,3 --strides=2,2,2 \ 8 | --fc_layers=200,200,200,200 --lambda_embedding=1.0 \ 9 | --lambda_support=0.1 --lambda_query=0.1 --margin=0.1 --norm=layer \ 10 | --logdir='mylog/' --eval=True 11 | elif [ "$1" = "sim_push" ]; then 12 | python3 main_il.py \ 13 | --dataset=sim_push --iterations=400000 --batch_size=100 \ 14 | --lr=0.0005 --support=5 --query=5 --embedding=20 \ 15 | --activation=elu --filters=16,16,16,16 --kernels=5,5,5,5 \ 16 | --strides=2,2,2,2 --fc_layers=200,200,200 --lambda_embedding=1.0 \ 17 | --lambda_support=0.1 --lambda_query=0.1 --margin=0.1 --norm=layer \ 18 | --logdir='mylog/' --eval=True 19 | else 20 | echo 'Invalid.' 21 | fi -------------------------------------------------------------------------------- /test/consumers/test_control.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | import tensorflow as tf 3 | from consumers.control import Control 4 | from networks.cnn import CNN 5 | 6 | ACTION_SIZE = 4 7 | STATE_SIZE = 6 8 | BATCH = 5 9 | FRAMES = 3 10 | EMB_SIZE = 7 11 | 12 | 13 | class TestControl(TestCase): 14 | 15 | def _fake_tensors(self): 16 | return { 17 | 'sentences': tf.random_uniform((BATCH, EMB_SIZE)), 18 | 'ctrnet_images': tf.random_uniform((BATCH, 2, 32, 32, 3)), 19 | 'ctrnet_states': tf.random_uniform((BATCH, 2, STATE_SIZE)), 20 | 'training': True 21 | } 22 | 23 | def _check_loss(self, scope, state): 24 | net = CNN(filters=[8, 16], fc_layers=[20, 20], kernel_sizes=[3, 3], 25 | strides=[2, 2], max_pool=False, norm=None, activation='relu') 26 | c = Control(network=net, action_size=ACTION_SIZE, include_state=state) 27 | with tf.variable_scope(scope): 28 | outputs = c.consume(self._fake_tensors()) 29 | sess = tf.InteractiveSession() 30 | tf.global_variables_initializer().run() 31 | actions = sess.run(outputs['output_actions']) 32 | self.assertEqual(actions.shape, (BATCH, 2, ACTION_SIZE)) 33 | 34 | def test_no_state(self): 35 | self._check_loss('test_no_state', False) 36 | 37 | def test_with_state(self): 38 | self._check_loss('test_with_state', True) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() -------------------------------------------------------------------------------- /test/consumers/test_imitation_loss.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | import tensorflow as tf 3 | from consumers.imitation_loss import ImitationLoss 4 | import numpy as np 5 | 6 | BATCH = 5 7 | ACTION_SIZE = 10 8 | 9 | 10 | class TestImitationLoss(TestCase): 11 | 12 | def _fake_tensors(self): 13 | return { 14 | 'output_actions': tf.random_uniform((BATCH, 2, ACTION_SIZE)), 15 | 'ctrnet_outputs': tf.random_uniform((BATCH, 2, ACTION_SIZE)), 16 | } 17 | 18 | def test_float_outputs(self): 19 | il = ImitationLoss() 20 | with tf.variable_scope('test_float_outputs'): 21 | outputs = il.consume(self._fake_tensors()) 22 | sess = tf.InteractiveSession() 23 | loss_support, loss_query = sess.run( 24 | [outputs['loss_support'], outputs['loss_query']]) 25 | self.assertIs(type(loss_support), np.float32) 26 | self.assertIs(type(loss_query), np.float32) 27 | 28 | 29 | if __name__ == '__main__': 30 | main() -------------------------------------------------------------------------------- /test/consumers/test_margin_loss.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | import tensorflow as tf 3 | from consumers.margin_loss import MarginLoss 4 | import numpy as np 5 | 6 | BATCH = 5 7 | EMB_SIZE = 10 8 | 9 | 10 | class TestMarginLoss(TestCase): 11 | 12 | def _fake_tensors(self, support, query): 13 | return { 14 | 'support_embedding': tf.random_uniform((BATCH, support, EMB_SIZE)), 15 | 'query_embedding': tf.random_uniform((BATCH, query, EMB_SIZE)), 16 | 'prefix': 'test' 17 | } 18 | 19 | def _check_loss_and_accuracy(self, support, query, scope): 20 | ml = MarginLoss(0.1) 21 | with tf.variable_scope(scope): 22 | outputs = ml.consume(self._fake_tensors(support, query)) 23 | sess = tf.InteractiveSession() 24 | loss, accuracy, sentences = sess.run( 25 | [outputs['loss_embedding'], outputs['embedding_accuracy'], 26 | outputs['sentences']]) 27 | self.assertIs(type(loss), np.float32) 28 | self.assertIs(type(accuracy), np.float32) 29 | self.assertTrue(0.0 <= accuracy <= 1.0) 30 | self.assertEqual(sentences.shape, (BATCH, EMB_SIZE)) 31 | 32 | def test_equal_support_query_size(self): 33 | self._check_loss_and_accuracy(3, 3, 'test_equal_support_query_size') 34 | 35 | def test_support_more_than_query_size(self): 36 | self._check_loss_and_accuracy(5, 3, 'test_support_more_than_query_size') 37 | 38 | def test_query_more_than_support_size(self): 39 | self._check_loss_and_accuracy(3, 5, 'test_query_more_than_support_size') 40 | 41 | if __name__ == '__main__': 42 | main() -------------------------------------------------------------------------------- /test/consumers/test_task_embedding.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | import tensorflow as tf 3 | from consumers.task_embedding import TaskEmbedding 4 | from networks.cnn import CNN 5 | 6 | STATE_SIZE = 3 7 | ACTION_SIZE = 4 8 | BATCH = 5 9 | FRAMES = 3 10 | EMB_SIZE = 7 11 | 12 | 13 | class TestTaskEmbedding(TestCase): 14 | 15 | def _fake_tensors(self, support, query): 16 | return { 17 | 'embnet_images': tf.random_uniform( 18 | (BATCH, support + query, FRAMES, 32, 32, 3)), 19 | 'embnet_states': tf.random_uniform( 20 | (BATCH, support + query, FRAMES, STATE_SIZE)), 21 | 'embnet_outputs': tf.random_uniform( 22 | (BATCH, support + query, FRAMES, ACTION_SIZE)), 23 | 'training': False, 24 | 'support': support, 25 | 'query': query 26 | } 27 | 28 | def _check_loss_and_accuracy(self, support, query, scope, state, action): 29 | net = CNN(filters=[8, 16], fc_layers=[20, 20], kernel_sizes=[3, 3], 30 | strides=[2, 2], max_pool=False, norm=None, activation='relu') 31 | te = TaskEmbedding(network=net, embedding_size=EMB_SIZE, 32 | frame_collapse_method='concat', 33 | include_state=state, include_action=action) 34 | with tf.variable_scope(scope): 35 | outputs = te.consume(self._fake_tensors(support, query)) 36 | sess = tf.InteractiveSession() 37 | tf.global_variables_initializer().run() 38 | sup_emb, que_emb = sess.run( 39 | [outputs['support_embedding'], outputs['query_embedding']]) 40 | self.assertEqual(sup_emb.shape, (BATCH, support, EMB_SIZE)) 41 | self.assertEqual(que_emb.shape, (BATCH, support, EMB_SIZE)) 42 | 43 | def test_no_state_action(self): 44 | self._check_loss_and_accuracy( 45 | 3, 3, 'test_no_state_action', False, False) 46 | 47 | def test_with_state(self): 48 | self._check_loss_and_accuracy( 49 | 3, 3, 'test_with_state', True, False) 50 | 51 | def test_with_state_and_action(self): 52 | self._check_loss_and_accuracy( 53 | 3, 3, 'test_with_state_and_action', True, True) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() -------------------------------------------------------------------------------- /test/data/test_data_sequencer.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | import tensorflow as tf 3 | from data.data_sequencer import DataSequencer 4 | import numpy as np 5 | 6 | TIME = 10 7 | STATES = 2 8 | OUTS = 3 9 | 10 | 11 | class TestDataSequencer(TestCase): 12 | 13 | def _ram_data(self, sequence_strategy, sequence_num, scope): 14 | ds = DataSequencer(sequence_strategy, TIME) 15 | with tf.variable_scope(scope): 16 | images = tf.random_uniform((TIME, 8, 8, 3)) 17 | states = tf.random_uniform((TIME, STATES)) 18 | outputs = tf.random_uniform((TIME, OUTS)) 19 | limgs, lstates, louts = ds.load(images, states, outputs) 20 | sess = tf.InteractiveSession() 21 | out_imgs, out_states, out_outs = sess.run([limgs, lstates, louts]) 22 | self.assertEqual(np.array(out_imgs).shape, (sequence_num, 8, 8, 3)) 23 | self.assertEqual(np.array(out_states).shape, (sequence_num, STATES)) 24 | self.assertEqual(np.array(out_outs).shape, (sequence_num, OUTS)) 25 | 26 | def test_ram_first(self): 27 | self._ram_data('first', 1, 'test_ram_first') 28 | 29 | def test_ram_last(self): 30 | self._ram_data('last', 1, 'test_ram_last') 31 | 32 | def test_ram_first_last(self): 33 | self._ram_data('first_last', 2, 'test_ram_first_last') 34 | 35 | def test_ram_all(self): 36 | self._ram_data('all', TIME, 'test_ram_all') 37 | 38 | 39 | if __name__ == '__main__': 40 | main() -------------------------------------------------------------------------------- /test/data/test_datasets.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | from data.mil_sim_push import MilSimPush 3 | from data.mil_sim_reach import MilSimReach 4 | 5 | 6 | class TestDatasets(TestCase): 7 | 8 | def test_mil_sim_reach_train(self): 9 | data = MilSimReach(training_size=10, validation_size=5) 10 | train, validation = data.training_set() 11 | self.assertEqual(len(train), 10) 12 | self.assertEqual(len(validation), 5) 13 | self.assertEqual(len(train[0]), 9) 14 | self.assertEqual(len(validation[0]), 5) 15 | self.assertIn('actions', train[0][0]) 16 | self.assertIn('states', train[0][0]) 17 | self.assertIn('image_files', train[0][0]) 18 | 19 | def test_mil_sim_reach_test(self): 20 | data = MilSimReach(training_size=10, validation_size=10) 21 | test = data.test_set() 22 | self.assertEqual(len(test), 150) 23 | 24 | def test_mil_sim_push_train(self): 25 | data = MilSimPush(training_size=10, validation_size=5) 26 | train, validation = data.training_set() 27 | self.assertEqual(len(train), 10) 28 | self.assertEqual(len(validation), 5) 29 | self.assertEqual(len(train[0]), 12) 30 | self.assertEqual(len(validation[0]), 12) 31 | self.assertIn('actions', train[0][0]) 32 | self.assertIn('states', train[0][0]) 33 | self.assertIn('image_files', train[0][0]) 34 | 35 | def test_mil_sim_push_test(self): 36 | data = MilSimPush(training_size=10, validation_size=10) 37 | test = data.test_set() 38 | self.assertEqual(len(test), 74) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /test/data/test_generator.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | from unittest.mock import MagicMock 3 | from data.generator import Generator 4 | import numpy as np 5 | import tensorflow as tf 6 | import os 7 | from data.data_sequencer import DataSequencer 8 | 9 | 10 | TIME_HORIZON = 4 11 | BATCH_SIZE = 3 12 | SUPPORT_SIZE = 2 13 | QUERY_SIZE = 2 14 | TASKS = 20 15 | EXAMPLES = 6 16 | IMG_SHAPE = (8, 8, 3) 17 | STATE_SIZE = 11 18 | OUTPUT_SIZE = 9 19 | 20 | 21 | class TestGenerator(TestCase): 22 | 23 | def _fake_dataset_load(self, tasks, examples): 24 | fake_folder = os.path.join( 25 | os.path.dirname(os.path.realpath(__file__)), 26 | '../test_data', 'test_task') 27 | data = [[{ 28 | 'image_files': fake_folder, 29 | 'states': np.ones((TIME_HORIZON, STATE_SIZE)), 30 | 'actions': np.ones((TIME_HORIZON, OUTPUT_SIZE)) 31 | } for _ in range(examples)] for _ in range(tasks)] 32 | # Return fake train and validation data 33 | return data, data 34 | 35 | def _fake_dataset(self, tasks, examples): 36 | dataset = MagicMock() 37 | dataset.time_horizon = TIME_HORIZON 38 | dataset.training_set = MagicMock( 39 | return_value=self._fake_dataset_load(tasks, examples)) 40 | return dataset 41 | 42 | # TODO: Should move the data_sequencer code to the correct test class 43 | def _embedding_strategy(self, scope, strategy, frames, 44 | batch_size=BATCH_SIZE, support_size=SUPPORT_SIZE, 45 | query_size=QUERY_SIZE): 46 | dataset = self._fake_dataset(TASKS, EXAMPLES) 47 | data_seq = DataSequencer(strategy, TIME_HORIZON) 48 | gen = Generator(dataset, batch_size, support_size, query_size, 49 | data_sequencer=data_seq) 50 | with tf.variable_scope(scope): 51 | sess = tf.InteractiveSession() 52 | train_handle, val_handle = gen.get_handles(sess) 53 | (embed_images, embnet_states, embnet_outputs, ctrnet_images, 54 | ctrnet_states, ctrnet_outputs) = sess.run( 55 | gen.next_element, feed_dict={gen.handle: train_handle}) 56 | self.assertEqual( 57 | embed_images.shape, 58 | (batch_size, support_size + query_size, frames) + IMG_SHAPE) 59 | self.assertEqual( 60 | embnet_states.shape, 61 | (batch_size, support_size + query_size, frames, STATE_SIZE)) 62 | self.assertEqual( 63 | embnet_outputs.shape, 64 | (batch_size, support_size + query_size, frames, OUTPUT_SIZE)) 65 | self.assertEqual(ctrnet_images.shape, (batch_size, 2) + IMG_SHAPE) 66 | self.assertEqual(ctrnet_states.shape, (batch_size, 2, STATE_SIZE)) 67 | self.assertEqual(ctrnet_outputs.shape, (batch_size, 2, OUTPUT_SIZE)) 68 | 69 | def test_first_frame_embedding(self): 70 | self._embedding_strategy('test_first_frame_embedding', 'first', 1) 71 | 72 | def test_last_frame_embedding(self): 73 | self._embedding_strategy('test_last_frame_embedding', 'last', 1) 74 | 75 | def test_first_last_frame_embedding(self): 76 | self._embedding_strategy('test_first_last_frame_embedding', 77 | 'first_last', 2) 78 | 79 | def test_all_frame_embedding(self): 80 | self._embedding_strategy('test_all_frame_embedding', 'all', 81 | TIME_HORIZON) 82 | 83 | def test_invalid_frame_embedding_throws_error(self): 84 | with self.assertRaises(ValueError): 85 | self._embedding_strategy( 86 | 'test_invalid_frame_embedding_throws_error', 'invalid', 1) 87 | 88 | def test_support_and_query_more_than_samples(self): 89 | with self.assertRaises(Exception): 90 | self._embedding_strategy( 91 | 'test_support_and_query_more_than_samples', 'first', 92 | 1, support_size=TIME_HORIZON+1) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /test/evaluation/test_evalulation.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | import tensorflow as tf 3 | from evaluation.eval_mil_reach import EvalMilReach 4 | from evaluation.eval_mil_push import EvalMilPush 5 | from consumers import eval_consumer 6 | from data import data_sequencer 7 | from data.mil_sim_reach import MilSimReach 8 | from data.mil_sim_push import MilSimPush 9 | 10 | 11 | class TestEvaluation(TestCase): 12 | 13 | def test_sim_reach(self): 14 | 15 | data = MilSimReach() 16 | with tf.variable_scope('test_sim_reach'): 17 | data_seq = data_sequencer.DataSequencer('first_last', 18 | data.time_horizon) 19 | eval_con = eval_consumer.EvalConsumer(data, data_seq, 2) 20 | outputs = eval_con.consume({}) 21 | outputs['sentences'] = tf.ones((1, 20)) 22 | outputs['output_actions'] = tf.ones((1, data.action_size)) 23 | sess = tf.InteractiveSession() 24 | tf.global_variables_initializer().run() 25 | eval = EvalMilReach(sess, dataset=data, outputs=outputs, supports=2, 26 | num_tasks=2, num_trials=2, log_dir=".", 27 | record_gifs=False, render=False) 28 | eval.evaluate(0) 29 | 30 | def test_sim_push(self): 31 | 32 | data = MilSimPush() 33 | with tf.variable_scope('test_sim_push'): 34 | data_seq = data_sequencer.DataSequencer('first_last', 35 | data.time_horizon) 36 | eval_con = eval_consumer.EvalConsumer(data, data_seq, 2) 37 | outputs = eval_con.consume({}) 38 | outputs['sentences'] = tf.ones((1, 20)) 39 | outputs['output_actions'] = tf.ones((1, data.action_size)) 40 | sess = tf.InteractiveSession() 41 | tf.global_variables_initializer().run() 42 | eval = EvalMilPush(sess, dataset=data, outputs=outputs, supports=2, 43 | num_tasks=2, num_trials=2, log_dir=".", 44 | record_gifs=False, render=False) 45 | eval.evaluate(0) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /test/integration/test_integration.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | from evaluation.eval_mil_reach import EvalMilReach 3 | from consumers.eval_consumer import EvalConsumer 4 | from consumers.control import Control 5 | from consumers.generator_consumer import GeneratorConsumer 6 | from consumers.imitation_loss import ImitationLoss 7 | from consumers.margin_loss import MarginLoss 8 | from consumers.task_embedding import TaskEmbedding 9 | from data.data_sequencer import DataSequencer 10 | from data.generator import Generator 11 | from data.mil_sim_push import MilSimPush 12 | from data.mil_sim_reach import MilSimReach 13 | from networks.cnn import CNN 14 | from trainers.il_trainer import ILTrainer 15 | from trainers.pipeline import Pipeline 16 | import tensorflow as tf 17 | 18 | 19 | class TestIntegration(TestCase): 20 | 21 | def _default_pipeline(self, dataset, q_s_size=2): 22 | support_size = query_size = q_s_size 23 | net = CNN(filters=[4, 4, 4, 4], fc_layers=[20, 20], 24 | kernel_sizes=[3, 3, 3, 3], strides=[2, 2, 2, 2], 25 | max_pool=False, norm=None, activation='relu') 26 | 27 | sequencer = DataSequencer('first_last', dataset.time_horizon) 28 | gen = Generator(dataset=dataset, batch_size=4, 29 | support_size=support_size, query_size=query_size, 30 | data_sequencer=sequencer) 31 | 32 | gen_con = GeneratorConsumer(gen, dataset, support_size, query_size) 33 | task_emb = TaskEmbedding( 34 | network=net, embedding_size=6, support_size=support_size, 35 | query_size=query_size, include_state=False, include_action=False) 36 | ml = MarginLoss(margin=0.1) 37 | ctr = Control(network=net, action_size=dataset.action_size, 38 | include_state=True) 39 | il = ImitationLoss() 40 | consumers = [gen_con, task_emb, ml, ctr, il] 41 | p = Pipeline(consumers) 42 | outputs = p.get_outputs() 43 | trainer = ILTrainer(pipeline=p, outputs=outputs, 44 | generator=gen, iterations=10) 45 | trainer.train() 46 | 47 | def _default_evaluation(self, dataset, q_s_size=2, disk_images=True): 48 | support_size = query_size = q_s_size 49 | net = CNN(filters=[4, 4, 4, 4], fc_layers=[20, 20], 50 | kernel_sizes=[3, 3, 3, 3], strides=[2, 2, 2, 2], 51 | max_pool=False, norm=None, activation='relu') 52 | sequencer = DataSequencer('first_last', dataset.time_horizon) 53 | eval_cons = EvalConsumer(dataset, sequencer, support_size, disk_images) 54 | task_emb = TaskEmbedding( 55 | network=net, embedding_size=6, support_size=support_size, 56 | query_size=query_size, include_state=False, include_action=False) 57 | ml = MarginLoss(margin=0.1) 58 | ctr = Control(network=net, action_size=dataset.action_size, 59 | include_state=True) 60 | consumers = [eval_cons, task_emb, ml, ctr] 61 | p = Pipeline(consumers) 62 | outs = p.get_outputs() 63 | return outs, p.get_session() 64 | 65 | def test_imitation_learning_mil_reach(self): 66 | data = MilSimReach() 67 | with tf.variable_scope('test_imitation_learning_mil_reach'): 68 | self._default_pipeline(data, 1) 69 | 70 | def test_imitation_learning_mil_push(self): 71 | data = MilSimPush() 72 | with tf.variable_scope('test_imitation_learning_mil_push'): 73 | self._default_pipeline(data) 74 | 75 | def test_eval_mil_reach(self): 76 | data = MilSimReach() 77 | with tf.variable_scope('test_eval_mil_reach'): 78 | outs, sess = self._default_evaluation(data, 2) 79 | eval = EvalMilReach(sess=sess, 80 | dataset=data, 81 | outputs=outs, 82 | supports=2, 83 | num_tasks=2, 84 | num_trials=2, 85 | record_gifs=False, 86 | render=False) 87 | tf.global_variables_initializer().run() 88 | eval.evaluate(0) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /test/networks/test_networks.py: -------------------------------------------------------------------------------- 1 | from unittest import main, TestCase 2 | import tensorflow as tf 3 | import numpy as np 4 | from networks.cnn import CNN 5 | from networks.input_output import * 6 | 7 | 8 | class TestNetwork(TestCase): 9 | """ 10 | This will test is everything to do with the network works configures correctly 11 | """ 12 | 13 | def _inputs(self): 14 | action = NetworkHead('action', 3, None) 15 | pose = NetworkHead('pose', 2, 'relu') 16 | img_in = NetworkInput('img_img', 'conv', 0, 17 | tf.placeholder(tf.float32, (None, 8, 8, 3))) 18 | angles_in = NetworkInput('angles', 'fc', 0, 19 | tf.placeholder(tf.float32, (None, 2)), 20 | 'concat', axis=-1) 21 | return action, pose, img_in, angles_in 22 | 23 | def _basic_cnn(self): 24 | action, pose, img_in, angles_in = self._inputs() 25 | net = CNN(filters=[8, 16], fc_layers=[20, 20], kernel_sizes=[3, 3], 26 | strides=[2, 2], max_pool=False, norm=None, activation='relu') 27 | outputs = net.forward([img_in, angles_in], [action, pose], 28 | training=None) 29 | return outputs, img_in, angles_in 30 | 31 | def _norm_cnn(self, norm, scope): 32 | action, pose, img_in, angles_in = self._inputs() 33 | net = CNN(filters=[8, 16], fc_layers=[20, 20], kernel_sizes=[3, 3], 34 | strides=[2, 2], max_pool=False, norm=norm, 35 | activation='relu') 36 | 37 | with tf.variable_scope(scope): 38 | train = tf.placeholder(tf.bool) 39 | outputs = net.forward([img_in, angles_in], [action, pose], 40 | training=train) 41 | sess = tf.InteractiveSession() 42 | tf.global_variables_initializer().run() 43 | action, pose = sess.run( 44 | [outputs['action'], outputs['pose']], feed_dict={ 45 | img_in.tensor: np.ones((1, 8, 8, 3)), 46 | angles_in.tensor: np.ones((1, 2)), 47 | train: True 48 | }) 49 | self.assertEqual(action.shape, (1, 3,)) 50 | self.assertEqual(pose.shape, (1, 2,)) 51 | 52 | def test_construct_model(self): 53 | with tf.variable_scope('test_construct_model'): 54 | outputs, _, _ = self._basic_cnn() 55 | self.assertTrue('action' in outputs) 56 | self.assertTrue('pose' in outputs) 57 | 58 | def test_model_forward_pass(self): 59 | with tf.variable_scope('test_model_forward_pass'): 60 | outputs, img_in, angles_in = self._basic_cnn() 61 | sess = tf.InteractiveSession() 62 | tf.global_variables_initializer().run() 63 | action, pose = sess.run( 64 | [outputs['action'], outputs['pose']], feed_dict={ 65 | img_in.tensor: np.ones((1, 8, 8, 3)), 66 | angles_in.tensor: np.ones((1, 2)) 67 | }) 68 | self.assertEqual(action.shape, (1, 3,)) 69 | self.assertEqual(pose.shape, (1, 2,)) 70 | 71 | def test_model_with_batchnorm(self): 72 | self._norm_cnn('batch', 'test_model_with_batchnorm') 73 | 74 | def test_model_with_layernorm(self): 75 | self._norm_cnn('layer', 'test_model_with_layernorm') 76 | 77 | 78 | if __name__ == '__main__': 79 | main() -------------------------------------------------------------------------------- /test/test_data/test_task/0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepjam/TecNets/bf885956fd45b601ea0a820c124d70702e88ac7a/test/test_data/test_task/0.gif -------------------------------------------------------------------------------- /test/test_data/test_task/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepjam/TecNets/bf885956fd45b601ea0a820c124d70702e88ac7a/test/test_data/test_task/1.gif -------------------------------------------------------------------------------- /test/test_data/test_task/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepjam/TecNets/bf885956fd45b601ea0a820c124d70702e88ac7a/test/test_data/test_task/2.gif -------------------------------------------------------------------------------- /test/test_data/test_task/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepjam/TecNets/bf885956fd45b601ea0a820c124d70702e88ac7a/test/test_data/test_task/3.gif -------------------------------------------------------------------------------- /trainers/il_trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | 4 | VAL_SUMMARY_INTERVAL = 100 5 | SUMMARY_INTERVAL = 100 6 | SAVE_INTERVAL = 10000 7 | EVAL_INTERVAL = 25000 8 | 9 | 10 | class ILTrainer(object): 11 | 12 | def __init__(self, pipeline, outputs, generator, iterations, 13 | summary_writer=None, eval=None): 14 | self.pipeline = pipeline 15 | self.generator = generator 16 | self.outputs = outputs 17 | self.iterations = iterations 18 | self.summary_writer = summary_writer 19 | self.eval = eval 20 | 21 | if eval is not None: 22 | # Convenience for plotting eval successes in tensorboard 23 | self.eval_summary_in = tf.placeholder(tf.float32) 24 | self.eval_summary = tf.summary.scalar('evaluation_success', 25 | self.eval_summary_in) 26 | 27 | def train(self): 28 | 29 | sess = self.pipeline.get_session() 30 | train_handle, validation_handle = self.generator.get_handles(sess) 31 | 32 | outputs = self.outputs 33 | total_loss = self.pipeline.get_loss() 34 | train_op = self.pipeline.get_train_op(total_loss) 35 | train_summaries = self.pipeline.get_summaries('train') 36 | validation_summaries = self.pipeline.get_summaries('validation') 37 | 38 | tf.global_variables_initializer().run() 39 | 40 | # Load if we have supplied a checkpoint 41 | resume_itr = self.pipeline.load() 42 | 43 | print('Setup Complete. Starting training...') 44 | 45 | for itr in range(resume_itr, self.iterations + 1): 46 | 47 | fetches = [train_op] 48 | 49 | feed_dict = { 50 | self.generator.handle: train_handle, 51 | outputs['training']: True 52 | } 53 | 54 | if itr % SUMMARY_INTERVAL == 0: 55 | fetches.append(total_loss) 56 | if self.summary_writer is not None: 57 | fetches.append(train_summaries) 58 | 59 | start = time.time() 60 | result = sess.run(fetches, feed_dict) 61 | 62 | if itr % SUMMARY_INTERVAL == 0: 63 | print('Summary iter', itr, '| Loss:', 64 | result[1], '| Time:', time.time() - start) 65 | if self.summary_writer is not None: 66 | self.summary_writer.add_summary(sess, result[-1], itr) 67 | 68 | if (itr % VAL_SUMMARY_INTERVAL == 0 and 69 | self.summary_writer is not None): 70 | feed_dict = { 71 | self.generator.handle: validation_handle, 72 | outputs['training']: False 73 | } 74 | result = sess.run([validation_summaries], feed_dict) 75 | self.summary_writer.add_summary(sess, result[0], itr) 76 | 77 | if itr % EVAL_INTERVAL == 0 and itr > 1 and self.eval is not None: 78 | acc = self.eval.evaluate(itr) 79 | print('Evaluation at iter %d. Success rate: %.2f' % (itr, acc)) 80 | if self.summary_writer is not None: 81 | eval_success = sess.run( 82 | self.eval_summary, {self.eval_summary_in: acc}) 83 | self.summary_writer.add_summary(sess, eval_success, itr) 84 | 85 | if itr % SAVE_INTERVAL == 0: 86 | self.pipeline.save(itr) 87 | -------------------------------------------------------------------------------- /trainers/pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Pipeline(object): 5 | 6 | def __init__(self, consumers, learning_rate=1e-4, grad_clip=None, 7 | saver=None, loader=None): 8 | self.consumers = consumers 9 | self.grad_clip = grad_clip 10 | self.opt = tf.train.AdamOptimizer(learning_rate) 11 | self.total_loss = None 12 | self.sess = None 13 | self.gradients = None 14 | self.saver = saver 15 | self.loader = loader 16 | self.got_outputs = False 17 | 18 | def get_session(self): 19 | if self.sess is None: 20 | self.sess = tf.InteractiveSession() 21 | return self.sess 22 | 23 | def get_outputs(self): 24 | outputs = {} 25 | for consumer in self.consumers: 26 | outputs = consumer.consume(outputs) 27 | self.got_outputs = True 28 | return outputs 29 | 30 | def load(self): 31 | if self.loader is None: 32 | return 1 33 | if not self.got_outputs: 34 | raise RuntimeError( 35 | 'get_outputs() needs to be called before loading a model.') 36 | return self.loader.load(self.get_session()) 37 | 38 | def save(self, itr): 39 | if self.saver is not None: 40 | self.saver.save(self.get_session(), itr) 41 | 42 | def get_summaries(self, prefix): 43 | if self.total_loss is None: 44 | self.get_loss() 45 | summaries = [tf.summary.scalar(prefix + '_total_loss', self.total_loss)] 46 | for consumer in self.consumers: 47 | summaries.append(consumer.get_summaries(prefix)) 48 | 49 | if self.gradients is None: 50 | raise RuntimeError('Call get_train_op before this.') 51 | for grad, var in self.gradients: 52 | summaries.append(tf.summary.histogram(var.name, var)) 53 | summaries.append(tf.summary.histogram(var.name + '/gradient', grad)) 54 | 55 | return tf.summary.merge(summaries) 56 | 57 | def get_loss(self): 58 | loss = 0 59 | for consumer in self.consumers: 60 | loss += consumer.get_loss() 61 | self.total_loss = loss 62 | return loss 63 | 64 | def get_train_op(self, loss): 65 | # gvs = self.opt.compute_gradients(loss) 66 | gradients = tf.gradients(loss, tf.trainable_variables()) 67 | self.gradients = list(zip(gradients, tf.trainable_variables())) 68 | if self.grad_clip is not None: 69 | self.gradients = [ 70 | (tf.clip_by_value(grad, -self.grad_clip, self.grad_clip) 71 | if grad is not None else grad, var) 72 | for grad, var in self.gradients] 73 | return self.opt.apply_gradients(self.gradients) 74 | -------------------------------------------------------------------------------- /trainers/summary_writer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | 5 | class SummaryWriter(object): 6 | 7 | def __init__(self, log_dir): 8 | self.log_dir = log_dir 9 | if not os.path.exists(self.log_dir): 10 | os.makedirs(self.log_dir) 11 | self.writer = None 12 | 13 | def add_summary(self, sess, data, itr): 14 | if self.writer is None: 15 | self.writer = tf.summary.FileWriter(self.log_dir, sess.graph) 16 | self.writer.add_summary(data, itr) 17 | --------------------------------------------------------------------------------