├── 10_nvidia.json ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── conda_env.yml ├── conda_env_robosuite.yml ├── configs ├── dcs │ └── core.yaml └── robosuite │ ├── core.yaml │ └── core_imageonly.yaml ├── distracting_control ├── README.md ├── background.py ├── camera.py ├── camera_test.py ├── color.py ├── distracting_control_demo.py ├── requirements.txt ├── run.sh ├── suite.py ├── suite_test.py └── suite_utils.py ├── environment_container_dcs.py ├── environment_container_robosuite.py ├── environments.py ├── modules.py ├── replay_buffer.py ├── setup.sh ├── train.py ├── train.sh ├── utils.py ├── videos ├── hard-walker.gif ├── medium-cartpole.gif └── medium-cheetah.gif └── world_model.py /10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | 8 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | Distracting Control Suite 7 | Copyright 2021 The Google Research Authors. 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | 20 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2021 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY: 43 | 44 | This software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoRe: Contrastive Recurrent State-Space Models 2 | 3 | This code implements the CoRe model and reproduces experimental results found in
4 | **Robust Robotic Control from Pixels using Contrastive Recurrent State-Space models**
5 | NeurIPS Deep Reinforcement Learning Workshop 2021
6 | Nitish Srivastava, Walter Talbott, Martin Bertran Lopez, Shuangfei Zhai & Joshua M. Susskind
7 | [[paper](https://arxiv.org/abs/2112.01163)]
8 | 9 | ![cartpole](videos/medium-cartpole.gif) 10 | 11 | ![cheetah](videos/medium-cheetah.gif) 12 | 13 | ![walker](videos/hard-walker.gif) 14 | 15 | 16 | ## Requirements and Installation 17 | Clone this repository and then execute the following steps. See `setup.sh` for an example of how to run these steps on a Ubuntu 18.04 machine. 18 | 19 | * Install dependencies. 20 | ``` 21 | apt install -y libgl1-mesa-dev libgl1-mesa-glx libglew-dev \ 22 | libosmesa6-dev software-properties-common net-tools unzip \ 23 | virtualenv wget xpra xserver-xorg-dev libglfw3-dev patchelf xvfb ffmpeg 24 | ``` 25 | * Download the [DAVIS 2017 26 | dataset](https://davischallenge.org/davis2017/code.html). Make sure to select the 2017 TrainVal - Images and Annotations (480p). The training images will be used as distracting backgrounds. The `DAVIS` directory should be in the same directory as the code. Check that `ls ./DAVIS/JPEGImages/480p/...` shows 90 video directories. 27 | * Install MuJoCo 2.1. 28 | - Download [MuJoCo version 2.1](https://mujoco.org/download) binaries for Linux or macOS. 29 | - Unzip the downloaded `mujoco210` directory into `~/.mujoco/mujoco210`. 30 | * Install MuJoCo 2.0 (For robosuite experiments only). 31 | - Download [MuJoCo version 2.0](https://roboti.us/download.html) binaries for Linux or macOS. 32 | - Unzip the downloaded directory and move it into `~/.mujoco/`. 33 | - Symlink `mujoco200_linux` (or `mujoco200_macos`) to `mujoco200`. 34 | ``` 35 | ln -s ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200 36 | ``` 37 | - Place the [license key](https://roboti.us/license.html) at `~/.mujoco/mjkey.txt`. 38 | - Add the MuJoCo binaries to `LD_LIBRARY_PATH`. 39 | ``` 40 | export LD_LIBRARY_PATH=$HOME/.mujoco/mujoco200/bin:$LD_LIBRARY_PATH 41 | ``` 42 | * Setup EGL GPU rendering (if a GPU is available). 43 | - To ensure that the GPU is prioritized over the CPU for EGL rendering 44 | ``` 45 | cp 10_nvidia.json /usr/share/glvnd/egl_vendor.d/ 46 | ``` 47 | - Create a dummy nvidia directory so that mujoco_py builds the extensions needed for GPU rendering. 48 | ``` 49 | mkdir -p /usr/lib/nvidia-000 50 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia-000 51 | ``` 52 | * Create a conda environment. 53 | 54 | For Distracting Control Suite 55 | ``` 56 | conda env create -f conda_env.yml 57 | ``` 58 | 59 | For Robosuite 60 | ``` 61 | conda env create -f conda_env_robosuite.yml 62 | ``` 63 | 64 | ## Training 65 | 66 | * The CoRe model can be trained on the Distracting Control Suite as follows: 67 | 68 | ``` 69 | conda activate core 70 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py --config configs/dcs/core.yaml 71 | ``` 72 | The training artifacts, including tensorboard logs and videos of validation rollouts will be written in `./artifacts/`. 73 | 74 | To change the distraction setting, modify the `difficulty` parameter in `configs/dcs/core.yaml`. Possible values are `['easy', 'medium', 'hard', 'none', 'hard_bg']`. 75 | 76 | To change the domain, modify the `domain` parameter in `configs/dcs/core.yaml`. Possible values are `['ball_in_cup', 'cartpole', 'cheetah', 'finger', 'reacher', 'walker']`. 77 | 78 | * To train on Robosuite (Door Task, Franka Panda Arm) 79 | 80 | - Using RGB image and proprioceptive inputs. 81 | ``` 82 | conda activate core_robosuite 83 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py --config configs/robosuite/core.yaml 84 | ``` 85 | - Using RGB image inputs only. 86 | ``` 87 | conda activate core_robosuite 88 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py --config configs/robosuite/core_imageonly.yaml 89 | ``` 90 | 91 | ## Citation 92 | ``` 93 | @article{srivastava2021core, 94 | title={Robust Robotic Control from Pixels using Contrastive Recurrent State-Space Models}, 95 | author={Nitish Srivastava and Walter Talbott and Martin Bertran Lopez and Shuangfei Zhai and Josh Susskind}, 96 | journal={NeurIPS Deep Reinforcement Learning Workshop}, 97 | year={2021} 98 | } 99 | ``` 100 | 101 | ## License 102 | This code is released under the [LICENSE](LICENSE) terms. 103 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: core 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - pip=20.1.1 7 | - python=3.7.6 8 | - pytorch=1.10 9 | - torchvision=0.11.0 10 | - cudatoolkit=11.3 11 | - tensorflow=2.2.0 12 | - tensorboard=2.6.0 13 | - absl-py=0.13.0 14 | - pyparsing=2.4.7 15 | - pillow=6.1.0 16 | - h5py=2.10.0 17 | - pyyaml=5.4.1 18 | - moviepy=1.0.1 19 | - numpy=1.20 20 | - conda-forge::matplotlib=3.2.2 21 | - conda-forge::gym=0.19.0 22 | - conda-forge::ffmpeg=4.4.0 23 | - conda-forge::libiconv=1.16 24 | - conda-forge::imageio=2.6.1 25 | - pip: 26 | - termcolor==1.1.0 27 | - dm_control 28 | -------------------------------------------------------------------------------- /conda_env_robosuite.yml: -------------------------------------------------------------------------------- 1 | name: core 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - pip=20.1.1 7 | - python=3.7.6 8 | - pytorch=1.10 9 | - torchvision=0.11.0 10 | - cudatoolkit=11.3 11 | - tensorflow=2.2.0 12 | - tensorboard=2.6.0 13 | - absl-py=0.13.0 14 | - pyparsing=2.4.7 15 | - pillow=6.1.0 16 | - h5py=2.10.0 17 | - pyyaml=5.4.1 18 | - moviepy=1.0.1 19 | - numpy=1.20 20 | - conda-forge::matplotlib=3.2.2 21 | - conda-forge::gym=0.19.0 22 | - conda-forge::ffmpeg=4.4.0 23 | - conda-forge::libiconv=1.16 24 | - conda-forge::imageio=2.6.1 25 | - pip: 26 | - termcolor==1.1.0 27 | - robosuite==1.2.0 28 | -------------------------------------------------------------------------------- /configs/dcs/core.yaml: -------------------------------------------------------------------------------- 1 | parameters: 2 | seed: 42 3 | env: 4 | name: dcs 5 | domain: cheetah 6 | difficulty: hard 7 | dynamic: true 8 | background_dataset_path: DAVIS/JPEGImages/480p 9 | allow_background_distraction: true 10 | allow_camera_distraction: true 11 | allow_color_distraction: true 12 | image_height: 256 13 | image_width: 256 14 | num_frames_to_stack: 1 15 | start_rl_training_after: 0 16 | num_envs: 1 17 | num_val_envs: 5 #20 18 | num_episodes_per_val_env_for_reward: 1 #5 19 | max_steps: 2000000 # This divided by action_repeat is the max updates. 20 | episode_steps: 1000 # This divided by action_repeat is the max steps per episode. 21 | update_frequency_factor: 0.5 22 | initial_data_steps: 1000 23 | gamma: 0.99 24 | contrastive: 25 | inverse_temperature_init: 1.0 26 | softmax_over: both 27 | #mask_type: exclude_other_sequences 28 | include_model_params_in_critic: true 29 | sac_deterministic_state: true 30 | recon_from_prior: true 31 | lr: 3.e-4 32 | lr_inverse_temp: 2.e-3 33 | lr_actor: 1.e-3 34 | lr_critic: 1.e-3 35 | lr_alpha: 1.e-3 36 | momentum_alpha: 0.5 37 | momentum_critic: 0.9 38 | momentum_actor: 0.9 39 | max_grad_norm_wm: 10 40 | max_grad_norm_critic: 100 41 | max_grad_norm_actor: 10 42 | max_grad_norm_log_alpha: 1.0 43 | initial_log_alpha: -2.3 44 | max_log_alpha: 3.0 45 | update_target_critic_after: 1 46 | update_target_critic_tau: 0.005 47 | weight_decay: 0.0 48 | replay_buffer_size: 100000 49 | batch_size: 32 50 | dynamics_seq_len: 32 51 | random_crop: true 52 | random_crop_padding: 0 53 | same_crop_across_time: true 54 | crop_height: 84 55 | crop_width: 84 56 | loss_scales: 57 | eta_fwd: 0.01 58 | eta_r: 1.0 59 | eta_inv: 1.0 60 | eta_q: 0.5 61 | eta_s: 1.0 62 | wm: 63 | sample_state: true 64 | reward_prediction_from_prior: false 65 | decode_deterministic: false 66 | propagate_deterministic: false 67 | obs_encoder: # Input dims determined by binder symbol dims, or observation model's output dims. 68 | fc_activation: elu 69 | fc_hiddens: [256, 256] 70 | output_activation: identity 71 | layer_norm_output: true 72 | obs_decoder: 73 | fc_activation: elu 74 | fc_hiddens: [256, 256] # Will add another layer to make the output be be the same dim as enocder's input. 75 | output_activation: identity 76 | layer_norm_output: true 77 | reward_net: 78 | fc_activation: elu 79 | fc_hiddens: [128, 128, 1] 80 | output_activation: identity 81 | inverse_dynamics: 82 | fc_activation: elu 83 | fc_hiddens: [128, 128] 84 | output_activation: tanh 85 | dynamics: 86 | discrete: false 87 | rnn_state_dims: 200 88 | latent_dims: 64 89 | forward_dynamics_loss: kl #neg_log_prob 90 | free_kl: 1.0 91 | recurrent: true 92 | rnn_input: # Takes as input latent_dims + actions. The output of this is input to GRU. 93 | fc_activation: elu 94 | fc_hiddens: [400, 400] 95 | output_activation: elu 96 | posterior: # Takes as input rnn_state_dims + obs_embed 97 | fc_activation: elu 98 | fc_hiddens: [400, 400] 99 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax 100 | layer_norm_output: true 101 | layer_norm_affine: true 102 | prior: # takes as input rnn_state_dims 103 | fc_activation: elu 104 | fc_hiddens: [400, 400, 400] 105 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax 106 | layer_norm_output: true 107 | layer_norm_affine: true 108 | use_gating_network: false 109 | encoder_gating: 110 | conv_batch_norm: false 111 | conv_activation: elu 112 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this. 113 | conv_filters: 114 | - [16, [5, 5], 1, [2, 2]] # (64, 64) 115 | - [16, [5, 5], 1, [2, 2]] # (64, 64) 116 | - [16, [5, 5], 1, [2, 2]] # (64, 64) 117 | - [1, [1, 1], 1, [0, 0]] # (64, 64) 118 | encoder: 119 | conv_batch_norm: false 120 | conv_activation: relu 121 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this. 122 | conv_filters: 123 | - [32, [3, 3], 2, [0, 0]] # (31, 31) 124 | - [32, [3, 3], 1, [0, 0]] # (29, 29) 125 | - [32, [3, 3], 1, [0, 0]] # (27, 27) 126 | - [32, [3, 3], 1, [0, 0]] # (25, 25) 127 | fc_hiddens: [50] 128 | fc_activation: relu 129 | fc_batch_norm: false 130 | output_activation: identity 131 | layer_norm_output: true 132 | layer_norm_affine: false 133 | double_output_dims_for_std: false 134 | actor: 135 | fc_activation: relu 136 | fc_hiddens: [1024, 1024, 1024] 137 | output_activation: identity 138 | policy_max_logstd: 2 139 | policy_min_logstd: -10 140 | critic: 141 | fc_activation: relu 142 | fc_hiddens: [1024, 1024, 1024, 1] 143 | output_activation: identity 144 | validate_every_iters: 10 145 | print_every: 50 146 | -------------------------------------------------------------------------------- /configs/robosuite/core.yaml: -------------------------------------------------------------------------------- 1 | parameters: 2 | seed: 42 3 | env: 4 | name: robosuite 5 | image_key: frontview_image 6 | other_key: robot0_proprio-state 7 | crop_image: true 8 | crop_width: 144 9 | crop_height: 144 10 | crop_center_xy: [104, 144] # For the Door task. 11 | #crop_center_xy: [128, 128] 12 | domain_randomize: true 13 | domain_randomization_config: 14 | randomize_color: true 15 | randomize_camera: true 16 | randomize_lighting: true 17 | randomize_dynamics: false 18 | #if True, randomize on every call to @reset. This, in 19 | # conjunction with setting @randomize_every_n_steps to 0, is useful to 20 | # generate a new domain per episode. 21 | randomize_on_reset: true 22 | randomize_every_n_steps: 0 23 | controller: OSC_POSE 24 | robosuite_config: 25 | control_freq: 20 26 | env_name: Door 27 | hard_reset: false 28 | horizon: 500 29 | ignore_done: true 30 | reward_scale: 1.0 31 | reward_shaping: true 32 | robots: [ Panda ] 33 | has_renderer: false 34 | has_offscreen_renderer: true 35 | use_object_obs: false 36 | use_camera_obs: true 37 | # ('frontview', 'birdview', 'agentview', 'sideview', 'robot0_robotview', 'robot0_eye_in_hand') 38 | camera_names: frontview 39 | camera_heights: 256 40 | camera_widths: 256 41 | camera_depths: false 42 | start_rl_training_after: 0 43 | num_envs: 1 44 | num_val_envs: 10 45 | num_episodes_per_val_env_for_reward: 2 46 | max_steps: 1000000 # This divided by action_repeat is the max updates. 47 | episode_steps: 500 # This divided by action_repeat is the max steps per episode. 48 | update_frequency_factor: 0.5 49 | initial_data_steps: 1000 50 | gamma: 0.99 51 | contrastive: 52 | inverse_temperature_init: 1.0 53 | softmax_over: both 54 | #mask_type: exclude_other_sequences 55 | include_model_params_in_critic: true 56 | sac_deterministic_state: true 57 | recon_from_prior: true 58 | lr: 3.e-4 59 | lr_inverse_temp: 2.e-3 60 | lr_actor: 1.e-3 61 | lr_critic: 1.e-3 62 | lr_alpha: 1.e-3 63 | max_grad_norm_wm: 10 64 | max_grad_norm_critic: 100 65 | max_grad_norm_actor: 10 66 | max_grad_norm_log_alpha: 1.0 67 | initial_log_alpha: -2.3 68 | max_log_alpha: 3.0 69 | update_target_critic_after: 1 70 | update_target_critic_tau: 0.005 71 | weight_decay: 1.e-8 72 | replay_buffer_size: 200000 73 | batch_size: 32 74 | dynamics_seq_len: 32 75 | random_crop: true 76 | random_crop_padding: 0 77 | same_crop_across_time: true 78 | crop_height: 128 79 | crop_width: 128 80 | rollout_prior_init_t: 10 81 | loss_scales: 82 | eta_fwd: 0.01 83 | eta_r: 1.0 84 | eta_inv: 1.0 85 | eta_q: 0.5 86 | eta_s: 1.0 87 | wm: 88 | reward_prediction_from_prior: false 89 | decode_deterministic: false 90 | propagate_deterministic: false 91 | obs_encoder: # Input dims determined by binder symbol dims, or observation model's output dims. 92 | fc_activation: elu 93 | fc_hiddens: [256, 256] 94 | output_activation: identity 95 | layer_norm_output: true 96 | obs_decoder: 97 | fc_activation: elu 98 | fc_hiddens: [256, 256] # Will add another layer to make the output be be the same dim as enocder's input. 99 | output_activation: identity 100 | layer_norm_output: true 101 | reward_net: 102 | fc_activation: elu 103 | fc_hiddens: [128, 128, 1] 104 | output_activation: identity 105 | inverse_dynamics: 106 | fc_activation: elu 107 | fc_hiddens: [128, 128] 108 | output_activation: tanh 109 | dynamics: 110 | discrete: false 111 | rnn_state_dims: 200 112 | latent_dims: 64 113 | forward_dynamics_loss: kl #neg_log_prob 114 | free_kl: 1.0 115 | recurrent: true 116 | rnn_input: # Takes as input latent_dims + actions. The output of this is input to GRU. 117 | fc_activation: elu 118 | fc_hiddens: [400, 400] 119 | output_activation: elu 120 | posterior: # Takes as input rnn_state_dims + obs_embed 121 | fc_activation: elu 122 | fc_hiddens: [400, 400] 123 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax 124 | layer_norm_output: true 125 | layer_norm_affine: true 126 | prior: # takes as input rnn_state_dims 127 | fc_activation: elu 128 | fc_hiddens: [400, 400, 400] 129 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax 130 | layer_norm_output: true 131 | layer_norm_affine: true 132 | encoder: 133 | conv_batch_norm: false 134 | conv_activation: relu 135 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this. 136 | conv_filters: 137 | - [32, [3, 3], 2, [1, 1]] # (64) 138 | - [32, [3, 3], 2, [1, 1]] # (32) 139 | - [32, [3, 3], 1, [0, 0]] # (30) 140 | - [32, [3, 3], 1, [0, 0]] # (28) 141 | - [32, [3, 3], 1, [0, 0]] # (26) 142 | fc_hiddens: [50] 143 | fc_activation: relu 144 | fc_batch_norm: false 145 | output_activation: identity 146 | layer_norm_output: true 147 | layer_norm_affine: false 148 | other_obs_model: 149 | fc_activation: elu 150 | fc_hiddens: [200, 200, 200] 151 | output_activation: elu 152 | actor: 153 | fc_activation: relu 154 | fc_hiddens: [1024, 1024, 1024] 155 | output_activation: identity 156 | policy_max_logstd: 2 157 | policy_min_logstd: -10 158 | critic: 159 | fc_activation: relu 160 | fc_hiddens: [1024, 1024, 1024, 1] 161 | output_activation: identity 162 | validate_every_iters: 10 163 | print_every: 50 164 | -------------------------------------------------------------------------------- /configs/robosuite/core_imageonly.yaml: -------------------------------------------------------------------------------- 1 | parameters: 2 | seed: 42 3 | env: 4 | name: robosuite 5 | image_key: frontview_image 6 | crop_image: true 7 | crop_width: 144 8 | crop_height: 144 9 | crop_center_xy: [104, 144] # For the Door task center the crop at this location. 10 | #crop_center_xy: [128, 128] 11 | domain_randomize: true 12 | domain_randomization_config: 13 | randomize_color: true 14 | randomize_camera: true 15 | randomize_lighting: true 16 | randomize_dynamics: false 17 | #if True, randomize on every call to @reset. This, in 18 | # conjunction with setting @randomize_every_n_steps to 0, is useful to 19 | # generate a new domain per episode. 20 | randomize_on_reset: true 21 | randomize_every_n_steps: 0 22 | controller: OSC_POSE 23 | robosuite_config: 24 | control_freq: 20 25 | env_name: Door 26 | hard_reset: false 27 | horizon: 500 28 | ignore_done: true 29 | reward_scale: 1.0 30 | reward_shaping: true 31 | robots: [ Panda ] 32 | has_renderer: false 33 | has_offscreen_renderer: true 34 | use_object_obs: false 35 | use_camera_obs: true 36 | # ('frontview', 'birdview', 'agentview', 'sideview', 'robot0_robotview', 'robot0_eye_in_hand') 37 | camera_names: frontview 38 | camera_heights: 256 39 | camera_widths: 256 40 | camera_depths: false 41 | start_rl_training_after: 0 42 | num_envs: 1 43 | num_val_envs: 10 44 | num_episodes_per_val_env_for_reward: 2 45 | max_steps: 1000000 # This divided by action_repeat is the max updates. 46 | episode_steps: 500 # This divided by action_repeat is the max steps per episode. 47 | update_frequency_factor: 0.5 48 | initial_data_steps: 1000 49 | gamma: 0.99 50 | contrastive: 51 | inverse_temperature_init: 1.0 52 | softmax_over: both 53 | #mask_type: exclude_same_sequence 54 | #mask_type: exclude_other_sequences 55 | include_model_params_in_critic: true 56 | sac_deterministic_state: true 57 | recon_from_prior: true 58 | lr: 3.e-4 59 | lr_inverse_temp: 2.e-3 60 | lr_actor: 1.e-3 61 | lr_critic: 1.e-3 62 | lr_alpha: 1.e-3 63 | max_grad_norm_wm: 10 64 | max_grad_norm_critic: 100 65 | max_grad_norm_actor: 10 66 | max_grad_norm_log_alpha: 1.0 67 | initial_log_alpha: -2.3 68 | max_log_alpha: 3.0 69 | update_target_critic_after: 1 70 | update_target_critic_tau: 0.005 71 | weight_decay: 0.0 72 | replay_buffer_size: 100000 73 | batch_size: 32 74 | dynamics_seq_len: 32 75 | random_crop: true 76 | random_crop_padding: 0 77 | same_crop_across_time: true 78 | crop_height: 128 79 | crop_width: 128 80 | rollout_prior_init_t: 10 81 | loss_scales: 82 | eta_fwd: 0.01 83 | eta_r: 1.0 84 | eta_inv: 1.0 85 | eta_q: 0.5 86 | eta_s: 1.0 87 | wm: 88 | sample_state: true 89 | reward_prediction_from_prior: false 90 | decode_deterministic: false 91 | propagate_deterministic: false 92 | obs_encoder: # Input dims determined by binder symbol dims, or observation model's output dims. 93 | fc_activation: elu 94 | fc_hiddens: [256, 256] 95 | output_activation: identity 96 | layer_norm_output: true 97 | obs_decoder: 98 | fc_activation: elu 99 | fc_hiddens: [256, 256] # Will add another layer to make the output be be the same dim as enocder's input. 100 | output_activation: identity 101 | layer_norm_output: true 102 | reward_net: 103 | fc_activation: elu 104 | fc_hiddens: [128, 128, 1] 105 | output_activation: identity 106 | inverse_dynamics: 107 | fc_activation: elu 108 | fc_hiddens: [128, 128] 109 | output_activation: tanh 110 | dynamics: 111 | discrete: false 112 | rnn_state_dims: 200 113 | latent_dims: 64 114 | forward_dynamics_loss: kl #neg_log_prob 115 | free_kl: 1.0 116 | recurrent: true 117 | rnn_input: # Takes as input latent_dims + actions. The output of this is input to GRU. 118 | fc_activation: elu 119 | fc_hiddens: [400, 400] 120 | output_activation: elu 121 | posterior: # Takes as input rnn_state_dims + obs_embed 122 | fc_activation: elu 123 | fc_hiddens: [400, 400] 124 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax 125 | layer_norm_output: true 126 | layer_norm_affine: true 127 | prior: # takes as input rnn_state_dims 128 | fc_activation: elu 129 | fc_hiddens: [400, 400, 400] 130 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax 131 | layer_norm_output: true 132 | layer_norm_affine: true 133 | encoder: 134 | conv_batch_norm: false 135 | conv_activation: relu 136 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this. 137 | conv_filters: 138 | - [32, [3, 3], 2, [1, 1]] # (64) 139 | - [32, [3, 3], 2, [1, 1]] # (32) 140 | - [32, [3, 3], 1, [0, 0]] # (30) 141 | - [32, [3, 3], 1, [0, 0]] # (28) 142 | - [32, [3, 3], 1, [0, 0]] # (26) 143 | #conv_filters: 144 | # - [1, [4, 4], 2] # (64) # (112) 145 | # - [2, [4, 4], 2] # (32) # (56) 146 | # - [2, [4, 4], 2] # (16) #(28) 147 | # - [2, [4, 4], 2] # (8) # (14) 148 | # - [4, [4, 4], 2] # (4) # (7) 149 | fc_hiddens: [50] 150 | fc_activation: relu 151 | fc_batch_norm: false 152 | output_activation: identity 153 | layer_norm_output: true 154 | layer_norm_affine: false 155 | actor: 156 | fc_activation: relu 157 | fc_hiddens: [1024, 1024, 1024] 158 | output_activation: identity 159 | policy_max_logstd: 2 160 | policy_min_logstd: -10 161 | critic: 162 | fc_activation: relu 163 | fc_hiddens: [1024, 1024, 1024, 1] 164 | output_activation: identity 165 | validate_every_iters: 10 166 | print_every: 50 167 | -------------------------------------------------------------------------------- /distracting_control/README.md: -------------------------------------------------------------------------------- 1 | # The Distracting Control Suite 2 | 3 | `distracting_control` extends `dm_control` with static or dynamic visual 4 | distractions in the form of changing colors, backgrounds, and camera poses. 5 | Details and experimental results can be found in our 6 | [paper](https://arxiv.org/pdf/2101.02722.pdf). 7 | 8 | ## Requirements and Installation 9 | 10 | * Clone this repository 11 | * `sh run.sh` 12 | * Follow the instructions and install 13 | [dm_control](https://github.com/deepmind/dm_control#requirements-and-installation). Make sure you setup your MuJoCo keys correctly. 14 | * Download the [DAVIS 2017 15 | dataset](https://davischallenge.org/davis2017/code.html). Make sure to select the 2017 TrainVal - Images and Annotations (480p). The training images will be used as distracting backgrounds. 16 | 17 | ## Instructions 18 | 19 | * You can run the `distracting_control_demo` to generate sample images of the 20 | different tasks at different difficulties: 21 | 22 | ``` 23 | python distracting_control_demo --davis_path=$HOME/DAVIS/JPEGImages/480p/ 24 | --output_dir=/tmp/distrtacting_control_demo 25 | ``` 26 | * As seen from the demo to generate an instance of the environment you simply 27 | need to import the suite and use `suite.load` while specifying the 28 | `dm_control` domain and task, then choosing a difficulty and providing the 29 | dataset_path. 30 | 31 | * Note the environment follows the dm_control environment APIs. 32 | 33 | ## Paper 34 | 35 | If you use this code, please cite the accompanying [paper](https://arxiv.org/pdf/2101.02722.pdf) as: 36 | 37 | ``` 38 | @article{stone2021distracting, 39 | title={The Distracting Control Suite -- A Challenging Benchmark for Reinforcement Learning from Pixels}, 40 | author={Austin Stone and Oscar Ramirez and Kurt Konolige and Rico Jonschkowski}, 41 | year={2021}, 42 | journal={arXiv preprint arXiv:2101.02722}, 43 | } 44 | ``` 45 | 46 | ## Disclaimer 47 | 48 | This is not an official Google product. 49 | -------------------------------------------------------------------------------- /distracting_control/background.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """A wrapper for dm_control environments which applies color distractions.""" 18 | import os 19 | 20 | from PIL import Image 21 | import collections 22 | from dm_control.rl import control 23 | import numpy as np 24 | 25 | import tensorflow as tf 26 | from dm_control.mujoco.wrapper import mjbindings 27 | 28 | DAVIS17_TRAINING_VIDEOS = [ 29 | 'bear', 'bmx-bumps', 'boat', 'boxing-fisheye', 'breakdance-flare', 'bus', 30 | 'car-turn', 'cat-girl', 'classic-car', 'color-run', 'crossing', 31 | 'dance-jump', 'dancing', 'disc-jockey', 'dog-agility', 'dog-gooses', 32 | 'dogs-scale', 'drift-turn', 'drone', 'elephant', 'flamingo', 'hike', 33 | 'hockey', 'horsejump-low', 'kid-football', 'kite-walk', 'koala', 34 | 'lady-running', 'lindy-hop', 'longboard', 'lucia', 'mallard-fly', 35 | 'mallard-water', 'miami-surf', 'motocross-bumps', 'motorbike', 'night-race', 36 | 'paragliding', 'planes-water', 'rallye', 'rhino', 'rollerblade', 37 | 'schoolgirls', 'scooter-board', 'scooter-gray', 'sheep', 'skate-park', 38 | 'snowboard', 'soccerball', 'stroller', 'stunt', 'surf', 'swing', 'tennis', 39 | 'tractor-sand', 'train', 'tuk-tuk', 'upside-down', 'varanus-cage', 'walking' 40 | ] 41 | DAVIS17_VALIDATION_VIDEOS = [ 42 | 'bike-packing', 'blackswan', 'bmx-trees', 'breakdance', 'camel', 43 | 'car-roundabout', 'car-shadow', 'cows', 'dance-twirl', 'dog', 'dogs-jump', 44 | 'drift-chicane', 'drift-straight', 'goat', 'gold-fish', 'horsejump-high', 45 | 'india', 'judo', 'kite-surf', 'lab-coat', 'libby', 'loading', 'mbike-trick', 46 | 'motocross-jump', 'paragliding-launch', 'parkour', 'pigs', 'scooter-black', 47 | 'shooting', 'soapbox' 48 | ] 49 | SKY_TEXTURE_INDEX = 0 50 | Texture = collections.namedtuple('Texture', ('size', 'address', 'textures')) 51 | 52 | 53 | def imread(filename): 54 | img = Image.open(filename) 55 | img_np = np.asarray(img) 56 | return img_np 57 | 58 | 59 | def size_and_flatten(image, ref_height, ref_width): 60 | # Resize image if necessary and flatten the result. 61 | image_height, image_width = image.shape[:2] 62 | 63 | if image_height != ref_height or image_width != ref_width: 64 | image = tf.cast(tf.image.resize(image, [ref_height, ref_width]), tf.uint8) 65 | return tf.reshape(image, [-1]).numpy() 66 | 67 | 68 | def blend_to_background(alpha, image, background): 69 | if alpha == 1.0: 70 | return image 71 | elif alpha == 0.0: 72 | return background 73 | else: 74 | return (alpha * image.astype(np.float32) 75 | + (1. - alpha) * background.astype(np.float32)).astype(np.uint8) 76 | 77 | 78 | class DistractingBackgroundEnv(control.Environment): 79 | """Environment wrapper for background visual distraction. 80 | 81 | **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure 82 | the background image changes are applied before rendering occurs. 83 | """ 84 | 85 | def __init__(self, 86 | env, 87 | dataset_path=None, 88 | dataset_videos=None, 89 | video_alpha=1.0, 90 | ground_plane_alpha=1.0, 91 | num_videos=None, 92 | dynamic=False, 93 | seed=None, 94 | shuffle_buffer_size=None): 95 | 96 | if not 0 <= video_alpha <= 1: 97 | raise ValueError('`video_alpha` must be in the range [0, 1]') 98 | 99 | self._env = env 100 | self._video_alpha = video_alpha 101 | self._ground_plane_alpha = ground_plane_alpha 102 | self._random_state = np.random.RandomState(seed=seed) 103 | self._dynamic = dynamic 104 | self._shuffle_buffer_size = shuffle_buffer_size 105 | self._background = None 106 | self._current_img_index = 0 107 | 108 | if not dataset_path or num_videos == 0: 109 | # Allow running the wrapper without backgrounds to still set the ground 110 | # plane alpha value. 111 | self._video_paths = [] 112 | else: 113 | # Use all videos if no specific ones were passed. 114 | if not dataset_videos: 115 | dataset_videos = sorted(tf.io.gfile.listdir(dataset_path)) 116 | # Replace video placeholders 'train'/'val' with the list of videos. 117 | elif dataset_videos in ['train', 'training']: 118 | dataset_videos = DAVIS17_TRAINING_VIDEOS 119 | elif dataset_videos in ['val', 'validation']: 120 | dataset_videos = DAVIS17_VALIDATION_VIDEOS 121 | # Get complete paths for all videos. 122 | video_paths = [ 123 | os.path.join(dataset_path, subdir) for subdir in dataset_videos 124 | ] 125 | 126 | # Optionally use only the first num_paths many paths. 127 | if num_videos is not None: 128 | if num_videos > len(video_paths) or num_videos < 0: 129 | raise ValueError(f'`num_bakground_paths` is {num_videos} but ' 130 | 'should not be larger than the number of available ' 131 | f'background paths ({len(video_paths)}) and at ' 132 | 'least 0.') 133 | video_paths = video_paths[:num_videos] 134 | 135 | self._video_paths = video_paths 136 | 137 | def reset(self): 138 | """Reset the background state.""" 139 | time_step = self._env.reset() 140 | self._reset_background() 141 | return time_step 142 | 143 | def _reset_background(self): 144 | # Make grid semi-transparent. 145 | if self._ground_plane_alpha is not None: 146 | self._env.physics.named.model.mat_rgba['grid', 147 | 'a'] = self._ground_plane_alpha 148 | 149 | # For some reason the height of the skybox is set to 4800 by default, 150 | # which does not work with new textures. 151 | self._env.physics.model.tex_height[SKY_TEXTURE_INDEX] = 800 152 | 153 | # Set the sky texture reference. 154 | sky_height = self._env.physics.model.tex_height[SKY_TEXTURE_INDEX] 155 | sky_width = self._env.physics.model.tex_width[SKY_TEXTURE_INDEX] 156 | sky_size = sky_height * sky_width * 3 157 | sky_address = self._env.physics.model.tex_adr[SKY_TEXTURE_INDEX] 158 | 159 | sky_texture = self._env.physics.model.tex_rgb[sky_address:sky_address + 160 | sky_size].astype(np.float32) 161 | 162 | if self._video_paths: 163 | 164 | if self._shuffle_buffer_size: 165 | # Shuffle images from all videos together to get background frames. 166 | file_names = [ 167 | os.path.join(path, fn) 168 | for path in self._video_paths 169 | for fn in tf.io.gfile.listdir(path) 170 | ] 171 | self._random_state.shuffle(file_names) 172 | # Load only the first n images for performance reasons. 173 | file_names = file_names[:self._shuffle_buffer_size] 174 | images = [imread(fn) for fn in file_names] 175 | else: 176 | # Randomly pick a video and load all images. 177 | video_path = self._random_state.choice(self._video_paths) 178 | file_names = tf.io.gfile.listdir(video_path) 179 | if not self._dynamic: 180 | # Randomly pick a single static frame. 181 | file_names = [self._random_state.choice(file_names)] 182 | images = [imread(os.path.join(video_path, fn)) for fn in file_names] 183 | 184 | # Pick a random starting point and steping direction. 185 | self._current_img_index = self._random_state.choice(len(images)) 186 | self._step_direction = self._random_state.choice([-1, 1]) 187 | 188 | # Prepare images in the texture format by resizing and flattening. 189 | 190 | # Generate image textures. 191 | texturized_images = [] 192 | for image in images: 193 | image_flattened = size_and_flatten(image, sky_height, sky_width) 194 | new_texture = blend_to_background(self._video_alpha, image_flattened, 195 | sky_texture) 196 | texturized_images.append(new_texture) 197 | 198 | else: 199 | 200 | self._current_img_index = 0 201 | texturized_images = [sky_texture] 202 | 203 | self._background = Texture(sky_size, sky_address, texturized_images) 204 | self._apply() 205 | 206 | def step(self, action): 207 | time_step = self._env.step(action) 208 | 209 | if time_step.first(): 210 | self._reset_background() 211 | return time_step 212 | 213 | if self._dynamic and self._video_paths: 214 | # Move forward / backward in the image sequence by updating the index. 215 | self._current_img_index += self._step_direction 216 | 217 | # Start moving forward if we are past the start of the images. 218 | if self._current_img_index <= 0: 219 | self._current_img_index = 0 220 | self._step_direction = abs(self._step_direction) 221 | # Start moving backwards if we are past the end of the images. 222 | if self._current_img_index >= len(self._background.textures): 223 | self._current_img_index = len(self._background.textures) - 1 224 | self._step_direction = -abs(self._step_direction) 225 | 226 | self._apply() 227 | return time_step 228 | 229 | def _apply(self): 230 | """Apply the background texture to the physics.""" 231 | 232 | if self._background: 233 | start = self._background.address 234 | end = self._background.address + self._background.size 235 | texture = self._background.textures[self._current_img_index] 236 | 237 | self._env.physics.model.tex_rgb[start:end] = texture 238 | # Upload the new texture to the GPU. Note: we need to make sure that the 239 | # OpenGL context belonging to this Physics instance is the current one. 240 | with self._env.physics.contexts.gl.make_current() as ctx: 241 | ctx.call( 242 | mjbindings.mjlib.mjr_uploadTexture, 243 | self._env.physics.model.ptr, 244 | self._env.physics.contexts.mujoco.ptr, 245 | SKY_TEXTURE_INDEX, 246 | ) 247 | 248 | # Forward property and method calls to self._env. 249 | def __getattr__(self, attr): 250 | if hasattr(self._env, attr): 251 | return getattr(self._env, attr) 252 | raise AttributeError("'{}' object has no attribute '{}'".format( 253 | type(self).__name__, attr)) 254 | -------------------------------------------------------------------------------- /distracting_control/camera.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """A wrapper for dm_control environments which applies camera distractions.""" 18 | 19 | import copy 20 | from dm_control.rl import control 21 | import numpy as np 22 | 23 | CAMERA_MODES = ['fixed', 'track', 'trackcom', 'targetbody', 'targetbodycom'] 24 | 25 | 26 | def eul2mat(theta): 27 | """Converts euler angles (x, y, z) to a rotation matrix.""" 28 | 29 | return np.array([[ 30 | np.cos(theta[1]) * np.cos(theta[2]), 31 | np.sin(theta[0]) * np.sin(theta[1]) * np.cos(theta[2]) - 32 | np.sin(theta[2]) * np.cos(theta[0]), 33 | np.sin(theta[1]) * np.cos(theta[0]) * np.cos(theta[2]) + 34 | np.sin(theta[0]) * np.sin(theta[2]) 35 | ], 36 | [ 37 | np.sin(theta[2]) * np.cos(theta[1]), 38 | np.sin(theta[0]) * np.sin(theta[1]) * np.sin(theta[2]) + 39 | np.cos(theta[0]) * np.cos(theta[2]), 40 | np.sin(theta[1]) * np.sin(theta[2]) * np.cos(theta[0]) - 41 | np.sin(theta[0]) * np.cos(theta[2]) 42 | ], 43 | [ 44 | -np.sin(theta[1]), 45 | np.sin(theta[0]) * np.cos(theta[1]), 46 | np.cos(theta[0]) * np.cos(theta[1]) 47 | ]]) 48 | 49 | 50 | def _mat_from_theta(cos_theta, sin_theta, a): 51 | """Builds a rotation matrix from theta and an orientation vector.""" 52 | 53 | row1 = [ 54 | cos_theta + a[0]**2. * (1. - cos_theta), 55 | a[0] * a[1] * (1 - cos_theta) - a[2] * sin_theta, 56 | a[0] * a[2] * (1 - cos_theta) + a[1] * sin_theta 57 | ] 58 | row2 = [ 59 | a[1] * a[0] * (1 - cos_theta) + a[2] * sin_theta, 60 | cos_theta + a[1]**2. * (1 - cos_theta), 61 | a[1] * a[2] * (1. - cos_theta) - a[0] * sin_theta 62 | ] 63 | row3 = [ 64 | a[2] * a[0] * (1. - cos_theta) - a[1] * sin_theta, 65 | a[2] * a[1] * (1. - cos_theta) + a[0] * sin_theta, 66 | cos_theta + (a[2]**2.) * (1. - cos_theta) 67 | ] 68 | return np.stack([row1, row2, row3]) 69 | 70 | 71 | def rotvec2mat(theta, vec): 72 | """Converts a rotation around a vector to a rotation matrix.""" 73 | 74 | a = vec / np.sqrt(np.sum(vec**2.)) 75 | sin_theta = np.sin(theta) 76 | cos_theta = np.cos(theta) 77 | 78 | return _mat_from_theta(cos_theta, sin_theta, a) 79 | 80 | 81 | def get_lookat_xmat_no_roll(agent_pos, camera_pos): 82 | """Solves for the cam rotation centering the agent with 0 roll.""" 83 | 84 | # NOTE(austinstone): This method leads to wild oscillations around the north 85 | # and south polls. 86 | # For example, if agent is at (0., 0., 0.) and the camera is at (.01, 0., 1.), 87 | # this will produce a yaw of 90 degrees whereas if the camera is slightly 88 | # adjacent at (-.01, 0., 1.) this will produce a yaw of -90 degrees. I'm 89 | # not sure what the fix is, as this seems like the behavior we want in all 90 | # environments except for reacher. 91 | delta_vec = agent_pos - camera_pos 92 | delta_vec /= np.sqrt(np.sum(delta_vec**2.)) 93 | yaw = np.arctan2(delta_vec[0], delta_vec[1]) 94 | pitch = np.arctan2(delta_vec[2], np.sqrt(np.sum(delta_vec[:2]**2.))) 95 | pitch += np.pi / 2. # Camera starts out looking at [0, 0, -1.] 96 | return eul2mat([pitch, 0., -yaw]).flatten() 97 | 98 | 99 | def get_lookat_xmat(agent_pos, camera_pos): 100 | """Solves for the cam rotation centering the agent, allowing roll.""" 101 | 102 | # Solve for the rotation which centers the agent in the scene. 103 | delta_vec = agent_pos - camera_pos 104 | delta_vec /= np.sqrt(np.sum(delta_vec**2.)) 105 | y_vec = np.array([0., 0., -1.]) # This is where the cam starts from. 106 | a = np.cross(y_vec, delta_vec) 107 | sin_theta = np.sqrt(np.sum(a**2.)) 108 | cos_theta = np.dot(delta_vec, y_vec) 109 | a /= (np.sqrt(np.sum(a**2.)) + .0001) 110 | return _mat_from_theta(cos_theta, sin_theta, a) 111 | 112 | 113 | def cart2sphere(cart): 114 | r = np.sqrt(np.sum(cart**2.)) 115 | h_angle = np.arctan2(cart[1], cart[0]) 116 | v_angle = np.arctan2(np.sqrt(np.sum(cart[:2]**2.)), cart[2]) 117 | return np.array([r, h_angle, v_angle]) 118 | 119 | 120 | def sphere2cart(sphere): 121 | r, h_angle, v_angle = sphere 122 | x = r * np.sin(v_angle) * np.cos(h_angle) 123 | y = r * np.sin(v_angle) * np.sin(h_angle) 124 | z = r * np.cos(v_angle) 125 | return np.array([x, y, z]) 126 | 127 | 128 | def clip_cam_position(position, min_radius, max_radius, min_h_angle, 129 | max_h_angle, min_v_angle, max_v_angle): 130 | new_position = [-1., -1., -1.] 131 | new_position[0] = np.clip(position[0], min_radius, max_radius) 132 | new_position[1] = np.clip(position[1], min_h_angle, max_h_angle) 133 | new_position[2] = np.clip(position[2], min_v_angle, max_v_angle) 134 | return new_position 135 | 136 | 137 | def get_lookat_point(physics, camera_id): 138 | """Get the point that the camera is looking at. 139 | 140 | It is assumed that the "point" the camera looks at the agent distance 141 | away and projected along the camera viewing matrix. 142 | 143 | Args: 144 | physics: mujoco physics objects 145 | camera_id: int 146 | 147 | Returns: 148 | position: float32 np.array of length 3 149 | """ 150 | dist_to_agent = physics.named.data.cam_xpos[ 151 | camera_id] - physics.named.data.subtree_com[1] 152 | dist_to_agent = np.sqrt(np.sum(dist_to_agent**2.)) 153 | initial_viewing_mat = copy.deepcopy(physics.named.data.cam_xmat[camera_id]) 154 | initial_viewing_mat = np.reshape(initial_viewing_mat, (3, 3)) 155 | z_vec = np.array([0., 0., -dist_to_agent]) 156 | rotated_vec = np.dot(initial_viewing_mat, z_vec) 157 | return rotated_vec + physics.named.data.cam_xpos[camera_id] 158 | 159 | 160 | class DistractingCameraEnv(control.Environment): 161 | """Environment wrapper for camera pose visual distraction. 162 | 163 | **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure 164 | the camera pose changes are applied before rendering occurs. 165 | """ 166 | 167 | def __init__(self, 168 | env, 169 | camera_id, 170 | horizontal_delta, 171 | vertical_delta, 172 | max_vel, 173 | vel_std, 174 | roll_delta, 175 | max_roll_vel, 176 | roll_std, 177 | max_zoom_in_percent, 178 | max_zoom_out_percent, 179 | limit_to_upper_quadrant=False, 180 | seed=None): 181 | self._env = env 182 | self._camera_id = camera_id 183 | self._horizontal_delta = horizontal_delta 184 | self._vertical_delta = vertical_delta 185 | 186 | self._horizontal_delta = horizontal_delta 187 | self._vertical_delta = vertical_delta 188 | self._max_vel = max_vel 189 | self._vel_std = vel_std 190 | self._roll_delta = roll_delta 191 | self._max_roll_vel = max_roll_vel 192 | self._roll_vel_std = roll_std 193 | self._max_zoom_in_percent = max_zoom_in_percent 194 | self._max_zoom_out_percent = max_zoom_out_percent 195 | self._limit_to_upper_quadrant = limit_to_upper_quadrant 196 | 197 | self._random_state = np.random.RandomState(seed=seed) 198 | 199 | # These camera state parameters will be set on the first reset call. 200 | self._camera_type = None 201 | self._camera_initial_lookat_point = None 202 | 203 | self._camera_vel = None 204 | self._max_h_angle = None 205 | self._max_v_angle = None 206 | self._min_h_angle = None 207 | self._min_v_angle = None 208 | self._radius = None 209 | self._roll_vel = None 210 | self._vel_scaling = None 211 | 212 | def setup_camera(self): 213 | """Set up camera motion ranges and state.""" 214 | # Define boundaries on the range of the camera motion. 215 | mode = self._env._physics.model.cam_mode[0] 216 | 217 | camera_type = CAMERA_MODES[mode] 218 | assert camera_type in ['fixed', 'trackcom'] 219 | 220 | self._camera_type = camera_type 221 | self._cam_initial_lookat_point = get_lookat_point(self._env.physics, 222 | self._camera_id) 223 | 224 | start_pos = copy.deepcopy( 225 | self._env.physics.named.data.cam_xpos[self._camera_id]) 226 | 227 | if self._camera_type != 'fixed': 228 | # Center the camera relative to the agent's center of mass. 229 | start_pos -= self._env.physics.named.data.subtree_com[1] 230 | 231 | start_r, start_h_angle, start_v_angle = cart2sphere(start_pos) 232 | # Scale the velocity by the starting radius. Most environments have radius 4, 233 | # but this downscales the velocity for the envs with radius < 4. 234 | self._vel_scaling = start_r / 4. 235 | self._max_h_angle = start_h_angle + self._horizontal_delta 236 | self._min_h_angle = start_h_angle - self._horizontal_delta 237 | self._max_v_angle = start_v_angle + self._vertical_delta 238 | self._min_v_angle = start_v_angle - self._vertical_delta 239 | 240 | if self._limit_to_upper_quadrant: 241 | # A centered cam is at np.pi / 2. 242 | self._max_v_angle = min(self._max_v_angle, np.pi / 2.) 243 | self._min_v_angle = max(self._min_v_angle, 0.) 244 | # A centered cam is at -np.pi / 2. 245 | self._max_h_angle = min(self._max_h_angle, 0.) 246 | self._min_h_angle = max(self._min_h_angle, -np.pi) 247 | 248 | self._max_roll = self._roll_delta 249 | self._min_roll = -self._roll_delta 250 | self._min_radius = max(start_r - start_r * self._max_zoom_in_percent, 0.) 251 | self._max_radius = start_r + start_r * self._max_zoom_out_percent 252 | 253 | # Decide the starting position for the camera. 254 | self._h_angle = self._random_state.uniform(self._min_h_angle, 255 | self._max_h_angle) 256 | 257 | self._v_angle = self._random_state.uniform(self._min_v_angle, 258 | self._max_v_angle) 259 | 260 | self._radius = self._random_state.uniform(self._min_radius, 261 | self._max_radius) 262 | 263 | self._roll = self._random_state.uniform(self._min_roll, self._max_roll) 264 | 265 | # Decide the starting velocity for the camera. 266 | vel = self._random_state.randn(3) 267 | vel /= np.sqrt(np.sum(vel**2.)) 268 | vel *= self._random_state.uniform(0., self._max_vel) 269 | self._camera_vel = vel 270 | self._roll_vel = self._random_state.uniform(-self._max_roll_vel, 271 | self._max_roll_vel) 272 | 273 | def reset(self): 274 | """Reset the camera state. """ 275 | time_step = self._env.reset() 276 | self.setup_camera() 277 | self._apply() 278 | return time_step 279 | 280 | 281 | def step(self, action): 282 | time_step = self._env.step(action) 283 | 284 | if time_step.first(): 285 | self.setup_camera() 286 | 287 | self._apply() 288 | return time_step 289 | 290 | def _apply(self): 291 | if not self._camera_type: 292 | self.setup_camera() 293 | 294 | # Random walk the velocity. 295 | vel_delta = self._random_state.randn(3) 296 | self._camera_vel += vel_delta * self._vel_std * self._vel_scaling 297 | self._roll_vel += self._random_state.randn() * self._roll_vel_std 298 | 299 | # Clip velocity if it gets too big. 300 | vel_norm = np.sqrt(np.sum(self._camera_vel**2.)) 301 | if vel_norm > self._max_vel * self._vel_scaling: 302 | self._camera_vel *= (self._max_vel * self._vel_scaling) / vel_norm 303 | 304 | self._roll_vel = np.clip(self._roll_vel, -self._max_roll_vel, 305 | self._max_roll_vel) 306 | 307 | cart_cam_pos = sphere2cart([self._radius, self._h_angle, self._v_angle]) 308 | # Apply velocity vector to camera 309 | sphere_cam_pos2 = cart2sphere(cart_cam_pos + self._camera_vel) 310 | sphere_cam_pos2 = clip_cam_position(sphere_cam_pos2, self._min_radius, 311 | self._max_radius, self._min_h_angle, 312 | self._max_h_angle, self._min_v_angle, 313 | self._max_v_angle) 314 | 315 | self._camera_vel = sphere2cart(sphere_cam_pos2) - cart_cam_pos 316 | 317 | self._radius, self._h_angle, self._v_angle = sphere_cam_pos2 318 | 319 | roll2 = self._roll + self._roll_vel 320 | roll2 = np.clip(roll2, self._min_roll, self._max_roll) 321 | 322 | self._roll_vel = roll2 - self._roll 323 | self._roll = roll2 324 | 325 | cart_cam_pos = sphere2cart(sphere_cam_pos2) 326 | 327 | if self._limit_to_upper_quadrant: 328 | lookat_method = get_lookat_xmat_no_roll 329 | else: 330 | # This method avoids jitteriness at the pole but allows some roll 331 | # in the camera matrix. This is important for reacher. 332 | lookat_method = get_lookat_xmat 333 | 334 | if self._camera_type == 'fixed': 335 | lookat_mat = lookat_method(self._cam_initial_lookat_point, 336 | cart_cam_pos) 337 | else: 338 | # Go from agent centric to world coords 339 | cart_cam_pos += self._env.physics.named.data.subtree_com[1] 340 | lookat_mat = lookat_method( 341 | get_lookat_point(self._env.physics, self._camera_id), cart_cam_pos) 342 | 343 | lookat_mat = np.reshape(lookat_mat, (3, 3)) 344 | roll_mat = rotvec2mat(self._roll, np.array([0., 0., 1.])) 345 | xmat = np.dot(lookat_mat, roll_mat) 346 | self._env.physics.named.data.cam_xpos[self._camera_id] = cart_cam_pos 347 | self._env.physics.named.data.cam_xmat[self._camera_id] = xmat.flatten() 348 | 349 | # Forward property and method calls to self._env. 350 | def __getattr__(self, attr): 351 | if hasattr(self._env, attr): 352 | return getattr(self._env, attr) 353 | raise AttributeError("'{}' object has no attribute '{}'".format( 354 | type(self).__name__, attr)) 355 | -------------------------------------------------------------------------------- /distracting_control/camera_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for camera movement code.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl.testing import absltest 23 | from dm_control import suite as dm_control_suite 24 | from dm_control.suite import cartpole 25 | from dm_control.suite.wrappers import pixels 26 | import numpy as np 27 | 28 | from distracting_control import camera 29 | 30 | 31 | def get_camera_params(domain_name, scale, dynamic): 32 | return dict( 33 | vertical_delta=np.pi / 2 * scale, 34 | horizontal_delta=np.pi / 2 * scale, 35 | # Limit camera to -90 / 90 degree rolls. 36 | roll_delta=np.pi / 2. * scale, 37 | vel_std=.1 * scale if dynamic else 0., 38 | max_vel=.4 * scale if dynamic else 0., 39 | roll_std=np.pi / 300 * scale if dynamic else 0., 40 | max_roll_vel=np.pi / 50 * scale if dynamic else 0., 41 | # Allow the camera to zoom in at most 50%. 42 | max_zoom_in_percent=.5 * scale, 43 | # Allow the camera to zoom out at most 200%. 44 | max_zoom_out_percent=1.5 * scale, 45 | limit_to_upper_quadrant='reacher' not in domain_name, 46 | ) 47 | 48 | 49 | def distraction_wrap(env, domain_name): 50 | camera_kwargs = get_camera_params( 51 | domain_name=domain_name, scale=0.0, dynamic=True) 52 | return camera.DistractingCameraEnv(env, camera_id=0, **camera_kwargs) 53 | 54 | 55 | class CameraTest(absltest.TestCase): 56 | 57 | def test_dynamic(self): 58 | camera_kwargs = get_camera_params( 59 | domain_name='cartpole', scale=0.1, dynamic=True) 60 | env = cartpole.swingup() 61 | env = camera.DistractingCameraEnv(env, camera_id=0, **camera_kwargs) 62 | env = pixels.Wrapper(env, render_kwargs={'camera_id': 0}) 63 | action_spec = env.action_spec() 64 | time_step = env.reset() 65 | frames = [] 66 | while not time_step.last() and len(frames) < 10: 67 | action = np.random.uniform( 68 | action_spec.minimum, action_spec.maximum, size=action_spec.shape) 69 | time_step = env.step(action) 70 | frames.append(time_step.observation['pixels']) 71 | self.assertEqual(frames[0].shape, (240, 320, 3)) 72 | 73 | def test_get_lookat_mat(self): 74 | agent_pos = np.array([1., -3., 4.]) 75 | cam_position = np.array([0., 0., 0.]) 76 | mat = camera.get_lookat_xmat_no_roll(agent_pos, cam_position) 77 | agent_pos = agent_pos / np.sqrt(np.sum(agent_pos**2.)) 78 | start = np.array([0., 0., -1.]) # Cam starts looking down Z. 79 | out = np.dot(mat.reshape((3, 3)), start) 80 | self.assertTrue(np.isclose(np.max(np.abs(out - agent_pos)), 0.)) 81 | 82 | def test_spherical_conversion(self): 83 | cart = np.array([1.4, -2.8, 3.9]) 84 | sphere = camera.cart2sphere(cart) 85 | cart2 = camera.sphere2cart(sphere) 86 | self.assertTrue(np.isclose(np.max(np.abs(cart2 - cart)), 0.)) 87 | 88 | def test_envs_same(self): 89 | # Test that the camera augmentations with magnitude 0 gives the same results 90 | # as when no camera augmentations are used. 91 | render_kwargs = {'width': 84, 'height': 84, 'camera_id': 0} 92 | domain_and_task = [('cartpole', 'swingup'), 93 | ('reacher', 'easy'), 94 | ('finger', 'spin'), 95 | ('cheetah', 'run'), 96 | ('ball_in_cup', 'catch'), 97 | ('walker', 'walk')] 98 | for (domain, task) in domain_and_task: 99 | seed = 42 100 | envs = [('baseline', 101 | pixels.Wrapper( 102 | dm_control_suite.load( 103 | domain, task, task_kwargs={'random': seed}), 104 | render_kwargs=render_kwargs)), 105 | ('no-wrapper', 106 | pixels.Wrapper( 107 | dm_control_suite.load( 108 | domain, task, task_kwargs={'random': seed}), 109 | render_kwargs=render_kwargs)), 110 | ('w/-camera_kwargs', 111 | pixels.Wrapper( 112 | distraction_wrap( 113 | dm_control_suite.load( 114 | domain, task, task_kwargs={'random': seed}), domain), 115 | render_kwargs=render_kwargs))] 116 | frames = [] 117 | for _, env in envs: 118 | random_state = np.random.RandomState(42) 119 | action_spec = env.action_spec() 120 | time_step = env.reset() 121 | frames.append([]) 122 | while not time_step.last() and len(frames[-1]) < 20: 123 | action = random_state.uniform( 124 | action_spec.minimum, action_spec.maximum, size=action_spec.shape) 125 | time_step = env.step(action) 126 | frame = time_step.observation['pixels'][:, :, 0:3] 127 | frames[-1].append(frame) 128 | frames_np = np.array(frames) 129 | for i in range(1, len(envs)): 130 | difference = np.mean(abs(frames_np[0] - frames_np[i])) 131 | self.assertEqual(difference, 0.) 132 | 133 | if __name__ == '__main__': 134 | absltest.main() 135 | -------------------------------------------------------------------------------- /distracting_control/color.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """A wrapper for dm_control environments which applies color distractions.""" 18 | 19 | from dm_control.rl import control 20 | import numpy as np 21 | 22 | 23 | class DistractingColorEnv(control.Environment): 24 | """Environment wrapper for color visual distraction. 25 | 26 | **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure 27 | the color changes are applied before rendering occurs. 28 | """ 29 | 30 | def __init__(self, env, step_std, max_delta, seed=None): 31 | """Initialize the environment wrapper. 32 | 33 | Args: 34 | env: instance of dm_control Environment to wrap with augmentations. 35 | """ 36 | if step_std < 0: 37 | raise ValueError("`step_std` must be greater than or equal to 0.") 38 | if max_delta < 0: 39 | raise ValueError("`max_delta` must be greater than or equal to 0.") 40 | 41 | self._env = env 42 | self._step_std = step_std 43 | self._max_delta = max_delta 44 | self._random_state = np.random.RandomState() 45 | 46 | self._cam_type = None 47 | self._current_rgb = None 48 | self._max_rgb = None 49 | self._min_rgb = None 50 | self._original_rgb = None 51 | 52 | def reset(self): 53 | """Reset the distractions state.""" 54 | time_step = self._env.reset() 55 | self._reset_color() 56 | return time_step 57 | 58 | def _reset_color(self): 59 | # Save all original colors. 60 | if self._original_rgb is None: 61 | self._original_rgb = np.copy(self._env.physics.model.mat_rgba)[:, :3] 62 | # Determine minimum and maximum rgb values. 63 | self._max_rgb = np.clip(self._original_rgb + self._max_delta, 0.0, 1.0) 64 | self._min_rgb = np.clip(self._original_rgb - self._max_delta, 0.0, 1.0) 65 | 66 | # Pick random colors in the allowed ranges. 67 | r = self._random_state.uniform(size=self._min_rgb.shape) 68 | self._current_rgb = self._min_rgb + r * (self._max_rgb - self._min_rgb) 69 | 70 | # Apply the color changes. 71 | self._env.physics.model.mat_rgba[:, :3] = self._current_rgb 72 | 73 | def step(self, action): 74 | time_step = self._env.step(action) 75 | 76 | if time_step.first(): 77 | self._reset_color() 78 | return time_step 79 | 80 | color_change = self._random_state.randn(*self._current_rgb.shape) 81 | color_change = color_change * self._step_std 82 | 83 | new_color = self._current_rgb + color_change 84 | 85 | self._current_rgb = np.clip( 86 | new_color, 87 | a_min=self._min_rgb, 88 | a_max=self._max_rgb, 89 | ) 90 | 91 | # Apply the color changes. 92 | self._env.physics.model.mat_rgba[:, :3] = self._current_rgb 93 | return time_step 94 | 95 | # Forward property and method calls to self._env. 96 | def __getattr__(self, attr): 97 | if hasattr(self._env, attr): 98 | return getattr(self._env, attr) 99 | raise AttributeError("'{}' object has no attribute '{}'".format( 100 | type(self).__name__, attr)) 101 | -------------------------------------------------------------------------------- /distracting_control/distracting_control_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A simple demo that produces an image from the environment.""" 17 | import os 18 | from absl import app 19 | from absl import flags 20 | 21 | import PIL 22 | 23 | from distracting_control import suite 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | # NOTE: This has been populated with the demo images path, but you should 28 | # download and extract the DAVIS dataset and point this path to the location 29 | # of DAVIS. 30 | flags.DEFINE_string( 31 | 'davis_path', 32 | '', 33 | 'Path to DAVIS images, used for background distractions.') 34 | flags.DEFINE_string('output_dir', '/tmp/distracting_control_demo', 35 | 'Directory where the results are being saved.') 36 | 37 | 38 | def main(unused_argv): 39 | 40 | if not FLAGS.davis_path: 41 | raise ValueError( 42 | 'You must download and extract the DAVIS dataset and pass a path ' 43 | 'a path to the videos, e.g.: /tmp/DAVIS/JPEGImages/480p') 44 | 45 | for i, difficulty in enumerate(['easy', 'medium', 'hard']): 46 | for j, (domain, task) in enumerate([ 47 | ('ball_in_cup', 'catch'), 48 | ('cartpole', 'swingup'), 49 | ('cheetah', 'run'), 50 | ('finger', 'spin'), 51 | ('reacher', 'easy'), 52 | ('walker', 'walk')]): 53 | 54 | env = suite.load( 55 | domain, task, difficulty, background_dataset_path=FLAGS.davis_path) 56 | 57 | # Get the first frame. 58 | time_step = env.reset() 59 | frame = time_step.observation['pixels'][:, :, 0:3] 60 | 61 | # Save the first frame. 62 | try: 63 | os.mkdir(FLAGS.output_dir) 64 | except OSError: 65 | pass 66 | filepath = os.path.join(FLAGS.output_dir, f'{i:02d}-{j:02d}.jpg') 67 | image = PIL.Image.fromarray(frame) 68 | image.save(filepath) 69 | print(f'Saved results to {FLAGS.output_dir}') 70 | 71 | 72 | if __name__ == '__main__': 73 | app.run(main) 74 | -------------------------------------------------------------------------------- /distracting_control/requirements.txt: -------------------------------------------------------------------------------- 1 | dm-control==0.0.322773188 2 | absl-py>=0.9.0 3 | numpy>=1.18.3 4 | mock>=mock-4.0 5 | pillow>=7.0.0 6 | -------------------------------------------------------------------------------- /distracting_control/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | set -e 17 | set -x 18 | 19 | virtualenv -p python3 . 20 | source ./bin/activate 21 | 22 | pip install tensorflow 23 | pip install -r distracting_control/requirements.txt 24 | python -m distracting_control.suite_test 25 | python -m distracting_control.camera_test 26 | -------------------------------------------------------------------------------- /distracting_control/suite.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A collection of MuJoCo-based Reinforcement Learning environments. 17 | 18 | The suite provides a similar API to the original dm_control suite. 19 | Users can configure the distractions on top of the original tasks. The suite is 20 | targeted for loading environments directly with similar configurations as those 21 | used in the original paper. Each distraction wrapper can be used independently 22 | though. 23 | """ 24 | try: 25 | from dm_control import suite # pylint: disable=g-import-not-at-top 26 | from dm_control.suite.wrappers import pixels # pylint: disable=g-import-not-at-top 27 | except ImportError: 28 | suite = None 29 | 30 | from distracting_control import background 31 | from distracting_control import camera 32 | from distracting_control import color 33 | from distracting_control import suite_utils 34 | 35 | 36 | def is_available(): 37 | return suite is not None 38 | 39 | 40 | def load(domain_name, 41 | task_name, 42 | difficulty=None, 43 | dynamic=False, 44 | background_dataset_path=None, 45 | background_dataset_videos="train", 46 | background_kwargs=None, 47 | camera_kwargs=None, 48 | color_kwargs=None, 49 | task_kwargs=None, 50 | environment_kwargs=None, 51 | visualize_reward=False, 52 | render_kwargs=None, 53 | pixels_only=True, 54 | pixels_observation_key="pixels", 55 | allow_color_distraction=True, 56 | allow_background_distraction=True, 57 | allow_camera_distraction=True, 58 | ): 59 | """Returns an environment from a domain name, task name and optional settings. 60 | 61 | ```python 62 | env = suite.load('cartpole', 'balance') 63 | ``` 64 | 65 | Adding a difficulty will configure distractions matching the reference paper 66 | for easy, medium, hard. 67 | 68 | Users can also toggle dynamic properties for distractions. 69 | 70 | Args: 71 | domain_name: A string containing the name of a domain. 72 | task_name: A string containing the name of a task. 73 | task_kwargs: Optional `dict` of keyword arguments for the task. 74 | difficulty: Difficulty for the suite. One of 'easy', 'medium', 'hard'. 75 | dynamic: Boolean controlling whether distractions are dynamic or static. 76 | backgound_dataset_path: String to the davis directory that contains the 77 | video directories. 78 | background_dataset_videos: String ('train'/'val') or list of strings of the 79 | DAVIS videos to be used for backgrounds. 80 | background_kwargs: Dict, overwrites settings for background distractions. 81 | camera_kwargs: Dict, overwrites settings for camera distractions. 82 | color_kwargs: Dict, overwrites settings for color distractions. 83 | task_kwargs: Dict, dm control task kwargs. 84 | environment_kwargs: Optional `dict` specifying keyword arguments for the 85 | environment. 86 | visualize_reward: Optional `bool`. If `True`, object colours in rendered 87 | frames are set to indicate the reward at each step. Default `False`. 88 | render_kwargs: Dict, render kwargs for pixel wrapper. 89 | pixels_only: Boolean controlling the exclusion of states in the observation. 90 | pixels_observation_key: Key in the observation used for the rendered image. 91 | 92 | Returns: 93 | The requested environment. 94 | """ 95 | if not is_available(): 96 | raise ImportError("dm_control module is not available. Make sure you " 97 | "follow the installation instructions from the " 98 | "dm_control package.") 99 | 100 | if difficulty not in [None, "easy", "medium", "hard", "pse", "hard_bg"]: 101 | raise ValueError("Difficulty should be one of: 'easy', 'medium', 'hard', 'pse', 'hard_bg'.") 102 | 103 | render_kwargs = render_kwargs or {} 104 | if "camera_id" not in render_kwargs: 105 | render_kwargs["camera_id"] = 2 if domain_name == "quadruped" else 0 106 | 107 | env = suite.load( 108 | domain_name, 109 | task_name, 110 | task_kwargs=task_kwargs, 111 | environment_kwargs=environment_kwargs, 112 | visualize_reward=visualize_reward) 113 | 114 | # Apply background distractions. 115 | if (difficulty or background_kwargs) and allow_background_distraction: 116 | background_dataset_path = ( 117 | background_dataset_path or suite_utils.DEFAULT_BACKGROUND_PATH) 118 | final_background_kwargs = dict() 119 | if difficulty: 120 | # Get kwargs for the given difficulty. 121 | if background_dataset_videos in ['val', 'validation']: 122 | num_videos = suite_utils.DIFFICULTY_NUM_VIDEOS_VAL[difficulty] 123 | else: 124 | num_videos = suite_utils.DIFFICULTY_NUM_VIDEOS[difficulty] 125 | final_background_kwargs.update( 126 | suite_utils.get_background_kwargs(domain_name, num_videos, dynamic, 127 | background_dataset_path, 128 | background_dataset_videos)) 129 | else: 130 | # Set the dataset path and the videos. 131 | final_background_kwargs.update( 132 | dict( 133 | dataset_path=background_dataset_path, 134 | dataset_videos=background_dataset_videos)) 135 | if background_kwargs: 136 | # Overwrite kwargs with those passed here. 137 | final_background_kwargs.update(background_kwargs) 138 | env = background.DistractingBackgroundEnv(env, **final_background_kwargs) 139 | 140 | # Apply camera distractions. 141 | if (difficulty or camera_kwargs) and allow_camera_distraction: 142 | final_camera_kwargs = dict(camera_id=render_kwargs["camera_id"]) 143 | if difficulty: 144 | # Get kwargs for the given difficulty. 145 | scale = suite_utils.DIFFICULTY_SCALE[difficulty] 146 | final_camera_kwargs.update( 147 | suite_utils.get_camera_kwargs(domain_name, scale, dynamic)) 148 | if camera_kwargs: 149 | # Overwrite kwargs with those passed here. 150 | final_camera_kwargs.update(camera_kwargs) 151 | env = camera.DistractingCameraEnv(env, **final_camera_kwargs) 152 | 153 | # Apply color distractions. 154 | if (difficulty or color_kwargs) and allow_color_distraction: 155 | final_color_kwargs = dict() 156 | if difficulty: 157 | # Get kwargs for the given difficulty. 158 | scale = suite_utils.DIFFICULTY_SCALE[difficulty] 159 | final_color_kwargs.update(suite_utils.get_color_kwargs(scale, dynamic)) 160 | if color_kwargs: 161 | # Overwrite kwargs with those passed here. 162 | final_color_kwargs.update(color_kwargs) 163 | env = color.DistractingColorEnv(env, **final_color_kwargs) 164 | 165 | # Apply Pixel wrapper after distractions. This is needed to ensure the 166 | # changes from the distraction wrapper are applied to the MuJoCo environment 167 | # before the rendering occurs. 168 | env = pixels.Wrapper( 169 | env, 170 | pixels_only=pixels_only, 171 | render_kwargs=render_kwargs, 172 | observation_key=pixels_observation_key) 173 | 174 | return env 175 | -------------------------------------------------------------------------------- /distracting_control/suite_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for suite code.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import mock 21 | from distracting_control import suite 22 | 23 | DAVIS_PATH = '/tmp/davis' 24 | 25 | class SuiteTest(parameterized.TestCase): 26 | 27 | @parameterized.named_parameters(('none', None), 28 | ('easy', 'easy'), 29 | ('medium', 'medium'), 30 | ('hard', 'hard')) 31 | @mock.patch.object(suite, 'pixels') 32 | @mock.patch.object(suite, 'suite') 33 | def test_suite_load_with_difficulty(self, difficulty, mock_dm_suite, 34 | mock_pixels): 35 | domain_name = 'cartpole' 36 | task_name = 'balance' 37 | suite.load( 38 | domain_name, 39 | task_name, 40 | difficulty, 41 | background_dataset_path=DAVIS_PATH) 42 | 43 | mock_dm_suite.load.assert_called_with( 44 | domain_name, 45 | task_name, 46 | environment_kwargs=None, 47 | task_kwargs=None, 48 | visualize_reward=False) 49 | 50 | mock_pixels.Wrapper.assert_called_with( 51 | mock.ANY, 52 | observation_key='pixels', 53 | pixels_only=True, 54 | render_kwargs={'camera_id': 0}) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /distracting_control/suite_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A collection of MuJoCo-based Reinforcement Learning environments. 17 | 18 | The suite provides a similar API to the original dm_control suite. 19 | Users can configure the distractions on top of the original tasks. The suite is 20 | targeted for loading environments directly with similar configurations as those 21 | used in the original paper. Each distraction wrapper can be used independently 22 | though. 23 | """ 24 | import numpy as np 25 | 26 | DIFFICULTY_SCALE = dict(pse=0.0, easy=0.1, medium=0.2, hard=0.3, hard_bg=0.0) 27 | DIFFICULTY_NUM_VIDEOS = dict(pse=2, easy=4, medium=8, hard=None, hard_bg=None) 28 | DIFFICULTY_NUM_VIDEOS_VAL = dict(pse=None, easy=4, medium=8, hard=None, hard_bg=None) 29 | DEFAULT_BACKGROUND_PATH = "$HOME/davis/" 30 | 31 | 32 | def get_color_kwargs(scale, dynamic): 33 | max_delta = scale 34 | step_std = 0.03 * scale if dynamic else 0.0 35 | return dict(max_delta=max_delta, step_std=step_std) 36 | 37 | 38 | def get_camera_kwargs(domain_name, scale, dynamic): 39 | assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah', 40 | 'ball_in_cup', 'walker'] 41 | assert scale >= 0.0 42 | assert scale <= 1.0 43 | return dict( 44 | vertical_delta=np.pi / 2 * scale, 45 | horizontal_delta=np.pi / 2 * scale, 46 | # Limit camera to -90 / 90 degree rolls. 47 | roll_delta=np.pi / 2. * scale, 48 | vel_std=.1 * scale if dynamic else 0., 49 | max_vel=.4 * scale if dynamic else 0., 50 | roll_std=np.pi / 300 * scale if dynamic else 0., 51 | max_roll_vel=np.pi / 50 * scale if dynamic else 0., 52 | max_zoom_in_percent=.5 * scale, 53 | max_zoom_out_percent=1.5 * scale, 54 | limit_to_upper_quadrant='reacher' not in domain_name, 55 | ) 56 | 57 | 58 | def get_background_kwargs(domain_name, 59 | num_videos, 60 | dynamic, 61 | dataset_path, 62 | dataset_videos=None, 63 | shuffle=False, 64 | video_alpha=1.0): 65 | assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah', 66 | 'ball_in_cup', 'walker'] 67 | if domain_name == 'reacher': 68 | ground_plane_alpha = 0.0 69 | elif domain_name in ['walker', 'cheetah']: 70 | ground_plane_alpha = 1.0 71 | else: 72 | ground_plane_alpha = 0.3 73 | 74 | return dict( 75 | num_videos=num_videos, 76 | video_alpha=video_alpha, 77 | ground_plane_alpha=ground_plane_alpha, 78 | dynamic=dynamic, 79 | dataset_path=dataset_path, 80 | dataset_videos=dataset_videos, 81 | shuffle_buffer_size=100 if shuffle else None, 82 | ) 83 | -------------------------------------------------------------------------------- /environment_container_dcs.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | """ Wraps Distracting Control Suite in a Gym-like wrapper.""" 7 | import numpy as np 8 | from distracting_control import suite as distracting_suite 9 | from collections import deque 10 | from PIL import Image 11 | import argparse 12 | import yaml 13 | import matplotlib.pyplot as plt 14 | plt.ion() 15 | 16 | 17 | class EnvironmentContainerDCS(object): 18 | """ 19 | Wrapper around DCS. 20 | """ 21 | def __init__(self, config, train=True, seed=None): 22 | self.domain = config['domain'] 23 | 24 | self.get_other_obs = False 25 | 26 | # The standard task and action_repeat for each domain. 27 | task_info = { 28 | 'ball_in_cup': ('catch', 4), 29 | 'cartpole': ('swingup', 8), 30 | 'cheetah': ('run', 4), 31 | 'finger': ('spin', 2), 32 | 'reacher': ('easy', 4), 33 | 'walker': ('walk', 2), 34 | } 35 | self.task, self.action_repeat = task_info[self.domain] 36 | 37 | self.difficulty = config['difficulty'] 38 | if self.difficulty in ['none', 'None']: 39 | self.difficulty = None 40 | 41 | self.dynamic = config['dynamic'] 42 | self.background_dataset_path = config.get('background_dataset_path', 'DAVIS/JPEGImages/480p') 43 | allow_color_distraction = config.get('allow_color_distraction', True) 44 | allow_background_distraction = config.get('allow_background_distraction', True) 45 | allow_camera_distraction = config.get('allow_camera_distraction', True) 46 | if seed is not None: 47 | task_kwargs = {'random': seed} 48 | else: 49 | task_kwargs = None 50 | self.env = distracting_suite.load( 51 | self.domain, self.task, difficulty=self.difficulty, dynamic=self.dynamic, 52 | background_dataset_path=self.background_dataset_path, 53 | background_dataset_videos='train' if train else 'val', 54 | pixels_only=False, 55 | task_kwargs=task_kwargs, 56 | allow_color_distraction=allow_color_distraction, 57 | allow_camera_distraction=allow_camera_distraction, 58 | allow_background_distraction=allow_background_distraction) 59 | 60 | action_spec = self.env.action_spec() 61 | self.action_dims = len(action_spec.minimum) 62 | self.action_low = action_spec.minimum 63 | self.action_high = action_spec.maximum 64 | self.num_frames_to_stack = config.get('num_frames_to_stack', 1) 65 | if self.num_frames_to_stack > 1: 66 | self.frame_queue = deque([], maxlen=self.num_frames_to_stack) 67 | self.config = config 68 | self.image_height, self.image_width = self.config['image_height'], self.config['image_width'] 69 | self.num_channels = 3 * self.num_frames_to_stack 70 | self.other_dims = 0 71 | 72 | def get_action_dims(self): 73 | return self.action_dims 74 | 75 | def get_action_repeat(self): 76 | return self.action_repeat 77 | 78 | def get_action_limits(self): 79 | return self.action_low, self.action_high 80 | 81 | def get_obs_chw(self): 82 | return self.num_channels, self.image_height, self.image_width 83 | 84 | def get_obs_other_dims(self): 85 | return self.other_dims 86 | 87 | def reset(self): 88 | time_step = self.env.reset() 89 | if self.num_frames_to_stack > 1: 90 | self.frame_queue.clear() 91 | obs = self._get_image(time_step) # C, H, W. 92 | obs_dict = {'image': obs} 93 | return obs_dict 94 | 95 | def step(self, action): 96 | reward = 0 97 | for _ in range(self.action_repeat): 98 | time_step = self.env.step(action) 99 | reward += time_step.reward 100 | obs = self._get_image(time_step) 101 | done = False 102 | info = {} 103 | obs_dict = {'image': obs} 104 | return obs_dict, reward, done, info 105 | 106 | def _get_image(self, time_step): 107 | image_height, image_width = self.config['image_height'], self.config['image_width'] 108 | obs = time_step.observation['pixels'][:, :, 0:3] # (240, 320, 3). 109 | # Resize to image_height, image_width 110 | obs = Image.fromarray(obs).resize((image_width, image_height), resample=Image.BILINEAR) 111 | #obs = cv2.resize(obs, dsize=(image_width, image_height)) 112 | obs = np.asarray(obs) 113 | obs = obs.transpose((2, 0, 1)).copy() # (C, H, W) 114 | obs = self._stack_images(obs) 115 | return obs 116 | 117 | def _stack_images(self, obs): 118 | if self.num_frames_to_stack > 1: 119 | if len(self.frame_queue) == 0: # Just after reset. 120 | for _ in range(self.num_frames_to_stack): 121 | self.frame_queue.append(obs) 122 | else: 123 | self.frame_queue.append(obs) 124 | obs = np.concatenate(list(self.frame_queue), axis=0) 125 | return obs 126 | 127 | 128 | class EnvironmentContainerDCS_DMC_paired(EnvironmentContainerDCS): 129 | def __init__(self, config, train=True, seed=1): 130 | super().__init__(config, train=train, seed=seed) 131 | config_dmc = config.copy() 132 | config_dmc['difficulty'] = 'none' 133 | self.dmc = EnvironmentContainerDCS(config_dmc, train=train, seed=seed) 134 | 135 | def reset(self): 136 | obs_dict_dmc = self.dmc.reset() 137 | obs_dict = super().reset() 138 | obs_dict['image_clean'] = obs_dict_dmc['image'] 139 | return obs_dict 140 | 141 | def step(self, action): 142 | obs_dict_dmc, _, _, _ = self.dmc.step(action) 143 | obs_dict, reward, done, info = super().step(action) 144 | obs_dict['image_clean'] = obs_dict_dmc['image'] 145 | return obs_dict, reward, done, info 146 | 147 | 148 | def argument_parser(argument): 149 | """ Argument parser """ 150 | parser = argparse.ArgumentParser(description='Binder Network.') 151 | parser.add_argument('-c', '--config', default='', type=str, help='Training config') 152 | args = parser.parse_args(argument) 153 | return args 154 | 155 | 156 | def test(): 157 | args = argument_parser(None) 158 | try: 159 | with open(args.config) as f: 160 | config = yaml.safe_load(f) 161 | except FileNotFoundError: 162 | print("Error opening specified config yaml at: {}. " 163 | "Please check filepath and try again.".format(args.config)) 164 | 165 | config = config['parameters'] 166 | seed = config['seed'] 167 | np.random.seed(seed) 168 | env = EnvironmentContainerDCS_DMC_paired(config['env'], train=True, seed=config['seed']) 169 | plt.figure(1) 170 | action_low, action_high = env.get_action_limits() 171 | action_dims = env.get_action_dims() 172 | for _ in range(1): 173 | env.reset() 174 | for _ in range(1): 175 | action = np.random.uniform(action_low, action_high, action_dims) 176 | obs_dict, reward, done, info = env.step(action) 177 | obs_dcs = obs_dict['image'].transpose((1, 2, 0)) 178 | obs_dmc = obs_dict['image_clean'].transpose((1, 2, 0)) 179 | plt.clf() 180 | obs = np.concatenate([obs_dcs, obs_dmc], axis=1) 181 | plt.imshow(obs) 182 | plt.pause(0.001) 183 | filename = '/Users/nitish/Desktop/binder_figures/sample_0.png' 184 | plt.savefig(filename) 185 | 186 | 187 | if __name__ == '__main__': 188 | test() 189 | -------------------------------------------------------------------------------- /environment_container_robosuite.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import argparse 6 | import os 7 | import numpy as np 8 | import yaml 9 | from collections import deque 10 | import robosuite as suite 11 | from robosuite.controllers import load_controller_config, ALL_CONTROLLERS 12 | from mujoco_py import MujocoException 13 | from robosuite.wrappers.domain_randomization_wrapper import DomainRandomizationWrapper, DEFAULT_CAMERA_ARGS,\ 14 | DEFAULT_COLOR_ARGS, DEFAULT_LIGHTING_ARGS 15 | import matplotlib.pyplot as plt 16 | plt.ion() 17 | 18 | 19 | def dict_merge(default, user): 20 | d = default.copy() 21 | d.update(user) 22 | return d 23 | 24 | 25 | def load_robosuite_controller_config(controller): 26 | if controller in set(ALL_CONTROLLERS): 27 | # This is a default controller 28 | controller_config = load_controller_config(default_controller=controller) 29 | else: 30 | # This is a string to the custom controller 31 | controller_config = load_controller_config(custom_fpath=controller) 32 | return controller_config 33 | 34 | 35 | class EnvironmentContainerRobosuite(object): 36 | def __init__(self, config, train=True, seed=None): 37 | super().__init__() 38 | self.config = config 39 | robosuite_config = config['robosuite_config'] 40 | self.crop_image = config.get('crop_image', False) 41 | if self.crop_image: 42 | self.image_height = config['crop_height'] 43 | self.image_width = config['crop_width'] 44 | self.crop_center_xy = config['crop_center_xy'] 45 | cx, cy = self.crop_center_xy 46 | self.crop_left = cx - self.image_width // 2 47 | self.crop_top = cy - self.image_height // 2 48 | else: 49 | self.image_height = robosuite_config['camera_heights'] 50 | self.image_width = robosuite_config['camera_widths'] 51 | 52 | self.dr = config.get('domain_randomize', False) 53 | dr_config = config.get('domain_randomization_config', None) 54 | controller = config['controller'] 55 | robosuite_config['controller_configs'] = load_robosuite_controller_config(controller) 56 | 57 | self.env = suite.make(**robosuite_config) 58 | self.robosuite_config = robosuite_config 59 | self.env_name = self.robosuite_config['env_name'] 60 | 61 | self.image_key = config['image_key'] 62 | if 'other_key' in config: 63 | self.other_key = config['other_key'] 64 | self.get_other_obs = True 65 | ob_dict = self.env.reset() 66 | assert self.other_key in ob_dict 67 | self.other_dims = len(ob_dict[self.other_key]) 68 | else: 69 | self.get_other_obs = False 70 | self.other_dims = 0 71 | 72 | low, high = self.env.action_spec 73 | self.action_dims = len(low) 74 | self.action_repeat = 1 75 | 76 | if self.dr: 77 | dr_config['color_randomization_args'] = dict_merge(DEFAULT_COLOR_ARGS, 78 | dr_config.get('color_randomization_args', {})) 79 | dr_config['camera_randomization_args'] = dict_merge(DEFAULT_CAMERA_ARGS, 80 | dr_config.get('camera_randomization_args', {})) 81 | dr_config['lighting_randomization_args'] = dict_merge(DEFAULT_LIGHTING_ARGS, 82 | dr_config.get('lighting_randomization_args', {})) 83 | self.env = DomainRandomizationWrapper(self.env, seed=seed, **dr_config) 84 | 85 | self.num_frames_to_stack = config.get('num_frames_to_stack', 1) 86 | if self.num_frames_to_stack > 1: 87 | self.frame_queue = deque([], maxlen=self.num_frames_to_stack) 88 | self.num_channels = 3 * self.num_frames_to_stack 89 | self.other_dims = self.other_dims * self.num_frames_to_stack 90 | self.ob_dict = None # To hold the last raw observation. 91 | 92 | def get_action_dims(self): 93 | return self.action_dims 94 | 95 | def get_action_repeat(self): 96 | return self.action_repeat 97 | 98 | def get_action_limits(self): 99 | low, high = self.env.action_spec 100 | return low, high 101 | 102 | def get_obs_chw(self): 103 | return self.num_channels, self.image_height, self.image_width 104 | 105 | def get_obs_other_dims(self): 106 | return self.other_dims 107 | 108 | def preprocess_image(self, img): 109 | # If this is an RGB image, make it C, H, W. 110 | # This is done here rather than just before fprop so that 111 | # if/when depth is added, we can concatenate it as a channel. 112 | if len(img.shape) == 3: # RGB image 113 | if self.crop_image: 114 | x, y = self.crop_left, self.crop_top 115 | img = img[y:y + self.image_height, x:x + self.image_width, :] 116 | elif len(img.shape) == 2: # Depth image 117 | if self.crop_image: 118 | x, y = self.crop_left, self.crop_top 119 | img = img[y:y + self.image_height, x:x + self.image_width] 120 | if len(img.shape) >= 2: 121 | img = img[::-1, ...].copy() # The frontview image is upside-down. 122 | img = img.transpose(2, 0, 1) # CHW 123 | return img 124 | 125 | def _get_obs(self, obs_dict, verbose=False): 126 | assert self.image_key in obs_dict, "key {} not found in obs".format(self.image_key) 127 | img = self.preprocess_image(obs_dict[self.image_key]) 128 | res = dict(image=img) 129 | if self.get_other_obs: 130 | assert self.other_key in obs_dict, "key {} not found in obs".format(self.other_key) 131 | res['other'] = obs_dict[self.other_key] 132 | if self.num_frames_to_stack > 1: 133 | res = self._get_stacked_obs(res) 134 | return res 135 | 136 | def _get_stacked_obs(self, obs): 137 | if len(self.frame_queue) == 0: 138 | for _ in range(self.num_frames_to_stack): 139 | self.frame_queue.append(obs) 140 | else: 141 | self.frame_queue.append(obs) 142 | keys = obs.keys() 143 | res = {} 144 | for key in keys: 145 | res[key] = np.concatenate([frame[key] for frame in self.frame_queue]) 146 | return res 147 | 148 | def reset(self): 149 | if self.num_frames_to_stack > 1: 150 | self.frame_queue.clear() 151 | self.ob_dict = self.env.reset() 152 | #if self.dr: # Fix for resetting bug. Domain is not getting randmized for the obs coming from reset. 153 | action = np.zeros(self.action_dims) 154 | self.ob_dict, _, _, _ = self.env.step(action) 155 | obs = self._get_obs(self.ob_dict) 156 | return obs 157 | 158 | def step(self, action): 159 | try: 160 | ob_dict, reward, done, info = self.env.step(action) 161 | except MujocoException as e: 162 | print('MujocoException', e) 163 | print('Will skip this action') 164 | if self.ob_dict is not None: 165 | ob_dict = self.ob_dict 166 | reward = 0 167 | done = False 168 | info = {} 169 | self.ob_dict = ob_dict 170 | # Additional reward shaping. 171 | #if self.env_name == 'Door': 172 | # # Additional reward shaping for door angle. 173 | # if reward < 1: 174 | # hinge_qpos = self.env.sim.data.qpos[self.env.hinge_qpos_addr] 175 | # reward += np.clip(0.5 * hinge_qpos / 0.3, 0, 0.5) 176 | obs = self._get_obs(ob_dict) 177 | return obs, reward, done, info 178 | 179 | def render(self, mode, **kwargs): 180 | if mode == 'rgb_array': 181 | image_list = [] 182 | for (cam_name, cam_w, cam_h, cam_d) in \ 183 | zip(self.env.camera_names, self.env.camera_widths, self.env.camera_heights, self.env.camera_depths): 184 | 185 | # Add camera observations to the dict 186 | camera_obs = self.env.sim.render( 187 | camera_name=cam_name, 188 | width=cam_w, 189 | height=cam_h, 190 | depth=cam_d 191 | ) 192 | if cam_d: 193 | img, depth = camera_obs 194 | camera_obs = np.concatenate([img, depth[:, :, None]], axis=2) 195 | image_list.append(camera_obs) 196 | image = np.concatenate(image_list, axis=1) 197 | return image # return RGB frame suitable for video 198 | elif mode == 'human': 199 | self.env.render() # pop up a window and render 200 | else: 201 | raise NotImplementedError 202 | 203 | def argument_parser(argument): 204 | """ Argument parser """ 205 | parser = argparse.ArgumentParser(description='Binder Network.') 206 | parser.add_argument('-c', '--config', default='', type=str, help='Training config') 207 | args = parser.parse_args(argument) 208 | return args 209 | 210 | 211 | 212 | def test2(): 213 | args = argument_parser(None) 214 | try: 215 | with open(args.config) as f: 216 | config = yaml.safe_load(f) 217 | except FileNotFoundError: 218 | print("Error opening specified config yaml at: {}. " 219 | "Please check filepath and try again.".format(args.config)) 220 | 221 | config = config['parameters'] 222 | seed = config['seed'] 223 | np.random.seed(seed) 224 | env = EnvironmentContainerRobosuite(config['env']) 225 | obs_dict = env.reset() 226 | action_low, action_high = env.get_action_limits() 227 | action_dims = env.get_action_dims() 228 | plt.figure(1) 229 | obs_list = [] 230 | for ii in range(2): 231 | obs = obs_dict['image'].transpose((1, 2, 0)) 232 | obs_list.append(obs) 233 | plt.clf() 234 | plt.imshow(obs) 235 | plt.suptitle('Image {}'.format(ii)) 236 | plt.pause(0.5) 237 | action = np.random.uniform(action_low, action_high, action_dims) 238 | obs_dict, reward, done, info = env.step(action) 239 | obs_list.append(np.abs(obs_list[-1] - obs_list[-2])) 240 | obs = np.concatenate(obs_list, axis=1) 241 | plt.imshow(obs) 242 | plt.axis('off') 243 | plt.show() 244 | input('Press enter') 245 | 246 | 247 | def test(): 248 | args = argument_parser(None) 249 | try: 250 | with open(args.config) as f: 251 | config = yaml.safe_load(f) 252 | except FileNotFoundError: 253 | print("Error opening specified config yaml at: {}. " 254 | "Please check filepath and try again.".format(args.config)) 255 | 256 | config = config['parameters'] 257 | seed = config['seed'] 258 | np.random.seed(seed) 259 | plt.figure(1) 260 | randomize_settings = [(False, False), (False, True), (True, True)] 261 | obs_list = [] 262 | for robot in ['Panda', 'Jaco']: 263 | for randomize_camera, randomize_other in randomize_settings: 264 | config['env']['domain_randomize'] = True 265 | config['env']['robosuite_config']['robots'] = [robot] 266 | config['env']['domain_randomization_config']['randomize_camera'] = randomize_camera 267 | config['env']['domain_randomization_config']['randomize_color'] = randomize_other 268 | config['env']['domain_randomization_config']['randomize_lighting'] = randomize_other 269 | env = EnvironmentContainerRobosuite(config['env'], seed=seed) 270 | env.reset() 271 | action_low, action_high = env.get_action_limits() 272 | action_dims = env.get_action_dims() 273 | action = np.random.uniform(action_low, action_high, action_dims) 274 | obs_dict, reward, done, info = env.step(action) 275 | obs = obs_dict['image'].transpose((1, 2, 0)) 276 | plt.clf() 277 | plt.imshow(obs) 278 | plt.pause(0.001) 279 | obs_list.append(obs) 280 | obs1 = np.concatenate(obs_list[:3], axis=0) 281 | obs2 = np.concatenate(obs_list[3:], axis=0) 282 | obs = np.concatenate([obs1, obs2], axis=1) 283 | plt.draw() 284 | plt.imshow(obs) 285 | plt.axis('off') 286 | plt.show() 287 | plt.savefig('/Users/nitish/Desktop/binder_figures/robosuite_supp_new.png', bbox_inches='tight') 288 | 289 | if __name__ == '__main__': 290 | test() 291 | -------------------------------------------------------------------------------- /environments.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | """ Interface to all environments. """ 6 | 7 | def make_environment(config, train=True, seed=None): 8 | env_name = config.get('name', 'dcs') 9 | if env_name == 'dcs': 10 | from environment_container_dcs import EnvironmentContainerDCS 11 | return EnvironmentContainerDCS(config, train=train, seed=seed) 12 | elif env_name == 'robosuite': 13 | from environment_container_robosuite import EnvironmentContainerRobosuite 14 | return EnvironmentContainerRobosuite(config, train=train, seed=seed) 15 | elif env_name == 'dcs_dmc_paired': 16 | from environment_container_dcs import EnvironmentContainerDCS_DMC_paired 17 | return EnvironmentContainerDCS_DMC_paired(config, train=train, seed=seed) 18 | else: 19 | raise ValueError 20 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import random 9 | import numpy as np 10 | from torch import distributions as pyd 11 | import math 12 | 13 | 14 | def weight_init(m): 15 | """Custom weight init for Conv2D and Linear layers.""" 16 | if isinstance(m, nn.Linear): 17 | nn.init.orthogonal_(m.weight.data) 18 | if m.bias is not None: 19 | m.bias.data.fill_(0.0) 20 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 21 | # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf 22 | assert m.weight.size(2) == m.weight.size(3) 23 | m.weight.data.fill_(0.0) 24 | if m.bias is not None: 25 | m.bias.data.fill_(0.0) 26 | mid = m.weight.size(2) // 2 27 | gain = nn.init.calculate_gain('relu') 28 | nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) 29 | 30 | 31 | def act_func(act_str): 32 | act = None 33 | if act_str == 'relu': 34 | act = nn.ReLU 35 | elif act_str == 'elu': 36 | act = nn.ELU 37 | elif act_str == 'identity': 38 | act = nn.Identity 39 | elif act_str == 'tanh': 40 | act = nn.Tanh 41 | elif act_str == 'sigmoid': 42 | act = nn.Sigmoid 43 | else: 44 | raise ValueError('Unknown activation function.') 45 | return act 46 | 47 | 48 | class GRUCell(nn.Module): 49 | """ 50 | GRUCell with optional layer norm. 51 | """ 52 | def __init__(self, input_dims, hidden_dims, norm=False, update_bias=-1): 53 | super().__init__() 54 | self.input_dims = input_dims 55 | self.hidden_dims = hidden_dims 56 | self.linear = nn.Linear(input_dims + hidden_dims, 3 * hidden_dims, bias=not norm) 57 | self.norm = norm 58 | self.update_bias = update_bias 59 | if norm: 60 | self.norm_layer = nn.LayerNorm(3 * hidden_dims) 61 | 62 | def forward(self, inputs, state): 63 | gate_inputs = self.linear(torch.cat([inputs, state], dim=-1)) 64 | if self.norm: 65 | gate_inputs = self.norm_layer(gate_inputs) 66 | reset, cand, update = gate_inputs.chunk(3, -1) 67 | reset = torch.sigmoid(reset) 68 | cand = torch.tanh(reset * cand) 69 | update = torch.sigmoid(update + self.update_bias) 70 | output = update * cand + (1 - update) * state 71 | return output 72 | 73 | 74 | class FCNet(nn.Module): 75 | """ MLP with fully-connected layers.""" 76 | def __init__(self, config, in_features, out_features=None): 77 | super().__init__() 78 | layers = [] 79 | fc_act = act_func(config['fc_activation']) 80 | self.input_dims = in_features 81 | num_hids = config['fc_hiddens'] 82 | if out_features: 83 | num_hids.append(out_features) 84 | for i, num_hid in enumerate(num_hids): 85 | fc_layer = nn.Linear(in_features=in_features, out_features=num_hid) 86 | layers.append(fc_layer) 87 | if i < len(num_hids) - 1: 88 | layers.append(fc_act()) 89 | if config.get('fc_batch_norm', False): 90 | layers.append(nn.BatchNorm1d(num_hid)) 91 | else: 92 | output_act = act_func(config['output_activation']) 93 | layers.append(output_act()) 94 | in_features = num_hid 95 | if config.get('layer_norm_output', False): 96 | layers.append(nn.LayerNorm([in_features], elementwise_affine=config.get('layer_norm_affine', False))) 97 | self.fc_net = nn.Sequential(*layers) 98 | self.output_dims = in_features 99 | 100 | def forward(self, x): 101 | assert x.shape[-1] == self.input_dims, "Last dim is {} but should be {}".format(x.shape[-1], self.input_dims) 102 | orig_shape = list(x.shape) 103 | x = x.view(-1, self.input_dims) 104 | x = self.fc_net(x) 105 | orig_shape[-1] = self.output_dims 106 | x = x.view(*orig_shape) 107 | return x 108 | 109 | 110 | class CNN(nn.Module): 111 | """ CNN, optionally followed by fully connected layers.""" 112 | def __init__(self, config, chw): 113 | super().__init__() 114 | self.config = config 115 | channels, height, width = chw 116 | cnn_layers, channels, height, width = self.make_conv_layers((channels, height, width), config['conv_filters']) 117 | self.conv_net = nn.Sequential(*cnn_layers) 118 | num_features = channels * height * width 119 | if 'fc_hiddens' in config: 120 | fc_layers, num_features = self.make_fc_layers(num_features) 121 | self.fc_net = nn.Sequential(*fc_layers) 122 | else: 123 | self.fc_net = None 124 | self.output_dims = num_features 125 | 126 | def make_conv_layers(self, input_chw, conv_filters): 127 | channels, height, width = input_chw 128 | conv_act = act_func(self.config['conv_activation']) 129 | base_num_hid = self.config['base_num_hid'] 130 | layers = [] 131 | for i, filter_spec in enumerate(conv_filters): 132 | if len(filter_spec) == 3: # Padding is not specified. 133 | num_filters, kernel_size, stride = filter_spec 134 | padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) 135 | elif len(filter_spec) == 4: 136 | num_filters, kernel_size, stride, padding = filter_spec 137 | num_filters = num_filters * base_num_hid 138 | conv_layer = nn.Conv2d(channels, out_channels=num_filters, kernel_size=kernel_size, stride=stride, 139 | padding=padding) 140 | height = (height - kernel_size[0] + 2 * padding[0]) // stride + 1 141 | width = (width - kernel_size[1] + 2 * padding[1]) // stride + 1 142 | layers.append(conv_layer) 143 | layers.append(conv_act()) 144 | if self.config['conv_batch_norm']: 145 | bn = nn.BatchNorm2d(num_filters) 146 | layers.append(bn) 147 | channels = num_filters 148 | return layers, channels, height, width 149 | 150 | def make_fc_layers(self, in_features): 151 | fc_act = act_func(self.config['fc_activation']) 152 | base_num_hid = self.config['base_num_hid'] 153 | layers = [] 154 | for i, num_hid in enumerate(self.config['fc_hiddens']): 155 | num_hid = num_hid * base_num_hid 156 | fc_layer = nn.Linear(in_features=in_features, out_features=num_hid) 157 | layers.append(fc_layer) 158 | if i < len(self.config['fc_hiddens']) - 1: 159 | layers.append(fc_act()) 160 | if self.config.get('fc_batch_norm', False): 161 | layers.append(nn.BatchNorm1d(num_hid)) 162 | else: 163 | if self.config.get('layer_norm_output', False): 164 | layers.append(nn.LayerNorm([num_hid], elementwise_affine=False)) 165 | output_act = act_func(self.config['output_activation']) 166 | layers.append(output_act()) 167 | in_features = num_hid 168 | return layers, in_features 169 | 170 | def forward(self, x): 171 | """ 172 | Args: 173 | x : (..., C, H, W) 174 | Output: 175 | x: (..., D) 176 | """ 177 | orig_shape = list(x.size()) 178 | x = x.view(-1, *orig_shape[-3:]) 179 | x = self.conv_net(x) 180 | if self.fc_net is not None: 181 | x = x.view(x.size(0), -1) # Flatten. 182 | x = self.fc_net(x) 183 | x = x.view(*orig_shape[:-3], x.shape[-1]) 184 | else: 185 | x = x.view(*orig_shape[:-3], *x.shape[-3:]) 186 | return x 187 | 188 | 189 | class TransposeCNN(nn.Module): 190 | """Decode images from a vector. 191 | """ 192 | 193 | def __init__(self, config, input_size): 194 | """Initializes a ConvDecoder instance. 195 | 196 | Args: 197 | input_size (int): Input size, usually feature size output from 198 | RSSM. 199 | depth (int): Number of channels in the first conv layer 200 | act (Any): Activation for Encoder, default ReLU 201 | shape (List): Shape of observation input 202 | """ 203 | super().__init__() 204 | 205 | layers = [] 206 | in_features = input_size 207 | fc_act = act_func(config['fc_activation']) 208 | base_num_hid = config['base_num_hid'] 209 | for i, num_hid in enumerate(config['fc_hiddens']): 210 | num_hid = num_hid * base_num_hid 211 | fc_layer = nn.Linear(in_features=in_features, out_features=num_hid) 212 | layers.append(fc_layer) 213 | layers.append(fc_act()) 214 | if config.get('fc_batch_norm', False): 215 | layers.append(nn.BatchNorm1d(num_hid)) 216 | in_features = num_hid 217 | self.fc_net = nn.Sequential(*layers) 218 | 219 | filters = config['conv_filters'] 220 | act = act_func(config['conv_activation']) 221 | output_act = act_func(config['output_activation']) 222 | bn = config.get('conv_batch_norm', False) 223 | in_channels = in_features 224 | layers = [] 225 | in_size = [1, 1] 226 | for i, (out_channels, kernel, stride) in enumerate(filters): 227 | if i < len(filters) - 1: 228 | out_channels = out_channels * base_num_hid 229 | layer = nn.ConvTranspose2d(in_channels, out_channels, kernel, stride=stride) 230 | layers.append(layer) 231 | out_size = [kernel[0] + (in_size[0] - 1) * stride, kernel[1] + (in_size[1] - 1) * stride] 232 | if i < len(filters) - 1: # Don't put batch norm in the last layer, potentially different activation func. 233 | if act is not None: 234 | layers.append(act()) 235 | if bn: 236 | layers.append(nn.BatchNorm2d(out_channels)) 237 | else: 238 | layers.append(output_act()) 239 | in_channels = out_channels 240 | in_size = out_size 241 | 242 | self.conv_transpose = nn.Sequential(*layers) 243 | self.out_size = in_size 244 | self.out_channels = in_channels 245 | 246 | def forward(self, x): 247 | """ 248 | Args: 249 | x: (..., D) 250 | Output: 251 | x : (..., C, H, W) 252 | """ 253 | orig_shape = list(x.size()) 254 | x = x.view(-1, orig_shape[-1]) 255 | x = self.fc_net(x) 256 | C = x.shape[-1] 257 | x = x.view(-1, C, 1, 1) 258 | x = self.conv_transpose(x) 259 | out_shape = list(x.size()) 260 | res_shape = orig_shape[:-1] + out_shape[1:] 261 | x = x.view(*res_shape) 262 | return x 263 | 264 | 265 | class ObservationModel(nn.Module): 266 | """ 267 | Module that encapsulates the observation encoder (and optionally, decoder). 268 | """ 269 | def __init__(self, config, image_chw, other_dims=0): 270 | """ 271 | Inputs: 272 | image_chw: (channels, height, width) for the image observations. 273 | other_dims : int, any extra dimensions, e.g. proprioceptive state. 274 | """ 275 | super().__init__() 276 | 277 | # Gating network. 278 | self.use_gating_network = config.get('use_gating_network', False) 279 | if self.use_gating_network: 280 | self.gating_net = CNN(config['encoder_gating'], image_chw) 281 | self.gating_net_channels = image_chw[0] 282 | 283 | # Image encoder. 284 | self.encoder = CNN(config['encoder'], image_chw) 285 | self.output_dims = self.encoder.output_dims 286 | 287 | # (Optional) Proprioceptive state encoder. 288 | if 'other_obs_model' in config: 289 | print('Proprioceptive state has {} dims'.format(other_dims)) 290 | self.other_model = FCNet(config['other_obs_model'], other_dims) 291 | self.output_dims += self.other_model.output_dims 292 | else: 293 | self.other_model = None 294 | 295 | if 'decoder' in config: 296 | # Image predictor P(x_t | z_t, h_t) 297 | self.decoder = TransposeCNN(config['decoder'], self.output_dims) 298 | else: 299 | self.decoder = None 300 | 301 | self.apply(weight_init) 302 | 303 | def encode_pixels(self, x, encoder=None): 304 | if self.use_gating_network: 305 | orig_shape = list(x.shape) 306 | input_channels = orig_shape[-3] 307 | num_frames = input_channels // self.gating_net_channels 308 | # 9 input_channels = 3 frames * 3 (RGB) channels. 309 | # Gating net is shared across stacked frames. 310 | x = x.view(*orig_shape[:-3], num_frames, self.gating_net_channels, *orig_shape[-2:]) 311 | g = torch.sigmoid(self.gating_net(x)) 312 | assert g.shape[-3] == 1 # Gating should give one scalar per location. 313 | x = x * g 314 | x = x.view(*orig_shape) 315 | else: 316 | g = None 317 | 318 | if encoder is None: 319 | x = self.encoder(x) 320 | else: 321 | x = encoder(x) 322 | return x, g 323 | 324 | def encode(self, batch): 325 | """ 326 | Encode the observations. 327 | Inputs: 328 | batch: dict containing 'obs_image' and optionally, other observations. 329 | Returns: 330 | outputs: dict containing 'obs_encoding', and optionally 'obs_gating'. 331 | """ 332 | outputs = {} 333 | x, g = self.encode_pixels(batch['obs_image']) 334 | if self.use_gating_network: 335 | outputs['obs_gating'] = g 336 | 337 | if self.other_model is not None: 338 | other_obs = batch['obs_other'] 339 | other_encoding = self.other_model(other_obs) 340 | x = torch.cat([x, other_encoding], dim=-1) 341 | outputs['obs_features'] = x 342 | return outputs 343 | 344 | def decode(self, x): 345 | assert self.decoder is not None 346 | return self.decoder(x) 347 | 348 | def forward(self, x): 349 | return self.encode(x) 350 | 351 | 352 | class TanhTransform(pyd.transforms.Transform): 353 | domain = pyd.constraints.real 354 | codomain = pyd.constraints.interval(-1.0, 1.0) 355 | bijective = True 356 | sign = +1 357 | 358 | def __init__(self, cache_size=1): 359 | super().__init__(cache_size=cache_size) 360 | 361 | @staticmethod 362 | def atanh(x): 363 | return 0.5 * (x.log1p() - (-x).log1p()) 364 | 365 | def __eq__(self, other): 366 | return isinstance(other, TanhTransform) 367 | 368 | def _call(self, x): 369 | return x.tanh() 370 | 371 | def _inverse(self, y): 372 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 373 | # one should use `cache_size=1` instead 374 | return self.atanh(y) 375 | 376 | def log_abs_det_jacobian(self, x, y): 377 | # We use a formula that is more numerically stable, see details in the following link 378 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 379 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 380 | 381 | 382 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 383 | def __init__(self, loc, scale): 384 | self.loc = loc 385 | self.scale = scale 386 | 387 | try: 388 | self.base_dist = pyd.Normal(loc, scale) 389 | except Exception as e: 390 | print(e) 391 | print('Loc mean: ', loc.mean()) 392 | print('Loc: ', loc) 393 | raise e 394 | transforms = [TanhTransform()] 395 | super().__init__(self.base_dist, transforms) 396 | 397 | @property 398 | def mean(self): 399 | mu = self.loc 400 | for tr in self.transforms: 401 | mu = tr(mu) 402 | return mu 403 | 404 | 405 | class TanhGaussianPolicy(nn.Module): 406 | def __init__(self, config, state_dims, action_dims): 407 | super().__init__() 408 | self.fc_net = FCNet(config, state_dims, action_dims * 2) 409 | self.log_std_min = config.get("policy_min_logstd", -10) 410 | self.log_std_max = config.get("policy_max_logstd", 2) 411 | self.tiny = 1.e-7 412 | self.clip = 1 - self.tiny 413 | self.apply(weight_init) 414 | 415 | def forward(self, x, sample=True): 416 | x = self.fc_net(x) 417 | mu, log_std = torch.chunk(x, 2, dim=-1) 418 | mu = 5 * torch.tanh(mu / 5) 419 | log_std = torch.sigmoid(log_std) 420 | log_std = self.log_std_min + (self.log_std_max - self.log_std_min) * log_std 421 | std = torch.exp(log_std) 422 | action_dist = SquashedNormal(mu, std) 423 | if sample: 424 | action_sample = action_dist.rsample() 425 | action_sample = action_sample.clamp(-self.clip, self.clip) 426 | else: 427 | action_sample = torch.tanh(mu) 428 | log_prob = action_dist.log_prob(action_sample).sum(dim=-1) 429 | return action_sample, log_prob, std 430 | 431 | 432 | class DoubleCritic(nn.Module): 433 | def __init__(self, config, state_dims, action_dims): 434 | super().__init__() 435 | self.critic1 = FCNet(config, state_dims + action_dims) 436 | self.critic2 = FCNet(config, state_dims + action_dims) 437 | self.state_dims = state_dims 438 | self.action_dims = action_dims 439 | self.apply(weight_init) 440 | 441 | def forward(self, states, actions): 442 | assert states.shape[-1] == self.state_dims 443 | assert actions.shape[-1] == self.action_dims 444 | inputs = torch.cat([states, actions], dim=-1) 445 | q1 = self.critic1(inputs).squeeze(-1) 446 | q2 = self.critic2(inputs).squeeze(-1) 447 | return q1, q2 448 | 449 | 450 | class ContrastivePrediction(nn.Module): 451 | def __init__(self, config, obs_dims): 452 | super().__init__() 453 | init_inverse_temp = config['inverse_temperature_init'] 454 | inverse_temp = torch.tensor([float(init_inverse_temp)]) 455 | self.inverse_temp = nn.parameter.Parameter(inverse_temp) 456 | self.output_dims = obs_dims 457 | self.softmax_over = config.get('softmax_over', 'both') # ['obs', 'symbol', 'both'] 458 | mask_type = config.get('mask_type', None) 459 | if mask_type is None: 460 | self.mask = None 461 | else: 462 | if mask_type == 'exclude_same_sequence': 463 | mask = self.get_exclude_same_sequence_mask(31, 32) 464 | elif mask_type == 'exclude_other_sequences': 465 | mask = self.get_exclude_other_sequences_mask(31, 32) 466 | else: 467 | raise Exception('Unknown mask type') 468 | self.register_buffer('mask', mask) 469 | 470 | def get_exclude_same_sequence_mask(self, T, B): 471 | """ 472 | Exclude other timesteps from the same sequence. 473 | mask[i, j] = True means replace that by -inf. 474 | """ 475 | mask = torch.zeros(T, B, T, B) 476 | per_b_mask = 1 - torch.ones(T, T) 477 | for b in range(B): 478 | mask[:, b, :, b] = per_b_mask 479 | mask = mask.view(T*B, T*B) 480 | mask = mask == 1 481 | return mask 482 | 483 | def get_exclude_other_sequences_mask(self, T, B): 484 | """ 485 | Exclude other sequences in the batch. 486 | mask[i, j] = True means replace that by -inf. 487 | """ 488 | mask = torch.ones(T, B, T, B) 489 | for b in range(B): 490 | mask[:, b, :, b] = 0 491 | mask = mask.view(T*B, T*B) 492 | mask = mask == 1 493 | return mask 494 | 495 | def forward(self, x, y, train=False): 496 | """ 497 | Inputs: 498 | x: (..., D) encoder features. 499 | y: (..., D) decoder features. 500 | Outputs: 501 | loss: contrastive loss. 502 | """ 503 | x = x.view(-1, self.output_dims) # (T * B, D) 504 | y = y.view(-1, self.output_dims) # (T * B, D) 505 | inv_temp = F.softplus(self.inverse_temp) 506 | logits = inv_temp * torch.matmul(x, y.T) # (B', B') 507 | if self.mask is not None and train: 508 | logits[self.mask] = float('-inf') 509 | log_probs1 = F.log_softmax(logits, dim=1) 510 | log_probs2 = F.log_softmax(logits, dim=0) 511 | loss1 = -(log_probs1.diagonal().mean()) 512 | loss2 = -(log_probs2.diagonal().mean()) 513 | if self.softmax_over == 'symbol': 514 | loss = loss1 515 | elif self.softmax_over == 'obs': 516 | loss = loss2 517 | else: 518 | loss = (loss1 + loss2) / 2 519 | return loss 520 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | """ Replay Buffer for storing sequential data""" 6 | import random 7 | import torch 8 | from utils import torchify 9 | 10 | 11 | class SequenceReplayBuffer(object): 12 | def __init__(self, size=None): 13 | self.data = [] 14 | self.size = size 15 | self.index = 0 16 | 17 | def __len__(self): 18 | return len(self.data) 19 | 20 | def add(self, seq): 21 | """ 22 | seq : list of dict of key, value, where value is numpy array or float/int. 23 | """ 24 | seq = torchify(seq) # dict of (T, ...) 25 | if self.size is None or len(self.data) < self.size: 26 | self.data.append(seq) 27 | else: 28 | self.data[self.index] = seq 29 | self.index += 1 30 | if self.index == self.size: 31 | self.index = 0 32 | 33 | def sample(self, num_seq, seq_len=0): 34 | """ 35 | Sample a batch from the replay buffer. 36 | Args: 37 | num_seq: Batch size 38 | seq_len: Length of each sequence. Default=0 means pick the entire sequence (i.e. seq_len=T) 39 | Returns: 40 | res : dict of tensors of shape (num_seq, seq_len, *entity_shape) 41 | """ 42 | # Pick seq_ids. 43 | seq_count = len(self.data) 44 | inds = list(range(seq_count)) 45 | if num_seq < seq_count: 46 | inds = random.sample(inds, k=num_seq) 47 | elif num_seq > seq_count: 48 | inds = random.choices(inds, k=num_seq) 49 | batch = [] 50 | for ind in inds: 51 | seq = self.data[ind] 52 | key = list(seq.keys())[0] 53 | T = len(seq[key]) 54 | if seq_len <= 0 or T <= seq_len: 55 | seq_sample = seq 56 | else: 57 | start_pos = random.randint(0, T - seq_len) 58 | seq_sample = {k: v[start_pos:start_pos + seq_len] for k, v in seq.items()} 59 | batch.append(seq_sample) 60 | # pack the batch into a dict of tensors. 61 | keys = batch[0].keys() 62 | res = {} 63 | for key in keys: 64 | res[key] = torch.stack([sample[key] for sample in batch]) # (B, T, ...) 65 | return res 66 | 67 | def all(self): 68 | return self.sample(len(self.data)) 69 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | apt-get update 6 | apt install -y less tmux psmisc curl git libgl1-mesa-dev libgl1-mesa-glx libglew-dev \ 7 | libosmesa6-dev software-properties-common net-tools unzip vim \ 8 | virtualenv wget xpra xserver-xorg-dev libglfw3-dev patchelf xvfb ffmpeg git 9 | 10 | # Download distraction videos. 11 | wget https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip 12 | unzip -q DAVIS-2017-trainval-480p.zip 13 | 14 | # Install MuJoCo 15 | mkdir ~/.mujoco 16 | 17 | # Mujoco 2.1.0 for dm_control 18 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 19 | tar -xvzf mujoco210-linux-x86_64.tar.gz 20 | mv mujoco210 ~/.mujoco/mujoco210 21 | 22 | # Install MuJoCo 2.0 23 | wget https://www.roboti.us/download/mujoco200_linux.zip 24 | wget https://roboti.us/file/mjkey.txt 25 | unzip mujoco200_linux.zip 26 | mv mujoco200_linux ~/.mujoco/mujoco200_linux 27 | mv mjkey.txt ~/.mujoco/ 28 | ln -s ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200 29 | export LD_LIBRARY_PATH=$HOME/.mujoco/mujoco200/bin:$LD_LIBRARY_PATH 30 | 31 | # Put 10_nvidia.json in the right place. 32 | # This is needed to make the renderer use gpu. 33 | cp 10_nvidia.json /usr/share/glvnd/egl_vendor.d/ 34 | 35 | # Mujoco-py uses a very brittle test to determine whether to use gpu rendering. 36 | # It looks for a directory named /usr/lib/nvidia-xxx (but doesn't really need or use any libraries present there). 37 | # So we just create a dummy one here. 38 | mkdir -p /usr/lib/nvidia-000 39 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia-000 40 | 41 | # Create conda environment. 42 | conda config --set remote_read_timeout_secs 600 43 | conda env create -f conda_env_robosuite.yml 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from modules import TanhGaussianPolicy, FCNet, ObservationModel, DoubleCritic, ContrastivePrediction 9 | from world_model import WorldModel 10 | import numpy as np 11 | from replay_buffer import SequenceReplayBuffer 12 | import torch.optim 13 | import argparse 14 | import yaml 15 | import os 16 | from torch.utils.tensorboard import SummaryWriter 17 | import time 18 | import random 19 | from utils import write_video_mp4, torchify, add_weight_decay, generate_expt_id, crop_image_tensor, get_parameter_list 20 | from environments import make_environment 21 | from copy import deepcopy 22 | import sys 23 | 24 | 25 | class Trainer(object): 26 | """ Trainer for all models. """ 27 | def __init__(self, config, device, debug): 28 | self.config = config 29 | self.device = device 30 | self.debug = debug 31 | 32 | # Where should artifacts be written out. 33 | artifact_dir = config.get('artifact_dir', os.environ.get('BOLT_ARTIFACT_DIR', 'artifacts')) 34 | if not debug: 35 | self.log_dir = os.path.join(artifact_dir, config['expt_id']) 36 | os.makedirs(self.log_dir, exist_ok=True) 37 | config_filename = os.path.join(self.log_dir, 'config.yaml') 38 | with open(config_filename, 'w') as f: 39 | yaml.dump(config, f, sort_keys=True) 40 | print('Results will be logged to {}'.format(self.log_dir)) 41 | self.tb_writer = SummaryWriter(self.log_dir) 42 | 43 | self.num_envs = self.config['num_envs'] 44 | self.num_val_envs = self.config['num_val_envs'] 45 | seed = self.config['seed'] 46 | self.train_env_containers = [make_environment(self.config['env'], train=True, seed=seed+i) for i in range(self.num_envs)] 47 | seed += self.num_envs 48 | self.val_env_containers = [make_environment(self.config['env'], train=False, seed=seed+i) for i in range(self.num_val_envs)] 49 | env = self.train_env_containers[0] 50 | self.action_repeat = env.get_action_repeat() 51 | action_dims = env.get_action_dims() 52 | self.obs_channels, self.obs_height, self.obs_width = env.get_obs_chw() 53 | self.obs_other_dims = env.get_obs_other_dims() 54 | 55 | # Setup the observation encoder. 56 | self.crop_height = config['crop_height'] 57 | self.crop_width = config['crop_width'] 58 | self.same_crop_across_time = self.config.get('same_crop_across_time', False) 59 | self.random_crop_padding = self.config.get('random_crop_padding', 0) 60 | chw = (self.obs_channels, self.crop_height, self.crop_width) 61 | self.observation_model = ObservationModel(config, chw, self.obs_other_dims) 62 | obs_dims = self.observation_model.output_dims 63 | 64 | # Setup Contrastive Prediction. 65 | self.contrastive_prediction = ContrastivePrediction(config['contrastive'], obs_dims) 66 | 67 | # Setup the recurrent dynamics model. 68 | if 'wm' in config: 69 | self.model = WorldModel(config['wm'], obs_dims, action_dims) 70 | state_dims = self.model.state_dims 71 | self.exclude_wm_loss = config.get('exclude_wm_loss', False) 72 | else: # For models like plain SAC. 73 | self.model = None 74 | state_dims = obs_dims 75 | self.exclude_wm_loss = True 76 | 77 | # Setup Actor and Critic. 78 | self.actor = TanhGaussianPolicy(config['actor'], state_dims, action_dims) 79 | self.critic = DoubleCritic(config['critic'], state_dims, action_dims) 80 | self.log_alpha = nn.parameter.Parameter(torch.tensor([float(self.config['initial_log_alpha'])], device=device)) 81 | self.target_entropy = self.config.get('target_entropy', -action_dims) 82 | 83 | # Initialization. 84 | if 'initial_model_path' in config: 85 | model_path = os.path.join(artifact_dir, config['initial_model_path']) 86 | self.load(model_path) 87 | 88 | if self.model is not None: 89 | self.model = self.model.to(device) 90 | self.observation_model = self.observation_model.to(device) 91 | self.contrastive_prediction = self.contrastive_prediction.to(device) 92 | self.actor = self.actor.to(device) 93 | self.critic = self.critic.to(device) 94 | 95 | self.has_momentum_encoder = config.get('momentum_encoder', False) 96 | 97 | # Set up optimizers. 98 | # Model optimizer. 99 | params = add_weight_decay(self.observation_model, config['weight_decay']) 100 | if self.model is not None: 101 | params.extend(add_weight_decay(self.model, config['weight_decay'])) 102 | contrastive_pred_params = list(self.contrastive_prediction.parameters()) 103 | params.append({'params': contrastive_pred_params, 'lr': config['lr_inverse_temp'], 'weight_decay': 0.0}) 104 | self.optimizer = torch.optim.Adam(params, lr=config['lr'], betas=(0.9, 0.999)) 105 | 106 | # Actor optimizer. 107 | actor_params = add_weight_decay(self.actor, config['weight_decay']) 108 | self.optimizer_actor = torch.optim.Adam(actor_params, lr=config['lr_actor'], betas=(config.get('momentum_actor', 0.9), 0.999)) 109 | 110 | # Critic optimizer. 111 | critic_params = add_weight_decay(self.critic, config['weight_decay']) 112 | if config.get('include_model_params_in_critic', False): # Include model params in critic. 113 | critic_params.extend(params) 114 | self.sac_detach_states = False 115 | else: 116 | self.sac_detach_states = True 117 | self.optimizer_critic = torch.optim.Adam(critic_params, lr=config['lr_critic'], betas=(config.get('momentum_critic', 0.9), 0.999)) 118 | self.critic_optimizer_parameter_list = get_parameter_list(self.optimizer_critic) 119 | self.target_critic = deepcopy(self.critic).to(device) 120 | 121 | # Adaptive temperature optimizer for SAC. 122 | self.optimizer_alpha = torch.optim.Adam([self.log_alpha], lr=config['lr_alpha'], 123 | betas=(config.get('momentum_alpha', 0.9), 0.999)) 124 | 125 | if self.has_momentum_encoder: 126 | self.target_encoder = deepcopy(self.observation_model.encoder) 127 | self.moco_dims = self.target_encoder.output_dims 128 | self.moco_W = nn.Parameter(torch.rand(self.moco_dims, self.moco_dims, device=self.device)) 129 | curl_params = list(self.observation_model.encoder.parameters()) 130 | curl_params.append(self.moco_W) 131 | self.curl_optimizer = torch.optim.Adam(curl_params, lr=config['lr_curl'], betas=(0.9, 0.999)) 132 | 133 | self.optimizer_parameter_list = get_parameter_list(self.optimizer) 134 | 135 | # Whether SAC uses the mean or sampled state. 136 | self.sac_deterministic_state = config.get('sac_deterministic_state', False) 137 | 138 | # Whether to decode from the prior or posterior. 139 | self.recon_from_prior = self.config.get('recon_from_prior', False) 140 | 141 | # For saving the best model. 142 | self.best_loss = None 143 | 144 | def save(self, step, loss): 145 | """ Save a checkpoint. """ 146 | if self.debug: 147 | return 148 | checkpoint_dir = os.path.join(self.log_dir, 'checkpoint_%08d' % step) 149 | if not os.path.isdir(checkpoint_dir): 150 | os.makedirs(checkpoint_dir) 151 | filename = os.path.join(checkpoint_dir, 'model.pt') 152 | print('Saved model to {}'.format(filename)) 153 | info = { 154 | 'observation_model_state_dict': self.observation_model.state_dict(), 155 | 'contrastive_prediction_state_dict': self.contrastive_prediction.state_dict(), 156 | 'actor_state_dict': self.actor.state_dict(), 157 | 'critic_state_dict': self.critic.state_dict(), 158 | 'log_alpha': self.log_alpha.item(), 159 | 'loss': loss, 160 | 'step': step 161 | } 162 | if self.model is not None: 163 | info['model_state_dict'] = self.model.state_dict() 164 | torch.save(info, filename) 165 | if self.best_loss is None or loss < self.best_loss: 166 | self.best_loss = loss 167 | filename = os.path.join(self.log_dir, 'best_model.pt') 168 | print('Saved model to {}'.format(filename)) 169 | torch.save(info, filename) 170 | 171 | def load(self, filename, skip_actor_critic=False): 172 | info = torch.load(filename, map_location=torch.device('cpu')) 173 | missing, unexpected = self.observation_model.load_state_dict(info['observation_model_state_dict'], strict=False) 174 | print('Missing keys', missing) 175 | print('Unexpected keys', unexpected) 176 | if 'contrastive_prediction_state_dict' in info: 177 | self.contrastive_prediction.load_state_dict(info['contrastive_prediction_state_dict']) 178 | if not skip_actor_critic: 179 | self.actor.load_state_dict(info['actor_state_dict']) 180 | self.critic.load_state_dict(info['critic_state_dict']) 181 | self.log_alpha[0] = info['log_alpha'] 182 | if self.model is not None: 183 | self.model.load_state_dict(info['model_state_dict']) 184 | 185 | def normalize(self, obs): 186 | return obs.float() / 255 187 | 188 | def unnormalize(self, obs): 189 | obs = obs[..., -3:, :, :] # Select the last three channels. Sometimes we have stacked frames. 190 | return torch.clamp(obs * 255, 0, 255).to(torch.uint8) 191 | 192 | def forward_prop(self, batch, decoding_for_viz=False): 193 | """ 194 | Fprop the batch through the model. 195 | """ 196 | outputs = self.observation_model(batch) # (T, B, dims) 197 | batch['obs_features'] = outputs['obs_features'] 198 | 199 | # Fprop through the world model. 200 | outputs.update(self.model(batch)) 201 | 202 | if self.model.decoder is not None: # If the world model uses decoding. 203 | if self.recon_from_prior: 204 | obs_features_recon = outputs['obs_features_recon_prior'] 205 | else: 206 | obs_features_recon = outputs['obs_features_recon_post'] 207 | outputs['obs_features_recon'] = obs_features_recon 208 | 209 | if self.observation_model.decoder is not None: # If decoding all the way to pixels. 210 | if self.config.get('detach_pixel_decoder', False): 211 | obs_features_recon = obs_features_recon.detach() 212 | obs_recon = self.observation_model.decode(obs_features_recon) 213 | outputs['obs_recon'] = obs_recon 214 | if decoding_for_viz: 215 | # We want to visualize recon from both prior and posterior latent state. 216 | # One of them is computed above, so the other is computed here. 217 | if self.recon_from_prior: 218 | outputs['obs_recon_prior'] = obs_recon 219 | obs_features_recon = outputs['obs_features_recon_post'].detach() 220 | outputs['obs_recon_post'] = self.observation_model.decode(obs_features_recon) 221 | else: 222 | outputs['obs_recon_post'] = obs_recon 223 | obs_features_recon = outputs['obs_features_recon_prior'].detach() 224 | outputs['obs_recon_prior'] = self.observation_model.decode(obs_features_recon) 225 | return outputs 226 | 227 | def loss_reward(self, batch, outputs, loss, metrics): 228 | loss_scales = self.config['loss_scales'] 229 | if loss_scales['eta_r'] == 0: 230 | metrics['loss_reward'] = 0. 231 | else: 232 | reward_prediction = outputs['reward_prediction'][1:] # The reward is the reward at time t. 233 | reward = batch['reward'][:-1] # reward at index t corresponds to state t+1. 234 | loss_reward = nn.functional.smooth_l1_loss(reward_prediction, reward).mean() 235 | metrics['loss_reward'] = loss_reward.item() 236 | loss = loss + loss_reward * loss_scales['eta_r'] 237 | return loss, metrics 238 | 239 | def loss_fwd_dynamics(self, batch, outputs, loss, metrics): 240 | loss_scales = self.config['loss_scales'] 241 | if loss_scales['eta_fwd'] == 0: 242 | metrics['loss_fwd'] = 0. 243 | else: 244 | posterior_detached = {k: v.detach() for k, v in outputs['posterior'].items()} 245 | 246 | # skip t=0, because prior is uninformative there. No reason why posterior should match that. 247 | loss_fwd = self.model.dynamics.compute_forward_dynamics_loss(posterior_detached, outputs['prior'])[1:] 248 | loss_fwd = loss_fwd.mean() 249 | 250 | if 'eta_q' in loss_scales and loss_scales['eta_q'] > 0.0: 251 | eta_q = loss_scales['eta_q'] 252 | prior_detached = {k: v.detach() for k, v in outputs['prior'].items()} 253 | loss_fwd_q = self.model.dynamics.compute_forward_dynamics_loss(outputs['posterior'], prior_detached)[1:] 254 | loss_fwd_q = loss_fwd_q.mean() 255 | loss_fwd = (1 - eta_q) * loss_fwd + loss_scales['eta_q'] * loss_fwd_q 256 | metrics['loss_fwd'] = loss_fwd.item() 257 | loss = loss + loss_fwd * loss_scales['eta_fwd'] 258 | return loss, metrics 259 | 260 | def loss_contrastive(self, batch, outputs, train, loss, metrics): 261 | loss_scales = self.config['loss_scales'] 262 | eta_s = loss_scales['eta_s'] 263 | if eta_s > 0: 264 | x = outputs['obs_features'] 265 | y = outputs['obs_features_recon'] 266 | if self.recon_from_prior: # Skip t=0 when reconstructing from prior, because prior knows nothing at t=0. 267 | x = x[1:] 268 | y = y[1:] 269 | loss_contrastive = self.contrastive_prediction(x, y, train=train) 270 | loss = loss + loss_contrastive * eta_s 271 | metrics['loss_contrastive'] = loss_contrastive.item() 272 | return loss, metrics 273 | 274 | def loss_observation_recon(self, batch, outputs, loss, metrics): 275 | if 'obs_recon' not in outputs: 276 | return loss, metrics 277 | obs_recon = outputs['obs_recon'] 278 | obs = batch.get('obs_image_clean', batch['obs_image']) 279 | if self.recon_from_prior: # Skip t=0. 280 | obs_recon = obs_recon[1:] 281 | obs = obs[1:] 282 | loss_obs = nn.functional.smooth_l1_loss(obs_recon, obs).mean() 283 | loss_scales = self.config['loss_scales'] 284 | metrics['loss_obs'] = loss_obs.item() 285 | loss = loss + loss_scales['eta_x'] * loss_obs 286 | return loss, metrics 287 | 288 | def loss_inv_dynamics(self, batch, outputs, loss, metrics): 289 | if 'action_prediction' not in outputs: 290 | return loss, metrics 291 | loss_scales = self.config['loss_scales'] 292 | if loss_scales['eta_inv'] == 0: 293 | metrics['loss_inverse_dynamics'] = 0. 294 | else: 295 | action_prediction = outputs['action_prediction'] # (T-1, B, a_dims) tanh valued. 296 | action = batch['action'][:-1] 297 | loss_inverse_dynamics = 0.5 * ((action - action_prediction) ** 2).sum(dim=-1).mean() 298 | loss = loss + loss_scales['eta_inv'] * loss_inverse_dynamics 299 | metrics['loss_inverse_dynamics'] = loss_inverse_dynamics.item() 300 | return loss, metrics 301 | 302 | def compute_loss(self, batch, train, decoding_for_viz=False): 303 | outputs = self.forward_prop(batch, decoding_for_viz=decoding_for_viz) 304 | metrics = {} 305 | loss = 0 306 | loss, metrics = self.loss_reward(batch, outputs, loss, metrics) 307 | loss, metrics = self.loss_inv_dynamics(batch, outputs, loss, metrics) 308 | loss, metrics = self.loss_fwd_dynamics(batch, outputs, loss, metrics) 309 | loss, metrics = self.loss_contrastive(batch, outputs, train, loss, metrics) 310 | loss, metrics = self.loss_observation_recon(batch, outputs, loss, metrics) 311 | metrics['loss_total'] = loss.item() 312 | return loss, metrics, outputs 313 | 314 | def update_curl(self, batch, step, heavy_logging=False): 315 | with torch.no_grad(): 316 | f_k, _ = self.observation_model.encode_pixels(batch['obs_image_2'], encoder=self.target_encoder) 317 | f_k = f_k.detach().view(-1, self.moco_dims) 318 | f_q, _ = self.observation_model.encode_pixels(batch['obs_image']) 319 | f_q = f_q.view(-1, self.moco_dims) 320 | f_proj = torch.matmul(f_k, self.moco_W) 321 | logits = torch.matmul(f_q, f_proj.T) 322 | log_probs = F.log_softmax(logits, dim=1) 323 | loss = -(log_probs.diagonal().mean()) 324 | metrics = {'moco_loss' : loss.item()} 325 | 326 | self.curl_optimizer.zero_grad() 327 | loss.backward() 328 | self.curl_optimizer.step() 329 | 330 | # Do momentum update. 331 | tau = self.config.get('update_target_encoder_tau', 1) 332 | self.update_target(self.target_encoder, self.observation_model.encoder, tau) 333 | 334 | return metrics 335 | 336 | def update_world_model(self, batch, step, heavy_logging=False): 337 | """ 338 | Update the world model. 339 | batch : Dict containing keys ('action', 'obs_image', 'reward', etc) 340 | 'action' : (T, B action_dims) 341 | 'obs_image' : (T, B, C, H, W) 342 | 'reward': (T, B) 343 | """ 344 | loss, metrics, outputs = self.compute_loss(batch, train=True, decoding_for_viz=heavy_logging) 345 | self.optimizer.zero_grad() 346 | loss.backward() 347 | if 'max_grad_norm_wm' in self.config: 348 | grad_norm = torch.nn.utils.clip_grad_norm_(self.optimizer_parameter_list, self.config['max_grad_norm_wm']) 349 | metrics['grad_norm_wm'] = grad_norm.item() 350 | self.optimizer.step() 351 | 352 | if step % self.config['print_every'] == 0: 353 | loss_str = ' '.join(['{}: {:.2f}'.format(k, v) for k, v in sorted(metrics.items())]) 354 | print('Step {} {}'.format(step, loss_str)) 355 | if not self.debug and self.tb_writer is not None: 356 | for k, v in metrics.items(): 357 | self.tb_writer.add_scalar('metrics/{}'.format(k), v, step) 358 | if heavy_logging: 359 | max_B = 16 360 | self.tb_writer.add_video('obs/input', 361 | self.unnormalize(batch['obs_image']).transpose(0, 1)[:max_B], step) 362 | if self.observation_model.decoder is not None: 363 | self.tb_writer.add_video('obs/recon', 364 | self.unnormalize(outputs['obs_recon']).transpose(0, 1)[:max_B], step) 365 | self.tb_writer.add_video('obs/recon_post', 366 | self.unnormalize(outputs['obs_recon_post']).transpose(0, 1)[:max_B], step) 367 | self.tb_writer.add_video('obs/recon_prior', 368 | self.unnormalize(outputs['obs_recon_prior']).transpose(0, 1)[:max_B], step) 369 | return metrics, outputs 370 | 371 | def log_video(self, video_tag, frames, step): 372 | """ 373 | Log a video to disk. 374 | Args: 375 | frames : List of (B, T, C, H, W) 376 | step: training step. 377 | video_tag: tag used for logging into tensorboard and as dir name for disk. 378 | """ 379 | self.tb_writer.add_video(video_tag, frames, step) 380 | 381 | B, T, C, H, W = list(frames.shape) 382 | frames = frames.permute(1, 2, 3, 0, 4).contiguous().view(T, C, H, B*W) # Stack batch along width. 383 | video_dir = os.path.join(self.log_dir, video_tag) 384 | os.makedirs(video_dir, exist_ok=True) 385 | filename = os.path.join(video_dir, 'video_%08d.mp4' % step) 386 | write_video_mp4(filename, frames) 387 | 388 | def validate(self, step): 389 | self.observation_model.eval() 390 | if self.model is not None: 391 | self.model.eval() 392 | self.actor.eval() 393 | self.critic.eval() 394 | tic = time.time() 395 | metrics = {} 396 | # Collect data. One episode in each val environment. 397 | replay_buffer = SequenceReplayBuffer() 398 | num_episodes_per_val_env_for_reward = self.config.get('num_episodes_per_val_env_for_reward', 10) 399 | sample_policy = self.config.get('val_stochastic_policy', False) 400 | if sample_policy: 401 | print('Using stochastic policy for val') 402 | episode_reward = self.collect_data_from_actor(replay_buffer, 403 | num_episodes_per_env=num_episodes_per_val_env_for_reward, 404 | train=False, sample_policy=sample_policy) 405 | metrics['episode_reward'] = episode_reward 406 | 407 | # Take the first few episodes for computing the rest of the metrics. They are expensive to compute. 408 | num_episodes_for_model = self.config.get('num_episodes_val_for_model', 5) 409 | batch = replay_buffer.sample(num_episodes_for_model) 410 | batch = self.prep_batch(batch, random_crop=False) 411 | steps_per_episode = self.config['episode_steps'] // self.action_repeat 412 | 413 | if not self.exclude_wm_loss: 414 | with torch.no_grad(): 415 | loss, model_metrics, outputs = self.compute_loss(batch, train=False, decoding_for_viz=True) 416 | metrics.update(model_metrics) 417 | # Generate rollout from prior. 418 | if self.observation_model.decoder is not None: 419 | init_t = self.config['rollout_prior_init_t'] 420 | assert 0 < init_t < steps_per_episode - 1 421 | init_state = dict([(k, v[init_t-1]) for k, v in outputs['posterior'].items()]) 422 | prior = self.model.dynamics.rollout_prior(init_state, batch['action'][init_t:, ...], deterministic=False) 423 | # Decode to images. 424 | latent = self.model.dynamics.get_state(prior, deterministic=False) 425 | obs_recon_imagined = self.observation_model.decode(self.model.decoder(latent)) 426 | # Add the first init_t images from the posterior. 427 | obs_recon_imagined = torch.cat([outputs['obs_recon_prior'][:init_t, :], obs_recon_imagined], dim=0) 428 | 429 | elif self.observation_model.use_gating_network: # Even if model is None, we want outputs to have gating. 430 | with torch.no_grad(): 431 | outputs = self.observation_model(batch) # (T, B, dims) # Used to visualize gating. 432 | 433 | toc = time.time() 434 | metrics.update({ 435 | 'timing': toc - tic, 436 | }) 437 | 438 | loss_str = ' '.join(['{}: {:.2f}'.format(k, v) for k, v in sorted(metrics.items())]) 439 | print('Val Iter {} {}'.format(step, loss_str)) 440 | if not self.debug and self.tb_writer is not None: 441 | for k, v in metrics.items(): 442 | self.tb_writer.add_scalar('val_metrics/{}'.format(k), v, step) 443 | obs = self.unnormalize(batch['obs_image']).transpose(0, 1) # (B, T, C, H, W) 444 | if self.observation_model.use_gating_network: 445 | obs_gating = outputs['obs_gating'].transpose(0, 1) # (B, T, F, 1, H, W) 446 | obs_gating = obs_gating[:, :, -1, :, :, :] # The gating for the last frame. 447 | obs_gating = (obs_gating * 255).to(torch.uint8) 448 | obs_gating = obs_gating.expand_as(obs).contiguous() # replicate along RGB. 449 | obs = torch.cat([obs, obs_gating], dim=3) 450 | if self.model is not None and self.observation_model.decoder is not None: 451 | obs_recon = self.unnormalize(outputs['obs_recon']).transpose(0, 1) 452 | obs_recon_post = self.unnormalize(outputs['obs_recon_post']).transpose(0, 1) 453 | obs_recon_prior = self.unnormalize(outputs['obs_recon_prior']).transpose(0, 1) 454 | obs_recon_imagined = self.unnormalize(obs_recon_imagined).transpose(0, 1) 455 | obs = torch.cat([obs, obs_recon, obs_recon_post, obs_recon_prior, obs_recon_imagined], dim=3) 456 | self.log_video('obs/val', obs, step) 457 | return -episode_reward 458 | 459 | def collect_data_random_policy(self, replay_buffer, num_episodes_per_env=1, train=True): 460 | steps_per_episode = self.config['episode_steps'] // self.action_repeat 461 | env_containers = self.train_env_containers if train else self.val_env_containers 462 | total_reward = 0 463 | for env_container in env_containers: 464 | action_low, action_high = env_container.get_action_limits() 465 | action_dims = env_container.get_action_dims() 466 | for _ in range(num_episodes_per_env): 467 | obs = env_container.reset() 468 | seq = [] 469 | for _ in range(steps_per_episode): 470 | action = np.random.uniform(action_low, action_high, action_dims) 471 | next_obs, reward, _, _ = env_container.step(action) 472 | seq.append(dict(obs=obs, action=action, reward=reward)) 473 | obs = next_obs 474 | total_reward += reward 475 | replay_buffer.add(seq) 476 | avg_reward = total_reward / (num_episodes_per_env * len(env_containers)) 477 | return avg_reward 478 | 479 | def prep_batch(self, batch, random_crop=False): 480 | """ Prepare batch of data for input to the model. 481 | Inputs: 482 | batch : Dict containing 'obs', etc. 483 | Returns: 484 | batch: Same dict, but with images randomly cropped, moved to GPU, normalized. 485 | """ 486 | for key in batch.keys(): 487 | batch[key] = batch[key].to(self.device) 488 | obs_image_cropped = crop_image_tensor(batch['obs_image'], self.crop_height, self.crop_width, 489 | random_crop=random_crop, 490 | same_crop_across_time=self.same_crop_across_time, 491 | padding=self.random_crop_padding) 492 | if self.has_momentum_encoder: 493 | batch['obs_image_2'] = crop_image_tensor(batch['obs_image'], self.crop_height, self.crop_width, 494 | random_crop=random_crop, 495 | same_crop_across_time=self.same_crop_across_time, 496 | padding=self.random_crop_padding) 497 | 498 | if 'obs_image_clean' in batch: # When we have paired distraction-free and distracting obs. 499 | batch['obs_image_clean'] = crop_image_tensor(batch['obs_image_clean'], self.crop_height, self.crop_width, random_crop=False, same_crop_across_time=True, padding=0) 500 | else: 501 | batch['obs_image_clean'] = crop_image_tensor(batch['obs_image'], self.crop_height, self.crop_width, random_crop=False, same_crop_across_time=True, padding=0) 502 | 503 | batch['obs_image'] = obs_image_cropped 504 | if len(batch['obs_image'].shape) == 5: # (B, T, C, H, W) -> (T, B, C, H, W) 505 | swap_first_two_dims = True 506 | else: # (B, C, H, W) -> no change. 507 | swap_first_two_dims = False 508 | for key in batch.keys(): 509 | if swap_first_two_dims: 510 | batch[key] = batch[key].transpose(0, 1) 511 | batch[key] = batch[key].contiguous().float().detach() 512 | batch['obs_image'] = self.normalize(batch['obs_image']) 513 | if 'obs_image_clean' in batch: 514 | batch['obs_image_clean'] = self.normalize(batch['obs_image_clean']) 515 | if 'obs_imaage_2' in batch: 516 | batch['obs_image_2'] = self.normalize(batch['obs_image_2']) 517 | return batch 518 | 519 | def collect_data_from_actor(self, replay_buffer, num_episodes_per_env=1, train=True, sample_policy=True): 520 | steps_per_episode = self.config['episode_steps'] // self.action_repeat 521 | self.observation_model.eval() 522 | if self.model is not None: 523 | self.model.eval() 524 | self.actor.eval() 525 | reward_total = 0 526 | env_containers = self.train_env_containers if train else self.val_env_containers 527 | num_env = len(env_containers) 528 | 529 | for _ in range(num_episodes_per_env): 530 | seq_list = [] 531 | obs_list = [] 532 | for env_container in env_containers: 533 | obs = env_container.reset() 534 | seq_list.append(list()) 535 | obs_list.append(dict(obs=obs)) 536 | posterior = None 537 | action = None 538 | for _ in range(steps_per_episode): 539 | # Find the action to take for a batch of environments. 540 | batch = torchify(obs_list) # Dict of (B, ...) 541 | batch = self.prep_batch(batch, random_crop=False) 542 | outputs = self.observation_model(batch) 543 | obs_features = outputs['obs_features'] 544 | if self.model is not None: # If using a dynamics model. 545 | latent, posterior = self.model.forward_one_step(obs_features, posterior, action, 546 | deterministic_latent=self.sac_deterministic_state) 547 | else: 548 | latent = obs_features 549 | action, _, _ = self.actor(latent, sample=sample_policy) 550 | action_npy = action.detach().cpu().numpy() # (B, a_dims) 551 | 552 | # Step each environment with the computed action. 553 | for i, env_container in enumerate(env_containers): 554 | current_action = action_npy[i] 555 | obs, reward, _, _ = env_container.step(current_action) 556 | seq_list[i].append(dict(obs=obs_list[i]['obs'], action=current_action, reward=reward)) 557 | obs_list[i]['obs'] = obs 558 | reward_total += reward 559 | for seq in seq_list: 560 | replay_buffer.add(seq) 561 | episode_reward = reward_total / (num_env * num_episodes_per_env) 562 | return episode_reward 563 | 564 | def update_target(self, target, critic, tau): 565 | target_params_dict = dict(target.named_parameters()) 566 | for n, p in critic.named_parameters(): 567 | target_params_dict[n].data.copy_( 568 | (1 - tau) * target_params_dict[n] + tau * p 569 | ) 570 | 571 | def update_actor_critic_sac(self, batch, step, heavy_logging=False): 572 | """ 573 | Inputs: 574 | batch : Dict containing keys ('action', 'obs', 'reward') 575 | 'action' : (T, B, action_dims) 576 | 'obs' : (T, B, C, H, W) 577 | 'reward': (T, B) 578 | """ 579 | metrics = {} 580 | 581 | outputs = self.observation_model(batch) # (T, B, dims) 582 | obs_features = outputs['obs_features'] 583 | batch['obs_features'] = obs_features 584 | if self.model is not None: 585 | outputs = self.model(batch) # Dict containing prior (stoch, logits), posterior(..) 586 | states = self.model.dynamics.get_state(outputs['posterior'], 587 | deterministic=self.sac_deterministic_state) 588 | else: 589 | states = obs_features 590 | 591 | # Update critic (potentially including the encoder). 592 | current_states = states[:-1] 593 | if self.sac_detach_states: 594 | current_states = current_states.detach() 595 | current_actions = batch['action'][:-1] 596 | reward = batch['reward'][:-1] # (T-1, B) 597 | next_states = states[1:].detach() 598 | alpha = torch.exp(self.log_alpha).detach() 599 | gamma = self.config['gamma'] 600 | with torch.no_grad(): 601 | if torch.isnan(next_states).any(): 602 | raise Exception('Next states contains nan') 603 | next_actions, next_action_log_probs, _ = self.actor(next_states) 604 | target_q1, target_q2 = self.target_critic(next_states, next_actions) 605 | target_v = torch.min(target_q1, target_q2) - alpha * next_action_log_probs 606 | target_q = reward + gamma * target_v 607 | q1, q2 = self.critic(current_states, current_actions) # (T-1, B) 608 | critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q) 609 | self.optimizer_critic.zero_grad() 610 | critic_loss.backward() 611 | if 'max_grad_norm_critic' in self.config: 612 | grad_norm = torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.config['max_grad_norm_critic']) 613 | metrics['grad_norm_critic'] = grad_norm.item() 614 | self.optimizer_critic.step() 615 | 616 | # Update actor. 617 | current_states_detached = current_states.detach() # Actor loss does not backpropagate into encoder or dynamics. 618 | policy_actions, policy_action_log_probs, policy_action_std = self.actor(current_states_detached) # (T-1, B, action_dims) 619 | q1, q2 = self.critic(current_states_detached, policy_actions) 620 | q = torch.min(q1, q2) 621 | q_loss = -q.mean() 622 | entropy_loss = policy_action_log_probs.mean() 623 | entropy_loss_wt = torch.exp(self.log_alpha).detach() 624 | actor_loss = q_loss + entropy_loss_wt * entropy_loss 625 | self.optimizer_actor.zero_grad() 626 | actor_loss.backward() 627 | if 'max_grad_norm_actor' in self.config: 628 | grad_norm = torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.config['max_grad_norm_actor']) 629 | metrics['grad_norm_actor'] = grad_norm.item() 630 | self.optimizer_actor.step() 631 | 632 | # Update alpha (adaptive entropy loss wt) 633 | alpha_loss = -(torch.exp(self.log_alpha) * (self.target_entropy + entropy_loss.detach())) 634 | self.optimizer_alpha.zero_grad() 635 | alpha_loss.backward() 636 | if 'max_grad_norm_log_alpha' in self.config: 637 | grad_norm = torch.nn.utils.clip_grad_norm_([self.log_alpha], self.config['max_grad_norm_log_alpha']) 638 | metrics['grad_norm_log_alpha'] = grad_norm.item() 639 | self.optimizer_alpha.step() 640 | if 'max_log_alpha' in self.config: 641 | with torch.no_grad(): 642 | self.log_alpha.clamp_(max=self.config['max_log_alpha']) 643 | 644 | if step % self.config['update_target_critic_after'] == 0: 645 | tau = self.config.get('update_target_critic_tau', 1) 646 | self.update_target(self.target_critic, self.critic, tau) 647 | 648 | metrics.update({ 649 | 'critic_loss': critic_loss.item(), 650 | 'actor_loss': actor_loss.item(), 651 | 'q_loss': q_loss.item(), 652 | 'entropy_loss': entropy_loss.item(), 653 | 'log_alpha': self.log_alpha.item(), 654 | }) 655 | if not self.debug and self.tb_writer is not None: 656 | for k, v in metrics.items(): 657 | self.tb_writer.add_scalar('rl_metrics/{}'.format(k), v, step) 658 | if heavy_logging: 659 | self.tb_writer.add_histogram('rl_metrics/reward', reward.view(-1), step) 660 | self.tb_writer.add_histogram('rl_metrics/q_targets', target_q.view(-1), step) 661 | self.tb_writer.add_histogram('rl_metrics/critic_scores', q.view(-1), step) 662 | self.tb_writer.add_histogram('rl_metrics/action', policy_actions.view(-1), step) 663 | self.tb_writer.add_histogram('rl_metrics/action_log_probs', policy_action_log_probs.view(-1), step) 664 | self.tb_writer.add_histogram('rl_metrics/action_std', policy_action_std.view(-1), step) 665 | return metrics 666 | 667 | def train(self): 668 | """ Train the model.""" 669 | # Setup replay buffer. 670 | steps_per_episode = self.config['episode_steps'] // self.action_repeat 671 | replay_buffer_size = self.config['replay_buffer_size'] 672 | num_episodes_in_replay_buffer = replay_buffer_size // steps_per_episode 673 | replay_buffer = SequenceReplayBuffer(size=num_episodes_in_replay_buffer) 674 | 675 | # Find out how many data collection iterations to do use. 676 | max_steps = self.config['max_steps'] // self.action_repeat 677 | num_iters = max_steps // (self.num_envs * steps_per_episode) 678 | 679 | # How many gradients updates per iteration. 680 | num_updates_per_iter = int(steps_per_episode * self.config.get('update_frequency_factor', 1.0)) 681 | 682 | random_crop = self.config.get('random_crop', False) 683 | B = self.config['batch_size'] 684 | T = self.config['dynamics_seq_len'] 685 | train_step = 0 686 | 687 | # Initial data collection. 688 | initial_episodes_per_env = self.config['initial_data_steps'] // (self.num_envs * steps_per_episode) # Used to delay both world model and rl training. 689 | start_rl_training_after = self.config['start_rl_training_after'] # Used to delay rl training until world model has updated for a bit. 690 | 691 | for ii in range(num_iters): 692 | if ii % self.config['validate_every_iters'] == 0: 693 | loss = self.validate(ii) 694 | if ii == 0: 695 | print('Completed validation') 696 | self.save(ii, loss) 697 | 698 | # Collect data. One episode in each environment. 699 | with torch.no_grad(): 700 | if ii < initial_episodes_per_env or train_step < start_rl_training_after: 701 | episode_reward = self.collect_data_random_policy(replay_buffer, num_episodes_per_env=1, train=True) 702 | else: 703 | episode_reward = self.collect_data_from_actor(replay_buffer, num_episodes_per_env=1, train=True, 704 | sample_policy=True) 705 | if not self.debug and self.tb_writer is not None: 706 | self.tb_writer.add_scalar('rl_metrics/episode_reward', episode_reward, ii) 707 | 708 | if ii < initial_episodes_per_env: # No updates until a few episodes have been collected. 709 | continue 710 | self.observation_model.train() 711 | if self.model is not None: 712 | self.model.train() 713 | self.actor.train() 714 | self.critic.train() 715 | for i in range(num_updates_per_iter): 716 | # Train world model. 717 | tic = time.time() 718 | train_step += 1 719 | batch = replay_buffer.sample(B, T) # Dict of (B, T, ..) 720 | batch = self.prep_batch(batch, random_crop=random_crop) 721 | tic1 = time.time() 722 | if not self.exclude_wm_loss: # Skip for model-free variants, like SAC, RSAC. 723 | self.update_world_model(batch, train_step, heavy_logging=(i == 0)) 724 | tic2 = time.time() 725 | if self.has_momentum_encoder: 726 | self.update_curl(batch, train_step, heavy_logging=(i == 0)) 727 | tic3 = time.time() 728 | if train_step >= start_rl_training_after: 729 | self.update_actor_critic_sac(batch, train_step, heavy_logging=(i == 0)) 730 | toc = time.time() 731 | timing_metrics = { 732 | 'time_data_prep': tic1 - tic, 733 | 'time_wm_update': tic2 - tic1, 734 | 'time_curl_update': tic3 - tic2, 735 | 'time_ac_update': toc - tic3, 736 | 'time_per_update': toc - tic, 737 | } 738 | if not self.debug and self.tb_writer is not None: 739 | for k, v in timing_metrics.items(): 740 | self.tb_writer.add_scalar('timing_metrics/{}'.format(k), v, train_step) 741 | if train_step == 1: 742 | print('Completed one step') 743 | 744 | 745 | def argument_parser(argument): 746 | """ Argument parser """ 747 | parser = argparse.ArgumentParser(description='Binder Network.') 748 | parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') 749 | parser.add_argument('-c', '--config', default='', type=str, help='Training config') 750 | parser.add_argument('--debug', action='store_true', help='Debug mode. Disable logging.') 751 | args = parser.parse_args(argument) 752 | return args 753 | 754 | 755 | def main(): 756 | args = argument_parser(None) 757 | if not args.disable_cuda and torch.cuda.is_available(): 758 | device = torch.device('cuda') 759 | print('Running on GPU {}'.format(torch.cuda.get_device_name(0))) 760 | else: 761 | device = torch.device('cpu') 762 | print('Running on CPU') 763 | 764 | try: 765 | with open(args.config) as f: 766 | config = yaml.safe_load(f) 767 | except FileNotFoundError: 768 | print("Error opening specified config yaml at: {}. " 769 | "Please check filepath and try again.".format(args.config)) 770 | sys.exit(1) 771 | 772 | config = config['parameters'] 773 | config['expt_id'] = generate_expt_id() 774 | seed = config['seed'] 775 | random.seed(seed) 776 | np.random.seed(seed) 777 | torch.manual_seed(seed) 778 | 779 | trainer = Trainer(config, device, args.debug) 780 | trainer.train() 781 | 782 | 783 | if __name__ == '__main__': 784 | main() 785 | 786 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | source /miniconda/etc/profile.d/conda.sh 2 | conda init bash 3 | conda activate core 4 | export PYTHONPATH=.:$PYTHONPATH 5 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia-000 6 | tensorboard --logdir ${BOLT_ARTIFACT_DIR} --bind_all --port ${TENSORBOARD_PORT} & 7 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py 8 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import random 6 | import string 7 | import time 8 | import torch 9 | from torchvision.transforms import Resize, RandomCrop, CenterCrop 10 | import subprocess as sp 11 | import numpy as np 12 | import os 13 | import imageio 14 | 15 | 16 | def torchify(seq): 17 | """ 18 | Convert list of dict of numpy arrays/floats/dicts to dict of tensors. 19 | Args: 20 | seq : List of dicts. 21 | Returns: 22 | batch : Dict of tensors of shape (T, ..). 23 | """ 24 | keys = seq[0].keys() 25 | batch = {} 26 | for key in keys: 27 | value = seq[0][key] 28 | if isinstance(value, np.ndarray): 29 | batch[key] = torch.stack([torch.from_numpy(frame[key]) for frame in seq]) # (T, ..) 30 | elif isinstance(value, float) or isinstance(value, int): 31 | batch[key] = torch.tensor([frame[key] for frame in seq]) 32 | elif isinstance(value, dict): 33 | sub_batch = torchify([frame[key] for frame in seq]) 34 | for sub_key, val in sub_batch.items(): 35 | batch[key + '_' + sub_key] = val 36 | else: 37 | raise Exception('Unknown type of value in torchify for key ', key) 38 | return batch 39 | 40 | 41 | class FreezeParameters: 42 | def __init__(self, parameters): 43 | self.parameters = parameters 44 | self.param_states = [p.requires_grad for p in self.parameters] 45 | 46 | def __enter__(self): 47 | for param in self.parameters: 48 | param.requires_grad = False 49 | 50 | def __exit__(self, exc_type, exc_val, exc_tb): 51 | for i, param in enumerate(self.parameters): 52 | param.requires_grad = self.param_states[i] 53 | 54 | 55 | def freeze_parameters(parameters): 56 | for param in parameters: 57 | param.requires_grad = False 58 | 59 | 60 | def add_weight_decay(net, l2_value, skip_list=(), exclude_both_list=()): 61 | decay, no_decay = [], [] 62 | for name, param in net.named_parameters(): 63 | if not param.requires_grad or name in exclude_both_list: 64 | continue # frozen weights 65 | if len(param.shape) == 1 or name.endswith('.bias') or name in skip_list: 66 | no_decay.append(param) 67 | else: 68 | decay.append(param) 69 | return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}] 70 | 71 | 72 | def get_random_string(length=12): 73 | # choose from all lowercase letter 74 | letters = string.ascii_lowercase + string.digits 75 | result_str = ''.join(random.choice(letters) for i in range(length)) 76 | return result_str 77 | 78 | 79 | def generate_expt_id(): 80 | task_id = get_random_string() 81 | return time.strftime('%Y-%m-%d_%H-%M-%S') + '_' + task_id 82 | 83 | 84 | def features_as_image(x, x_min=None, x_max=None): 85 | """ 86 | x : (B, D) or (1, B, D) 87 | """ 88 | assert len(x.shape) == 2 or (len(x.shape) == 3 and x.shape[0] == 1) 89 | if x_max is None: 90 | x_max = x.max() 91 | if x_min is None: 92 | x_min = x.min() 93 | x = (x - x_min) / (x_max - x_min) 94 | if len(x.shape) == 2: 95 | x = x.unsqueeze(0) # (1, H, W) 96 | return x 97 | 98 | 99 | def probs_as_images(probs, W=64): 100 | """ 101 | Inputs: 102 | probs: (..., D) 103 | Outputs: 104 | vid : (..., C, H, W), C = 3, W = 32 105 | """ 106 | orig_shape = list(probs.shape) 107 | D = orig_shape[-1] 108 | H = D // W 109 | assert D % W == 0 110 | probs = (probs * 255).to(torch.uint8) 111 | C = 3 112 | probs = probs.view(-1, 1, H, W).expand(-1, C, -1, -1).contiguous() 113 | probs = probs.view(*orig_shape[:-1], C, H, W) 114 | return probs 115 | 116 | 117 | def resize_image(images, height, width): 118 | """ 119 | Resize images. 120 | Inputs: 121 | images: (..., C, H, W) 122 | Outputs: 123 | images: (..., C, height, width) 124 | """ 125 | orig_shape = list(images.shape) 126 | C, H, W = orig_shape[-3:] 127 | images = images.view(-1, C, H, W) 128 | resize_op = Resize((height, width)) 129 | images = resize_op(images) 130 | images = images.view(*orig_shape[:-2], height, width) 131 | return images 132 | 133 | 134 | def crop_image_tensor(obs, crop_height, crop_width, random_crop=False, same_crop_across_time=False, padding=0): 135 | """ 136 | Crop a tensor of images. 137 | Args: 138 | obs: (B, T, C, H, W), or (B, C, H, W) or (C, H, W). 139 | crop_height: Height of the cropped image. 140 | crop_width: Width of the cropped image. 141 | random_crop: If true, crop random patch. Otherwise the center crop is returned. 142 | same_crop_across_time: Maintain the same crop across time for temporal sequences. 143 | padding: How much edge padding to add. 144 | Returns: 145 | cropped_obs: (B, T, C, crop_height, crop_width) 146 | """ 147 | assert len(obs.shape) >= 3 148 | channels, height, width = obs.shape[-3], obs.shape[-2], obs.shape[-1] 149 | if random_crop: 150 | transform = RandomCrop((crop_height, crop_width), padding=padding, padding_mode='edge') 151 | orig_shape = list(obs.shape[:-2]) 152 | if same_crop_across_time and len(obs.shape) >= 5: 153 | T = obs.shape[-4] 154 | channels = channels * T 155 | obs = obs.view(-1, channels, height, width) 156 | cropped_obs = torch.zeros(obs.size(0), channels, crop_height, crop_width, dtype=obs.dtype, device=obs.device) 157 | for i in range(obs.size(0)): 158 | cropped_obs[i, ...] = transform(obs[i, ...]) 159 | cropped_obs = cropped_obs.view(*orig_shape, crop_height, crop_width) 160 | else: 161 | transform = CenterCrop((crop_height, crop_width)) 162 | cropped_obs = transform(obs) 163 | return cropped_obs 164 | 165 | 166 | def get_parameter_list(optimizer): 167 | params_list = [] 168 | for group in optimizer.param_groups: 169 | params_list.extend(list(group['params'])) 170 | return params_list 171 | 172 | 173 | def stack_tensor_dict_list(tensor_dict_list): 174 | """ Stack tensors in a list of dictionaries. """ 175 | keys = tensor_dict_list[0].keys() 176 | res = {} 177 | for key in keys: 178 | res[key] = torch.stack([d[key] for d in tensor_dict_list]) 179 | return res 180 | 181 | 182 | def write_video_mp4_command(filename, frames): 183 | """ 184 | frames : T, C, H, W 185 | """ 186 | if isinstance(frames, np.ndarray): 187 | T, channels, height, width = frames.shape 188 | elif isinstance(frames, torch.Tensor): 189 | T, channels, height, width = frames.shape 190 | frames = frames.detach().cpu().numpy() 191 | elif isinstance(frames, list): 192 | channels, height, width = frames[0].shape 193 | assert channels >= 3 194 | frames = np.stack([frame[:3, ...] for frame in frames], axis=0) 195 | else: 196 | raise Exception('Unknown frame specification.') 197 | frames = frames.astype(np.uint8) 198 | frames = frames.transpose((0, 2, 3, 1)) 199 | print('Writing video {}'.format(filename)) 200 | # Write out as a mp4 video. 201 | command = ['ffmpeg', 202 | '-y', # (optional) overwrite output file if it exists 203 | '-f', 'rawvideo', 204 | '-vcodec', 'rawvideo', 205 | '-s', '{}x{}'.format(width, height), # size of one frame 206 | '-pix_fmt', 'rgb24', 207 | '-r', '20', # frames per second 208 | '-an', # Tells FFMPEG not to expect any audio 209 | '-i', '-', # The input comes from a pipe 210 | '-vcodec', 'libx264', 211 | '-pix_fmt', 'yuv420p', 212 | filename] 213 | print(' '.join(command)) 214 | proc = sp.Popen(command, stdin=sp.PIPE, stderr=sp.PIPE) 215 | outs, errs = proc.communicate(input=frames.tobytes()) 216 | 217 | 218 | def write_video_mp4(filename, frames): 219 | """ 220 | frames : T, C, H, W 221 | """ 222 | if isinstance(frames, np.ndarray): 223 | T, channels, height, width = frames.shape 224 | elif isinstance(frames, torch.Tensor): 225 | T, channels, height, width = frames.shape 226 | frames = frames.detach().cpu().numpy() 227 | elif isinstance(frames, list): 228 | channels, height, width = frames[0].shape 229 | assert channels >= 3 230 | frames = np.stack([frame[:3, ...] for frame in frames], axis=0) 231 | else: 232 | raise Exception('Unknown frame specification.') 233 | frames = frames.astype(np.uint8) 234 | frames = frames.transpose((0, 2, 3, 1)) 235 | print('Writing video {}'.format(filename)) 236 | # Write out as a mp4 video. 237 | writer = imageio.get_writer(filename, fps=20) 238 | for frame in frames: 239 | writer.append_data(frame) 240 | writer.close() 241 | 242 | 243 | def test_write_video_mp4(): 244 | filename = 'test.mp4' 245 | frames = np.random.rand(100, 3, 128, 128) * 255 246 | write_video_mp4(filename, frames) 247 | 248 | 249 | if __name__ == '__main__': 250 | test_write_video_mp4() 251 | -------------------------------------------------------------------------------- /videos/hard-walker.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-core/22879e81e2b670d3b2e7bced2a2f735ae8f81971/videos/hard-walker.gif -------------------------------------------------------------------------------- /videos/medium-cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-core/22879e81e2b670d3b2e7bced2a2f735ae8f81971/videos/medium-cartpole.gif -------------------------------------------------------------------------------- /videos/medium-cheetah.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-core/22879e81e2b670d3b2e7bced2a2f735ae8f81971/videos/medium-cheetah.gif -------------------------------------------------------------------------------- /world_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved. 4 | # 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from utils import stack_tensor_dict_list 9 | from modules import FCNet, CNN, TransposeCNN, GRUCell, weight_init 10 | 11 | 12 | def get_pairwise_smooth_l1(x): 13 | dims = x.shape[0] 14 | x_shape = list(x.shape) 15 | x1 = x.unsqueeze(0).expand(dims, *x_shape) 16 | x2 = x.unsqueeze(1).expand(dims, *x_shape) 17 | return F.smooth_l1_loss(x1, x2, reduction='none') 18 | 19 | 20 | def get_pairwise_l2(x): 21 | dims = x.shape[0] 22 | x_shape = list(x.shape) 23 | x1 = x.unsqueeze(0).expand(dims, *x_shape) 24 | x2 = x.unsqueeze(1).expand(dims, *x_shape) 25 | return F.mse_loss(x1, x2, reduction='none') 26 | 27 | 28 | def sample_softmax(probs): 29 | dist = torch.distributions.one_hot_categorical.OneHotCategorical(probs=probs) 30 | stoch = dist.sample() + probs - probs.detach() # Trick to send straight-through gradients into probs. 31 | return stoch 32 | 33 | 34 | class Dynamics(nn.Module): 35 | def __init__(self, config, obs_embed_dims, action_dims): 36 | super().__init__() 37 | self.config = config 38 | self.action_dims = action_dims 39 | self.obs_embed_dims = obs_embed_dims 40 | 41 | self.forward_dynamics_loss = config['forward_dynamics_loss'] 42 | assert self.forward_dynamics_loss in ['neg_log_prob', 'kl'] 43 | self.discrete = config.get('discrete', False) 44 | if self.discrete: 45 | self.num_softmaxes = config['num_softmaxes'] 46 | self.dims_per_softmax = config['dims_per_softmax'] 47 | self.latent_dims = self.dims_per_softmax * self.num_softmaxes 48 | self.suff_stats_dims = self.latent_dims # dims needed to specify a distribution over the latent state. 49 | else: 50 | self.latent_dims = config['latent_dims'] 51 | self.suff_stats_dims = 2 * self.latent_dims # mu, log_std 52 | 53 | self.recurrent = config.get('recurrent', False) 54 | if self.recurrent: 55 | self.rnn_state_dims = config['rnn_state_dims'] 56 | else: 57 | self.rnn_state_dims = 0 58 | self.state_dims = self.latent_dims + self.rnn_state_dims 59 | 60 | if self.recurrent: 61 | # Input to recurrent model. 62 | self.rnn_input_net = FCNet(config['rnn_input'], self.latent_dims + action_dims) # Inputs: Prev latent + act. 63 | num_features = self.rnn_input_net.output_dims 64 | print('RNN Input Net input_dims: {} output_dims {}'.format(self.latent_dims + action_dims, num_features)) 65 | 66 | # Recurrent model. 67 | print('GRUCell input_dims {} hidden dims {}'.format(num_features, self.rnn_state_dims)) 68 | self.cell = GRUCell(num_features, self.rnn_state_dims, norm=True) # h_t = cell([z_{t-1},a_{t-1}], h_{t-1}) 69 | 70 | # Prior model (Transition function) P(z_t | h_t) 71 | print('Prior Net input dims {} output dims {}'.format(self.rnn_state_dims, self.suff_stats_dims)) 72 | self.prior_net = FCNet(config['prior'], self.rnn_state_dims, out_features=self.suff_stats_dims) 73 | else: 74 | # Prior model (Transition function) P(z_t | h_{t-1}, a_{t-1}) 75 | print('Prior Net --------') 76 | print('Prior Net input dims {} output dims {}'.format(self.latent_dims + action_dims, self.suff_stats_dims)) 77 | self.prior_net = FCNet(config['prior'], self.latent_dims + action_dims, self.suff_stats_dims) 78 | 79 | # Posterior model (Representation model). P(z_t | x_t, h_t) 80 | print('Posterior Net input dims {} output dims {}'.format(self.obs_embed_dims + self.rnn_state_dims, 81 | self.suff_stats_dims)) 82 | self.posterior_net = FCNet(config['posterior'], self.obs_embed_dims + self.rnn_state_dims, 83 | out_features=self.suff_stats_dims) 84 | 85 | # Initial_state. 86 | if self.discrete: 87 | initial_logits = torch.zeros(self.num_softmaxes, self.dims_per_softmax) 88 | self.register_buffer('initial_logits', initial_logits) 89 | else: 90 | initial_mu = torch.zeros(self.latent_dims) 91 | initial_log_std = torch.zeros(self.latent_dims) 92 | self.register_buffer('initial_mu', initial_mu) 93 | self.register_buffer('initial_log_std', initial_log_std) 94 | if self.recurrent: 95 | initial_rnn_state = torch.zeros(self.rnn_state_dims) 96 | self.register_buffer('initial_rnn_state', initial_rnn_state) 97 | 98 | self.elementwise_bisim = False 99 | self.free_kl = config.get('free_kl', 0.0) 100 | 101 | def initial_state(self, batch_size, deterministic=False): 102 | if self.discrete: 103 | logits = self.initial_logits.unsqueeze(0).expand(batch_size, -1, -1) 104 | probs = torch.softmax(logits, dim=-1) 105 | if deterministic: 106 | state = probs 107 | else: 108 | state = sample_softmax(probs) 109 | state = dict(logits=logits, probs=probs.detach(), state=state.detach()) 110 | else: 111 | mu = self.initial_mu.unsqueeze(0).expand(batch_size, -1) 112 | log_std = self.initial_log_std.unsqueeze(0).expand(batch_size, -1) 113 | if deterministic: 114 | state = mu 115 | else: 116 | state = mu + torch.exp(log_std) * torch.randn_like(log_std) 117 | state = dict(log_std=log_std, mu=mu, state=state.detach()) 118 | if self.recurrent: 119 | state['rnn_state'] = self.initial_rnn_state.unsqueeze(0).expand(batch_size, -1) 120 | return state 121 | 122 | def get_state(self, state_dict, deterministic=False): 123 | """ 124 | Get the latent vector by concatenating rnn state and flattened latent state. 125 | Args: 126 | state_dict: Dict containing 'state', 'rnn_state', 'mu', 'log_std', (OR 'logits') 127 | Returns: 128 | state: (B, D) 129 | """ 130 | if deterministic: 131 | if self.discrete: 132 | state = state_dict['probs'].flatten(-2, -1) 133 | else: 134 | state = state_dict['mu'] 135 | else: 136 | if self.discrete: 137 | state = state_dict['state'].flatten(-2, -1) 138 | else: 139 | state = state_dict['state'] 140 | if self.recurrent: 141 | state = torch.cat([state, state_dict['rnn_state']], dim=-1) 142 | return state 143 | 144 | def compute_kl(self, state_p, state_q): 145 | """ 146 | p log (p/q) 147 | """ 148 | if self.discrete: 149 | logits_p = state_p['logits'] 150 | p = state_p['probs'] 151 | logits_q = state_q['logits'] 152 | log_p = nn.functional.log_softmax(logits_p, dim=-1) 153 | log_q = nn.functional.log_softmax(logits_q, dim=-1) 154 | kld = (p * (log_p - log_q)).sum(dim=-1) 155 | else: 156 | mu_1 = state_p['mu'] 157 | mu_2 = state_q['mu'] 158 | log_std_1 = state_p['log_std'] 159 | log_std_2 = state_q['log_std'] 160 | var_1 = torch.exp(2 * log_std_1) 161 | var_2 = torch.exp(2 * log_std_2) 162 | kld = log_std_2 - log_std_1 + (var_1 + (mu_1 - mu_2) ** 2) / (2 * var_2) - 0.5 163 | kld = kld.sum(dim=-1) 164 | if self.free_kl > 0.0: 165 | kld = kld.clamp_min(self.free_kl) 166 | return kld 167 | 168 | def compute_neg_log_prob(self, state_p, state_q): 169 | """ 170 | Negative Log prob of mean of p under the q distribution. 171 | """ 172 | if self.discrete: 173 | raise NotImplementedError 174 | else: 175 | x = state_p['mu'] 176 | mu = state_q['mu'] 177 | log_std = state_q['log_std'].clamp(-10., 10.) 178 | var = torch.exp(2 * log_std) 179 | neg_log_prob = log_std + ((x - mu) ** 2) / (2 * var) 180 | neg_log_prob = neg_log_prob.mean(dim=-1) 181 | return neg_log_prob 182 | 183 | def compute_forward_dynamics_loss(self, state_p, state_q): 184 | if self.forward_dynamics_loss == 'kl': 185 | return self.compute_kl(state_p, state_q) 186 | elif self.forward_dynamics_loss == 'neg_log_prob': 187 | return self.compute_neg_log_prob(state_p, state_q) 188 | else: 189 | raise Exception('Unknown forward dynamics loss') 190 | 191 | def process_suff_stats(self, suff_stats): 192 | """ 193 | Process the sufficient statistics (obtained from the output of posterior or prior networks). 194 | If discrete, the output is interpreted as logits. 195 | Otherwise, the output is chunked into two : mu and sigma. 196 | """ 197 | # Get the prior state. 198 | res = {} 199 | if self.discrete: 200 | logits = suff_stats.view(-1, self.num_softmaxes, self.dims_per_softmax) 201 | probs = nn.functional.softmax(logits, dim=-1) 202 | state = self.sample_softmax(probs) 203 | res['logits'] = logits 204 | res['probs'] = probs 205 | else: 206 | mu, log_std = suff_stats.chunk(2, -1) 207 | log_std = log_std.clamp(-10.0, 10.0) 208 | state = mu + torch.exp(log_std) * torch.randn_like(log_std) 209 | res['mu'] = mu 210 | res['log_std'] = log_std 211 | res['state'] = state 212 | return res 213 | 214 | def imagine_step(self, prev_state_dict, prev_action, deterministic=False): 215 | """ 216 | Args: 217 | prev_state_dict : Dict containing 'state', 'mu', etc (B, num_softmaxes, dims_per_softmax). 218 | prev_action: (B, action_dims) or None 219 | deterministic: If True, sample from the prev_state distribution, otherwise use the mode. 220 | Returns: 221 | prior: Dict containing 'stoch', 'deter', 'logits' 222 | """ 223 | 224 | # prev_action is None at the first time step where we don't know what action led to the first state. 225 | # In this case, we just pass on the prior as out best guess. 226 | if prev_action is None: 227 | return prev_state_dict 228 | 229 | if deterministic: 230 | if self.discrete: 231 | prev_state = prev_state_dict['probs'] 232 | else: 233 | prev_state = prev_state_dict['mu'] 234 | else: 235 | prev_state = prev_state_dict['state'] 236 | if self.discrete: 237 | prev_state = prev_state.flatten(-2, -1) 238 | x = torch.cat([prev_state, prev_action], dim=-1) 239 | 240 | res = {} 241 | if self.recurrent: 242 | # Step the RNN to get rnn state. 243 | prev_rnn_state = prev_state_dict['rnn_state'] 244 | x = self.rnn_input_net(x) 245 | x = self.cell(x, prev_rnn_state) 246 | res['rnn_state'] = x 247 | 248 | # Get the prior state. 249 | x = self.prior_net(x) 250 | res.update(self.process_suff_stats(x)) 251 | return res 252 | 253 | def forward(self, prev_state_dict, prev_action, obs_embed, deterministic=False): 254 | """ 255 | Take one step of the dynamics, and return the updated prior and posterior. 256 | Args: 257 | prev_state_dict : Dict containing 'state' (B, num_softmaxes, dims_per_softmax), 'rnn_state' (B, deter_dims) 258 | prev_action: (B, action_dims) 259 | obs_embed: (B, obs_embed_dims) 260 | deterministic: If false, estimate the prior using samples from the prev_state distribution, otherwise use the mode. 261 | Returns: 262 | prior: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std') 263 | posterior: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std') 264 | """ 265 | prior = self.imagine_step(prev_state_dict, prev_action, deterministic=deterministic) 266 | posterior = {} 267 | x = obs_embed 268 | if self.recurrent: 269 | rnn_state = prior['rnn_state'] 270 | posterior['rnn_state'] = rnn_state 271 | x = torch.cat([x, rnn_state], dim=-1) 272 | x = self.posterior_net(x) 273 | posterior.update(self.process_suff_stats(x)) 274 | return prior, posterior 275 | 276 | def rollout_prior(self, initial_state, actions, deterministic=False): 277 | """ 278 | Rollout the dynamics given actions, starting from an initial state. 279 | Args: 280 | initial_state: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std') (B, ...) 281 | actions: (T, B, action_dims) 282 | deterministic: If true, use samples from the state distribution over time, otherwise use the mode. 283 | Return: 284 | state: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std') (T, B, ...) 285 | """ 286 | T = actions.shape[0] 287 | state = initial_state 288 | state_list = [] 289 | for t in range(T): 290 | state = self.imagine_step(state, actions[t], deterministic=deterministic) 291 | state_list.append(state) 292 | state = stack_tensor_dict_list(state_list) 293 | return state 294 | 295 | 296 | class WorldModel(nn.Module): 297 | def __init__(self, config, obs_dims, action_dims): 298 | super().__init__() 299 | 300 | # Obs encoder. 301 | if 'obs_encoder' in config: 302 | self.encoder = FCNet(config['obs_encoder'], obs_dims) 303 | obs_embed_dims = self.encoder.output_dims 304 | print('Observation feature encoder ---- Input dims {} Output dims {}'.format(obs_dims, obs_embed_dims)) 305 | else: 306 | self.encoder = None 307 | obs_embed_dims = obs_dims 308 | 309 | # Dynamics Model. 310 | self.dynamics = Dynamics(config['dynamics'], obs_embed_dims, action_dims) 311 | state_dims = self.dynamics.state_dims 312 | print('Dynamics ---- obs dims {} action dims {} state dims {}'.format(obs_embed_dims, action_dims, state_dims)) 313 | 314 | # Image predictor P(x_t | z_t, h_t) 315 | if 'obs_decoder' in config: 316 | self.decoder = FCNet(config['obs_decoder'], state_dims, out_features=obs_dims) 317 | print('Observation feature decoder ---- input dims {} output dims {}'.format(state_dims, obs_dims)) 318 | else: 319 | self.decoder = None 320 | 321 | print('Reward predictor ---- input dims {}'.format(state_dims)) 322 | # Reward predictor P(r | z_t, h_t) 323 | self.reward_net = FCNet(config['reward_net'], state_dims) 324 | self.reward_prediction_from_prior = config['reward_prediction_from_prior'] 325 | 326 | if 'inverse_dynamics' in config: 327 | print('Inverse dynamics model ------ input dims 2 * {} action dims {}'.format(state_dims, action_dims)) 328 | self.inv_dynamics = FCNet(config['inverse_dynamics'], 2 * state_dims, action_dims) 329 | else: 330 | self.inv_dynamics = None 331 | 332 | self.propagate_deterministic = config['propagate_deterministic'] 333 | self.decode_deterministic = config['decode_deterministic'] 334 | self.state_dims = state_dims 335 | 336 | if self.dynamics.forward_dynamics_loss == 'neg_log_prob': 337 | assert self.propagate_deterministic, "Posterior variance is not trained, propagate_deterministic should be True" 338 | self.apply(weight_init) 339 | 340 | def forward_one_step(self, obs_encoding, posterior=None, action=None, deterministic_latent=True): 341 | """ 342 | Inputs: 343 | batch: Dict containing 344 | """ 345 | if posterior is None: 346 | B = obs_encoding.size(0) 347 | posterior = self.dynamics.initial_state(B) 348 | obs_embed = self.encoder(obs_encoding) 349 | _, posterior = self.dynamics(posterior, action, obs_embed, deterministic=self.propagate_deterministic) 350 | latent = self.dynamics.get_state(posterior, deterministic=deterministic_latent) 351 | return latent, posterior 352 | 353 | def forward(self, batch): 354 | """ 355 | Args: 356 | batch: Dict containing 357 | 'obs': (T, B, C, H, W) 358 | 'obs_features': (T, B, feat_dims) 359 | 'action': (T, B, action_dims) 360 | 'reward': (T, B) 361 | Return: 362 | output: Dict containing 'obs_recon' 363 | """ 364 | obs_features = batch['obs_features'] 365 | action = batch['action'] 366 | T, B, _ = action.shape 367 | 368 | if self.encoder is not None: 369 | obs_features = self.encoder(obs_features) # (T, B, obs_embed_dims) 370 | prev_state = self.dynamics.initial_state(B) 371 | prior_list = [] 372 | posterior_list = [] 373 | for t in range(T): 374 | prev_action = action[t-1] if t > 0 else None 375 | current_obs = obs_features[t] 376 | prior, posterior = self.dynamics(prev_state, prev_action, current_obs, deterministic=self.propagate_deterministic) 377 | prior_list.append(prior) 378 | posterior_list.append(posterior) 379 | prev_state = posterior 380 | prior = stack_tensor_dict_list(prior_list) 381 | posterior = stack_tensor_dict_list(posterior_list) 382 | latent_post = self.dynamics.get_state(posterior, deterministic=self.decode_deterministic) # T, B, D 383 | latent_prior = self.dynamics.get_state(prior, deterministic=False) 384 | if self.reward_prediction_from_prior: 385 | reward_prediction = self.reward_net(latent_prior).squeeze(-1) # DBC uses latent prior. 386 | else: 387 | reward_prediction = self.reward_net(latent_post).squeeze(-1) 388 | outputs = dict(prior=prior, 389 | posterior=posterior, 390 | reward_prediction=reward_prediction) 391 | if self.decoder is not None: 392 | outputs['obs_features_recon_post'] = self.decoder(latent_post) # (T, B, dims) 393 | outputs['obs_features_recon_prior'] = self.decoder(latent_prior) # (T, B, dims) 394 | if self.inv_dynamics is not None: 395 | paired_states = torch.cat([latent_post[:-1], latent_post[1:]], dim=-1) 396 | action_prediction = self.inv_dynamics(paired_states) 397 | outputs['action_prediction'] = action_prediction 398 | return outputs 399 | --------------------------------------------------------------------------------