├── __init__.py ├── bin ├── __init__.py ├── e2e_test.py ├── train_eval.py └── eval.py ├── loops ├── __init__.py └── train_eval.py ├── tasks ├── __init__.py └── tasks.py ├── utils ├── __init__.py ├── npz.py ├── nested.py ├── visualization.py └── wrappers.py ├── agents ├── __init__.py ├── jax_pure_reward.py ├── pure_reward.py └── sv2p.py ├── objectives ├── __init__.py └── objectives.py ├── planners └── __init__.py ├── simulate ├── __init__.py └── simulate.py ├── imported_models ├── __init__.py ├── sv2p_tests.py ├── sv2p_hparams.py ├── common.py ├── reward_models.py ├── tests_utils.py └── sv2p.py ├── configs ├── pure_reward.gin ├── rewnet_offline_eval.gin ├── jax_pure_reward.gin ├── tests │ ├── jax_pure_reward.gin │ ├── sv2p.gin │ ├── planet.gin │ ├── sv2p_cem_atari.gin │ ├── planet_cem_atari.gin │ ├── planet_cem_cheetah.gin │ ├── sv2p_tfcem_cheetah.gin │ └── jax_pure_reward_cem_cheetah.gin ├── planet_offline_eval.gin ├── pure_reward_offline_eval.gin ├── sv2p.gin ├── sv2p_offline_eval.gin ├── planet.gin ├── rewnet.gin ├── sv2p_cem_atari.gin ├── planet_cem_atari.gin ├── rewnet_cem_atari.gin ├── pure_reward_cem_atari.gin ├── jax_pure_reward_cem_cheetah.gin ├── sv2p_pycem_cheetah.gin ├── planet_cem_walker.gin ├── planet_mppi_cheetah.gin ├── planet_cem_cheetah.gin ├── rewnet_cem_cheetah.gin ├── pure_reward_cem_cheetah.gin └── sv2p_tfcem_cheetah.gin ├── CONTRIBUTING.md ├── setup.py ├── README.md └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /bin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /loops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /objectives/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /planners/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /simulate/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /imported_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /configs/pure_reward.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.pure_reward 16 | 17 | # Parameters for model: 18 | pure_reward.PureReward.recurrent = True 19 | pure_reward.PureReward.output_length = 12 20 | pure_reward.PureReward.task = %TASK 21 | pure_reward.PureReward.model_dir = %model_dir 22 | 23 | # Parameters for train_fn: 24 | pure_reward.create_train_fn.train_steps = 100 25 | pure_reward.create_train_fn.batch = 32 26 | 27 | # Parameters for predict_fn: 28 | pure_reward.create_predict_fn.batch = 1024 29 | pure_reward.create_predict_fn.proposals = 1024 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /configs/rewnet_offline_eval.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | include 'rewnet.gin' 16 | 17 | # model changes 18 | RecurrentStateSpaceModel.include_frames_in_prediction = True 19 | create_planet_train_fn.train_steps = 100 20 | 21 | offline_evaluate.train_eval_iterations = 1000 22 | offline_evaluate.predict_fn = @create_planet_predict_fn() 23 | offline_evaluate.observe_fn = @create_planet_observe_fn() 24 | offline_evaluate.reset_fn = @create_planet_reset_fn() 25 | offline_evaluate.train_fn = @create_planet_train_fn() 26 | offline_evaluate.episode_length = 250 27 | offline_evaluate.prediction_horizon = 12 28 | offline_evaluate.batch = 500 29 | offline_evaluate.num_episodes = 1000 -------------------------------------------------------------------------------- /configs/jax_pure_reward.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.jax_pure_reward 16 | 17 | # Parameters for model: 18 | JaxPureReward.num_layers = 3 19 | JaxPureReward.hidden_dims = 64 20 | create_model.module = @JaxPureReward() 21 | create_model.task = %TASK 22 | MODEL = @model/singleton() 23 | model/singleton.constructor = @create_model 24 | 25 | # Parameters for predict, observe and reset 26 | create_predict_fn.model = %MODEL 27 | 28 | # Parameters for train_fn: 29 | create_train_fn.train_steps = 100 30 | create_train_fn.batch = 64 31 | create_train_fn.duration = 12 32 | create_train_fn.learning_rate = 1e-3 33 | create_train_fn.model_dir = %model_dir 34 | create_train_fn.model = %MODEL 35 | -------------------------------------------------------------------------------- /configs/tests/jax_pure_reward.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.jax_pure_reward 16 | 17 | # Parameters for model: 18 | JaxPureReward.num_layers = 3 19 | JaxPureReward.hidden_dims = 64 20 | create_model.module = @JaxPureReward() 21 | create_model.task = %TASK 22 | MODEL = @model/singleton() 23 | model/singleton.constructor = @create_model 24 | 25 | # Parameters for predict, observe and reset 26 | create_predict_fn.model = %MODEL 27 | 28 | # Parameters for train_fn: 29 | create_train_fn.train_steps = 100 30 | create_train_fn.batch = 64 31 | create_train_fn.duration = 12 32 | create_train_fn.learning_rate = 1e-3 33 | create_train_fn.model_dir = %model_dir 34 | create_train_fn.model = %MODEL 35 | -------------------------------------------------------------------------------- /imported_models/sv2p_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """End 2 end tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow.compat.v1 as tf 23 | 24 | from world_models.imported_models import sv2p 25 | from world_models.imported_models import sv2p_hparams 26 | from world_models.imported_models import tests_utils 27 | 28 | 29 | class SV2PTest(tests_utils.BaseModelTest): 30 | """SV2P tests.""" 31 | 32 | def testSV2PWithActionsAndRewards(self): 33 | self.TestWithActionAndRewards(sv2p_hparams.sv2p_hparams(), sv2p.SV2P) 34 | 35 | 36 | if __name__ == '__main__': 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /configs/planet_offline_eval.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | include 'planet_cem_cheetah.gin' 16 | 17 | # model changes 18 | RecurrentStateSpaceModel.include_frames_in_prediction = True 19 | create_planet_train_fn.train_steps = 100 20 | 21 | offline_evaluate.train_eval_iterations = 1000 22 | offline_evaluate.predict_fn = @create_planet_predict_fn() 23 | offline_evaluate.observe_fn = @create_planet_observe_fn() 24 | offline_evaluate.reset_fn = @create_planet_reset_fn() 25 | offline_evaluate.train_fn = @create_planet_train_fn() 26 | offline_evaluate.episode_length = 250 27 | offline_evaluate.prediction_horizon = 12 28 | offline_evaluate.batch = 500 29 | offline_evaluate.num_episodes = 1000 30 | offline_evaluate.online_eval_task = %TASK 31 | offline_evaluate.online_eval_planner = %EVAL_PLANNER 32 | offline_evaluate.online_eval_episodes = 1 33 | offline_evaluate.enable_train = True -------------------------------------------------------------------------------- /configs/pure_reward_offline_eval.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | include 'pure_reward.gin' 16 | 17 | pure_reward.create_train_fn.train_steps = 100 18 | 19 | # Parameters for task: 20 | DeepMindControl.domain_name = 'cheetah' 21 | DeepMindControl.task_name = 'run' 22 | DeepMindControl.max_duration = 1000 23 | DeepMindControl.action_repeat = 4 24 | TASK = @task/singleton() 25 | task/singleton.constructor = @DeepMindControl 26 | 27 | offline_evaluate.train_eval_iterations = 1000 28 | offline_evaluate.predict_fn = @pure_reward.create_predict_fn() 29 | offline_evaluate.observe_fn = @pure_reward.observe_fn 30 | offline_evaluate.reset_fn = @pure_reward.reset_fn 31 | offline_evaluate.train_fn = @pure_reward.create_train_fn() 32 | offline_evaluate.episode_length = 250 33 | offline_evaluate.prediction_horizon = 12 34 | offline_evaluate.batch = 500 35 | offline_evaluate.num_episodes = 1000 36 | -------------------------------------------------------------------------------- /configs/sv2p.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.sv2p 16 | 17 | # Parameters for model: 18 | sv2p.SV2P.input_length = 2 19 | sv2p.SV2P.output_length = 12 20 | sv2p.SV2P.task = %TASK 21 | sv2p.SV2P.model_dir = %model_dir 22 | MODEL = @model/singleton() 23 | model/singleton.constructor = @SV2P 24 | 25 | # Parameters for train_fn, observe, reset_fn and predict_fn: 26 | STRATEGY = @strategy/singleton() 27 | strategy/singleton.constructor = @tf.distribute.MirroredStrategy 28 | sv2p.create_train_fn.train_steps = 100 29 | sv2p.create_train_fn.batch = 64 30 | sv2p.create_train_fn.model = %MODEL 31 | sv2p.create_train_fn.strategy = %STRATEGY 32 | sv2p.create_reset_fn.model = %MODEL 33 | sv2p.create_observe_fn.model = %MODEL 34 | sv2p.create_predict_fn.model = %MODEL 35 | sv2p.create_predict_fn.prediction_horizon = 12 36 | sv2p.create_predict_fn.strategy = %STRATEGY 37 | -------------------------------------------------------------------------------- /configs/tests/sv2p.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.sv2p 16 | 17 | # Parameters for model: 18 | sv2p.SV2P.input_length = 2 19 | sv2p.SV2P.output_length = 2 20 | sv2p.SV2P.task = %TASK 21 | sv2p.SV2P.model_dir = %model_dir 22 | MODEL = @model/singleton() 23 | model/singleton.constructor = @SV2P 24 | 25 | # Parameters for train_fn, observe, reset_fn and predict_fn: 26 | STRATEGY = @strategy/singleton() 27 | strategy/singleton.constructor = @tf.distribute.MirroredStrategy 28 | sv2p.create_train_fn.train_steps = 1 29 | sv2p.create_train_fn.batch = 1 30 | sv2p.create_train_fn.model = %MODEL 31 | sv2p.create_train_fn.strategy = %STRATEGY 32 | sv2p.create_reset_fn.model = %MODEL 33 | sv2p.create_observe_fn.model = %MODEL 34 | sv2p.create_predict_fn.model = %MODEL 35 | sv2p.create_predict_fn.prediction_horizon = 2 36 | sv2p.create_predict_fn.strategy = %STRATEGY 37 | -------------------------------------------------------------------------------- /configs/sv2p_offline_eval.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | include 'sv2p.gin' 16 | 17 | # model changes 18 | sv2p.SV2P.include_frames_in_prediction = True 19 | sv2p.create_train_fn.train_steps = 100 20 | 21 | # Parameters for task: 22 | DeepMindControl.domain_name = 'cheetah' 23 | DeepMindControl.task_name = 'run' 24 | DeepMindControl.max_duration = 1000 25 | DeepMindControl.action_repeat = 4 26 | TASK = @task/singleton() 27 | task/singleton.constructor = @DeepMindControl 28 | 29 | offline_evaluate.train_eval_iterations = 1000 30 | offline_evaluate.predict_fn = @sv2p.create_predict_fn() 31 | offline_evaluate.observe_fn = @sv2p.create_observe_fn() 32 | offline_evaluate.reset_fn = @sv2p.create_reset_fn() 33 | offline_evaluate.train_fn = @sv2p.create_train_fn() 34 | offline_evaluate.episode_length = 250 35 | offline_evaluate.prediction_horizon = 12 36 | offline_evaluate.batch = 500 37 | offline_evaluate.num_episodes = 1000 38 | -------------------------------------------------------------------------------- /imported_models/sv2p_hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Forked from T2T. Waiting to be cleaned up.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import types 22 | 23 | 24 | def sv2p_hparams(): 25 | """SV2P model hparams.""" 26 | hparams = types.SimpleNamespace() 27 | hparams.video_num_input_frames = 4 28 | hparams.video_num_target_frames = 4 29 | 30 | hparams.merged_reward_model = False 31 | hparams.reward_model_stop_gradient = True 32 | hparams.reward_prediction_classes = 1 33 | 34 | hparams.loss_reward_multiplier = 1.0 35 | hparams.loss_extra_multiplier = 1e-3 36 | 37 | hparams.stochastic = False 38 | hparams.latent_channels = 1 39 | hparams.latent_min_logvar = -5.0 40 | 41 | hparams.num_masks = 10 42 | hparams.relu_shift = 1e-12 43 | hparams.dna_kernel_size = 5 44 | 45 | hparams.scheduled_sampling_iterations = 10000 46 | return hparams 47 | -------------------------------------------------------------------------------- /configs/planet.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.planet 16 | 17 | # Parameters for model: 18 | RecurrentStateSpaceModel.frame_size = (64, 64, 3) 19 | RecurrentStateSpaceModel.reward_stop_gradient = False 20 | RecurrentStateSpaceModel.task = %TASK 21 | MODEL = @model/singleton() 22 | model/singleton.constructor = @RecurrentStateSpaceModel 23 | 24 | # Parameters for predict, observe and reset 25 | STRATEGY = @strategy/singleton() 26 | strategy/singleton.constructor = @tf.distribute.MirroredStrategy 27 | create_planet_predict_fn.model = %MODEL 28 | create_planet_predict_fn.strategy = %STRATEGY 29 | create_planet_observe_fn.model = %MODEL 30 | create_planet_observe_fn.model_dir = %model_dir 31 | create_planet_observe_fn.strategy = %STRATEGY 32 | create_planet_reset_fn.model = %MODEL 33 | 34 | # Parameters for train_fn: 35 | create_planet_train_fn.train_steps = 100 36 | create_planet_train_fn.batch = 64 37 | create_planet_train_fn.duration = 12 38 | create_planet_train_fn.learning_rate = 1e-3 39 | create_planet_train_fn.model_dir = %model_dir 40 | create_planet_train_fn.model = %MODEL 41 | create_planet_train_fn.strategy = %STRATEGY 42 | -------------------------------------------------------------------------------- /configs/tests/planet.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.planet 16 | 17 | # Parameters for model: 18 | RecurrentStateSpaceModel.frame_size = (64, 64, 3) 19 | RecurrentStateSpaceModel.reward_stop_gradient = False 20 | RecurrentStateSpaceModel.task = %TASK 21 | MODEL = @model/singleton() 22 | model/singleton.constructor = @RecurrentStateSpaceModel 23 | 24 | # Parameters for predict, observe and reset 25 | STRATEGY = @strategy/singleton() 26 | strategy/singleton.constructor = @tf.distribute.MirroredStrategy 27 | create_planet_predict_fn.model = %MODEL 28 | create_planet_predict_fn.strategy = %STRATEGY 29 | create_planet_observe_fn.model = %MODEL 30 | create_planet_observe_fn.model_dir = %model_dir 31 | create_planet_observe_fn.strategy = %STRATEGY 32 | create_planet_reset_fn.model = %MODEL 33 | 34 | # Parameters for train_fn: 35 | create_planet_train_fn.train_steps = 1 36 | create_planet_train_fn.batch = 2 37 | create_planet_train_fn.duration = 2 38 | create_planet_train_fn.learning_rate = 1e-3 39 | create_planet_train_fn.model_dir = %model_dir 40 | create_planet_train_fn.model = %MODEL 41 | create_planet_train_fn.strategy = %STRATEGY 42 | -------------------------------------------------------------------------------- /imported_models/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Basic utils for video.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | def shape_list(x): 25 | """Return list of dims, statically where possible.""" 26 | x = tf.convert_to_tensor(x) 27 | 28 | # If unknown rank, return dynamic shape 29 | if x.get_shape().dims is None: 30 | return tf.shape(x) 31 | 32 | static = x.get_shape().as_list() 33 | shape = tf.shape(x) 34 | 35 | ret = [] 36 | for i, dim in enumerate(static): 37 | if dim is None: 38 | dim = shape[i] 39 | ret.append(dim) 40 | return ret 41 | 42 | 43 | def to_int32(tensor): 44 | return tf.cast(tensor, tf.int32) 45 | 46 | 47 | def to_uint8(tensor): 48 | return tf.cast(tensor, tf.uint8) 49 | 50 | 51 | def to_float(tensor): 52 | return tf.cast(tensor, tf.float32) 53 | 54 | 55 | def tinyify(array, tiny_mode, small_mode): 56 | if tiny_mode: 57 | return [1 for _ in array] 58 | if small_mode: 59 | return [max(x // 4, 1) for x in array] 60 | return array 61 | -------------------------------------------------------------------------------- /configs/rewnet.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.agents.planet 16 | 17 | # Parameters for model: 18 | RecurrentStateSpaceModel.frame_size = (64, 64, 3) 19 | RecurrentStateSpaceModel.reward_from_frames = True 20 | RecurrentStateSpaceModel.reward_stop_gradient = True 21 | RecurrentStateSpaceModel.task = %TASK 22 | MODEL = @model/singleton() 23 | model/singleton.constructor = @RecurrentStateSpaceModel 24 | 25 | # Parameters for predict, observe and reset 26 | STRATEGY = @strategy/singleton() 27 | strategy/singleton.constructor = @tf.distribute.MirroredStrategy 28 | create_planet_predict_fn.model = %MODEL 29 | create_planet_predict_fn.strategy = %STRATEGY 30 | create_planet_observe_fn.model = %MODEL 31 | create_planet_observe_fn.strategy = %STRATEGY 32 | create_planet_observe_fn.model_dir = %model_dir 33 | create_planet_reset_fn.model = %MODEL 34 | 35 | # Parameters for train_fn: 36 | create_planet_train_fn.train_steps = 100 37 | create_planet_train_fn.batch = 50 38 | create_planet_train_fn.duration = 12 39 | create_planet_train_fn.learning_rate = 1e-3 40 | create_planet_train_fn.model_dir = %model_dir 41 | create_planet_train_fn.model = %MODEL 42 | create_planet_train_fn.strategy = %STRATEGY 43 | -------------------------------------------------------------------------------- /bin/e2e_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """End to end test suite.""" 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import gin 22 | import os 23 | import tensorflow.compat.v1 as tf 24 | from world_models.loops import train_eval 25 | 26 | 27 | class E2ETest(parameterized.TestCase): 28 | 29 | @classmethod 30 | def setUpClass(cls): 31 | tf.enable_eager_execution() 32 | super(E2ETest, cls).setUpClass() 33 | 34 | @parameterized.parameters( 35 | 'configs/tests/planet_cem_cheetah.gin', 36 | 'configs/tests/planet_cem_atari.gin', 37 | 'configs/tests/sv2p_tfcem_cheetah.gin', 38 | 'configs/tests/sv2p_cem_atari.gin', 39 | 'configs/tests/jax_pure_reward_cem_cheetah.gin', 40 | ) 41 | def testConfig(self, config_path): 42 | tmp_dir = self.create_tempdir() 43 | config_params = train_eval.get_gin_override_params(tmp_dir) 44 | test_srcdir = absltest.get_default_test_srcdir() 45 | config_path = os.path.join(test_srcdir, config_path) 46 | gin.parse_config_files_and_bindings([config_path], config_params) 47 | train_eval.train_eval_loop() 48 | 49 | 50 | if __name__ == '__main__': 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /configs/tests/sv2p_cem_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.objectives.objectives 16 | import world_models.planners.planners 17 | import world_models.simulate.simulate 18 | import world_models.tasks.tasks 19 | 20 | include 'sv2p.gin' 21 | 22 | # Parameters for task: 23 | Atari.game = 'PongNoFrameskip-v4' 24 | Atari.max_duration = 40 25 | Atari.action_repeat = 4 26 | TASK = @task/singleton() 27 | task/singleton.constructor = @Atari 28 | 29 | # Parameters for planner: 30 | CEM.horizon = 3 31 | CEM.proposals = 10 32 | CEM.fraction = 0.1 33 | CEM.iterations = 1 34 | CEM.objective_fn = @objectives.DiscountedReward() 35 | CEM.predict_fn = @sv2p.create_predict_fn() 36 | CEM.observe_fn = @sv2p.create_observe_fn() 37 | CEM.reset_fn = @sv2p.create_reset_fn() 38 | CEM.task = %TASK 39 | TRAIN_PLANNER = @train_planner/singleton() 40 | train_planner/singleton.constructor = @CEM 41 | 42 | # Parameters for train_eval_loop: 43 | train_eval_loop.task = %TASK 44 | train_eval_loop.train_planner = %TRAIN_PLANNER 45 | train_eval_loop.eval_planner = %TRAIN_PLANNER 46 | train_eval_loop.num_train_episodes_per_iteration = 1 47 | train_eval_loop.eval_every_n_iterations = 0 48 | train_eval_loop.num_iterations = 1 49 | train_eval_loop.model_dir = %model_dir 50 | train_eval_loop.episodes_dir = %episodes_dir 51 | train_eval_loop.train_fn = @sv2p.create_train_fn() 52 | -------------------------------------------------------------------------------- /configs/tests/planet_cem_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.objectives.objectives 16 | import world_models.planners.planners 17 | import world_models.simulate.simulate 18 | import world_models.tasks.tasks 19 | 20 | include 'planet.gin' 21 | 22 | # Parameters for task: 23 | Atari.game = 'PongNoFrameskip-v4' 24 | Atari.max_duration = 40 25 | Atari.action_repeat = 4 26 | TASK = @task/singleton() 27 | task/singleton.constructor = @Atari 28 | 29 | # Parameters for planner: 30 | CEM.horizon = 2 31 | CEM.proposals = 10 32 | CEM.fraction = 0.1 33 | CEM.iterations = 10 34 | CEM.objective_fn = @objectives.DiscountedReward() 35 | CEM.predict_fn = @create_planet_predict_fn() 36 | CEM.observe_fn = @create_planet_observe_fn() 37 | CEM.reset_fn = @create_planet_reset_fn() 38 | CEM.task = %TASK 39 | TRAIN_PLANNER = @train_planner/singleton() 40 | train_planner/singleton.constructor = @CEM 41 | 42 | # Parameters for train_eval_loop: 43 | train_eval_loop.task = %TASK 44 | train_eval_loop.train_planner = %TRAIN_PLANNER 45 | train_eval_loop.eval_planner = %TRAIN_PLANNER 46 | train_eval_loop.num_train_episodes_per_iteration = 1 47 | train_eval_loop.eval_every_n_iterations = 0 48 | train_eval_loop.num_iterations = 1 49 | train_eval_loop.model_dir = %model_dir 50 | train_eval_loop.episodes_dir = %episodes_dir 51 | train_eval_loop.train_fn = @create_planet_train_fn() 52 | -------------------------------------------------------------------------------- /configs/tests/planet_cem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.objectives.objectives 16 | import world_models.planners.planners 17 | import world_models.simulate.simulate 18 | import world_models.tasks.tasks 19 | 20 | include 'planet.gin' 21 | 22 | # Parameters for task: 23 | DeepMindControl.domain_name = 'cheetah' 24 | DeepMindControl.task_name = 'run' 25 | DeepMindControl.max_duration = 40 26 | DeepMindControl.action_repeat = 4 27 | TASK = @task/singleton() 28 | task/singleton.constructor = @DeepMindControl 29 | 30 | # Parameters for planner: 31 | TensorFlowCEM.horizon = 2 32 | TensorFlowCEM.proposals = 10 33 | TensorFlowCEM.fraction = 0.1 34 | TensorFlowCEM.iterations = 10 35 | TensorFlowCEM.objective_fn = @objectives.TensorFlowDiscountedReward() 36 | TensorFlowCEM.predict_fn = @create_planet_predict_fn() 37 | TensorFlowCEM.observe_fn = @create_planet_observe_fn() 38 | TensorFlowCEM.reset_fn = @create_planet_reset_fn() 39 | TensorFlowCEM.task = %TASK 40 | TRAIN_PLANNER = @train_planner/singleton() 41 | train_planner/singleton.constructor = @TensorFlowCEM 42 | 43 | # Parameters for train_eval_loop: 44 | train_eval_loop.task = %TASK 45 | train_eval_loop.train_planner = %TRAIN_PLANNER 46 | train_eval_loop.eval_planner = %TRAIN_PLANNER 47 | train_eval_loop.num_train_episodes_per_iteration = 1 48 | train_eval_loop.eval_every_n_iterations = 0 49 | train_eval_loop.num_iterations = 1 50 | train_eval_loop.model_dir = %model_dir 51 | train_eval_loop.episodes_dir = %episodes_dir 52 | train_eval_loop.train_fn = @create_planet_train_fn() 53 | -------------------------------------------------------------------------------- /configs/tests/sv2p_tfcem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.objectives.objectives 16 | import world_models.planners.planners 17 | import world_models.simulate.simulate 18 | import world_models.tasks.tasks 19 | 20 | include 'sv2p.gin' 21 | 22 | # Parameters for task: 23 | DeepMindControl.domain_name = 'cheetah' 24 | DeepMindControl.task_name = 'run' 25 | DeepMindControl.max_duration = 40 26 | DeepMindControl.action_repeat = 4 27 | TASK = @task/singleton() 28 | task/singleton.constructor = @DeepMindControl 29 | 30 | # Parameters for planner: 31 | TensorFlowCEM.horizon = 2 32 | TensorFlowCEM.proposals = 5 33 | TensorFlowCEM.fraction = 0.5 34 | TensorFlowCEM.iterations = 2 35 | TensorFlowCEM.objective_fn = @objectives.TensorFlowDiscountedReward() 36 | TensorFlowCEM.predict_fn = @sv2p.create_predict_fn() 37 | TensorFlowCEM.observe_fn = @sv2p.create_observe_fn() 38 | TensorFlowCEM.reset_fn = @sv2p.create_reset_fn() 39 | TensorFlowCEM.task = %TASK 40 | TRAIN_PLANNER = @train_planner/singleton() 41 | train_planner/singleton.constructor = @TensorFlowCEM 42 | 43 | # Parameters for train_eval_loop: 44 | train_eval_loop.task = %TASK 45 | train_eval_loop.train_planner = %TRAIN_PLANNER 46 | train_eval_loop.eval_planner = %TRAIN_PLANNER 47 | train_eval_loop.num_train_episodes_per_iteration = 1 48 | train_eval_loop.eval_every_n_iterations = 0 # 0 disables the evaluation phase 49 | train_eval_loop.num_iterations = 1 50 | train_eval_loop.model_dir = %model_dir 51 | train_eval_loop.episodes_dir = %episodes_dir 52 | train_eval_loop.train_fn = @sv2p.create_train_fn() 53 | -------------------------------------------------------------------------------- /objectives/objectives.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of objectives.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import gin 21 | import numpy as np 22 | import tensorflow.compat.v1 as tf 23 | from typing import Dict, Text 24 | 25 | 26 | class Objective(object): 27 | """Base class for objectives.""" 28 | 29 | def __call__(self, predictions: Dict[Text, np.ndarray]): 30 | """Calculates the reward from predictions. 31 | 32 | Args: 33 | predictions: a dictionary with possibly the following entries: 34 | * "image": [batch, steps, height, width, channels] np array. 35 | * "reward": [batch, steps, 1] np array. 36 | 37 | Returns: 38 | a [batch, 1] ndarray for the rewards. 39 | """ 40 | raise NotImplementedError 41 | 42 | 43 | @gin.configurable 44 | class RandomObjective(Objective): 45 | """A test objective that returns random rewards sampled from a normal dist.""" 46 | 47 | def __call__(self, predictions): 48 | batch = predictions["image"].shape[0] 49 | return np.random.normal(size=[batch, 1]) 50 | 51 | 52 | @gin.configurable 53 | class DiscountedReward(Objective): 54 | """To be used with world model already predicting rewards.""" 55 | 56 | def __call__(self, predictions): 57 | return np.sum(predictions["reward"], axis=1) 58 | 59 | 60 | @gin.configurable 61 | class TensorFlowDiscountedReward(Objective): 62 | """TensorFlow version of discounted reward.""" 63 | 64 | @tf.function 65 | def __call__(self, predictions): 66 | return tf.reduce_sum(predictions["reward"], axis=1) 67 | -------------------------------------------------------------------------------- /configs/sv2p_cem_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'sv2p.gin' 22 | 23 | # Parameters for task: 24 | Atari.game = 'PongNoFrameskip-v4' 25 | Atari.max_duration = 1000 26 | Atari.action_repeat = 4 27 | TASK = @task/singleton() 28 | task/singleton.constructor = @Atari 29 | 30 | # Parameters for planner: 31 | CEM.horizon = 12 32 | CEM.proposals = 128 33 | CEM.fraction = 0.1 34 | CEM.iterations = 10 35 | CEM.objective_fn = @objectives.DiscountedReward() 36 | CEM.predict_fn = @sv2p.create_predict_fn() 37 | CEM.observe_fn = @sv2p.create_observe_fn() 38 | CEM.reset_fn = @sv2p.create_reset_fn() 39 | CEM.task = %TASK 40 | EVAL_PLANNER = @eval_planner/singleton() 41 | eval_planner/singleton.constructor = @CEM 42 | Randomizer.base_planner = %EVAL_PLANNER 43 | Randomizer.task = %TASK 44 | Randomizer.rand_prob = 0.2 45 | RandomColdStart.task = %TASK 46 | RandomColdStart.random_episodes = 5 47 | RandomColdStart.base_planner = @Randomizer() 48 | TRAIN_PLANNER = @train_planner/singleton() 49 | train_planner/singleton.constructor = @RandomColdStart 50 | 51 | # Parameters for train_eval_loop: 52 | train_eval_loop.task = %TASK 53 | train_eval_loop.train_planner = %TRAIN_PLANNER 54 | train_eval_loop.eval_planner = %EVAL_PLANNER 55 | train_eval_loop.num_train_episodes_per_iteration = 1 56 | train_eval_loop.eval_every_n_iterations = 10 57 | train_eval_loop.num_iterations = 1010 58 | train_eval_loop.model_dir = %model_dir 59 | train_eval_loop.episodes_dir = %episodes_dir 60 | train_eval_loop.train_fn = @sv2p.create_train_fn() 61 | -------------------------------------------------------------------------------- /configs/planet_cem_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'planet.gin' 22 | 23 | # Parameters for task: 24 | Atari.game = 'PongNoFrameskip-v4' 25 | Atari.max_duration = 1000 26 | Atari.action_repeat = 4 27 | TASK = @task/singleton() 28 | task/singleton.constructor = @Atari 29 | 30 | # Parameters for planner: 31 | CEM.horizon = 12 32 | CEM.proposals = 128 33 | CEM.fraction = 0.1 34 | CEM.iterations = 10 35 | CEM.objective_fn = @objectives.DiscountedReward() 36 | CEM.predict_fn = @create_planet_predict_fn() 37 | CEM.observe_fn = @create_planet_observe_fn() 38 | CEM.reset_fn = @create_planet_reset_fn() 39 | CEM.task = %TASK 40 | EVAL_PLANNER = @eval_planner/singleton() 41 | eval_planner/singleton.constructor = @CEM 42 | Randomizer.base_planner = %EVAL_PLANNER 43 | Randomizer.task = %TASK 44 | Randomizer.rand_prob = 0.2 45 | RandomColdStart.task = %TASK 46 | RandomColdStart.random_episodes = 5 47 | RandomColdStart.base_planner = @Randomizer() 48 | TRAIN_PLANNER = @train_planner/singleton() 49 | train_planner/singleton.constructor = @RandomColdStart 50 | 51 | # Parameters for train_eval_loop: 52 | train_eval_loop.task = %TASK 53 | train_eval_loop.train_planner = %TRAIN_PLANNER 54 | train_eval_loop.eval_planner = %EVAL_PLANNER 55 | train_eval_loop.num_train_episodes_per_iteration = 1 56 | train_eval_loop.eval_every_n_iterations = 10 57 | train_eval_loop.num_iterations = 1010 58 | train_eval_loop.model_dir = %model_dir 59 | train_eval_loop.episodes_dir = %episodes_dir 60 | train_eval_loop.train_fn = @create_planet_train_fn() 61 | -------------------------------------------------------------------------------- /configs/rewnet_cem_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'rewnet.gin' 22 | 23 | # Parameters for task: 24 | Atari.game = 'PongNoFrameskip-v4' 25 | Atari.max_duration = 1000 26 | Atari.action_repeat = 4 27 | TASK = @task/singleton() 28 | task/singleton.constructor = @Atari 29 | 30 | # Parameters for planner: 31 | CEM.horizon = 12 32 | CEM.proposals = 128 33 | CEM.fraction = 0.1 34 | CEM.iterations = 10 35 | CEM.objective_fn = @objectives.DiscountedReward() 36 | CEM.predict_fn = @create_planet_predict_fn() 37 | CEM.observe_fn = @create_planet_observe_fn() 38 | CEM.reset_fn = @create_planet_reset_fn() 39 | CEM.task = %TASK 40 | EVAL_PLANNER = @eval_planner/singleton() 41 | eval_planner/singleton.constructor = @CEM 42 | Randomizer.base_planner = %EVAL_PLANNER 43 | Randomizer.task = %TASK 44 | Randomizer.rand_prob = 0.2 45 | RandomColdStart.task = %TASK 46 | RandomColdStart.random_episodes = 5 47 | RandomColdStart.base_planner = @Randomizer() 48 | TRAIN_PLANNER = @train_planner/singleton() 49 | train_planner/singleton.constructor = @RandomColdStart 50 | 51 | # Parameters for train_eval_loop: 52 | train_eval_loop.task = %TASK 53 | train_eval_loop.train_planner = %TRAIN_PLANNER 54 | train_eval_loop.eval_planner = %EVAL_PLANNER 55 | train_eval_loop.num_train_episodes_per_iteration = 1 56 | train_eval_loop.eval_every_n_iterations = 10 57 | train_eval_loop.num_iterations = 1010 58 | train_eval_loop.model_dir = %model_dir 59 | train_eval_loop.episodes_dir = %episodes_dir 60 | train_eval_loop.train_fn = @create_planet_train_fn() 61 | -------------------------------------------------------------------------------- /configs/pure_reward_cem_atari.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'pure_reward.gin' 22 | 23 | # Parameters for task: 24 | Atari.game = 'PongNoFrameskip-v4' 25 | Atari.max_duration = 1000 26 | Atari.action_repeat = 4 27 | TASK = @task/singleton() 28 | task/singleton.constructor = @Atari 29 | 30 | # Parameters for planner: 31 | CEM.horizon = 12 32 | CEM.proposals = 128 33 | CEM.fraction = 0.1 34 | CEM.iterations = 10 35 | CEM.objective_fn = @objectives.DiscountedReward() 36 | CEM.objective_fn = @objectives.TensorFlowDiscountedReward() 37 | CEM.predict_fn = @pure_reward.create_predict_fn() 38 | CEM.observe_fn = @pure_reward.observe_fn 39 | CEM.reset_fn = @pure_reward.reset_fn 40 | CEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @CEM 43 | Randomizer.base_planner = %EVAL_PLANNER 44 | Randomizer.task = %TASK 45 | Randomizer.rand_prob = 0.2 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 5 48 | RandomColdStart.base_planner = @Randomizer() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 58 | train_eval_loop.num_iterations = 1010 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @pure_reward.create_train_fn() 62 | -------------------------------------------------------------------------------- /configs/jax_pure_reward_cem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'jax_pure_reward.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | CEM.horizon = 12 33 | CEM.proposals = 128 34 | CEM.fraction = 0.1 35 | CEM.iterations = 10 36 | CEM.objective_fn = @objectives.DiscountedReward() 37 | CEM.predict_fn = @create_predict_fn() 38 | CEM.observe_fn = @observe_fn 39 | CEM.reset_fn = @reset_fn 40 | CEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @CEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 5 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 58 | train_eval_loop.num_iterations = 1010 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @create_train_fn() 62 | -------------------------------------------------------------------------------- /configs/tests/jax_pure_reward_cem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'jax_pure_reward.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | CEM.horizon = 12 33 | CEM.proposals = 128 34 | CEM.fraction = 0.1 35 | CEM.iterations = 10 36 | CEM.objective_fn = @objectives.DiscountedReward() 37 | CEM.predict_fn = @create_predict_fn() 38 | CEM.observe_fn = @observe_fn 39 | CEM.reset_fn = @reset_fn 40 | CEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @CEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 0 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 0 58 | train_eval_loop.num_iterations = 1 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @create_train_fn() 62 | -------------------------------------------------------------------------------- /configs/sv2p_pycem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'sv2p.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | CEM.horizon = 12 33 | CEM.proposals = 128 34 | CEM.fraction = 0.1 35 | CEM.iterations = 10 36 | CEM.objective_fn = @objectives.DiscountedReward() 37 | CEM.predict_fn = @sv2p.create_predict_fn() 38 | CEM.observe_fn = @sv2p.create_observe_fn() 39 | CEM.reset_fn = @sv2p.create_reset_fn() 40 | CEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @CEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 1 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 58 | train_eval_loop.num_iterations = 1010 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @sv2p.create_train_fn() 62 | -------------------------------------------------------------------------------- /configs/planet_cem_walker.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'planet.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'walker' 25 | DeepMindControl.task_name = 'walk' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 2 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | CEM.horizon = 12 33 | CEM.proposals = 1000 34 | CEM.fraction = 0.1 35 | CEM.iterations = 10 36 | CEM.objective_fn = @objectives.DiscountedReward() 37 | CEM.predict_fn = @create_planet_predict_fn() 38 | CEM.observe_fn = @create_planet_observe_fn() 39 | CEM.reset_fn = @create_planet_reset_fn() 40 | CEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @CEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 5 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 58 | train_eval_loop.num_iterations = 1000000 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @create_planet_train_fn() 62 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import setuptools 16 | 17 | with open("README.md", "r") as fh: 18 | long_description = fh.read() 19 | 20 | setuptools.setup( 21 | name="world_models", 22 | version="1.0.0", 23 | author="Google LLC", 24 | author_email="no-reply@google.com", 25 | description="World Models Library", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | url="https://github.com/google-research/tree/master/world_models", 29 | packages=["world_models", "world_models.agents", "world_models.bin", 30 | "world_models.imported_models", "world_models.loops", 31 | "world_models.objectives", "world_models.planners", 32 | "world_models.simulate", "world_models.tasks", 33 | "world_models.utils"], 34 | package_dir={"world_models": "", "world_models.agents": "agents", 35 | "world_models.bin": "bin", 36 | "world_models.imported_models": "imported_models", 37 | "world_models.loops": "loops", 38 | "world_models.objectives": "objectives", 39 | "world_models.planners": "planners", 40 | "world_models.simulate": "simulate", 41 | "world_models.tasks": "tasks", 42 | "world_models.utils": "utils"}, 43 | classifiers=[ 44 | "Programming Language :: Python :: 3", 45 | "License :: OSI Approved :: Apache Software License", 46 | "Operating System :: OS Independent", 47 | ], 48 | install_requires=[ 49 | "absl-py", 50 | "gin-config", 51 | "numpy", 52 | "tensorflow==1.15", 53 | "tensorflow-probability==0.7", 54 | "gym", 55 | "dm_control", 56 | "mujoco-py==2.0.2.8", 57 | ], 58 | python_requires="<3.8", 59 | ) 60 | -------------------------------------------------------------------------------- /configs/planet_mppi_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'planet.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | MPPI.horizon = 12 33 | MPPI.proposals = 128 34 | MPPI.fraction = 0.1 35 | MPPI.iterations = 10 36 | MPPI.beta = [0.6, 0.4, 0.0] 37 | MPPI.gamma = 0.1 38 | MPPI.objective_fn = @objectives.DiscountedReward() 39 | MPPI.predict_fn = @create_planet_predict_fn() 40 | MPPI.observe_fn = @create_planet_observe_fn() 41 | MPPI.reset_fn = @create_planet_reset_fn() 42 | MPPI.task = %TASK 43 | EVAL_PLANNER = @eval_planner/singleton() 44 | eval_planner/singleton.constructor = @MPPI 45 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 46 | GaussianRandomNoise.task = %TASK 47 | GaussianRandomNoise.stdev = 0.3 48 | RandomColdStart.task = %TASK 49 | RandomColdStart.random_episodes = 5 50 | RandomColdStart.base_planner = @GaussianRandomNoise() 51 | TRAIN_PLANNER = @train_planner/singleton() 52 | train_planner/singleton.constructor = @RandomColdStart 53 | 54 | # Parameters for train_eval_loop: 55 | train_eval_loop.task = %TASK 56 | train_eval_loop.train_planner = %TRAIN_PLANNER 57 | train_eval_loop.eval_planner = %EVAL_PLANNER 58 | train_eval_loop.num_train_episodes_per_iteration = 1 59 | train_eval_loop.eval_every_n_iterations = 10 60 | train_eval_loop.num_iterations = 1010 61 | train_eval_loop.model_dir = %model_dir 62 | train_eval_loop.episodes_dir = %episodes_dir 63 | train_eval_loop.train_fn = @create_planet_train_fn() 64 | -------------------------------------------------------------------------------- /configs/planet_cem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'planet.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | TensorFlowCEM.horizon = 12 33 | TensorFlowCEM.proposals = 128 34 | TensorFlowCEM.fraction = 0.1 35 | TensorFlowCEM.iterations = 10 36 | TensorFlowCEM.objective_fn = @objectives.TensorFlowDiscountedReward() 37 | TensorFlowCEM.predict_fn = @create_planet_predict_fn() 38 | TensorFlowCEM.observe_fn = @create_planet_observe_fn() 39 | TensorFlowCEM.reset_fn = @create_planet_reset_fn() 40 | TensorFlowCEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @TensorFlowCEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 5 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 58 | train_eval_loop.num_iterations = 1010 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @create_planet_train_fn() 62 | -------------------------------------------------------------------------------- /configs/rewnet_cem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'rewnet.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | TensorFlowCEM.horizon = 12 33 | TensorFlowCEM.proposals = 128 34 | TensorFlowCEM.fraction = 0.1 35 | TensorFlowCEM.iterations = 10 36 | TensorFlowCEM.objective_fn = @objectives.TensorFlowDiscountedReward() 37 | TensorFlowCEM.predict_fn = @create_planet_predict_fn() 38 | TensorFlowCEM.observe_fn = @create_planet_observe_fn() 39 | TensorFlowCEM.reset_fn = @create_planet_reset_fn() 40 | TensorFlowCEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @TensorFlowCEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 5 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 58 | train_eval_loop.num_iterations = 1010 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @create_planet_train_fn() 62 | -------------------------------------------------------------------------------- /configs/pure_reward_cem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'pure_reward.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | TensorFlowCEM.horizon = 12 33 | TensorFlowCEM.proposals = 128 34 | TensorFlowCEM.fraction = 0.1 35 | TensorFlowCEM.iterations = 10 36 | TensorFlowCEM.objective_fn = @objectives.TensorFlowDiscountedReward() 37 | TensorFlowCEM.predict_fn = @pure_reward.create_predict_fn() 38 | TensorFlowCEM.observe_fn = @pure_reward.observe_fn 39 | TensorFlowCEM.reset_fn = @pure_reward.reset_fn 40 | TensorFlowCEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @TensorFlowCEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 1 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 58 | train_eval_loop.num_iterations = 1010 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @pure_reward.create_train_fn() 62 | -------------------------------------------------------------------------------- /configs/sv2p_tfcem_cheetah.gin: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import world_models.loops.train_eval 16 | import world_models.objectives.objectives 17 | import world_models.planners.planners 18 | import world_models.simulate.simulate 19 | import world_models.tasks.tasks 20 | 21 | include 'sv2p.gin' 22 | 23 | # Parameters for task: 24 | DeepMindControl.domain_name = 'cheetah' 25 | DeepMindControl.task_name = 'run' 26 | DeepMindControl.max_duration = 1000 27 | DeepMindControl.action_repeat = 4 28 | TASK = @task/singleton() 29 | task/singleton.constructor = @DeepMindControl 30 | 31 | # Parameters for planner: 32 | TensorFlowCEM.horizon = 12 33 | TensorFlowCEM.proposals = 128 34 | TensorFlowCEM.fraction = 0.1 35 | TensorFlowCEM.iterations = 10 36 | TensorFlowCEM.objective_fn = @objectives.TensorFlowDiscountedReward() 37 | TensorFlowCEM.predict_fn = @sv2p.create_predict_fn() 38 | TensorFlowCEM.observe_fn = @sv2p.create_observe_fn() 39 | TensorFlowCEM.reset_fn = @sv2p.create_reset_fn() 40 | TensorFlowCEM.task = %TASK 41 | EVAL_PLANNER = @eval_planner/singleton() 42 | eval_planner/singleton.constructor = @TensorFlowCEM 43 | GaussianRandomNoise.base_planner = %EVAL_PLANNER 44 | GaussianRandomNoise.task = %TASK 45 | GaussianRandomNoise.stdev = 0.3 46 | RandomColdStart.task = %TASK 47 | RandomColdStart.random_episodes = 1 48 | RandomColdStart.base_planner = @GaussianRandomNoise() 49 | TRAIN_PLANNER = @train_planner/singleton() 50 | train_planner/singleton.constructor = @RandomColdStart 51 | 52 | # Parameters for train_eval_loop: 53 | train_eval_loop.task = %TASK 54 | train_eval_loop.train_planner = %TRAIN_PLANNER 55 | train_eval_loop.eval_planner = %EVAL_PLANNER 56 | train_eval_loop.num_train_episodes_per_iteration = 1 57 | train_eval_loop.eval_every_n_iterations = 10 # 0 disables the evaluation phase 58 | train_eval_loop.num_iterations = 1010 59 | train_eval_loop.model_dir = %model_dir 60 | train_eval_loop.episodes_dir = %episodes_dir 61 | train_eval_loop.train_fn = @sv2p.create_train_fn() 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # World Models Library 2 | 3 | World Models is a platform-agnostic library to facilitate visual based 4 | agents for planning. This notebook 5 | ([run it in colab](https://colab.research.google.com/github/google-research/world_models/blob/master/intro.ipynb)) 6 | shows how to use World Models library and its different 7 | components. 8 | 9 | To run locally, use the following command: 10 | 11 | ```$xslt 12 | python3 -m world_models.bin.train_eval \ 13 | --config_path=/path/to/config \ 14 | --output_dir=/path/to/output_dir \ 15 | --logtostderr 16 | ``` 17 | 18 | ## Experiment Results 19 | Below is a summary of our findings. For full discussion please see our paper: 20 | [Models, Pixels, and Rewards: Evaluating Design Trade-offs in Visual Model-Based Reinforcement Learning](https://arxiv.org/abs/2012.04603) 21 | 22 | Is predicting future rewards sufficient for achieving success in visual 23 | model-based reinforcement learning? We experimentally demonstrate that this 24 | is usually **not** the case in the online settings and the key is to 25 | predict future images too. 26 | 27 | ![](https://user-images.githubusercontent.com/4112440/101852808-4ef49000-3b13-11eb-9266-8ea3ed291bd9.gif) 28 | 29 | Amazingly, this also means there is a weak correlation between reward 30 | prediction accuracy and performance of the agent. However, we show that there 31 | is a much stronger correlation between image reconstruction error and the 32 | performance of the agent. 33 | 34 | ![](https://user-images.githubusercontent.com/4112440/101852932-9713b280-3b13-11eb-8003-d0080a482872.png) 35 | 36 | We show how this phenomenon is directly related to exploration: models that 37 | fit the data better usually perform better in an *offline* setup. 38 | Surprisingly, these are often not the same models that perform the best 39 | when learning and exploring from scratch! 40 | 41 | ![](https://user-images.githubusercontent.com/4112440/101853015-c32f3380-3b13-11eb-9823-47befb7745ba.jpeg) 42 | 43 | 44 | ## How to Cite 45 | If you use this work, please cite the following paper where it was first introduced: 46 | ``` 47 | @article{2020worldmodels, 48 | title = {Models, Pixels, and Rewards: Evaluating Design Trade-offs in Visual Model-Based Reinforcement Learning}, 49 | author = {Mohammad Babaeizadeh and Mohammad Taghi Saffar and Danijar Hafner and Harini Kannan and Chelsea Finn and Sergey Levine and Dumitru Erhan}, 50 | year = {2020}, 51 | url = {https://arxiv.org/abs/2012.04603} 52 | } 53 | ``` 54 | 55 | You can reach us at wm-core@google.com 56 | ## Dependencies 57 | * absl 58 | * gin-config 59 | * TensorFlow==1.15 60 | * TensorFlow probability==0.7 61 | * gym 62 | * dm_control 63 | * MuJoCo 64 | 65 | Disclaimer: This is not an official Google product. 66 | -------------------------------------------------------------------------------- /bin/train_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Train, simulate and evaluate loop with a model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from absl import app 23 | from absl import flags 24 | import gin 25 | import tensorflow.compat.v1 as tf 26 | 27 | from world_models.loops import train_eval 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("output_dir", None, "Output directory.") 32 | flags.DEFINE_multi_string( 33 | "config_path", None, 34 | "Newline separated list of paths to a world models gin configs.") 35 | flags.DEFINE_multi_string("config_param", None, 36 | "Newline separated list of Gin parameter bindings.") 37 | flags.DEFINE_bool("enable_eager", True, "Enable eager execution mode.") 38 | flags.DEFINE_integer("num_virtual_gpus", -1, "If >1, enables virtual gpus.") 39 | flags.DEFINE_boolean("offline_train", False, "Train model on offline data.") 40 | flags.DEFINE_string("offline_train_data_dir", None, 41 | "Data dir to be used for offline training.") 42 | 43 | def main(argv): 44 | del argv # Unused 45 | if FLAGS.enable_eager: 46 | tf.enable_eager_execution() 47 | tf.config.set_soft_device_placement(True) 48 | 49 | config_params = FLAGS.config_param or [] 50 | config_params += train_eval.get_gin_override_params(FLAGS.output_dir) 51 | base_config_path = os.path.dirname(FLAGS.config_path[0]) 52 | gin.add_config_file_search_path(base_config_path) 53 | gin.parse_config_files_and_bindings(FLAGS.config_path, config_params) 54 | 55 | if FLAGS.num_virtual_gpus > -1: 56 | gpus = tf.config.experimental.list_physical_devices("GPU") 57 | 58 | total_gpu_mem_limit = 8192 59 | per_gpu_mem_limit = total_gpu_mem_limit / FLAGS.num_virtual_gpus 60 | virtual_gpus = [ 61 | tf.config.experimental.VirtualDeviceConfiguration( 62 | memory_limit=per_gpu_mem_limit) 63 | ] * FLAGS.num_virtual_gpus 64 | tf.config.experimental.set_virtual_device_configuration( 65 | gpus[0], virtual_gpus) 66 | logical_gpus = tf.config.experimental.list_logical_devices("GPU") 67 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 68 | 69 | train_eval.train_eval_loop( 70 | offline_train=FLAGS.offline_train, 71 | offline_train_data_dir=FLAGS.offline_train_data_dir) 72 | 73 | 74 | if __name__ == "__main__": 75 | flags.mark_flags_as_required(["output_dir"]) 76 | flags.mark_flags_as_required(["config_path"]) 77 | 78 | app.run(main) 79 | 80 | -------------------------------------------------------------------------------- /tasks/tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implementation of tasks.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from typing import Text 21 | 22 | from dm_control import suite 23 | import gin 24 | import gym 25 | from world_models.utils import wrappers 26 | 27 | 28 | class Task(object): 29 | """Base class for tasks.""" 30 | 31 | @property 32 | def name(self) -> Text: 33 | raise NotImplementedError 34 | 35 | def create_env(self) -> gym.Env: 36 | raise NotImplementedError 37 | 38 | def create_nonvisual_env(self) -> gym.Env: 39 | raise NotImplementedError 40 | 41 | 42 | @gin.configurable 43 | class DeepMindControl(Task): 44 | """Deepmind Control Suite environment.""" 45 | 46 | def __init__(self, 47 | domain_name: Text = gin.REQUIRED, 48 | task_name: Text = gin.REQUIRED, 49 | camera_id: int = 0, 50 | max_duration: int = 1000, 51 | action_repeat: int = 1): 52 | self._domain_name = domain_name 53 | self._task_name = task_name 54 | self._camera_id = camera_id 55 | self._max_duration = max_duration 56 | self._action_repeat = action_repeat 57 | 58 | @property 59 | def name(self): 60 | return self._domain_name + ":" + self._task_name 61 | 62 | def create_env(self): 63 | env = suite.load(self._domain_name, self._task_name) 64 | env = wrappers.DeepMindEnv(env, camera_id=self._camera_id) 65 | env = wrappers.MaximumDuration(env, duration=self._max_duration) 66 | env = wrappers.ActionRepeat(env, n=self._action_repeat) 67 | env = wrappers.RenderObservation(env) 68 | env = wrappers.ConvertTo32Bit(env) 69 | return env 70 | 71 | def create_nonvisual_env(self): 72 | env = suite.load(self._domain_name, self._task_name) 73 | env = wrappers.DeepMindEnv(env, camera_id=self._camera_id) 74 | env = wrappers.ActionRepeat(env, n=self._action_repeat) 75 | return env 76 | 77 | def __reduce__(self): 78 | args = (self._domain_name, self._task_name, self._camera_id, 79 | self._max_duration, self._action_repeat) 80 | return self.__class__, args 81 | 82 | 83 | @gin.configurable 84 | class Atari(Task): 85 | """ATARI envs from OpenAI gym.""" 86 | 87 | def __init__(self, 88 | game: Text = gin.REQUIRED, 89 | width: int = 64, 90 | height: int = 64, 91 | channels: int = 3, 92 | max_duration: int = 1000, 93 | action_repeat: int = 1): 94 | import atari_py # pylint: disable=unused-import, unused-variable 95 | assert channels == 1 or channels == 3 96 | self._game = game 97 | self._width = width 98 | self._height = height 99 | self._channels = channels 100 | self._max_duration = max_duration 101 | self._action_repeat = action_repeat 102 | 103 | @property 104 | def name(self): 105 | return "atari_%s" % self._game 106 | 107 | def create_env(self): 108 | env = gym.make(self._game) 109 | env = wrappers.ObservationDict(env) 110 | env = wrappers.MaximumDuration(env, duration=self._max_duration) 111 | env = wrappers.ActionRepeat(env, n=self._action_repeat) 112 | env = wrappers.RenderObservation(env) 113 | env = wrappers.ConvertTo32Bit(env) 114 | return env 115 | -------------------------------------------------------------------------------- /utils/npz.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Data utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import datetime 22 | import functools 23 | import io 24 | import os 25 | import random 26 | import uuid 27 | 28 | import numpy as np 29 | import tensorflow.compat.v1 as tf 30 | 31 | from world_models.utils import nested 32 | 33 | tfdd = tf.data.Dataset 34 | 35 | # pylint:disable=missing-docstring 36 | 37 | 38 | def save_dictionaries(dictionaries, directory): 39 | for dictionary in dictionaries: 40 | save_dictionary(dictionary, directory) 41 | 42 | 43 | def save_dictionary(dictionary, directory): 44 | """Save a dictionary as npz.""" 45 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 46 | identifier = str(uuid.uuid4()).replace('-', '') 47 | filename = '{}-{}.npz'.format(timestamp, identifier) 48 | filename = os.path.join(directory, filename) 49 | if not tf.gfile.Exists(directory): 50 | tf.gfile.MakeDirs(directory) 51 | with io.BytesIO() as file_: 52 | np.savez_compressed(file_, **dictionary) 53 | file_.seek(0) 54 | with tf.gfile.Open(filename, 'w') as ff: 55 | ff.write(file_.read()) 56 | return filename 57 | 58 | 59 | def load_dataset_from_directory(directory, 60 | length, 61 | batch, 62 | cache_update_every=1000, 63 | buffer_size=10): 64 | loader = functools.partial(numpy_loader, directory, cache_update_every) 65 | dtypes, shapes = read_spec(loader) 66 | dtypes = {key: tf.as_dtype(value) for key, value in dtypes.items()} 67 | shapes = {key: (None,) + shape[1:] for key, shape in shapes.items()} 68 | chunking = functools.partial(chunk_sequence, length=length) 69 | dataset = tfdd.from_generator(loader, dtypes, shapes) 70 | dataset = dataset.flat_map(chunking) 71 | dataset = dataset.batch(batch, drop_remainder=True) 72 | dataset = dataset.prefetch(buffer_size) 73 | return dataset 74 | 75 | 76 | def numpy_loader(directory, cache_update_every=1000): 77 | """A generator for loading npzs from a directory.""" 78 | cache = {} 79 | while True: 80 | data = _sample(list(cache.values()), cache_update_every) 81 | for dictionary in _permuted(data, cache_update_every): 82 | yield dictionary 83 | directory = os.path.expanduser(directory) 84 | filenames = tf.gfile.Glob(os.path.join(directory, '*.npz')) 85 | filenames = [filename for filename in filenames if filename not in cache] 86 | for filename in filenames: 87 | with tf.gfile.Open(filename, 'rb') as file_: 88 | cache[filename] = dict(np.load(file_)) 89 | 90 | 91 | def _sample(sequence, amount): 92 | amount = min(amount, len(sequence)) 93 | return random.sample(sequence, amount) 94 | 95 | 96 | def _permuted(sequence, amount): 97 | """a generator for `amount` elements from permuted elements in `sequence`.""" 98 | if not sequence: 99 | return 100 | index = 0 101 | while True: 102 | for element in np.random.permutation(sequence): 103 | if index >= amount: 104 | return 105 | yield element 106 | index += 1 107 | 108 | 109 | def read_spec(loader): 110 | dictionaries = loader() 111 | dictionary = next(dictionaries) 112 | dictionaries.close() 113 | dtypes = {key: value.dtype for key, value in dictionary.items()} 114 | shapes = {key: value.shape for key, value in dictionary.items()} 115 | return dtypes, shapes 116 | 117 | 118 | def chunk_sequence(sequence, length): 119 | """Randomly chunks a sequence into smaller ones. 120 | 121 | This is useful for sampling short videos from long ones. 122 | 123 | Args: 124 | sequence: the original dataset with long sequences. 125 | length: length of the desired chunks. 126 | 127 | Returns: 128 | chuncked dataset. 129 | """ 130 | with tf.device('/cpu:0'): 131 | seq_length = tf.shape(nested.flatten(sequence)[0])[0] 132 | max_offset = seq_length - length 133 | op = tf.Assert(tf.greater_equal(max_offset, 0), data=[length, seq_length]) 134 | with tf.control_dependencies([op]): 135 | offset = tf.random_uniform((), 0, max_offset + 1, dtype=tf.int32) 136 | clipped = nested.map(lambda x: x[offset:offset + length], sequence) 137 | chunks = tfdd.from_tensor_slices( 138 | nested.map( 139 | lambda x: tf.reshape(x, [-1, length] + x.shape[1:].as_list()), 140 | clipped)) 141 | return chunks 142 | -------------------------------------------------------------------------------- /simulate/simulate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements the main logic for running a simulation with a world model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import time 22 | from typing import List, Dict, Text, Tuple 23 | 24 | from absl import logging 25 | import gin 26 | import numpy as np 27 | from world_models.planners import planners 28 | from world_models.tasks import tasks 29 | 30 | 31 | def get_timestep_data(obs, reward, action): 32 | """Form a dict to hold the info of a single timestep in simulation.""" 33 | timestep = {} 34 | if isinstance(obs, dict): 35 | for key, value in obs.items(): 36 | timestep[key] = value 37 | else: 38 | timestep['image'] = obs 39 | if action.ndim == 0: 40 | action = np.expand_dims(action, axis=-1) 41 | timestep['action'] = action 42 | timestep['reward'] = np.asarray([reward], dtype=np.float32) 43 | return timestep 44 | 45 | 46 | @gin.configurable(blacklist=['episode']) 47 | def approximate_value(episode, gamma=0.99): 48 | last_value = 0.0 49 | for d in reversed(episode): 50 | d['value'] = d['reward'] + gamma*last_value 51 | last_value = d['value'] 52 | return episode 53 | 54 | 55 | def get_prediction_data(prediction, prediction_keys): 56 | """Form a dict to hold predictions of the model for a single timestep.""" 57 | timestep = {} 58 | if 'image' in prediction_keys: 59 | timestep['image'] = prediction.get('image', np.zeros(1)) 60 | if 'reward' in prediction_keys: 61 | timestep['reward'] = prediction.get('reward', np.zeros(1)) 62 | return timestep 63 | 64 | 65 | def preprocess(image, reward): 66 | reward = float(reward) 67 | # TPUs do not support uint8, convert images to int32. 68 | image = image.astype(np.int32) 69 | return image, reward 70 | 71 | 72 | def single_episode(planner, env): 73 | """Simulate a single episode. 74 | 75 | Args: 76 | planner: a `Planner` object that uses a world model for planning. 77 | env: the environment. 78 | 79 | Returns: 80 | episode: a dictionary with `image`, `action` and `reward` keys 81 | and np.ndarray values. may include other keys if the env has additional 82 | information. 83 | """ 84 | data = [] 85 | 86 | planner.reset() 87 | obs, reward, done = env.reset(), 0.0, False 88 | obs['image'], reward = preprocess(obs['image'], reward) 89 | action = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) 90 | data.append(get_timestep_data(obs, reward, action)) 91 | 92 | prediction_keys = set() 93 | predictions = [] 94 | step = 0 95 | start_time = time.time() 96 | while not done: 97 | action, prediction = planner(obs['image'], action, reward) 98 | obs, reward, done, _ = env.step(action) 99 | obs['image'], reward = preprocess(obs['image'], reward) 100 | data.append(get_timestep_data(obs, reward, action)) 101 | prediction_keys.update(prediction.keys()) 102 | predictions.append(prediction) 103 | step += 1 104 | if step % 10 == 0: 105 | step_per_sec = step / np.float(time.time() - start_time) 106 | logging.info('Environment step %d, step per sec %.4f', step, step_per_sec) 107 | 108 | data = approximate_value(data) 109 | # stack timesteps in the episode to form numpy arrays 110 | episode = { 111 | key: np.stack([d[key] for d in data], axis=0) for key in data[0].keys() 112 | } 113 | predictions = [get_prediction_data(p, prediction_keys) for p in predictions] 114 | predictions = { 115 | key: 116 | np.stack(np.broadcast_arrays(*(p[key] for p in predictions)), axis=0) 117 | for key in prediction_keys 118 | } 119 | 120 | return episode, predictions 121 | 122 | 123 | def simulate( 124 | task: tasks.Task, planner: planners.Planner, num_episodes: int 125 | ) -> Tuple[List[Dict[Text, np.ndarray]], List[Dict[Text, np.ndarray]], float]: 126 | """Simulate the world. 127 | 128 | Args: 129 | task: a `Task` object. 130 | planner: a `Planner` object that uses a world model for planning. 131 | num_episodes: how many episodes to simulate. Each episode continues until it 132 | is done. 133 | 134 | Returns: 135 | episodes: a list of episodes. each episode is a dictionary with `image`, 136 | `action` and `reward` keys and np.ndarray values. may include other 137 | keys if the env has additional information. 138 | predictions: a list of episode predictions. each episode prediction is a 139 | dictionary containing the model predictions at every step of the 140 | environment in that episode. 141 | score: the average score of a complete episode. 142 | """ 143 | env = task.create_env() 144 | episodes = [] 145 | predictions = [] 146 | for i in range(num_episodes): 147 | logging.info('Starting episode %d', i) 148 | episode, prediction = single_episode(planner, env) 149 | episodes.append(episode) 150 | predictions.append(prediction) 151 | 152 | score = np.mean(list(np.sum(episode['reward']) for episode in episodes)) 153 | return episodes, predictions, score 154 | -------------------------------------------------------------------------------- /agents/jax_pure_reward.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # python3 16 | """A simple example of an agent implemented in Jax.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | 23 | from flax import jax_utils 24 | from flax import nn 25 | from flax import optim 26 | from flax import struct 27 | from flax.training import checkpoints 28 | import gin 29 | import jax 30 | from jax import lax 31 | from jax import numpy as jnp 32 | import numpy as np 33 | 34 | from world_models.utils import npz 35 | 36 | 37 | # pylint:disable=missing-function-docstring 38 | 39 | 40 | def configurable_module(module): 41 | if not issubclass(module, nn.Module): 42 | raise ValueError("this decorator can only be used on flax.nn.Module class.") 43 | 44 | def wrapper(**kwargs): 45 | return module.partial(**kwargs) 46 | 47 | wrapper.__name__ = module.__name__ 48 | 49 | return gin.configurable(wrapper) 50 | 51 | 52 | @struct.dataclass 53 | class TrainState: 54 | step: int 55 | optimizer: optim.Optimizer 56 | 57 | 58 | @configurable_module 59 | class JaxPureReward(nn.Module): 60 | """A simple pure reward predictor with Jax.""" 61 | 62 | def apply(self, actions, num_layers, hidden_dims): 63 | timesteps = actions.shape[1] 64 | # flatten time into batch 65 | actions = jnp.reshape(actions, (-1,) + actions.shape[2:]) 66 | # embed actions 67 | x = nn.Dense(actions, hidden_dims) 68 | for _ in range(num_layers): 69 | x = nn.Dense(x, hidden_dims) 70 | x = nn.LayerNorm(x) 71 | x = nn.relu(x) 72 | x = nn.Dense(x, 1) 73 | x = jnp.reshape(x, (-1, timesteps, 1)) 74 | return x 75 | 76 | 77 | @gin.configurable 78 | def create_model(module, task): 79 | """Initializes the model and returns it.""" 80 | action_space = task.create_env().action_space 81 | rng_key = jax.random.PRNGKey(0) 82 | _, params = module.init_by_shape(rng_key, 83 | [((1, 1) + action_space.shape, jnp.float32)]) 84 | model = nn.Model(module, params) 85 | return model 86 | 87 | 88 | @gin.configurable 89 | def reset_fn(**kwargs): 90 | del kwargs # Unused 91 | return jax_utils.replicate({}) 92 | 93 | 94 | @gin.configurable(blacklist=["images", "actions", "rewards", "state"]) 95 | def observe_fn(images, actions, rewards, state): 96 | del images, actions, rewards # Unused 97 | return state 98 | 99 | 100 | @gin.configurable 101 | def create_predict_fn(model): 102 | @functools.partial(jax.pmap) 103 | def predict(actions): 104 | return model(actions) 105 | 106 | def predict_fn(actions, state): 107 | del state # Unused 108 | actions = jnp.reshape(actions, 109 | (jax.local_device_count(), -1) + actions.shape[-2:]) 110 | predictions = predict(actions) 111 | predictions = jnp.reshape(predictions, (-1,) + predictions.shape[-2:]) 112 | return {"reward": jax.device_get(predictions)} 113 | 114 | return predict_fn 115 | 116 | 117 | @gin.configurable 118 | def create_train_fn(model, model_dir, duration, batch, train_steps, 119 | learning_rate): 120 | optimizer = optim.Adam() 121 | opt = optimizer.create(model) 122 | state = TrainState(optimizer=opt, step=0) # pytype:disable=wrong-keyword-args 123 | state = checkpoints.restore_checkpoint(model_dir, state) 124 | state = jax_utils.replicate(state) 125 | iterator = None 126 | 127 | @functools.partial(jax.pmap, axis_name="batch") 128 | def train_step(obs, state): 129 | actions = obs["action"] 130 | rewards = obs["reward"] 131 | step = state.step 132 | optimizer = state.optimizer 133 | 134 | def loss(model): 135 | predictions = model(actions) 136 | l = (rewards - predictions) ** 2 137 | l = jnp.mean(l) 138 | return l 139 | 140 | grad_fn = jax.value_and_grad(loss) 141 | l, grads = grad_fn(state.optimizer.target) 142 | grads = lax.pmean(grads, axis_name="batch") 143 | new_optimizer = optimizer.apply_gradient(grads, learning_rate=learning_rate) 144 | new_state = state.replace(step=step + 1, optimizer=new_optimizer) 145 | return new_state, l 146 | 147 | def train(data_path): 148 | nonlocal iterator 149 | nonlocal state 150 | 151 | if iterator is None: 152 | dataset = npz.load_dataset_from_directory(data_path, duration, batch) 153 | iterator = dataset.make_one_shot_iterator() 154 | iterator = map( 155 | lambda x: jax.tree_map( 156 | lambda x: np.reshape( 157 | x, (jax.local_device_count(), -1) + x.numpy().shape[1:]), 158 | x), 159 | iterator) 160 | iterator = jax_utils.prefetch_to_device(iterator, 2) 161 | 162 | for _ in range(train_steps): 163 | obs = next(iterator) 164 | state, l = train_step(obs, state) 165 | local_state = get_first_device(state) 166 | l = get_first_device(l) 167 | checkpoints.save_checkpoint(model_dir, local_state, local_state.step) 168 | 169 | return train 170 | 171 | 172 | def get_first_device(value): 173 | return jax.tree_map(lambda x: x[0], value) 174 | -------------------------------------------------------------------------------- /imported_models/reward_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Reward Models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | import tensorflow.compat.v1.layers as tfl 23 | 24 | from world_models.imported_models import common 25 | from world_models.imported_models import layers 26 | from tensorflow.contrib import layers as tfcl 27 | 28 | 29 | 30 | def reward_prediction_basic(prediction): 31 | """The most simple reward predictor. 32 | 33 | This works by averaging the pixels and running a dense layer on top. 34 | 35 | Args: 36 | prediction: The predicted image. 37 | 38 | Returns: 39 | the predicted reward. 40 | """ 41 | x = prediction 42 | x = tf.reduce_mean(x, axis=[1, 2], keepdims=True) 43 | x = tf.squeeze(x, axis=[1, 2]) 44 | x = tfl.dense(x, 128, activation=tf.nn.relu, name="reward_pred") 45 | return x 46 | 47 | 48 | def reward_prediction_mid(input_images): 49 | """A reward predictor network from intermediate layers. 50 | 51 | The inputs can be any image size (usually the intermediate conv outputs). 52 | The model runs 3 conv layers on top of each with a dense layer at the end. 53 | All of these are combined with 2 additional dense layer. 54 | 55 | Args: 56 | input_images: the input images. size is arbitrary. 57 | 58 | Returns: 59 | the predicted reward. 60 | """ 61 | encoded = [] 62 | for i, x in enumerate(input_images): 63 | enc = x 64 | enc = tfl.conv2d(enc, 16, [3, 3], strides=(1, 1), activation=tf.nn.relu) 65 | enc = tfl.conv2d(enc, 8, [3, 3], strides=(2, 2), activation=tf.nn.relu) 66 | enc = tfl.conv2d(enc, 4, [3, 3], strides=(2, 2), activation=tf.nn.relu) 67 | enc = tfl.flatten(enc) 68 | enc = tfl.dense(enc, 8, activation=tf.nn.relu, name="rew_enc_%d" % i) 69 | encoded.append(enc) 70 | x = encoded 71 | x = tf.stack(x, axis=1) 72 | x = tfl.flatten(x) 73 | x = tfl.dense(x, 32, activation=tf.nn.relu, name="rew_dense1") 74 | x = tfl.dense(x, 16, activation=tf.nn.relu, name="rew_dense2") 75 | return x 76 | 77 | 78 | def reward_prediction_big(input_images, input_reward, action, latent, 79 | action_injection, small_mode): 80 | """A big reward predictor network that incorporates lots of additional info. 81 | 82 | Args: 83 | input_images: context frames. 84 | input_reward: context rewards. 85 | action: next action. 86 | latent: predicted latent vector for this frame. 87 | action_injection: action injection method. 88 | small_mode: smaller convs for faster runtiume. 89 | 90 | Returns: 91 | the predicted reward. 92 | """ 93 | conv_size = common.tinyify([32, 32, 16, 8], False, small_mode) 94 | 95 | x = tf.concat(input_images, axis=3) 96 | x = tfcl.layer_norm(x) 97 | 98 | if not small_mode: 99 | x = tfl.conv2d( 100 | x, 101 | conv_size[1], [3, 3], 102 | strides=(2, 2), 103 | activation=tf.nn.relu, 104 | name="reward_conv1") 105 | x = tfcl.layer_norm(x) 106 | 107 | # Inject additional inputs 108 | if action is not None: 109 | x = layers.inject_additional_input(x, action, "action_enc", 110 | action_injection) 111 | if input_reward is not None: 112 | x = layers.inject_additional_input(x, input_reward, "reward_enc") 113 | if latent is not None: 114 | latent = tfl.flatten(latent) 115 | latent = tf.expand_dims(latent, axis=1) 116 | latent = tf.expand_dims(latent, axis=1) 117 | x = layers.inject_additional_input(x, latent, "latent_enc") 118 | 119 | x = tfl.conv2d( 120 | x, 121 | conv_size[2], [3, 3], 122 | strides=(2, 2), 123 | activation=tf.nn.relu, 124 | name="reward_conv2") 125 | x = tfcl.layer_norm(x) 126 | x = tfl.conv2d( 127 | x, 128 | conv_size[3], [3, 3], 129 | strides=(2, 2), 130 | activation=tf.nn.relu, 131 | name="reward_conv3") 132 | return x 133 | 134 | 135 | def reward_prediction_video_conv(frames, rewards, prediction_len): 136 | """A reward predictor network from observed/predicted images. 137 | 138 | The inputs is a list of frames. 139 | 140 | Args: 141 | frames: the list of input images. 142 | rewards: previously observed rewards. 143 | prediction_len: the length of the reward vector. 144 | 145 | Returns: 146 | the predicted rewards. 147 | """ 148 | x = tf.concat(frames, axis=-1) 149 | x = tfl.conv2d(x, 32, [3, 3], strides=(2, 2), activation=tf.nn.relu) 150 | x = tfl.conv2d(x, 32, [3, 3], strides=(2, 2), activation=tf.nn.relu) 151 | x = tfl.conv2d(x, 16, [3, 3], strides=(2, 2), activation=tf.nn.relu) 152 | x = tfl.conv2d(x, 8, [3, 3], strides=(2, 2), activation=tf.nn.relu) 153 | x = tfl.flatten(x) 154 | 155 | y = tf.concat(rewards, axis=-1) 156 | y = tfl.dense(y, 32, activation=tf.nn.relu) 157 | y = tfl.dense(y, 16, activation=tf.nn.relu) 158 | y = tfl.dense(y, 8, activation=tf.nn.relu) 159 | 160 | z = tf.concat([x, y], axis=-1) 161 | z = tfl.dense(z, 32, activation=tf.nn.relu) 162 | z = tfl.dense(z, 16, activation=tf.nn.relu) 163 | z = tfl.dense(z, prediction_len, activation=None) 164 | z = tf.expand_dims(z, axis=-1) 165 | return z 166 | -------------------------------------------------------------------------------- /imported_models/tests_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilties for testing video models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | 23 | import tensorflow.compat.v1 as tf 24 | tf.enable_eager_execution() 25 | 26 | 27 | def fill_hparams(hparams, in_frames, out_frames): 28 | hparams.video_num_input_frames = in_frames 29 | hparams.video_num_target_frames = out_frames 30 | hparams.tiny_mode = True 31 | hparams.reward_prediction = False 32 | return hparams 33 | 34 | 35 | def create_basic_features(in_frames, out_frames, is_training): 36 | video_len = in_frames + out_frames if is_training else in_frames 37 | x = np.random.randint(0, 256, size=(8, video_len, 64, 64, 3)) 38 | features = { 39 | "frames": tf.constant(x, dtype=tf.int32), 40 | } 41 | return features 42 | 43 | 44 | def create_action_features(in_frames, out_frames, is_training): 45 | features = create_basic_features(in_frames, out_frames, is_training) 46 | video_len = in_frames + out_frames # future actions should be present 47 | x = np.random.randint(0, 5, size=(8, video_len, 1)) 48 | features["actions"] = tf.constant(x, dtype=tf.float32) 49 | return features 50 | 51 | 52 | def create_full_features(in_frames, out_frames, is_training): 53 | features = create_action_features(in_frames, out_frames, is_training) 54 | video_len = in_frames + out_frames if is_training else in_frames 55 | x = np.random.randint(0, 5, size=(8, video_len, 1)) 56 | features["rewards"] = tf.constant(x, dtype=tf.float32) 57 | return features 58 | 59 | 60 | def get_shape_list(tensor): 61 | return [d.value for d in tensor.shape] 62 | 63 | 64 | def get_expected_shape(video, expected_len): 65 | shape = get_shape_list(video) 66 | shape[1] = expected_len 67 | return shape 68 | 69 | 70 | class BaseModelTest(tf.test.TestCase): 71 | """Base helper class for next frame tests.""" 72 | 73 | def TrainModel(self, model_cls, hparams, features): 74 | model = model_cls(hparams) 75 | tf.train.get_or_create_global_step() 76 | with tf.GradientTape() as tape: 77 | loss = model.train(features) 78 | variables = tape.watched_variables() 79 | grads = tape.gradient(loss, variables) 80 | # Make sure the backward pass works as well. 81 | optimizer = tf.train.AdamOptimizer(1e-3) 82 | optimizer.apply_gradients(zip(grads, variables)) 83 | return loss 84 | 85 | def InferModel(self, model_cls, hparams, features): 86 | model = model_cls(hparams) 87 | tf.train.get_or_create_global_step() 88 | predictions = model.infer(features) 89 | return predictions 90 | 91 | def TestVideoModel(self, in_frames, out_frames, hparams, model): 92 | hparams = fill_hparams(hparams, in_frames, out_frames) 93 | 94 | features = create_basic_features(in_frames, out_frames, True) 95 | loss = self.TrainModel(model, hparams, features) 96 | 97 | self.assertEqual(get_shape_list(loss), [8]) 98 | self.assertEqual(loss.dtype, tf.float32) 99 | 100 | def TestVideoModelInfer(self, in_frames, out_frames, hparams, model): 101 | hparams = fill_hparams(hparams, in_frames, out_frames) 102 | 103 | features = create_basic_features(in_frames, out_frames, False) 104 | output, _ = self.InferModel(model, hparams, features) 105 | 106 | self.assertIsInstance(output, dict) 107 | self.assertIn("frames", output) 108 | expected_shape = get_expected_shape(features["frames"], out_frames) 109 | output_shape = get_shape_list(output["frames"]) 110 | self.assertEqual(output_shape, expected_shape) 111 | 112 | def TestVideoModelWithActions(self, in_frames, out_frames, hparams, model): 113 | hparams = fill_hparams(hparams, in_frames, out_frames) 114 | hparams.reward_prediction = False 115 | 116 | features = create_action_features(in_frames, out_frames, True) 117 | loss = self.TrainModel(model, hparams, features) 118 | 119 | self.assertEqual(get_shape_list(loss), [8]) 120 | self.assertEqual(loss.dtype, tf.float32) 121 | 122 | def TestVideoModelWithActionsInfer(self, in_frames, out_frames, hparams, 123 | model): 124 | hparams = fill_hparams(hparams, in_frames, out_frames) 125 | hparams.reward_prediction = False 126 | 127 | features = create_action_features(in_frames, out_frames, False) 128 | output = self.InferModel(model, hparams, features) 129 | 130 | self.assertIsInstance(output, dict) 131 | self.assertIn("frames", output) 132 | expected_shape = get_expected_shape(features["frames"], out_frames) 133 | output_shape = get_shape_list(output["frames"]) 134 | self.assertEqual(output_shape, expected_shape) 135 | 136 | def TestVideoModelWithActionAndRewards(self, in_frames, out_frames, hparams, 137 | model): 138 | hparams = fill_hparams(hparams, in_frames, out_frames) 139 | hparams.reward_prediction = True 140 | 141 | features = create_full_features(in_frames, out_frames, True) 142 | loss, _ = self.TrainModel(model, hparams, features) 143 | 144 | self.assertEqual(get_shape_list(loss), [8]) 145 | self.assertEqual(loss.dtype, tf.float32) 146 | 147 | def TestVideoModelWithActionAndRewardsInfer(self, in_frames, out_frames, 148 | hparams, model): 149 | hparams = fill_hparams(hparams, in_frames, out_frames) 150 | hparams.reward_prediction = True 151 | 152 | features = create_full_features(in_frames, out_frames, False) 153 | 154 | output = self.InferModel(model, hparams, features) 155 | 156 | self.assertIsInstance(output, dict) 157 | self.assertIn("frames", output) 158 | self.assertIn("rewards", output) 159 | expected_shape = get_expected_shape(features["frames"], out_frames) 160 | output_shape = get_shape_list(output["frames"]) 161 | self.assertEqual(output_shape, expected_shape) 162 | expected_shape = get_expected_shape(features["rewards"], out_frames) 163 | output_shape = get_shape_list(output["rewards"]) 164 | self.assertEqual(output_shape, expected_shape) 165 | 166 | def TestOnVariousInputOutputSizes(self, hparams, model): 167 | test_funcs = [self.TestVideoModel] 168 | test_funcs += [self.TestVideoModelInfer] 169 | for test_func in test_funcs: 170 | test_func(1, 1, hparams, model) 171 | test_func(1, 6, hparams, model) 172 | test_func(4, 1, hparams, model) 173 | test_func(7, 5, hparams, model) 174 | 175 | def TestWithActions(self, hparams, model): 176 | test_funcs = [self.TestVideoModelWithActions] 177 | test_funcs += [self.TestVideoModelWithActionsInfer] 178 | for test_func in test_funcs: 179 | test_func(1, 1, hparams, model) 180 | test_func(1, 6, hparams, model) 181 | test_func(4, 1, hparams, model) 182 | test_func(7, 5, hparams, model) 183 | 184 | def TestWithActionAndRewards(self, hparams, model): 185 | test_funcs = [self.TestVideoModelWithActionAndRewards] 186 | test_funcs += [self.TestVideoModelWithActionAndRewardsInfer] 187 | for test_func in test_funcs: 188 | test_func(1, 1, hparams, model) 189 | test_func(1, 6, hparams, model) 190 | test_func(4, 1, hparams, model) 191 | test_func(7, 5, hparams, model) 192 | -------------------------------------------------------------------------------- /utils/nested.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools for manipulating nested tuples, list, and dictionaries.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | _builtin_zip = zip 22 | _builtin_map = map 23 | _builtin_filter = filter 24 | 25 | 26 | def zip_(*structures, **kwargs): 27 | """Combine corresponding elements in multiple nested structure to tuples. 28 | 29 | The nested structures can consist of any combination of lists, tuples, and 30 | dicts. All provided structures must have the same nesting. 31 | 32 | Args: 33 | *structures: Nested structures. 34 | **kwargs: other args: 35 | "flatten": Whether to flatten the resulting structure into a tuple. 36 | Keys of dictionaries will be discarded. 37 | 38 | Returns: 39 | Nested structure. 40 | """ 41 | # Named keyword arguments are not allowed after *args in Python 2. 42 | flatten_flag = kwargs.pop('flatten', False) 43 | assert not kwargs, 'zip() got unexpected keyword arguments.' 44 | return map( 45 | lambda *x: x if len(x) > 1 else x[0], 46 | *structures, 47 | flatten=flatten_flag) 48 | 49 | 50 | def map_(function, *structures, **kwargs): 51 | """Apply a function to every element in a nested structure. 52 | 53 | If multiple structures are provided as input, their structure must match and 54 | the function will be applied to corresponding groups of elements. The nested 55 | structure can consist of any combination of lists, tuples, and dicts. 56 | 57 | Args: 58 | function: The function to apply to the elements of the structure. Receives 59 | one argument for every structure that is provided. 60 | *structures: One of more nested structures. 61 | **kwargs: other args: 62 | "flatten": Whether to flatten the resulting structure into a tuple. 63 | Keys of dictionaries will be discarded. 64 | 65 | Returns: 66 | Nested structure. 67 | """ 68 | # Named keyword arguments are not allowed after *args in Python 2. 69 | flatten_flag = kwargs.pop('flatten', False) 70 | assert not kwargs, 'map() got unexpected keyword arguments.' 71 | 72 | def impl(function, *structures): 73 | """impl.""" 74 | if not structures: 75 | return structures 76 | if all(isinstance(s, (tuple, list)) for s in structures): 77 | if len(set(len(x) for x in structures)) > 1: 78 | raise ValueError('Cannot merge tuples or lists of different length.') 79 | args = tuple((impl(function, *x) for x in _builtin_zip(*structures))) 80 | if hasattr(structures[0], '_fields'): # namedtuple 81 | return type(structures[0])(*args) 82 | else: # tuple, list 83 | return type(structures[0])(args) 84 | if all(isinstance(s, dict) for s in structures): 85 | if len(set(frozenset(x.keys()) for x in structures)) > 1: 86 | raise ValueError('Cannot merge dicts with different keys.') 87 | merged = { 88 | k: impl(function, *(s[k] for s in structures)) 89 | for k in structures[0]} 90 | return type(structures[0])(merged) 91 | return function(*structures) 92 | 93 | result = impl(function, *structures) 94 | if flatten_flag: 95 | result = flatten_(result) 96 | return result 97 | 98 | 99 | def flatten_(structure): 100 | """Combine all leaves of a nested structure into a tuple. 101 | 102 | The nested structure can consist of any combination of tuples, lists, and 103 | dicts. Dictionary keys will be discarded but values will ordered by the 104 | sorting of the keys. 105 | 106 | Args: 107 | structure: Nested structure. 108 | 109 | Returns: 110 | Flat tuple. 111 | """ 112 | if isinstance(structure, dict): 113 | result = () 114 | for key in sorted(list(structure.keys())): 115 | result += flatten_(structure[key]) 116 | return result 117 | if isinstance(structure, (tuple, list)): 118 | result = () 119 | for element in structure: 120 | result += flatten_(element) 121 | return result 122 | return (structure,) 123 | 124 | 125 | def filter_(predicate, *structures, **kwargs): 126 | """Select elements of a nested structure based on a predicate function. 127 | 128 | If multiple structures are provided as input, their structure must match and 129 | the function will be applied to corresponding groups of elements. The nested 130 | structure can consist of any combination of lists, tuples, and dicts. 131 | 132 | Args: 133 | predicate: The function to determine whether an element should be kept. 134 | Receives one argument for every structure that is provided. 135 | *structures: One of more nested structures. 136 | **kwargs: other args: 137 | "flatten": Whether to flatten the resulting structure into a tuple. 138 | Keys of dictionaries will be discarded. 139 | 140 | Returns: 141 | Nested structure. 142 | """ 143 | # Named keyword arguments are not allowed after *args in Python 2. 144 | flatten_flag = kwargs.pop('flatten', False) 145 | assert not kwargs, 'filter() got unexpected keyword arguments.' 146 | 147 | def impl(predicate, *structures): 148 | """impl.""" 149 | if not structures: 150 | return structures 151 | if all(isinstance(s, (tuple, list)) for s in structures): 152 | if len(set(len(x) for x in structures)) > 1: 153 | raise ValueError('Cannot merge tuples or lists of different length.') 154 | # Only wrap in tuples if more than one structure provided. 155 | if len(structures) > 1: 156 | filtered = (impl(predicate, *x) for x in _builtin_zip(*structures)) 157 | else: 158 | filtered = (impl(predicate, x) for x in structures[0]) 159 | # Remove empty containers and construct result structure. 160 | if hasattr(structures[0], '_fields'): # namedtuple 161 | filtered = (x if x != () else None for x in filtered) # pylint: disable=g-explicit-bool-comparison 162 | return type(structures[0])(*filtered) 163 | else: # tuple, list 164 | filtered = ( 165 | x for x in filtered if not isinstance(x, (tuple, list, dict)) or x) 166 | return type(structures[0])(filtered) 167 | if all(isinstance(s, dict) for s in structures): 168 | if len(set(frozenset(x.keys()) for x in structures)) > 1: 169 | raise ValueError('Cannot merge dicts with different keys.') 170 | # Only wrap in tuples if more than one structure provided. 171 | if len(structures) > 1: 172 | filtered = { 173 | k: impl(predicate, *(s[k] for s in structures)) 174 | for k in structures[0]} 175 | else: 176 | filtered = {k: impl(predicate, v) for k, v in structures[0].items()} 177 | # Remove empty containers and construct result structure. 178 | filtered = { 179 | k: v for k, v in filtered.items() 180 | if not isinstance(v, (tuple, list, dict)) or v} 181 | return type(structures[0])(filtered) 182 | if len(structures) > 1: 183 | return structures if predicate(*structures) else () 184 | else: 185 | return structures[0] if predicate(structures[0]) else () 186 | 187 | result = impl(predicate, *structures) 188 | if flatten_flag: 189 | result = flatten_(result) 190 | return result 191 | 192 | 193 | zip = zip_ # pylint: disable=redefined-builtin 194 | map = map_ # pylint: disable=redefined-builtin 195 | flatten = flatten_ # pylint: disable=redefined-builtin 196 | filter = filter_ # pylint: disable=redefined-builtin 197 | 198 | -------------------------------------------------------------------------------- /loops/train_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Train, simulate and evaluate loop with a model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import json 22 | import os 23 | import time 24 | from typing import Text, Callable 25 | 26 | from absl import flags 27 | from absl import logging 28 | import gin 29 | import numpy as np 30 | import tensorflow.compat.v1 as tf 31 | import tensorflow.compat.v2.summary as tfs 32 | 33 | from world_models.planners import planners 34 | from world_models.simulate import simulate 35 | from world_models.tasks import tasks 36 | from world_models.utils import npz 37 | from world_models.utils import visualization 38 | 39 | 40 | FLAGS = flags.FLAGS 41 | flags.DEFINE_string("tpu", None, "gRPC address of the TPU worker.") 42 | 43 | 44 | @gin.configurable 45 | def get_tpu_strategy(): 46 | """Creates a TPUStrategy for distribution on TPUs.""" 47 | resolver = tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.tpu) 48 | tf.config.experimental_connect_to_cluster(resolver) 49 | tf.tpu.experimental.initialize_tpu_system(resolver) 50 | strategy = tf.distribute.experimental.TPUStrategy(resolver) 51 | return strategy 52 | 53 | 54 | def visualize(summary_dir, global_step, episodes, predictions, scalars): 55 | """Visualizes the episodes in TensorBoard.""" 56 | if tf.executing_eagerly(): 57 | writer = tfs.create_file_writer(summary_dir) 58 | with writer.as_default(): 59 | videos = np.stack([e["image"] for e in episodes]) 60 | video_summary = visualization.py_gif_summary( 61 | tag="episodes/video", images=videos, max_outputs=20, fps=20) 62 | tfs.experimental.write_raw_pb(video_summary, step=global_step) 63 | for k in scalars: 64 | tfs.scalar(name="episodes/%s" % k, data=scalars[k], step=global_step) 65 | if "image" in predictions[0]: 66 | videos = np.stack([e["image"] for e in predictions]) 67 | video_summary = visualization.py_gif_summary( 68 | tag="episodes/video_prediction", 69 | images=videos, 70 | max_outputs=6, 71 | fps=20) 72 | tfs.experimental.write_raw_pb(video_summary, step=global_step) 73 | if "reward" in predictions[0]: 74 | rewards = np.stack([e["reward"][1:] for e in episodes]) 75 | predicted_rewards = np.stack([p["reward"] for p in predictions]) 76 | signals = np.stack([rewards, predicted_rewards], axis=1) 77 | signals = signals[:, :, :, 0] 78 | visualization.py_plot_1d_signal( 79 | name="episodes/reward", 80 | signals=signals, 81 | labels=["reward", "prediction"], 82 | max_outputs=6, 83 | step=global_step) 84 | reward_dir = os.path.join(summary_dir, "rewards") 85 | rewards_to_save = {"true": rewards, "pred": predicted_rewards} 86 | npz.save_dictionary(rewards_to_save, reward_dir) 87 | else: 88 | summary_writer = tf.summary.FileWriter(summary_dir) 89 | for k in scalars: 90 | s = tf.Summary() 91 | s.value.add(tag="episodes/" + k, simple_value=scalars[k]) 92 | summary_writer.add_summary(s, global_step) 93 | videos = np.stack([e["image"] for e in episodes]) 94 | video_summary = visualization.py_gif_summary( 95 | tag="episodes/video", images=videos, max_outputs=20, fps=30) 96 | summary_writer.add_summary(video_summary, global_step) 97 | summary_writer.flush() 98 | 99 | 100 | def simulate_and_persist(task, planner, num_episodes, data_dir): 101 | """Runs the simulation, persists the results and returns the output.""" 102 | start_time = time.time() 103 | episodes, predictions, score = simulate.simulate( 104 | task=task, planner=planner, num_episodes=num_episodes) 105 | simulate_time = time.time() - start_time 106 | npz.save_dictionaries(episodes, data_dir) 107 | return episodes, predictions, score, simulate_time 108 | 109 | 110 | @gin.configurable(blacklist=["offline_train", "offline_train_data_dir"]) 111 | def train_eval_loop(task: tasks.Task = gin.REQUIRED, 112 | train_planner: planners.Planner = gin.REQUIRED, 113 | eval_planner: planners.Planner = gin.REQUIRED, 114 | train_fn: Callable[[Text], None] = gin.REQUIRED, 115 | num_train_episodes_per_iteration: int = 1, 116 | eval_every_n_iterations: int = 1, 117 | num_iterations: int = 1, 118 | episodes_dir: Text = None, 119 | model_dir: Text = None, 120 | offline_train: bool = False, 121 | offline_train_data_dir: Text = None): 122 | """train and eval loop.""" 123 | assert episodes_dir, "episodes_dir is required" 124 | assert model_dir, "model_dir is required" 125 | 126 | # Load iteration info if exists 127 | iterations_info = [] 128 | iterations_datafile = os.path.join(episodes_dir, "info.json") 129 | if tf.io.gfile.exists(iterations_datafile): 130 | with tf.io.gfile.GFile(iterations_datafile, "r") as fp: 131 | iterations_info = json.load(fp) 132 | current_iteration = len(iterations_info) 133 | train_planner.set_episode_num(current_iteration * 134 | num_train_episodes_per_iteration) 135 | 136 | logging.info("Starting the simulation of %s", task.name) 137 | for i in range(current_iteration, num_iterations): 138 | logging.info("=" * 30) 139 | logging.info("Starting Iteration %08d", i) 140 | logging.info("=" * 30) 141 | iteration_info = {"iteration_num": i} 142 | if num_train_episodes_per_iteration: 143 | iteration_start_time = time.time() 144 | if offline_train: 145 | assert offline_train_data_dir, ("offline_train_data_dir is required in " 146 | "offline training mode") 147 | train_dir = offline_train_data_dir 148 | else: 149 | train_dir = os.path.join(episodes_dir, "train") 150 | episodes, predictions, score, simulate_time = simulate_and_persist( 151 | task, train_planner, num_train_episodes_per_iteration, train_dir) 152 | logging.info("Average score during training at iteration %d was: %f", i, 153 | score) 154 | 155 | logging.info("Training model at iteration %d", i) 156 | training_start_time = time.time() 157 | train_fn(train_dir) 158 | train_time = time.time() - training_start_time 159 | iteration_time = time.time() - iteration_start_time 160 | 161 | if not offline_train: 162 | scalars = { 163 | "score": score, 164 | "train_time": train_time, 165 | "simulate_time": simulate_time, 166 | "iteration_time": iteration_time, 167 | "iterations_per_hour": 3600.0 / iteration_time, 168 | } 169 | visualize( 170 | os.path.join(model_dir, "train"), i, episodes, predictions, scalars) 171 | 172 | if eval_every_n_iterations and i % eval_every_n_iterations == 0: 173 | eval_dir = os.path.join(episodes_dir, "eval") 174 | episodes, predictions, score, simulate_time = simulate_and_persist( 175 | task, eval_planner, 1, eval_dir) 176 | logging.info("Average score during evaluation at iteration %d was: %f", i, 177 | score) 178 | 179 | scalars = { 180 | "score": score, 181 | "simulate_time": simulate_time, 182 | } 183 | visualize( 184 | os.path.join(model_dir, "eval"), i, episodes, predictions, scalars) 185 | iterations_info.append(iteration_info) 186 | with tf.io.gfile.GFile(iterations_datafile, "w") as fp: 187 | json.dump(iterations_info, fp) 188 | 189 | 190 | def get_gin_override_params(output_dir): 191 | """Get gin config params to override.""" 192 | model_dir = os.path.join(output_dir, "model") 193 | episodes_dir = os.path.join(output_dir, "episodes") 194 | args = [ 195 | "model_dir='%s'" % model_dir, 196 | "episodes_dir='%s'" % episodes_dir, 197 | ] 198 | return args 199 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gif summary utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import io 22 | import subprocess 23 | 24 | import numpy as np 25 | import tensorflow.compat.v1 as tf 26 | import tensorflow.compat.v2.summary as tfs 27 | from tensorflow.python.ops import summary_op_util 28 | 29 | 30 | def encode_gif(images, fps): 31 | """Encodes numpy images into gif string. 32 | 33 | Args: 34 | images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape 35 | `[batch_size, time, height, width, channels]` where `channels` is 1 or 3. 36 | fps: frames per second of the animation 37 | 38 | Returns: 39 | The encoded gif string. 40 | 41 | Raises: 42 | IOError: If the ffmpeg command returns an error. 43 | """ 44 | ffmpeg = 'ffmpeg' 45 | h, w, c = images[0].shape 46 | cmd = [ 47 | ffmpeg, '-y', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-r', 48 | '%.02f' % fps, '-s', 49 | '%dx%d' % (w, h), '-pix_fmt', { 50 | 1: 'gray', 51 | 3: 'rgb24' 52 | }[c], '-i', '-', '-filter_complex', 53 | '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', '-r', 54 | '%.02f' % fps, '-f', 'gif', '-' 55 | ] 56 | proc = subprocess.Popen( 57 | cmd, 58 | stdin=subprocess.PIPE, 59 | stdout=subprocess.PIPE, 60 | stderr=subprocess.PIPE) 61 | for image in images: 62 | proc.stdin.write(image.tostring()) 63 | out, err = proc.communicate() 64 | if proc.returncode: 65 | err = '\n'.join([' '.join(cmd), err.decode('utf8')]) 66 | raise IOError(err) 67 | del proc 68 | return out 69 | 70 | 71 | def py_gif_summary(tag, images, max_outputs, fps): 72 | """Outputs a `Summary` protocol buffer with gif animations. 73 | 74 | Args: 75 | tag: Name of the summary. 76 | images: A 5-D `uint8` `np.array` of shape `[batch_size, time, height, width, 77 | channels]` where `channels` is 1 or 3. 78 | max_outputs: Max number of batch elements to generate gifs for. 79 | fps: frames per second of the animation 80 | 81 | Returns: 82 | The serialized `Summary` protocol buffer. 83 | 84 | Raises: 85 | ValueError: If `images` is not a 5-D `uint8` array with 1 or 3 channels. 86 | """ 87 | is_bytes = isinstance(tag, bytes) 88 | if is_bytes: 89 | tag = tag.decode('utf-8') 90 | images = np.asarray(images, dtype=np.uint8) 91 | if images.ndim != 5: 92 | raise ValueError('Tensor must be 5-D for gif summary.') 93 | batch_size, _, height, width, channels = images.shape 94 | if channels not in (1, 3): 95 | raise ValueError('Tensors must have 1 or 3 channels for gif summary.') 96 | summ = tf.Summary() 97 | num_outputs = min(batch_size, max_outputs) 98 | for i in range(num_outputs): 99 | image_summ = tf.Summary.Image() 100 | image_summ.height = height 101 | image_summ.width = width 102 | image_summ.colorspace = channels # 1: grayscale, 3: RGB 103 | try: 104 | image_summ.encoded_image_string = encode_gif(images[i], fps) 105 | except (IOError, OSError) as e: 106 | tf.logging.warning( 107 | 'Unable to encode images to a gif string because either ffmpeg is ' 108 | 'not installed or ffmpeg returned an error: %s. Falling back to an ' 109 | 'image summary of the first frame in the sequence.', e) 110 | try: 111 | from PIL import Image # pylint: disable=g-import-not-at-top 112 | with io.BytesIO() as output: 113 | Image.fromarray(images[i][0]).save(output, 'PNG') 114 | image_summ.encoded_image_string = output.getvalue() 115 | except Exception: # pylint: disable=broad-except 116 | tf.logging.warning( 117 | 'Gif summaries requires ffmpeg or PIL to be installed: %s', e) 118 | image_summ.encoded_image_string = (''.encode('utf-8') 119 | if is_bytes else '') 120 | if num_outputs == 1: 121 | summ_tag = '{}/gif'.format(tag) 122 | else: 123 | summ_tag = '{}/gif/{}'.format(tag, i) 124 | summ.value.add(tag=summ_tag, image=image_summ) 125 | summ_str = summ.SerializeToString() 126 | return summ_str 127 | 128 | 129 | def gif_summary(name, tensor, max_outputs, fps, collections=None, family=None): 130 | """Outputs a `Summary` protocol buffer with gif animations. 131 | 132 | Args: 133 | name: Name of the summary. 134 | tensor: A 5-D `uint8` `Tensor` of shape `[batch_size, time, height, width, 135 | channels]` where `channels` is 1 or 3. 136 | max_outputs: Max number of batch elements to generate gifs for. 137 | fps: frames per second of the animation 138 | collections: Optional list of tf.GraphKeys. The collections to add the 139 | summary to. Defaults to [tf.GraphKeys.SUMMARIES] 140 | family: Optional; if provided, used as the prefix of the summary tag name, 141 | which controls the tab name used for display on Tensorboard. 142 | 143 | Returns: 144 | A scalar `Tensor` of type `string`. The serialized `Summary` protocol 145 | buffer. 146 | """ 147 | tensor = tf.convert_to_tensor(tensor) 148 | if tensor.dtype in (tf.float32, tf.float64): 149 | tensor = tf.cast(255.0 * tensor, tf.uint8) 150 | with summary_op_util.summary_scope( 151 | name, family, values=[tensor]) as (tag, scope): 152 | val = tf.py_func( 153 | py_gif_summary, [tag, tensor, max_outputs, fps], 154 | tf.string, 155 | stateful=False, 156 | name=scope) 157 | summary_op_util.collect(val, collections, [tf.GraphKeys.SUMMARIES]) 158 | return val 159 | 160 | 161 | def matplot_figure_to_tensor(fig): 162 | buf = io.BytesIO() 163 | fig.savefig(buf, format='png') 164 | buf.seek(0) 165 | return tf.image.decode_png(buf.getvalue(), channels=4) 166 | 167 | 168 | def plot_1d_signals(signals, labels, size): 169 | """Plot a 1d signals and converts into an image tensor.""" 170 | import matplotlib # pylint: disable=g-import-not-at-top 171 | matplotlib.use('Agg') 172 | import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top 173 | 174 | images = [] 175 | for i in range(size): 176 | fig = plt.figure() 177 | ax = fig.add_subplot(1, 1, 1) 178 | for x, l in zip(signals[i], labels): 179 | ax.plot(x, label=l) 180 | ax.legend() 181 | image = matplot_figure_to_tensor(fig) 182 | plt.close(fig) 183 | images.append(image) 184 | stacked = tf.stack(images, axis=0) 185 | return stacked 186 | 187 | 188 | def py_plot_1d_signal(name, signals, labels, max_outputs=3, step=None): 189 | """Visualizes a list of 1d signals. 190 | 191 | Args: 192 | name: name of the summary. 193 | signals: a [batch, lines, steps] np.array list of 1d arrays. 194 | labels: a [lines] list of labels for each signal. 195 | max_outputs: the maximum number of plots to add to summaries. 196 | step: an explicit step or None. 197 | 198 | Returns: 199 | the summary result. 200 | """ 201 | image = plot_1d_signals(signals, labels, min(max_outputs, signals.shape[0])) 202 | return tfs.image(name, image, step, max_outputs=max_outputs) 203 | 204 | 205 | def tf_plot_1d_signal(name, signals, labels, max_outputs=3, step=None): 206 | """Visualizes a list of 1d signals. 207 | 208 | Args: 209 | name: name of the summary. 210 | signals: a [batch, lines, steps] tensor, each line a 1d signal. 211 | labels: a [lines] list of labels for each signal. 212 | max_outputs: the maximum number of plots to add to summaries. 213 | step: an explicit step or None. 214 | 215 | Returns: 216 | the summary result. 217 | """ 218 | image = tf.py_function( 219 | plot_1d_signals, 220 | (signals, labels, tf.math.minimum(max_outputs, tf.shape(signals)[0])), 221 | tf.uint8) 222 | return tfs.image(name, image, step, max_outputs=max_outputs) 223 | 224 | 225 | def side_by_side_frames(name, tensors): 226 | """Visualizes frames side by side. 227 | 228 | Args: 229 | name: name of the summary. 230 | tensors: a list of video tensors to be merged side by side. 231 | 232 | Returns: 233 | the summary result. 234 | """ 235 | x = tf.concat(tensors, axis=3) 236 | x = tf.concat(tf.unstack(x, axis=1), axis=1) 237 | return tfs.image(name, x) 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /agents/pure_reward.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # python3 16 | """A simple forward reward predictor.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from typing import Callable, Text 23 | 24 | import gin 25 | import gym 26 | import tensorflow.compat.v1 as tf 27 | import tensorflow.compat.v2.keras as tfk 28 | 29 | 30 | from world_models.tasks import tasks 31 | from world_models.utils import npz 32 | 33 | tfkl = tfk.layers 34 | 35 | _model = None 36 | 37 | 38 | def get_model(): 39 | global _model 40 | if _model is None: 41 | _model = PureReward() 42 | _model.compile( 43 | optimizer=tfk.optimizers.Adam(), 44 | loss=tfk.losses.MeanSquaredError(), 45 | metrics=[tfk.metrics.MeanSquaredError()]) 46 | return _model 47 | 48 | 49 | @gin.configurable 50 | class PureReward(tfk.Model): 51 | """Pure reward model.""" 52 | 53 | def __init__(self, 54 | recurrent: bool = gin.REQUIRED, 55 | task: tasks.Task = gin.REQUIRED, 56 | output_length: int = gin.REQUIRED, 57 | model_dir: Text = gin.REQUIRED): 58 | super(PureReward, self).__init__() 59 | self._action_space = task.create_env().action_space 60 | self.output_length = output_length 61 | self.model_dir = model_dir 62 | self.ckpt_file = os.path.join(self.model_dir, "ckpt.hd5") 63 | self.epoch = 0 64 | self.callbacks = [ 65 | tfk.callbacks.TensorBoard( 66 | log_dir=model_dir, write_graph=False, profile_batch=0), 67 | ] 68 | self.recurrent = recurrent 69 | if self.recurrent: 70 | self._init_recurrent_model() 71 | else: 72 | self._init_model() 73 | 74 | def _init_model(self): 75 | x = 16 76 | self.frame_enc = tfk.Sequential([ 77 | tfkl.Conv2D(2 * x, 3, 2, input_shape=(64, 64, 3)), 78 | tfkl.LeakyReLU(), 79 | tfkl.Conv2D(4 * x, 3, 2), 80 | tfkl.LeakyReLU(), 81 | tfkl.Conv2D(1 * x, 3, 2), 82 | tfkl.LeakyReLU(), 83 | tfkl.Flatten(), 84 | tfkl.Dense(1 * x), 85 | ]) 86 | 87 | if self.is_discrete_action: 88 | action_space_size = self._action_space.n 89 | else: 90 | action_space_size = self._action_space.shape[0] 91 | self.action_enc = tfk.Sequential([ 92 | tfkl.Flatten( 93 | input_shape=(self.output_length, action_space_size)), 94 | tfkl.Dense(4 * x), 95 | tfkl.LeakyReLU(), 96 | tfkl.Dense(2 * x), 97 | tfkl.LeakyReLU(), 98 | tfkl.Dense(1 * x), 99 | tfkl.LeakyReLU(), 100 | ]) 101 | 102 | self.reward_pred = tfk.Sequential([ 103 | tfkl.Flatten(input_shape=(2 * x,)), 104 | tfkl.Dense(8 * x), 105 | tfkl.LeakyReLU(), 106 | tfkl.Dropout(0.2), 107 | tfkl.Dense(2 * x), 108 | tfkl.LeakyReLU(), 109 | tfkl.Dropout(0.2), 110 | tfkl.Dense(self.output_length) 111 | ]) 112 | 113 | def _init_recurrent_model(self): 114 | x = 16 115 | self.frame_enc = tfk.Sequential([ 116 | tfkl.Conv2D(2 * x, 3, 2, input_shape=(64, 64, 3)), 117 | tfkl.LeakyReLU(), 118 | tfkl.Conv2D(4 * x, 3, 2), 119 | tfkl.LeakyReLU(), 120 | tfkl.Conv2D(1 * x, 3, 2), 121 | tfkl.LeakyReLU(), 122 | tfkl.Flatten(), 123 | tfkl.Dense(1 * x), 124 | ]) 125 | 126 | if self.is_discrete_action: 127 | action_space_size = self._action_space.n 128 | else: 129 | action_space_size = self._action_space.shape[0] 130 | self.action_enc = tfk.Sequential([ 131 | tfkl.Dense(4 * x, input_shape=(self.output_length, action_space_size)), 132 | tfkl.LeakyReLU(), 133 | tfkl.Dense(2 * x), 134 | tfkl.LeakyReLU(), 135 | tfkl.Dense(1 * x), 136 | tfkl.LeakyReLU(), 137 | ]) 138 | 139 | self.reward_pred = tfk.Sequential([ 140 | tfkl.LSTM(256, return_sequences=True, 141 | input_shape=(self.output_length, 2 * x)), 142 | tfkl.LayerNormalization(), 143 | tfkl.LeakyReLU(), 144 | tfkl.LSTM(128, return_sequences=True), 145 | tfkl.LayerNormalization(), 146 | tfkl.LeakyReLU(), 147 | tfkl.LSTM(64, return_sequences=True), 148 | tfkl.LayerNormalization(), 149 | tfkl.LeakyReLU(), 150 | tfkl.Dense(8 * x), 151 | tfkl.LeakyReLU(), 152 | tfkl.Dropout(0.2), 153 | tfkl.Dense(2 * x), 154 | tfkl.LeakyReLU(), 155 | tfkl.Dropout(0.2), 156 | tfkl.Dense(1) 157 | ]) 158 | 159 | @property 160 | def is_discrete_action(self): 161 | return isinstance(self._action_space, gym.spaces.Discrete) 162 | 163 | @tf.function 164 | def preprocess(self, inputs): 165 | frames, actions = inputs 166 | frames = tf.image.convert_image_dtype(frames, tf.float32) 167 | if self.is_discrete_action: 168 | actions = tf.one_hot( 169 | actions[:, :, 0], self._action_space.n, dtype=tf.float32) 170 | else: 171 | actions = tf.to_float(actions) 172 | return frames, actions 173 | 174 | @tf.function 175 | def call(self, inputs): 176 | frames, actions = self.preprocess(inputs) 177 | enc_frame = self.frame_enc(frames) 178 | enc_actions = self.action_enc(actions) 179 | if self.recurrent: 180 | # Add fake time dimension 181 | enc_frame = tf.expand_dims(enc_frame, axis=1) 182 | enc_frame = tf.tile(enc_frame, [1, self.output_length, 1]) 183 | stacked = tf.concat([enc_frame, enc_actions], axis=-1) 184 | output = self.reward_pred(stacked) 185 | if not self.recurrent: 186 | output = tf.expand_dims(output, axis=-1) 187 | return output 188 | 189 | 190 | @gin.configurable 191 | def observe_fn(last_image, last_action, last_reward, state): 192 | """the observe_fn for the model.""" 193 | del last_action, last_reward, state 194 | state = last_image[:, 0] 195 | return state 196 | 197 | 198 | @gin.configurable 199 | def create_predict_fn(batch: int = gin.REQUIRED, proposals: int = gin.REQUIRED): 200 | """Create predict fn.""" 201 | 202 | del batch, proposals 203 | model = get_model() 204 | if tf.io.gfile.exists(model.ckpt_file + ".index"): 205 | model.load_weights(model.ckpt_file) 206 | 207 | @tf.function 208 | def predict_fn(future_action, state): 209 | """A predict_fn for the model. 210 | 211 | Args: 212 | future_action: a [batch, time, action_dims] np array. 213 | state: a dictionary generated by `observe_fn`. 214 | 215 | Returns: 216 | predictions: a dictionary with possibly the following entries: 217 | * "image": [batch, time, height, width, channels] np array. 218 | * "reward": [batch, time] np array. 219 | """ 220 | model = get_model() 221 | rewards = model((state, future_action)) 222 | return {"reward": rewards} 223 | 224 | return predict_fn 225 | 226 | 227 | @gin.configurable 228 | def create_train_fn( 229 | train_steps: int = gin.REQUIRED, 230 | batch: int = gin.REQUIRED, 231 | ) -> Callable[[Text], None]: 232 | """creates a train_fn to train SV2P model referenced in state. 233 | 234 | Args: 235 | train_steps: number of training steps. 236 | batch: the batch size. 237 | 238 | Returns: 239 | A train_fn with the following positional arguments: 240 | * data_path: the path to the directory containing all episodes. 241 | This function returns nothing. 242 | """ 243 | 244 | iterator = None 245 | 246 | def generator(iterator): 247 | while True: 248 | yield next(iterator) 249 | 250 | def train_fn(data_path: Text, save_rewards: bool = True): 251 | """Training function.""" 252 | nonlocal iterator 253 | model = get_model() 254 | if iterator is None: 255 | duration = 1 + model.output_length 256 | dataset = npz.load_dataset_from_directory(data_path, duration, batch) 257 | dataset = dataset.map(lambda x: ( # pylint: disable=g-long-lambda 258 | (x["image"][:, 0], x["action"][:, 1:]), x["reward"][:, 1:])) 259 | iterator = iter(dataset) 260 | 261 | if tf.io.gfile.exists(model.ckpt_file + ".index"): 262 | model.load_weights(model.ckpt_file) 263 | model.fit_generator( 264 | generator(iterator), 265 | callbacks=model.callbacks, 266 | initial_epoch=model.epoch, 267 | epochs=model.epoch + 1, 268 | steps_per_epoch=train_steps) 269 | 270 | if save_rewards: 271 | reward_dir = os.path.join(model.model_dir, "train_rewards") 272 | x, y = next(iterator) 273 | model.fit(x, y) 274 | p = model(x) 275 | rewards_to_save = {"true": y, "pred": p} 276 | npz.save_dictionary(rewards_to_save, reward_dir) 277 | model.epoch += 1 278 | model.save_weights(model.ckpt_file) 279 | 280 | return train_fn 281 | 282 | 283 | @gin.configurable(blacklist=["state", "batch_size"]) 284 | def reset_fn(**kwargs): 285 | """A reset_fn for SV2P. 286 | 287 | Args: 288 | **kwargs: a dictionary of inputs, including previous state. 289 | 290 | Returns: 291 | a new dictionary with posteriors removed from the state. 292 | """ 293 | return kwargs["state"] 294 | -------------------------------------------------------------------------------- /bin/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluate a world model on an offline dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from typing import Any, Callable, Dict, Text 23 | 24 | from absl import app 25 | from absl import flags 26 | from absl import logging 27 | import gin 28 | import numpy as np 29 | import tensorflow.compat.v1 as tf 30 | 31 | from world_models.loops import train_eval 32 | from world_models.planners import planners 33 | from world_models.simulate import simulate 34 | from world_models.tasks import tasks 35 | from world_models.utils import npz 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | flags.DEFINE_string("model_dir", None, "Model checkpoint directory.") 40 | flags.DEFINE_string("data_dir", None, "data directory.") 41 | flags.DEFINE_string("output_dir", None, "output directory.") 42 | flags.DEFINE_multi_string( 43 | "config_path", None, 44 | "Newline separated list of paths to a world models gin configs.") 45 | flags.DEFINE_multi_string("config_param", None, 46 | "Newline separated list of Gin parameter bindings.") 47 | flags.DEFINE_bool("enable_eager", True, "Enable eager execution mode.") 48 | flags.DEFINE_integer("num_virtual_gpus", -1, "If >1, enables virtual gpus.") 49 | flags.DEFINE_boolean("train", False, "Train the model on data before eval.") 50 | flags.DEFINE_string("train_data_dir", None, "train data path.") 51 | 52 | 53 | def frame_error(predicted_frames, ground_truth_frames): 54 | """Frame prediction error as average L2 norm between pixels.""" 55 | batch, prediction_horizon = predicted_frames.shape[:2] 56 | return np.mean( 57 | np.linalg.norm( 58 | np.reshape( 59 | np.asarray(predicted_frames, dtype=np.float32), 60 | [batch, prediction_horizon, -1, 1]) - np.reshape( 61 | np.asarray(ground_truth_frames, dtype=np.float32), 62 | [batch, prediction_horizon, -1, 1]), 63 | axis=-1), 64 | axis=-1) 65 | 66 | 67 | def reward_error(predicted_rewards, ground_truth_rewards): 68 | """Reward prediction error as L2 norm.""" 69 | return np.linalg.norm(predicted_rewards - ground_truth_rewards, axis=-1) 70 | 71 | 72 | @gin.configurable( 73 | blacklist=["eval_dir", "train_dir", "model_dir", "result_dir"]) 74 | def offline_evaluate( 75 | predict_fn: Callable[[np.ndarray, Any], Dict[Text, np.ndarray]], 76 | observe_fn: Callable[[np.ndarray, np.ndarray, np.ndarray, Any], Any], 77 | reset_fn: Callable[..., Any], 78 | train_fn: Callable[[Text], None] = None, 79 | train_dir: Text = None, 80 | enable_train: bool = False, 81 | train_eval_iterations: int = 0, 82 | online_eval_task: tasks.Task = None, 83 | online_eval_planner: planners.Planner = None, 84 | online_eval_episodes: int = 0, 85 | eval_dir: Text = None, 86 | model_dir: Text = None, 87 | result_dir: Text = None, 88 | episode_length: int = None, 89 | num_episodes: int = 100, 90 | prediction_horizon: int = 1, 91 | batch: int = 128): 92 | """offline model evaluation.""" 93 | assert eval_dir, "eval_dir is required" 94 | assert model_dir, "model_dir is required" 95 | assert result_dir, "result_dir is required" 96 | assert episode_length, "episode_length is required" 97 | 98 | if enable_train: 99 | assert train_dir, "train_dir is required for training" 100 | assert train_eval_iterations, ("train_eval_iterations is required for " 101 | "training") 102 | for i in range(train_eval_iterations): 103 | train_fn(train_dir) 104 | result_dir_at_step = os.path.join(result_dir, "%d" % i) 105 | eval_once( 106 | result_dir=result_dir_at_step, 107 | eval_dir=eval_dir, 108 | episode_length=episode_length, 109 | prediction_horizon=prediction_horizon, 110 | batch=batch, 111 | num_episodes=num_episodes, 112 | reset_fn=reset_fn, 113 | observe_fn=observe_fn, 114 | predict_fn=predict_fn) 115 | if online_eval_episodes: 116 | summary_dir = os.path.join(result_dir, "online_eval") 117 | episodes, predictions, score = simulate.simulate( 118 | online_eval_task, online_eval_planner, online_eval_episodes) 119 | train_eval.visualize(summary_dir, i, episodes, predictions, 120 | {"score": score}) 121 | else: 122 | eval_once( 123 | result_dir=result_dir, 124 | eval_dir=eval_dir, 125 | episode_length=episode_length, 126 | prediction_horizon=prediction_horizon, 127 | batch=batch, 128 | num_episodes=num_episodes, 129 | reset_fn=reset_fn, 130 | observe_fn=observe_fn, 131 | predict_fn=predict_fn) 132 | 133 | 134 | def eval_once(result_dir, eval_dir, episode_length, prediction_horizon, batch, 135 | num_episodes, reset_fn, observe_fn, predict_fn): 136 | """Run offline eval once and store the results in `result_dir`.""" 137 | dataset = npz.load_dataset_from_directory(eval_dir, episode_length, batch) 138 | iterator = dataset.as_numpy_iterator() 139 | 140 | state = None 141 | reward_path = os.path.join(result_dir, "rewards") 142 | reward_error_at_prediction_horizon = np.zeros((prediction_horizon)) 143 | frame_error_at_prediction_horizon = np.zeros((prediction_horizon)) 144 | logging.info("Staring evaluation") 145 | predictions = {} 146 | for b, episodes in enumerate(iterator): 147 | if b * batch >= num_episodes: 148 | break 149 | if episodes["image"].dtype != np.uint8: 150 | episodes["image"] = np.clip(episodes["image"] * 255, 0, 151 | 255).astype(np.uint8) 152 | state = reset_fn(state=state, proposals=batch) 153 | for i in range(episode_length - prediction_horizon): 154 | timestep = {key: value[:, i:i + 1] for key, value in episodes.items()} 155 | frame = timestep["image"] 156 | reward = timestep["reward"] 157 | action = timestep["action"] 158 | future_actions = episodes["action"][:, i:i + prediction_horizon] 159 | future_frames = episodes["image"][:, i:i + prediction_horizon] 160 | future_rewards = episodes["reward"][:, i:i + prediction_horizon] 161 | state = observe_fn(frame, action, reward, state) 162 | predictions = predict_fn(future_actions, state) 163 | if "reward" in predictions: 164 | npz.save_dictionary( 165 | { 166 | "pred": predictions["reward"], 167 | "true": future_rewards 168 | }, reward_path) 169 | reward_error_at_prediction_horizon += np.sum( 170 | reward_error(predictions["reward"], future_rewards), axis=0) 171 | if "image" in predictions: 172 | frame_error_at_prediction_horizon += np.sum( 173 | frame_error(predictions["image"], future_frames), axis=0) 174 | logging.info("Finished evaluation on %d episodes", batch) 175 | 176 | reward_error_at_prediction_horizon /= num_episodes * ( 177 | episode_length - prediction_horizon) 178 | frame_error_at_prediction_horizon /= num_episodes * ( 179 | episode_length - prediction_horizon) 180 | logging.info("Finished evaluation") 181 | results = {} 182 | if "reward" in predictions: 183 | logging.info( 184 | "Average reward L2 norm error for different prediction horizons: %s", 185 | reward_error_at_prediction_horizon) 186 | results["reward_error"] = reward_error_at_prediction_horizon 187 | else: 188 | logging.info("predict_fn does not predict rewards." 189 | " L2 norm on reward prediction could not be calculated.") 190 | if "image" in predictions: 191 | logging.info( 192 | "Average frame L2 norm error for different prediction horizons: %s", 193 | frame_error_at_prediction_horizon) 194 | results["image_error"] = frame_error_at_prediction_horizon 195 | else: 196 | logging.info("predict_fn does not predict frames." 197 | " L2 norm on frame prediction could not be calculated.") 198 | npz.save_dictionary(results, result_dir) 199 | 200 | 201 | def main(argv): 202 | del argv # Unused 203 | if FLAGS.enable_eager: 204 | tf.enable_eager_execution() 205 | 206 | config_params = FLAGS.config_param or [] 207 | config_params += [ 208 | "model_dir='%s'" % FLAGS.model_dir, 209 | "episodes_dir='%s'" % FLAGS.output_dir 210 | ] 211 | gin.parse_config_files_and_bindings(FLAGS.config_path, config_params) 212 | 213 | if FLAGS.num_virtual_gpus > -1: 214 | gpus = tf.config.experimental.list_physical_devices("GPU") 215 | 216 | total_gpu_mem_limit = 8192 217 | per_gpu_mem_limit = total_gpu_mem_limit / FLAGS.num_virtual_gpus 218 | virtual_gpus = [ 219 | tf.config.experimental.VirtualDeviceConfiguration( 220 | memory_limit=per_gpu_mem_limit) 221 | ] * FLAGS.num_virtual_gpus 222 | tf.config.experimental.set_virtual_device_configuration( 223 | gpus[0], virtual_gpus) 224 | logical_gpus = tf.config.experimental.list_logical_devices("GPU") 225 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 226 | 227 | offline_evaluate( # pylint:disable=no-value-for-parameter 228 | result_dir=FLAGS.output_dir, 229 | model_dir=FLAGS.model_dir, 230 | eval_dir=FLAGS.data_dir, 231 | train_dir=FLAGS.train_data_dir, 232 | enable_train=FLAGS.train) 233 | 234 | 235 | if __name__ == "__main__": 236 | flags.mark_flags_as_required(["output_dir"]) 237 | flags.mark_flags_as_required(["config_path"]) 238 | app.run(main) 239 | -------------------------------------------------------------------------------- /utils/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Environment wrappers.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import gym 22 | import numpy as np 23 | from PIL import Image 24 | from world_models.utils import nested 25 | 26 | 27 | class ObservationDict(gym.Wrapper): 28 | """Changes the observation space to be a dict.""" 29 | 30 | def __init__(self, env, key='observ'): 31 | self._key = key 32 | self.env = env 33 | 34 | def __getattr__(self, name): 35 | return getattr(self.env, name) 36 | 37 | @property 38 | def observation_space(self): 39 | spaces = {self._key: self.env.observation_space} 40 | return gym.spaces.Dict(spaces) 41 | 42 | @property 43 | def action_space(self): 44 | return self.env.action_space 45 | 46 | def step(self, action): 47 | obs, reward, done, info = self.env.step(action) 48 | obs = {self._key: np.array(obs)} 49 | return obs, reward, done, info 50 | 51 | def reset(self): 52 | obs = self.env.reset() 53 | obs = {self._key: np.array(obs)} 54 | return obs 55 | 56 | 57 | class ActionRepeat(gym.Wrapper): 58 | """Repeats the same action `n` times and returns the last step results.""" 59 | 60 | def __init__(self, env, n): 61 | super(ActionRepeat, self).__init__(env) 62 | assert n >= 1 63 | self._n = n 64 | 65 | def __getattr__(self, name): 66 | return getattr(self.env, name) 67 | 68 | def step(self, action): 69 | done = False 70 | total_reward = 0 71 | current_step = 0 72 | while current_step < self._n and not done: 73 | observ, reward, done, info = self.env.step(action) 74 | total_reward += reward 75 | current_step += 1 76 | return observ, total_reward, done, info 77 | 78 | 79 | class ActionNormalize(gym.Env): 80 | """Normalizes the action space.""" 81 | 82 | def __init__(self, env): 83 | self._env = env 84 | self._mask = np.logical_and( 85 | np.isfinite(env.action_space.low), np.isfinite(env.action_space.high)) 86 | self._low = np.where(self._mask, env.action_space.low, -1) 87 | self._high = np.where(self._mask, env.action_space.high, 1) 88 | 89 | def __getattr__(self, name): 90 | return getattr(self._env, name) 91 | 92 | @property 93 | def action_space(self): 94 | low = np.where(self._mask, -np.ones_like(self._low), self._low) 95 | high = np.where(self._mask, np.ones_like(self._low), self._high) 96 | return gym.spaces.Box(low, high, dtype=np.float32) 97 | 98 | def step(self, action): 99 | original = (action + 1) / 2 * (self._high - self._low) + self._low 100 | original = np.where(self._mask, original, action) 101 | return self._env.step(original) 102 | 103 | def reset(self): 104 | return self._env.reset() 105 | 106 | def render(self, mode='human'): 107 | return self._env.render(mode=mode) 108 | 109 | 110 | class MaximumDuration(gym.Wrapper): 111 | """Force sets `done` after the specified duration.""" 112 | 113 | def __init__(self, env, duration): 114 | super(MaximumDuration, self).__init__(env) 115 | self._duration = duration 116 | self._step = None 117 | 118 | def __getattr__(self, name): 119 | return getattr(self.env, name) 120 | 121 | def step(self, action): 122 | if self._step is None: 123 | raise RuntimeError('Must reset environment.') 124 | observ, reward, done, info = self.env.step(action) 125 | self._step += 1 126 | if self._step >= self._duration: 127 | done = True 128 | self._step = None 129 | return observ, reward, done, info 130 | 131 | def reset(self): 132 | self._step = 0 133 | return self.env.reset() 134 | 135 | 136 | class MinimumDuration(gym.Wrapper): 137 | """Force resets `done` before the specified duration.""" 138 | 139 | def __init__(self, env, duration): 140 | super(MinimumDuration, self).__init__(env) 141 | self._duration = duration 142 | self._step = None 143 | 144 | def __getattr__(self, name): 145 | return getattr(self.env, name) 146 | 147 | def step(self, action): 148 | observ, reward, done, info = self.env.step(action) 149 | self._step += 1 150 | if self._step < self._duration: 151 | done = False 152 | return observ, reward, done, info 153 | 154 | def reset(self): 155 | self._step = 0 156 | return self.env.reset() 157 | 158 | 159 | class ConvertTo32Bit(gym.Wrapper): 160 | """Converts observation and rewards to int/float32.""" 161 | 162 | def __getattr__(self, name): 163 | return getattr(self.env, name) 164 | 165 | def step(self, action): 166 | observ, reward, done, info = self.env.step(action) 167 | observ = nested.map(self._convert_observ, observ) 168 | reward = self._convert_reward(reward) 169 | return observ, reward, done, info 170 | 171 | def reset(self): 172 | observ = self.env.reset() 173 | observ = nested.map(self._convert_observ, observ) 174 | return observ 175 | 176 | def _convert_observ(self, observ): 177 | if not np.isfinite(observ).all(): 178 | raise ValueError('Infinite observation encountered.') 179 | if observ.dtype == np.float64: 180 | return observ.astype(np.float32) 181 | if observ.dtype == np.int64: 182 | return observ.astype(np.int32) 183 | return observ 184 | 185 | def _convert_reward(self, reward): 186 | if not np.isfinite(reward).all(): 187 | raise ValueError('Infinite reward encountered.') 188 | return np.array(reward, dtype=np.float32) 189 | 190 | 191 | class RenderObservation(gym.Env): 192 | """Changes the observation space to rendered frames.""" 193 | 194 | def __init__(self, env, size=(64, 64), dtype=np.uint8, key='image'): 195 | assert isinstance(env.observation_space, gym.spaces.Dict) 196 | self.env = env 197 | self._size = size 198 | self._dtype = dtype 199 | self._key = key 200 | 201 | def __getattr__(self, name): 202 | return getattr(self.env, name) 203 | 204 | @property 205 | def observation_space(self): 206 | high = {np.uint8: 255, np.float: 1.0}[self._dtype] 207 | image = gym.spaces.Box(0, high, self._size + (3,), dtype=self._dtype) 208 | spaces = self.env.observation_space.spaces.copy() 209 | assert self._key not in spaces 210 | spaces[self._key] = image 211 | return gym.spaces.Dict(spaces) 212 | 213 | @property 214 | def action_space(self): 215 | return self.env.action_space 216 | 217 | def step(self, action): 218 | obs, reward, done, info = self.env.step(action) 219 | obs[self._key] = self._render_image() 220 | return obs, reward, done, info 221 | 222 | def reset(self): 223 | obs = self.env.reset() 224 | obs[self._key] = self._render_image() 225 | return obs 226 | 227 | def _render_image(self): 228 | """Renders the environment and processes the image.""" 229 | image = self.env.render('rgb_array') 230 | if image.shape[:2] != self._size: 231 | image = np.array(Image.fromarray(image).resize(self._size)) 232 | if self._dtype and image.dtype != self._dtype: 233 | if image.dtype in (np.float32, np.float64) and self._dtype == np.uint8: 234 | image = (image * 255).astype(self._dtype) 235 | elif image.dtype == np.uint8 and self._dtype in (np.float32, np.float64): 236 | image = image.astype(self._dtype) / 255 237 | else: 238 | message = 'Cannot convert observations from {} to {}.' 239 | raise NotImplementedError(message.format(image.dtype, self._dtype)) 240 | return image 241 | 242 | class DeepMindEnv(gym.Env): 243 | """Wrapper for deepmind MuJoCo environments to expose gym env methods.""" 244 | metadata = {'render.modes': ['rgb_array']} 245 | reward_range = (-np.inf, np.inf) 246 | 247 | def __init__(self, env, render_size=(64, 64), camera_id=0): 248 | self._env = env 249 | self._render_size = render_size 250 | self._camera_id = camera_id 251 | 252 | def __getattr__(self, name): 253 | return getattr(self._env, name) 254 | 255 | @property 256 | def observation_space(self): 257 | components = {} 258 | for key, value in self._env.observation_spec().items(): 259 | components[key] = gym.spaces.Box( 260 | -np.inf, np.inf, value.shape, dtype=np.float32) 261 | return gym.spaces.Dict(components) 262 | 263 | @property 264 | def action_space(self): 265 | action_spec = self._env.action_spec() 266 | return gym.spaces.Box( 267 | action_spec.minimum, action_spec.maximum, dtype=np.float32) 268 | 269 | def step(self, action): 270 | time_step = self._env.step(action) 271 | obs = dict(time_step.observation) 272 | reward = time_step.reward or 0 273 | done = time_step.last() 274 | info = {'discount': time_step.discount} 275 | return obs, reward, done, info 276 | 277 | def reset(self): 278 | time_step = self._env.reset() 279 | return dict(time_step.observation) 280 | 281 | def render(self, *args, **kwargs): 282 | if kwargs.get('mode', 'rgb_array') != 'rgb_array': 283 | raise ValueError("Only render mode 'rgb_array' is supported.") 284 | del args # Unused 285 | del kwargs # Unused 286 | return self._env.physics.render( 287 | *self._render_size, camera_id=self._camera_id) 288 | 289 | def get_state(self): 290 | return ( 291 | np.array(self.physics.data.qpos), 292 | np.array(self.physics.data.qvel), 293 | np.array(self.physics.data.ctrl), 294 | np.array(self.physics.data.act)) 295 | 296 | def set_state(self, state): 297 | with self.physics.reset_context(): 298 | self.physics.data.qpos[:] = state[0] 299 | self.physics.data.qvel[:] = state[1] 300 | self.physics.data.ctrl[:] = state[2] 301 | self.physics.data.act[:] = state[3] 302 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /agents/sv2p.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # python3 16 | """The wrapper for SV2P model.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from typing import Callable, Text, Tuple 23 | 24 | import gin 25 | import gym 26 | import numpy as np 27 | import tensorflow.compat.v1 as tf 28 | import tensorflow 29 | from world_models.imported_models import sv2p 30 | from world_models.imported_models import sv2p_hparams 31 | from world_models.tasks import tasks 32 | from world_models.utils import npz 33 | 34 | from tensorflow.python.distribute import values 35 | 36 | tfs = tensorflow.compat.v2.summary 37 | 38 | gin.external_configurable(tf.distribute.MirroredStrategy, 39 | "tf.distribute.MirroredStrategy") 40 | 41 | 42 | @gin.configurable 43 | class SV2P(object): 44 | """Wrapper for SV2P.""" 45 | 46 | def __init__(self, 47 | task: tasks.Task = gin.REQUIRED, 48 | input_length: int = gin.REQUIRED, 49 | output_length: int = gin.REQUIRED, 50 | frame_size: Tuple[int] = (64, 64, 3), 51 | include_frames_in_prediction=False, 52 | model_dir: Text = gin.REQUIRED): 53 | self.action_space = task.create_env().action_space 54 | self.frame_size = frame_size 55 | self.input_length = input_length 56 | self.output_length = output_length 57 | self.model_dir = model_dir 58 | self._include_frames_in_prediction = include_frames_in_prediction 59 | self._hparams = sv2p_hparams.sv2p_hparams() 60 | self._hparams.video_num_input_frames = input_length 61 | self._hparams.video_num_target_frames = output_length 62 | self._model = sv2p.SV2P(self._hparams) 63 | self._train = self._model.train 64 | self._infer = self._model.infer 65 | self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3, epsilon=1e-3) 66 | self.summary_writer = tfs.create_file_writer(self.model_dir) 67 | 68 | def train(self, features): 69 | return self._train(features) 70 | 71 | def infer(self, features, prediction_len): 72 | """Enable predicting more frames than training width.""" 73 | # Adjust hparams. 74 | hp = self._hparams 75 | hp.latent_num_frames = self.input_length + self.output_length 76 | hp.video_num_target_frames = prediction_len 77 | # Call model. 78 | output = self._infer(features) 79 | # Roll-back hparams changes after infer 80 | hp.latent_num_frames = 0 81 | hp.video_num_target_frames = self.output_length 82 | result = {"reward": output["rewards"]} 83 | if self._include_frames_in_prediction: 84 | result["image"] = output["frames"] 85 | return result 86 | 87 | @property 88 | def is_discrete_action(self): 89 | return isinstance(self.action_space, gym.spaces.Discrete) 90 | 91 | def format_actions(self, actions): 92 | if self.is_discrete_action: 93 | return tf.one_hot(actions[:, :, 0], self.action_space.n, dtype=tf.float32) 94 | else: 95 | return tf.to_float(actions) 96 | 97 | def create_features(self, images, rewards, actions): 98 | return { 99 | "frames": images, 100 | "rewards": rewards, 101 | "actions": actions, 102 | } 103 | 104 | def _get_trackables(self, global_step, optimizer): 105 | trackables = self._model.trackables 106 | if global_step is not None: 107 | trackables["global_step"] = global_step 108 | if optimizer is not None: 109 | trackables["optimizer"] = optimizer 110 | return trackables 111 | 112 | def _get_checkpoint_manager(self, global_step, optimizer): 113 | trackables = self._get_trackables(global_step, optimizer) 114 | checkpoint = tf.train.Checkpoint(**trackables) 115 | manager = tf.train.CheckpointManager( 116 | checkpoint, self.model_dir, max_to_keep=1) 117 | return checkpoint, manager 118 | 119 | def restore_latest_checkpoint(self, global_step=None, optimizer=None): 120 | checkpoint, manager = self._get_checkpoint_manager(global_step, optimizer) 121 | checkpoint.restore(manager.latest_checkpoint) 122 | return global_step, optimizer 123 | 124 | def save_checkpoint(self, global_step, optimizer=None): 125 | _, manager = self._get_checkpoint_manager(global_step, optimizer) 126 | manager.save(global_step) 127 | 128 | 129 | @gin.configurable 130 | def create_observe_fn(model=gin.REQUIRED): 131 | """Creates an observe function for SV2P.""" 132 | 133 | @tf.function 134 | def observe_fn(last_image, last_action, last_reward, state): 135 | """the observe_fn for sv2p.""" 136 | last_action = model.format_actions(last_action) 137 | last_reward = tf.to_float(last_reward) 138 | new_state = { 139 | "images": tf.concat([state["images"], last_image], axis=1)[:, 1:], 140 | "actions": tf.concat([state["actions"], last_action], axis=1)[:, 1:], 141 | "rewards": tf.concat([state["rewards"], last_reward], axis=1)[:, 1:] 142 | } 143 | return new_state 144 | 145 | return observe_fn 146 | 147 | 148 | @gin.configurable 149 | def create_predict_fn(model=gin.REQUIRED, 150 | prediction_horizon=gin.REQUIRED, 151 | strategy=gin.REQUIRED): 152 | """Creates a predict function for SV2P.""" 153 | model.restore_latest_checkpoint(global_step=None, optimizer=None) 154 | 155 | @tf.function 156 | def predict_fn(future_action, state): 157 | """A predict_fn for SV2P model referenced in state. 158 | 159 | Args: 160 | future_action: a [batch, time, action_dims] tensor. 161 | state: a dictionary generated by `observe_fn`. 162 | 163 | Returns: 164 | predictions: a dictionary with possibly the following entries: 165 | * "reward": [batch, time, 1] tensor. 166 | """ 167 | future_action = model.format_actions(future_action) 168 | actions = tf.concat((state["actions"], future_action), axis=1) 169 | infer_data = model.create_features(state["images"], state["rewards"], 170 | actions) 171 | 172 | # break down the inputs along the batch dimension to form equal sized 173 | # tensors in each replica. 174 | num_replicas = strategy.num_replicas_in_sync 175 | inputs = { 176 | key: tf.split(value, num_replicas) for key, value in infer_data.items() 177 | } 178 | dist_inputs = [] 179 | for i in range(num_replicas): 180 | dist_inputs.append({key: value[i] for key, value in inputs.items()}) 181 | devices = values.ReplicaDeviceMap(strategy.extended.worker_devices) 182 | dist_inputs = values.PerReplica(devices, tuple(dist_inputs)) 183 | dist_predictions = strategy.experimental_run_v2( 184 | model.infer, args=(dist_inputs, prediction_horizon)) 185 | dist_predictions = { 186 | key: strategy.experimental_local_results(value) 187 | for key, value in dist_predictions.items() 188 | } 189 | predictions = { 190 | key: tf.concat(value, axis=0) 191 | for key, value in dist_predictions.items() 192 | } 193 | return predictions 194 | 195 | return predict_fn 196 | 197 | 198 | @gin.configurable 199 | def create_train_fn(train_steps: int = gin.REQUIRED, 200 | batch: int = gin.REQUIRED, 201 | model: SV2P = gin.REQUIRED, 202 | strategy: tf.distribute.Strategy = gin.REQUIRED, 203 | save_rewards: bool = True) -> Callable[[Text], None]: 204 | """creates a train_fn to train SV2P model referenced in state. 205 | 206 | Args: 207 | train_steps: number of training steps. 208 | batch: the batch size. 209 | model: an SV2P model reference. 210 | strategy: a tf.distribute.Strategy object. 211 | save_rewards: whether or not to save the predicted rewards. 212 | 213 | Returns: 214 | A train_fn with the following positional arguments: 215 | * data_path: the path to the directory containing all episodes. 216 | This function returns nothing. 217 | """ 218 | iterator = None 219 | 220 | @tf.function 221 | def train_step(obs): 222 | """Single training step.""" 223 | 224 | def train_iter(obs): 225 | with tf.GradientTape() as tape: 226 | actions = model.format_actions(obs["action"]) 227 | features = model.create_features(obs["image"], obs["reward"], actions) 228 | loss, pred_rewards = model.train(features) 229 | loss = tf.reduce_mean(loss) 230 | variables = tape.watched_variables() 231 | grads = tape.gradient(loss, variables) 232 | grads, _ = tf.clip_by_global_norm(grads, 1000) 233 | model.optimizer.apply_gradients(zip(grads, variables)) 234 | return loss, pred_rewards, obs["reward"] 235 | 236 | return strategy.experimental_run_v2(train_iter, args=(obs,)) 237 | 238 | def train_fn(data_path: Text): 239 | """Training function for SV2P.""" 240 | nonlocal iterator 241 | if iterator is None: 242 | duration = model.input_length + model.output_length 243 | dataset = npz.load_dataset_from_directory(data_path, duration, batch) 244 | dataset = strategy.experimental_distribute_dataset(dataset) 245 | iterator = dataset 246 | 247 | with strategy.scope(): 248 | global_step = tf.train.get_or_create_global_step() 249 | tfs.experimental.set_step(global_step) 250 | global_step, model.optimizer = model.restore_latest_checkpoint( 251 | global_step, model.optimizer) 252 | with model.summary_writer.as_default(), tfs.record_if( 253 | tf.math.equal(tf.math.mod(global_step, 100), 0)): 254 | true_rewards, pred_rewards = None, None 255 | for step, data in enumerate(iterator): 256 | if step > train_steps: 257 | if save_rewards: 258 | # We are only saving the last training batch. 259 | reward_dir = os.path.join(model.model_dir, "train_rewards") 260 | true_rewards = strategy.experimental_local_results(true_rewards) 261 | pred_rewards = strategy.experimental_local_results(pred_rewards) 262 | true_rewards = np.concatenate([x.numpy() for x in true_rewards]) 263 | pred_rewards = np.concatenate([x.numpy() for x in pred_rewards]) 264 | rewards_to_save = {"true": true_rewards, "pred": pred_rewards} 265 | npz.save_dictionary(rewards_to_save, reward_dir) 266 | break 267 | loss, pred_rewards, true_rewards = train_step(data) 268 | if step % 100 == 0: 269 | loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss) 270 | tf.logging.info("Training loss at %d step: %f", step, loss) 271 | global_step.assign_add(1) 272 | 273 | model.save_checkpoint(global_step, model.optimizer) 274 | 275 | return train_fn 276 | 277 | 278 | @gin.configurable 279 | def create_reset_fn(model=gin.REQUIRED): 280 | """Creates a reset_fn function.""" 281 | 282 | @tf.function 283 | def reset_fn(**kwargs): 284 | """A reset_fn for SV2P. 285 | 286 | Args: 287 | **kwargs: a dictionary of inputs, including previous state. 288 | 289 | Returns: 290 | a new dictionary with posteriors removed from the state. 291 | """ 292 | batch_size = kwargs["proposals"] 293 | input_len = model.input_length 294 | image_shape = model.frame_size 295 | if model.is_discrete_action: 296 | action_shape = (model.action_space.n,) 297 | else: 298 | action_shape = model.action_space.shape 299 | action_dtype = tf.float32 300 | return { 301 | "images": 302 | tf.zeros((batch_size, input_len) + image_shape, dtype=tf.int32), 303 | "actions": 304 | tf.zeros( 305 | (batch_size, input_len) + action_shape, 306 | dtype=tf.as_dtype(action_dtype)), 307 | "rewards": 308 | tf.zeros((batch_size, input_len, 1)), 309 | } 310 | 311 | return reset_fn 312 | -------------------------------------------------------------------------------- /imported_models/sv2p.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """SV2P: Stochastic Variational Video Prediction. 16 | 17 | based on the following paper: 18 | https://arxiv.org/abs/1710.11252 19 | by Mohammad Babaeizadeh, Chelsea Finn, Dumitru Erhan, 20 | Roy H. Campbell and Sergey Levine 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow.compat.v1 as tf 28 | import tensorflow 29 | 30 | from world_models.imported_models import common 31 | from world_models.imported_models import layers 32 | from world_models.imported_models import reward_models 33 | from world_models.utils import visualization 34 | from tensorflow.contrib import layers as contrib_layers 35 | 36 | tfl = tf.layers 37 | tfcl = contrib_layers 38 | tfs = tensorflow.compat.v2.summary 39 | 40 | 41 | class SV2P(object): 42 | """Stochastic Variational Video Prediction From Basic Model!""" 43 | 44 | def __init__(self, hparams): 45 | self.hparams = hparams 46 | if hparams.merged_reward_model: 47 | self._body = tf.make_template("body", self.__body_merged_model) 48 | else: 49 | self._body = tf.make_template("body", self.__body_separate_model) 50 | self.infer = tf.make_template("infer", self.__infer) 51 | self.train = tf.make_template("train", self.__train) 52 | self.loss = tf.make_template("loss", self.__loss) 53 | 54 | self.posterior = tf.make_template("posterior", self.conv_latent_tower) 55 | self.prior = tf.make_template("prior", self.gaussian_prior) 56 | self.latent_tow = self.posterior if self.hparams.stochastic else self.prior 57 | 58 | def __train(self, features): 59 | """Trains the model.""" 60 | self.is_training = True 61 | frames, actions, rewards = self.extract_features(features) 62 | latent_mean, latent_logvar = self.latent_tow(features["frames"]) 63 | latent = layers.get_gaussian_tensor(latent_mean, latent_logvar) 64 | extra_loss = layers.kl_divergence(latent_mean, latent_logvar) 65 | pred_frames, pred_rewards = self._body(frames, actions, rewards, latent) 66 | predictions = { 67 | "frames": tf.stack(pred_frames[:-1], axis=1), 68 | "rewards": tf.stack(pred_rewards[:-1], axis=1) 69 | } 70 | return self.loss(predictions, features, extra_loss), predictions["rewards"] 71 | 72 | def __infer(self, features): 73 | """Produce predictions from the model by running it.""" 74 | self.is_training = False 75 | frames, actions, rewards = self.extract_features(features) 76 | mean, logvar = self.gaussian_prior(features["frames"]) 77 | latent = layers.get_gaussian_tensor(mean, logvar) 78 | 79 | pred_frames, pred_rewards = self._body(frames, actions, rewards, latent) 80 | 81 | extra_predicted_frames = self.hparams.video_num_input_frames 82 | predictions = { 83 | "frames": tf.stack(pred_frames[extra_predicted_frames:], axis=1), 84 | "rewards": tf.stack(pred_rewards[extra_predicted_frames:], axis=1) 85 | } 86 | return predictions 87 | 88 | def __loss(self, predictions, features, extra_loss): 89 | """Calculates the loss.""" 90 | reward_loss = tf.constant(0.0, tf.float32) 91 | 92 | pred = predictions["rewards"] 93 | loss_func = tf.keras.losses.MSE 94 | targ = features["rewards"][:, 1:] 95 | reward_loss += tf.reduce_mean(tfl.flatten(loss_func(targ, pred)), axis=-1) 96 | 97 | pred = predictions["frames"] 98 | targ = tf.image.convert_image_dtype(features["frames"][:, 1:], tf.float32) 99 | loss_func = tf.keras.losses.MSE 100 | frames_loss = tf.reduce_mean(tfl.flatten(loss_func(targ, pred)), axis=-1) 101 | 102 | total_loss = ( 103 | frames_loss + reward_loss * self.hparams.loss_reward_multiplier + 104 | extra_loss * self.hparams.loss_extra_multiplier) 105 | 106 | tfs.scalar("loss/total", tf.reduce_mean(total_loss)) 107 | tfs.scalar("loss/frames", tf.reduce_mean(frames_loss)) 108 | tfs.scalar("loss/reward", tf.reduce_mean(reward_loss)) 109 | tfs.scalar("loss/extra", tf.reduce_mean(extra_loss)) 110 | visualization.side_by_side_frames("vis/frames", [targ, pred]) 111 | 112 | return total_loss 113 | 114 | def extract_features(self, features): 115 | frames = tf.unstack(features["frames"], axis=1) 116 | actions = tf.unstack(features["actions"], axis=1) 117 | rewards = tf.unstack(features["rewards"], axis=1) 118 | return frames, actions, rewards 119 | 120 | @property 121 | def trackables(self): 122 | return {"model": self._body} 123 | 124 | @property 125 | def video_len(self): 126 | hp = self.hparams 127 | return hp.video_num_input_frames + hp.video_num_target_frames 128 | 129 | def get_iteration_num(self): 130 | return tf.train.get_global_step() 131 | 132 | def reward_prediction(self, mid_outputs): 133 | """Select reward predictor based on hparams.""" 134 | x = reward_models.reward_prediction_mid(mid_outputs) 135 | x = tfl.flatten(x) 136 | x = tfl.dense( 137 | x, 138 | self.hparams.reward_prediction_classes, 139 | activation=None, 140 | name="reward_map") 141 | return x 142 | 143 | def upsample(self, x, num_outputs, strides): 144 | x = tfl.conv2d_transpose( 145 | x, num_outputs, (3, 3), strides=strides, activation=tf.nn.relu) 146 | return x[:, 1:, 1:, :] 147 | 148 | def shape_list(self, x): 149 | return x.shape.as_list() 150 | 151 | def inject_additional_input(self, layer, inputs, name): 152 | """Injects the additional input into the layer. 153 | 154 | Args: 155 | layer: layer that the input should be injected to. 156 | inputs: inputs to be injected. 157 | name: TF scope name. 158 | 159 | Returns: 160 | updated layer. 161 | 162 | Raises: 163 | ValueError: in case of unknown mode. 164 | """ 165 | inputs = common.to_float(inputs) 166 | layer_shape = self.shape_list(layer) 167 | emb = layers.encode_to_shape(inputs, layer_shape, name) 168 | layer = tf.concat(values=[layer, emb], axis=-1) 169 | return layer 170 | 171 | def bottom_part_tower(self, input_image, action, reward, latent, lstm_state, 172 | lstm_size, conv_size): 173 | """The bottom part of predictive towers. 174 | 175 | With the current (early) design, the main prediction tower and 176 | the reward prediction tower share the same arcitecture. TF Scope can be 177 | adjusted as required to either share or not share the weights between 178 | the two towers. 179 | 180 | Args: 181 | input_image: the current image. 182 | action: the action taken by the agent. 183 | reward: the previous reward. observed or predicted. 184 | latent: the latent vector. 185 | lstm_state: the current internal states of conv lstms. 186 | lstm_size: the size of lstms. 187 | conv_size: the size of convolutions. 188 | 189 | Returns: 190 | - the output of the partial network. 191 | - intermidate outputs for skip connections. 192 | """ 193 | lstm_func = layers.conv_lstm_2d 194 | input_image = layers.make_even_size(input_image) 195 | 196 | layer_id = 0 197 | enc0 = tfl.conv2d( 198 | input_image, 199 | conv_size[0], [5, 5], 200 | strides=(2, 2), 201 | activation=tf.nn.relu, 202 | padding="SAME", 203 | name="scale1_conv1") 204 | enc0 = tfcl.layer_norm(enc0, scope="layer_norm1") 205 | 206 | hidden1, lstm_state[layer_id] = lstm_func( 207 | enc0, lstm_state[layer_id], lstm_size[layer_id], name="state1") 208 | hidden1 = tfcl.layer_norm(hidden1, scope="layer_norm2") 209 | layer_id += 1 210 | 211 | hidden2, lstm_state[layer_id] = lstm_func( 212 | hidden1, lstm_state[layer_id], lstm_size[layer_id], name="state2") 213 | hidden2 = tfcl.layer_norm(hidden2, scope="layer_norm3") 214 | hidden2 = layers.make_even_size(hidden2) 215 | enc1 = tfl.conv2d( 216 | hidden2, 217 | hidden2.get_shape()[3], [3, 3], 218 | strides=(2, 2), 219 | padding="SAME", 220 | activation=tf.nn.relu, 221 | name="conv2") 222 | layer_id += 1 223 | 224 | hidden3, lstm_state[layer_id] = lstm_func( 225 | enc1, lstm_state[layer_id], lstm_size[layer_id], name="state3") 226 | hidden3 = tfcl.layer_norm(hidden3, scope="layer_norm4") 227 | layer_id += 1 228 | 229 | hidden4, lstm_state[layer_id] = lstm_func( 230 | hidden3, lstm_state[layer_id], lstm_size[layer_id], name="state4") 231 | hidden4 = tfcl.layer_norm(hidden4, scope="layer_norm5") 232 | hidden4 = layers.make_even_size(hidden4) 233 | enc2 = tfl.conv2d( 234 | hidden4, 235 | hidden4.get_shape()[3], [3, 3], 236 | strides=(2, 2), 237 | padding="SAME", 238 | activation=tf.nn.relu, 239 | name="conv3") 240 | layer_id += 1 241 | 242 | enc2 = self.inject_additional_input(enc2, action, "action_enc") 243 | if reward is not None: 244 | enc2 = self.inject_additional_input(enc2, reward, "reward_enc") 245 | with tf.control_dependencies([latent]): 246 | enc2 = tf.concat([enc2, latent], axis=3) 247 | 248 | enc3 = tfl.conv2d( 249 | enc2, 250 | hidden4.get_shape()[3], [1, 1], 251 | strides=(1, 1), 252 | padding="SAME", 253 | activation=tf.nn.relu, 254 | name="conv4") 255 | 256 | hidden5, lstm_state[layer_id] = lstm_func( 257 | enc3, lstm_state[layer_id], lstm_size[layer_id], name="state5") 258 | hidden5 = tfcl.layer_norm(hidden5, scope="layer_norm6") 259 | layer_id += 1 260 | return hidden5, (enc0, enc1), layer_id 261 | 262 | def construct_predictive_tower(self, input_image, action, reward, lstm_state, 263 | latent): 264 | """Main prediction tower.""" 265 | lstm_func = layers.conv_lstm_2d 266 | frame_shape = self.shape_list(input_image) 267 | _, img_height, img_width, color_channels = frame_shape 268 | batch_size = tf.shape(input_image)[0] 269 | # the number of different pixel motion predictions 270 | # and the number of masks for each of those predictions 271 | num_masks = self.hparams.num_masks 272 | 273 | lstm_size = [32, 32, 64, 64, 128, 64, 32] 274 | conv_size = [32] 275 | 276 | with tf.variable_scope("bottom", reuse=tf.AUTO_REUSE): 277 | hidden5, skips, layer_id = self.bottom_part_tower(input_image, action, 278 | reward, latent, 279 | lstm_state, lstm_size, 280 | conv_size) 281 | enc0, enc1 = skips 282 | 283 | enc4 = self.upsample(hidden5, self.shape_list(hidden5)[-1], [2, 2]) 284 | 285 | enc1_shape = self.shape_list(enc1) 286 | enc4 = enc4[:, :enc1_shape[1], :enc1_shape[2], :] # Cut to shape. 287 | 288 | hidden6, lstm_state[layer_id] = lstm_func( 289 | enc4, 290 | lstm_state[layer_id], 291 | lstm_size[5], 292 | name="state6", 293 | spatial_dims=enc1_shape[1:-1]) # 16x16 294 | hidden6 = tfcl.layer_norm(hidden6, scope="layer_norm7") 295 | # Skip connection. 296 | hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 297 | layer_id += 1 298 | 299 | enc5 = self.upsample(hidden6, self.shape_list(hidden6)[-1], [2, 2]) 300 | 301 | enc0_shape = self.shape_list(enc0) 302 | hidden7, lstm_state[layer_id] = lstm_func( 303 | enc5, 304 | lstm_state[layer_id], 305 | lstm_size[6], 306 | name="state7", 307 | spatial_dims=enc0_shape[1:-1]) # 32x32 308 | hidden7 = tfcl.layer_norm(hidden7, scope="layer_norm8") 309 | layer_id += 1 310 | 311 | # Skip connection. 312 | hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 313 | 314 | enc6 = self.upsample(hidden7, self.shape_list(hidden7)[-1], [2, 2]) 315 | enc6 = tfcl.layer_norm(enc6, scope="layer_norm9") 316 | 317 | enc7 = tfl.conv2d_transpose( 318 | enc6, 319 | color_channels, [1, 1], 320 | strides=(1, 1), 321 | padding="SAME", 322 | name="convt4", 323 | activation=None) 324 | # This allows the network to also generate one image from scratch, 325 | # which is useful when regions of the image become unoccluded. 326 | transformed = [tf.nn.sigmoid(enc7)] 327 | 328 | cdna_input = tfl.flatten(hidden5) 329 | transformed += layers.cdna_transformation(input_image, cdna_input, 330 | num_masks, int(color_channels), 331 | self.hparams.dna_kernel_size, 332 | self.hparams.relu_shift) 333 | 334 | masks = tfl.conv2d( 335 | enc6, 336 | filters=num_masks + 1, 337 | kernel_size=[1, 1], 338 | strides=(1, 1), 339 | name="convt7", 340 | padding="SAME") 341 | masks = masks[:, :img_height, :img_width, ...] 342 | shape = tf.stack( 343 | [batch_size, int(img_height), 344 | int(img_width), num_masks + 1]) 345 | masks = tf.reshape( 346 | tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), shape) 347 | mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks) 348 | output = mask_list[0] * input_image 349 | for layer, mask in zip(transformed, mask_list[1:]): 350 | output += layer * mask 351 | 352 | mid_outputs = [enc0, enc1, enc4, enc5, enc6] 353 | return output, lstm_state, mid_outputs 354 | 355 | def gaussian_prior(self, images): 356 | batch_size = tf.shape(images)[0] 357 | # TODO(mbz): this only works for 64x64 image size. 358 | assert images.shape[2] == 64 359 | assert images.shape[3] == 64 360 | shape = tf.stack([batch_size, 8, 8, 1]) 361 | return tf.zeros(shape), tf.zeros(shape) 362 | 363 | def conv_latent_tower(self, images): 364 | """Builds convolutional latent tower for stochastic model. 365 | 366 | At training time this tower generates a latent distribution (mean and std) 367 | conditioned on the entire video. This latent variable will be fed to the 368 | main tower as an extra variable to be used for future frames prediction. 369 | At inference time, the tower is disabled and only returns latents sampled 370 | from N(0,1). 371 | If the multi_latent flag is on, a different latent for every timestep would 372 | be generated. 373 | 374 | Args: 375 | images: tensor of ground truth image sequences 376 | 377 | Returns: 378 | latent_mean: predicted latent mean 379 | latent_logvar: predicted latent log variance 380 | """ 381 | conv_size = [32, 64, 64] 382 | latent_channels = self.hparams.latent_channels 383 | min_logvar = self.hparams.latent_min_logvar 384 | images = tf.concat(tf.unstack(common.to_float(images), axis=1), axis=-1) 385 | with tf.variable_scope("latent", reuse=tf.AUTO_REUSE): 386 | x = images 387 | x = tfl.conv2d( 388 | x, 389 | conv_size[0], [3, 3], 390 | strides=(2, 2), 391 | padding="SAME", 392 | activation=tf.nn.relu, 393 | name="latent_conv1") 394 | x = tfcl.layer_norm(x) 395 | x = tfl.conv2d( 396 | x, 397 | conv_size[1], [3, 3], 398 | strides=(2, 2), 399 | padding="SAME", 400 | activation=tf.nn.relu, 401 | name="latent_conv2") 402 | x = tfcl.layer_norm(x) 403 | x = tfl.conv2d( 404 | x, 405 | conv_size[2], [3, 3], 406 | strides=(1, 1), 407 | padding="SAME", 408 | activation=tf.nn.relu, 409 | name="latent_conv3") 410 | x = tfcl.layer_norm(x) 411 | 412 | nc = latent_channels 413 | mean = tfl.conv2d( 414 | x, 415 | nc, [3, 3], 416 | strides=(2, 2), 417 | padding="SAME", 418 | activation=None, 419 | name="latent_mean") 420 | logv = tfl.conv2d( 421 | x, 422 | nc, [3, 3], 423 | strides=(2, 2), 424 | padding="SAME", 425 | activation=tf.nn.relu, 426 | name="latent_std") 427 | logvar = logv + min_logvar 428 | return mean, logvar 429 | 430 | def scheduled_sample_prob(self, gt_frame, pred_frame): 431 | prob = tf.math.divide_no_nan( 432 | common.to_float(self.get_iteration_num()), 433 | common.to_float(self.hparams.scheduled_sampling_iterations)) 434 | prob = tf.nn.relu(1.0 - prob) 435 | return tf.cond( 436 | tf.math.less_equal(tf.random.uniform([]), prob), lambda: gt_frame, 437 | lambda: pred_frame) 438 | 439 | def build_merged_model(self, all_frames, all_actions, all_rewards, latent): 440 | """Main video processing function.""" 441 | hparams = self.hparams 442 | 443 | res_frames, res_rewards = [], [] 444 | internal_states = [None] * 7 445 | 446 | pred_image = all_frames[0] 447 | pred_reward = all_rewards[0] 448 | for i in range(self.video_len): 449 | cur_action = all_actions[i] 450 | 451 | done_warm_start = (i >= hparams.video_num_input_frames) 452 | if done_warm_start: 453 | if self.is_training: 454 | cur_frame = self.scheduled_sample_prob(all_frames[i], pred_image) 455 | cur_reward = self.scheduled_sample_prob(all_rewards[i], pred_reward) 456 | else: 457 | cur_frame = pred_image 458 | cur_reward = pred_reward 459 | else: 460 | cur_frame = all_frames[i] 461 | cur_reward = all_rewards[i] 462 | 463 | with tf.variable_scope("main", reuse=tf.AUTO_REUSE): 464 | pred_image, internal_states, mids = self.construct_predictive_tower( 465 | cur_frame, cur_action, cur_reward, internal_states, latent) 466 | if hparams.reward_model_stop_gradient: 467 | mids = [tf.stop_gradient(x) for x in mids] 468 | pred_reward = self.reward_prediction(mids) 469 | 470 | res_frames.append(pred_image) 471 | res_rewards.append(pred_reward) 472 | 473 | return [res_frames, res_rewards] 474 | 475 | def build_video_model(self, all_frames, all_actions, latent): 476 | """Main video processing function.""" 477 | hparams = self.hparams 478 | 479 | res_frames = [] 480 | internal_states = [None] * 7 481 | 482 | pred_image = all_frames[0] 483 | for i in range(self.video_len): 484 | cur_action = all_actions[i] 485 | 486 | done_warm_start = (i >= hparams.video_num_input_frames) 487 | if done_warm_start: 488 | if self.is_training: 489 | cur_frame = self.scheduled_sample_prob(all_frames[i], pred_image) 490 | else: 491 | cur_frame = pred_image 492 | else: 493 | cur_frame = all_frames[i] 494 | 495 | with tf.variable_scope("main", reuse=tf.AUTO_REUSE): 496 | pred_image, internal_states, _ = self.construct_predictive_tower( 497 | cur_frame, cur_action, None, internal_states, latent) 498 | res_frames.append(pred_image) 499 | return res_frames 500 | 501 | def build_reward_model(self, frames, rewards): 502 | frames = frames[:self.video_len - 1] 503 | res_rewards = reward_models.reward_prediction_video_conv( 504 | frames, rewards, self.video_len) 505 | return tf.unstack(res_rewards, axis=1) 506 | 507 | def preprocess(self, frames, actions, rewards): 508 | frames = [tf.image.convert_image_dtype(x, tf.float32) for x in frames] 509 | return frames, actions, rewards 510 | 511 | def __body_merged_model(self, frames, actions, rewards, latent): 512 | """Body function.""" 513 | frames, actions, rewards = self.preprocess(frames, actions, rewards) 514 | res_frames, res_rewards = self.build_merged_model(frames, actions, rewards, 515 | latent) 516 | return res_frames, res_rewards 517 | 518 | def __body_separate_model(self, frames, actions, rewards, latent): 519 | """Body function.""" 520 | frames, actions, rewards = self.preprocess(frames, actions, rewards) 521 | res_frames = self.build_video_model(frames, actions, latent) 522 | if self.hparams.reward_model_stop_gradient: 523 | input_frames = [tf.stop_gradient(x) for x in res_frames] 524 | else: 525 | input_frames = res_frames 526 | input_rewards = rewards[:self.hparams.video_num_input_frames] 527 | res_rewards = self.build_reward_model(input_frames, input_rewards) 528 | return res_frames, res_rewards 529 | --------------------------------------------------------------------------------