├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── colab
└── plot_training_graphs.ipynb
├── episodic_curiosity
├── __init__.py
├── constants.py
├── constants_test.py
├── curiosity_env_wrapper.py
├── curiosity_env_wrapper_test.py
├── curiosity_evaluation.py
├── curiosity_evaluation_test.py
├── env_factory.py
├── env_factory_test.py
├── environments
│ ├── __init__.py
│ ├── dmlab_utils.py
│ └── fake_gym_env.py
├── episodic_memory.py
├── episodic_memory_test.py
├── eval_policy.py
├── generate_r_training_data.py
├── generate_r_training_data_test.py
├── keras_checkpoint.py
├── logging.py
├── oracle.py
├── r_network.py
├── r_network_training.py
├── r_network_training_test.py
├── train_policy.py
├── train_r.py
├── train_r_test.py
├── utils.py
└── visualize_curiosity_reward.py
├── misc
├── ant_github.gif
└── navigation_github.gif
├── scripts
├── gs_sync.py
├── launch_cloud_vms.py
├── launcher_script.py
├── vm_drop_root.sh
└── vm_start.sh
├── setup.py
└── third_party
├── __init__.py
├── baselines
├── LICENSE
├── __init__.py
├── a2c
│ ├── __init__.py
│ └── utils.py
├── bench
│ ├── __init__.py
│ └── monitor.py
├── common
│ ├── __init__.py
│ ├── atari_wrappers.py
│ ├── cmd_util.py
│ ├── distributions.py
│ ├── input.py
│ ├── math_util.py
│ ├── misc_util.py
│ ├── runners.py
│ ├── tf_util.py
│ ├── tile_images.py
│ └── vec_env
│ │ ├── __init__.py
│ │ ├── dummy_vec_env.py
│ │ ├── subproc_vec_env.py
│ │ └── threaded_vec_env.py
├── logger.py
└── ppo2
│ ├── __init__.py
│ ├── pathak_utils.py
│ ├── policies.py
│ └── ppo2.py
├── dmlab
├── LICENSE
└── dmlab_min_goal_distance.patch
├── gym
├── LICENSE
├── __init__.py
├── ant.py
├── ant_wrapper.py
├── ant_wrapper_test.py
├── assets
│ ├── mujoco_ant_custom_texture_camerav2.xml
│ └── texture.png
└── mujoco_env.py
└── keras_resnet
├── LICENSE
├── __init__.py
└── models.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | You are welcome to contribute patches to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Episodic Curiosity Through Reachability
2 |
3 | #### In ICLR 2019 [[Project Website](https://sites.google.com/corp/view/episodic-curiosity)][[Paper](https://arxiv.org/abs/1810.02274)]
4 |
5 | [Nikolay Savinov¹](http://people.inf.ethz.ch/nsavinov/), [Anton Raichuk²](https://ai.google/research/people/AntonRaichuk), [Raphaël Marinier²](https://ai.google/research/people/105955), Damien Vincent², [Marc Pollefeys¹](https://www.inf.ethz.ch/personal/marc.pollefeys/), [Timothy Lillicrap³](http://contrastiveconvergence.net/~timothylillicrap/index.php), [Sylvain Gelly²](https://ai.google/research/people/SylvainGelly)
6 | ¹ETH Zurich, ²Google AI, ³DeepMind
7 |
8 | Navigation out of curiosity | Locomotion out of curiosity
9 | --------------------------------------------------- | ---------------------------
10 |
|
11 |
12 | This is an implementation of our
13 | [ICLR 2019 Episodic Curiosity Through Reachability](https://arxiv.org/abs/1810.02274).
14 | If you use this work, please cite:
15 |
16 | @inproceedings{Savinov2019_EC,
17 | Author = {Savinov, Nikolay and Raichuk, Anton and Marinier, Rapha{\"e}l and Vincent, Damien and Pollefeys, Marc and Lillicrap, Timothy and Gelly, Sylvain},
18 | Title = {Episodic Curiosity through Reachability},
19 | Booktitle = {International Conference on Learning Representations ({ICLR})},
20 | Year = {2019}
21 | }
22 |
23 | ### Requirements
24 |
25 | The code was tested on Linux only. The code assumes that the command "python"
26 | invokes python 2.7. We recommend you use virtualenv:
27 |
28 | ```shell
29 | sudo apt-get install python-pip
30 | pip install virtualenv
31 | python -m virtualenv episodic_curiosity_env
32 | source episodic_curiosity_env/bin/activate
33 | ```
34 |
35 | ### Installation
36 |
37 | Clone this repository:
38 |
39 | ```shell
40 | git clone https://github.com/google-research/episodic-curiosity.git
41 | cd episodic-curiosity
42 | ```
43 |
44 | We require a modified version of
45 | [DeepMind lab](https://github.com/deepmind/lab):
46 |
47 | Clone DeepMind Lab:
48 |
49 | ```shell
50 | git clone https://github.com/deepmind/lab
51 | cd lab
52 | ```
53 |
54 | Apply our patch to DeepMind Lab:
55 |
56 | ```shell
57 | git checkout 7b851dcbf6171fa184bf8a25bf2c87fe6d3f5380
58 | git checkout -b modified_dmlab
59 | git apply ../third_party/dmlab/dmlab_min_goal_distance.patch
60 | ```
61 |
62 | Install DMLab as a PIP module by following
63 | [these instructions](https://github.com/deepmind/lab/tree/master/python/pip_package)
64 |
65 | In a nutshell, once you've installed DMLab dependencies, you need to run:
66 |
67 | ```shell
68 | bazel build -c opt python/pip_package:build_pip_package
69 | ./bazel-bin/python/pip_package/build_pip_package /tmp/dmlab_pkg
70 | pip install /tmp/dmlab_pkg/DeepMind_Lab-1.0-py2-none-any.whl --force-reinstall
71 | ```
72 |
73 | If you wish to run Mujoco experiments (section S1 of the paper), you need to
74 | install dm_control and its dependencies. See [this
75 | documentation](https://github.com/deepmind/dm_control#requirements-and-installation),
76 | and replace `pip install -e .` by `pip install -e .[mujoco]` in the command
77 | below.
78 |
79 | Finally, install episodic curiosity and its pip dependencies:
80 |
81 | ```shell
82 | cd episodic-curiosity
83 | pip install -e .
84 | ```
85 |
86 |
87 |
88 | ### Resource requirements for training
89 |
90 | | Environment | Training method | Required GPU | Recommended RAM |
91 | | ----------- | ---------- | -------------------- | -------------------------- |
92 | | DMLab | PPO | No | 32GBs |
93 | | DMLab | PPO + Grid Oracle | No | 32GBs |
94 | | DMLab | PPO + EC using already trained R-networks | No | 32GBs |
95 | | DMLab | PPO + EC with R-network training | Yes, otherwise, training is slower by >20x.
Required GPU RAM: 5GBs | 50GBs
Tip: reduce `dataset_buffer_size` for using less RAM at the expense of policy performance. |
96 | | DMLab | PPO + ECO | Yes, otherwise, raining is slower by >20x.
Required GPU RAM: 5GBs | 80GBs
Tip: reduce `observation_history_size` for using less RAM, at the expense of policy performance |
97 | | Mujoco | PPO + EC using already trained R-networks | No | 32GBs |
98 |
99 |
100 | ## Trained models
101 |
102 | Trained R-networks and policies can be found in the
103 | `episodic-curiosity` Google cloud bucket. You can access them via the
104 | [web interface](https://console.cloud.google.com/storage/browser/episodic-curiosity),
105 | or copy them with the `gsutil` command from the
106 | [Google Cloud SDK](https://cloud.google.com/sdk):
107 |
108 | ```shell
109 | gsutil -m cp -r gs://episodic-curiosity/r_networks .
110 | gsutil -m cp -r gs://episodic-curiosity/policies .
111 | ```
112 |
113 | Example of command to visualize a trained policy with two episodes of
114 | 1000 steps, and create videos similar to the ones at the top of this
115 | page:
116 |
117 | ``` shell
118 | python -m episodic_curiosity.visualize_curiosity_reward --workdir=/tmp/ec_visualizations --r_net_weights= --policy_path= --alsologtostderr --num_episodes=2 --num_steps=1000 --visualization_type=surrogate_reward --trajectory_mode=do_nothing
119 | ```
120 |
121 | This requires that you install extra dependencies for generating
122 | videos, with `pip install -e .[video]`
123 |
124 |
125 | ## Training
126 |
127 | ### On a single machine
128 |
129 | [scripts/launcher_script.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launcher_script.py)
130 | is the main entry point to reproduce the results of Table 1 in the
131 | [paper](https://arxiv.org/abs/1810.02274). For instance, the following command
132 | line launches training of the *PPO + EC* method on the *Sparse+Doors* scenario:
133 |
134 | ```sh
135 | python episodic-curiosity/scripts/launcher_script.py --workdir=/tmp/ec_workdir --method=ppo_plus_ec --scenario=sparseplusdoors
136 | ```
137 |
138 | Main flags:
139 |
140 | | Flag | Descriptions |
141 | | :----------- | :--------- |
142 | | --method | Solving method to use, corresponds to the rows in table 1 of the [paper](https://arxiv.org/abs/1810.02274). Possible values: `ppo, ppo_plus_ec, ppo_plus_eco, ppo_plus_grid_oracle` |
143 | | --scenario | Scenario to launch. Corresponds to the columns in table 1 of the [paper](https://arxiv.org/abs/1810.02274). Possible values: `noreward, norewardnofire, sparse, verysparse, sparseplusdoors, dense1, dense2`. `ant_no_reward` is also supported which corresponds to the first row of table S1. |
144 | | --workdir | Directory where logs and checkpoints will be stored. |
145 | | --run_number | Run number of the current run. This is used to create an appropriate subdir in workdir. |
146 | | --r_networks_path | Only meaningful for the `ppo_plus_ec` method. Path to the root dir for pre-trained r networks. If specified, we train the policy using those pre-trained r networks. If not specified, we first generate the R network training data, train the R network and then train the policy. |
147 |
148 |
149 | Training takes a couple of days. We used CPUs with 16 hyper-threads, but smaller
150 | CPUs should do.
151 |
152 | Under the hood,
153 | [launcher_script.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launcher_script.py)
154 | launches
155 | [train_policy.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/train_policy.py)
156 | with the right hyperparameters. For the method `ppo_plus_ec`, it first launches
157 | [generate_r_training_data.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/generate_r_training_data.py)
158 | to accumulate training data for the R-network using a random policy, then
159 | launches
160 | [train_r.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/train_r.py)
161 | to train the R-network, and finally
162 | [train_policy.py](https://github.com/google-research/episodic-curiosity/blob/master/episodic_curiosity/train_policy.py)
163 | for the policy. In the method `ppo_plus_eco`, all this happens online as part of
164 | the policy training.
165 |
166 | ### On Google Cloud
167 |
168 | First, make sure you have the [Google Cloud SDK](https://cloud.google.com/sdk)
169 | installed.
170 |
171 | [scripts/launch_cloud_vms.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launch_cloud_vms.py)
172 | is the main entry point. Edit the script and replace the `FILL-ME`s with the
173 | details of your GCP project. In particular, you will need to point it to a GCP
174 | disk snapshot with the installed dependencies as described in the
175 | [Installation](#Installation) section.
176 |
177 | IMPORTANT: By default the script reproduces all results in table 1 and launches
178 | ~300 VMs on cloud with GPUs (7 scenarios x 4 methods x 10 runs). The cost of
179 | running all those VMs is very significant: on the order of USD 30 **per day**
180 | **per VM** based on early 2019 GCP pricing. Pass
181 | `--i_understand_launching_vms_is_expensive` to
182 | [scripts/launch_cloud_vms.py](https://github.com/google-research/episodic-curiosity/blob/master/scripts/launch_cloud_vms.py)
183 | to indicate that you understood that.
184 |
185 | Under the hood, `launch_cloud_vms.py` launches one VM for each (scenario,
186 | method, run_number) tuple. The VMs use startup scripts to launch training, and
187 | retrieve the parameters of the run through
188 | [Instance Metadata](https://cloud.google.com/compute/docs/storing-retrieving-metadata).
189 |
190 | TIP: Use `sudo journalctl -u google-startup-scripts.service` to see the logs of
191 | the startup script.
192 |
193 | ### Training logs
194 |
195 | Each training job stores logs and checkpoints in a workdir. The workdir is
196 | organized as follows:
197 |
198 | | File or Directory | Description |
199 | | :----------------------------------------- | :------------------------------ |
200 | | `r_training_data/{R_TRAINING,VALIDATION}/` | TF Records with data generated from a random policy for R-network training. Only for method `ppo_plus_ec` without supplying pre-trained R-networks. |
201 | | `r_networks/` | Keras checkpoints of trained R-networks. Only for method `ppo_plus_ec` without supplying pre-trained R-networks. |
202 | | `reward_{train,valid,test}.csv` | CSV files with {train,valid,test} rewards, tracking the performance of the policy at multiple training steps. |
203 | | `checkpoints/` | Checkpoints of the policy. |
204 | | `log.txt`, `progress.csv` | Training logs and CSV from OpenAI's PPO2 code. |
205 |
206 | On cloud, the workdir of each job will be synced to a cloud bucket directory of
207 | the form `////run_number_/`.
208 |
209 | We provide a
210 | [colab](https://github.com/google-research/episodic-curiosity/blob/master/colab/plot_training_graphs.ipynb)
211 | to plot graphs during training of the policies, using data from the
212 | `reward_{train,valid,test}.csv` files.
213 |
214 | ### Related projects
215 | Check out the code for [Semi-parametric Topological Memory](
216 | https://github.com/nsavinov/SPTM), which uses graph-based episodic memory
217 | constructed from a short video to navigate in novel environments (thus providing
218 | exploitation policy, complementary to the exploration policy in this work).
219 |
220 | ### Known limitations
221 |
222 | - As of 2019/02/20, `ppo_plus_eco` method is not robust to restarts, because
223 | the R-network trained online is not checkpointed.
224 |
225 |
--------------------------------------------------------------------------------
/episodic_curiosity/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
--------------------------------------------------------------------------------
/episodic_curiosity/constants.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Constants for episodic curiosity."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 |
21 | from __future__ import print_function
22 |
23 | from enum import Enum
24 |
25 |
26 | class Level(object):
27 | """Represents a DMLab level, possibly with additional non-standard settings.
28 |
29 | Attributes:
30 | dmlab_level_name: Name of the DMLab level
31 | fully_qualified_name: Unique name used to distinguish between multiple DMLab
32 | levels with the same name but different settings.
33 | extra_env_settings: dict, additional DMLab environment settings for this
34 | level.
35 | random_maze: Whether the geometry of the maze is supposed to change when we
36 | change the seed.
37 | use_r_net_from_level: If provided, don't train a R-net for this level, but
38 | instead, use the trained R-net from another level
39 | (identified by its fully qualified name).
40 | include_in_paper: Whether this level is included in the paper.
41 | scenarios: Optional list of scenarios this level is used for.
42 | """
43 |
44 | def __init__(self,
45 | dmlab_level_name,
46 | fully_qualified_name = None,
47 | extra_env_settings = None,
48 | random_maze = False,
49 | use_r_net_from_level = None,
50 | include_in_paper = False,
51 | scenarios = None):
52 | self.dmlab_level_name = dmlab_level_name
53 | self.fully_qualified_name = fully_qualified_name or dmlab_level_name
54 | self.extra_env_settings = extra_env_settings or {}
55 | self.random_maze = random_maze
56 | self.use_r_net_from_level = use_r_net_from_level
57 | self.include_in_paper = include_in_paper
58 | self.scenarios = scenarios
59 |
60 | def asdict(self):
61 | return vars(self)
62 |
63 |
64 | class SplitType(Enum):
65 | R_TRAINING = 0
66 | POLICY_TRAINING = 3
67 | VALIDATION = 1
68 | TEST = 2
69 |
70 |
71 | class Const(object):
72 | """Constants"""
73 | MAX_ACTION_DISTANCE = 5
74 | NEGATIVE_SAMPLE_MULTIPLIER = 5
75 | # env
76 | OBSERVATION_HEIGHT = 120
77 | OBSERVATION_WIDTH = 160
78 | OBSERVATION_CHANNELS = 3
79 | OBSERVATION_SHAPE = (OBSERVATION_HEIGHT, OBSERVATION_WIDTH,
80 | OBSERVATION_CHANNELS)
81 | # model and training
82 | BATCH_SIZE = 64
83 | EDGE_CLASSES = 2
84 | DUMP_AFTER_BATCHES = 100
85 | EDGE_MAX_EPOCHS = 2000
86 | ADAM_PARAMS = {
87 | 'lr': 1e-04,
88 | 'beta_1': 0.9,
89 | 'beta_2': 0.999,
90 | 'epsilon': 1e-08,
91 | 'decay': 0.0
92 | }
93 | ACTION_REPEAT = 4
94 | STORE_CHECKPOINT_EVERY_N_EPOCHS = 30
95 |
96 | LEVELS = [
97 | # Levels on which we evaluate episodic curiosity.
98 | # Corresponds to 'Sparse' setting in the paper
99 | # (arxiv.org/pdf/1810.02274.pdf).
100 | Level('contributed/dmlab30/explore_goal_locations_large',
101 | fully_qualified_name='explore_goal_locations_large',
102 | random_maze=True,
103 | include_in_paper=True,
104 | scenarios=['sparse', 'noreward', 'norewardnofire']),
105 |
106 | # WARNING!! For explore_goal_locations_large_sparse and
107 | # explore_goal_locations_large_verysparse to work properly (i.e. taking
108 | # into account minGoalDistance), you need to use the dmlab MPM:
109 | # learning/brain/research/dune/rl/dmlab_env_package.
110 | # Corresponds to 'Very Sparse' setting in the paper.
111 | Level(
112 | 'contributed/dmlab30/explore_goal_locations_large',
113 | fully_qualified_name='explore_goal_locations_large_verysparse',
114 | extra_env_settings={
115 | # Forces the spawn and goals to be further apart.
116 | # Unfortunately, we cannot go much higher, because we need to
117 | # guarantee that for any goal location, we can at least find one
118 | # spawn location that is further than this number (the goal
119 | # location might be in the middle of the map...).
120 | 'minGoalDistance': 10,
121 | },
122 | use_r_net_from_level='explore_goal_locations_large',
123 | random_maze=True, include_in_paper=True,
124 | scenarios=['verysparse']),
125 |
126 | # Corresponds to 'Sparse+Doors' setting in the paper.
127 | Level('contributed/dmlab30/explore_obstructed_goals_large',
128 | fully_qualified_name='explore_obstructed_goals_large',
129 | random_maze=True,
130 | include_in_paper=True,
131 | scenarios=['sparseplusdoors']),
132 |
133 | # Two levels where we expect to show episodic curiosity does not hurt.
134 | # Corresponds to 'Dense 1' setting in the paper.
135 | Level('contributed/dmlab30/rooms_keys_doors_puzzle',
136 | fully_qualified_name='rooms_keys_doors_puzzle',
137 | include_in_paper=True,
138 | scenarios=['dense1']),
139 | # Corresponds to 'Dense 2' setting in the paper.
140 | Level('contributed/dmlab30/rooms_collect_good_objects_train',
141 | fully_qualified_name='rooms_collect_good_objects_train',
142 | include_in_paper=True,
143 | scenarios=['dense2']),
144 | ]
145 |
146 | MIXER_SEEDS = {
147 | # Equivalent to not setting a mixer seed. Mixer seed to train the
148 | # R-network.
149 | SplitType.R_TRAINING: 0,
150 | # Mixer seed for training the policy.
151 | SplitType.POLICY_TRAINING: 0x3D23BE66,
152 | SplitType.VALIDATION: 0x2B79ED94, # Invented.
153 | SplitType.TEST: 0x600D5EED, # Same as DM's.
154 | }
155 |
156 | @staticmethod
157 | def find_level(fully_qualified_name):
158 | """Finds a DMLab level by fully qualified name."""
159 | for level in Const.LEVELS:
160 | if level.fully_qualified_name == fully_qualified_name:
161 | return level
162 | # Fallback to the DMLab level with the corresponding name.
163 | return Level(fully_qualified_name,
164 | extra_env_settings = {
165 | # Make 'rooms_exploit_deferred_effects_test',
166 | # 'rooms_collect_good_objects_test' work.
167 | 'allowHoldOutLevels': True
168 | })
169 |
170 | @staticmethod
171 | def find_level_by_scenario(scenario):
172 | """Finds a DMLab level by scenario name."""
173 | for level in Const.LEVELS:
174 | if level.scenarios and scenario in level.scenarios:
175 | return level
176 | raise ValueError('Scenario "{}" not found.'.format(scenario))
177 |
--------------------------------------------------------------------------------
/episodic_curiosity/constants_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for dune.rl.episodic_curiosity.constants."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import absltest
23 | from episodic_curiosity import constants
24 |
25 |
26 | class ConstantsTest(absltest.TestCase):
27 |
28 | def test_unique_levels(self):
29 | unique_levels = set()
30 | for level in constants.Const.LEVELS:
31 | self.assertNotIn(level.fully_qualified_name, unique_levels)
32 | unique_levels.add(level.fully_qualified_name)
33 |
34 | def test_find_level(self):
35 | self.assertEqual(
36 | constants.Const.find_level('explore_goal_locations_large')
37 | .dmlab_level_name, 'contributed/dmlab30/explore_goal_locations_large')
38 |
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/episodic_curiosity/curiosity_env_wrapper_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test of curiosity_env_wrapper.py."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from episodic_curiosity import curiosity_env_wrapper
23 | from episodic_curiosity import episodic_memory
24 | from third_party.baselines.common.vec_env.dummy_vec_env import DummyVecEnv
25 | import gym
26 | import numpy as np
27 | import tensorflow as tf
28 |
29 |
30 | class DummyImageEnv(gym.Env):
31 |
32 | def __init__(self):
33 | self._num_actions = 4
34 | self._image_shape = (28, 28, 3)
35 | self._done_prob = 0.01
36 |
37 | self.action_space = gym.spaces.Discrete(self._num_actions)
38 | self.observation_space = gym.spaces.Box(
39 | 0, 255, self._image_shape, dtype=np.float32)
40 |
41 | def seed(self, seed=None):
42 | pass
43 |
44 | def step(self, action):
45 | observation = np.random.normal(size=self._image_shape)
46 | reward = 0.0
47 | done = (np.random.rand() < self._done_prob)
48 | info = {}
49 | return observation, reward, done, info
50 |
51 | def reset(self):
52 | return np.random.normal(size=self._image_shape)
53 |
54 | def render(self, mode='human'):
55 | raise NotImplementedError('Rendering not implemented')
56 |
57 |
58 | # TODO(damienv): To be removed once the code in third_party
59 | # is compatible with python 2.
60 | class HackDummyVecEnv(DummyVecEnv):
61 |
62 | def step_wait(self):
63 | for e in range(self.num_envs):
64 | action = self.actions[e]
65 | if isinstance(self.envs[e].action_space, gym.spaces.Discrete):
66 | action = int(action)
67 |
68 | obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = (
69 | self.envs[e].step(action))
70 | if self.buf_dones[e]:
71 | obs = self.envs[e].reset()
72 | self._save_obs(e, obs)
73 | return (np.copy(self._obs_from_buf()),
74 | np.copy(self.buf_rews),
75 | np.copy(self.buf_dones),
76 | list(self.buf_infos))
77 |
78 |
79 | def embedding_similarity(x1, x2):
80 | assert x1.shape[0] == x2.shape[0]
81 | epsilon = 1e-6
82 |
83 | # Inner product between the embeddings in x1
84 | # and the embeddings in x2.
85 | s = np.sum(x1 * x2, axis=-1)
86 |
87 | s /= np.linalg.norm(x1, axis=-1) * np.linalg.norm(x2, axis=-1) + epsilon
88 | return 0.5 * (s + 1.0)
89 |
90 |
91 | def linear_embedding(m, x):
92 | # Flatten all but the batch dimension if needed.
93 | if len(x.shape) > 2:
94 | x = np.reshape(x, [x.shape[0], -1])
95 | return np.matmul(x, m)
96 |
97 |
98 | class EpisodicEnvWrapperTest(tf.test.TestCase):
99 |
100 | def EnvFactory(self):
101 | return DummyImageEnv()
102 |
103 | def testResizeObservation(self):
104 | img_grayscale = np.random.randint(low=0, high=256, size=[64, 48, 1])
105 | img_grayscale = img_grayscale.astype(np.uint8)
106 | resized_img = curiosity_env_wrapper.resize_observation(img_grayscale,
107 | [16, 12, 1])
108 | self.assertAllEqual([16, 12, 1], resized_img.shape)
109 |
110 | img_color = np.random.randint(low=0, high=256, size=[64, 48, 3])
111 | img_color = img_color.astype(np.uint8)
112 | resized_img = curiosity_env_wrapper.resize_observation(img_color,
113 | [16, 12, 1])
114 | self.assertAllEqual([16, 12, 1], resized_img.shape)
115 | resized_img = curiosity_env_wrapper.resize_observation(img_color,
116 | [16, 12, 3])
117 | self.assertAllEqual([16, 12, 3], resized_img.shape)
118 |
119 | def testEpisodicEnvWrapperSimple(self):
120 | num_envs = 10
121 | vec_env = HackDummyVecEnv([self.EnvFactory] * num_envs)
122 |
123 | embedding_size = 16
124 | vec_episodic_memory = [episodic_memory.EpisodicMemory(
125 | capacity=1000,
126 | observation_shape=[embedding_size],
127 | observation_compare_fn=embedding_similarity)
128 | for _ in range(num_envs)]
129 |
130 | mat = np.random.normal(size=[28 * 28 * 3, embedding_size])
131 | observation_embedding = lambda x, m=mat: linear_embedding(m, x)
132 |
133 | target_image_shape = [14, 14, 1]
134 | env_wrapper = curiosity_env_wrapper.CuriosityEnvWrapper(
135 | vec_env, vec_episodic_memory,
136 | observation_embedding,
137 | target_image_shape)
138 |
139 | observations = env_wrapper.reset()
140 | self.assertAllEqual([num_envs] + target_image_shape,
141 | observations.shape)
142 |
143 | dummy_actions = [1] * num_envs
144 | for _ in range(100):
145 | previous_mem_length = [len(mem) for mem in vec_episodic_memory]
146 | observations, unused_rewards, dones, unused_infos = (
147 | env_wrapper.step(dummy_actions))
148 | current_mem_length = [len(mem) for mem in vec_episodic_memory]
149 |
150 | self.assertAllEqual([num_envs] + target_image_shape,
151 | observations.shape)
152 | for k in range(num_envs):
153 | if dones[k]:
154 | self.assertEqual(1, current_mem_length[k])
155 | else:
156 | self.assertGreaterEqual(current_mem_length[k],
157 | previous_mem_length[k])
158 |
159 |
160 | if __name__ == '__main__':
161 | tf.test.main()
162 |
--------------------------------------------------------------------------------
/episodic_curiosity/curiosity_evaluation.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Library to evaluate exploration."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 |
21 | from __future__ import print_function
22 |
23 | from episodic_curiosity.constants import Const
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 |
28 | class OracleExplorationReward(object):
29 | """Class that computes the ideal exploration bonus."""
30 |
31 | def __init__(self, reward_grid_size):
32 | """Creates a new oracle to compute the exploration reward."""
33 | self._reward_grid_size = reward_grid_size
34 | self._collected_positions = set()
35 |
36 | def update_state(self, agent_position):
37 | """Set the new state (i.e. the position).
38 |
39 | Args:
40 | agent_position: x,y,z position of the agent.
41 |
42 | Returns:
43 | The exploration bonus for having visited this position.
44 | """
45 | x, y, z = agent_position
46 | quantized_x = int(x / self._reward_grid_size)
47 | quantized_y = int(y / self._reward_grid_size)
48 | quantized_z = int(z / self._reward_grid_size)
49 | position_id = (quantized_x, quantized_y, quantized_z)
50 | if position_id in self._collected_positions:
51 | # No reward if the position has already been explored.
52 | return 0.0
53 | else:
54 | self._collected_positions.add(position_id)
55 | return 1.0
56 |
57 |
58 | def policy_state_coverage(env, policy_action, reward_grid_size,
59 | eval_time_steps):
60 | """Computes the maze coverage by a given policy.
61 |
62 | Args:
63 | env: A Gym environment.
64 | policy_action: Function which, given a state, returns an action.
65 | e.g. For DMLab, the policy will return an Actions protobuf.
66 | reward_grid_size: L1 distance between 2 consecutive curiosity reward.
67 | eval_time_steps: List of times after which state coverage should be
68 | computed.
69 |
70 | Returns:
71 | The total number of cumulative rewards at different times
72 | in the episode.
73 | """
74 | max_episode_length = max(eval_time_steps) + 1
75 |
76 | # During test, the exploration reward is given by an oracle that has
77 | # access to the agent coordinates.
78 | cumulative_exploration_reward = 0.0
79 | oracle_exploration_reward = OracleExplorationReward(
80 | reward_grid_size=reward_grid_size)
81 |
82 | # Initial observation.
83 | observation = env.reset()
84 |
85 | reward_list = {}
86 | for k in range(max_episode_length):
87 | action = policy_action(observation)
88 |
89 | # TODO(damienv): Repeating an action should be a wrapper
90 | # of the environment.
91 | repeat_action_count = Const.ACTION_REPEAT
92 | done = False
93 | for _ in range(repeat_action_count):
94 | observation, _, done, metadata = env.step(action)
95 | if done:
96 | break
97 |
98 | # Abort if we've reached the end of the episode.
99 | if done:
100 | reward_list[k] = cumulative_exploration_reward
101 | break
102 |
103 | # Convert the new agent position into a possible exploration bonus.
104 | # Note: getting the position of the agent is specific to DMLab.
105 | agent_position = metadata['position']
106 | cumulative_exploration_reward += oracle_exploration_reward.update_state(
107 | agent_position)
108 |
109 | step_count = k + 1
110 | if step_count in eval_time_steps:
111 | reward_list[step_count] = cumulative_exploration_reward
112 |
113 | return reward_list
114 |
115 |
116 | class PolicyWrapper(object):
117 | """Wrap a policy defined by some input/output nodes in a TF graph."""
118 |
119 | def __init__(self,
120 | input_observation,
121 | input_state,
122 | output_state,
123 | output_actions):
124 | # The tensors to feed the policy and to retrieve the relevant outputs.
125 | self._input_observation = input_observation
126 | self._input_state = input_state
127 | self._output_state = output_state
128 | self._output_actions = output_actions
129 |
130 | # TensorFlow session.
131 | self._sess = None
132 |
133 | def set_session(self, tf_session):
134 | self._sess = tf_session
135 |
136 | def reset(self):
137 | self._current_state = np.zeros(
138 | self._input_state.get_shape().as_list(),
139 | dtype=np.float32)
140 |
141 | def action(self, observation):
142 | """Action to perform given an observation."""
143 | # Converts to batched obervation (with batch_size=1).
144 | observation = np.expand_dims(observation, axis=0)
145 |
146 | # Run the underlying policy and update the state of the policy.
147 | actions, next_state = self._sess.run(
148 | [self._output_actions, self._output_state],
149 | feed_dict={self._input_observation: observation,
150 | self._input_state: self._current_state})
151 | self._current_state = next_state
152 |
153 | # Un-batch the action.
154 | action = actions[0]
155 | return action
156 |
157 |
158 | def load_policy(graph_def,
159 | input_observation_name,
160 | input_state_name,
161 | output_state_name,
162 | output_pd_params_name,
163 | tf_sampling_fn):
164 | """Load a policy from a graph file.
165 |
166 | Args:
167 | graph_def: Graph definition.
168 | input_observation_name: Name in the graph definition of the tensor
169 | of observations.
170 | input_state_name: Name in the graph definition of the tensor
171 | of input states.
172 | output_state_name: Name in the graph definition of the tensor of output
173 | states.
174 | output_pd_params_name: Name in the graph definition of the tensor
175 | representing the parameters of the distribution over actions.
176 | tf_sampling_fn: Function which samples action based in the probability
177 | distribution parameters (given by output_pd_params_name).
178 |
179 | Returns:
180 | Returns a python policy.
181 | """
182 | tensors = tf.import_graph_def(
183 | graph_def,
184 | return_elements=[input_observation_name,
185 | input_state_name,
186 | output_state_name,
187 | output_pd_params_name],
188 | name='')
189 |
190 | input_observation, input_state, output_state, output_pd_params = tensors
191 | output_actions = tf_sampling_fn(output_pd_params)
192 |
193 | return PolicyWrapper(input_observation,
194 | input_state,
195 | output_state,
196 | output_actions)
197 |
--------------------------------------------------------------------------------
/episodic_curiosity/curiosity_evaluation_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Simple test of the curiosity evaluation."""
17 |
18 | from __future__ import division
19 | from __future__ import print_function
20 | from __future__ import unicode_literals
21 |
22 | from episodic_curiosity import curiosity_evaluation
23 | from episodic_curiosity.environments import fake_gym_env
24 | import numpy as np
25 | import tensorflow as tf
26 | from tensorflow.contrib import layers as contrib_layers
27 |
28 |
29 | def random_policy(unused_observation):
30 | action = np.random.randint(low=0, high=fake_gym_env.FakeGymEnv.NUM_ACTIONS)
31 | return action
32 |
33 |
34 | class CuriosityEvaluationTest(tf.test.TestCase):
35 |
36 | def EvalPolicy(self, policy):
37 | env = fake_gym_env.FakeGymEnv()
38 |
39 | # Distance between 2 consecutive curiosity rewards.
40 | reward_grid_size = 10.0
41 |
42 | # Times of evaluation.
43 | eval_time_steps = [100, 500, 3000]
44 |
45 | rewards = curiosity_evaluation.policy_state_coverage(
46 | env, policy, reward_grid_size, eval_time_steps)
47 |
48 | # The exploration reward is at most the number of steps.
49 | # It is equal to the number of steps when the policy explores a new state
50 | # at every time step.
51 | print('Curiosity reward: {}'.format(rewards))
52 | for k, r in rewards.items():
53 | self.assertGreaterEqual(k, r)
54 |
55 | def testRandomPolicy(self):
56 | self.EvalPolicy(random_policy)
57 |
58 | def testNNPolicy(self):
59 | batch_size = 1
60 | x = tf.placeholder(
61 | tf.float32,
62 | shape=(batch_size,) + fake_gym_env.FakeGymEnv.OBSERVATION_SHAPE)
63 | x = tf.div(x, 255.0)
64 |
65 | # This is just to make the test run fast enough.
66 | x_downscaled = tf.image.resize_images(x, [8, 8])
67 | x_downscaled = tf.reshape(x_downscaled, [batch_size, -1])
68 |
69 | # Logits to select the action.
70 | num_actions = 7
71 | h = contrib_layers.fully_connected(
72 | inputs=x_downscaled, num_outputs=32, activation_fn=None, scope='fc0')
73 | h = tf.nn.relu(h)
74 | y_logits = contrib_layers.fully_connected(
75 | inputs=h, num_outputs=num_actions, activation_fn=None, scope='fc1')
76 | temperature = 100.0
77 | y_logits /= temperature
78 |
79 | # Draw the action according to the distribution inferred by the logits.
80 | r = tf.random_uniform(tf.shape(y_logits),
81 | minval=0.001, maxval=0.999)
82 | y_logits -= tf.log(-tf.log(r))
83 | y = tf.argmax(y_logits, axis=-1)
84 |
85 | input_state = tf.placeholder(tf.float32, shape=(37))
86 | output_state = input_state
87 |
88 | # Policy from the previous network.
89 | policy = curiosity_evaluation.PolicyWrapper(
90 | x, input_state, output_state, y)
91 |
92 | global_init_op = tf.global_variables_initializer()
93 | with self.test_session() as sess:
94 | sess.run(global_init_op)
95 |
96 | policy.set_session(sess)
97 | policy.reset()
98 |
99 | self.EvalPolicy(policy.action)
100 |
101 |
102 | if __name__ == '__main__':
103 | tf.test.main()
104 |
--------------------------------------------------------------------------------
/episodic_curiosity/env_factory_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test of env_factory.py."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 |
24 | from absl import flags
25 | from episodic_curiosity import env_factory
26 | from episodic_curiosity import r_network
27 | from episodic_curiosity.environments import fake_gym_env
28 | from third_party.keras_resnet import models
29 | import numpy as np
30 | import tensorflow as tf
31 | from tensorflow import keras
32 |
33 |
34 | FLAGS = flags.FLAGS
35 |
36 |
37 | class EnvFactoryTest(tf.test.TestCase):
38 |
39 | def setUp(self):
40 | super(EnvFactoryTest, self).setUp()
41 | keras.backend.clear_session()
42 | self.weight_path = os.path.join(tf.test.get_temp_dir(), 'weights.h5')
43 | self.input_shape = fake_gym_env.FakeGymEnv.OBSERVATION_SHAPE
44 | self.dumped_r_network, _, _ = models.ResnetBuilder.build_siamese_resnet_18(
45 | self.input_shape)
46 | self.dumped_r_network.compile(
47 | loss='categorical_crossentropy', optimizer=keras.optimizers.Adam())
48 | self.dumped_r_network.save_weights(self.weight_path)
49 |
50 | def testCreateRNetwork(self):
51 | r_network.RNetwork(self.input_shape, self.weight_path)
52 |
53 | def testCreateAndRunEnvironment(self):
54 | # pylint: disable=g-long-lambda
55 | env_factory.create_single_env = (
56 | lambda level_name, seed, dmlab_homepath, use_monitor, split, action_set:
57 | fake_gym_env.FakeGymEnv())
58 | # pylint: enable=g-long-lambda
59 |
60 | env, env_valid, env_test = env_factory.create_environments(
61 | 'explore_object_locations_small', 1, self.weight_path)
62 | env.reset()
63 | actions = [0]
64 | for _ in range(5):
65 | env.step(actions)
66 | env.close()
67 | env_valid.close()
68 | env_test.close()
69 |
70 | def testRNetworkLazyLoading(self):
71 | """Tests that the RNetwork weights are lazy loaded."""
72 | # pylint: disable=g-long-lambda
73 | rand_obs = lambda: np.random.uniform(low=-0.01, high=0.01,
74 | size=(1,) + self.input_shape)
75 | obs1 = rand_obs()
76 | obs2 = rand_obs()
77 | expected_similarity = self.dumped_r_network.predict([obs1, obs2])
78 | r_net = r_network.RNetwork(self.input_shape, self.weight_path)
79 | with self.test_session() as sess:
80 | sess.run(tf.global_variables_initializer())
81 | emb1 = r_net.embed_observation(obs1)
82 | emb2 = r_net.embed_observation(obs2)
83 | similarity = r_net.embedding_similarity(emb1, emb2)
84 | self.assertAlmostEqual(similarity[0], expected_similarity[0, 1])
85 |
86 |
87 | if __name__ == '__main__':
88 | tf.test.main()
89 |
--------------------------------------------------------------------------------
/episodic_curiosity/environments/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
--------------------------------------------------------------------------------
/episodic_curiosity/environments/fake_gym_env.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Fake gym environment.
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import gym
24 | import numpy as np
25 |
26 |
27 | # This class is shared among multiple tests. Please refrain from adding logic
28 | # that is specific to your test in this class. Instead, you should create a new
29 | # local fake env that is specific to your use-case (possibly inheriting from
30 | # this one).
31 | class FakeGymEnv(gym.Env):
32 | """Fake gym environment."""
33 | OBSERVATION_HEIGHT = 120
34 | OBSERVATION_WIDTH = 160
35 | OBSERVATION_CHANNELS = 3
36 | OBSERVATION_SHAPE = (OBSERVATION_HEIGHT, OBSERVATION_WIDTH,
37 | OBSERVATION_CHANNELS)
38 | NUM_ACTIONS = 4
39 | EPISODE_LENGTH = 100
40 |
41 | def __init__(self):
42 | self.action_space = gym.spaces.Discrete(self.NUM_ACTIONS)
43 | self.observation_space = gym.spaces.Box(
44 | 0, 255, self.OBSERVATION_SHAPE, dtype=np.float32)
45 | self.episode_step = 0
46 |
47 | def seed(self, seed=None):
48 | pass
49 |
50 | def _observation(self):
51 | return np.random.randint(
52 | low=0, high=255, size=self.OBSERVATION_SHAPE, dtype=np.uint8)
53 |
54 | def step(self, action):
55 | observation = self._observation()
56 | reward = 0.0
57 | done = self.episode_step >= self.EPISODE_LENGTH
58 | self.episode_step += 1
59 | info = {'position': np.random.uniform(low=0, high=1000, size=[3])}
60 | return observation, reward, done, info
61 |
62 | def reset(self):
63 | self.episode_step = 0
64 | return self._observation()
65 |
66 | def render(self, mode='human'):
67 | raise NotImplementedError('Rendering not implemented')
68 |
--------------------------------------------------------------------------------
/episodic_curiosity/episodic_memory.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Class that represents an episodic memory."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 |
20 | from __future__ import print_function
21 |
22 | import gin
23 | import numpy as np
24 |
25 |
26 | @gin.configurable
27 | class EpisodicMemory(object):
28 | """Episodic memory."""
29 |
30 | def __init__(self,
31 | observation_shape,
32 | observation_compare_fn,
33 | replacement='fifo',
34 | capacity=200):
35 | """Creates an episodic memory.
36 |
37 | Args:
38 | observation_shape: Shape of an observation.
39 | observation_compare_fn: Function used to measure similarity between
40 | two observations. This function returns the estimated probability that
41 | two observations are similar.
42 | replacement: String to select the behavior when a sample is added
43 | to the memory when this one is full.
44 | Can be one of: 'fifo', 'random'.
45 | 'fifo' keeps the last "capacity" samples into the memory.
46 | 'random' results in a geometric distribution of the age of the samples
47 | present in the memory.
48 | capacity: Capacity of the episodic memory.
49 |
50 | Raises:
51 | ValueError: when the replacement scheme is invalid.
52 | """
53 | self._capacity = capacity
54 | self._replacement = replacement
55 | if self._replacement not in ['fifo', 'random']:
56 | raise ValueError('Invalid replacement scheme')
57 | self._observation_shape = observation_shape
58 | self._observation_compare_fn = observation_compare_fn
59 | self.reset(False)
60 |
61 | def reset(self, show_stats=True):
62 | """Resets the memory."""
63 | if show_stats:
64 | size = len(self)
65 | age_histogram, _ = np.histogram(self._memory_age[:size],
66 | 10, [0, self._count])
67 | age_histogram = age_histogram.astype(np.float32)
68 | age_histogram = age_histogram / np.sum(age_histogram)
69 | print('Number of samples added in the previous trajectory: {}'.format(
70 | self._count))
71 | print('Histogram of sample freshness (old to fresh): {}'.format(
72 | age_histogram))
73 |
74 | self._count = 0
75 | # Stores environment observations.
76 | self._obs_memory = np.zeros([self._capacity] + self._observation_shape)
77 | # Stores the infos returned by the environment. For debugging and
78 | # visualization purposes.
79 | self._info_memory = [None] * self._capacity
80 | self._memory_age = np.zeros([self._capacity], dtype=np.int32)
81 |
82 | @property
83 | def capacity(self):
84 | return self._capacity
85 |
86 | def __len__(self):
87 | return min(self._count, self._capacity)
88 |
89 | @property
90 | def info_memory(self):
91 | return self._info_memory
92 |
93 | def add(self, observation, info):
94 | """Adds an observation to the memory.
95 |
96 | Args:
97 | observation: Observation to add to the episodic memory.
98 | info: Info returned by the environment together with the observation,
99 | for debugging and visualization purposes.
100 |
101 | Raises:
102 | ValueError: when the capacity of the memory is exceeded.
103 | """
104 | if self._count >= self._capacity:
105 | if self._replacement == 'random':
106 | # By using random replacement, the age of elements inside the memory
107 | # follows a geometric distribution (more fresh samples compared to
108 | # old samples).
109 | index = np.random.randint(low=0, high=self._capacity)
110 | elif self._replacement == 'fifo':
111 | # In this scheme, only the last self._capacity elements are kept.
112 | # Samples are replaced using a FIFO scheme (implemented as a circular
113 | # buffer).
114 | index = self._count % self._capacity
115 | else:
116 | raise ValueError('Invalid replacement scheme')
117 | else:
118 | index = self._count
119 |
120 | self._obs_memory[index] = observation
121 | self._info_memory[index] = info
122 | self._memory_age[index] = self._count
123 | self._count += 1
124 |
125 | def similarity(self, observation):
126 | """Similarity between the input observation and the ones from the memory.
127 |
128 | Args:
129 | observation: The input observation.
130 |
131 | Returns:
132 | A numpy array of similarities corresponding to the similarity between
133 | the input and each of the element in the memory.
134 | """
135 | # Make the observation batched with batch_size = self._size before
136 | # computing the similarities.
137 | # TODO(damienv): could we avoid replicating the observation ?
138 | # (with some form of broadcasting).
139 | size = len(self)
140 | observation = np.array([observation] * size)
141 | similarities = self._observation_compare_fn(observation,
142 | self._obs_memory[:size])
143 | return similarities
144 |
145 |
146 | @gin.configurable
147 | def similarity_to_memory(observation,
148 | episodic_memory,
149 | similarity_aggregation='percentile'):
150 | """Returns the similarity of the observation to the episodic memory.
151 |
152 | Args:
153 | observation: The observation the agent transitions to.
154 | episodic_memory: Episodic memory.
155 | similarity_aggregation: Aggregation method to turn the multiple
156 | similarities to each observation in the memory into a scalar.
157 |
158 | Returns:
159 | A scalar corresponding to the similarity to episodic memory. This is
160 | computed by aggregating the similarities between the new observation
161 | and every observation in the memory, according to 'similarity_aggregation'.
162 | """
163 | # Computes the similarities between the current observation and the past
164 | # observations in the memory.
165 | memory_length = len(episodic_memory)
166 | if memory_length == 0:
167 | return 0.0
168 | similarities = episodic_memory.similarity(observation)
169 | # Implements different surrogate aggregated similarities.
170 | # TODO(damienv): Implement other types of surrogate aggregated similarities.
171 | if similarity_aggregation == 'max':
172 | aggregated = np.max(similarities)
173 | elif similarity_aggregation == 'nth_largest':
174 | n = min(10, memory_length)
175 | aggregated = np.partition(similarities, -n)[-n]
176 | elif similarity_aggregation == 'percentile':
177 | percentile = 90
178 | aggregated = np.percentile(similarities, percentile)
179 | elif similarity_aggregation == 'relative_count':
180 | # Number of samples in the memory similar to the input observation.
181 | count = sum(similarities > 0.5)
182 | aggregated = float(count) / len(similarities)
183 |
184 | return aggregated
185 |
--------------------------------------------------------------------------------
/episodic_curiosity/episodic_memory_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Test of episodic_memory.py."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from episodic_curiosity import episodic_memory
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 |
27 | def embedding_similarity(x1, x2):
28 | assert x1.shape[0] == x2.shape[0]
29 | epsilon = 1e-6
30 |
31 | # Inner product between the embeddings in x1
32 | # and the embeddings in x2.
33 | s = np.sum(x1 * x2, axis=-1)
34 |
35 | s /= np.linalg.norm(x1, axis=-1) * np.linalg.norm(x2, axis=-1) + epsilon
36 | return 0.5 * (s + 1.0)
37 |
38 |
39 | class EpisodicMemoryTest(tf.test.TestCase):
40 |
41 | def RunTest(self, memory, observation_shape, add_count):
42 | expected_size = min(add_count, memory.capacity)
43 |
44 | for _ in range(add_count):
45 | observation = np.random.normal(size=observation_shape)
46 | memory.add(observation, dict())
47 | self.assertEqual(expected_size, len(memory))
48 |
49 | current_observation = np.random.normal(size=observation_shape)
50 | similarities = memory.similarity(current_observation)
51 | self.assertEqual(expected_size, len(similarities))
52 | self.assertAllLessEqual(similarities, 1.0)
53 | self.assertAllGreaterEqual(similarities, 0.0)
54 |
55 | def testEpisodicMemory(self):
56 | observation_shape = [9]
57 | memory = episodic_memory.EpisodicMemory(
58 | observation_shape=observation_shape,
59 | observation_compare_fn=embedding_similarity,
60 | capacity=150)
61 |
62 | self.RunTest(memory,
63 | observation_shape,
64 | add_count=100)
65 | memory.reset()
66 |
67 | self.RunTest(memory,
68 | observation_shape,
69 | add_count=200)
70 | memory.reset()
71 |
72 |
73 | if __name__ == '__main__':
74 | tf.test.main()
75 |
--------------------------------------------------------------------------------
/episodic_curiosity/eval_policy.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Evaluation of a policy on a GYM environment."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import gym
23 | import numpy as np
24 |
25 |
26 | class DummyVideoWriter(object):
27 |
28 | def add(self, obs):
29 | pass
30 |
31 | def close(self, filename):
32 | pass
33 |
34 |
35 | class PolicyEvaluator(object):
36 | """Evaluate a policy on a GYM environment."""
37 |
38 | def __init__(self, vec_env,
39 | metric_callback=None,
40 | video_filename=None, grayscale=False,
41 | eval_frequency=25):
42 | """New policy evaluator.
43 |
44 | Args:
45 | vec_env: baselines.VecEnv correspond to a vector of GYM environments.
46 | metric_callback: Function that is given the average reward and the time
47 | step of the evaluation.
48 | video_filename: Prefix of filenames used to record video.
49 | grayscale: Whether the observation is grayscale or color.
50 | eval_frequency: Only performs evaluation once every eval_frequency times.
51 | """
52 | self._vec_env = vec_env
53 | self._metric_callback = metric_callback
54 | self._video_filename = video_filename
55 | self._grayscale = grayscale
56 |
57 | self._eval_count = 0
58 | self._eval_frequency = eval_frequency
59 | self._discrete_actions = isinstance(self._vec_env.observation_space,
60 | gym.spaces.Discrete)
61 |
62 | def evaluate(self, model_step_fn, global_step):
63 | """Evaluate the policy as given by its step function.
64 |
65 | Args:
66 | model_step_fn: Function which given a batch of observations,
67 | a batch of policy states and a batch of dones flags returns
68 | a batch of selected actions and updated policy states.
69 | global_step: The global step of the training process.
70 | """
71 | if self._eval_count % self._eval_frequency != 0:
72 | self._eval_count += 1
73 | return
74 | self._eval_count += 1
75 |
76 | video_writer = DummyVideoWriter()
77 | video_writer2 = DummyVideoWriter()
78 | has_alternative_video = False
79 | if self._video_filename:
80 | video_filename = '{}_{}.mp4'.format(self._video_filename, global_step)
81 | video_filename2 = '{}_{}_v2.mp4'.format(self._video_filename, global_step)
82 | else:
83 | video_filename = 'dummy.mp4'
84 | video_filename2 = 'dummy2.mp4'
85 |
86 | # Initial state of the policy.
87 | # TODO(damienv): make the policy state dimension part of the constructor.
88 | policy_state_dim = 512
89 | policy_states = np.zeros((self._vec_env.num_envs, policy_state_dim),
90 | dtype=np.float32)
91 |
92 | # Reset the environments before starting the evaluation.
93 | dones = [False] * self._vec_env.num_envs
94 | sticky_dones = [False] * self._vec_env.num_envs
95 | obs = self._vec_env.reset()
96 |
97 | # Evaluation loop.
98 | total_reward = np.zeros((self._vec_env.num_envs,), dtype=np.float32)
99 | step_iter = 0
100 | action_distribution = {}
101 | while not all(sticky_dones):
102 | actions, _, policy_states, _ = model_step_fn(obs, policy_states, dones)
103 |
104 | # Update the distribution of actions seen along the trajectory.
105 | if self._discrete_actions:
106 | for action in actions:
107 | if action not in action_distribution:
108 | action_distribution[action] = 0
109 | action_distribution[action] += 1
110 |
111 | # Update the states of the environment based on the selected actions.
112 | obs, rewards, dones, infos = self._vec_env.step(actions)
113 | step_iter += 1
114 | for k in range(self._vec_env.num_envs,):
115 | if not sticky_dones[k]:
116 | total_reward[k] += rewards[k]
117 | sticky_dones = [sd or d for (sd, d) in zip(sticky_dones, dones)]
118 |
119 | # Optionally record the frames of the 1st environment.
120 | if not sticky_dones[0]:
121 | if infos[0].get('frame') is not None:
122 | frame = infos[0]['frame']
123 | else:
124 | frame = obs[0]
125 | if self._grayscale:
126 | video_writer.add(frame[:, :, 0])
127 | else:
128 | video_writer.add(frame)
129 | if infos[0].get('frame:track') is not None:
130 | has_alternative_video = True
131 | frame = infos[0]['frame:track']
132 | if self._grayscale:
133 | video_writer2.add(frame[:, :, 0])
134 | else:
135 | video_writer2.add(frame)
136 |
137 | if self._metric_callback:
138 | self._metric_callback(np.mean(total_reward), global_step)
139 |
140 | print('Average reward: {}, total reward: {}'.format(np.mean(total_reward),
141 | total_reward))
142 | if self._discrete_actions:
143 | print('Action distribution: {}'.format(action_distribution))
144 | video_writer.close(video_filename)
145 | if has_alternative_video:
146 | video_writer2.close(video_filename2)
147 |
--------------------------------------------------------------------------------
/episodic_curiosity/generate_r_training_data_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for generate_r_training_data."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl.testing import absltest
22 | from episodic_curiosity.environments import fake_gym_env
23 | from episodic_curiosity.generate_r_training_data import generate_random_episode_buffer
24 |
25 |
26 | class TestGenerateRTrainingData(absltest.TestCase):
27 |
28 | def setUp(self):
29 | # Fake Environment with an infinite length episode.
30 | self.env = fake_gym_env.FakeGymEnv()
31 |
32 | def test_generate_random_episode(self):
33 | for _ in range(2):
34 | episode_buffer = generate_random_episode_buffer(self.env)
35 | self.assertEqual(len(episode_buffer), self.env.EPISODE_LENGTH)
36 | self.assertTupleEqual(episode_buffer[0][0].shape,
37 | self.env.OBSERVATION_SHAPE)
38 |
39 |
40 | if __name__ == '__main__':
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/episodic_curiosity/keras_checkpoint.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Keras checkpointing using GFile."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 |
21 | from __future__ import print_function
22 |
23 | import os
24 | import tempfile
25 | import time
26 |
27 | from absl import logging
28 | import tensorflow as tf
29 | from tensorflow import keras
30 |
31 |
32 | # Taken from: dos/ml/lstm/train_util.py, but also supports unformatted strings
33 | # and writing summary files.
34 | class GFileModelCheckpoint(keras.callbacks.ModelCheckpoint):
35 | """Keras callback to checkpoint model to a gfile location.
36 |
37 | Makes the keras ModelCheckpoint callback compatible with google filesystem
38 | paths, such as CNS files.
39 | Models will be saved to tmp_file_path and copied from there to file_path.
40 | Also writes a summary file with model performance along with the checkpoint.
41 | """
42 |
43 | def __init__(self,
44 | file_path,
45 | save_summary,
46 | summary = None,
47 | *args,
48 | **kwargs): # pylint: disable=keyword-arg-before-vararg
49 | """Initializes checkpointer with appropriate filepaths.
50 |
51 | Args:
52 | file_path: gfile location to save model to. Supports unformatted strings
53 | similarly to keras ModelCheckpoint.
54 | save_summary: Whether we should generate and save a summary file.
55 | summary: Additional items to write to the summary file.
56 | *args: positional args passed to the underlying ModelCheckpoint.
57 | **kwargs: named args passed to the underlying ModelCheckpoint.
58 | """
59 | self.save_summary = save_summary
60 | self.summary = summary
61 | # We assume that this directory is not used by anybody else, so we uniquify
62 | # it (a bit overkill, but hey).
63 | self.tmp_dir = os.path.join(
64 | tempfile.gettempdir(),
65 | 'tmp_keras_weights_%d_%d' % (int(time.time() * 1e6), id(self)))
66 | tf.gfile.MakeDirs(self.tmp_dir)
67 | self.tmp_path = os.path.join(self.tmp_dir, os.path.basename(file_path))
68 | self.gfile_dir = os.path.dirname(file_path)
69 | super(GFileModelCheckpoint, self).__init__(self.tmp_path, *args, **kwargs)
70 |
71 | def on_epoch_end(self, epoch, logs = None):
72 | """At end of epoch, performs the gfile checkpointing."""
73 | super(GFileModelCheckpoint, self).on_epoch_end(epoch, logs=None)
74 | if self.epochs_since_last_save == 0: # ModelCheckpoint just saved
75 | tmp_dir_contents = tf.gfile.ListDirectory(self.tmp_dir)
76 | for tmp_weights_filename in tmp_dir_contents:
77 | src = os.path.join(self.tmp_dir, tmp_weights_filename)
78 | dst = os.path.join(self.gfile_dir, tmp_weights_filename)
79 | logging.info('Copying saved keras model weights from %s to %s', src,
80 | dst)
81 | tf.gfile.Copy(src, dst, overwrite=True)
82 | tf.gfile.Remove(src)
83 | if self.save_summary:
84 | merged_summary = {}
85 | merged_summary.update(self.summary)
86 | if logs:
87 | merged_summary.update(logs)
88 | with tf.gfile.Open(dst.replace('.h5', '.summary.txt'),
89 | 'w') as summary_file:
90 | summary_file.write('\n'.join(
91 | ['{}: {}'.format(k, v) for k, v in merged_summary.items()]))
92 |
--------------------------------------------------------------------------------
/episodic_curiosity/logging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Logging for episodic curiosity."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import datetime
21 | import numpy as np
22 |
23 |
24 | class VideoWriter(object):
25 | """Wrapper around video writer APIs."""
26 |
27 | def __init__(self, filename):
28 | # We don't have this library inside Google, so we only import it in the
29 | # Open-source version when we need it, and tell pytype to ignore the
30 | # fact that it's missing.
31 | # pylint: disable=g-import-not-at-top
32 | import skvideo.io # type: ignore
33 | self._writer = skvideo.io.FFmpegWriter(filename)
34 |
35 | def add(self, frame):
36 | self._writer.writeFrame(frame)
37 |
38 | def close(self):
39 | self._writer.close()
40 |
41 |
42 |
43 |
44 | def get_video_writer(video_filename):
45 | return VideoWriter(video_filename) # pylint:disable=unreachable
46 |
47 |
48 | def save_episode_buffer_as_video(episode_buffer, video_filename):
49 | """Saves episode_buffer."""
50 | video_writer = get_video_writer(video_filename)
51 | for frame in episode_buffer:
52 | video_writer.add(frame)
53 | video_writer.close()
54 |
55 |
56 | def save_training_examples_as_video(training_examples, video_filename):
57 | """Split example into two images and show side-by-side for a while."""
58 | video_writer = get_video_writer(video_filename)
59 | for example in training_examples:
60 | first = example[Ellipsis, :3]
61 | second = example[Ellipsis, 3:]
62 | side_by_side = np.concatenate((first, second), axis=0)
63 | video_writer.add(side_by_side)
64 | video_writer.close()
65 |
66 |
67 | def get_logger_dir(exp_id):
68 | return datetime.datetime.now().strftime('ec-%Y-%m-%d-%H-%M-%S-%f_') + exp_id
69 |
--------------------------------------------------------------------------------
/episodic_curiosity/oracle.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Computes some oracle reward based on the actual agent position."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import gin
23 |
24 |
25 | @gin.configurable
26 | class OracleExplorationReward(object):
27 | """Class that computes the ideal exploration bonus."""
28 |
29 | def __init__(self, reward_grid_size=30.0, cell_reward_normalizer=900.0):
30 | """Creates a new oracle to compute the exploration reward.
31 |
32 | Args:
33 | reward_grid_size: Size of a cell that contains a unique reward.
34 | cell_reward_normalizer: Denominator for computation of a cell reward
35 | """
36 | self._reward_grid_size = reward_grid_size
37 |
38 | # Make the total sum of exploration reward that can be collected
39 | # independent of the grid size.
40 | # Here, we assume that the position is laying on a 2D manifold,
41 | # hence the multiplication by the area of a 2D cell.
42 | self._cell_reward = float(reward_grid_size * reward_grid_size)
43 |
44 | # Somewhat normalize the exploration reward so that it is neither
45 | # too big or too small.
46 | self._cell_reward /= cell_reward_normalizer
47 |
48 | self.reset()
49 |
50 | def reset(self):
51 | self._collected_positions = set()
52 |
53 | def update_position(self, agent_position):
54 | """Set the new state (i.e. the position).
55 |
56 | Args:
57 | agent_position: x,y,z position of the agent.
58 |
59 | Returns:
60 | The exploration bonus for having visited this position.
61 | """
62 | x, y, z = agent_position
63 | quantized_x = int(x / self._reward_grid_size)
64 | quantized_y = int(y / self._reward_grid_size)
65 | quantized_z = int(z / self._reward_grid_size)
66 | position_id = (quantized_x, quantized_y, quantized_z)
67 | if position_id in self._collected_positions:
68 | # No reward if the position has already been explored.
69 | return 0.0
70 | else:
71 | self._collected_positions.add(position_id)
72 | return self._cell_reward
73 |
--------------------------------------------------------------------------------
/episodic_curiosity/r_network.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """R-network and some related functions to train R-networks."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 |
21 | from __future__ import print_function
22 |
23 | import tempfile
24 |
25 | from absl import logging
26 | from third_party.keras_resnet import models
27 | import numpy as np
28 | import tensorflow as tf
29 | from tensorflow import keras
30 |
31 |
32 | class RNetwork(object):
33 | """Encapsulates a trained R network, with lazy loading of weights."""
34 |
35 | def __init__(self, input_shape, weight_path):
36 | """Inits the RNetwork.
37 |
38 | Args:
39 | input_shape: (height, width, channel)
40 | weight_path: Path to the weights of the r_network.
41 | """
42 | self._weight_path = weight_path
43 | (self._r_network, self._embedding_network,
44 | self._similarity_network) = models.ResnetBuilder.build_siamese_resnet_18(
45 | input_shape)
46 | self._r_network.compile(
47 | loss='categorical_crossentropy', optimizer=keras.optimizers.Adam())
48 | self._weights_loaded = False
49 |
50 | def _maybe_load_weights(self):
51 | """Loads R-network weights if needed.
52 |
53 | The RNetwork is used together with an environment used by ppo2.learn.
54 | Unfortunately, ppo2.learn initializes all global TF variables at the
55 | beginning of the training, which in particular, random-initializes the
56 | weights of the R Network. We therefore load the weights lazily, to make sure
57 | they are loaded after the global initialization happens in ppo2.learn.
58 | """
59 | if self._weights_loaded:
60 | return
61 | if self._weight_path is None:
62 | # Typically the case when doing online training of the R-network.
63 | return
64 | # Keras does not support reading weights from CNS, so we have to copy the
65 | # weights to a temporary local file.
66 | with tempfile.NamedTemporaryFile(prefix='r_net', suffix='.h5',
67 | delete=False) as tmp_file:
68 | tmp_path = tmp_file.name
69 | tf.gfile.Copy(self._weight_path, tmp_path, overwrite=True)
70 | logging.info('Loading weights from %s...', tmp_path)
71 | print('Loading into R network:')
72 | self._r_network.summary()
73 | self._r_network.load_weights(tmp_path)
74 | tf.gfile.Remove(tmp_path)
75 | self._weights_loaded = True
76 |
77 | def embed_observation(self, x):
78 | """Embeds an observation.
79 |
80 | Args:
81 | x: batched input observations. Expected to have the shape specified when
82 | the RNetwork was contructed (plus the batch dimension as first dim).
83 |
84 | Returns:
85 | embedding, shape [batch, models.EMBEDDING_DIM]
86 | """
87 | self._maybe_load_weights()
88 | return self._embedding_network.predict(x)
89 |
90 | def embedding_similarity(self, x, y):
91 | """Computes the similarity between two embeddings.
92 |
93 | Args:
94 | x: batch of the first embedding. Shape [batch, models.EMBEDDING_DIM].
95 | y: batch of the first embedding. Shape [batch, models.EMBEDDING_DIM].
96 |
97 | Returns:
98 | Similarity probabilities. 1 means very similar according to the net.
99 | 0 means very dissimilar. Shape [batch].
100 | """
101 | self._maybe_load_weights()
102 | return self._similarity_network.predict([x, y],
103 | batch_size=1024)[:, 1]
104 |
--------------------------------------------------------------------------------
/episodic_curiosity/r_network_training_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for r_network_training."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import random
23 |
24 | from absl.testing import absltest
25 | from episodic_curiosity import r_network_training
26 | from episodic_curiosity.constants import Const
27 | from episodic_curiosity.r_network_training import create_training_data_from_episode_buffer_v123
28 | from episodic_curiosity.r_network_training import create_training_data_from_episode_buffer_v4
29 | from episodic_curiosity.r_network_training import generate_negative_example
30 | import numpy as np
31 |
32 |
33 | _OBSERVATION_SHAPE = (120, 160, 3)
34 |
35 |
36 | class TestRNetworkTraining(absltest.TestCase):
37 |
38 | def test_generate_negative_example(self):
39 | max_action_distance = 5
40 | len_episode_buffer = 100
41 | for _ in range(1000):
42 | buffer_position = np.random.randint(low=0, high=len_episode_buffer)
43 | first, second = generate_negative_example(buffer_position,
44 | len_episode_buffer,
45 | max_action_distance)
46 | self.assertGreater(
47 | abs(second - first),
48 | Const.NEGATIVE_SAMPLE_MULTIPLIER * max_action_distance)
49 | self.assertGreaterEqual(second, 0)
50 | self.assertLess(second, len_episode_buffer)
51 |
52 | def test_generate_negative_example2(self):
53 | max_action_distance = 5
54 | range_max = 5 * Const.NEGATIVE_SAMPLE_MULTIPLIER * max_action_distance
55 | for buffer_position in range(0, range_max):
56 | for buffer_length in range(buffer_position + 1, 3 * range_max):
57 | # Mainly check that it does not raise an exception.
58 | new_pos, index = generate_negative_example(buffer_position,
59 | buffer_length,
60 | max_action_distance)
61 | msg = 'buffer_pos={}, buffer_length={}'.format(buffer_position,
62 | buffer_length)
63 | self.assertLess(new_pos, buffer_length, msg)
64 | self.assertGreaterEqual(new_pos, 0, msg)
65 | if index is not None:
66 | self.assertLess(index, buffer_length, msg)
67 | self.assertGreaterEqual(index, 0, msg)
68 |
69 | def test_create_training_data_from_episode_buffer_v1(self):
70 | # Make the test deterministic.
71 | random.seed(7)
72 | max_action_distance_mode = 'v1_affect_num_training_examples'
73 | max_action_distance = 5
74 | len_episode_buffer = 1000
75 | episode_buffer = [(np.zeros(_OBSERVATION_SHAPE), {
76 | 'position': i
77 | }) for i in range(len_episode_buffer)]
78 | x1, x2, labels = create_training_data_from_episode_buffer_v123(
79 | episode_buffer, max_action_distance, max_action_distance_mode)
80 | self.assertEqual(len(x1), len(x2))
81 | self.assertEqual(len(labels), len(x1))
82 | # 4 is the average of 1 + randint(1, 5).
83 | self.assertGreater(len(x1), len_episode_buffer / 4 * 0.8)
84 | self.assertLess(len(x1), len_episode_buffer / 4 * 1.2)
85 | for x in [x1, x2]:
86 | self.assertTupleEqual(x[0][0].shape, _OBSERVATION_SHAPE)
87 | max_previous_pos = -1
88 | for xx1, xx2, label in zip(x1, x2, labels):
89 | if not label:
90 | continue
91 | obs1, info1 = xx1
92 | del obs1 # unused
93 | pos1 = info1['position']
94 | obs2, info2 = xx2
95 | del obs2 # unused
96 | pos2 = info2['position']
97 | assert max_previous_pos < pos1, (
98 | 'In v1 mode, intervals of positive examples should have no overlap.')
99 | assert max_previous_pos < pos2, (
100 | 'In v1 mode, intervals of positive examples should have no overlap.')
101 | max_previous_pos = max(pos1, pos2)
102 | self._check_example_pairs(x1, x2, labels, max_action_distance)
103 |
104 | def test_create_training_data_from_episode_buffer_v2(self):
105 | # Make the test deterministic.
106 | random.seed(7)
107 | max_action_distance_mode = 'v2_fixed_num_training_examples'
108 | max_action_distance = 2
109 | len_episode_buffer = 1000
110 | episode_buffer = [(np.zeros(_OBSERVATION_SHAPE), {
111 | 'position': i
112 | }) for i in range(len_episode_buffer)]
113 | x1, x2, labels = create_training_data_from_episode_buffer_v123(
114 | episode_buffer, max_action_distance, max_action_distance_mode)
115 | self.assertEqual(len(x1), len(x2))
116 | self.assertEqual(len(labels), len(x1))
117 | # 4 is the average of 1 + randint(1, 5), where 5 is used in the v2 sampling
118 | # algorithm in order to keep the same number of training example regardless
119 | # of max_action_distance.
120 | self.assertGreater(len(x1), len_episode_buffer / 4 * 0.8)
121 | self.assertLess(len(x1), len_episode_buffer / 4 * 1.2)
122 | self._check_example_pairs(x1, x2, labels, max_action_distance)
123 |
124 | def test_create_training_data_from_episode_buffer_v4(self):
125 | # Make the test deterministic.
126 | random.seed(7)
127 | for avg_num_examples_per_env_step in (0.2, 0.5, 1, 2):
128 | max_action_distance = 5
129 | len_episode_buffer = 1000
130 | episode_buffer = [(np.zeros(_OBSERVATION_SHAPE), {
131 | 'position': i
132 | }) for i in range(len_episode_buffer)]
133 | x1, x2, labels = create_training_data_from_episode_buffer_v4(
134 | episode_buffer, max_action_distance, avg_num_examples_per_env_step)
135 | self.assertEqual(len(x1), len(x2))
136 | self.assertEqual(len(labels), len(x1))
137 | num_positives = len([label for label in labels if label == 1])
138 | self._assert_within(num_positives, len(x1) // 2, 5)
139 | self._assert_within(
140 | len(labels),
141 | len(episode_buffer) * avg_num_examples_per_env_step, 5)
142 | self._check_example_pairs(x1, x2, labels, max_action_distance)
143 |
144 | def test_create_training_data_from_episode_buffer_too_short(self):
145 | max_action_distance = 5
146 | buff = [np.zeros(_OBSERVATION_SHAPE)] * max_action_distance
147 | # Repeat the test multiple times. create_training_data_from_episode_buffer
148 | # uses randomness, so we want to hit the case where it tries to generate
149 | # negative examples.
150 | for _ in range(50):
151 | _, _, labels = create_training_data_from_episode_buffer_v123(
152 | buff, max_action_distance, mode='v1_affect_num_training_examples')
153 | for label in labels:
154 | # Not enough buffer to generate negative examples, so we should get only
155 | # (but possibly none) positive examples.
156 | self.assertEqual(label, 1)
157 |
158 | def _check_example_pairs(self, x1, x2, labels, max_action_distance):
159 | for xx1, xx2, label in zip(x1, x2, labels):
160 | obs1, info1 = xx1
161 | del obs1 # unused
162 | pos1 = info1['position']
163 | obs2, info2 = xx2
164 | del obs2 # unused
165 | pos2 = info2['position']
166 | diff = abs(pos1 - pos2)
167 | if label:
168 | self.assertLessEqual(diff, max_action_distance)
169 | else:
170 | self.assertGreater(
171 | diff, Const.NEGATIVE_SAMPLE_MULTIPLIER * max_action_distance)
172 |
173 | def _assert_within(self, value, expected, percentage):
174 | self.assertGreater(
175 | value,
176 | expected * (100 - percentage) / 100, '{} vs {} within {}%'.format(
177 | value, expected, percentage))
178 | self.assertLess(value,
179 | expected * (100 + percentage) / 100,
180 | '{} vs {} within {}%'.format(value, expected, percentage))
181 |
182 | def fit_generator(self,
183 | batch_gen, steps_per_epoch, epochs,
184 | validation_data): # pylint: disable=unused-argument
185 | max_distance = 5
186 | for _ in range(epochs):
187 | for _ in range(steps_per_epoch):
188 | pairs, labels = next(batch_gen)
189 |
190 | # Make sure all the pairs are coming from the same trajectory / env
191 | # and that they are compatible with the given labels.
192 | batch_x1 = pairs[0]
193 | batch_x2 = pairs[1]
194 | batch_size = batch_x1.shape[0]
195 | assert batch_x2.shape[0] == batch_size
196 | for k in range(batch_size):
197 | x1 = batch_x1[k]
198 | x2 = batch_x2[k]
199 | env_x1, trajectory_x1, step_x1 = x1
200 | env_x2, trajectory_x2, step_x2 = x2
201 | self.assertEqual(env_x1, env_x2)
202 | self.assertEqual(trajectory_x1, trajectory_x2)
203 |
204 | distance = abs(step_x2 - step_x1)
205 | if labels[k][1]:
206 | self.assertLessEqual(distance, max_distance)
207 | else:
208 | self.assertGreater(distance, max_distance)
209 |
210 | def test_r_network_training(self):
211 | r_model = self
212 | r_trainer = r_network_training.RNetworkTrainer(
213 | r_model,
214 | observation_history_size=10000,
215 | training_interval=20000)
216 |
217 | # Every observation is a vector of dimension 3:
218 | # (worker_id, trajectory_id, step_num)
219 | feed_count = 20000
220 | env_count = 8
221 | worker_id = list(range(env_count))
222 | trajectory_id = [0] * env_count
223 | step_idx = [0] * env_count
224 | proba_done = 0.01
225 | for _ in range(feed_count):
226 | # Observations: size = env_count x 3
227 | observations = np.stack([worker_id, trajectory_id, step_idx])
228 | observations = np.transpose(observations)
229 | dones = np.random.choice([True, False], size=env_count,
230 | p=[proba_done, 1.0 - proba_done])
231 | r_trainer.on_new_observation(observations, None, dones, None)
232 |
233 | step_idx = [s + 1 for s in step_idx]
234 |
235 | # Update the trajectory index and the environment step.
236 | for k in range(env_count):
237 | if dones[k]:
238 | step_idx[k] = 0
239 | trajectory_id[k] += 1
240 |
241 |
242 | if __name__ == '__main__':
243 | absltest.main()
244 |
--------------------------------------------------------------------------------
/episodic_curiosity/train_policy.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Main file for training policies.
17 |
18 | Many hyperparameters need to be passed through gin flags.
19 | Consider using scripts/launcher_script.py to invoke train_policy with the
20 | right hyperparameters and flags.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import os
28 | import tempfile
29 | import time
30 |
31 | from absl import flags
32 | from episodic_curiosity import env_factory
33 | from episodic_curiosity import eval_policy
34 | from episodic_curiosity import utils
35 | from third_party.baselines import logger
36 | from third_party.baselines.ppo2 import policies
37 | from third_party.baselines.ppo2 import ppo2
38 | import gin
39 | import tensorflow as tf
40 |
41 |
42 | flags.DEFINE_string('workdir', None,
43 | 'Root directory for writing logs/summaries/checkpoints.')
44 | flags.DEFINE_string('env_name', 'CartPole-v0', 'What environment to run')
45 | flags.DEFINE_string('policy_architecture', 'cnn',
46 | 'What model architecture to use')
47 | flags.DEFINE_string('r_checkpoint', '', 'Location of the R-network checkpoint')
48 | flags.DEFINE_integer('num_env', 12, 'Number of environment copies to run in '
49 | 'subprocesses.')
50 | flags.DEFINE_string('dmlab_homepath', '', '')
51 | flags.DEFINE_integer('num_timesteps', 10000000, 'Number of frames to run '
52 | 'training for.')
53 | flags.DEFINE_string('action_set', '',
54 | '(small|nofire|) - which action set to use')
55 | flags.DEFINE_bool('use_curiosity', False,
56 | 'Whether to enable Pathak\'s curiosity')
57 | flags.DEFINE_bool('random_state_predictor', False,
58 | 'Whether to use random state predictor for Pathak\'s '
59 | 'curiosity')
60 | flags.DEFINE_float('curiosity_strength', 0.01,
61 | 'Strength of the intrinsic reward in Pathak\'s algorithm.')
62 | flags.DEFINE_float('forward_inverse_ratio', 0.2,
63 | 'Weighting of forward vs inverse loss in Pathak\'s '
64 | 'algorithm')
65 | flags.DEFINE_float('curiosity_loss_strength', 10,
66 | 'Weight of the curiosity loss in Pathak\'s algorithm.')
67 |
68 |
69 | # pylint: disable=g-inconsistent-quotes
70 | flags.DEFINE_multi_string(
71 | 'gin_files', [], 'List of paths to gin configuration files')
72 | flags.DEFINE_multi_string(
73 | 'gin_bindings', [],
74 | 'Gin bindings to override the values set in the config files '
75 | '(e.g. "DQNAgent.epsilon_train=0.1",'
76 | ' "create_environment.game_name="Pong"").')
77 | # pylint: enable=g-inconsistent-quotes
78 |
79 | FLAGS = flags.FLAGS
80 |
81 |
82 | def get_environment(env_name):
83 | dmlab_prefix = 'dmlab:'
84 | atari_prefix = 'atari:'
85 | parkour_prefix = 'parkour:'
86 | if env_name.startswith(dmlab_prefix):
87 | level_name = env_name[len(dmlab_prefix):]
88 | return env_factory.create_environments(
89 | level_name,
90 | FLAGS.num_env,
91 | FLAGS.r_checkpoint,
92 | FLAGS.dmlab_homepath,
93 | action_set=FLAGS.action_set,
94 | r_network_weights_store_path=FLAGS.workdir)
95 | elif env_name.startswith(atari_prefix):
96 | level_name = env_name[len(atari_prefix):]
97 | return env_factory.create_environments(
98 | level_name,
99 | FLAGS.num_env,
100 | FLAGS.r_checkpoint,
101 | environment_engine='atari',
102 | r_network_weights_store_path=FLAGS.workdir)
103 | if env_name.startswith(parkour_prefix):
104 | return env_factory.create_environments(
105 | env_name[len(parkour_prefix):],
106 | FLAGS.num_env,
107 | FLAGS.r_checkpoint,
108 | environment_engine='parkour',
109 | r_network_weights_store_path=FLAGS.workdir)
110 | raise ValueError('Unknown environment: {}'.format(env_name))
111 |
112 |
113 | @gin.configurable
114 | def train(workdir, env_name, num_timesteps,
115 | nsteps=256,
116 | nminibatches=4,
117 | noptepochs=4,
118 | learning_rate=2.5e-4,
119 | ent_coef=0.01):
120 | """Runs PPO training.
121 |
122 | Args:
123 | workdir: where to store experiment results/logs
124 | env_name: the name of the environment to run
125 | num_timesteps: for how many timesteps to run training
126 | nsteps: Number of consecutive environment steps to use during training.
127 | nminibatches: Minibatch size.
128 | noptepochs: Number of optimization epochs.
129 | learning_rate: Initial learning rate.
130 | ent_coef: Entropy coefficient.
131 | """
132 | train_measurements = utils.create_measurement_series(workdir, 'reward_train')
133 | valid_measurements = utils.create_measurement_series(workdir, 'reward_valid')
134 | test_measurements = utils.create_measurement_series(workdir, 'reward_test')
135 |
136 | def measurement_callback(unused_eplenmean, eprewmean, global_step_val):
137 | if train_measurements:
138 | train_measurements.create_measurement(
139 | objective_value=eprewmean, step=global_step_val)
140 |
141 | def eval_callback_on_valid(eprewmean, global_step_val):
142 | if valid_measurements:
143 | valid_measurements.create_measurement(
144 | objective_value=eprewmean, step=global_step_val)
145 |
146 | def eval_callback_on_test(eprewmean, global_step_val):
147 | if test_measurements:
148 | test_measurements.create_measurement(
149 | objective_value=eprewmean, step=global_step_val)
150 |
151 | logger_dir = workdir
152 | logger.configure(logger_dir)
153 |
154 | env, valid_env, test_env = get_environment(env_name)
155 | is_ant = env_name.startswith('parkour:')
156 |
157 | # Validation metric.
158 | policy_evaluator_on_valid = eval_policy.PolicyEvaluator(
159 | valid_env,
160 | metric_callback=eval_callback_on_valid,
161 | video_filename=None)
162 |
163 | # Test metric (+ videos).
164 | video_filename = os.path.join(FLAGS.workdir, 'video')
165 | policy_evaluator_on_test = eval_policy.PolicyEvaluator(
166 | test_env,
167 | metric_callback=eval_callback_on_test,
168 | video_filename=video_filename,
169 | grayscale=(env_name.startswith('atari:')))
170 |
171 | # Delay to make sure that all the DMLab environments acquire
172 | # the GPU resources before TensorFlow acquire the rest of the memory.
173 | # TODO(damienv): Possibly use allow_grow in a TensorFlow session
174 | # so that there is no such problem anymore.
175 | time.sleep(15)
176 |
177 | cloud_sync_callback = lambda: None
178 |
179 | def evaluate_valid_test(model_step_fn, global_step):
180 | if not is_ant:
181 | policy_evaluator_on_valid.evaluate(model_step_fn, global_step)
182 | policy_evaluator_on_test.evaluate(model_step_fn, global_step)
183 |
184 | with tf.Session():
185 | policy = {'cnn': policies.CnnPolicy,
186 | 'lstm': policies.LstmPolicy,
187 | 'lnlstm': policies.LnLstmPolicy,
188 | 'mlp': policies.MlpPolicy}[FLAGS.policy_architecture]
189 |
190 | # Openai baselines never performs num_timesteps env steps because
191 | # of the way it samples training data in batches. The number of timesteps
192 | # is multiplied by 1.1 (hacky) to insure at least num_timesteps are
193 | # performed.
194 |
195 | ppo2.learn(policy, env=env, nsteps=nsteps, nminibatches=nminibatches,
196 | lam=0.95, gamma=0.99, noptepochs=noptepochs, log_interval=1,
197 | ent_coef=ent_coef,
198 | lr=learning_rate if is_ant else lambda f: f * learning_rate,
199 | cliprange=0.2 if is_ant else lambda f: f * 0.1,
200 | total_timesteps=int(num_timesteps * 1.1),
201 | train_callback=measurement_callback,
202 | eval_callback=evaluate_valid_test,
203 | cloud_sync_callback=cloud_sync_callback,
204 | save_interval=200, workdir=workdir,
205 | use_curiosity=FLAGS.use_curiosity,
206 | curiosity_strength=FLAGS.curiosity_strength,
207 | forward_inverse_ratio=FLAGS.forward_inverse_ratio,
208 | curiosity_loss_strength=FLAGS.curiosity_loss_strength,
209 | random_state_predictor=FLAGS.random_state_predictor)
210 | cloud_sync_callback()
211 | test_env.close()
212 | valid_env.close()
213 | utils.maybe_close_measurements(train_measurements)
214 | utils.maybe_close_measurements(valid_measurements)
215 | utils.maybe_close_measurements(test_measurements)
216 |
217 |
218 |
219 |
220 | def main(_):
221 | utils.dump_flags_to_file(os.path.join(FLAGS.workdir, 'flags.txt'))
222 | tf.logging.set_verbosity(tf.logging.INFO)
223 | gin.parse_config_files_and_bindings(FLAGS.gin_files,
224 | FLAGS.gin_bindings)
225 | train(FLAGS.workdir, env_name=FLAGS.env_name,
226 | num_timesteps=FLAGS.num_timesteps)
227 |
228 |
229 | if __name__ == '__main__':
230 | tf.app.run()
231 |
--------------------------------------------------------------------------------
/episodic_curiosity/train_r_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for dune.rl.episodic_curiosity.train_r."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | import os
22 | from absl import flags
23 | from absl.testing import absltest
24 | from episodic_curiosity import constants
25 | from episodic_curiosity import keras_checkpoint
26 | from episodic_curiosity import train_r
27 | import mock
28 | import numpy as np
29 | import tensorflow as tf
30 | from tensorflow import keras
31 |
32 | FLAGS = flags.FLAGS
33 |
34 |
35 | class TrainRTest(absltest.TestCase):
36 |
37 | def setUp(self):
38 | super(TrainRTest, self).setUp()
39 | keras.backend.clear_session()
40 |
41 | def test_export_stats_to_xm(self):
42 | xm_series = train_r.XmSeries(
43 | loss=mock.MagicMock(),
44 | acc=mock.MagicMock(),
45 | val_loss=mock.MagicMock(),
46 | val_acc=mock.MagicMock())
47 | self._fit_model_with_callback(train_r.ExportStatsToXm(xm_series))
48 | for series in xm_series._asdict().values():
49 | self.assertEqual(series.create_measurement.call_count, 1)
50 |
51 | def test_model_weights_checkpoint(self):
52 | path = os.path.join(FLAGS.test_tmpdir, 'r_network_weights.{epoch:05d}.h5')
53 | self._fit_model_with_callback(
54 | keras_checkpoint.GFileModelCheckpoint(
55 | path,
56 | save_summary=True,
57 | summary=constants.Level('explore_goal_locations_small').asdict(),
58 | save_weights_only=True,
59 | period=1))
60 | self.assertTrue(tf.gfile.Exists(path.format(epoch=1)))
61 | self.assertTrue(
62 | tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt')))
63 |
64 | def test_full_model_checkpoint(self):
65 | path = os.path.join(FLAGS.test_tmpdir, 'r_network_full.{epoch:05d}.h5')
66 | self._fit_model_with_callback(
67 | keras_checkpoint.GFileModelCheckpoint(
68 | path, save_summary=False, save_weights_only=False, period=1))
69 | self.assertTrue(tf.gfile.Exists(path.format(epoch=1)))
70 | self.assertFalse(
71 | tf.gfile.Exists(path.format(epoch=1).replace('h5', 'summary.txt')))
72 |
73 | def _fit_model_with_callback(self, callback):
74 | inpt = keras.layers.Input(shape=(1,))
75 | model = keras.models.Model(inputs=inpt, outputs=keras.layers.Dense(1)(inpt))
76 | model.compile(
77 | loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
78 | model.fit(
79 | x=np.random.rand(100, 1),
80 | y=np.random.rand(100, 1),
81 | validation_data=(np.random.rand(100, 1), np.random.rand(100, 1)),
82 | callbacks=[callback])
83 |
84 |
85 | if __name__ == '__main__':
86 | absltest.main()
87 |
--------------------------------------------------------------------------------
/episodic_curiosity/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A few utilities for episodic curiosity.
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 | import csv
24 | import os
25 | import time
26 | from absl import flags
27 | import numpy as np
28 | import tensorflow as tf
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def get_frame(env_observation, info):
34 | """Searches for a rendered frame in 'info', fallbacks to the env obs."""
35 | info_frame = info.get('frame')
36 | if info_frame is not None:
37 | return info_frame
38 | return env_observation
39 |
40 |
41 | def dump_flags_to_file(filename):
42 | """Dumps FLAGS to a file."""
43 | with tf.gfile.Open(filename, 'w') as output:
44 | output.write('\n'.join([
45 | '{}={}'.format(flag_name, flag_value)
46 | for flag_name, flag_value in FLAGS.flag_values_dict().items()
47 | ]))
48 |
49 |
50 | class MeasurementsWriter(object):
51 | """Writes measurements to CSV."""
52 |
53 | def __init__(self, workdir, measurement_name):
54 | """Initializes a MeasurementsWriter.
55 |
56 | Args:
57 | workdir: Directory to which the CSV file will be written.
58 | measurement_name: Name of the measurement.
59 | """
60 | filename = os.path.join(workdir, measurement_name + '.csv')
61 | file_exists = tf.gfile.Exists(filename)
62 | self._out_file = tf.gfile.Open(filename, mode='a+')
63 | self._csv_writer = csv.writer(self._out_file)
64 | if not file_exists:
65 | self._csv_writer.writerow(['step', measurement_name, 'timestamp_s'])
66 | self._out_file.flush()
67 | self._measurement_name = measurement_name
68 | self._last_flush_time = 0
69 |
70 | def create_measurement(self, objective_value, step):
71 | """Adds a measurement.
72 |
73 | Args:
74 | objective_value: Value to report for the given training step.
75 | step: Training step.
76 | """
77 | flush_every_s = 5
78 | self._csv_writer.writerow(
79 | [str(step), str(objective_value), str(int(time.time()))])
80 | if time.time() - self._last_flush_time >= flush_every_s:
81 | self._last_flush_time = time.time()
82 | self._out_file.flush()
83 |
84 | def close(self):
85 | del self._csv_writer
86 | self._out_file.close()
87 |
88 |
89 | def create_measurement_series(workdir, label):
90 | """Creates an object for exporting a per-training-step metric."""
91 | return MeasurementsWriter(workdir, label) # pylint:disable=unreachable
92 |
93 |
94 | def maybe_close_measurements(measurements):
95 | if isinstance(measurements, MeasurementsWriter):
96 | measurements.close()
97 |
98 |
99 | def load_keras_model(path):
100 | """Loads a keras model from a h5 file path."""
101 | # pylint:disable=unreachable
102 | return tf.keras.models.load_model(path, compile=True)
103 |
--------------------------------------------------------------------------------
/misc/ant_github.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/episodic-curiosity/3c406964473d98fb977b1617a170a447b3c548fd/misc/ant_github.gif
--------------------------------------------------------------------------------
/misc/navigation_github.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/episodic-curiosity/3c406964473d98fb977b1617a170a447b3c548fd/misc/navigation_github.gif
--------------------------------------------------------------------------------
/scripts/gs_sync.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Script that periodically syncs a dir to a cloud storage bucket.
17 |
18 | We only use standard python deps so that this script can be executed in any
19 | setup.
20 | This script repeatedly calls the 'gsutil rsync' command.
21 | """
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import argparse
27 | import os
28 | import subprocess
29 | import time
30 |
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--workdir', help='Source to sync to the cloud bucket')
33 | parser.add_argument(
34 | '--sync_to_cloud_bucket',
35 | help='Cloud bucket (format gs://bucket_name) to sync the workdir to')
36 | FLAGS = parser.parse_args()
37 |
38 |
39 | def sync_to_cloud_bucket():
40 | """Repeatedly syncs a path to a cloud bucket using gsutil."""
41 | sync_cmd = (
42 | 'gsutil -m rsync -r {src} {dst}'.format(
43 | src=os.path.expanduser(FLAGS.workdir),
44 | dst=os.path.join(
45 | FLAGS.sync_to_cloud_bucket, os.path.basename(FLAGS.workdir))))
46 | while True:
47 | print('Syncing to cloud bucket:', sync_cmd)
48 | # We don't stop on failure, it can be transcient issue, a subsequent run
49 | # might work.
50 | subprocess.call(sync_cmd, shell=True)
51 | time.sleep(60)
52 |
53 |
54 | if __name__ == '__main__':
55 | sync_to_cloud_bucket()
56 |
--------------------------------------------------------------------------------
/scripts/launch_cloud_vms.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """py3 script that creates all the VMs on GCP for episodic curiosity.
17 |
18 | Right now, this launches the experiments to reproduce table 1 of
19 | https://arxiv.org/pdf/1810.02274.pdf.
20 |
21 | This script only depends on the standard py libraries on purpose, so that it can
22 | be executed under any setup.
23 |
24 | The gcloud command should be accessible (see:
25 | https://cloud.google.com/sdk/gcloud/).
26 |
27 | Invoke this script at the root of episodic-curiosity:
28 | python3 scripts/launch_cloud_vms.py.
29 |
30 | Tip: inspect the logs of the startup script on the VMs with:
31 | sudo journalctl -u google-startup-scripts.service
32 | """
33 |
34 | from __future__ import absolute_import
35 | from __future__ import division
36 | from __future__ import print_function
37 | import argparse
38 | import subprocess
39 | import concurrent.futures
40 |
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--i_understand_launching_vms_is_expensive',
43 | action='store_true',
44 | help=('Each VM costs on the order of USD 30 per day based '
45 | 'on early 2019 GCP pricing. This script launches '
46 | 'many VMs at once, which can cost significant money. '
47 | 'Pass this flag to show that you understood this.'))
48 | FLAGS = parser.parse_args()
49 |
50 | # Information about your GCP account/project:
51 | GCLOUD_PROJECT = 'FILL-ME'
52 | SERVICE_ACCOUNT = 'FILL-ME'
53 | ZONE = 'FILL-ME'
54 | # GCP snapshot with episodic-curiosity (and its dependencies) installed as
55 | # explained in README.md.
56 | SOURCE_SNAPSHOT = 'FILL-ME'
57 | # User to use on the VMs.
58 | VM_USER = 'FILL-ME'
59 | # Training logs and checkpoints will be synced to this Google Cloud bucket path.
60 | # E.g. 'gs://my-episodic-curiosity-logs/training_logs'
61 | SYNC_LOGS_TO_PATH = 'FILL-ME'
62 |
63 | # Path on a Google Cloud bucket to the pre-trained R-networks. If empty,
64 | # R-networks will be re-trained.
65 | PRETRAINED_R_NETS_PATH = 'gs://episodic-curiosity/r_networks'
66 |
67 |
68 | # Name templates for instances and disks.
69 | NAME_TEMPLATE = 'ec-20190301-{method}-{scenario}-{run_number}'
70 |
71 | # Scenarios to launch.
72 | SCENARIOS = [
73 | 'noreward',
74 | 'norewardnofire',
75 | 'sparse',
76 | 'verysparse',
77 | 'sparseplusdoors',
78 | 'dense1',
79 | 'dense2',
80 | ]
81 |
82 | # Methods to launch.
83 | METHODS = [
84 | 'ppo',
85 | # This is the online version of episodic curiosity.
86 | 'ppo_plus_eco',
87 | # This is the version of episodic curiosity where the R-network is trained
88 | # before the policy.
89 | 'ppo_plus_ec',
90 | 'ppo_plus_grid_oracle'
91 | ]
92 |
93 | # Number of identical training jobs to launch for each scenario and method.
94 | # Given the variance across runs, multiple of them are needed in order to get
95 | # confidence in the results.
96 | NUM_REPEATS = 10
97 |
98 |
99 | def required_resources_for_method(method, uses_pretrained_r_net):
100 | """Returns the required resources for the given training method.
101 |
102 | Args:
103 | method: str, training method.
104 | uses_pretrained_r_net: bool, whether we use pre-trained r-net.
105 |
106 | Returns:
107 | Tuple (RAM (MBs), num CPUs, num GPUs)
108 | """
109 | if method == 'ppo_plus_eco':
110 | # We need to rent 2 GPUs, because with this amount of RAM, GCP won't allow
111 | # us to rent only one.
112 | return (105472, 16, 2)
113 | if method == 'ppo_plus_ec' and not uses_pretrained_r_net:
114 | return (52224, 12, 1)
115 | return (32768, 12, 1)
116 |
117 |
118 | def launch_vm(vm_id, vm_metadata):
119 | """Creates and launches a VM on Google Cloud compute engine.
120 |
121 | Args:
122 | vm_id: str, unique ID of the vm.
123 | vm_metadata: Dict[str, str], metadata key/value pairs passed to the vm.
124 | """
125 | print('\nCreating disk and vm with ID:', vm_id)
126 | vm_metadata['vm_id'] = vm_id
127 | ram_mbs, num_cpus, num_gpus = required_resources_for_method(
128 | vm_metadata['method'],
129 | bool(vm_metadata['pretrained_r_nets_path']))
130 |
131 | create_disk_cmd = (
132 | 'gcloud compute disks create '
133 | '"{disk_name}" --zone "{zone}" --source-snapshot "{source_snapshot}" '
134 | '--type "pd-standard" --project="{gcloud_project}" '
135 | '--size=200GB'.format(
136 | disk_name=vm_id,
137 | zone=ZONE,
138 | source_snapshot=SOURCE_SNAPSHOT,
139 | gcloud_project=GCLOUD_PROJECT,
140 | ))
141 | print('Calling', create_disk_cmd)
142 | # Don't fail if disk already exists.
143 | subprocess.call(create_disk_cmd, shell=True)
144 |
145 | create_instance_cmd = (
146 | 'gcloud compute --project={gcloud_project} instances create '
147 | '{instance_name} --zone={zone} --machine-type={machine_type} '
148 | '--subnet=default --network-tier=PREMIUM --maintenance-policy=TERMINATE '
149 | '--service-account={service_account} '
150 | '--scopes=storage-full,compute-rw '
151 | '--accelerator=type=nvidia-tesla-p100,count={gpu_count} '
152 | '--disk=name={disk_name},device-name={disk_name},mode=rw,boot=yes,'
153 | 'auto-delete=yes --restart-on-failure '
154 | '--metadata-from-file startup-script=./scripts/vm_drop_root.sh '
155 | '--metadata {vm_metadata} --async'.format(
156 | instance_name=vm_id,
157 | zone=ZONE,
158 | machine_type='custom-{num_cpus}-{ram_mbs}'.format(
159 | num_cpus=num_cpus, ram_mbs=ram_mbs),
160 | gpu_count=num_gpus,
161 | disk_name=vm_id,
162 | vm_metadata=(
163 | ','.join('{}={}'.format(k, v) for k, v in vm_metadata.items())),
164 | gcloud_project=GCLOUD_PROJECT,
165 | service_account=SERVICE_ACCOUNT,
166 | ))
167 |
168 | print('Calling', create_instance_cmd)
169 | subprocess.check_call(create_instance_cmd, shell=True)
170 |
171 |
172 | def main():
173 | launch_args = []
174 | for method in METHODS:
175 | for scenario in SCENARIOS:
176 | for run_number in range(NUM_REPEATS):
177 | vm_id = NAME_TEMPLATE.format(
178 | method=method.replace('_', '-'),
179 | scenario=scenario.replace('_', '-'),
180 | run_number=run_number)
181 | launch_args.append((
182 | vm_id,
183 | {
184 | 'method': method,
185 | 'scenario': scenario,
186 | 'run_number': str(run_number),
187 | 'user': VM_USER,
188 | 'pretrained_r_nets_path': PRETRAINED_R_NETS_PATH,
189 | 'sync_logs_to_path': SYNC_LOGS_TO_PATH
190 | }))
191 | print('YOU ARE ABOUT TO START', len(launch_args), 'VMs on GCP.')
192 | if not FLAGS.i_understand_launching_vms_is_expensive:
193 | print('Please pass --i_understand_launching_vms_is_expensive to specify '
194 | 'that you understood the cost implications of launching',
195 | len(launch_args), 'VMs')
196 | return
197 | # We use many threads in order to quickly start many instances.
198 | with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
199 | futures = []
200 | for args in launch_args:
201 | futures.append(executor.submit(launch_vm, *args))
202 | for f in futures:
203 | assert f.result() is None
204 |
205 |
206 | if __name__ == '__main__':
207 | main()
208 |
--------------------------------------------------------------------------------
/scripts/vm_drop_root.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # This script runs as root when the VM instance starts.
18 | # It launches a child script, after dropping the root user.
19 |
20 | set -x
21 |
22 | USER=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/user" -H "Metadata-Flavor: Google")
23 | EPISODIC_CURIOSITY_DIR="/home/${USER}/episodic-curiosity"
24 | PRETRAINED_R_NETS_PATH=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/pretrained_r_nets_path" -H "Metadata-Flavor: Google")
25 |
26 | # Note: during development, you could sync code from your local machine to a
27 | # cloud bucket, sync it here from the bucket to the VM, and pip install it.
28 |
29 | if [[ "${PRETRAINED_R_NETS_PATH}" ]]
30 | then
31 | gsutil -m cp -r "${PRETRAINED_R_NETS_PATH}" "/home/${USER}"
32 | fi
33 |
34 | chmod a+x "${EPISODIC_CURIOSITY_DIR}/scripts/vm_start.sh"
35 |
36 | # Launch vm_start under the given user.
37 | su - "${USER}" -c "${EPISODIC_CURIOSITY_DIR}/scripts/vm_start.sh"
38 |
--------------------------------------------------------------------------------
/scripts/vm_start.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | # Copyright 2019 Google LLC.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # This script runs when the VM instance starts. It launches episodic training
18 | # using launcher_script.py, and using the per-instance metadata (preferably set
19 | # using launch_cloud_vms.py).
20 |
21 | set -x
22 |
23 | whoami
24 |
25 | cd
26 |
27 | VM_ID=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/vm_id" -H "Metadata-Flavor: Google")
28 | METHOD=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/method" -H "Metadata-Flavor: Google")
29 | SCENARIO=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/scenario" -H "Metadata-Flavor: Google")
30 | RUN_NUMBER=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/run_number" -H "Metadata-Flavor: Google")
31 | PRETRAINED_R_NETS_PATH=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/pretrained_r_nets_path" -H "Metadata-Flavor: Google")
32 | SYNC_LOGS_TO_PATH=$(curl "http://metadata.google.internal/computeMetadata/v1/instance/attributes/sync_logs_to_path" -H "Metadata-Flavor: Google")
33 | HOMEDIR=$(pwd)
34 |
35 | WORKDIR="${HOMEDIR}/${VM_ID}"
36 |
37 | EPISODIC_CURIOSITY_DIR="${HOMEDIR}/episodic-curiosity"
38 |
39 | mkdir -p "${WORKDIR}"
40 |
41 | # This must happen before we activate the virtual env, otherwise, gsutil does not work.
42 | python "${EPISODIC_CURIOSITY_DIR}/scripts/gs_sync.py" --workdir="${WORKDIR}" --sync_to_cloud_bucket="${SYNC_LOGS_TO_PATH}" &
43 |
44 | source episodic_curiosity_env/bin/activate
45 |
46 | cd "${EPISODIC_CURIOSITY_DIR}"
47 |
48 | if [[ "${PRETRAINED_R_NETS_PATH}" ]]
49 | then
50 | BASENAME=$(basename "${PRETRAINED_R_NETS_PATH}")
51 | R_NETWORKS_PATH_FLAG="--r_networks_path=${HOMEDIR}/${BASENAME}"
52 | else
53 | R_NETWORKS_PATH_FLAG=""
54 | fi
55 |
56 | python "${EPISODIC_CURIOSITY_DIR}/scripts/launcher_script.py" --workdir="${WORKDIR}" --method="${METHOD}" --scenario="${SCENARIO}" --run_number="${RUN_NUMBER}" ${R_NETWORKS_PATH_FLAG}
57 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup script for installing episodic_curiosity as a pip module."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import setuptools
21 |
22 | VERSION = '1.0.0'
23 |
24 | install_requires = [
25 | # See installation instructions:
26 | # https://github.com/deepmind/lab/tree/master/python/pip_package
27 | 'DeepMind-Lab',
28 | 'absl-py>=0.7.0',
29 | 'dill>=0.2.9',
30 | 'enum>=0.4.7',
31 | # Won't be needed anymore when moving to python3.
32 | 'futures>=3.2.0',
33 | 'gin-config>=0.1.2',
34 | 'gym>=0.10.9',
35 | 'numpy>=1.16.0',
36 | 'opencv-python>=4.0.0.21',
37 | 'pypng>=0.0.19',
38 | 'pytype>=2019.1.18',
39 | 'scikit-image>=0.14.2',
40 | 'six>=1.12.0',
41 | 'tensorflow-gpu>=1.12.0',
42 | ]
43 |
44 | description = ('Episodic Curiosity. This is the code that allows reproducing '
45 | 'the results in the scientific paper '
46 | 'https://arxiv.org/pdf/1810.02274.pdf.')
47 |
48 |
49 | setuptools.setup(
50 | name='episodic-curiosity',
51 | version=VERSION,
52 | packages=setuptools.find_packages(),
53 | description=description,
54 | long_description=description,
55 | url='https://github.com/google-research/episodic-curiosity',
56 | author='Google LLC',
57 | author_email='opensource@google.com',
58 | install_requires=install_requires,
59 | extras_require={
60 | 'video': ['sk-video'],
61 | 'mujoco': [
62 | # For installation of dm_control see:
63 | # https://github.com/deepmind/dm_control#requirements-and-installation.
64 | 'dm_control',
65 | 'functools32',
66 | 'scikit-image',
67 | ],
68 | },
69 | license='Apache 2.0',
70 | keywords='reinforcement-learning curiosity exploration deepmind-lab',
71 | )
72 |
--------------------------------------------------------------------------------
/third_party/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 |
--------------------------------------------------------------------------------
/third_party/baselines/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2017 OpenAI (http://openai.com)
4 | Copyright (c) 2018 Google LLC (http://google.com)
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/third_party/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
--------------------------------------------------------------------------------
/third_party/baselines/a2c/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
--------------------------------------------------------------------------------
/third_party/baselines/a2c/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | import gym
4 | import numpy as np
5 | import tensorflow as tf
6 | from gym import spaces
7 | from collections import deque
8 |
9 | def sample(logits):
10 | noise = tf.random_uniform(tf.shape(logits))
11 | return tf.argmax(logits - tf.log(-tf.log(noise)), 1)
12 |
13 | def cat_entropy(logits):
14 | a0 = logits - tf.reduce_max(logits, 1, keep_dims=True)
15 | ea0 = tf.exp(a0)
16 | z0 = tf.reduce_sum(ea0, 1, keep_dims=True)
17 | p0 = ea0 / z0
18 | return tf.reduce_sum(p0 * (tf.log(z0) - a0), 1)
19 |
20 | def cat_entropy_softmax(p0):
21 | return - tf.reduce_sum(p0 * tf.log(p0 + 1e-6), axis = 1)
22 |
23 | def mse(pred, target):
24 | return tf.square(pred-target)/2.
25 |
26 | def ortho_init(scale=1.0):
27 | def _ortho_init(shape, dtype, partition_info=None):
28 | #lasagne ortho init for tf
29 | shape = tuple(shape)
30 | if len(shape) == 2:
31 | flat_shape = shape
32 | elif len(shape) == 4: # assumes NHWC
33 | flat_shape = (np.prod(shape[:-1]), shape[-1])
34 | else:
35 | raise NotImplementedError
36 | a = np.random.normal(0.0, 1.0, flat_shape)
37 | u, _, v = np.linalg.svd(a, full_matrices=False)
38 | q = u if u.shape == flat_shape else v # pick the one with the correct shape
39 | q = q.reshape(shape)
40 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
41 | return _ortho_init
42 |
43 | def conv(x, scope, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False):
44 | if data_format == 'NHWC':
45 | channel_ax = 3
46 | strides = [1, stride, stride, 1]
47 | bshape = [1, 1, 1, nf]
48 | elif data_format == 'NCHW':
49 | channel_ax = 1
50 | strides = [1, 1, stride, stride]
51 | bshape = [1, nf, 1, 1]
52 | else:
53 | raise NotImplementedError
54 | bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1]
55 | nin = x.get_shape()[channel_ax].value
56 | wshape = [rf, rf, nin, nf]
57 | with tf.variable_scope(scope):
58 | w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
59 | b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0))
60 | if not one_dim_bias and data_format == 'NHWC':
61 | b = tf.reshape(b, bshape)
62 | return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format)
63 |
64 | def fc(x, scope, nh, init_scale=1.0, init_bias=0.0):
65 | with tf.variable_scope(scope):
66 | nin = x.get_shape()[1].value
67 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
68 | b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias))
69 | return tf.matmul(x, w)+b
70 |
71 | def batch_to_seq(h, nbatch, nsteps, flat=False):
72 | if flat:
73 | h = tf.reshape(h, [nbatch, nsteps])
74 | else:
75 | h = tf.reshape(h, [nbatch, nsteps, -1])
76 | return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)]
77 |
78 | def seq_to_batch(h, flat = False):
79 | shape = h[0].get_shape().as_list()
80 | if not flat:
81 | assert(len(shape) > 1)
82 | nh = h[0].get_shape()[-1].value
83 | return tf.reshape(tf.concat(axis=1, values=h), [-1, nh])
84 | else:
85 | return tf.reshape(tf.stack(values=h, axis=1), [-1])
86 |
87 | def lstm(xs, ms, s, scope, nh, init_scale=1.0):
88 | nbatch, nin = [v.value for v in xs[0].get_shape()]
89 | nsteps = len(xs)
90 | with tf.variable_scope(scope):
91 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
92 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
93 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0))
94 |
95 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
96 | for idx, (x, m) in enumerate(zip(xs, ms)):
97 | c = c*(1-m)
98 | h = h*(1-m)
99 | z = tf.matmul(x, wx) + tf.matmul(h, wh) + b
100 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
101 | i = tf.nn.sigmoid(i)
102 | f = tf.nn.sigmoid(f)
103 | o = tf.nn.sigmoid(o)
104 | u = tf.tanh(u)
105 | c = f*c + i*u
106 | h = o*tf.tanh(c)
107 | xs[idx] = h
108 | s = tf.concat(axis=1, values=[c, h])
109 | return xs, s
110 |
111 | def _ln(x, g, b, e=1e-5, axes=[1]):
112 | u, s = tf.nn.moments(x, axes=axes, keep_dims=True)
113 | x = (x-u)/tf.sqrt(s+e)
114 | x = x*g+b
115 | return x
116 |
117 | def lnlstm(xs, ms, s, scope, nh, init_scale=1.0):
118 | nbatch, nin = [v.value for v in xs[0].get_shape()]
119 | nsteps = len(xs)
120 | with tf.variable_scope(scope):
121 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
122 | gx = tf.get_variable("gx", [nh*4], initializer=tf.constant_initializer(1.0))
123 | bx = tf.get_variable("bx", [nh*4], initializer=tf.constant_initializer(0.0))
124 |
125 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
126 | gh = tf.get_variable("gh", [nh*4], initializer=tf.constant_initializer(1.0))
127 | bh = tf.get_variable("bh", [nh*4], initializer=tf.constant_initializer(0.0))
128 |
129 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0))
130 |
131 | gc = tf.get_variable("gc", [nh], initializer=tf.constant_initializer(1.0))
132 | bc = tf.get_variable("bc", [nh], initializer=tf.constant_initializer(0.0))
133 |
134 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s)
135 | for idx, (x, m) in enumerate(zip(xs, ms)):
136 | c = c*(1-m)
137 | h = h*(1-m)
138 | z = _ln(tf.matmul(x, wx), gx, bx) + _ln(tf.matmul(h, wh), gh, bh) + b
139 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z)
140 | i = tf.nn.sigmoid(i)
141 | f = tf.nn.sigmoid(f)
142 | o = tf.nn.sigmoid(o)
143 | u = tf.tanh(u)
144 | c = f*c + i*u
145 | h = o*tf.tanh(_ln(c, gc, bc))
146 | xs[idx] = h
147 | s = tf.concat(axis=1, values=[c, h])
148 | return xs, s
149 |
150 | def conv_to_fc(x):
151 | nh = np.prod([v.value for v in x.get_shape()[1:]])
152 | x = tf.reshape(x, [-1, nh])
153 | return x
154 |
155 | def discount_with_dones(rewards, dones, gamma):
156 | discounted = []
157 | r = 0
158 | for reward, done in zip(rewards[::-1], dones[::-1]):
159 | r = reward + gamma*r*(1.-done) # fixed off by one bug
160 | discounted.append(r)
161 | return discounted[::-1]
162 |
163 | def find_trainable_variables(key):
164 | with tf.variable_scope(key):
165 | return tf.trainable_variables()
166 |
167 | def make_path(f):
168 | return os.makedirs(f, exist_ok=True)
169 |
170 | def constant(p):
171 | return 1
172 |
173 | def linear(p):
174 | return 1-p
175 |
176 | def middle_drop(p):
177 | eps = 0.75
178 | if 1-p 0
28 | obs = None
29 | for _ in range(noops):
30 | obs, _, done, _ = self.env.step(self.noop_action)
31 | if done:
32 | obs = self.env.reset(**kwargs)
33 | return obs
34 |
35 | def step(self, ac):
36 | return self.env.step(ac)
37 |
38 | class FireResetEnv(gym.Wrapper):
39 | def __init__(self, env):
40 | """Take action on reset for environments that are fixed until firing."""
41 | gym.Wrapper.__init__(self, env)
42 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
43 | assert len(env.unwrapped.get_action_meanings()) >= 3
44 |
45 | def reset(self, **kwargs):
46 | self.env.reset(**kwargs)
47 | obs, _, done, _ = self.env.step(1)
48 | if done:
49 | self.env.reset(**kwargs)
50 | obs, _, done, _ = self.env.step(2)
51 | if done:
52 | self.env.reset(**kwargs)
53 | return obs
54 |
55 | def step(self, ac):
56 | return self.env.step(ac)
57 |
58 | class EpisodicLifeEnv(gym.Wrapper):
59 | def __init__(self, env):
60 | """Make end-of-life == end-of-episode, but only reset on true game over.
61 | Done by DeepMind for the DQN and co. since it helps value estimation.
62 | """
63 | gym.Wrapper.__init__(self, env)
64 | self.lives = 0
65 | self.was_real_done = True
66 |
67 | def step(self, action):
68 | obs, reward, done, info = self.env.step(action)
69 | self.was_real_done = done
70 | # check current lives, make loss of life terminal,
71 | # then update lives to handle bonus lives
72 | lives = self.env.unwrapped.ale.lives()
73 | if lives < self.lives and lives > 0:
74 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames
75 | # so its important to keep lives > 0, so that we only reset once
76 | # the environment advertises done.
77 | done = True
78 | self.lives = lives
79 | return obs, reward, done, info
80 |
81 | def reset(self, **kwargs):
82 | """Reset only when lives are exhausted.
83 | This way all states are still reachable even though lives are episodic,
84 | and the learner need not know about any of this behind-the-scenes.
85 | """
86 | if self.was_real_done:
87 | obs = self.env.reset(**kwargs)
88 | else:
89 | # no-op step to advance from terminal/lost life state
90 | obs, _, _, _ = self.env.step(0)
91 | self.lives = self.env.unwrapped.ale.lives()
92 | return obs
93 |
94 | class MaxAndSkipEnv(gym.Wrapper):
95 | def __init__(self, env, skip=4):
96 | """Return only every `skip`-th frame"""
97 | gym.Wrapper.__init__(self, env)
98 | # most recent raw observations (for max pooling across time steps)
99 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
100 | self._skip = skip
101 |
102 | def step(self, action):
103 | """Repeat action, sum reward, and max over last observations."""
104 | total_reward = 0.0
105 | done = None
106 | for i in range(self._skip):
107 | obs, reward, done, info = self.env.step(action)
108 | if i == self._skip - 2: self._obs_buffer[0] = obs
109 | if i == self._skip - 1: self._obs_buffer[1] = obs
110 | total_reward += reward
111 | if done:
112 | break
113 | # Note that the observation on the done=True frame
114 | # doesn't matter
115 | max_frame = self._obs_buffer.max(axis=0)
116 |
117 | return max_frame, total_reward, done, info
118 |
119 | def reset(self, **kwargs):
120 | return self.env.reset(**kwargs)
121 |
122 | class ClipRewardEnv(gym.RewardWrapper):
123 | def __init__(self, env):
124 | gym.RewardWrapper.__init__(self, env)
125 |
126 | def reward(self, reward):
127 | """Bin reward to {+1, 0, -1} by its sign."""
128 | return np.sign(reward)
129 |
130 | class WarpFrame(gym.ObservationWrapper):
131 | def __init__(self, env):
132 | """Warp frames to 84x84 as done in the Nature paper and later work."""
133 | gym.ObservationWrapper.__init__(self, env)
134 | self.width = 84
135 | self.height = 84
136 | self.observation_space = spaces.Box(low=0, high=255,
137 | shape=(self.height, self.width, 1), dtype=np.uint8)
138 |
139 | def observation(self, frame):
140 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
141 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
142 | return frame[:, :, None]
143 |
144 | class FrameStack(gym.Wrapper):
145 | def __init__(self, env, k):
146 | """Stack k last frames.
147 |
148 | Returns lazy array, which is much more memory efficient.
149 |
150 | See Also
151 | --------
152 | baselines.common.atari_wrappers.LazyFrames
153 | """
154 | gym.Wrapper.__init__(self, env)
155 | self.k = k
156 | self.frames = deque([], maxlen=k)
157 | shp = env.observation_space.shape
158 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
159 |
160 | def reset(self):
161 | ob = self.env.reset()
162 | for _ in range(self.k):
163 | self.frames.append(ob)
164 | return self._get_ob()
165 |
166 | def step(self, action):
167 | ob, reward, done, info = self.env.step(action)
168 | self.frames.append(ob)
169 | return self._get_ob(), reward, done, info
170 |
171 | def _get_ob(self):
172 | assert len(self.frames) == self.k
173 | return LazyFrames(list(self.frames))
174 |
175 | class ScaledFloatFrame(gym.ObservationWrapper):
176 | def __init__(self, env):
177 | gym.ObservationWrapper.__init__(self, env)
178 |
179 | def observation(self, observation):
180 | # careful! This undoes the memory optimization, use
181 | # with smaller replay buffers only.
182 | return np.array(observation).astype(np.float32) / 255.0
183 |
184 | class LazyFrames(object):
185 | def __init__(self, frames):
186 | """This object ensures that common frames between the observations are only stored once.
187 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
188 | buffers.
189 |
190 | This object should only be converted to numpy array before being passed to the model.
191 |
192 | You'd not believe how complex the previous solution was."""
193 | self._frames = frames
194 | self._out = None
195 |
196 | def _force(self):
197 | if self._out is None:
198 | self._out = np.concatenate(self._frames, axis=2)
199 | self._frames = None
200 | return self._out
201 |
202 | def __array__(self, dtype=None):
203 | out = self._force()
204 | if dtype is not None:
205 | out = out.astype(dtype)
206 | return out
207 |
208 | def __len__(self):
209 | return len(self._force())
210 |
211 | def __getitem__(self, i):
212 | return self._force()[i]
213 |
214 | def make_atari(env_id):
215 | env = gym.make(env_id)
216 | assert 'NoFrameskip' in env.spec.id
217 | env = NoopResetEnv(env, noop_max=30)
218 | env = MaxAndSkipEnv(env, skip=4)
219 | return env
220 |
221 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
222 | """Configure environment for DeepMind-style Atari.
223 | """
224 | if episode_life:
225 | env = EpisodicLifeEnv(env)
226 | if 'FIRE' in env.unwrapped.get_action_meanings():
227 | env = FireResetEnv(env)
228 | env = WarpFrame(env)
229 | if scale:
230 | env = ScaledFloatFrame(env)
231 | if clip_rewards:
232 | env = ClipRewardEnv(env)
233 | if frame_stack:
234 | env = FrameStack(env, 4)
235 | return env
236 |
237 |
--------------------------------------------------------------------------------
/third_party/baselines/common/cmd_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """Helpers for scripts like run_atari.py.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import os
10 | from third_party.baselines import logger
11 | from third_party.baselines.bench import Monitor
12 | from third_party.baselines.common import set_global_seeds
13 | from third_party.baselines.common.atari_wrappers import make_atari
14 | from third_party.baselines.common.atari_wrappers import wrap_deepmind
15 | from third_party.baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
16 |
17 |
18 | def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0,
19 | use_monitor=True):
20 | """Create a wrapped, monitored SubprocVecEnv for Atari.
21 | """
22 | if wrapper_kwargs is None: wrapper_kwargs = {}
23 | def make_env(rank): # pylint: disable=C0111
24 | def _thunk():
25 | env = make_atari(env_id)
26 | env.seed(seed + rank)
27 | if use_monitor:
28 | env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(),
29 | str(rank)))
30 | return wrap_deepmind(env, **wrapper_kwargs)
31 | return _thunk
32 | set_global_seeds(seed)
33 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
34 |
--------------------------------------------------------------------------------
/third_party/baselines/common/input.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import tensorflow as tf
3 | from gym.spaces import Discrete, Box
4 |
5 | def observation_input(ob_space, batch_size=None, name='Ob'):
6 | '''
7 | Build observation input with encoding depending on the
8 | observation space type
9 | Params:
10 |
11 | ob_space: observation space (should be one of gym.spaces)
12 | batch_size: batch size for input (default is None, so that resulting input placeholder can take tensors with any batch size)
13 | name: tensorflow variable name for input placeholder
14 |
15 | returns: tuple (input_placeholder, processed_input_tensor)
16 | '''
17 | if isinstance(ob_space, Discrete):
18 | input_x = tf.placeholder(shape=(batch_size,), dtype=tf.int32, name=name)
19 | processed_x = tf.to_float(tf.one_hot(input_x, ob_space.n))
20 | return input_x, processed_x
21 |
22 | elif isinstance(ob_space, Box):
23 | input_shape = (batch_size,) + ob_space.shape
24 | input_x = tf.placeholder(shape=input_shape, dtype=ob_space.dtype, name=name)
25 | processed_x = tf.to_float(input_x)
26 | return input_x, processed_x
27 |
28 | else:
29 | raise NotImplementedError
30 |
31 |
32 |
--------------------------------------------------------------------------------
/third_party/baselines/common/math_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import numpy as np
3 | import scipy.signal
4 |
5 |
6 | def discount(x, gamma):
7 | """
8 | computes discounted sums along 0th dimension of x.
9 |
10 | inputs
11 | ------
12 | x: ndarray
13 | gamma: float
14 |
15 | outputs
16 | -------
17 | y: ndarray with same shape as x, satisfying
18 |
19 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k],
20 | where k = len(x) - t - 1
21 |
22 | """
23 | assert x.ndim >= 1
24 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1]
25 |
26 | def explained_variance(ypred,y):
27 | """
28 | Computes fraction of variance that ypred explains about y.
29 | Returns 1 - Var[y-ypred] / Var[y]
30 |
31 | interpretation:
32 | ev=0 => might as well have predicted zero
33 | ev=1 => perfect prediction
34 | ev<0 => worse than just predicting zero
35 |
36 | """
37 | assert y.ndim == 1 and ypred.ndim == 1
38 | vary = np.var(y)
39 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary
40 |
41 | def explained_variance_2d(ypred, y):
42 | assert y.ndim == 2 and ypred.ndim == 2
43 | vary = np.var(y, axis=0)
44 | out = 1 - np.var(y-ypred)/vary
45 | out[vary < 1e-10] = 0
46 | return out
47 |
48 | def ncc(ypred, y):
49 | return np.corrcoef(ypred, y)[1,0]
50 |
51 | def flatten_arrays(arrs):
52 | return np.concatenate([arr.flat for arr in arrs])
53 |
54 | def unflatten_vector(vec, shapes):
55 | i=0
56 | arrs = []
57 | for shape in shapes:
58 | size = np.prod(shape)
59 | arr = vec[i:i+size].reshape(shape)
60 | arrs.append(arr)
61 | i += size
62 | return arrs
63 |
64 | def discount_with_boundaries(X, New, gamma):
65 | """
66 | X: 2d array of floats, time x features
67 | New: 2d array of bools, indicating when a new episode has started
68 | """
69 | Y = np.zeros_like(X)
70 | T = X.shape[0]
71 | Y[T-1] = X[T-1]
72 | for t in range(T-2, -1, -1):
73 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1])
74 | return Y
75 |
76 | def test_discount_with_boundaries():
77 | gamma=0.9
78 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32')
79 | starts = [1.0, 0.0, 0.0, 1.0]
80 | y = discount_with_boundaries(x, starts, gamma)
81 | assert np.allclose(y, [
82 | 1 + gamma * 2 + gamma**2 * 3,
83 | 2 + gamma * 3,
84 | 3,
85 | 4
86 | ])
87 |
--------------------------------------------------------------------------------
/third_party/baselines/common/misc_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import gym
3 | import numpy as np
4 | import os
5 | import pickle
6 | import random
7 | import tempfile
8 | import zipfile
9 |
10 |
11 | def zipsame(*seqs):
12 | L = len(seqs[0])
13 | assert all(len(seq) == L for seq in seqs[1:])
14 | return zip(*seqs)
15 |
16 |
17 | def unpack(seq, sizes):
18 | """
19 | Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'.
20 | None = just one bare element, not a list
21 |
22 | Example:
23 | unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6])
24 | """
25 | seq = list(seq)
26 | it = iter(seq)
27 | assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes)
28 | for size in sizes:
29 | if size is None:
30 | yield it.__next__()
31 | else:
32 | li = []
33 | for _ in range(size):
34 | li.append(it.__next__())
35 | yield li
36 |
37 |
38 | class EzPickle(object):
39 | """Objects that are pickled and unpickled via their constructor
40 | arguments.
41 |
42 | Example usage:
43 |
44 | class Dog(Animal, EzPickle):
45 | def __init__(self, furcolor, tailkind="bushy"):
46 | Animal.__init__()
47 | EzPickle.__init__(furcolor, tailkind)
48 | ...
49 |
50 | When this object is unpickled, a new Dog will be constructed by passing the provided
51 | furcolor and tailkind into the constructor. However, philosophers are still not sure
52 | whether it is still the same dog.
53 |
54 | This is generally needed only for environments which wrap C/C++ code, such as MuJoCo
55 | and Atari.
56 | """
57 |
58 | def __init__(self, *args, **kwargs):
59 | self._ezpickle_args = args
60 | self._ezpickle_kwargs = kwargs
61 |
62 | def __getstate__(self):
63 | return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs}
64 |
65 | def __setstate__(self, d):
66 | out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
67 | self.__dict__.update(out.__dict__)
68 |
69 |
70 | def set_global_seeds(i):
71 | try:
72 | import tensorflow as tf
73 | except ImportError:
74 | pass
75 | else:
76 | tf.set_random_seed(i)
77 | np.random.seed(i)
78 | random.seed(i)
79 |
80 |
81 | def pretty_eta(seconds_left):
82 | """Print the number of seconds in human readable format.
83 |
84 | Examples:
85 | 2 days
86 | 2 hours and 37 minutes
87 | less than a minute
88 |
89 | Paramters
90 | ---------
91 | seconds_left: int
92 | Number of seconds to be converted to the ETA
93 | Returns
94 | -------
95 | eta: str
96 | String representing the pretty ETA.
97 | """
98 | minutes_left = seconds_left // 60
99 | seconds_left %= 60
100 | hours_left = minutes_left // 60
101 | minutes_left %= 60
102 | days_left = hours_left // 24
103 | hours_left %= 24
104 |
105 | def helper(cnt, name):
106 | return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else ''))
107 |
108 | if days_left > 0:
109 | msg = helper(days_left, 'day')
110 | if hours_left > 0:
111 | msg += ' and ' + helper(hours_left, 'hour')
112 | return msg
113 | if hours_left > 0:
114 | msg = helper(hours_left, 'hour')
115 | if minutes_left > 0:
116 | msg += ' and ' + helper(minutes_left, 'minute')
117 | return msg
118 | if minutes_left > 0:
119 | return helper(minutes_left, 'minute')
120 | return 'less than a minute'
121 |
122 |
123 | class RunningAvg(object):
124 | def __init__(self, gamma, init_value=None):
125 | """Keep a running estimate of a quantity. This is a bit like mean
126 | but more sensitive to recent changes.
127 |
128 | Parameters
129 | ----------
130 | gamma: float
131 | Must be between 0 and 1, where 0 is the most sensitive to recent
132 | changes.
133 | init_value: float or None
134 | Initial value of the estimate. If None, it will be set on the first update.
135 | """
136 | self._value = init_value
137 | self._gamma = gamma
138 |
139 | def update(self, new_val):
140 | """Update the estimate.
141 |
142 | Parameters
143 | ----------
144 | new_val: float
145 | new observated value of estimated quantity.
146 | """
147 | if self._value is None:
148 | self._value = new_val
149 | else:
150 | self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val
151 |
152 | def __float__(self):
153 | """Get the current estimate"""
154 | return self._value
155 |
156 | def boolean_flag(parser, name, default=False, help=None):
157 | """Add a boolean flag to argparse parser.
158 |
159 | Parameters
160 | ----------
161 | parser: argparse.Parser
162 | parser to add the flag to
163 | name: str
164 | -- will enable the flag, while --no- will disable it
165 | default: bool or None
166 | default value of the flag
167 | help: str
168 | help string for the flag
169 | """
170 | dest = name.replace('-', '_')
171 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
172 | parser.add_argument("--no-" + name, action="store_false", dest=dest)
173 |
174 |
175 | def get_wrapper_by_name(env, classname):
176 | """Given an a gym environment possibly wrapped multiple times, returns a wrapper
177 | of class named classname or raises ValueError if no such wrapper was applied
178 |
179 | Parameters
180 | ----------
181 | env: gym.Env of gym.Wrapper
182 | gym environment
183 | classname: str
184 | name of the wrapper
185 |
186 | Returns
187 | -------
188 | wrapper: gym.Wrapper
189 | wrapper named classname
190 | """
191 | currentenv = env
192 | while True:
193 | if classname == currentenv.class_name():
194 | return currentenv
195 | elif isinstance(currentenv, gym.Wrapper):
196 | currentenv = currentenv.env
197 | else:
198 | raise ValueError("Couldn't find wrapper named %s" % classname)
199 |
200 |
201 | def relatively_safe_pickle_dump(obj, path, compression=False):
202 | """This is just like regular pickle dump, except from the fact that failure cases are
203 | different:
204 |
205 | - It's never possible that we end up with a pickle in corrupted state.
206 | - If a there was a different file at the path, that file will remain unchanged in the
207 | even of failure (provided that filesystem rename is atomic).
208 | - it is sometimes possible that we end up with useless temp file which needs to be
209 | deleted manually (it will be removed automatically on the next function call)
210 |
211 | The indended use case is periodic checkpoints of experiment state, such that we never
212 | corrupt previous checkpoints if the current one fails.
213 |
214 | Parameters
215 | ----------
216 | obj: object
217 | object to pickle
218 | path: str
219 | path to the output file
220 | compression: bool
221 | if true pickle will be compressed
222 | """
223 | temp_storage = path + ".relatively_safe"
224 | if compression:
225 | # Using gzip here would be simpler, but the size is limited to 2GB
226 | with tempfile.NamedTemporaryFile() as uncompressed_file:
227 | pickle.dump(obj, uncompressed_file)
228 | uncompressed_file.file.flush()
229 | with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip:
230 | myzip.write(uncompressed_file.name, "data")
231 | else:
232 | with open(temp_storage, "wb") as f:
233 | pickle.dump(obj, f)
234 | os.rename(temp_storage, path)
235 |
236 |
237 | def pickle_load(path, compression=False):
238 | """Unpickle a possible compressed pickle.
239 |
240 | Parameters
241 | ----------
242 | path: str
243 | path to the output file
244 | compression: bool
245 | if true assumes that pickle was compressed when created and attempts decompression.
246 |
247 | Returns
248 | -------
249 | obj: object
250 | the unpickled object
251 | """
252 |
253 | if compression:
254 | with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip:
255 | with myzip.open("data") as f:
256 | return pickle.load(f)
257 | else:
258 | with open(path, "rb") as f:
259 | return pickle.load(f)
260 |
--------------------------------------------------------------------------------
/third_party/baselines/common/runners.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import numpy as np
3 | import abc
4 | from abc import abstractmethod
5 |
6 | class AbstractEnvRunner(object):
7 | __metaclass__ = abc.ABCMeta
8 |
9 | def __init__(self, env, model, nsteps):
10 | self.env = env
11 | self.model = model
12 | nenv = env.num_envs
13 | self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
14 | self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name)
15 | self.obs[:] = env.reset()
16 | self.nsteps = nsteps
17 | self.states = model.initial_state
18 | self.dones = [False for _ in range(nenv)]
19 |
20 | @abstractmethod
21 | def run(self):
22 | raise NotImplementedError
23 |
--------------------------------------------------------------------------------
/third_party/baselines/common/tile_images.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import numpy as np
3 |
4 | def tile_images(img_nhwc):
5 | """
6 | Tile N images into one big PxQ image
7 | (P,Q) are chosen to be as close as possible, and if N
8 | is square, then P=Q.
9 |
10 | input: img_nhwc, list or array of images, ndim=4 once turned into array
11 | n = batch index, h = height, w = width, c = channel
12 | returns:
13 | bigim_HWc, ndarray with ndim=3
14 | """
15 | img_nhwc = np.asarray(img_nhwc)
16 | N, h, w, c = img_nhwc.shape
17 | H = int(np.ceil(np.sqrt(N)))
18 | W = int(np.ceil(float(N)/H))
19 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
20 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
21 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
22 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
23 | return img_Hh_Ww_c
24 |
25 |
--------------------------------------------------------------------------------
/third_party/baselines/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | from abc import ABCMeta, abstractmethod
7 | from third_party.baselines import logger
8 |
9 | class AlreadySteppingError(Exception):
10 | """
11 | Raised when an asynchronous step is running while
12 | step_async() is called again.
13 | """
14 | def __init__(self):
15 | msg = 'already running an async step'
16 | Exception.__init__(self, msg)
17 |
18 | class NotSteppingError(Exception):
19 | """
20 | Raised when an asynchronous step is not running but
21 | step_wait() is called.
22 | """
23 | def __init__(self):
24 | msg = 'not running an async step'
25 | Exception.__init__(self, msg)
26 |
27 | class VecEnv(object):
28 | """
29 | An abstract asynchronous, vectorized environment.
30 | """
31 | __metaclass__ = ABCMeta
32 |
33 | def __init__(self, num_envs, observation_space, action_space):
34 | self.num_envs = num_envs
35 | self.observation_space = observation_space
36 | self.action_space = action_space
37 |
38 | @abstractmethod
39 | def reset(self):
40 | """
41 | Reset all the environments and return an array of
42 | observations, or a tuple of observation arrays.
43 |
44 | If step_async is still doing work, that work will
45 | be cancelled and step_wait() should not be called
46 | until step_async() is invoked again.
47 | """
48 | pass
49 |
50 | @abstractmethod
51 | def step_async(self, actions):
52 | """
53 | Tell all the environments to start taking a step
54 | with the given actions.
55 | Call step_wait() to get the results of the step.
56 |
57 | You should not call this if a step_async run is
58 | already pending.
59 | """
60 | pass
61 |
62 | @abstractmethod
63 | def step_wait(self):
64 | """
65 | Wait for the step taken with step_async().
66 |
67 | Returns (obs, rews, dones, infos):
68 | - obs: an array of observations, or a tuple of
69 | arrays of observations.
70 | - rews: an array of rewards
71 | - dones: an array of "episode done" booleans
72 | - infos: a sequence of info objects
73 | """
74 | pass
75 |
76 | @abstractmethod
77 | def close(self):
78 | """
79 | Clean up the environments' resources.
80 | """
81 | pass
82 |
83 | def step(self, actions):
84 | self.step_async(actions)
85 | return self.step_wait()
86 |
87 | def render(self, mode='human'):
88 | logger.warn('Render not defined for %s'%self)
89 |
90 | @property
91 | def unwrapped(self):
92 | if isinstance(self, VecEnvWrapper):
93 | return self.venv.unwrapped
94 | else:
95 | return self
96 |
97 | class VecEnvWrapper(VecEnv):
98 | def __init__(self, venv, observation_space=None, action_space=None):
99 | self.venv = venv
100 | VecEnv.__init__(self,
101 | num_envs=venv.num_envs,
102 | observation_space=observation_space or venv.observation_space,
103 | action_space=action_space or venv.action_space)
104 |
105 | def step_async(self, actions):
106 | self.venv.step_async(actions)
107 |
108 | @abstractmethod
109 | def reset(self):
110 | pass
111 |
112 | @abstractmethod
113 | def step_wait(self):
114 | pass
115 |
116 | def close(self):
117 | return self.venv.close()
118 |
119 | def render(self):
120 | self.venv.render()
121 |
122 | class CloudpickleWrapper(object):
123 | """
124 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
125 | """
126 | def __init__(self, x):
127 | self.x = x
128 | def __getstate__(self):
129 | import cloudpickle
130 | return cloudpickle.dumps(self.x)
131 | def __setstate__(self, ob):
132 | import pickle
133 | self.x = pickle.loads(ob)
134 |
--------------------------------------------------------------------------------
/third_party/baselines/common/vec_env/dummy_vec_env.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | from gym import spaces
8 | from collections import OrderedDict
9 | from third_party.baselines.common.vec_env import VecEnv
10 |
11 | class DummyVecEnv(VecEnv):
12 | def __init__(self, env_fns):
13 | self.envs = [fn() for fn in env_fns]
14 | env = self.envs[0]
15 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
16 | shapes, dtypes = {}, {}
17 | self.keys = []
18 | obs_space = env.observation_space
19 |
20 | if isinstance(obs_space, spaces.Dict):
21 | assert isinstance(obs_space.spaces, OrderedDict)
22 | subspaces = obs_space.spaces
23 | else:
24 | subspaces = {None: obs_space}
25 |
26 | for key, box in subspaces.items():
27 | shapes[key] = box.shape
28 | dtypes[key] = box.dtype
29 | self.keys.append(key)
30 |
31 | self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
32 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
33 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
34 | self.buf_infos = [{} for _ in range(self.num_envs)]
35 | self.actions = None
36 |
37 | def step_async(self, actions):
38 | self.actions = actions
39 |
40 | def step_wait(self):
41 | for e in range(self.num_envs):
42 | obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e])
43 | if self.buf_dones[e]:
44 | obs = self.envs[e].reset()
45 | self._save_obs(e, obs)
46 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
47 | self.buf_infos.copy())
48 |
49 | def reset(self):
50 | for e in range(self.num_envs):
51 | obs = self.envs[e].reset()
52 | self._save_obs(e, obs)
53 | return self._obs_from_buf()
54 |
55 | def close(self):
56 | return
57 |
58 | def render(self, mode='human'):
59 | return [e.render(mode=mode) for e in self.envs]
60 |
61 | def _save_obs(self, e, obs):
62 | for k in self.keys:
63 | if k is None:
64 | self.buf_obs[k][e] = obs
65 | else:
66 | self.buf_obs[k][e] = obs[k]
67 |
68 | def _obs_from_buf(self):
69 | if self.keys==[None]:
70 | return self.buf_obs[None]
71 | else:
72 | return self.buf_obs
73 |
--------------------------------------------------------------------------------
/third_party/baselines/common/vec_env/subproc_vec_env.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | from multiprocessing import Process, Pipe
8 | from third_party.baselines.common.vec_env import VecEnv, CloudpickleWrapper
9 | from third_party.baselines.common.tile_images import tile_images
10 |
11 |
12 | def worker(remote, parent_remote, env_fn_wrapper):
13 | parent_remote.close()
14 | env = env_fn_wrapper.x()
15 | while True:
16 | cmd, data = remote.recv()
17 | if cmd == 'step':
18 | ob, reward, done, info = env.step(data)
19 | if done:
20 | ob = env.reset()
21 | remote.send((ob, reward, done, info))
22 | elif cmd == 'reset':
23 | ob = env.reset()
24 | remote.send(ob)
25 | elif cmd == 'render':
26 | remote.send(env.render(mode='rgb_array'))
27 | elif cmd == 'close':
28 | remote.close()
29 | break
30 | elif cmd == 'get_spaces':
31 | remote.send((env.observation_space, env.action_space))
32 | else:
33 | raise NotImplementedError
34 |
35 |
36 | class SubprocVecEnv(VecEnv):
37 | def __init__(self, env_fns, spaces=None):
38 | """
39 | envs: list of gym environments to run in subprocesses
40 | """
41 | self.waiting = False
42 | self.closed = False
43 | nenvs = len(env_fns)
44 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
45 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
46 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
47 | for p in self.ps:
48 | p.daemon = True # if the main process crashes, we should not cause things to hang
49 | p.start()
50 | for remote in self.work_remotes:
51 | remote.close()
52 |
53 | self.remotes[0].send(('get_spaces', None))
54 | observation_space, action_space = self.remotes[0].recv()
55 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
56 |
57 | def step_async(self, actions):
58 | for remote, action in zip(self.remotes, actions):
59 | remote.send(('step', action))
60 | self.waiting = True
61 |
62 | def step_wait(self):
63 | results = [remote.recv() for remote in self.remotes]
64 | self.waiting = False
65 | obs, rews, dones, infos = zip(*results)
66 | return np.stack(obs), np.stack(rews), np.stack(dones), infos
67 |
68 | def reset(self):
69 | for remote in self.remotes:
70 | remote.send(('reset', None))
71 | return np.stack([remote.recv() for remote in self.remotes])
72 |
73 | def reset_task(self):
74 | for remote in self.remotes:
75 | remote.send(('reset_task', None))
76 | return np.stack([remote.recv() for remote in self.remotes])
77 |
78 | def close(self):
79 | if self.closed:
80 | return
81 | if self.waiting:
82 | for remote in self.remotes:
83 | remote.recv()
84 | for remote in self.remotes:
85 | remote.send(('close', None))
86 | for p in self.ps:
87 | p.join()
88 | self.closed = True
89 |
90 | def render(self, mode='human'):
91 | for pipe in self.remotes:
92 | pipe.send(('render', None))
93 | imgs = [pipe.recv() for pipe in self.remotes]
94 | bigimg = tile_images(imgs)
95 | if mode == 'human':
96 | import cv2
97 | cv2.imshow('vecenv', bigimg[:,:,::-1])
98 | cv2.waitKey(1)
99 | elif mode == 'rgb_array':
100 | return bigimg
101 | else:
102 | raise NotImplementedError
103 |
--------------------------------------------------------------------------------
/third_party/baselines/common/vec_env/threaded_vec_env.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """VecEnv implementation using python threads instead of subprocesses."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import threading
10 | from third_party.baselines.common.vec_env import VecEnv
11 | import numpy as np
12 | from six.moves import queue as Queue # pylint: disable=redefined-builtin
13 |
14 |
15 | def thread_worker(send_q, recv_q, env_fn):
16 | """Similar to SubprocVecEnv.worker(), but for TreadedVecEnv.
17 |
18 | Args:
19 | send_q: Queue which ThreadedVecEnv sends commands to.
20 | recv_q: Queue which ThreadedVecEnv receives commands from.
21 | env_fn: Callable that creates an instance of the environment.
22 | """
23 | env = env_fn()
24 | while True:
25 | cmd, data = send_q.get()
26 | if cmd == 'step':
27 | ob, reward, done, info = env.step(data)
28 | if done:
29 | ob = env.reset()
30 | recv_q.put((ob, reward, done, info))
31 | elif cmd == 'reset':
32 | ob = env.reset()
33 | recv_q.put(ob)
34 | elif cmd == 'render':
35 | recv_q.put(env.render(mode='rgb_array'))
36 | elif cmd == 'close':
37 | break
38 | elif cmd == 'get_spaces':
39 | recv_q.put((env.observation_space, env.action_space))
40 | else:
41 | raise NotImplementedError
42 |
43 |
44 | class ThreadedVecEnv(VecEnv):
45 | """Similar to SubprocVecEnv, but uses python threads instead of subprocs.
46 |
47 | Sub-processes involve forks, and a lot of code (incl. google3's) is not
48 | fork-safe, leading to deadlocks. The drawback of python threads is that the
49 | python code is still executed serially because of the GIL. However, many
50 | environments do the heavy lifting in C++ (where the GIL is released, and
51 | hence execution can happen in parallel), so python threads are not often
52 | limiting.
53 | """
54 |
55 | def __init__(self, env_fns, spaces=None):
56 | """
57 | envs: list of gym environments to run in python threads.
58 | """
59 | self.waiting = False
60 | self.closed = False
61 | nenvs = len(env_fns)
62 | self.send_queues = [Queue.Queue() for _ in range(nenvs)]
63 | self.recv_queues = [Queue.Queue() for _ in range(nenvs)]
64 | self.threads = [threading.Thread(target=thread_worker,
65 | args=(send_q, recv_q, env_fn))
66 | for (send_q, recv_q, env_fn) in
67 | zip(self.send_queues, self.recv_queues, env_fns)]
68 | for thread in self.threads:
69 | thread.daemon = True
70 | thread.start()
71 |
72 | self.send_queues[0].put(('get_spaces', None))
73 | observation_space, action_space = self.recv_queues[0].get()
74 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
75 |
76 | def step_async(self, actions):
77 | for send_q, action in zip(self.send_queues, actions):
78 | send_q.put(('step', action))
79 | self.waiting = True
80 |
81 | def step_wait(self):
82 | results = self._receive_all()
83 | self.waiting = False
84 | obs, rews, dones, infos = zip(*results)
85 | return np.stack(obs), np.stack(rews), np.stack(dones), infos
86 |
87 | def reset(self):
88 | self._send_all(('reset', None))
89 | return np.stack(self._receive_all())
90 |
91 | def reset_task(self):
92 | self._send_all(('reset_task', None))
93 | return np.stack(self._receive_all())
94 |
95 | def close(self):
96 | if self.closed:
97 | return
98 | if self.waiting:
99 | self._receive_all()
100 | self._send_all(('close', None))
101 | for thread in self.threads:
102 | thread.join()
103 | self.closed = True
104 |
105 | def render(self, mode='human'):
106 | raise NotImplementedError
107 |
108 | def _send_all(self, item):
109 | for send_q in self.send_queues:
110 | send_q.put(item)
111 |
112 | def _receive_all(self):
113 | return [recv_q.get() for recv_q in self.recv_queues]
114 |
--------------------------------------------------------------------------------
/third_party/baselines/ppo2/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
--------------------------------------------------------------------------------
/third_party/baselines/ppo2/pathak_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """Utility functions used by Pathak's curiosity algorithm.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 |
13 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad='SAME',
14 | dtype=tf.float32, collections=None, trainable=True):
15 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
16 | x = tf.to_float(x)
17 | stride_shape = [1, stride[0], stride[1], 1]
18 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]),
19 | num_filters]
20 |
21 | # there are 'num input feature maps * filter height * filter width'
22 | # inputs to each hidden unit
23 | fan_in = np.prod(filter_shape[:3])
24 | # each unit in the lower layer receives a gradient from:
25 | # 'num output feature maps * filter height * filter width' /
26 | # pooling size
27 | fan_out = np.prod(filter_shape[:2]) * num_filters
28 | # initialize weights with random weights
29 | w_bound = np.sqrt(6. / (fan_in + fan_out))
30 |
31 | w = tf.get_variable('W', filter_shape, dtype,
32 | tf.random_uniform_initializer(-w_bound, w_bound),
33 | collections=collections, trainable=trainable)
34 | b = tf.get_variable('b', [1, 1, 1, num_filters],
35 | initializer=tf.constant_initializer(0.0),
36 | collections=collections, trainable=trainable)
37 | return tf.nn.conv2d(x, w, stride_shape, pad) + b
38 |
39 |
40 | def flatten(x):
41 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
42 |
43 |
44 | def normalized_columns_initializer(std=1.0):
45 | def _initializer(shape, dtype=None, partition_info=None):
46 | out = np.random.randn(*shape).astype(np.float32)
47 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
48 | return tf.constant(out)
49 | return _initializer
50 |
51 |
52 | def linear(x, size, name, initializer=None, bias_init=0):
53 | w = tf.get_variable(name + '/w', [x.get_shape()[1], size],
54 | initializer=initializer)
55 | b = tf.get_variable(name + '/b', [size],
56 | initializer=tf.constant_initializer(bias_init))
57 | return tf.matmul(x, w) + b
58 |
59 |
60 | def universeHead(x, nConvs=4, trainable=True):
61 | """Universe agent example.
62 |
63 | Args:
64 | x: input image
65 | nConvs: number of convolutional layers
66 | trainable: whether conv2d variables are trainable
67 |
68 | Returns:
69 | [None, 288] embedding
70 | """
71 | print('Using universe head design')
72 | x = tf.image.resize_images(x, [42, 42])
73 | x = tf.cast(x, tf.float32) / 255.
74 | for i in range(nConvs):
75 | x = tf.nn.elu(conv2d(x, 32, 'l{}'.format(i + 1), [3, 3], [2, 2],
76 | trainable=trainable))
77 | # print('Loop{} '.format(i+1),tf.shape(x))
78 | # print('Loop{}'.format(i+1),x.get_shape())
79 | x = flatten(x)
80 | return x
81 |
82 |
83 | def icm_forward_model(encoded_state, action, num_actions, hidden_layer_size):
84 | action = tf.one_hot(action, num_actions)
85 | combined_input = tf.concat([encoded_state, action], axis=1)
86 | hidden = tf.nn.relu(linear(combined_input, hidden_layer_size, 'f1',
87 | normalized_columns_initializer(0.01)))
88 | pred_next_state = linear(hidden, encoded_state.get_shape()[1].value, 'flast',
89 | normalized_columns_initializer(0.01))
90 | return pred_next_state
91 |
92 |
93 | def icm_inverse_model(encoded_state, encoded_next_state, num_actions,
94 | hidden_layer_size):
95 | combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
96 | hidden = tf.nn.relu(linear(combined_input, hidden_layer_size, 'g1',
97 | normalized_columns_initializer(0.01)))
98 | # Predicted action logits
99 | logits = linear(hidden, num_actions, 'glast',
100 | normalized_columns_initializer(0.01))
101 | return logits
102 |
--------------------------------------------------------------------------------
/third_party/baselines/ppo2/policies.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | from third_party.baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm
7 | from third_party.baselines.common.distributions import make_pdtype
8 | from third_party.baselines.common.input import observation_input
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 | def nature_cnn(unscaled_images, **conv_kwargs):
13 | """
14 | CNN from Nature paper.
15 | """
16 | scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
17 | activ = tf.nn.relu
18 | h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2),
19 | **conv_kwargs))
20 | h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
21 | h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs))
22 | h3 = conv_to_fc(h3)
23 | return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
24 |
25 | class LnLstmPolicy(object):
26 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False):
27 | nenv = nbatch // nsteps
28 | X, processed_x = observation_input(ob_space, nbatch)
29 | M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
30 | S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
31 | self.pdtype = make_pdtype(ac_space)
32 | with tf.variable_scope("model", reuse=reuse):
33 | h = nature_cnn(processed_x)
34 | xs = batch_to_seq(h, nenv, nsteps)
35 | ms = batch_to_seq(M, nenv, nsteps)
36 | h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm)
37 | h5 = seq_to_batch(h5)
38 | vf = fc(h5, 'v', 1)
39 | self.pd, self.pi = self.pdtype.pdfromlatent(h5)
40 |
41 | v0 = vf[:, 0]
42 | a0 = self.pd.sample()
43 | neglogp0 = self.pd.neglogp(a0)
44 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)
45 |
46 | def step(ob, state, mask):
47 | return sess.run([a0, v0, snew, neglogp0], {X:ob, S:state, M:mask})
48 |
49 | def value(ob, state, mask):
50 | return sess.run(v0, {X:ob, S:state, M:mask})
51 |
52 | self.X = X
53 | self.M = M
54 | self.S = S
55 | self.vf = vf
56 | self.step = step
57 | self.value = value
58 |
59 | class LstmPolicy(object):
60 |
61 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, nlstm=256, reuse=False):
62 | nenv = nbatch // nsteps
63 | self.pdtype = make_pdtype(ac_space)
64 | X, processed_x = observation_input(ob_space, nbatch)
65 |
66 | M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
67 | S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
68 | with tf.variable_scope("model", reuse=reuse):
69 | h = nature_cnn(X)
70 | xs = batch_to_seq(h, nenv, nsteps)
71 | ms = batch_to_seq(M, nenv, nsteps)
72 | h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
73 | h5 = seq_to_batch(h5)
74 | vf = fc(h5, 'v', 1)
75 | self.pd, self.pi = self.pdtype.pdfromlatent(h5)
76 |
77 | v0 = vf[:, 0]
78 | a0 = self.pd.sample()
79 | neglogp0 = self.pd.neglogp(a0)
80 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32)
81 |
82 | def step(ob, state, mask):
83 | return sess.run([a0, v0, snew, neglogp0], {X:ob, S:state, M:mask})
84 |
85 | def value(ob, state, mask):
86 | return sess.run(v0, {X:ob, S:state, M:mask})
87 |
88 | self.X = X
89 | self.M = M
90 | self.S = S
91 | self.vf = vf
92 | self.step = step
93 | self.value = value
94 |
95 | class CnnPolicy(object):
96 |
97 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613
98 | self.pdtype = make_pdtype(ac_space)
99 | X, processed_x = observation_input(ob_space, nbatch)
100 | with tf.variable_scope("model", reuse=reuse):
101 | h = nature_cnn(processed_x, **conv_kwargs)
102 | vf = fc(h, 'v', 1)[:,0]
103 | self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01)
104 |
105 | a0 = self.pd.sample()
106 | neglogp0 = self.pd.neglogp(a0)
107 | self.initial_state = None
108 |
109 | def step(ob, *_args, **_kwargs):
110 | a, v, neglogp = sess.run([a0, vf, neglogp0], {X:ob})
111 | return a, v, self.initial_state, neglogp
112 |
113 | def value(ob, *_args, **_kwargs):
114 | return sess.run(vf, {X:ob})
115 |
116 | self.X = X
117 | self.vf = vf
118 | self.step = step
119 | self.value = value
120 |
121 | class MlpPolicy(object):
122 | def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613
123 | self.pdtype = make_pdtype(ac_space)
124 | with tf.variable_scope("model", reuse=reuse):
125 | X, processed_x = observation_input(ob_space, nbatch)
126 | activ = tf.tanh
127 | processed_x = tf.layers.flatten(processed_x)
128 | pi_h1 = activ(fc(processed_x, 'pi_fc1', nh=64, init_scale=np.sqrt(2)))
129 | pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2)))
130 | vf_h1 = activ(fc(processed_x, 'vf_fc1', nh=64, init_scale=np.sqrt(2)))
131 | vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2)))
132 | vf = fc(vf_h2, 'vf', 1)[:,0]
133 |
134 | self.pd, self.pi = self.pdtype.pdfromlatent(pi_h2, init_scale=0.01)
135 |
136 |
137 | a0 = self.pd.sample()
138 | neglogp0 = self.pd.neglogp(a0)
139 | self.initial_state = None
140 |
141 | def step(ob, *_args, **_kwargs):
142 | a, v, neglogp = sess.run([a0, vf, neglogp0], {X:ob})
143 | return a, v, self.initial_state, neglogp
144 |
145 | def value(ob, *_args, **_kwargs):
146 | return sess.run(vf, {X:ob})
147 |
148 | self.X = X
149 | self.vf = vf
150 | self.step = step
151 | self.value = value
152 |
--------------------------------------------------------------------------------
/third_party/dmlab/dmlab_min_goal_distance.patch:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2016 Google Inc.
2 | # Copyright (C) 2018 Google Inc.
3 | # Copyright (C) 2019 Google Inc.
4 | #
5 | # This program is free software; you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation; either version 2 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License along
16 | # with this program; if not, write to the Free Software Foundation, Inc.,
17 | # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 | #
19 | diff --git a/deepmind/engine/lua_maze_generation.cc b/deepmind/engine/lua_maze_generation.cc
20 | index 8c1b9a9..23c08f7 100644
21 | --- a/deepmind/engine/lua_maze_generation.cc
22 | +++ b/deepmind/engine/lua_maze_generation.cc
23 | @@ -242,6 +242,8 @@ lua::NResultsOr LuaMazeGeneration::CreateRandom(lua_State* L) {
24 |
25 | maze_generation::TextMaze maze(maze_generation::Size{height, width});
26 |
27 | + table.LookUp("minGoalDistance", &maze.min_goal_distance_);
28 | +
29 | // Create random rooms.
30 | maze_generation::SeparateRectangleParams params{};
31 | params.min_size = maze_generation::Size{room_min_size, room_min_size};
32 | @@ -602,6 +604,14 @@ lua::NResultsOr LuaMazeGeneration::VisitRandomPath(lua_State* L) {
33 | return 1;
34 | }
35 |
36 | +lua::NResultsOr LuaMazeGeneration::MinGoalDistance(lua_State* L) {
37 | + if (lua_gettop(L) != 1) {
38 | + return "[MinGoalDistance] - No args expected";
39 | + }
40 | + lua::Push(L, text_maze_.min_goal_distance_);
41 | + return 1;
42 | +}
43 | +
44 | // Registers classes metatable with Lua.
45 | // [0, 0, -]
46 | void LuaMazeGeneration::Register(lua_State* L) {
47 | @@ -625,6 +635,7 @@ void LuaMazeGeneration::Register(lua_State* L) {
48 | {"fromWorldPos", Class::Member<&LuaMazeGeneration::FromWorldPos>},
49 | {"visitFill", Class::Member<&LuaMazeGeneration::VisitFill>},
50 | {"visitRandomPath", Class::Member<&LuaMazeGeneration::VisitRandomPath>},
51 | + {"minGoalDistance", Class::Member<&LuaMazeGeneration::MinGoalDistance>},
52 | };
53 | Class::Register(L, regs);
54 | LuaRoom::Register(L);
55 | diff --git a/deepmind/engine/lua_maze_generation.h b/deepmind/engine/lua_maze_generation.h
56 | index 5e83253..a2915f2 100644
57 | --- a/deepmind/engine/lua_maze_generation.h
58 | +++ b/deepmind/engine/lua_maze_generation.h
59 | @@ -196,6 +196,9 @@ class LuaMazeGeneration : public lua::Class {
60 | // [1, 1, e]
61 | lua::NResultsOr CountVariations(lua_State* L);
62 |
63 | + // Minimum distance between a player spawn and a goal.
64 | + lua::NResultsOr MinGoalDistance(lua_State* L);
65 | +
66 | maze_generation::TextMaze text_maze_;
67 |
68 | static std::uint64_t mixer_seq_;
69 | diff --git a/deepmind/level_generation/text_maze_generation/text_maze.cc b/deepmind/level_generation/text_maze_generation/text_maze.cc
70 | index cc5234e..0d50b12 100644
71 | --- a/deepmind/level_generation/text_maze_generation/text_maze.cc
72 | +++ b/deepmind/level_generation/text_maze_generation/text_maze.cc
73 | @@ -22,7 +22,7 @@ namespace deepmind {
74 | namespace lab {
75 | namespace maze_generation {
76 |
77 | -TextMaze::TextMaze(Size extents) : area_{{0, 0}, extents} {
78 | +TextMaze::TextMaze(Size extents) : area_{{0, 0}, extents}, min_goal_distance_(0) {
79 | std::string level_layer(area_.size.height * (area_.size.width + 1), '*');
80 | std::string variations_layer(area_.size.height * (area_.size.width + 1), '.');
81 | for (int i = 0; i < area_.size.height; ++i) {
82 | diff --git a/deepmind/level_generation/text_maze_generation/text_maze.h b/deepmind/level_generation/text_maze_generation/text_maze.h
83 | index 6b32a6d..bb7396c 100644
84 | --- a/deepmind/level_generation/text_maze_generation/text_maze.h
85 | +++ b/deepmind/level_generation/text_maze_generation/text_maze.h
86 | @@ -242,6 +242,9 @@ class TextMaze {
87 | Rectangle area_;
88 | std::array text_;
89 | std::vector ids_;
90 | +
91 | + public:
92 | + int min_goal_distance_;
93 | };
94 |
95 | } // namespace maze_generation
96 | diff --git a/game_scripts/factories/explore/factory.lua b/game_scripts/factories/explore/factory.lua
97 | index 2352483..57e759d 100644
98 | --- a/game_scripts/factories/explore/factory.lua
99 | +++ b/game_scripts/factories/explore/factory.lua
100 | @@ -69,6 +69,7 @@ function factory.createLevelApi(kwargs)
101 | kwargs.opts.quickRestart = true
102 | end
103 | kwargs.opts.randomSeed = false
104 | + kwargs.opts.minGoalDistance = kwargs.opts.minGoalDistance or 0
105 | kwargs.opts.roomCount = kwargs.opts.roomCount or 4
106 | kwargs.opts.roomMaxSize = kwargs.opts.roomMaxSize or 5
107 | kwargs.opts.roomMinSize = kwargs.opts.roomMinSize or 3
108 | @@ -115,6 +116,7 @@ function factory.createLevelApi(kwargs)
109 | roomMinSize = kwargs.opts.roomMinSize,
110 | roomMaxSize = kwargs.opts.roomMaxSize,
111 | extraConnectionProbability = kwargs.opts.extraConnectionProbability,
112 | + minGoalDistance = kwargs.opts.minGoalDistance,
113 | }
114 |
115 | if kwargs.opts.decalScale and kwargs.opts.decalScale ~= 1 then
116 | diff --git a/game_scripts/factories/explore/goal_locations_factory.lua b/game_scripts/factories/explore/goal_locations_factory.lua
117 | index d3dd9bb..86e7db5 100644
118 | --- a/game_scripts/factories/explore/goal_locations_factory.lua
119 | +++ b/game_scripts/factories/explore/goal_locations_factory.lua
120 | @@ -31,9 +31,9 @@ function level:restart(maze)
121 | local spawnLocations = {}
122 | maze:visitFill{
123 | cell = level._goalLocations[level._goalId],
124 | - func = function(i, j)
125 | + func = function(i, j, distance)
126 | local c = maze:getEntityCell(i, j)
127 | - if c == 'P' then
128 | + if c == 'P' and distance >= maze:minGoalDistance() then
129 | table.insert(spawnLocations, {i, j})
130 | end
131 | end
132 |
--------------------------------------------------------------------------------
/third_party/gym/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2017 OpenAI (http://openai.com)
4 | Copyright (c) 2018 The TF-Agents Authors.
5 | Copyright (c) 2018 Google LLC (http://google.com)
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in
15 | all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 | THE SOFTWARE.
24 |
--------------------------------------------------------------------------------
/third_party/gym/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
--------------------------------------------------------------------------------
/third_party/gym/ant.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # The MIT License
3 | #
4 | # Copyright (c) 2016 OpenAI (https://openai.com)
5 | # Copyright (c) 2018 The TF-Agents Authors.
6 | # Copyright (c) 2018 Google LLC (http://google.com)
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in
16 | # all copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 | # THE SOFTWARE.
25 |
26 | """Ant environment."""
27 |
28 | from __future__ import absolute_import
29 | from __future__ import division
30 | from __future__ import print_function
31 |
32 | from dm_control.mujoco.wrapper.mjbindings import enums
33 | from third_party.gym import mujoco_env
34 | from gym import utils
35 | import numpy as np
36 |
37 |
38 | # pylint: disable=missing-docstring
39 | class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
40 |
41 | def __init__(self,
42 | expose_all_qpos=False,
43 | expose_body_coms=None,
44 | expose_body_comvels=None,
45 | model_path="ant.xml"):
46 | self._expose_all_qpos = expose_all_qpos
47 | self._expose_body_coms = expose_body_coms
48 | self._expose_body_comvels = expose_body_comvels
49 | self._body_com_indices = {}
50 | self._body_comvel_indices = {}
51 | # Settings from
52 | # https://github.com/openai/gym/blob/master/gym/envs/__init__.py
53 | mujoco_env.MujocoEnv.__init__(
54 | self, model_path, 5, max_episode_steps=1000, reward_threshold=6000.0)
55 | utils.EzPickle.__init__(self)
56 |
57 | self.camera_setup()
58 |
59 | def step(self, a):
60 | xposbefore = self.get_body_com("torso")[0]
61 | self.do_simulation(a, self.frame_skip)
62 | xposafter = self.get_body_com("torso")[0]
63 | forward_reward = (xposafter - xposbefore) / self.dt
64 | ctrl_cost = .5 * np.square(a).sum()
65 | contact_cost = 0.5 * 1e-3 * np.sum(
66 | np.square(np.clip(self.physics.data.cfrc_ext, -1, 1)))
67 | survive_reward = 1.0
68 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
69 | state = self.state_vector()
70 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
71 | done = not notdone
72 | ob = self._get_obs()
73 | return ob, reward, done, dict(
74 | reward_forward=forward_reward,
75 | reward_ctrl=-ctrl_cost,
76 | reward_contact=-contact_cost,
77 | reward_survive=survive_reward)
78 |
79 | def _get_obs(self):
80 | if self._expose_all_qpos:
81 | obs = np.concatenate([
82 | self.physics.data.qpos.flat,
83 | self.physics.data.qvel.flat,
84 | np.clip(self.physics.data.cfrc_ext, -1, 1).flat,
85 | ])
86 | else:
87 | obs = np.concatenate([
88 | self.physics.data.qpos.flat[2:],
89 | self.physics.data.qvel.flat,
90 | np.clip(self.physics.data.cfrc_ext, -1, 1).flat,
91 | ])
92 |
93 | if self._expose_body_coms is not None:
94 | for name in self._expose_body_coms:
95 | com = self.get_body_com(name)
96 | if name not in self._body_com_indices:
97 | indices = range(len(obs), len(obs) + len(com))
98 | self._body_com_indices[name] = indices
99 | obs = np.concatenate([obs, com])
100 |
101 | if self._expose_body_comvels is not None:
102 | for name in self._expose_body_comvels:
103 | comvel = self.get_body_comvel(name)
104 | if name not in self._body_comvel_indices:
105 | indices = range(len(obs), len(obs) + len(comvel))
106 | self._body_comvel_indices[name] = indices
107 | obs = np.concatenate([obs, comvel])
108 | return obs
109 |
110 | def reset_model(self):
111 | qpos = self.init_qpos + self.np_random.uniform(
112 | size=self.physics.model.nq, low=-.1, high=.1)
113 | qvel = self.init_qvel + self.np_random.randn(self.physics.model.nv) * .1
114 | self.set_state(qpos, qvel)
115 | return self._get_obs()
116 |
117 | def camera_setup(self):
118 | # pylint: disable=protected-access
119 | self.camera._render_camera.type_ = enums.mjtCamera.mjCAMERA_TRACKING
120 | # pylint: disable=protected-access
121 | self.camera._render_camera.trackbodyid = 1
122 | # pylint: disable=protected-access
123 | self.camera._render_camera.distance = self.physics.model.stat.extent
124 |
125 | @property
126 | def body_com_indices(self):
127 | return self._body_com_indices
128 |
129 | @property
130 | def body_comvel_indices(self):
131 | return self._body_comvel_indices
132 |
--------------------------------------------------------------------------------
/third_party/gym/ant_wrapper_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # The MIT License
3 | #
4 | # Copyright (c) 2016 OpenAI (https://openai.com)
5 | # Copyright (c) 2018 The TF-Agents Authors.
6 | # Copyright (c) 2018 Google LLC (http://google.com)
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in
16 | # all copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWIS, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 | # THE SOFTWARE.
25 | """Tests for google3.third_party.py.third_party.gym.ant_wrapper."""
26 |
27 | from __future__ import absolute_import
28 | from __future__ import division
29 | from __future__ import print_function
30 |
31 | import os
32 | from third_party.gym import ant_wrapper
33 | from google3.pyglib import resources
34 | from google3.testing.pybase import googletest
35 |
36 | ASSETS_DIR = 'google3/third_party/py/third_party.gym/assets'
37 |
38 |
39 | def get_resource(filename):
40 | return resources.GetResourceFilenameInDirectoryTree(
41 | os.path.join(ASSETS_DIR, filename))
42 |
43 |
44 | class AntWrapperTest(googletest.TestCase):
45 |
46 | def test_ant_wrapper(self):
47 | env = ant_wrapper.AntWrapper(
48 | get_resource('mujoco_ant_custom_texture_camerav2.xml'),
49 | texture_mode='fixed',
50 | texture_file_pattern=get_resource('texture.png'))
51 | env.reset()
52 | obs, unused_reward, unused_done, info = env.step(env.action_space.sample())
53 | self.assertEqual(obs.shape, (27,))
54 | self.assertIn('frame', info)
55 | self.assertEqual(info['frame'].shape,
56 | (120, 160, 3))
57 |
58 |
59 | if __name__ == '__main__':
60 | googletest.main()
61 |
--------------------------------------------------------------------------------
/third_party/gym/assets/mujoco_ant_custom_texture_camerav2.xml:
--------------------------------------------------------------------------------
1 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/third_party/gym/assets/texture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/episodic-curiosity/3c406964473d98fb977b1617a170a447b3c548fd/third_party/gym/assets/texture.png
--------------------------------------------------------------------------------
/third_party/gym/mujoco_env.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # The MIT License
3 | #
4 | # Copyright (c) 2016 OpenAI (https://openai.com)
5 | # Copyright (c) 2018 The TF-Agents Authors.
6 | # Copyright (c) 2018 Google LLC (http://google.com)
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to the following conditions:
14 | #
15 | # The above copyright notice and this permission notice shall be included in
16 | # all copies or substantial portions of the Software.
17 | #
18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 | # THE SOFTWARE.
25 |
26 | """Port of the openai gym mujoco envs to run with dm_control.
27 |
28 | In this process / port, the simulators are not the same exactly.
29 | Possible reasons are how the integrators are run --
30 | deepmind uses mj_step / mj_step1 / mj_step2 where as
31 | openai just uses mj_step.
32 | Openai does a compute subtree which I am not doing here
33 |
34 | In addition, I am not 100% confident in how often the 'MjData' is synced,
35 | as a result there could be a 1 frame offset.
36 | Finally, dm_control code appears to do resets slightly differently and use
37 | different mj functions.
38 | """
39 |
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 |
44 | import os
45 |
46 | from dm_control import mujoco
47 | from dm_control.mujoco.wrapper.mjbindings import enums
48 | from dm_control.mujoco.wrapper.mjbindings.functions import mjlib
49 | import gym
50 | from gym import spaces
51 | from gym.utils import seeding
52 | import numpy as np
53 | from six.moves import xrange
54 |
55 |
56 | class CustomPhysics(mujoco.Physics):
57 |
58 | def step(self, n_sub_steps=1):
59 | """Advances physics with up-to-date position and velocity dependent fields.
60 |
61 | The actuation can be updated by calling the `set_control` function first.
62 |
63 | Args:
64 | n_sub_steps: Optional number of times to advance the physics. Default 1.
65 | """
66 |
67 | # This does not line up to how openai does things but instead to how
68 | # deepmind does.
69 | # There are configurations where environments like half cheetah become
70 | # unstable.
71 | # This is a rough proxy for what is supposed to happen but not perfect.
72 | for _ in xrange(n_sub_steps):
73 | mjlib.mj_step(self.model.ptr, self.data.ptr)
74 |
75 | if self.model.opt.integrator != enums.mjtIntegrator.mjINT_EULER:
76 | mjlib.mj_step1(self.model.ptr, self.data.ptr)
77 |
78 |
79 | class MujocoEnv(gym.Env):
80 | """Superclass MuJoCo environments modified to use deepmind's mujoco wrapper.
81 | """
82 |
83 | def __init__(self,
84 | model_path,
85 | frame_skip,
86 | max_episode_steps=None,
87 | reward_threshold=None):
88 | if not os.path.exists(model_path):
89 | raise IOError('File %s does not exist' % model_path)
90 |
91 | self.frame_skip = frame_skip
92 | self.physics = CustomPhysics.from_xml_path(model_path)
93 | self.camera = mujoco.MovableCamera(self.physics, height=480, width=640)
94 |
95 | self.viewer = None
96 |
97 | self.metadata = {
98 | 'render.modes': ['human', 'rgb_array'],
99 | 'video.frames_per_second': int(np.round(1.0 / self.dt))
100 | }
101 |
102 | self.init_qpos = self.physics.data.qpos.ravel().copy()
103 | self.init_qvel = self.physics.data.qvel.ravel().copy()
104 | observation, _, done, _ = self.step(np.zeros(self.physics.model.nu))
105 | assert not done
106 | self.obs_dim = observation.size
107 |
108 | bounds = self.physics.model.actuator_ctrlrange.copy()
109 | low = bounds[:, 0]
110 | high = bounds[:, 1]
111 | self.action_space = spaces.Box(low, high, dtype=np.float32)
112 |
113 | high = np.inf * np.ones(self.obs_dim)
114 | low = -high
115 | self.observation_space = spaces.Box(low, high, dtype=np.float32)
116 |
117 | self.max_episode_steps = max_episode_steps
118 | self.reward_threshold = reward_threshold
119 |
120 | self.seed()
121 | self.camera_setup()
122 |
123 | def seed(self, seed=None):
124 | self.np_random, seed = seeding.np_random(seed)
125 | return [seed]
126 |
127 | def reset_model(self):
128 | """Reset the robot degrees of freedom (qpos and qvel).
129 |
130 | Implement this in each subclass.
131 | """
132 | raise NotImplementedError()
133 |
134 | def viewer_setup(self):
135 | """This method is called when the viewer is initialized and after all reset.
136 |
137 | Optionally implement this method, if you need to tinker with camera
138 | position
139 | and so forth.
140 | """
141 | pass
142 |
143 | def reset(self):
144 | mjlib.mj_resetData(self.physics.model.ptr, self.physics.data.ptr)
145 | ob = self.reset_model()
146 | return ob
147 |
148 | def set_state(self, qpos, qvel):
149 | assert qpos.shape == (self.physics.model.nq,) and qvel.shape == (
150 | self.physics.model.nv,)
151 | assert self.physics.get_state().size == qpos.size + qvel.size
152 | state = np.concatenate([qpos, qvel], 0)
153 | with self.physics.reset_context():
154 | self.physics.set_state(state)
155 |
156 | @property
157 | def dt(self):
158 | return self.physics.model.opt.timestep * self.frame_skip
159 |
160 | def do_simulation(self, ctrl, n_frames):
161 | self.physics.set_control(ctrl)
162 | for _ in range(n_frames):
163 | self.physics.step()
164 |
165 | def render(self, mode='human'):
166 | if mode == 'rgb_array':
167 | data = self.camera.render()
168 | return np.copy(data) # render reuses the same memory space.
169 | elif mode == 'human':
170 | raise NotImplementedError(
171 | 'Currently no interactive renderings are allowed.')
172 |
173 | def get_body_com(self, body_name):
174 | idx = self.physics.model.name2id(body_name, 1)
175 | return self.physics.data.subtree_com[idx]
176 |
177 | def get_body_comvel(self, body_name):
178 | # As of MuJoCo v2.0, updates to `mjData->subtree_linvel` will be skipped
179 | # unless these quantities are needed by the simulation. We therefore call
180 | # `mj_subtreeVel` to update them explicitly.
181 | mjlib.mj_subtreeVel(self.physics.model.ptr, self.physics.data.ptr)
182 | idx = self.physics.model.name2id(body_name, 1)
183 | return self.physics.data.subtree_linvel[idx]
184 |
185 | def get_body_xmat(self, body_name):
186 | raise NotImplementedError()
187 |
188 | def state_vector(self):
189 | return np.concatenate(
190 | [self.physics.data.qpos.flat, self.physics.data.qvel.flat])
191 |
192 | def get_state(self):
193 | return np.array(self.physics.data.qpos.flat), np.array(
194 | self.physics.data.qvel.flat)
195 |
196 | def camera_setup(self):
197 | pass # override this to set up camera
198 |
--------------------------------------------------------------------------------
/third_party/keras_resnet/LICENSE:
--------------------------------------------------------------------------------
1 | COPYRIGHT
2 |
3 | All contributions by Raghavendra Kotikalapudi:
4 | Copyright (c) 2016, Raghavendra Kotikalapudi.
5 | All rights reserved.
6 |
7 | All other contributions:
8 | Copyright (c) 2016, the respective contributors.
9 | All rights reserved.
10 |
11 | Copyright (c) 2018 Google LLC
12 | All rights reserved.
13 |
14 | Each contributor holds copyright over their respective contributions.
15 | The project versioning (Git) records all such contribution source information.
16 |
17 | LICENSE
18 |
19 | The MIT License (MIT)
20 |
21 | Permission is hereby granted, free of charge, to any person obtaining a copy
22 | of this software and associated documentation files (the "Software"), to deal
23 | in the Software without restriction, including without limitation the rights
24 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25 | copies of the Software, and to permit persons to whom the Software is
26 | furnished to do so, subject to the following conditions:
27 |
28 | The above copyright notice and this permission notice shall be included in all
29 | copies or substantial portions of the Software.
30 |
31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37 | SOFTWARE.
38 |
--------------------------------------------------------------------------------
/third_party/keras_resnet/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
--------------------------------------------------------------------------------