├── figs ├── DMC.png ├── Atari_Full.png ├── framework.png ├── Atari_IQM_OG.png └── Atari_Distribution.png ├── Atari ├── scripts │ ├── __init__.py │ └── run.py ├── src │ ├── __init__.py │ ├── agent.py │ ├── utils.py │ ├── rlpyt_buffer.py │ ├── vit_modules.py │ ├── rlpyt_atari_env.py │ ├── masking_generator.py │ ├── algos.py │ └── rlpyt_utils.py ├── requirements.txt └── README.md ├── DMControl ├── requirements.txt ├── src │ ├── video.py │ ├── decoder.py │ ├── configs │ │ ├── cheetah_run.yaml │ │ ├── finger_spin.yaml │ │ ├── walker_walk.yaml │ │ ├── reacher_easy.yaml │ │ ├── ball_in_cup_catch.yaml │ │ └── cartpole_swingup.yaml │ ├── config.yaml │ ├── transition_model.py │ ├── logger.py │ ├── vit_modules.py │ ├── masking_generator.py │ ├── encoder.py │ ├── base_sac.py │ └── curl_sac.py └── README.md ├── CODE_OF_CONDUCT.md ├── LICENSE ├── SUPPORT.md ├── .gitignore ├── SECURITY.md ├── .github └── workflows │ └── codeql-analysis.yml ├── NOTICE.md └── README.md /figs/DMC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Mask-based-Latent-Reconstruction/HEAD/figs/DMC.png -------------------------------------------------------------------------------- /figs/Atari_Full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Mask-based-Latent-Reconstruction/HEAD/figs/Atari_Full.png -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Mask-based-Latent-Reconstruction/HEAD/figs/framework.png -------------------------------------------------------------------------------- /figs/Atari_IQM_OG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Mask-based-Latent-Reconstruction/HEAD/figs/Atari_IQM_OG.png -------------------------------------------------------------------------------- /figs/Atari_Distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Mask-based-Latent-Reconstruction/HEAD/figs/Atari_Distribution.png -------------------------------------------------------------------------------- /Atari/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- -------------------------------------------------------------------------------- /Atari/src/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | from gym.envs.registration import register 7 | 8 | register( 9 | id='atari-v0', 10 | entry_point='src.envs:AtariEnv', 11 | ) -------------------------------------------------------------------------------- /DMControl/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | absl-py 3 | pyparsing 4 | torchelastic==0.2.1 5 | torchvision==0.8.2 6 | wandb==0.10.14 7 | hydra-core 8 | termcolor 9 | mujoco_py 10 | git+git://github.com/deepmind/dm_control.git 11 | git+git://github.com/1nadequacy/dmc2gym.git 12 | tb-nightly 13 | imageio 14 | imageio-ffmpeg 15 | scikit-image 16 | kornia==0.4.1 17 | wandb 18 | tensorboard 19 | kornia 20 | # numpy==1.19 21 | timm -------------------------------------------------------------------------------- /Atari/requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.18.0 2 | atari-py==0.2.6 3 | torch==1.7.1 4 | torchelastic 5 | torchvision==0.8.2 6 | tensorboard==1.14.0 7 | wandb==0.10.14 8 | opencv-python==4.5.1.48 9 | scikit-learn==0.24.1 10 | scipy==1.6.0 11 | recordclass==0.14.3 12 | matplotlib==3.3.3 13 | kornia==0.4.1 14 | numpy 15 | pyprind 16 | -e git+https://github.com/astooke/rlpyt.git@b32d589d12d31ba3c8a9cfb7a3c85c6e350b2904#egg=rlpyt 17 | tqdm 18 | future==0.18.2 19 | timm 20 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /DMControl/src/video.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is borrowed from https://github.com/MishaLaskin/curl 3 | # -------------------------------------------------------- 4 | 5 | import os 6 | import imageio 7 | import numpy as np 8 | 9 | 10 | class VideoRecorder(object): 11 | def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): 12 | self.dir_name = dir_name 13 | self.height = height 14 | self.width = width 15 | self.camera_id = camera_id 16 | self.fps = fps 17 | self.frames = [] 18 | 19 | def init(self, enabled=True): 20 | self.frames = [] 21 | self.enabled = self.dir_name is not None and enabled 22 | 23 | def record(self, env): 24 | if self.enabled: 25 | try: 26 | frame = env.render(mode='rgb_array', 27 | height=self.height, 28 | width=self.width, 29 | camera_id=self.camera_id) 30 | except: 31 | frame = env.render(mode='rgb_array', ) 32 | 33 | self.frames.append(frame) 34 | 35 | def save(self, file_name): 36 | if self.enabled: 37 | path = os.path.join(self.dir_name, file_name) 38 | imageio.mimsave(path, self.frames, fps=self.fps) 39 | -------------------------------------------------------------------------------- /DMControl/README.md: -------------------------------------------------------------------------------- 1 | # MLR for DMControl 2 | ## Installation 3 | Install the requirements: 4 | ~~~ 5 | pip install -r requirements.txt 6 | ~~~ 7 | 8 | Uninstall glfw if your machine is core dumped: 9 | ~~~ 10 | pip uninstall glfw -y 11 | ~~~ 12 | Or simply add this line to `train.py` if your machine meets glfw errors: 13 | ~~~ 14 | os.environ['MUJOCO_GL'] = 'egl' 15 | ~~~ 16 | 17 | ## Usage 18 | Here we give the configuration files in the `./configs` folder for the six environments mentioned in the paper. 19 | 20 | ``` 21 | cd ./src 22 | python train.py --config-path ./configs --config-name cartpole_swingup jumps=15 \ 23 | seed=1 agent=mtm_sac num_env_steps=100000 work_dir=output wandb=false 24 | ``` 25 | 26 | Some important options in configuration files: 27 | * `agent`: mtm_sac is our MLR agent; 28 | * `jumps`: sequence length (For example, jumps=15 means we sample a trajectory from t=0:15, which will include 16 consecutive observations); 29 | * `num_env_step`: total environment steps. Note that environment steps = interaction steps * action repeat; 30 | * `mask_ratio`: mask ratio; 31 | * `patch_size`: spatial patch size; 32 | * `block_size`: temporal block size; 33 | * `work_dir`: output directory; 34 | * `wandb`: if True, then wandb visualization is turned on. 35 | 36 | 37 | ## Result 38 | We achieve the best result on both DMControl-100k and DMControl-500k benchmarks. Our result is averaged over 10 random seeds. 39 | 40 | ![image](../figs/DMC.png) 41 | 42 | ## Acknowledgement 43 | Our implementation on DMControl is partially based on [CURL](https://github.com/MishaLaskin/curl) by Michael Laskin & Aravind Srinivas. We sincerely thank the authors. -------------------------------------------------------------------------------- /DMControl/src/decoder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | 13 | 14 | class Decoder(nn.Module): 15 | def __init__(self, 16 | # obs_shape, 17 | feature_dim, 18 | # num_layers=2, 19 | num_filters=32, 20 | min_spat_size=11, 21 | ): 22 | super().__init__() 23 | 24 | self.num_filters = num_filters 25 | self.min_spat_size = min_spat_size 26 | self.fc_expand = nn.Linear(feature_dim, num_filters * min_spat_size * min_spat_size) 27 | 28 | 29 | self.deconv = nn.Sequential( 30 | nn.Conv2d(num_filters, num_filters*16, kernel_size=3, stride=1, padding=1), 31 | nn.PixelShuffle(4), 32 | nn.ReLU(True), 33 | 34 | nn.Conv2d(num_filters, num_filters*4, kernel_size=3, stride=1, padding=1), 35 | nn.PixelShuffle(2), 36 | nn.ReLU(True), 37 | 38 | nn.ReflectionPad2d(1), 39 | nn.Conv2d(in_channels=num_filters, out_channels=9, kernel_size=7, padding=0) 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.fc_expand(x) 44 | x = x.view(x.size(0), self.num_filters, self.min_spat_size, self.min_spat_size) 45 | x = self.deconv(x) 46 | x = (torch.tanh(x) + 1) / 2 47 | return x 48 | 49 | 50 | if __name__ == '__main__': 51 | decoder = Decoder(feature_dim=50, min_spat_size=12, num_filters=32) 52 | x = torch.randn(2, 50) 53 | y = decoder(x) 54 | print(y.size()) 55 | -------------------------------------------------------------------------------- /DMControl/src/configs/cheetah_run.yaml: -------------------------------------------------------------------------------- 1 | # environment 2 | env_name: cheetah/run 3 | domain_name: null 4 | task_name: null 5 | pre_transform_image_size: 100 6 | image_size: 84 7 | action_repeat: 4 8 | frame_stack: 3 9 | # replay buffer 10 | replay_buffer_capacity: 100000 11 | # train 12 | agent: ??? 13 | init_steps: 1000 14 | num_env_steps: 105000 # 500100 15 | batch_size: 512 16 | hidden_dim: 1024 17 | # eval 18 | eval_freq: 5000 19 | num_eval_episodes: 10 20 | # critic 21 | critic_lr: 2e-4 22 | critic_beta: 0.9 23 | critic_tau: 0.01 24 | critic_target_update_freq: 2 25 | # actor 26 | actor_lr: 2e-4 27 | actor_beta: 0.9 28 | actor_log_std_min: -10 29 | actor_log_std_max: 2 30 | actor_update_freq: 2 31 | # encoder 32 | encoder_type: pixel 33 | encoder_feature_dim: 50 34 | encoder_lr: 2e-4 35 | encoder_tau: 0.05 36 | num_layers: 4 37 | num_filters: 64 38 | curl_latent_dim: 128 39 | # sac 40 | discount: 0.99 41 | init_temperature: 0.1 42 | alpha_lr: 1e-4 43 | alpha_beta: 0.5 44 | # cycdm 45 | augmentation: ["crop", "intensity"] 46 | jumps: 6 47 | transition_model_type: 'deterministic' 48 | transition_model_layer_width: 512 49 | latent_dim: 100 50 | auxiliary_task_batch_size: 128 51 | time_offset: 0 52 | momentum_tau: 0.05 53 | aug_prob: 1.0 54 | num_aug_actions: 10 55 | loss_space: y 56 | bp_mode: gt 57 | cycle_steps: 6 58 | cycle_mode: fp+cycle # fp # fp+cycle # fp+bp+cycle 59 | fp_loss_weight: 1.0 60 | bp_loss_weight: 1.0 61 | rc_loss_weight: 0.0 62 | vc_loss_weight: 1.0 63 | reward_loss_weight: 0.0 64 | auxiliary_task_lr: 2e-4 65 | sigma: 0.1 66 | warmup: false 67 | # misc 68 | seed: null 69 | gpuid: null 70 | seed_and_gpuid: [-1, 0] 71 | work_dir: ??? 72 | save_tb: true 73 | save_buffer: false 74 | save_video: false 75 | save_model: true 76 | detach_encoder: false 77 | # log 78 | log_interval: 100 79 | # wandb 80 | wandb: false 81 | # MTP 82 | mask_ratio: 0.5 83 | patch_size: 10 84 | block_size: 8 85 | num_attn_layers: 2 -------------------------------------------------------------------------------- /DMControl/src/configs/finger_spin.yaml: -------------------------------------------------------------------------------- 1 | # environment 2 | env_name: finger/spin 3 | domain_name: null 4 | task_name: null 5 | pre_transform_image_size: 100 6 | image_size: 84 7 | action_repeat: 2 8 | frame_stack: 3 9 | # replay buffer 10 | replay_buffer_capacity: 100000 11 | # train 12 | agent: ??? 13 | init_steps: 1000 14 | num_env_steps: 105000 # 500100 15 | batch_size: 512 16 | hidden_dim: 1024 17 | # eval 18 | eval_freq: 5000 19 | num_eval_episodes: 10 20 | # critic 21 | critic_lr: 1e-3 22 | critic_beta: 0.9 23 | critic_tau: 0.01 24 | critic_target_update_freq: 2 25 | # actor 26 | actor_lr: 1e-3 27 | actor_beta: 0.9 28 | actor_log_std_min: -10 29 | actor_log_std_max: 2 30 | actor_update_freq: 2 31 | # encoder 32 | encoder_type: pixel 33 | encoder_feature_dim: 50 34 | encoder_lr: 1e-3 35 | encoder_tau: 0.05 36 | num_layers: 4 37 | num_filters: 64 38 | curl_latent_dim: 128 39 | # sac 40 | discount: 0.99 41 | init_temperature: 0.1 42 | alpha_lr: 1e-4 43 | alpha_beta: 0.5 44 | # cycdm 45 | augmentation: ["crop", "intensity"] 46 | jumps: 6 47 | transition_model_type: 'deterministic' 48 | transition_model_layer_width: 512 49 | latent_dim: 100 50 | auxiliary_task_batch_size: 128 51 | time_offset: 0 52 | momentum_tau: 0.05 53 | aug_prob: 1.0 54 | num_aug_actions: 10 55 | loss_space: y 56 | bp_mode: gt 57 | cycle_steps: 6 58 | cycle_mode: fp+cycle # fp # fp+cycle # fp+bp+cycle 59 | fp_loss_weight: 1.0 60 | bp_loss_weight: 1.0 61 | rc_loss_weight: 0.0 62 | vc_loss_weight: 1.0 63 | reward_loss_weight: 0.0 64 | auxiliary_task_lr: 1e-3 65 | sigma: 0.1 66 | warmup: false 67 | # misc 68 | seed: null 69 | gpuid: null 70 | seed_and_gpuid: [-1, 0] 71 | work_dir: ??? 72 | save_tb: true 73 | save_buffer: false 74 | save_video: false 75 | save_model: true 76 | detach_encoder: false 77 | # log 78 | log_interval: 100 79 | # wandb 80 | wandb: false 81 | # MTP 82 | mask_ratio: 0.5 83 | patch_size: 10 84 | block_size: 8 85 | num_attn_layers: 2 -------------------------------------------------------------------------------- /DMControl/src/configs/walker_walk.yaml: -------------------------------------------------------------------------------- 1 | # environment 2 | env_name: walker/walk 3 | domain_name: null 4 | task_name: null 5 | pre_transform_image_size: 100 6 | image_size: 84 7 | action_repeat: 2 8 | frame_stack: 3 9 | # replay buffer 10 | replay_buffer_capacity: 100000 11 | # train 12 | agent: ??? 13 | init_steps: 1000 14 | num_env_steps: 105000 # 500100 15 | batch_size: 512 16 | hidden_dim: 1024 17 | # eval 18 | eval_freq: 5000 19 | num_eval_episodes: 10 20 | # critic 21 | critic_lr: 1e-3 22 | critic_beta: 0.9 23 | critic_tau: 0.01 24 | critic_target_update_freq: 2 25 | # actor 26 | actor_lr: 1e-3 27 | actor_beta: 0.9 28 | actor_log_std_min: -10 29 | actor_log_std_max: 2 30 | actor_update_freq: 2 31 | # encoder 32 | encoder_type: pixel 33 | encoder_feature_dim: 50 34 | encoder_lr: 1e-3 35 | encoder_tau: 0.05 36 | num_layers: 4 37 | num_filters: 64 38 | curl_latent_dim: 128 39 | # sac 40 | discount: 0.99 41 | init_temperature: 0.1 42 | alpha_lr: 1e-4 43 | alpha_beta: 0.5 44 | # cycdm 45 | augmentation: ["crop", "intensity"] 46 | jumps: 6 47 | transition_model_type: 'deterministic' 48 | transition_model_layer_width: 512 49 | latent_dim: 100 50 | auxiliary_task_batch_size: 128 51 | time_offset: 0 52 | momentum_tau: 0.1 53 | aug_prob: 1.0 54 | num_aug_actions: 10 55 | loss_space: y 56 | bp_mode: gt 57 | cycle_steps: 6 58 | cycle_mode: fp+cycle # fp # fp+cycle # fp+bp+cycle 59 | fp_loss_weight: 1.0 60 | bp_loss_weight: 1.0 61 | rc_loss_weight: 0.0 62 | vc_loss_weight: 1.0 63 | reward_loss_weight: 0.0 64 | auxiliary_task_lr: 1e-3 65 | sigma: 0.1 66 | warmup: false 67 | # misc 68 | seed: null 69 | gpuid: null 70 | seed_and_gpuid: [-1, 0] 71 | work_dir: ??? 72 | save_tb: true 73 | save_buffer: false 74 | save_video: false 75 | save_model: true 76 | detach_encoder: false 77 | # log 78 | log_interval: 100 79 | # wandb 80 | wandb: false 81 | # MTP 82 | mask_ratio: 0.5 83 | patch_size: 10 84 | block_size: 8 85 | num_attn_layers: 2 -------------------------------------------------------------------------------- /DMControl/src/configs/reacher_easy.yaml: -------------------------------------------------------------------------------- 1 | # environment 2 | env_name: reacher/easy 3 | domain_name: null 4 | task_name: null 5 | pre_transform_image_size: 100 6 | image_size: 84 7 | action_repeat: 4 8 | frame_stack: 3 9 | # replay buffer 10 | replay_buffer_capacity: 100000 11 | # train 12 | agent: ??? 13 | init_steps: 1000 14 | num_env_steps: 105000 # 500100 15 | batch_size: 512 16 | hidden_dim: 1024 17 | # eval 18 | eval_freq: 5000 19 | num_eval_episodes: 10 20 | # critic 21 | critic_lr: 1e-3 22 | critic_beta: 0.9 23 | critic_tau: 0.01 24 | critic_target_update_freq: 2 25 | # actor 26 | actor_lr: 1e-3 27 | actor_beta: 0.9 28 | actor_log_std_min: -10 29 | actor_log_std_max: 2 30 | actor_update_freq: 2 31 | # encoder 32 | encoder_type: pixel 33 | encoder_feature_dim: 50 34 | encoder_lr: 1e-3 35 | encoder_tau: 0.05 36 | num_layers: 4 37 | num_filters: 64 38 | curl_latent_dim: 128 39 | # sac 40 | discount: 0.99 41 | init_temperature: 0.1 42 | alpha_lr: 1e-4 43 | alpha_beta: 0.5 44 | # cycdm 45 | augmentation: ["crop", "intensity"] 46 | jumps: 6 47 | transition_model_type: 'deterministic' 48 | transition_model_layer_width: 512 49 | latent_dim: 100 50 | auxiliary_task_batch_size: 128 51 | time_offset: 0 52 | momentum_tau: 0.05 53 | aug_prob: 1.0 54 | num_aug_actions: 10 55 | loss_space: y 56 | bp_mode: gt 57 | cycle_steps: 6 58 | cycle_mode: fp+cycle # fp # fp+cycle # fp+bp+cycle 59 | fp_loss_weight: 1.0 60 | bp_loss_weight: 1.0 61 | rc_loss_weight: 0.0 62 | vc_loss_weight: 1.0 63 | reward_loss_weight: 0.0 64 | auxiliary_task_lr: 1e-3 65 | sigma: 0.1 66 | warmup: false 67 | # misc 68 | seed: null 69 | gpuid: null 70 | seed_and_gpuid: [-1, 0] 71 | work_dir: ??? 72 | save_tb: true 73 | save_buffer: false 74 | save_video: false 75 | save_model: true 76 | detach_encoder: false 77 | # log 78 | log_interval: 100 79 | # wandb 80 | wandb: false 81 | # MTP 82 | mask_ratio: 0.5 83 | patch_size: 10 84 | block_size: 4 85 | num_attn_layers: 2 -------------------------------------------------------------------------------- /DMControl/src/configs/ball_in_cup_catch.yaml: -------------------------------------------------------------------------------- 1 | # environment 2 | env_name: ball_in_cup/catch 3 | domain_name: null 4 | task_name: null 5 | pre_transform_image_size: 100 6 | image_size: 84 7 | action_repeat: 4 8 | frame_stack: 3 9 | # replay buffer 10 | replay_buffer_capacity: 100000 11 | # train 12 | agent: ??? 13 | init_steps: 1000 14 | num_env_steps: 105000 # 500100 15 | batch_size: 512 16 | hidden_dim: 1024 17 | # eval 18 | eval_freq: 5000 19 | num_eval_episodes: 10 20 | # critic 21 | critic_lr: 1e-3 22 | critic_beta: 0.9 23 | critic_tau: 0.01 24 | critic_target_update_freq: 2 25 | # actor 26 | actor_lr: 1e-3 27 | actor_beta: 0.9 28 | actor_log_std_min: -10 29 | actor_log_std_max: 2 30 | actor_update_freq: 2 31 | # encoder 32 | encoder_type: pixel 33 | encoder_feature_dim: 50 34 | encoder_lr: 1e-3 35 | encoder_tau: 0.05 36 | num_layers: 4 37 | num_filters: 64 38 | curl_latent_dim: 128 39 | # sac 40 | discount: 0.99 41 | init_temperature: 0.1 42 | alpha_lr: 1e-4 43 | alpha_beta: 0.5 44 | # cycdm 45 | augmentation: ["crop", "intensity"] 46 | jumps: 6 47 | transition_model_type: 'deterministic' 48 | transition_model_layer_width: 512 49 | latent_dim: 100 50 | auxiliary_task_batch_size: 128 51 | time_offset: 0 52 | momentum_tau: 0.05 53 | aug_prob: 1.0 54 | num_aug_actions: 10 55 | loss_space: y 56 | bp_mode: gt 57 | cycle_steps: 6 58 | cycle_mode: fp+cycle # fp # fp+cycle # fp+bp+cycle 59 | fp_loss_weight: 1.0 60 | bp_loss_weight: 1.0 61 | rc_loss_weight: 0.0 62 | vc_loss_weight: 1.0 63 | reward_loss_weight: 0.0 64 | auxiliary_task_lr: 1e-3 65 | sigma: 0.1 66 | warmup: false 67 | # misc 68 | seed: null 69 | gpuid: null 70 | seed_and_gpuid: [-1, 0] 71 | work_dir: ??? 72 | save_tb: true 73 | save_buffer: false 74 | save_video: false 75 | save_model: true 76 | detach_encoder: false 77 | # log 78 | log_interval: 100 79 | # wandb 80 | wandb: false 81 | # MTP 82 | mask_ratio: 0.5 83 | patch_size: 10 84 | block_size: 8 85 | num_attn_layers: 2 -------------------------------------------------------------------------------- /DMControl/src/configs/cartpole_swingup.yaml: -------------------------------------------------------------------------------- 1 | # environment 2 | env_name: cartpole/swingup 3 | domain_name: null 4 | task_name: null 5 | pre_transform_image_size: 100 6 | image_size: 84 7 | action_repeat: 8 8 | frame_stack: 3 9 | # replay buffer 10 | replay_buffer_capacity: 100000 11 | # train 12 | agent: ??? 13 | init_steps: 1000 14 | num_env_steps: 105000 # 500100 15 | batch_size: 512 16 | hidden_dim: 1024 17 | # eval 18 | eval_freq: 5000 19 | num_eval_episodes: 10 20 | # critic 21 | critic_lr: 1e-3 22 | critic_beta: 0.9 23 | critic_tau: 0.01 24 | critic_target_update_freq: 2 25 | # actor 26 | actor_lr: 1e-3 27 | actor_beta: 0.9 28 | actor_log_std_min: -10 29 | actor_log_std_max: 2 30 | actor_update_freq: 2 31 | # encoder 32 | encoder_type: pixel 33 | encoder_feature_dim: 50 34 | encoder_lr: 1e-3 35 | encoder_tau: 0.05 36 | num_layers: 4 37 | num_filters: 64 38 | curl_latent_dim: 128 39 | # sac 40 | discount: 0.99 41 | init_temperature: 0.1 42 | alpha_lr: 1e-4 43 | alpha_beta: 0.5 44 | # cycdm 45 | augmentation: ["crop", "intensity"] 46 | jumps: 6 47 | transition_model_type: 'deterministic' 48 | transition_model_layer_width: 512 49 | latent_dim: 100 50 | auxiliary_task_batch_size: 128 51 | time_offset: 0 52 | momentum_tau: 0.05 53 | aug_prob: 1.0 54 | num_aug_actions: 10 55 | loss_space: y 56 | bp_mode: gt 57 | cycle_steps: 6 58 | cycle_mode: fp+cycle # fp # fp+cycle # fp+bp+cycle 59 | fp_loss_weight: 1.0 60 | bp_loss_weight: 1.0 61 | rc_loss_weight: 0.0 62 | vc_loss_weight: 1.0 63 | reward_loss_weight: 0.0 64 | auxiliary_task_lr: 1e-3 65 | sigma: 0.1 66 | warmup: false 67 | # misc 68 | seed: null 69 | gpuid: null 70 | seed_and_gpuid: [-1, 0] 71 | work_dir: ??? 72 | save_tb: true 73 | save_buffer: false 74 | save_video: false 75 | save_model: true 76 | detach_encoder: false 77 | # log 78 | log_interval: 100 79 | # wandb 80 | wandb: false 81 | # MTP 82 | mask_ratio: 0.5 83 | patch_size: 10 84 | block_size: 4 85 | num_attn_layers: 2 86 | -------------------------------------------------------------------------------- /Atari/README.md: -------------------------------------------------------------------------------- 1 | # MLR for Atari 2 | ## Installation 3 | Install the requirements: 4 | ~~~ 5 | pip install -r requirements.txt 6 | ~~~ 7 | 8 | ## Usage 9 | Train MLR: 10 | ~~~ 11 | python -m scripts.run --public --mlr-weight 1 --jumps 15 --game alien --seed 1 --final-eval-only 1 12 | ~~~ 13 | 14 | Some important options: 15 | * `jumps`: sequence length (For example, jumps=15 means we sample a trajectory from t=0:15, which will include 16 consecutive observations); 16 | * `mlr-weight`: MLR loss weight; 17 | * `game`: to specify the game; 18 | * `seed`: to specify the seed; 19 | * `final-eval-only`: evaluating agent at the end step ("1") or every 10k steps ("0"); 20 | * Other hyperparameters (input image size, mask ratio, block size & patch size): simply modify Line 368-371 in ./src/models.py. 21 | 22 | ## Result 23 | ### IQM and OG 24 | Comparison results on Atari-100k. Aggregate metrics (IQM and optimality gap (OG)) with 95\% confidence intervals (CIs) are used for the evaluation. Higher IQM and lower OG are better. 25 | 26 | ![image](../figs/Atari_IQM_OG.png) 27 | 28 | ### Full Scores 29 | Comparison on the Atari-100k benchmark. Our method reaches the highest scores on 11 out of 26 games and the best performance concerning the aggregate metrics, i.e., IQM and OG with 95\% confidence intervals. Our method augments Baseline with the MLR objective and achieves a 47.9\% relative improvement on IQM. 30 | 31 | ![image](../figs/Atari_Full.png) 32 | 33 | ### Performance Profiles 34 | 35 | Performance profiles on the Atari-100k benchmark based on human-normalized score distributions. Shaded regions indicates 95\% confidence bands. The score distribution of MLR is clearly superior to previous methods and Baseline. 36 | 37 | ![image](../figs/Atari_Distribution.png) 38 | 39 | ## Acknowledgement 40 | The implementation on Atari is partially based on [SPR](https://github.com/mila-iqia/spr) by Max Schwarzer & Ankesh Anand. We sincerely thank the authors. -------------------------------------------------------------------------------- /DMControl/src/config.yaml: -------------------------------------------------------------------------------- 1 | # environment 2 | env_name: cartpole/swingup 3 | domain_name: null 4 | task_name: null 5 | pre_transform_image_size: 100 6 | image_size: 84 7 | action_repeat: 8 8 | frame_stack: 3 9 | # replay buffer 10 | replay_buffer_capacity: 100000 11 | # train 12 | agent: cycdm_sac 13 | init_steps: 1000 14 | num_env_steps: 110000 # 500100 15 | batch_size: 512 16 | hidden_dim: 1024 17 | # eval 18 | eval_freq: 5000 19 | num_eval_episodes: 10 20 | # critic 21 | critic_lr: 1e-3 22 | critic_beta: 0.9 23 | critic_tau: 0.01 24 | critic_target_update_freq: 2 25 | # actor 26 | actor_lr: 1e-3 27 | actor_beta: 0.9 28 | actor_log_std_min: -10 29 | actor_log_std_max: 2 30 | actor_update_freq: 2 31 | # encoder 32 | encoder_type: pixel 33 | encoder_feature_dim: 50 34 | encoder_lr: 1e-3 35 | encoder_tau: 0.05 36 | num_layers: 4 37 | num_filters: 64 38 | curl_latent_dim: 128 39 | # sac 40 | discount: 0.99 41 | init_temperature: 0.1 42 | alpha_lr: 1e-4 43 | alpha_beta: 0.5 44 | # cycdm 45 | augmentation: ["crop", "intensity"] 46 | jumps: 6 47 | transition_model_type: 'deterministic' 48 | transition_model_layer_width: 512 49 | latent_dim: 100 50 | auxiliary_task_batch_size: 64 51 | time_offset: 0 52 | momentum_tau: 0.05 53 | aug_prob: 1.0 54 | num_aug_actions: 10 55 | loss_space: y 56 | bp_mode: gt 57 | cycle_steps: 6 58 | cycle_mode: fp+cycle # fp+cycle # fp+bp+cycle 59 | fp_loss_weight: 6.0 60 | bp_loss_weight: 1.0 61 | rc_loss_weight: 0.0 62 | vc_loss_weight: 1.0 63 | reward_loss_weight: 0.0 64 | auxiliary_task_lr: 1e-3 65 | warmup: false 66 | # misc 67 | seed: null 68 | gpuid: null 69 | seed_and_gpuid: [-1, 0] 70 | work_dir: /mnt/output/ 71 | save_tb: true 72 | save_buffer: false 73 | save_video: false 74 | save_model: true 75 | detach_encoder: false 76 | # log 77 | log_interval: 100 78 | # wandb 79 | wandb: false 80 | 81 | # SAR 82 | num_attn_layer: 2 83 | emb_action: true 84 | n_embd: 50 85 | adam_warmup_step: 1e3 # 0 for disabling warmup 86 | sar_loss_weight: 1 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '45 12 * * 0' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | # NOTICES AND INFORMATION 2 | 3 | Do Not Translate or Localize 4 | 5 | This software incorporates material from third parties. Microsoft makes certain 6 | open source code available at http://3rdpartysource.microsoft.com, or you may 7 | send a check or money order for US $5.00, including the product name, the open 8 | source component name, and version number, to: 9 | 10 | Source Code Compliance Team 11 | Microsoft Corporation 12 | One Microsoft Way 13 | Redmond, WA 98052 14 | USA 15 | 16 | Notwithstanding any other terms, you may reverse engineer this software to the 17 | extent required to debug changes to any libraries licensed under the GNU Lesser 18 | General Public License. 19 | 20 | =============================================================================== 21 | 22 | ## SPR 23 | 24 | **Source:** https://github.com/mila-iqia/spr 25 | 26 | **License:** 27 | ``` 28 | MIT License 29 | 30 | Copyright (c) 2019 Ankesh Anand 31 | 32 | Permission is hereby granted, free of charge, to any person obtaining a copy 33 | of this software and associated documentation files (the "Software"), to deal 34 | in the Software without restriction, including without limitation the rights 35 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | copies of the Software, and to permit persons to whom the Software is 37 | furnished to do so, subject to the following conditions: 38 | 39 | The above copyright notice and this permission notice shall be included in all 40 | copies or substantial portions of the Software. 41 | 42 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | SOFTWARE. 49 | ``` 50 | 51 | ## CURL 52 | 53 | **Source:** https://github.com/MishaLaskin/curl 54 | 55 | **License:** 56 | ``` 57 | MIT License 58 | 59 | Copyright (c) 2020 CURL (Contrastive Unsupervised Representations for Reinforcement Learning) Authors (https://arxiv.org/abs/2004.04136) 60 | 61 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 62 | 63 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 64 | 65 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 66 | ``` 67 | 68 | ## rlpyt 69 | 70 | **Source:** https://github.com/astooke/rlpyt 71 | 72 | **License:** 73 | ``` 74 | MIT License 75 | 76 | Copyright (c) 2019 astooke 77 | 78 | Permission is hereby granted, free of charge, to any person obtaining a copy 79 | of this software and associated documentation files (the "Software"), to deal 80 | in the Software without restriction, including without limitation the rights 81 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 82 | copies of the Software, and to permit persons to whom the Software is 83 | furnished to do so, subject to the following conditions: 84 | 85 | The above copyright notice and this permission notice shall be included in all 86 | copies or substantial portions of the Software. 87 | 88 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 89 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 90 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 91 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 92 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 93 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 94 | SOFTWARE. 95 | ``` -------------------------------------------------------------------------------- /DMControl/src/transition_model.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is borrowed from https://github.com/facebookresearch/deep_bisim4control/blob/main/transition_model.py 3 | # -------------------------------------------------------- 4 | 5 | import random 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | 13 | class DeterministicTransitionModel(nn.Module): 14 | def __init__(self, encoder_feature_dim, action_shape, layer_width): 15 | super().__init__() 16 | self.fc = nn.Linear(encoder_feature_dim + action_shape[0], layer_width) 17 | self.ln = nn.LayerNorm(layer_width) 18 | self.fc_mu = nn.Linear(layer_width, encoder_feature_dim) 19 | print("Deterministic transition model chosen.") 20 | 21 | def forward(self, x): 22 | x = self.fc(x) 23 | x = self.ln(x) 24 | x = torch.relu(x) 25 | 26 | mu = self.fc_mu(x) 27 | sigma = None 28 | return mu, sigma 29 | 30 | def sample_prediction(self, x): 31 | mu, sigma = self(x) 32 | return mu 33 | 34 | 35 | class ProbabilisticTransitionModel(nn.Module): 36 | def __init__(self, 37 | encoder_feature_dim, 38 | action_shape, 39 | layer_width, 40 | announce=True, 41 | max_sigma=1e1, 42 | min_sigma=1e-4): 43 | super().__init__() 44 | self.fc = nn.Linear(encoder_feature_dim + action_shape[0], layer_width) 45 | self.ln = nn.LayerNorm(layer_width) 46 | self.fc_mu = nn.Linear(layer_width, encoder_feature_dim) 47 | self.fc_sigma = nn.Linear(layer_width, encoder_feature_dim) 48 | 49 | self.max_sigma = max_sigma 50 | self.min_sigma = min_sigma 51 | assert (self.max_sigma >= self.min_sigma) 52 | if announce: 53 | print("Probabilistic transition model chosen.") 54 | 55 | def forward(self, x): 56 | x = self.fc(x) 57 | x = self.ln(x) 58 | x = torch.relu(x) 59 | 60 | mu = self.fc_mu(x) 61 | sigma = torch.sigmoid(self.fc_sigma(x)) # range (0, 1.) 62 | sigma = self.min_sigma + ( 63 | self.max_sigma - 64 | self.min_sigma) * sigma # scaled range (min_sigma, max_sigma) 65 | return mu, sigma 66 | 67 | def sample_prediction(self, x): 68 | mu, sigma = self(x) 69 | eps = torch.randn_like(sigma) 70 | return mu + sigma * eps 71 | 72 | 73 | class EnsembleOfProbabilisticTransitionModels(object): 74 | def __init__(self, 75 | encoder_feature_dim, 76 | action_shape, 77 | layer_width, 78 | ensemble_size=5): 79 | self.models = [ 80 | ProbabilisticTransitionModel(encoder_feature_dim, 81 | action_shape, 82 | layer_width, 83 | announce=False) 84 | for _ in range(ensemble_size) 85 | ] 86 | print("Ensemble of probabilistic transition models chosen.") 87 | 88 | def __call__(self, x): 89 | mu_sigma_list = [model.forward(x) for model in self.models] 90 | mus, sigmas = zip(*mu_sigma_list) 91 | mus, sigmas = torch.stack(mus), torch.stack(sigmas) 92 | return mus, sigmas 93 | 94 | def sample_prediction(self, x): 95 | model = random.choice(self.models) 96 | return model.sample_prediction(x) 97 | 98 | def to(self, device): 99 | for model in self.models: 100 | model.to(device) 101 | return self 102 | 103 | def parameters(self): 104 | list_of_parameters = [ 105 | list(model.parameters()) for model in self.models 106 | ] 107 | parameters = [p for ps in list_of_parameters for p in ps] 108 | return parameters 109 | 110 | 111 | _AVAILABLE_TRANSITION_MODELS = { 112 | '': DeterministicTransitionModel, 113 | 'deterministic': DeterministicTransitionModel, 114 | 'probabilistic': ProbabilisticTransitionModel, 115 | 'ensemble': EnsembleOfProbabilisticTransitionModels 116 | } 117 | 118 | 119 | def make_transition_model(transition_model_type, 120 | encoder_feature_dim, 121 | action_shape, 122 | layer_width=512): 123 | assert transition_model_type in _AVAILABLE_TRANSITION_MODELS 124 | return _AVAILABLE_TRANSITION_MODELS[transition_model_type]( 125 | encoder_feature_dim, action_shape, layer_width) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mask-based Latent Reconstruction for Reinforcement Learning 2 | 3 | This is the official implementation of *[Masked-based Latent Reconstruction for Reinforcement Learning](https://arxiv.org/abs/2201.12096)* (accepted by NeurIPS 2022), which outperforms the state-of-the-art sample-efficient reinforcement learning methods such as [CURL](https://arxiv.org/abs/2004.04136), [DrQ](https://arxiv.org/abs/2004.13649), [SPR](https://openreview.net/forum?id=uCQfPZwRaUu), [PlayVirtual](https://arxiv.org/abs/2106.04152), etc. 4 | 5 | - [arXiv](https://openreview.net/forum?id=GSHFVNejxs7&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DNeurIPS.cc%2F2021%2FConference%2FAuthors%23your-submissions)) 6 | - [OpenReview](https://openreview.net/forum?id=-zlJOVc580) 7 | - [SlidesLive](https://recorder-v3.slideslive.com/#/share?share=74702&s=2295c61d-8048-439f-a718-54adb5b8b629) 8 | 9 | ## Abstract 10 | For deep reinforcement learning (RL) from pixels, learning effective state representations is crucial for achieving high performance. However, in practice, limited experience and high-dimensional inputs prevent effective representation learning. To address this, motivated by the success of mask-based modeling in other research fields, we introduce mask-based reconstruction to promote state representation learning in RL. Specifically, we propose a simple yet effective self-supervised method, Mask-based Latent Reconstruction (MLR), to predict complete state representations in the latent space from the observations with spatially and temporally masked pixels. MLR enables better use of context information when learning state representations to make them more informative, which facilitates the training of RL agents. Extensive experiments show that our MLR significantly improves the sample efficiency in RL and outperforms the state-of-the-art sample-efficient RL methods on multiple continuous and discrete control benchmarks. 11 | 12 | ## Framework 13 | 14 | ![image](./figs/framework.png) 15 | 16 | Figure 1. The framework of the proposed MLR. We perform a random spatial-temporal masking (i.e., *cube* masking) on the sequence of consecutive observations in the pixel space. The masked observations are encoded to be the latent states through an online encoder. We further introduce a predictive latent decoder to decode/predict the latent states conditioned on the corresponding action sequence and temporal positional embeddings. Our method trains the networks to reconstruct the information available in the missing contents in an appropriate *latent* space using a cosine similarity based distance metric applied between the predicted features of the reconstructed states and the target features inferred from original observations by momentum networks. 17 | 18 | 19 | ## Run MLR 20 | We provide codes for two benchmarks: Atari and DMControl. 21 | ~~~ 22 | . 23 | ├── Atari 24 | | ├── README.md 25 | | └── ... 26 | |── DMControl 27 | | ├── README.md 28 | | └── ... 29 | ├── CODE_OF_CONDUCT.md 30 | ├── LICENSE 31 | ├── README.md 32 | ├── SUPPORT.md 33 | └── SECURITY.md 34 | ~~~ 35 | 36 | Run Atari code: enter ./Atari for more information. 37 | ~~~ 38 | cd ./Atari 39 | ~~~ 40 | Run DMControl code: enter ./DMControl for more information. 41 | ~~~ 42 | cd ./DMControl 43 | ~~~ 44 | 45 | ## Citation 46 | Please use the following BibTeX to cite our work. 47 | ``` 48 | @article{yu2022mask, 49 | title={Mask-based latent reconstruction for reinforcement learning}, 50 | author={Yu, Tao and Zhang, Zhizheng and Lan, Cuiling and Lu, Yan and Chen, Zhibo}, 51 | journal={Advances in Neural Information Processing Systems}, 52 | volume={35}, 53 | pages={25117--25131}, 54 | year={2022} 55 | } 56 | ``` 57 | 58 | ## Contributing 59 | 60 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 61 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 62 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 63 | 64 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 65 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 66 | provided by the bot. You will only need to do this once across all repos using our CLA. 67 | 68 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 69 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 70 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 71 | 72 | ## Trademarks 73 | 74 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 75 | trademarks or logos is subject to and must follow 76 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 77 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 78 | Any use of third-party trademarks or logos are subject to those third-party's policies. 79 | -------------------------------------------------------------------------------- /Atari/src/agent.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import copy 7 | import torch 8 | from rlpyt.agents.dqn.atari.atari_catdqn_agent import AtariCatDqnAgent 9 | from rlpyt.utils.buffer import buffer_to 10 | from rlpyt.utils.collections import namedarraytuple 11 | AgentInputs = namedarraytuple("AgentInputs", 12 | ["observation", "prev_action", "prev_reward"]) 13 | AgentInfo = namedarraytuple("AgentInfo", "p") 14 | AgentStep = namedarraytuple("AgentStep", ["action", "agent_info"]) 15 | 16 | 17 | class PVAgent(AtariCatDqnAgent): 18 | """Agent for Categorical DQN algorithm with search.""" 19 | 20 | def __init__(self, eval=False, **kwargs): 21 | """Standard init, and set the number of probability atoms (bins).""" 22 | super().__init__(**kwargs) 23 | self.eval = eval 24 | 25 | def __call__(self, observation, prev_action, prev_reward, train=False): 26 | """Returns Q-values for states/observations (with grad).""" 27 | if train: 28 | model_inputs = buffer_to((observation, prev_action, prev_reward), 29 | device=self.device) 30 | return self.model(*model_inputs, train=train) 31 | else: 32 | prev_action = self.distribution.to_onehot(prev_action) 33 | model_inputs = buffer_to((observation, prev_action, prev_reward), 34 | device=self.device) 35 | return self.model(*model_inputs).cpu() 36 | 37 | def initialize(self, 38 | env_spaces, 39 | share_memory=False, 40 | global_B=1, 41 | env_ranks=None): 42 | super().initialize(env_spaces, share_memory, global_B, env_ranks) 43 | # Overwrite distribution. 44 | self.search = SPRActionSelection(self.model, self.distribution) 45 | 46 | def to_device(self, cuda_idx=None): 47 | """Moves the model to the specified cuda device, if not ``None``. If 48 | sharing memory, instantiates a new model to preserve the shared (CPU) 49 | model. Agents with additional model components (beyond 50 | ``self.model``) for action-selection or for use during training should 51 | extend this method to move those to the device, as well. 52 | 53 | Typically called in the runner during startup. 54 | """ 55 | super().to_device(cuda_idx) 56 | self.search.to_device(cuda_idx) 57 | self.search.network = self.model 58 | 59 | def eval_mode(self, itr): 60 | """Extend method to set epsilon for evaluation, using 1 for 61 | pre-training eval.""" 62 | super().eval_mode(itr) 63 | self.search.epsilon = self.distribution.epsilon 64 | self.search.network.head.set_sampling(False) 65 | self.itr = itr 66 | 67 | def sample_mode(self, itr): 68 | """Extend method to set epsilon for sampling (including annealing).""" 69 | super().sample_mode(itr) 70 | self.search.epsilon = self.distribution.epsilon 71 | self.search.network.head.set_sampling(True) 72 | self.itr = itr 73 | 74 | def train_mode(self, itr): 75 | super().train_mode(itr) 76 | self.search.network.head.set_sampling(True) 77 | self.itr = itr 78 | 79 | @torch.no_grad() 80 | def step(self, observation, prev_action, prev_reward): 81 | """Compute the discrete distribution for the Q-value for each 82 | action for each state/observation (no grad).""" 83 | action, p = self.search.run(observation.to(self.search.device)) 84 | p = p.cpu() 85 | action = action.cpu() 86 | 87 | agent_info = AgentInfo(p=p) 88 | action, agent_info = buffer_to((action, agent_info), device="cpu") 89 | return AgentStep(action=action, agent_info=agent_info) 90 | 91 | 92 | class SPRActionSelection(torch.nn.Module): 93 | def __init__(self, network, distribution, device="cpu"): 94 | super().__init__() 95 | self.network = network 96 | self.epsilon = distribution._epsilon 97 | self.device = device 98 | self.first_call = True 99 | 100 | def to_device(self, idx): 101 | self.device = idx 102 | 103 | @torch.no_grad() 104 | def run(self, obs): 105 | while len(obs.shape) <= 4: 106 | obs.unsqueeze_(0) 107 | obs = obs.to(self.device).float() / 255. 108 | 109 | value = self.network.select_action(obs) 110 | action = self.select_action(value) 111 | # Stupid, stupid hack because rlpyt does _not_ handle batch_b=1 well. 112 | if self.first_call: 113 | action = action.squeeze() 114 | self.first_call = False 115 | return action, value.squeeze() 116 | 117 | def select_action(self, value): 118 | """Input can be shaped [T,B,Q] or [B,Q], and vector epsilon of length 119 | B will apply across the Batch dimension (same epsilon for all T).""" 120 | arg_select = torch.argmax(value, dim=-1) 121 | mask = torch.rand(arg_select.shape, device=value.device) < self.epsilon 122 | arg_rand = torch.randint(low=0, high=value.shape[-1], size=(mask.sum(),), device=value.device) 123 | arg_select[mask] = arg_rand 124 | return arg_select 125 | -------------------------------------------------------------------------------- /Atari/src/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import copy 7 | from rlpyt.experiments.configs.atari.dqn.atari_dqn import configs 8 | 9 | 10 | def count_parameters(model): 11 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 12 | 13 | 14 | class dummy_context_mgr: 15 | def __enter__(self): 16 | return None 17 | 18 | def __exit__(self, exc_type, exc_value, traceback): 19 | return False 20 | 21 | 22 | def set_config(args, game): 23 | # TODO: Use Hydra to manage configs 24 | ''' e.g. "algo" args will be valid in algos.py ''' 25 | config = configs['ernbw'] 26 | config['env']['game'] = game 27 | config["env"]["grayscale"] = args.grayscale 28 | config["env"]["num_img_obs"] = args.framestack 29 | config["eval_env"]["game"] = config["env"]["game"] 30 | config["eval_env"]["grayscale"] = args.grayscale 31 | config["eval_env"]["num_img_obs"] = args.framestack 32 | config['env']['imagesize'] = args.imagesize 33 | config['eval_env']['imagesize'] = args.imagesize 34 | config['env']['seed'] = args.seed 35 | config['eval_env']['seed'] = args.seed 36 | config["model"]["dueling"] = bool(args.dueling) 37 | config["algo"]["min_steps_learn"] = args.min_steps_learn 38 | config["algo"]["n_step_return"] = args.n_step 39 | config["algo"]["batch_size"] = args.batch_size 40 | config["algo"]["learning_rate"] = 0.0001 41 | config['algo']['replay_ratio'] = args.replay_ratio 42 | config['algo']['target_update_interval'] = args.target_update_interval 43 | config['algo']['target_update_tau'] = args.target_update_tau 44 | config['algo']['eps_steps'] = args.eps_steps 45 | config["algo"]["clip_grad_norm"] = args.max_grad_norm 46 | config['algo']['pri_alpha'] = 0.5 47 | config['algo']['pri_beta_steps'] = int(10e4) 48 | config['optim']['eps'] = 0.00015 49 | config["sampler"]["eval_max_trajectories"] = 100 50 | config["sampler"]["eval_n_envs"] = 100 51 | config["sampler"]["eval_max_steps"] = 100*28000 # 28k is just a safe ceiling 52 | config['sampler']['batch_B'] = args.batch_b 53 | config['sampler']['batch_T'] = args.batch_t 54 | 55 | config['agent']['eps_init'] = args.eps_init 56 | config['agent']['eps_final'] = args.eps_final 57 | config["model"]["noisy_nets_std"] = args.noisy_nets_std 58 | config["model"]["cycle_step"] = args.cycle_step 59 | config["model"]["space"] = args.space 60 | config["model"]["real_cycle"] = args.real_cycle 61 | config["model"]["virtual_cycle"] = args.virtual_cycle 62 | config["model"]["aug_num"] = args.aug_num 63 | config["model"]["fp"] = args.fp 64 | config["model"]["bp"] = args.bp 65 | config["model"]["bp_mode"] = args.bp_mode 66 | config["model"]["aug_type"] = args.aug_type 67 | 68 | if args.noisy_nets: 69 | config['agent']['eps_eval'] = 0.001 70 | 71 | # New SPR Arguments 72 | config["model"]["imagesize"] = args.imagesize 73 | config["model"]["jumps"] = args.jumps 74 | config["model"]["dynamics_blocks"] = args.dynamics_blocks 75 | config["model"]["spr"] = args.spr 76 | config["model"]["noisy_nets"] = args.noisy_nets 77 | config["model"]["momentum_encoder"] = args.momentum_encoder 78 | config["model"]["shared_encoder"] = args.shared_encoder 79 | config["model"]["local_spr"] = args.local_spr 80 | config["model"]["global_spr"] = args.global_spr 81 | config["model"]["distributional"] = args.distributional 82 | config["model"]["renormalize"] = args.renormalize 83 | config["model"]["norm_type"] = args.norm_type 84 | config["model"]["augmentation"] = args.augmentation 85 | config["model"]["q_l1_type"] = args.q_l1_type 86 | config["model"]["dropout"] = args.dropout 87 | config["model"]["time_offset"] = args.time_offset 88 | config["model"]["aug_prob"] = args.aug_prob 89 | config["model"]["target_augmentation"] = args.target_augmentation 90 | config["model"]["eval_augmentation"] = args.eval_augmentation 91 | config["model"]["classifier"] = args.classifier 92 | config["model"]["final_classifier"] = args.final_classifier 93 | config['model']['momentum_tau'] = args.momentum_tau 94 | config["model"]["dqn_hidden_size"] = args.dqn_hidden_size 95 | config["model"]["model_rl"] = args.model_rl_weight 96 | config["model"]["residual_tm"] = args.residual_tm 97 | config["algo"]["model_rl_weight"] = args.model_rl_weight 98 | config["algo"]["reward_loss_weight"] = args.reward_loss_weight 99 | config["algo"]["model_spr_weight"] = args.model_spr_weight 100 | config["algo"]["t0_spr_loss_weight"] = args.t0_spr_loss_weight 101 | config["algo"]["time_offset"] = args.time_offset 102 | config["algo"]["distributional"] = args.distributional 103 | config["algo"]["delta_clip"] = args.delta_clip 104 | config["algo"]["prioritized_replay"] = args.prioritized_replay 105 | config["algo"]["rc_weight"] = args.rc_weight 106 | config["algo"]["vc_weight"] = args.vc_weight 107 | config["algo"]["fp_weight"] = args.fp_weight 108 | config["algo"]["bp_weight"] = args.bp_weight 109 | config["algo"]["mlr_weight"] = args.mlr_weight 110 | config["algo"]["cycle_jumps"] = args.cycle_step 111 | config["algo"]["warmup"] = args.warmup 112 | config["algo"]["bp_warm"] = args.bp_warm 113 | 114 | 115 | return config -------------------------------------------------------------------------------- /Atari/src/rlpyt_buffer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is borrowed from https://github.com/mila-iqia 3 | # -------------------------------------------------------- 4 | 5 | from __future__ import division 6 | import torch 7 | 8 | from rlpyt.replays.sequence.prioritized import SamplesFromReplayPri 9 | 10 | from rlpyt.replays.sequence.n_step import SamplesFromReplay 11 | from rlpyt.replays.sequence.frame import AsyncPrioritizedSequenceReplayFrameBuffer, \ 12 | AsyncUniformSequenceReplayFrameBuffer, PrioritizedSequenceReplayFrameBuffer 13 | from rlpyt.utils.buffer import torchify_buffer, numpify_buffer 14 | from rlpyt.utils.collections import namedarraytuple 15 | from rlpyt.utils.misc import extract_sequences 16 | import traceback 17 | 18 | PrioritizedSamples = namedarraytuple("PrioritizedSamples", 19 | ["samples", "priorities"]) 20 | SamplesToBuffer = namedarraytuple("SamplesToBuffer", 21 | ["observation", "action", "reward", "done", "policy_probs", "value"]) 22 | SamplesFromReplayExt = namedarraytuple("SamplesFromReplayPriExt", 23 | SamplesFromReplay._fields + ("values", "age")) 24 | SamplesFromReplayPriExt = namedarraytuple("SamplesFromReplayPriExt", 25 | SamplesFromReplayPri._fields + ("values", "age")) 26 | EPS = 1e-6 27 | 28 | 29 | def samples_to_buffer(observation, action, reward, done, policy_probs, value, priorities=None): 30 | samples = SamplesToBuffer( 31 | observation=observation, 32 | action=action, 33 | reward=reward, 34 | done=done, 35 | policy_probs=policy_probs, 36 | value=value 37 | ) 38 | if priorities is not None: 39 | return PrioritizedSamples(samples=samples, 40 | priorities=priorities) 41 | else: 42 | return samples 43 | 44 | class AsyncUniformSequenceReplayFrameBufferExtended(AsyncUniformSequenceReplayFrameBuffer): 45 | """ 46 | Extends AsyncPrioritizedSequenceReplayFrameBuffer to return policy_logits and values too during sampling. 47 | """ 48 | def sample_batch(self, batch_B): 49 | while True: 50 | try: 51 | self._async_pull() # Updates from writers. 52 | batch_T = self.batch_T 53 | T_idxs, B_idxs = self.sample_idxs(batch_B, batch_T) 54 | sampled_indices = True 55 | if self.rnn_state_interval > 1: 56 | T_idxs = T_idxs * self.rnn_state_interval 57 | batch = self.extract_batch(T_idxs, B_idxs, self.batch_T) 58 | 59 | except Exception as _: 60 | print("FAILED TO LOAD BATCH") 61 | if sampled_indices: 62 | print("B_idxs:", B_idxs, flush=True) 63 | print("T_idxs:", T_idxs, flush=True) 64 | print("Batch_T:", self.batch_T, flush=True) 65 | print("Buffer T:", self.T, flush=True) 66 | 67 | elapsed_iters = self.t + self.T - T_idxs % self.T 68 | elapsed_samples = self.B*(elapsed_iters) 69 | values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T+self.n_step_return+1)) 70 | batch = SamplesFromReplayExt(*batch, values=values, age=elapsed_samples) 71 | if self.batch_T > 1: 72 | batch = self.sanitize_batch(batch) 73 | return batch 74 | 75 | def sanitize_batch(self, batch): 76 | has_dones, inds = torch.max(batch.done, 0) 77 | for i, (has_done, ind) in enumerate(zip(has_dones, inds)): 78 | if not has_done: 79 | continue 80 | batch.all_observation[ind+1:, i] = batch.all_observation[ind, i] 81 | batch.all_reward[ind+1:, i] = 0 82 | batch.return_[ind+1:, i] = 0 83 | batch.done_n[ind+1:, i] = True 84 | batch.values[ind+1:, i] = 0 85 | return batch 86 | 87 | 88 | class AsyncPrioritizedSequenceReplayFrameBufferExtended(AsyncPrioritizedSequenceReplayFrameBuffer): 89 | """ 90 | Extends AsyncPrioritizedSequenceReplayFrameBuffer to return policy_logits and values too during sampling. 91 | """ 92 | def sample_batch(self, batch_B): 93 | while True: 94 | try: 95 | self._async_pull() # Updates from writers. 96 | (T_idxs, B_idxs), priorities = self.priority_tree.sample( 97 | batch_B, unique=self.unique) 98 | sampled_indices = True 99 | if self.rnn_state_interval > 1: 100 | T_idxs = T_idxs * self.rnn_state_interval 101 | 102 | batch = self.extract_batch(T_idxs, B_idxs, self.batch_T) 103 | 104 | except Exception as _: 105 | print("FAILED TO LOAD BATCH") 106 | traceback.print_exc() 107 | if sampled_indices: 108 | print("B_idxs:", B_idxs, flush=True) 109 | print("T_idxs:", T_idxs, flush=True) 110 | print("Batch_T:", self.batch_T, flush=True) 111 | print("Buffer T:", self.T, flush=True) 112 | 113 | is_weights = (1. / (priorities + 1e-5)) ** self.beta 114 | is_weights /= max(is_weights) # Normalize. 115 | is_weights = torchify_buffer(is_weights).float() 116 | 117 | elapsed_iters = self.t + self.T - T_idxs % self.T 118 | elapsed_samples = self.B*(elapsed_iters) 119 | values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T+self.n_step_return+1)) 120 | batch = SamplesFromReplayPriExt(*batch, 121 | values=values, 122 | is_weights=is_weights, 123 | age=elapsed_samples) 124 | if self.batch_T > 1: 125 | batch = self.sanitize_batch(batch) 126 | return batch 127 | 128 | def sanitize_batch(self, batch): 129 | has_dones, inds = torch.max(batch.done, 0) 130 | for i, (has_done, ind) in enumerate(zip(has_dones, inds)): 131 | if not has_done: 132 | continue 133 | batch.all_observation[ind+1:, i] = batch.all_observation[ind, i] 134 | batch.all_reward[ind+1:, i] = 0 135 | batch.return_[ind+1:, i] = 0 136 | batch.done_n[ind+1:, i] = True 137 | batch.values[ind+1:, i] = 0 138 | return batch 139 | -------------------------------------------------------------------------------- /DMControl/src/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import json 7 | import os 8 | import shutil 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import torch 13 | import torchvision 14 | import wandb 15 | from termcolor import colored 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | FORMAT_CONFIG = { 19 | 'rl': { 20 | 'train': [('episode', 'E', 'int'), ('step', 'S', 'int'), 21 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), 22 | ('batch_reward', 'BR', 'float'), 23 | ('actor_loss', 'A_LOSS', 'float'), 24 | ('critic_loss', 'C_LOSS', 'float'), 25 | ('curl_loss', 'CURL_LOSS', 'float'), 26 | ('spr_loss', 'SPR_LOSS', 'float'), 27 | ('idm_loss', 'IDM_LOSS', 'float'), 28 | ('cycdm_loss', 'CYCDM_LOSS', 'float'), 29 | ], 30 | 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')] 31 | } 32 | } 33 | 34 | 35 | class AverageMeter(object): 36 | def __init__(self): 37 | self._sum = 0 38 | self._count = 0 39 | 40 | def update(self, value, n=1): 41 | self._sum += value 42 | self._count += n 43 | 44 | def value(self): 45 | return self._sum / max(1, self._count) 46 | 47 | 48 | class MetersGroup(object): 49 | def __init__(self, file_name, formating): 50 | self._file_name = file_name 51 | if os.path.exists(file_name): 52 | os.remove(file_name) 53 | self._formating = formating 54 | self._meters = defaultdict(AverageMeter) 55 | 56 | def log(self, key, value, n=1): 57 | self._meters[key].update(value, n) 58 | 59 | def _prime_meters(self): 60 | data = dict() 61 | for key, meter in self._meters.items(): 62 | if key.startswith('train'): 63 | key = key[len('train') + 1:] 64 | else: 65 | key = key[len('eval') + 1:] 66 | key = key.replace('/', '_') 67 | data[key] = meter.value() 68 | return data 69 | 70 | def _dump_to_file(self, data): 71 | with open(self._file_name, 'a') as f: 72 | f.write(json.dumps(data) + '\n') 73 | 74 | def _format(self, key, value, ty): 75 | template = '%s: ' 76 | if ty == 'int': 77 | template += '%d' 78 | elif ty == 'float': 79 | template += '%.04f' 80 | elif ty == 'time': 81 | template += '%.01f s' 82 | else: 83 | raise 'invalid format type: %s' % ty 84 | return template % (key, value) 85 | 86 | def _dump_to_console(self, data, prefix): 87 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 88 | pieces = ['{:5}'.format(prefix)] 89 | for key, disp_key, ty in self._formating: 90 | value = data.get(key, 0) 91 | pieces.append(self._format(disp_key, value, ty)) 92 | print('| %s' % (' | '.join(pieces))) 93 | 94 | def dump(self, step, prefix): 95 | if len(self._meters) == 0: 96 | return 97 | data = self._prime_meters() 98 | data['step'] = step 99 | self._dump_to_file(data) 100 | self._dump_to_console(data, prefix) 101 | self._meters.clear() 102 | 103 | 104 | class Logger(object): 105 | def __init__(self, log_dir, use_tb=True, use_wandb=False, config='rl'): 106 | self._log_dir = log_dir 107 | if use_tb: 108 | tb_dir = os.path.join(log_dir, 'tb') 109 | if os.path.exists(tb_dir): 110 | shutil.rmtree(tb_dir) 111 | self._sw = SummaryWriter(tb_dir) 112 | else: 113 | self._sw = None 114 | self._train_mg = MetersGroup(os.path.join(log_dir, 'train_info.log'), 115 | formating=FORMAT_CONFIG[config]['train']) 116 | self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval_info.log'), 117 | formating=FORMAT_CONFIG[config]['eval']) 118 | self.use_wandb = use_wandb 119 | 120 | def _try_sw_log(self, key, value, step): 121 | if self._sw is not None: 122 | self._sw.add_scalar(key, value, step) 123 | 124 | def _try_sw_log_image(self, key, image, step): 125 | if self._sw is not None: 126 | assert image.dim() == 3 127 | grid = torchvision.utils.make_grid(image.unsqueeze(1)) 128 | self._sw.add_image(key, grid, step) 129 | 130 | def _try_sw_log_video(self, key, frames, step): 131 | if self._sw is not None: 132 | frames = torch.from_numpy(np.array(frames)) 133 | frames = frames.unsqueeze(0) 134 | self._sw.add_video(key, frames, step, fps=30) 135 | 136 | def _try_sw_log_histogram(self, key, histogram, step): 137 | if self._sw is not None: 138 | self._sw.add_histogram(key, histogram, step) 139 | 140 | def log(self, key, value, step, n=1): 141 | assert key.startswith('train') or key.startswith('eval') 142 | if type(value) == torch.Tensor: 143 | value = value.item() 144 | self._try_sw_log(key, value / n, step) 145 | mg = self._train_mg if key.startswith('train') else self._eval_mg 146 | mg.log(key, value, n) 147 | if self.use_wandb: 148 | wandb.log({key: value}, step=step) 149 | 150 | def log_param(self, key, param, step): 151 | self.log_histogram(key + '_w', param.weight.data, step) 152 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 153 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 154 | if hasattr(param, 'bias'): 155 | self.log_histogram(key + '_b', param.bias.data, step) 156 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 157 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 158 | 159 | def log_image(self, key, image, step): 160 | assert key.startswith('train') or key.startswith('eval') 161 | self._try_sw_log_image(key, image, step) 162 | 163 | def log_video(self, key, frames, step): 164 | assert key.startswith('train') or key.startswith('eval') 165 | self._try_sw_log_video(key, frames, step) 166 | 167 | def log_histogram(self, key, histogram, step): 168 | assert key.startswith('train') or key.startswith('eval') 169 | self._try_sw_log_histogram(key, histogram, step) 170 | 171 | def dump(self, step): 172 | self._train_mg.dump(step, 'train') 173 | self._eval_mg.dump(step, 'eval') 174 | -------------------------------------------------------------------------------- /Atari/src/vit_modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 13 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 14 | import numpy as np 15 | 16 | 17 | def trunc_normal_(tensor, mean=0., std=1.): 18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 19 | 20 | 21 | def get_sinusoid_encoding_table(n_position, d_hid): 22 | ''' Sinusoid position encoding table ''' 23 | # TODO: make it with torch instead of numpy 24 | def get_position_angle_vec(position): 25 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 26 | 27 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 28 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 29 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 30 | 31 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 32 | 33 | 34 | class Block(nn.Module): 35 | 36 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 37 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 38 | attn_head_dim=None): 39 | super().__init__() 40 | self.norm1 = norm_layer(dim) 41 | self.attn = Attention( 42 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 43 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 44 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 45 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 46 | self.norm2 = norm_layer(dim) 47 | mlp_hidden_dim = int(dim * mlp_ratio) 48 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 49 | 50 | if init_values > 0: 51 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 52 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 53 | else: 54 | self.gamma_1, self.gamma_2 = None, None 55 | 56 | def forward(self, x): 57 | if self.gamma_1 is None: 58 | x = x + self.drop_path(self.attn(self.norm1(x))) 59 | x = x + self.drop_path(self.mlp(self.norm2(x))) 60 | else: 61 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) 62 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 63 | return x 64 | 65 | class PatchEmbed(nn.Module): 66 | """ Image to Patch Embedding 67 | """ 68 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 69 | super().__init__() 70 | img_size = to_2tuple(img_size) 71 | patch_size = to_2tuple(patch_size) 72 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 73 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 74 | self.img_size = img_size 75 | self.patch_size = patch_size 76 | self.num_patches = num_patches 77 | 78 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 79 | 80 | def forward(self, x, **kwargs): 81 | B, C, H, W = x.shape 82 | # FIXME look at relaxing size constraints 83 | assert H == self.img_size[0] and W == self.img_size[1], \ 84 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 85 | x = self.proj(x).flatten(2).transpose(1, 2) 86 | return x 87 | 88 | 89 | class Mlp(nn.Module): 90 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 91 | super().__init__() 92 | out_features = out_features or in_features 93 | hidden_features = hidden_features or in_features 94 | self.fc1 = nn.Linear(in_features, hidden_features) 95 | self.act = act_layer() 96 | self.fc2 = nn.Linear(hidden_features, out_features) 97 | self.drop = nn.Dropout(drop) 98 | 99 | def forward(self, x): 100 | x = self.fc1(x) 101 | x = self.act(x) 102 | # x = self.drop(x) 103 | # commit this for the orignal BERT implement 104 | x = self.fc2(x) 105 | x = self.drop(x) 106 | return x 107 | 108 | 109 | class DropPath(nn.Module): 110 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 111 | """ 112 | def __init__(self, drop_prob=None): 113 | super(DropPath, self).__init__() 114 | self.drop_prob = drop_prob 115 | 116 | def forward(self, x): 117 | return drop_path(x, self.drop_prob, self.training) 118 | 119 | def extra_repr(self) -> str: 120 | return 'p={}'.format(self.drop_prob) 121 | 122 | 123 | class Attention(nn.Module): 124 | def __init__( 125 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 126 | proj_drop=0., attn_head_dim=None): 127 | super().__init__() 128 | self.num_heads = num_heads 129 | head_dim = dim // num_heads 130 | if attn_head_dim is not None: 131 | head_dim = attn_head_dim 132 | all_head_dim = head_dim * self.num_heads 133 | self.scale = qk_scale or head_dim ** -0.5 134 | 135 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 136 | if qkv_bias: 137 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 138 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 139 | else: 140 | self.q_bias = None 141 | self.v_bias = None 142 | 143 | self.attn_drop = nn.Dropout(attn_drop) 144 | self.proj = nn.Linear(all_head_dim, dim) 145 | self.proj_drop = nn.Dropout(proj_drop) 146 | 147 | def forward(self, x): 148 | B, N, C = x.shape 149 | qkv_bias = None 150 | if self.q_bias is not None: 151 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 152 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 153 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 154 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 155 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 156 | 157 | q = q * self.scale 158 | attn = (q @ k.transpose(-2, -1)) 159 | 160 | 161 | attn = attn.softmax(dim=-1) 162 | attn = self.attn_drop(attn) 163 | 164 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 165 | x = self.proj(x) 166 | x = self.proj_drop(x) 167 | return x -------------------------------------------------------------------------------- /DMControl/src/vit_modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Based on BEiT, timm, DINO and DeiT code bases 3 | # https://github.com/microsoft/unilm/tree/master/beit 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 13 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 14 | import numpy as np 15 | 16 | 17 | def trunc_normal_(tensor, mean=0., std=1.): 18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 19 | 20 | 21 | def get_sinusoid_encoding_table(n_position, d_hid): 22 | ''' Sinusoid position encoding table ''' 23 | # TODO: make it with torch instead of numpy 24 | def get_position_angle_vec(position): 25 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 26 | 27 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 28 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 29 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 30 | 31 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 32 | 33 | 34 | class Block(nn.Module): 35 | 36 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 37 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 38 | attn_head_dim=None): 39 | super().__init__() 40 | self.norm1 = norm_layer(dim) 41 | self.attn = Attention( 42 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 43 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 44 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 45 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 46 | self.norm2 = norm_layer(dim) 47 | mlp_hidden_dim = int(dim * mlp_ratio) 48 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 49 | 50 | if init_values > 0: 51 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 52 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 53 | else: 54 | self.gamma_1, self.gamma_2 = None, None 55 | 56 | def forward(self, x): 57 | if self.gamma_1 is None: 58 | x = x + self.drop_path(self.attn(self.norm1(x))) 59 | x = x + self.drop_path(self.mlp(self.norm2(x))) 60 | else: 61 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) 62 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 63 | return x 64 | 65 | class PatchEmbed(nn.Module): 66 | """ Image to Patch Embedding 67 | """ 68 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 69 | super().__init__() 70 | img_size = to_2tuple(img_size) 71 | patch_size = to_2tuple(patch_size) 72 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 73 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 74 | self.img_size = img_size 75 | self.patch_size = patch_size 76 | self.num_patches = num_patches 77 | 78 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 79 | 80 | def forward(self, x, **kwargs): 81 | B, C, H, W = x.shape 82 | # FIXME look at relaxing size constraints 83 | assert H == self.img_size[0] and W == self.img_size[1], \ 84 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 85 | x = self.proj(x).flatten(2).transpose(1, 2) 86 | return x 87 | 88 | 89 | class Mlp(nn.Module): 90 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 91 | super().__init__() 92 | out_features = out_features or in_features 93 | hidden_features = hidden_features or in_features 94 | self.fc1 = nn.Linear(in_features, hidden_features) 95 | self.act = act_layer() 96 | self.fc2 = nn.Linear(hidden_features, out_features) 97 | self.drop = nn.Dropout(drop) 98 | 99 | def forward(self, x): 100 | x = self.fc1(x) 101 | x = self.act(x) 102 | # x = self.drop(x) 103 | # commit this for the orignal BERT implement 104 | x = self.fc2(x) 105 | x = self.drop(x) 106 | return x 107 | 108 | 109 | class DropPath(nn.Module): 110 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 111 | """ 112 | def __init__(self, drop_prob=None): 113 | super(DropPath, self).__init__() 114 | self.drop_prob = drop_prob 115 | 116 | def forward(self, x): 117 | return drop_path(x, self.drop_prob, self.training) 118 | 119 | def extra_repr(self) -> str: 120 | return 'p={}'.format(self.drop_prob) 121 | 122 | 123 | class Attention(nn.Module): 124 | def __init__( 125 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 126 | proj_drop=0., attn_head_dim=None): 127 | super().__init__() 128 | self.num_heads = num_heads 129 | head_dim = dim // num_heads 130 | if attn_head_dim is not None: 131 | head_dim = attn_head_dim 132 | all_head_dim = head_dim * self.num_heads 133 | self.scale = qk_scale or head_dim ** -0.5 134 | 135 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 136 | if qkv_bias: 137 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 138 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 139 | else: 140 | self.q_bias = None 141 | self.v_bias = None 142 | 143 | self.attn_drop = nn.Dropout(attn_drop) 144 | self.proj = nn.Linear(all_head_dim, dim) 145 | self.proj_drop = nn.Dropout(proj_drop) 146 | 147 | def forward(self, x): 148 | B, N, C = x.shape 149 | qkv_bias = None 150 | if self.q_bias is not None: 151 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 152 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 153 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 154 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 155 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 156 | 157 | q = q * self.scale 158 | attn = (q @ k.transpose(-2, -1)) 159 | 160 | 161 | attn = attn.softmax(dim=-1) 162 | attn = self.attn_drop(attn) 163 | 164 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 165 | x = self.proj(x) 166 | x = self.proj_drop(x) 167 | return x -------------------------------------------------------------------------------- /Atari/scripts/run.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import copy 7 | from rlpyt.experiments.configs.atari.dqn.atari_dqn import configs 8 | from rlpyt.samplers.serial.sampler import SerialSampler 9 | from rlpyt.envs.atari.atari_env import AtariTrajInfo 10 | from rlpyt.utils.logging.context import logger_context 11 | 12 | import os 13 | import wandb 14 | import torch 15 | import numpy as np 16 | 17 | from src.models import PVCatDqnModel 18 | from src.rlpyt_utils import OneToOneSerialEvalCollector, SerialSampler, MinibatchRlEvalWandb 19 | from src.algos import PVCategoricalDQN 20 | from src.agent import PVAgent 21 | from src.rlpyt_atari_env import AtariEnv 22 | from src.utils import set_config 23 | 24 | 25 | def build_and_train(game="pong", run_ID=0, cuda_idx=0, args=None): 26 | np.random.seed(args.seed) 27 | torch.manual_seed(args.seed) 28 | env = AtariEnv 29 | # env.seed(1) 30 | config = set_config(args, game) 31 | 32 | sampler = SerialSampler( 33 | EnvCls=env, 34 | TrajInfoCls=AtariTrajInfo, # default traj info + GameScore 35 | env_kwargs=config["env"], 36 | eval_env_kwargs=config["eval_env"], 37 | batch_T=config['sampler']['batch_T'], 38 | batch_B=config['sampler']['batch_B'], 39 | max_decorrelation_steps=0, 40 | eval_CollectorCls=OneToOneSerialEvalCollector, 41 | eval_n_envs=config["sampler"]["eval_n_envs"], 42 | eval_max_steps=config['sampler']['eval_max_steps'], 43 | eval_max_trajectories=config["sampler"]["eval_max_trajectories"], 44 | ) 45 | args.discount = config["algo"]["discount"] # 0.99 46 | algo = PVCategoricalDQN(optim_kwargs=config["optim"], jumps=args.jumps, **config["algo"]) # Run with defaults. 47 | agent = PVAgent(ModelCls=PVCatDqnModel, model_kwargs=config["model"], **config["agent"]) 48 | 49 | wandb.config.update(config) 50 | runner = MinibatchRlEvalWandb( 51 | algo=algo, 52 | agent=agent, 53 | sampler=sampler, 54 | n_steps=args.n_steps, 55 | affinity=dict(cuda_idx=cuda_idx), 56 | log_interval_steps=args.n_steps//args.num_logs, 57 | seed=args.seed, 58 | final_eval_only=args.final_eval_only, 59 | ) 60 | config = dict(game=game) 61 | name = "dqn_" + game 62 | log_dir = "logs" 63 | print("=======") 64 | with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"): 65 | runner.train() 66 | 67 | quit() 68 | 69 | 70 | if __name__ == "__main__": 71 | import argparse 72 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 73 | parser.add_argument('--name', help='Experiment name', default='Atari') 74 | parser.add_argument('--game', help='Atari game', default='ms_pacman') 75 | parser.add_argument('--seed', type=int, default=0) 76 | parser.add_argument('--grayscale', type=int, default=1) 77 | parser.add_argument('--framestack', type=int, default=4) 78 | parser.add_argument('--imagesize', type=int, default=84) 79 | parser.add_argument('--n-steps', type=int, default=100000) 80 | parser.add_argument('--dqn-hidden-size', type=int, default=256) 81 | parser.add_argument('--target-update-interval', type=int, default=1) 82 | parser.add_argument('--target-update-tau', type=float, default=1.) 83 | parser.add_argument('--momentum-tau', type=float, default=1.) 84 | parser.add_argument('--batch-b', type=int, default=1) 85 | parser.add_argument('--batch-t', type=int, default=1) 86 | parser.add_argument('--beluga', action="store_true") 87 | parser.add_argument('--jumps', type=int, default=9) 88 | parser.add_argument('--num-logs', type=int, default=10) 89 | parser.add_argument('--renormalize', type=int, default=1) 90 | parser.add_argument('--dueling', type=int, default=1) 91 | parser.add_argument('--replay-ratio', type=int, default=64) 92 | parser.add_argument('--dynamics-blocks', type=int, default=0) 93 | parser.add_argument('--residual-tm', type=int, default=0.) 94 | parser.add_argument('--n-step', type=int, default=10) 95 | parser.add_argument('--batch-size', type=int, default=32) 96 | parser.add_argument('--tag', type=str, default='', help='Tag for wandb run.') 97 | parser.add_argument('--wandb-dir', type=str, default='./', help='Directory for wandb files.') 98 | parser.add_argument('--norm-type', type=str, default='bn', choices=["bn", "ln", "in", "none"], help='Normalization') 99 | parser.add_argument('--aug-prob', type=float, default=1., help='Probability to apply augmentation') 100 | parser.add_argument('--dropout', type=float, default=0., help='Dropout probability in convnet.') 101 | parser.add_argument('--spr', type=int, default=1) 102 | parser.add_argument('--distributional', type=int, default=1) 103 | parser.add_argument('--delta-clip', type=float, default=1., help="Huber Delta") 104 | parser.add_argument('--prioritized-replay', type=int, default=1) 105 | parser.add_argument('--momentum-encoder', type=int, default=1) 106 | parser.add_argument('--shared-encoder', type=int, default=0) 107 | parser.add_argument('--local-spr', type=int, default=0) 108 | parser.add_argument('--global-spr', type=int, default=1) 109 | parser.add_argument('--noisy-nets', type=int, default=1) 110 | parser.add_argument('--noisy-nets-std', type=float, default=0.1) 111 | parser.add_argument('--classifier', type=str, default='q_l1', choices=["mlp", "bilinear", "q_l1", "q_l2", "none"], help='Style of NCE classifier') 112 | parser.add_argument('--final-classifier', type=str, default='linear', choices=["mlp", "linear", "none"], help='Style of NCE classifier') 113 | parser.add_argument('--augmentation', type=str, default=["shift", "intensity"], nargs="+", 114 | choices=["none", "rrc", "affine", "crop", "blur", "shift", "intensity"], 115 | help='Style of augmentation') 116 | parser.add_argument('--q-l1-type', type=str, default=["value", "advantage"], nargs="+", 117 | choices=["noisy", "value", "advantage", "relu"], 118 | help='Style of q_l1 projection') 119 | parser.add_argument('--target-augmentation', type=int, default=1, help='Use augmentation on inputs to target networks') 120 | parser.add_argument('--eval-augmentation', type=int, default=0, help='Use augmentation on inputs at evaluation time') 121 | parser.add_argument('--reward-loss-weight', type=float, default=0.) 122 | parser.add_argument('--model-rl-weight', type=float, default=0.) 123 | parser.add_argument('--model-spr-weight', type=float, default=5.) 124 | parser.add_argument('--t0-spr-loss-weight', type=float, default=0.) 125 | parser.add_argument('--eps-steps', type=int, default=2001) 126 | parser.add_argument('--min-steps-learn', type=int, default=2000) 127 | parser.add_argument('--eps-init', type=float, default=1.) 128 | parser.add_argument('--eps-final', type=float, default=0.) 129 | parser.add_argument('--final-eval-only', type=int, default=1) 130 | parser.add_argument('--time-offset', type=int, default=0) 131 | parser.add_argument('--project', type=str, default="SPR") 132 | parser.add_argument('--entity', type=str, default="abs-world-models") 133 | parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=0) 134 | parser.add_argument('--max-grad-norm', type=float, default=10., help='Max Grad Norm') 135 | parser.add_argument('--public', action='store_true', help='If set, uses anonymous wandb logging') 136 | parser.add_argument('--real-cycle', action='store_true') 137 | parser.add_argument('--virtual-cycle', action='store_true') 138 | parser.add_argument('--fp', action='store_true', help='if set, then calculate forward prediction loss') 139 | parser.add_argument('--bp', action='store_true', help='if set, then calculate backward prediction loss') 140 | parser.add_argument('--bp-mode', type=str, default='gt', help='train RDM starting from gt s_t+K(gt) or estimated s_t+K(esti)') 141 | parser.add_argument('--rc-weight', type=float, default=0.) 142 | parser.add_argument('--vc-weight', type=float, default=1.) 143 | parser.add_argument('--fp-weight', type=float, default=1., help='Forward Prediction Loss weight') 144 | parser.add_argument('--bp-weight', type=float, default=0., help='Backward Prediction Loss weight') 145 | parser.add_argument('--mlr-weight', type=float, default=0., help='MLR Loss weight') 146 | parser.add_argument('--cycle-step', type=int, default=9) 147 | parser.add_argument('--aug-num', type=float, default=2.) 148 | parser.add_argument('--space', type=str, default='y', help='calculate similarity loss in cycle on which space: mse, z or y') 149 | parser.add_argument('--aug-type', type=str, default='random', help='random, nonaug or hybrid') 150 | parser.add_argument('--warmup', type=int, default=50000) 151 | parser.add_argument('--bp-warm', action='store_true', help='if set, then annealing decreasing bp-weight') 152 | args = parser.parse_args() 153 | print(args) 154 | 155 | print("============= Experiment:{} =============".format(args.name)) 156 | os.environ['WANDB_MODE'] = 'dryrun' 157 | 158 | if args.public: 159 | wandb.init(anonymous="allow", config=args, tags=[args.tag] if args.tag else None, dir=args.wandb_dir) 160 | else: 161 | # wandb.init(project=args.project, entity=args.entity, config=args, tags=[args.tag] if args.tag else None, dir=args.wandb_dir) 162 | wandb.init(project=args.project, config=args, tags=[args.tag] if args.tag else None, dir=args.wandb_dir, name=args.name) 163 | wandb.config.update(vars(args)) 164 | build_and_train(game=args.game, 165 | cuda_idx=args.cuda_idx, 166 | args=args) 167 | 168 | -------------------------------------------------------------------------------- /Atari/src/rlpyt_atari_env.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is borrowed from https://github.com/mila-iqia 3 | # -------------------------------------------------------- 4 | 5 | import numpy as np 6 | import os 7 | import atari_py 8 | import cv2 9 | from collections import namedtuple 10 | from gym.utils import seeding 11 | 12 | from rlpyt.envs.base import Env, EnvStep 13 | from rlpyt.spaces.int_box import IntBox 14 | from rlpyt.utils.quick_args import save__init__args 15 | from rlpyt.samplers.collections import TrajInfo 16 | 17 | 18 | EnvInfo = namedtuple("EnvInfo", ["game_score", "traj_done"]) 19 | 20 | 21 | class AtariTrajInfo(TrajInfo): 22 | """TrajInfo class for use with Atari Env, to store raw game score separate 23 | from clipped reward signal.""" 24 | 25 | def __init__(self, **kwargs): 26 | super().__init__(**kwargs) 27 | self.GameScore = 0 28 | 29 | def step(self, observation, action, reward, done, agent_info, env_info): 30 | super().step(observation, action, reward, done, agent_info, env_info) 31 | self.GameScore += getattr(env_info, "game_score", 0) 32 | 33 | 34 | class AtariEnv(Env): 35 | """An efficient implementation of the classic Atari RL envrionment using the 36 | Arcade Learning Environment (ALE). 37 | 38 | Output `env_info` includes: 39 | * `game_score`: raw game score, separate from reward clipping. 40 | * `traj_done`: special signal which signals game-over or timeout, so that sampler doesn't reset the environment when ``done==True`` but ``traj_done==False``, which can happen when ``episodic_lives==True``. 41 | 42 | Always performs 2-frame max to avoid flickering (this is pretty fast). 43 | 44 | Screen size downsampling is done by cropping two rows and then 45 | downsampling by 2x using `cv2`: (210, 160) --> (80, 104). Downsampling by 46 | 2x is much faster than the old scheme to (84, 84), and the (80, 104) shape 47 | is fairly convenient for convolution filter parameters which don't cut off 48 | edges. 49 | 50 | The action space is an `IntBox` for the number of actions. The observation 51 | space is an `IntBox` with ``dtype=uint8`` to save memory; conversion to float 52 | should happen inside the agent's model's ``forward()`` method. 53 | 54 | (See the file for implementation details.) 55 | 56 | 57 | Args: 58 | game (str): game name 59 | frame_skip (int): frames per step (>=1) 60 | num_img_obs (int): number of frames in observation (>=1) 61 | clip_reward (bool): if ``True``, clip reward to np.sign(reward) 62 | episodic_lives (bool): if ``True``, output ``done=True`` but ``env_info[traj_done]=False`` when a life is lost 63 | max_start_noops (int): upper limit for random number of noop actions after reset 64 | repeat_action_probability (0-1): probability for sticky actions 65 | horizon (int): max number of steps before timeout / ``traj_done=True`` 66 | """ 67 | 68 | def __init__(self, 69 | game="pong", 70 | frame_skip=4, # Frames per step (>=1). 71 | num_img_obs=4, # Number of (past) frames in observation (>=1). 72 | clip_reward=True, 73 | episodic_lives=True, 74 | max_start_noops=30, 75 | repeat_action_probability=0., 76 | horizon=27000, 77 | stack_actions=0, 78 | grayscale=True, 79 | imagesize=84, 80 | seed=42, 81 | id=0, 82 | ): 83 | save__init__args(locals(), underscore=True) 84 | # ALE 85 | game_path = atari_py.get_game_path(game) 86 | if not os.path.exists(game_path): 87 | raise IOError("You asked for game {} but path {} does not " 88 | " exist".format(game, game_path)) 89 | self.ale = atari_py.ALEInterface() 90 | self.seed(seed, id) 91 | self.ale.setFloat(b'repeat_action_probability', repeat_action_probability) 92 | self.ale.loadROM(game_path) 93 | 94 | # Spaces 95 | self.stack_actions = stack_actions 96 | self._action_set = self.ale.getMinimalActionSet() 97 | self._action_space = IntBox(low=0, high=len(self._action_set)) 98 | self.channels = 1 if grayscale else 3 99 | self.grayscale = grayscale 100 | self.imagesize = imagesize 101 | if self.stack_actions: self.channels += 1 102 | obs_shape = (num_img_obs, self.channels, imagesize, imagesize) 103 | self._observation_space = IntBox(low=0, high=255, shape=obs_shape, 104 | dtype="uint8") 105 | self._max_frame = self.ale.getScreenGrayscale() if self.grayscale \ 106 | else self.ale.getScreenRGB() 107 | self._raw_frame_1 = self._max_frame.copy() 108 | self._raw_frame_2 = self._max_frame.copy() 109 | self._obs = np.zeros(shape=obs_shape, dtype="uint8") 110 | 111 | # Settings 112 | self._has_fire = "FIRE" in self.get_action_meanings() 113 | self._has_up = "UP" in self.get_action_meanings() 114 | self._horizon = int(horizon) 115 | self.reset() 116 | 117 | def seed(self, seed=None, id=0): 118 | _, seed1 = seeding.np_random(seed) 119 | if id > 0: 120 | seed = seed*100 + id 121 | self.np_random, _ = seeding.np_random(seed) 122 | # Derive a random seed. This gets passed as a uint, but gets 123 | # checked as an int elsewhere, so we need to keep it below 124 | # 2**31. 125 | seed2 = seeding.hash_seed(seed1 + 1) % 2**31 126 | # Empirically, we need to seed before loading the ROM. 127 | self.ale.setInt(b'random_seed', seed2) 128 | 129 | def reset(self): 130 | """Performs hard reset of ALE game.""" 131 | self.ale.reset_game() 132 | self._reset_obs() 133 | self._life_reset() 134 | if self._max_start_noops > 0: 135 | for _ in range(self.np_random.randint(1, self._max_start_noops + 1)): 136 | self.ale.act(0) 137 | if self._check_life(): 138 | self.reset() 139 | self._update_obs(0) # (don't bother to populate any frame history) 140 | self._step_counter = 0 141 | return self.get_obs() 142 | 143 | def step(self, action): 144 | a = self._action_set[action] 145 | game_score = np.array(0., dtype="float32") 146 | for _ in range(self._frame_skip - 1): 147 | game_score += self.ale.act(a) 148 | self._get_screen(1) 149 | game_score += self.ale.act(a) 150 | lost_life = self._check_life() # Advances from lost_life state. 151 | if lost_life and self._episodic_lives: 152 | self._reset_obs() # Internal reset. 153 | self._update_obs(action) 154 | reward = np.sign(game_score) if self._clip_reward else game_score 155 | game_over = self.ale.game_over() or self._step_counter >= self.horizon 156 | done = game_over or (self._episodic_lives and lost_life) 157 | info = EnvInfo(game_score=game_score, traj_done=game_over) 158 | self._step_counter += 1 159 | return EnvStep(self.get_obs(), reward, done, info) 160 | 161 | def render(self, wait=10, show_full_obs=False): 162 | """Shows game screen via cv2, with option to show all frames in observation.""" 163 | img = self.get_obs() 164 | if show_full_obs: 165 | shape = img.shape 166 | img = img.reshape(shape[0] * shape[1], shape[2]) 167 | else: 168 | img = img[-1] 169 | cv2.imshow(self._game, img) 170 | cv2.waitKey(wait) 171 | 172 | def get_obs(self): 173 | return self._obs.copy() 174 | 175 | ########################################################################### 176 | # Helpers 177 | 178 | def _get_screen(self, frame=1): 179 | frame = self._raw_frame_1 if frame == 1 else self._raw_frame_2 180 | if self.grayscale: 181 | self.ale.getScreenGrayscale(frame) 182 | else: 183 | self.ale.getScreenRGB(frame) 184 | 185 | def _update_obs(self, action): 186 | """Max of last two frames; crop two rows; downsample by 2x.""" 187 | self._get_screen(2) 188 | np.maximum(self._raw_frame_1, self._raw_frame_2, self._max_frame) 189 | img = cv2.resize(self._max_frame, (self.imagesize, self.imagesize), cv2.INTER_LINEAR) 190 | if len(img.shape) == 2: 191 | img = img[np.newaxis] 192 | else: 193 | img = np.transpose(img, (2, 0, 1)) 194 | if self.stack_actions: 195 | action = int(255.*action/self._action_space.n) 196 | action = np.ones_like(img[:1])*action 197 | img = np.concatenate([img, action], 0) 198 | # NOTE: order OLDEST to NEWEST should match use in frame-wise buffer. 199 | self._obs = np.concatenate([self._obs[1:], img[np.newaxis]]) 200 | 201 | def _reset_obs(self): 202 | self._obs[:] = 0 203 | self._max_frame[:] = 0 204 | self._raw_frame_1[:] = 0 205 | self._raw_frame_2[:] = 0 206 | 207 | def _check_life(self): 208 | lives = self.ale.lives() 209 | lost_life = (lives < self._lives) and (lives > 0) 210 | if lost_life: 211 | self._life_reset() 212 | return lost_life 213 | 214 | def _life_reset(self): 215 | self.ale.act(0) 216 | self._lives = self.ale.lives() 217 | 218 | ########################################################################### 219 | # Properties 220 | 221 | @property 222 | def game(self): 223 | return self._game 224 | 225 | @property 226 | def frame_skip(self): 227 | return self._frame_skip 228 | 229 | @property 230 | def num_img_obs(self): 231 | return self._num_img_obs 232 | 233 | @property 234 | def clip_reward(self): 235 | return self._clip_reward 236 | 237 | @property 238 | def max_start_noops(self): 239 | return self._max_start_noops 240 | 241 | @property 242 | def episodic_lives(self): 243 | return self._episodic_lives 244 | 245 | @property 246 | def repeat_action_probability(self): 247 | return self._repeat_action_probability 248 | 249 | @property 250 | def horizon(self): 251 | return self._horizon 252 | 253 | def get_action_meanings(self): 254 | return [ACTION_MEANING[i] for i in self._action_set] 255 | 256 | 257 | ACTION_MEANING = { 258 | 0: "NOOP", 259 | 1: "FIRE", 260 | 2: "UP", 261 | 3: "RIGHT", 262 | 4: "LEFT", 263 | 5: "DOWN", 264 | 6: "UPRIGHT", 265 | 7: "UPLEFT", 266 | 8: "DOWNRIGHT", 267 | 9: "DOWNLEFT", 268 | 10: "UPFIRE", 269 | 11: "RIGHTFIRE", 270 | 12: "LEFTFIRE", 271 | 13: "DOWNFIRE", 272 | 14: "UPRIGHTFIRE", 273 | 15: "UPLEFTFIRE", 274 | 16: "DOWNRIGHTFIRE", 275 | 17: "DOWNLEFTFIRE", 276 | } 277 | 278 | ACTION_INDEX = {v: k for k, v in ACTION_MEANING.items()} 279 | -------------------------------------------------------------------------------- /Atari/src/masking_generator.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import random 7 | import math 8 | import numpy as np 9 | from numpy.core.shape_base import block 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class MaskingGenerator: 15 | def __init__( 16 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, 17 | min_aspect=0.3, max_aspect=None): 18 | if not isinstance(input_size, tuple): 19 | input_size = (input_size, ) * 2 20 | self.height, self.width = input_size 21 | 22 | self.num_patches = self.height * self.width 23 | self.num_masking_patches = num_masking_patches 24 | 25 | self.min_num_patches = min_num_patches 26 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 27 | 28 | max_aspect = max_aspect or 1 / min_aspect 29 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 30 | 31 | def __repr__(self): 32 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 33 | self.height, self.width, self.min_num_patches, self.max_num_patches, 34 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 35 | return repr_str 36 | 37 | def get_shape(self): 38 | return self.height, self.width 39 | 40 | def _mask(self, mask, max_mask_patches): 41 | delta = 0 42 | for attempt in range(10): 43 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 44 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 45 | h = int(round(math.sqrt(target_area * aspect_ratio))) 46 | w = int(round(math.sqrt(target_area / aspect_ratio))) 47 | if w < self.width and h < self.height: 48 | top = random.randint(0, self.height - h) 49 | left = random.randint(0, self.width - w) 50 | 51 | num_masked = mask[top: top + h, left: left + w].sum() 52 | # Overlap 53 | if 0 < h * w - num_masked <= max_mask_patches: 54 | for i in range(top, top + h): 55 | for j in range(left, left + w): 56 | if mask[i, j] == 0: 57 | mask[i, j] = 1 58 | delta += 1 59 | 60 | if delta > 0: 61 | break 62 | return delta 63 | 64 | def __call__(self): 65 | mask = np.zeros(shape=self.get_shape(), dtype=np.int) 66 | mask_count = 0 67 | while mask_count < self.num_masking_patches: 68 | max_mask_patches = self.num_masking_patches - mask_count 69 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 70 | 71 | delta = self._mask(mask, max_mask_patches) 72 | if delta == 0: 73 | break 74 | else: 75 | mask_count += delta 76 | 77 | return mask 78 | 79 | class RandomMaskingGenerator: 80 | def __init__(self, input_size, mask_ratio): 81 | if not isinstance(input_size, tuple): 82 | input_size = (input_size,) * 2 83 | 84 | self.height, self.width = input_size 85 | 86 | self.num_patches = self.height * self.width 87 | self.num_mask = int(mask_ratio * self.num_patches) 88 | 89 | def __repr__(self): 90 | repr_str = "Maks: total patches {}, mask patches {}".format( 91 | self.num_patches, self.num_mask 92 | ) 93 | return repr_str 94 | 95 | def __call__(self): 96 | mask = np.hstack([ 97 | np.zeros(self.num_patches - self.num_mask), 98 | np.ones(self.num_mask), 99 | ]) 100 | np.random.shuffle(mask) 101 | return mask.astype(np.bool) # [196] 102 | 103 | class RandomMaskingMapGenerator: 104 | def __init__(self, input_size, mask_ratio, image_size): 105 | if not isinstance(input_size, tuple): 106 | input_size = (input_size,) * 2 107 | 108 | self.height, self.width = input_size 109 | 110 | self.num_patches = self.height * self.width 111 | self.num_mask = int(mask_ratio * self.num_patches) 112 | 113 | self.image_size = image_size 114 | self.upsampler = nn.Upsample((image_size, image_size)) 115 | 116 | def __repr__(self): 117 | repr_str = "Maks: total patches {}, mask patches {}".format( 118 | self.num_patches, self.num_mask 119 | ) 120 | return repr_str 121 | 122 | def __call__(self): 123 | mask = np.hstack([ 124 | np.zeros(self.num_patches - self.num_mask), 125 | np.ones(self.num_mask), 126 | ]) 127 | np.random.shuffle(mask) 128 | mask = torch.from_numpy(mask).reshape(self.height, self.width) 129 | mask = self.upsampler(mask[None, None].float()) 130 | return mask # [196] 131 | 132 | class DiverseRandomMaskingMapGenerator: 133 | def __init__(self, input_size, mask_ratio, image_size, time_span): 134 | if not isinstance(input_size, tuple): 135 | input_size = (input_size,) * 2 136 | 137 | self.height, self.width = input_size 138 | 139 | self.num_patches = self.height * self.width 140 | self.num_mask = int(mask_ratio * self.num_patches) 141 | 142 | self.image_size = image_size 143 | self.upsampler = nn.Upsample((image_size, image_size)) 144 | 145 | self.time_span = time_span 146 | 147 | def __repr__(self): 148 | repr_str = "Maks: total patches {}, mask patches {}".format( 149 | self.num_patches, self.num_mask 150 | ) 151 | return repr_str 152 | 153 | def __call__(self): 154 | mask = np.hstack([ 155 | np.zeros(self.num_patches - self.num_mask), 156 | np.ones(self.num_mask), 157 | ]) 158 | mask = np.vstack([mask] * self.time_span) 159 | 160 | for i in range(self.time_span): 161 | mask[i] = mask[i, torch.randperm(mask.shape[-1])] 162 | 163 | mask = torch.from_numpy(mask)[:, :, None].reshape(self.time_span, self.height, self.width) 164 | mask = self.upsampler(mask[None].float()) 165 | return mask 166 | 167 | class RandomMaskingListGenerator: 168 | def __init__(self, list_len, mask_ratio): 169 | self.num_patches = list_len 170 | self.num_mask = int(mask_ratio * self.num_patches) 171 | 172 | def __repr__(self): 173 | repr_str = "Maks: total patches {}, mask patches {}".format( 174 | self.num_patches, self.num_mask 175 | ) 176 | return repr_str 177 | 178 | def __call__(self): 179 | mask = np.hstack([ 180 | np.zeros(self.num_patches - self.num_mask), 181 | np.ones(self.num_mask), 182 | ]) 183 | np.random.shuffle(mask) 184 | return mask.astype(np.bool) 185 | # return torch.from_numpy(mask).float() # [196] 186 | 187 | 188 | class RestMaskingListGenerator: 189 | def __init__(self, list_len, mask_ratio): 190 | self.num_patches = list_len 191 | self.num_mask = int(mask_ratio * self.num_patches) 192 | 193 | def __repr__(self): 194 | repr_str = "Maks: total patches {}, mask patches {}".format( 195 | self.num_patches, self.num_mask 196 | ) 197 | return repr_str 198 | 199 | def __call__(self): 200 | mask = np.hstack([ 201 | np.zeros(self.num_patches - self.num_mask), 202 | np.ones(self.num_mask), 203 | ]) 204 | # np.random.shuffle(mask) 205 | return mask.astype(np.bool) 206 | # return torch.from_numpy(mask).float() # [196] 207 | 208 | class RandomRestMaskingListGenerator: 209 | ''' 210 | t > offset, set to True; t < offset, set to False 211 | ''' 212 | def __init__(self, list_len): 213 | self.num_patches = list_len 214 | 215 | def __repr__(self): 216 | repr_str = "Maks: total patches {}".format( 217 | self.num_patches 218 | ) 219 | return repr_str 220 | 221 | def __call__(self): 222 | offset = np.random.randint(1, self.num_patches) 223 | mask = np.hstack([ 224 | np.zeros(offset), 225 | np.ones(self.num_patches - offset), 226 | ]) 227 | # np.random.shuffle(mask) 228 | return mask.astype(np.bool) 229 | # return torch.from_numpy(mask).float() # [196] 230 | 231 | class RandomBlockMaskingListGenerator: 232 | def __init__(self, list_len, mask_ratio, block_size): 233 | assert list_len % block_size == 0 234 | self.list_len = list_len 235 | self.block_size = block_size 236 | self.num_blocks = list_len // block_size 237 | 238 | self.num_mask = int(mask_ratio * self.num_blocks) 239 | 240 | def __call__(self): 241 | mask = np.hstack([ 242 | np.zeros(self.num_blocks - self.num_mask), 243 | np.ones(self.num_mask), 244 | ]) 245 | np.random.shuffle(mask) 246 | mask = np.repeat(mask, self.block_size) 247 | return mask.astype(np.bool) 248 | 249 | class RandomMaskGenerator: 250 | def __init__(self, list_len, block_size, mask_ratio): 251 | assert mask_ratio <= 1.0 252 | self.block_size = block_size 253 | self.num_blocks = list_len // block_size 254 | self.num_masked_blocks = int(mask_ratio * self.num_blocks) 255 | self.num_rest = list_len % block_size 256 | 257 | def __call__(self): 258 | mask = np.hstack([ 259 | np.zeros(self.num_blocks - self.num_masked_blocks), 260 | np.ones(self.num_masked_blocks), 261 | ]) 262 | np.random.shuffle(mask) 263 | mask = np.repeat(mask, self.block_size) 264 | mask = np.concatenate([mask, np.zeros(self.num_rest)], 0) 265 | mask = np.roll(mask, np.random.randint(0, self.block_size)) 266 | # mask[0] = 0 # set the first elsement always to unmasked 267 | return mask.astype(np.bool) 268 | 269 | class CubeMaskGenerator: 270 | def __init__(self, input_size, image_size, clip_size, block_size, mask_ratio): 271 | assert mask_ratio <= 1.0 272 | 273 | if not isinstance(input_size, tuple): 274 | input_size = (input_size,) * 2 275 | self.height, self.width = input_size 276 | self.num_patches = self.height * self.width 277 | self.num_mask = int(mask_ratio * self.num_patches) 278 | self.image_size = image_size 279 | self.upsampler = nn.Upsample((image_size, image_size)) 280 | 281 | self.block_size = block_size 282 | self.num_blocks = clip_size // block_size 283 | 284 | 285 | def __call__(self): 286 | mask = np.hstack([ 287 | np.zeros(self.num_patches - self.num_mask), 288 | np.ones(self.num_mask), 289 | ]) 290 | for i in range(self.num_blocks): 291 | np.random.shuffle(mask) 292 | cur_mask = torch.from_numpy(mask).reshape(self.height, self.width) 293 | cur_mask = self.upsampler(cur_mask[None, None].float()) # (1, 1, h, w) 294 | cur_mask = cur_mask.expand(self.block_size, *cur_mask.size()[1:]) 295 | cube_mask = torch.cat([cube_mask, cur_mask]) if i > 0 else cur_mask 296 | return cube_mask 297 | 298 | if __name__ == '__main__': 299 | masker = DiverseRandomMaskingMapGenerator(4, 0.5, 8, 2) 300 | print(masker()) -------------------------------------------------------------------------------- /DMControl/src/masking_generator.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import random 7 | import math 8 | import numpy as np 9 | from numpy.core.shape_base import block 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class MaskingGenerator: 15 | ''' Borrowed from https://github.com/pengzhiliang/MAE-pytorch ''' 16 | def __init__( 17 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, 18 | min_aspect=0.3, max_aspect=None): 19 | if not isinstance(input_size, tuple): 20 | input_size = (input_size, ) * 2 21 | self.height, self.width = input_size 22 | 23 | self.num_patches = self.height * self.width 24 | self.num_masking_patches = num_masking_patches 25 | 26 | self.min_num_patches = min_num_patches 27 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 28 | 29 | max_aspect = max_aspect or 1 / min_aspect 30 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 31 | 32 | def __repr__(self): 33 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 34 | self.height, self.width, self.min_num_patches, self.max_num_patches, 35 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 36 | return repr_str 37 | 38 | def get_shape(self): 39 | return self.height, self.width 40 | 41 | def _mask(self, mask, max_mask_patches): 42 | delta = 0 43 | for attempt in range(10): 44 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 45 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 46 | h = int(round(math.sqrt(target_area * aspect_ratio))) 47 | w = int(round(math.sqrt(target_area / aspect_ratio))) 48 | if w < self.width and h < self.height: 49 | top = random.randint(0, self.height - h) 50 | left = random.randint(0, self.width - w) 51 | 52 | num_masked = mask[top: top + h, left: left + w].sum() 53 | # Overlap 54 | if 0 < h * w - num_masked <= max_mask_patches: 55 | for i in range(top, top + h): 56 | for j in range(left, left + w): 57 | if mask[i, j] == 0: 58 | mask[i, j] = 1 59 | delta += 1 60 | 61 | if delta > 0: 62 | break 63 | return delta 64 | 65 | def __call__(self): 66 | mask = np.zeros(shape=self.get_shape(), dtype=np.int) 67 | mask_count = 0 68 | while mask_count < self.num_masking_patches: 69 | max_mask_patches = self.num_masking_patches - mask_count 70 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 71 | 72 | delta = self._mask(mask, max_mask_patches) 73 | if delta == 0: 74 | break 75 | else: 76 | mask_count += delta 77 | 78 | return mask 79 | 80 | class RandomMaskingGenerator: 81 | ''' Borrowed from https://github.com/pengzhiliang/MAE-pytorch ''' 82 | def __init__(self, input_size, mask_ratio): 83 | if not isinstance(input_size, tuple): 84 | input_size = (input_size,) * 2 85 | 86 | self.height, self.width = input_size 87 | 88 | self.num_patches = self.height * self.width 89 | self.num_mask = int(mask_ratio * self.num_patches) 90 | 91 | def __repr__(self): 92 | repr_str = "Maks: total patches {}, mask patches {}".format( 93 | self.num_patches, self.num_mask 94 | ) 95 | return repr_str 96 | 97 | def __call__(self): 98 | mask = np.hstack([ 99 | np.zeros(self.num_patches - self.num_mask), 100 | np.ones(self.num_mask), 101 | ]) 102 | np.random.shuffle(mask) 103 | return mask.astype(np.bool) # [196] 104 | 105 | class RandomMaskingMapGenerator: 106 | def __init__(self, input_size, mask_ratio, image_size): 107 | if not isinstance(input_size, tuple): 108 | input_size = (input_size,) * 2 109 | 110 | self.height, self.width = input_size 111 | 112 | self.num_patches = self.height * self.width 113 | self.num_mask = int(mask_ratio * self.num_patches) 114 | 115 | self.image_size = image_size 116 | self.upsampler = nn.Upsample((image_size, image_size)) 117 | 118 | def __repr__(self): 119 | repr_str = "Maks: total patches {}, mask patches {}".format( 120 | self.num_patches, self.num_mask 121 | ) 122 | return repr_str 123 | 124 | def __call__(self): 125 | mask = np.hstack([ 126 | np.zeros(self.num_patches - self.num_mask), 127 | np.ones(self.num_mask), 128 | ]) 129 | np.random.shuffle(mask) 130 | mask = torch.from_numpy(mask).reshape(self.height, self.width) 131 | mask = self.upsampler(mask[None, None].float()) 132 | return mask # [196] 133 | 134 | class DiverseRandomMaskingMapGenerator: 135 | def __init__(self, input_size, mask_ratio, image_size, time_span): 136 | if not isinstance(input_size, tuple): 137 | input_size = (input_size,) * 2 138 | 139 | self.height, self.width = input_size 140 | 141 | self.num_patches = self.height * self.width 142 | self.num_mask = int(mask_ratio * self.num_patches) 143 | 144 | self.image_size = image_size 145 | self.upsampler = nn.Upsample((image_size, image_size)) 146 | 147 | self.time_span = time_span 148 | 149 | def __repr__(self): 150 | repr_str = "Maks: total patches {}, mask patches {}".format( 151 | self.num_patches, self.num_mask 152 | ) 153 | return repr_str 154 | 155 | def __call__(self): 156 | mask = np.hstack([ 157 | np.zeros(self.num_patches - self.num_mask), 158 | np.ones(self.num_mask), 159 | ]) 160 | mask = np.vstack([mask] * self.time_span) 161 | 162 | for i in range(self.time_span): 163 | mask[i] = mask[i, torch.randperm(mask.shape[-1])] 164 | 165 | mask = torch.from_numpy(mask)[:, :, None].reshape(self.time_span, self.height, self.width) 166 | mask = self.upsampler(mask[None].float()) 167 | return mask 168 | 169 | class RandomMaskingListGenerator: 170 | def __init__(self, list_len, mask_ratio): 171 | self.num_patches = list_len 172 | self.num_mask = int(mask_ratio * self.num_patches) 173 | 174 | def __repr__(self): 175 | repr_str = "Maks: total patches {}, mask patches {}".format( 176 | self.num_patches, self.num_mask 177 | ) 178 | return repr_str 179 | 180 | def __call__(self): 181 | mask = np.hstack([ 182 | np.zeros(self.num_patches - self.num_mask), 183 | np.ones(self.num_mask), 184 | ]) 185 | np.random.shuffle(mask) 186 | return mask.astype(np.bool) 187 | # return torch.from_numpy(mask).float() # [196] 188 | 189 | 190 | class RestMaskingListGenerator: 191 | def __init__(self, list_len, mask_ratio): 192 | self.num_patches = list_len 193 | self.num_mask = int(mask_ratio * self.num_patches) 194 | 195 | def __repr__(self): 196 | repr_str = "Maks: total patches {}, mask patches {}".format( 197 | self.num_patches, self.num_mask 198 | ) 199 | return repr_str 200 | 201 | def __call__(self): 202 | mask = np.hstack([ 203 | np.zeros(self.num_patches - self.num_mask), 204 | np.ones(self.num_mask), 205 | ]) 206 | # np.random.shuffle(mask) 207 | return mask.astype(np.bool) 208 | # return torch.from_numpy(mask).float() # [196] 209 | 210 | class RandomRestMaskingListGenerator: 211 | ''' 212 | t > offset, set to True; t < offset, set to False 213 | ''' 214 | def __init__(self, list_len): 215 | self.num_patches = list_len 216 | 217 | def __repr__(self): 218 | repr_str = "Maks: total patches {}".format( 219 | self.num_patches 220 | ) 221 | return repr_str 222 | 223 | def __call__(self): 224 | offset = np.random.randint(1, self.num_patches) 225 | mask = np.hstack([ 226 | np.zeros(offset), 227 | np.ones(self.num_patches - offset), 228 | ]) 229 | # np.random.shuffle(mask) 230 | return mask.astype(np.bool) 231 | # return torch.from_numpy(mask).float() # [196] 232 | 233 | class RandomBlockMaskingListGenerator: 234 | def __init__(self, list_len, mask_ratio, block_size): 235 | assert list_len % block_size == 0 236 | self.list_len = list_len 237 | self.block_size = block_size 238 | self.num_blocks = list_len // block_size 239 | 240 | self.num_mask = int(mask_ratio * self.num_blocks) 241 | 242 | def __call__(self): 243 | mask = np.hstack([ 244 | np.zeros(self.num_blocks - self.num_mask), 245 | np.ones(self.num_mask), 246 | ]) 247 | np.random.shuffle(mask) 248 | mask = np.repeat(mask, self.block_size) 249 | return mask.astype(np.bool) 250 | 251 | class RandomMaskGenerator: 252 | def __init__(self, list_len, block_size, mask_ratio): 253 | assert mask_ratio <= 1.0 254 | self.block_size = block_size 255 | self.num_blocks = list_len // block_size 256 | self.num_masked_blocks = int(mask_ratio * self.num_blocks) 257 | self.num_rest = list_len % block_size 258 | 259 | def __call__(self): 260 | mask = np.hstack([ 261 | np.zeros(self.num_blocks - self.num_masked_blocks), 262 | np.ones(self.num_masked_blocks), 263 | ]) 264 | np.random.shuffle(mask) 265 | mask = np.repeat(mask, self.block_size) 266 | mask = np.concatenate([mask, np.zeros(self.num_rest)], 0) 267 | mask = np.roll(mask, np.random.randint(0, self.block_size)) 268 | # mask[0] = 0 # set the first elsement always to unmasked 269 | return mask.astype(np.bool) 270 | 271 | class CubeMaskGenerator: 272 | def __init__(self, input_size, image_size, clip_size, block_size, mask_ratio): 273 | assert mask_ratio <= 1.0 274 | 275 | if not isinstance(input_size, tuple): 276 | input_size = (input_size,) * 2 277 | self.height, self.width = input_size 278 | self.num_patches = self.height * self.width 279 | self.num_mask = int(mask_ratio * self.num_patches) 280 | self.image_size = image_size 281 | self.upsampler = nn.Upsample((image_size, image_size)) 282 | 283 | self.block_size = block_size 284 | self.num_blocks = clip_size // block_size 285 | 286 | 287 | def __call__(self): 288 | mask = np.hstack([ 289 | np.zeros(self.num_patches - self.num_mask), 290 | np.ones(self.num_mask), 291 | ]) 292 | for i in range(self.num_blocks): 293 | np.random.shuffle(mask) 294 | cur_mask = torch.from_numpy(mask).reshape(self.height, self.width) 295 | cur_mask = self.upsampler(cur_mask[None, None].float()) # (1, 1, h, w) 296 | cur_mask = cur_mask.expand(self.block_size, *cur_mask.size()[1:]) 297 | cube_mask = torch.cat([cube_mask, cur_mask]) if i > 0 else cur_mask 298 | return cube_mask 299 | 300 | if __name__ == '__main__': 301 | masker = DiverseRandomMaskingMapGenerator(4, 0.5, 8, 2) 302 | print(masker()) -------------------------------------------------------------------------------- /DMControl/src/encoder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is borrowed from https://github.com/MishaLaskin/curl 3 | # -------------------------------------------------------- 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules import module 9 | 10 | from vit_modules import * 11 | from masking_generator import RandomMaskingGenerator 12 | 13 | def tie_weights(src, trg): 14 | assert type(src) == type(trg) 15 | trg.weight = src.weight 16 | trg.bias = src.bias 17 | 18 | 19 | # for 84 x 84 inputs 20 | OUT_DIM = {2: 39, 4: 35, 6: 31} 21 | # for 64 x 64 inputs 22 | OUT_DIM_64 = {2: 29, 4: 25, 6: 21} 23 | 24 | 25 | class PixelEncoder(nn.Module): 26 | """Convolutional encoder of pixels observations.""" 27 | def __init__(self, 28 | obs_shape, 29 | feature_dim, 30 | num_layers=2, 31 | num_filters=32, 32 | output_logits=False): 33 | super().__init__() 34 | 35 | assert len(obs_shape) == 3 36 | self.obs_shape = obs_shape 37 | self.feature_dim = feature_dim 38 | self.num_layers = num_layers 39 | 40 | self.convs = nn.ModuleList( 41 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]) 42 | for i in range(num_layers - 1): 43 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) 44 | 45 | out_dim = OUT_DIM_64[num_layers] if obs_shape[-1] == 64 else OUT_DIM[ 46 | num_layers] 47 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) 48 | self.ln = nn.LayerNorm(self.feature_dim) 49 | 50 | self.outputs = dict() 51 | self.output_logits = output_logits 52 | 53 | def reparameterize(self, mu, logstd): 54 | std = torch.exp(logstd) 55 | eps = torch.randn_like(std) 56 | return mu + eps * std 57 | 58 | def forward_conv(self, obs, flatten=True): 59 | obs = obs / 255. 60 | self.outputs['obs'] = obs 61 | 62 | conv = torch.relu(self.convs[0](obs)) 63 | self.outputs['conv1'] = conv 64 | 65 | for i in range(1, self.num_layers): 66 | conv = torch.relu(self.convs[i](conv)) 67 | self.outputs['conv%s' % (i + 1)] = conv 68 | 69 | h = conv.view(conv.size(0), -1) if flatten else conv 70 | return h 71 | 72 | def forward(self, obs, detach=False): 73 | h = self.forward_conv(obs) 74 | 75 | if detach: 76 | h = h.detach() 77 | 78 | h_fc = self.fc(h) 79 | self.outputs['fc'] = h_fc 80 | 81 | h_norm = self.ln(h_fc) 82 | self.outputs['ln'] = h_norm 83 | 84 | if self.output_logits: 85 | out = h_norm 86 | else: 87 | out = torch.tanh(h_norm) 88 | self.outputs['tanh'] = out 89 | 90 | return out 91 | 92 | def copy_conv_weights_from(self, source): 93 | """Tie convolutional layers""" 94 | # only tie conv layers 95 | for i in range(self.num_layers): 96 | tie_weights(src=source.convs[i], trg=self.convs[i]) 97 | 98 | def log(self, L, step, log_freq): 99 | if step % log_freq != 0: 100 | return 101 | 102 | for k, v in self.outputs.items(): 103 | L.log_histogram('train_encoder/%s_hist' % k, v, step) 104 | if len(v.shape) > 2: 105 | L.log_image('train_encoder/%s_img' % k, v[0], step) 106 | 107 | for i in range(self.num_layers): 108 | L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) 109 | L.log_param('train_encoder/fc', self.fc, step) 110 | L.log_param('train_encoder/ln', self.ln, step) 111 | 112 | 113 | class IdentityEncoder(nn.Module): 114 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters, *args): 115 | super().__init__() 116 | 117 | assert len(obs_shape) == 1 118 | self.feature_dim = obs_shape[0] 119 | 120 | def forward(self, obs, detach=False): 121 | return obs 122 | 123 | def copy_conv_weights_from(self, source): 124 | pass 125 | 126 | def log(self, L, step, log_freq): 127 | pass 128 | 129 | 130 | 131 | class PretrainVisionTransformerEncoder(nn.Module): 132 | """ Vision Transformer with support for patch or hybrid CNN input stage 133 | """ 134 | # def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 135 | # num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 136 | # drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, 137 | # use_learnable_pos_emb=False): 138 | def __init__(self, img_size=84, patch_size=7, in_chans=9, num_classes=0, embed_dim=441, depth=4, 139 | num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 140 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=0, 141 | use_learnable_pos_emb=False, feature_dim=50): 142 | super().__init__() 143 | self.num_classes = num_classes 144 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 145 | 146 | # embed_dim = 128 147 | # self.conv = nn.Conv2d(in_chans, embed_dim, patch_size, stride=patch_size) 148 | 149 | self.patch_embed = PatchEmbed( 150 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 151 | num_patches = self.patch_embed.num_patches 152 | 153 | # # TODO: Add the cls token 154 | # self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 155 | # num_patches += 1 156 | 157 | if use_learnable_pos_emb: 158 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 159 | else: 160 | # sine-cosine positional embeddings 161 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 162 | 163 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 164 | self.blocks = nn.ModuleList([ 165 | Block( 166 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 167 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 168 | init_values=init_values) 169 | for i in range(depth)]) 170 | self.norm = norm_layer(embed_dim) 171 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 172 | 173 | # self.state = nn.Sequential( 174 | # nn.Linear(embed_dim, feature_dim), 175 | # nn.LayerNorm(feature_dim) 176 | # ) 177 | 178 | if use_learnable_pos_emb: 179 | trunc_normal_(self.pos_embed, std=.02) 180 | 181 | # trunc_normal_(self.cls_token, std=.02) 182 | self.apply(self._init_weights) 183 | 184 | 185 | def _init_weights(self, m): 186 | if isinstance(m, nn.Linear): 187 | trunc_normal_(m.weight, std=.02) 188 | if isinstance(m, nn.Linear) and m.bias is not None: 189 | nn.init.constant_(m.bias, 0) 190 | elif isinstance(m, nn.LayerNorm): 191 | nn.init.constant_(m.bias, 0) 192 | nn.init.constant_(m.weight, 1.0) 193 | 194 | def get_num_layers(self): 195 | return len(self.blocks) 196 | 197 | @torch.jit.ignore 198 | def no_weight_decay(self): 199 | return {'pos_embed', 'cls_token'} 200 | 201 | def get_classifier(self): 202 | return self.head 203 | 204 | def reset_classifier(self, num_classes, global_pool=''): 205 | self.num_classes = num_classes 206 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 207 | 208 | def forward_features(self, x, mask): 209 | x = self.patch_embed(x) # (B, NumPatches*NumPatches, C=9*7*7) 210 | 211 | # cls_tokens = self.cls_token.expand(x.size(0), -1, -1) 212 | # x = torch.cat((cls_tokens, x), dim=1) 213 | 214 | x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() 215 | 216 | B, _, C = x.shape 217 | # x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible 218 | x_vis = x if mask is None else x[:, ~mask].reshape(B, -1, C) 219 | 220 | for blk in self.blocks: 221 | x_vis = blk(x_vis) 222 | 223 | x_vis = self.norm(x_vis) 224 | return x_vis 225 | 226 | def forward(self, x, mask=None): 227 | x = self.forward_features(x, mask) 228 | x = self.head(x) 229 | # x = x.mean(1) 230 | # # x = x[:, 0] 231 | # x = x.detach() if detach else x 232 | # x = self.state(x) 233 | return x 234 | 235 | class ViTEncoder(nn.Module): 236 | def __init__(self, img_size=84, patch_size=12, in_chans=9, num_classes=0, embed_dim=512, depth=4, 237 | num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 238 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=0, 239 | use_learnable_pos_emb=False, feature_dim=50): 240 | super().__init__() 241 | self.feature_dim = feature_dim 242 | 243 | self.vit = PretrainVisionTransformerEncoder( 244 | img_size, patch_size, in_chans, num_classes, embed_dim, depth, 245 | num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, 246 | drop_path_rate, norm_layer, init_values, use_learnable_pos_emb, feature_dim 247 | ) 248 | 249 | 250 | self.state = nn.Sequential( 251 | nn.Linear(embed_dim, feature_dim), 252 | nn.LayerNorm(feature_dim) 253 | ) 254 | 255 | def forward(self, x, mask=None, detach=False): 256 | x = x / 255. 257 | x = self.vit(x, mask) 258 | x = x.mean(1) 259 | x = x.detach() if detach else x 260 | # x = self.state(x) 261 | return x 262 | 263 | def copy_conv_weights_from(self, source): 264 | """Tie convolutional layers""" 265 | # only tie conv layers 266 | # for (src_child, tgt_child) in zip(source.vit.children(), self.vit.children()): 267 | for (src_module, tgt_module) in zip(source.vit.modules(), self.vit.modules()): 268 | if isinstance(src_module, nn.Module): 269 | try: 270 | # print("Tie: ", src_module) 271 | tie_weights(src=src_module, trg=tgt_module) 272 | except: 273 | # print("Skip: ", src_module) 274 | pass 275 | # for i in range(self.num_layers): 276 | # tie_weights(src=source.convs[i], trg=self.convs[i]) 277 | 278 | def log(self, L, step, log_freq): 279 | pass 280 | 281 | _AVAILABLE_ENCODERS = { 282 | 'pixel': PixelEncoder, 283 | # 'pixel': ViTEncoder, 284 | 'identity': IdentityEncoder 285 | } 286 | 287 | def make_encoder(encoder_type, 288 | obs_shape, 289 | feature_dim, 290 | num_layers, 291 | num_filters, 292 | output_logits=False): 293 | assert encoder_type in _AVAILABLE_ENCODERS 294 | return _AVAILABLE_ENCODERS[encoder_type](obs_shape, feature_dim, 295 | num_layers, num_filters, 296 | output_logits) 297 | # return _AVAILABLE_ENCODERS[encoder_type]( 298 | # img_size=obs_shape[1], 299 | # patch_size=8, 300 | # embed_dim=128, 301 | # depth=4, 302 | # num_heads=8, 303 | # feature_dim=feature_dim) 304 | 305 | 306 | if __name__ == '__main__': 307 | vit_encoder = PretrainVisionTransformerEncoder() 308 | x = torch.randn(2, 9, 84, 84) 309 | masked_position_generator = RandomMaskingGenerator(input_size=12, mask_ratio=0) 310 | mask = masked_position_generator() # (input_size*input_size, ) 311 | num_valid_patch = mask.sum() 312 | inv_mask = ~(mask.astype(np.bool)) 313 | # mask = mask[None] 314 | f = vit_encoder(x, mask.astype(np.bool)) -------------------------------------------------------------------------------- /DMControl/src/base_sac.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import copy 7 | import math 8 | 9 | import numpy as np 10 | from numpy.core.fromnumeric import shape 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn import Parameter 15 | from torch.nn import LayerNorm 16 | from kornia.augmentation import (CenterCrop, RandomAffine, RandomCrop, 17 | RandomResizedCrop) 18 | from kornia.filters import GaussianBlur2d 19 | 20 | import utils 21 | from utils import PositionalEmbedding, InverseSquareRootSchedule, AnneallingSchedule 22 | from encoder import make_encoder 23 | from transition_model import make_transition_model 24 | import torchvision.transforms._transforms_video as v_transform 25 | from masking_generator import RandomMaskingListGenerator, RandomMaskingMapGenerator 26 | from vit_modules import Block, trunc_normal_ 27 | 28 | LOG_FREQ = 10000 29 | 30 | 31 | def gaussian_logprob(noise, log_std): 32 | """Compute Gaussian log probability.""" 33 | residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) 34 | return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) 35 | 36 | 37 | def squash(mu, pi, log_pi): 38 | """Apply squashing function. 39 | See appendix C from https://arxiv.org/pdf/1812.05905.pdf. 40 | """ 41 | mu = torch.tanh(mu) 42 | if pi is not None: 43 | pi = torch.tanh(pi) 44 | if log_pi is not None: 45 | log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) 46 | return mu, pi, log_pi 47 | 48 | 49 | def weight_init(m): 50 | """Custom weight init for Conv2D and Linear layers.""" 51 | if isinstance(m, nn.Linear): 52 | nn.init.orthogonal_(m.weight.data) 53 | m.bias.data.fill_(0.0) 54 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 55 | # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf 56 | assert m.weight.size(2) == m.weight.size(3) 57 | m.weight.data.fill_(0.0) 58 | m.bias.data.fill_(0.0) 59 | mid = m.weight.size(2) // 2 60 | gain = nn.init.calculate_gain('relu') 61 | nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) 62 | 63 | 64 | class Actor(nn.Module): 65 | """MLP actor network.""" 66 | def __init__(self, obs_shape, action_shape, hidden_dim, encoder_type, 67 | encoder_feature_dim, log_std_min, log_std_max, num_layers, 68 | num_filters): 69 | super().__init__() 70 | 71 | self.encoder = make_encoder(encoder_type, 72 | obs_shape, 73 | encoder_feature_dim, 74 | num_layers, 75 | num_filters, 76 | output_logits=True) 77 | print(self.encoder) 78 | self.log_std_min = log_std_min 79 | self.log_std_max = log_std_max 80 | 81 | self.trunk = nn.Sequential( 82 | nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), 83 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 84 | nn.Linear(hidden_dim, 2 * action_shape[0])) 85 | 86 | self.outputs = dict() 87 | self.apply(weight_init) 88 | 89 | def forward(self, 90 | obs, 91 | compute_pi=True, 92 | compute_log_pi=True, 93 | detach_encoder=False): 94 | obs = self.encoder(obs, detach=detach_encoder) 95 | 96 | mu, log_std = self.trunk(obs).chunk(2, dim=-1) 97 | 98 | # constrain log_std inside [log_std_min, log_std_max] 99 | log_std = torch.tanh(log_std) 100 | log_std = self.log_std_min + 0.5 * (self.log_std_max - 101 | self.log_std_min) * (log_std + 1) 102 | 103 | self.outputs['mu'] = mu 104 | self.outputs['std'] = log_std.exp() 105 | 106 | if compute_pi: 107 | std = log_std.exp() 108 | noise = torch.randn_like(mu) 109 | pi = mu + noise * std 110 | else: 111 | pi = None 112 | entropy = None 113 | 114 | if compute_log_pi: 115 | log_pi = gaussian_logprob(noise, log_std) 116 | else: 117 | log_pi = None 118 | 119 | mu, pi, log_pi = squash(mu, pi, log_pi) 120 | 121 | return mu, pi, log_pi, log_std 122 | 123 | def log(self, L, step, log_freq=LOG_FREQ): 124 | if step % log_freq != 0: 125 | return 126 | 127 | for k, v in self.outputs.items(): 128 | L.log_histogram('train_actor/%s_hist' % k, v, step) 129 | 130 | L.log_param('train_actor/fc1', self.trunk[0], step) 131 | L.log_param('train_actor/fc2', self.trunk[2], step) 132 | L.log_param('train_actor/fc3', self.trunk[4], step) 133 | 134 | 135 | class QFunction(nn.Module): 136 | """MLP for q-function.""" 137 | def __init__(self, obs_dim, action_dim, hidden_dim): 138 | super().__init__() 139 | 140 | self.trunk = nn.Sequential(nn.Linear(obs_dim + action_dim, 141 | hidden_dim), nn.ReLU(), 142 | nn.Linear(hidden_dim, hidden_dim), 143 | nn.ReLU(), nn.Linear(hidden_dim, 1)) 144 | 145 | def forward(self, obs, action): 146 | assert obs.size(0) == action.size(0) 147 | 148 | obs_action = torch.cat([obs, action], dim=1) 149 | return self.trunk(obs_action) 150 | 151 | 152 | class Critic(nn.Module): 153 | """Critic network, employes two q-functions.""" 154 | def __init__(self, obs_shape, action_shape, hidden_dim, encoder_type, 155 | encoder_feature_dim, num_layers, num_filters): 156 | super().__init__() 157 | 158 | self.encoder = make_encoder(encoder_type, 159 | obs_shape, 160 | encoder_feature_dim, 161 | num_layers, 162 | num_filters, 163 | output_logits=True) 164 | 165 | self.Q1 = QFunction(self.encoder.feature_dim, action_shape[0], 166 | hidden_dim) 167 | self.Q2 = QFunction(self.encoder.feature_dim, action_shape[0], 168 | hidden_dim) 169 | 170 | self.outputs = dict() 171 | self.apply(weight_init) 172 | 173 | def forward(self, obs, action, detach_encoder=False): 174 | # detach_encoder allows to stop gradient propogation to encoder 175 | obs = self.encoder(obs, detach=detach_encoder) 176 | 177 | q1 = self.Q1(obs, action) 178 | q2 = self.Q2(obs, action) 179 | 180 | self.outputs['q1'] = q1 181 | self.outputs['q2'] = q2 182 | 183 | return q1, q2 184 | 185 | def log(self, L, step, log_freq=LOG_FREQ): 186 | if step % log_freq != 0: 187 | return 188 | 189 | self.encoder.log(L, step, log_freq) 190 | 191 | for k, v in self.outputs.items(): 192 | L.log_histogram('train_critic/%s_hist' % k, v, step) 193 | 194 | for i in range(3): 195 | L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) 196 | L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) 197 | 198 | 199 | class Intensity(nn.Module): 200 | def __init__(self, scale): 201 | super().__init__() 202 | self.scale = scale 203 | 204 | def forward(self, x): 205 | r = torch.randn((x.size(0), 1, 1, 1), device=x.device) 206 | noise = 1.0 + (self.scale * r.clamp(-2.0, 2.0)) 207 | return x * noise 208 | 209 | class BaseSacAgent(object): 210 | def __init__( 211 | self, 212 | obs_shape, 213 | action_shape, 214 | device, 215 | augmentation=[], 216 | transition_model_type='probabilistic', 217 | transition_model_layer_width=512, 218 | jumps=5, 219 | latent_dim=512, 220 | time_offset=0, 221 | momentum_tau=1.0, 222 | aug_prob=1.0, 223 | auxiliary_task_lr=1e-3, 224 | action_aug_type='random', 225 | num_aug_actions=None, 226 | loss_space='y', 227 | bp_mode='gt', 228 | cycle_steps=5, 229 | cycle_mode='fp+cycle', 230 | fp_loss_weight=1.0, 231 | bp_loss_weight=1.0, 232 | rc_loss_weight=1.0, 233 | vc_loss_weight=1.0, 234 | reward_loss_weight=0.0, 235 | # from curl 236 | hidden_dim=256, 237 | discount=0.99, 238 | init_temperature=0.01, 239 | alpha_lr=1e-3, 240 | alpha_beta=0.9, 241 | actor_lr=1e-3, 242 | actor_beta=0.9, 243 | actor_log_std_min=-10, 244 | actor_log_std_max=2, 245 | actor_update_freq=2, 246 | critic_lr=1e-3, 247 | critic_beta=0.9, 248 | critic_tau=0.005, 249 | critic_target_update_freq=2, 250 | encoder_type='pixel', 251 | encoder_feature_dim=50, 252 | encoder_lr=1e-3, 253 | encoder_tau=0.005, 254 | num_layers=4, 255 | num_filters=32, 256 | cpc_update_freq=1, 257 | log_interval=100, 258 | detach_encoder=False, 259 | curl_latent_dim=128, 260 | sigma=0.05): 261 | self.device = device 262 | self.discount = discount 263 | self.critic_tau = critic_tau 264 | self.encoder_tau = encoder_tau 265 | self.actor_update_freq = actor_update_freq 266 | self.critic_target_update_freq = critic_target_update_freq 267 | self.log_interval = log_interval 268 | self.image_size = obs_shape[-1] 269 | self.curl_latent_dim = curl_latent_dim 270 | self.detach_encoder = detach_encoder 271 | self.encoder_type = encoder_type 272 | self.encoder_feature_dim = encoder_feature_dim 273 | 274 | self.jumps = jumps 275 | self.momentum_tau = momentum_tau 276 | 277 | self.actor = Actor(obs_shape, action_shape, hidden_dim, encoder_type, 278 | encoder_feature_dim, actor_log_std_min, 279 | actor_log_std_max, num_layers, 280 | num_filters).to(device) 281 | 282 | self.critic = Critic(obs_shape, action_shape, hidden_dim, encoder_type, 283 | encoder_feature_dim, num_layers, 284 | num_filters).to(device) 285 | 286 | self.critic_target = Critic(obs_shape, action_shape, hidden_dim, 287 | encoder_type, encoder_feature_dim, 288 | num_layers, num_filters).to(device) 289 | 290 | self.critic_target.load_state_dict(self.critic.state_dict()) 291 | 292 | # tie encoders between actor and critic, and CURL and critic 293 | self.actor.encoder.copy_conv_weights_from(self.critic.encoder) 294 | 295 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) 296 | self.log_alpha.requires_grad = True 297 | # set target entropy to -|A| 298 | self.target_entropy = -np.prod(action_shape) 299 | 300 | # optimizers 301 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 302 | lr=actor_lr, 303 | betas=(actor_beta, 0.999)) 304 | 305 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 306 | lr=critic_lr, 307 | betas=(critic_beta, 0.999)) 308 | 309 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], 310 | lr=alpha_lr, 311 | betas=(alpha_beta, 0.999)) 312 | 313 | self.train() 314 | self.critic_target.train() 315 | 316 | 317 | def train(self, training=True): 318 | self.training = training 319 | self.actor.train(training) 320 | self.critic.train(training) 321 | 322 | @property 323 | def alpha(self): 324 | return self.log_alpha.exp() 325 | 326 | def select_action(self, obs): 327 | with torch.no_grad(): 328 | obs = torch.FloatTensor(obs).to(self.device) 329 | obs = obs.unsqueeze(0) 330 | mu, _, _, _ = self.actor(obs, 331 | compute_pi=False, 332 | compute_log_pi=False) 333 | return mu.cpu().data.numpy().flatten() 334 | 335 | def sample_action(self, obs): 336 | if obs.shape[-1] != self.image_size: 337 | obs = utils.center_crop_image(obs, self.image_size) 338 | 339 | with torch.no_grad(): 340 | obs = torch.FloatTensor(obs).to(self.device) 341 | obs = obs.unsqueeze(0) 342 | mu, pi, _, _ = self.actor(obs, compute_log_pi=False) 343 | return pi.cpu().data.numpy().flatten() 344 | 345 | def update_critic(self, obs, action, reward, next_obs, not_done, L, step): 346 | with torch.no_grad(): 347 | _, policy_action, log_pi, _ = self.actor(next_obs) 348 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) 349 | target_V = torch.min(target_Q1, 350 | target_Q2) - self.alpha.detach() * log_pi 351 | target_Q = reward + (not_done * self.discount * target_V) 352 | 353 | # get current Q estimates 354 | current_Q1, current_Q2 = self.critic( 355 | obs, action, detach_encoder=self.detach_encoder) 356 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( 357 | current_Q2, target_Q) 358 | if step % self.log_interval == 0: 359 | L.log('train_critic/loss', critic_loss, step) 360 | 361 | # Optimize the critic 362 | self.critic_optimizer.zero_grad() 363 | critic_loss.backward() 364 | self.critic_optimizer.step() 365 | 366 | self.critic.log(L, step) 367 | 368 | def update_actor_and_alpha(self, obs, L, step): 369 | # detach encoder, so we don't update it with the actor loss 370 | _, pi, log_pi, log_std = self.actor(obs, detach_encoder=True) 371 | actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True) 372 | 373 | actor_Q = torch.min(actor_Q1, actor_Q2) 374 | actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() 375 | 376 | if step % self.log_interval == 0: 377 | L.log('train_actor/loss', actor_loss, step) 378 | L.log('train_actor/target_entropy', self.target_entropy, step) 379 | entropy = 0.5 * log_std.shape[1] * \ 380 | (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) 381 | if step % self.log_interval == 0: 382 | L.log('train_actor/entropy', entropy.mean(), step) 383 | 384 | # optimize the actor 385 | self.actor_optimizer.zero_grad() 386 | actor_loss.backward() 387 | self.actor_optimizer.step() 388 | 389 | self.actor.log(L, step) 390 | 391 | self.log_alpha_optimizer.zero_grad() 392 | alpha_loss = (self.alpha * 393 | (-log_pi - self.target_entropy).detach()).mean() 394 | if step % self.log_interval == 0: 395 | L.log('train_alpha/loss', alpha_loss, step) 396 | L.log('train_alpha/value', self.alpha, step) 397 | alpha_loss.backward() 398 | self.log_alpha_optimizer.step() 399 | 400 | def update(self, replay_buffer, L, step): 401 | if self.encoder_type == 'pixel': 402 | elements = replay_buffer.sample_spr() 403 | obs, action, reward, next_obs, not_done, mtm_kwargs = elements 404 | else: 405 | elements = replay_buffer.sample_proprio() 406 | obs, action, reward, next_obs, not_done = elements 407 | 408 | if step % self.log_interval == 0: 409 | L.log('train/batch_reward', reward.mean(), step) 410 | 411 | self.update_critic(obs, action, reward, next_obs, not_done, L, step) 412 | 413 | if step % self.actor_update_freq == 0: 414 | self.update_actor_and_alpha(obs, L, step) 415 | 416 | if step % self.critic_target_update_freq == 0: 417 | utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, 418 | self.critic_tau) 419 | utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, 420 | self.critic_tau) 421 | utils.soft_update_params(self.critic.encoder, 422 | self.critic_target.encoder, 423 | self.encoder_tau) 424 | 425 | def save(self, model_dir, step): 426 | torch.save(self.actor.state_dict(), 427 | '%s/actor_%s.pt' % (model_dir, step)) 428 | torch.save(self.critic.state_dict(), 429 | '%s/critic_%s.pt' % (model_dir, step)) 430 | 431 | def save_cycdm(self, model_dir, step): 432 | pass 433 | 434 | def load(self, model_dir, step): 435 | self.actor.load_state_dict( 436 | torch.load('%s/actor_%s.pt' % (model_dir, step))) 437 | self.critic.load_state_dict( 438 | torch.load('%s/critic_%s.pt' % (model_dir, step))) 439 | 440 | 441 | -------------------------------------------------------------------------------- /Atari/src/algos.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2022 Microsoft 3 | # Licensed under The MIT License 4 | # -------------------------------------------------------- 5 | 6 | import copy 7 | import torch 8 | import numpy as np 9 | 10 | from rlpyt.utils.collections import namedarraytuple 11 | from collections import namedtuple 12 | from rlpyt.algos.dqn.cat_dqn import CategoricalDQN 13 | from rlpyt.utils.tensor import select_at_indexes, valid_mean 14 | from rlpyt.algos.utils import valid_from_done 15 | from rlpyt.utils.logging import logger 16 | from src.rlpyt_buffer import AsyncPrioritizedSequenceReplayFrameBufferExtended, \ 17 | AsyncUniformSequenceReplayFrameBufferExtended 18 | from src.models import from_categorical, to_categorical 19 | 20 | from src.masking_generator import CubeMaskGenerator 21 | 22 | SamplesToBuffer = namedarraytuple("SamplesToBuffer", 23 | ["observation", "action", "reward", "done"]) 24 | ModelSamplesToBuffer = namedarraytuple("SamplesToBuffer", 25 | ["observation", "action", "reward", "done", "value"]) 26 | 27 | OptInfo = namedtuple("OptInfo", ["loss", "gradNorm", "tdAbsErr"]) 28 | ModelOptInfo = namedtuple("OptInfo", ["loss", "gradNorm", 29 | "tdAbsErr", 30 | "modelRLLoss", 31 | "RewardLoss", 32 | "modelGradNorm", 33 | # "SPRLoss", 34 | # "ModelSPRLoss"]) 35 | "MLRLoss"]) 36 | EPS = 1e-6 # (NaN-guard) 37 | 38 | 39 | class PVCategoricalDQN(CategoricalDQN): 40 | """Distributional DQN with fixed probability bins for the Q-value of each 41 | action, a.k.a. categorical.""" 42 | 43 | def __init__(self, 44 | t0_spr_loss_weight=1., 45 | model_rl_weight=1., 46 | reward_loss_weight=1., 47 | model_spr_weight=1., 48 | time_offset=0, 49 | distributional=1, 50 | jumps=0, 51 | rc_weight=1., 52 | vc_weight=1., 53 | cycle_jumps=1, 54 | fp_weight=1, 55 | bp_weight=1, 56 | mlr_weight=1, 57 | warmup=50000, 58 | bp_warm=False, 59 | **kwargs): 60 | super().__init__(**kwargs) 61 | self.opt_info_fields = tuple(f for f in ModelOptInfo._fields) # copy 62 | self.t0_spr_loss_weight = t0_spr_loss_weight 63 | self.model_spr_weight = model_spr_weight 64 | self.jumps = jumps 65 | self.cycle_jumps = cycle_jumps 66 | self.fp_weight = fp_weight * self.jumps # * K (as we mean it across time later) 67 | self.bp_weight = bp_weight * self.jumps 68 | self.mlr_weight = mlr_weight 69 | self.bp_warm = bp_warm 70 | 71 | self.reward_loss_weight = reward_loss_weight 72 | self.model_rl_weight = model_rl_weight 73 | self.time_offset = time_offset 74 | self.warmup = warmup 75 | 76 | # img_size = 84 77 | # mask_ratio = 0.5 78 | # block_size = 4 79 | # patch_size = 12 80 | # input_size = img_size // patch_size 81 | # self.masker = CubeMaskGenerator( 82 | # input_size=input_size, image_size=img_size, clip_size=self.jumps+1, \ 83 | # block_size=block_size, mask_ratio=mask_ratio) # 1 for mask, num_grid=input_size 84 | 85 | if not distributional: 86 | self.rl_loss = self.dqn_rl_loss 87 | else: 88 | self.rl_loss = self.dist_rl_loss 89 | 90 | self.rc_weight = rc_weight 91 | self.vc_weight = vc_weight 92 | 93 | def initialize_replay_buffer(self, examples, batch_spec, async_=False): 94 | example_to_buffer = ModelSamplesToBuffer( 95 | observation=examples["observation"], 96 | action=examples["action"], 97 | reward=examples["reward"], 98 | done=examples["done"], 99 | value=examples["agent_info"].p, 100 | ) 101 | replay_kwargs = dict( 102 | example=example_to_buffer, 103 | size=self.replay_size, 104 | B=batch_spec.B, 105 | batch_T=self.jumps+1+self.time_offset, 106 | discount=self.discount, 107 | n_step_return=self.n_step_return, 108 | rnn_state_interval=0, 109 | ) 110 | 111 | if self.prioritized_replay: 112 | replay_kwargs['alpha'] = self.pri_alpha 113 | replay_kwargs['beta'] = self.pri_beta_init 114 | # replay_kwargs["input_priorities"] = self.input_priorities 115 | buffer = AsyncPrioritizedSequenceReplayFrameBufferExtended(**replay_kwargs) 116 | else: 117 | buffer = AsyncUniformSequenceReplayFrameBufferExtended(**replay_kwargs) 118 | 119 | self.replay_buffer = buffer 120 | 121 | def optim_initialize(self, rank=0): 122 | """Called in initilize or by async runner after forking sampler.""" 123 | self.rank = rank 124 | try: 125 | # We're probably dealing with DDP 126 | self.optimizer = self.OptimCls(self.agent.model.module.parameters(), 127 | lr=self.learning_rate, **self.optim_kwargs) 128 | self.model = self.agent.model.module 129 | except: 130 | self.optimizer = self.OptimCls(self.agent.model.parameters(), 131 | lr=self.learning_rate, **self.optim_kwargs) 132 | self.model = self.agent.model 133 | if self.initial_optim_state_dict is not None: 134 | self.optimizer.load_state_dict(self.initial_optim_state_dict) 135 | if self.prioritized_replay: 136 | self.pri_beta_itr = max(1, self.pri_beta_steps // self.sampler_bs) 137 | 138 | def samples_to_buffer(self, samples): 139 | """Defines how to add data from sampler into the replay buffer. Called 140 | in optimize_agent() if samples are provided to that method. In 141 | asynchronous mode, will be called in the memory_copier process.""" 142 | return ModelSamplesToBuffer( 143 | observation=samples.env.observation, 144 | action=samples.agent.action, 145 | reward=samples.env.reward, 146 | done=samples.env.done, 147 | value=samples.agent.agent_info.p, 148 | ) 149 | 150 | def optimize_agent(self, itr, samples=None, sampler_itr=None): 151 | """ 152 | Extracts the needed fields from input samples and stores them in the 153 | replay buffer. Then samples from the replay buffer to train the agent 154 | by gradient updates (with the number of updates determined by replay 155 | ratio, sampler batch size, and training batch size). If using prioritized 156 | replay, updates the priorities for sampled training batches. 157 | """ 158 | itr = itr if sampler_itr is None else sampler_itr # Async uses sampler_itr.= 159 | if samples is not None: 160 | samples_to_buffer = self.samples_to_buffer(samples) 161 | self.replay_buffer.append_samples(samples_to_buffer) 162 | opt_info = ModelOptInfo(*([] for _ in range(len(ModelOptInfo._fields)))) 163 | if itr < self.min_itr_learn: 164 | return opt_info 165 | 166 | # Gaussian ramp up for cycle loss 167 | if self.warmup > 0 and itr < self.warmup: 168 | warm_weight = np.exp(-5 * ((1 - (itr+1)/self.warmup)**2)) 169 | # real_rc_weight = self.rc_weight * warm_weight 170 | real_vc_weight = self.vc_weight * warm_weight 171 | real_bp_weight = self.bp_weight * (1 - warm_weight) 172 | else: 173 | # real_rc_weight = self.rc_weight 174 | real_vc_weight = self.vc_weight 175 | real_bp_weight = 0 176 | real_rc_weight = self.rc_weight 177 | 178 | for _ in range(self.updates_per_optimize): 179 | samples_from_replay = self.replay_buffer.sample_batch(self.batch_size) 180 | loss, td_abs_errors, model_rl_loss, mlr_loss = self.loss(samples_from_replay) 181 | 182 | total_loss = loss + self.model_rl_weight * model_rl_loss 183 | total_loss = total_loss + self.mlr_weight * mlr_loss 184 | 185 | if itr % 500 == 0: 186 | print("Itr {}, RL loss: {:.6f}, MLR loss: {:.6f}"\ 187 | .format(itr, loss.item(), mlr_loss.item()) 188 | ) 189 | # print(real_bp_weight, real_vc_weight) 190 | 191 | self.optimizer.zero_grad() 192 | total_loss.backward() 193 | grad_norm = torch.nn.utils.clip_grad_norm_( 194 | self.model.stem_parameters(), self.clip_grad_norm) 195 | if len(list(self.model.transformer.parameters())) > 0: 196 | model_grad_norm = torch.nn.utils.clip_grad_norm_( 197 | self.model.transformer.parameters(), self.clip_grad_norm) 198 | else: 199 | model_grad_norm = 0 200 | self.optimizer.step() 201 | if self.prioritized_replay: 202 | self.replay_buffer.update_batch_priorities(td_abs_errors) 203 | opt_info.loss.append(loss.item()) 204 | opt_info.gradNorm.append(torch.tensor(grad_norm).item()) # grad_norm is a float sometimes, so wrap in tensor 205 | opt_info.modelRLLoss.append(model_rl_loss.item()) 206 | # opt_info.RewardLoss.append(reward_loss.item()) 207 | opt_info.modelGradNorm.append(torch.tensor(model_grad_norm).item()) 208 | opt_info.MLRLoss.append(mlr_loss.item()) 209 | opt_info.tdAbsErr.extend(td_abs_errors[::8].numpy()) # Downsample. 210 | self.update_counter += 1 211 | if self.update_counter % self.target_update_interval == 0: 212 | self.agent.update_target(self.target_update_tau) 213 | self.update_itr_hyperparams(itr) 214 | return opt_info 215 | 216 | def dqn_rl_loss(self, qs, samples, index): 217 | """ 218 | Computes the Q-learning loss, based on: 0.5 * (Q - target_Q) ^ 2. 219 | Implements regular DQN or Double-DQN for computing target_Q values 220 | using the agent's target network. Computes the Huber loss using 221 | ``delta_clip``, or if ``None``, uses MSE. When using prioritized 222 | replay, multiplies losses by importance sample weights. 223 | 224 | Input ``samples`` have leading batch dimension [B,..] (but not time). 225 | 226 | Calls the agent to compute forward pass on training inputs, and calls 227 | ``agent.target()`` to compute target values. 228 | 229 | Returns loss and TD-absolute-errors for use in prioritization. 230 | 231 | Warning: 232 | If not using mid_batch_reset, the sampler will only reset environments 233 | between iterations, so some samples in the replay buffer will be 234 | invalid. This case is not supported here currently. 235 | """ 236 | q = select_at_indexes(samples.all_action[index+1], qs).cpu() 237 | with torch.no_grad(): 238 | target_qs = self.agent.target(samples.all_observation[index + self.n_step_return], 239 | samples.all_action[index + self.n_step_return], 240 | samples.all_reward[index + self.n_step_return]) # [B,A,P'] 241 | if self.double_dqn: 242 | next_qs = self.agent(samples.all_observation[index + self.n_step_return], 243 | samples.all_action[index + self.n_step_return], 244 | samples.all_reward[index + self.n_step_return]) # [B,A,P'] 245 | next_a = torch.argmax(next_qs, dim=-1) 246 | target_q = select_at_indexes(next_a, target_qs) 247 | else: 248 | target_q = torch.max(target_qs, dim=-1).values 249 | 250 | disc_target_q = (self.discount ** self.n_step_return) * target_q 251 | y = samples.return_[index] + (1 - samples.done_n[index].float()) * disc_target_q 252 | 253 | delta = y - q 254 | losses = 0.5 * delta ** 2 255 | abs_delta = abs(delta) 256 | if self.delta_clip > 0: # Huber loss. 257 | b = self.delta_clip * (abs_delta - self.delta_clip / 2) 258 | losses = torch.where(abs_delta <= self.delta_clip, losses, b) 259 | td_abs_errors = abs_delta.detach() 260 | if self.delta_clip > 0: 261 | td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip) 262 | return losses, td_abs_errors 263 | 264 | def dist_rl_loss(self, log_pred_ps, samples, index): 265 | delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) 266 | z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms) 267 | # Make 2-D tensor of contracted z_domain for each data point, 268 | # with zeros where next value should not be added. 269 | next_z = z * (self.discount ** self.n_step_return) # [P'] 270 | next_z = torch.ger(1 - samples.done_n[index].float(), next_z) # [B,P'] 271 | ret = samples.return_[index].unsqueeze(1) # [B,1] 272 | next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] 273 | 274 | z_bc = z.view(1, -1, 1) # [1,P,1] 275 | next_z_bc = next_z.unsqueeze(1) # [B,1,P'] 276 | abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z 277 | projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. 278 | # projection_coeffs is a 3-D tensor: [B,P,P'] 279 | # dim-0: independent data entries 280 | # dim-1: base_z atoms (remains after projection) 281 | # dim-2: next_z atoms (summed in projection) 282 | 283 | with torch.no_grad(): 284 | target_ps = self.agent.target(samples.all_observation[index + self.n_step_return], 285 | samples.all_action[index + self.n_step_return], 286 | samples.all_reward[index + self.n_step_return]) # [B,A,P'] 287 | if self.double_dqn: 288 | next_ps = self.agent(samples.all_observation[index + self.n_step_return], 289 | samples.all_action[index + self.n_step_return], 290 | samples.all_reward[index + self.n_step_return]) # [B,A,P'] 291 | next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] 292 | next_a = torch.argmax(next_qs, dim=-1) # [B] 293 | else: 294 | target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] 295 | next_a = torch.argmax(target_qs, dim=-1) # [B] 296 | target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] 297 | target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] 298 | target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] 299 | p = select_at_indexes(samples.all_action[index + 1].squeeze(-1), 300 | log_pred_ps.cpu()) # [B,P] 301 | # p = torch.clamp(p, EPS, 1) # NaN-guard. 302 | losses = -torch.sum(target_p * p, dim=1) # Cross-entropy. 303 | 304 | target_p = torch.clamp(target_p, EPS, 1) 305 | KL_div = torch.sum(target_p * 306 | (torch.log(target_p) - p.detach()), dim=1) 307 | KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. 308 | 309 | return losses, KL_div.detach() 310 | 311 | def loss(self, samples): 312 | """ 313 | Computes the Distributional Q-learning loss, based on projecting the 314 | discounted rewards + target Q-distribution into the current Q-domain, 315 | with cross-entropy loss. 316 | 317 | Returns loss and KL-divergence-errors for use in prioritization. 318 | """ 319 | if self.model.noisy: 320 | self.model.head.reset_noise() 321 | # start = time.time() 322 | # spr_loss: (1+jumps, bs); idm or bdm loss: (bs, ) 323 | # mask = self.masker() # [jumps+1, 1, 84, 84], T = jumps+1 324 | # mask = mask[:, None, None].expand(mask.size(0), *samples.all_observation.size()[1:]) 325 | # # print(mask.size()) # [12, B, 4, 1, 84, 84] 326 | # print(samples.all_observation.size()) # [2T-1, B, 4, 1, 84, 84] 327 | log_pred_ps, pred_rew, mlr_loss = self.agent( 328 | samples.all_observation.to(self.agent.device), # (T,bs,4,1,100,100) 329 | samples.all_action.to(self.agent.device), 330 | samples.all_reward.to(self.agent.device), 331 | # mask.to(self.agent.device), 332 | train=True) # [B,A,P] 333 | 334 | rl_loss, KL = self.rl_loss(log_pred_ps[0], samples, index=0) 335 | # if len(pred_rew) > 0: 336 | # pred_rew = torch.stack(pred_rew, 0) 337 | # with torch.no_grad(): 338 | # reward_target = to_categorical( 339 | # samples.all_reward[:self.jumps+1].flatten().to(self.agent.device), 340 | # limit=1).view(*pred_rew.shape) 341 | # reward_loss = -torch.sum(reward_target * pred_rew, 2).mean(0).cpu() 342 | # else: 343 | # reward_loss = torch.zeros(samples.all_observation.shape[1],) 344 | model_rl_loss = torch.zeros(samples.all_observation.size(1)) 345 | 346 | if self.model_rl_weight > 0: 347 | for i in range(1, self.jumps+1): 348 | jump_rl_loss, model_KL = self.rl_loss(log_pred_ps[i], 349 | samples, 350 | i) 351 | model_rl_loss = model_rl_loss + jump_rl_loss 352 | 353 | mlr_loss = mlr_loss.cpu() 354 | 355 | if self.prioritized_replay: 356 | weights = samples.is_weights 357 | # reward_loss = reward_loss * weights 358 | 359 | # RL losses are no longer scaled in the c51 function 360 | rl_loss = rl_loss * weights 361 | model_rl_loss = model_rl_loss * weights 362 | mlr_loss = mlr_loss * weights 363 | 364 | return rl_loss.mean(), KL, model_rl_loss.mean(), mlr_loss.mean() 365 | -------------------------------------------------------------------------------- /Atari/src/rlpyt_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is borrowed from https://github.com/mila-iqia 3 | # -------------------------------------------------------- 4 | 5 | from rlpyt.samplers.base import BaseSampler 6 | from rlpyt.samplers.buffer import build_samples_buffer 7 | from rlpyt.samplers.parallel.cpu.collectors import CpuResetCollector 8 | from rlpyt.samplers.serial.collectors import SerialEvalCollector 9 | from rlpyt.utils.buffer import buffer_from_example, torchify_buffer, numpify_buffer 10 | from rlpyt.utils.logging import logger 11 | from rlpyt.utils.quick_args import save__init__args 12 | from rlpyt.utils.seed import set_seed 13 | from rlpyt.runners.minibatch_rl import MinibatchRlEval 14 | 15 | import wandb 16 | import psutil 17 | 18 | import torch 19 | import numpy as np 20 | import time 21 | 22 | 23 | atari_human_scores = dict( 24 | alien=7127.7, amidar=1719.5, assault=742.0, asterix=8503.3, 25 | bank_heist=753.1, battle_zone=37187.5, boxing=12.1, 26 | breakout=30.5, chopper_command=7387.8, crazy_climber=35829.4, 27 | demon_attack=1971.0, freeway=29.6, frostbite=4334.7, 28 | gopher=2412.5, hero=30826.4, jamesbond=302.8, kangaroo=3035.0, 29 | krull=2665.5, kung_fu_master=22736.3, ms_pacman=6951.6, pong=14.6, 30 | private_eye=69571.3, qbert=13455.0, road_runner=7845.0, 31 | seaquest=42054.7, up_n_down=11693.2 32 | ) 33 | 34 | atari_der_scores = dict( 35 | alien=739.9, amidar=188.6, assault=431.2, asterix=470.8, 36 | bank_heist=51.0, battle_zone=10124.6, boxing=0.2, 37 | breakout=1.9, chopper_command=861.8, crazy_climber=16185.3, 38 | demon_attack=508, freeway=27.9, frostbite=866.8, 39 | gopher=349.5, hero=6857.0, jamesbond=301.6, 40 | kangaroo=779.3, krull=2851.5, kung_fu_master=14346.1, 41 | ms_pacman=1204.1, pong=-19.3, private_eye=97.8, qbert=1152.9, 42 | road_runner=9600.0, seaquest=354.1, up_n_down=2877.4, 43 | ) 44 | 45 | atari_nature_scores = dict( 46 | alien=3069, amidar=739.5, assault=3359, 47 | asterix=6012, bank_heist=429.7, battle_zone=26300., 48 | boxing=71.8, breakout=401.2, chopper_command=6687., 49 | crazy_climber=114103, demon_attack=9711., freeway=30.3, 50 | frostbite=328.3, gopher=8520., hero=19950., jamesbond=576.7, 51 | kangaroo=6740., krull=3805., kung_fu_master=23270., 52 | ms_pacman=2311., pong=18.9, private_eye=1788., 53 | qbert=10596., road_runner=18257., seaquest=5286., up_n_down=8456. 54 | ) 55 | 56 | atari_random_scores = dict( 57 | alien=227.8, amidar=5.8, assault=222.4, 58 | asterix=210.0, bank_heist=14.2, battle_zone=2360.0, 59 | boxing=0.1, breakout=1.7, chopper_command=811.0, 60 | crazy_climber=10780.5, demon_attack=152.1, freeway=0.0, 61 | frostbite=65.2, gopher=257.6, hero=1027.0, jamesbond=29.0, 62 | kangaroo=52.0, krull=1598.0, kung_fu_master=258.5, 63 | ms_pacman=307.3, pong=-20.7, private_eye=24.9, 64 | qbert=163.9, road_runner=11.5, seaquest=68.4, up_n_down=533.4 65 | ) 66 | 67 | 68 | def maybe_update_summary(key, value): 69 | if key not in wandb.run.summary.keys(): 70 | wandb.run.summary[key] = value 71 | else: 72 | wandb.run.summary[key] = max(value, wandb.run.summary[key]) 73 | 74 | 75 | class MinibatchRlEvalWandb(MinibatchRlEval): 76 | 77 | def __init__(self, final_eval_only=False, *args, **kwargs): 78 | super().__init__(*args, **kwargs) 79 | self.final_eval_only = final_eval_only 80 | 81 | def log_diagnostics(self, itr, eval_traj_infos, eval_time): 82 | cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size 83 | self.wandb_info = {'cum_steps': cum_steps} 84 | super().log_diagnostics(itr, eval_traj_infos, eval_time) 85 | wandb.log(self.wandb_info) 86 | 87 | def startup(self): 88 | """ 89 | Sets hardware affinities, initializes the following: 1) sampler (which 90 | should initialize the agent), 2) agent device and data-parallel wrapper (if applicable), 91 | 3) algorithm, 4) logger. 92 | """ 93 | p = psutil.Process() 94 | try: 95 | if (self.affinity.get("master_cpus", None) is not None and 96 | self.affinity.get("set_affinity", True)): 97 | p.cpu_affinity(self.affinity["master_cpus"]) 98 | cpu_affin = p.cpu_affinity() 99 | except AttributeError: 100 | cpu_affin = "UNAVAILABLE MacOS" 101 | logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: " 102 | f"{cpu_affin}.") 103 | if self.affinity.get("master_torch_threads", None) is not None: 104 | torch.set_num_threads(self.affinity["master_torch_threads"]) 105 | logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: " 106 | f"{torch.get_num_threads()}.") 107 | set_seed(self.seed) 108 | torch.backends.cudnn.deterministic = True 109 | torch.backends.cudnn.benchmark = False 110 | 111 | self.rank = rank = getattr(self, "rank", 0) 112 | self.world_size = world_size = getattr(self, "world_size", 1) 113 | examples = self.sampler.initialize( 114 | agent=self.agent, # Agent gets initialized in sampler. 115 | affinity=self.affinity, 116 | seed=self.seed + 1, 117 | bootstrap_value=getattr(self.algo, "bootstrap_value", False), 118 | traj_info_kwargs=self.get_traj_info_kwargs(), 119 | rank=rank, 120 | world_size=world_size, 121 | ) 122 | self.itr_batch_size = self.sampler.batch_spec.size * world_size 123 | n_itr = self.get_n_itr() 124 | self.agent.to_device(self.affinity.get("cuda_idx", None)) 125 | if world_size > 1: 126 | self.agent.data_parallel() 127 | self.algo.initialize( 128 | agent=self.agent, 129 | n_itr=n_itr, 130 | batch_spec=self.sampler.batch_spec, 131 | mid_batch_reset=self.sampler.mid_batch_reset, 132 | examples=examples, 133 | world_size=world_size, 134 | rank=rank, 135 | ) 136 | self.initialize_logging() 137 | return n_itr 138 | 139 | def _log_infos(self, traj_infos=None): 140 | """ 141 | Writes trajectory info and optimizer info into csv via the logger. 142 | Resets stored optimizer info. 143 | """ 144 | if traj_infos is None: 145 | traj_infos = self._traj_infos 146 | if traj_infos: 147 | for k in traj_infos[0]: 148 | if not k.startswith("_"): 149 | values = [info[k] for info in traj_infos] 150 | logger.record_tabular_misc_stat(k, 151 | values) 152 | 153 | wandb.run.summary[k] = np.average(values) 154 | self.wandb_info[k + "Average"] = np.average(values) 155 | self.wandb_info[k + "Std"] = np.std(values) 156 | self.wandb_info[k + "Min"] = np.min(values) 157 | self.wandb_info[k + "Max"] = np.max(values) 158 | self.wandb_info[k + "Median"] = np.median(values) 159 | if k == 'GameScore': 160 | game = self.sampler.env_kwargs['game'] 161 | random_score = atari_random_scores[game] 162 | der_score = atari_der_scores[game] 163 | nature_score = atari_nature_scores[game] 164 | human_score = atari_human_scores[game] 165 | normalized_score = (np.average(values) - random_score) / (human_score - random_score) 166 | der_normalized_score = (np.average(values) - random_score) / (der_score - random_score) 167 | nature_normalized_score = (np.average(values) - random_score) / (nature_score - random_score) 168 | self.wandb_info[k + "Normalized"] = normalized_score 169 | self.wandb_info[k + "DERNormalized"] = der_normalized_score 170 | self.wandb_info[k + "NatureNormalized"] = nature_normalized_score 171 | 172 | maybe_update_summary(k+"Best", np.average(values)) 173 | maybe_update_summary(k+"NormalizedBest", normalized_score) 174 | maybe_update_summary(k+"DERNormalizedBest", der_normalized_score) 175 | maybe_update_summary(k+"NatureNormalizedBest", nature_normalized_score) 176 | 177 | if self._opt_infos: 178 | for k, v in self._opt_infos.items(): 179 | logger.record_tabular_misc_stat(k, v) 180 | self.wandb_info[k] = np.average(v) 181 | wandb.run.summary[k] = np.average(v) 182 | self._opt_infos = {k: list() for k in self._opt_infos} # (reset) 183 | 184 | def evaluate_agent(self, itr): 185 | """ 186 | Record offline evaluation of agent performance, by ``sampler.evaluate_agent()``. 187 | """ 188 | if itr > 0: 189 | self.pbar.stop() 190 | 191 | if self.final_eval_only: 192 | eval = itr == 0 or itr >= self.n_itr - 1 193 | else: 194 | eval = itr == 0 or itr >= self.min_itr_learn - 1 195 | 196 | # eval = False 197 | 198 | if eval: 199 | logger.log("Evaluating agent...") 200 | self.agent.eval_mode(itr) # Might be agent in sampler. 201 | eval_time = -time.time() 202 | traj_infos = self.sampler.evaluate_agent(itr) 203 | eval_time += time.time() 204 | else: 205 | traj_infos = [] 206 | eval_time = 0.0 207 | logger.log("Evaluation runs complete.") 208 | return traj_infos, eval_time 209 | 210 | def train(self): 211 | """ 212 | Performs startup, evaluates the initial agent, then loops by 213 | alternating between ``sampler.obtain_samples()`` and 214 | ``algo.optimize_agent()``. Pauses to evaluate the agent at the 215 | specified log interval. 216 | """ 217 | n_itr = self.startup() 218 | self.n_itr = n_itr 219 | with logger.prefix(f"itr #0 "): 220 | eval_traj_infos, eval_time = self.evaluate_agent(0) 221 | self.log_diagnostics(0, eval_traj_infos, eval_time) 222 | for itr in range(n_itr): 223 | logger.set_iteration(itr) 224 | with logger.prefix(f"itr #{itr} "): 225 | self.agent.sample_mode(itr) 226 | samples, traj_infos = self.sampler.obtain_samples(itr) 227 | self.agent.train_mode(itr) 228 | opt_info = self.algo.optimize_agent(itr, samples) 229 | self.store_diagnostics(itr, traj_infos, opt_info) 230 | if (itr + 1) % self.log_interval_itrs == 0: 231 | eval_traj_infos, eval_time = self.evaluate_agent(itr) 232 | self.log_diagnostics(itr, eval_traj_infos, eval_time) 233 | self.shutdown() 234 | 235 | 236 | def delete_ind_from_tensor(tensor, ind): 237 | tensor = torch.cat([tensor[:ind], tensor[ind+1:]], 0) 238 | return tensor 239 | 240 | 241 | def delete_ind_from_array(array, ind): 242 | tensor = np.concatenate([array[:ind], array[ind+1:]], 0) 243 | return tensor 244 | 245 | 246 | class OneToOneSerialEvalCollector(SerialEvalCollector): 247 | def collect_evaluation(self, itr): 248 | assert self.max_trajectories == len(self.envs) 249 | traj_infos = [self.TrajInfoCls() for _ in range(len(self.envs))] 250 | completed_traj_infos = list() 251 | observations = list() 252 | for env in self.envs: 253 | observations.append(env.reset()) 254 | observation = buffer_from_example(observations[0], len(self.envs)) 255 | for b, o in enumerate(observations): 256 | observation[b] = o 257 | action = buffer_from_example(self.envs[0].action_space.null_value(), 258 | len(self.envs)) 259 | reward = np.zeros(len(self.envs), dtype="float32") 260 | obs_pyt, act_pyt, rew_pyt = torchify_buffer((observation, action, reward)) 261 | self.agent.reset() 262 | self.agent.eval_mode(itr) 263 | live_envs = list(range(len(self.envs))) 264 | for t in range(self.max_T): 265 | act_pyt, agent_info = self.agent.step(obs_pyt, act_pyt, rew_pyt) 266 | action = numpify_buffer(act_pyt) 267 | 268 | b = 0 269 | while b < len(live_envs): # don't want to do a for loop since live envs changes over time 270 | env_id = live_envs[b] 271 | o, r, d, env_info = self.envs[env_id].step(action[b]) 272 | traj_infos[env_id].step(observation[b], 273 | action[b], r, d, 274 | agent_info[b], env_info) 275 | if getattr(env_info, "traj_done", d): 276 | completed_traj_infos.append(traj_infos[env_id].terminate(o)) 277 | 278 | observation = delete_ind_from_array(observation, b) 279 | reward = delete_ind_from_array(reward, b) 280 | action = delete_ind_from_array(action, b) 281 | obs_pyt, act_pyt, rew_pyt = torchify_buffer((observation, action, reward)) 282 | 283 | del live_envs[b] 284 | b -= 1 # live_envs[b] is now the next env, so go back one. 285 | else: 286 | observation[b] = o 287 | reward[b] = r 288 | 289 | b += 1 290 | 291 | if (self.max_trajectories is not None and 292 | len(completed_traj_infos) >= self.max_trajectories): 293 | logger.log("Evaluation reached max num trajectories " 294 | f"({self.max_trajectories}).") 295 | return completed_traj_infos 296 | 297 | if t == self.max_T - 1: 298 | logger.log("Evaluation reached max num time steps " 299 | f"({self.max_T}).") 300 | return completed_traj_infos 301 | 302 | 303 | class SerialSampler(BaseSampler): 304 | """The simplest sampler; no parallelism, everything occurs in same, master 305 | Python process. This can be easier for debugging (e.g. can use 306 | ``breakpoint()`` in master process) and might be fast enough for 307 | experiment purposes. Should be used with collectors which generate the 308 | agent's actions internally, i.e. CPU-based collectors but not GPU-based 309 | ones. 310 | NOTE: We modify this class from rlpyt to pass an id to EnvCls when creating 311 | environments. 312 | """ 313 | 314 | def __init__(self, *args, CollectorCls=CpuResetCollector, 315 | eval_CollectorCls=SerialEvalCollector, **kwargs): 316 | super().__init__(*args, CollectorCls=CollectorCls, 317 | eval_CollectorCls=eval_CollectorCls, **kwargs) 318 | 319 | def initialize( 320 | self, 321 | agent, 322 | affinity=None, 323 | seed=None, 324 | bootstrap_value=False, 325 | traj_info_kwargs=None, 326 | rank=0, 327 | world_size=1, 328 | ): 329 | """Store the input arguments. Instantiate the specified number of environment 330 | instances (``batch_B``). Initialize the agent, and pre-allocate a memory buffer 331 | to hold the samples collected in each batch. Applies ``traj_info_kwargs`` settings 332 | to the `TrajInfoCls` by direct class attribute assignment. Instantiates the Collector 333 | and, if applicable, the evaluation Collector. 334 | 335 | Returns a structure of inidividual examples for data fields such as `observation`, 336 | `action`, etc, which can be used to allocate a replay buffer. 337 | """ 338 | B = self.batch_spec.B 339 | envs = [self.EnvCls(id=i, **self.env_kwargs) for i in range(B)] 340 | global_B = B * world_size 341 | env_ranks = list(range(rank * B, (rank + 1) * B)) 342 | agent.initialize(envs[0].spaces, share_memory=False, 343 | global_B=global_B, env_ranks=env_ranks) 344 | samples_pyt, samples_np, examples = build_samples_buffer(agent, envs[0], 345 | self.batch_spec, bootstrap_value, agent_shared=False, 346 | env_shared=False, subprocess=False) 347 | if traj_info_kwargs: 348 | for k, v in traj_info_kwargs.items(): 349 | setattr(self.TrajInfoCls, "_" + k, v) # Avoid passing at init. 350 | collector = self.CollectorCls( 351 | rank=0, 352 | envs=envs, 353 | samples_np=samples_np, 354 | batch_T=self.batch_spec.T, 355 | TrajInfoCls=self.TrajInfoCls, 356 | agent=agent, 357 | global_B=global_B, 358 | env_ranks=env_ranks, # Might get applied redundantly to agent. 359 | ) 360 | if self.eval_n_envs > 0: # May do evaluation. 361 | eval_envs = [self.EnvCls(id=i, **self.eval_env_kwargs) 362 | for i in range(self.eval_n_envs)] 363 | eval_CollectorCls = self.eval_CollectorCls or SerialEvalCollector 364 | self.eval_collector = eval_CollectorCls( 365 | envs=eval_envs, 366 | agent=agent, 367 | TrajInfoCls=self.TrajInfoCls, 368 | max_T=self.eval_max_steps // self.eval_n_envs, 369 | max_trajectories=self.eval_max_trajectories, 370 | ) 371 | 372 | agent_inputs, traj_infos = collector.start_envs( 373 | self.max_decorrelation_steps) 374 | collector.start_agent() 375 | 376 | self.agent = agent 377 | self.samples_pyt = samples_pyt 378 | self.samples_np = samples_np 379 | self.collector = collector 380 | self.agent_inputs = agent_inputs 381 | self.traj_infos = traj_infos 382 | logger.log("Serial Sampler initialized.") 383 | return examples 384 | 385 | def obtain_samples(self, itr): 386 | """Call the collector to execute a batch of agent-environment interactions. 387 | Return data in torch tensors, and a list of trajectory-info objects from 388 | episodes which ended. 389 | """ 390 | # self.samples_np[:] = 0 # Unnecessary and may take time. 391 | agent_inputs, traj_infos, completed_infos = self.collector.collect_batch( 392 | self.agent_inputs, self.traj_infos, itr) 393 | self.collector.reset_if_needed(agent_inputs) 394 | self.agent_inputs = agent_inputs 395 | self.traj_infos = traj_infos 396 | return self.samples_pyt, completed_infos 397 | 398 | def evaluate_agent(self, itr): 399 | """Call the evaluation collector to execute agent-environment interactions.""" 400 | return self.eval_collector.collect_evaluation(itr) 401 | -------------------------------------------------------------------------------- /DMControl/src/curl_sac.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is borrowed from https://github.com/MishaLaskin/curl 3 | # -------------------------------------------------------- 4 | 5 | import copy 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import utils 14 | from encoder import make_encoder 15 | 16 | LOG_FREQ = 10000 17 | 18 | 19 | def gaussian_logprob(noise, log_std): 20 | """Compute Gaussian log probability.""" 21 | residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) 22 | return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) 23 | 24 | 25 | def squash(mu, pi, log_pi): 26 | """Apply squashing function. 27 | See appendix C from https://arxiv.org/pdf/1812.05905.pdf. 28 | """ 29 | mu = torch.tanh(mu) 30 | if pi is not None: 31 | pi = torch.tanh(pi) 32 | if log_pi is not None: 33 | log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) 34 | return mu, pi, log_pi 35 | 36 | 37 | def weight_init(m): 38 | """Custom weight init for Conv2D and Linear layers.""" 39 | if isinstance(m, nn.Linear): 40 | nn.init.orthogonal_(m.weight.data) 41 | m.bias.data.fill_(0.0) 42 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 43 | # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf 44 | assert m.weight.size(2) == m.weight.size(3) 45 | m.weight.data.fill_(0.0) 46 | m.bias.data.fill_(0.0) 47 | mid = m.weight.size(2) // 2 48 | gain = nn.init.calculate_gain('relu') 49 | nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) 50 | 51 | 52 | class Actor(nn.Module): 53 | """MLP actor network.""" 54 | def __init__(self, obs_shape, action_shape, hidden_dim, encoder_type, 55 | encoder_feature_dim, log_std_min, log_std_max, num_layers, 56 | num_filters): 57 | super().__init__() 58 | 59 | self.encoder = make_encoder(encoder_type, 60 | obs_shape, 61 | encoder_feature_dim, 62 | num_layers, 63 | num_filters, 64 | output_logits=True) 65 | 66 | self.log_std_min = log_std_min 67 | self.log_std_max = log_std_max 68 | 69 | self.trunk = nn.Sequential( 70 | nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), 71 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 72 | nn.Linear(hidden_dim, 2 * action_shape[0])) 73 | 74 | self.outputs = dict() 75 | self.apply(weight_init) 76 | 77 | def forward(self, 78 | obs, 79 | compute_pi=True, 80 | compute_log_pi=True, 81 | detach_encoder=False): 82 | obs = self.encoder(obs, detach=detach_encoder) 83 | 84 | mu, log_std = self.trunk(obs).chunk(2, dim=-1) 85 | 86 | # constrain log_std inside [log_std_min, log_std_max] 87 | log_std = torch.tanh(log_std) 88 | log_std = self.log_std_min + 0.5 * (self.log_std_max - 89 | self.log_std_min) * (log_std + 1) 90 | 91 | self.outputs['mu'] = mu 92 | self.outputs['std'] = log_std.exp() 93 | 94 | if compute_pi: 95 | std = log_std.exp() 96 | noise = torch.randn_like(mu) 97 | pi = mu + noise * std 98 | else: 99 | pi = None 100 | entropy = None 101 | 102 | if compute_log_pi: 103 | log_pi = gaussian_logprob(noise, log_std) 104 | else: 105 | log_pi = None 106 | 107 | mu, pi, log_pi = squash(mu, pi, log_pi) 108 | 109 | return mu, pi, log_pi, log_std 110 | 111 | def log(self, L, step, log_freq=LOG_FREQ): 112 | if step % log_freq != 0: 113 | return 114 | 115 | for k, v in self.outputs.items(): 116 | L.log_histogram('train_actor/%s_hist' % k, v, step) 117 | 118 | L.log_param('train_actor/fc1', self.trunk[0], step) 119 | L.log_param('train_actor/fc2', self.trunk[2], step) 120 | L.log_param('train_actor/fc3', self.trunk[4], step) 121 | 122 | 123 | class QFunction(nn.Module): 124 | """MLP for q-function.""" 125 | def __init__(self, obs_dim, action_dim, hidden_dim): 126 | super().__init__() 127 | 128 | self.trunk = nn.Sequential(nn.Linear(obs_dim + action_dim, 129 | hidden_dim), nn.ReLU(), 130 | nn.Linear(hidden_dim, hidden_dim), 131 | nn.ReLU(), nn.Linear(hidden_dim, 1)) 132 | 133 | def forward(self, obs, action): 134 | assert obs.size(0) == action.size(0) 135 | 136 | obs_action = torch.cat([obs, action], dim=1) 137 | return self.trunk(obs_action) 138 | 139 | 140 | class Critic(nn.Module): 141 | """Critic network, employes two q-functions.""" 142 | def __init__(self, obs_shape, action_shape, hidden_dim, encoder_type, 143 | encoder_feature_dim, num_layers, num_filters): 144 | super().__init__() 145 | 146 | self.encoder = make_encoder(encoder_type, 147 | obs_shape, 148 | encoder_feature_dim, 149 | num_layers, 150 | num_filters, 151 | output_logits=True) 152 | 153 | self.Q1 = QFunction(self.encoder.feature_dim, action_shape[0], 154 | hidden_dim) 155 | self.Q2 = QFunction(self.encoder.feature_dim, action_shape[0], 156 | hidden_dim) 157 | 158 | self.outputs = dict() 159 | self.apply(weight_init) 160 | 161 | def forward(self, obs, action, detach_encoder=False): 162 | # detach_encoder allows to stop gradient propogation to encoder 163 | obs = self.encoder(obs, detach=detach_encoder) 164 | 165 | q1 = self.Q1(obs, action) 166 | q2 = self.Q2(obs, action) 167 | 168 | self.outputs['q1'] = q1 169 | self.outputs['q2'] = q2 170 | 171 | return q1, q2 172 | 173 | def log(self, L, step, log_freq=LOG_FREQ): 174 | if step % log_freq != 0: 175 | return 176 | 177 | self.encoder.log(L, step, log_freq) 178 | 179 | for k, v in self.outputs.items(): 180 | L.log_histogram('train_critic/%s_hist' % k, v, step) 181 | 182 | for i in range(3): 183 | L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) 184 | L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) 185 | 186 | 187 | class CURL(nn.Module): 188 | """ 189 | CURL 190 | """ 191 | def __init__(self, 192 | obs_shape, 193 | z_dim, 194 | batch_size, 195 | critic, 196 | critic_target, 197 | output_type="continuous"): 198 | super(CURL, self).__init__() 199 | self.batch_size = batch_size 200 | 201 | self.encoder = critic.encoder 202 | 203 | self.encoder_target = critic_target.encoder 204 | 205 | self.W = nn.Parameter(torch.rand(z_dim, z_dim)) 206 | self.output_type = output_type 207 | 208 | def encode(self, x, detach=False, ema=False): 209 | """ 210 | Encoder: z_t = e(x_t) 211 | :param x: x_t, x y coordinates 212 | :return: z_t, value in r2 213 | """ 214 | if ema: 215 | with torch.no_grad(): 216 | z_out = self.encoder_target(x) 217 | else: 218 | z_out = self.encoder(x) 219 | 220 | if detach: 221 | z_out = z_out.detach() 222 | return z_out 223 | 224 | def compute_logits(self, z_a, z_pos): 225 | """ 226 | Uses logits trick for CURL: 227 | - compute (B,B) matrix z_a (W z_pos.T) 228 | - positives are all diagonal elements 229 | - negatives are all other elements 230 | - to compute loss use multiclass cross entropy with identity matrix for labels 231 | """ 232 | Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) 233 | logits = torch.matmul(z_a, Wz) # (B,B) 234 | logits = logits - torch.max(logits, 1)[0][:, None] 235 | return logits 236 | 237 | 238 | class CurlSacAgent(object): 239 | """CURL representation learning with SAC.""" 240 | def __init__(self, 241 | obs_shape, 242 | action_shape, 243 | device, 244 | hidden_dim=256, 245 | discount=0.99, 246 | init_temperature=0.01, 247 | alpha_lr=1e-3, 248 | alpha_beta=0.9, 249 | actor_lr=1e-3, 250 | actor_beta=0.9, 251 | actor_log_std_min=-10, 252 | actor_log_std_max=2, 253 | actor_update_freq=2, 254 | critic_lr=1e-3, 255 | critic_beta=0.9, 256 | critic_tau=0.005, 257 | critic_target_update_freq=2, 258 | encoder_type='pixel', 259 | encoder_feature_dim=50, 260 | encoder_lr=1e-3, 261 | encoder_tau=0.005, 262 | num_layers=4, 263 | num_filters=32, 264 | cpc_update_freq=1, 265 | log_interval=100, 266 | detach_encoder=False, 267 | curl_latent_dim=128): 268 | self.device = device 269 | self.discount = discount 270 | self.critic_tau = critic_tau 271 | self.encoder_tau = encoder_tau 272 | self.actor_update_freq = actor_update_freq 273 | self.critic_target_update_freq = critic_target_update_freq 274 | self.cpc_update_freq = cpc_update_freq 275 | self.log_interval = log_interval 276 | self.image_size = obs_shape[-1] 277 | self.curl_latent_dim = curl_latent_dim 278 | self.detach_encoder = detach_encoder 279 | self.encoder_type = encoder_type 280 | 281 | self.actor = Actor(obs_shape, action_shape, hidden_dim, encoder_type, 282 | encoder_feature_dim, actor_log_std_min, 283 | actor_log_std_max, num_layers, 284 | num_filters).to(device) 285 | 286 | self.critic = Critic(obs_shape, action_shape, hidden_dim, encoder_type, 287 | encoder_feature_dim, num_layers, 288 | num_filters).to(device) 289 | 290 | self.critic_target = Critic(obs_shape, action_shape, hidden_dim, 291 | encoder_type, encoder_feature_dim, 292 | num_layers, num_filters).to(device) 293 | 294 | self.critic_target.load_state_dict(self.critic.state_dict()) 295 | 296 | # tie encoders between actor and critic, and CURL and critic 297 | self.actor.encoder.copy_conv_weights_from(self.critic.encoder) 298 | 299 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) 300 | self.log_alpha.requires_grad = True 301 | # set target entropy to -|A| 302 | self.target_entropy = -np.prod(action_shape) 303 | 304 | # optimizers 305 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 306 | lr=actor_lr, 307 | betas=(actor_beta, 0.999)) 308 | 309 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 310 | lr=critic_lr, 311 | betas=(critic_beta, 0.999)) 312 | 313 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], 314 | lr=alpha_lr, 315 | betas=(alpha_beta, 0.999)) 316 | 317 | if self.encoder_type == 'pixel': 318 | # create CURL encoder (the 128 batch size is probably unnecessary) 319 | self.CURL = CURL(obs_shape, 320 | encoder_feature_dim, 321 | self.curl_latent_dim, 322 | self.critic, 323 | self.critic_target, 324 | output_type='continuous').to(self.device) 325 | 326 | # optimizer for critic encoder for reconstruction loss 327 | self.encoder_optimizer = torch.optim.Adam( 328 | self.critic.encoder.parameters(), lr=encoder_lr) 329 | 330 | self.cpc_optimizer = torch.optim.Adam(self.CURL.parameters(), 331 | lr=encoder_lr) 332 | self.cross_entropy_loss = nn.CrossEntropyLoss() 333 | 334 | self.train() 335 | self.critic_target.train() 336 | 337 | def train(self, training=True): 338 | self.training = training 339 | self.actor.train(training) 340 | self.critic.train(training) 341 | if self.encoder_type == 'pixel': 342 | self.CURL.train(training) 343 | 344 | @property 345 | def alpha(self): 346 | return self.log_alpha.exp() 347 | 348 | def select_action(self, obs): 349 | with torch.no_grad(): 350 | obs = torch.FloatTensor(obs).to(self.device) 351 | obs = obs.unsqueeze(0) 352 | mu, _, _, _ = self.actor(obs, 353 | compute_pi=False, 354 | compute_log_pi=False) 355 | return mu.cpu().data.numpy().flatten() 356 | 357 | def sample_action(self, obs): 358 | if obs.shape[-1] != self.image_size: 359 | obs = utils.center_crop_image(obs, self.image_size) 360 | 361 | with torch.no_grad(): 362 | obs = torch.FloatTensor(obs).to(self.device) 363 | obs = obs.unsqueeze(0) 364 | mu, pi, _, _ = self.actor(obs, compute_log_pi=False) 365 | return pi.cpu().data.numpy().flatten() 366 | 367 | def update_critic(self, obs, action, reward, next_obs, not_done, L, step): 368 | with torch.no_grad(): 369 | _, policy_action, log_pi, _ = self.actor(next_obs) 370 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) 371 | target_V = torch.min(target_Q1, 372 | target_Q2) - self.alpha.detach() * log_pi 373 | target_Q = reward + (not_done * self.discount * target_V) 374 | 375 | # get current Q estimates 376 | current_Q1, current_Q2 = self.critic( 377 | obs, action, detach_encoder=self.detach_encoder) 378 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( 379 | current_Q2, target_Q) 380 | if step % self.log_interval == 0: 381 | L.log('train_critic/loss', critic_loss, step) 382 | 383 | # Optimize the critic 384 | self.critic_optimizer.zero_grad() 385 | critic_loss.backward() 386 | self.critic_optimizer.step() 387 | 388 | self.critic.log(L, step) 389 | 390 | def update_actor_and_alpha(self, obs, L, step): 391 | # detach encoder, so we don't update it with the actor loss 392 | _, pi, log_pi, log_std = self.actor(obs, detach_encoder=True) 393 | actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True) 394 | 395 | actor_Q = torch.min(actor_Q1, actor_Q2) 396 | actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() 397 | 398 | if step % self.log_interval == 0: 399 | L.log('train_actor/loss', actor_loss, step) 400 | L.log('train_actor/target_entropy', self.target_entropy, step) 401 | entropy = 0.5 * log_std.shape[1] * \ 402 | (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) 403 | if step % self.log_interval == 0: 404 | L.log('train_actor/entropy', entropy.mean(), step) 405 | 406 | # optimize the actor 407 | self.actor_optimizer.zero_grad() 408 | actor_loss.backward() 409 | self.actor_optimizer.step() 410 | 411 | self.actor.log(L, step) 412 | 413 | self.log_alpha_optimizer.zero_grad() 414 | alpha_loss = (self.alpha * 415 | (-log_pi - self.target_entropy).detach()).mean() 416 | if step % self.log_interval == 0: 417 | L.log('train_alpha/loss', alpha_loss, step) 418 | L.log('train_alpha/value', self.alpha, step) 419 | alpha_loss.backward() 420 | self.log_alpha_optimizer.step() 421 | 422 | def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step): 423 | 424 | z_a = self.CURL.encode(obs_anchor) 425 | z_pos = self.CURL.encode(obs_pos, ema=True) 426 | 427 | logits = self.CURL.compute_logits(z_a, z_pos) 428 | labels = torch.arange(logits.shape[0]).long().to(self.device) 429 | loss = self.cross_entropy_loss(logits, labels) 430 | 431 | self.encoder_optimizer.zero_grad() 432 | self.cpc_optimizer.zero_grad() 433 | loss.backward() 434 | 435 | self.encoder_optimizer.step() 436 | self.cpc_optimizer.step() 437 | if step % self.log_interval == 0: 438 | L.log('train/curl_loss', loss, step) 439 | 440 | def update(self, replay_buffer, L, step): 441 | if self.encoder_type == 'pixel': 442 | obs, action, reward, next_obs, not_done, cpc_kwargs = replay_buffer.sample_cpc( 443 | ) 444 | else: 445 | obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio( 446 | ) 447 | 448 | if step % self.log_interval == 0: 449 | L.log('train/batch_reward', reward.mean(), step) 450 | 451 | self.update_critic(obs, action, reward, next_obs, not_done, L, step) 452 | 453 | if step % self.actor_update_freq == 0: 454 | self.update_actor_and_alpha(obs, L, step) 455 | 456 | if step % self.critic_target_update_freq == 0: 457 | utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, 458 | self.critic_tau) 459 | utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, 460 | self.critic_tau) 461 | utils.soft_update_params(self.critic.encoder, 462 | self.critic_target.encoder, 463 | self.encoder_tau) 464 | 465 | if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel': 466 | obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs[ 467 | "obs_pos"] 468 | self.update_cpc(obs_anchor, obs_pos, cpc_kwargs, L, step) 469 | 470 | def save(self, model_dir, step): 471 | torch.save(self.actor.state_dict(), 472 | '%s/actor_%s.pt' % (model_dir, step)) 473 | torch.save(self.critic.state_dict(), 474 | '%s/critic_%s.pt' % (model_dir, step)) 475 | 476 | def save_curl(self, model_dir, step): 477 | # if hasattr(self, 'CURL'): 478 | # torch.save(self.CURL.state_dict(), '%s/curl_%s.pt' % (model_dir, step)) 479 | torch.save(self.CURL.state_dict(), '%s/curl_%s.pt' % (model_dir, step)) 480 | 481 | def load(self, model_dir, step): 482 | self.actor.load_state_dict( 483 | torch.load('%s/actor_%s.pt' % (model_dir, step))) 484 | self.critic.load_state_dict( 485 | torch.load('%s/critic_%s.pt' % (model_dir, step))) 486 | --------------------------------------------------------------------------------