├── 10_nvidia.json
├── ACKNOWLEDGEMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── conda_env.yml
├── conda_env_robosuite.yml
├── configs
├── dcs
│ └── core.yaml
└── robosuite
│ ├── core.yaml
│ └── core_imageonly.yaml
├── distracting_control
├── README.md
├── background.py
├── camera.py
├── camera_test.py
├── color.py
├── distracting_control_demo.py
├── requirements.txt
├── run.sh
├── suite.py
├── suite_test.py
└── suite_utils.py
├── environment_container_dcs.py
├── environment_container_robosuite.py
├── environments.py
├── modules.py
├── replay_buffer.py
├── setup.sh
├── train.py
├── train.sh
├── utils.py
├── videos
├── hard-walker.gif
├── medium-cartpole.gif
└── medium-cheetah.gif
└── world_model.py
/10_nvidia.json:
--------------------------------------------------------------------------------
1 | {
2 | "file_format_version" : "1.0.0",
3 | "ICD" : {
4 | "library_path" : "libEGL_nvidia.so.0"
5 | }
6 | }
7 |
8 |
--------------------------------------------------------------------------------
/ACKNOWLEDGEMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this software may utilize the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | _____________________
6 | Distracting Control Suite
7 | Copyright 2021 The Google Research Authors.
8 | Licensed under the Apache License, Version 2.0 (the "License");
9 | you may not use this file except in compliance with the License.
10 | You may obtain a copy of the License at
11 |
12 | http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | Unless required by applicable law or agreed to in writing, software
15 | distributed under the License is distributed on an "AS IS" BASIS,
16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | See the License for the specific language governing permissions and
18 | limitations under the License.
19 |
20 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2021 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
41 | -------------------------------------------------------------------------------
42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
43 |
44 | This software includes a number of subcomponents with separate
45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
46 | -------------------------------------------------------------------------------
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoRe: Contrastive Recurrent State-Space Models
2 |
3 | This code implements the CoRe model and reproduces experimental results found in
4 | **Robust Robotic Control from Pixels using Contrastive Recurrent State-Space models**
5 | NeurIPS Deep Reinforcement Learning Workshop 2021
6 | Nitish Srivastava, Walter Talbott, Martin Bertran Lopez, Shuangfei Zhai & Joshua M. Susskind
7 | [[paper](https://arxiv.org/abs/2112.01163)]
8 |
9 | 
10 |
11 | 
12 |
13 | 
14 |
15 |
16 | ## Requirements and Installation
17 | Clone this repository and then execute the following steps. See `setup.sh` for an example of how to run these steps on a Ubuntu 18.04 machine.
18 |
19 | * Install dependencies.
20 | ```
21 | apt install -y libgl1-mesa-dev libgl1-mesa-glx libglew-dev \
22 | libosmesa6-dev software-properties-common net-tools unzip \
23 | virtualenv wget xpra xserver-xorg-dev libglfw3-dev patchelf xvfb ffmpeg
24 | ```
25 | * Download the [DAVIS 2017
26 | dataset](https://davischallenge.org/davis2017/code.html). Make sure to select the 2017 TrainVal - Images and Annotations (480p). The training images will be used as distracting backgrounds. The `DAVIS` directory should be in the same directory as the code. Check that `ls ./DAVIS/JPEGImages/480p/...` shows 90 video directories.
27 | * Install MuJoCo 2.1.
28 | - Download [MuJoCo version 2.1](https://mujoco.org/download) binaries for Linux or macOS.
29 | - Unzip the downloaded `mujoco210` directory into `~/.mujoco/mujoco210`.
30 | * Install MuJoCo 2.0 (For robosuite experiments only).
31 | - Download [MuJoCo version 2.0](https://roboti.us/download.html) binaries for Linux or macOS.
32 | - Unzip the downloaded directory and move it into `~/.mujoco/`.
33 | - Symlink `mujoco200_linux` (or `mujoco200_macos`) to `mujoco200`.
34 | ```
35 | ln -s ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200
36 | ```
37 | - Place the [license key](https://roboti.us/license.html) at `~/.mujoco/mjkey.txt`.
38 | - Add the MuJoCo binaries to `LD_LIBRARY_PATH`.
39 | ```
40 | export LD_LIBRARY_PATH=$HOME/.mujoco/mujoco200/bin:$LD_LIBRARY_PATH
41 | ```
42 | * Setup EGL GPU rendering (if a GPU is available).
43 | - To ensure that the GPU is prioritized over the CPU for EGL rendering
44 | ```
45 | cp 10_nvidia.json /usr/share/glvnd/egl_vendor.d/
46 | ```
47 | - Create a dummy nvidia directory so that mujoco_py builds the extensions needed for GPU rendering.
48 | ```
49 | mkdir -p /usr/lib/nvidia-000
50 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia-000
51 | ```
52 | * Create a conda environment.
53 |
54 | For Distracting Control Suite
55 | ```
56 | conda env create -f conda_env.yml
57 | ```
58 |
59 | For Robosuite
60 | ```
61 | conda env create -f conda_env_robosuite.yml
62 | ```
63 |
64 | ## Training
65 |
66 | * The CoRe model can be trained on the Distracting Control Suite as follows:
67 |
68 | ```
69 | conda activate core
70 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py --config configs/dcs/core.yaml
71 | ```
72 | The training artifacts, including tensorboard logs and videos of validation rollouts will be written in `./artifacts/`.
73 |
74 | To change the distraction setting, modify the `difficulty` parameter in `configs/dcs/core.yaml`. Possible values are `['easy', 'medium', 'hard', 'none', 'hard_bg']`.
75 |
76 | To change the domain, modify the `domain` parameter in `configs/dcs/core.yaml`. Possible values are `['ball_in_cup', 'cartpole', 'cheetah', 'finger', 'reacher', 'walker']`.
77 |
78 | * To train on Robosuite (Door Task, Franka Panda Arm)
79 |
80 | - Using RGB image and proprioceptive inputs.
81 | ```
82 | conda activate core_robosuite
83 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py --config configs/robosuite/core.yaml
84 | ```
85 | - Using RGB image inputs only.
86 | ```
87 | conda activate core_robosuite
88 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py --config configs/robosuite/core_imageonly.yaml
89 | ```
90 |
91 | ## Citation
92 | ```
93 | @article{srivastava2021core,
94 | title={Robust Robotic Control from Pixels using Contrastive Recurrent State-Space Models},
95 | author={Nitish Srivastava and Walter Talbott and Martin Bertran Lopez and Shuangfei Zhai and Josh Susskind},
96 | journal={NeurIPS Deep Reinforcement Learning Workshop},
97 | year={2021}
98 | }
99 | ```
100 |
101 | ## License
102 | This code is released under the [LICENSE](LICENSE) terms.
103 |
--------------------------------------------------------------------------------
/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: core
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - pip=20.1.1
7 | - python=3.7.6
8 | - pytorch=1.10
9 | - torchvision=0.11.0
10 | - cudatoolkit=11.3
11 | - tensorflow=2.2.0
12 | - tensorboard=2.6.0
13 | - absl-py=0.13.0
14 | - pyparsing=2.4.7
15 | - pillow=6.1.0
16 | - h5py=2.10.0
17 | - pyyaml=5.4.1
18 | - moviepy=1.0.1
19 | - numpy=1.20
20 | - conda-forge::matplotlib=3.2.2
21 | - conda-forge::gym=0.19.0
22 | - conda-forge::ffmpeg=4.4.0
23 | - conda-forge::libiconv=1.16
24 | - conda-forge::imageio=2.6.1
25 | - pip:
26 | - termcolor==1.1.0
27 | - dm_control
28 |
--------------------------------------------------------------------------------
/conda_env_robosuite.yml:
--------------------------------------------------------------------------------
1 | name: core
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - pip=20.1.1
7 | - python=3.7.6
8 | - pytorch=1.10
9 | - torchvision=0.11.0
10 | - cudatoolkit=11.3
11 | - tensorflow=2.2.0
12 | - tensorboard=2.6.0
13 | - absl-py=0.13.0
14 | - pyparsing=2.4.7
15 | - pillow=6.1.0
16 | - h5py=2.10.0
17 | - pyyaml=5.4.1
18 | - moviepy=1.0.1
19 | - numpy=1.20
20 | - conda-forge::matplotlib=3.2.2
21 | - conda-forge::gym=0.19.0
22 | - conda-forge::ffmpeg=4.4.0
23 | - conda-forge::libiconv=1.16
24 | - conda-forge::imageio=2.6.1
25 | - pip:
26 | - termcolor==1.1.0
27 | - robosuite==1.2.0
28 |
--------------------------------------------------------------------------------
/configs/dcs/core.yaml:
--------------------------------------------------------------------------------
1 | parameters:
2 | seed: 42
3 | env:
4 | name: dcs
5 | domain: cheetah
6 | difficulty: hard
7 | dynamic: true
8 | background_dataset_path: DAVIS/JPEGImages/480p
9 | allow_background_distraction: true
10 | allow_camera_distraction: true
11 | allow_color_distraction: true
12 | image_height: 256
13 | image_width: 256
14 | num_frames_to_stack: 1
15 | start_rl_training_after: 0
16 | num_envs: 1
17 | num_val_envs: 5 #20
18 | num_episodes_per_val_env_for_reward: 1 #5
19 | max_steps: 2000000 # This divided by action_repeat is the max updates.
20 | episode_steps: 1000 # This divided by action_repeat is the max steps per episode.
21 | update_frequency_factor: 0.5
22 | initial_data_steps: 1000
23 | gamma: 0.99
24 | contrastive:
25 | inverse_temperature_init: 1.0
26 | softmax_over: both
27 | #mask_type: exclude_other_sequences
28 | include_model_params_in_critic: true
29 | sac_deterministic_state: true
30 | recon_from_prior: true
31 | lr: 3.e-4
32 | lr_inverse_temp: 2.e-3
33 | lr_actor: 1.e-3
34 | lr_critic: 1.e-3
35 | lr_alpha: 1.e-3
36 | momentum_alpha: 0.5
37 | momentum_critic: 0.9
38 | momentum_actor: 0.9
39 | max_grad_norm_wm: 10
40 | max_grad_norm_critic: 100
41 | max_grad_norm_actor: 10
42 | max_grad_norm_log_alpha: 1.0
43 | initial_log_alpha: -2.3
44 | max_log_alpha: 3.0
45 | update_target_critic_after: 1
46 | update_target_critic_tau: 0.005
47 | weight_decay: 0.0
48 | replay_buffer_size: 100000
49 | batch_size: 32
50 | dynamics_seq_len: 32
51 | random_crop: true
52 | random_crop_padding: 0
53 | same_crop_across_time: true
54 | crop_height: 84
55 | crop_width: 84
56 | loss_scales:
57 | eta_fwd: 0.01
58 | eta_r: 1.0
59 | eta_inv: 1.0
60 | eta_q: 0.5
61 | eta_s: 1.0
62 | wm:
63 | sample_state: true
64 | reward_prediction_from_prior: false
65 | decode_deterministic: false
66 | propagate_deterministic: false
67 | obs_encoder: # Input dims determined by binder symbol dims, or observation model's output dims.
68 | fc_activation: elu
69 | fc_hiddens: [256, 256]
70 | output_activation: identity
71 | layer_norm_output: true
72 | obs_decoder:
73 | fc_activation: elu
74 | fc_hiddens: [256, 256] # Will add another layer to make the output be be the same dim as enocder's input.
75 | output_activation: identity
76 | layer_norm_output: true
77 | reward_net:
78 | fc_activation: elu
79 | fc_hiddens: [128, 128, 1]
80 | output_activation: identity
81 | inverse_dynamics:
82 | fc_activation: elu
83 | fc_hiddens: [128, 128]
84 | output_activation: tanh
85 | dynamics:
86 | discrete: false
87 | rnn_state_dims: 200
88 | latent_dims: 64
89 | forward_dynamics_loss: kl #neg_log_prob
90 | free_kl: 1.0
91 | recurrent: true
92 | rnn_input: # Takes as input latent_dims + actions. The output of this is input to GRU.
93 | fc_activation: elu
94 | fc_hiddens: [400, 400]
95 | output_activation: elu
96 | posterior: # Takes as input rnn_state_dims + obs_embed
97 | fc_activation: elu
98 | fc_hiddens: [400, 400]
99 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax
100 | layer_norm_output: true
101 | layer_norm_affine: true
102 | prior: # takes as input rnn_state_dims
103 | fc_activation: elu
104 | fc_hiddens: [400, 400, 400]
105 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax
106 | layer_norm_output: true
107 | layer_norm_affine: true
108 | use_gating_network: false
109 | encoder_gating:
110 | conv_batch_norm: false
111 | conv_activation: elu
112 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this.
113 | conv_filters:
114 | - [16, [5, 5], 1, [2, 2]] # (64, 64)
115 | - [16, [5, 5], 1, [2, 2]] # (64, 64)
116 | - [16, [5, 5], 1, [2, 2]] # (64, 64)
117 | - [1, [1, 1], 1, [0, 0]] # (64, 64)
118 | encoder:
119 | conv_batch_norm: false
120 | conv_activation: relu
121 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this.
122 | conv_filters:
123 | - [32, [3, 3], 2, [0, 0]] # (31, 31)
124 | - [32, [3, 3], 1, [0, 0]] # (29, 29)
125 | - [32, [3, 3], 1, [0, 0]] # (27, 27)
126 | - [32, [3, 3], 1, [0, 0]] # (25, 25)
127 | fc_hiddens: [50]
128 | fc_activation: relu
129 | fc_batch_norm: false
130 | output_activation: identity
131 | layer_norm_output: true
132 | layer_norm_affine: false
133 | double_output_dims_for_std: false
134 | actor:
135 | fc_activation: relu
136 | fc_hiddens: [1024, 1024, 1024]
137 | output_activation: identity
138 | policy_max_logstd: 2
139 | policy_min_logstd: -10
140 | critic:
141 | fc_activation: relu
142 | fc_hiddens: [1024, 1024, 1024, 1]
143 | output_activation: identity
144 | validate_every_iters: 10
145 | print_every: 50
146 |
--------------------------------------------------------------------------------
/configs/robosuite/core.yaml:
--------------------------------------------------------------------------------
1 | parameters:
2 | seed: 42
3 | env:
4 | name: robosuite
5 | image_key: frontview_image
6 | other_key: robot0_proprio-state
7 | crop_image: true
8 | crop_width: 144
9 | crop_height: 144
10 | crop_center_xy: [104, 144] # For the Door task.
11 | #crop_center_xy: [128, 128]
12 | domain_randomize: true
13 | domain_randomization_config:
14 | randomize_color: true
15 | randomize_camera: true
16 | randomize_lighting: true
17 | randomize_dynamics: false
18 | #if True, randomize on every call to @reset. This, in
19 | # conjunction with setting @randomize_every_n_steps to 0, is useful to
20 | # generate a new domain per episode.
21 | randomize_on_reset: true
22 | randomize_every_n_steps: 0
23 | controller: OSC_POSE
24 | robosuite_config:
25 | control_freq: 20
26 | env_name: Door
27 | hard_reset: false
28 | horizon: 500
29 | ignore_done: true
30 | reward_scale: 1.0
31 | reward_shaping: true
32 | robots: [ Panda ]
33 | has_renderer: false
34 | has_offscreen_renderer: true
35 | use_object_obs: false
36 | use_camera_obs: true
37 | # ('frontview', 'birdview', 'agentview', 'sideview', 'robot0_robotview', 'robot0_eye_in_hand')
38 | camera_names: frontview
39 | camera_heights: 256
40 | camera_widths: 256
41 | camera_depths: false
42 | start_rl_training_after: 0
43 | num_envs: 1
44 | num_val_envs: 10
45 | num_episodes_per_val_env_for_reward: 2
46 | max_steps: 1000000 # This divided by action_repeat is the max updates.
47 | episode_steps: 500 # This divided by action_repeat is the max steps per episode.
48 | update_frequency_factor: 0.5
49 | initial_data_steps: 1000
50 | gamma: 0.99
51 | contrastive:
52 | inverse_temperature_init: 1.0
53 | softmax_over: both
54 | #mask_type: exclude_other_sequences
55 | include_model_params_in_critic: true
56 | sac_deterministic_state: true
57 | recon_from_prior: true
58 | lr: 3.e-4
59 | lr_inverse_temp: 2.e-3
60 | lr_actor: 1.e-3
61 | lr_critic: 1.e-3
62 | lr_alpha: 1.e-3
63 | max_grad_norm_wm: 10
64 | max_grad_norm_critic: 100
65 | max_grad_norm_actor: 10
66 | max_grad_norm_log_alpha: 1.0
67 | initial_log_alpha: -2.3
68 | max_log_alpha: 3.0
69 | update_target_critic_after: 1
70 | update_target_critic_tau: 0.005
71 | weight_decay: 1.e-8
72 | replay_buffer_size: 200000
73 | batch_size: 32
74 | dynamics_seq_len: 32
75 | random_crop: true
76 | random_crop_padding: 0
77 | same_crop_across_time: true
78 | crop_height: 128
79 | crop_width: 128
80 | rollout_prior_init_t: 10
81 | loss_scales:
82 | eta_fwd: 0.01
83 | eta_r: 1.0
84 | eta_inv: 1.0
85 | eta_q: 0.5
86 | eta_s: 1.0
87 | wm:
88 | reward_prediction_from_prior: false
89 | decode_deterministic: false
90 | propagate_deterministic: false
91 | obs_encoder: # Input dims determined by binder symbol dims, or observation model's output dims.
92 | fc_activation: elu
93 | fc_hiddens: [256, 256]
94 | output_activation: identity
95 | layer_norm_output: true
96 | obs_decoder:
97 | fc_activation: elu
98 | fc_hiddens: [256, 256] # Will add another layer to make the output be be the same dim as enocder's input.
99 | output_activation: identity
100 | layer_norm_output: true
101 | reward_net:
102 | fc_activation: elu
103 | fc_hiddens: [128, 128, 1]
104 | output_activation: identity
105 | inverse_dynamics:
106 | fc_activation: elu
107 | fc_hiddens: [128, 128]
108 | output_activation: tanh
109 | dynamics:
110 | discrete: false
111 | rnn_state_dims: 200
112 | latent_dims: 64
113 | forward_dynamics_loss: kl #neg_log_prob
114 | free_kl: 1.0
115 | recurrent: true
116 | rnn_input: # Takes as input latent_dims + actions. The output of this is input to GRU.
117 | fc_activation: elu
118 | fc_hiddens: [400, 400]
119 | output_activation: elu
120 | posterior: # Takes as input rnn_state_dims + obs_embed
121 | fc_activation: elu
122 | fc_hiddens: [400, 400]
123 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax
124 | layer_norm_output: true
125 | layer_norm_affine: true
126 | prior: # takes as input rnn_state_dims
127 | fc_activation: elu
128 | fc_hiddens: [400, 400, 400]
129 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax
130 | layer_norm_output: true
131 | layer_norm_affine: true
132 | encoder:
133 | conv_batch_norm: false
134 | conv_activation: relu
135 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this.
136 | conv_filters:
137 | - [32, [3, 3], 2, [1, 1]] # (64)
138 | - [32, [3, 3], 2, [1, 1]] # (32)
139 | - [32, [3, 3], 1, [0, 0]] # (30)
140 | - [32, [3, 3], 1, [0, 0]] # (28)
141 | - [32, [3, 3], 1, [0, 0]] # (26)
142 | fc_hiddens: [50]
143 | fc_activation: relu
144 | fc_batch_norm: false
145 | output_activation: identity
146 | layer_norm_output: true
147 | layer_norm_affine: false
148 | other_obs_model:
149 | fc_activation: elu
150 | fc_hiddens: [200, 200, 200]
151 | output_activation: elu
152 | actor:
153 | fc_activation: relu
154 | fc_hiddens: [1024, 1024, 1024]
155 | output_activation: identity
156 | policy_max_logstd: 2
157 | policy_min_logstd: -10
158 | critic:
159 | fc_activation: relu
160 | fc_hiddens: [1024, 1024, 1024, 1]
161 | output_activation: identity
162 | validate_every_iters: 10
163 | print_every: 50
164 |
--------------------------------------------------------------------------------
/configs/robosuite/core_imageonly.yaml:
--------------------------------------------------------------------------------
1 | parameters:
2 | seed: 42
3 | env:
4 | name: robosuite
5 | image_key: frontview_image
6 | crop_image: true
7 | crop_width: 144
8 | crop_height: 144
9 | crop_center_xy: [104, 144] # For the Door task center the crop at this location.
10 | #crop_center_xy: [128, 128]
11 | domain_randomize: true
12 | domain_randomization_config:
13 | randomize_color: true
14 | randomize_camera: true
15 | randomize_lighting: true
16 | randomize_dynamics: false
17 | #if True, randomize on every call to @reset. This, in
18 | # conjunction with setting @randomize_every_n_steps to 0, is useful to
19 | # generate a new domain per episode.
20 | randomize_on_reset: true
21 | randomize_every_n_steps: 0
22 | controller: OSC_POSE
23 | robosuite_config:
24 | control_freq: 20
25 | env_name: Door
26 | hard_reset: false
27 | horizon: 500
28 | ignore_done: true
29 | reward_scale: 1.0
30 | reward_shaping: true
31 | robots: [ Panda ]
32 | has_renderer: false
33 | has_offscreen_renderer: true
34 | use_object_obs: false
35 | use_camera_obs: true
36 | # ('frontview', 'birdview', 'agentview', 'sideview', 'robot0_robotview', 'robot0_eye_in_hand')
37 | camera_names: frontview
38 | camera_heights: 256
39 | camera_widths: 256
40 | camera_depths: false
41 | start_rl_training_after: 0
42 | num_envs: 1
43 | num_val_envs: 10
44 | num_episodes_per_val_env_for_reward: 2
45 | max_steps: 1000000 # This divided by action_repeat is the max updates.
46 | episode_steps: 500 # This divided by action_repeat is the max steps per episode.
47 | update_frequency_factor: 0.5
48 | initial_data_steps: 1000
49 | gamma: 0.99
50 | contrastive:
51 | inverse_temperature_init: 1.0
52 | softmax_over: both
53 | #mask_type: exclude_same_sequence
54 | #mask_type: exclude_other_sequences
55 | include_model_params_in_critic: true
56 | sac_deterministic_state: true
57 | recon_from_prior: true
58 | lr: 3.e-4
59 | lr_inverse_temp: 2.e-3
60 | lr_actor: 1.e-3
61 | lr_critic: 1.e-3
62 | lr_alpha: 1.e-3
63 | max_grad_norm_wm: 10
64 | max_grad_norm_critic: 100
65 | max_grad_norm_actor: 10
66 | max_grad_norm_log_alpha: 1.0
67 | initial_log_alpha: -2.3
68 | max_log_alpha: 3.0
69 | update_target_critic_after: 1
70 | update_target_critic_tau: 0.005
71 | weight_decay: 0.0
72 | replay_buffer_size: 100000
73 | batch_size: 32
74 | dynamics_seq_len: 32
75 | random_crop: true
76 | random_crop_padding: 0
77 | same_crop_across_time: true
78 | crop_height: 128
79 | crop_width: 128
80 | rollout_prior_init_t: 10
81 | loss_scales:
82 | eta_fwd: 0.01
83 | eta_r: 1.0
84 | eta_inv: 1.0
85 | eta_q: 0.5
86 | eta_s: 1.0
87 | wm:
88 | sample_state: true
89 | reward_prediction_from_prior: false
90 | decode_deterministic: false
91 | propagate_deterministic: false
92 | obs_encoder: # Input dims determined by binder symbol dims, or observation model's output dims.
93 | fc_activation: elu
94 | fc_hiddens: [256, 256]
95 | output_activation: identity
96 | layer_norm_output: true
97 | obs_decoder:
98 | fc_activation: elu
99 | fc_hiddens: [256, 256] # Will add another layer to make the output be be the same dim as enocder's input.
100 | output_activation: identity
101 | layer_norm_output: true
102 | reward_net:
103 | fc_activation: elu
104 | fc_hiddens: [128, 128, 1]
105 | output_activation: identity
106 | inverse_dynamics:
107 | fc_activation: elu
108 | fc_hiddens: [128, 128]
109 | output_activation: tanh
110 | dynamics:
111 | discrete: false
112 | rnn_state_dims: 200
113 | latent_dims: 64
114 | forward_dynamics_loss: kl #neg_log_prob
115 | free_kl: 1.0
116 | recurrent: true
117 | rnn_input: # Takes as input latent_dims + actions. The output of this is input to GRU.
118 | fc_activation: elu
119 | fc_hiddens: [400, 400]
120 | output_activation: elu
121 | posterior: # Takes as input rnn_state_dims + obs_embed
122 | fc_activation: elu
123 | fc_hiddens: [400, 400]
124 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax
125 | layer_norm_output: true
126 | layer_norm_affine: true
127 | prior: # takes as input rnn_state_dims
128 | fc_activation: elu
129 | fc_hiddens: [400, 400, 400]
130 | output_activation: identity # Outputs logits. Num outputs = stoch_num_softmaxes * stoch_dims_per_softmax
131 | layer_norm_output: true
132 | layer_norm_affine: true
133 | encoder:
134 | conv_batch_norm: false
135 | conv_activation: relu
136 | base_num_hid: 1 # The number of filters in fc and conv layers is multiplied by this.
137 | conv_filters:
138 | - [32, [3, 3], 2, [1, 1]] # (64)
139 | - [32, [3, 3], 2, [1, 1]] # (32)
140 | - [32, [3, 3], 1, [0, 0]] # (30)
141 | - [32, [3, 3], 1, [0, 0]] # (28)
142 | - [32, [3, 3], 1, [0, 0]] # (26)
143 | #conv_filters:
144 | # - [1, [4, 4], 2] # (64) # (112)
145 | # - [2, [4, 4], 2] # (32) # (56)
146 | # - [2, [4, 4], 2] # (16) #(28)
147 | # - [2, [4, 4], 2] # (8) # (14)
148 | # - [4, [4, 4], 2] # (4) # (7)
149 | fc_hiddens: [50]
150 | fc_activation: relu
151 | fc_batch_norm: false
152 | output_activation: identity
153 | layer_norm_output: true
154 | layer_norm_affine: false
155 | actor:
156 | fc_activation: relu
157 | fc_hiddens: [1024, 1024, 1024]
158 | output_activation: identity
159 | policy_max_logstd: 2
160 | policy_min_logstd: -10
161 | critic:
162 | fc_activation: relu
163 | fc_hiddens: [1024, 1024, 1024, 1]
164 | output_activation: identity
165 | validate_every_iters: 10
166 | print_every: 50
167 |
--------------------------------------------------------------------------------
/distracting_control/README.md:
--------------------------------------------------------------------------------
1 | # The Distracting Control Suite
2 |
3 | `distracting_control` extends `dm_control` with static or dynamic visual
4 | distractions in the form of changing colors, backgrounds, and camera poses.
5 | Details and experimental results can be found in our
6 | [paper](https://arxiv.org/pdf/2101.02722.pdf).
7 |
8 | ## Requirements and Installation
9 |
10 | * Clone this repository
11 | * `sh run.sh`
12 | * Follow the instructions and install
13 | [dm_control](https://github.com/deepmind/dm_control#requirements-and-installation). Make sure you setup your MuJoCo keys correctly.
14 | * Download the [DAVIS 2017
15 | dataset](https://davischallenge.org/davis2017/code.html). Make sure to select the 2017 TrainVal - Images and Annotations (480p). The training images will be used as distracting backgrounds.
16 |
17 | ## Instructions
18 |
19 | * You can run the `distracting_control_demo` to generate sample images of the
20 | different tasks at different difficulties:
21 |
22 | ```
23 | python distracting_control_demo --davis_path=$HOME/DAVIS/JPEGImages/480p/
24 | --output_dir=/tmp/distrtacting_control_demo
25 | ```
26 | * As seen from the demo to generate an instance of the environment you simply
27 | need to import the suite and use `suite.load` while specifying the
28 | `dm_control` domain and task, then choosing a difficulty and providing the
29 | dataset_path.
30 |
31 | * Note the environment follows the dm_control environment APIs.
32 |
33 | ## Paper
34 |
35 | If you use this code, please cite the accompanying [paper](https://arxiv.org/pdf/2101.02722.pdf) as:
36 |
37 | ```
38 | @article{stone2021distracting,
39 | title={The Distracting Control Suite -- A Challenging Benchmark for Reinforcement Learning from Pixels},
40 | author={Austin Stone and Oscar Ramirez and Kurt Konolige and Rico Jonschkowski},
41 | year={2021},
42 | journal={arXiv preprint arXiv:2101.02722},
43 | }
44 | ```
45 |
46 | ## Disclaimer
47 |
48 | This is not an official Google product.
49 |
--------------------------------------------------------------------------------
/distracting_control/background.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """A wrapper for dm_control environments which applies color distractions."""
18 | import os
19 |
20 | from PIL import Image
21 | import collections
22 | from dm_control.rl import control
23 | import numpy as np
24 |
25 | import tensorflow as tf
26 | from dm_control.mujoco.wrapper import mjbindings
27 |
28 | DAVIS17_TRAINING_VIDEOS = [
29 | 'bear', 'bmx-bumps', 'boat', 'boxing-fisheye', 'breakdance-flare', 'bus',
30 | 'car-turn', 'cat-girl', 'classic-car', 'color-run', 'crossing',
31 | 'dance-jump', 'dancing', 'disc-jockey', 'dog-agility', 'dog-gooses',
32 | 'dogs-scale', 'drift-turn', 'drone', 'elephant', 'flamingo', 'hike',
33 | 'hockey', 'horsejump-low', 'kid-football', 'kite-walk', 'koala',
34 | 'lady-running', 'lindy-hop', 'longboard', 'lucia', 'mallard-fly',
35 | 'mallard-water', 'miami-surf', 'motocross-bumps', 'motorbike', 'night-race',
36 | 'paragliding', 'planes-water', 'rallye', 'rhino', 'rollerblade',
37 | 'schoolgirls', 'scooter-board', 'scooter-gray', 'sheep', 'skate-park',
38 | 'snowboard', 'soccerball', 'stroller', 'stunt', 'surf', 'swing', 'tennis',
39 | 'tractor-sand', 'train', 'tuk-tuk', 'upside-down', 'varanus-cage', 'walking'
40 | ]
41 | DAVIS17_VALIDATION_VIDEOS = [
42 | 'bike-packing', 'blackswan', 'bmx-trees', 'breakdance', 'camel',
43 | 'car-roundabout', 'car-shadow', 'cows', 'dance-twirl', 'dog', 'dogs-jump',
44 | 'drift-chicane', 'drift-straight', 'goat', 'gold-fish', 'horsejump-high',
45 | 'india', 'judo', 'kite-surf', 'lab-coat', 'libby', 'loading', 'mbike-trick',
46 | 'motocross-jump', 'paragliding-launch', 'parkour', 'pigs', 'scooter-black',
47 | 'shooting', 'soapbox'
48 | ]
49 | SKY_TEXTURE_INDEX = 0
50 | Texture = collections.namedtuple('Texture', ('size', 'address', 'textures'))
51 |
52 |
53 | def imread(filename):
54 | img = Image.open(filename)
55 | img_np = np.asarray(img)
56 | return img_np
57 |
58 |
59 | def size_and_flatten(image, ref_height, ref_width):
60 | # Resize image if necessary and flatten the result.
61 | image_height, image_width = image.shape[:2]
62 |
63 | if image_height != ref_height or image_width != ref_width:
64 | image = tf.cast(tf.image.resize(image, [ref_height, ref_width]), tf.uint8)
65 | return tf.reshape(image, [-1]).numpy()
66 |
67 |
68 | def blend_to_background(alpha, image, background):
69 | if alpha == 1.0:
70 | return image
71 | elif alpha == 0.0:
72 | return background
73 | else:
74 | return (alpha * image.astype(np.float32)
75 | + (1. - alpha) * background.astype(np.float32)).astype(np.uint8)
76 |
77 |
78 | class DistractingBackgroundEnv(control.Environment):
79 | """Environment wrapper for background visual distraction.
80 |
81 | **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure
82 | the background image changes are applied before rendering occurs.
83 | """
84 |
85 | def __init__(self,
86 | env,
87 | dataset_path=None,
88 | dataset_videos=None,
89 | video_alpha=1.0,
90 | ground_plane_alpha=1.0,
91 | num_videos=None,
92 | dynamic=False,
93 | seed=None,
94 | shuffle_buffer_size=None):
95 |
96 | if not 0 <= video_alpha <= 1:
97 | raise ValueError('`video_alpha` must be in the range [0, 1]')
98 |
99 | self._env = env
100 | self._video_alpha = video_alpha
101 | self._ground_plane_alpha = ground_plane_alpha
102 | self._random_state = np.random.RandomState(seed=seed)
103 | self._dynamic = dynamic
104 | self._shuffle_buffer_size = shuffle_buffer_size
105 | self._background = None
106 | self._current_img_index = 0
107 |
108 | if not dataset_path or num_videos == 0:
109 | # Allow running the wrapper without backgrounds to still set the ground
110 | # plane alpha value.
111 | self._video_paths = []
112 | else:
113 | # Use all videos if no specific ones were passed.
114 | if not dataset_videos:
115 | dataset_videos = sorted(tf.io.gfile.listdir(dataset_path))
116 | # Replace video placeholders 'train'/'val' with the list of videos.
117 | elif dataset_videos in ['train', 'training']:
118 | dataset_videos = DAVIS17_TRAINING_VIDEOS
119 | elif dataset_videos in ['val', 'validation']:
120 | dataset_videos = DAVIS17_VALIDATION_VIDEOS
121 | # Get complete paths for all videos.
122 | video_paths = [
123 | os.path.join(dataset_path, subdir) for subdir in dataset_videos
124 | ]
125 |
126 | # Optionally use only the first num_paths many paths.
127 | if num_videos is not None:
128 | if num_videos > len(video_paths) or num_videos < 0:
129 | raise ValueError(f'`num_bakground_paths` is {num_videos} but '
130 | 'should not be larger than the number of available '
131 | f'background paths ({len(video_paths)}) and at '
132 | 'least 0.')
133 | video_paths = video_paths[:num_videos]
134 |
135 | self._video_paths = video_paths
136 |
137 | def reset(self):
138 | """Reset the background state."""
139 | time_step = self._env.reset()
140 | self._reset_background()
141 | return time_step
142 |
143 | def _reset_background(self):
144 | # Make grid semi-transparent.
145 | if self._ground_plane_alpha is not None:
146 | self._env.physics.named.model.mat_rgba['grid',
147 | 'a'] = self._ground_plane_alpha
148 |
149 | # For some reason the height of the skybox is set to 4800 by default,
150 | # which does not work with new textures.
151 | self._env.physics.model.tex_height[SKY_TEXTURE_INDEX] = 800
152 |
153 | # Set the sky texture reference.
154 | sky_height = self._env.physics.model.tex_height[SKY_TEXTURE_INDEX]
155 | sky_width = self._env.physics.model.tex_width[SKY_TEXTURE_INDEX]
156 | sky_size = sky_height * sky_width * 3
157 | sky_address = self._env.physics.model.tex_adr[SKY_TEXTURE_INDEX]
158 |
159 | sky_texture = self._env.physics.model.tex_rgb[sky_address:sky_address +
160 | sky_size].astype(np.float32)
161 |
162 | if self._video_paths:
163 |
164 | if self._shuffle_buffer_size:
165 | # Shuffle images from all videos together to get background frames.
166 | file_names = [
167 | os.path.join(path, fn)
168 | for path in self._video_paths
169 | for fn in tf.io.gfile.listdir(path)
170 | ]
171 | self._random_state.shuffle(file_names)
172 | # Load only the first n images for performance reasons.
173 | file_names = file_names[:self._shuffle_buffer_size]
174 | images = [imread(fn) for fn in file_names]
175 | else:
176 | # Randomly pick a video and load all images.
177 | video_path = self._random_state.choice(self._video_paths)
178 | file_names = tf.io.gfile.listdir(video_path)
179 | if not self._dynamic:
180 | # Randomly pick a single static frame.
181 | file_names = [self._random_state.choice(file_names)]
182 | images = [imread(os.path.join(video_path, fn)) for fn in file_names]
183 |
184 | # Pick a random starting point and steping direction.
185 | self._current_img_index = self._random_state.choice(len(images))
186 | self._step_direction = self._random_state.choice([-1, 1])
187 |
188 | # Prepare images in the texture format by resizing and flattening.
189 |
190 | # Generate image textures.
191 | texturized_images = []
192 | for image in images:
193 | image_flattened = size_and_flatten(image, sky_height, sky_width)
194 | new_texture = blend_to_background(self._video_alpha, image_flattened,
195 | sky_texture)
196 | texturized_images.append(new_texture)
197 |
198 | else:
199 |
200 | self._current_img_index = 0
201 | texturized_images = [sky_texture]
202 |
203 | self._background = Texture(sky_size, sky_address, texturized_images)
204 | self._apply()
205 |
206 | def step(self, action):
207 | time_step = self._env.step(action)
208 |
209 | if time_step.first():
210 | self._reset_background()
211 | return time_step
212 |
213 | if self._dynamic and self._video_paths:
214 | # Move forward / backward in the image sequence by updating the index.
215 | self._current_img_index += self._step_direction
216 |
217 | # Start moving forward if we are past the start of the images.
218 | if self._current_img_index <= 0:
219 | self._current_img_index = 0
220 | self._step_direction = abs(self._step_direction)
221 | # Start moving backwards if we are past the end of the images.
222 | if self._current_img_index >= len(self._background.textures):
223 | self._current_img_index = len(self._background.textures) - 1
224 | self._step_direction = -abs(self._step_direction)
225 |
226 | self._apply()
227 | return time_step
228 |
229 | def _apply(self):
230 | """Apply the background texture to the physics."""
231 |
232 | if self._background:
233 | start = self._background.address
234 | end = self._background.address + self._background.size
235 | texture = self._background.textures[self._current_img_index]
236 |
237 | self._env.physics.model.tex_rgb[start:end] = texture
238 | # Upload the new texture to the GPU. Note: we need to make sure that the
239 | # OpenGL context belonging to this Physics instance is the current one.
240 | with self._env.physics.contexts.gl.make_current() as ctx:
241 | ctx.call(
242 | mjbindings.mjlib.mjr_uploadTexture,
243 | self._env.physics.model.ptr,
244 | self._env.physics.contexts.mujoco.ptr,
245 | SKY_TEXTURE_INDEX,
246 | )
247 |
248 | # Forward property and method calls to self._env.
249 | def __getattr__(self, attr):
250 | if hasattr(self._env, attr):
251 | return getattr(self._env, attr)
252 | raise AttributeError("'{}' object has no attribute '{}'".format(
253 | type(self).__name__, attr))
254 |
--------------------------------------------------------------------------------
/distracting_control/camera.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """A wrapper for dm_control environments which applies camera distractions."""
18 |
19 | import copy
20 | from dm_control.rl import control
21 | import numpy as np
22 |
23 | CAMERA_MODES = ['fixed', 'track', 'trackcom', 'targetbody', 'targetbodycom']
24 |
25 |
26 | def eul2mat(theta):
27 | """Converts euler angles (x, y, z) to a rotation matrix."""
28 |
29 | return np.array([[
30 | np.cos(theta[1]) * np.cos(theta[2]),
31 | np.sin(theta[0]) * np.sin(theta[1]) * np.cos(theta[2]) -
32 | np.sin(theta[2]) * np.cos(theta[0]),
33 | np.sin(theta[1]) * np.cos(theta[0]) * np.cos(theta[2]) +
34 | np.sin(theta[0]) * np.sin(theta[2])
35 | ],
36 | [
37 | np.sin(theta[2]) * np.cos(theta[1]),
38 | np.sin(theta[0]) * np.sin(theta[1]) * np.sin(theta[2]) +
39 | np.cos(theta[0]) * np.cos(theta[2]),
40 | np.sin(theta[1]) * np.sin(theta[2]) * np.cos(theta[0]) -
41 | np.sin(theta[0]) * np.cos(theta[2])
42 | ],
43 | [
44 | -np.sin(theta[1]),
45 | np.sin(theta[0]) * np.cos(theta[1]),
46 | np.cos(theta[0]) * np.cos(theta[1])
47 | ]])
48 |
49 |
50 | def _mat_from_theta(cos_theta, sin_theta, a):
51 | """Builds a rotation matrix from theta and an orientation vector."""
52 |
53 | row1 = [
54 | cos_theta + a[0]**2. * (1. - cos_theta),
55 | a[0] * a[1] * (1 - cos_theta) - a[2] * sin_theta,
56 | a[0] * a[2] * (1 - cos_theta) + a[1] * sin_theta
57 | ]
58 | row2 = [
59 | a[1] * a[0] * (1 - cos_theta) + a[2] * sin_theta,
60 | cos_theta + a[1]**2. * (1 - cos_theta),
61 | a[1] * a[2] * (1. - cos_theta) - a[0] * sin_theta
62 | ]
63 | row3 = [
64 | a[2] * a[0] * (1. - cos_theta) - a[1] * sin_theta,
65 | a[2] * a[1] * (1. - cos_theta) + a[0] * sin_theta,
66 | cos_theta + (a[2]**2.) * (1. - cos_theta)
67 | ]
68 | return np.stack([row1, row2, row3])
69 |
70 |
71 | def rotvec2mat(theta, vec):
72 | """Converts a rotation around a vector to a rotation matrix."""
73 |
74 | a = vec / np.sqrt(np.sum(vec**2.))
75 | sin_theta = np.sin(theta)
76 | cos_theta = np.cos(theta)
77 |
78 | return _mat_from_theta(cos_theta, sin_theta, a)
79 |
80 |
81 | def get_lookat_xmat_no_roll(agent_pos, camera_pos):
82 | """Solves for the cam rotation centering the agent with 0 roll."""
83 |
84 | # NOTE(austinstone): This method leads to wild oscillations around the north
85 | # and south polls.
86 | # For example, if agent is at (0., 0., 0.) and the camera is at (.01, 0., 1.),
87 | # this will produce a yaw of 90 degrees whereas if the camera is slightly
88 | # adjacent at (-.01, 0., 1.) this will produce a yaw of -90 degrees. I'm
89 | # not sure what the fix is, as this seems like the behavior we want in all
90 | # environments except for reacher.
91 | delta_vec = agent_pos - camera_pos
92 | delta_vec /= np.sqrt(np.sum(delta_vec**2.))
93 | yaw = np.arctan2(delta_vec[0], delta_vec[1])
94 | pitch = np.arctan2(delta_vec[2], np.sqrt(np.sum(delta_vec[:2]**2.)))
95 | pitch += np.pi / 2. # Camera starts out looking at [0, 0, -1.]
96 | return eul2mat([pitch, 0., -yaw]).flatten()
97 |
98 |
99 | def get_lookat_xmat(agent_pos, camera_pos):
100 | """Solves for the cam rotation centering the agent, allowing roll."""
101 |
102 | # Solve for the rotation which centers the agent in the scene.
103 | delta_vec = agent_pos - camera_pos
104 | delta_vec /= np.sqrt(np.sum(delta_vec**2.))
105 | y_vec = np.array([0., 0., -1.]) # This is where the cam starts from.
106 | a = np.cross(y_vec, delta_vec)
107 | sin_theta = np.sqrt(np.sum(a**2.))
108 | cos_theta = np.dot(delta_vec, y_vec)
109 | a /= (np.sqrt(np.sum(a**2.)) + .0001)
110 | return _mat_from_theta(cos_theta, sin_theta, a)
111 |
112 |
113 | def cart2sphere(cart):
114 | r = np.sqrt(np.sum(cart**2.))
115 | h_angle = np.arctan2(cart[1], cart[0])
116 | v_angle = np.arctan2(np.sqrt(np.sum(cart[:2]**2.)), cart[2])
117 | return np.array([r, h_angle, v_angle])
118 |
119 |
120 | def sphere2cart(sphere):
121 | r, h_angle, v_angle = sphere
122 | x = r * np.sin(v_angle) * np.cos(h_angle)
123 | y = r * np.sin(v_angle) * np.sin(h_angle)
124 | z = r * np.cos(v_angle)
125 | return np.array([x, y, z])
126 |
127 |
128 | def clip_cam_position(position, min_radius, max_radius, min_h_angle,
129 | max_h_angle, min_v_angle, max_v_angle):
130 | new_position = [-1., -1., -1.]
131 | new_position[0] = np.clip(position[0], min_radius, max_radius)
132 | new_position[1] = np.clip(position[1], min_h_angle, max_h_angle)
133 | new_position[2] = np.clip(position[2], min_v_angle, max_v_angle)
134 | return new_position
135 |
136 |
137 | def get_lookat_point(physics, camera_id):
138 | """Get the point that the camera is looking at.
139 |
140 | It is assumed that the "point" the camera looks at the agent distance
141 | away and projected along the camera viewing matrix.
142 |
143 | Args:
144 | physics: mujoco physics objects
145 | camera_id: int
146 |
147 | Returns:
148 | position: float32 np.array of length 3
149 | """
150 | dist_to_agent = physics.named.data.cam_xpos[
151 | camera_id] - physics.named.data.subtree_com[1]
152 | dist_to_agent = np.sqrt(np.sum(dist_to_agent**2.))
153 | initial_viewing_mat = copy.deepcopy(physics.named.data.cam_xmat[camera_id])
154 | initial_viewing_mat = np.reshape(initial_viewing_mat, (3, 3))
155 | z_vec = np.array([0., 0., -dist_to_agent])
156 | rotated_vec = np.dot(initial_viewing_mat, z_vec)
157 | return rotated_vec + physics.named.data.cam_xpos[camera_id]
158 |
159 |
160 | class DistractingCameraEnv(control.Environment):
161 | """Environment wrapper for camera pose visual distraction.
162 |
163 | **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure
164 | the camera pose changes are applied before rendering occurs.
165 | """
166 |
167 | def __init__(self,
168 | env,
169 | camera_id,
170 | horizontal_delta,
171 | vertical_delta,
172 | max_vel,
173 | vel_std,
174 | roll_delta,
175 | max_roll_vel,
176 | roll_std,
177 | max_zoom_in_percent,
178 | max_zoom_out_percent,
179 | limit_to_upper_quadrant=False,
180 | seed=None):
181 | self._env = env
182 | self._camera_id = camera_id
183 | self._horizontal_delta = horizontal_delta
184 | self._vertical_delta = vertical_delta
185 |
186 | self._horizontal_delta = horizontal_delta
187 | self._vertical_delta = vertical_delta
188 | self._max_vel = max_vel
189 | self._vel_std = vel_std
190 | self._roll_delta = roll_delta
191 | self._max_roll_vel = max_roll_vel
192 | self._roll_vel_std = roll_std
193 | self._max_zoom_in_percent = max_zoom_in_percent
194 | self._max_zoom_out_percent = max_zoom_out_percent
195 | self._limit_to_upper_quadrant = limit_to_upper_quadrant
196 |
197 | self._random_state = np.random.RandomState(seed=seed)
198 |
199 | # These camera state parameters will be set on the first reset call.
200 | self._camera_type = None
201 | self._camera_initial_lookat_point = None
202 |
203 | self._camera_vel = None
204 | self._max_h_angle = None
205 | self._max_v_angle = None
206 | self._min_h_angle = None
207 | self._min_v_angle = None
208 | self._radius = None
209 | self._roll_vel = None
210 | self._vel_scaling = None
211 |
212 | def setup_camera(self):
213 | """Set up camera motion ranges and state."""
214 | # Define boundaries on the range of the camera motion.
215 | mode = self._env._physics.model.cam_mode[0]
216 |
217 | camera_type = CAMERA_MODES[mode]
218 | assert camera_type in ['fixed', 'trackcom']
219 |
220 | self._camera_type = camera_type
221 | self._cam_initial_lookat_point = get_lookat_point(self._env.physics,
222 | self._camera_id)
223 |
224 | start_pos = copy.deepcopy(
225 | self._env.physics.named.data.cam_xpos[self._camera_id])
226 |
227 | if self._camera_type != 'fixed':
228 | # Center the camera relative to the agent's center of mass.
229 | start_pos -= self._env.physics.named.data.subtree_com[1]
230 |
231 | start_r, start_h_angle, start_v_angle = cart2sphere(start_pos)
232 | # Scale the velocity by the starting radius. Most environments have radius 4,
233 | # but this downscales the velocity for the envs with radius < 4.
234 | self._vel_scaling = start_r / 4.
235 | self._max_h_angle = start_h_angle + self._horizontal_delta
236 | self._min_h_angle = start_h_angle - self._horizontal_delta
237 | self._max_v_angle = start_v_angle + self._vertical_delta
238 | self._min_v_angle = start_v_angle - self._vertical_delta
239 |
240 | if self._limit_to_upper_quadrant:
241 | # A centered cam is at np.pi / 2.
242 | self._max_v_angle = min(self._max_v_angle, np.pi / 2.)
243 | self._min_v_angle = max(self._min_v_angle, 0.)
244 | # A centered cam is at -np.pi / 2.
245 | self._max_h_angle = min(self._max_h_angle, 0.)
246 | self._min_h_angle = max(self._min_h_angle, -np.pi)
247 |
248 | self._max_roll = self._roll_delta
249 | self._min_roll = -self._roll_delta
250 | self._min_radius = max(start_r - start_r * self._max_zoom_in_percent, 0.)
251 | self._max_radius = start_r + start_r * self._max_zoom_out_percent
252 |
253 | # Decide the starting position for the camera.
254 | self._h_angle = self._random_state.uniform(self._min_h_angle,
255 | self._max_h_angle)
256 |
257 | self._v_angle = self._random_state.uniform(self._min_v_angle,
258 | self._max_v_angle)
259 |
260 | self._radius = self._random_state.uniform(self._min_radius,
261 | self._max_radius)
262 |
263 | self._roll = self._random_state.uniform(self._min_roll, self._max_roll)
264 |
265 | # Decide the starting velocity for the camera.
266 | vel = self._random_state.randn(3)
267 | vel /= np.sqrt(np.sum(vel**2.))
268 | vel *= self._random_state.uniform(0., self._max_vel)
269 | self._camera_vel = vel
270 | self._roll_vel = self._random_state.uniform(-self._max_roll_vel,
271 | self._max_roll_vel)
272 |
273 | def reset(self):
274 | """Reset the camera state. """
275 | time_step = self._env.reset()
276 | self.setup_camera()
277 | self._apply()
278 | return time_step
279 |
280 |
281 | def step(self, action):
282 | time_step = self._env.step(action)
283 |
284 | if time_step.first():
285 | self.setup_camera()
286 |
287 | self._apply()
288 | return time_step
289 |
290 | def _apply(self):
291 | if not self._camera_type:
292 | self.setup_camera()
293 |
294 | # Random walk the velocity.
295 | vel_delta = self._random_state.randn(3)
296 | self._camera_vel += vel_delta * self._vel_std * self._vel_scaling
297 | self._roll_vel += self._random_state.randn() * self._roll_vel_std
298 |
299 | # Clip velocity if it gets too big.
300 | vel_norm = np.sqrt(np.sum(self._camera_vel**2.))
301 | if vel_norm > self._max_vel * self._vel_scaling:
302 | self._camera_vel *= (self._max_vel * self._vel_scaling) / vel_norm
303 |
304 | self._roll_vel = np.clip(self._roll_vel, -self._max_roll_vel,
305 | self._max_roll_vel)
306 |
307 | cart_cam_pos = sphere2cart([self._radius, self._h_angle, self._v_angle])
308 | # Apply velocity vector to camera
309 | sphere_cam_pos2 = cart2sphere(cart_cam_pos + self._camera_vel)
310 | sphere_cam_pos2 = clip_cam_position(sphere_cam_pos2, self._min_radius,
311 | self._max_radius, self._min_h_angle,
312 | self._max_h_angle, self._min_v_angle,
313 | self._max_v_angle)
314 |
315 | self._camera_vel = sphere2cart(sphere_cam_pos2) - cart_cam_pos
316 |
317 | self._radius, self._h_angle, self._v_angle = sphere_cam_pos2
318 |
319 | roll2 = self._roll + self._roll_vel
320 | roll2 = np.clip(roll2, self._min_roll, self._max_roll)
321 |
322 | self._roll_vel = roll2 - self._roll
323 | self._roll = roll2
324 |
325 | cart_cam_pos = sphere2cart(sphere_cam_pos2)
326 |
327 | if self._limit_to_upper_quadrant:
328 | lookat_method = get_lookat_xmat_no_roll
329 | else:
330 | # This method avoids jitteriness at the pole but allows some roll
331 | # in the camera matrix. This is important for reacher.
332 | lookat_method = get_lookat_xmat
333 |
334 | if self._camera_type == 'fixed':
335 | lookat_mat = lookat_method(self._cam_initial_lookat_point,
336 | cart_cam_pos)
337 | else:
338 | # Go from agent centric to world coords
339 | cart_cam_pos += self._env.physics.named.data.subtree_com[1]
340 | lookat_mat = lookat_method(
341 | get_lookat_point(self._env.physics, self._camera_id), cart_cam_pos)
342 |
343 | lookat_mat = np.reshape(lookat_mat, (3, 3))
344 | roll_mat = rotvec2mat(self._roll, np.array([0., 0., 1.]))
345 | xmat = np.dot(lookat_mat, roll_mat)
346 | self._env.physics.named.data.cam_xpos[self._camera_id] = cart_cam_pos
347 | self._env.physics.named.data.cam_xmat[self._camera_id] = xmat.flatten()
348 |
349 | # Forward property and method calls to self._env.
350 | def __getattr__(self, attr):
351 | if hasattr(self._env, attr):
352 | return getattr(self._env, attr)
353 | raise AttributeError("'{}' object has no attribute '{}'".format(
354 | type(self).__name__, attr))
355 |
--------------------------------------------------------------------------------
/distracting_control/camera_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for camera movement code."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl.testing import absltest
23 | from dm_control import suite as dm_control_suite
24 | from dm_control.suite import cartpole
25 | from dm_control.suite.wrappers import pixels
26 | import numpy as np
27 |
28 | from distracting_control import camera
29 |
30 |
31 | def get_camera_params(domain_name, scale, dynamic):
32 | return dict(
33 | vertical_delta=np.pi / 2 * scale,
34 | horizontal_delta=np.pi / 2 * scale,
35 | # Limit camera to -90 / 90 degree rolls.
36 | roll_delta=np.pi / 2. * scale,
37 | vel_std=.1 * scale if dynamic else 0.,
38 | max_vel=.4 * scale if dynamic else 0.,
39 | roll_std=np.pi / 300 * scale if dynamic else 0.,
40 | max_roll_vel=np.pi / 50 * scale if dynamic else 0.,
41 | # Allow the camera to zoom in at most 50%.
42 | max_zoom_in_percent=.5 * scale,
43 | # Allow the camera to zoom out at most 200%.
44 | max_zoom_out_percent=1.5 * scale,
45 | limit_to_upper_quadrant='reacher' not in domain_name,
46 | )
47 |
48 |
49 | def distraction_wrap(env, domain_name):
50 | camera_kwargs = get_camera_params(
51 | domain_name=domain_name, scale=0.0, dynamic=True)
52 | return camera.DistractingCameraEnv(env, camera_id=0, **camera_kwargs)
53 |
54 |
55 | class CameraTest(absltest.TestCase):
56 |
57 | def test_dynamic(self):
58 | camera_kwargs = get_camera_params(
59 | domain_name='cartpole', scale=0.1, dynamic=True)
60 | env = cartpole.swingup()
61 | env = camera.DistractingCameraEnv(env, camera_id=0, **camera_kwargs)
62 | env = pixels.Wrapper(env, render_kwargs={'camera_id': 0})
63 | action_spec = env.action_spec()
64 | time_step = env.reset()
65 | frames = []
66 | while not time_step.last() and len(frames) < 10:
67 | action = np.random.uniform(
68 | action_spec.minimum, action_spec.maximum, size=action_spec.shape)
69 | time_step = env.step(action)
70 | frames.append(time_step.observation['pixels'])
71 | self.assertEqual(frames[0].shape, (240, 320, 3))
72 |
73 | def test_get_lookat_mat(self):
74 | agent_pos = np.array([1., -3., 4.])
75 | cam_position = np.array([0., 0., 0.])
76 | mat = camera.get_lookat_xmat_no_roll(agent_pos, cam_position)
77 | agent_pos = agent_pos / np.sqrt(np.sum(agent_pos**2.))
78 | start = np.array([0., 0., -1.]) # Cam starts looking down Z.
79 | out = np.dot(mat.reshape((3, 3)), start)
80 | self.assertTrue(np.isclose(np.max(np.abs(out - agent_pos)), 0.))
81 |
82 | def test_spherical_conversion(self):
83 | cart = np.array([1.4, -2.8, 3.9])
84 | sphere = camera.cart2sphere(cart)
85 | cart2 = camera.sphere2cart(sphere)
86 | self.assertTrue(np.isclose(np.max(np.abs(cart2 - cart)), 0.))
87 |
88 | def test_envs_same(self):
89 | # Test that the camera augmentations with magnitude 0 gives the same results
90 | # as when no camera augmentations are used.
91 | render_kwargs = {'width': 84, 'height': 84, 'camera_id': 0}
92 | domain_and_task = [('cartpole', 'swingup'),
93 | ('reacher', 'easy'),
94 | ('finger', 'spin'),
95 | ('cheetah', 'run'),
96 | ('ball_in_cup', 'catch'),
97 | ('walker', 'walk')]
98 | for (domain, task) in domain_and_task:
99 | seed = 42
100 | envs = [('baseline',
101 | pixels.Wrapper(
102 | dm_control_suite.load(
103 | domain, task, task_kwargs={'random': seed}),
104 | render_kwargs=render_kwargs)),
105 | ('no-wrapper',
106 | pixels.Wrapper(
107 | dm_control_suite.load(
108 | domain, task, task_kwargs={'random': seed}),
109 | render_kwargs=render_kwargs)),
110 | ('w/-camera_kwargs',
111 | pixels.Wrapper(
112 | distraction_wrap(
113 | dm_control_suite.load(
114 | domain, task, task_kwargs={'random': seed}), domain),
115 | render_kwargs=render_kwargs))]
116 | frames = []
117 | for _, env in envs:
118 | random_state = np.random.RandomState(42)
119 | action_spec = env.action_spec()
120 | time_step = env.reset()
121 | frames.append([])
122 | while not time_step.last() and len(frames[-1]) < 20:
123 | action = random_state.uniform(
124 | action_spec.minimum, action_spec.maximum, size=action_spec.shape)
125 | time_step = env.step(action)
126 | frame = time_step.observation['pixels'][:, :, 0:3]
127 | frames[-1].append(frame)
128 | frames_np = np.array(frames)
129 | for i in range(1, len(envs)):
130 | difference = np.mean(abs(frames_np[0] - frames_np[i]))
131 | self.assertEqual(difference, 0.)
132 |
133 | if __name__ == '__main__':
134 | absltest.main()
135 |
--------------------------------------------------------------------------------
/distracting_control/color.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """A wrapper for dm_control environments which applies color distractions."""
18 |
19 | from dm_control.rl import control
20 | import numpy as np
21 |
22 |
23 | class DistractingColorEnv(control.Environment):
24 | """Environment wrapper for color visual distraction.
25 |
26 | **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure
27 | the color changes are applied before rendering occurs.
28 | """
29 |
30 | def __init__(self, env, step_std, max_delta, seed=None):
31 | """Initialize the environment wrapper.
32 |
33 | Args:
34 | env: instance of dm_control Environment to wrap with augmentations.
35 | """
36 | if step_std < 0:
37 | raise ValueError("`step_std` must be greater than or equal to 0.")
38 | if max_delta < 0:
39 | raise ValueError("`max_delta` must be greater than or equal to 0.")
40 |
41 | self._env = env
42 | self._step_std = step_std
43 | self._max_delta = max_delta
44 | self._random_state = np.random.RandomState()
45 |
46 | self._cam_type = None
47 | self._current_rgb = None
48 | self._max_rgb = None
49 | self._min_rgb = None
50 | self._original_rgb = None
51 |
52 | def reset(self):
53 | """Reset the distractions state."""
54 | time_step = self._env.reset()
55 | self._reset_color()
56 | return time_step
57 |
58 | def _reset_color(self):
59 | # Save all original colors.
60 | if self._original_rgb is None:
61 | self._original_rgb = np.copy(self._env.physics.model.mat_rgba)[:, :3]
62 | # Determine minimum and maximum rgb values.
63 | self._max_rgb = np.clip(self._original_rgb + self._max_delta, 0.0, 1.0)
64 | self._min_rgb = np.clip(self._original_rgb - self._max_delta, 0.0, 1.0)
65 |
66 | # Pick random colors in the allowed ranges.
67 | r = self._random_state.uniform(size=self._min_rgb.shape)
68 | self._current_rgb = self._min_rgb + r * (self._max_rgb - self._min_rgb)
69 |
70 | # Apply the color changes.
71 | self._env.physics.model.mat_rgba[:, :3] = self._current_rgb
72 |
73 | def step(self, action):
74 | time_step = self._env.step(action)
75 |
76 | if time_step.first():
77 | self._reset_color()
78 | return time_step
79 |
80 | color_change = self._random_state.randn(*self._current_rgb.shape)
81 | color_change = color_change * self._step_std
82 |
83 | new_color = self._current_rgb + color_change
84 |
85 | self._current_rgb = np.clip(
86 | new_color,
87 | a_min=self._min_rgb,
88 | a_max=self._max_rgb,
89 | )
90 |
91 | # Apply the color changes.
92 | self._env.physics.model.mat_rgba[:, :3] = self._current_rgb
93 | return time_step
94 |
95 | # Forward property and method calls to self._env.
96 | def __getattr__(self, attr):
97 | if hasattr(self._env, attr):
98 | return getattr(self._env, attr)
99 | raise AttributeError("'{}' object has no attribute '{}'".format(
100 | type(self).__name__, attr))
101 |
--------------------------------------------------------------------------------
/distracting_control/distracting_control_demo.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A simple demo that produces an image from the environment."""
17 | import os
18 | from absl import app
19 | from absl import flags
20 |
21 | import PIL
22 |
23 | from distracting_control import suite
24 |
25 | FLAGS = flags.FLAGS
26 |
27 | # NOTE: This has been populated with the demo images path, but you should
28 | # download and extract the DAVIS dataset and point this path to the location
29 | # of DAVIS.
30 | flags.DEFINE_string(
31 | 'davis_path',
32 | '',
33 | 'Path to DAVIS images, used for background distractions.')
34 | flags.DEFINE_string('output_dir', '/tmp/distracting_control_demo',
35 | 'Directory where the results are being saved.')
36 |
37 |
38 | def main(unused_argv):
39 |
40 | if not FLAGS.davis_path:
41 | raise ValueError(
42 | 'You must download and extract the DAVIS dataset and pass a path '
43 | 'a path to the videos, e.g.: /tmp/DAVIS/JPEGImages/480p')
44 |
45 | for i, difficulty in enumerate(['easy', 'medium', 'hard']):
46 | for j, (domain, task) in enumerate([
47 | ('ball_in_cup', 'catch'),
48 | ('cartpole', 'swingup'),
49 | ('cheetah', 'run'),
50 | ('finger', 'spin'),
51 | ('reacher', 'easy'),
52 | ('walker', 'walk')]):
53 |
54 | env = suite.load(
55 | domain, task, difficulty, background_dataset_path=FLAGS.davis_path)
56 |
57 | # Get the first frame.
58 | time_step = env.reset()
59 | frame = time_step.observation['pixels'][:, :, 0:3]
60 |
61 | # Save the first frame.
62 | try:
63 | os.mkdir(FLAGS.output_dir)
64 | except OSError:
65 | pass
66 | filepath = os.path.join(FLAGS.output_dir, f'{i:02d}-{j:02d}.jpg')
67 | image = PIL.Image.fromarray(frame)
68 | image.save(filepath)
69 | print(f'Saved results to {FLAGS.output_dir}')
70 |
71 |
72 | if __name__ == '__main__':
73 | app.run(main)
74 |
--------------------------------------------------------------------------------
/distracting_control/requirements.txt:
--------------------------------------------------------------------------------
1 | dm-control==0.0.322773188
2 | absl-py>=0.9.0
3 | numpy>=1.18.3
4 | mock>=mock-4.0
5 | pillow>=7.0.0
6 |
--------------------------------------------------------------------------------
/distracting_control/run.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2021 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | #!/bin/bash
16 | set -e
17 | set -x
18 |
19 | virtualenv -p python3 .
20 | source ./bin/activate
21 |
22 | pip install tensorflow
23 | pip install -r distracting_control/requirements.txt
24 | python -m distracting_control.suite_test
25 | python -m distracting_control.camera_test
26 |
--------------------------------------------------------------------------------
/distracting_control/suite.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A collection of MuJoCo-based Reinforcement Learning environments.
17 |
18 | The suite provides a similar API to the original dm_control suite.
19 | Users can configure the distractions on top of the original tasks. The suite is
20 | targeted for loading environments directly with similar configurations as those
21 | used in the original paper. Each distraction wrapper can be used independently
22 | though.
23 | """
24 | try:
25 | from dm_control import suite # pylint: disable=g-import-not-at-top
26 | from dm_control.suite.wrappers import pixels # pylint: disable=g-import-not-at-top
27 | except ImportError:
28 | suite = None
29 |
30 | from distracting_control import background
31 | from distracting_control import camera
32 | from distracting_control import color
33 | from distracting_control import suite_utils
34 |
35 |
36 | def is_available():
37 | return suite is not None
38 |
39 |
40 | def load(domain_name,
41 | task_name,
42 | difficulty=None,
43 | dynamic=False,
44 | background_dataset_path=None,
45 | background_dataset_videos="train",
46 | background_kwargs=None,
47 | camera_kwargs=None,
48 | color_kwargs=None,
49 | task_kwargs=None,
50 | environment_kwargs=None,
51 | visualize_reward=False,
52 | render_kwargs=None,
53 | pixels_only=True,
54 | pixels_observation_key="pixels",
55 | allow_color_distraction=True,
56 | allow_background_distraction=True,
57 | allow_camera_distraction=True,
58 | ):
59 | """Returns an environment from a domain name, task name and optional settings.
60 |
61 | ```python
62 | env = suite.load('cartpole', 'balance')
63 | ```
64 |
65 | Adding a difficulty will configure distractions matching the reference paper
66 | for easy, medium, hard.
67 |
68 | Users can also toggle dynamic properties for distractions.
69 |
70 | Args:
71 | domain_name: A string containing the name of a domain.
72 | task_name: A string containing the name of a task.
73 | task_kwargs: Optional `dict` of keyword arguments for the task.
74 | difficulty: Difficulty for the suite. One of 'easy', 'medium', 'hard'.
75 | dynamic: Boolean controlling whether distractions are dynamic or static.
76 | backgound_dataset_path: String to the davis directory that contains the
77 | video directories.
78 | background_dataset_videos: String ('train'/'val') or list of strings of the
79 | DAVIS videos to be used for backgrounds.
80 | background_kwargs: Dict, overwrites settings for background distractions.
81 | camera_kwargs: Dict, overwrites settings for camera distractions.
82 | color_kwargs: Dict, overwrites settings for color distractions.
83 | task_kwargs: Dict, dm control task kwargs.
84 | environment_kwargs: Optional `dict` specifying keyword arguments for the
85 | environment.
86 | visualize_reward: Optional `bool`. If `True`, object colours in rendered
87 | frames are set to indicate the reward at each step. Default `False`.
88 | render_kwargs: Dict, render kwargs for pixel wrapper.
89 | pixels_only: Boolean controlling the exclusion of states in the observation.
90 | pixels_observation_key: Key in the observation used for the rendered image.
91 |
92 | Returns:
93 | The requested environment.
94 | """
95 | if not is_available():
96 | raise ImportError("dm_control module is not available. Make sure you "
97 | "follow the installation instructions from the "
98 | "dm_control package.")
99 |
100 | if difficulty not in [None, "easy", "medium", "hard", "pse", "hard_bg"]:
101 | raise ValueError("Difficulty should be one of: 'easy', 'medium', 'hard', 'pse', 'hard_bg'.")
102 |
103 | render_kwargs = render_kwargs or {}
104 | if "camera_id" not in render_kwargs:
105 | render_kwargs["camera_id"] = 2 if domain_name == "quadruped" else 0
106 |
107 | env = suite.load(
108 | domain_name,
109 | task_name,
110 | task_kwargs=task_kwargs,
111 | environment_kwargs=environment_kwargs,
112 | visualize_reward=visualize_reward)
113 |
114 | # Apply background distractions.
115 | if (difficulty or background_kwargs) and allow_background_distraction:
116 | background_dataset_path = (
117 | background_dataset_path or suite_utils.DEFAULT_BACKGROUND_PATH)
118 | final_background_kwargs = dict()
119 | if difficulty:
120 | # Get kwargs for the given difficulty.
121 | if background_dataset_videos in ['val', 'validation']:
122 | num_videos = suite_utils.DIFFICULTY_NUM_VIDEOS_VAL[difficulty]
123 | else:
124 | num_videos = suite_utils.DIFFICULTY_NUM_VIDEOS[difficulty]
125 | final_background_kwargs.update(
126 | suite_utils.get_background_kwargs(domain_name, num_videos, dynamic,
127 | background_dataset_path,
128 | background_dataset_videos))
129 | else:
130 | # Set the dataset path and the videos.
131 | final_background_kwargs.update(
132 | dict(
133 | dataset_path=background_dataset_path,
134 | dataset_videos=background_dataset_videos))
135 | if background_kwargs:
136 | # Overwrite kwargs with those passed here.
137 | final_background_kwargs.update(background_kwargs)
138 | env = background.DistractingBackgroundEnv(env, **final_background_kwargs)
139 |
140 | # Apply camera distractions.
141 | if (difficulty or camera_kwargs) and allow_camera_distraction:
142 | final_camera_kwargs = dict(camera_id=render_kwargs["camera_id"])
143 | if difficulty:
144 | # Get kwargs for the given difficulty.
145 | scale = suite_utils.DIFFICULTY_SCALE[difficulty]
146 | final_camera_kwargs.update(
147 | suite_utils.get_camera_kwargs(domain_name, scale, dynamic))
148 | if camera_kwargs:
149 | # Overwrite kwargs with those passed here.
150 | final_camera_kwargs.update(camera_kwargs)
151 | env = camera.DistractingCameraEnv(env, **final_camera_kwargs)
152 |
153 | # Apply color distractions.
154 | if (difficulty or color_kwargs) and allow_color_distraction:
155 | final_color_kwargs = dict()
156 | if difficulty:
157 | # Get kwargs for the given difficulty.
158 | scale = suite_utils.DIFFICULTY_SCALE[difficulty]
159 | final_color_kwargs.update(suite_utils.get_color_kwargs(scale, dynamic))
160 | if color_kwargs:
161 | # Overwrite kwargs with those passed here.
162 | final_color_kwargs.update(color_kwargs)
163 | env = color.DistractingColorEnv(env, **final_color_kwargs)
164 |
165 | # Apply Pixel wrapper after distractions. This is needed to ensure the
166 | # changes from the distraction wrapper are applied to the MuJoCo environment
167 | # before the rendering occurs.
168 | env = pixels.Wrapper(
169 | env,
170 | pixels_only=pixels_only,
171 | render_kwargs=render_kwargs,
172 | observation_key=pixels_observation_key)
173 |
174 | return env
175 |
--------------------------------------------------------------------------------
/distracting_control/suite_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for suite code."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import mock
21 | from distracting_control import suite
22 |
23 | DAVIS_PATH = '/tmp/davis'
24 |
25 | class SuiteTest(parameterized.TestCase):
26 |
27 | @parameterized.named_parameters(('none', None),
28 | ('easy', 'easy'),
29 | ('medium', 'medium'),
30 | ('hard', 'hard'))
31 | @mock.patch.object(suite, 'pixels')
32 | @mock.patch.object(suite, 'suite')
33 | def test_suite_load_with_difficulty(self, difficulty, mock_dm_suite,
34 | mock_pixels):
35 | domain_name = 'cartpole'
36 | task_name = 'balance'
37 | suite.load(
38 | domain_name,
39 | task_name,
40 | difficulty,
41 | background_dataset_path=DAVIS_PATH)
42 |
43 | mock_dm_suite.load.assert_called_with(
44 | domain_name,
45 | task_name,
46 | environment_kwargs=None,
47 | task_kwargs=None,
48 | visualize_reward=False)
49 |
50 | mock_pixels.Wrapper.assert_called_with(
51 | mock.ANY,
52 | observation_key='pixels',
53 | pixels_only=True,
54 | render_kwargs={'camera_id': 0})
55 |
56 |
57 | if __name__ == '__main__':
58 | absltest.main()
59 |
--------------------------------------------------------------------------------
/distracting_control/suite_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A collection of MuJoCo-based Reinforcement Learning environments.
17 |
18 | The suite provides a similar API to the original dm_control suite.
19 | Users can configure the distractions on top of the original tasks. The suite is
20 | targeted for loading environments directly with similar configurations as those
21 | used in the original paper. Each distraction wrapper can be used independently
22 | though.
23 | """
24 | import numpy as np
25 |
26 | DIFFICULTY_SCALE = dict(pse=0.0, easy=0.1, medium=0.2, hard=0.3, hard_bg=0.0)
27 | DIFFICULTY_NUM_VIDEOS = dict(pse=2, easy=4, medium=8, hard=None, hard_bg=None)
28 | DIFFICULTY_NUM_VIDEOS_VAL = dict(pse=None, easy=4, medium=8, hard=None, hard_bg=None)
29 | DEFAULT_BACKGROUND_PATH = "$HOME/davis/"
30 |
31 |
32 | def get_color_kwargs(scale, dynamic):
33 | max_delta = scale
34 | step_std = 0.03 * scale if dynamic else 0.0
35 | return dict(max_delta=max_delta, step_std=step_std)
36 |
37 |
38 | def get_camera_kwargs(domain_name, scale, dynamic):
39 | assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah',
40 | 'ball_in_cup', 'walker']
41 | assert scale >= 0.0
42 | assert scale <= 1.0
43 | return dict(
44 | vertical_delta=np.pi / 2 * scale,
45 | horizontal_delta=np.pi / 2 * scale,
46 | # Limit camera to -90 / 90 degree rolls.
47 | roll_delta=np.pi / 2. * scale,
48 | vel_std=.1 * scale if dynamic else 0.,
49 | max_vel=.4 * scale if dynamic else 0.,
50 | roll_std=np.pi / 300 * scale if dynamic else 0.,
51 | max_roll_vel=np.pi / 50 * scale if dynamic else 0.,
52 | max_zoom_in_percent=.5 * scale,
53 | max_zoom_out_percent=1.5 * scale,
54 | limit_to_upper_quadrant='reacher' not in domain_name,
55 | )
56 |
57 |
58 | def get_background_kwargs(domain_name,
59 | num_videos,
60 | dynamic,
61 | dataset_path,
62 | dataset_videos=None,
63 | shuffle=False,
64 | video_alpha=1.0):
65 | assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah',
66 | 'ball_in_cup', 'walker']
67 | if domain_name == 'reacher':
68 | ground_plane_alpha = 0.0
69 | elif domain_name in ['walker', 'cheetah']:
70 | ground_plane_alpha = 1.0
71 | else:
72 | ground_plane_alpha = 0.3
73 |
74 | return dict(
75 | num_videos=num_videos,
76 | video_alpha=video_alpha,
77 | ground_plane_alpha=ground_plane_alpha,
78 | dynamic=dynamic,
79 | dataset_path=dataset_path,
80 | dataset_videos=dataset_videos,
81 | shuffle_buffer_size=100 if shuffle else None,
82 | )
83 |
--------------------------------------------------------------------------------
/environment_container_dcs.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | """ Wraps Distracting Control Suite in a Gym-like wrapper."""
7 | import numpy as np
8 | from distracting_control import suite as distracting_suite
9 | from collections import deque
10 | from PIL import Image
11 | import argparse
12 | import yaml
13 | import matplotlib.pyplot as plt
14 | plt.ion()
15 |
16 |
17 | class EnvironmentContainerDCS(object):
18 | """
19 | Wrapper around DCS.
20 | """
21 | def __init__(self, config, train=True, seed=None):
22 | self.domain = config['domain']
23 |
24 | self.get_other_obs = False
25 |
26 | # The standard task and action_repeat for each domain.
27 | task_info = {
28 | 'ball_in_cup': ('catch', 4),
29 | 'cartpole': ('swingup', 8),
30 | 'cheetah': ('run', 4),
31 | 'finger': ('spin', 2),
32 | 'reacher': ('easy', 4),
33 | 'walker': ('walk', 2),
34 | }
35 | self.task, self.action_repeat = task_info[self.domain]
36 |
37 | self.difficulty = config['difficulty']
38 | if self.difficulty in ['none', 'None']:
39 | self.difficulty = None
40 |
41 | self.dynamic = config['dynamic']
42 | self.background_dataset_path = config.get('background_dataset_path', 'DAVIS/JPEGImages/480p')
43 | allow_color_distraction = config.get('allow_color_distraction', True)
44 | allow_background_distraction = config.get('allow_background_distraction', True)
45 | allow_camera_distraction = config.get('allow_camera_distraction', True)
46 | if seed is not None:
47 | task_kwargs = {'random': seed}
48 | else:
49 | task_kwargs = None
50 | self.env = distracting_suite.load(
51 | self.domain, self.task, difficulty=self.difficulty, dynamic=self.dynamic,
52 | background_dataset_path=self.background_dataset_path,
53 | background_dataset_videos='train' if train else 'val',
54 | pixels_only=False,
55 | task_kwargs=task_kwargs,
56 | allow_color_distraction=allow_color_distraction,
57 | allow_camera_distraction=allow_camera_distraction,
58 | allow_background_distraction=allow_background_distraction)
59 |
60 | action_spec = self.env.action_spec()
61 | self.action_dims = len(action_spec.minimum)
62 | self.action_low = action_spec.minimum
63 | self.action_high = action_spec.maximum
64 | self.num_frames_to_stack = config.get('num_frames_to_stack', 1)
65 | if self.num_frames_to_stack > 1:
66 | self.frame_queue = deque([], maxlen=self.num_frames_to_stack)
67 | self.config = config
68 | self.image_height, self.image_width = self.config['image_height'], self.config['image_width']
69 | self.num_channels = 3 * self.num_frames_to_stack
70 | self.other_dims = 0
71 |
72 | def get_action_dims(self):
73 | return self.action_dims
74 |
75 | def get_action_repeat(self):
76 | return self.action_repeat
77 |
78 | def get_action_limits(self):
79 | return self.action_low, self.action_high
80 |
81 | def get_obs_chw(self):
82 | return self.num_channels, self.image_height, self.image_width
83 |
84 | def get_obs_other_dims(self):
85 | return self.other_dims
86 |
87 | def reset(self):
88 | time_step = self.env.reset()
89 | if self.num_frames_to_stack > 1:
90 | self.frame_queue.clear()
91 | obs = self._get_image(time_step) # C, H, W.
92 | obs_dict = {'image': obs}
93 | return obs_dict
94 |
95 | def step(self, action):
96 | reward = 0
97 | for _ in range(self.action_repeat):
98 | time_step = self.env.step(action)
99 | reward += time_step.reward
100 | obs = self._get_image(time_step)
101 | done = False
102 | info = {}
103 | obs_dict = {'image': obs}
104 | return obs_dict, reward, done, info
105 |
106 | def _get_image(self, time_step):
107 | image_height, image_width = self.config['image_height'], self.config['image_width']
108 | obs = time_step.observation['pixels'][:, :, 0:3] # (240, 320, 3).
109 | # Resize to image_height, image_width
110 | obs = Image.fromarray(obs).resize((image_width, image_height), resample=Image.BILINEAR)
111 | #obs = cv2.resize(obs, dsize=(image_width, image_height))
112 | obs = np.asarray(obs)
113 | obs = obs.transpose((2, 0, 1)).copy() # (C, H, W)
114 | obs = self._stack_images(obs)
115 | return obs
116 |
117 | def _stack_images(self, obs):
118 | if self.num_frames_to_stack > 1:
119 | if len(self.frame_queue) == 0: # Just after reset.
120 | for _ in range(self.num_frames_to_stack):
121 | self.frame_queue.append(obs)
122 | else:
123 | self.frame_queue.append(obs)
124 | obs = np.concatenate(list(self.frame_queue), axis=0)
125 | return obs
126 |
127 |
128 | class EnvironmentContainerDCS_DMC_paired(EnvironmentContainerDCS):
129 | def __init__(self, config, train=True, seed=1):
130 | super().__init__(config, train=train, seed=seed)
131 | config_dmc = config.copy()
132 | config_dmc['difficulty'] = 'none'
133 | self.dmc = EnvironmentContainerDCS(config_dmc, train=train, seed=seed)
134 |
135 | def reset(self):
136 | obs_dict_dmc = self.dmc.reset()
137 | obs_dict = super().reset()
138 | obs_dict['image_clean'] = obs_dict_dmc['image']
139 | return obs_dict
140 |
141 | def step(self, action):
142 | obs_dict_dmc, _, _, _ = self.dmc.step(action)
143 | obs_dict, reward, done, info = super().step(action)
144 | obs_dict['image_clean'] = obs_dict_dmc['image']
145 | return obs_dict, reward, done, info
146 |
147 |
148 | def argument_parser(argument):
149 | """ Argument parser """
150 | parser = argparse.ArgumentParser(description='Binder Network.')
151 | parser.add_argument('-c', '--config', default='', type=str, help='Training config')
152 | args = parser.parse_args(argument)
153 | return args
154 |
155 |
156 | def test():
157 | args = argument_parser(None)
158 | try:
159 | with open(args.config) as f:
160 | config = yaml.safe_load(f)
161 | except FileNotFoundError:
162 | print("Error opening specified config yaml at: {}. "
163 | "Please check filepath and try again.".format(args.config))
164 |
165 | config = config['parameters']
166 | seed = config['seed']
167 | np.random.seed(seed)
168 | env = EnvironmentContainerDCS_DMC_paired(config['env'], train=True, seed=config['seed'])
169 | plt.figure(1)
170 | action_low, action_high = env.get_action_limits()
171 | action_dims = env.get_action_dims()
172 | for _ in range(1):
173 | env.reset()
174 | for _ in range(1):
175 | action = np.random.uniform(action_low, action_high, action_dims)
176 | obs_dict, reward, done, info = env.step(action)
177 | obs_dcs = obs_dict['image'].transpose((1, 2, 0))
178 | obs_dmc = obs_dict['image_clean'].transpose((1, 2, 0))
179 | plt.clf()
180 | obs = np.concatenate([obs_dcs, obs_dmc], axis=1)
181 | plt.imshow(obs)
182 | plt.pause(0.001)
183 | filename = '/Users/nitish/Desktop/binder_figures/sample_0.png'
184 | plt.savefig(filename)
185 |
186 |
187 | if __name__ == '__main__':
188 | test()
189 |
--------------------------------------------------------------------------------
/environment_container_robosuite.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import argparse
6 | import os
7 | import numpy as np
8 | import yaml
9 | from collections import deque
10 | import robosuite as suite
11 | from robosuite.controllers import load_controller_config, ALL_CONTROLLERS
12 | from mujoco_py import MujocoException
13 | from robosuite.wrappers.domain_randomization_wrapper import DomainRandomizationWrapper, DEFAULT_CAMERA_ARGS,\
14 | DEFAULT_COLOR_ARGS, DEFAULT_LIGHTING_ARGS
15 | import matplotlib.pyplot as plt
16 | plt.ion()
17 |
18 |
19 | def dict_merge(default, user):
20 | d = default.copy()
21 | d.update(user)
22 | return d
23 |
24 |
25 | def load_robosuite_controller_config(controller):
26 | if controller in set(ALL_CONTROLLERS):
27 | # This is a default controller
28 | controller_config = load_controller_config(default_controller=controller)
29 | else:
30 | # This is a string to the custom controller
31 | controller_config = load_controller_config(custom_fpath=controller)
32 | return controller_config
33 |
34 |
35 | class EnvironmentContainerRobosuite(object):
36 | def __init__(self, config, train=True, seed=None):
37 | super().__init__()
38 | self.config = config
39 | robosuite_config = config['robosuite_config']
40 | self.crop_image = config.get('crop_image', False)
41 | if self.crop_image:
42 | self.image_height = config['crop_height']
43 | self.image_width = config['crop_width']
44 | self.crop_center_xy = config['crop_center_xy']
45 | cx, cy = self.crop_center_xy
46 | self.crop_left = cx - self.image_width // 2
47 | self.crop_top = cy - self.image_height // 2
48 | else:
49 | self.image_height = robosuite_config['camera_heights']
50 | self.image_width = robosuite_config['camera_widths']
51 |
52 | self.dr = config.get('domain_randomize', False)
53 | dr_config = config.get('domain_randomization_config', None)
54 | controller = config['controller']
55 | robosuite_config['controller_configs'] = load_robosuite_controller_config(controller)
56 |
57 | self.env = suite.make(**robosuite_config)
58 | self.robosuite_config = robosuite_config
59 | self.env_name = self.robosuite_config['env_name']
60 |
61 | self.image_key = config['image_key']
62 | if 'other_key' in config:
63 | self.other_key = config['other_key']
64 | self.get_other_obs = True
65 | ob_dict = self.env.reset()
66 | assert self.other_key in ob_dict
67 | self.other_dims = len(ob_dict[self.other_key])
68 | else:
69 | self.get_other_obs = False
70 | self.other_dims = 0
71 |
72 | low, high = self.env.action_spec
73 | self.action_dims = len(low)
74 | self.action_repeat = 1
75 |
76 | if self.dr:
77 | dr_config['color_randomization_args'] = dict_merge(DEFAULT_COLOR_ARGS,
78 | dr_config.get('color_randomization_args', {}))
79 | dr_config['camera_randomization_args'] = dict_merge(DEFAULT_CAMERA_ARGS,
80 | dr_config.get('camera_randomization_args', {}))
81 | dr_config['lighting_randomization_args'] = dict_merge(DEFAULT_LIGHTING_ARGS,
82 | dr_config.get('lighting_randomization_args', {}))
83 | self.env = DomainRandomizationWrapper(self.env, seed=seed, **dr_config)
84 |
85 | self.num_frames_to_stack = config.get('num_frames_to_stack', 1)
86 | if self.num_frames_to_stack > 1:
87 | self.frame_queue = deque([], maxlen=self.num_frames_to_stack)
88 | self.num_channels = 3 * self.num_frames_to_stack
89 | self.other_dims = self.other_dims * self.num_frames_to_stack
90 | self.ob_dict = None # To hold the last raw observation.
91 |
92 | def get_action_dims(self):
93 | return self.action_dims
94 |
95 | def get_action_repeat(self):
96 | return self.action_repeat
97 |
98 | def get_action_limits(self):
99 | low, high = self.env.action_spec
100 | return low, high
101 |
102 | def get_obs_chw(self):
103 | return self.num_channels, self.image_height, self.image_width
104 |
105 | def get_obs_other_dims(self):
106 | return self.other_dims
107 |
108 | def preprocess_image(self, img):
109 | # If this is an RGB image, make it C, H, W.
110 | # This is done here rather than just before fprop so that
111 | # if/when depth is added, we can concatenate it as a channel.
112 | if len(img.shape) == 3: # RGB image
113 | if self.crop_image:
114 | x, y = self.crop_left, self.crop_top
115 | img = img[y:y + self.image_height, x:x + self.image_width, :]
116 | elif len(img.shape) == 2: # Depth image
117 | if self.crop_image:
118 | x, y = self.crop_left, self.crop_top
119 | img = img[y:y + self.image_height, x:x + self.image_width]
120 | if len(img.shape) >= 2:
121 | img = img[::-1, ...].copy() # The frontview image is upside-down.
122 | img = img.transpose(2, 0, 1) # CHW
123 | return img
124 |
125 | def _get_obs(self, obs_dict, verbose=False):
126 | assert self.image_key in obs_dict, "key {} not found in obs".format(self.image_key)
127 | img = self.preprocess_image(obs_dict[self.image_key])
128 | res = dict(image=img)
129 | if self.get_other_obs:
130 | assert self.other_key in obs_dict, "key {} not found in obs".format(self.other_key)
131 | res['other'] = obs_dict[self.other_key]
132 | if self.num_frames_to_stack > 1:
133 | res = self._get_stacked_obs(res)
134 | return res
135 |
136 | def _get_stacked_obs(self, obs):
137 | if len(self.frame_queue) == 0:
138 | for _ in range(self.num_frames_to_stack):
139 | self.frame_queue.append(obs)
140 | else:
141 | self.frame_queue.append(obs)
142 | keys = obs.keys()
143 | res = {}
144 | for key in keys:
145 | res[key] = np.concatenate([frame[key] for frame in self.frame_queue])
146 | return res
147 |
148 | def reset(self):
149 | if self.num_frames_to_stack > 1:
150 | self.frame_queue.clear()
151 | self.ob_dict = self.env.reset()
152 | #if self.dr: # Fix for resetting bug. Domain is not getting randmized for the obs coming from reset.
153 | action = np.zeros(self.action_dims)
154 | self.ob_dict, _, _, _ = self.env.step(action)
155 | obs = self._get_obs(self.ob_dict)
156 | return obs
157 |
158 | def step(self, action):
159 | try:
160 | ob_dict, reward, done, info = self.env.step(action)
161 | except MujocoException as e:
162 | print('MujocoException', e)
163 | print('Will skip this action')
164 | if self.ob_dict is not None:
165 | ob_dict = self.ob_dict
166 | reward = 0
167 | done = False
168 | info = {}
169 | self.ob_dict = ob_dict
170 | # Additional reward shaping.
171 | #if self.env_name == 'Door':
172 | # # Additional reward shaping for door angle.
173 | # if reward < 1:
174 | # hinge_qpos = self.env.sim.data.qpos[self.env.hinge_qpos_addr]
175 | # reward += np.clip(0.5 * hinge_qpos / 0.3, 0, 0.5)
176 | obs = self._get_obs(ob_dict)
177 | return obs, reward, done, info
178 |
179 | def render(self, mode, **kwargs):
180 | if mode == 'rgb_array':
181 | image_list = []
182 | for (cam_name, cam_w, cam_h, cam_d) in \
183 | zip(self.env.camera_names, self.env.camera_widths, self.env.camera_heights, self.env.camera_depths):
184 |
185 | # Add camera observations to the dict
186 | camera_obs = self.env.sim.render(
187 | camera_name=cam_name,
188 | width=cam_w,
189 | height=cam_h,
190 | depth=cam_d
191 | )
192 | if cam_d:
193 | img, depth = camera_obs
194 | camera_obs = np.concatenate([img, depth[:, :, None]], axis=2)
195 | image_list.append(camera_obs)
196 | image = np.concatenate(image_list, axis=1)
197 | return image # return RGB frame suitable for video
198 | elif mode == 'human':
199 | self.env.render() # pop up a window and render
200 | else:
201 | raise NotImplementedError
202 |
203 | def argument_parser(argument):
204 | """ Argument parser """
205 | parser = argparse.ArgumentParser(description='Binder Network.')
206 | parser.add_argument('-c', '--config', default='', type=str, help='Training config')
207 | args = parser.parse_args(argument)
208 | return args
209 |
210 |
211 |
212 | def test2():
213 | args = argument_parser(None)
214 | try:
215 | with open(args.config) as f:
216 | config = yaml.safe_load(f)
217 | except FileNotFoundError:
218 | print("Error opening specified config yaml at: {}. "
219 | "Please check filepath and try again.".format(args.config))
220 |
221 | config = config['parameters']
222 | seed = config['seed']
223 | np.random.seed(seed)
224 | env = EnvironmentContainerRobosuite(config['env'])
225 | obs_dict = env.reset()
226 | action_low, action_high = env.get_action_limits()
227 | action_dims = env.get_action_dims()
228 | plt.figure(1)
229 | obs_list = []
230 | for ii in range(2):
231 | obs = obs_dict['image'].transpose((1, 2, 0))
232 | obs_list.append(obs)
233 | plt.clf()
234 | plt.imshow(obs)
235 | plt.suptitle('Image {}'.format(ii))
236 | plt.pause(0.5)
237 | action = np.random.uniform(action_low, action_high, action_dims)
238 | obs_dict, reward, done, info = env.step(action)
239 | obs_list.append(np.abs(obs_list[-1] - obs_list[-2]))
240 | obs = np.concatenate(obs_list, axis=1)
241 | plt.imshow(obs)
242 | plt.axis('off')
243 | plt.show()
244 | input('Press enter')
245 |
246 |
247 | def test():
248 | args = argument_parser(None)
249 | try:
250 | with open(args.config) as f:
251 | config = yaml.safe_load(f)
252 | except FileNotFoundError:
253 | print("Error opening specified config yaml at: {}. "
254 | "Please check filepath and try again.".format(args.config))
255 |
256 | config = config['parameters']
257 | seed = config['seed']
258 | np.random.seed(seed)
259 | plt.figure(1)
260 | randomize_settings = [(False, False), (False, True), (True, True)]
261 | obs_list = []
262 | for robot in ['Panda', 'Jaco']:
263 | for randomize_camera, randomize_other in randomize_settings:
264 | config['env']['domain_randomize'] = True
265 | config['env']['robosuite_config']['robots'] = [robot]
266 | config['env']['domain_randomization_config']['randomize_camera'] = randomize_camera
267 | config['env']['domain_randomization_config']['randomize_color'] = randomize_other
268 | config['env']['domain_randomization_config']['randomize_lighting'] = randomize_other
269 | env = EnvironmentContainerRobosuite(config['env'], seed=seed)
270 | env.reset()
271 | action_low, action_high = env.get_action_limits()
272 | action_dims = env.get_action_dims()
273 | action = np.random.uniform(action_low, action_high, action_dims)
274 | obs_dict, reward, done, info = env.step(action)
275 | obs = obs_dict['image'].transpose((1, 2, 0))
276 | plt.clf()
277 | plt.imshow(obs)
278 | plt.pause(0.001)
279 | obs_list.append(obs)
280 | obs1 = np.concatenate(obs_list[:3], axis=0)
281 | obs2 = np.concatenate(obs_list[3:], axis=0)
282 | obs = np.concatenate([obs1, obs2], axis=1)
283 | plt.draw()
284 | plt.imshow(obs)
285 | plt.axis('off')
286 | plt.show()
287 | plt.savefig('/Users/nitish/Desktop/binder_figures/robosuite_supp_new.png', bbox_inches='tight')
288 |
289 | if __name__ == '__main__':
290 | test()
291 |
--------------------------------------------------------------------------------
/environments.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | """ Interface to all environments. """
6 |
7 | def make_environment(config, train=True, seed=None):
8 | env_name = config.get('name', 'dcs')
9 | if env_name == 'dcs':
10 | from environment_container_dcs import EnvironmentContainerDCS
11 | return EnvironmentContainerDCS(config, train=train, seed=seed)
12 | elif env_name == 'robosuite':
13 | from environment_container_robosuite import EnvironmentContainerRobosuite
14 | return EnvironmentContainerRobosuite(config, train=train, seed=seed)
15 | elif env_name == 'dcs_dmc_paired':
16 | from environment_container_dcs import EnvironmentContainerDCS_DMC_paired
17 | return EnvironmentContainerDCS_DMC_paired(config, train=train, seed=seed)
18 | else:
19 | raise ValueError
20 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import random
9 | import numpy as np
10 | from torch import distributions as pyd
11 | import math
12 |
13 |
14 | def weight_init(m):
15 | """Custom weight init for Conv2D and Linear layers."""
16 | if isinstance(m, nn.Linear):
17 | nn.init.orthogonal_(m.weight.data)
18 | if m.bias is not None:
19 | m.bias.data.fill_(0.0)
20 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
21 | # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
22 | assert m.weight.size(2) == m.weight.size(3)
23 | m.weight.data.fill_(0.0)
24 | if m.bias is not None:
25 | m.bias.data.fill_(0.0)
26 | mid = m.weight.size(2) // 2
27 | gain = nn.init.calculate_gain('relu')
28 | nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
29 |
30 |
31 | def act_func(act_str):
32 | act = None
33 | if act_str == 'relu':
34 | act = nn.ReLU
35 | elif act_str == 'elu':
36 | act = nn.ELU
37 | elif act_str == 'identity':
38 | act = nn.Identity
39 | elif act_str == 'tanh':
40 | act = nn.Tanh
41 | elif act_str == 'sigmoid':
42 | act = nn.Sigmoid
43 | else:
44 | raise ValueError('Unknown activation function.')
45 | return act
46 |
47 |
48 | class GRUCell(nn.Module):
49 | """
50 | GRUCell with optional layer norm.
51 | """
52 | def __init__(self, input_dims, hidden_dims, norm=False, update_bias=-1):
53 | super().__init__()
54 | self.input_dims = input_dims
55 | self.hidden_dims = hidden_dims
56 | self.linear = nn.Linear(input_dims + hidden_dims, 3 * hidden_dims, bias=not norm)
57 | self.norm = norm
58 | self.update_bias = update_bias
59 | if norm:
60 | self.norm_layer = nn.LayerNorm(3 * hidden_dims)
61 |
62 | def forward(self, inputs, state):
63 | gate_inputs = self.linear(torch.cat([inputs, state], dim=-1))
64 | if self.norm:
65 | gate_inputs = self.norm_layer(gate_inputs)
66 | reset, cand, update = gate_inputs.chunk(3, -1)
67 | reset = torch.sigmoid(reset)
68 | cand = torch.tanh(reset * cand)
69 | update = torch.sigmoid(update + self.update_bias)
70 | output = update * cand + (1 - update) * state
71 | return output
72 |
73 |
74 | class FCNet(nn.Module):
75 | """ MLP with fully-connected layers."""
76 | def __init__(self, config, in_features, out_features=None):
77 | super().__init__()
78 | layers = []
79 | fc_act = act_func(config['fc_activation'])
80 | self.input_dims = in_features
81 | num_hids = config['fc_hiddens']
82 | if out_features:
83 | num_hids.append(out_features)
84 | for i, num_hid in enumerate(num_hids):
85 | fc_layer = nn.Linear(in_features=in_features, out_features=num_hid)
86 | layers.append(fc_layer)
87 | if i < len(num_hids) - 1:
88 | layers.append(fc_act())
89 | if config.get('fc_batch_norm', False):
90 | layers.append(nn.BatchNorm1d(num_hid))
91 | else:
92 | output_act = act_func(config['output_activation'])
93 | layers.append(output_act())
94 | in_features = num_hid
95 | if config.get('layer_norm_output', False):
96 | layers.append(nn.LayerNorm([in_features], elementwise_affine=config.get('layer_norm_affine', False)))
97 | self.fc_net = nn.Sequential(*layers)
98 | self.output_dims = in_features
99 |
100 | def forward(self, x):
101 | assert x.shape[-1] == self.input_dims, "Last dim is {} but should be {}".format(x.shape[-1], self.input_dims)
102 | orig_shape = list(x.shape)
103 | x = x.view(-1, self.input_dims)
104 | x = self.fc_net(x)
105 | orig_shape[-1] = self.output_dims
106 | x = x.view(*orig_shape)
107 | return x
108 |
109 |
110 | class CNN(nn.Module):
111 | """ CNN, optionally followed by fully connected layers."""
112 | def __init__(self, config, chw):
113 | super().__init__()
114 | self.config = config
115 | channels, height, width = chw
116 | cnn_layers, channels, height, width = self.make_conv_layers((channels, height, width), config['conv_filters'])
117 | self.conv_net = nn.Sequential(*cnn_layers)
118 | num_features = channels * height * width
119 | if 'fc_hiddens' in config:
120 | fc_layers, num_features = self.make_fc_layers(num_features)
121 | self.fc_net = nn.Sequential(*fc_layers)
122 | else:
123 | self.fc_net = None
124 | self.output_dims = num_features
125 |
126 | def make_conv_layers(self, input_chw, conv_filters):
127 | channels, height, width = input_chw
128 | conv_act = act_func(self.config['conv_activation'])
129 | base_num_hid = self.config['base_num_hid']
130 | layers = []
131 | for i, filter_spec in enumerate(conv_filters):
132 | if len(filter_spec) == 3: # Padding is not specified.
133 | num_filters, kernel_size, stride = filter_spec
134 | padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
135 | elif len(filter_spec) == 4:
136 | num_filters, kernel_size, stride, padding = filter_spec
137 | num_filters = num_filters * base_num_hid
138 | conv_layer = nn.Conv2d(channels, out_channels=num_filters, kernel_size=kernel_size, stride=stride,
139 | padding=padding)
140 | height = (height - kernel_size[0] + 2 * padding[0]) // stride + 1
141 | width = (width - kernel_size[1] + 2 * padding[1]) // stride + 1
142 | layers.append(conv_layer)
143 | layers.append(conv_act())
144 | if self.config['conv_batch_norm']:
145 | bn = nn.BatchNorm2d(num_filters)
146 | layers.append(bn)
147 | channels = num_filters
148 | return layers, channels, height, width
149 |
150 | def make_fc_layers(self, in_features):
151 | fc_act = act_func(self.config['fc_activation'])
152 | base_num_hid = self.config['base_num_hid']
153 | layers = []
154 | for i, num_hid in enumerate(self.config['fc_hiddens']):
155 | num_hid = num_hid * base_num_hid
156 | fc_layer = nn.Linear(in_features=in_features, out_features=num_hid)
157 | layers.append(fc_layer)
158 | if i < len(self.config['fc_hiddens']) - 1:
159 | layers.append(fc_act())
160 | if self.config.get('fc_batch_norm', False):
161 | layers.append(nn.BatchNorm1d(num_hid))
162 | else:
163 | if self.config.get('layer_norm_output', False):
164 | layers.append(nn.LayerNorm([num_hid], elementwise_affine=False))
165 | output_act = act_func(self.config['output_activation'])
166 | layers.append(output_act())
167 | in_features = num_hid
168 | return layers, in_features
169 |
170 | def forward(self, x):
171 | """
172 | Args:
173 | x : (..., C, H, W)
174 | Output:
175 | x: (..., D)
176 | """
177 | orig_shape = list(x.size())
178 | x = x.view(-1, *orig_shape[-3:])
179 | x = self.conv_net(x)
180 | if self.fc_net is not None:
181 | x = x.view(x.size(0), -1) # Flatten.
182 | x = self.fc_net(x)
183 | x = x.view(*orig_shape[:-3], x.shape[-1])
184 | else:
185 | x = x.view(*orig_shape[:-3], *x.shape[-3:])
186 | return x
187 |
188 |
189 | class TransposeCNN(nn.Module):
190 | """Decode images from a vector.
191 | """
192 |
193 | def __init__(self, config, input_size):
194 | """Initializes a ConvDecoder instance.
195 |
196 | Args:
197 | input_size (int): Input size, usually feature size output from
198 | RSSM.
199 | depth (int): Number of channels in the first conv layer
200 | act (Any): Activation for Encoder, default ReLU
201 | shape (List): Shape of observation input
202 | """
203 | super().__init__()
204 |
205 | layers = []
206 | in_features = input_size
207 | fc_act = act_func(config['fc_activation'])
208 | base_num_hid = config['base_num_hid']
209 | for i, num_hid in enumerate(config['fc_hiddens']):
210 | num_hid = num_hid * base_num_hid
211 | fc_layer = nn.Linear(in_features=in_features, out_features=num_hid)
212 | layers.append(fc_layer)
213 | layers.append(fc_act())
214 | if config.get('fc_batch_norm', False):
215 | layers.append(nn.BatchNorm1d(num_hid))
216 | in_features = num_hid
217 | self.fc_net = nn.Sequential(*layers)
218 |
219 | filters = config['conv_filters']
220 | act = act_func(config['conv_activation'])
221 | output_act = act_func(config['output_activation'])
222 | bn = config.get('conv_batch_norm', False)
223 | in_channels = in_features
224 | layers = []
225 | in_size = [1, 1]
226 | for i, (out_channels, kernel, stride) in enumerate(filters):
227 | if i < len(filters) - 1:
228 | out_channels = out_channels * base_num_hid
229 | layer = nn.ConvTranspose2d(in_channels, out_channels, kernel, stride=stride)
230 | layers.append(layer)
231 | out_size = [kernel[0] + (in_size[0] - 1) * stride, kernel[1] + (in_size[1] - 1) * stride]
232 | if i < len(filters) - 1: # Don't put batch norm in the last layer, potentially different activation func.
233 | if act is not None:
234 | layers.append(act())
235 | if bn:
236 | layers.append(nn.BatchNorm2d(out_channels))
237 | else:
238 | layers.append(output_act())
239 | in_channels = out_channels
240 | in_size = out_size
241 |
242 | self.conv_transpose = nn.Sequential(*layers)
243 | self.out_size = in_size
244 | self.out_channels = in_channels
245 |
246 | def forward(self, x):
247 | """
248 | Args:
249 | x: (..., D)
250 | Output:
251 | x : (..., C, H, W)
252 | """
253 | orig_shape = list(x.size())
254 | x = x.view(-1, orig_shape[-1])
255 | x = self.fc_net(x)
256 | C = x.shape[-1]
257 | x = x.view(-1, C, 1, 1)
258 | x = self.conv_transpose(x)
259 | out_shape = list(x.size())
260 | res_shape = orig_shape[:-1] + out_shape[1:]
261 | x = x.view(*res_shape)
262 | return x
263 |
264 |
265 | class ObservationModel(nn.Module):
266 | """
267 | Module that encapsulates the observation encoder (and optionally, decoder).
268 | """
269 | def __init__(self, config, image_chw, other_dims=0):
270 | """
271 | Inputs:
272 | image_chw: (channels, height, width) for the image observations.
273 | other_dims : int, any extra dimensions, e.g. proprioceptive state.
274 | """
275 | super().__init__()
276 |
277 | # Gating network.
278 | self.use_gating_network = config.get('use_gating_network', False)
279 | if self.use_gating_network:
280 | self.gating_net = CNN(config['encoder_gating'], image_chw)
281 | self.gating_net_channels = image_chw[0]
282 |
283 | # Image encoder.
284 | self.encoder = CNN(config['encoder'], image_chw)
285 | self.output_dims = self.encoder.output_dims
286 |
287 | # (Optional) Proprioceptive state encoder.
288 | if 'other_obs_model' in config:
289 | print('Proprioceptive state has {} dims'.format(other_dims))
290 | self.other_model = FCNet(config['other_obs_model'], other_dims)
291 | self.output_dims += self.other_model.output_dims
292 | else:
293 | self.other_model = None
294 |
295 | if 'decoder' in config:
296 | # Image predictor P(x_t | z_t, h_t)
297 | self.decoder = TransposeCNN(config['decoder'], self.output_dims)
298 | else:
299 | self.decoder = None
300 |
301 | self.apply(weight_init)
302 |
303 | def encode_pixels(self, x, encoder=None):
304 | if self.use_gating_network:
305 | orig_shape = list(x.shape)
306 | input_channels = orig_shape[-3]
307 | num_frames = input_channels // self.gating_net_channels
308 | # 9 input_channels = 3 frames * 3 (RGB) channels.
309 | # Gating net is shared across stacked frames.
310 | x = x.view(*orig_shape[:-3], num_frames, self.gating_net_channels, *orig_shape[-2:])
311 | g = torch.sigmoid(self.gating_net(x))
312 | assert g.shape[-3] == 1 # Gating should give one scalar per location.
313 | x = x * g
314 | x = x.view(*orig_shape)
315 | else:
316 | g = None
317 |
318 | if encoder is None:
319 | x = self.encoder(x)
320 | else:
321 | x = encoder(x)
322 | return x, g
323 |
324 | def encode(self, batch):
325 | """
326 | Encode the observations.
327 | Inputs:
328 | batch: dict containing 'obs_image' and optionally, other observations.
329 | Returns:
330 | outputs: dict containing 'obs_encoding', and optionally 'obs_gating'.
331 | """
332 | outputs = {}
333 | x, g = self.encode_pixels(batch['obs_image'])
334 | if self.use_gating_network:
335 | outputs['obs_gating'] = g
336 |
337 | if self.other_model is not None:
338 | other_obs = batch['obs_other']
339 | other_encoding = self.other_model(other_obs)
340 | x = torch.cat([x, other_encoding], dim=-1)
341 | outputs['obs_features'] = x
342 | return outputs
343 |
344 | def decode(self, x):
345 | assert self.decoder is not None
346 | return self.decoder(x)
347 |
348 | def forward(self, x):
349 | return self.encode(x)
350 |
351 |
352 | class TanhTransform(pyd.transforms.Transform):
353 | domain = pyd.constraints.real
354 | codomain = pyd.constraints.interval(-1.0, 1.0)
355 | bijective = True
356 | sign = +1
357 |
358 | def __init__(self, cache_size=1):
359 | super().__init__(cache_size=cache_size)
360 |
361 | @staticmethod
362 | def atanh(x):
363 | return 0.5 * (x.log1p() - (-x).log1p())
364 |
365 | def __eq__(self, other):
366 | return isinstance(other, TanhTransform)
367 |
368 | def _call(self, x):
369 | return x.tanh()
370 |
371 | def _inverse(self, y):
372 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
373 | # one should use `cache_size=1` instead
374 | return self.atanh(y)
375 |
376 | def log_abs_det_jacobian(self, x, y):
377 | # We use a formula that is more numerically stable, see details in the following link
378 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
379 | return 2. * (math.log(2.) - x - F.softplus(-2. * x))
380 |
381 |
382 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
383 | def __init__(self, loc, scale):
384 | self.loc = loc
385 | self.scale = scale
386 |
387 | try:
388 | self.base_dist = pyd.Normal(loc, scale)
389 | except Exception as e:
390 | print(e)
391 | print('Loc mean: ', loc.mean())
392 | print('Loc: ', loc)
393 | raise e
394 | transforms = [TanhTransform()]
395 | super().__init__(self.base_dist, transforms)
396 |
397 | @property
398 | def mean(self):
399 | mu = self.loc
400 | for tr in self.transforms:
401 | mu = tr(mu)
402 | return mu
403 |
404 |
405 | class TanhGaussianPolicy(nn.Module):
406 | def __init__(self, config, state_dims, action_dims):
407 | super().__init__()
408 | self.fc_net = FCNet(config, state_dims, action_dims * 2)
409 | self.log_std_min = config.get("policy_min_logstd", -10)
410 | self.log_std_max = config.get("policy_max_logstd", 2)
411 | self.tiny = 1.e-7
412 | self.clip = 1 - self.tiny
413 | self.apply(weight_init)
414 |
415 | def forward(self, x, sample=True):
416 | x = self.fc_net(x)
417 | mu, log_std = torch.chunk(x, 2, dim=-1)
418 | mu = 5 * torch.tanh(mu / 5)
419 | log_std = torch.sigmoid(log_std)
420 | log_std = self.log_std_min + (self.log_std_max - self.log_std_min) * log_std
421 | std = torch.exp(log_std)
422 | action_dist = SquashedNormal(mu, std)
423 | if sample:
424 | action_sample = action_dist.rsample()
425 | action_sample = action_sample.clamp(-self.clip, self.clip)
426 | else:
427 | action_sample = torch.tanh(mu)
428 | log_prob = action_dist.log_prob(action_sample).sum(dim=-1)
429 | return action_sample, log_prob, std
430 |
431 |
432 | class DoubleCritic(nn.Module):
433 | def __init__(self, config, state_dims, action_dims):
434 | super().__init__()
435 | self.critic1 = FCNet(config, state_dims + action_dims)
436 | self.critic2 = FCNet(config, state_dims + action_dims)
437 | self.state_dims = state_dims
438 | self.action_dims = action_dims
439 | self.apply(weight_init)
440 |
441 | def forward(self, states, actions):
442 | assert states.shape[-1] == self.state_dims
443 | assert actions.shape[-1] == self.action_dims
444 | inputs = torch.cat([states, actions], dim=-1)
445 | q1 = self.critic1(inputs).squeeze(-1)
446 | q2 = self.critic2(inputs).squeeze(-1)
447 | return q1, q2
448 |
449 |
450 | class ContrastivePrediction(nn.Module):
451 | def __init__(self, config, obs_dims):
452 | super().__init__()
453 | init_inverse_temp = config['inverse_temperature_init']
454 | inverse_temp = torch.tensor([float(init_inverse_temp)])
455 | self.inverse_temp = nn.parameter.Parameter(inverse_temp)
456 | self.output_dims = obs_dims
457 | self.softmax_over = config.get('softmax_over', 'both') # ['obs', 'symbol', 'both']
458 | mask_type = config.get('mask_type', None)
459 | if mask_type is None:
460 | self.mask = None
461 | else:
462 | if mask_type == 'exclude_same_sequence':
463 | mask = self.get_exclude_same_sequence_mask(31, 32)
464 | elif mask_type == 'exclude_other_sequences':
465 | mask = self.get_exclude_other_sequences_mask(31, 32)
466 | else:
467 | raise Exception('Unknown mask type')
468 | self.register_buffer('mask', mask)
469 |
470 | def get_exclude_same_sequence_mask(self, T, B):
471 | """
472 | Exclude other timesteps from the same sequence.
473 | mask[i, j] = True means replace that by -inf.
474 | """
475 | mask = torch.zeros(T, B, T, B)
476 | per_b_mask = 1 - torch.ones(T, T)
477 | for b in range(B):
478 | mask[:, b, :, b] = per_b_mask
479 | mask = mask.view(T*B, T*B)
480 | mask = mask == 1
481 | return mask
482 |
483 | def get_exclude_other_sequences_mask(self, T, B):
484 | """
485 | Exclude other sequences in the batch.
486 | mask[i, j] = True means replace that by -inf.
487 | """
488 | mask = torch.ones(T, B, T, B)
489 | for b in range(B):
490 | mask[:, b, :, b] = 0
491 | mask = mask.view(T*B, T*B)
492 | mask = mask == 1
493 | return mask
494 |
495 | def forward(self, x, y, train=False):
496 | """
497 | Inputs:
498 | x: (..., D) encoder features.
499 | y: (..., D) decoder features.
500 | Outputs:
501 | loss: contrastive loss.
502 | """
503 | x = x.view(-1, self.output_dims) # (T * B, D)
504 | y = y.view(-1, self.output_dims) # (T * B, D)
505 | inv_temp = F.softplus(self.inverse_temp)
506 | logits = inv_temp * torch.matmul(x, y.T) # (B', B')
507 | if self.mask is not None and train:
508 | logits[self.mask] = float('-inf')
509 | log_probs1 = F.log_softmax(logits, dim=1)
510 | log_probs2 = F.log_softmax(logits, dim=0)
511 | loss1 = -(log_probs1.diagonal().mean())
512 | loss2 = -(log_probs2.diagonal().mean())
513 | if self.softmax_over == 'symbol':
514 | loss = loss1
515 | elif self.softmax_over == 'obs':
516 | loss = loss2
517 | else:
518 | loss = (loss1 + loss2) / 2
519 | return loss
520 |
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | """ Replay Buffer for storing sequential data"""
6 | import random
7 | import torch
8 | from utils import torchify
9 |
10 |
11 | class SequenceReplayBuffer(object):
12 | def __init__(self, size=None):
13 | self.data = []
14 | self.size = size
15 | self.index = 0
16 |
17 | def __len__(self):
18 | return len(self.data)
19 |
20 | def add(self, seq):
21 | """
22 | seq : list of dict of key, value, where value is numpy array or float/int.
23 | """
24 | seq = torchify(seq) # dict of (T, ...)
25 | if self.size is None or len(self.data) < self.size:
26 | self.data.append(seq)
27 | else:
28 | self.data[self.index] = seq
29 | self.index += 1
30 | if self.index == self.size:
31 | self.index = 0
32 |
33 | def sample(self, num_seq, seq_len=0):
34 | """
35 | Sample a batch from the replay buffer.
36 | Args:
37 | num_seq: Batch size
38 | seq_len: Length of each sequence. Default=0 means pick the entire sequence (i.e. seq_len=T)
39 | Returns:
40 | res : dict of tensors of shape (num_seq, seq_len, *entity_shape)
41 | """
42 | # Pick seq_ids.
43 | seq_count = len(self.data)
44 | inds = list(range(seq_count))
45 | if num_seq < seq_count:
46 | inds = random.sample(inds, k=num_seq)
47 | elif num_seq > seq_count:
48 | inds = random.choices(inds, k=num_seq)
49 | batch = []
50 | for ind in inds:
51 | seq = self.data[ind]
52 | key = list(seq.keys())[0]
53 | T = len(seq[key])
54 | if seq_len <= 0 or T <= seq_len:
55 | seq_sample = seq
56 | else:
57 | start_pos = random.randint(0, T - seq_len)
58 | seq_sample = {k: v[start_pos:start_pos + seq_len] for k, v in seq.items()}
59 | batch.append(seq_sample)
60 | # pack the batch into a dict of tensors.
61 | keys = batch[0].keys()
62 | res = {}
63 | for key in keys:
64 | res[key] = torch.stack([sample[key] for sample in batch]) # (B, T, ...)
65 | return res
66 |
67 | def all(self):
68 | return self.sample(len(self.data))
69 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | apt-get update
6 | apt install -y less tmux psmisc curl git libgl1-mesa-dev libgl1-mesa-glx libglew-dev \
7 | libosmesa6-dev software-properties-common net-tools unzip vim \
8 | virtualenv wget xpra xserver-xorg-dev libglfw3-dev patchelf xvfb ffmpeg git
9 |
10 | # Download distraction videos.
11 | wget https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
12 | unzip -q DAVIS-2017-trainval-480p.zip
13 |
14 | # Install MuJoCo
15 | mkdir ~/.mujoco
16 |
17 | # Mujoco 2.1.0 for dm_control
18 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz
19 | tar -xvzf mujoco210-linux-x86_64.tar.gz
20 | mv mujoco210 ~/.mujoco/mujoco210
21 |
22 | # Install MuJoCo 2.0
23 | wget https://www.roboti.us/download/mujoco200_linux.zip
24 | wget https://roboti.us/file/mjkey.txt
25 | unzip mujoco200_linux.zip
26 | mv mujoco200_linux ~/.mujoco/mujoco200_linux
27 | mv mjkey.txt ~/.mujoco/
28 | ln -s ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200
29 | export LD_LIBRARY_PATH=$HOME/.mujoco/mujoco200/bin:$LD_LIBRARY_PATH
30 |
31 | # Put 10_nvidia.json in the right place.
32 | # This is needed to make the renderer use gpu.
33 | cp 10_nvidia.json /usr/share/glvnd/egl_vendor.d/
34 |
35 | # Mujoco-py uses a very brittle test to determine whether to use gpu rendering.
36 | # It looks for a directory named /usr/lib/nvidia-xxx (but doesn't really need or use any libraries present there).
37 | # So we just create a dummy one here.
38 | mkdir -p /usr/lib/nvidia-000
39 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia-000
40 |
41 | # Create conda environment.
42 | conda config --set remote_read_timeout_secs 600
43 | conda env create -f conda_env_robosuite.yml
44 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from modules import TanhGaussianPolicy, FCNet, ObservationModel, DoubleCritic, ContrastivePrediction
9 | from world_model import WorldModel
10 | import numpy as np
11 | from replay_buffer import SequenceReplayBuffer
12 | import torch.optim
13 | import argparse
14 | import yaml
15 | import os
16 | from torch.utils.tensorboard import SummaryWriter
17 | import time
18 | import random
19 | from utils import write_video_mp4, torchify, add_weight_decay, generate_expt_id, crop_image_tensor, get_parameter_list
20 | from environments import make_environment
21 | from copy import deepcopy
22 | import sys
23 |
24 |
25 | class Trainer(object):
26 | """ Trainer for all models. """
27 | def __init__(self, config, device, debug):
28 | self.config = config
29 | self.device = device
30 | self.debug = debug
31 |
32 | # Where should artifacts be written out.
33 | artifact_dir = config.get('artifact_dir', os.environ.get('BOLT_ARTIFACT_DIR', 'artifacts'))
34 | if not debug:
35 | self.log_dir = os.path.join(artifact_dir, config['expt_id'])
36 | os.makedirs(self.log_dir, exist_ok=True)
37 | config_filename = os.path.join(self.log_dir, 'config.yaml')
38 | with open(config_filename, 'w') as f:
39 | yaml.dump(config, f, sort_keys=True)
40 | print('Results will be logged to {}'.format(self.log_dir))
41 | self.tb_writer = SummaryWriter(self.log_dir)
42 |
43 | self.num_envs = self.config['num_envs']
44 | self.num_val_envs = self.config['num_val_envs']
45 | seed = self.config['seed']
46 | self.train_env_containers = [make_environment(self.config['env'], train=True, seed=seed+i) for i in range(self.num_envs)]
47 | seed += self.num_envs
48 | self.val_env_containers = [make_environment(self.config['env'], train=False, seed=seed+i) for i in range(self.num_val_envs)]
49 | env = self.train_env_containers[0]
50 | self.action_repeat = env.get_action_repeat()
51 | action_dims = env.get_action_dims()
52 | self.obs_channels, self.obs_height, self.obs_width = env.get_obs_chw()
53 | self.obs_other_dims = env.get_obs_other_dims()
54 |
55 | # Setup the observation encoder.
56 | self.crop_height = config['crop_height']
57 | self.crop_width = config['crop_width']
58 | self.same_crop_across_time = self.config.get('same_crop_across_time', False)
59 | self.random_crop_padding = self.config.get('random_crop_padding', 0)
60 | chw = (self.obs_channels, self.crop_height, self.crop_width)
61 | self.observation_model = ObservationModel(config, chw, self.obs_other_dims)
62 | obs_dims = self.observation_model.output_dims
63 |
64 | # Setup Contrastive Prediction.
65 | self.contrastive_prediction = ContrastivePrediction(config['contrastive'], obs_dims)
66 |
67 | # Setup the recurrent dynamics model.
68 | if 'wm' in config:
69 | self.model = WorldModel(config['wm'], obs_dims, action_dims)
70 | state_dims = self.model.state_dims
71 | self.exclude_wm_loss = config.get('exclude_wm_loss', False)
72 | else: # For models like plain SAC.
73 | self.model = None
74 | state_dims = obs_dims
75 | self.exclude_wm_loss = True
76 |
77 | # Setup Actor and Critic.
78 | self.actor = TanhGaussianPolicy(config['actor'], state_dims, action_dims)
79 | self.critic = DoubleCritic(config['critic'], state_dims, action_dims)
80 | self.log_alpha = nn.parameter.Parameter(torch.tensor([float(self.config['initial_log_alpha'])], device=device))
81 | self.target_entropy = self.config.get('target_entropy', -action_dims)
82 |
83 | # Initialization.
84 | if 'initial_model_path' in config:
85 | model_path = os.path.join(artifact_dir, config['initial_model_path'])
86 | self.load(model_path)
87 |
88 | if self.model is not None:
89 | self.model = self.model.to(device)
90 | self.observation_model = self.observation_model.to(device)
91 | self.contrastive_prediction = self.contrastive_prediction.to(device)
92 | self.actor = self.actor.to(device)
93 | self.critic = self.critic.to(device)
94 |
95 | self.has_momentum_encoder = config.get('momentum_encoder', False)
96 |
97 | # Set up optimizers.
98 | # Model optimizer.
99 | params = add_weight_decay(self.observation_model, config['weight_decay'])
100 | if self.model is not None:
101 | params.extend(add_weight_decay(self.model, config['weight_decay']))
102 | contrastive_pred_params = list(self.contrastive_prediction.parameters())
103 | params.append({'params': contrastive_pred_params, 'lr': config['lr_inverse_temp'], 'weight_decay': 0.0})
104 | self.optimizer = torch.optim.Adam(params, lr=config['lr'], betas=(0.9, 0.999))
105 |
106 | # Actor optimizer.
107 | actor_params = add_weight_decay(self.actor, config['weight_decay'])
108 | self.optimizer_actor = torch.optim.Adam(actor_params, lr=config['lr_actor'], betas=(config.get('momentum_actor', 0.9), 0.999))
109 |
110 | # Critic optimizer.
111 | critic_params = add_weight_decay(self.critic, config['weight_decay'])
112 | if config.get('include_model_params_in_critic', False): # Include model params in critic.
113 | critic_params.extend(params)
114 | self.sac_detach_states = False
115 | else:
116 | self.sac_detach_states = True
117 | self.optimizer_critic = torch.optim.Adam(critic_params, lr=config['lr_critic'], betas=(config.get('momentum_critic', 0.9), 0.999))
118 | self.critic_optimizer_parameter_list = get_parameter_list(self.optimizer_critic)
119 | self.target_critic = deepcopy(self.critic).to(device)
120 |
121 | # Adaptive temperature optimizer for SAC.
122 | self.optimizer_alpha = torch.optim.Adam([self.log_alpha], lr=config['lr_alpha'],
123 | betas=(config.get('momentum_alpha', 0.9), 0.999))
124 |
125 | if self.has_momentum_encoder:
126 | self.target_encoder = deepcopy(self.observation_model.encoder)
127 | self.moco_dims = self.target_encoder.output_dims
128 | self.moco_W = nn.Parameter(torch.rand(self.moco_dims, self.moco_dims, device=self.device))
129 | curl_params = list(self.observation_model.encoder.parameters())
130 | curl_params.append(self.moco_W)
131 | self.curl_optimizer = torch.optim.Adam(curl_params, lr=config['lr_curl'], betas=(0.9, 0.999))
132 |
133 | self.optimizer_parameter_list = get_parameter_list(self.optimizer)
134 |
135 | # Whether SAC uses the mean or sampled state.
136 | self.sac_deterministic_state = config.get('sac_deterministic_state', False)
137 |
138 | # Whether to decode from the prior or posterior.
139 | self.recon_from_prior = self.config.get('recon_from_prior', False)
140 |
141 | # For saving the best model.
142 | self.best_loss = None
143 |
144 | def save(self, step, loss):
145 | """ Save a checkpoint. """
146 | if self.debug:
147 | return
148 | checkpoint_dir = os.path.join(self.log_dir, 'checkpoint_%08d' % step)
149 | if not os.path.isdir(checkpoint_dir):
150 | os.makedirs(checkpoint_dir)
151 | filename = os.path.join(checkpoint_dir, 'model.pt')
152 | print('Saved model to {}'.format(filename))
153 | info = {
154 | 'observation_model_state_dict': self.observation_model.state_dict(),
155 | 'contrastive_prediction_state_dict': self.contrastive_prediction.state_dict(),
156 | 'actor_state_dict': self.actor.state_dict(),
157 | 'critic_state_dict': self.critic.state_dict(),
158 | 'log_alpha': self.log_alpha.item(),
159 | 'loss': loss,
160 | 'step': step
161 | }
162 | if self.model is not None:
163 | info['model_state_dict'] = self.model.state_dict()
164 | torch.save(info, filename)
165 | if self.best_loss is None or loss < self.best_loss:
166 | self.best_loss = loss
167 | filename = os.path.join(self.log_dir, 'best_model.pt')
168 | print('Saved model to {}'.format(filename))
169 | torch.save(info, filename)
170 |
171 | def load(self, filename, skip_actor_critic=False):
172 | info = torch.load(filename, map_location=torch.device('cpu'))
173 | missing, unexpected = self.observation_model.load_state_dict(info['observation_model_state_dict'], strict=False)
174 | print('Missing keys', missing)
175 | print('Unexpected keys', unexpected)
176 | if 'contrastive_prediction_state_dict' in info:
177 | self.contrastive_prediction.load_state_dict(info['contrastive_prediction_state_dict'])
178 | if not skip_actor_critic:
179 | self.actor.load_state_dict(info['actor_state_dict'])
180 | self.critic.load_state_dict(info['critic_state_dict'])
181 | self.log_alpha[0] = info['log_alpha']
182 | if self.model is not None:
183 | self.model.load_state_dict(info['model_state_dict'])
184 |
185 | def normalize(self, obs):
186 | return obs.float() / 255
187 |
188 | def unnormalize(self, obs):
189 | obs = obs[..., -3:, :, :] # Select the last three channels. Sometimes we have stacked frames.
190 | return torch.clamp(obs * 255, 0, 255).to(torch.uint8)
191 |
192 | def forward_prop(self, batch, decoding_for_viz=False):
193 | """
194 | Fprop the batch through the model.
195 | """
196 | outputs = self.observation_model(batch) # (T, B, dims)
197 | batch['obs_features'] = outputs['obs_features']
198 |
199 | # Fprop through the world model.
200 | outputs.update(self.model(batch))
201 |
202 | if self.model.decoder is not None: # If the world model uses decoding.
203 | if self.recon_from_prior:
204 | obs_features_recon = outputs['obs_features_recon_prior']
205 | else:
206 | obs_features_recon = outputs['obs_features_recon_post']
207 | outputs['obs_features_recon'] = obs_features_recon
208 |
209 | if self.observation_model.decoder is not None: # If decoding all the way to pixels.
210 | if self.config.get('detach_pixel_decoder', False):
211 | obs_features_recon = obs_features_recon.detach()
212 | obs_recon = self.observation_model.decode(obs_features_recon)
213 | outputs['obs_recon'] = obs_recon
214 | if decoding_for_viz:
215 | # We want to visualize recon from both prior and posterior latent state.
216 | # One of them is computed above, so the other is computed here.
217 | if self.recon_from_prior:
218 | outputs['obs_recon_prior'] = obs_recon
219 | obs_features_recon = outputs['obs_features_recon_post'].detach()
220 | outputs['obs_recon_post'] = self.observation_model.decode(obs_features_recon)
221 | else:
222 | outputs['obs_recon_post'] = obs_recon
223 | obs_features_recon = outputs['obs_features_recon_prior'].detach()
224 | outputs['obs_recon_prior'] = self.observation_model.decode(obs_features_recon)
225 | return outputs
226 |
227 | def loss_reward(self, batch, outputs, loss, metrics):
228 | loss_scales = self.config['loss_scales']
229 | if loss_scales['eta_r'] == 0:
230 | metrics['loss_reward'] = 0.
231 | else:
232 | reward_prediction = outputs['reward_prediction'][1:] # The reward is the reward at time t.
233 | reward = batch['reward'][:-1] # reward at index t corresponds to state t+1.
234 | loss_reward = nn.functional.smooth_l1_loss(reward_prediction, reward).mean()
235 | metrics['loss_reward'] = loss_reward.item()
236 | loss = loss + loss_reward * loss_scales['eta_r']
237 | return loss, metrics
238 |
239 | def loss_fwd_dynamics(self, batch, outputs, loss, metrics):
240 | loss_scales = self.config['loss_scales']
241 | if loss_scales['eta_fwd'] == 0:
242 | metrics['loss_fwd'] = 0.
243 | else:
244 | posterior_detached = {k: v.detach() for k, v in outputs['posterior'].items()}
245 |
246 | # skip t=0, because prior is uninformative there. No reason why posterior should match that.
247 | loss_fwd = self.model.dynamics.compute_forward_dynamics_loss(posterior_detached, outputs['prior'])[1:]
248 | loss_fwd = loss_fwd.mean()
249 |
250 | if 'eta_q' in loss_scales and loss_scales['eta_q'] > 0.0:
251 | eta_q = loss_scales['eta_q']
252 | prior_detached = {k: v.detach() for k, v in outputs['prior'].items()}
253 | loss_fwd_q = self.model.dynamics.compute_forward_dynamics_loss(outputs['posterior'], prior_detached)[1:]
254 | loss_fwd_q = loss_fwd_q.mean()
255 | loss_fwd = (1 - eta_q) * loss_fwd + loss_scales['eta_q'] * loss_fwd_q
256 | metrics['loss_fwd'] = loss_fwd.item()
257 | loss = loss + loss_fwd * loss_scales['eta_fwd']
258 | return loss, metrics
259 |
260 | def loss_contrastive(self, batch, outputs, train, loss, metrics):
261 | loss_scales = self.config['loss_scales']
262 | eta_s = loss_scales['eta_s']
263 | if eta_s > 0:
264 | x = outputs['obs_features']
265 | y = outputs['obs_features_recon']
266 | if self.recon_from_prior: # Skip t=0 when reconstructing from prior, because prior knows nothing at t=0.
267 | x = x[1:]
268 | y = y[1:]
269 | loss_contrastive = self.contrastive_prediction(x, y, train=train)
270 | loss = loss + loss_contrastive * eta_s
271 | metrics['loss_contrastive'] = loss_contrastive.item()
272 | return loss, metrics
273 |
274 | def loss_observation_recon(self, batch, outputs, loss, metrics):
275 | if 'obs_recon' not in outputs:
276 | return loss, metrics
277 | obs_recon = outputs['obs_recon']
278 | obs = batch.get('obs_image_clean', batch['obs_image'])
279 | if self.recon_from_prior: # Skip t=0.
280 | obs_recon = obs_recon[1:]
281 | obs = obs[1:]
282 | loss_obs = nn.functional.smooth_l1_loss(obs_recon, obs).mean()
283 | loss_scales = self.config['loss_scales']
284 | metrics['loss_obs'] = loss_obs.item()
285 | loss = loss + loss_scales['eta_x'] * loss_obs
286 | return loss, metrics
287 |
288 | def loss_inv_dynamics(self, batch, outputs, loss, metrics):
289 | if 'action_prediction' not in outputs:
290 | return loss, metrics
291 | loss_scales = self.config['loss_scales']
292 | if loss_scales['eta_inv'] == 0:
293 | metrics['loss_inverse_dynamics'] = 0.
294 | else:
295 | action_prediction = outputs['action_prediction'] # (T-1, B, a_dims) tanh valued.
296 | action = batch['action'][:-1]
297 | loss_inverse_dynamics = 0.5 * ((action - action_prediction) ** 2).sum(dim=-1).mean()
298 | loss = loss + loss_scales['eta_inv'] * loss_inverse_dynamics
299 | metrics['loss_inverse_dynamics'] = loss_inverse_dynamics.item()
300 | return loss, metrics
301 |
302 | def compute_loss(self, batch, train, decoding_for_viz=False):
303 | outputs = self.forward_prop(batch, decoding_for_viz=decoding_for_viz)
304 | metrics = {}
305 | loss = 0
306 | loss, metrics = self.loss_reward(batch, outputs, loss, metrics)
307 | loss, metrics = self.loss_inv_dynamics(batch, outputs, loss, metrics)
308 | loss, metrics = self.loss_fwd_dynamics(batch, outputs, loss, metrics)
309 | loss, metrics = self.loss_contrastive(batch, outputs, train, loss, metrics)
310 | loss, metrics = self.loss_observation_recon(batch, outputs, loss, metrics)
311 | metrics['loss_total'] = loss.item()
312 | return loss, metrics, outputs
313 |
314 | def update_curl(self, batch, step, heavy_logging=False):
315 | with torch.no_grad():
316 | f_k, _ = self.observation_model.encode_pixels(batch['obs_image_2'], encoder=self.target_encoder)
317 | f_k = f_k.detach().view(-1, self.moco_dims)
318 | f_q, _ = self.observation_model.encode_pixels(batch['obs_image'])
319 | f_q = f_q.view(-1, self.moco_dims)
320 | f_proj = torch.matmul(f_k, self.moco_W)
321 | logits = torch.matmul(f_q, f_proj.T)
322 | log_probs = F.log_softmax(logits, dim=1)
323 | loss = -(log_probs.diagonal().mean())
324 | metrics = {'moco_loss' : loss.item()}
325 |
326 | self.curl_optimizer.zero_grad()
327 | loss.backward()
328 | self.curl_optimizer.step()
329 |
330 | # Do momentum update.
331 | tau = self.config.get('update_target_encoder_tau', 1)
332 | self.update_target(self.target_encoder, self.observation_model.encoder, tau)
333 |
334 | return metrics
335 |
336 | def update_world_model(self, batch, step, heavy_logging=False):
337 | """
338 | Update the world model.
339 | batch : Dict containing keys ('action', 'obs_image', 'reward', etc)
340 | 'action' : (T, B action_dims)
341 | 'obs_image' : (T, B, C, H, W)
342 | 'reward': (T, B)
343 | """
344 | loss, metrics, outputs = self.compute_loss(batch, train=True, decoding_for_viz=heavy_logging)
345 | self.optimizer.zero_grad()
346 | loss.backward()
347 | if 'max_grad_norm_wm' in self.config:
348 | grad_norm = torch.nn.utils.clip_grad_norm_(self.optimizer_parameter_list, self.config['max_grad_norm_wm'])
349 | metrics['grad_norm_wm'] = grad_norm.item()
350 | self.optimizer.step()
351 |
352 | if step % self.config['print_every'] == 0:
353 | loss_str = ' '.join(['{}: {:.2f}'.format(k, v) for k, v in sorted(metrics.items())])
354 | print('Step {} {}'.format(step, loss_str))
355 | if not self.debug and self.tb_writer is not None:
356 | for k, v in metrics.items():
357 | self.tb_writer.add_scalar('metrics/{}'.format(k), v, step)
358 | if heavy_logging:
359 | max_B = 16
360 | self.tb_writer.add_video('obs/input',
361 | self.unnormalize(batch['obs_image']).transpose(0, 1)[:max_B], step)
362 | if self.observation_model.decoder is not None:
363 | self.tb_writer.add_video('obs/recon',
364 | self.unnormalize(outputs['obs_recon']).transpose(0, 1)[:max_B], step)
365 | self.tb_writer.add_video('obs/recon_post',
366 | self.unnormalize(outputs['obs_recon_post']).transpose(0, 1)[:max_B], step)
367 | self.tb_writer.add_video('obs/recon_prior',
368 | self.unnormalize(outputs['obs_recon_prior']).transpose(0, 1)[:max_B], step)
369 | return metrics, outputs
370 |
371 | def log_video(self, video_tag, frames, step):
372 | """
373 | Log a video to disk.
374 | Args:
375 | frames : List of (B, T, C, H, W)
376 | step: training step.
377 | video_tag: tag used for logging into tensorboard and as dir name for disk.
378 | """
379 | self.tb_writer.add_video(video_tag, frames, step)
380 |
381 | B, T, C, H, W = list(frames.shape)
382 | frames = frames.permute(1, 2, 3, 0, 4).contiguous().view(T, C, H, B*W) # Stack batch along width.
383 | video_dir = os.path.join(self.log_dir, video_tag)
384 | os.makedirs(video_dir, exist_ok=True)
385 | filename = os.path.join(video_dir, 'video_%08d.mp4' % step)
386 | write_video_mp4(filename, frames)
387 |
388 | def validate(self, step):
389 | self.observation_model.eval()
390 | if self.model is not None:
391 | self.model.eval()
392 | self.actor.eval()
393 | self.critic.eval()
394 | tic = time.time()
395 | metrics = {}
396 | # Collect data. One episode in each val environment.
397 | replay_buffer = SequenceReplayBuffer()
398 | num_episodes_per_val_env_for_reward = self.config.get('num_episodes_per_val_env_for_reward', 10)
399 | sample_policy = self.config.get('val_stochastic_policy', False)
400 | if sample_policy:
401 | print('Using stochastic policy for val')
402 | episode_reward = self.collect_data_from_actor(replay_buffer,
403 | num_episodes_per_env=num_episodes_per_val_env_for_reward,
404 | train=False, sample_policy=sample_policy)
405 | metrics['episode_reward'] = episode_reward
406 |
407 | # Take the first few episodes for computing the rest of the metrics. They are expensive to compute.
408 | num_episodes_for_model = self.config.get('num_episodes_val_for_model', 5)
409 | batch = replay_buffer.sample(num_episodes_for_model)
410 | batch = self.prep_batch(batch, random_crop=False)
411 | steps_per_episode = self.config['episode_steps'] // self.action_repeat
412 |
413 | if not self.exclude_wm_loss:
414 | with torch.no_grad():
415 | loss, model_metrics, outputs = self.compute_loss(batch, train=False, decoding_for_viz=True)
416 | metrics.update(model_metrics)
417 | # Generate rollout from prior.
418 | if self.observation_model.decoder is not None:
419 | init_t = self.config['rollout_prior_init_t']
420 | assert 0 < init_t < steps_per_episode - 1
421 | init_state = dict([(k, v[init_t-1]) for k, v in outputs['posterior'].items()])
422 | prior = self.model.dynamics.rollout_prior(init_state, batch['action'][init_t:, ...], deterministic=False)
423 | # Decode to images.
424 | latent = self.model.dynamics.get_state(prior, deterministic=False)
425 | obs_recon_imagined = self.observation_model.decode(self.model.decoder(latent))
426 | # Add the first init_t images from the posterior.
427 | obs_recon_imagined = torch.cat([outputs['obs_recon_prior'][:init_t, :], obs_recon_imagined], dim=0)
428 |
429 | elif self.observation_model.use_gating_network: # Even if model is None, we want outputs to have gating.
430 | with torch.no_grad():
431 | outputs = self.observation_model(batch) # (T, B, dims) # Used to visualize gating.
432 |
433 | toc = time.time()
434 | metrics.update({
435 | 'timing': toc - tic,
436 | })
437 |
438 | loss_str = ' '.join(['{}: {:.2f}'.format(k, v) for k, v in sorted(metrics.items())])
439 | print('Val Iter {} {}'.format(step, loss_str))
440 | if not self.debug and self.tb_writer is not None:
441 | for k, v in metrics.items():
442 | self.tb_writer.add_scalar('val_metrics/{}'.format(k), v, step)
443 | obs = self.unnormalize(batch['obs_image']).transpose(0, 1) # (B, T, C, H, W)
444 | if self.observation_model.use_gating_network:
445 | obs_gating = outputs['obs_gating'].transpose(0, 1) # (B, T, F, 1, H, W)
446 | obs_gating = obs_gating[:, :, -1, :, :, :] # The gating for the last frame.
447 | obs_gating = (obs_gating * 255).to(torch.uint8)
448 | obs_gating = obs_gating.expand_as(obs).contiguous() # replicate along RGB.
449 | obs = torch.cat([obs, obs_gating], dim=3)
450 | if self.model is not None and self.observation_model.decoder is not None:
451 | obs_recon = self.unnormalize(outputs['obs_recon']).transpose(0, 1)
452 | obs_recon_post = self.unnormalize(outputs['obs_recon_post']).transpose(0, 1)
453 | obs_recon_prior = self.unnormalize(outputs['obs_recon_prior']).transpose(0, 1)
454 | obs_recon_imagined = self.unnormalize(obs_recon_imagined).transpose(0, 1)
455 | obs = torch.cat([obs, obs_recon, obs_recon_post, obs_recon_prior, obs_recon_imagined], dim=3)
456 | self.log_video('obs/val', obs, step)
457 | return -episode_reward
458 |
459 | def collect_data_random_policy(self, replay_buffer, num_episodes_per_env=1, train=True):
460 | steps_per_episode = self.config['episode_steps'] // self.action_repeat
461 | env_containers = self.train_env_containers if train else self.val_env_containers
462 | total_reward = 0
463 | for env_container in env_containers:
464 | action_low, action_high = env_container.get_action_limits()
465 | action_dims = env_container.get_action_dims()
466 | for _ in range(num_episodes_per_env):
467 | obs = env_container.reset()
468 | seq = []
469 | for _ in range(steps_per_episode):
470 | action = np.random.uniform(action_low, action_high, action_dims)
471 | next_obs, reward, _, _ = env_container.step(action)
472 | seq.append(dict(obs=obs, action=action, reward=reward))
473 | obs = next_obs
474 | total_reward += reward
475 | replay_buffer.add(seq)
476 | avg_reward = total_reward / (num_episodes_per_env * len(env_containers))
477 | return avg_reward
478 |
479 | def prep_batch(self, batch, random_crop=False):
480 | """ Prepare batch of data for input to the model.
481 | Inputs:
482 | batch : Dict containing 'obs', etc.
483 | Returns:
484 | batch: Same dict, but with images randomly cropped, moved to GPU, normalized.
485 | """
486 | for key in batch.keys():
487 | batch[key] = batch[key].to(self.device)
488 | obs_image_cropped = crop_image_tensor(batch['obs_image'], self.crop_height, self.crop_width,
489 | random_crop=random_crop,
490 | same_crop_across_time=self.same_crop_across_time,
491 | padding=self.random_crop_padding)
492 | if self.has_momentum_encoder:
493 | batch['obs_image_2'] = crop_image_tensor(batch['obs_image'], self.crop_height, self.crop_width,
494 | random_crop=random_crop,
495 | same_crop_across_time=self.same_crop_across_time,
496 | padding=self.random_crop_padding)
497 |
498 | if 'obs_image_clean' in batch: # When we have paired distraction-free and distracting obs.
499 | batch['obs_image_clean'] = crop_image_tensor(batch['obs_image_clean'], self.crop_height, self.crop_width, random_crop=False, same_crop_across_time=True, padding=0)
500 | else:
501 | batch['obs_image_clean'] = crop_image_tensor(batch['obs_image'], self.crop_height, self.crop_width, random_crop=False, same_crop_across_time=True, padding=0)
502 |
503 | batch['obs_image'] = obs_image_cropped
504 | if len(batch['obs_image'].shape) == 5: # (B, T, C, H, W) -> (T, B, C, H, W)
505 | swap_first_two_dims = True
506 | else: # (B, C, H, W) -> no change.
507 | swap_first_two_dims = False
508 | for key in batch.keys():
509 | if swap_first_two_dims:
510 | batch[key] = batch[key].transpose(0, 1)
511 | batch[key] = batch[key].contiguous().float().detach()
512 | batch['obs_image'] = self.normalize(batch['obs_image'])
513 | if 'obs_image_clean' in batch:
514 | batch['obs_image_clean'] = self.normalize(batch['obs_image_clean'])
515 | if 'obs_imaage_2' in batch:
516 | batch['obs_image_2'] = self.normalize(batch['obs_image_2'])
517 | return batch
518 |
519 | def collect_data_from_actor(self, replay_buffer, num_episodes_per_env=1, train=True, sample_policy=True):
520 | steps_per_episode = self.config['episode_steps'] // self.action_repeat
521 | self.observation_model.eval()
522 | if self.model is not None:
523 | self.model.eval()
524 | self.actor.eval()
525 | reward_total = 0
526 | env_containers = self.train_env_containers if train else self.val_env_containers
527 | num_env = len(env_containers)
528 |
529 | for _ in range(num_episodes_per_env):
530 | seq_list = []
531 | obs_list = []
532 | for env_container in env_containers:
533 | obs = env_container.reset()
534 | seq_list.append(list())
535 | obs_list.append(dict(obs=obs))
536 | posterior = None
537 | action = None
538 | for _ in range(steps_per_episode):
539 | # Find the action to take for a batch of environments.
540 | batch = torchify(obs_list) # Dict of (B, ...)
541 | batch = self.prep_batch(batch, random_crop=False)
542 | outputs = self.observation_model(batch)
543 | obs_features = outputs['obs_features']
544 | if self.model is not None: # If using a dynamics model.
545 | latent, posterior = self.model.forward_one_step(obs_features, posterior, action,
546 | deterministic_latent=self.sac_deterministic_state)
547 | else:
548 | latent = obs_features
549 | action, _, _ = self.actor(latent, sample=sample_policy)
550 | action_npy = action.detach().cpu().numpy() # (B, a_dims)
551 |
552 | # Step each environment with the computed action.
553 | for i, env_container in enumerate(env_containers):
554 | current_action = action_npy[i]
555 | obs, reward, _, _ = env_container.step(current_action)
556 | seq_list[i].append(dict(obs=obs_list[i]['obs'], action=current_action, reward=reward))
557 | obs_list[i]['obs'] = obs
558 | reward_total += reward
559 | for seq in seq_list:
560 | replay_buffer.add(seq)
561 | episode_reward = reward_total / (num_env * num_episodes_per_env)
562 | return episode_reward
563 |
564 | def update_target(self, target, critic, tau):
565 | target_params_dict = dict(target.named_parameters())
566 | for n, p in critic.named_parameters():
567 | target_params_dict[n].data.copy_(
568 | (1 - tau) * target_params_dict[n] + tau * p
569 | )
570 |
571 | def update_actor_critic_sac(self, batch, step, heavy_logging=False):
572 | """
573 | Inputs:
574 | batch : Dict containing keys ('action', 'obs', 'reward')
575 | 'action' : (T, B, action_dims)
576 | 'obs' : (T, B, C, H, W)
577 | 'reward': (T, B)
578 | """
579 | metrics = {}
580 |
581 | outputs = self.observation_model(batch) # (T, B, dims)
582 | obs_features = outputs['obs_features']
583 | batch['obs_features'] = obs_features
584 | if self.model is not None:
585 | outputs = self.model(batch) # Dict containing prior (stoch, logits), posterior(..)
586 | states = self.model.dynamics.get_state(outputs['posterior'],
587 | deterministic=self.sac_deterministic_state)
588 | else:
589 | states = obs_features
590 |
591 | # Update critic (potentially including the encoder).
592 | current_states = states[:-1]
593 | if self.sac_detach_states:
594 | current_states = current_states.detach()
595 | current_actions = batch['action'][:-1]
596 | reward = batch['reward'][:-1] # (T-1, B)
597 | next_states = states[1:].detach()
598 | alpha = torch.exp(self.log_alpha).detach()
599 | gamma = self.config['gamma']
600 | with torch.no_grad():
601 | if torch.isnan(next_states).any():
602 | raise Exception('Next states contains nan')
603 | next_actions, next_action_log_probs, _ = self.actor(next_states)
604 | target_q1, target_q2 = self.target_critic(next_states, next_actions)
605 | target_v = torch.min(target_q1, target_q2) - alpha * next_action_log_probs
606 | target_q = reward + gamma * target_v
607 | q1, q2 = self.critic(current_states, current_actions) # (T-1, B)
608 | critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q)
609 | self.optimizer_critic.zero_grad()
610 | critic_loss.backward()
611 | if 'max_grad_norm_critic' in self.config:
612 | grad_norm = torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.config['max_grad_norm_critic'])
613 | metrics['grad_norm_critic'] = grad_norm.item()
614 | self.optimizer_critic.step()
615 |
616 | # Update actor.
617 | current_states_detached = current_states.detach() # Actor loss does not backpropagate into encoder or dynamics.
618 | policy_actions, policy_action_log_probs, policy_action_std = self.actor(current_states_detached) # (T-1, B, action_dims)
619 | q1, q2 = self.critic(current_states_detached, policy_actions)
620 | q = torch.min(q1, q2)
621 | q_loss = -q.mean()
622 | entropy_loss = policy_action_log_probs.mean()
623 | entropy_loss_wt = torch.exp(self.log_alpha).detach()
624 | actor_loss = q_loss + entropy_loss_wt * entropy_loss
625 | self.optimizer_actor.zero_grad()
626 | actor_loss.backward()
627 | if 'max_grad_norm_actor' in self.config:
628 | grad_norm = torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.config['max_grad_norm_actor'])
629 | metrics['grad_norm_actor'] = grad_norm.item()
630 | self.optimizer_actor.step()
631 |
632 | # Update alpha (adaptive entropy loss wt)
633 | alpha_loss = -(torch.exp(self.log_alpha) * (self.target_entropy + entropy_loss.detach()))
634 | self.optimizer_alpha.zero_grad()
635 | alpha_loss.backward()
636 | if 'max_grad_norm_log_alpha' in self.config:
637 | grad_norm = torch.nn.utils.clip_grad_norm_([self.log_alpha], self.config['max_grad_norm_log_alpha'])
638 | metrics['grad_norm_log_alpha'] = grad_norm.item()
639 | self.optimizer_alpha.step()
640 | if 'max_log_alpha' in self.config:
641 | with torch.no_grad():
642 | self.log_alpha.clamp_(max=self.config['max_log_alpha'])
643 |
644 | if step % self.config['update_target_critic_after'] == 0:
645 | tau = self.config.get('update_target_critic_tau', 1)
646 | self.update_target(self.target_critic, self.critic, tau)
647 |
648 | metrics.update({
649 | 'critic_loss': critic_loss.item(),
650 | 'actor_loss': actor_loss.item(),
651 | 'q_loss': q_loss.item(),
652 | 'entropy_loss': entropy_loss.item(),
653 | 'log_alpha': self.log_alpha.item(),
654 | })
655 | if not self.debug and self.tb_writer is not None:
656 | for k, v in metrics.items():
657 | self.tb_writer.add_scalar('rl_metrics/{}'.format(k), v, step)
658 | if heavy_logging:
659 | self.tb_writer.add_histogram('rl_metrics/reward', reward.view(-1), step)
660 | self.tb_writer.add_histogram('rl_metrics/q_targets', target_q.view(-1), step)
661 | self.tb_writer.add_histogram('rl_metrics/critic_scores', q.view(-1), step)
662 | self.tb_writer.add_histogram('rl_metrics/action', policy_actions.view(-1), step)
663 | self.tb_writer.add_histogram('rl_metrics/action_log_probs', policy_action_log_probs.view(-1), step)
664 | self.tb_writer.add_histogram('rl_metrics/action_std', policy_action_std.view(-1), step)
665 | return metrics
666 |
667 | def train(self):
668 | """ Train the model."""
669 | # Setup replay buffer.
670 | steps_per_episode = self.config['episode_steps'] // self.action_repeat
671 | replay_buffer_size = self.config['replay_buffer_size']
672 | num_episodes_in_replay_buffer = replay_buffer_size // steps_per_episode
673 | replay_buffer = SequenceReplayBuffer(size=num_episodes_in_replay_buffer)
674 |
675 | # Find out how many data collection iterations to do use.
676 | max_steps = self.config['max_steps'] // self.action_repeat
677 | num_iters = max_steps // (self.num_envs * steps_per_episode)
678 |
679 | # How many gradients updates per iteration.
680 | num_updates_per_iter = int(steps_per_episode * self.config.get('update_frequency_factor', 1.0))
681 |
682 | random_crop = self.config.get('random_crop', False)
683 | B = self.config['batch_size']
684 | T = self.config['dynamics_seq_len']
685 | train_step = 0
686 |
687 | # Initial data collection.
688 | initial_episodes_per_env = self.config['initial_data_steps'] // (self.num_envs * steps_per_episode) # Used to delay both world model and rl training.
689 | start_rl_training_after = self.config['start_rl_training_after'] # Used to delay rl training until world model has updated for a bit.
690 |
691 | for ii in range(num_iters):
692 | if ii % self.config['validate_every_iters'] == 0:
693 | loss = self.validate(ii)
694 | if ii == 0:
695 | print('Completed validation')
696 | self.save(ii, loss)
697 |
698 | # Collect data. One episode in each environment.
699 | with torch.no_grad():
700 | if ii < initial_episodes_per_env or train_step < start_rl_training_after:
701 | episode_reward = self.collect_data_random_policy(replay_buffer, num_episodes_per_env=1, train=True)
702 | else:
703 | episode_reward = self.collect_data_from_actor(replay_buffer, num_episodes_per_env=1, train=True,
704 | sample_policy=True)
705 | if not self.debug and self.tb_writer is not None:
706 | self.tb_writer.add_scalar('rl_metrics/episode_reward', episode_reward, ii)
707 |
708 | if ii < initial_episodes_per_env: # No updates until a few episodes have been collected.
709 | continue
710 | self.observation_model.train()
711 | if self.model is not None:
712 | self.model.train()
713 | self.actor.train()
714 | self.critic.train()
715 | for i in range(num_updates_per_iter):
716 | # Train world model.
717 | tic = time.time()
718 | train_step += 1
719 | batch = replay_buffer.sample(B, T) # Dict of (B, T, ..)
720 | batch = self.prep_batch(batch, random_crop=random_crop)
721 | tic1 = time.time()
722 | if not self.exclude_wm_loss: # Skip for model-free variants, like SAC, RSAC.
723 | self.update_world_model(batch, train_step, heavy_logging=(i == 0))
724 | tic2 = time.time()
725 | if self.has_momentum_encoder:
726 | self.update_curl(batch, train_step, heavy_logging=(i == 0))
727 | tic3 = time.time()
728 | if train_step >= start_rl_training_after:
729 | self.update_actor_critic_sac(batch, train_step, heavy_logging=(i == 0))
730 | toc = time.time()
731 | timing_metrics = {
732 | 'time_data_prep': tic1 - tic,
733 | 'time_wm_update': tic2 - tic1,
734 | 'time_curl_update': tic3 - tic2,
735 | 'time_ac_update': toc - tic3,
736 | 'time_per_update': toc - tic,
737 | }
738 | if not self.debug and self.tb_writer is not None:
739 | for k, v in timing_metrics.items():
740 | self.tb_writer.add_scalar('timing_metrics/{}'.format(k), v, train_step)
741 | if train_step == 1:
742 | print('Completed one step')
743 |
744 |
745 | def argument_parser(argument):
746 | """ Argument parser """
747 | parser = argparse.ArgumentParser(description='Binder Network.')
748 | parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
749 | parser.add_argument('-c', '--config', default='', type=str, help='Training config')
750 | parser.add_argument('--debug', action='store_true', help='Debug mode. Disable logging.')
751 | args = parser.parse_args(argument)
752 | return args
753 |
754 |
755 | def main():
756 | args = argument_parser(None)
757 | if not args.disable_cuda and torch.cuda.is_available():
758 | device = torch.device('cuda')
759 | print('Running on GPU {}'.format(torch.cuda.get_device_name(0)))
760 | else:
761 | device = torch.device('cpu')
762 | print('Running on CPU')
763 |
764 | try:
765 | with open(args.config) as f:
766 | config = yaml.safe_load(f)
767 | except FileNotFoundError:
768 | print("Error opening specified config yaml at: {}. "
769 | "Please check filepath and try again.".format(args.config))
770 | sys.exit(1)
771 |
772 | config = config['parameters']
773 | config['expt_id'] = generate_expt_id()
774 | seed = config['seed']
775 | random.seed(seed)
776 | np.random.seed(seed)
777 | torch.manual_seed(seed)
778 |
779 | trainer = Trainer(config, device, args.debug)
780 | trainer.train()
781 |
782 |
783 | if __name__ == '__main__':
784 | main()
785 |
786 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | source /miniconda/etc/profile.d/conda.sh
2 | conda init bash
3 | conda activate core
4 | export PYTHONPATH=.:$PYTHONPATH
5 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin:/usr/lib/nvidia-000
6 | tensorboard --logdir ${BOLT_ARTIFACT_DIR} --bind_all --port ${TENSORBOARD_PORT} &
7 | MUJOCO_GL=egl CUDA_VISIBLE_DEVICES=0 python train.py
8 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import random
6 | import string
7 | import time
8 | import torch
9 | from torchvision.transforms import Resize, RandomCrop, CenterCrop
10 | import subprocess as sp
11 | import numpy as np
12 | import os
13 | import imageio
14 |
15 |
16 | def torchify(seq):
17 | """
18 | Convert list of dict of numpy arrays/floats/dicts to dict of tensors.
19 | Args:
20 | seq : List of dicts.
21 | Returns:
22 | batch : Dict of tensors of shape (T, ..).
23 | """
24 | keys = seq[0].keys()
25 | batch = {}
26 | for key in keys:
27 | value = seq[0][key]
28 | if isinstance(value, np.ndarray):
29 | batch[key] = torch.stack([torch.from_numpy(frame[key]) for frame in seq]) # (T, ..)
30 | elif isinstance(value, float) or isinstance(value, int):
31 | batch[key] = torch.tensor([frame[key] for frame in seq])
32 | elif isinstance(value, dict):
33 | sub_batch = torchify([frame[key] for frame in seq])
34 | for sub_key, val in sub_batch.items():
35 | batch[key + '_' + sub_key] = val
36 | else:
37 | raise Exception('Unknown type of value in torchify for key ', key)
38 | return batch
39 |
40 |
41 | class FreezeParameters:
42 | def __init__(self, parameters):
43 | self.parameters = parameters
44 | self.param_states = [p.requires_grad for p in self.parameters]
45 |
46 | def __enter__(self):
47 | for param in self.parameters:
48 | param.requires_grad = False
49 |
50 | def __exit__(self, exc_type, exc_val, exc_tb):
51 | for i, param in enumerate(self.parameters):
52 | param.requires_grad = self.param_states[i]
53 |
54 |
55 | def freeze_parameters(parameters):
56 | for param in parameters:
57 | param.requires_grad = False
58 |
59 |
60 | def add_weight_decay(net, l2_value, skip_list=(), exclude_both_list=()):
61 | decay, no_decay = [], []
62 | for name, param in net.named_parameters():
63 | if not param.requires_grad or name in exclude_both_list:
64 | continue # frozen weights
65 | if len(param.shape) == 1 or name.endswith('.bias') or name in skip_list:
66 | no_decay.append(param)
67 | else:
68 | decay.append(param)
69 | return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}]
70 |
71 |
72 | def get_random_string(length=12):
73 | # choose from all lowercase letter
74 | letters = string.ascii_lowercase + string.digits
75 | result_str = ''.join(random.choice(letters) for i in range(length))
76 | return result_str
77 |
78 |
79 | def generate_expt_id():
80 | task_id = get_random_string()
81 | return time.strftime('%Y-%m-%d_%H-%M-%S') + '_' + task_id
82 |
83 |
84 | def features_as_image(x, x_min=None, x_max=None):
85 | """
86 | x : (B, D) or (1, B, D)
87 | """
88 | assert len(x.shape) == 2 or (len(x.shape) == 3 and x.shape[0] == 1)
89 | if x_max is None:
90 | x_max = x.max()
91 | if x_min is None:
92 | x_min = x.min()
93 | x = (x - x_min) / (x_max - x_min)
94 | if len(x.shape) == 2:
95 | x = x.unsqueeze(0) # (1, H, W)
96 | return x
97 |
98 |
99 | def probs_as_images(probs, W=64):
100 | """
101 | Inputs:
102 | probs: (..., D)
103 | Outputs:
104 | vid : (..., C, H, W), C = 3, W = 32
105 | """
106 | orig_shape = list(probs.shape)
107 | D = orig_shape[-1]
108 | H = D // W
109 | assert D % W == 0
110 | probs = (probs * 255).to(torch.uint8)
111 | C = 3
112 | probs = probs.view(-1, 1, H, W).expand(-1, C, -1, -1).contiguous()
113 | probs = probs.view(*orig_shape[:-1], C, H, W)
114 | return probs
115 |
116 |
117 | def resize_image(images, height, width):
118 | """
119 | Resize images.
120 | Inputs:
121 | images: (..., C, H, W)
122 | Outputs:
123 | images: (..., C, height, width)
124 | """
125 | orig_shape = list(images.shape)
126 | C, H, W = orig_shape[-3:]
127 | images = images.view(-1, C, H, W)
128 | resize_op = Resize((height, width))
129 | images = resize_op(images)
130 | images = images.view(*orig_shape[:-2], height, width)
131 | return images
132 |
133 |
134 | def crop_image_tensor(obs, crop_height, crop_width, random_crop=False, same_crop_across_time=False, padding=0):
135 | """
136 | Crop a tensor of images.
137 | Args:
138 | obs: (B, T, C, H, W), or (B, C, H, W) or (C, H, W).
139 | crop_height: Height of the cropped image.
140 | crop_width: Width of the cropped image.
141 | random_crop: If true, crop random patch. Otherwise the center crop is returned.
142 | same_crop_across_time: Maintain the same crop across time for temporal sequences.
143 | padding: How much edge padding to add.
144 | Returns:
145 | cropped_obs: (B, T, C, crop_height, crop_width)
146 | """
147 | assert len(obs.shape) >= 3
148 | channels, height, width = obs.shape[-3], obs.shape[-2], obs.shape[-1]
149 | if random_crop:
150 | transform = RandomCrop((crop_height, crop_width), padding=padding, padding_mode='edge')
151 | orig_shape = list(obs.shape[:-2])
152 | if same_crop_across_time and len(obs.shape) >= 5:
153 | T = obs.shape[-4]
154 | channels = channels * T
155 | obs = obs.view(-1, channels, height, width)
156 | cropped_obs = torch.zeros(obs.size(0), channels, crop_height, crop_width, dtype=obs.dtype, device=obs.device)
157 | for i in range(obs.size(0)):
158 | cropped_obs[i, ...] = transform(obs[i, ...])
159 | cropped_obs = cropped_obs.view(*orig_shape, crop_height, crop_width)
160 | else:
161 | transform = CenterCrop((crop_height, crop_width))
162 | cropped_obs = transform(obs)
163 | return cropped_obs
164 |
165 |
166 | def get_parameter_list(optimizer):
167 | params_list = []
168 | for group in optimizer.param_groups:
169 | params_list.extend(list(group['params']))
170 | return params_list
171 |
172 |
173 | def stack_tensor_dict_list(tensor_dict_list):
174 | """ Stack tensors in a list of dictionaries. """
175 | keys = tensor_dict_list[0].keys()
176 | res = {}
177 | for key in keys:
178 | res[key] = torch.stack([d[key] for d in tensor_dict_list])
179 | return res
180 |
181 |
182 | def write_video_mp4_command(filename, frames):
183 | """
184 | frames : T, C, H, W
185 | """
186 | if isinstance(frames, np.ndarray):
187 | T, channels, height, width = frames.shape
188 | elif isinstance(frames, torch.Tensor):
189 | T, channels, height, width = frames.shape
190 | frames = frames.detach().cpu().numpy()
191 | elif isinstance(frames, list):
192 | channels, height, width = frames[0].shape
193 | assert channels >= 3
194 | frames = np.stack([frame[:3, ...] for frame in frames], axis=0)
195 | else:
196 | raise Exception('Unknown frame specification.')
197 | frames = frames.astype(np.uint8)
198 | frames = frames.transpose((0, 2, 3, 1))
199 | print('Writing video {}'.format(filename))
200 | # Write out as a mp4 video.
201 | command = ['ffmpeg',
202 | '-y', # (optional) overwrite output file if it exists
203 | '-f', 'rawvideo',
204 | '-vcodec', 'rawvideo',
205 | '-s', '{}x{}'.format(width, height), # size of one frame
206 | '-pix_fmt', 'rgb24',
207 | '-r', '20', # frames per second
208 | '-an', # Tells FFMPEG not to expect any audio
209 | '-i', '-', # The input comes from a pipe
210 | '-vcodec', 'libx264',
211 | '-pix_fmt', 'yuv420p',
212 | filename]
213 | print(' '.join(command))
214 | proc = sp.Popen(command, stdin=sp.PIPE, stderr=sp.PIPE)
215 | outs, errs = proc.communicate(input=frames.tobytes())
216 |
217 |
218 | def write_video_mp4(filename, frames):
219 | """
220 | frames : T, C, H, W
221 | """
222 | if isinstance(frames, np.ndarray):
223 | T, channels, height, width = frames.shape
224 | elif isinstance(frames, torch.Tensor):
225 | T, channels, height, width = frames.shape
226 | frames = frames.detach().cpu().numpy()
227 | elif isinstance(frames, list):
228 | channels, height, width = frames[0].shape
229 | assert channels >= 3
230 | frames = np.stack([frame[:3, ...] for frame in frames], axis=0)
231 | else:
232 | raise Exception('Unknown frame specification.')
233 | frames = frames.astype(np.uint8)
234 | frames = frames.transpose((0, 2, 3, 1))
235 | print('Writing video {}'.format(filename))
236 | # Write out as a mp4 video.
237 | writer = imageio.get_writer(filename, fps=20)
238 | for frame in frames:
239 | writer.append_data(frame)
240 | writer.close()
241 |
242 |
243 | def test_write_video_mp4():
244 | filename = 'test.mp4'
245 | frames = np.random.rand(100, 3, 128, 128) * 255
246 | write_video_mp4(filename, frames)
247 |
248 |
249 | if __name__ == '__main__':
250 | test_write_video_mp4()
251 |
--------------------------------------------------------------------------------
/videos/hard-walker.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-core/22879e81e2b670d3b2e7bced2a2f735ae8f81971/videos/hard-walker.gif
--------------------------------------------------------------------------------
/videos/medium-cartpole.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-core/22879e81e2b670d3b2e7bced2a2f735ae8f81971/videos/medium-cartpole.gif
--------------------------------------------------------------------------------
/videos/medium-cheetah.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-core/22879e81e2b670d3b2e7bced2a2f735ae8f81971/videos/medium-cheetah.gif
--------------------------------------------------------------------------------
/world_model.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2021 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from utils import stack_tensor_dict_list
9 | from modules import FCNet, CNN, TransposeCNN, GRUCell, weight_init
10 |
11 |
12 | def get_pairwise_smooth_l1(x):
13 | dims = x.shape[0]
14 | x_shape = list(x.shape)
15 | x1 = x.unsqueeze(0).expand(dims, *x_shape)
16 | x2 = x.unsqueeze(1).expand(dims, *x_shape)
17 | return F.smooth_l1_loss(x1, x2, reduction='none')
18 |
19 |
20 | def get_pairwise_l2(x):
21 | dims = x.shape[0]
22 | x_shape = list(x.shape)
23 | x1 = x.unsqueeze(0).expand(dims, *x_shape)
24 | x2 = x.unsqueeze(1).expand(dims, *x_shape)
25 | return F.mse_loss(x1, x2, reduction='none')
26 |
27 |
28 | def sample_softmax(probs):
29 | dist = torch.distributions.one_hot_categorical.OneHotCategorical(probs=probs)
30 | stoch = dist.sample() + probs - probs.detach() # Trick to send straight-through gradients into probs.
31 | return stoch
32 |
33 |
34 | class Dynamics(nn.Module):
35 | def __init__(self, config, obs_embed_dims, action_dims):
36 | super().__init__()
37 | self.config = config
38 | self.action_dims = action_dims
39 | self.obs_embed_dims = obs_embed_dims
40 |
41 | self.forward_dynamics_loss = config['forward_dynamics_loss']
42 | assert self.forward_dynamics_loss in ['neg_log_prob', 'kl']
43 | self.discrete = config.get('discrete', False)
44 | if self.discrete:
45 | self.num_softmaxes = config['num_softmaxes']
46 | self.dims_per_softmax = config['dims_per_softmax']
47 | self.latent_dims = self.dims_per_softmax * self.num_softmaxes
48 | self.suff_stats_dims = self.latent_dims # dims needed to specify a distribution over the latent state.
49 | else:
50 | self.latent_dims = config['latent_dims']
51 | self.suff_stats_dims = 2 * self.latent_dims # mu, log_std
52 |
53 | self.recurrent = config.get('recurrent', False)
54 | if self.recurrent:
55 | self.rnn_state_dims = config['rnn_state_dims']
56 | else:
57 | self.rnn_state_dims = 0
58 | self.state_dims = self.latent_dims + self.rnn_state_dims
59 |
60 | if self.recurrent:
61 | # Input to recurrent model.
62 | self.rnn_input_net = FCNet(config['rnn_input'], self.latent_dims + action_dims) # Inputs: Prev latent + act.
63 | num_features = self.rnn_input_net.output_dims
64 | print('RNN Input Net input_dims: {} output_dims {}'.format(self.latent_dims + action_dims, num_features))
65 |
66 | # Recurrent model.
67 | print('GRUCell input_dims {} hidden dims {}'.format(num_features, self.rnn_state_dims))
68 | self.cell = GRUCell(num_features, self.rnn_state_dims, norm=True) # h_t = cell([z_{t-1},a_{t-1}], h_{t-1})
69 |
70 | # Prior model (Transition function) P(z_t | h_t)
71 | print('Prior Net input dims {} output dims {}'.format(self.rnn_state_dims, self.suff_stats_dims))
72 | self.prior_net = FCNet(config['prior'], self.rnn_state_dims, out_features=self.suff_stats_dims)
73 | else:
74 | # Prior model (Transition function) P(z_t | h_{t-1}, a_{t-1})
75 | print('Prior Net --------')
76 | print('Prior Net input dims {} output dims {}'.format(self.latent_dims + action_dims, self.suff_stats_dims))
77 | self.prior_net = FCNet(config['prior'], self.latent_dims + action_dims, self.suff_stats_dims)
78 |
79 | # Posterior model (Representation model). P(z_t | x_t, h_t)
80 | print('Posterior Net input dims {} output dims {}'.format(self.obs_embed_dims + self.rnn_state_dims,
81 | self.suff_stats_dims))
82 | self.posterior_net = FCNet(config['posterior'], self.obs_embed_dims + self.rnn_state_dims,
83 | out_features=self.suff_stats_dims)
84 |
85 | # Initial_state.
86 | if self.discrete:
87 | initial_logits = torch.zeros(self.num_softmaxes, self.dims_per_softmax)
88 | self.register_buffer('initial_logits', initial_logits)
89 | else:
90 | initial_mu = torch.zeros(self.latent_dims)
91 | initial_log_std = torch.zeros(self.latent_dims)
92 | self.register_buffer('initial_mu', initial_mu)
93 | self.register_buffer('initial_log_std', initial_log_std)
94 | if self.recurrent:
95 | initial_rnn_state = torch.zeros(self.rnn_state_dims)
96 | self.register_buffer('initial_rnn_state', initial_rnn_state)
97 |
98 | self.elementwise_bisim = False
99 | self.free_kl = config.get('free_kl', 0.0)
100 |
101 | def initial_state(self, batch_size, deterministic=False):
102 | if self.discrete:
103 | logits = self.initial_logits.unsqueeze(0).expand(batch_size, -1, -1)
104 | probs = torch.softmax(logits, dim=-1)
105 | if deterministic:
106 | state = probs
107 | else:
108 | state = sample_softmax(probs)
109 | state = dict(logits=logits, probs=probs.detach(), state=state.detach())
110 | else:
111 | mu = self.initial_mu.unsqueeze(0).expand(batch_size, -1)
112 | log_std = self.initial_log_std.unsqueeze(0).expand(batch_size, -1)
113 | if deterministic:
114 | state = mu
115 | else:
116 | state = mu + torch.exp(log_std) * torch.randn_like(log_std)
117 | state = dict(log_std=log_std, mu=mu, state=state.detach())
118 | if self.recurrent:
119 | state['rnn_state'] = self.initial_rnn_state.unsqueeze(0).expand(batch_size, -1)
120 | return state
121 |
122 | def get_state(self, state_dict, deterministic=False):
123 | """
124 | Get the latent vector by concatenating rnn state and flattened latent state.
125 | Args:
126 | state_dict: Dict containing 'state', 'rnn_state', 'mu', 'log_std', (OR 'logits')
127 | Returns:
128 | state: (B, D)
129 | """
130 | if deterministic:
131 | if self.discrete:
132 | state = state_dict['probs'].flatten(-2, -1)
133 | else:
134 | state = state_dict['mu']
135 | else:
136 | if self.discrete:
137 | state = state_dict['state'].flatten(-2, -1)
138 | else:
139 | state = state_dict['state']
140 | if self.recurrent:
141 | state = torch.cat([state, state_dict['rnn_state']], dim=-1)
142 | return state
143 |
144 | def compute_kl(self, state_p, state_q):
145 | """
146 | p log (p/q)
147 | """
148 | if self.discrete:
149 | logits_p = state_p['logits']
150 | p = state_p['probs']
151 | logits_q = state_q['logits']
152 | log_p = nn.functional.log_softmax(logits_p, dim=-1)
153 | log_q = nn.functional.log_softmax(logits_q, dim=-1)
154 | kld = (p * (log_p - log_q)).sum(dim=-1)
155 | else:
156 | mu_1 = state_p['mu']
157 | mu_2 = state_q['mu']
158 | log_std_1 = state_p['log_std']
159 | log_std_2 = state_q['log_std']
160 | var_1 = torch.exp(2 * log_std_1)
161 | var_2 = torch.exp(2 * log_std_2)
162 | kld = log_std_2 - log_std_1 + (var_1 + (mu_1 - mu_2) ** 2) / (2 * var_2) - 0.5
163 | kld = kld.sum(dim=-1)
164 | if self.free_kl > 0.0:
165 | kld = kld.clamp_min(self.free_kl)
166 | return kld
167 |
168 | def compute_neg_log_prob(self, state_p, state_q):
169 | """
170 | Negative Log prob of mean of p under the q distribution.
171 | """
172 | if self.discrete:
173 | raise NotImplementedError
174 | else:
175 | x = state_p['mu']
176 | mu = state_q['mu']
177 | log_std = state_q['log_std'].clamp(-10., 10.)
178 | var = torch.exp(2 * log_std)
179 | neg_log_prob = log_std + ((x - mu) ** 2) / (2 * var)
180 | neg_log_prob = neg_log_prob.mean(dim=-1)
181 | return neg_log_prob
182 |
183 | def compute_forward_dynamics_loss(self, state_p, state_q):
184 | if self.forward_dynamics_loss == 'kl':
185 | return self.compute_kl(state_p, state_q)
186 | elif self.forward_dynamics_loss == 'neg_log_prob':
187 | return self.compute_neg_log_prob(state_p, state_q)
188 | else:
189 | raise Exception('Unknown forward dynamics loss')
190 |
191 | def process_suff_stats(self, suff_stats):
192 | """
193 | Process the sufficient statistics (obtained from the output of posterior or prior networks).
194 | If discrete, the output is interpreted as logits.
195 | Otherwise, the output is chunked into two : mu and sigma.
196 | """
197 | # Get the prior state.
198 | res = {}
199 | if self.discrete:
200 | logits = suff_stats.view(-1, self.num_softmaxes, self.dims_per_softmax)
201 | probs = nn.functional.softmax(logits, dim=-1)
202 | state = self.sample_softmax(probs)
203 | res['logits'] = logits
204 | res['probs'] = probs
205 | else:
206 | mu, log_std = suff_stats.chunk(2, -1)
207 | log_std = log_std.clamp(-10.0, 10.0)
208 | state = mu + torch.exp(log_std) * torch.randn_like(log_std)
209 | res['mu'] = mu
210 | res['log_std'] = log_std
211 | res['state'] = state
212 | return res
213 |
214 | def imagine_step(self, prev_state_dict, prev_action, deterministic=False):
215 | """
216 | Args:
217 | prev_state_dict : Dict containing 'state', 'mu', etc (B, num_softmaxes, dims_per_softmax).
218 | prev_action: (B, action_dims) or None
219 | deterministic: If True, sample from the prev_state distribution, otherwise use the mode.
220 | Returns:
221 | prior: Dict containing 'stoch', 'deter', 'logits'
222 | """
223 |
224 | # prev_action is None at the first time step where we don't know what action led to the first state.
225 | # In this case, we just pass on the prior as out best guess.
226 | if prev_action is None:
227 | return prev_state_dict
228 |
229 | if deterministic:
230 | if self.discrete:
231 | prev_state = prev_state_dict['probs']
232 | else:
233 | prev_state = prev_state_dict['mu']
234 | else:
235 | prev_state = prev_state_dict['state']
236 | if self.discrete:
237 | prev_state = prev_state.flatten(-2, -1)
238 | x = torch.cat([prev_state, prev_action], dim=-1)
239 |
240 | res = {}
241 | if self.recurrent:
242 | # Step the RNN to get rnn state.
243 | prev_rnn_state = prev_state_dict['rnn_state']
244 | x = self.rnn_input_net(x)
245 | x = self.cell(x, prev_rnn_state)
246 | res['rnn_state'] = x
247 |
248 | # Get the prior state.
249 | x = self.prior_net(x)
250 | res.update(self.process_suff_stats(x))
251 | return res
252 |
253 | def forward(self, prev_state_dict, prev_action, obs_embed, deterministic=False):
254 | """
255 | Take one step of the dynamics, and return the updated prior and posterior.
256 | Args:
257 | prev_state_dict : Dict containing 'state' (B, num_softmaxes, dims_per_softmax), 'rnn_state' (B, deter_dims)
258 | prev_action: (B, action_dims)
259 | obs_embed: (B, obs_embed_dims)
260 | deterministic: If false, estimate the prior using samples from the prev_state distribution, otherwise use the mode.
261 | Returns:
262 | prior: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std')
263 | posterior: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std')
264 | """
265 | prior = self.imagine_step(prev_state_dict, prev_action, deterministic=deterministic)
266 | posterior = {}
267 | x = obs_embed
268 | if self.recurrent:
269 | rnn_state = prior['rnn_state']
270 | posterior['rnn_state'] = rnn_state
271 | x = torch.cat([x, rnn_state], dim=-1)
272 | x = self.posterior_net(x)
273 | posterior.update(self.process_suff_stats(x))
274 | return prior, posterior
275 |
276 | def rollout_prior(self, initial_state, actions, deterministic=False):
277 | """
278 | Rollout the dynamics given actions, starting from an initial state.
279 | Args:
280 | initial_state: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std') (B, ...)
281 | actions: (T, B, action_dims)
282 | deterministic: If true, use samples from the state distribution over time, otherwise use the mode.
283 | Return:
284 | state: Dict containing 'state', 'rnn_state', ('logits', OR 'mu', 'log_std') (T, B, ...)
285 | """
286 | T = actions.shape[0]
287 | state = initial_state
288 | state_list = []
289 | for t in range(T):
290 | state = self.imagine_step(state, actions[t], deterministic=deterministic)
291 | state_list.append(state)
292 | state = stack_tensor_dict_list(state_list)
293 | return state
294 |
295 |
296 | class WorldModel(nn.Module):
297 | def __init__(self, config, obs_dims, action_dims):
298 | super().__init__()
299 |
300 | # Obs encoder.
301 | if 'obs_encoder' in config:
302 | self.encoder = FCNet(config['obs_encoder'], obs_dims)
303 | obs_embed_dims = self.encoder.output_dims
304 | print('Observation feature encoder ---- Input dims {} Output dims {}'.format(obs_dims, obs_embed_dims))
305 | else:
306 | self.encoder = None
307 | obs_embed_dims = obs_dims
308 |
309 | # Dynamics Model.
310 | self.dynamics = Dynamics(config['dynamics'], obs_embed_dims, action_dims)
311 | state_dims = self.dynamics.state_dims
312 | print('Dynamics ---- obs dims {} action dims {} state dims {}'.format(obs_embed_dims, action_dims, state_dims))
313 |
314 | # Image predictor P(x_t | z_t, h_t)
315 | if 'obs_decoder' in config:
316 | self.decoder = FCNet(config['obs_decoder'], state_dims, out_features=obs_dims)
317 | print('Observation feature decoder ---- input dims {} output dims {}'.format(state_dims, obs_dims))
318 | else:
319 | self.decoder = None
320 |
321 | print('Reward predictor ---- input dims {}'.format(state_dims))
322 | # Reward predictor P(r | z_t, h_t)
323 | self.reward_net = FCNet(config['reward_net'], state_dims)
324 | self.reward_prediction_from_prior = config['reward_prediction_from_prior']
325 |
326 | if 'inverse_dynamics' in config:
327 | print('Inverse dynamics model ------ input dims 2 * {} action dims {}'.format(state_dims, action_dims))
328 | self.inv_dynamics = FCNet(config['inverse_dynamics'], 2 * state_dims, action_dims)
329 | else:
330 | self.inv_dynamics = None
331 |
332 | self.propagate_deterministic = config['propagate_deterministic']
333 | self.decode_deterministic = config['decode_deterministic']
334 | self.state_dims = state_dims
335 |
336 | if self.dynamics.forward_dynamics_loss == 'neg_log_prob':
337 | assert self.propagate_deterministic, "Posterior variance is not trained, propagate_deterministic should be True"
338 | self.apply(weight_init)
339 |
340 | def forward_one_step(self, obs_encoding, posterior=None, action=None, deterministic_latent=True):
341 | """
342 | Inputs:
343 | batch: Dict containing
344 | """
345 | if posterior is None:
346 | B = obs_encoding.size(0)
347 | posterior = self.dynamics.initial_state(B)
348 | obs_embed = self.encoder(obs_encoding)
349 | _, posterior = self.dynamics(posterior, action, obs_embed, deterministic=self.propagate_deterministic)
350 | latent = self.dynamics.get_state(posterior, deterministic=deterministic_latent)
351 | return latent, posterior
352 |
353 | def forward(self, batch):
354 | """
355 | Args:
356 | batch: Dict containing
357 | 'obs': (T, B, C, H, W)
358 | 'obs_features': (T, B, feat_dims)
359 | 'action': (T, B, action_dims)
360 | 'reward': (T, B)
361 | Return:
362 | output: Dict containing 'obs_recon'
363 | """
364 | obs_features = batch['obs_features']
365 | action = batch['action']
366 | T, B, _ = action.shape
367 |
368 | if self.encoder is not None:
369 | obs_features = self.encoder(obs_features) # (T, B, obs_embed_dims)
370 | prev_state = self.dynamics.initial_state(B)
371 | prior_list = []
372 | posterior_list = []
373 | for t in range(T):
374 | prev_action = action[t-1] if t > 0 else None
375 | current_obs = obs_features[t]
376 | prior, posterior = self.dynamics(prev_state, prev_action, current_obs, deterministic=self.propagate_deterministic)
377 | prior_list.append(prior)
378 | posterior_list.append(posterior)
379 | prev_state = posterior
380 | prior = stack_tensor_dict_list(prior_list)
381 | posterior = stack_tensor_dict_list(posterior_list)
382 | latent_post = self.dynamics.get_state(posterior, deterministic=self.decode_deterministic) # T, B, D
383 | latent_prior = self.dynamics.get_state(prior, deterministic=False)
384 | if self.reward_prediction_from_prior:
385 | reward_prediction = self.reward_net(latent_prior).squeeze(-1) # DBC uses latent prior.
386 | else:
387 | reward_prediction = self.reward_net(latent_post).squeeze(-1)
388 | outputs = dict(prior=prior,
389 | posterior=posterior,
390 | reward_prediction=reward_prediction)
391 | if self.decoder is not None:
392 | outputs['obs_features_recon_post'] = self.decoder(latent_post) # (T, B, dims)
393 | outputs['obs_features_recon_prior'] = self.decoder(latent_prior) # (T, B, dims)
394 | if self.inv_dynamics is not None:
395 | paired_states = torch.cat([latent_post[:-1], latent_post[1:]], dim=-1)
396 | action_prediction = self.inv_dynamics(paired_states)
397 | outputs['action_prediction'] = action_prediction
398 | return outputs
399 |
--------------------------------------------------------------------------------