├── .gitignore ├── LICENSE ├── README.md ├── policy_based ├── Dockerfile ├── README.md ├── atari_reset │ ├── LICENSE │ ├── __init__.py │ └── atari_reset │ │ ├── __init__.py │ │ ├── policies.py │ │ ├── ppo.py │ │ └── wrappers.py ├── goexplore_py │ ├── __init__.py │ ├── archives.py │ ├── basics.py │ ├── cell_representations.py │ ├── data_classes.py │ ├── experiment_settings.py │ ├── explorers.py │ ├── ge_models.py │ ├── ge_policies.py │ ├── ge_runners.py │ ├── ge_wrappers.py │ ├── generic_atari_env.py │ ├── generic_goal_conditioned_env.py │ ├── globals.py │ ├── goal_representations.py │ ├── goexplore.py │ ├── import_ai.py │ ├── logger.py │ ├── montezuma_env.py │ ├── mpi_support.py │ ├── pitfall_env.py │ ├── profiler.py │ ├── randselectors.py │ ├── trajectory_gatherers.py │ ├── trajectory_manager.py │ ├── trajectory_trackers.py │ └── utils.py ├── goexplore_start.py ├── requirements.txt ├── run_policy_based_ge_montezuma.sh └── run_policy_based_ge_pitfall.sh └── robustified ├── README.md ├── control_im_fetch.sh ├── control_ppo_fetch.sh ├── gen_demo ├── __init__.py ├── atari_demo │ ├── LICENSE │ ├── __init__.py │ ├── cloned_vec_env.py │ ├── utils.py │ └── wrappers.py └── new_gen_demo.py ├── gen_demo_atari.sh ├── gen_demo_fetch.sh ├── goexplore_py ├── __init__.py ├── basics.py ├── complex_fetch_env.py ├── explorers.py ├── fetch_xml │ ├── LICENSE │ ├── README.md │ ├── box │ │ ├── asset.xml │ │ └── chain.xml │ ├── door │ │ ├── asset.xml │ │ ├── chain0.xml │ │ └── chain1.xml │ ├── fetch_maneuver.xml │ ├── fetch_pole.xml │ ├── gallery │ │ ├── boxes.JPG │ │ ├── maneuver.JPG │ │ ├── objects.JPG │ │ └── pole.JPG │ ├── robot │ │ └── fetch │ │ │ ├── asset.xml │ │ │ ├── base_link_collision.stl │ │ │ ├── bellows_link_collision.stl │ │ │ ├── chain.xml │ │ │ ├── elbow_flex_link_collision.stl │ │ │ ├── estop_link.stl │ │ │ ├── forearm_roll_link_collision.stl │ │ │ ├── gripper_link.stl │ │ │ ├── head_pan_link_collision.stl │ │ │ ├── head_tilt_link_collision.stl │ │ │ ├── l_gripper_finger_link.stl │ │ │ ├── l_wheel_link.stl │ │ │ ├── l_wheel_link_collision.stl │ │ │ ├── laser_link.stl │ │ │ ├── main.xml │ │ │ ├── r_gripper_finger_link.stl │ │ │ ├── r_wheel_link_collision.stl │ │ │ ├── shoulder_lift_link_collision.stl │ │ │ ├── shoulder_pan_link_collision.stl │ │ │ ├── torso_fixed_link.stl │ │ │ ├── torso_lift_link_collision.stl │ │ │ ├── upperarm_roll_link_collision.stl │ │ │ ├── wrist_flex_link_collision.stl │ │ │ └── wrist_roll_link_collision.stl │ ├── shelf │ │ └── chain.xml │ ├── table │ │ ├── asset.xml │ │ └── chain.xml │ ├── teleOp_boxes.xml │ ├── teleOp_boxes_1.xml │ ├── teleOp_objects.xml │ └── texture │ │ ├── marble.png │ │ ├── small_placeholder_2d.png │ │ ├── small_placeholder_cube.png │ │ └── wood.png ├── generic_atari_env.py ├── generic_goal_conditioned_env.py ├── goexplore.py ├── import_ai.py ├── main.py ├── montezuma_env.py ├── notebook_utils.py ├── pitfall_env.py ├── randselectors.py ├── utils.py └── visualize.py ├── phase1_downscaled.sh ├── phase1_fetch.sh ├── phase1_montezuma.sh ├── phase1_pitfall.sh ├── phase2_atari.sh ├── phase2_atari_test.sh ├── phase2_fetch.sh └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *~ 4 | .idea 5 | results 6 | .ipynb_checkpoints 7 | .DS_Store 8 | demos_* 9 | *.png 10 | *.mp4 11 | to_do 12 | to_kill 13 | *.demo 14 | *.tar.gz 15 | *.jobs 16 | *.monitor.csv 17 | .unison* 18 | *.orig 19 | *~master 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | "License" shall mean the terms and conditions for use, reproduction, 2 | and distribution as defined by the text below. 3 | 4 | "You" (or "Your") shall mean an individual or Legal Entity exercising 5 | permissions granted by this License. 6 | 7 | "Legal Entity" shall mean the union of the acting entity and all other 8 | entities that control, are controlled by, or are under common control 9 | with that entity. For the purposes of this definition, "control" means 10 | (i) the power, direct or indirect, to cause the direction or management 11 | of such entity, whether by contract or otherwise, or (ii) ownership of 12 | fifty percent (50%) or more of the outstanding shares, or (iii) 13 | beneficial ownership of such entity. 14 | 15 | "Source" form shall mean the preferred form for making modifications, 16 | including but not limited to software source code, documentation source, 17 | and configuration files. 18 | 19 | "Object" form shall mean any form resulting from mechanical transformation 20 | or translation of a Source form, including but not limited to compiled 21 | object code, generated documentation, and conversions to other media types. 22 | 23 | "Work" shall mean the work of authorship, whether in Source or Object form, 24 | made available under this License. 25 | 26 | This License governs use of the accompanying Work, and your use of the Work 27 | constitutes acceptance of this License. 28 | 29 | You may use this Work for any non-commercial purpose, subject to the 30 | restrictions in this License. Some purposes which can be non-commercial 31 | are teaching, academic research, and personal experimentation. You may also 32 | distribute this Work with books or other teaching materials, or publish the 33 | Work on websites, that are intended to teach the use of the Work. 34 | 35 | You may not use or distribute this Work, or any derivative works, outputs, 36 | or results from the Work, in any form for commercial purposes. Non-exhaustive 37 | examples of commercial purposes would be running business operations, licensing, 38 | leasing, or selling the Work, or distributing the Work for use with commercial 39 | products. 40 | 41 | You may modify this Work and distribute the modified Work for non-commercial 42 | purposes, however, you may not grant rights to the Work or derivative works 43 | that are broader than or in conflict with those provided by this License. 44 | For example, you may not distribute modifications of the Work under terms 45 | that would permit commercial use, or under terms that purport to require 46 | the Work or derivative works to be sublicensed to others. 47 | 48 | In return, we require that you agree: 49 | 50 | 1. Not to remove any copyright or other notices from the Work. 51 | 52 | 2. That if you distribute the Work in Source or Object form, you will include 53 | a verbatim copy of this License. 54 | 55 | 3. That if you distribute derivative works of the Work in Source form, you do 56 | so only under a license that includes all of the provisions of this License 57 | and is not in conflict with this License, and if you distribute derivative 58 | works of the Work solely in Object form you do so only under a license that 59 | complies with this License. 60 | 61 | 4. That if you have modified the Work or created derivative works from the 62 | Work, and distribute such modifications or derivative works, you will cause 63 | the modified files to carry prominent notices so that recipients know that 64 | they are not receiving the original Work. Such notices must state: (i) that 65 | you have changed the Work; and (ii) the date of any changes. 66 | 67 | 5. If you publicly use the Work or any output or result of the Work, you will 68 | provide a notice with such use that provides any person who uses, views, 69 | accesses, interacts with, or is otherwise exposed to the Work (i) with 70 | information of the nature of the Work, (ii) with a link to the Work, and 71 | (iii) a notice that the Work is available under this License. 72 | 73 | 6. THAT THE WORK COMES "AS IS", WITH NO WARRANTIES. THIS MEANS NO EXPRESS, 74 | IMPLIED OR STATUTORY WARRANTY, INCLUDING WITHOUT LIMITATION, WARRANTIES OF 75 | MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE OR ANY WARRANTY OF TITLE 76 | OR NON-INFRINGEMENT. ALSO, YOU MUST PASS THIS DISCLAIMER ON WHENEVER YOU 77 | DISTRIBUTE THE WORK OR DERIVATIVE WORKS. 78 | 79 | 7. THAT NEITHER UBER TECHNOLOGIES, INC. NOR ANY OF ITS AFFILIATES, SUPPLIERS, 80 | SUCCESSORS, NOR ASSIGNS WILL BE LIABLE FOR ANY DAMAGES RELATED TO THE WORK OR 81 | THIS LICENSE, INCLUDING DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL OR INCIDENTAL 82 | DAMAGES, TO THE MAXIMUM EXTENT THE LAW PERMITS, NO MATTER WHAT LEGAL THEORY IT 83 | IS BASED ON. ALSO, YOU MUST PASS THIS LIMITATION OF LIABILITY ON WHENEVER YOU 84 | DISTRIBUTE THE WORK OR DERIVATIVE WORKS. 85 | 86 | 8. That if you sue anyone over patents that you think may apply to the Work or 87 | anyone's use of the Work, your license to the Work ends automatically. 88 | 89 | 9. That your rights under the License end automatically if you breach it 90 | in any way. 91 | 92 | 10. Uber Technologies, Inc. reserves all rights not expressly granted to you 93 | in this License. 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Go-Explore 2 | 3 | This is the code for [First return then explore](https://arxiv.org/abs/2004.12919), the new Go-explore paper. Code for the [original paper](https://arxiv.org/abs/1901.10995) can be found in this repository under the tag "v1.0" or the release "Go-Explore v1". 4 | 5 | The code for Go-Explore with a deterministic exploration phase followed by a robustification phase is located in the `robustified` subdirectory. The code for Go-Explore with a policy-based exploration phase is located in the `policy_based` subdirectory. Installation instructions for each variant of Go-Explore can be found in their respective directories. 6 | -------------------------------------------------------------------------------- /policy_based/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6-buster 2 | ADD requirements.txt /root/requirements.txt 3 | RUN apt-get update -q \ 4 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 5 | openmpi-bin \ 6 | && apt-get clean \ 7 | && rm -rf /var/lib/apt/lists/* 8 | RUN cat /root/requirements.txt | xargs -n 1 -L 1 pip install 9 | ADD . /root/ 10 | -------------------------------------------------------------------------------- /policy_based/README.md: -------------------------------------------------------------------------------- 1 | # Policy-based Go-Explore 2 | 3 | Code accompanying the paper "First return then explore", available at: [arxiv.org/abs/2004.12919](https://arxiv.org/abs/2004.12919) 4 | 5 | 6 | ## Requirements 7 | 8 | Tested with Python 3.6. The required libraries are listed below and in `requirements.txt`. Libraries need to be installed in the specified order. Unless otherwise specified, libraries can be installed using `pip install `. 9 | 10 | **Required libraries:** 11 | - tensorflow==1.15.2 12 | - mpi4py 13 | - gym[Atari] 14 | - horovod 15 | - baselines@git+https://github.com/openai/baselines@ea25b9e8b234e6ee1bca43083f8f3cf974143998 16 | - Pillow 17 | - imageio 18 | - matplotlib 19 | - loky 20 | - joblib 21 | - dataclasses 22 | - opencv-python 23 | - cloudpickle 24 | 25 | ## Usage 26 | 27 | To test that everything is installed correctly, a local run of policy-based Go-Explore on Montezuma's Revenge or Pitfall can be started by executing `run_policy_based_ge_montezuma.sh` or `run_policy_based_ge_pitfall.sh`, respectively. To reproduce the experiments presented in the afformentioned paper, open each file and change the following settings to: 28 | 29 | ``` 30 | NB_MPI_WORKERS=16 31 | NB_ENVS_PER_WORKER=16 32 | SEED=0 33 | CHECKPOINT=200000000 34 | ``` 35 | 36 | The seed should be changed for each run. Note that, to run effeciently, this code needs to be executed in a compute environment where each worker has access to a GPU. By default, results will be written to `~/temp`, though this can be changed by editing the `sh` files. 37 | -------------------------------------------------------------------------------- /policy_based/atari_reset/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2018 OpenAI (http://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /policy_based/atari_reset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/policy_based/atari_reset/__init__.py -------------------------------------------------------------------------------- /policy_based/atari_reset/atari_reset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/policy_based/atari_reset/atari_reset/__init__.py -------------------------------------------------------------------------------- /policy_based/atari_reset/atari_reset/policies.py: -------------------------------------------------------------------------------- 1 | """ 2 | // Modifications Copyright (c) 2020 Uber Technologies Inc. 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.nn import rnn_cell 8 | from baselines.common.distributions import make_pdtype 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def to2d(x): 14 | size = 1 15 | for shapel in x.get_shape()[1:]: 16 | size *= shapel.value 17 | return tf.reshape(x, (-1, size)) 18 | 19 | 20 | def normc_init(std=1.0, axis=0): 21 | """ 22 | Initialize with normalized columns 23 | """ 24 | 25 | # noinspection PyUnusedLocal 26 | def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613 27 | out = np.random.randn(*shape).astype(np.float32) 28 | out *= std / np.sqrt(np.square(out).sum(axis=axis, keepdims=True)) 29 | return tf.constant(out) 30 | return _initializer 31 | 32 | 33 | def ortho_init(scale=1.0): 34 | # noinspection PyUnusedLocal 35 | def _ortho_init(shape, dtype, partition_info=None): # pylint: disable=W0613 36 | shape = tuple(shape) 37 | if len(shape) == 2: 38 | flat_shape = shape 39 | elif len(shape) == 4: # assumes NHWC 40 | flat_shape = (np.prod(shape[:-1]), shape[-1]) 41 | else: 42 | raise NotImplementedError 43 | a = np.random.normal(0.0, 1.0, flat_shape) 44 | u, _, v = np.linalg.svd(a, full_matrices=False) 45 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 46 | q = q.reshape(shape) 47 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32) 48 | return _ortho_init 49 | 50 | 51 | def fc(x, scope, nout, init_scale=1.0, init_bias=0.0): 52 | with tf.variable_scope(scope): # pylint: disable=E1129 53 | nin = x.get_shape()[1].value 54 | w = tf.get_variable("w", [nin, nout], initializer=normc_init(init_scale)) 55 | b = tf.get_variable("b", [nout], initializer=tf.constant_initializer(init_bias)) 56 | return tf.matmul(x, w) + b 57 | 58 | 59 | def conv(x, scope, noutchannels, filtsize, stride, pad='VALID', init_scale=1.0): 60 | with tf.variable_scope(scope): 61 | nin = x.get_shape()[3].value 62 | w = tf.get_variable("w", [filtsize, filtsize, nin, noutchannels], initializer=ortho_init(init_scale)) 63 | b = tf.get_variable("b", [noutchannels], initializer=tf.constant_initializer(0.0)) 64 | z = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b 65 | return z 66 | 67 | 68 | class GRUCell(rnn_cell.RNNCell): 69 | """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" 70 | def __init__(self, num_units, name, nin, rec_gate_init=0.): 71 | rnn_cell.RNNCell.__init__(self) 72 | self._num_units = num_units 73 | self.rec_gate_init = rec_gate_init 74 | self.w1 = tf.get_variable(name + "w1", [nin+num_units, 2*num_units], initializer=normc_init(1.)) 75 | self.b1 = tf.get_variable(name + "b1", [2*num_units], initializer=tf.constant_initializer(rec_gate_init)) 76 | self.w2 = tf.get_variable(name + "w2", [nin+num_units, num_units], initializer=normc_init(1.)) 77 | self.b2 = tf.get_variable(name + "b2", [num_units], initializer=tf.constant_initializer(0.)) 78 | 79 | @property 80 | def state_size(self): 81 | return self._num_units 82 | 83 | @property 84 | def output_size(self): 85 | return self._num_units 86 | 87 | # noinspection PyMethodOverriding 88 | def call(self, inputs, state): 89 | """Gated recurrent unit (GRU) with nunits cells.""" 90 | x, new = inputs 91 | while len(state.get_shape().as_list()) > len(new.get_shape().as_list()): 92 | new = tf.expand_dims(new, len(new.get_shape().as_list())) 93 | h = state * (1.0 - new) 94 | hx = tf.concat([h, x], axis=1) 95 | mr = tf.sigmoid(tf.matmul(hx, self.w1) + self.b1) 96 | m, r = tf.split(mr, 2, axis=1) 97 | rh_x = tf.concat([r * h, x], axis=1) 98 | htil = tf.tanh(tf.matmul(rh_x, self.w2) + self.b2) 99 | h = m * h + (1.0 - m) * htil 100 | return h, h 101 | 102 | 103 | class CnnPolicy(object): 104 | 105 | def __init__(self, sess, ob_space, ac_space, nbatch, _nsteps, _test_mode=False, reuse=False): 106 | nh, nw, nc = ob_space.shape 107 | ob_shape = (nbatch, nh, nw, nc) 108 | nact = ac_space.n 109 | x = tf.placeholder(tf.uint8, ob_shape) 110 | with tf.variable_scope("model", reuse=reuse): 111 | h = tf.nn.relu(conv(tf.cast(x, tf.float32)/255., 'c1', noutchannels=64, filtsize=8, stride=4)) 112 | h2 = tf.nn.relu(conv(h, 'c2', noutchannels=128, filtsize=4, stride=2)) 113 | h3 = tf.nn.relu(conv(h2, 'c3', noutchannels=128, filtsize=3, stride=1)) 114 | h3 = to2d(h3) 115 | h4 = tf.nn.relu(fc(h3, 'fc1', nout=1024)) 116 | pi = fc(h4, 'pi', nact, init_scale=0.01) 117 | vf = fc(h4, 'v', 1, init_scale=0.01)[:, 0] 118 | 119 | self.pdtype = make_pdtype(ac_space) 120 | self.pd = self.pdtype.pdfromflat(pi) 121 | 122 | a0 = self.pd.sample() 123 | neglogp0 = self.pd.neglogp(a0) 124 | self.initial_state = None 125 | 126 | def step(ob, *_args, **_kwargs): 127 | a, v, neglogp = sess.run([a0, vf, neglogp0], {x: ob}) 128 | return a, v, self.initial_state, neglogp 129 | 130 | def value(ob, *_args, **_kwargs): 131 | return sess.run(vf, {x: ob}) 132 | 133 | self.X = x 134 | self.pi = pi 135 | self.vf = vf 136 | self.step = step 137 | self.value = value 138 | 139 | 140 | class GRUPolicy(object): 141 | 142 | def __init__(self, sess, ob_space, ac_space, nenv, nsteps, memsize=800, test_mode=False, reuse=False): 143 | nh, nw, nc = ob_space.shape 144 | nbatch = nenv*nsteps 145 | ob_shape = (nbatch, nh, nw, nc) 146 | nact = ac_space.n 147 | 148 | # use variables instead of placeholder to keep data on GPU if we're training 149 | x = tf.placeholder(tf.uint8, ob_shape) # obs 150 | mask = tf.placeholder(tf.float32, [nbatch]) # mask (done t-1) 151 | states = tf.placeholder(tf.float32, [nenv, memsize]) # states 152 | e = tf.placeholder(tf.uint8, [nbatch]) 153 | 154 | with tf.variable_scope("model", reuse=reuse): 155 | h = tf.nn.relu(conv(tf.cast(x, tf.float32)/255., 'c1', noutchannels=64, filtsize=8, stride=4)) 156 | h2 = tf.nn.relu(conv(h, 'c2', noutchannels=128, filtsize=4, stride=2)) 157 | h3 = tf.nn.relu(conv(h2, 'c3', noutchannels=128, filtsize=3, stride=1)) 158 | h3 = to2d(h3) 159 | h4 = tf.contrib.layers.layer_norm(fc(h3, 'fc1', nout=memsize), center=False, scale=False, 160 | activation_fn=tf.nn.relu) 161 | h5 = tf.reshape(h4, [nenv, nsteps, memsize]) 162 | 163 | m = tf.reshape(mask, [nenv, nsteps, 1]) 164 | cell = GRUCell(memsize, 'gru1', nin=memsize) 165 | h6, snew = tf.nn.dynamic_rnn(cell, (h5, m), dtype=tf.float32, time_major=False, initial_state=states, 166 | swap_memory=True) 167 | 168 | h7 = tf.concat([tf.reshape(h6, [nbatch, memsize]), h4], axis=1) 169 | pi = fc(h7, 'pi', nact, init_scale=0.01) 170 | if test_mode: 171 | pi *= 2. 172 | else: 173 | pi = tf.where(e > 0, pi/2., pi) 174 | vf = tf.squeeze(fc(h7, 'v', 1, init_scale=0.01)) 175 | 176 | self.pdtype = make_pdtype(ac_space) 177 | self.pd = self.pdtype.pdfromflat(pi) 178 | 179 | a0 = self.pd.sample() 180 | neglogp0 = self.pd.neglogp(a0) 181 | self.initial_state = np.zeros((nenv, memsize), dtype=np.float32) 182 | 183 | def step(ob, state, mask_, increase_ent): 184 | return sess.run([a0, vf, snew, neglogp0], {x: ob, states: state, mask: mask_, e: increase_ent}) 185 | 186 | def value(ob, state, mask_): 187 | return sess.run(vf, {x: ob, states: state, mask: mask_}) 188 | 189 | self.X = x 190 | self.M = mask 191 | self.S = states 192 | self.E = e 193 | self.pi = pi 194 | self.vf = vf 195 | self.step = step 196 | self.value = value 197 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/basics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import copy 10 | import collections 11 | import functools 12 | import warnings as _warnings 13 | 14 | try: 15 | from dataclasses import dataclass, field as datafield 16 | 17 | def copyfield(data): 18 | return datafield(default_factory=lambda: copy.deepcopy(data)) 19 | except ImportError: 20 | _warnings.warn('dataclasses not found. To get it, use Python 3.7 or pip install dataclasses') 21 | raise 22 | 23 | infinity = float('inf') 24 | 25 | 26 | def notebook_max_width(): 27 | from IPython.core.display import display, HTML 28 | display(HTML("")) 29 | 30 | 31 | class Memoized(object): 32 | """ 33 | Decorator. Caches a function's return value each time it is called. 34 | If called later with the same arguments, the cached value is returned 35 | (not reevaluated). 36 | """ 37 | 38 | def __init__(self, func): 39 | self.func = func 40 | self.cache = {} 41 | 42 | def __call__(self, *args): 43 | if not isinstance(args, collections.Hashable): 44 | # uncacheable. a list, for instance. 45 | # better to not cache than blow up. 46 | return self.func(*args) 47 | if args in self.cache: 48 | return self.cache[args] 49 | else: 50 | value = self.func(*args) 51 | self.cache[args] = value 52 | return value 53 | 54 | def __repr__(self): 55 | """Return the function's docstring.""" 56 | return self.func.__doc__ 57 | 58 | def __get__(self, obj, objtype): 59 | """Support instance methods.""" 60 | return functools.partial(self.__call__, obj) 61 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/data_classes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import warnings as _warnings 10 | import copy 11 | from typing import List, Any 12 | from collections import deque 13 | import sys 14 | 15 | try: 16 | from dataclasses import dataclass, field as datafield 17 | 18 | def copyfield(data): 19 | return datafield(default_factory=lambda: copy.deepcopy(data)) 20 | except ModuleNotFoundError: 21 | _warnings.warn('dataclasses not found. To get it, use Python 3.7 or pip install dataclasses') 22 | 23 | 24 | @dataclass 25 | class GridDimension: 26 | attr: str 27 | div: int 28 | 29 | 30 | @dataclass 31 | class CellInfoDeterministic: 32 | #: The score of the last accepted trajectory to this cell 33 | score: int = -float('inf') 34 | #: Number of trajectories that included this cell 35 | nb_seen: int = 0 36 | #: The number of times this cell was chosen as the cell to explore from 37 | nb_chosen: int = 0 38 | #: The number of times this cell was chosen since it was last updated 39 | nb_chosen_since_update: int = 0 40 | #: The number of times this cell was chosen since it last resulted in discovering a new cell 41 | nb_chosen_since_to_new: int = 0 42 | #: The number of times this cell was chosen since it last resulted in updating any cell 43 | nb_chosen_since_to_update: int = 0 44 | #: The number of actions that had this cell as the resulting state (i.e. all frames spend in this cell) 45 | nb_actions: int = 0 46 | #: The number of times this cell was chosen to explore towards 47 | nb_chosen_for_exploration: int = 0 48 | #: The number of times this cell was reached when chosen to explore towards 49 | nb_reached_for_exploration: int = 0 50 | #: Length of the trajectory 51 | trajectory_len: int = float('inf') 52 | #: Saved restore state. In a purely deterministic environment, 53 | #: this allows us to fast-forward to the end state instead 54 | #: of replaying. 55 | restore: Any = None 56 | #: Sliding window for calculating our success rate of reaching different cells 57 | reached: deque = copyfield(deque(maxlen=100)) 58 | #: List of cells that we went through to reach this cell 59 | cell_traj: List[Any] = copyfield([]) 60 | exact_pos = None 61 | real_cell = None 62 | traj_last = None 63 | real_traj: List[int] = None 64 | 65 | @property 66 | def nb_reached(self): 67 | return sum(self.reached) 68 | 69 | 70 | @dataclass 71 | class CellInfoStochastic: 72 | # Basic information used by Go-Explore to determine when to update a cell 73 | #: The score of the last accepted trajectory to this cell 74 | score: int = -sys.maxsize 75 | #: Length of the trajectory 76 | trajectory_len: int = sys.maxsize 77 | 78 | # Together, these determine the trajectory the lead to this cell. 79 | # Necessary in order to follow cell trajectories, as well as for self-imitation learning. 80 | #: The identifier of the trajectory leading to this cell 81 | cell_traj_id: int = -1 82 | #: The index of the last cell of the cell trajectory 83 | cell_traj_end: int = -1 84 | 85 | # Optional information for determining whether cells are discovered more quickly over time. 86 | #: Whether this cell was initially discovered while returning to a cell (rather than while exploring from a cell) 87 | ret_discovered: int = 0 88 | #: At which frame was this cell discovered 89 | frame: int = -1 90 | #: The trajectory id of the first trajectory to find this cell 91 | first_cell_traj_id: int = -1 92 | #: How far along the trajectory was this cell discovered (if it was discovered while returning) 93 | traj_disc: int = 0 94 | #: What was the full length (in cells) of the trajectory being followed when this cell was discovered 95 | #: (if it was discovered while returning) 96 | total_traj_length: int = 0 97 | #: Flag to control the update-on-reset process 98 | should_reset: bool = False 99 | 100 | # Optional information that can be used to take special actions near cells with high failure-to-reach rates. 101 | #: The number of times the agent has failed to reach this cell when it was presented to the agent as a sub-goal. 102 | nb_sub_goal_failed: int = 0 103 | #: Used to track for how long the failure rate has been above a certain threshold. 104 | nb_failures_above_thresh: int = 0 105 | 106 | # Information used to determine cell-selection probabilities 107 | #: The number of times this cell was chosen as the cell to explore from 108 | nb_chosen: int = 0 109 | #: Number of times this cell has been reached 110 | nb_reached: int = 0 111 | #: The number of actions that had this cell as the resulting state (i.e. all frames spend in this cell) 112 | nb_actions_taken_in_cell: int = 0 113 | #: The number of times this cell has been part of a trajectory 114 | nb_seen: int = 0 115 | #: The number of times the information of this cell was reset when update-on-reset is enabled 116 | nb_reset: int = 0 117 | 118 | def add(self, other): 119 | self.nb_chosen += other.nb_chosen 120 | self.nb_reached += other.nb_reached 121 | self.nb_actions_taken_in_cell += other.nb_actions_taken_in_cell 122 | 123 | 124 | @dataclass 125 | class TrajectoryElement: 126 | __slots__ = ['cells', 'action', 'reward', 'done', 'length', 'score', 'restore'] 127 | cells: {} 128 | action: int 129 | reward: float 130 | done: bool 131 | length: int 132 | score: float 133 | restore: Any 134 | 135 | 136 | @dataclass 137 | class LogParameters: 138 | n_digits: int 139 | checkpoint_game: int 140 | checkpoint_compute: int 141 | checkpoint_first_iteration: bool 142 | checkpoint_last_iteration: bool 143 | max_game_steps: int 144 | max_compute_steps: int 145 | max_time: int 146 | max_iterations: int 147 | max_cells: int 148 | max_score: int 149 | save_pictures: List[str] 150 | clear_pictures: List[str] 151 | base_path: str 152 | checkpoint_it: int 153 | save_archive: bool 154 | save_model: bool 155 | checkpoint_time: int 156 | 157 | def should_render(self, name): 158 | return name in self.save_pictures or 'all' in self.save_pictures 159 | 160 | 161 | @dataclass() 162 | class Weight: 163 | weight: float = 1.0 164 | power: float = 1.0 165 | 166 | def __repr__(self): 167 | return f'w={self.weight:.2f}=p={self.power:.2f}' 168 | 169 | 170 | @dataclass() 171 | class DirWeights: 172 | horiz: float = 2.0 173 | vert: float = 0.3 174 | score_low: float = 0.0 175 | score_high: float = 0.0 176 | 177 | def __repr__(self): 178 | return f'h={self.horiz:.2f}=v={self.vert:.2f}=l={self.score_low:.2f}=h={self.score_high:.2f}' 179 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/explorers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import random 10 | import os 11 | import numpy as np 12 | from .globals import get_action_meaning, get_trajectory 13 | 14 | 15 | class RandomExplorer: 16 | def init_seed(self): 17 | pass 18 | 19 | def get_action(self, env): 20 | return random.randint(0, env.action_space.n - 1) 21 | 22 | def __repr__(self): 23 | return 'RandomExplorer()' 24 | 25 | 26 | class RepeatedRandomExplorer: 27 | def __init__(self, mean_repeat): 28 | self.mean_repeat = mean_repeat 29 | self.action = 0 30 | self.remaining = 0 31 | 32 | def init_seed(self): 33 | self.remaining = 0 34 | 35 | def get_action(self, env): 36 | if self.remaining <= 0: 37 | self.action = random.randint(0, env.action_space.n - 1) 38 | # Note, this is equivalent to selecting an action and then repeating it 39 | # with some probability. 40 | self.remaining = np.random.geometric(1 / self.mean_repeat) 41 | self.remaining -= 1 42 | return self.action 43 | 44 | def __repr__(self): 45 | return f'repeat-{self.mean_repeat}' 46 | 47 | 48 | class ReplayTrajectoryExplorer: 49 | def __init__(self, prev_idxs, actions): 50 | self.prev_idxs = prev_idxs 51 | self.actions = actions 52 | self.trajectory = [] 53 | self.action_index = 0 54 | self.current_goal = None 55 | 56 | def init_seed(self): 57 | pass 58 | 59 | def get_action(self, env): 60 | goal_rep = env.recursive_getattr('goal_cell_rep') 61 | current_cell = env.get_current_cell() 62 | # We have reached the end of our trajectory, a new goal should have been chosen 63 | if self.action_index >= len(self.trajectory): 64 | goal = env.recursive_getattr('goal_cell_info') 65 | print('Selected goal:', goal_rep) 66 | print('Previous goal:', self.current_goal) 67 | if goal_rep == self.current_goal: 68 | print("ERROR: The same goal was selected twice in a row.") 69 | raise Exception('The same goal was selected twice in a row.') 70 | self.current_goal = goal_rep 71 | self.trajectory = get_trajectory(self.prev_idxs, self.actions, goal.traj_last) 72 | if goal.real_traj is not None: 73 | assert goal.real_traj == self.trajectory 74 | if goal.trajectory_len is not -1: 75 | assert len(self.trajectory) == goal.trajectory_len 76 | self.action_index = 0 77 | elif goal_rep != self.current_goal: 78 | print("ERROR: New goal selected before trajectory to previous goal was finished.") 79 | print("Full trajectory was:", self.trajectory) 80 | print('process id:', os.getpid()) 81 | print("Which is:", [get_action_meaning(a) for a in self.trajectory]) 82 | raise Exception('New goal selected before trajectory to previous goal was finished.') 83 | if len(self.trajectory) > 0: 84 | action = self.trajectory[self.action_index] 85 | self.action_index += 1 86 | else: 87 | action = 0 88 | print('In cell:', current_cell, 'Playing action:', self.action_index-1, action, get_action_meaning(action)) 89 | return action 90 | 91 | 92 | class RepeatedRandomExplorerRobot: 93 | def __init__(self, mean_repeat=10): 94 | self.mean_repeat = mean_repeat 95 | self.action = 0 96 | self.remaining = 0 97 | 98 | def init_seed(self): 99 | self.remaining = 0 100 | 101 | def get_action(self, env): 102 | if self.remaining <= 0: 103 | self.action = env.action_space.sample() 104 | # Note, this is equivalent to selecting an action and then repeating it 105 | # with some probability. 106 | self.remaining = np.random.geometric(1 / self.mean_repeat) 107 | self.remaining -= 1 108 | return self.action 109 | 110 | def __repr__(self): 111 | return f'repeat-{self.mean_repeat}' 112 | 113 | 114 | class DoNothingExplorer: 115 | def init_seed(self): 116 | pass 117 | 118 | def get_action(self, *_args): 119 | return 0 120 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/ge_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import tensorflow as tf 10 | from typing import Any 11 | import atari_reset.atari_reset.ppo as ppo 12 | 13 | 14 | class GoalConditionedModelFlexEnt(ppo.Model): 15 | def __init__(self): 16 | super(GoalConditionedModelFlexEnt, self).__init__() 17 | 18 | def init(self, policy, ob_space, ac_space, nenv, nsteps, ent_coef, vf_coef, l2_coef, 19 | cliprange, adam_epsilon=1e-6, load_path=None, test_mode=False, goal_space=None, disable_hvd=False): 20 | self.sess = tf.get_default_session() 21 | self.init_models(policy, ob_space, ac_space, nenv, nsteps, test_mode, goal_space) 22 | self.init_loss(nenv, nsteps, cliprange, disable_hvd) 23 | self.loss = self.pg_loss - self.entropy * ent_coef + self.vf_loss * vf_coef + l2_coef * self.l2_loss 24 | self.finalize(load_path, adam_epsilon) 25 | 26 | def init_models(self, policy, ob_space, ac_space, nenv, nsteps, test_mode, goal_space): 27 | # At test time, we only need the most recent action in order to take a step. 28 | self.act_model = policy(self.sess, ob_space, ac_space, nenv, 1, test_mode=test_mode, reuse=False, 29 | goal_space=goal_space) 30 | # At training time, we need to keep track of the last T (nsteps) of actions that we took. 31 | self.train_model = policy(self.sess, ob_space, ac_space, nenv, nsteps, test_mode=test_mode, reuse=True, 32 | goal_space=goal_space) 33 | 34 | def train_from_runner(self, lr: float, runner: Any): 35 | return self.train(lr, 36 | runner.ar_mb_obs_2.reshape(self.train_model.X.shape), 37 | runner.ar_mb_goals, 38 | runner.ar_mb_rets, 39 | runner.ar_mb_advs, 40 | runner.ar_mb_dones, 41 | runner.ar_mb_actions, 42 | runner.ar_mb_values, 43 | runner.ar_mb_neglogpacs, 44 | runner.ar_mb_valids, 45 | runner.ar_mb_ent, 46 | runner.mb_states[0]) 47 | 48 | def train(self, lr, obs, goals, returns, advs, masks, actions, values, neglogpacs, valids, increase_ent, 49 | states=None): 50 | td_map = {self.LR: lr, self.train_model.X: obs, self.train_model.goal: goals, 51 | self.A: actions, self.ADV: advs, self.VALID: valids, self.R: returns, 52 | self.OLDNEGLOGPAC: neglogpacs, self.OLDVPRED: values, self.train_model.E: increase_ent} 53 | if states is not None: 54 | td_map[self.train_model.S] = states 55 | td_map[self.train_model.M] = masks 56 | return self.sess.run(self.loss_requested, feed_dict=td_map)[:-1] 57 | 58 | 59 | class GoalConFlexEntSilModel(GoalConditionedModelFlexEnt): 60 | def __init__(self): 61 | super(GoalConFlexEntSilModel, self).__init__() 62 | self.sil_loss = None 63 | self.SIL_A = None 64 | self.SIL_VALID = None 65 | self.SIL_R = None 66 | self.sil_pg_loss = None 67 | self.sil_vf_loss = None 68 | self.sil_entropy = None 69 | self.sil_valid_min = None 70 | self.sil_valid_max = None 71 | self.sil_valid_mean = None 72 | 73 | self.neglop_sil_min = None 74 | self.neglop_sil_max = None 75 | self.neglop_sil_mean = None 76 | 77 | # Debug 78 | self.mean_val_pred = None 79 | self.mean_sil_r = None 80 | self.train_it = 0 81 | 82 | def init(self, policy, ob_space, ac_space, nenv, nsteps, ent_coef, vf_coef, l2_coef, 83 | cliprange, adam_epsilon=1e-6, load_path=None, test_mode=False, goal_space=None, sil_coef=0.0, 84 | sil_vf_coef=0.0, sil_ent_coef=0.0, disable_hvd=False): 85 | self.sess = tf.get_default_session() 86 | self.init_models(policy, ob_space, ac_space, nenv, nsteps, test_mode, goal_space) 87 | self.init_loss(nenv, nsteps, cliprange, disable_hvd) 88 | self.init_sil_loss(nenv, nsteps, sil_vf_coef, sil_ent_coef) 89 | self.loss = (self.pg_loss 90 | - self.entropy * ent_coef 91 | + self.vf_loss * vf_coef 92 | + l2_coef * self.l2_loss 93 | + sil_coef * self.sil_loss) 94 | 95 | self.finalize(load_path, adam_epsilon) 96 | self.loss_requested_dict = {self.pg_loss: 'policy_loss', 97 | self.vf_loss: 'value_loss', 98 | self.l2_loss: 'l2_loss', 99 | self.entropy: 'policy_entropy', 100 | self.approxkl: 'approxkl', 101 | self.clipfrac: 'clipfrac', 102 | self.sil_pg_loss: 'sil_pg_loss', 103 | self.sil_vf_loss: 'sil_vf_loss', 104 | self.sil_loss: 'sil_loss', 105 | self.sil_entropy: 'sil_entropy', 106 | self.sil_valid_min: 'sil_valid_min', 107 | self.sil_valid_max: 'sil_valid_max', 108 | self.sil_valid_mean: 'sil_valid_mean', 109 | self.neglop_sil_min: 'neglop_sil_min', 110 | self.neglop_sil_max: 'neglop_sil_max', 111 | self.neglop_sil_mean: 'neglop_sil_mean', 112 | self.mean_val_pred: 'mean_val_pred', 113 | self.mean_sil_r: 'mean_sil_r', 114 | self.train_op: ''} 115 | self.init_requested_loss() 116 | 117 | def init_sil_loss(self, nenv, nsteps, sil_vf_coef, sil_ent_coef): 118 | self.SIL_A = self.train_model.pdtype.sample_placeholder([nenv*nsteps], name='sil_action') 119 | self.SIL_VALID = tf.placeholder(tf.float32, [nenv*nsteps], name='sil_valid') 120 | self.SIL_R = tf.placeholder(tf.float32, [nenv*nsteps], name='sil_return') 121 | 122 | neglogp_sil_ac = self.train_model.pd.neglogp(self.SIL_A) 123 | 124 | self.sil_pg_loss = tf.reduce_mean(neglogp_sil_ac * tf.nn.relu(self.SIL_R - self.OLDVPRED) * self.SIL_VALID) 125 | self.sil_vf_loss = .5 * tf.reduce_mean(tf.square(tf.nn.relu(self.SIL_R - self.vpred)) * self.SIL_VALID) 126 | self.sil_entropy = tf.reduce_mean(self.SIL_VALID * self.train_model.pd.entropy()) 127 | self.sil_loss = self.sil_pg_loss + sil_vf_coef * self.sil_vf_loss + sil_ent_coef * self.sil_entropy 128 | 129 | self.sil_valid_min = tf.reduce_min(self.SIL_VALID) 130 | self.sil_valid_max = tf.reduce_max(self.SIL_VALID) 131 | self.sil_valid_mean = tf.reduce_mean(self.SIL_VALID) 132 | 133 | self.neglop_sil_min = tf.reduce_min(neglogp_sil_ac) 134 | self.neglop_sil_max = tf.reduce_max(neglogp_sil_ac) 135 | self.neglop_sil_mean = tf.reduce_mean(neglogp_sil_ac) 136 | 137 | self.mean_val_pred = tf.reduce_mean(self.OLDVPRED) 138 | self.mean_sil_r = tf.reduce_mean(self.SIL_R) 139 | 140 | def train_from_runner(self, lr: float, runner: Any): 141 | obs = runner.ar_mb_obs_2.reshape(self.train_model.X.shape) 142 | 143 | return self.train(lr, 144 | obs, 145 | runner.ar_mb_goals, 146 | runner.ar_mb_rets, 147 | runner.ar_mb_advs, 148 | runner.ar_mb_dones, 149 | runner.ar_mb_actions, 150 | runner.ar_mb_values, 151 | runner.ar_mb_neglogpacs, 152 | runner.ar_mb_valids, 153 | runner.ar_mb_ent, 154 | runner.ar_mb_sil_actions, 155 | runner.ar_mb_sil_rew, 156 | runner.ar_mb_sil_valid, 157 | runner.mb_states[0]) 158 | 159 | def train(self, lr, obs, goals, returns, advs, masks, actions, values, neglogpacs, valids, increase_ent, 160 | sil_actions=None, sil_rew=None, sil_valid=None, states=None): 161 | self.train_it += 1 162 | td_map = {self.LR: lr, self.train_model.X: obs, self.train_model.goal: goals, self.A: actions, self.ADV: advs, 163 | self.VALID: valids, self.R: returns, 164 | self.OLDNEGLOGPAC: neglogpacs, self.OLDVPRED: values, self.train_model.E: increase_ent, 165 | self.SIL_A: sil_actions, self.SIL_R: sil_rew, self.SIL_VALID: sil_valid} 166 | if states is not None: 167 | td_map[self.train_model.S] = states 168 | td_map[self.train_model.M] = masks 169 | return self.filter_requested_losses(self.sess.run(self.loss_requested, feed_dict=td_map)) 170 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/ge_policies.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import logging 10 | import numpy as np 11 | import tensorflow as tf 12 | import atari_reset.atari_reset.policies as po 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class GRUPolicyGoalConSimpleFlexEnt(object): 17 | def __init__(self, sess, ob_space, ac_space, nenv, nsteps, memsize=800, test_mode=False, reuse=False, 18 | goal_space=None): 19 | nh, nw, nc = ob_space.shape 20 | nbatch = nenv*nsteps 21 | ob_shape = (nbatch, nh, nw, nc) 22 | logger.info(f'goal_space.shape: {goal_space.shape}') 23 | goal_shape = tuple([nbatch] + list(goal_space.shape)) 24 | logger.info(f'goal_shape: {goal_shape}') 25 | nact = ac_space.n 26 | 27 | # use variables instead of placeholder to keep data on GPU if we're training 28 | nn_input = tf.placeholder(tf.uint8, ob_shape, 'input') # obs 29 | goal = tf.placeholder(tf.float32, goal_shape, 'goal') # goal 30 | mask = tf.placeholder(tf.float32, [nbatch], 'done_mask') # mask (done t-1) 31 | states = tf.placeholder(tf.float32, [nenv, memsize], 'hidden_state') # states 32 | entropy = tf.placeholder(tf.float32, [nbatch], 'entropy_factor') 33 | fake_actions = tf.placeholder(tf.int64, [nbatch], 'fake_actions') 34 | logger.info(f'fake_actions.shape: {fake_actions.shape}') 35 | logger.info(f'fake_actions.dtype: {fake_actions.dtype}') 36 | 37 | with tf.variable_scope("model", reuse=reuse): 38 | logger.info(f'input.shape {nn_input.shape}') 39 | h = tf.nn.relu(po.conv(tf.cast(nn_input, tf.float32)/255., 'c1', noutchannels=64, filtsize=8, stride=4)) 40 | logger.info(f'h.shape: {h.shape}') 41 | h2 = tf.nn.relu(po.conv(h, 'c2', noutchannels=128, filtsize=4, stride=2)) 42 | logger.info(f'h2.shape: {h2.shape}') 43 | h3 = tf.nn.relu(po.conv(h2, 'c3', noutchannels=128, filtsize=3, stride=1)) 44 | logger.info(f'h3.shape: {h3.shape}') 45 | h3 = po.to2d(h3) 46 | logger.info(f'h3.shape: {h3.shape}') 47 | g1 = tf.cast(goal, tf.float32) 48 | logger.info(f'g1.shape: {g1.shape}') 49 | h3 = tf.concat([h3, g1], axis=1) 50 | logger.info(f'h3.shape: {h3.shape}') 51 | h4 = tf.contrib.layers.layer_norm(po.fc(h3, 'fc1', nout=memsize), center=False, scale=False, 52 | activation_fn=tf.nn.relu) 53 | logger.info(f'h4.shape: {h4.shape}') 54 | h5 = tf.reshape(h4, [nenv, nsteps, memsize]) 55 | 56 | m = tf.reshape(mask, [nenv, nsteps, 1]) 57 | cell = po.GRUCell(memsize, 'gru1', nin=memsize) 58 | h6, snew = tf.nn.dynamic_rnn(cell, (h5, m), dtype=tf.float32, time_major=False, 59 | initial_state=states, swap_memory=True) 60 | logger.info(f'h6.shape: {h6.shape}') 61 | 62 | h7 = tf.concat([tf.reshape(h6, [nbatch, memsize]), h4], axis=1) 63 | pi = po.fc(h7, 'pi', nact, init_scale=0.01) 64 | if test_mode: 65 | pi *= 2. 66 | else: 67 | pi /= tf.reshape(entropy, (nbatch, 1)) 68 | logger.info(f'h7.shape: {h7.shape}') 69 | vf_before_squeeze = po.fc(h7, 'v', 1, init_scale=0.01) 70 | logger.info(f'vf_before_squeeze.shape: {vf_before_squeeze.shape}') 71 | vf = tf.squeeze(vf_before_squeeze, axis=[1]) 72 | logger.info(f'vf.shape: {vf.shape}') 73 | 74 | self.pdtype = po.make_pdtype(ac_space) 75 | self.pd = self.pdtype.pdfromflat(pi) 76 | a0 = self.pd.sample() 77 | logger.info(f'a0.shape: {a0.shape}') 78 | logger.info(f'a0.dtype: {a0.dtype}') 79 | neglogp0 = self.pd.neglogp(a0) 80 | self.initial_state = np.zeros((nenv, memsize), dtype=np.float32) 81 | 82 | neg_log_fake_a = self.pd.neglogp(fake_actions) 83 | 84 | def step(local_ob, local_goal, local_state, local_mask, local_increase_ent): 85 | return sess.run([a0, vf, snew, neglogp0], 86 | {nn_input: local_ob, states: local_state, mask: local_mask, entropy: local_increase_ent, 87 | goal: local_goal}) 88 | 89 | def step_fake_action(local_ob, local_goal, local_state, local_mask, local_increase_ent, local_fake_action): 90 | return sess.run([a0, vf, snew, neglogp0, neg_log_fake_a], 91 | {nn_input: local_ob, 92 | states: local_state, 93 | mask: local_mask, 94 | entropy: local_increase_ent, 95 | goal: local_goal, 96 | fake_actions: local_fake_action}) 97 | 98 | def value(local_ob, local_goal, local_state, local_mask): 99 | return sess.run(vf, {nn_input: local_ob, states: local_state, mask: local_mask, goal: local_goal}) 100 | 101 | self.X = nn_input 102 | self.goal = goal 103 | self.M = mask 104 | self.S = states 105 | self.E = entropy 106 | self.pi = pi 107 | self.vf = vf 108 | self.step = step 109 | self.step_fake_action = step_fake_action 110 | self.value = value 111 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/ge_runners.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import numpy as np 10 | import atari_reset.atari_reset.ppo as ppo 11 | 12 | 13 | class RunnerFlexEntSilProper(ppo.Runner): 14 | def __init__(self, env, model, nsteps, gamma, lam, norm_adv, subtract_rew_avg): 15 | super(RunnerFlexEntSilProper, self).__init__(env, model, nsteps, gamma, lam, norm_adv, subtract_rew_avg) 16 | self.mb_sil_actions = self.reg_shift_list() 17 | self.mb_sil_rew = self.reg_shift_list() 18 | self.mb_sil_valid = self.reg_shift_list() 19 | self.ar_mb_sil_valid = None 20 | self.ar_mb_sil_actions = None 21 | self.ar_mb_sil_rew = None 22 | self.trunc_lst_mb_sil_valid = None 23 | 24 | def append_mb_data(self, actions, values, states, neglogpacs, obs_and_goals, rewards, dones, infos): 25 | super(RunnerFlexEntSilProper, self).append_mb_data(actions, 26 | values, 27 | states, 28 | neglogpacs, 29 | obs_and_goals, 30 | rewards, 31 | dones, 32 | infos) 33 | 34 | def get_sil_valid(info): 35 | is_valid = float(info.get('sil_action') is not None) 36 | return is_valid 37 | 38 | self.mb_sil_valid.append([get_sil_valid(info) for info in infos]) 39 | sil_actions = np.zeros_like(actions) 40 | for cur_info_id, info in enumerate(infos): 41 | cur_action = info.get('sil_action') 42 | if cur_action is not None: 43 | sil_actions[cur_info_id] = cur_action 44 | 45 | self.mb_sil_actions.append(sil_actions) 46 | self.mb_sil_rew.append([info.get('sil_value', 0) for info in infos]) 47 | 48 | def gather_return_info(self, end): 49 | super(RunnerFlexEntSilProper, self).gather_return_info(end) 50 | self.ar_mb_sil_valid = ppo.sf01(np.asarray(self.mb_sil_valid[:end], dtype=np.float32), 'sil_valids') 51 | self.ar_mb_sil_actions = ppo.sf01(np.asarray(self.mb_sil_actions[:end]), 'sil_actions') 52 | self.ar_mb_sil_rew = ppo.sf01(np.asarray(self.mb_sil_rew[:end], dtype=np.float32), 'sil_rewards') 53 | self.trunc_lst_mb_sil_valid = ppo.sf01(np.asarray(self.mb_sil_valid[-len(self.mb_cells):len(self.mb_sil_valid)], 54 | dtype=np.float32), 'trunc_sil_valids') 55 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/generic_atari_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import numpy as np 10 | import gym 11 | import copy 12 | from typing import Tuple, List 13 | 14 | 15 | def convert_state(state): 16 | import cv2 17 | return ((cv2.resize(cv2.cvtColor(state, cv2.COLOR_RGB2GRAY), 18 | MyAtari.TARGET_SHAPE, interpolation=cv2.INTER_AREA) / 255.0) * 19 | MyAtari.MAX_PIX_VALUE).astype(np.uint8) 20 | 21 | 22 | class AtariPosLevel: 23 | __slots__ = ['level', 'score', 'room', 'x', 'y', 'tuple'] 24 | 25 | def __init__(self, level=0, score=0, room=0, x=0, y=0): 26 | self.level = level 27 | self.score = score 28 | self.room = room 29 | self.x = x 30 | self.y = y 31 | self.tuple = None 32 | self.set_tuple() 33 | 34 | def set_tuple(self): 35 | self.tuple = (self.level, self.score, self.room, self.x, self.y) 36 | 37 | def __hash__(self): 38 | return hash(self.tuple) 39 | 40 | def __eq__(self, other): 41 | if not isinstance(other, AtariPosLevel): 42 | return False 43 | return self.tuple == other.tuple 44 | 45 | def __getstate__(self): 46 | return self.tuple 47 | 48 | def __setstate__(self, d): 49 | self.level, self.score, self.room, self.x, self.y = d 50 | self.tuple = d 51 | 52 | def __repr__(self): 53 | return f'Level={self.level} Room={self.room} Objects={self.score} x={self.x} y={self.y}' 54 | 55 | 56 | def clip(a, min_v, max_v): 57 | if a < min_v: 58 | return min_v 59 | if a > max_v: 60 | return max_v 61 | return a 62 | 63 | 64 | class MyAtari: 65 | def __init__(self, name, x_repeat=2, end_on_death=False): 66 | self.name = name 67 | self.env = gym.make(f'{name}Deterministic-v4') 68 | self.env.reset() 69 | self.unwrapped.seed(0) 70 | self.state = [] 71 | self.x_repeat = x_repeat 72 | self.rooms = [] 73 | self.unprocessed_state = None 74 | self.end_on_death = end_on_death 75 | self.prev_lives = 0 76 | 77 | def __getattr__(self, e): 78 | return getattr(self.env, e) 79 | 80 | def reset(self) -> List[np.ndarray]: 81 | self.unprocessed_state = self.env.reset() 82 | self.state = [convert_state(self.unprocessed_state)] 83 | for _ in range(3): 84 | self.unprocessed_state = self.env.step(0)[0] 85 | self.state.append(convert_state(self.unprocessed_state)) 86 | 87 | return copy.copy(self.state) 88 | 89 | def get_restore(self): 90 | return ( 91 | self.unwrapped.clone_full_state(), 92 | copy.copy(self.state), 93 | ) 94 | 95 | def restore(self, data): 96 | ( 97 | full_state, 98 | state, 99 | ) = data 100 | self.state = copy.copy(state) 101 | self.env.reset() 102 | self.env.unwrapped.restore_full_state(full_state) 103 | return copy.copy(self.state) 104 | 105 | def step(self, action) -> Tuple[List[np.ndarray], float, bool, dict]: 106 | self.unprocessed_state, reward, done, lol = self.env.step(action) 107 | self.state.append(convert_state(self.unprocessed_state)) 108 | self.state.pop(0) 109 | 110 | cur_lives = self.env.unwrapped.ale.lives() 111 | if self.end_on_death and cur_lives < self.prev_lives: 112 | done = True 113 | self.prev_lives = cur_lives 114 | 115 | return copy.copy(self.state), reward, done, lol 116 | 117 | def get_pos(self): 118 | # NOTE: this only returns a dummy position 119 | return AtariPosLevel() 120 | 121 | def render_with_known(self, known_positions, resolution, show=True, filename=None, combine_val=max, 122 | get_val=lambda x: x.score, minmax=None): 123 | pass 124 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/generic_goal_conditioned_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import gym 10 | import numpy as np 11 | import types 12 | import copy 13 | import random 14 | from typing import List 15 | 16 | 17 | def make_robot_env(name, will_render=False): 18 | env = gym.make(name) 19 | 20 | # Fix rendering to 1/ hide the overlay, 2/ read the proper pixels in rgb_array 21 | # mode and 3/ prevent rendering if will_render is not announced (necessary because 22 | # when will_render is announced, we proactively create a viewer as soon as the 23 | # env is created, because creating it later causes inaccuracies). 24 | def render(self, mode='human'): 25 | assert will_render, 'Rendering in an environment with will_render=False' 26 | self._render_callback() 27 | self._get_viewer()._hide_overlay = True 28 | if mode == 'rgb_array': 29 | self._get_viewer().render() 30 | import glfw 31 | width, height = glfw.get_window_size(self.viewer.window) 32 | data = self._get_viewer().read_pixels(width, height, depth=False) 33 | # original image is upside-down, so flip it 34 | return data[::-1, :, :] 35 | elif mode == 'human': 36 | self._get_viewer().render() 37 | 38 | env.unwrapped.render = types.MethodType(render, env.unwrapped) 39 | if will_render: 40 | # Pre-cache the viewer because creating it while the environment is running 41 | # sometimes causes errors 42 | # noinspection PyProtectedMember 43 | env.unwrapped._get_viewer() 44 | 45 | if 'Fetch' in name: 46 | # The way _render_callback is implemented in Fetch environments causes issues. 47 | # This monkey patch fixes them. 48 | def _render_callback(self): 49 | # Visualize target. 50 | sites_offset = (self.sim.data.site_xpos - self.sim.model.site_pos).copy() 51 | site_id = self.sim.model.site_name2id('target0') 52 | self.sim.model.site_pos[site_id] = self.goal - sites_offset[0] 53 | 54 | env.unwrapped._render_callback = types.MethodType(_render_callback, env.unwrapped) 55 | 56 | return env 57 | 58 | 59 | class DomainConditionedPosLevel: 60 | __slots__ = ['level', 'score', 'room', 'x', 'y', 'tuple'] 61 | 62 | def __init__(self, level=0, score=0, room=0, x=0, y=0): 63 | self.level = level 64 | self.score = score 65 | self.room = room 66 | self.x = x 67 | self.y = y 68 | self.tuple = None 69 | 70 | self.set_tuple() 71 | 72 | def set_tuple(self): 73 | self.tuple = (self.level, self.score, self.room, self.x, self.y) 74 | 75 | def __hash__(self): 76 | return hash(self.tuple) 77 | 78 | def __eq__(self, other): 79 | if not isinstance(other, DomainConditionedPosLevel): 80 | return False 81 | return self.tuple == other.tuple 82 | 83 | def __getstate__(self): 84 | return self.tuple 85 | 86 | def __setstate__(self, d): 87 | self.level, self.score, self.room, self.x, self.y = d 88 | self.tuple = d 89 | 90 | def __repr__(self): 91 | return f'Level={self.level} Room={self.room} Objects={self.score} x={self.x} y={self.y}' 92 | 93 | 94 | class MyRobot: 95 | TARGET_SHAPE = 0 96 | MAX_PIX_VALUE = 0 97 | 98 | def __init__(self, env_name, interval_size=0.1, seed_low=0, seed_high=0): 99 | self.env_name = env_name 100 | self.env = make_robot_env(env_name) 101 | self.interval_size = interval_size 102 | self.state = None 103 | self.actual_state = None 104 | self.rooms = [] 105 | self.trajectory = [] 106 | 107 | self.seed_low = seed_low 108 | self.seed_high = seed_high 109 | self.seed = None 110 | 111 | self.cur_achieved_goal = None 112 | self.achieved_has_moved = False 113 | self.score_so_far = 0 114 | 115 | self.follow_grip_until_moved = ('FetchPickAndPlace' in env_name and False) 116 | 117 | self.reset() 118 | 119 | def __getattr__(self, e): 120 | assert self.env is not self 121 | return getattr(self.env, e) 122 | 123 | def pos_from_state(self, seed, state): 124 | if self.follow_grip_until_moved: 125 | pos = state['achieved_goal'] if self.achieved_has_moved else state['observation'][:3] 126 | return np.array([seed, self.achieved_has_moved] + list(pos / self.interval_size), dtype=np.int32) 127 | return np.array([seed, self.score_so_far] + 128 | list((state['achieved_goal'] / self.interval_size).astype(np.int32)), dtype=np.int32) 129 | 130 | def reset(self) -> List[np.ndarray]: 131 | self.seed = None 132 | self.trajectory = None 133 | self.actual_state = None 134 | self.cur_achieved_goal = None 135 | self.achieved_has_moved = False 136 | self.score_so_far = 0 137 | self.state = [self.pos_from_state(-1, {'achieved_goal': np.array([]), 'observation': np.array([])})] 138 | return copy.copy(self.state) 139 | 140 | def get_restore(self): 141 | # noinspection PyProtectedMember 142 | return copy.deepcopy(( 143 | None, 144 | self.env._elapsed_steps, 145 | self.interval_size, 146 | self.cur_achieved_goal, 147 | self.achieved_has_moved, 148 | self.score_so_far, 149 | self.state, 150 | self.actual_state, 151 | self.trajectory, 152 | self.seed, 153 | )) 154 | 155 | def restore(self, data): 156 | ( 157 | simstate, 158 | _, 159 | self.interval_size, 160 | self.cur_achieved_goal, 161 | self.achieved_has_moved, 162 | self.score_so_far, 163 | self.state, 164 | actual_state, 165 | trajectory, 166 | seed, 167 | ) = copy.deepcopy(data) 168 | self.reset() 169 | self.seed = seed 170 | for a in trajectory: 171 | self.step(a) 172 | assert np.allclose(self.actual_state['achieved_goal'], actual_state['achieved_goal']) 173 | return copy.copy(self.state) 174 | 175 | def step(self, action): 176 | if self.trajectory is None: 177 | if self.seed is None: 178 | self.seed = random.randint(self.seed_low, self.seed_high) 179 | self.env.seed(self.seed) 180 | self.actual_state = self.env.reset() 181 | self.trajectory = [] 182 | self.state = [self.pos_from_state(self.seed, self.actual_state)] 183 | 184 | self.trajectory.append(copy.copy(action)) 185 | self.actual_state, reward, done, lol = self.env.step(action) 186 | reward = int(reward) + 1 187 | self.score_so_far += reward 188 | self.state = [self.pos_from_state(self.seed, self.actual_state)] 189 | 190 | if (not self.achieved_has_moved and 191 | self.cur_achieved_goal is not None and 192 | not np.allclose(self.cur_achieved_goal, self.actual_state['achieved_goal'])): 193 | self.achieved_has_moved = True 194 | self.cur_achieved_goal = self.actual_state['achieved_goal'] 195 | 196 | return copy.copy(self.state), reward, done, lol 197 | 198 | def get_pos(self): 199 | return DomainConditionedPosLevel() 200 | 201 | def render_with_known(self, known_positions, resolution, show=True, filename=None, combine_val=max, 202 | get_val=lambda x: x.score, minmax=None): 203 | pass 204 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/globals.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | from typing import List, Optional 10 | import os 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 15 | 16 | EXP_STRAT_INIT = -1 17 | EXP_STRAT_NONE = 0 18 | EXP_STRAT_RAND = 1 19 | EXP_STRAT_POLICY = 2 20 | 21 | 22 | ACTION_MEANINGS: Optional[List] = None 23 | MASTER_PID = None 24 | BASE_PATH = None 25 | 26 | 27 | def set_action_meanings(meanings=List[str]): 28 | global ACTION_MEANINGS 29 | ACTION_MEANINGS = meanings 30 | logger.debug(f'ACTION_MEANINGS set for process: {os.getpid()}') 31 | 32 | 33 | def get_action_meaning(i): 34 | return ACTION_MEANINGS[i] 35 | 36 | 37 | def get_trajectory(prev_idxs: List[int], actions: List[int], idx: int): 38 | trajectory = [] 39 | if idx is not None: 40 | while prev_idxs[idx] is not None: 41 | action = actions[idx] 42 | idx = idx - prev_idxs[idx] 43 | trajectory.append(action) 44 | trajectory.reverse() 45 | return trajectory 46 | 47 | 48 | def set_master_pid(pid): 49 | global MASTER_PID 50 | MASTER_PID = pid 51 | 52 | 53 | def get_master_pid(): 54 | global MASTER_PID 55 | return MASTER_PID 56 | 57 | 58 | def set_base_path(base_path): 59 | global BASE_PATH 60 | BASE_PATH = base_path 61 | 62 | 63 | def get_base_path(): 64 | global BASE_PATH 65 | return BASE_PATH 66 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/goal_representations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import numpy as np 10 | from typing import List, Any 11 | from gym import spaces 12 | 13 | 14 | class AbstractGoalRepresentation: 15 | def get_goal_space(self): 16 | raise NotImplementedError('get_goal_space needs to be implemented.') 17 | 18 | def get(self, current_cell: Any, final_goal: Any, sub_goal: Any): 19 | raise NotImplementedError('get needs to be implemented.') 20 | 21 | 22 | class FlatGoalRep(AbstractGoalRepresentation): 23 | def __init__(self, rep_type: str, rel_final_goal: bool, rel_sub_goal: bool, length_data: Any): 24 | self.rep_type = rep_type 25 | self.rel_final_goal = rel_final_goal 26 | self.rel_sub_goal = rel_sub_goal 27 | self.length_data = length_data 28 | 29 | if self.rep_type == 'final_goal': 30 | self.total_length = self._get_length(self.rel_final_goal, length_data) 31 | elif self.rep_type == 'sub_goal': 32 | self.total_length = self._get_length(self.rel_sub_goal, length_data) 33 | elif self.rep_type == 'final_goal_and_sub_goal': 34 | self.total_length = (self._get_length(self.rel_final_goal, length_data) + 35 | self._get_length(self.rel_sub_goal, length_data)) 36 | else: 37 | raise NotImplementedError('Unknown representation type: ' + self.rep_type) 38 | 39 | def get_goal_space(self): 40 | raise NotImplementedError('get_goal_space needs to be implemented.') 41 | 42 | def get(self, current_cell: Any, final_goal: Any, sub_goal: Any): 43 | if self.rep_type == 'final_goal': 44 | return self._get_goal_rep(final_goal, current_cell, self.rel_final_goal) 45 | elif self.rep_type == 'sub_goal': 46 | return self._get_goal_rep(sub_goal, current_cell, self.rel_sub_goal) 47 | elif self.rep_type == 'final_goal_and_sub_goal': 48 | final_rep = self._get_goal_rep(final_goal, current_cell, self.rel_final_goal) 49 | sub_rep = self._get_goal_rep(sub_goal, current_cell, self.rel_sub_goal) 50 | return np.concatenate((sub_rep, final_rep)) 51 | else: 52 | raise NotImplementedError('Unknown representation type: ' + self.rep_type) 53 | 54 | def _get_length(self, relative: bool, length_data: Any): 55 | raise NotImplementedError('_get_length needs to be implemented.') 56 | 57 | def _get_goal_rep(self, goal: Any, current_cell: Any, relative: bool): 58 | raise NotImplementedError('_get_goal_rep needs to be implemented.') 59 | 60 | 61 | class ScaledGoalRep(FlatGoalRep): 62 | """ 63 | Takes the array from a representation and divides it by normalizing constants. 64 | """ 65 | def __init__(self, rep_type: str, rel_final_goal: bool, rel_sub_goal: bool, rep_length, norm_const=None, 66 | off_const=None): 67 | super().__init__(rep_type, rel_final_goal, rel_sub_goal, rep_length) 68 | 69 | self.normalizing_constants = np.ones(rep_length) 70 | self.offset_constants = np.zeros(rep_length) 71 | 72 | if norm_const: 73 | self.normalizing_constants = norm_const 74 | if off_const: 75 | self.offset_constants = off_const 76 | 77 | def get_goal_space(self): 78 | return spaces.Box(low=-float('inf'), high=float('inf'), shape=(self.total_length,), dtype=np.float32) 79 | 80 | def _get_goal_rep(self, goal: Any, current_cell: Any, relative: bool): 81 | goal_rep = np.cast[np.float32](goal.as_array()) 82 | goal_rep /= self.normalizing_constants 83 | goal_rep += self.offset_constants 84 | if relative: 85 | current_rep = np.cast[np.float32](current_cell.as_array()) 86 | current_rep /= self.normalizing_constants 87 | current_rep += self.offset_constants 88 | goal_rep -= current_rep 89 | return goal_rep 90 | 91 | def _get_length(self, relative: bool, rep_length): 92 | return rep_length 93 | 94 | 95 | class GoalRepData: 96 | def __init__(self, rep_lengths: List[int], goal: Any, current_loc: Any, relative: bool): 97 | self.rep_lengths = rep_lengths 98 | self.goal_array = goal.as_array() 99 | self.current_array = None 100 | self.relative = relative 101 | if self.relative: 102 | self.current_array = current_loc.as_array() 103 | 104 | def get_index(self, i): 105 | max_value = self.rep_lengths[i] - 1 106 | if self.relative: 107 | feature_index = max_value + int(self.goal_array[i]) - int(self.current_array[i]) 108 | if feature_index < 0: 109 | feature_index = 0 110 | elif feature_index > max_value*2 - 1: 111 | feature_index = max_value*2 - 1 112 | else: 113 | feature_index = int(self.goal_array[i]) 114 | if feature_index < 0: 115 | feature_index = 0 116 | elif feature_index > max_value: 117 | feature_index = max_value 118 | return feature_index 119 | 120 | 121 | class OneHotGoalRep(FlatGoalRep): 122 | """ 123 | Takes the array from a representation and discretizes each value into a one-hot vector. 124 | """ 125 | def __init__(self, rep_type: str, rel_final_goal: bool, rel_sub_goal: bool, rep_lengths: List[int]): 126 | super().__init__(rep_type, rel_final_goal, rel_sub_goal, rep_lengths) 127 | 128 | def get_goal_space(self): 129 | return spaces.Box(low=0, high=1, shape=(self.total_length,), dtype=np.float32) 130 | 131 | def _get_goal_rep(self, goal: Any, current_loc: Any, relative: bool): 132 | cur_index = 0 133 | length = self._get_length(relative, self.length_data) 134 | goal_rep = np.zeros(length) 135 | goal_rep_data = GoalRepData(self.length_data, goal, current_loc, relative) 136 | for i in range(len(self.length_data)): 137 | feature_index = goal_rep_data.get_index(i) 138 | goal_rep[cur_index + feature_index] = 1.0 139 | cur_index += self.length_data[i] 140 | return goal_rep 141 | 142 | def _get_length(self, relative: bool, rep_lengths): 143 | if relative: 144 | return (sum(rep_lengths) * 2) - 1 145 | else: 146 | return sum(rep_lengths) 147 | 148 | 149 | class PosFilterGoalRep(AbstractGoalRepresentation): 150 | """ 151 | Takes the x and y attributes from a representation and turns it into an image sheet that can be stacked as a filter. 152 | """ 153 | 154 | def get(self, current_cell: Any, final_goal: Any, sub_goal: Any): 155 | if self.rep_type == 'final_goal': 156 | return self._get_goal_rep(final_goal) 157 | elif self.rep_type == 'sub_goal': 158 | return self._get_goal_rep(sub_goal) 159 | elif self.rep_type == 'final_goal_and_sub_goal': 160 | final_rep = self._get_goal_rep(final_goal) 161 | sub_rep = self._get_goal_rep(sub_goal) 162 | return np.concatenate((sub_rep, final_rep)) 163 | else: 164 | raise NotImplementedError('Unknown representation type: ' + self.rep_type) 165 | 166 | def __init__(self, shape, x_res, y_res, x_offset=0, y_offset=0, goal_value=1, norm_const=None, pos_only=False, 167 | rep_type='final_goal'): 168 | self.shape = shape 169 | self.x_res = x_res 170 | self.y_res = y_res 171 | self.x_offset = x_offset 172 | self.y_offset = y_offset 173 | self.goal_value = goal_value 174 | self.norm_const = norm_const 175 | self.rep_type = rep_type 176 | if norm_const is None: 177 | self.norm_const = np.ones(shape[-1] - 1) 178 | self.pos_only = pos_only 179 | 180 | def get_goal_space(self): 181 | return spaces.Box(low=0, high=255, shape=self.shape, dtype=np.float32) 182 | 183 | def _get_goal_rep(self, goal: Any): 184 | goal_rep = np.zeros(self.shape) 185 | x = self.x_offset + goal.get_x() * self.x_res 186 | y = self.y_offset + goal.get_y() * self.y_res 187 | goal_rep[x:x + self.x_res, y:y + self.y_res, 0] = self.goal_value 188 | if not self.pos_only: 189 | non_pos_features = goal.non_pos_as_array() 190 | for i, feature in enumerate(non_pos_features): 191 | goal_rep[:, :, i] = (feature / self.norm_const[i]) * 255 192 | 193 | return goal_rep 194 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/import_ai.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | from __future__ import print_function 10 | import warnings as _warnings 11 | import logging 12 | import random 13 | import matplotlib.pyplot as plt 14 | import matplotlib.patches as patches 15 | import matplotlib.patheffects as patheffects 16 | import numpy as np 17 | 18 | 19 | def is_notebook(): 20 | try: 21 | from IPython import get_ipython as _get_ipython 22 | if 'IPKernelApp' not in _get_ipython().config: # pragma: no cover 23 | raise ImportError("console") 24 | except ImportError: 25 | return False 26 | return True 27 | 28 | 29 | if not is_notebook(): 30 | import matplotlib 31 | matplotlib.use('Agg') 32 | 33 | # Known to be benign: https://github.com/ContinuumIO/anaconda-issues/issues/6678#issuecomment-337279157 34 | _warnings.filterwarnings('ignore', 'numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88') 35 | 36 | try: 37 | import cv2 38 | except ModuleNotFoundError: 39 | _warnings.warn('cv2 not found') 40 | 41 | try: 42 | import gym 43 | except ModuleNotFoundError: 44 | _warnings.warn('gym not found') 45 | 46 | try: 47 | if not is_notebook(): 48 | from tqdm import tqdm, trange 49 | else: 50 | from tqdm import tqdm_notebook as tqdm 51 | from tqdm import tnrange as trange 52 | except ModuleNotFoundError: 53 | _warnings.warn('tqdm not found') 54 | 55 | 56 | class IgnoreNoHandles(logging.Filter): 57 | def filter(self, record): 58 | if record.getMessage() == 'No handles with labels found to put in legend.': 59 | return 0 60 | return 1 61 | 62 | 63 | _plt_logger = logging.getLogger('matplotlib.legend') 64 | _plt_logger.addFilter(IgnoreNoHandles()) 65 | 66 | 67 | def show_img(im, figsize=None, ax=None, grid=False): 68 | if not ax: 69 | fig, ax = plt.subplots(figsize=figsize) 70 | ax.imshow(im) 71 | ax.set_xticks(np.linspace(0, 224, 8)) 72 | ax.set_yticks(np.linspace(0, 224, 8)) 73 | if grid: 74 | ax.grid() 75 | ax.set_yticklabels([]) 76 | ax.set_xticklabels([]) 77 | return ax 78 | 79 | 80 | def draw_outline(o, lw): 81 | o.set_path_effects([patheffects.Stroke( 82 | linewidth=lw, foreground='black'), patheffects.Normal()]) 83 | 84 | 85 | def draw_rect(ax, b, color='white'): 86 | patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2)) 87 | draw_outline(patch, 4) 88 | 89 | 90 | def draw_text(ax, xy, txt, sz=14, color='white'): 91 | text = ax.text(*xy, txt, verticalalignment='top', color=color, fontsize=sz, weight='bold') 92 | draw_outline(text, 1) 93 | 94 | 95 | class CircularMemory: 96 | def __init__(self, size): 97 | self.size = size 98 | self.mem = [] 99 | self.start_idx = 0 100 | 101 | def add(self, entry): 102 | if len(self.mem) < self.size: 103 | self.mem.append(entry) 104 | else: 105 | self.mem[self.start_idx] = entry 106 | self.start_idx = (self.start_idx + 1) % self.size 107 | 108 | def sample(self, n): 109 | return random.sample(self.mem, n) 110 | 111 | def __len__(self): 112 | return len(self.mem) 113 | 114 | def __getitem__(self, i): 115 | assert i < len(self) 116 | return self.mem[(self.start_idx + i) % self.size] 117 | 118 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | class SimpleLogger: 10 | def __init__(self, file_name): 11 | self.file_handle = open(file_name, 'w') 12 | self.column_names = [] 13 | self.values = [] 14 | self.first_line = True 15 | 16 | def write(self, name, value): 17 | if self.first_line: 18 | self.column_names.append(name) 19 | self.values.append(value) 20 | 21 | def flush(self): 22 | if self.first_line: 23 | self.first_line = False 24 | for i, column_name in enumerate(self.column_names): 25 | self.file_handle.write(column_name) 26 | if i < len(self.column_names) - 1: 27 | self.file_handle.write(', ') 28 | self.file_handle.write('\n') 29 | for i, value in enumerate(self.values): 30 | self.file_handle.write(str(value)) 31 | if i < len(self.column_names) - 1: 32 | self.file_handle.write(', ') 33 | self.file_handle.write('\n') 34 | self.file_handle.flush() 35 | self.values = [] 36 | 37 | def close(self): 38 | self.file_handle.close() 39 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/mpi_support.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | COMM_WORLD = None 10 | COMM_TYPE_SHARED = None 11 | 12 | 13 | def init_mpi(): 14 | global COMM_WORLD 15 | global COMM_TYPE_SHARED 16 | import mpi4py.rc 17 | mpi4py.rc.initialize = False 18 | from mpi4py import MPI 19 | COMM_WORLD = MPI.COMM_WORLD 20 | COMM_TYPE_SHARED = MPI.COMM_TYPE_SHARED 21 | 22 | 23 | def get_comm_world(): 24 | return COMM_WORLD 25 | 26 | 27 | def get_comm_type_shared(): 28 | return COMM_TYPE_SHARED 29 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import linecache 10 | 11 | 12 | def display_top(snapshot, key_type='traceback', limit=20): 13 | top_stats = snapshot.statistics(key_type) 14 | 15 | print("Top %s lines" % limit) 16 | for index, stat in enumerate(top_stats[:limit], 1): 17 | frame = stat.traceback[0] 18 | filename = frame.filename 19 | print("#%s: %s:%s: %.1f MB" 20 | % (index, filename, frame.lineno, stat.size / (1024*1024))) 21 | line = linecache.getline(frame.filename, frame.lineno).strip() 22 | 23 | for frame in stat.traceback.format(): 24 | print(frame) 25 | 26 | if line: 27 | print(' %s' % line) 28 | 29 | other = top_stats[limit:] 30 | if other: 31 | size = sum(stat.size for stat in other) 32 | print("%s other: %.1f MB" % (len(other), size / (1024*1024))) 33 | total = sum(stat.size for stat in top_stats) 34 | print("Total allocated size: %.1f MB" % (total / (1024*1024))) 35 | -------------------------------------------------------------------------------- /policy_based/goexplore_py/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Uber Technologies, Inc. 2 | # 3 | # Licensed under the Uber Non-Commercial License (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at the root directory of this project. 6 | # 7 | # See the License for the specific language governing permissions and 8 | # limitations under the License. 9 | import time 10 | import random 11 | import numpy as np 12 | import os 13 | import glob 14 | import hashlib 15 | from contextlib import contextmanager 16 | 17 | 18 | class TimedPickle: 19 | def __init__(self, data, name, enabled=True): 20 | self.data = data 21 | self.name = name 22 | self.enabled = enabled 23 | 24 | def __getstate__(self): 25 | return time.time(), self.data, self.name, self.enabled 26 | 27 | def __setstate__(self, s): 28 | tstart, self.data, self.name, self.enabled = s 29 | if self.enabled: 30 | print(f'pickle time for {self.name} = {time.time() - tstart} seconds') 31 | 32 | 33 | @contextmanager 34 | def use_seed(seed): 35 | # Save all the states 36 | python_state = random.getstate() 37 | np_state = np.random.get_state() 38 | 39 | # Seed all the rngs (note: adding different values to the seeds 40 | # in case the same underlying RNG is used by all and in case 41 | # that could be a problem. Probably not necessary) 42 | random.seed(seed + 2) 43 | np.random.seed(seed + 3) 44 | 45 | # Yield control! 46 | yield 47 | 48 | # Reset the rng states 49 | random.setstate(python_state) 50 | np.random.set_state(np_state) 51 | 52 | 53 | def get_code_hash(): 54 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 55 | all_code = '' 56 | for f in sorted(glob.glob(cur_dir + '**/*.py', recursive=True)): 57 | # We assume all whitespace is irrelevant, as well as comments 58 | with open(f) as fh: 59 | for line in fh: 60 | line = line.partition('#')[0] 61 | line = line.rstrip() 62 | 63 | all_code += ''.join(line.split()) 64 | 65 | code_hash = hashlib.sha256(all_code.encode('utf8')).hexdigest() 66 | 67 | return code_hash 68 | 69 | 70 | def clip(value, low, high): 71 | return max(min(value, high), low) 72 | -------------------------------------------------------------------------------- /policy_based/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.15.2 2 | mpi4py 3 | gym[Atari] 4 | horovod 5 | baselines@git+https://github.com/openai/baselines@ea25b9e8b234e6ee1bca43083f8f3cf974143998 6 | Pillow 7 | imageio 8 | matplotlib 9 | loky 10 | joblib 11 | dataclasses 12 | opencv-python 13 | cloudpickle -------------------------------------------------------------------------------- /policy_based/run_policy_based_ge_montezuma.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # The settings below are for testing the code locally 4 | # For the full experiment settings, change each setting to each "full experiment" value. 5 | 6 | # Full experiment: 16 7 | NB_MPI_WORKERS=2 8 | 9 | # Full experiment: 16 10 | NB_ENVS_PER_WORKER=2 11 | 12 | # Full experiment: different for each run 13 | SEED=0 14 | 15 | # Full experiment: 200000000 16 | CHECKPOINT=10000 17 | 18 | 19 | # The game is run with both sticky actions and noops. Also, for Montezuma's Revenge, the episode ends on death. 20 | GAME_OPTIONS="--sticky_actions --noops --end_on_death" 21 | 22 | # Both trajectory reward (goal_reward_factor) are 1, except for reaching the final cell, for which the reward is 3. 23 | # Extrinsic (game) rewards are clipped to [-2, 2]. Because most Atari games have large rewards, this usually means that extrinsic rewards are twice that of the trajectory rewards. 24 | REWARD_OPTIONS="--game_reward_factor 1 --goal_reward_factor 1 --clip_game_reward 1 --rew_clip_range=-2,2 --final_goal_reward 3" 25 | 26 | # Cell selection is relative to: 1 / (1 + 0.5*number_of_actions_taken_in_cell). 27 | CELL_SELECTION_OPTIONS="--selector weighted --selector_weights=attr,nb_actions_taken_in_cell,1,1,0.5 --base_weight 0" 28 | 29 | # When the agent takes too long to reach the next cell, its intropy increases according to (inc_ent_fac*steps)^ent_inc_power. 30 | # When exploring, this entropy increase starts when it takes more than expl_inc_ent_thresh (50) actions to reach a new cell. 31 | # When returning, entropy increase starts relative to the time it originally took to reach the target cell. 32 | ENTROPY_INC_OPTIONS="--entropy_strategy dynamic_increase --inc_ent_fac 0.01 --ent_inc_power 2 --ret_inc_ent_fac 1 --expl_inc_ent_thresh 50 --expl_ent_reset=on_new_cell --legacy_entropy 0" 33 | 34 | # The cell representation for Montezuma's Revenge is a domain knowledge representation including level, room, number of keys, and the x, y coordinate of the agent. 35 | # The x, y coordinate is discretized into bins of 36 by 18 pixels (note that the pixel of the x axis are doubled, so this is 18 by 18 on the orignal frame) 36 | CELL_REPRESENTATION_OPTIONS="--cell_representation level_room_keys_x_y --resolution=36,18" 37 | 38 | # When following a trajectory, the agent is allowed to reach the goal cell, or any of the subsequent soft_traj_win_size (10) - 1 cells. 39 | # While returning, the episode is terminated if it takes more than max_actions_to_goal (1000) to reach the current goal 40 | # While exploring, the episode is terminated if it takes more than max_actions_to_new_cell (1000) to discover a new cell 41 | # When the the final cell is reached, there is a random_exp_prob (0.5) chance that we explore by taking random actions, rather than by sampling from the policy. 42 | EPISODE_OPTIONS="--trajectory_tracker sparse_soft --soft_traj_win_size 10 --random_exp_prob 0.5 --max_actions_to_goal 1000 --max_actions_to_new_cell 1000 --delay 0" 43 | 44 | CHECKPOINT_OPTIONS="--checkpoint_compute ${CHECKPOINT} --clear_checkpoints trajectory" 45 | TRAINING_OPTIONS="--goal_rep onehot --gamma 0.99 --learning_rate=2.5e-4 --no_exploration_gradients --sil=sil --max_compute_steps 12000000000" 46 | MISC_OPTIONS="--low_prob_traj_tresh 0.01 --start_method spawn --log_info INFO --log_files __main__" 47 | mpirun -n ${NB_MPI_WORKERS} python3 goexplore_start.py --base_path ~/temp --seed ${SEED} --nb_envs ${NB_ENVS_PER_WORKER} ${REWARD_OPTIONS} ${CELL_SELECTION_OPTIONS} ${ENTROPY_INC_OPTIONS} ${CHECKPOINT_OPTIONS} ${CELL_REPRESENTATION_OPTIONS} ${EPISODE_OPTIONS} ${GAME_OPTIONS} ${TRAINING_OPTIONS} ${MISC_OPTIONS} 48 | -------------------------------------------------------------------------------- /policy_based/run_policy_based_ge_pitfall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | # 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | # 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | # 12 | # 13 | # The settings below are for testing the code locally 14 | # For the full experiment settings, change each setting to each "full experiment" value. 15 | 16 | # Full experiment: 16 17 | NB_MPI_WORKERS=2 18 | 19 | # Full experiment: 16 20 | NB_ENVS_PER_WORKER=2 21 | 22 | # Full experiment: different for each run 23 | SEED=0 24 | 25 | # Full experiment: 200000000 26 | CHECKPOINT=10000 27 | 28 | 29 | # The game is run with both sticky actions and noops.. 30 | GAME_OPTIONS="--game pitfall --sticky_actions --noops" 31 | 32 | # Both trajectory reward (goal_reward_factor) are 1, except for reaching the final cell, for which the reward is 3. 33 | # Extrinsic (game) rewards are clipped to [-2, 2]. Because most Atari games have large rewards, this usually means that extrinsic rewards are twice that of the trajectory rewards. 34 | REWARD_OPTIONS="--game_reward_factor 1 --goal_reward_factor 1 --clip_game_reward 1 --rew_clip_range=-2,2 --final_goal_reward 3" 35 | 36 | # Cell selection is relative to: 1 / (1 + 0.5*number_of_actions_taken_in_cell). 37 | CELL_SELECTION_OPTIONS="--selector weighted --selector_weights=attr,nb_actions_taken_in_cell,1,1,0.5 --base_weight 0" 38 | 39 | # When the agent takes too long to reach the next cell, its intropy increases according to (inc_ent_fac*steps)^ent_inc_power. 40 | # When exploring, this entropy increase starts when it takes more than expl_inc_ent_thresh (50) actions to reach a new cell. 41 | # When returning, entropy increase starts relative to the time it originally took to reach the target cell. 42 | ENTROPY_INC_OPTIONS="--entropy_strategy dynamic_increase --inc_ent_fac 0.01 --ent_inc_power 2 --ret_inc_ent_fac 1 --expl_inc_ent_thresh 50 --expl_ent_reset=on_new_cell --legacy_entropy 0" 43 | 44 | # The cell representation for Montezuma's Revenge is a domain knowledge representation including level, room, number of keys, and the x, y coordinate of the agent. 45 | # The x, y coordinate is discretized into bins of 36 by 18 pixels (note that the pixel of the x axis are doubled, so this is 18 by 18 on the orignal frame) 46 | CELL_REPRESENTATION_OPTIONS="--cell_representation room_x_y --resolution=36,18" 47 | 48 | # When following a trajectory, the agent is allowed to reach the goal cell, or any of the subsequent soft_traj_win_size (10) - 1 cells. 49 | # While returning, the episode is terminated if it takes more than max_actions_to_goal (1000) to reach the current goal 50 | # While exploring, the episode is terminated if it takes more than max_actions_to_new_cell (1000) to discover a new cell 51 | # When the the final cell is reached, there is a random_exp_prob (0.5) chance that we explore by taking random actions, rather than by sampling from the policy. 52 | EPISODE_OPTIONS="--trajectory_tracker sparse_soft --soft_traj_win_size 10 --random_exp_prob 0.5 --max_actions_to_goal 1000 --max_actions_to_new_cell 1000 --delay 0" 53 | 54 | CHECKPOINT_OPTIONS="--checkpoint_compute ${CHECKPOINT} --clear_checkpoints trajectory" 55 | TRAINING_OPTIONS="--goal_rep onehot --gamma 0.99 --learning_rate=2.5e-4 --no_exploration_gradients --sil=sil --max_compute_steps 10000000000" 56 | MISC_OPTIONS="--low_prob_traj_tresh 0.01 --start_method spawn --log_info INFO --log_files __main__" 57 | mpirun -n ${NB_MPI_WORKERS} python3 goexplore_start.py --base_path ~/temp --seed ${SEED} --nb_envs ${NB_ENVS_PER_WORKER} ${REWARD_OPTIONS} ${CELL_SELECTION_OPTIONS} ${ENTROPY_INC_OPTIONS} ${CHECKPOINT_OPTIONS} ${CELL_REPRESENTATION_OPTIONS} ${EPISODE_OPTIONS} ${GAME_OPTIONS} ${TRAINING_OPTIONS} ${MISC_OPTIONS} 58 | -------------------------------------------------------------------------------- /robustified/README.md: -------------------------------------------------------------------------------- 1 | # Go-Explore 2 | 3 | ## Requirements 4 | 5 | Tested with Python 3.6. `requirements.txt` gives the exact libraries used on a test machine 6 | able to run all phases on Atari. 7 | 8 | Required libraries for the exploration phase: 9 | - matplotlib 10 | - loky==2.3.1 11 | - dataclasses 12 | - gym 13 | - opencv-python 14 | 15 | The ALE/atari-py is not part of Go-Explore. If you are interested in running Go-Explore on Atari environments (for example to reproduce our experiments), you may install gym\[atari\] instead of just gym. Doing so will install atari-py. atari-py is licensed under GPLv2. 16 | 17 | Additional libraries for demo generation: 18 | - imageio-ffmpeg (optional) 19 | - fire 20 | - tqdm 21 | 22 | Additional libraries for robustification: 23 | - tensorflow=1.5.2 (or equivalent tensorflow-gpu) 24 | - pandas 25 | - horovod 26 | - filelock 27 | - mpi4py 28 | - baselines 29 | - To avoid having to install mujoco-py, install commit 6d1c6c78d38dd25799145026a590cc584ea22c88 (`pip install git+git://github.com/openai/baselines.git@6d1c6c78d38dd25799145026a590cc584ea22c88`) 30 | 31 | To run robustification, you will need to clone [uber-research/atari-reset](https://github.com/uber-research/atari-reset) (note: this is an improved fork of the original project, which you can find at [openai/atari-reset](https://github.com/openai/atari-reset)) and 32 | put it, copy it or link to it as `atari_reset` in the same folder as `goexplore_py`. 33 | E.g. you could run: 34 | 35 | `git clone https://github.com/uber-research/atari-reset atari_reset` 36 | 37 | 38 | Running the robotics environments requires a local installation of MuJoCo 2.0 (1.5 may work too), 39 | as well as a corresponding version of mujoco-py. mujoco-py is not included in requirements.txt as it is unnecessary 40 | for running the Atari environments. 41 | 42 | ## Usage 43 | 44 | The exploration phase experiments on Atari with a downscaled representation can be run with: 45 | 46 | `./phase1_downscaled.sh ` 47 | 48 | Running the exploration phase with domain knowledge on Montezuma's Revenge and Pitfall is done using: 49 | 50 | `./phase1_montezuma.sh ` 51 | 52 | and 53 | 54 | `./phase1_pitfall.sh ` 55 | 56 | If any argument is not supplied, a default value will be used. The default game is MontezumaRevenge for 57 | the downscaled experiments (for the domain knowledge experiment, there is no game argument), the default 58 | path is results and the default number of timesteps is 500,000 for the downscaled version, corresponding to the 2 billion frames used in the 59 | paper (due to frame skipping, one timestep corresponds to 4 frames), and 250,000 for the domain knowledge version, corresponding 60 | to the 1 billion frames used in the paper. 61 | 62 | The exploration phase produces a folder called `results`, and subfolders for each experiment, of the form 63 | `0000_fb6be589a3dc44c1b561336e04c6b4cb`, where the first element is an automatically increasing 64 | experiment id and the second element is a random string that helps prevent race condition issues if 65 | two experiments are started at the same time and assigned the same id. 66 | 67 | To generate demonstrations for Atari, run 68 | 69 | `./gen_demo_atari.sh ` 70 | 71 | source is mandatory and is the folder produced by the exploration phase, destination will default to _demo, 72 | game defaults to MontezumaRevenge. You may also pass `--render` as a fourth argument to generate a video of your 73 | exploration phase agent playing the game. 74 | 75 | To robustify, put a set of `.demo` files from different runs of Phase 1 into a folder 76 | (we used 10 in all cases, a single demonstration can also work, but is less 77 | likely to succeed). Then run `./phase2.sh `. The default game is MontezumaRevenge, default demo folder is `demos`, default resulst folder is `results` 78 | and default timesteps is `2,500,000` (corresponding to 10 billion frames as used in the paper for most games). The robustification 79 | code doesn't handle relative paths well so it is recommended to give it absolute paths. 80 | 81 | Important: all of the robustification results in the paper were performed with 8 GPUs through MPI. The `phase2*.sh` scripts do not start MPI themselves, you will need to do so yourself when calling them, e.g. by running `mpirun -np 8 ./phase2.sh `. 82 | 83 | You may then test the performance of your trained neural network using 84 | `./phase2_atari_test.sh ` 85 | where is one of the files produced by the robustification phase and printed in the log as `Saving to ...`. 86 | This will produce `.json` files for each possible number of no-ops (from 0 to 30) with scores, levels 87 | and exact action sequences produced by the test runs. 88 | 89 | For the fetch environments, the steps are similar but with the `.sh` files containing `fetch`. In this context, `game` 90 | is the target shelf identifier, with the valid identifiers being `0001`, `0010`, `0100` and `1000`. These scripts also need to be run on 8 GPUs with MPI as described above to reproduce our exact results. 91 | 92 | Note that the `gen_demo` 93 | script for fetch produces 10 demos from a single exploration phase run, so you do not need to run the exploration phase 94 | multiple times and combine the demo files to run robustification. 95 | 96 | Crucially, the fetch environment requires that mujoco-py be installed, which itself requires that MuJoCo 2.0 be installed. 97 | It is also important that the folder which contains `goexplore_py` be in the PYTHONPATH during robustification. 98 | 99 | Finally, the two controls (vanilla PPO and PPO + IM) for fetch can be run using `./control_ppo_fetch.sh` and `./control_im_fetch.sh`. 100 | Both take as parameters the target shelf, result output folder and number of frames. -------------------------------------------------------------------------------- /robustified/control_im_fetch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | game=${1:-1000} 15 | results=${2:-`pwd`/results} 16 | frames=${3:-1000000000} 17 | 18 | python atari_reset/train_atari.py --ffmem=128 --ffsh=1x256 --learning_r=0.0001 --demo_sel=normalize_by_target --nrst=40 --ent_coef=1e-05 --sil_coef=0.1 --extra_frames_exp_factor=4 --allowed_lag=10 --fetch_max_steps=300 --intrinsic_reward_weight=1.0 --num_timesteps $frames --fetch_target_location=$game --fetch_type=boxes_1 --fetch_nsubsteps=80 --fetch_total_timestep=0.08 --test_from_start --nenvs=128 --n_sil_envs=0 --sd_multiply_explore=2 --inc_entropy_threshold=10 --fetch_incl_extra_full_state --game=fetch --demo __nodemo__ --gamma=0.99 --vf_coef=0.5 --steps_per_demo=100 --move_threshold=0.1 --save_path=$results 19 | -------------------------------------------------------------------------------- /robustified/control_ppo_fetch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | game=${1:-1000} 15 | results=${2:-`pwd`/results} 16 | frames=${3:-1000000000} 17 | 18 | python atari_reset/train_atari.py --ffmem=128 --ffsh=1x256 --learning_r=0.0001 --demo_sel=normalize_by_target --nrst=40 --ent_coef=1e-05 --sil_coef=0.1 --extra_frames_exp_factor=4 --allowed_lag=10 --fetch_max_steps=300 --num_timesteps $frames --fetch_target_location=$game --fetch_type=boxes_1 --fetch_nsubsteps=80 --fetch_total_timestep=0.08 --test_from_start --nenvs=128 --n_sil_envs=0 --sd_multiply_explore=2 --inc_entropy_threshold=10 --fetch_incl_extra_full_state --game=fetch --demo __nodemo__ --gamma=0.99 --vf_coef=0.5 --steps_per_demo=100 --move_threshold=0.1 --save_path=$results 19 | 20 | -------------------------------------------------------------------------------- /robustified/gen_demo/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | -------------------------------------------------------------------------------- /robustified/gen_demo/atari_demo/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2018 OpenAI (http://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /robustified/gen_demo/atari_demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/gen_demo/atari_demo/__init__.py -------------------------------------------------------------------------------- /robustified/gen_demo/atari_demo/cloned_vec_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiprocessing import Process, Pipe 3 | import gym 4 | from baselines.common.vec_env.subproc_vec_env import CloudpickleWrapper 5 | 6 | class ClonedEnv(gym.Wrapper): 7 | def __init__(self, env, possible_actions_dict, best_action_dict, seed): 8 | gym.Wrapper.__init__(self, env) 9 | self.possible_actions_dict = possible_actions_dict 10 | self.best_action_dict = best_action_dict 11 | self.state = None 12 | self.rng = np.random.RandomState(seed) 13 | self.just_initialized = True 14 | self.l = 0 15 | self.r = 0 16 | 17 | def step(self, action=None): 18 | if self.state in self.possible_actions_dict: 19 | possible_actions = list(self.possible_actions_dict[self.state]) 20 | action = possible_actions[self.rng.randint(len(possible_actions))] 21 | obs, reward, done, info = self.env.step(action) 22 | self.l += 1 23 | self.r += reward 24 | self.state = self.env.unwrapped._get_ram().tostring() 25 | if self.state in self.possible_actions_dict: # still in known territory 26 | info['possible_actions'] = self.possible_actions_dict[self.state] 27 | if self.state in self.best_action_dict: 28 | info['best_action'] = self.best_action_dict[self.state] 29 | else: 30 | done = True 31 | past_l = self.l 32 | past_r = self.r 33 | self.l = 0 34 | self.r = 0 35 | if past_l > 0: 36 | info['episode'] = {'r': past_r, 'l': past_l} 37 | else: 38 | raise Exception('stepping cloned env without resetting') 39 | 40 | return obs, reward, done, info 41 | 42 | def reset(self): 43 | obs = self.env.reset() 44 | if isinstance(obs, tuple): 45 | obs,info = obs 46 | else: 47 | info = {} 48 | 49 | self.state = self.env.unwrapped._get_ram().tostring() 50 | if self.state in self.best_action_dict: 51 | info['best_action'] = self.best_action_dict[self.state] 52 | for randop in range(self.rng.randint(30)): # randomize starting point 53 | obs, reward, done, info = self._step(None) 54 | 55 | if self.just_initialized: 56 | self.just_initialized = False 57 | for randops in range(self.rng.randint(50000)): # randomize starting point further 58 | obs, reward, done, info = self._step(None) 59 | if done: 60 | obs, info = self._reset() 61 | 62 | return obs, info 63 | 64 | def get_best_actions_from_infos(infos): 65 | k = len(infos) 66 | best_actions = [0] * k 67 | action_masks = [1] * k 68 | for i in range(k): 69 | if 'best_action' in infos[i]: 70 | best_actions[i] = infos[i]['best_action'] 71 | action_masks[i] = 0 72 | return best_actions, action_masks 73 | 74 | def get_available_actions_from_infos(infos, n_actions): 75 | k = len(infos) 76 | best_actions = np.zeros((k,n_actions), dtype=np.uint8) 77 | action_masks = [1] * k 78 | for i in range(k): 79 | if 'possible_actions' in infos[i]: 80 | action_masks[i] = 0 81 | for j in infos[i]['possible_actions']: 82 | best_actions[i,j] = 1 83 | return best_actions, action_masks 84 | 85 | def worker2(nr, remote, env_fn_wrapper, mode): 86 | env = env_fn_wrapper.x() 87 | while True: 88 | cmd,count = remote.recv() 89 | if cmd == 'step': 90 | obs = [] 91 | rews = [] 92 | dones = [] 93 | infos = [] 94 | for step in range(count): 95 | ob, reward, done, info = env.step(0) # action is ignored in ClonedEnv downstream 96 | if done: 97 | ob = env.reset() 98 | if isinstance(ob, tuple): 99 | ob, new_info = ob 100 | info.update(new_info) 101 | if 'episode' in info: 102 | epinfo = info['episode'] 103 | print('simulator thread %d completed demo run with total return %d obtained in %d steps' % (nr, epinfo["r"], epinfo["l"])) 104 | obs.append(ob) 105 | rews.append(reward) 106 | dones.append(done) 107 | infos.append(info) 108 | if mode == 'best': 109 | best_actions, action_masks = get_best_actions_from_infos(infos) 110 | else: 111 | best_actions, action_masks = get_available_actions_from_infos(infos, env.action_space.n) 112 | remote.send((obs, rews, dones, best_actions, action_masks)) 113 | elif cmd == 'reset': 114 | ob = env.reset() 115 | if isinstance(ob, tuple): 116 | ob,info = ob 117 | else: 118 | info = {} 119 | remote.send((ob,info)) 120 | elif cmd == 'close': 121 | remote.close() 122 | break 123 | elif cmd == 'get_spaces': 124 | remote.send((env.action_space, env.observation_space)) 125 | else: 126 | raise NotImplementedError(str(cmd) + ' action not implemented in worker') 127 | 128 | class ClonedVecEnv(object): 129 | def __init__(self, env_fns, mode='best'): 130 | self.nenvs = len(env_fns) 131 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(self.nenvs)]) 132 | self.ps = [Process(target=worker2, args=(nr, work_remote, CloudpickleWrapper(env_fn), mode)) 133 | for (nr, work_remote, env_fn) in zip(range(self.nenvs), self.work_remotes, env_fns)] 134 | for p in self.ps: 135 | p.start() 136 | self.remotes[0].send(('get_spaces', None)) 137 | self.action_space, self.observation_space = self.remotes[0].recv() 138 | self.steps_taken = 0 139 | 140 | def step(self, time_steps=128): 141 | for remote in self.remotes: 142 | remote.send(('step', time_steps)) 143 | results = [remote.recv() for remote in self.remotes] 144 | obs, rews, dones, best_actions, action_masks = [np.stack(x) for x in zip(*results)] 145 | return obs, rews, dones, best_actions, action_masks 146 | 147 | def reset(self): 148 | for remote in self.remotes: 149 | remote.send(('reset', None)) 150 | results = [remote.recv() for remote in self.remotes] 151 | obs, infos = zip(*results) 152 | best_actions, action_masks = [np.stack(x) for x in get_best_actions_from_infos(infos)] 153 | return np.stack(obs), best_actions, action_masks 154 | 155 | def close(self): 156 | for remote in self.remotes: 157 | remote.send(('close', None)) 158 | for p in self.ps: 159 | p.join() 160 | 161 | def make_cloned_vec_env(nenvs, env_id, possible_actions_dict, best_action_dict, wrappers, mode='best'): 162 | def make_env(rank): 163 | def env_fn(): 164 | env = gym.make(env_id) 165 | env = ClonedEnv(env, possible_actions_dict, best_action_dict, rank) 166 | env = wrappers(env) 167 | return env 168 | return env_fn 169 | 170 | return ClonedVecEnv([make_env(i) for i in range(nenvs)], mode) 171 | 172 | -------------------------------------------------------------------------------- /robustified/gen_demo/atari_demo/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | import os 4 | 5 | def save_as_pickled_object(obj, filepath): 6 | """ 7 | This is a defensive way to write pickle.write, allowing for very large files on all platforms 8 | """ 9 | max_bytes = 2**31 - 1 10 | bytes_out = pickle.dumps(obj) 11 | n_bytes = sys.getsizeof(bytes_out) 12 | with open(filepath, 'wb') as f_out: 13 | for idx in range(0, n_bytes, max_bytes): 14 | f_out.write(bytes_out[idx:idx+max_bytes]) 15 | 16 | 17 | def load_as_pickled_object(filepath): 18 | """ 19 | This is a defensive way to write pickle.load, allowing for very large files on all platforms 20 | """ 21 | max_bytes = 2**31 - 1 22 | try: 23 | input_size = os.path.getsize(filepath) 24 | bytes_in = bytearray(0) 25 | with open(filepath, 'rb') as f_in: 26 | for _ in range(0, input_size, max_bytes): 27 | bytes_in += f_in.read(max_bytes) 28 | obj = pickle.loads(bytes_in) 29 | except: 30 | return None 31 | return obj 32 | 33 | -------------------------------------------------------------------------------- /robustified/gen_demo/atari_demo/wrappers.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import gym 3 | from gym import spaces 4 | 5 | class AtariDemo(gym.Wrapper): 6 | """ 7 | Records actions taken, creates checkpoints, allows time travel, restoring and saving of states 8 | """ 9 | 10 | def __init__(self, env, disable_time_travel=False): 11 | super(AtariDemo, self).__init__(env) 12 | self.action_space = spaces.Discrete(len(env.unwrapped._action_set)+1) # add "time travel" action 13 | self.save_every_k = 100 14 | self.max_time_travel_steps = 10000 15 | self.disable_time_travel = disable_time_travel 16 | 17 | def step(self, action): 18 | if action >= len(self.env.unwrapped._action_set): 19 | if self.disable_time_travel: 20 | obs, reward, done, info = self.env.step(0) 21 | else: 22 | obs, reward, done, info = self.time_travel() 23 | 24 | else: 25 | if self.steps_in_the_past > 0: 26 | self.restore_past_state() 27 | 28 | if len(self.done)>0 and self.done[-1]: 29 | obs = self.obs[-1] 30 | reward = 0 31 | done = True 32 | info = None 33 | 34 | else: 35 | self.lives.append(self.env.unwrapped.ale.lives()) 36 | 37 | obs, reward, done, info = self.env.step(action) 38 | 39 | self.actions.append(action) 40 | self.obs.append(obs) 41 | self.rewards.append(reward) 42 | self.done.append(done) 43 | self.info.append(info) 44 | 45 | # periodic checkpoint saving 46 | if not done: 47 | if (len(self.checkpoint_action_nr)>0 and len(self.actions) >= self.checkpoint_action_nr[-1] + self.save_every_k) \ 48 | or (len(self.checkpoint_action_nr)==0 and len(self.actions) >= self.save_every_k): 49 | self.save_checkpoint() 50 | 51 | return obs, reward, done, info 52 | 53 | def reset(self): 54 | obs = self.env.reset() 55 | self.actions = [] 56 | self.lives = [] 57 | self.checkpoints = [] 58 | self.checkpoint_action_nr = [] 59 | self.obs = [obs] 60 | self.rewards = [] 61 | self.done = [False] 62 | self.info = [None] 63 | self.steps_in_the_past = 0 64 | return obs 65 | 66 | def time_travel(self): 67 | if len(self.obs) > 1: 68 | reward = self.rewards.pop() 69 | self.obs.pop() 70 | self.done.pop() 71 | self.info.pop() 72 | self.lives.pop() 73 | obs = self.obs[-1] 74 | done = self.done[-1] 75 | info = self.info[-1] 76 | self.steps_in_the_past += 1 77 | 78 | else: # reached time travel limit 79 | reward = 0 80 | obs = self.obs[0] 81 | done = self.done[0] 82 | info = self.info[0] 83 | 84 | # rewards are differences in subsequent state values, and so should get reversed sign when going backward in time 85 | reward = -reward 86 | 87 | return obs, reward, done, info 88 | 89 | def save_to_file(self, file_name): 90 | dat = {'actions': self.actions, 'checkpoints': self.checkpoints, 'checkpoint_action_nr': self.checkpoint_action_nr, 91 | 'rewards': self.rewards, 'lives': self.lives} 92 | with open(file_name, "wb") as f: 93 | pickle.dump(dat, f) 94 | 95 | def load_from_file(self, file_name): 96 | self.reset() 97 | with open(file_name, "rb") as f: 98 | dat = pickle.load(f) 99 | self.actions = dat['actions'] 100 | self.checkpoints = dat['checkpoints'] 101 | self.checkpoint_action_nr = dat['checkpoint_action_nr'] 102 | self.rewards = dat['rewards'] 103 | self.lives = dat['lives'] 104 | self.load_state_and_walk_forward() 105 | 106 | def save_checkpoint(self): 107 | chk_pnt = self.env.unwrapped.clone_state() 108 | self.checkpoints.append(chk_pnt) 109 | self.checkpoint_action_nr.append(len(self.actions)) 110 | 111 | def restore_past_state(self): 112 | self.actions = self.actions[:-self.steps_in_the_past] 113 | while len(self.checkpoints)>0 and self.checkpoint_action_nr[-1]>len(self.actions): 114 | self.checkpoints.pop() 115 | self.checkpoint_action_nr.pop() 116 | self.load_state_and_walk_forward() 117 | self.steps_in_the_past = 0 118 | 119 | def load_state_and_walk_forward(self): 120 | if len(self.checkpoints)==0: 121 | self.env.reset() 122 | time_step = 0 123 | else: 124 | self.env.unwrapped.restore_state(self.checkpoints[-1]) 125 | time_step = self.checkpoint_action_nr[-1] 126 | 127 | for a in self.actions[time_step:]: 128 | action = self.env.unwrapped._action_set[a] 129 | self.env.unwrapped.ale.act(action) 130 | -------------------------------------------------------------------------------- /robustified/gen_demo_atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | source=$1 15 | dest=${2:-`echo $source | sed -E 's/\/*$//'`_demo} 16 | game=${3:-MontezumaRevenge} 17 | 18 | python gen_demo/new_gen_demo.py --source $source --destination=$dest --game=$game --select_done --n_demos=1 --compress=bz2 --min_compute_steps=0 $4 19 | 20 | -------------------------------------------------------------------------------- /robustified/gen_demo_fetch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | source=$1 15 | dest=${2:-`echo $source | sed -E 's/\/*$//'`_demo} 16 | game=${3:-1000} 17 | 18 | python gen_demo/new_gen_demo.py --fetch_target_location=$game --source=$source --destination=$dest --fetch_type=boxes_1 --fetch_nsubsteps=80 --fetch_total_timestep=0.08 --game=fetch --render --n_demos=10 --select_reward --compress=bz2 --min_compute_steps=0 19 | 20 | -------------------------------------------------------------------------------- /robustified/goexplore_py/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | -------------------------------------------------------------------------------- /robustified/goexplore_py/basics.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | import itertools 13 | import gzip as gz 14 | import bz2 15 | import glob 16 | from pathlib import Path 17 | import time 18 | import shutil 19 | import copy 20 | import gc 21 | import uuid 22 | import hashlib 23 | import multiprocessing 24 | import loky 25 | import os 26 | import random 27 | import collections 28 | from collections import Counter, defaultdict, namedtuple 29 | import sys 30 | import heapq 31 | from pathlib import Path 32 | import json 33 | import typing 34 | import functools 35 | import warnings as _warnings 36 | import argparse 37 | import pickle 38 | 39 | def fastdump(data, file): 40 | pickler = pickle.Pickler(file) 41 | pickler.fast = True 42 | pickler.dump(data) 43 | 44 | 45 | import enum 46 | from enum import Enum, IntEnum 47 | from contextlib import contextmanager 48 | 49 | try: 50 | from dataclasses import dataclass, field as datafield 51 | def copyfield(data): 52 | return datafield(default_factory=lambda: copy.deepcopy(data)) 53 | except Exception: 54 | _warnings.warn('dataclasses not found. To get it, use Python 3.7 or pip install dataclasses') 55 | 56 | infinity = float('inf') 57 | 58 | def notebook_max_width(): 59 | from IPython.core.display import display, HTML 60 | display(HTML("")) 61 | 62 | 63 | class memoized(object): 64 | '''Decorator. Caches a function's return value each time it is called. 65 | If called later with the same arguments, the cached value is returned 66 | (not reevaluated). 67 | ''' 68 | def __init__(self, func): 69 | self.func = func 70 | self.cache = {} 71 | def __call__(self, *args): 72 | if not isinstance(args, collections.Hashable): 73 | # uncacheable. a list, for instance. 74 | # better to not cache than blow up. 75 | return self.func(*args) 76 | if args in self.cache: 77 | return self.cache[args] 78 | else: 79 | value = self.func(*args) 80 | self.cache[args] = value 81 | return value 82 | def __repr__(self): 83 | '''Return the function's docstring.''' 84 | return self.func.__doc__ 85 | def __get__(self, obj, objtype): 86 | '''Support instance methods.''' 87 | return functools.partial(self.__call__, obj) 88 | -------------------------------------------------------------------------------- /robustified/goexplore_py/explorers.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | from .import_ai import * 13 | 14 | class RandomExplorer: 15 | def init_seed(self): 16 | pass 17 | 18 | def get_action(self, env): 19 | return random.randint(0, env.action_space.n - 1) 20 | 21 | def __repr__(self): 22 | return 'RandomExplorer()' 23 | 24 | 25 | class RepeatedRandomExplorer: 26 | def __init__(self, mean_repeat=10): 27 | self.mean_repeat = mean_repeat 28 | self.action = 0 29 | self.remaining = 0 30 | 31 | def init_seed(self): 32 | self.remaining = 0 33 | 34 | def get_action(self, env): 35 | if self.remaining <= 0: 36 | self.action = random.randint(0, env.action_space.n - 1) 37 | # Note, this is equivalent to selecting an action and then repeating it 38 | # with some probability. 39 | self.remaining = np.random.geometric(1 / self.mean_repeat) 40 | self.remaining -= 1 41 | return self.action 42 | 43 | def __repr__(self): 44 | return f'repeat-{self.mean_repeat}' 45 | 46 | 47 | class RepeatedRandomExplorerRobot: 48 | def __init__(self, mean_repeat=10): 49 | self.mean_repeat = mean_repeat 50 | self.action = 0 51 | self.remaining = 0 52 | 53 | def init_seed(self): 54 | self.remaining = 0 55 | 56 | def get_action(self, env): 57 | if self.remaining <= 0: 58 | self.action = env.action_space.sample() 59 | # Note, this is equivalent to selecting an action and then repeating it 60 | # with some probability. 61 | self.remaining = np.random.geometric(1 / self.mean_repeat) 62 | self.remaining -= 1 63 | return self.action 64 | 65 | def __repr__(self): 66 | return f'repeat-{self.mean_repeat}' 67 | 68 | 69 | class RandomDriftExplorerRobot: 70 | def __init__(self, sd): 71 | self.sd = sd 72 | 73 | def init_seed(self): 74 | pass 75 | 76 | def get_action(self, env): 77 | return env.prev_action + np.random.randn(env.prev_action.size) * self.sd 78 | 79 | def __repr__(self): 80 | return f'drift-{self.sd}' 81 | 82 | 83 | def actstr(act): 84 | return ' '.join([f'{e:01.2f}' for e in act]) 85 | 86 | 87 | class RandomDriftExplorerFetch: 88 | def __init__(self, sd): 89 | self.sd = sd 90 | 91 | def init_seed(self): 92 | pass 93 | 94 | def get_action(self, env): 95 | action = env.prev_action + np.random.randn(env.prev_action.size) * self.sd 96 | return action 97 | 98 | def __repr__(self): 99 | return f'drift-{self.sd}' 100 | 101 | 102 | class RepeatedRandomExplorerFetch: 103 | def __init__(self, mean_repeat=10): 104 | self.mean_repeat = mean_repeat 105 | self.action = 0 106 | self.remaining = 0 107 | 108 | def init_seed(self): 109 | self.remaining = 0 110 | 111 | def get_action(self, env): 112 | if self.remaining <= 0: 113 | self.action = env.sample_action(sd=2) 114 | # Note, this is equivalent to selecting an action and then repeating it 115 | # with some probability. 116 | self.remaining = np.random.geometric(1 / self.mean_repeat) 117 | self.remaining -= 1 118 | return self.action 119 | 120 | def __repr__(self): 121 | return f'repeat-{self.mean_repeat}' 122 | 123 | 124 | class DoNothingExplorer: 125 | def init_seed(self): 126 | pass 127 | 128 | def get_action(self, *args): 129 | return 0 130 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/README.md: -------------------------------------------------------------------------------- 1 | # Fetch 2 | Mujoco Models for Fetch Robot 3 | 4 | # Environments 5 | 6 | fetch_pole.xml | fetch_maneuver.xml |teleOp_boxes.xml | teleOp_objects.xml 7 | :-------------------------:|:-------------------------:|:-------------------------:|:-------------------------: 8 | ![Alt text](gallery/pole.JPG?raw=false "Fetch Pole") | ![Alt text](gallery/maneuver.JPG?raw=false "fetch maneuver") | ![Alt text](gallery/boxes.JPG?raw=false "teleOp Boxes") | ![Alt text](gallery/objects.JPG?raw=false "teleOp objects") 9 | 10 | 11 | 12 | # TeleOp Video 13 | ([Click to play](https://www.youtube.com/watch?v=2qEf5TkFXpQ)) 14 | 15 | [![TeleOp](https://img.youtube.com/vi/2qEf5TkFXpQ/0.jpg)](https://www.youtube.com/watch?v=2qEf5TkFXpQ) 16 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/box/asset.xml: -------------------------------------------------------------------------------- 1 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/box/chain.xml: -------------------------------------------------------------------------------- 1 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/door/asset.xml: -------------------------------------------------------------------------------- 1 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/door/chain0.xml: -------------------------------------------------------------------------------- 1 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 50 | 51 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/door/chain1.xml: -------------------------------------------------------------------------------- 1 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 48 | 49 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/fetch_maneuver.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/fetch_pole.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/gallery/boxes.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/gallery/boxes.JPG -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/gallery/maneuver.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/gallery/maneuver.JPG -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/gallery/objects.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/gallery/objects.JPG -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/gallery/pole.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/gallery/pole.JPG -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/asset.xml: -------------------------------------------------------------------------------- 1 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 43 | 44 | 46 | 47 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/base_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/base_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/bellows_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/bellows_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/elbow_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/elbow_flex_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/estop_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/estop_link.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/forearm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/forearm_roll_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/gripper_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/gripper_link.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/head_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/head_pan_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/head_tilt_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/head_tilt_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/l_gripper_finger_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/l_gripper_finger_link.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/l_wheel_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/l_wheel_link.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/l_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/l_wheel_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/laser_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/laser_link.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/main.xml: -------------------------------------------------------------------------------- 1 | 18 | --> 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/r_gripper_finger_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/r_gripper_finger_link.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/r_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/r_wheel_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/shoulder_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/shoulder_lift_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/shoulder_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/shoulder_pan_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/torso_fixed_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/torso_fixed_link.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/torso_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/torso_lift_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/upperarm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/upperarm_roll_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/wrist_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/wrist_flex_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/robot/fetch/wrist_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/robot/fetch/wrist_roll_link_collision.stl -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/shelf/chain.xml: -------------------------------------------------------------------------------- 1 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/table/asset.xml: -------------------------------------------------------------------------------- 1 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/table/chain.xml: -------------------------------------------------------------------------------- 1 | 13 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/teleOp_boxes.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/teleOp_boxes_1.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/teleOp_objects.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/texture/marble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/texture/marble.png -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/texture/small_placeholder_2d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/texture/small_placeholder_2d.png -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/texture/small_placeholder_cube.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/texture/small_placeholder_cube.png -------------------------------------------------------------------------------- /robustified/goexplore_py/fetch_xml/texture/wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/go-explore/702fb9c7a9aeecf2872d07ced236c730d5536a8f/robustified/goexplore_py/fetch_xml/texture/wood.png -------------------------------------------------------------------------------- /robustified/goexplore_py/generic_atari_env.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | from .basics import * 13 | from .import_ai import * 14 | from . import montezuma_env 15 | from .utils import imdownscale 16 | 17 | def convert_state(state): 18 | if MyAtari.TARGET_SHAPE is None: 19 | return None 20 | import cv2 21 | state = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY) 22 | if MyAtari.TARGET_SHAPE == (-1, -1): 23 | return RLEArray(state) 24 | return imdownscale(state, MyAtari.TARGET_SHAPE, MyAtari.MAX_PIX_VALUE) 25 | 26 | class AtariPosLevel: 27 | __slots__ = ['level', 'score', 'room', 'x', 'y', 'tuple'] 28 | 29 | def __init__(self, level=0, score=0, room=0, x=0, y=0): 30 | self.level = level 31 | self.score = score 32 | self.room = room 33 | self.x = x 34 | self.y = y 35 | 36 | self.set_tuple() 37 | 38 | def set_tuple(self): 39 | self.tuple = (self.level, self.score, self.room, self.x, self.y) 40 | 41 | def __hash__(self): 42 | return hash(self.tuple) 43 | 44 | def __eq__(self, other): 45 | if not isinstance(other, AtariPosLevel): 46 | return False 47 | return self.tuple == other.tuple 48 | 49 | def __getstate__(self): 50 | return self.tuple 51 | 52 | def __setstate__(self, d): 53 | self.level, self.score, self.room, self.x, self.y = d 54 | self.tuple = d 55 | 56 | def __repr__(self): 57 | return f'Level={self.level} Room={self.room} Objects={self.score} x={self.x} y={self.y}' 58 | 59 | def clip(a, m, M): 60 | if a < m: 61 | return m 62 | if a > M: 63 | return M 64 | return a 65 | 66 | 67 | class MyAtari: 68 | def __init__(self, name, x_repeat=2, end_on_death=False): 69 | self.name = name 70 | self.env = gym.make(f'{name}Deterministic-v4') 71 | self.unwrapped.seed(0) 72 | self.env.reset() 73 | self.state = [] 74 | self.x_repeat = x_repeat 75 | self.rooms = [] 76 | self.unprocessed_state = None 77 | self.end_on_death = end_on_death 78 | self.prev_lives = 0 79 | 80 | def __getattr__(self, e): 81 | return getattr(self.env, e) 82 | 83 | def reset(self) -> np.ndarray: 84 | self.env = gym.make(f'{self.name}Deterministic-v4') 85 | self.unwrapped.seed(0) 86 | self.unprocessed_state = self.env.reset() 87 | self.state = [convert_state(self.unprocessed_state)] 88 | return copy.copy(self.state) 89 | 90 | def get_restore(self): 91 | return ( 92 | self.unwrapped.clone_state(), 93 | copy.copy(self.state), 94 | self.env._elapsed_steps 95 | ) 96 | 97 | def restore(self, data): 98 | ( 99 | full_state, 100 | state, 101 | elapsed_steps 102 | ) = data 103 | self.state = copy.copy(state) 104 | self.env.reset() 105 | self.env._elapsed_steps = elapsed_steps 106 | self.env.unwrapped.restore_state(full_state) 107 | return copy.copy(self.state) 108 | 109 | def step(self, action) -> typing.Tuple[np.ndarray, float, bool, dict]: 110 | self.unprocessed_state, reward, done, lol = self.env.step(action) 111 | self.state.append(convert_state(self.unprocessed_state)) 112 | self.state.pop(0) 113 | 114 | cur_lives = self.env.unwrapped.ale.lives() 115 | if self.end_on_death and cur_lives < self.prev_lives: 116 | done = True 117 | self.prev_lives = cur_lives 118 | 119 | return copy.copy(self.state), reward, done, lol 120 | 121 | def get_pos(self): 122 | # NOTE: this only returns a dummy position 123 | return AtariPosLevel() 124 | 125 | def render_with_known(self, known_positions, resolution, show=True, filename=None, combine_val=max, 126 | get_val=lambda x: x.score, minmax=None): 127 | pass 128 | -------------------------------------------------------------------------------- /robustified/goexplore_py/generic_goal_conditioned_env.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | from .import_ai import * 13 | import types 14 | 15 | 16 | def make_robot_env(name, will_render=False): 17 | env = gym.make(name) 18 | 19 | # Fix rendering to 1/ hide the overlay, 2/ read the proper pixels in rgb_array 20 | # mode and 3/ prevent rendering if will_render is not announced (necessary because 21 | # when will_render is announced, we proactively create a viewer as soon as the 22 | # env is created, because creating it later causes inaccuracies). 23 | def render(self, mode='human'): 24 | assert will_render, 'Rendering in an environment with will_render=False' 25 | self._render_callback() 26 | self._get_viewer()._hide_overlay = True 27 | if mode == 'rgb_array': 28 | self._get_viewer().render() 29 | import glfw 30 | width, height = glfw.get_window_size(self.viewer.window) 31 | data = self._get_viewer().read_pixels(width, height, depth=False) 32 | # original image is upside-down, so flip it 33 | return data[::-1, :, :] 34 | elif mode == 'human': 35 | self._get_viewer().render() 36 | 37 | def get_full_state(self): 38 | pass 39 | 40 | def set_full_state(self, state): 41 | pass 42 | 43 | env.unwrapped.render = types.MethodType(render, env.unwrapped) 44 | env.unwrapped.get_full_state = types.MethodType(get_full_state, env.unwrapped) 45 | env.unwrapped.set_full_state = types.MethodType(set_full_state, env.unwrapped) 46 | if will_render: 47 | # Pre-cache the viewer because creating it while the environment is running 48 | # sometimes causes errors 49 | env.unwrapped._get_viewer() 50 | 51 | if 'Fetch' in name: 52 | # The way _render_callback is implemented in Fetch environments causes issues. 53 | # This monkey patch fixes them. 54 | def _render_callback(self): 55 | # Visualize target. 56 | sites_offset = (self.sim.data.site_xpos - self.sim.model.site_pos).copy() 57 | site_id = self.sim.model.site_name2id('target0') 58 | self.sim.model.site_pos[site_id] = self.goal - sites_offset[0] 59 | 60 | env.unwrapped._render_callback = types.MethodType(_render_callback, env.unwrapped) 61 | 62 | return env 63 | 64 | class DomainConditionedPosLevel: 65 | __slots__ = ['level', 'score', 'room', 'x', 'y', 'tuple'] 66 | 67 | def __init__(self, level=0, score=0, room=0, x=0, y=0): 68 | self.level = level 69 | self.score = score 70 | self.room = room 71 | self.x = x 72 | self.y = y 73 | 74 | self.set_tuple() 75 | 76 | def set_tuple(self): 77 | self.tuple = (self.level, self.score, self.room, self.x, self.y) 78 | 79 | def __hash__(self): 80 | return hash(self.tuple) 81 | 82 | def __eq__(self, other): 83 | if not isinstance(other, DomainConditionedPosLevel): 84 | return False 85 | return self.tuple == other.tuple 86 | 87 | def __getstate__(self): 88 | return self.tuple 89 | 90 | def __setstate__(self, d): 91 | self.level, self.score, self.room, self.x, self.y = d 92 | self.tuple = d 93 | 94 | def __repr__(self): 95 | return f'Level={self.level} Room={self.room} Objects={self.score} x={self.x} y={self.y}' 96 | 97 | 98 | class MyRobot: 99 | TARGET_SHAPE = 0 100 | MAX_PIX_VALUE = 0 101 | 102 | def __init__(self, env_name, interval_size=0.1, seed_low=0, seed_high=0): 103 | self.env_name = env_name 104 | self.env = make_robot_env(env_name) 105 | self.prev_action = np.zeros_like(self.env.action_space.sample()) 106 | self.interval_size = interval_size 107 | self.state = None 108 | self.actual_state = None 109 | self.rooms = [] 110 | self.trajectory = [] 111 | 112 | self.seed_low = seed_low 113 | self.seed_high = seed_high 114 | self.seed = None 115 | 116 | self.cur_achieved_goal = None 117 | self.achieved_has_moved = False 118 | self.score_so_far = 0 119 | 120 | self.follow_grip_until_moved = ('FetchPickAndPlace' in env_name and False) 121 | 122 | self.reset() 123 | 124 | def __getattr__(self, e): 125 | assert self.env is not self 126 | return getattr(self.env, e) 127 | 128 | def pos_from_state(self, seed, state): 129 | if self.follow_grip_until_moved: 130 | pos = state['achieved_goal'] if self.achieved_has_moved else state['observation'][:3] 131 | return np.array([seed, self.achieved_has_moved] + list(pos / self.interval_size), dtype=np.int32) 132 | return np.array([seed, self.score_so_far] + list((state['achieved_goal'] / self.interval_size).astype(np.int32)), dtype=np.int32) 133 | 134 | def reset(self) -> np.ndarray: 135 | self.seed = None 136 | self.trajectory = None 137 | self.actual_state = None 138 | self.cur_achieved_goal = None 139 | self.achieved_has_moved = False 140 | self.score_so_far = 0 141 | self.state = [self.pos_from_state(-1, {'achieved_goal': np.array([]), 'observation': np.array([])})] 142 | return copy.copy(self.state) 143 | 144 | def get_restore(self): 145 | return copy.deepcopy(( 146 | None, 147 | self.env._elapsed_steps, 148 | self.interval_size, 149 | self.cur_achieved_goal, 150 | self.achieved_has_moved, 151 | self.score_so_far, 152 | self.state, 153 | self.actual_state, 154 | self.trajectory, 155 | self.seed, 156 | )) 157 | 158 | def restore(self, data): 159 | ( 160 | simstate, 161 | self.env._elapsed_steps, 162 | self.interval_size, 163 | self.cur_achieved_goal, 164 | self.achieved_has_moved, 165 | self.score_so_far, 166 | state, 167 | actual_state, 168 | trajectory, 169 | seed, 170 | ) = copy.deepcopy(data) 171 | self.reset() 172 | self.seed = seed 173 | for a in trajectory: 174 | self.step(a) 175 | assert np.allclose(self.actual_state['achieved_goal'], actual_state['achieved_goal']) 176 | return copy.copy(self.state) 177 | 178 | def step(self, action): 179 | self.prev_action = copy.deepcopy(self.prev_action) 180 | self.prev_action[:] = action 181 | if self.trajectory is None: 182 | if self.seed is None: 183 | self.seed = random.randint(self.seed_low, self.seed_high) 184 | self.env.unwrapped.sim.reset() 185 | self.env.seed(self.seed) 186 | self.actual_state = self.env.reset() 187 | self.trajectory = [] 188 | self.state = [self.pos_from_state(self.seed, self.actual_state)] 189 | 190 | self.trajectory.append(copy.copy(self.prev_action)) 191 | action = np.tanh(self.prev_action) 192 | self.actual_state, reward, done, lol = self.env.step(action) 193 | reward = int(reward) + 1 194 | self.score_so_far += reward 195 | self.state = [self.pos_from_state(self.seed, self.actual_state)] 196 | 197 | if not self.achieved_has_moved and self.cur_achieved_goal is not None and not np.allclose(self.cur_achieved_goal, self.actual_state['achieved_goal']): 198 | self.achieved_has_moved = True 199 | self.cur_achieved_goal = self.actual_state['achieved_goal'] 200 | 201 | return copy.copy(self.state), reward, done, lol 202 | 203 | def get_pos(self): 204 | return DomainConditionedPosLevel() 205 | 206 | def render_with_known(self, known_positions, resolution, show=True, filename=None, combine_val=max, 207 | get_val=lambda x: x.score, minmax=None): 208 | pass -------------------------------------------------------------------------------- /robustified/goexplore_py/import_ai.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | from __future__ import print_function 13 | 14 | def is_notebook(): 15 | try: 16 | from IPython import get_ipython as _get_ipython 17 | if 'IPKernelApp' not in _get_ipython().config: # pragma: no cover 18 | raise ImportError("console") 19 | except: 20 | return False 21 | return True 22 | 23 | if not is_notebook(): 24 | import matplotlib 25 | matplotlib.use('Agg') 26 | 27 | from .basics import * 28 | 29 | import warnings as _warnings 30 | # Known to be benign: https://github.com/ContinuumIO/anaconda-issues/issues/6678#issuecomment-337279157 31 | _warnings.filterwarnings('ignore', 'numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88') 32 | 33 | try: 34 | import cv2 35 | except ModuleNotFoundError: 36 | _warnings.warn('cv2 not found') 37 | 38 | try: 39 | import gym 40 | except ModuleNotFoundError: 41 | _warnings.warn('gym not found') 42 | 43 | 44 | try: 45 | if is_notebook(): 46 | from tqdm import tqdm_notebook as tqdm 47 | from tqdm import tnrange as trange 48 | elif sys.stderr.isatty() and False: 49 | from tqdm import tqdm, trange 50 | else: 51 | class tqdm: 52 | def __init__(self, iterator=None, desc=None, smoothing=0, total=None): 53 | self.iterator = iterator 54 | self.desc = desc 55 | self.smoothing = smoothing 56 | self.total = total 57 | if self.total is None: 58 | try: 59 | self.total = len(iterator) 60 | except Exception: 61 | pass 62 | self.n = 0 63 | self.last_printed = 0 64 | self.start_time = time.time() 65 | 66 | def __enter__(self): 67 | self.start_time = time.time() 68 | return self 69 | 70 | def __exit__(self, exc_type, exc_val, exc_tb): 71 | self.refresh(force_print=True, done=True) 72 | 73 | def __iter__(self): 74 | for e in self.iterator: 75 | yield e 76 | self.update(1) 77 | self.refresh() 78 | 79 | def update(self, d): 80 | if d != 0: 81 | self.n += d 82 | self.refresh() 83 | 84 | def refresh(self, force_print=False, done=False): 85 | cur_time = time.time() 86 | if cur_time - self.last_printed < 10 and not force_print: 87 | return 88 | self.last_printed = cur_time 89 | self.write(f'{self.get_desc_str():16}[{self.get_prog_str():26}{self.get_speed_str(cur_time):13}]' + (' DONE' if done else '')) 90 | 91 | def get_desc_str(self): 92 | if self.desc is None: 93 | return '' 94 | return f'{self.desc}: ' 95 | 96 | def get_prog_str(self): 97 | total_str = '' 98 | if isinstance(self.n, int): 99 | if self.total is not None: 100 | total_substr = f'{int(self.total)}' 101 | total_str = f'{self.n / self.total * 100:2.0f}% {self.n:{len(total_substr)}}it/{total_substr}' 102 | else: 103 | total_str = str(self.n) + 'it' 104 | else: 105 | if self.total is not None: 106 | total_substr = f'{self.total:.1f}' 107 | total_str = f'{self.n / self.total * 100:2.0f}% {self.n:{len(total_substr)}.1f}it/{total_substr}' 108 | else: 109 | total_str = f'{self.n:.1f}it' 110 | return total_str 111 | 112 | def get_speed_str(self, cur_time): 113 | if cur_time <= self.start_time: 114 | return '' 115 | speed = self.n / (cur_time - self.start_time) 116 | if speed > 1: 117 | return f' {speed:.1f}it/s' 118 | if speed < 0.000000000001: 119 | return '' 120 | return f' {1/speed:.1f}s/it' 121 | 122 | @classmethod 123 | def write(cls, str): 124 | print(str, file=sys.stderr) 125 | sys.stderr.flush() 126 | 127 | except ModuleNotFoundError: 128 | _warnings.warn('tqdm not found') 129 | 130 | import numpy as np 131 | class RLEArray: 132 | def __init__(self, array, encoded_array=None, compression=1): 133 | import cv2 134 | if array is None: 135 | self.array = encoded_array 136 | else: 137 | assert not isinstance(array, RLEArray) 138 | # Note: 7 seems to be a good tradeoff between size and speed 139 | self.array = cv2.imencode('.png', array, [cv2.IMWRITE_PNG_COMPRESSION, compression])[1].flatten().tobytes() 140 | 141 | def to_np(self): 142 | return cv2.imdecode(np.frombuffer(self.array, np.uint8), 0) 143 | 144 | def tobytes(self): 145 | return self.array 146 | 147 | @classmethod 148 | def frombytes(cls, byt, dtype=np.uint8): 149 | return cls(None, np.frombuffer(byt, dtype=dtype)) 150 | 151 | 152 | import logging 153 | class IgnoreNoHandles(logging.Filter): 154 | def filter(self, record): 155 | if record.getMessage() == 'No handles with labels found to put in legend.': 156 | return 0 157 | return 1 158 | _plt_logger = logging.getLogger('matplotlib.legend') 159 | _plt_logger.addFilter(IgnoreNoHandles()) 160 | 161 | 162 | import matplotlib.pyplot as plt 163 | import numpy as np 164 | 165 | def show_img(im, figsize=None, ax=None, grid=False): 166 | if not ax: fig,ax = plt.subplots(figsize=figsize) 167 | ax.imshow(im) 168 | ax.set_xticks(np.linspace(0, 224, 8)) 169 | ax.set_yticks(np.linspace(0, 224, 8)) 170 | if grid: 171 | ax.grid() 172 | ax.set_yticklabels([]) 173 | ax.set_xticklabels([]) 174 | return ax 175 | 176 | def draw_outline(o, lw): 177 | o.set_path_effects([patheffects.Stroke( 178 | linewidth=lw, foreground='black'), patheffects.Normal()]) 179 | 180 | def draw_rect(ax, b, color='white'): 181 | patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2)) 182 | draw_outline(patch, 4) 183 | 184 | def draw_text(ax, xy, txt, sz=14, color='white'): 185 | text = ax.text(*xy, txt, 186 | verticalalignment='top', color=color, fontsize=sz, weight='bold') 187 | draw_outline(text, 1) 188 | 189 | 190 | 191 | class CircularMemory: 192 | def __init__(self, size): 193 | self.size = size 194 | self.mem = [] 195 | self.start_idx = 0 196 | 197 | def add(self, entry): 198 | if len(self.mem) < self.size: 199 | self.mem.append(entry) 200 | else: 201 | self.mem[self.start_idx] = entry 202 | self.start_idx = (self.start_idx + 1) % self.size 203 | 204 | def sample(self, n): 205 | return random.sample(self.mem, n) 206 | 207 | def __len__(self): 208 | return len(self.mem) 209 | 210 | def __getitem__(self, i): 211 | assert i < len(self) 212 | return self.mem[(self.start_idx + i) % self.size] 213 | 214 | -------------------------------------------------------------------------------- /robustified/goexplore_py/notebook_utils.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | from goexplore_py.import_ai import * 13 | 14 | sns.set_style('whitegrid') 15 | 16 | 17 | def p(x, y, treatment, t_name, treatment_data): 18 | x_to_y = defaultdict(list) 19 | for r in treatment_data[treatment]: 20 | for k, v in zip(r[x], r[y]): 21 | k = (k // 50000) * 50000 22 | x_to_y[k].append(v) 23 | 24 | all_x = [] 25 | all_y = [] 26 | for k, v in x_to_y.items(): 27 | for e in v: 28 | all_x.append(k) 29 | all_y.append(e) 30 | sns.lineplot(all_x, all_y, label=t_name) 31 | 32 | 33 | def plot(xs, ys, treatments, treatment_data, keys_of_interest, keys_of_interest_dict, pretty_names): 34 | for x, xname in xs: 35 | for y, yname in ys: 36 | plt.figure(figsize=(8, 5)) 37 | for treatment in treatments: 38 | key = create_key(treatment, keys_of_interest, keys_of_interest_dict) 39 | name = create_name(treatment, pretty_names) 40 | p(x, y, key, name, treatment_data) 41 | plt.title(yname + ' over ' + xname) 42 | plt.xlabel(xname) 43 | plt.ylabel(yname) 44 | plt.legend() 45 | plt.savefig(yname + ' over ' + xname + '.png') 46 | plt.show() 47 | 48 | 49 | def gather_treatments(results_folder, keys_of_interest): 50 | treatment_dict = {} 51 | for folder in tqdm(sorted(glob.glob(results_folder + '/*'))): 52 | meta_data = json.load(open(folder + '/kwargs.json')) 53 | meta_key = [] 54 | for key in keys_of_interest: 55 | meta_key.append(meta_data[key]) 56 | meta_key_tuple = tuple(meta_key) 57 | if meta_key_tuple not in treatment_dict: 58 | treatment_dict[meta_key_tuple] = [] 59 | treatment_dict[meta_key_tuple].append(folder) 60 | return treatment_dict 61 | 62 | 63 | def collect_data(treatments, treatment_dict, keys_of_interest, keys_of_interest_dict, pretty_names): 64 | treatment_data = {} 65 | for treatment_id in treatments: 66 | treatment_key = create_key(treatment_id, keys_of_interest, keys_of_interest_dict) 67 | all_res = [] 68 | print("Loading treatment:", create_name(treatment_id, pretty_names)) 69 | for folder in tqdm(treatment_dict[treatment_key]): 70 | compute_frames = [] 71 | real_frames = [] 72 | n_found = [] 73 | max_score = [] 74 | n_rooms = [] 75 | n_objects = [] 76 | for f in sorted(glob.glob('%s/*_set.7z' % folder)): 77 | data = pickle.load(lzma.open(f, 'rb')) 78 | real, compute = f.split('/')[-1].split('_set.')[0].split('_') 79 | compute_frames.append(int(compute)) 80 | real_frames.append(int(real)) 81 | n_found.append(len(data)) 82 | max_score.append(max(data[e] for e in data)) 83 | n_rooms.append(len(set((e.level, e.room) for e in data))) 84 | all_res.append({'compute': compute_frames, 'real': real_frames, 'found': n_found, 'score': max_score, 85 | 'rooms': n_rooms}) 86 | treatment_data[treatment_key] = all_res 87 | return treatment_data 88 | 89 | 90 | def create_key(param_dict, keys_of_interest, keys_of_interest_dict): 91 | key_proto = [] 92 | for key in keys_of_interest: 93 | if key in param_dict: 94 | key_proto.append(param_dict[key]) 95 | else: 96 | key_proto.append(keys_of_interest_dict[key]) 97 | return tuple(key_proto) 98 | 99 | 100 | def create_name(param_dict, pretty_names): 101 | name = "" 102 | for i, key in enumerate(param_dict.keys()): 103 | name += pretty_names[key] + " " + str(param_dict[key]) 104 | if i != len(param_dict) - 1: 105 | name += " " 106 | return name 107 | -------------------------------------------------------------------------------- /robustified/goexplore_py/utils.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | from .import_ai import * 13 | 14 | class TimedPickle: 15 | def __init__(self, data, name, enabled=True): 16 | self.data = data 17 | self.name = name 18 | self.enabled = enabled 19 | 20 | def __getstate__(self): 21 | return (time.time(), self.data, self.name, self.enabled) 22 | 23 | def __setstate__(self, s): 24 | tstart, self.data, self.name, self.enabled = s 25 | if self.enabled: 26 | print(f'pickle time for {self.name} = {time.time() - tstart} seconds') 27 | 28 | 29 | @contextmanager 30 | def use_seed(seed): 31 | # Save all the states 32 | python_state = random.getstate() 33 | np_state = np.random.get_state() 34 | 35 | # Seed all the rngs (note: adding different values to the seeds 36 | # in case the same underlying RNG is used by all and in case 37 | # that could be a problem. Probably not necessary) 38 | random.seed(seed) 39 | np.random.seed(seed + 1) 40 | 41 | # Yield control! 42 | yield 43 | 44 | # Reset the rng states 45 | random.setstate(python_state) 46 | np.random.set_state(np_state) 47 | 48 | 49 | def get_code_hash(): 50 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 51 | all_code = '' 52 | for f in sorted(glob.glob(cur_dir + '**/*.py', recursive=True)): 53 | # We assume all whitespace is irrelevant, as well as comments 54 | with open(f) as f: 55 | for line in f: 56 | line = line.partition('#')[0] 57 | line = line.rstrip() 58 | 59 | all_code += ''.join(line.split()) 60 | 61 | hash = hashlib.sha256(all_code.encode('utf8')).hexdigest() 62 | print('HASH', hash) 63 | 64 | return hash 65 | 66 | 67 | def imdownscale(state, target_shape, max_pix_value): 68 | if state.shape[::-1] == target_shape: 69 | resized = state 70 | else: 71 | resized = cv2.resize(state, target_shape, interpolation=cv2.INTER_AREA) 72 | img = ((resized / 255.0) * max_pix_value).astype(np.uint8) 73 | return RLEArray(img) 74 | -------------------------------------------------------------------------------- /robustified/goexplore_py/visualize.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2020 Uber Technologies, Inc. 3 | 4 | # Licensed under the Uber Non-Commercial License (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at the root directory of this project. 7 | 8 | # See the License for the specific language governing permissions and 9 | # limitations under the License. 10 | 11 | 12 | import numpy as np 13 | import cv2 14 | import matplotlib.pyplot as plt 15 | import matplotlib.cm 16 | import matplotlib.colors 17 | import pickle 18 | from collections import defaultdict 19 | 20 | from goexplore_py.montezuma_env import PYRAMID 21 | 22 | 23 | def render_with_known(data, filename): 24 | height, width = data[1][1].shape[:2] 25 | 26 | final_image = np.zeros((height * 4, width * 9, 3), dtype=np.uint8) + 255 27 | 28 | positions = PYRAMID 29 | 30 | def room_pos(room): 31 | for height, l in enumerate(positions): 32 | for width, r in enumerate(l): 33 | if r == room: 34 | return (height, width) 35 | return None 36 | 37 | points = defaultdict(int) 38 | 39 | # print(final_image) 40 | 41 | for room in range(24): 42 | if room in data: 43 | img = data[room][1] 44 | else: 45 | img = np.zeros((height, width, 3)) + 127 46 | y_room, x_room = room_pos(room) 47 | y_room *= height 48 | x_room *= width 49 | final_image[y_room:y_room + height, x_room:x_room + width, :] = img 50 | 51 | plt.figure(figsize=(final_image.shape[1] // 30, final_image.shape[0] // 30)) 52 | 53 | plt.imshow(final_image) 54 | 55 | plt.axis('off') 56 | plt.savefig(filename, bbox_inches='tight') 57 | plt.close() 58 | 59 | 60 | def main(): 61 | with open("all_rooms.pickle", "rb") as file: 62 | data = pickle.load(file) 63 | render_with_known(data, "test.png") 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /robustified/phase1_downscaled.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | game=${1:-MontezumaRevenge} 15 | results=${2:-results} 16 | frames=${3:-500000000} 17 | 18 | python goexplore_py/main.py --seen_weight=1.0 --reset_cell_on_update --base_path=$results --game=generic_$game --cell_split_factor=0.125 --first_compute_archive_size=1 --first_compute_dynamic_state=10000 --max_archive_size=50000 --max_recent_frames=10000 --recent_frame_add_prob=0.01 --recompute_dynamic_state_every=10000000 --split_iterations=3000 --state_is_pixels --high_score_weight=0.0 --max_hours=256 --max_compute_steps=$frames --n_cpus=88 --batch_size=100 --dynamic_state 19 | -------------------------------------------------------------------------------- /robustified/phase1_fetch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | game=${1:-1000} 15 | results=${2:-results} 16 | frames=${3:-20000000} 17 | 18 | python goexplore_py/main.py --keep_checkpoints --door_resolution=0.2 --door_offset=0.195 --target_location=$game --base_path=$results --fetch_type=boxes_1 --nsubsteps=80 --total_timestep=0.08 --minmax_grip_score=00 --game=fetch --seen_weight=1.0 --max_hours=128 --checkpoint_compute=500000 --max_compute_steps=$frames --repeat_action=10 --explore_steps=30 19 | -------------------------------------------------------------------------------- /robustified/phase1_montezuma.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | results=${2:-results} 15 | frames=${3:-250000000} 16 | 17 | python goexplore_py/main.py --pitfall_treasure_type=none --seen_weight=1 --high_score_weight=1 --horiz_weight=0.1 --vert_weight=0 --low_level_weight=0.1 --remember_rooms --reset_cell_on_update --game=montezuma --max_hours=256 --max_compute_steps=$frames --base_path=$results 18 | 19 | -------------------------------------------------------------------------------- /robustified/phase1_pitfall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | results=${2:-results} 15 | frames=${3:-250000000} 16 | 17 | python goexplore_py/main.py --seen_weight=1 --chosen_weight=0 --chosen_since_new_weight=0 --high_score_weight=0 --horiz_weight=0 --vert_weight=0 --pitfall_treasure_type=none --remember_rooms --reset_cell_on_update --game=pitfall --max_hours=256 --max_compute_steps=$frames --base_path=$results 18 | 19 | -------------------------------------------------------------------------------- /robustified/phase2_atari.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | game=${1:-MontezumaRevenge} 15 | demo=${2:-`pwd`/demos} 16 | results=${3:-`pwd`/results} 17 | frames=${4:-2500000000} 18 | 19 | python atari_reset/train_atari.py --game=$game --sil_vf_coef=0.01 --demo_selection=normalize_by_target --ent_coef=1e-05 --sil_ent_coef=1e-05 --n_sil_envs=2 --extra_sil_from_start_prob=0.3 --autoscale_fn=mean --sil_coef=0.1 --gamma=0.999 --move_threshold=0.1 --autoscale=10 --from_start_demo_reward_interval_factor=20000000 --nrstartsteps=160 --sil_pg_weight_by_value --sil_weight_success_rate --sil_vf_relu --sticky --noops --max_demo_len=400000 --autoscale_fn=mean --num_timesteps=$frames --autoscale_value --nenvs=32 --demo=$demo --no_game_over_on_life_loss --test_from_start --steps_per_demo=200 --no_videos --save_path=$results 20 | -------------------------------------------------------------------------------- /robustified/phase2_atari_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | game=${1:-MontezumaRevenge} 15 | source=${2:-`pwd`/results} 16 | results=${3:-`pwd`/test_results} 17 | 18 | python atari_reset/check_atari.py --num_timesteps=100000000000 --noops --sticky --num_per_noop=1000 --load_path=$source --game=$game --save_path=$results 19 | -------------------------------------------------------------------------------- /robustified/phase2_fetch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright (c) 2020 Uber Technologies, Inc. 4 | 5 | # Licensed under the Uber Non-Commercial License (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at the root directory of this project. 8 | 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | 14 | game=${1:-1000} 15 | demo=${2:-`pwd`/demos} 16 | results=${3:-`pwd`/results} 17 | frames=${4:-750000000} 18 | 19 | python atari_reset/train_atari.py --sil_vf_c=0.1 --ffmem=128 --ffsh=1x256 --learning_r=0.0001 --demo_sel=normalize_by_target --nrst=40 --ent_coef=1e-05 --sil_coef=0.1 --extra_frames_exp_factor=4 --allowed_lag=10 --sil_ent=1e-05 --sil_weight_success_rate --sil_vf_relu --num_timesteps $frames --sil_pg_weight_by_value --fetch_target_location=$game --extra_sil_from_start_prob=0 --extra_sil_before_demo_max=10 --fetch_type=boxes_1 --fetch_nsubsteps=80 --fetch_total_timestep=0.08 --nenvs=120 --n_sil_envs=8 --sd_multiply_explore=2 --inc_entropy_threshold=10 --fetch_incl_extra_full_state --game=fetch --demo $demo --gamma=0.99 --vf_coef=0.5 --steps_per_demo=100 --move_threshold=0.1 --save_path=$results 20 | -------------------------------------------------------------------------------- /robustified/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astor==0.8.1 3 | baselines==0.1.5 4 | certifi==2020.4.5.1 5 | cffi==1.14.0 6 | chardet==3.0.4 7 | click==7.1.2 8 | cloudpickle==1.4.1 9 | cycler==0.10.0 10 | Cython==0.29.17 11 | dataclasses==0.7 12 | dill==0.3.1.1 13 | fasteners==0.15 14 | filelock==3.0.12 15 | gast==0.2.2 16 | glfw==1.11.0 17 | google-pasta==0.2.0 18 | grpcio==1.28.1 19 | gym==0.10.11 20 | h5py==2.10.0 21 | horovod==0.19.1 22 | idna==2.9 23 | imageio==2.8.0 24 | imageio-ffmpeg==0.4.1 25 | joblib==0.14.1 26 | Keras-Applications==1.0.8 27 | Keras-Preprocessing==1.1.0 28 | kiwisolver==1.2.0 29 | loky==2.3.1 30 | Markdown==3.2.1 31 | matplotlib==3.2.1 32 | monotonic==1.5 33 | mpi4py==3.0.3 34 | numpy==1.18.4 35 | opencv-python==4.2.0.34 36 | opt-einsum==3.2.1 37 | Pillow==7.1.2 38 | progressbar2==3.51.3 39 | protobuf==3.11.3 40 | psutil==5.7.0 41 | pycparser==2.20 42 | pyglet==1.5.5 43 | PyOpenGL==3.1.5 44 | pyparsing==2.4.7 45 | python-dateutil==2.8.1 46 | python-utils==2.4.0 47 | PyYAML==5.3.1 48 | requests==2.23.0 49 | scipy==1.4.1 50 | six==1.14.0 51 | tensorboard==1.15.0 52 | tensorflow==1.15.2 53 | tensorflow-estimator==1.15.1 54 | termcolor==1.1.0 55 | tqdm==4.46.0 56 | urllib3==1.25.9 57 | Werkzeug==1.0.1 58 | wrapt==1.12.1 59 | --------------------------------------------------------------------------------