├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── colab └── plot_training_graphs.ipynb ├── episodic_curiosity ├── __init__.py ├── constants.py ├── constants_test.py ├── curiosity_env_wrapper.py ├── curiosity_env_wrapper_test.py ├── curiosity_evaluation.py ├── curiosity_evaluation_test.py ├── env_factory.py ├── env_factory_test.py ├── environments │ ├── __init__.py │ ├── dmlab_utils.py │ └── fake_gym_env.py ├── episodic_memory.py ├── episodic_memory_test.py ├── eval_policy.py ├── generate_r_training_data.py ├── generate_r_training_data_test.py ├── keras_checkpoint.py ├── logging.py ├── oracle.py ├── r_network.py ├── r_network_training.py ├── r_network_training_test.py ├── train_policy.py ├── train_r.py ├── train_r_test.py ├── utils.py └── visualize_curiosity_reward.py ├── misc ├── ant_github.gif └── navigation_github.gif ├── scripts ├── gs_sync.py ├── launch_cloud_vms.py ├── launcher_script.py ├── vm_drop_root.sh └── vm_start.sh ├── setup.py └── third_party ├── __init__.py ├── baselines ├── LICENSE ├── __init__.py ├── a2c │ ├── __init__.py │ └── utils.py ├── bench │ ├── __init__.py │ └── monitor.py ├── common │ ├── __init__.py │ ├── atari_wrappers.py │ ├── cmd_util.py │ ├── distributions.py │ ├── input.py │ ├── math_util.py │ ├── misc_util.py │ ├── runners.py │ ├── tf_util.py │ ├── tile_images.py │ └── vec_env │ │ ├── __init__.py │ │ ├── dummy_vec_env.py │ │ ├── subproc_vec_env.py │ │ └── threaded_vec_env.py ├── logger.py └── ppo2 │ ├── __init__.py │ ├── pathak_utils.py │ ├── policies.py │ └── ppo2.py ├── dmlab ├── LICENSE └── dmlab_min_goal_distance.patch ├── gym ├── LICENSE ├── __init__.py ├── ant.py ├── ant_wrapper.py ├── ant_wrapper_test.py ├── assets │ ├── mujoco_ant_custom_texture_camerav2.xml │ └── texture.png └── mujoco_env.py └── keras_resnet ├── LICENSE ├── __init__.py └── models.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | You are welcome to contribute patches to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Episodic Curiosity Through Reachability 2 | 3 | #### In ICLR 2019 [[Project Website](https://sites.google.com/corp/view/episodic-curiosity)][[Paper](https://arxiv.org/abs/1810.02274)] 4 | 5 | [Nikolay Savinov¹](http://people.inf.ethz.ch/nsavinov/), [Anton Raichuk²](https://ai.google/research/people/AntonRaichuk), [Raphaël Marinier²](https://ai.google/research/people/105955), Damien Vincent², [Marc Pollefeys¹](https://www.inf.ethz.ch/personal/marc.pollefeys/), [Timothy Lillicrap³](http://contrastiveconvergence.net/~timothylillicrap/index.php), [Sylvain Gelly²](https://ai.google/research/people/SylvainGelly)
6 | ¹ETH Zurich, ²Google AI, ³DeepMind
7 | 8 | Navigation out of curiosity | Locomotion out of curiosity 9 | --------------------------------------------------- | --------------------------- 10 | | 11 | 12 | This is an implementation of our 13 | [ICLR 2019 Episodic Curiosity Through Reachability](https://arxiv.org/abs/1810.02274). 14 | If you use this work, please cite: 15 | 16 | @inproceedings{Savinov2019_EC, 17 | Author = {Savinov, Nikolay and Raichuk, Anton and Marinier, Rapha{\"e}l and Vincent, Damien and Pollefeys, Marc and Lillicrap, Timothy and Gelly, Sylvain}, 18 | Title = {Episodic Curiosity through Reachability}, 19 | Booktitle = {International Conference on Learning Representations ({ICLR})}, 20 | Year = {2019} 21 | } 22 | 23 | ### Requirements 24 | 25 | The code was tested on Linux only. The code assumes that the command "python" 26 | invokes python 2.7. We recommend you use virtualenv: 27 | 28 | ```shell 29 | sudo apt-get install python-pip 30 | pip install virtualenv 31 | python -m virtualenv episodic_curiosity_env 32 | source episodic_curiosity_env/bin/activate 33 | ``` 34 | 35 | ### Installation 36 | 37 | Clone this repository: 38 | 39 | ```shell 40 | git clone https://github.com/google-research/episodic-curiosity.git 41 | cd episodic-curiosity 42 | ``` 43 | 44 | We require a modified version of 45 | [DeepMind lab](https://github.com/deepmind/lab): 46 | 47 | Clone DeepMind Lab: 48 | 49 | ```shell 50 | git clone https://github.com/deepmind/lab 51 | cd lab 52 | ``` 53 | 54 | Apply our patch to DeepMind Lab: 55 | 56 | ```shell 57 | git checkout 7b851dcbf6171fa184bf8a25bf2c87fe6d3f5380 58 | git checkout -b modified_dmlab 59 | git apply ../third_party/dmlab/dmlab_min_goal_distance.patch 60 | ``` 61 | 62 | Install DMLab as a PIP module by following 63 | [these instructions](https://github.com/deepmind/lab/tree/master/python/pip_package) 64 | 65 | In a nutshell, once you've installed DMLab dependencies, you need to run: 66 | 67 | ```shell 68 | bazel build -c opt python/pip_package:build_pip_package 69 | ./bazel-bin/python/pip_package/build_pip_package /tmp/dmlab_pkg 70 | pip install /tmp/dmlab_pkg/DeepMind_Lab-1.0-py2-none-any.whl --force-reinstall 71 | ``` 72 | 73 | If you wish to run Mujoco experiments (section S1 of the paper), you need to 74 | install dm_control and its dependencies. See [this 75 | documentation](https://github.com/deepmind/dm_control#requirements-and-installation), 76 | and replace `pip install -e .` by `pip install -e .[mujoco]` in the command 77 | below. 78 | 79 | Finally, install episodic curiosity and its pip dependencies: 80 | 81 | ```shell 82 | cd episodic-curiosity 83 | pip install -e . 84 | ``` 85 | 86 | 87 | 88 | ### Resource requirements for training 89 | 90 | | Environment | Training method | Required GPU | Recommended RAM | 91 | | ----------- | ---------- | -------------------- | -------------------------- | 92 | | DMLab | PPO | No | 32GBs | 93 | | DMLab | PPO + Grid Oracle | No | 32GBs | 94 | | DMLab | PPO + EC using already trained R-networks | No | 32GBs | 95 | | DMLab | PPO + EC with R-network training | Yes, otherwise, training is slower by >20x.
Required GPU RAM: 5GBs | 50GBs
Tip: reduce `dataset_buffer_size` for using less RAM at the expense of policy performance. | 96 | | DMLab | PPO + ECO | Yes, otherwise, raining is slower by >20x.
Required GPU RAM: 5GBs | 80GBs
Tip: reduce `observation_history_size` for using less RAM, at the expense of policy performance | 97 | | Mujoco | PPO + EC using already trained R-networks | No | 32GBs | 98 | 99 | 100 | ## Trained models 101 | 102 | Trained R-networks and policies can be found in the 103 | `episodic-curiosity` Google cloud bucket. You can access them via the 104 | [web interface](https://console.cloud.google.com/storage/browser/episodic-curiosity), 105 | or copy them with the `gsutil` command from the 106 | [Google Cloud SDK](https://cloud.google.com/sdk): 107 | 108 | ```shell 109 | gsutil -m cp -r gs://episodic-curiosity/r_networks . 110 | gsutil -m cp -r gs://episodic-curiosity/policies . 111 | ``` 112 | 113 | Example of command to visualize a trained policy with two episodes of 114 | 1000 steps, and create videos similar to the ones at the top of this 115 | page: 116 | 117 | ``` shell 118 | python -m episodic_curiosity.visualize_curiosity_reward --workdir=/tmp/ec_visualizations --r_net_weights= --policy_path= --alsologtostderr --num_episodes=2 --num_steps=1000 --visualization_type=surrogate_reward --trajectory_mode=do_nothing 119 | ``` 120 | 121 | This requires that you install extra dependencies for generating 122 | videos, with `pip install -e .[video]` 123 | 124 | 125 | ## Training 126 | 127 | ### On a single machine 128 | 129 | [scripts/launcher_script.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launcher_script.py) 130 | is the main entry point to reproduce the results of Table 1 in the 131 | [paper](https://arxiv.org/abs/1810.02274). For instance, the following command 132 | line launches training of the *PPO + EC* method on the *Sparse+Doors* scenario: 133 | 134 | ```sh 135 | python episodic-curiosity/scripts/launcher_script.py --workdir=/tmp/ec_workdir --method=ppo_plus_ec --scenario=sparseplusdoors 136 | ``` 137 | 138 | Main flags: 139 | 140 | | Flag | Descriptions | 141 | | :----------- | :--------- | 142 | | --method | Solving method to use, corresponds to the rows in table 1 of the [paper](https://arxiv.org/abs/1810.02274). Possible values: `ppo, ppo_plus_ec, ppo_plus_eco, ppo_plus_grid_oracle` | 143 | | --scenario | Scenario to launch. Corresponds to the columns in table 1 of the [paper](https://arxiv.org/abs/1810.02274). Possible values: `noreward, norewardnofire, sparse, verysparse, sparseplusdoors, dense1, dense2`. `ant_no_reward` is also supported which corresponds to the first row of table S1. | 144 | | --workdir | Directory where logs and checkpoints will be stored. | 145 | | --run_number | Run number of the current run. This is used to create an appropriate subdir in workdir. | 146 | | --r_networks_path | Only meaningful for the `ppo_plus_ec` method. Path to the root dir for pre-trained r networks. If specified, we train the policy using those pre-trained r networks. If not specified, we first generate the R network training data, train the R network and then train the policy. | 147 | 148 | 149 | Training takes a couple of days. We used CPUs with 16 hyper-threads, but smaller 150 | CPUs should do. 151 | 152 | Under the hood, 153 | [launcher_script.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launcher_script.py) 154 | launches 155 | [train_policy.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/train_policy.py) 156 | with the right hyperparameters. For the method `ppo_plus_ec`, it first launches 157 | [generate_r_training_data.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/generate_r_training_data.py) 158 | to accumulate training data for the R-network using a random policy, then 159 | launches 160 | [train_r.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/train_r.py) 161 | to train the R-network, and finally 162 | [train_policy.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/train_policy.py) 163 | for the policy. In the method `ppo_plus_eco`, all this happens online as part of 164 | the policy training. 165 | 166 | ### On Google Cloud 167 | 168 | First, make sure you have the [Google Cloud SDK](https://cloud.google.com/sdk) 169 | installed. 170 | 171 | [scripts/launch_cloud_vms.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launch_cloud_vms.py) 172 | is the main entry point. Edit the script and replace the `FILL-ME`s with the 173 | details of your GCP project. In particular, you will need to point it to a GCP 174 | disk snapshot with the installed dependencies as described in the 175 | [Installation](#Installation) section. 176 | 177 | IMPORTANT: By default the script reproduces all results in table 1 and launches 178 | ~300 VMs on cloud with GPUs (7 scenarios x 4 methods x 10 runs). The cost of 179 | running all those VMs is very significant: on the order of USD 30 **per day** 180 | **per VM** based on early 2019 GCP pricing. Pass 181 | `--i_understand_launching_vms_is_expensive` to 182 | [scripts/launch_cloud_vms.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launch_cloud_vms.py) 183 | to indicate that you understood that. 184 | 185 | Under the hood, `launch_cloud_vms.py` launches one VM for each (scenario, 186 | method, run_number) tuple. The VMs use startup scripts to launch training, and 187 | retrieve the parameters of the run through 188 | [Instance Metadata](https://cloud.google.com/compute/docs/storing-retrieving-metadata). 189 | 190 | TIP: Use `sudo journalctl -u google-startup-scripts.service` to see the logs of 191 | the startup script. 192 | 193 | ### Training logs 194 | 195 | Each training job stores logs and checkpoints in a workdir. The workdir is 196 | organized as follows: 197 | 198 | | File or Directory | Description | 199 | | :----------------------------------------- | :------------------------------ | 200 | | `r_training_data/{R_TRAINING,VALIDATION}/` | TF Records with data generated from a random policy for R-network training. Only for method `ppo_plus_ec` without supplying pre-trained R-networks. | 201 | | `r_networks/` | Keras checkpoints of trained R-networks. Only for method `ppo_plus_ec` without supplying pre-trained R-networks. | 202 | | `reward_{train,valid,test}.csv` | CSV files with {train,valid,test} rewards, tracking the performance of the policy at multiple training steps. | 203 | | `checkpoints/` | Checkpoints of the policy. | 204 | | `log.txt`, `progress.csv` | Training logs and CSV from OpenAI's PPO2 code. | 205 | 206 | On cloud, the workdir of each job will be synced to a cloud bucket directory of 207 | the form `////run_number_/`. 208 | 209 | We provide a 210 | [colab](https://github.com/google-research/episodic-curiosity/blob/master/colab/plot_training_graphs.ipynb) 211 | to plot graphs during training of the policies, using data from the 212 | `reward_{train,valid,test}.csv` files. 213 | 214 | ### Related projects 215 | Check out the code for [Semi-parametric Topological Memory]( 216 | https://github.com/nsavinov/SPTM), which uses graph-based episodic memory 217 | constructed from a short video to navigate in novel environments (thus providing 218 | exploitation policy, complementary to the exploration policy in this work). 219 | 220 | ### Known limitations 221 | 222 | - As of 2019/02/20, `ppo_plus_eco` method is not robust to restarts, because 223 | the R-network trained online is not checkpointed. 224 | 225 | -------------------------------------------------------------------------------- /episodic_curiosity/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | -------------------------------------------------------------------------------- /episodic_curiosity/constants.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Constants for episodic curiosity.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | from __future__ import print_function 22 | 23 | from enum import Enum 24 | 25 | 26 | class Level(object): 27 | """Represents a DMLab level, possibly with additional non-standard settings. 28 | 29 | Attributes: 30 | dmlab_level_name: Name of the DMLab level 31 | fully_qualified_name: Unique name used to distinguish between multiple DMLab 32 | levels with the same name but different settings. 33 | extra_env_settings: dict, additional DMLab environment settings for this 34 | level. 35 | random_maze: Whether the geometry of the maze is supposed to change when we 36 | change the seed. 37 | use_r_net_from_level: If provided, don't train a R-net for this level, but 38 | instead, use the trained R-net from another level 39 | (identified by its fully qualified name). 40 | include_in_paper: Whether this level is included in the paper. 41 | scenarios: Optional list of scenarios this level is used for. 42 | """ 43 | 44 | def __init__(self, 45 | dmlab_level_name, 46 | fully_qualified_name = None, 47 | extra_env_settings = None, 48 | random_maze = False, 49 | use_r_net_from_level = None, 50 | include_in_paper = False, 51 | scenarios = None): 52 | self.dmlab_level_name = dmlab_level_name 53 | self.fully_qualified_name = fully_qualified_name or dmlab_level_name 54 | self.extra_env_settings = extra_env_settings or {} 55 | self.random_maze = random_maze 56 | self.use_r_net_from_level = use_r_net_from_level 57 | self.include_in_paper = include_in_paper 58 | self.scenarios = scenarios 59 | 60 | def asdict(self): 61 | return vars(self) 62 | 63 | 64 | class SplitType(Enum): 65 | R_TRAINING = 0 66 | POLICY_TRAINING = 3 67 | VALIDATION = 1 68 | TEST = 2 69 | 70 | 71 | class Const(object): 72 | """Constants""" 73 | MAX_ACTION_DISTANCE = 5 74 | NEGATIVE_SAMPLE_MULTIPLIER = 5 75 | # env 76 | OBSERVATION_HEIGHT = 120 77 | OBSERVATION_WIDTH = 160 78 | OBSERVATION_CHANNELS = 3 79 | OBSERVATION_SHAPE = (OBSERVATION_HEIGHT, OBSERVATION_WIDTH, 80 | OBSERVATION_CHANNELS) 81 | # model and training 82 | BATCH_SIZE = 64 83 | EDGE_CLASSES = 2 84 | DUMP_AFTER_BATCHES = 100 85 | EDGE_MAX_EPOCHS = 2000 86 | ADAM_PARAMS = { 87 | 'lr': 1e-04, 88 | 'beta_1': 0.9, 89 | 'beta_2': 0.999, 90 | 'epsilon': 1e-08, 91 | 'decay': 0.0 92 | } 93 | ACTION_REPEAT = 4 94 | STORE_CHECKPOINT_EVERY_N_EPOCHS = 30 95 | 96 | LEVELS = [ 97 | # Levels on which we evaluate episodic curiosity. 98 | # Corresponds to 'Sparse' setting in the paper 99 | # (arxiv.org/pdf/1810.02274.pdf). 100 | Level('contributed/dmlab30/explore_goal_locations_large', 101 | fully_qualified_name='explore_goal_locations_large', 102 | random_maze=True, 103 | include_in_paper=True, 104 | scenarios=['sparse', 'noreward', 'norewardnofire']), 105 | 106 | # WARNING!! For explore_goal_locations_large_sparse and 107 | # explore_goal_locations_large_verysparse to work properly (i.e. taking 108 | # into account minGoalDistance), you need to use the dmlab MPM: 109 | # learning/brain/research/dune/rl/dmlab_env_package. 110 | # Corresponds to 'Very Sparse' setting in the paper. 111 | Level( 112 | 'contributed/dmlab30/explore_goal_locations_large', 113 | fully_qualified_name='explore_goal_locations_large_verysparse', 114 | extra_env_settings={ 115 | # Forces the spawn and goals to be further apart. 116 | # Unfortunately, we cannot go much higher, because we need to 117 | # guarantee that for any goal location, we can at least find one 118 | # spawn location that is further than this number (the goal 119 | # location might be in the middle of the map...). 120 | 'minGoalDistance': 10, 121 | }, 122 | use_r_net_from_level='explore_goal_locations_large', 123 | random_maze=True, include_in_paper=True, 124 | scenarios=['verysparse']), 125 | 126 | # Corresponds to 'Sparse+Doors' setting in the paper. 127 | Level('contributed/dmlab30/explore_obstructed_goals_large', 128 | fully_qualified_name='explore_obstructed_goals_large', 129 | random_maze=True, 130 | include_in_paper=True, 131 | scenarios=['sparseplusdoors']), 132 | 133 | # Two levels where we expect to show episodic curiosity does not hurt. 134 | # Corresponds to 'Dense 1' setting in the paper. 135 | Level('contributed/dmlab30/rooms_keys_doors_puzzle', 136 | fully_qualified_name='rooms_keys_doors_puzzle', 137 | include_in_paper=True, 138 | scenarios=['dense1']), 139 | # Corresponds to 'Dense 2' setting in the paper. 140 | Level('contributed/dmlab30/rooms_collect_good_objects_train', 141 | fully_qualified_name='rooms_collect_good_objects_train', 142 | include_in_paper=True, 143 | scenarios=['dense2']), 144 | ] 145 | 146 | MIXER_SEEDS = { 147 | # Equivalent to not setting a mixer seed. Mixer seed to train the 148 | # R-network. 149 | SplitType.R_TRAINING: 0, 150 | # Mixer seed for training the policy. 151 | SplitType.POLICY_TRAINING: 0x3D23BE66, 152 | SplitType.VALIDATION: 0x2B79ED94, # Invented. 153 | SplitType.TEST: 0x600D5EED, # Same as DM's. 154 | } 155 | 156 | @staticmethod 157 | def find_level(fully_qualified_name): 158 | """Finds a DMLab level by fully qualified name.""" 159 | for level in Const.LEVELS: 160 | if level.fully_qualified_name == fully_qualified_name: 161 | return level 162 | # Fallback to the DMLab level with the corresponding name. 163 | return Level(fully_qualified_name, 164 | extra_env_settings = { 165 | # Make 'rooms_exploit_deferred_effects_test', 166 | # 'rooms_collect_good_objects_test' work. 167 | 'allowHoldOutLevels': True 168 | }) 169 | 170 | @staticmethod 171 | def find_level_by_scenario(scenario): 172 | """Finds a DMLab level by scenario name.""" 173 | for level in Const.LEVELS: 174 | if level.scenarios and scenario in level.scenarios: 175 | return level 176 | raise ValueError('Scenario "{}" not found.'.format(scenario)) 177 | -------------------------------------------------------------------------------- /episodic_curiosity/constants_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for dune.rl.episodic_curiosity.constants.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import absltest 23 | from episodic_curiosity import constants 24 | 25 | 26 | class ConstantsTest(absltest.TestCase): 27 | 28 | def test_unique_levels(self): 29 | unique_levels = set() 30 | for level in constants.Const.LEVELS: 31 | self.assertNotIn(level.fully_qualified_name, unique_levels) 32 | unique_levels.add(level.fully_qualified_name) 33 | 34 | def test_find_level(self): 35 | self.assertEqual( 36 | constants.Const.find_level('explore_goal_locations_large') 37 | .dmlab_level_name, 'contributed/dmlab30/explore_goal_locations_large') 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /episodic_curiosity/curiosity_env_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test of curiosity_env_wrapper.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from episodic_curiosity import curiosity_env_wrapper 23 | from episodic_curiosity import episodic_memory 24 | from third_party.baselines.common.vec_env.dummy_vec_env import DummyVecEnv 25 | import gym 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | 30 | class DummyImageEnv(gym.Env): 31 | 32 | def __init__(self): 33 | self._num_actions = 4 34 | self._image_shape = (28, 28, 3) 35 | self._done_prob = 0.01 36 | 37 | self.action_space = gym.spaces.Discrete(self._num_actions) 38 | self.observation_space = gym.spaces.Box( 39 | 0, 255, self._image_shape, dtype=np.float32) 40 | 41 | def seed(self, seed=None): 42 | pass 43 | 44 | def step(self, action): 45 | observation = np.random.normal(size=self._image_shape) 46 | reward = 0.0 47 | done = (np.random.rand() < self._done_prob) 48 | info = {} 49 | return observation, reward, done, info 50 | 51 | def reset(self): 52 | return np.random.normal(size=self._image_shape) 53 | 54 | def render(self, mode='human'): 55 | raise NotImplementedError('Rendering not implemented') 56 | 57 | 58 | # TODO(damienv): To be removed once the code in third_party 59 | # is compatible with python 2. 60 | class HackDummyVecEnv(DummyVecEnv): 61 | 62 | def step_wait(self): 63 | for e in range(self.num_envs): 64 | action = self.actions[e] 65 | if isinstance(self.envs[e].action_space, gym.spaces.Discrete): 66 | action = int(action) 67 | 68 | obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = ( 69 | self.envs[e].step(action)) 70 | if self.buf_dones[e]: 71 | obs = self.envs[e].reset() 72 | self._save_obs(e, obs) 73 | return (np.copy(self._obs_from_buf()), 74 | np.copy(self.buf_rews), 75 | np.copy(self.buf_dones), 76 | list(self.buf_infos)) 77 | 78 | 79 | def embedding_similarity(x1, x2): 80 | assert x1.shape[0] == x2.shape[0] 81 | epsilon = 1e-6 82 | 83 | # Inner product between the embeddings in x1 84 | # and the embeddings in x2. 85 | s = np.sum(x1 * x2, axis=-1) 86 | 87 | s /= np.linalg.norm(x1, axis=-1) * np.linalg.norm(x2, axis=-1) + epsilon 88 | return 0.5 * (s + 1.0) 89 | 90 | 91 | def linear_embedding(m, x): 92 | # Flatten all but the batch dimension if needed. 93 | if len(x.shape) > 2: 94 | x = np.reshape(x, [x.shape[0], -1]) 95 | return np.matmul(x, m) 96 | 97 | 98 | class EpisodicEnvWrapperTest(tf.test.TestCase): 99 | 100 | def EnvFactory(self): 101 | return DummyImageEnv() 102 | 103 | def testResizeObservation(self): 104 | img_grayscale = np.random.randint(low=0, high=256, size=[64, 48, 1]) 105 | img_grayscale = img_grayscale.astype(np.uint8) 106 | resized_img = curiosity_env_wrapper.resize_observation(img_grayscale, 107 | [16, 12, 1]) 108 | self.assertAllEqual([16, 12, 1], resized_img.shape) 109 | 110 | img_color = np.random.randint(low=0, high=256, size=[64, 48, 3]) 111 | img_color = img_color.astype(np.uint8) 112 | resized_img = curiosity_env_wrapper.resize_observation(img_color, 113 | [16, 12, 1]) 114 | self.assertAllEqual([16, 12, 1], resized_img.shape) 115 | resized_img = curiosity_env_wrapper.resize_observation(img_color, 116 | [16, 12, 3]) 117 | self.assertAllEqual([16, 12, 3], resized_img.shape) 118 | 119 | def testEpisodicEnvWrapperSimple(self): 120 | num_envs = 10 121 | vec_env = HackDummyVecEnv([self.EnvFactory] * num_envs) 122 | 123 | embedding_size = 16 124 | vec_episodic_memory = [episodic_memory.EpisodicMemory( 125 | capacity=1000, 126 | observation_shape=[embedding_size], 127 | observation_compare_fn=embedding_similarity) 128 | for _ in range(num_envs)] 129 | 130 | mat = np.random.normal(size=[28 * 28 * 3, embedding_size]) 131 | observation_embedding = lambda x, m=mat: linear_embedding(m, x) 132 | 133 | target_image_shape = [14, 14, 1] 134 | env_wrapper = curiosity_env_wrapper.CuriosityEnvWrapper( 135 | vec_env, vec_episodic_memory, 136 | observation_embedding, 137 | target_image_shape) 138 | 139 | observations = env_wrapper.reset() 140 | self.assertAllEqual([num_envs] + target_image_shape, 141 | observations.shape) 142 | 143 | dummy_actions = [1] * num_envs 144 | for _ in range(100): 145 | previous_mem_length = [len(mem) for mem in vec_episodic_memory] 146 | observations, unused_rewards, dones, unused_infos = ( 147 | env_wrapper.step(dummy_actions)) 148 | current_mem_length = [len(mem) for mem in vec_episodic_memory] 149 | 150 | self.assertAllEqual([num_envs] + target_image_shape, 151 | observations.shape) 152 | for k in range(num_envs): 153 | if dones[k]: 154 | self.assertEqual(1, current_mem_length[k]) 155 | else: 156 | self.assertGreaterEqual(current_mem_length[k], 157 | previous_mem_length[k]) 158 | 159 | 160 | if __name__ == '__main__': 161 | tf.test.main() 162 | -------------------------------------------------------------------------------- /episodic_curiosity/curiosity_evaluation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to evaluate exploration.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | from __future__ import print_function 22 | 23 | from episodic_curiosity.constants import Const 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | 28 | class OracleExplorationReward(object): 29 | """Class that computes the ideal exploration bonus.""" 30 | 31 | def __init__(self, reward_grid_size): 32 | """Creates a new oracle to compute the exploration reward.""" 33 | self._reward_grid_size = reward_grid_size 34 | self._collected_positions = set() 35 | 36 | def update_state(self, agent_position): 37 | """Set the new state (i.e. the position). 38 | 39 | Args: 40 | agent_position: x,y,z position of the agent. 41 | 42 | Returns: 43 | The exploration bonus for having visited this position. 44 | """ 45 | x, y, z = agent_position 46 | quantized_x = int(x / self._reward_grid_size) 47 | quantized_y = int(y / self._reward_grid_size) 48 | quantized_z = int(z / self._reward_grid_size) 49 | position_id = (quantized_x, quantized_y, quantized_z) 50 | if position_id in self._collected_positions: 51 | # No reward if the position has already been explored. 52 | return 0.0 53 | else: 54 | self._collected_positions.add(position_id) 55 | return 1.0 56 | 57 | 58 | def policy_state_coverage(env, policy_action, reward_grid_size, 59 | eval_time_steps): 60 | """Computes the maze coverage by a given policy. 61 | 62 | Args: 63 | env: A Gym environment. 64 | policy_action: Function which, given a state, returns an action. 65 | e.g. For DMLab, the policy will return an Actions protobuf. 66 | reward_grid_size: L1 distance between 2 consecutive curiosity reward. 67 | eval_time_steps: List of times after which state coverage should be 68 | computed. 69 | 70 | Returns: 71 | The total number of cumulative rewards at different times 72 | in the episode. 73 | """ 74 | max_episode_length = max(eval_time_steps) + 1 75 | 76 | # During test, the exploration reward is given by an oracle that has 77 | # access to the agent coordinates. 78 | cumulative_exploration_reward = 0.0 79 | oracle_exploration_reward = OracleExplorationReward( 80 | reward_grid_size=reward_grid_size) 81 | 82 | # Initial observation. 83 | observation = env.reset() 84 | 85 | reward_list = {} 86 | for k in range(max_episode_length): 87 | action = policy_action(observation) 88 | 89 | # TODO(damienv): Repeating an action should be a wrapper 90 | # of the environment. 91 | repeat_action_count = Const.ACTION_REPEAT 92 | done = False 93 | for _ in range(repeat_action_count): 94 | observation, _, done, metadata = env.step(action) 95 | if done: 96 | break 97 | 98 | # Abort if we've reached the end of the episode. 99 | if done: 100 | reward_list[k] = cumulative_exploration_reward 101 | break 102 | 103 | # Convert the new agent position into a possible exploration bonus. 104 | # Note: getting the position of the agent is specific to DMLab. 105 | agent_position = metadata['position'] 106 | cumulative_exploration_reward += oracle_exploration_reward.update_state( 107 | agent_position) 108 | 109 | step_count = k + 1 110 | if step_count in eval_time_steps: 111 | reward_list[step_count] = cumulative_exploration_reward 112 | 113 | return reward_list 114 | 115 | 116 | class PolicyWrapper(object): 117 | """Wrap a policy defined by some input/output nodes in a TF graph.""" 118 | 119 | def __init__(self, 120 | input_observation, 121 | input_state, 122 | output_state, 123 | output_actions): 124 | # The tensors to feed the policy and to retrieve the relevant outputs. 125 | self._input_observation = input_observation 126 | self._input_state = input_state 127 | self._output_state = output_state 128 | self._output_actions = output_actions 129 | 130 | # TensorFlow session. 131 | self._sess = None 132 | 133 | def set_session(self, tf_session): 134 | self._sess = tf_session 135 | 136 | def reset(self): 137 | self._current_state = np.zeros( 138 | self._input_state.get_shape().as_list(), 139 | dtype=np.float32) 140 | 141 | def action(self, observation): 142 | """Action to perform given an observation.""" 143 | # Converts to batched obervation (with batch_size=1). 144 | observation = np.expand_dims(observation, axis=0) 145 | 146 | # Run the underlying policy and update the state of the policy. 147 | actions, next_state = self._sess.run( 148 | [self._output_actions, self._output_state], 149 | feed_dict={self._input_observation: observation, 150 | self._input_state: self._current_state}) 151 | self._current_state = next_state 152 | 153 | # Un-batch the action. 154 | action = actions[0] 155 | return action 156 | 157 | 158 | def load_policy(graph_def, 159 | input_observation_name, 160 | input_state_name, 161 | output_state_name, 162 | output_pd_params_name, 163 | tf_sampling_fn): 164 | """Load a policy from a graph file. 165 | 166 | Args: 167 | graph_def: Graph definition. 168 | input_observation_name: Name in the graph definition of the tensor 169 | of observations. 170 | input_state_name: Name in the graph definition of the tensor 171 | of input states. 172 | output_state_name: Name in the graph definition of the tensor of output 173 | states. 174 | output_pd_params_name: Name in the graph definition of the tensor 175 | representing the parameters of the distribution over actions. 176 | tf_sampling_fn: Function which samples action based in the probability 177 | distribution parameters (given by output_pd_params_name). 178 | 179 | Returns: 180 | Returns a python policy. 181 | """ 182 | tensors = tf.import_graph_def( 183 | graph_def, 184 | return_elements=[input_observation_name, 185 | input_state_name, 186 | output_state_name, 187 | output_pd_params_name], 188 | name='') 189 | 190 | input_observation, input_state, output_state, output_pd_params = tensors 191 | output_actions = tf_sampling_fn(output_pd_params) 192 | 193 | return PolicyWrapper(input_observation, 194 | input_state, 195 | output_state, 196 | output_actions) 197 | -------------------------------------------------------------------------------- /episodic_curiosity/curiosity_evaluation_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Simple test of the curiosity evaluation.""" 17 | 18 | from __future__ import division 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | 22 | from episodic_curiosity import curiosity_evaluation 23 | from episodic_curiosity.environments import fake_gym_env 24 | import numpy as np 25 | import tensorflow as tf 26 | from tensorflow.contrib import layers as contrib_layers 27 | 28 | 29 | def random_policy(unused_observation): 30 | action = np.random.randint(low=0, high=fake_gym_env.FakeGymEnv.NUM_ACTIONS) 31 | return action 32 | 33 | 34 | class CuriosityEvaluationTest(tf.test.TestCase): 35 | 36 | def EvalPolicy(self, policy): 37 | env = fake_gym_env.FakeGymEnv() 38 | 39 | # Distance between 2 consecutive curiosity rewards. 40 | reward_grid_size = 10.0 41 | 42 | # Times of evaluation. 43 | eval_time_steps = [100, 500, 3000] 44 | 45 | rewards = curiosity_evaluation.policy_state_coverage( 46 | env, policy, reward_grid_size, eval_time_steps) 47 | 48 | # The exploration reward is at most the number of steps. 49 | # It is equal to the number of steps when the policy explores a new state 50 | # at every time step. 51 | print('Curiosity reward: {}'.format(rewards)) 52 | for k, r in rewards.items(): 53 | self.assertGreaterEqual(k, r) 54 | 55 | def testRandomPolicy(self): 56 | self.EvalPolicy(random_policy) 57 | 58 | def testNNPolicy(self): 59 | batch_size = 1 60 | x = tf.placeholder( 61 | tf.float32, 62 | shape=(batch_size,) + fake_gym_env.FakeGymEnv.OBSERVATION_SHAPE) 63 | x = tf.div(x, 255.0) 64 | 65 | # This is just to make the test run fast enough. 66 | x_downscaled = tf.image.resize_images(x, [8, 8]) 67 | x_downscaled = tf.reshape(x_downscaled, [batch_size, -1]) 68 | 69 | # Logits to select the action. 70 | num_actions = 7 71 | h = contrib_layers.fully_connected( 72 | inputs=x_downscaled, num_outputs=32, activation_fn=None, scope='fc0') 73 | h = tf.nn.relu(h) 74 | y_logits = contrib_layers.fully_connected( 75 | inputs=h, num_outputs=num_actions, activation_fn=None, scope='fc1') 76 | temperature = 100.0 77 | y_logits /= temperature 78 | 79 | # Draw the action according to the distribution inferred by the logits. 80 | r = tf.random_uniform(tf.shape(y_logits), 81 | minval=0.001, maxval=0.999) 82 | y_logits -= tf.log(-tf.log(r)) 83 | y = tf.argmax(y_logits, axis=-1) 84 | 85 | input_state = tf.placeholder(tf.float32, shape=(37)) 86 | output_state = input_state 87 | 88 | # Policy from the previous network. 89 | policy = curiosity_evaluation.PolicyWrapper( 90 | x, input_state, output_state, y) 91 | 92 | global_init_op = tf.global_variables_initializer() 93 | with self.test_session() as sess: 94 | sess.run(global_init_op) 95 | 96 | policy.set_session(sess) 97 | policy.reset() 98 | 99 | self.EvalPolicy(policy.action) 100 | 101 | 102 | if __name__ == '__main__': 103 | tf.test.main() 104 | -------------------------------------------------------------------------------- /episodic_curiosity/env_factory_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test of env_factory.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | from absl import flags 25 | from episodic_curiosity import env_factory 26 | from episodic_curiosity import r_network 27 | from episodic_curiosity.environments import fake_gym_env 28 | from third_party.keras_resnet import models 29 | import numpy as np 30 | import tensorflow as tf 31 | from tensorflow import keras 32 | 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | 37 | class EnvFactoryTest(tf.test.TestCase): 38 | 39 | def setUp(self): 40 | super(EnvFactoryTest, self).setUp() 41 | keras.backend.clear_session() 42 | self.weight_path = os.path.join(tf.test.get_temp_dir(), 'weights.h5') 43 | self.input_shape = fake_gym_env.FakeGymEnv.OBSERVATION_SHAPE 44 | self.dumped_r_network, _, _ = models.ResnetBuilder.build_siamese_resnet_18( 45 | self.input_shape) 46 | self.dumped_r_network.compile( 47 | loss='categorical_crossentropy', optimizer=keras.optimizers.Adam()) 48 | self.dumped_r_network.save_weights(self.weight_path) 49 | 50 | def testCreateRNetwork(self): 51 | r_network.RNetwork(self.input_shape, self.weight_path) 52 | 53 | def testCreateAndRunEnvironment(self): 54 | # pylint: disable=g-long-lambda 55 | env_factory.create_single_env = ( 56 | lambda level_name, seed, dmlab_homepath, use_monitor, split, action_set: 57 | fake_gym_env.FakeGymEnv()) 58 | # pylint: enable=g-long-lambda 59 | 60 | env, env_valid, env_test = env_factory.create_environments( 61 | 'explore_object_locations_small', 1, self.weight_path) 62 | env.reset() 63 | actions = [0] 64 | for _ in range(5): 65 | env.step(actions) 66 | env.close() 67 | env_valid.close() 68 | env_test.close() 69 | 70 | def testRNetworkLazyLoading(self): 71 | """Tests that the RNetwork weights are lazy loaded.""" 72 | # pylint: disable=g-long-lambda 73 | rand_obs = lambda: np.random.uniform(low=-0.01, high=0.01, 74 | size=(1,) + self.input_shape) 75 | obs1 = rand_obs() 76 | obs2 = rand_obs() 77 | expected_similarity = self.dumped_r_network.predict([obs1, obs2]) 78 | r_net = r_network.RNetwork(self.input_shape, self.weight_path) 79 | with self.test_session() as sess: 80 | sess.run(tf.global_variables_initializer()) 81 | emb1 = r_net.embed_observation(obs1) 82 | emb2 = r_net.embed_observation(obs2) 83 | similarity = r_net.embedding_similarity(emb1, emb2) 84 | self.assertAlmostEqual(similarity[0], expected_similarity[0, 1]) 85 | 86 | 87 | if __name__ == '__main__': 88 | tf.test.main() 89 | -------------------------------------------------------------------------------- /episodic_curiosity/environments/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | -------------------------------------------------------------------------------- /episodic_curiosity/environments/fake_gym_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Fake gym environment. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import gym 24 | import numpy as np 25 | 26 | 27 | # This class is shared among multiple tests. Please refrain from adding logic 28 | # that is specific to your test in this class. Instead, you should create a new 29 | # local fake env that is specific to your use-case (possibly inheriting from 30 | # this one). 31 | class FakeGymEnv(gym.Env): 32 | """Fake gym environment.""" 33 | OBSERVATION_HEIGHT = 120 34 | OBSERVATION_WIDTH = 160 35 | OBSERVATION_CHANNELS = 3 36 | OBSERVATION_SHAPE = (OBSERVATION_HEIGHT, OBSERVATION_WIDTH, 37 | OBSERVATION_CHANNELS) 38 | NUM_ACTIONS = 4 39 | EPISODE_LENGTH = 100 40 | 41 | def __init__(self): 42 | self.action_space = gym.spaces.Discrete(self.NUM_ACTIONS) 43 | self.observation_space = gym.spaces.Box( 44 | 0, 255, self.OBSERVATION_SHAPE, dtype=np.float32) 45 | self.episode_step = 0 46 | 47 | def seed(self, seed=None): 48 | pass 49 | 50 | def _observation(self): 51 | return np.random.randint( 52 | low=0, high=255, size=self.OBSERVATION_SHAPE, dtype=np.uint8) 53 | 54 | def step(self, action): 55 | observation = self._observation() 56 | reward = 0.0 57 | done = self.episode_step >= self.EPISODE_LENGTH 58 | self.episode_step += 1 59 | info = {'position': np.random.uniform(low=0, high=1000, size=[3])} 60 | return observation, reward, done, info 61 | 62 | def reset(self): 63 | self.episode_step = 0 64 | return self._observation() 65 | 66 | def render(self, mode='human'): 67 | raise NotImplementedError('Rendering not implemented') 68 | -------------------------------------------------------------------------------- /episodic_curiosity/episodic_memory.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Class that represents an episodic memory.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | 20 | from __future__ import print_function 21 | 22 | import gin 23 | import numpy as np 24 | 25 | 26 | @gin.configurable 27 | class EpisodicMemory(object): 28 | """Episodic memory.""" 29 | 30 | def __init__(self, 31 | observation_shape, 32 | observation_compare_fn, 33 | replacement='fifo', 34 | capacity=200): 35 | """Creates an episodic memory. 36 | 37 | Args: 38 | observation_shape: Shape of an observation. 39 | observation_compare_fn: Function used to measure similarity between 40 | two observations. This function returns the estimated probability that 41 | two observations are similar. 42 | replacement: String to select the behavior when a sample is added 43 | to the memory when this one is full. 44 | Can be one of: 'fifo', 'random'. 45 | 'fifo' keeps the last "capacity" samples into the memory. 46 | 'random' results in a geometric distribution of the age of the samples 47 | present in the memory. 48 | capacity: Capacity of the episodic memory. 49 | 50 | Raises: 51 | ValueError: when the replacement scheme is invalid. 52 | """ 53 | self._capacity = capacity 54 | self._replacement = replacement 55 | if self._replacement not in ['fifo', 'random']: 56 | raise ValueError('Invalid replacement scheme') 57 | self._observation_shape = observation_shape 58 | self._observation_compare_fn = observation_compare_fn 59 | self.reset(False) 60 | 61 | def reset(self, show_stats=True): 62 | """Resets the memory.""" 63 | if show_stats: 64 | size = len(self) 65 | age_histogram, _ = np.histogram(self._memory_age[:size], 66 | 10, [0, self._count]) 67 | age_histogram = age_histogram.astype(np.float32) 68 | age_histogram = age_histogram / np.sum(age_histogram) 69 | print('Number of samples added in the previous trajectory: {}'.format( 70 | self._count)) 71 | print('Histogram of sample freshness (old to fresh): {}'.format( 72 | age_histogram)) 73 | 74 | self._count = 0 75 | # Stores environment observations. 76 | self._obs_memory = np.zeros([self._capacity] + self._observation_shape) 77 | # Stores the infos returned by the environment. For debugging and 78 | # visualization purposes. 79 | self._info_memory = [None] * self._capacity 80 | self._memory_age = np.zeros([self._capacity], dtype=np.int32) 81 | 82 | @property 83 | def capacity(self): 84 | return self._capacity 85 | 86 | def __len__(self): 87 | return min(self._count, self._capacity) 88 | 89 | @property 90 | def info_memory(self): 91 | return self._info_memory 92 | 93 | def add(self, observation, info): 94 | """Adds an observation to the memory. 95 | 96 | Args: 97 | observation: Observation to add to the episodic memory. 98 | info: Info returned by the environment together with the observation, 99 | for debugging and visualization purposes. 100 | 101 | Raises: 102 | ValueError: when the capacity of the memory is exceeded. 103 | """ 104 | if self._count >= self._capacity: 105 | if self._replacement == 'random': 106 | # By using random replacement, the age of elements inside the memory 107 | # follows a geometric distribution (more fresh samples compared to 108 | # old samples). 109 | index = np.random.randint(low=0, high=self._capacity) 110 | elif self._replacement == 'fifo': 111 | # In this scheme, only the last self._capacity elements are kept. 112 | # Samples are replaced using a FIFO scheme (implemented as a circular 113 | # buffer). 114 | index = self._count % self._capacity 115 | else: 116 | raise ValueError('Invalid replacement scheme') 117 | else: 118 | index = self._count 119 | 120 | self._obs_memory[index] = observation 121 | self._info_memory[index] = info 122 | self._memory_age[index] = self._count 123 | self._count += 1 124 | 125 | def similarity(self, observation): 126 | """Similarity between the input observation and the ones from the memory. 127 | 128 | Args: 129 | observation: The input observation. 130 | 131 | Returns: 132 | A numpy array of similarities corresponding to the similarity between 133 | the input and each of the element in the memory. 134 | """ 135 | # Make the observation batched with batch_size = self._size before 136 | # computing the similarities. 137 | # TODO(damienv): could we avoid replicating the observation ? 138 | # (with some form of broadcasting). 139 | size = len(self) 140 | observation = np.array([observation] * size) 141 | similarities = self._observation_compare_fn(observation, 142 | self._obs_memory[:size]) 143 | return similarities 144 | 145 | 146 | @gin.configurable 147 | def similarity_to_memory(observation, 148 | episodic_memory, 149 | similarity_aggregation='percentile'): 150 | """Returns the similarity of the observation to the episodic memory. 151 | 152 | Args: 153 | observation: The observation the agent transitions to. 154 | episodic_memory: Episodic memory. 155 | similarity_aggregation: Aggregation method to turn the multiple 156 | similarities to each observation in the memory into a scalar. 157 | 158 | Returns: 159 | A scalar corresponding to the similarity to episodic memory. This is 160 | computed by aggregating the similarities between the new observation 161 | and every observation in the memory, according to 'similarity_aggregation'. 162 | """ 163 | # Computes the similarities between the current observation and the past 164 | # observations in the memory. 165 | memory_length = len(episodic_memory) 166 | if memory_length == 0: 167 | return 0.0 168 | similarities = episodic_memory.similarity(observation) 169 | # Implements different surrogate aggregated similarities. 170 | # TODO(damienv): Implement other types of surrogate aggregated similarities. 171 | if similarity_aggregation == 'max': 172 | aggregated = np.max(similarities) 173 | elif similarity_aggregation == 'nth_largest': 174 | n = min(10, memory_length) 175 | aggregated = np.partition(similarities, -n)[-n] 176 | elif similarity_aggregation == 'percentile': 177 | percentile = 90 178 | aggregated = np.percentile(similarities, percentile) 179 | elif similarity_aggregation == 'relative_count': 180 | # Number of samples in the memory similar to the input observation. 181 | count = sum(similarities > 0.5) 182 | aggregated = float(count) / len(similarities) 183 | 184 | return aggregated 185 | -------------------------------------------------------------------------------- /episodic_curiosity/episodic_memory_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test of episodic_memory.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from episodic_curiosity import episodic_memory 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | 27 | def embedding_similarity(x1, x2): 28 | assert x1.shape[0] == x2.shape[0] 29 | epsilon = 1e-6 30 | 31 | # Inner product between the embeddings in x1 32 | # and the embeddings in x2. 33 | s = np.sum(x1 * x2, axis=-1) 34 | 35 | s /= np.linalg.norm(x1, axis=-1) * np.linalg.norm(x2, axis=-1) + epsilon 36 | return 0.5 * (s + 1.0) 37 | 38 | 39 | class EpisodicMemoryTest(tf.test.TestCase): 40 | 41 | def RunTest(self, memory, observation_shape, add_count): 42 | expected_size = min(add_count, memory.capacity) 43 | 44 | for _ in range(add_count): 45 | observation = np.random.normal(size=observation_shape) 46 | memory.add(observation, dict()) 47 | self.assertEqual(expected_size, len(memory)) 48 | 49 | current_observation = np.random.normal(size=observation_shape) 50 | similarities = memory.similarity(current_observation) 51 | self.assertEqual(expected_size, len(similarities)) 52 | self.assertAllLessEqual(similarities, 1.0) 53 | self.assertAllGreaterEqual(similarities, 0.0) 54 | 55 | def testEpisodicMemory(self): 56 | observation_shape = [9] 57 | memory = episodic_memory.EpisodicMemory( 58 | observation_shape=observation_shape, 59 | observation_compare_fn=embedding_similarity, 60 | capacity=150) 61 | 62 | self.RunTest(memory, 63 | observation_shape, 64 | add_count=100) 65 | memory.reset() 66 | 67 | self.RunTest(memory, 68 | observation_shape, 69 | add_count=200) 70 | memory.reset() 71 | 72 | 73 | if __name__ == '__main__': 74 | tf.test.main() 75 | -------------------------------------------------------------------------------- /episodic_curiosity/eval_policy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Evaluation of a policy on a GYM environment.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gym 23 | import numpy as np 24 | 25 | 26 | class DummyVideoWriter(object): 27 | 28 | def add(self, obs): 29 | pass 30 | 31 | def close(self, filename): 32 | pass 33 | 34 | 35 | class PolicyEvaluator(object): 36 | """Evaluate a policy on a GYM environment.""" 37 | 38 | def __init__(self, vec_env, 39 | metric_callback=None, 40 | video_filename=None, grayscale=False, 41 | eval_frequency=25): 42 | """New policy evaluator. 43 | 44 | Args: 45 | vec_env: baselines.VecEnv correspond to a vector of GYM environments. 46 | metric_callback: Function that is given the average reward and the time 47 | step of the evaluation. 48 | video_filename: Prefix of filenames used to record video. 49 | grayscale: Whether the observation is grayscale or color. 50 | eval_frequency: Only performs evaluation once every eval_frequency times. 51 | """ 52 | self._vec_env = vec_env 53 | self._metric_callback = metric_callback 54 | self._video_filename = video_filename 55 | self._grayscale = grayscale 56 | 57 | self._eval_count = 0 58 | self._eval_frequency = eval_frequency 59 | self._discrete_actions = isinstance(self._vec_env.observation_space, 60 | gym.spaces.Discrete) 61 | 62 | def evaluate(self, model_step_fn, global_step): 63 | """Evaluate the policy as given by its step function. 64 | 65 | Args: 66 | model_step_fn: Function which given a batch of observations, 67 | a batch of policy states and a batch of dones flags returns 68 | a batch of selected actions and updated policy states. 69 | global_step: The global step of the training process. 70 | """ 71 | if self._eval_count % self._eval_frequency != 0: 72 | self._eval_count += 1 73 | return 74 | self._eval_count += 1 75 | 76 | video_writer = DummyVideoWriter() 77 | video_writer2 = DummyVideoWriter() 78 | has_alternative_video = False 79 | if self._video_filename: 80 | video_filename = '{}_{}.mp4'.format(self._video_filename, global_step) 81 | video_filename2 = '{}_{}_v2.mp4'.format(self._video_filename, global_step) 82 | else: 83 | video_filename = 'dummy.mp4' 84 | video_filename2 = 'dummy2.mp4' 85 | 86 | # Initial state of the policy. 87 | # TODO(damienv): make the policy state dimension part of the constructor. 88 | policy_state_dim = 512 89 | policy_states = np.zeros((self._vec_env.num_envs, policy_state_dim), 90 | dtype=np.float32) 91 | 92 | # Reset the environments before starting the evaluation. 93 | dones = [False] * self._vec_env.num_envs 94 | sticky_dones = [False] * self._vec_env.num_envs 95 | obs = self._vec_env.reset() 96 | 97 | # Evaluation loop. 98 | total_reward = np.zeros((self._vec_env.num_envs,), dtype=np.float32) 99 | step_iter = 0 100 | action_distribution = {} 101 | while not all(sticky_dones): 102 | actions, _, policy_states, _ = model_step_fn(obs, policy_states, dones) 103 | 104 | # Update the distribution of actions seen along the trajectory. 105 | if self._discrete_actions: 106 | for action in actions: 107 | if action not in action_distribution: 108 | action_distribution[action] = 0 109 | action_distribution[action] += 1 110 | 111 | # Update the states of the environment based on the selected actions. 112 | obs, rewards, dones, infos = self._vec_env.step(actions) 113 | step_iter += 1 114 | for k in range(self._vec_env.num_envs,): 115 | if not sticky_dones[k]: 116 | total_reward[k] += rewards[k] 117 | sticky_dones = [sd or d for (sd, d) in zip(sticky_dones, dones)] 118 | 119 | # Optionally record the frames of the 1st environment. 120 | if not sticky_dones[0]: 121 | if infos[0].get('frame') is not None: 122 | frame = infos[0]['frame'] 123 | else: 124 | frame = obs[0] 125 | if self._grayscale: 126 | video_writer.add(frame[:, :, 0]) 127 | else: 128 | video_writer.add(frame) 129 | if infos[0].get('frame:track') is not None: 130 | has_alternative_video = True 131 | frame = infos[0]['frame:track'] 132 | if self._grayscale: 133 | video_writer2.add(frame[:, :, 0]) 134 | else: 135 | video_writer2.add(frame) 136 | 137 | if self._metric_callback: 138 | self._metric_callback(np.mean(total_reward), global_step) 139 | 140 | print('Average reward: {}, total reward: {}'.format(np.mean(total_reward), 141 | total_reward)) 142 | if self._discrete_actions: 143 | print('Action distribution: {}'.format(action_distribution)) 144 | video_writer.close(video_filename) 145 | if has_alternative_video: 146 | video_writer2.close(video_filename2) 147 | -------------------------------------------------------------------------------- /episodic_curiosity/generate_r_training_data_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for generate_r_training_data.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest 22 | from episodic_curiosity.environments import fake_gym_env 23 | from episodic_curiosity.generate_r_training_data import generate_random_episode_buffer 24 | 25 | 26 | class TestGenerateRTrainingData(absltest.TestCase): 27 | 28 | def setUp(self): 29 | # Fake Environment with an infinite length episode. 30 | self.env = fake_gym_env.FakeGymEnv() 31 | 32 | def test_generate_random_episode(self): 33 | for _ in range(2): 34 | episode_buffer = generate_random_episode_buffer(self.env) 35 | self.assertEqual(len(episode_buffer), self.env.EPISODE_LENGTH) 36 | self.assertTupleEqual(episode_buffer[0][0].shape, 37 | self.env.OBSERVATION_SHAPE) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /episodic_curiosity/keras_checkpoint.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Keras checkpointing using GFile.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | from __future__ import print_function 22 | 23 | import os 24 | import tempfile 25 | import time 26 | 27 | from absl import logging 28 | import tensorflow as tf 29 | from tensorflow import keras 30 | 31 | 32 | # Taken from: dos/ml/lstm/train_util.py, but also supports unformatted strings 33 | # and writing summary files. 34 | class GFileModelCheckpoint(keras.callbacks.ModelCheckpoint): 35 | """Keras callback to checkpoint model to a gfile location. 36 | 37 | Makes the keras ModelCheckpoint callback compatible with google filesystem 38 | paths, such as CNS files. 39 | Models will be saved to tmp_file_path and copied from there to file_path. 40 | Also writes a summary file with model performance along with the checkpoint. 41 | """ 42 | 43 | def __init__(self, 44 | file_path, 45 | save_summary, 46 | summary = None, 47 | *args, 48 | **kwargs): # pylint: disable=keyword-arg-before-vararg 49 | """Initializes checkpointer with appropriate filepaths. 50 | 51 | Args: 52 | file_path: gfile location to save model to. Supports unformatted strings 53 | similarly to keras ModelCheckpoint. 54 | save_summary: Whether we should generate and save a summary file. 55 | summary: Additional items to write to the summary file. 56 | *args: positional args passed to the underlying ModelCheckpoint. 57 | **kwargs: named args passed to the underlying ModelCheckpoint. 58 | """ 59 | self.save_summary = save_summary 60 | self.summary = summary 61 | # We assume that this directory is not used by anybody else, so we uniquify 62 | # it (a bit overkill, but hey). 63 | self.tmp_dir = os.path.join( 64 | tempfile.gettempdir(), 65 | 'tmp_keras_weights_%d_%d' % (int(time.time() * 1e6), id(self))) 66 | tf.gfile.MakeDirs(self.tmp_dir) 67 | self.tmp_path = os.path.join(self.tmp_dir, os.path.basename(file_path)) 68 | self.gfile_dir = os.path.dirname(file_path) 69 | super(GFileModelCheckpoint, self).__init__(self.tmp_path, *args, **kwargs) 70 | 71 | def on_epoch_end(self, epoch, logs = None): 72 | """At end of epoch, performs the gfile checkpointing.""" 73 | super(GFileModelCheckpoint, self).on_epoch_end(epoch, logs=None) 74 | if self.epochs_since_last_save == 0: # ModelCheckpoint just saved 75 | tmp_dir_contents = tf.gfile.ListDirectory(self.tmp_dir) 76 | for tmp_weights_filename in tmp_dir_contents: 77 | src = os.path.join(self.tmp_dir, tmp_weights_filename) 78 | dst = os.path.join(self.gfile_dir, tmp_weights_filename) 79 | logging.info('Copying saved keras model weights from %s to %s', src, 80 | dst) 81 | tf.gfile.Copy(src, dst, overwrite=True) 82 | tf.gfile.Remove(src) 83 | if self.save_summary: 84 | merged_summary = {} 85 | merged_summary.update(self.summary) 86 | if logs: 87 | merged_summary.update(logs) 88 | with tf.gfile.Open(dst.replace('.h5', '.summary.txt'), 89 | 'w') as summary_file: 90 | summary_file.write('\n'.join( 91 | ['{}: {}'.format(k, v) for k, v in merged_summary.items()])) 92 | -------------------------------------------------------------------------------- /episodic_curiosity/logging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Logging for episodic curiosity.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import datetime 21 | import numpy as np 22 | 23 | 24 | class VideoWriter(object): 25 | """Wrapper around video writer APIs.""" 26 | 27 | def __init__(self, filename): 28 | # We don't have this library inside Google, so we only import it in the 29 | # Open-source version when we need it, and tell pytype to ignore the 30 | # fact that it's missing. 31 | # pylint: disable=g-import-not-at-top 32 | import skvideo.io # type: ignore 33 | self._writer = skvideo.io.FFmpegWriter(filename) 34 | 35 | def add(self, frame): 36 | self._writer.writeFrame(frame) 37 | 38 | def close(self): 39 | self._writer.close() 40 | 41 | 42 | 43 | 44 | def get_video_writer(video_filename): 45 | return VideoWriter(video_filename) # pylint:disable=unreachable 46 | 47 | 48 | def save_episode_buffer_as_video(episode_buffer, video_filename): 49 | """Saves episode_buffer.""" 50 | video_writer = get_video_writer(video_filename) 51 | for frame in episode_buffer: 52 | video_writer.add(frame) 53 | video_writer.close() 54 | 55 | 56 | def save_training_examples_as_video(training_examples, video_filename): 57 | """Split example into two images and show side-by-side for a while.""" 58 | video_writer = get_video_writer(video_filename) 59 | for example in training_examples: 60 | first = example[Ellipsis, :3] 61 | second = example[Ellipsis, 3:] 62 | side_by_side = np.concatenate((first, second), axis=0) 63 | video_writer.add(side_by_side) 64 | video_writer.close() 65 | 66 | 67 | def get_logger_dir(exp_id): 68 | return datetime.datetime.now().strftime('ec-%Y-%m-%d-%H-%M-%S-%f_') + exp_id 69 | -------------------------------------------------------------------------------- /episodic_curiosity/oracle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Computes some oracle reward based on the actual agent position.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gin 23 | 24 | 25 | @gin.configurable 26 | class OracleExplorationReward(object): 27 | """Class that computes the ideal exploration bonus.""" 28 | 29 | def __init__(self, reward_grid_size=30.0, cell_reward_normalizer=900.0): 30 | """Creates a new oracle to compute the exploration reward. 31 | 32 | Args: 33 | reward_grid_size: Size of a cell that contains a unique reward. 34 | cell_reward_normalizer: Denominator for computation of a cell reward 35 | """ 36 | self._reward_grid_size = reward_grid_size 37 | 38 | # Make the total sum of exploration reward that can be collected 39 | # independent of the grid size. 40 | # Here, we assume that the position is laying on a 2D manifold, 41 | # hence the multiplication by the area of a 2D cell. 42 | self._cell_reward = float(reward_grid_size * reward_grid_size) 43 | 44 | # Somewhat normalize the exploration reward so that it is neither 45 | # too big or too small. 46 | self._cell_reward /= cell_reward_normalizer 47 | 48 | self.reset() 49 | 50 | def reset(self): 51 | self._collected_positions = set() 52 | 53 | def update_position(self, agent_position): 54 | """Set the new state (i.e. the position). 55 | 56 | Args: 57 | agent_position: x,y,z position of the agent. 58 | 59 | Returns: 60 | The exploration bonus for having visited this position. 61 | """ 62 | x, y, z = agent_position 63 | quantized_x = int(x / self._reward_grid_size) 64 | quantized_y = int(y / self._reward_grid_size) 65 | quantized_z = int(z / self._reward_grid_size) 66 | position_id = (quantized_x, quantized_y, quantized_z) 67 | if position_id in self._collected_positions: 68 | # No reward if the position has already been explored. 69 | return 0.0 70 | else: 71 | self._collected_positions.add(position_id) 72 | return self._cell_reward 73 | -------------------------------------------------------------------------------- /episodic_curiosity/r_network.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """R-network and some related functions to train R-networks.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | from __future__ import print_function 22 | 23 | import tempfile 24 | 25 | from absl import logging 26 | from third_party.keras_resnet import models 27 | import numpy as np 28 | import tensorflow as tf 29 | from tensorflow import keras 30 | 31 | 32 | class RNetwork(object): 33 | """Encapsulates a trained R network, with lazy loading of weights.""" 34 | 35 | def __init__(self, input_shape, weight_path): 36 | """Inits the RNetwork. 37 | 38 | Args: 39 | input_shape: (height, width, channel) 40 | weight_path: Path to the weights of the r_network. 41 | """ 42 | self._weight_path = weight_path 43 | (self._r_network, self._embedding_network, 44 | self._similarity_network) = models.ResnetBuilder.build_siamese_resnet_18( 45 | input_shape) 46 | self._r_network.compile( 47 | loss='categorical_crossentropy', optimizer=keras.optimizers.Adam()) 48 | self._weights_loaded = False 49 | 50 | def _maybe_load_weights(self): 51 | """Loads R-network weights if needed. 52 | 53 | The RNetwork is used together with an environment used by ppo2.learn. 54 | Unfortunately, ppo2.learn initializes all global TF variables at the 55 | beginning of the training, which in particular, random-initializes the 56 | weights of the R Network. We therefore load the weights lazily, to make sure 57 | they are loaded after the global initialization happens in ppo2.learn. 58 | """ 59 | if self._weights_loaded: 60 | return 61 | if self._weight_path is None: 62 | # Typically the case when doing online training of the R-network. 63 | return 64 | # Keras does not support reading weights from CNS, so we have to copy the 65 | # weights to a temporary local file. 66 | with tempfile.NamedTemporaryFile(prefix='r_net', suffix='.h5', 67 | delete=False) as tmp_file: 68 | tmp_path = tmp_file.name 69 | tf.gfile.Copy(self._weight_path, tmp_path, overwrite=True) 70 | logging.info('Loading weights from %s...', tmp_path) 71 | print('Loading into R network:') 72 | self._r_network.summary() 73 | self._r_network.load_weights(tmp_path) 74 | tf.gfile.Remove(tmp_path) 75 | self._weights_loaded = True 76 | 77 | def embed_observation(self, x): 78 | """Embeds an observation. 79 | 80 | Args: 81 | x: batched input observations. Expected to have the shape specified when 82 | the RNetwork was contructed (plus the batch dimension as first dim). 83 | 84 | Returns: 85 | embedding, shape [batch, models.EMBEDDING_DIM] 86 | """ 87 | self._maybe_load_weights() 88 | return self._embedding_network.predict(x) 89 | 90 | def embedding_similarity(self, x, y): 91 | """Computes the similarity between two embeddings. 92 | 93 | Args: 94 | x: batch of the first embedding. Shape [batch, models.EMBEDDING_DIM]. 95 | y: batch of the first embedding. Shape [batch, models.EMBEDDING_DIM]. 96 | 97 | Returns: 98 | Similarity probabilities. 1 means very similar according to the net. 99 | 0 means very dissimilar. Shape [batch]. 100 | """ 101 | self._maybe_load_weights() 102 | return self._similarity_network.predict([x, y], 103 | batch_size=1024)[:, 1] 104 | -------------------------------------------------------------------------------- /episodic_curiosity/r_network_training_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for r_network_training.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import random 23 | 24 | from absl.testing import absltest 25 | from episodic_curiosity import r_network_training 26 | from episodic_curiosity.constants import Const 27 | from episodic_curiosity.r_network_training import create_training_data_from_episode_buffer_v123 28 | from episodic_curiosity.r_network_training import create_training_data_from_episode_buffer_v4 29 | from episodic_curiosity.r_network_training import generate_negative_example 30 | import numpy as np 31 | 32 | 33 | _OBSERVATION_SHAPE = (120, 160, 3) 34 | 35 | 36 | class TestRNetworkTraining(absltest.TestCase): 37 | 38 | def test_generate_negative_example(self): 39 | max_action_distance = 5 40 | len_episode_buffer = 100 41 | for _ in range(1000): 42 | buffer_position = np.random.randint(low=0, high=len_episode_buffer) 43 | first, second = generate_negative_example(buffer_position, 44 | len_episode_buffer, 45 | max_action_distance) 46 | self.assertGreater( 47 | abs(second - first), 48 | Const.NEGATIVE_SAMPLE_MULTIPLIER * max_action_distance) 49 | self.assertGreaterEqual(second, 0) 50 | self.assertLess(second, len_episode_buffer) 51 | 52 | def test_generate_negative_example2(self): 53 | max_action_distance = 5 54 | range_max = 5 * Const.NEGATIVE_SAMPLE_MULTIPLIER * max_action_distance 55 | for buffer_position in range(0, range_max): 56 | for buffer_length in range(buffer_position + 1, 3 * range_max): 57 | # Mainly check that it does not raise an exception. 58 | new_pos, index = generate_negative_example(buffer_position, 59 | buffer_length, 60 | max_action_distance) 61 | msg = 'buffer_pos={}, buffer_length={}'.format(buffer_position, 62 | buffer_length) 63 | self.assertLess(new_pos, buffer_length, msg) 64 | self.assertGreaterEqual(new_pos, 0, msg) 65 | if index is not None: 66 | self.assertLess(index, buffer_length, msg) 67 | self.assertGreaterEqual(index, 0, msg) 68 | 69 | def test_create_training_data_from_episode_buffer_v1(self): 70 | # Make the test deterministic. 71 | random.seed(7) 72 | max_action_distance_mode = 'v1_affect_num_training_examples' 73 | max_action_distance = 5 74 | len_episode_buffer = 1000 75 | episode_buffer = [(np.zeros(_OBSERVATION_SHAPE), { 76 | 'position': i 77 | }) for i in range(len_episode_buffer)] 78 | x1, x2, labels = create_training_data_from_episode_buffer_v123( 79 | episode_buffer, max_action_distance, max_action_distance_mode) 80 | self.assertEqual(len(x1), len(x2)) 81 | self.assertEqual(len(labels), len(x1)) 82 | # 4 is the average of 1 + randint(1, 5). 83 | self.assertGreater(len(x1), len_episode_buffer / 4 * 0.8) 84 | self.assertLess(len(x1), len_episode_buffer / 4 * 1.2) 85 | for x in [x1, x2]: 86 | self.assertTupleEqual(x[0][0].shape, _OBSERVATION_SHAPE) 87 | max_previous_pos = -1 88 | for xx1, xx2, label in zip(x1, x2, labels): 89 | if not label: 90 | continue 91 | obs1, info1 = xx1 92 | del obs1 # unused 93 | pos1 = info1['position'] 94 | obs2, info2 = xx2 95 | del obs2 # unused 96 | pos2 = info2['position'] 97 | assert max_previous_pos < pos1, ( 98 | 'In v1 mode, intervals of positive examples should have no overlap.') 99 | assert max_previous_pos < pos2, ( 100 | 'In v1 mode, intervals of positive examples should have no overlap.') 101 | max_previous_pos = max(pos1, pos2) 102 | self._check_example_pairs(x1, x2, labels, max_action_distance) 103 | 104 | def test_create_training_data_from_episode_buffer_v2(self): 105 | # Make the test deterministic. 106 | random.seed(7) 107 | max_action_distance_mode = 'v2_fixed_num_training_examples' 108 | max_action_distance = 2 109 | len_episode_buffer = 1000 110 | episode_buffer = [(np.zeros(_OBSERVATION_SHAPE), { 111 | 'position': i 112 | }) for i in range(len_episode_buffer)] 113 | x1, x2, labels = create_training_data_from_episode_buffer_v123( 114 | episode_buffer, max_action_distance, max_action_distance_mode) 115 | self.assertEqual(len(x1), len(x2)) 116 | self.assertEqual(len(labels), len(x1)) 117 | # 4 is the average of 1 + randint(1, 5), where 5 is used in the v2 sampling 118 | # algorithm in order to keep the same number of training example regardless 119 | # of max_action_distance. 120 | self.assertGreater(len(x1), len_episode_buffer / 4 * 0.8) 121 | self.assertLess(len(x1), len_episode_buffer / 4 * 1.2) 122 | self._check_example_pairs(x1, x2, labels, max_action_distance) 123 | 124 | def test_create_training_data_from_episode_buffer_v4(self): 125 | # Make the test deterministic. 126 | random.seed(7) 127 | for avg_num_examples_per_env_step in (0.2, 0.5, 1, 2): 128 | max_action_distance = 5 129 | len_episode_buffer = 1000 130 | episode_buffer = [(np.zeros(_OBSERVATION_SHAPE), { 131 | 'position': i 132 | }) for i in range(len_episode_buffer)] 133 | x1, x2, labels = create_training_data_from_episode_buffer_v4( 134 | episode_buffer, max_action_distance, avg_num_examples_per_env_step) 135 | self.assertEqual(len(x1), len(x2)) 136 | self.assertEqual(len(labels), len(x1)) 137 | num_positives = len([label for label in labels if label == 1]) 138 | self._assert_within(num_positives, len(x1) // 2, 5) 139 | self._assert_within( 140 | len(labels), 141 | len(episode_buffer) * avg_num_examples_per_env_step, 5) 142 | self._check_example_pairs(x1, x2, labels, max_action_distance) 143 | 144 | def test_create_training_data_from_episode_buffer_too_short(self): 145 | max_action_distance = 5 146 | buff = [np.zeros(_OBSERVATION_SHAPE)] * max_action_distance 147 | # Repeat the test multiple times. create_training_data_from_episode_buffer 148 | # uses randomness, so we want to hit the case where it tries to generate 149 | # negative examples. 150 | for _ in range(50): 151 | _, _, labels = create_training_data_from_episode_buffer_v123( 152 | buff, max_action_distance, mode='v1_affect_num_training_examples') 153 | for label in labels: 154 | # Not enough buffer to generate negative examples, so we should get only 155 | # (but possibly none) positive examples. 156 | self.assertEqual(label, 1) 157 | 158 | def _check_example_pairs(self, x1, x2, labels, max_action_distance): 159 | for xx1, xx2, label in zip(x1, x2, labels): 160 | obs1, info1 = xx1 161 | del obs1 # unused 162 | pos1 = info1['position'] 163 | obs2, info2 = xx2 164 | del obs2 # unused 165 | pos2 = info2['position'] 166 | diff = abs(pos1 - pos2) 167 | if label: 168 | self.assertLessEqual(diff, max_action_distance) 169 | else: 170 | self.assertGreater( 171 | diff, Const.NEGATIVE_SAMPLE_MULTIPLIER * max_action_distance) 172 | 173 | def _assert_within(self, value, expected, percentage): 174 | self.assertGreater( 175 | value, 176 | expected * (100 - percentage) / 100, '{} vs {} within {}%'.format( 177 | value, expected, percentage)) 178 | self.assertLess(value, 179 | expected * (100 + percentage) / 100, 180 | '{} vs {} within {}%'.format(value, expected, percentage)) 181 | 182 | def fit_generator(self, 183 | batch_gen, steps_per_epoch, epochs, 184 | validation_data): # pylint: disable=unused-argument 185 | max_distance = 5 186 | for _ in range(epochs): 187 | for _ in range(steps_per_epoch): 188 | pairs, labels = next(batch_gen) 189 | 190 | # Make sure all the pairs are coming from the same trajectory / env 191 | # and that they are compatible with the given labels. 192 | batch_x1 = pairs[0] 193 | batch_x2 = pairs[1] 194 | batch_size = batch_x1.shape[0] 195 | assert batch_x2.shape[0] == batch_size 196 | for k in range(batch_size): 197 | x1 = batch_x1[k] 198 | x2 = batch_x2[k] 199 | env_x1, trajectory_x1, step_x1 = x1 200 | env_x2, trajectory_x2, step_x2 = x2 201 | self.assertEqual(env_x1, env_x2) 202 | self.assertEqual(trajectory_x1, trajectory_x2) 203 | 204 | distance = abs(step_x2 - step_x1) 205 | if labels[k][1]: 206 | self.assertLessEqual(distance, max_distance) 207 | else: 208 | self.assertGreater(distance, max_distance) 209 | 210 | def test_r_network_training(self): 211 | r_model = self 212 | r_trainer = r_network_training.RNetworkTrainer( 213 | r_model, 214 | observation_history_size=10000, 215 | training_interval=20000) 216 | 217 | # Every observation is a vector of dimension 3: 218 | # (worker_id, trajectory_id, step_num) 219 | feed_count = 20000 220 | env_count = 8 221 | worker_id = list(range(env_count)) 222 | trajectory_id = [0] * env_count 223 | step_idx = [0] * env_count 224 | proba_done = 0.01 225 | for _ in range(feed_count): 226 | # Observations: size = env_count x 3 227 | observations = np.stack([worker_id, trajectory_id, step_idx]) 228 | observations = np.transpose(observations) 229 | dones = np.random.choice([True, False], size=env_count, 230 | p=[proba_done, 1.0 - proba_done]) 231 | r_trainer.on_new_observation(observations, None, dones, None) 232 | 233 | step_idx = [s + 1 for s in step_idx] 234 | 235 | # Update the trajectory index and the environment step. 236 | for k in range(env_count): 237 | if dones[k]: 238 | step_idx[k] = 0 239 | trajectory_id[k] += 1 240 | 241 | 242 | if __name__ == '__main__': 243 | absltest.main() 244 | -------------------------------------------------------------------------------- /episodic_curiosity/train_policy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Main file for training policies. 17 | 18 | Many hyperparameters need to be passed through gin flags. 19 | Consider using scripts/launcher_script.py to invoke train_policy with the 20 | right hyperparameters and flags. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import os 28 | import tempfile 29 | import time 30 | 31 | from absl import flags 32 | from episodic_curiosity import env_factory 33 | from episodic_curiosity import eval_policy 34 | from episodic_curiosity import utils 35 | from third_party.baselines import logger 36 | from third_party.baselines.ppo2 import policies 37 | from third_party.baselines.ppo2 import ppo2 38 | import gin 39 | import tensorflow as tf 40 | 41 | 42 | flags.DEFINE_string('workdir', None, 43 | 'Root directory for writing logs/summaries/checkpoints.') 44 | flags.DEFINE_string('env_name', 'CartPole-v0', 'What environment to run') 45 | flags.DEFINE_string('policy_architecture', 'cnn', 46 | 'What model architecture to use') 47 | flags.DEFINE_string('r_checkpoint', '', 'Location of the R-network checkpoint') 48 | flags.DEFINE_integer('num_env', 12, 'Number of environment copies to run in ' 49 | 'subprocesses.') 50 | flags.DEFINE_string('dmlab_homepath', '', '') 51 | flags.DEFINE_integer('num_timesteps', 10000000, 'Number of frames to run ' 52 | 'training for.') 53 | flags.DEFINE_string('action_set', '', 54 | '(small|nofire|) - which action set to use') 55 | flags.DEFINE_bool('use_curiosity', False, 56 | 'Whether to enable Pathak\'s curiosity') 57 | flags.DEFINE_bool('random_state_predictor', False, 58 | 'Whether to use random state predictor for Pathak\'s ' 59 | 'curiosity') 60 | flags.DEFINE_float('curiosity_strength', 0.01, 61 | 'Strength of the intrinsic reward in Pathak\'s algorithm.') 62 | flags.DEFINE_float('forward_inverse_ratio', 0.2, 63 | 'Weighting of forward vs inverse loss in Pathak\'s ' 64 | 'algorithm') 65 | flags.DEFINE_float('curiosity_loss_strength', 10, 66 | 'Weight of the curiosity loss in Pathak\'s algorithm.') 67 | 68 | 69 | # pylint: disable=g-inconsistent-quotes 70 | flags.DEFINE_multi_string( 71 | 'gin_files', [], 'List of paths to gin configuration files') 72 | flags.DEFINE_multi_string( 73 | 'gin_bindings', [], 74 | 'Gin bindings to override the values set in the config files ' 75 | '(e.g. "DQNAgent.epsilon_train=0.1",' 76 | ' "create_environment.game_name="Pong"").') 77 | # pylint: enable=g-inconsistent-quotes 78 | 79 | FLAGS = flags.FLAGS 80 | 81 | 82 | def get_environment(env_name): 83 | dmlab_prefix = 'dmlab:' 84 | atari_prefix = 'atari:' 85 | parkour_prefix = 'parkour:' 86 | if env_name.startswith(dmlab_prefix): 87 | level_name = env_name[len(dmlab_prefix):] 88 | return env_factory.create_environments( 89 | level_name, 90 | FLAGS.num_env, 91 | FLAGS.r_checkpoint, 92 | FLAGS.dmlab_homepath, 93 | action_set=FLAGS.action_set, 94 | r_network_weights_store_path=FLAGS.workdir) 95 | elif env_name.startswith(atari_prefix): 96 | level_name = env_name[len(atari_prefix):] 97 | return env_factory.create_environments( 98 | level_name, 99 | FLAGS.num_env, 100 | FLAGS.r_checkpoint, 101 | environment_engine='atari', 102 | r_network_weights_store_path=FLAGS.workdir) 103 | if env_name.startswith(parkour_prefix): 104 | return env_factory.create_environments( 105 | env_name[len(parkour_prefix):], 106 | FLAGS.num_env, 107 | FLAGS.r_checkpoint, 108 | environment_engine='parkour', 109 | r_network_weights_store_path=FLAGS.workdir) 110 | raise ValueError('Unknown environment: {}'.format(env_name)) 111 | 112 | 113 | @gin.configurable 114 | def train(workdir, env_name, num_timesteps, 115 | nsteps=256, 116 | nminibatches=4, 117 | noptepochs=4, 118 | learning_rate=2.5e-4, 119 | ent_coef=0.01): 120 | """Runs PPO training. 121 | 122 | Args: 123 | workdir: where to store experiment results/logs 124 | env_name: the name of the environment to run 125 | num_timesteps: for how many timesteps to run training 126 | nsteps: Number of consecutive environment steps to use during training. 127 | nminibatches: Minibatch size. 128 | noptepochs: Number of optimization epochs. 129 | learning_rate: Initial learning rate. 130 | ent_coef: Entropy coefficient. 131 | """ 132 | train_measurements = utils.create_measurement_series(workdir, 'reward_train') 133 | valid_measurements = utils.create_measurement_series(workdir, 'reward_valid') 134 | test_measurements = utils.create_measurement_series(workdir, 'reward_test') 135 | 136 | def measurement_callback(unused_eplenmean, eprewmean, global_step_val): 137 | if train_measurements: 138 | train_measurements.create_measurement( 139 | objective_value=eprewmean, step=global_step_val) 140 | 141 | def eval_callback_on_valid(eprewmean, global_step_val): 142 | if valid_measurements: 143 | valid_measurements.create_measurement( 144 | objective_value=eprewmean, step=global_step_val) 145 | 146 | def eval_callback_on_test(eprewmean, global_step_val): 147 | if test_measurements: 148 | test_measurements.create_measurement( 149 | objective_value=eprewmean, step=global_step_val) 150 | 151 | logger_dir = workdir 152 | logger.configure(logger_dir) 153 | 154 | env, valid_env, test_env = get_environment(env_name) 155 | is_ant = env_name.startswith('parkour:') 156 | 157 | # Validation metric. 158 | policy_evaluator_on_valid = eval_policy.PolicyEvaluator( 159 | valid_env, 160 | metric_callback=eval_callback_on_valid, 161 | video_filename=None) 162 | 163 | # Test metric (+ videos). 164 | video_filename = os.path.join(FLAGS.workdir, 'video') 165 | policy_evaluator_on_test = eval_policy.PolicyEvaluator( 166 | test_env, 167 | metric_callback=eval_callback_on_test, 168 | video_filename=video_filename, 169 | grayscale=(env_name.startswith('atari:'))) 170 | 171 | # Delay to make sure that all the DMLab environments acquire 172 | # the GPU resources before TensorFlow acquire the rest of the memory. 173 | # TODO(damienv): Possibly use allow_grow in a TensorFlow session 174 | # so that there is no such problem anymore. 175 | time.sleep(15) 176 | 177 | cloud_sync_callback = lambda: None 178 | 179 | def evaluate_valid_test(model_step_fn, global_step): 180 | if not is_ant: 181 | policy_evaluator_on_valid.evaluate(model_step_fn, global_step) 182 | policy_evaluator_on_test.evaluate(model_step_fn, global_step) 183 | 184 | with tf.Session(): 185 | policy = {'cnn': policies.CnnPolicy, 186 | 'lstm': policies.LstmPolicy, 187 | 'lnlstm': policies.LnLstmPolicy, 188 | 'mlp': policies.MlpPolicy}[FLAGS.policy_architecture] 189 | 190 | # Openai baselines never performs num_timesteps env steps because 191 | # of the way it samples training data in batches. The number of timesteps 192 | # is multiplied by 1.1 (hacky) to insure at least num_timesteps are 193 | # performed. 194 | 195 | ppo2.learn(policy, env=env, nsteps=nsteps, nminibatches=nminibatches, 196 | lam=0.95, gamma=0.99, noptepochs=noptepochs, log_interval=1, 197 | ent_coef=ent_coef, 198 | lr=learning_rate if is_ant else lambda f: f * learning_rate, 199 | cliprange=0.2 if is_ant else lambda f: f * 0.1, 200 | total_timesteps=int(num_timesteps * 1.1), 201 | train_callback=measurement_callback, 202 | eval_callback=evaluate_valid_test, 203 | cloud_sync_callback=cloud_sync_callback, 204 | save_interval=200, workdir=workdir, 205 | use_curiosity=FLAGS.use_curiosity, 206 | curiosity_strength=FLAGS.curiosity_strength, 207 | forward_inverse_ratio=FLAGS.forward_inverse_ratio, 208 | curiosity_loss_strength=FLAGS.curiosity_loss_strength, 209 | random_state_predictor=FLAGS.random_state_predictor) 210 | cloud_sync_callback() 211 | test_env.close() 212 | valid_env.close() 213 | utils.maybe_close_measurements(train_measurements) 214 | utils.maybe_close_measurements(valid_measurements) 215 | utils.maybe_close_measurements(test_measurements) 216 | 217 | 218 | 219 | 220 | def main(_): 221 | utils.dump_flags_to_file(os.path.join(FLAGS.workdir, 'flags.txt')) 222 | tf.logging.set_verbosity(tf.logging.INFO) 223 | gin.parse_config_files_and_bindings(FLAGS.gin_files, 224 | FLAGS.gin_bindings) 225 | train(FLAGS.workdir, env_name=FLAGS.env_name, 226 | num_timesteps=FLAGS.num_timesteps) 227 | 228 | 229 | if __name__ == '__main__': 230 | tf.app.run() 231 | -------------------------------------------------------------------------------- /episodic_curiosity/train_r_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for dune.rl.episodic_curiosity.train_r.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import os 22 | from absl import flags 23 | from absl.testing import absltest 24 | from episodic_curiosity import constants 25 | from episodic_curiosity import keras_checkpoint 26 | from episodic_curiosity import train_r 27 | import mock 28 | import numpy as np 29 | import tensorflow as tf 30 | from tensorflow import keras 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | 35 | class TrainRTest(absltest.TestCase): 36 | 37 | def setUp(self): 38 | super(TrainRTest, self).setUp() 39 | keras.backend.clear_session() 40 | 41 | def test_export_stats_to_xm(self): 42 | xm_series = train_r.XmSeries( 43 | loss=mock.MagicMock(), 44 | acc=mock.MagicMock(), 45 | val_loss=mock.MagicMock(), 46 | val_acc=mock.MagicMock()) 47 | self._fit_model_with_callback(train_r.ExportStatsToXm(xm_series)) 48 | for series in xm_series._asdict().values(): 49 | self.assertEqual(series.create_measurement.call_count, 1) 50 | 51 | def test_model_weights_checkpoint(self): 52 | path = os.path.join(FLAGS.test_tmpdir, 'r_network_weights.{epoch:05d}.h5') 53 | self._fit_model_with_callback( 54 | keras_checkpoint.GFileModelCheckpoint( 55 | path, 56 | save_summary=True, 57 | summary=constants.Level('explore_goal_locations_small').asdict(), 58 | save_weights_only=True, 59 | period=1)) 60 | self.assertTrue(tf.gfile.Exists(path.format(epoch=1))) 61 | self.assertTrue( 62 | tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt'))) 63 | 64 | def test_full_model_checkpoint(self): 65 | path = os.path.join(FLAGS.test_tmpdir, 'r_network_full.{epoch:05d}.h5') 66 | self._fit_model_with_callback( 67 | keras_checkpoint.GFileModelCheckpoint( 68 | path, save_summary=False, save_weights_only=False, period=1)) 69 | self.assertTrue(tf.gfile.Exists(path.format(epoch=1))) 70 | self.assertFalse( 71 | tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt'))) 72 | 73 | def _fit_model_with_callback(self, callback): 74 | inpt = keras.layers.Input(shape=(1,)) 75 | model = keras.models.Model(inputs=inpt, outputs=keras.layers.Dense(1)(inpt)) 76 | model.compile( 77 | loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy']) 78 | model.fit( 79 | x=np.random.rand(100, 1), 80 | y=np.random.rand(100, 1), 81 | validation_data=(np.random.rand(100, 1), np.random.rand(100, 1)), 82 | callbacks=[callback]) 83 | 84 | 85 | if __name__ == '__main__': 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /episodic_curiosity/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A few utilities for episodic curiosity. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | import csv 24 | import os 25 | import time 26 | from absl import flags 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def get_frame(env_observation, info): 34 | """Searches for a rendered frame in 'info', fallbacks to the env obs.""" 35 | info_frame = info.get('frame') 36 | if info_frame is not None: 37 | return info_frame 38 | return env_observation 39 | 40 | 41 | def dump_flags_to_file(filename): 42 | """Dumps FLAGS to a file.""" 43 | with tf.gfile.Open(filename, 'w') as output: 44 | output.write('\n'.join([ 45 | '{}={}'.format(flag_name, flag_value) 46 | for flag_name, flag_value in FLAGS.flag_values_dict().items() 47 | ])) 48 | 49 | 50 | class MeasurementsWriter(object): 51 | """Writes measurements to CSV.""" 52 | 53 | def __init__(self, workdir, measurement_name): 54 | """Initializes a MeasurementsWriter. 55 | 56 | Args: 57 | workdir: Directory to which the CSV file will be written. 58 | measurement_name: Name of the measurement. 59 | """ 60 | filename = os.path.join(workdir, measurement_name + '.csv') 61 | file_exists = tf.gfile.Exists(filename) 62 | self._out_file = tf.gfile.Open(filename, mode='a+') 63 | self._csv_writer = csv.writer(self._out_file) 64 | if not file_exists: 65 | self._csv_writer.writerow(['step', measurement_name, 'timestamp_s']) 66 | self._out_file.flush() 67 | self._measurement_name = measurement_name 68 | self._last_flush_time = 0 69 | 70 | def create_measurement(self, objective_value, step): 71 | """Adds a measurement. 72 | 73 | Args: 74 | objective_value: Value to report for the given training step. 75 | step: Training step. 76 | """ 77 | flush_every_s = 5 78 | self._csv_writer.writerow( 79 | [str(step), str(objective_value), str(int(time.time()))]) 80 | if time.time() - self._last_flush_time >= flush_every_s: 81 | self._last_flush_time = time.time() 82 | self._out_file.flush() 83 | 84 | def close(self): 85 | del self._csv_writer 86 | self._out_file.close() 87 | 88 | 89 | def create_measurement_series(workdir, label): 90 | """Creates an object for exporting a per-training-step metric.""" 91 | return MeasurementsWriter(workdir, label) # pylint:disable=unreachable 92 | 93 | 94 | def maybe_close_measurements(measurements): 95 | if isinstance(measurements, MeasurementsWriter): 96 | measurements.close() 97 | 98 | 99 | def load_keras_model(path): 100 | """Loads a keras model from a h5 file path.""" 101 | # pylint:disable=unreachable 102 | return tf.keras.models.load_model(path, compile=True) 103 | -------------------------------------------------------------------------------- /misc/ant_github.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/episodic-curiosity/3c406964473d98fb977b1617a170a447b3c548fd/misc/ant_github.gif -------------------------------------------------------------------------------- /misc/navigation_github.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/episodic-curiosity/3c406964473d98fb977b1617a170a447b3c548fd/misc/navigation_github.gif -------------------------------------------------------------------------------- /scripts/gs_sync.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Script that periodically syncs a dir to a cloud storage bucket. 17 | 18 | We only use standard python deps so that this script can be executed in any 19 | setup. 20 | This script repeatedly calls the 'gsutil rsync' command. 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import argparse 27 | import os 28 | import subprocess 29 | import time 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--workdir', help='Source to sync to the cloud bucket') 33 | parser.add_argument( 34 | '--sync_to_cloud_bucket', 35 | help='Cloud bucket (format gs://bucket_name) to sync the workdir to') 36 | FLAGS = parser.parse_args() 37 | 38 | 39 | def sync_to_cloud_bucket(): 40 | """Repeatedly syncs a path to a cloud bucket using gsutil.""" 41 | sync_cmd = ( 42 | 'gsutil -m rsync -r {src} {dst}'.format( 43 | src=os.path.expanduser(FLAGS.workdir), 44 | dst=os.path.join( 45 | FLAGS.sync_to_cloud_bucket, os.path.basename(FLAGS.workdir)))) 46 | while True: 47 | print('Syncing to cloud bucket:', sync_cmd) 48 | # We don't stop on failure, it can be transcient issue, a subsequent run 49 | # might work. 50 | subprocess.call(sync_cmd, shell=True) 51 | time.sleep(60) 52 | 53 | 54 | if __name__ == '__main__': 55 | sync_to_cloud_bucket() 56 | -------------------------------------------------------------------------------- /scripts/launch_cloud_vms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """py3 script that creates all the VMs on GCP for episodic curiosity. 17 | 18 | Right now, this launches the experiments to reproduce table 1 of 19 | https://arxiv.org/pdf/1810.02274.pdf. 20 | 21 | This script only depends on the standard py libraries on purpose, so that it can 22 | be executed under any setup. 23 | 24 | The gcloud command should be accessible (see: 25 | https://cloud.google.com/sdk/gcloud/). 26 | 27 | Invoke this script at the root of episodic-curiosity: 28 | python3 scripts/launch_cloud_vms.py. 29 | 30 | Tip: inspect the logs of the startup script on the VMs with: 31 | sudo journalctl -u google-startup-scripts.service 32 | """ 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | import argparse 38 | import subprocess 39 | import concurrent.futures 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--i_understand_launching_vms_is_expensive', 43 | action='store_true', 44 | help=('Each VM costs on the order of USD 30 per day based ' 45 | 'on early 2019 GCP pricing. This script launches ' 46 | 'many VMs at once, which can cost significant money. ' 47 | 'Pass this flag to show that you understood this.')) 48 | FLAGS = parser.parse_args() 49 | 50 | # Information about your GCP account/project: 51 | GCLOUD_PROJECT = 'FILL-ME' 52 | SERVICE_ACCOUNT = 'FILL-ME' 53 | ZONE = 'FILL-ME' 54 | # GCP snapshot with episodic-curiosity (and its dependencies) installed as 55 | # explained in README.md. 56 | SOURCE_SNAPSHOT = 'FILL-ME' 57 | # User to use on the VMs. 58 | VM_USER = 'FILL-ME' 59 | # Training logs and checkpoints will be synced to this Google Cloud bucket path. 60 | # E.g. 'gs://my-episodic-curiosity-logs/training_logs' 61 | SYNC_LOGS_TO_PATH = 'FILL-ME' 62 | 63 | # Path on a Google Cloud bucket to the pre-trained R-networks. If empty, 64 | # R-networks will be re-trained. 65 | PRETRAINED_R_NETS_PATH = 'gs://episodic-curiosity/r_networks' 66 | 67 | 68 | # Name templates for instances and disks. 69 | NAME_TEMPLATE = 'ec-20190301-{method}-{scenario}-{run_number}' 70 | 71 | # Scenarios to launch. 72 | SCENARIOS = [ 73 | 'noreward', 74 | 'norewardnofire', 75 | 'sparse', 76 | 'verysparse', 77 | 'sparseplusdoors', 78 | 'dense1', 79 | 'dense2', 80 | ] 81 | 82 | # Methods to launch. 83 | METHODS = [ 84 | 'ppo', 85 | # This is the online version of episodic curiosity. 86 | 'ppo_plus_eco', 87 | # This is the version of episodic curiosity where the R-network is trained 88 | # before the policy. 89 | 'ppo_plus_ec', 90 | 'ppo_plus_grid_oracle' 91 | ] 92 | 93 | # Number of identical training jobs to launch for each scenario and method. 94 | # Given the variance across runs, multiple of them are needed in order to get 95 | # confidence in the results. 96 | NUM_REPEATS = 10 97 | 98 | 99 | def required_resources_for_method(method, uses_pretrained_r_net): 100 | """Returns the required resources for the given training method. 101 | 102 | Args: 103 | method: str, training method. 104 | uses_pretrained_r_net: bool, whether we use pre-trained r-net. 105 | 106 | Returns: 107 | Tuple (RAM (MBs), num CPUs, num GPUs) 108 | """ 109 | if method == 'ppo_plus_eco': 110 | # We need to rent 2 GPUs, because with this amount of RAM, GCP won't allow 111 | # us to rent only one. 112 | return (105472, 16, 2) 113 | if method == 'ppo_plus_ec' and not uses_pretrained_r_net: 114 | return (52224, 12, 1) 115 | return (32768, 12, 1) 116 | 117 | 118 | def launch_vm(vm_id, vm_metadata): 119 | """Creates and launches a VM on Google Cloud compute engine. 120 | 121 | Args: 122 | vm_id: str, unique ID of the vm. 123 | vm_metadata: Dict[str, str], metadata key/value pairs passed to the vm. 124 | """ 125 | print('\nCreating disk and vm with ID:', vm_id) 126 | vm_metadata['vm_id'] = vm_id 127 | ram_mbs, num_cpus, num_gpus = required_resources_for_method( 128 | vm_metadata['method'], 129 | bool(vm_metadata['pretrained_r_nets_path'])) 130 | 131 | create_disk_cmd = ( 132 | 'gcloud compute disks create ' 133 | '"{disk_name}" --zone "{zone}" --source-snapshot "{source_snapshot}" ' 134 | '--type "pd-standard" --project="{gcloud_project}" ' 135 | '--size=200GB'.format( 136 | disk_name=vm_id, 137 | zone=ZONE, 138 | source_snapshot=SOURCE_SNAPSHOT, 139 | gcloud_project=GCLOUD_PROJECT, 140 | )) 141 | print('Calling', create_disk_cmd) 142 | # Don't fail if disk already exists. 143 | subprocess.call(create_disk_cmd, shell=True) 144 | 145 | create_instance_cmd = ( 146 | 'gcloud compute --project={gcloud_project} instances create ' 147 | '{instance_name} --zone={zone} --machine-type={machine_type} ' 148 | '--subnet=default --network-tier=PREMIUM --maintenance-policy=TERMINATE ' 149 | '--service-account={service_account} ' 150 | '--scopes=storage-full,compute-rw ' 151 | '--accelerator=type=nvidia-tesla-p100,count={gpu_count} ' 152 | '--disk=name={disk_name},device-name={disk_name},mode=rw,boot=yes,' 153 | 'auto-delete=yes --restart-on-failure ' 154 | '--metadata-from-file startup-script=./scripts/vm_drop_root.sh ' 155 | '--metadata {vm_metadata} --async'.format( 156 | instance_name=vm_id, 157 | zone=ZONE, 158 | machine_type='custom-{num_cpus}-{ram_mbs}'.format( 159 | num_cpus=num_cpus, ram_mbs=ram_mbs), 160 | gpu_count=num_gpus, 161 | disk_name=vm_id, 162 | vm_metadata=( 163 | ','.join('{}={}'.format(k, v) for k, v in vm_metadata.items())), 164 | gcloud_project=GCLOUD_PROJECT, 165 | service_account=SERVICE_ACCOUNT, 166 | )) 167 | 168 | print('Calling', create_instance_cmd) 169 | subprocess.check_call(create_instance_cmd, shell=True) 170 | 171 | 172 | def main(): 173 | launch_args = [] 174 | for method in METHODS: 175 | for scenario in SCENARIOS: 176 | for run_number in range(NUM_REPEATS): 177 | vm_id = NAME_TEMPLATE.format( 178 | method=method.replace('_', '-'), 179 | scenario=scenario.replace('_', '-'), 180 | run_number=run_number) 181 | launch_args.append(( 182 | vm_id, 183 | { 184 | 'method': method, 185 | 'scenario': scenario, 186 | 'run_number': str(run_number), 187 | 'user': VM_USER, 188 | 'pretrained_r_nets_path': PRETRAINED_R_NETS_PATH, 189 | 'sync_logs_to_path': SYNC_LOGS_TO_PATH 190 | })) 191 | print('YOU ARE ABOUT TO START', len(launch_args), 'VMs on GCP.') 192 | if not FLAGS.i_understand_launching_vms_is_expensive: 193 | print('Please pass --i_understand_launching_vms_is_expensive to specify ' 194 | 'that you understood the cost implications of launching', 195 | len(launch_args), 'VMs') 196 | return 197 | # We use many threads in order to quickly start many instances. 198 | with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: 199 | futures = [] 200 | for args in launch_args: 201 | futures.append(executor.submit(launch_vm, *args)) 202 | for f in futures: 203 | assert f.result() is None 204 | 205 | 206 | if __name__ == '__main__': 207 | main() 208 | -------------------------------------------------------------------------------- /scripts/vm_drop_root.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This script runs as root when the VM instance starts. 18 | # It launches a child script, after dropping the root user. 19 | 20 | set -x 21 | 22 | USER=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/user" -H "Metadata-Flavor: Google") 23 | EPISODIC_CURIOSITY_DIR="/home/${USER}/episodic-curiosity" 24 | PRETRAINED_R_NETS_PATH=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/pretrained_r_nets_path" -H "Metadata-Flavor: Google") 25 | 26 | # Note: during development, you could sync code from your local machine to a 27 | # cloud bucket, sync it here from the bucket to the VM, and pip install it. 28 | 29 | if [[ "${PRETRAINED_R_NETS_PATH}" ]] 30 | then 31 | gsutil -m cp -r "${PRETRAINED_R_NETS_PATH}" "/home/${USER}" 32 | fi 33 | 34 | chmod a+x "${EPISODIC_CURIOSITY_DIR}/scripts/vm_start.sh" 35 | 36 | # Launch vm_start under the given user. 37 | su - "${USER}" -c "${EPISODIC_CURIOSITY_DIR}/scripts/vm_start.sh" 38 | -------------------------------------------------------------------------------- /scripts/vm_start.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Copyright 2019 Google LLC. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This script runs when the VM instance starts. It launches episodic training 18 | # using launcher_script.py, and using the per-instance metadata (preferably set 19 | # using launch_cloud_vms.py). 20 | 21 | set -x 22 | 23 | whoami 24 | 25 | cd 26 | 27 | VM_ID=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/vm_id" -H "Metadata-Flavor: Google") 28 | METHOD=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/method" -H "Metadata-Flavor: Google") 29 | SCENARIO=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/scenario" -H "Metadata-Flavor: Google") 30 | RUN_NUMBER=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/run_number" -H "Metadata-Flavor: Google") 31 | PRETRAINED_R_NETS_PATH=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/pretrained_r_nets_path" -H "Metadata-Flavor: Google") 32 | SYNC_LOGS_TO_PATH=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/sync_logs_to_path" -H "Metadata-Flavor: Google") 33 | HOMEDIR=$(pwd) 34 | 35 | WORKDIR="${HOMEDIR}/${VM_ID}" 36 | 37 | EPISODIC_CURIOSITY_DIR="${HOMEDIR}/episodic-curiosity" 38 | 39 | mkdir -p "${WORKDIR}" 40 | 41 | # This must happen before we activate the virtual env, otherwise, gsutil does not work. 42 | python "${EPISODIC_CURIOSITY_DIR}/scripts/gs_sync.py" --workdir="${WORKDIR}" --sync_to_cloud_bucket="${SYNC_LOGS_TO_PATH}" & 43 | 44 | source episodic_curiosity_env/bin/activate 45 | 46 | cd "${EPISODIC_CURIOSITY_DIR}" 47 | 48 | if [[ "${PRETRAINED_R_NETS_PATH}" ]] 49 | then 50 | BASENAME=$(basename "${PRETRAINED_R_NETS_PATH}") 51 | R_NETWORKS_PATH_FLAG="--r_networks_path=${HOMEDIR}/${BASENAME}" 52 | else 53 | R_NETWORKS_PATH_FLAG="" 54 | fi 55 | 56 | python "${EPISODIC_CURIOSITY_DIR}/scripts/launcher_script.py" --workdir="${WORKDIR}" --method="${METHOD}" --scenario="${SCENARIO}" --run_number="${RUN_NUMBER}" ${R_NETWORKS_PATH_FLAG} 57 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Setup script for installing episodic_curiosity as a pip module.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import setuptools 21 | 22 | VERSION = '1.0.0' 23 | 24 | install_requires = [ 25 | # See installation instructions: 26 | # https://github.com/deepmind/lab/tree/master/python/pip_package 27 | 'DeepMind-Lab', 28 | 'absl-py>=0.7.0', 29 | 'dill>=0.2.9', 30 | 'enum>=0.4.7', 31 | # Won't be needed anymore when moving to python3. 32 | 'futures>=3.2.0', 33 | 'gin-config>=0.1.2', 34 | 'gym>=0.10.9', 35 | 'numpy>=1.16.0', 36 | 'opencv-python>=4.0.0.21', 37 | 'pypng>=0.0.19', 38 | 'pytype>=2019.1.18', 39 | 'scikit-image>=0.14.2', 40 | 'six>=1.12.0', 41 | 'tensorflow-gpu>=1.12.0', 42 | ] 43 | 44 | description = ('Episodic Curiosity. This is the code that allows reproducing ' 45 | 'the results in the scientific paper ' 46 | 'https://arxiv.org/pdf/1810.02274.pdf.') 47 | 48 | 49 | setuptools.setup( 50 | name='episodic-curiosity', 51 | version=VERSION, 52 | packages=setuptools.find_packages(), 53 | description=description, 54 | long_description=description, 55 | url='https://github.com/google-research/episodic-curiosity', 56 | author='Google LLC', 57 | author_email='opensource@google.com', 58 | install_requires=install_requires, 59 | extras_require={ 60 | 'video': ['sk-video'], 61 | 'mujoco': [ 62 | # For installation of dm_control see: 63 | # https://github.com/deepmind/dm_control#requirements-and-installation. 64 | 'dm_control', 65 | 'functools32', 66 | 'scikit-image', 67 | ], 68 | }, 69 | license='Apache 2.0', 70 | keywords='reinforcement-learning curiosity exploration deepmind-lab', 71 | ) 72 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | -------------------------------------------------------------------------------- /third_party/baselines/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | Copyright (c) 2018 Google LLC (http://google.com) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /third_party/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | -------------------------------------------------------------------------------- /third_party/baselines/a2c/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | -------------------------------------------------------------------------------- /third_party/baselines/a2c/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import gym 4 | import numpy as np 5 | import tensorflow as tf 6 | from gym import spaces 7 | from collections import deque 8 | 9 | def sample(logits): 10 | noise = tf.random_uniform(tf.shape(logits)) 11 | return tf.argmax(logits - tf.log(-tf.log(noise)), 1) 12 | 13 | def cat_entropy(logits): 14 | a0 = logits - tf.reduce_max(logits, 1, keep_dims=True) 15 | ea0 = tf.exp(a0) 16 | z0 = tf.reduce_sum(ea0, 1, keep_dims=True) 17 | p0 = ea0 / z0 18 | return tf.reduce_sum(p0 * (tf.log(z0) - a0), 1) 19 | 20 | def cat_entropy_softmax(p0): 21 | return - tf.reduce_sum(p0 * tf.log(p0 + 1e-6), axis = 1) 22 | 23 | def mse(pred, target): 24 | return tf.square(pred-target)/2. 25 | 26 | def ortho_init(scale=1.0): 27 | def _ortho_init(shape, dtype, partition_info=None): 28 | #lasagne ortho init for tf 29 | shape = tuple(shape) 30 | if len(shape) == 2: 31 | flat_shape = shape 32 | elif len(shape) == 4: # assumes NHWC 33 | flat_shape = (np.prod(shape[:-1]), shape[-1]) 34 | else: 35 | raise NotImplementedError 36 | a = np.random.normal(0.0, 1.0, flat_shape) 37 | u, _, v = np.linalg.svd(a, full_matrices=False) 38 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 39 | q = q.reshape(shape) 40 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32) 41 | return _ortho_init 42 | 43 | def conv(x, scope, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False): 44 | if data_format == 'NHWC': 45 | channel_ax = 3 46 | strides = [1, stride, stride, 1] 47 | bshape = [1, 1, 1, nf] 48 | elif data_format == 'NCHW': 49 | channel_ax = 1 50 | strides = [1, 1, stride, stride] 51 | bshape = [1, nf, 1, 1] 52 | else: 53 | raise NotImplementedError 54 | bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1] 55 | nin = x.get_shape()[channel_ax].value 56 | wshape = [rf, rf, nin, nf] 57 | with tf.variable_scope(scope): 58 | w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale)) 59 | b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0)) 60 | if not one_dim_bias and data_format == 'NHWC': 61 | b = tf.reshape(b, bshape) 62 | return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) 63 | 64 | def fc(x, scope, nh, init_scale=1.0, init_bias=0.0): 65 | with tf.variable_scope(scope): 66 | nin = x.get_shape()[1].value 67 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) 68 | b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias)) 69 | return tf.matmul(x, w)+b 70 | 71 | def batch_to_seq(h, nbatch, nsteps, flat=False): 72 | if flat: 73 | h = tf.reshape(h, [nbatch, nsteps]) 74 | else: 75 | h = tf.reshape(h, [nbatch, nsteps, -1]) 76 | return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)] 77 | 78 | def seq_to_batch(h, flat = False): 79 | shape = h[0].get_shape().as_list() 80 | if not flat: 81 | assert(len(shape) > 1) 82 | nh = h[0].get_shape()[-1].value 83 | return tf.reshape(tf.concat(axis=1, values=h), [-1, nh]) 84 | else: 85 | return tf.reshape(tf.stack(values=h, axis=1), [-1]) 86 | 87 | def lstm(xs, ms, s, scope, nh, init_scale=1.0): 88 | nbatch, nin = [v.value for v in xs[0].get_shape()] 89 | nsteps = len(xs) 90 | with tf.variable_scope(scope): 91 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) 92 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) 93 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0)) 94 | 95 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s) 96 | for idx, (x, m) in enumerate(zip(xs, ms)): 97 | c = c*(1-m) 98 | h = h*(1-m) 99 | z = tf.matmul(x, wx) + tf.matmul(h, wh) + b 100 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) 101 | i = tf.nn.sigmoid(i) 102 | f = tf.nn.sigmoid(f) 103 | o = tf.nn.sigmoid(o) 104 | u = tf.tanh(u) 105 | c = f*c + i*u 106 | h = o*tf.tanh(c) 107 | xs[idx] = h 108 | s = tf.concat(axis=1, values=[c, h]) 109 | return xs, s 110 | 111 | def _ln(x, g, b, e=1e-5, axes=[1]): 112 | u, s = tf.nn.moments(x, axes=axes, keep_dims=True) 113 | x = (x-u)/tf.sqrt(s+e) 114 | x = x*g+b 115 | return x 116 | 117 | def lnlstm(xs, ms, s, scope, nh, init_scale=1.0): 118 | nbatch, nin = [v.value for v in xs[0].get_shape()] 119 | nsteps = len(xs) 120 | with tf.variable_scope(scope): 121 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) 122 | gx = tf.get_variable("gx", [nh*4], initializer=tf.constant_initializer(1.0)) 123 | bx = tf.get_variable("bx", [nh*4], initializer=tf.constant_initializer(0.0)) 124 | 125 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) 126 | gh = tf.get_variable("gh", [nh*4], initializer=tf.constant_initializer(1.0)) 127 | bh = tf.get_variable("bh", [nh*4], initializer=tf.constant_initializer(0.0)) 128 | 129 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0)) 130 | 131 | gc = tf.get_variable("gc", [nh], initializer=tf.constant_initializer(1.0)) 132 | bc = tf.get_variable("bc", [nh], initializer=tf.constant_initializer(0.0)) 133 | 134 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s) 135 | for idx, (x, m) in enumerate(zip(xs, ms)): 136 | c = c*(1-m) 137 | h = h*(1-m) 138 | z = _ln(tf.matmul(x, wx), gx, bx) + _ln(tf.matmul(h, wh), gh, bh) + b 139 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) 140 | i = tf.nn.sigmoid(i) 141 | f = tf.nn.sigmoid(f) 142 | o = tf.nn.sigmoid(o) 143 | u = tf.tanh(u) 144 | c = f*c + i*u 145 | h = o*tf.tanh(_ln(c, gc, bc)) 146 | xs[idx] = h 147 | s = tf.concat(axis=1, values=[c, h]) 148 | return xs, s 149 | 150 | def conv_to_fc(x): 151 | nh = np.prod([v.value for v in x.get_shape()[1:]]) 152 | x = tf.reshape(x, [-1, nh]) 153 | return x 154 | 155 | def discount_with_dones(rewards, dones, gamma): 156 | discounted = [] 157 | r = 0 158 | for reward, done in zip(rewards[::-1], dones[::-1]): 159 | r = reward + gamma*r*(1.-done) # fixed off by one bug 160 | discounted.append(r) 161 | return discounted[::-1] 162 | 163 | def find_trainable_variables(key): 164 | with tf.variable_scope(key): 165 | return tf.trainable_variables() 166 | 167 | def make_path(f): 168 | return os.makedirs(f, exist_ok=True) 169 | 170 | def constant(p): 171 | return 1 172 | 173 | def linear(p): 174 | return 1-p 175 | 176 | def middle_drop(p): 177 | eps = 0.75 178 | if 1-p 0 28 | obs = None 29 | for _ in range(noops): 30 | obs, _, done, _ = self.env.step(self.noop_action) 31 | if done: 32 | obs = self.env.reset(**kwargs) 33 | return obs 34 | 35 | def step(self, ac): 36 | return self.env.step(ac) 37 | 38 | class FireResetEnv(gym.Wrapper): 39 | def __init__(self, env): 40 | """Take action on reset for environments that are fixed until firing.""" 41 | gym.Wrapper.__init__(self, env) 42 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 43 | assert len(env.unwrapped.get_action_meanings()) >= 3 44 | 45 | def reset(self, **kwargs): 46 | self.env.reset(**kwargs) 47 | obs, _, done, _ = self.env.step(1) 48 | if done: 49 | self.env.reset(**kwargs) 50 | obs, _, done, _ = self.env.step(2) 51 | if done: 52 | self.env.reset(**kwargs) 53 | return obs 54 | 55 | def step(self, ac): 56 | return self.env.step(ac) 57 | 58 | class EpisodicLifeEnv(gym.Wrapper): 59 | def __init__(self, env): 60 | """Make end-of-life == end-of-episode, but only reset on true game over. 61 | Done by DeepMind for the DQN and co. since it helps value estimation. 62 | """ 63 | gym.Wrapper.__init__(self, env) 64 | self.lives = 0 65 | self.was_real_done = True 66 | 67 | def step(self, action): 68 | obs, reward, done, info = self.env.step(action) 69 | self.was_real_done = done 70 | # check current lives, make loss of life terminal, 71 | # then update lives to handle bonus lives 72 | lives = self.env.unwrapped.ale.lives() 73 | if lives < self.lives and lives > 0: 74 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames 75 | # so its important to keep lives > 0, so that we only reset once 76 | # the environment advertises done. 77 | done = True 78 | self.lives = lives 79 | return obs, reward, done, info 80 | 81 | def reset(self, **kwargs): 82 | """Reset only when lives are exhausted. 83 | This way all states are still reachable even though lives are episodic, 84 | and the learner need not know about any of this behind-the-scenes. 85 | """ 86 | if self.was_real_done: 87 | obs = self.env.reset(**kwargs) 88 | else: 89 | # no-op step to advance from terminal/lost life state 90 | obs, _, _, _ = self.env.step(0) 91 | self.lives = self.env.unwrapped.ale.lives() 92 | return obs 93 | 94 | class MaxAndSkipEnv(gym.Wrapper): 95 | def __init__(self, env, skip=4): 96 | """Return only every `skip`-th frame""" 97 | gym.Wrapper.__init__(self, env) 98 | # most recent raw observations (for max pooling across time steps) 99 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 100 | self._skip = skip 101 | 102 | def step(self, action): 103 | """Repeat action, sum reward, and max over last observations.""" 104 | total_reward = 0.0 105 | done = None 106 | for i in range(self._skip): 107 | obs, reward, done, info = self.env.step(action) 108 | if i == self._skip - 2: self._obs_buffer[0] = obs 109 | if i == self._skip - 1: self._obs_buffer[1] = obs 110 | total_reward += reward 111 | if done: 112 | break 113 | # Note that the observation on the done=True frame 114 | # doesn't matter 115 | max_frame = self._obs_buffer.max(axis=0) 116 | 117 | return max_frame, total_reward, done, info 118 | 119 | def reset(self, **kwargs): 120 | return self.env.reset(**kwargs) 121 | 122 | class ClipRewardEnv(gym.RewardWrapper): 123 | def __init__(self, env): 124 | gym.RewardWrapper.__init__(self, env) 125 | 126 | def reward(self, reward): 127 | """Bin reward to {+1, 0, -1} by its sign.""" 128 | return np.sign(reward) 129 | 130 | class WarpFrame(gym.ObservationWrapper): 131 | def __init__(self, env): 132 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 133 | gym.ObservationWrapper.__init__(self, env) 134 | self.width = 84 135 | self.height = 84 136 | self.observation_space = spaces.Box(low=0, high=255, 137 | shape=(self.height, self.width, 1), dtype=np.uint8) 138 | 139 | def observation(self, frame): 140 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 141 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 142 | return frame[:, :, None] 143 | 144 | class FrameStack(gym.Wrapper): 145 | def __init__(self, env, k): 146 | """Stack k last frames. 147 | 148 | Returns lazy array, which is much more memory efficient. 149 | 150 | See Also 151 | -------- 152 | baselines.common.atari_wrappers.LazyFrames 153 | """ 154 | gym.Wrapper.__init__(self, env) 155 | self.k = k 156 | self.frames = deque([], maxlen=k) 157 | shp = env.observation_space.shape 158 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) 159 | 160 | def reset(self): 161 | ob = self.env.reset() 162 | for _ in range(self.k): 163 | self.frames.append(ob) 164 | return self._get_ob() 165 | 166 | def step(self, action): 167 | ob, reward, done, info = self.env.step(action) 168 | self.frames.append(ob) 169 | return self._get_ob(), reward, done, info 170 | 171 | def _get_ob(self): 172 | assert len(self.frames) == self.k 173 | return LazyFrames(list(self.frames)) 174 | 175 | class ScaledFloatFrame(gym.ObservationWrapper): 176 | def __init__(self, env): 177 | gym.ObservationWrapper.__init__(self, env) 178 | 179 | def observation(self, observation): 180 | # careful! This undoes the memory optimization, use 181 | # with smaller replay buffers only. 182 | return np.array(observation).astype(np.float32) / 255.0 183 | 184 | class LazyFrames(object): 185 | def __init__(self, frames): 186 | """This object ensures that common frames between the observations are only stored once. 187 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 188 | buffers. 189 | 190 | This object should only be converted to numpy array before being passed to the model. 191 | 192 | You'd not believe how complex the previous solution was.""" 193 | self._frames = frames 194 | self._out = None 195 | 196 | def _force(self): 197 | if self._out is None: 198 | self._out = np.concatenate(self._frames, axis=2) 199 | self._frames = None 200 | return self._out 201 | 202 | def __array__(self, dtype=None): 203 | out = self._force() 204 | if dtype is not None: 205 | out = out.astype(dtype) 206 | return out 207 | 208 | def __len__(self): 209 | return len(self._force()) 210 | 211 | def __getitem__(self, i): 212 | return self._force()[i] 213 | 214 | def make_atari(env_id): 215 | env = gym.make(env_id) 216 | assert 'NoFrameskip' in env.spec.id 217 | env = NoopResetEnv(env, noop_max=30) 218 | env = MaxAndSkipEnv(env, skip=4) 219 | return env 220 | 221 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 222 | """Configure environment for DeepMind-style Atari. 223 | """ 224 | if episode_life: 225 | env = EpisodicLifeEnv(env) 226 | if 'FIRE' in env.unwrapped.get_action_meanings(): 227 | env = FireResetEnv(env) 228 | env = WarpFrame(env) 229 | if scale: 230 | env = ScaledFloatFrame(env) 231 | if clip_rewards: 232 | env = ClipRewardEnv(env) 233 | if frame_stack: 234 | env = FrameStack(env, 4) 235 | return env 236 | 237 | -------------------------------------------------------------------------------- /third_party/baselines/common/cmd_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Helpers for scripts like run_atari.py. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import os 10 | from third_party.baselines import logger 11 | from third_party.baselines.bench import Monitor 12 | from third_party.baselines.common import set_global_seeds 13 | from third_party.baselines.common.atari_wrappers import make_atari 14 | from third_party.baselines.common.atari_wrappers import wrap_deepmind 15 | from third_party.baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 16 | 17 | 18 | def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0, 19 | use_monitor=True): 20 | """Create a wrapped, monitored SubprocVecEnv for Atari. 21 | """ 22 | if wrapper_kwargs is None: wrapper_kwargs = {} 23 | def make_env(rank): # pylint: disable=C0111 24 | def _thunk(): 25 | env = make_atari(env_id) 26 | env.seed(seed + rank) 27 | if use_monitor: 28 | env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), 29 | str(rank))) 30 | return wrap_deepmind(env, **wrapper_kwargs) 31 | return _thunk 32 | set_global_seeds(seed) 33 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 34 | -------------------------------------------------------------------------------- /third_party/baselines/common/input.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from gym.spaces import Discrete, Box 4 | 5 | def observation_input(ob_space, batch_size=None, name='Ob'): 6 | ''' 7 | Build observation input with encoding depending on the 8 | observation space type 9 | Params: 10 | 11 | ob_space: observation space (should be one of gym.spaces) 12 | batch_size: batch size for input (default is None, so that resulting input placeholder can take tensors with any batch size) 13 | name: tensorflow variable name for input placeholder 14 | 15 | returns: tuple (input_placeholder, processed_input_tensor) 16 | ''' 17 | if isinstance(ob_space, Discrete): 18 | input_x = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name) 19 | processed_x = tf.to_float(tf.one_hot(input_x, ob_space.n)) 20 | return input_x, processed_x 21 | 22 | elif isinstance(ob_space, Box): 23 | input_shape = (batch_size,) + ob_space.shape 24 | input_x = tf.placeholder(shape=input_shape, dtype=ob_space.dtype, name=name) 25 | processed_x = tf.to_float(input_x) 26 | return input_x, processed_x 27 | 28 | else: 29 | raise NotImplementedError 30 | 31 | 32 | -------------------------------------------------------------------------------- /third_party/baselines/common/math_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | import scipy.signal 4 | 5 | 6 | def discount(x, gamma): 7 | """ 8 | computes discounted sums along 0th dimension of x. 9 | 10 | inputs 11 | ------ 12 | x: ndarray 13 | gamma: float 14 | 15 | outputs 16 | ------- 17 | y: ndarray with same shape as x, satisfying 18 | 19 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 20 | where k = len(x) - t - 1 21 | 22 | """ 23 | assert x.ndim >= 1 24 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1] 25 | 26 | def explained_variance(ypred,y): 27 | """ 28 | Computes fraction of variance that ypred explains about y. 29 | Returns 1 - Var[y-ypred] / Var[y] 30 | 31 | interpretation: 32 | ev=0 => might as well have predicted zero 33 | ev=1 => perfect prediction 34 | ev<0 => worse than just predicting zero 35 | 36 | """ 37 | assert y.ndim == 1 and ypred.ndim == 1 38 | vary = np.var(y) 39 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary 40 | 41 | def explained_variance_2d(ypred, y): 42 | assert y.ndim == 2 and ypred.ndim == 2 43 | vary = np.var(y, axis=0) 44 | out = 1 - np.var(y-ypred)/vary 45 | out[vary < 1e-10] = 0 46 | return out 47 | 48 | def ncc(ypred, y): 49 | return np.corrcoef(ypred, y)[1,0] 50 | 51 | def flatten_arrays(arrs): 52 | return np.concatenate([arr.flat for arr in arrs]) 53 | 54 | def unflatten_vector(vec, shapes): 55 | i=0 56 | arrs = [] 57 | for shape in shapes: 58 | size = np.prod(shape) 59 | arr = vec[i:i+size].reshape(shape) 60 | arrs.append(arr) 61 | i += size 62 | return arrs 63 | 64 | def discount_with_boundaries(X, New, gamma): 65 | """ 66 | X: 2d array of floats, time x features 67 | New: 2d array of bools, indicating when a new episode has started 68 | """ 69 | Y = np.zeros_like(X) 70 | T = X.shape[0] 71 | Y[T-1] = X[T-1] 72 | for t in range(T-2, -1, -1): 73 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1]) 74 | return Y 75 | 76 | def test_discount_with_boundaries(): 77 | gamma=0.9 78 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32') 79 | starts = [1.0, 0.0, 0.0, 1.0] 80 | y = discount_with_boundaries(x, starts, gamma) 81 | assert np.allclose(y, [ 82 | 1 + gamma * 2 + gamma**2 * 3, 83 | 2 + gamma * 3, 84 | 3, 85 | 4 86 | ]) 87 | -------------------------------------------------------------------------------- /third_party/baselines/common/misc_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import gym 3 | import numpy as np 4 | import os 5 | import pickle 6 | import random 7 | import tempfile 8 | import zipfile 9 | 10 | 11 | def zipsame(*seqs): 12 | L = len(seqs[0]) 13 | assert all(len(seq) == L for seq in seqs[1:]) 14 | return zip(*seqs) 15 | 16 | 17 | def unpack(seq, sizes): 18 | """ 19 | Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'. 20 | None = just one bare element, not a list 21 | 22 | Example: 23 | unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6]) 24 | """ 25 | seq = list(seq) 26 | it = iter(seq) 27 | assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes) 28 | for size in sizes: 29 | if size is None: 30 | yield it.__next__() 31 | else: 32 | li = [] 33 | for _ in range(size): 34 | li.append(it.__next__()) 35 | yield li 36 | 37 | 38 | class EzPickle(object): 39 | """Objects that are pickled and unpickled via their constructor 40 | arguments. 41 | 42 | Example usage: 43 | 44 | class Dog(Animal, EzPickle): 45 | def __init__(self, furcolor, tailkind="bushy"): 46 | Animal.__init__() 47 | EzPickle.__init__(furcolor, tailkind) 48 | ... 49 | 50 | When this object is unpickled, a new Dog will be constructed by passing the provided 51 | furcolor and tailkind into the constructor. However, philosophers are still not sure 52 | whether it is still the same dog. 53 | 54 | This is generally needed only for environments which wrap C/C++ code, such as MuJoCo 55 | and Atari. 56 | """ 57 | 58 | def __init__(self, *args, **kwargs): 59 | self._ezpickle_args = args 60 | self._ezpickle_kwargs = kwargs 61 | 62 | def __getstate__(self): 63 | return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs} 64 | 65 | def __setstate__(self, d): 66 | out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"]) 67 | self.__dict__.update(out.__dict__) 68 | 69 | 70 | def set_global_seeds(i): 71 | try: 72 | import tensorflow as tf 73 | except ImportError: 74 | pass 75 | else: 76 | tf.set_random_seed(i) 77 | np.random.seed(i) 78 | random.seed(i) 79 | 80 | 81 | def pretty_eta(seconds_left): 82 | """Print the number of seconds in human readable format. 83 | 84 | Examples: 85 | 2 days 86 | 2 hours and 37 minutes 87 | less than a minute 88 | 89 | Paramters 90 | --------- 91 | seconds_left: int 92 | Number of seconds to be converted to the ETA 93 | Returns 94 | ------- 95 | eta: str 96 | String representing the pretty ETA. 97 | """ 98 | minutes_left = seconds_left // 60 99 | seconds_left %= 60 100 | hours_left = minutes_left // 60 101 | minutes_left %= 60 102 | days_left = hours_left // 24 103 | hours_left %= 24 104 | 105 | def helper(cnt, name): 106 | return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else '')) 107 | 108 | if days_left > 0: 109 | msg = helper(days_left, 'day') 110 | if hours_left > 0: 111 | msg += ' and ' + helper(hours_left, 'hour') 112 | return msg 113 | if hours_left > 0: 114 | msg = helper(hours_left, 'hour') 115 | if minutes_left > 0: 116 | msg += ' and ' + helper(minutes_left, 'minute') 117 | return msg 118 | if minutes_left > 0: 119 | return helper(minutes_left, 'minute') 120 | return 'less than a minute' 121 | 122 | 123 | class RunningAvg(object): 124 | def __init__(self, gamma, init_value=None): 125 | """Keep a running estimate of a quantity. This is a bit like mean 126 | but more sensitive to recent changes. 127 | 128 | Parameters 129 | ---------- 130 | gamma: float 131 | Must be between 0 and 1, where 0 is the most sensitive to recent 132 | changes. 133 | init_value: float or None 134 | Initial value of the estimate. If None, it will be set on the first update. 135 | """ 136 | self._value = init_value 137 | self._gamma = gamma 138 | 139 | def update(self, new_val): 140 | """Update the estimate. 141 | 142 | Parameters 143 | ---------- 144 | new_val: float 145 | new observated value of estimated quantity. 146 | """ 147 | if self._value is None: 148 | self._value = new_val 149 | else: 150 | self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val 151 | 152 | def __float__(self): 153 | """Get the current estimate""" 154 | return self._value 155 | 156 | def boolean_flag(parser, name, default=False, help=None): 157 | """Add a boolean flag to argparse parser. 158 | 159 | Parameters 160 | ---------- 161 | parser: argparse.Parser 162 | parser to add the flag to 163 | name: str 164 | -- will enable the flag, while --no- will disable it 165 | default: bool or None 166 | default value of the flag 167 | help: str 168 | help string for the flag 169 | """ 170 | dest = name.replace('-', '_') 171 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help) 172 | parser.add_argument("--no-" + name, action="store_false", dest=dest) 173 | 174 | 175 | def get_wrapper_by_name(env, classname): 176 | """Given an a gym environment possibly wrapped multiple times, returns a wrapper 177 | of class named classname or raises ValueError if no such wrapper was applied 178 | 179 | Parameters 180 | ---------- 181 | env: gym.Env of gym.Wrapper 182 | gym environment 183 | classname: str 184 | name of the wrapper 185 | 186 | Returns 187 | ------- 188 | wrapper: gym.Wrapper 189 | wrapper named classname 190 | """ 191 | currentenv = env 192 | while True: 193 | if classname == currentenv.class_name(): 194 | return currentenv 195 | elif isinstance(currentenv, gym.Wrapper): 196 | currentenv = currentenv.env 197 | else: 198 | raise ValueError("Couldn't find wrapper named %s" % classname) 199 | 200 | 201 | def relatively_safe_pickle_dump(obj, path, compression=False): 202 | """This is just like regular pickle dump, except from the fact that failure cases are 203 | different: 204 | 205 | - It's never possible that we end up with a pickle in corrupted state. 206 | - If a there was a different file at the path, that file will remain unchanged in the 207 | even of failure (provided that filesystem rename is atomic). 208 | - it is sometimes possible that we end up with useless temp file which needs to be 209 | deleted manually (it will be removed automatically on the next function call) 210 | 211 | The indended use case is periodic checkpoints of experiment state, such that we never 212 | corrupt previous checkpoints if the current one fails. 213 | 214 | Parameters 215 | ---------- 216 | obj: object 217 | object to pickle 218 | path: str 219 | path to the output file 220 | compression: bool 221 | if true pickle will be compressed 222 | """ 223 | temp_storage = path + ".relatively_safe" 224 | if compression: 225 | # Using gzip here would be simpler, but the size is limited to 2GB 226 | with tempfile.NamedTemporaryFile() as uncompressed_file: 227 | pickle.dump(obj, uncompressed_file) 228 | uncompressed_file.file.flush() 229 | with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip: 230 | myzip.write(uncompressed_file.name, "data") 231 | else: 232 | with open(temp_storage, "wb") as f: 233 | pickle.dump(obj, f) 234 | os.rename(temp_storage, path) 235 | 236 | 237 | def pickle_load(path, compression=False): 238 | """Unpickle a possible compressed pickle. 239 | 240 | Parameters 241 | ---------- 242 | path: str 243 | path to the output file 244 | compression: bool 245 | if true assumes that pickle was compressed when created and attempts decompression. 246 | 247 | Returns 248 | ------- 249 | obj: object 250 | the unpickled object 251 | """ 252 | 253 | if compression: 254 | with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip: 255 | with myzip.open("data") as f: 256 | return pickle.load(f) 257 | else: 258 | with open(path, "rb") as f: 259 | return pickle.load(f) 260 | -------------------------------------------------------------------------------- /third_party/baselines/common/runners.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | import abc 4 | from abc import abstractmethod 5 | 6 | class AbstractEnvRunner(object): 7 | __metaclass__ = abc.ABCMeta 8 | 9 | def __init__(self, env, model, nsteps): 10 | self.env = env 11 | self.model = model 12 | nenv = env.num_envs 13 | self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape 14 | self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name) 15 | self.obs[:] = env.reset() 16 | self.nsteps = nsteps 17 | self.states = model.initial_state 18 | self.dones = [False for _ in range(nenv)] 19 | 20 | @abstractmethod 21 | def run(self): 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /third_party/baselines/common/tile_images.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | 4 | def tile_images(img_nhwc): 5 | """ 6 | Tile N images into one big PxQ image 7 | (P,Q) are chosen to be as close as possible, and if N 8 | is square, then P=Q. 9 | 10 | input: img_nhwc, list or array of images, ndim=4 once turned into array 11 | n = batch index, h = height, w = width, c = channel 12 | returns: 13 | bigim_HWc, ndarray with ndim=3 14 | """ 15 | img_nhwc = np.asarray(img_nhwc) 16 | N, h, w, c = img_nhwc.shape 17 | H = int(np.ceil(np.sqrt(N))) 18 | W = int(np.ceil(float(N)/H)) 19 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 20 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 21 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 22 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 23 | return img_Hh_Ww_c 24 | 25 | -------------------------------------------------------------------------------- /third_party/baselines/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from abc import ABCMeta, abstractmethod 7 | from third_party.baselines import logger 8 | 9 | class AlreadySteppingError(Exception): 10 | """ 11 | Raised when an asynchronous step is running while 12 | step_async() is called again. 13 | """ 14 | def __init__(self): 15 | msg = 'already running an async step' 16 | Exception.__init__(self, msg) 17 | 18 | class NotSteppingError(Exception): 19 | """ 20 | Raised when an asynchronous step is not running but 21 | step_wait() is called. 22 | """ 23 | def __init__(self): 24 | msg = 'not running an async step' 25 | Exception.__init__(self, msg) 26 | 27 | class VecEnv(object): 28 | """ 29 | An abstract asynchronous, vectorized environment. 30 | """ 31 | __metaclass__ = ABCMeta 32 | 33 | def __init__(self, num_envs, observation_space, action_space): 34 | self.num_envs = num_envs 35 | self.observation_space = observation_space 36 | self.action_space = action_space 37 | 38 | @abstractmethod 39 | def reset(self): 40 | """ 41 | Reset all the environments and return an array of 42 | observations, or a tuple of observation arrays. 43 | 44 | If step_async is still doing work, that work will 45 | be cancelled and step_wait() should not be called 46 | until step_async() is invoked again. 47 | """ 48 | pass 49 | 50 | @abstractmethod 51 | def step_async(self, actions): 52 | """ 53 | Tell all the environments to start taking a step 54 | with the given actions. 55 | Call step_wait() to get the results of the step. 56 | 57 | You should not call this if a step_async run is 58 | already pending. 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def step_wait(self): 64 | """ 65 | Wait for the step taken with step_async(). 66 | 67 | Returns (obs, rews, dones, infos): 68 | - obs: an array of observations, or a tuple of 69 | arrays of observations. 70 | - rews: an array of rewards 71 | - dones: an array of "episode done" booleans 72 | - infos: a sequence of info objects 73 | """ 74 | pass 75 | 76 | @abstractmethod 77 | def close(self): 78 | """ 79 | Clean up the environments' resources. 80 | """ 81 | pass 82 | 83 | def step(self, actions): 84 | self.step_async(actions) 85 | return self.step_wait() 86 | 87 | def render(self, mode='human'): 88 | logger.warn('Render not defined for %s'%self) 89 | 90 | @property 91 | def unwrapped(self): 92 | if isinstance(self, VecEnvWrapper): 93 | return self.venv.unwrapped 94 | else: 95 | return self 96 | 97 | class VecEnvWrapper(VecEnv): 98 | def __init__(self, venv, observation_space=None, action_space=None): 99 | self.venv = venv 100 | VecEnv.__init__(self, 101 | num_envs=venv.num_envs, 102 | observation_space=observation_space or venv.observation_space, 103 | action_space=action_space or venv.action_space) 104 | 105 | def step_async(self, actions): 106 | self.venv.step_async(actions) 107 | 108 | @abstractmethod 109 | def reset(self): 110 | pass 111 | 112 | @abstractmethod 113 | def step_wait(self): 114 | pass 115 | 116 | def close(self): 117 | return self.venv.close() 118 | 119 | def render(self): 120 | self.venv.render() 121 | 122 | class CloudpickleWrapper(object): 123 | """ 124 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 125 | """ 126 | def __init__(self, x): 127 | self.x = x 128 | def __getstate__(self): 129 | import cloudpickle 130 | return cloudpickle.dumps(self.x) 131 | def __setstate__(self, ob): 132 | import pickle 133 | self.x = pickle.loads(ob) 134 | -------------------------------------------------------------------------------- /third_party/baselines/common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from gym import spaces 8 | from collections import OrderedDict 9 | from third_party.baselines.common.vec_env import VecEnv 10 | 11 | class DummyVecEnv(VecEnv): 12 | def __init__(self, env_fns): 13 | self.envs = [fn() for fn in env_fns] 14 | env = self.envs[0] 15 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 16 | shapes, dtypes = {}, {} 17 | self.keys = [] 18 | obs_space = env.observation_space 19 | 20 | if isinstance(obs_space, spaces.Dict): 21 | assert isinstance(obs_space.spaces, OrderedDict) 22 | subspaces = obs_space.spaces 23 | else: 24 | subspaces = {None: obs_space} 25 | 26 | for key, box in subspaces.items(): 27 | shapes[key] = box.shape 28 | dtypes[key] = box.dtype 29 | self.keys.append(key) 30 | 31 | self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } 32 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) 33 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 34 | self.buf_infos = [{} for _ in range(self.num_envs)] 35 | self.actions = None 36 | 37 | def step_async(self, actions): 38 | self.actions = actions 39 | 40 | def step_wait(self): 41 | for e in range(self.num_envs): 42 | obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e]) 43 | if self.buf_dones[e]: 44 | obs = self.envs[e].reset() 45 | self._save_obs(e, obs) 46 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), 47 | self.buf_infos.copy()) 48 | 49 | def reset(self): 50 | for e in range(self.num_envs): 51 | obs = self.envs[e].reset() 52 | self._save_obs(e, obs) 53 | return self._obs_from_buf() 54 | 55 | def close(self): 56 | return 57 | 58 | def render(self, mode='human'): 59 | return [e.render(mode=mode) for e in self.envs] 60 | 61 | def _save_obs(self, e, obs): 62 | for k in self.keys: 63 | if k is None: 64 | self.buf_obs[k][e] = obs 65 | else: 66 | self.buf_obs[k][e] = obs[k] 67 | 68 | def _obs_from_buf(self): 69 | if self.keys==[None]: 70 | return self.buf_obs[None] 71 | else: 72 | return self.buf_obs 73 | -------------------------------------------------------------------------------- /third_party/baselines/common/vec_env/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from multiprocessing import Process, Pipe 8 | from third_party.baselines.common.vec_env import VecEnv, CloudpickleWrapper 9 | from third_party.baselines.common.tile_images import tile_images 10 | 11 | 12 | def worker(remote, parent_remote, env_fn_wrapper): 13 | parent_remote.close() 14 | env = env_fn_wrapper.x() 15 | while True: 16 | cmd, data = remote.recv() 17 | if cmd == 'step': 18 | ob, reward, done, info = env.step(data) 19 | if done: 20 | ob = env.reset() 21 | remote.send((ob, reward, done, info)) 22 | elif cmd == 'reset': 23 | ob = env.reset() 24 | remote.send(ob) 25 | elif cmd == 'render': 26 | remote.send(env.render(mode='rgb_array')) 27 | elif cmd == 'close': 28 | remote.close() 29 | break 30 | elif cmd == 'get_spaces': 31 | remote.send((env.observation_space, env.action_space)) 32 | else: 33 | raise NotImplementedError 34 | 35 | 36 | class SubprocVecEnv(VecEnv): 37 | def __init__(self, env_fns, spaces=None): 38 | """ 39 | envs: list of gym environments to run in subprocesses 40 | """ 41 | self.waiting = False 42 | self.closed = False 43 | nenvs = len(env_fns) 44 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 45 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 46 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 47 | for p in self.ps: 48 | p.daemon = True # if the main process crashes, we should not cause things to hang 49 | p.start() 50 | for remote in self.work_remotes: 51 | remote.close() 52 | 53 | self.remotes[0].send(('get_spaces', None)) 54 | observation_space, action_space = self.remotes[0].recv() 55 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 56 | 57 | def step_async(self, actions): 58 | for remote, action in zip(self.remotes, actions): 59 | remote.send(('step', action)) 60 | self.waiting = True 61 | 62 | def step_wait(self): 63 | results = [remote.recv() for remote in self.remotes] 64 | self.waiting = False 65 | obs, rews, dones, infos = zip(*results) 66 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 67 | 68 | def reset(self): 69 | for remote in self.remotes: 70 | remote.send(('reset', None)) 71 | return np.stack([remote.recv() for remote in self.remotes]) 72 | 73 | def reset_task(self): 74 | for remote in self.remotes: 75 | remote.send(('reset_task', None)) 76 | return np.stack([remote.recv() for remote in self.remotes]) 77 | 78 | def close(self): 79 | if self.closed: 80 | return 81 | if self.waiting: 82 | for remote in self.remotes: 83 | remote.recv() 84 | for remote in self.remotes: 85 | remote.send(('close', None)) 86 | for p in self.ps: 87 | p.join() 88 | self.closed = True 89 | 90 | def render(self, mode='human'): 91 | for pipe in self.remotes: 92 | pipe.send(('render', None)) 93 | imgs = [pipe.recv() for pipe in self.remotes] 94 | bigimg = tile_images(imgs) 95 | if mode == 'human': 96 | import cv2 97 | cv2.imshow('vecenv', bigimg[:,:,::-1]) 98 | cv2.waitKey(1) 99 | elif mode == 'rgb_array': 100 | return bigimg 101 | else: 102 | raise NotImplementedError 103 | -------------------------------------------------------------------------------- /third_party/baselines/common/vec_env/threaded_vec_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """VecEnv implementation using python threads instead of subprocesses.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import threading 10 | from third_party.baselines.common.vec_env import VecEnv 11 | import numpy as np 12 | from six.moves import queue as Queue # pylint: disable=redefined-builtin 13 | 14 | 15 | def thread_worker(send_q, recv_q, env_fn): 16 | """Similar to SubprocVecEnv.worker(), but for TreadedVecEnv. 17 | 18 | Args: 19 | send_q: Queue which ThreadedVecEnv sends commands to. 20 | recv_q: Queue which ThreadedVecEnv receives commands from. 21 | env_fn: Callable that creates an instance of the environment. 22 | """ 23 | env = env_fn() 24 | while True: 25 | cmd, data = send_q.get() 26 | if cmd == 'step': 27 | ob, reward, done, info = env.step(data) 28 | if done: 29 | ob = env.reset() 30 | recv_q.put((ob, reward, done, info)) 31 | elif cmd == 'reset': 32 | ob = env.reset() 33 | recv_q.put(ob) 34 | elif cmd == 'render': 35 | recv_q.put(env.render(mode='rgb_array')) 36 | elif cmd == 'close': 37 | break 38 | elif cmd == 'get_spaces': 39 | recv_q.put((env.observation_space, env.action_space)) 40 | else: 41 | raise NotImplementedError 42 | 43 | 44 | class ThreadedVecEnv(VecEnv): 45 | """Similar to SubprocVecEnv, but uses python threads instead of subprocs. 46 | 47 | Sub-processes involve forks, and a lot of code (incl. google3's) is not 48 | fork-safe, leading to deadlocks. The drawback of python threads is that the 49 | python code is still executed serially because of the GIL. However, many 50 | environments do the heavy lifting in C++ (where the GIL is released, and 51 | hence execution can happen in parallel), so python threads are not often 52 | limiting. 53 | """ 54 | 55 | def __init__(self, env_fns, spaces=None): 56 | """ 57 | envs: list of gym environments to run in python threads. 58 | """ 59 | self.waiting = False 60 | self.closed = False 61 | nenvs = len(env_fns) 62 | self.send_queues = [Queue.Queue() for _ in range(nenvs)] 63 | self.recv_queues = [Queue.Queue() for _ in range(nenvs)] 64 | self.threads = [threading.Thread(target=thread_worker, 65 | args=(send_q, recv_q, env_fn)) 66 | for (send_q, recv_q, env_fn) in 67 | zip(self.send_queues, self.recv_queues, env_fns)] 68 | for thread in self.threads: 69 | thread.daemon = True 70 | thread.start() 71 | 72 | self.send_queues[0].put(('get_spaces', None)) 73 | observation_space, action_space = self.recv_queues[0].get() 74 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 75 | 76 | def step_async(self, actions): 77 | for send_q, action in zip(self.send_queues, actions): 78 | send_q.put(('step', action)) 79 | self.waiting = True 80 | 81 | def step_wait(self): 82 | results = self._receive_all() 83 | self.waiting = False 84 | obs, rews, dones, infos = zip(*results) 85 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 86 | 87 | def reset(self): 88 | self._send_all(('reset', None)) 89 | return np.stack(self._receive_all()) 90 | 91 | def reset_task(self): 92 | self._send_all(('reset_task', None)) 93 | return np.stack(self._receive_all()) 94 | 95 | def close(self): 96 | if self.closed: 97 | return 98 | if self.waiting: 99 | self._receive_all() 100 | self._send_all(('close', None)) 101 | for thread in self.threads: 102 | thread.join() 103 | self.closed = True 104 | 105 | def render(self, mode='human'): 106 | raise NotImplementedError 107 | 108 | def _send_all(self, item): 109 | for send_q in self.send_queues: 110 | send_q.put(item) 111 | 112 | def _receive_all(self): 113 | return [recv_q.get() for recv_q in self.recv_queues] 114 | -------------------------------------------------------------------------------- /third_party/baselines/ppo2/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | -------------------------------------------------------------------------------- /third_party/baselines/ppo2/pathak_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Utility functions used by Pathak's curiosity algorithm. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | 13 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad='SAME', 14 | dtype=tf.float32, collections=None, trainable=True): 15 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 16 | x = tf.to_float(x) 17 | stride_shape = [1, stride[0], stride[1], 1] 18 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), 19 | num_filters] 20 | 21 | # there are 'num input feature maps * filter height * filter width' 22 | # inputs to each hidden unit 23 | fan_in = np.prod(filter_shape[:3]) 24 | # each unit in the lower layer receives a gradient from: 25 | # 'num output feature maps * filter height * filter width' / 26 | # pooling size 27 | fan_out = np.prod(filter_shape[:2]) * num_filters 28 | # initialize weights with random weights 29 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 30 | 31 | w = tf.get_variable('W', filter_shape, dtype, 32 | tf.random_uniform_initializer(-w_bound, w_bound), 33 | collections=collections, trainable=trainable) 34 | b = tf.get_variable('b', [1, 1, 1, num_filters], 35 | initializer=tf.constant_initializer(0.0), 36 | collections=collections, trainable=trainable) 37 | return tf.nn.conv2d(x, w, stride_shape, pad) + b 38 | 39 | 40 | def flatten(x): 41 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) 42 | 43 | 44 | def normalized_columns_initializer(std=1.0): 45 | def _initializer(shape, dtype=None, partition_info=None): 46 | out = np.random.randn(*shape).astype(np.float32) 47 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) 48 | return tf.constant(out) 49 | return _initializer 50 | 51 | 52 | def linear(x, size, name, initializer=None, bias_init=0): 53 | w = tf.get_variable(name + '/w', [x.get_shape()[1], size], 54 | initializer=initializer) 55 | b = tf.get_variable(name + '/b', [size], 56 | initializer=tf.constant_initializer(bias_init)) 57 | return tf.matmul(x, w) + b 58 | 59 | 60 | def universeHead(x, nConvs=4, trainable=True): 61 | """Universe agent example. 62 | 63 | Args: 64 | x: input image 65 | nConvs: number of convolutional layers 66 | trainable: whether conv2d variables are trainable 67 | 68 | Returns: 69 | [None, 288] embedding 70 | """ 71 | print('Using universe head design') 72 | x = tf.image.resize_images(x, [42, 42]) 73 | x = tf.cast(x, tf.float32) / 255. 74 | for i in range(nConvs): 75 | x = tf.nn.elu(conv2d(x, 32, 'l{}'.format(i + 1), [3, 3], [2, 2], 76 | trainable=trainable)) 77 | # print('Loop{} '.format(i+1),tf.shape(x)) 78 | # print('Loop{}'.format(i+1),x.get_shape()) 79 | x = flatten(x) 80 | return x 81 | 82 | 83 | def icm_forward_model(encoded_state, action, num_actions, hidden_layer_size): 84 | action = tf.one_hot(action, num_actions) 85 | combined_input = tf.concat([encoded_state, action], axis=1) 86 | hidden = tf.nn.relu(linear(combined_input, hidden_layer_size, 'f1', 87 | normalized_columns_initializer(0.01))) 88 | pred_next_state = linear(hidden, encoded_state.get_shape()[1].value, 'flast', 89 | normalized_columns_initializer(0.01)) 90 | return pred_next_state 91 | 92 | 93 | def icm_inverse_model(encoded_state, encoded_next_state, num_actions, 94 | hidden_layer_size): 95 | combined_input = tf.concat([encoded_state, encoded_next_state], axis=1) 96 | hidden = tf.nn.relu(linear(combined_input, hidden_layer_size, 'g1', 97 | normalized_columns_initializer(0.01))) 98 | # Predicted action logits 99 | logits = linear(hidden, num_actions, 'glast', 100 | normalized_columns_initializer(0.01)) 101 | return logits 102 | -------------------------------------------------------------------------------- /third_party/baselines/ppo2/policies.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from third_party.baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm 7 | from third_party.baselines.common.distributions import make_pdtype 8 | from third_party.baselines.common.input import observation_input 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | def nature_cnn(unscaled_images, **conv_kwargs): 13 | """ 14 | CNN from Nature paper. 15 | """ 16 | scaled_images = tf.cast(unscaled_images, tf.float32) / 255. 17 | activ = tf.nn.relu 18 | h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), 19 | **conv_kwargs)) 20 | h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs)) 21 | h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs)) 22 | h3 = conv_to_fc(h3) 23 | return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) 24 | 25 | class LnLstmPolicy(object): 26 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): 27 | nenv = nbatch // nsteps 28 | X, processed_x = observation_input(ob_space, nbatch) 29 | M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) 30 | S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states 31 | self.pdtype = make_pdtype(ac_space) 32 | with tf.variable_scope("model", reuse=reuse): 33 | h = nature_cnn(processed_x) 34 | xs = batch_to_seq(h, nenv, nsteps) 35 | ms = batch_to_seq(M, nenv, nsteps) 36 | h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) 37 | h5 = seq_to_batch(h5) 38 | vf = fc(h5, 'v', 1) 39 | self.pd, self.pi = self.pdtype.pdfromlatent(h5) 40 | 41 | v0 = vf[:, 0] 42 | a0 = self.pd.sample() 43 | neglogp0 = self.pd.neglogp(a0) 44 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32) 45 | 46 | def step(ob, state, mask): 47 | return sess.run([a0, v0, snew, neglogp0], {X:ob, S:state, M:mask}) 48 | 49 | def value(ob, state, mask): 50 | return sess.run(v0, {X:ob, S:state, M:mask}) 51 | 52 | self.X = X 53 | self.M = M 54 | self.S = S 55 | self.vf = vf 56 | self.step = step 57 | self.value = value 58 | 59 | class LstmPolicy(object): 60 | 61 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False): 62 | nenv = nbatch // nsteps 63 | self.pdtype = make_pdtype(ac_space) 64 | X, processed_x = observation_input(ob_space, nbatch) 65 | 66 | M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) 67 | S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states 68 | with tf.variable_scope("model", reuse=reuse): 69 | h = nature_cnn(X) 70 | xs = batch_to_seq(h, nenv, nsteps) 71 | ms = batch_to_seq(M, nenv, nsteps) 72 | h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) 73 | h5 = seq_to_batch(h5) 74 | vf = fc(h5, 'v', 1) 75 | self.pd, self.pi = self.pdtype.pdfromlatent(h5) 76 | 77 | v0 = vf[:, 0] 78 | a0 = self.pd.sample() 79 | neglogp0 = self.pd.neglogp(a0) 80 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32) 81 | 82 | def step(ob, state, mask): 83 | return sess.run([a0, v0, snew, neglogp0], {X:ob, S:state, M:mask}) 84 | 85 | def value(ob, state, mask): 86 | return sess.run(v0, {X:ob, S:state, M:mask}) 87 | 88 | self.X = X 89 | self.M = M 90 | self.S = S 91 | self.vf = vf 92 | self.step = step 93 | self.value = value 94 | 95 | class CnnPolicy(object): 96 | 97 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613 98 | self.pdtype = make_pdtype(ac_space) 99 | X, processed_x = observation_input(ob_space, nbatch) 100 | with tf.variable_scope("model", reuse=reuse): 101 | h = nature_cnn(processed_x, **conv_kwargs) 102 | vf = fc(h, 'v', 1)[:,0] 103 | self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01) 104 | 105 | a0 = self.pd.sample() 106 | neglogp0 = self.pd.neglogp(a0) 107 | self.initial_state = None 108 | 109 | def step(ob, *_args, **_kwargs): 110 | a, v, neglogp = sess.run([a0, vf, neglogp0], {X:ob}) 111 | return a, v, self.initial_state, neglogp 112 | 113 | def value(ob, *_args, **_kwargs): 114 | return sess.run(vf, {X:ob}) 115 | 116 | self.X = X 117 | self.vf = vf 118 | self.step = step 119 | self.value = value 120 | 121 | class MlpPolicy(object): 122 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 123 | self.pdtype = make_pdtype(ac_space) 124 | with tf.variable_scope("model", reuse=reuse): 125 | X, processed_x = observation_input(ob_space, nbatch) 126 | activ = tf.tanh 127 | processed_x = tf.layers.flatten(processed_x) 128 | pi_h1 = activ(fc(processed_x, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) 129 | pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) 130 | vf_h1 = activ(fc(processed_x, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) 131 | vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) 132 | vf = fc(vf_h2, 'vf', 1)[:,0] 133 | 134 | self.pd, self.pi = self.pdtype.pdfromlatent(pi_h2, init_scale=0.01) 135 | 136 | 137 | a0 = self.pd.sample() 138 | neglogp0 = self.pd.neglogp(a0) 139 | self.initial_state = None 140 | 141 | def step(ob, *_args, **_kwargs): 142 | a, v, neglogp = sess.run([a0, vf, neglogp0], {X:ob}) 143 | return a, v, self.initial_state, neglogp 144 | 145 | def value(ob, *_args, **_kwargs): 146 | return sess.run(vf, {X:ob}) 147 | 148 | self.X = X 149 | self.vf = vf 150 | self.step = step 151 | self.value = value 152 | -------------------------------------------------------------------------------- /third_party/dmlab/dmlab_min_goal_distance.patch: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2016 Google Inc. 2 | # Copyright (C) 2018 Google Inc. 3 | # Copyright (C) 2019 Google Inc. 4 | # 5 | # This program is free software; you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation; either version 2 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License along 16 | # with this program; if not, write to the Free Software Foundation, Inc., 17 | # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 18 | # 19 | diff --git a/deepmind/engine/lua_maze_generation.cc b/deepmind/engine/lua_maze_generation.cc 20 | index 8c1b9a9..23c08f7 100644 21 | --- a/deepmind/engine/lua_maze_generation.cc 22 | +++ b/deepmind/engine/lua_maze_generation.cc 23 | @@ -242,6 +242,8 @@ lua::NResultsOr LuaMazeGeneration::CreateRandom(lua_State* L) { 24 | 25 | maze_generation::TextMaze maze(maze_generation::Size{height, width}); 26 | 27 | + table.LookUp("minGoalDistance", &maze.min_goal_distance_); 28 | + 29 | // Create random rooms. 30 | maze_generation::SeparateRectangleParams params{}; 31 | params.min_size = maze_generation::Size{room_min_size, room_min_size}; 32 | @@ -602,6 +604,14 @@ lua::NResultsOr LuaMazeGeneration::VisitRandomPath(lua_State* L) { 33 | return 1; 34 | } 35 | 36 | +lua::NResultsOr LuaMazeGeneration::MinGoalDistance(lua_State* L) { 37 | + if (lua_gettop(L) != 1) { 38 | + return "[MinGoalDistance] - No args expected"; 39 | + } 40 | + lua::Push(L, text_maze_.min_goal_distance_); 41 | + return 1; 42 | +} 43 | + 44 | // Registers classes metatable with Lua. 45 | // [0, 0, -] 46 | void LuaMazeGeneration::Register(lua_State* L) { 47 | @@ -625,6 +635,7 @@ void LuaMazeGeneration::Register(lua_State* L) { 48 | {"fromWorldPos", Class::Member<&LuaMazeGeneration::FromWorldPos>}, 49 | {"visitFill", Class::Member<&LuaMazeGeneration::VisitFill>}, 50 | {"visitRandomPath", Class::Member<&LuaMazeGeneration::VisitRandomPath>}, 51 | + {"minGoalDistance", Class::Member<&LuaMazeGeneration::MinGoalDistance>}, 52 | }; 53 | Class::Register(L, regs); 54 | LuaRoom::Register(L); 55 | diff --git a/deepmind/engine/lua_maze_generation.h b/deepmind/engine/lua_maze_generation.h 56 | index 5e83253..a2915f2 100644 57 | --- a/deepmind/engine/lua_maze_generation.h 58 | +++ b/deepmind/engine/lua_maze_generation.h 59 | @@ -196,6 +196,9 @@ class LuaMazeGeneration : public lua::Class { 60 | // [1, 1, e] 61 | lua::NResultsOr CountVariations(lua_State* L); 62 | 63 | + // Minimum distance between a player spawn and a goal. 64 | + lua::NResultsOr MinGoalDistance(lua_State* L); 65 | + 66 | maze_generation::TextMaze text_maze_; 67 | 68 | static std::uint64_t mixer_seq_; 69 | diff --git a/deepmind/level_generation/text_maze_generation/text_maze.cc b/deepmind/level_generation/text_maze_generation/text_maze.cc 70 | index cc5234e..0d50b12 100644 71 | --- a/deepmind/level_generation/text_maze_generation/text_maze.cc 72 | +++ b/deepmind/level_generation/text_maze_generation/text_maze.cc 73 | @@ -22,7 +22,7 @@ namespace deepmind { 74 | namespace lab { 75 | namespace maze_generation { 76 | 77 | -TextMaze::TextMaze(Size extents) : area_{{0, 0}, extents} { 78 | +TextMaze::TextMaze(Size extents) : area_{{0, 0}, extents}, min_goal_distance_(0) { 79 | std::string level_layer(area_.size.height * (area_.size.width + 1), '*'); 80 | std::string variations_layer(area_.size.height * (area_.size.width + 1), '.'); 81 | for (int i = 0; i < area_.size.height; ++i) { 82 | diff --git a/deepmind/level_generation/text_maze_generation/text_maze.h b/deepmind/level_generation/text_maze_generation/text_maze.h 83 | index 6b32a6d..bb7396c 100644 84 | --- a/deepmind/level_generation/text_maze_generation/text_maze.h 85 | +++ b/deepmind/level_generation/text_maze_generation/text_maze.h 86 | @@ -242,6 +242,9 @@ class TextMaze { 87 | Rectangle area_; 88 | std::array text_; 89 | std::vector ids_; 90 | + 91 | + public: 92 | + int min_goal_distance_; 93 | }; 94 | 95 | } // namespace maze_generation 96 | diff --git a/game_scripts/factories/explore/factory.lua b/game_scripts/factories/explore/factory.lua 97 | index 2352483..57e759d 100644 98 | --- a/game_scripts/factories/explore/factory.lua 99 | +++ b/game_scripts/factories/explore/factory.lua 100 | @@ -69,6 +69,7 @@ function factory.createLevelApi(kwargs) 101 | kwargs.opts.quickRestart = true 102 | end 103 | kwargs.opts.randomSeed = false 104 | + kwargs.opts.minGoalDistance = kwargs.opts.minGoalDistance or 0 105 | kwargs.opts.roomCount = kwargs.opts.roomCount or 4 106 | kwargs.opts.roomMaxSize = kwargs.opts.roomMaxSize or 5 107 | kwargs.opts.roomMinSize = kwargs.opts.roomMinSize or 3 108 | @@ -115,6 +116,7 @@ function factory.createLevelApi(kwargs) 109 | roomMinSize = kwargs.opts.roomMinSize, 110 | roomMaxSize = kwargs.opts.roomMaxSize, 111 | extraConnectionProbability = kwargs.opts.extraConnectionProbability, 112 | + minGoalDistance = kwargs.opts.minGoalDistance, 113 | } 114 | 115 | if kwargs.opts.decalScale and kwargs.opts.decalScale ~= 1 then 116 | diff --git a/game_scripts/factories/explore/goal_locations_factory.lua b/game_scripts/factories/explore/goal_locations_factory.lua 117 | index d3dd9bb..86e7db5 100644 118 | --- a/game_scripts/factories/explore/goal_locations_factory.lua 119 | +++ b/game_scripts/factories/explore/goal_locations_factory.lua 120 | @@ -31,9 +31,9 @@ function level:restart(maze) 121 | local spawnLocations = {} 122 | maze:visitFill{ 123 | cell = level._goalLocations[level._goalId], 124 | - func = function(i, j) 125 | + func = function(i, j, distance) 126 | local c = maze:getEntityCell(i, j) 127 | - if c == 'P' then 128 | + if c == 'P' and distance >= maze:minGoalDistance() then 129 | table.insert(spawnLocations, {i, j}) 130 | end 131 | end 132 | -------------------------------------------------------------------------------- /third_party/gym/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | Copyright (c) 2018 The TF-Agents Authors. 5 | Copyright (c) 2018 Google LLC (http://google.com) 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /third_party/gym/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | -------------------------------------------------------------------------------- /third_party/gym/ant.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # The MIT License 3 | # 4 | # Copyright (c) 2016 OpenAI (https://openai.com) 5 | # Copyright (c) 2018 The TF-Agents Authors. 6 | # Copyright (c) 2018 Google LLC (http://google.com) 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 24 | # THE SOFTWARE. 25 | 26 | """Ant environment.""" 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | from dm_control.mujoco.wrapper.mjbindings import enums 33 | from third_party.gym import mujoco_env 34 | from gym import utils 35 | import numpy as np 36 | 37 | 38 | # pylint: disable=missing-docstring 39 | class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): 40 | 41 | def __init__(self, 42 | expose_all_qpos=False, 43 | expose_body_coms=None, 44 | expose_body_comvels=None, 45 | model_path="ant.xml"): 46 | self._expose_all_qpos = expose_all_qpos 47 | self._expose_body_coms = expose_body_coms 48 | self._expose_body_comvels = expose_body_comvels 49 | self._body_com_indices = {} 50 | self._body_comvel_indices = {} 51 | # Settings from 52 | # https://github.com/openai/gym/blob/master/gym/envs/__init__.py 53 | mujoco_env.MujocoEnv.__init__( 54 | self, model_path, 5, max_episode_steps=1000, reward_threshold=6000.0) 55 | utils.EzPickle.__init__(self) 56 | 57 | self.camera_setup() 58 | 59 | def step(self, a): 60 | xposbefore = self.get_body_com("torso")[0] 61 | self.do_simulation(a, self.frame_skip) 62 | xposafter = self.get_body_com("torso")[0] 63 | forward_reward = (xposafter - xposbefore) / self.dt 64 | ctrl_cost = .5 * np.square(a).sum() 65 | contact_cost = 0.5 * 1e-3 * np.sum( 66 | np.square(np.clip(self.physics.data.cfrc_ext, -1, 1))) 67 | survive_reward = 1.0 68 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 69 | state = self.state_vector() 70 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 71 | done = not notdone 72 | ob = self._get_obs() 73 | return ob, reward, done, dict( 74 | reward_forward=forward_reward, 75 | reward_ctrl=-ctrl_cost, 76 | reward_contact=-contact_cost, 77 | reward_survive=survive_reward) 78 | 79 | def _get_obs(self): 80 | if self._expose_all_qpos: 81 | obs = np.concatenate([ 82 | self.physics.data.qpos.flat, 83 | self.physics.data.qvel.flat, 84 | np.clip(self.physics.data.cfrc_ext, -1, 1).flat, 85 | ]) 86 | else: 87 | obs = np.concatenate([ 88 | self.physics.data.qpos.flat[2:], 89 | self.physics.data.qvel.flat, 90 | np.clip(self.physics.data.cfrc_ext, -1, 1).flat, 91 | ]) 92 | 93 | if self._expose_body_coms is not None: 94 | for name in self._expose_body_coms: 95 | com = self.get_body_com(name) 96 | if name not in self._body_com_indices: 97 | indices = range(len(obs), len(obs) + len(com)) 98 | self._body_com_indices[name] = indices 99 | obs = np.concatenate([obs, com]) 100 | 101 | if self._expose_body_comvels is not None: 102 | for name in self._expose_body_comvels: 103 | comvel = self.get_body_comvel(name) 104 | if name not in self._body_comvel_indices: 105 | indices = range(len(obs), len(obs) + len(comvel)) 106 | self._body_comvel_indices[name] = indices 107 | obs = np.concatenate([obs, comvel]) 108 | return obs 109 | 110 | def reset_model(self): 111 | qpos = self.init_qpos + self.np_random.uniform( 112 | size=self.physics.model.nq, low=-.1, high=.1) 113 | qvel = self.init_qvel + self.np_random.randn(self.physics.model.nv) * .1 114 | self.set_state(qpos, qvel) 115 | return self._get_obs() 116 | 117 | def camera_setup(self): 118 | # pylint: disable=protected-access 119 | self.camera._render_camera.type_ = enums.mjtCamera.mjCAMERA_TRACKING 120 | # pylint: disable=protected-access 121 | self.camera._render_camera.trackbodyid = 1 122 | # pylint: disable=protected-access 123 | self.camera._render_camera.distance = self.physics.model.stat.extent 124 | 125 | @property 126 | def body_com_indices(self): 127 | return self._body_com_indices 128 | 129 | @property 130 | def body_comvel_indices(self): 131 | return self._body_comvel_indices 132 | -------------------------------------------------------------------------------- /third_party/gym/ant_wrapper_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # The MIT License 3 | # 4 | # Copyright (c) 2016 OpenAI (https://openai.com) 5 | # Copyright (c) 2018 The TF-Agents Authors. 6 | # Copyright (c) 2018 Google LLC (http://google.com) 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWIS, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 24 | # THE SOFTWARE. 25 | """Tests for google3.third_party.py.third_party.gym.ant_wrapper.""" 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | import os 32 | from third_party.gym import ant_wrapper 33 | from google3.pyglib import resources 34 | from google3.testing.pybase import googletest 35 | 36 | ASSETS_DIR = 'google3/third_party/py/third_party.gym/assets' 37 | 38 | 39 | def get_resource(filename): 40 | return resources.GetResourceFilenameInDirectoryTree( 41 | os.path.join(ASSETS_DIR, filename)) 42 | 43 | 44 | class AntWrapperTest(googletest.TestCase): 45 | 46 | def test_ant_wrapper(self): 47 | env = ant_wrapper.AntWrapper( 48 | get_resource('mujoco_ant_custom_texture_camerav2.xml'), 49 | texture_mode='fixed', 50 | texture_file_pattern=get_resource('texture.png')) 51 | env.reset() 52 | obs, unused_reward, unused_done, info = env.step(env.action_space.sample()) 53 | self.assertEqual(obs.shape, (27,)) 54 | self.assertIn('frame', info) 55 | self.assertEqual(info['frame'].shape, 56 | (120, 160, 3)) 57 | 58 | 59 | if __name__ == '__main__': 60 | googletest.main() 61 | -------------------------------------------------------------------------------- /third_party/gym/assets/mujoco_ant_custom_texture_camerav2.xml: -------------------------------------------------------------------------------- 1 | 26 | 27 | 28 | 109 | -------------------------------------------------------------------------------- /third_party/gym/assets/texture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/episodic-curiosity/3c406964473d98fb977b1617a170a447b3c548fd/third_party/gym/assets/texture.png -------------------------------------------------------------------------------- /third_party/gym/mujoco_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # The MIT License 3 | # 4 | # Copyright (c) 2016 OpenAI (https://openai.com) 5 | # Copyright (c) 2018 The TF-Agents Authors. 6 | # Copyright (c) 2018 Google LLC (http://google.com) 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in 16 | # all copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 24 | # THE SOFTWARE. 25 | 26 | """Port of the openai gym mujoco envs to run with dm_control. 27 | 28 | In this process / port, the simulators are not the same exactly. 29 | Possible reasons are how the integrators are run -- 30 | deepmind uses mj_step / mj_step1 / mj_step2 where as 31 | openai just uses mj_step. 32 | Openai does a compute subtree which I am not doing here 33 | 34 | In addition, I am not 100% confident in how often the 'MjData' is synced, 35 | as a result there could be a 1 frame offset. 36 | Finally, dm_control code appears to do resets slightly differently and use 37 | different mj functions. 38 | """ 39 | 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import os 45 | 46 | from dm_control import mujoco 47 | from dm_control.mujoco.wrapper.mjbindings import enums 48 | from dm_control.mujoco.wrapper.mjbindings.functions import mjlib 49 | import gym 50 | from gym import spaces 51 | from gym.utils import seeding 52 | import numpy as np 53 | from six.moves import xrange 54 | 55 | 56 | class CustomPhysics(mujoco.Physics): 57 | 58 | def step(self, n_sub_steps=1): 59 | """Advances physics with up-to-date position and velocity dependent fields. 60 | 61 | The actuation can be updated by calling the `set_control` function first. 62 | 63 | Args: 64 | n_sub_steps: Optional number of times to advance the physics. Default 1. 65 | """ 66 | 67 | # This does not line up to how openai does things but instead to how 68 | # deepmind does. 69 | # There are configurations where environments like half cheetah become 70 | # unstable. 71 | # This is a rough proxy for what is supposed to happen but not perfect. 72 | for _ in xrange(n_sub_steps): 73 | mjlib.mj_step(self.model.ptr, self.data.ptr) 74 | 75 | if self.model.opt.integrator != enums.mjtIntegrator.mjINT_EULER: 76 | mjlib.mj_step1(self.model.ptr, self.data.ptr) 77 | 78 | 79 | class MujocoEnv(gym.Env): 80 | """Superclass MuJoCo environments modified to use deepmind's mujoco wrapper. 81 | """ 82 | 83 | def __init__(self, 84 | model_path, 85 | frame_skip, 86 | max_episode_steps=None, 87 | reward_threshold=None): 88 | if not os.path.exists(model_path): 89 | raise IOError('File %s does not exist' % model_path) 90 | 91 | self.frame_skip = frame_skip 92 | self.physics = CustomPhysics.from_xml_path(model_path) 93 | self.camera = mujoco.MovableCamera(self.physics, height=480, width=640) 94 | 95 | self.viewer = None 96 | 97 | self.metadata = { 98 | 'render.modes': ['human', 'rgb_array'], 99 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 100 | } 101 | 102 | self.init_qpos = self.physics.data.qpos.ravel().copy() 103 | self.init_qvel = self.physics.data.qvel.ravel().copy() 104 | observation, _, done, _ = self.step(np.zeros(self.physics.model.nu)) 105 | assert not done 106 | self.obs_dim = observation.size 107 | 108 | bounds = self.physics.model.actuator_ctrlrange.copy() 109 | low = bounds[:, 0] 110 | high = bounds[:, 1] 111 | self.action_space = spaces.Box(low, high, dtype=np.float32) 112 | 113 | high = np.inf * np.ones(self.obs_dim) 114 | low = -high 115 | self.observation_space = spaces.Box(low, high, dtype=np.float32) 116 | 117 | self.max_episode_steps = max_episode_steps 118 | self.reward_threshold = reward_threshold 119 | 120 | self.seed() 121 | self.camera_setup() 122 | 123 | def seed(self, seed=None): 124 | self.np_random, seed = seeding.np_random(seed) 125 | return [seed] 126 | 127 | def reset_model(self): 128 | """Reset the robot degrees of freedom (qpos and qvel). 129 | 130 | Implement this in each subclass. 131 | """ 132 | raise NotImplementedError() 133 | 134 | def viewer_setup(self): 135 | """This method is called when the viewer is initialized and after all reset. 136 | 137 | Optionally implement this method, if you need to tinker with camera 138 | position 139 | and so forth. 140 | """ 141 | pass 142 | 143 | def reset(self): 144 | mjlib.mj_resetData(self.physics.model.ptr, self.physics.data.ptr) 145 | ob = self.reset_model() 146 | return ob 147 | 148 | def set_state(self, qpos, qvel): 149 | assert qpos.shape == (self.physics.model.nq,) and qvel.shape == ( 150 | self.physics.model.nv,) 151 | assert self.physics.get_state().size == qpos.size + qvel.size 152 | state = np.concatenate([qpos, qvel], 0) 153 | with self.physics.reset_context(): 154 | self.physics.set_state(state) 155 | 156 | @property 157 | def dt(self): 158 | return self.physics.model.opt.timestep * self.frame_skip 159 | 160 | def do_simulation(self, ctrl, n_frames): 161 | self.physics.set_control(ctrl) 162 | for _ in range(n_frames): 163 | self.physics.step() 164 | 165 | def render(self, mode='human'): 166 | if mode == 'rgb_array': 167 | data = self.camera.render() 168 | return np.copy(data) # render reuses the same memory space. 169 | elif mode == 'human': 170 | raise NotImplementedError( 171 | 'Currently no interactive renderings are allowed.') 172 | 173 | def get_body_com(self, body_name): 174 | idx = self.physics.model.name2id(body_name, 1) 175 | return self.physics.data.subtree_com[idx] 176 | 177 | def get_body_comvel(self, body_name): 178 | # As of MuJoCo v2.0, updates to `mjData->subtree_linvel` will be skipped 179 | # unless these quantities are needed by the simulation. We therefore call 180 | # `mj_subtreeVel` to update them explicitly. 181 | mjlib.mj_subtreeVel(self.physics.model.ptr, self.physics.data.ptr) 182 | idx = self.physics.model.name2id(body_name, 1) 183 | return self.physics.data.subtree_linvel[idx] 184 | 185 | def get_body_xmat(self, body_name): 186 | raise NotImplementedError() 187 | 188 | def state_vector(self): 189 | return np.concatenate( 190 | [self.physics.data.qpos.flat, self.physics.data.qvel.flat]) 191 | 192 | def get_state(self): 193 | return np.array(self.physics.data.qpos.flat), np.array( 194 | self.physics.data.qvel.flat) 195 | 196 | def camera_setup(self): 197 | pass # override this to set up camera 198 | -------------------------------------------------------------------------------- /third_party/keras_resnet/LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | All contributions by Raghavendra Kotikalapudi: 4 | Copyright (c) 2016, Raghavendra Kotikalapudi. 5 | All rights reserved. 6 | 7 | All other contributions: 8 | Copyright (c) 2016, the respective contributors. 9 | All rights reserved. 10 | 11 | Copyright (c) 2018 Google LLC 12 | All rights reserved. 13 | 14 | Each contributor holds copyright over their respective contributions. 15 | The project versioning (Git) records all such contribution source information. 16 | 17 | LICENSE 18 | 19 | The MIT License (MIT) 20 | 21 | Permission is hereby granted, free of charge, to any person obtaining a copy 22 | of this software and associated documentation files (the "Software"), to deal 23 | in the Software without restriction, including without limitation the rights 24 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 25 | copies of the Software, and to permit persons to whom the Software is 26 | furnished to do so, subject to the following conditions: 27 | 28 | The above copyright notice and this permission notice shall be included in all 29 | copies or substantial portions of the Software. 30 | 31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 37 | SOFTWARE. 38 | -------------------------------------------------------------------------------- /third_party/keras_resnet/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | --------------------------------------------------------------------------------