├── README.md ├── cfgs ├── config_dmc.yaml ├── config_rlbench.yaml ├── config_rlbench_drqv2plus.yaml ├── dmc_task │ ├── acrobot_swingup.yaml │ ├── cartpole_balance.yaml │ ├── cartpole_balance_sparse.yaml │ ├── cartpole_swingup.yaml │ ├── cartpole_swingup_sparse.yaml │ ├── cheetah_run.yaml │ ├── cup_catch.yaml │ ├── easy.yaml │ ├── finger_spin.yaml │ ├── finger_turn_easy.yaml │ ├── finger_turn_hard.yaml │ ├── hard.yaml │ ├── hopper_hop.yaml │ ├── hopper_stand.yaml │ ├── humanoid_run.yaml │ ├── humanoid_stand.yaml │ ├── humanoid_walk.yaml │ ├── medium.yaml │ ├── pendulum_swingup.yaml │ ├── quadruped_run.yaml │ ├── quadruped_walk.yaml │ ├── reach_duplo.yaml │ ├── reacher_easy.yaml │ ├── reacher_hard.yaml │ ├── walker_run.yaml │ ├── walker_stand.yaml │ └── walker_walk.yaml └── rlbench_task │ ├── basketball_in_hoop.yaml │ ├── default.yaml │ ├── insert_usb_in_computer.yaml │ ├── lamp_on.yaml │ ├── meat_on_grill.yaml │ ├── open_door.yaml │ ├── open_drawer.yaml │ ├── open_microwave.yaml │ ├── open_oven.yaml │ ├── phone_on_base.yaml │ ├── pick_up_cup.yaml │ ├── press_switch.yaml │ ├── put_books_on_bookshelf.yaml │ ├── put_money_in_safe.yaml │ ├── put_rubbish_in_bin.yaml │ ├── reach_target.yaml │ ├── slide_block_to_target.yaml │ ├── stack_wine.yaml │ ├── sweep_to_dustpan.yaml │ ├── take_lid_off_saucepan.yaml │ ├── take_plate_off_colored_dish_rack.yaml │ ├── toilet_seat_up.yaml │ └── turn_tap.yaml ├── conda_env.yml ├── cqn.py ├── cqn_dmc.py ├── cqn_utils.py ├── dmc.py ├── drqv2plus.py ├── logger.py ├── media ├── cqn_gif1.gif └── cqn_gif2.gif ├── replay_buffer.py ├── replay_buffer_dmc.py ├── rlbench_env.py ├── train_dmc.py ├── train_rlbench.py ├── train_rlbench_drqv2plus.py ├── utils.py └── video.py /README.md: -------------------------------------------------------------------------------- 1 | # Continuous Control with Coarse-to-fine Reinforcement Learning 2 | 3 | A re-implementation of **Coarse-to-fine Q-Network (CQN)**, a sample-efficient value-based RL algorithm for continuous control, introduced in: 4 | 5 | [**Continuous Control with Coarse-to-fine Reinforcement Learning**](https://younggyo.me/cqn/) 6 | 7 | [Younggyo Seo](https://younggyo.me/), [Jafar Uruç](https://github.com/JafarAbdi), [Stephen James](https://stepjam.github.io/) 8 | 9 | Our key idea is to learn RL agents that zoom-into continuous action space in a coarse-to-fine manner, enabling us to train value-based RL agents for continuous control with few discrete actions at each level. 10 | 11 | See our project webpage https://younggyo.me/cqn/ for more information. 12 | 13 | ![gif1](media/cqn_gif1.gif) 14 | ![gif2](media/cqn_gif2.gif) 15 | 16 | ## Instructions for RLBench experiments 17 | 18 | Install conda environment: 19 | ``` 20 | conda env create -f conda_env.yml 21 | conda activate cqn 22 | ``` 23 | 24 | Install RLBench and PyRep (latest versions at the date of July 10, 2024 should be used). 25 | Follow the guide in original repositories for (1) installing RLBench and PyRep and (2) enabling headless mode. (See README in [RLBench](https://github.com/stepjam/RLBench) \& [Robobase](https://github.com/robobase-org/robobase?tab=readme-ov-file#rlbench) for information on installing RLBench.) 26 | 27 | ``` 28 | git clone https://github.com/stepjam/RLBench 29 | git clone https://github.com/stepjam/PyRep 30 | # Install PyRep 31 | cd PyRep 32 | git checkout 8f420be8064b1970aae18a9cfbc978dfb15747ef 33 | pip install . 34 | # Install RLBench 35 | cd RLBench 36 | git checkout b80e51feb3694d9959cb8c0408cd385001b01382 37 | pip install . 38 | ``` 39 | 40 | Pre-collect demonstrations 41 | ``` 42 | cd RLBench/rlbench 43 | CUDA_VISIBLE_DEVICES=0 DISPLAY=:0.0 python dataset_generator.py --save_path=/your/own/directory --image_size 84 84 --renderer opengl3 --episodes_per_task 100 --variations 1 --processes 1 --tasks take_lid_off_saucepan --arm_max_velocity 2.0 --arm_max_acceleration 8.0 44 | ``` 45 | 46 | Run experiments (CQN): 47 | ``` 48 | CUDA_VISIBLE_DEVICES=0 DISPLAY=:0.0 python train_rlbench.py rlbench_task=take_lid_off_saucepan num_demos=100 dataset_root=/your/own/directory 49 | ``` 50 | 51 | Run baseline experiments (DrQv2+): 52 | ``` 53 | CUDA_VISIBLE_DEVICES=0 DISPLAY=:0.0 python train_rlbench_drqv2plus.py rlbench_task=take_lid_off_saucepan num_demos=100 dataset_root=/your/own/directory 54 | ``` 55 | 56 | ## Instructions for DMC experiments 57 | 58 | Run experiments: 59 | ``` 60 | CUDA_VISIBLE_DEVICES=0 python train_dmc.py dmc_task=cartpole_swingup 61 | ``` 62 | 63 | Warning: CQN is not extensively tested in DMC 64 | 65 | 66 | ## Acknowledgements 67 | This repository is based on public implementation of [DrQ-v2](https://github.com/facebookresearch/drqv2) 68 | 69 | 70 | ## Citation 71 | ``` 72 | @article{seo2024continuous, 73 | title={Continuous Control with Coarse-to-fine Reinforcement Learning}, 74 | author={Seo, Younggyo and Uru{\c{c}}, Jafar and James, Stephen}, 75 | journal={arXiv preprint arXiv:2407.07787}, 76 | year={2024} 77 | } 78 | ``` -------------------------------------------------------------------------------- /cfgs/config_dmc.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - dmc_task@_global_: quadruped_walk 4 | - override hydra/launcher: submitit_local 5 | 6 | # task settings 7 | frame_stack: 3 8 | action_repeat: 2 9 | discount: 0.99 10 | # train settings 11 | num_seed_frames: 4000 12 | # eval 13 | eval_every_frames: 10000 14 | num_eval_episodes: 10 15 | # snapshot 16 | save_snapshot: false 17 | # replay buffer 18 | replay_buffer_size: 1000000 19 | replay_buffer_num_workers: 4 20 | nstep: 3 21 | batch_size: 256 22 | # misc 23 | seed: 1 24 | device: cuda 25 | save_video: true 26 | save_train_video: false 27 | use_tb: true 28 | use_wandb: false 29 | # experiment 30 | experiment: exp 31 | # agent 32 | lr: 1e-4 33 | feature_dim: 64 34 | 35 | agent: 36 | _target_: cqn_dmc.CQNAgent 37 | obs_shape: ??? # to be specified later 38 | action_shape: ??? # to be specified later 39 | device: ${device} 40 | lr: ${lr} 41 | critic_target_tau: 0.02 42 | update_every_steps: 2 43 | use_logger: ??? # to be specified later 44 | num_expl_steps: 2000 45 | feature_dim: ${feature_dim} 46 | hidden_dim: 512 47 | levels: 3 48 | bins: 5 49 | atoms: 51 50 | v_min: 0 51 | v_max: 200 52 | stddev_schedule: 0.1 53 | 54 | wandb: 55 | project: cqn 56 | entity: rll 57 | name: name 58 | 59 | hydra: 60 | run: 61 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${hydra.job.override_dirname} 62 | sweep: 63 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment} 64 | subdir: ${hydra.job.num} 65 | launcher: 66 | timeout_min: 4300 67 | cpus_per_task: 10 68 | gpus_per_node: 1 69 | tasks_per_node: 1 70 | mem_gb: 160 71 | nodes: 1 72 | submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm 73 | -------------------------------------------------------------------------------- /cfgs/config_rlbench.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - rlbench_task@_global_: reach_target 4 | - override hydra/launcher: submitit_local 5 | 6 | # task settings 7 | frame_stack: 8 8 | action_repeat: 1 9 | discount: 0.99 10 | # train settings 11 | num_seed_frames: 0 12 | # eval 13 | eval_every_frames: 2500 14 | num_eval_episodes: 25 15 | # snapshot 16 | save_snapshot: false 17 | # replay buffer 18 | replay_buffer_size: 1000000 19 | replay_buffer_num_workers: 4 20 | nstep: 1 21 | batch_size: 256 22 | use_relabeling: true 23 | # misc 24 | seed: 1 25 | device: cuda 26 | save_video: true 27 | save_train_video: false 28 | use_tb: true 29 | use_wandb: false 30 | # experiment 31 | experiment: exp 32 | # agent 33 | lr: 5e-5 34 | weight_decay: 0.1 35 | feature_dim: 64 36 | do_always_bootstrap: false # setting to True can sometimes work better 37 | num_update_steps: 1 38 | # environment 39 | arm_max_velocity: 2.0 40 | arm_max_acceleration: 8.0 41 | 42 | agent: 43 | _target_: cqn.CQNAgent 44 | rgb_obs_shape: ??? # to be specified later 45 | low_dim_obs_shape: ??? # to be specified later 46 | action_shape: ??? # to be specified later 47 | device: ${device} 48 | lr: ${lr} 49 | weight_decay: ${weight_decay} 50 | critic_target_tau: 0.02 51 | update_every_steps: 1 52 | use_logger: ??? # to be specified later 53 | num_expl_steps: 0 54 | feature_dim: ${feature_dim} 55 | hidden_dim: 512 56 | levels: 3 57 | bins: 5 58 | atoms: 51 59 | v_min: -2.0 # maybe -1.0/1.0? 60 | v_max: 2.0 61 | critic_lambda: 0.1 62 | stddev_schedule: 0.01 63 | bc_lambda: 1.0 64 | bc_margin: 0.01 65 | 66 | wandb: 67 | project: cqn 68 | entity: rll 69 | name: name 70 | 71 | hydra: 72 | run: 73 | dir: ./exp_local/cqn_pixel_rlbench/${now:%Y%m%d%H%M%S} 74 | sweep: 75 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment} 76 | subdir: ${hydra.job.num} 77 | launcher: 78 | timeout_min: 4300 79 | cpus_per_task: 10 80 | gpus_per_node: 1 81 | tasks_per_node: 1 82 | mem_gb: 160 83 | nodes: 1 84 | submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm 85 | -------------------------------------------------------------------------------- /cfgs/config_rlbench_drqv2plus.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - rlbench_task@_global_: reach_target 4 | - override hydra/launcher: submitit_local 5 | 6 | # task settings 7 | frame_stack: 8 8 | action_repeat: 1 9 | discount: 0.99 10 | # train settings 11 | num_seed_frames: 0 12 | # eval 13 | eval_every_frames: 2500 14 | num_eval_episodes: 25 15 | # snapshot 16 | save_snapshot: false 17 | # replay buffer 18 | replay_buffer_size: 1000000 19 | replay_buffer_num_workers: 4 20 | nstep: 1 21 | batch_size: 256 22 | use_relabeling: true 23 | # misc 24 | seed: 1 25 | device: cuda 26 | save_video: true 27 | save_train_video: false 28 | use_tb: true 29 | use_wandb: false 30 | # experiment 31 | experiment: exp 32 | # agent 33 | lr: 1e-4 34 | weight_decay: 0.1 35 | feature_dim: 64 36 | do_always_bootstrap: false # setting to True can sometimes work better 37 | num_update_steps: 1 38 | # environment 39 | arm_max_velocity: 2.0 40 | arm_max_acceleration: 8.0 41 | 42 | agent: 43 | _target_: drqv2plus.DrQV2Agent 44 | rgb_obs_shape: ??? # to be specified later 45 | low_dim_obs_shape: ??? # to be specified later 46 | action_shape: ??? # to be specified later 47 | device: ${device} 48 | lr: ${lr} 49 | weight_decay: ${weight_decay} 50 | feature_dim: ${feature_dim} 51 | hidden_dim: 1024 52 | use_distributional_critic: true 53 | distributional_critic_limit: 2.0 54 | distributional_critic_atoms: 101 55 | distributional_critic_transform: false 56 | bc_lambda: 1.0 57 | critic_target_tau: 0.01 58 | num_expl_steps: 0 59 | update_every_steps: 1 60 | stddev_schedule: 0.01 61 | stddev_clip: 0.3 62 | use_logger: ??? # to be specified later 63 | 64 | wandb: 65 | project: cqn 66 | entity: younggyo 67 | name: name 68 | 69 | hydra: 70 | run: 71 | dir: ./exp_local/drqv2plus_pixel_rlbench/${now:%Y%m%d%H%M%S} 72 | sweep: 73 | dir: ./exp/${now:%Y.%m.%d}/${now:%H%M}_${agent_cfg.experiment} 74 | subdir: ${hydra.job.num} 75 | launcher: 76 | timeout_min: 4300 77 | cpus_per_task: 10 78 | gpus_per_node: 1 79 | tasks_per_node: 1 80 | mem_gb: 160 81 | nodes: 1 82 | submitit_folder: ./exp/${now:%Y.%m.%d}/${now:%H%M%S}_${agent_cfg.experiment}/.slurm 83 | -------------------------------------------------------------------------------- /cfgs/dmc_task/acrobot_swingup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: acrobot_swingup 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/cartpole_balance.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cartpole_balance 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/cartpole_balance_sparse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cartpole_balance_sparse 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/cartpole_swingup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cartpole_swingup 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/cartpole_swingup_sparse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: cartpole_swingup_sparse 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/cheetah_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: cheetah_run 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/cup_catch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: cup_catch 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/easy.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 1100000 2 | stddev_schedule: 'linear(1.0,0.1,100000)' -------------------------------------------------------------------------------- /cfgs/dmc_task/finger_spin.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: finger_spin 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/finger_turn_easy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: finger_turn_easy -------------------------------------------------------------------------------- /cfgs/dmc_task/finger_turn_hard.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: finger_turn_hard 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/hard.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 30100000 2 | stddev_schedule: 'linear(1.0,0.1,2000000)' -------------------------------------------------------------------------------- /cfgs/dmc_task/hopper_hop.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: hopper_hop 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/hopper_stand.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: hopper_stand 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/humanoid_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hard 3 | - _self_ 4 | 5 | task_name: humanoid_run 6 | lr: 8e-5 7 | feature_dim: 100 8 | -------------------------------------------------------------------------------- /cfgs/dmc_task/humanoid_stand.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hard 3 | - _self_ 4 | 5 | task_name: humanoid_stand 6 | lr: 8e-5 7 | feature_dim: 100 8 | -------------------------------------------------------------------------------- /cfgs/dmc_task/humanoid_walk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hard 3 | - _self_ 4 | 5 | task_name: humanoid_walk 6 | lr: 8e-5 7 | feature_dim: 100 8 | -------------------------------------------------------------------------------- /cfgs/dmc_task/medium.yaml: -------------------------------------------------------------------------------- 1 | num_train_frames: 3100000 2 | stddev_schedule: 'linear(1.0,0.1,500000)' -------------------------------------------------------------------------------- /cfgs/dmc_task/pendulum_swingup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: pendulum_swingup 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/quadruped_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: quadruped_run 6 | replay_buffer_size: 100000 -------------------------------------------------------------------------------- /cfgs/dmc_task/quadruped_walk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: quadruped_walk 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/reach_duplo.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: reach_duplo 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/reacher_easy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: reacher_easy 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/reacher_hard.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: reacher_hard 6 | -------------------------------------------------------------------------------- /cfgs/dmc_task/walker_run.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - medium 3 | - _self_ 4 | 5 | task_name: walker_run 6 | nstep: 1 7 | batch_size: 512 -------------------------------------------------------------------------------- /cfgs/dmc_task/walker_stand.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: walker_stand 6 | nstep: 1 7 | batch_size: 512 8 | -------------------------------------------------------------------------------- /cfgs/dmc_task/walker_walk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - easy 3 | - _self_ 4 | 5 | task_name: walker_walk 6 | nstep: 1 7 | batch_size: 512 8 | -------------------------------------------------------------------------------- /cfgs/rlbench_task/basketball_in_hoop.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: basketball_in_hoop 6 | episode_length: 125 -------------------------------------------------------------------------------- /cfgs/rlbench_task/default.yaml: -------------------------------------------------------------------------------- 1 | dataset_root: '' 2 | camera_shape: [84,84] 3 | camera_keys: [front,wrist,left_shoulder,right_shoulder] 4 | state_keys: [joint_positions,gripper_open] 5 | renderer: opengl3 6 | num_demos: 100 7 | num_train_frames: 30250 -------------------------------------------------------------------------------- /cfgs/rlbench_task/insert_usb_in_computer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: insert_usb_in_computer 6 | episode_length: 100 -------------------------------------------------------------------------------- /cfgs/rlbench_task/lamp_on.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: lamp_on 6 | episode_length: 100 -------------------------------------------------------------------------------- /cfgs/rlbench_task/meat_on_grill.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: meat_on_gril 6 | episode_length: 150 -------------------------------------------------------------------------------- /cfgs/rlbench_task/open_door.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: open_door 6 | episode_length: 125 -------------------------------------------------------------------------------- /cfgs/rlbench_task/open_drawer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: open_drawer 6 | episode_length: 100 -------------------------------------------------------------------------------- /cfgs/rlbench_task/open_microwave.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: open_microwave 6 | episode_length: 125 -------------------------------------------------------------------------------- /cfgs/rlbench_task/open_oven.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: open_oven 6 | episode_length: 225 -------------------------------------------------------------------------------- /cfgs/rlbench_task/phone_on_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: phone_on_base 6 | episode_length: 175 -------------------------------------------------------------------------------- /cfgs/rlbench_task/pick_up_cup.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: pick_up_cup 6 | episode_length: 100 -------------------------------------------------------------------------------- /cfgs/rlbench_task/press_switch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: press_switch 6 | episode_length: 100 -------------------------------------------------------------------------------- /cfgs/rlbench_task/put_books_on_bookshelf.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: put_books_on_bookshelf 6 | episode_length: 175 -------------------------------------------------------------------------------- /cfgs/rlbench_task/put_money_in_safe.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: put_money_in_safe 6 | episode_length: 150 -------------------------------------------------------------------------------- /cfgs/rlbench_task/put_rubbish_in_bin.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: put_rubbish_in_bin 6 | episode_length: 150 -------------------------------------------------------------------------------- /cfgs/rlbench_task/reach_target.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: reach_target 6 | episode_length: 50 -------------------------------------------------------------------------------- /cfgs/rlbench_task/slide_block_to_target.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: slide_block_to_target 6 | episode_length: 150 -------------------------------------------------------------------------------- /cfgs/rlbench_task/stack_wine.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: stack_wine 6 | episode_length: 150 -------------------------------------------------------------------------------- /cfgs/rlbench_task/sweep_to_dustpan.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: sweep_to_dustpan 6 | episode_length: 100 -------------------------------------------------------------------------------- /cfgs/rlbench_task/take_lid_off_saucepan.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: take_lid_off_saucepan 6 | episode_length: 100 -------------------------------------------------------------------------------- /cfgs/rlbench_task/take_plate_off_colored_dish_rack.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: take_plate_off_colored_dish_rack 6 | episode_length: 150 -------------------------------------------------------------------------------- /cfgs/rlbench_task/toilet_seat_up.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: toilet_seat_up 6 | episode_length: 150 -------------------------------------------------------------------------------- /cfgs/rlbench_task/turn_tap.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - _self_ 4 | 5 | task_name: turn_tap 6 | episode_length: 125 -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: cqn 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.10 6 | - pip 7 | - pip: 8 | - "torch==2.3.1" 9 | - "torchvision==0.18.1" 10 | - "termcolor==2.4.0" 11 | - "imageio==2.34.1" 12 | - "imageio_ffmpeg==0.5.1" 13 | - "hydra-core==1.1.0" 14 | - "hydra-submitit-launcher==1.1.5" 15 | - "opencv-python-headless==4.10.0.82" 16 | - "numpy==1.26.4" 17 | - "tensorboard==2.17.0" 18 | - "dm_env==1.6" 19 | - "dm_control==1.0.20" 20 | - "gymnasium==0.29.1" 21 | - "wandb==0.17.4" -------------------------------------------------------------------------------- /cqn.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import utils 7 | from cqn_utils import ( 8 | random_action_if_within_delta, 9 | zoom_in, 10 | encode_action, 11 | decode_action, 12 | ) 13 | 14 | 15 | class RandomShiftsAug(nn.Module): 16 | def __init__(self, pad): 17 | super().__init__() 18 | self.pad = pad 19 | 20 | def forward(self, x): 21 | n, c, h, w = x.size() 22 | assert h == w 23 | padding = tuple([self.pad] * 4) 24 | x = F.pad(x, padding, "replicate") 25 | eps = 1.0 / (h + 2 * self.pad) 26 | arange = torch.linspace( 27 | -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype 28 | )[:h] 29 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 30 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 31 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 32 | 33 | shift = torch.randint( 34 | 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype 35 | ) 36 | shift *= 2.0 / (h + 2 * self.pad) 37 | 38 | grid = base_grid + shift 39 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) 40 | 41 | 42 | class ImgChLayerNorm(nn.Module): 43 | def __init__(self, num_channels, eps: float = 1e-5): 44 | super().__init__() 45 | self.weight = nn.Parameter(torch.ones(num_channels)) 46 | self.bias = nn.Parameter(torch.zeros(num_channels)) 47 | self.eps = eps 48 | 49 | def forward(self, x): 50 | # x: [B, C, H, W] 51 | u = x.mean(1, keepdim=True) 52 | s = (x - u).pow(2).mean(1, keepdim=True) 53 | x = (x - u) / torch.sqrt(s + self.eps) 54 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 55 | return x 56 | 57 | 58 | class MultiViewCNNEncoder(nn.Module): 59 | def __init__(self, obs_shape): 60 | super().__init__() 61 | 62 | assert len(obs_shape) == 4 63 | self.num_views = obs_shape[0] 64 | self.repr_dim = self.num_views * 256 * 5 * 5 # for 84,84. hard-coded 65 | 66 | self.conv_nets = nn.ModuleList() 67 | for _ in range(self.num_views): 68 | conv_net = nn.Sequential( 69 | nn.Conv2d(obs_shape[1], 32, 4, stride=2, padding=1), 70 | ImgChLayerNorm(32), 71 | nn.SiLU(), 72 | nn.Conv2d(32, 64, 4, stride=2, padding=1), 73 | ImgChLayerNorm(64), 74 | nn.SiLU(), 75 | nn.Conv2d(64, 128, 4, stride=2, padding=1), 76 | ImgChLayerNorm(128), 77 | nn.SiLU(), 78 | nn.Conv2d(128, 256, 4, stride=2, padding=1), 79 | ImgChLayerNorm(256), 80 | nn.SiLU(), 81 | ) 82 | self.conv_nets.append(conv_net) 83 | 84 | self.apply(utils.weight_init) 85 | 86 | def forward(self, obs: torch.Tensor): 87 | # obs: [B, V, C, H, W] 88 | obs = obs / 255.0 - 0.5 89 | hs = [] 90 | for v in range(self.num_views): 91 | h = self.conv_nets[v](obs[:, v]) 92 | h = h.view(h.shape[0], -1) 93 | hs.append(h) 94 | h = torch.cat(hs, -1) 95 | return h 96 | 97 | 98 | class C2FCriticNetwork(nn.Module): 99 | def __init__( 100 | self, 101 | repr_dim: int, 102 | low_dim: int, 103 | action_shape: Tuple, 104 | feature_dim: int, 105 | hidden_dim: int, 106 | levels: int, 107 | bins: int, 108 | atoms: int, 109 | ): 110 | super().__init__() 111 | self._levels = levels 112 | self._actor_dim = action_shape[0] 113 | self._bins = bins 114 | 115 | # Advantage stream in Dueling network 116 | self.adv_rgb_encoder = nn.Sequential( 117 | nn.Linear(repr_dim, feature_dim, bias=False), 118 | nn.LayerNorm(feature_dim), 119 | nn.Tanh(), 120 | ) 121 | self.adv_low_dim_encoder = nn.Sequential( 122 | nn.Linear(low_dim, feature_dim, bias=False), 123 | nn.LayerNorm(feature_dim), 124 | nn.Tanh(), 125 | ) 126 | self.adv_net = nn.Sequential( 127 | nn.Linear( 128 | feature_dim * 2 + self._actor_dim + levels, hidden_dim, bias=False 129 | ), 130 | nn.LayerNorm(hidden_dim), 131 | nn.SiLU(), 132 | nn.Linear(hidden_dim, hidden_dim, bias=False), 133 | nn.LayerNorm(hidden_dim), 134 | nn.SiLU(), 135 | ) 136 | self.adv_head = nn.Linear(hidden_dim, self._actor_dim * bins * atoms) 137 | self.adv_output_shape = (self._actor_dim, bins, atoms) 138 | 139 | # Value stream in Dueling network 140 | self.value_rgb_encoder = nn.Sequential( 141 | nn.Linear(repr_dim, feature_dim, bias=False), 142 | nn.LayerNorm(feature_dim), 143 | nn.Tanh(), 144 | ) 145 | self.value_low_dim_encoder = nn.Sequential( 146 | nn.Linear(low_dim, feature_dim, bias=False), 147 | nn.LayerNorm(feature_dim), 148 | nn.Tanh(), 149 | ) 150 | self.value_net = nn.Sequential( 151 | nn.Linear( 152 | feature_dim * 2 + self._actor_dim + levels, hidden_dim, bias=False 153 | ), 154 | nn.LayerNorm(hidden_dim), 155 | nn.SiLU(), 156 | nn.Linear(hidden_dim, hidden_dim, bias=False), 157 | nn.LayerNorm(hidden_dim), 158 | nn.SiLU(), 159 | ) 160 | self.value_head = nn.Linear(hidden_dim, self._actor_dim * 1 * atoms) 161 | self.value_output_shape = (self._actor_dim, 1, atoms) 162 | 163 | self.apply(utils.weight_init) 164 | self.adv_head.weight.data.fill_(0.0) 165 | self.adv_head.bias.data.fill_(0.0) 166 | self.value_head.weight.data.fill_(0.0) 167 | self.value_head.bias.data.fill_(0.0) 168 | 169 | def forward( 170 | self, level: int, rgb_obs: torch.Tensor, low_dim_obs, prev_action: torch.Tensor 171 | ): 172 | """ 173 | Inputs: 174 | - level: level index 175 | - obs: features from visual encoder 176 | - low_dim_obs: low-dimensional observations 177 | - prev_action: actions from previous level 178 | 179 | Outputs: 180 | - q_logits: (batch_size, action_dim, bins, atoms) 181 | """ 182 | level_id = ( 183 | torch.eye(self._levels, device=rgb_obs.device, dtype=rgb_obs.dtype)[level] 184 | .unsqueeze(0) 185 | .repeat_interleave(rgb_obs.shape[0], 0) 186 | ) 187 | 188 | value_h = torch.cat( 189 | [self.value_rgb_encoder(rgb_obs), self.value_low_dim_encoder(low_dim_obs)], 190 | -1, 191 | ) 192 | value_x = torch.cat([value_h, prev_action, level_id], -1) 193 | values = self.value_head(self.value_net(value_x)).view( 194 | -1, *self.value_output_shape 195 | ) 196 | 197 | adv_h = torch.cat( 198 | [self.adv_rgb_encoder(rgb_obs), self.adv_low_dim_encoder(low_dim_obs)], -1 199 | ) 200 | adv_x = torch.cat([adv_h, prev_action, level_id], -1) 201 | advs = self.adv_head(self.adv_net(adv_x)).view(-1, *self.adv_output_shape) 202 | 203 | q_logits = values + advs - advs.mean(-2, keepdim=True) 204 | return q_logits 205 | 206 | 207 | class C2FCritic(nn.Module): 208 | def __init__( 209 | self, 210 | action_shape: tuple, 211 | repr_dim: int, 212 | low_dim: int, 213 | feature_dim: int, 214 | hidden_dim: int, 215 | levels: int, 216 | bins: int, 217 | atoms: int, 218 | v_min: float, 219 | v_max: float, 220 | ): 221 | super().__init__() 222 | 223 | self.levels = levels 224 | self.bins = bins 225 | self.atoms = atoms 226 | self.v_min = v_min 227 | self.v_max = v_max 228 | actor_dim = action_shape[0] 229 | self.initial_low = nn.Parameter( 230 | torch.FloatTensor([-1.0] * actor_dim), requires_grad=False 231 | ) 232 | self.initial_high = nn.Parameter( 233 | torch.FloatTensor([1.0] * actor_dim), requires_grad=False 234 | ) 235 | self.support = nn.Parameter( 236 | torch.linspace(v_min, v_max, atoms), requires_grad=False 237 | ) 238 | self.delta_z = (v_max - v_min) / (atoms - 1) 239 | 240 | self.network = C2FCriticNetwork( 241 | repr_dim, 242 | low_dim, 243 | action_shape, 244 | feature_dim, 245 | hidden_dim, 246 | levels, 247 | bins, 248 | atoms, 249 | ) 250 | 251 | def get_action(self, rgb_obs: torch.Tensor, low_dim_obs: torch.Tensor): 252 | metrics = dict() 253 | low = self.initial_low.repeat(rgb_obs.shape[0], 1).detach() 254 | high = self.initial_high.repeat(rgb_obs.shape[0], 1).detach() 255 | 256 | for level in range(self.levels): 257 | q_logits = self.network(level, rgb_obs, low_dim_obs, (low + high) / 2) 258 | q_probs = F.softmax(q_logits, 3) 259 | qs = (q_probs * self.support.expand_as(q_probs).detach()).sum(3) 260 | argmax_q = random_action_if_within_delta(qs) 261 | if argmax_q is None: 262 | argmax_q = qs.max(-1)[1] # [..., D] 263 | # Zoom-in 264 | low, high = zoom_in(low, high, argmax_q, self.bins) 265 | 266 | # for logging 267 | qs_a = torch.gather(qs, dim=-1, index=argmax_q.unsqueeze(-1))[ 268 | ..., 0 269 | ] # [..., D] 270 | metrics[f"critic_target_q_level{level}"] = qs_a.mean().item() 271 | continuous_action = (high + low) / 2.0 # [..., D] 272 | return continuous_action, metrics 273 | 274 | def forward( 275 | self, 276 | rgb_obs: torch.Tensor, 277 | low_dim_obs: torch.Tensor, 278 | continuous_action: torch.Tensor, 279 | ): 280 | """Compute value distributions for given obs and action. 281 | 282 | Args: 283 | obs: [B, repr_dim] shaped feature tensor 284 | low_dim_obs: [B, low_dim] shaped feature tensor 285 | continuous_action: [B, D] shaped action tensor 286 | 287 | Return: 288 | q_probs: [B, L, D, bins, atoms] for value distribution at all bins 289 | q_probs_a: [B, L, D, atoms] for value distribution at given bin 290 | log_q_probs: [B, L, D, bins, atoms] with log probabilities 291 | log_q_probs_a: [B, L, D, atoms] with log probabilities 292 | """ 293 | 294 | discrete_action = encode_action( 295 | continuous_action, 296 | self.initial_low, 297 | self.initial_high, 298 | self.levels, 299 | self.bins, 300 | ) 301 | 302 | q_probs_per_level = [] 303 | q_probs_a_per_level = [] 304 | log_q_probs_per_level = [] 305 | log_q_probs_a_per_level = [] 306 | 307 | low = self.initial_low.repeat(rgb_obs.shape[0], 1).detach() 308 | high = self.initial_high.repeat(rgb_obs.shape[0], 1).detach() 309 | for level in range(self.levels): 310 | q_logits = self.network(level, rgb_obs, low_dim_obs, (low + high) / 2) 311 | argmax_q = discrete_action[..., level, :].long() # [..., L, D] -> [..., D] 312 | 313 | # (Log) Probs [..., D, bins, atoms] 314 | # (Log) Probs_a [..., D, atoms] 315 | q_probs = F.softmax(q_logits, 3) # [B, D, bins, atoms] 316 | q_probs_a = torch.gather( 317 | q_probs, 318 | dim=-2, 319 | index=argmax_q.unsqueeze(-1) 320 | .unsqueeze(-1) 321 | .repeat_interleave(self.atoms, -1), 322 | ) 323 | q_probs_a = q_probs_a[..., 0, :] # [B, D, atoms] 324 | 325 | log_q_probs = F.log_softmax(q_logits, 3) # [B, D, bins, atoms] 326 | log_q_probs_a = torch.gather( 327 | log_q_probs, 328 | dim=-2, 329 | index=argmax_q.unsqueeze(-1) 330 | .unsqueeze(-1) 331 | .repeat_interleave(self.atoms, -1), 332 | ) 333 | log_q_probs_a = log_q_probs_a[..., 0, :] # [B, D, atoms] 334 | 335 | q_probs_per_level.append(q_probs) 336 | q_probs_a_per_level.append(q_probs_a) 337 | log_q_probs_per_level.append(log_q_probs) 338 | log_q_probs_a_per_level.append(log_q_probs_a) 339 | 340 | # Zoom-in 341 | low, high = zoom_in(low, high, argmax_q, self.bins) 342 | 343 | q_probs = torch.stack(q_probs_per_level, -4) # [B, L, D, bins, atoms] 344 | q_probs_a = torch.stack(q_probs_a_per_level, -3) # [B, L, D, atoms] 345 | log_q_probs = torch.stack(log_q_probs_per_level, -4) 346 | log_q_probs_a = torch.stack(log_q_probs_a_per_level, -3) 347 | return q_probs, q_probs_a, log_q_probs, log_q_probs_a 348 | 349 | def compute_target_q_dist( 350 | self, 351 | next_rgb_obs: torch.Tensor, 352 | next_low_dim_obs: torch.Tensor, 353 | next_continuous_action: torch.Tensor, 354 | reward: torch.Tensor, 355 | discount: torch.Tensor, 356 | ): 357 | """Compute target distribution for distributional critic 358 | based on https://github.com/Kaixhin/Rainbow/blob/master/agent.py implementation 359 | 360 | Args: 361 | next_rgb_obs: [B, repr_dim] shaped feature tensor 362 | next_low_dim_obs: [B, low_dim] shaped feature tensor 363 | next_continuous_action: [B, D] shaped action tensor 364 | reward: [B, 1] shaped reward tensor 365 | discount: [B, 1] shaped discount tensor 366 | 367 | Return: 368 | m: [B, L, D, atoms] shaped tensor for value distribution 369 | """ 370 | next_q_probs_a = self.forward( 371 | next_rgb_obs, next_low_dim_obs, next_continuous_action 372 | )[1] 373 | 374 | shape = next_q_probs_a.shape # [B, L, D, atoms] 375 | next_q_probs_a = next_q_probs_a.view(-1, self.atoms) 376 | batch_size = next_q_probs_a.shape[0] 377 | 378 | # Compute Tz for [B, atoms] 379 | Tz = reward + discount * self.support.unsqueeze(0).detach() 380 | Tz = Tz.clamp(min=self.v_min, max=self.v_max) 381 | # Compute L2 projection of Tz onto fixed support z 382 | b = (Tz - self.v_min) / self.delta_z 383 | lower, upper = b.floor().to(torch.int64), b.ceil().to(torch.int64) 384 | # Fix disappearing probability mass when l =b = u (b is int) 385 | lower[(upper > 0) * (lower == upper)] -= 1 386 | upper[(lower < (self.atoms - 1)) * (lower == upper)] += 1 387 | 388 | # Repeat Tz for (L * D) times -> [B * L * D, atoms] 389 | multiplier = batch_size // lower.shape[0] 390 | b = torch.repeat_interleave(b, multiplier, 0) 391 | lower = torch.repeat_interleave(lower, multiplier, 0) 392 | upper = torch.repeat_interleave(upper, multiplier, 0) 393 | 394 | # Distribute probability of Tz 395 | m = torch.zeros_like(next_q_probs_a) 396 | offset = ( 397 | torch.linspace( 398 | 0, 399 | ((batch_size - 1) * self.atoms), 400 | batch_size, 401 | device=lower.device, 402 | dtype=lower.dtype, 403 | ) 404 | .unsqueeze(1) 405 | .expand(batch_size, self.atoms) 406 | ) 407 | m.view(-1).index_add_( 408 | 0, 409 | (lower + offset).view(-1), 410 | (next_q_probs_a * (upper.float() - b)).view(-1), 411 | ) # m_l = m_l + p(s_t+n, a*)(u - b) 412 | m.view(-1).index_add_( 413 | 0, 414 | (upper + offset).view(-1), 415 | (next_q_probs_a * (b - lower.float())).view(-1), 416 | ) # m_u = m_u + p(s_t+n, a*)(b - l) 417 | 418 | m = m.view(*shape) # [B, L, D, atoms] 419 | return m 420 | 421 | def encode_decode_action(self, continuous_action: torch.Tensor): 422 | """Encode and decode actions""" 423 | discrete_action = encode_action( 424 | continuous_action, 425 | self.initial_low, 426 | self.initial_high, 427 | self.levels, 428 | self.bins, 429 | ) 430 | continuous_action = decode_action( 431 | discrete_action, 432 | self.initial_low, 433 | self.initial_high, 434 | self.levels, 435 | self.bins, 436 | ) 437 | return continuous_action 438 | 439 | 440 | class CQNAgent: 441 | def __init__( 442 | self, 443 | rgb_obs_shape, 444 | low_dim_obs_shape, 445 | action_shape, 446 | device, 447 | lr, 448 | feature_dim, 449 | hidden_dim, 450 | levels, 451 | bins, 452 | atoms, 453 | v_min, 454 | v_max, 455 | bc_lambda, 456 | bc_margin, 457 | critic_lambda, 458 | critic_target_tau, 459 | weight_decay, 460 | num_expl_steps, 461 | update_every_steps, 462 | stddev_schedule, 463 | use_logger, 464 | ): 465 | self.device = device 466 | self.critic_target_tau = critic_target_tau 467 | self.update_every_steps = update_every_steps 468 | self.use_logger = use_logger 469 | self.num_expl_steps = num_expl_steps 470 | self.stddev_schedule = stddev_schedule 471 | self.bc_lambda = bc_lambda 472 | self.bc_margin = bc_margin 473 | self.critic_lambda = critic_lambda 474 | 475 | # models 476 | self.encoder = MultiViewCNNEncoder(rgb_obs_shape).to(device) 477 | self.critic = C2FCritic( 478 | action_shape, 479 | self.encoder.repr_dim, 480 | low_dim_obs_shape[-1], 481 | feature_dim, 482 | hidden_dim, 483 | levels, 484 | bins, 485 | atoms, 486 | v_min, 487 | v_max, 488 | ).to(device) 489 | self.critic_target = C2FCritic( 490 | action_shape, 491 | self.encoder.repr_dim, 492 | low_dim_obs_shape[-1], 493 | feature_dim, 494 | hidden_dim, 495 | levels, 496 | bins, 497 | atoms, 498 | v_min, 499 | v_max, 500 | ).to(device) 501 | self.critic_target.load_state_dict(self.critic.state_dict()) 502 | 503 | # optimizers 504 | self.encoder_opt = torch.optim.AdamW( 505 | self.encoder.parameters(), lr=lr, weight_decay=weight_decay 506 | ) 507 | self.critic_opt = torch.optim.AdamW( 508 | self.critic.parameters(), lr=lr, weight_decay=weight_decay 509 | ) 510 | 511 | # data augmentation 512 | self.aug = RandomShiftsAug(pad=4) 513 | 514 | self.train() 515 | self.critic_target.eval() 516 | 517 | print(self.encoder) 518 | print(self.critic) 519 | 520 | def train(self, training=True): 521 | self.training = training 522 | self.encoder.train(training) 523 | self.critic.train(training) 524 | 525 | def act(self, rgb_obs, low_dim_obs, step, eval_mode): 526 | rgb_obs = torch.as_tensor(rgb_obs, device=self.device).unsqueeze(0) 527 | low_dim_obs = torch.as_tensor(low_dim_obs, device=self.device).unsqueeze(0) 528 | rgb_obs = self.encoder(rgb_obs) 529 | stddev = utils.schedule(self.stddev_schedule, step) 530 | action, _ = self.critic_target.get_action( 531 | rgb_obs, low_dim_obs 532 | ) # use critic_target 533 | stddev = torch.ones_like(action) * stddev 534 | dist = utils.TruncatedNormal(action, stddev) 535 | if eval_mode: 536 | action = dist.mean 537 | else: 538 | action = dist.sample(clip=None) 539 | if step < self.num_expl_steps: 540 | action.uniform_(-1.0, 1.0) 541 | action = self.critic.encode_decode_action(action) 542 | return action.cpu().numpy()[0] 543 | 544 | def update_critic( 545 | self, 546 | rgb_obs, 547 | low_dim_obs, 548 | action, 549 | reward, 550 | discount, 551 | next_rgb_obs, 552 | next_low_dim_obs, 553 | demos, 554 | ): 555 | metrics = dict() 556 | 557 | with torch.no_grad(): 558 | next_action, mets = self.critic.get_action(next_rgb_obs, next_low_dim_obs) 559 | target_q_probs_a = self.critic_target.compute_target_q_dist( 560 | next_rgb_obs, next_low_dim_obs, next_action, reward, discount 561 | ) 562 | if self.use_logger: 563 | metrics.update(**mets) 564 | 565 | # Cross entropy loss for C51 566 | q_probs, q_probs_a, log_q_probs, log_q_probs_a = self.critic( 567 | rgb_obs, low_dim_obs, action 568 | ) 569 | q_critic_loss = -torch.sum(target_q_probs_a * log_q_probs_a, 3).mean() 570 | critic_loss = self.critic_lambda * q_critic_loss 571 | 572 | if self.use_logger: 573 | metrics["q_critic_loss"] = q_critic_loss.item() 574 | 575 | if self.bc_lambda > 0.0: 576 | qs = None 577 | demos = demos.float().squeeze(1) # [B,] 578 | if self.use_logger: 579 | metrics["ratio_of_demos"] = demos.mean().item() 580 | 581 | if torch.sum(demos) > 0: 582 | # q_probs: [B, L, D, bins, atoms], q_probs_a: [B, L, D, atoms] 583 | q_probs_cdf = torch.cumsum(q_probs, -1) 584 | q_probs_a_cdf = torch.cumsum(q_probs_a, -1) 585 | # q_probs_{a_{i}} is stochastically dominant over q_probs_{a_{-i}} 586 | bc_fosd_loss = ( 587 | (q_probs_a_cdf.unsqueeze(-2) - q_probs_cdf) 588 | .clamp(min=0) 589 | .sum(-1) 590 | .mean([-1, -2, -3]) 591 | ) 592 | bc_fosd_loss = (bc_fosd_loss * demos).sum() / demos.sum() 593 | critic_loss = critic_loss + self.bc_lambda * bc_fosd_loss 594 | if self.use_logger: 595 | metrics["bc_fosd_loss"] = bc_fosd_loss.item() 596 | 597 | if self.bc_margin > 0: 598 | qs = (q_probs * self.critic.support.expand_as(q_probs)).sum(-1) 599 | qs_a = (q_probs_a * self.critic.support.expand_as(q_probs_a)).sum( 600 | -1 601 | ) 602 | margin_loss = torch.clamp( 603 | self.bc_margin - (qs_a.unsqueeze(-1) - qs), min=0 604 | ).mean([-1, -2, -3]) 605 | margin_loss = (margin_loss * demos).sum() / demos.sum() 606 | critic_loss = critic_loss + self.bc_lambda * margin_loss 607 | if self.use_logger: 608 | metrics["bc_margin_loss"] = margin_loss.item() 609 | 610 | # optimize encoder and critic 611 | self.encoder_opt.zero_grad(set_to_none=True) 612 | self.critic_opt.zero_grad(set_to_none=True) 613 | critic_loss.backward() 614 | self.critic_opt.step() 615 | self.encoder_opt.step() 616 | 617 | return metrics 618 | 619 | def update(self, replay_iter, step): 620 | metrics = dict() 621 | 622 | if step % self.update_every_steps != 0: 623 | return metrics 624 | 625 | batch = next(replay_iter) 626 | ( 627 | rgb_obs, 628 | low_dim_obs, 629 | action, 630 | reward, 631 | discount, 632 | next_rgb_obs, 633 | next_low_dim_obs, 634 | demos, 635 | ) = utils.to_torch(batch, self.device) 636 | 637 | # augment 638 | rgb_obs = rgb_obs.float() 639 | next_rgb_obs = next_rgb_obs.float() 640 | rgb_obs = torch.stack( 641 | [self.aug(rgb_obs[:, v]) for v in range(rgb_obs.shape[1])], 1 642 | ) 643 | next_rgb_obs = torch.stack( 644 | [self.aug(next_rgb_obs[:, v]) for v in range(next_rgb_obs.shape[1])], 1 645 | ) 646 | # encode 647 | rgb_obs = self.encoder(rgb_obs) 648 | with torch.no_grad(): 649 | next_rgb_obs = self.encoder(next_rgb_obs) 650 | 651 | if self.use_logger: 652 | metrics["batch_reward"] = reward.mean().item() 653 | 654 | # update critic 655 | metrics.update( 656 | self.update_critic( 657 | rgb_obs, 658 | low_dim_obs, 659 | action, 660 | reward, 661 | discount, 662 | next_rgb_obs, 663 | next_low_dim_obs, 664 | demos, 665 | ) 666 | ) 667 | 668 | # update critic target 669 | utils.soft_update_params( 670 | self.critic, self.critic_target, self.critic_target_tau 671 | ) 672 | 673 | return metrics 674 | -------------------------------------------------------------------------------- /cqn_dmc.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from functools import partial 3 | from typing import Tuple 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | from cqn_utils import ( 11 | random_action_if_within_delta, 12 | zoom_in, 13 | encode_action, 14 | decode_action, 15 | ) 16 | 17 | 18 | class RandomShiftsAug(nn.Module): 19 | def __init__(self, pad): 20 | super().__init__() 21 | self.pad = pad 22 | 23 | def forward(self, x): 24 | n, c, h, w = x.size() 25 | assert h == w 26 | padding = tuple([self.pad] * 4) 27 | x = F.pad(x, padding, "replicate") 28 | eps = 1.0 / (h + 2 * self.pad) 29 | arange = torch.linspace( 30 | -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype 31 | )[:h] 32 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 33 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 34 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 35 | 36 | shift = torch.randint( 37 | 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype 38 | ) 39 | shift *= 2.0 / (h + 2 * self.pad) 40 | 41 | grid = base_grid + shift 42 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) 43 | 44 | 45 | class ImgChLayerNorm(nn.Module): 46 | def __init__(self, num_channels, eps: float = 1e-5): 47 | super().__init__() 48 | self.weight = nn.Parameter(torch.ones(num_channels)) 49 | self.bias = nn.Parameter(torch.zeros(num_channels)) 50 | self.eps = eps 51 | 52 | def forward(self, x): 53 | # x: [B, C, H, W] 54 | u = x.mean(1, keepdim=True) 55 | s = (x - u).pow(2).mean(1, keepdim=True) 56 | x = (x - u) / torch.sqrt(s + self.eps) 57 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 58 | return x 59 | 60 | 61 | class Encoder(nn.Module): 62 | def __init__(self, obs_shape): 63 | super().__init__() 64 | 65 | assert len(obs_shape) == 3 66 | self.repr_dim = 32 * 35 * 35 67 | 68 | self.convnet = nn.Sequential( 69 | nn.Conv2d(obs_shape[0], 32, 3, stride=2), 70 | nn.GroupNorm(1, 32), 71 | nn.SiLU(inplace=True), 72 | nn.Conv2d(32, 32, 3, stride=1), 73 | nn.GroupNorm(1, 32), 74 | nn.SiLU(inplace=True), 75 | nn.Conv2d(32, 32, 3, stride=1), 76 | nn.GroupNorm(1, 32), 77 | nn.SiLU(inplace=True), 78 | nn.Conv2d(32, 32, 3, stride=1), 79 | nn.GroupNorm(1, 32), 80 | nn.SiLU(inplace=True), 81 | ) 82 | 83 | self.apply(utils.weight_init) 84 | 85 | def forward(self, obs): 86 | obs = obs / 255.0 - 0.5 87 | h = self.convnet(obs) 88 | h = h.view(h.shape[0], -1) 89 | return h 90 | 91 | 92 | class C2FCriticNetwork(nn.Module): 93 | def __init__( 94 | self, 95 | repr_dim: int, 96 | action_shape: Tuple, 97 | feature_dim: int, 98 | hidden_dim: int, 99 | levels: int, 100 | bins: int, 101 | atoms: int, 102 | ): 103 | super().__init__() 104 | self._levels = levels 105 | self._actor_dim = action_shape[0] 106 | self._bins = bins 107 | 108 | # Advantage stream in Dueling network 109 | self.adv_trunk = nn.Sequential( 110 | nn.Linear(repr_dim, feature_dim, bias=False), 111 | nn.LayerNorm(feature_dim), 112 | nn.Tanh(), 113 | ) 114 | self.adv_net = nn.Sequential( 115 | nn.Linear(feature_dim + self._actor_dim + levels, hidden_dim, bias=False), 116 | nn.LayerNorm(hidden_dim), 117 | nn.SiLU(inplace=True), 118 | nn.Linear(hidden_dim, hidden_dim, bias=False), 119 | nn.LayerNorm(hidden_dim), 120 | nn.SiLU(inplace=True), 121 | ) 122 | self.adv_head = nn.Linear(hidden_dim, self._actor_dim * bins * atoms) 123 | self.adv_output_shape = (self._actor_dim, bins, atoms) 124 | 125 | # Value stream in Dueling network 126 | self.value_trunk = nn.Sequential( 127 | nn.Linear(repr_dim, feature_dim, bias=False), 128 | nn.LayerNorm(feature_dim), 129 | nn.Tanh(), 130 | ) 131 | self.value_net = nn.Sequential( 132 | nn.Linear(feature_dim + self._actor_dim + levels, hidden_dim, bias=False), 133 | nn.LayerNorm(hidden_dim), 134 | nn.SiLU(inplace=True), 135 | nn.Linear(hidden_dim, hidden_dim, bias=False), 136 | nn.LayerNorm(hidden_dim), 137 | nn.SiLU(inplace=True), 138 | ) 139 | self.value_head = nn.Linear(hidden_dim, self._actor_dim * 1 * atoms) 140 | self.value_output_shape = (self._actor_dim, 1, atoms) 141 | 142 | self.apply(utils.weight_init) 143 | self.adv_head.weight.data.fill_(0.0) 144 | self.adv_head.bias.data.fill_(0.0) 145 | self.value_head.weight.data.fill_(0.0) 146 | self.value_head.bias.data.fill_(0.0) 147 | 148 | def forward(self, level: int, obs: torch.Tensor, prev_action: torch.Tensor): 149 | """ 150 | Inputs: 151 | - level: level index 152 | - obs: features from visual encoder 153 | - prev_action: actions from previous level 154 | 155 | Outputs: 156 | - q_logits: (batch_size, action_dim, bins, atoms) 157 | """ 158 | level_id = ( 159 | torch.eye(self._levels, device=obs.device, dtype=obs.dtype)[level] 160 | .unsqueeze(0) 161 | .repeat_interleave(obs.shape[0], 0) 162 | ) 163 | 164 | value_h = self.value_trunk(obs) 165 | value_x = torch.cat([value_h, prev_action, level_id], -1) 166 | values = self.value_head(self.value_net(value_x)).view( 167 | -1, *self.value_output_shape 168 | ) 169 | 170 | adv_h = self.adv_trunk(obs) 171 | adv_x = torch.cat([adv_h, prev_action, level_id], -1) 172 | advs = self.adv_head(self.adv_net(adv_x)).view(-1, *self.adv_output_shape) 173 | 174 | q_logits = values + advs - advs.mean(-2, keepdim=True) 175 | return q_logits 176 | 177 | 178 | class C2FCritic(nn.Module): 179 | def __init__( 180 | self, 181 | action_shape: tuple, 182 | repr_dim: int, 183 | feature_dim: int, 184 | hidden_dim: int, 185 | levels: int, 186 | bins: int, 187 | atoms: int, 188 | v_min: float, 189 | v_max: float, 190 | ): 191 | super().__init__() 192 | 193 | self.levels = levels 194 | self.bins = bins 195 | self.atoms = atoms 196 | self.v_min = v_min 197 | self.v_max = v_max 198 | actor_dim = action_shape[0] 199 | self.initial_low = nn.Parameter( 200 | torch.FloatTensor([-1.0] * actor_dim), requires_grad=False 201 | ) 202 | self.initial_high = nn.Parameter( 203 | torch.FloatTensor([1.0] * actor_dim), requires_grad=False 204 | ) 205 | self.support = nn.Parameter( 206 | torch.linspace(v_min, v_max, atoms), requires_grad=False 207 | ) 208 | self.delta_z = (v_max - v_min) / (atoms - 1) 209 | 210 | self.network = C2FCriticNetwork( 211 | repr_dim, action_shape, feature_dim, hidden_dim, levels, bins, atoms 212 | ) 213 | 214 | def get_action(self, obs: torch.Tensor): 215 | metrics = dict() 216 | low = self.initial_low.repeat(obs.shape[0], 1).detach() 217 | high = self.initial_high.repeat(obs.shape[0], 1).detach() 218 | 219 | for level in range(self.levels): 220 | q_logits = self.network(level, obs, (low + high) / 2) 221 | q_probs = F.softmax(q_logits, 3) 222 | qs = (q_probs * self.support.expand_as(q_probs).detach()).sum(3) 223 | argmax_q = random_action_if_within_delta(qs) 224 | if argmax_q is None: 225 | argmax_q = qs.max(-1)[1] # [..., D] 226 | # Zoom-in 227 | low, high = zoom_in(low, high, argmax_q, self.bins) 228 | 229 | # for logging 230 | qs_a = torch.gather(qs, dim=-1, index=argmax_q.unsqueeze(-1))[ 231 | ..., 0 232 | ] # [..., D] 233 | metrics[f"critic_target_q_level{level}"] = qs_a.mean().item() 234 | continuous_action = (high + low) / 2.0 # [..., D] 235 | return continuous_action, metrics 236 | 237 | def forward( 238 | self, 239 | obs: torch.Tensor, 240 | continuous_action: torch.Tensor, 241 | ): 242 | """Compute value distributions for given obs and action. 243 | 244 | Args: 245 | obs: [B, F] shaped feature tensor 246 | continuous_action: [B, D] shaped action tensor 247 | 248 | Return: 249 | q_probs: [B, L, D, bins, atoms] for value distribution at all bins 250 | q_probs_a: [B, L, D, atoms] for value distribution at given bin 251 | log_q_probs: [B, L, D, bins, atoms] with log probabilities 252 | log_q_probs_a: [B, L, D, atoms] with log probabilities 253 | """ 254 | 255 | discrete_action = encode_action( 256 | continuous_action, 257 | self.initial_low, 258 | self.initial_high, 259 | self.levels, 260 | self.bins, 261 | ) 262 | 263 | q_probs_per_level = [] 264 | q_probs_a_per_level = [] 265 | log_q_probs_per_level = [] 266 | log_q_probs_a_per_level = [] 267 | 268 | low = self.initial_low.repeat(obs.shape[0], 1).detach() 269 | high = self.initial_high.repeat(obs.shape[0], 1).detach() 270 | for level in range(self.levels): 271 | q_logits = self.network(level, obs, (low + high) / 2) 272 | argmax_q = discrete_action[..., level, :].long() # [..., L, D] -> [..., D] 273 | 274 | # (Log) Probs [..., D, bins, atoms] 275 | # (Log) Probs_a [..., D, atoms] 276 | q_probs = F.softmax(q_logits, 3) # [B, D, bins, atoms] 277 | q_probs_a = torch.gather( 278 | q_probs, 279 | dim=-2, 280 | index=argmax_q.unsqueeze(-1) 281 | .unsqueeze(-1) 282 | .repeat_interleave(self.atoms, -1), 283 | ) 284 | q_probs_a = q_probs_a[..., 0, :] # [B, D, atoms] 285 | 286 | log_q_probs = F.log_softmax(q_logits, 3) # [B, D, bins, atoms] 287 | log_q_probs_a = torch.gather( 288 | log_q_probs, 289 | dim=-2, 290 | index=argmax_q.unsqueeze(-1) 291 | .unsqueeze(-1) 292 | .repeat_interleave(self.atoms, -1), 293 | ) 294 | log_q_probs_a = log_q_probs_a[..., 0, :] # [B, D, atoms] 295 | 296 | q_probs_per_level.append(q_probs) 297 | q_probs_a_per_level.append(q_probs_a) 298 | log_q_probs_per_level.append(log_q_probs) 299 | log_q_probs_a_per_level.append(log_q_probs_a) 300 | 301 | # Zoom-in 302 | low, high = zoom_in(low, high, argmax_q, self.bins) 303 | 304 | q_probs = torch.stack(q_probs_per_level, -4) # [B, L, D, bins, atoms] 305 | q_probs_a = torch.stack(q_probs_a_per_level, -3) # [B, L, D, atoms] 306 | log_q_probs = torch.stack(log_q_probs_per_level, -4) 307 | log_q_probs_a = torch.stack(log_q_probs_a_per_level, -3) 308 | return q_probs, q_probs_a, log_q_probs, log_q_probs_a 309 | 310 | def compute_target_q_dist( 311 | self, 312 | next_obs: torch.Tensor, 313 | next_continuous_action: torch.Tensor, 314 | reward: torch.Tensor, 315 | discount: torch.Tensor, 316 | ): 317 | """Compute target distribution for distributional critic 318 | based on https://github.com/Kaixhin/Rainbow/blob/master/agent.py implementation 319 | 320 | Args: 321 | next_obs: [B, F] shaped feature tensor 322 | next_continuous_action: [B, D] shaped action tensor 323 | reward: [B, 1] shaped reward tensor 324 | discount: [B, 1] shaped discount tensor 325 | 326 | Return: 327 | m: [B, L, D, atoms] shaped tensor for value distribution 328 | """ 329 | next_q_probs_a = self.forward(next_obs, next_continuous_action)[1] 330 | 331 | shape = next_q_probs_a.shape # [B, L, D, atoms] 332 | next_q_probs_a = next_q_probs_a.view(-1, self.atoms) 333 | batch_size = next_q_probs_a.shape[0] 334 | 335 | # Compute Tz for [B, atoms] 336 | Tz = reward + discount * self.support.unsqueeze(0).detach() 337 | Tz = Tz.clamp(min=self.v_min, max=self.v_max) 338 | # Compute L2 projection of Tz onto fixed support z 339 | b = (Tz - self.v_min) / self.delta_z 340 | lower, upper = b.floor().to(torch.int64), b.ceil().to(torch.int64) 341 | # Fix disappearing probability mass when l =b = u (b is int) 342 | lower[(upper > 0) * (lower == upper)] -= 1 343 | upper[(lower < (self.atoms - 1)) * (lower == upper)] += 1 344 | 345 | # Repeat Tz for (L * D) times -> [B * L * D, atoms] 346 | multiplier = batch_size // lower.shape[0] 347 | b = torch.repeat_interleave(b, multiplier, 0) 348 | lower = torch.repeat_interleave(lower, multiplier, 0) 349 | upper = torch.repeat_interleave(upper, multiplier, 0) 350 | 351 | # Distribute probability of Tz 352 | m = torch.zeros_like(next_q_probs_a) 353 | offset = ( 354 | torch.linspace( 355 | 0, 356 | ((batch_size - 1) * self.atoms), 357 | batch_size, 358 | device=lower.device, 359 | dtype=lower.dtype, 360 | ) 361 | .unsqueeze(1) 362 | .expand(batch_size, self.atoms) 363 | ) 364 | m.view(-1).index_add_( 365 | 0, 366 | (lower + offset).view(-1), 367 | (next_q_probs_a * (upper.float() - b)).view(-1), 368 | ) # m_l = m_l + p(s_t+n, a*)(u - b) 369 | m.view(-1).index_add_( 370 | 0, 371 | (upper + offset).view(-1), 372 | (next_q_probs_a * (b - lower.float())).view(-1), 373 | ) # m_u = m_u + p(s_t+n, a*)(b - l) 374 | 375 | m = m.view(*shape) # [B, L, D, atoms] 376 | return m 377 | 378 | def encode_decode_action(self, continuous_action: torch.Tensor): 379 | """Encode and decode actions""" 380 | discrete_action = encode_action( 381 | continuous_action, 382 | self.initial_low, 383 | self.initial_high, 384 | self.levels, 385 | self.bins, 386 | ) 387 | continuous_action = decode_action( 388 | discrete_action, 389 | self.initial_low, 390 | self.initial_high, 391 | self.levels, 392 | self.bins, 393 | ) 394 | return continuous_action 395 | 396 | 397 | class CQNAgent: 398 | def __init__( 399 | self, 400 | obs_shape, 401 | action_shape, 402 | device, 403 | lr, 404 | feature_dim, 405 | hidden_dim, 406 | levels, 407 | bins, 408 | atoms, 409 | v_min, 410 | v_max, 411 | critic_target_tau, 412 | num_expl_steps, 413 | update_every_steps, 414 | stddev_schedule, 415 | use_logger, 416 | ): 417 | self.device = device 418 | self.critic_target_tau = critic_target_tau 419 | self.update_every_steps = update_every_steps 420 | self.use_logger = use_logger 421 | self.num_expl_steps = num_expl_steps 422 | self.stddev_schedule = stddev_schedule 423 | 424 | # models 425 | self.encoder = Encoder(obs_shape).to(device) 426 | self.critic = C2FCritic( 427 | action_shape, 428 | self.encoder.repr_dim, 429 | feature_dim, 430 | hidden_dim, 431 | levels, 432 | bins, 433 | atoms, 434 | v_min, 435 | v_max, 436 | ).to(device) 437 | self.critic_target = C2FCritic( 438 | action_shape, 439 | self.encoder.repr_dim, 440 | feature_dim, 441 | hidden_dim, 442 | levels, 443 | bins, 444 | atoms, 445 | v_min, 446 | v_max, 447 | ).to(device) 448 | self.critic_target.load_state_dict(self.critic.state_dict()) 449 | 450 | # optimizers 451 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr) 452 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 453 | 454 | # data augmentation 455 | self.aug = RandomShiftsAug(pad=4) 456 | 457 | self.train() 458 | self.critic_target.eval() 459 | 460 | print(self.encoder) 461 | print(self.critic) 462 | 463 | def train(self, training=True): 464 | self.training = training 465 | self.encoder.train(training) 466 | self.critic.train(training) 467 | 468 | def act(self, obs, step, eval_mode): 469 | obs = torch.as_tensor(obs, device=self.device) 470 | obs = self.encoder(obs.unsqueeze(0)) 471 | stddev = utils.schedule(self.stddev_schedule, step) 472 | action, _ = self.critic.get_action(obs) # use critic_target 473 | stddev = torch.ones_like(action) * stddev 474 | dist = utils.TruncatedNormal(action, stddev) 475 | if eval_mode: 476 | action = dist.mean 477 | else: 478 | action = dist.sample(clip=None) 479 | if step < self.num_expl_steps: 480 | action.uniform_(-1.0, 1.0) 481 | action = self.critic.encode_decode_action(action) 482 | return action.cpu().numpy()[0] 483 | 484 | def update_critic(self, obs, action, reward, discount, next_obs): 485 | metrics = dict() 486 | 487 | with torch.no_grad(): 488 | next_action, mets = self.critic.get_action(next_obs) 489 | target_q_probs_a = self.critic_target.compute_target_q_dist( 490 | next_obs, next_action, reward, discount 491 | ) 492 | if self.use_logger: 493 | metrics.update(**mets) 494 | 495 | # Cross entropy loss for C51 496 | log_q_probs_a = self.critic(obs, action)[3] 497 | critic_loss = -torch.sum(target_q_probs_a * log_q_probs_a, 3).mean() 498 | 499 | if self.use_logger: 500 | metrics["critic_loss"] = critic_loss.item() 501 | 502 | # optimize encoder and critic 503 | self.encoder_opt.zero_grad(set_to_none=True) 504 | self.critic_opt.zero_grad(set_to_none=True) 505 | critic_loss.backward() 506 | self.critic_opt.step() 507 | self.encoder_opt.step() 508 | 509 | return metrics 510 | 511 | def update(self, replay_iter, step): 512 | metrics = dict() 513 | 514 | if step % self.update_every_steps != 0: 515 | return metrics 516 | 517 | batch = next(replay_iter) 518 | obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device) 519 | 520 | # augment 521 | obs = self.aug(obs.float()) 522 | next_obs = self.aug(next_obs.float()) 523 | # encode 524 | obs = self.encoder(obs) 525 | with torch.no_grad(): 526 | next_obs = self.encoder(next_obs) 527 | 528 | if self.use_logger: 529 | metrics["batch_reward"] = reward.mean().item() 530 | 531 | # update critic 532 | metrics.update(self.update_critic(obs, action, reward, discount, next_obs)) 533 | 534 | # update critic target 535 | utils.soft_update_params( 536 | self.critic, self.critic_target, self.critic_target_tau 537 | ) 538 | 539 | return metrics 540 | -------------------------------------------------------------------------------- /cqn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def random_action_if_within_delta(qs, delta=0.0001): 5 | q_diff = qs.max(-1).values - qs.min(-1).values 6 | random_action_mask = q_diff < delta 7 | if random_action_mask.sum() == 0: 8 | return None 9 | argmax_q = qs.max(-1)[1] 10 | random_actions = torch.randint(0, qs.size(-1), random_action_mask.shape).to( 11 | qs.device 12 | ) 13 | argmax_q = torch.where(random_action_mask, random_actions, argmax_q) 14 | return argmax_q 15 | 16 | 17 | def encode_action( 18 | continuous_action: torch.Tensor, 19 | initial_low: torch.Tensor, 20 | initial_high: torch.Tensor, 21 | levels: int, 22 | bins: int, 23 | ): 24 | """Encode continuous action to discrete action 25 | 26 | Args: 27 | continuous_action: [..., D] shape tensor 28 | initial_low: [D] shape tensor consisting of -1 29 | initial_high: [D] shape tensor consisting of 1 30 | Returns: 31 | discrete_action: [..., L, D] shape tensor where L is the level 32 | """ 33 | low = initial_low.repeat(*continuous_action.shape[:-1], 1) 34 | high = initial_high.repeat(*continuous_action.shape[:-1], 1) 35 | 36 | idxs = [] 37 | for _ in range(levels): 38 | # Put continuous values into bin 39 | slice_range = (high - low) / bins 40 | idx = torch.floor((continuous_action - low) / slice_range) 41 | idx = torch.clip(idx, 0, bins - 1) 42 | idxs.append(idx) 43 | 44 | # Re-compute low/high for each bin (i.e., Zoom-in) 45 | recalculated_action = low + slice_range * idx 46 | recalculated_action = torch.clip(recalculated_action, -1.0, 1.0) 47 | low = recalculated_action 48 | high = recalculated_action + slice_range 49 | low = torch.maximum(-torch.ones_like(low), low) 50 | high = torch.minimum(torch.ones_like(high), high) 51 | discrete_action = torch.stack(idxs, -2) 52 | return discrete_action 53 | 54 | 55 | def decode_action( 56 | discrete_action: torch.Tensor, 57 | initial_low: torch.Tensor, 58 | initial_high: torch.Tensor, 59 | levels: int, 60 | bins: int, 61 | ): 62 | """Decode discrete action to continuous action 63 | 64 | Args: 65 | discrete_action: [..., L, D] shape tensor 66 | initial_low: [D] shape tensor consisting of -1 67 | initial_high: [D] shape tensor consisting of 1 68 | Returns: 69 | continuous_action: [..., D] shape tensor 70 | """ 71 | low = initial_low.repeat(*discrete_action.shape[:-2], 1) 72 | high = initial_high.repeat(*discrete_action.shape[:-2], 1) 73 | for i in range(levels): 74 | slice_range = (high - low) / bins 75 | continuous_action = low + slice_range * discrete_action[..., i, :] 76 | low = continuous_action 77 | high = continuous_action + slice_range 78 | low = torch.maximum(-torch.ones_like(low), low) 79 | high = torch.minimum(torch.ones_like(high), high) 80 | continuous_action = (high + low) / 2.0 81 | return continuous_action 82 | 83 | 84 | def zoom_in(low: torch.Tensor, high: torch.Tensor, argmax_q: torch.Tensor, bins: int): 85 | """Zoom-in to the selected interval 86 | 87 | Args: 88 | low: [D] shape tensor that denotes minimum of the current interval 89 | high: [D] shape tensor that denotes maximum of the current interval 90 | Returns: 91 | low: [D] shape tensor that denotes minimum of the *next* interval 92 | high: [D] shape tensor that denotes maximum of the *next* interval 93 | """ 94 | slice_range = (high - low) / bins 95 | continuous_action = low + slice_range * argmax_q 96 | low = continuous_action 97 | high = continuous_action + slice_range 98 | low = torch.maximum(-torch.ones_like(low), low) 99 | high = torch.minimum(torch.ones_like(high), high) 100 | return low, high 101 | -------------------------------------------------------------------------------- /dmc.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Any, NamedTuple 3 | 4 | import dm_env 5 | import numpy as np 6 | from dm_control import manipulation, suite 7 | from dm_control.suite.wrappers import action_scale, pixels 8 | from dm_env import StepType, specs 9 | 10 | 11 | class ExtendedTimeStep(NamedTuple): 12 | step_type: Any 13 | reward: Any 14 | discount: Any 15 | observation: Any 16 | action: Any 17 | 18 | def first(self): 19 | return self.step_type == StepType.FIRST 20 | 21 | def mid(self): 22 | return self.step_type == StepType.MID 23 | 24 | def last(self): 25 | return self.step_type == StepType.LAST 26 | 27 | def __getitem__(self, attr): 28 | if isinstance(attr, str): 29 | return getattr(self, attr) 30 | else: 31 | return tuple.__getitem__(self, attr) 32 | 33 | 34 | class ActionRepeatWrapper(dm_env.Environment): 35 | def __init__(self, env, num_repeats): 36 | self._env = env 37 | self._num_repeats = num_repeats 38 | 39 | def step(self, action): 40 | reward = 0.0 41 | discount = 1.0 42 | for i in range(self._num_repeats): 43 | time_step = self._env.step(action) 44 | reward += (time_step.reward or 0.0) * discount 45 | discount *= time_step.discount 46 | if time_step.last(): 47 | break 48 | 49 | return time_step._replace(reward=reward, discount=discount) 50 | 51 | def observation_spec(self): 52 | return self._env.observation_spec() 53 | 54 | def action_spec(self): 55 | return self._env.action_spec() 56 | 57 | def reset(self): 58 | return self._env.reset() 59 | 60 | def __getattr__(self, name): 61 | return getattr(self._env, name) 62 | 63 | 64 | class FrameStackWrapper(dm_env.Environment): 65 | def __init__(self, env, num_frames, pixels_key="pixels"): 66 | self._env = env 67 | self._num_frames = num_frames 68 | self._frames = deque([], maxlen=num_frames) 69 | self._pixels_key = pixels_key 70 | 71 | wrapped_obs_spec = env.observation_spec() 72 | assert pixels_key in wrapped_obs_spec 73 | 74 | pixels_shape = wrapped_obs_spec[pixels_key].shape 75 | # remove batch dim 76 | if len(pixels_shape) == 4: 77 | pixels_shape = pixels_shape[1:] 78 | self._obs_spec = specs.BoundedArray( 79 | shape=np.concatenate( 80 | [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0 81 | ), 82 | dtype=np.uint8, 83 | minimum=0, 84 | maximum=255, 85 | name="observation", 86 | ) 87 | 88 | def _transform_observation(self, time_step): 89 | assert len(self._frames) == self._num_frames 90 | obs = np.concatenate(list(self._frames), axis=0) 91 | return time_step._replace(observation=obs) 92 | 93 | def _extract_pixels(self, time_step): 94 | pixels = time_step.observation[self._pixels_key] 95 | # remove batch dim 96 | if len(pixels.shape) == 4: 97 | pixels = pixels[0] 98 | return pixels.transpose(2, 0, 1).copy() 99 | 100 | def reset(self): 101 | time_step = self._env.reset() 102 | pixels = self._extract_pixels(time_step) 103 | for _ in range(self._num_frames): 104 | self._frames.append(pixels) 105 | return self._transform_observation(time_step) 106 | 107 | def step(self, action): 108 | time_step = self._env.step(action) 109 | pixels = self._extract_pixels(time_step) 110 | self._frames.append(pixels) 111 | return self._transform_observation(time_step) 112 | 113 | def observation_spec(self): 114 | return self._obs_spec 115 | 116 | def action_spec(self): 117 | return self._env.action_spec() 118 | 119 | def __getattr__(self, name): 120 | return getattr(self._env, name) 121 | 122 | 123 | class ActionDTypeWrapper(dm_env.Environment): 124 | def __init__(self, env, dtype): 125 | self._env = env 126 | wrapped_action_spec = env.action_spec() 127 | self._action_spec = specs.BoundedArray( 128 | wrapped_action_spec.shape, 129 | dtype, 130 | wrapped_action_spec.minimum, 131 | wrapped_action_spec.maximum, 132 | "action", 133 | ) 134 | 135 | def step(self, action): 136 | action = action.astype(self._env.action_spec().dtype) 137 | return self._env.step(action) 138 | 139 | def observation_spec(self): 140 | return self._env.observation_spec() 141 | 142 | def action_spec(self): 143 | return self._action_spec 144 | 145 | def reset(self): 146 | return self._env.reset() 147 | 148 | def __getattr__(self, name): 149 | return getattr(self._env, name) 150 | 151 | 152 | class ExtendedTimeStepWrapper(dm_env.Environment): 153 | def __init__(self, env): 154 | self._env = env 155 | 156 | def reset(self): 157 | time_step = self._env.reset() 158 | return self._augment_time_step(time_step) 159 | 160 | def step(self, action): 161 | time_step = self._env.step(action) 162 | return self._augment_time_step(time_step, action) 163 | 164 | def _augment_time_step(self, time_step, action=None): 165 | if action is None: 166 | action_spec = self.action_spec() 167 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 168 | return ExtendedTimeStep( 169 | observation=time_step.observation, 170 | step_type=time_step.step_type, 171 | action=action, 172 | reward=time_step.reward or 0.0, 173 | discount=time_step.discount or 1.0, 174 | ) 175 | 176 | def observation_spec(self): 177 | return self._env.observation_spec() 178 | 179 | def action_spec(self): 180 | return self._env.action_spec() 181 | 182 | def __getattr__(self, name): 183 | return getattr(self._env, name) 184 | 185 | 186 | def make(name, frame_stack, action_repeat, seed): 187 | domain, task = name.split("_", 1) 188 | # overwrite cup to ball_in_cup 189 | domain = dict(cup="ball_in_cup").get(domain, domain) 190 | # make sure reward is not visualized 191 | if (domain, task) in suite.ALL_TASKS: 192 | env = suite.load( 193 | domain, task, task_kwargs={"random": seed}, visualize_reward=False 194 | ) 195 | pixels_key = "pixels" 196 | else: 197 | name = f"{domain}_{task}_vision" 198 | env = manipulation.load(name, seed=seed) 199 | pixels_key = "front_close" 200 | # add wrappers 201 | env = ActionDTypeWrapper(env, np.float32) 202 | env = ActionRepeatWrapper(env, action_repeat) 203 | env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0) 204 | # add renderings for clasical tasks 205 | if (domain, task) in suite.ALL_TASKS: 206 | # zoom in camera for quadruped 207 | camera_id = dict(quadruped=2).get(domain, 0) 208 | render_kwargs = dict(height=84, width=84, camera_id=camera_id) 209 | env = pixels.Wrapper(env, pixels_only=True, render_kwargs=render_kwargs) 210 | # stack several frames 211 | env = FrameStackWrapper(env, frame_stack, pixels_key) 212 | env = ExtendedTimeStepWrapper(env) 213 | return env 214 | -------------------------------------------------------------------------------- /drqv2plus.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import utils 7 | 8 | 9 | class RandomShiftsAug(nn.Module): 10 | def __init__(self, pad): 11 | super().__init__() 12 | self.pad = pad 13 | 14 | def forward(self, x): 15 | n, c, h, w = x.size() 16 | assert h == w 17 | padding = tuple([self.pad] * 4) 18 | x = F.pad(x, padding, "replicate") 19 | eps = 1.0 / (h + 2 * self.pad) 20 | arange = torch.linspace( 21 | -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype 22 | )[:h] 23 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 24 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 25 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 26 | 27 | shift = torch.randint( 28 | 0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype 29 | ) 30 | shift *= 2.0 / (h + 2 * self.pad) 31 | 32 | grid = base_grid + shift 33 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) 34 | 35 | 36 | class MultiViewCNNEncoder(nn.Module): 37 | def __init__(self, obs_shape): 38 | super().__init__() 39 | 40 | assert len(obs_shape) == 4 41 | self.num_views = obs_shape[0] 42 | self.repr_dim = self.num_views * 256 * 5 * 5 # for 84,84. hard-coded 43 | 44 | self.conv_nets = nn.ModuleList() 45 | for _ in range(self.num_views): 46 | conv_net = nn.Sequential( 47 | nn.Conv2d(obs_shape[1], 32, 4, stride=2, padding=1), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(32, 64, 4, stride=2, padding=1), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(64, 128, 4, stride=2, padding=1), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(128, 256, 4, stride=2, padding=1), 54 | nn.ReLU(inplace=True), 55 | ) 56 | self.conv_nets.append(conv_net) 57 | 58 | self.apply(utils.weight_init) 59 | 60 | def forward(self, obs: torch.Tensor): 61 | # obs: [B, V, C, H, W] 62 | obs = obs / 255.0 - 0.5 63 | hs = [] 64 | for v in range(self.num_views): 65 | h = self.conv_nets[v](obs[:, v]) 66 | h = h.view(h.shape[0], -1) 67 | hs.append(h) 68 | h = torch.cat(hs, -1) 69 | return h 70 | 71 | 72 | class Actor(nn.Module): 73 | def __init__( 74 | self, 75 | repr_dim: int, 76 | low_dim: int, 77 | action_shape: Tuple, 78 | feature_dim: int, 79 | hidden_dim: int, 80 | ): 81 | super().__init__() 82 | self._actor_dim = action_shape[0] 83 | 84 | self.rgb_encoder = nn.Sequential( 85 | nn.Linear(repr_dim, feature_dim), 86 | nn.LayerNorm(feature_dim), 87 | nn.Tanh(), 88 | ) 89 | self.low_dim_encoder = nn.Sequential( 90 | nn.Linear(low_dim, feature_dim), 91 | nn.LayerNorm(feature_dim), 92 | nn.Tanh(), 93 | ) 94 | self.policy = nn.Sequential( 95 | nn.Linear(feature_dim * 2, hidden_dim), 96 | nn.ReLU(inplace=True), 97 | nn.Linear(hidden_dim, hidden_dim), 98 | nn.ReLU(inplace=True), 99 | nn.Linear(hidden_dim, self._actor_dim), 100 | ) 101 | self.apply(utils.weight_init) 102 | 103 | def forward(self, rgb_obs: torch.Tensor, low_dim_obs: torch.Tensor, std: float): 104 | """ 105 | Inputs: 106 | - rgb_obs: features from visual encoder 107 | - low_dim_obs: low-dimensional observations 108 | 109 | Outputs: 110 | - dist: torch distribution for policy 111 | """ 112 | rgb_h = self.rgb_encoder(rgb_obs) 113 | low_dim_h = self.low_dim_encoder(low_dim_obs) 114 | h = torch.cat([rgb_h, low_dim_h], -1) 115 | 116 | mu = self.policy(h) 117 | mu = torch.tanh(mu) 118 | std = torch.ones_like(mu) * std 119 | 120 | dist = utils.TruncatedNormal(mu, std) 121 | return dist 122 | 123 | 124 | class Critic(nn.Module): 125 | def __init__( 126 | self, 127 | repr_dim: int, 128 | low_dim: int, 129 | action_shape: tuple, 130 | feature_dim: int, 131 | hidden_dim: int, 132 | out_shape: tuple, 133 | ): 134 | super().__init__() 135 | self._actor_dim = action_shape[0] 136 | self._out_shape = out_shape 137 | out_dim = 1 138 | for s in out_shape: 139 | out_dim *= s 140 | 141 | self.rgb_encoder = nn.Sequential( 142 | nn.Linear(repr_dim, feature_dim), 143 | nn.LayerNorm(feature_dim), 144 | nn.Tanh(), 145 | ) 146 | self.low_dim_encoder = nn.Sequential( 147 | nn.Linear(low_dim, feature_dim), 148 | nn.LayerNorm(feature_dim), 149 | nn.Tanh(), 150 | ) 151 | self.Q1 = nn.Sequential( 152 | nn.Linear(feature_dim * 2 + self._actor_dim, hidden_dim), 153 | nn.ReLU(inplace=True), 154 | nn.Linear(hidden_dim, hidden_dim), 155 | nn.ReLU(inplace=True), 156 | nn.Linear(hidden_dim, out_dim), 157 | ) 158 | self.Q2 = nn.Sequential( 159 | nn.Linear(feature_dim * 2 + self._actor_dim, hidden_dim), 160 | nn.ReLU(inplace=True), 161 | nn.Linear(hidden_dim, hidden_dim), 162 | nn.ReLU(inplace=True), 163 | nn.Linear(hidden_dim, out_dim), 164 | ) 165 | self.apply(utils.weight_init) 166 | 167 | def forward( 168 | self, rgb_obs: torch.Tensor, low_dim_obs: torch.Tensor, actions: torch.Tensor 169 | ): 170 | """ 171 | Inputs: 172 | - obs: features from visual encoder 173 | - low_dim_obs: low-dimensional observations 174 | - action: actions 175 | 176 | Outputs: 177 | - qs: (batch_size, 2) 178 | """ 179 | rgb_h = self.rgb_encoder(rgb_obs) 180 | low_dim_h = self.low_dim_encoder(low_dim_obs) 181 | h = torch.cat([rgb_h, low_dim_h, actions], -1) 182 | q1 = self.Q1(h).view(h.shape[0], *self._out_shape) 183 | q2 = self.Q2(h).view(h.shape[0], *self._out_shape) 184 | qs = torch.cat([q1, q2], -1) 185 | return qs 186 | 187 | 188 | class DistributionalCritic(Critic): 189 | def __init__( 190 | self, 191 | distributional_critic_limit: float, 192 | distributional_critic_atoms: int, 193 | distributional_critic_transform: bool, 194 | *args, 195 | **kwargs 196 | ): 197 | super().__init__( 198 | *args, 199 | **kwargs, 200 | ) 201 | self.limit = distributional_critic_limit 202 | self.atoms = distributional_critic_atoms 203 | self.transform = distributional_critic_transform 204 | 205 | def to_dist(self, qs): 206 | return torch.cat( 207 | [ 208 | utils.to_categorical( 209 | qs[:, q_idx].unsqueeze(-1), 210 | limit=self.limit, 211 | num_atoms=self.atoms, 212 | transformation=self.transform, 213 | ) 214 | for q_idx in range(qs.size(-1)) 215 | ], 216 | dim=-1, 217 | ) 218 | 219 | def from_dist(self, qs): 220 | return torch.cat( 221 | [ 222 | utils.from_categorical( 223 | qs[..., q_idx], 224 | limit=self.limit, 225 | transformation=self.transform, 226 | ) 227 | for q_idx in range(qs.size(-1)) 228 | ], 229 | dim=-1, 230 | ) 231 | 232 | def compute_distributional_critic_loss(self, qs, target_qs): 233 | loss = 0.0 234 | for q_idx in range(qs.size(-1)): 235 | loss += -torch.sum( 236 | torch.log_softmax(qs[[..., q_idx]], -1) 237 | * target_qs.squeeze(-1).detach(), 238 | -1, 239 | ) 240 | return loss.unsqueeze(-1) 241 | 242 | 243 | class DrQV2Agent: 244 | def __init__( 245 | self, 246 | rgb_obs_shape, 247 | low_dim_obs_shape, 248 | action_shape, 249 | device, 250 | lr, 251 | weight_decay, 252 | feature_dim, 253 | hidden_dim, 254 | use_distributional_critic, 255 | distributional_critic_limit, 256 | distributional_critic_atoms, 257 | distributional_critic_transform, 258 | bc_lambda, 259 | critic_target_tau, 260 | num_expl_steps, 261 | update_every_steps, 262 | stddev_schedule, 263 | stddev_clip, 264 | use_logger, 265 | ): 266 | self.device = device 267 | self.critic_target_tau = critic_target_tau 268 | self.update_every_steps = update_every_steps 269 | self.use_logger = use_logger 270 | self.num_expl_steps = num_expl_steps 271 | self.stddev_schedule = stddev_schedule 272 | self.stddev_clip = stddev_clip 273 | self.bc_lambda = bc_lambda 274 | self.use_distributional_critic = use_distributional_critic 275 | self.distributional_critic_limit = distributional_critic_limit 276 | self.distributional_critic_atoms = distributional_critic_atoms 277 | self.distributional_critic_transform = distributional_critic_transform 278 | 279 | # models 280 | low_dim = low_dim_obs_shape[-1] 281 | self.encoder = MultiViewCNNEncoder(rgb_obs_shape).to(device) 282 | self.actor = Actor( 283 | self.encoder.repr_dim, low_dim, action_shape, feature_dim, hidden_dim 284 | ).to(device) 285 | 286 | if use_distributional_critic: 287 | self.critic = DistributionalCritic( 288 | self.distributional_critic_limit, 289 | self.distributional_critic_atoms, 290 | self.distributional_critic_transform, 291 | self.encoder.repr_dim, 292 | low_dim, 293 | action_shape, 294 | feature_dim, 295 | hidden_dim, 296 | out_shape=(self.distributional_critic_atoms, 1), 297 | ).to(device) 298 | self.critic_target = DistributionalCritic( 299 | self.distributional_critic_limit, 300 | self.distributional_critic_atoms, 301 | self.distributional_critic_transform, 302 | self.encoder.repr_dim, 303 | low_dim, 304 | action_shape, 305 | feature_dim, 306 | hidden_dim, 307 | out_shape=(self.distributional_critic_atoms, 1), 308 | ).to(device) 309 | else: 310 | self.critic = Critic( 311 | self.encoder.repr_dim, 312 | low_dim, 313 | action_shape, 314 | feature_dim, 315 | hidden_dim, 316 | out_shape=(1,), 317 | ).to(device) 318 | self.critic_target = Critic( 319 | self.encoder.repr_dim, 320 | low_dim, 321 | action_shape, 322 | feature_dim, 323 | hidden_dim, 324 | out_shape=(1,), 325 | ).to(device) 326 | self.critic_target.load_state_dict(self.critic.state_dict()) 327 | 328 | # optimizers 329 | self.encoder_opt = torch.optim.AdamW( 330 | self.encoder.parameters(), lr=lr, weight_decay=weight_decay 331 | ) 332 | self.actor_opt = torch.optim.AdamW( 333 | self.actor.parameters(), lr=lr, weight_decay=weight_decay 334 | ) 335 | self.critic_opt = torch.optim.AdamW( 336 | self.critic.parameters(), lr=lr, weight_decay=weight_decay 337 | ) 338 | 339 | # data augmentation 340 | self.aug = RandomShiftsAug(pad=4) 341 | 342 | self.train() 343 | self.critic_target.eval() 344 | 345 | print(self.encoder) 346 | print(self.critic) 347 | 348 | def train(self, training=True): 349 | self.training = training 350 | self.encoder.train(training) 351 | self.actor.train(training) 352 | self.critic.train(training) 353 | 354 | def act(self, rgb_obs, low_dim_obs, step, eval_mode): 355 | rgb_obs = torch.as_tensor(rgb_obs, device=self.device).unsqueeze(0) 356 | low_dim_obs = torch.as_tensor(low_dim_obs, device=self.device).unsqueeze(0) 357 | rgb_obs = self.encoder(rgb_obs) 358 | stddev = utils.schedule(self.stddev_schedule, step) 359 | dist = self.actor(rgb_obs, low_dim_obs, stddev) 360 | if eval_mode: 361 | action = dist.mean 362 | else: 363 | action = dist.sample(clip=None) 364 | if step < self.num_expl_steps: 365 | action.uniform_(-1.0, 1.0) 366 | return action.cpu().numpy()[0] 367 | 368 | def update_critic( 369 | self, 370 | rgb_obs, 371 | low_dim_obs, 372 | action, 373 | reward, 374 | discount, 375 | next_rgb_obs, 376 | next_low_dim_obs, 377 | step, 378 | ): 379 | metrics = dict() 380 | 381 | with torch.no_grad(): 382 | stddev = utils.schedule(self.stddev_schedule, step) 383 | dist = self.actor(next_rgb_obs, next_low_dim_obs, stddev) 384 | next_action = dist.sample(clip=self.stddev_clip) 385 | target_qs = self.critic_target(next_rgb_obs, next_low_dim_obs, next_action) 386 | if self.use_distributional_critic: 387 | target_qs = self.critic_target.from_dist(target_qs) 388 | target_Q1, target_Q2 = target_qs[..., 0], target_qs[..., 1] 389 | target_V = torch.min(target_Q1, target_Q2).unsqueeze(1) 390 | target_Q = reward + (discount * target_V) 391 | if self.use_logger: 392 | metrics["critic_target_q"] = target_Q.mean().item() 393 | if self.use_distributional_critic: 394 | target_Q = self.critic_target.to_dist(target_Q) 395 | 396 | qs = self.critic(rgb_obs, low_dim_obs, action) 397 | 398 | if self.use_distributional_critic: 399 | critic_loss = self.critic.compute_distributional_critic_loss( 400 | qs, target_Q 401 | ).mean() 402 | else: 403 | Q1, Q2 = qs[..., 0], qs[..., 1] 404 | target_Q = target_Q.squeeze(1) 405 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 406 | if self.use_logger: 407 | metrics["critic_q1"] = Q1.mean().item() 408 | metrics["critic_q2"] = Q2.mean().item() 409 | 410 | if self.use_logger: 411 | metrics["critic_loss"] = critic_loss.item() 412 | 413 | # optimize encoder and critic 414 | self.encoder_opt.zero_grad(set_to_none=True) 415 | self.critic_opt.zero_grad(set_to_none=True) 416 | critic_loss.backward() 417 | self.critic_opt.step() 418 | self.encoder_opt.step() 419 | 420 | return metrics 421 | 422 | def update_actor(self, rgb_obs, low_dim_obs, demo_action, demos, step): 423 | metrics = dict() 424 | 425 | stddev = utils.schedule(self.stddev_schedule, step) 426 | dist = self.actor(rgb_obs, low_dim_obs, stddev) 427 | action = dist.sample(clip=self.stddev_clip) 428 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 429 | qs = self.critic(rgb_obs, low_dim_obs, action) 430 | if self.use_distributional_critic: 431 | qs = self.critic.from_dist(qs) 432 | Q1, Q2 = qs[..., 0], qs[..., 1] 433 | Q = torch.min(Q1, Q2) 434 | 435 | base_actor_loss = -Q.mean() 436 | 437 | bc_metrics, bc_loss = self.get_bc_loss(dist.mean, demo_action, demos) 438 | metrics.update(bc_metrics) 439 | actor_loss = base_actor_loss + self.bc_lambda * bc_loss 440 | 441 | # optimize actor 442 | self.actor_opt.zero_grad(set_to_none=True) 443 | actor_loss.backward() 444 | self.actor_opt.step() 445 | 446 | if self.use_logger: 447 | metrics["actor_loss"] = base_actor_loss.mean().item() 448 | metrics["actor_logprob"] = log_prob.mean().item() 449 | metrics["actor_ent"] = dist.entropy().sum(dim=-1).mean().item() 450 | 451 | return metrics 452 | 453 | def get_bc_loss(self, predicted_action, buffer_action, demos): 454 | metrics = dict() 455 | bc_loss = 0 456 | if demos is not None: 457 | # Only apply loss to demo items 458 | demos = demos.float() 459 | bs = demos.shape[0] 460 | 461 | if demos.sum() > 0: 462 | bc_loss = ( 463 | F.mse_loss( 464 | predicted_action.view(bs, -1), 465 | buffer_action.view(bs, -1), 466 | reduction="none", 467 | ) 468 | * demos 469 | ) 470 | bc_loss = bc_loss.sum() / demos.sum() 471 | if self.use_logger: 472 | metrics["actor_bc_loss"] = bc_loss.item() 473 | if self.use_logger: 474 | metrics["ratio_of_demos"] = demos.mean().item() 475 | return metrics, bc_loss 476 | 477 | def update(self, replay_iter, step): 478 | metrics = dict() 479 | 480 | if step % self.update_every_steps != 0: 481 | return metrics 482 | 483 | batch = next(replay_iter) 484 | ( 485 | rgb_obs, 486 | low_dim_obs, 487 | action, 488 | reward, 489 | discount, 490 | next_rgb_obs, 491 | next_low_dim_obs, 492 | demos, 493 | ) = utils.to_torch(batch, self.device) 494 | 495 | # augment 496 | rgb_obs = rgb_obs.float() 497 | next_rgb_obs = next_rgb_obs.float() 498 | rgb_obs = torch.stack( 499 | [self.aug(rgb_obs[:, v]) for v in range(rgb_obs.shape[1])], 1 500 | ) 501 | next_rgb_obs = torch.stack( 502 | [self.aug(next_rgb_obs[:, v]) for v in range(next_rgb_obs.shape[1])], 1 503 | ) 504 | # encode 505 | rgb_obs = self.encoder(rgb_obs) 506 | with torch.no_grad(): 507 | next_rgb_obs = self.encoder(next_rgb_obs) 508 | 509 | if self.use_logger: 510 | metrics["batch_reward"] = reward.mean().item() 511 | 512 | # update critic 513 | metrics.update( 514 | self.update_critic( 515 | rgb_obs, 516 | low_dim_obs, 517 | action, 518 | reward, 519 | discount, 520 | next_rgb_obs, 521 | next_low_dim_obs, 522 | step, 523 | ) 524 | ) 525 | 526 | # update actor 527 | metrics.update( 528 | self.update_actor( 529 | rgb_obs.detach(), 530 | low_dim_obs.detach(), 531 | action, 532 | demos, 533 | step, 534 | ) 535 | ) 536 | 537 | # update critic target 538 | utils.soft_update_params( 539 | self.critic, self.critic_target, self.critic_target_tau 540 | ) 541 | 542 | return metrics 543 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import wandb 6 | from torch.utils.tensorboard import SummaryWriter 7 | from omegaconf import OmegaConf 8 | 9 | import torch 10 | from termcolor import colored 11 | 12 | COMMON_TRAIN_FORMAT = [ 13 | ("frame", "F", "int"), 14 | ("step", "S", "int"), 15 | ("episode", "E", "int"), 16 | ("episode_length", "L", "int"), 17 | ("episode_reward", "R", "float"), 18 | ("buffer_size", "BS", "int"), 19 | ("fps", "FPS", "float"), 20 | ("total_time", "T", "time"), 21 | ] 22 | 23 | COMMON_EVAL_FORMAT = [ 24 | ("frame", "F", "int"), 25 | ("step", "S", "int"), 26 | ("episode", "E", "int"), 27 | ("episode_length", "L", "int"), 28 | ("episode_reward", "R", "float"), 29 | ("total_time", "T", "time"), 30 | ] 31 | 32 | 33 | class AverageMeter(object): 34 | def __init__(self): 35 | self._sum = 0 36 | self._count = 0 37 | 38 | def update(self, value, n=1): 39 | self._sum += value 40 | self._count += n 41 | 42 | def value(self): 43 | return self._sum / max(1, self._count) 44 | 45 | 46 | class MetersGroup(object): 47 | def __init__(self, csv_file_name, formating): 48 | self._csv_file_name = csv_file_name 49 | self._formating = formating 50 | self._meters = defaultdict(AverageMeter) 51 | self._csv_file = None 52 | self._csv_writer = None 53 | 54 | def log(self, key, value, n=1): 55 | self._meters[key].update(value, n) 56 | 57 | def _prime_meters(self): 58 | data = dict() 59 | for key, meter in self._meters.items(): 60 | if key.startswith("train"): 61 | key = key[len("train") + 1 :] 62 | else: 63 | key = key[len("eval") + 1 :] 64 | key = key.replace("/", "_") 65 | data[key] = meter.value() 66 | return data 67 | 68 | def _remove_old_entries(self, data): 69 | rows = [] 70 | with self._csv_file_name.open("r") as f: 71 | reader = csv.DictReader(f) 72 | for row in reader: 73 | if float(row["episode"]) >= data["episode"]: 74 | break 75 | rows.append(row) 76 | with self._csv_file_name.open("w") as f: 77 | writer = csv.DictWriter(f, fieldnames=sorted(data.keys()), restval=0.0) 78 | writer.writeheader() 79 | for row in rows: 80 | writer.writerow(row) 81 | 82 | def _dump_to_csv(self, data): 83 | if self._csv_writer is None: 84 | should_write_header = True 85 | if self._csv_file_name.exists(): 86 | self._remove_old_entries(data) 87 | should_write_header = False 88 | 89 | self._csv_file = self._csv_file_name.open("a") 90 | self._csv_writer = csv.DictWriter( 91 | self._csv_file, fieldnames=sorted(data.keys()), restval=0.0 92 | ) 93 | if should_write_header: 94 | self._csv_writer.writeheader() 95 | 96 | self._csv_writer.writerow(data) 97 | self._csv_file.flush() 98 | 99 | def _format(self, key, value, ty): 100 | if ty == "int": 101 | value = int(value) 102 | return f"{key}: {value}" 103 | elif ty == "float": 104 | return f"{key}: {value:.04f}" 105 | elif ty == "time": 106 | value = str(datetime.timedelta(seconds=int(value))) 107 | return f"{key}: {value}" 108 | else: 109 | raise f"invalid format type: {ty}" 110 | 111 | def _dump_to_console(self, data, prefix): 112 | prefix = colored(prefix, "yellow" if prefix == "train" else "green") 113 | pieces = [f"| {prefix: <14}"] 114 | for key, disp_key, ty in self._formating: 115 | value = data.get(key, 0) 116 | pieces.append(self._format(disp_key, value, ty)) 117 | print(" | ".join(pieces)) 118 | 119 | def dump(self, step, prefix): 120 | if len(self._meters) == 0: 121 | return 122 | data = self._prime_meters() 123 | data["frame"] = step 124 | self._dump_to_csv(data) 125 | self._dump_to_console(data, prefix) 126 | self._meters.clear() 127 | 128 | 129 | class Logger(object): 130 | def __init__(self, log_dir, use_tb, use_wandb, config): 131 | self._log_dir = log_dir 132 | self._use_tb = use_tb 133 | self._use_wandb = use_wandb 134 | self._train_mg = MetersGroup( 135 | log_dir / "train.csv", formating=COMMON_TRAIN_FORMAT 136 | ) 137 | self._eval_mg = MetersGroup(log_dir / "eval.csv", formating=COMMON_EVAL_FORMAT) 138 | if use_tb: 139 | self._sw = SummaryWriter(str(log_dir / "tb")) 140 | 141 | config_dict = OmegaConf.to_container(config, resolve=True) 142 | 143 | if use_wandb: 144 | wandb.init( 145 | project=config.wandb.project, 146 | entity=config.wandb.entity, 147 | name=config.wandb.name, 148 | config=config_dict, 149 | ) 150 | self._wandb_logs = dict() 151 | 152 | def _try_sw_log(self, key, value, step): 153 | if self._use_tb: 154 | self._sw.add_scalar(key, value, step) 155 | 156 | def _try_wandb_log(self, key, value, step): 157 | if self._use_wandb: 158 | self._wandb_logs[key] = value 159 | 160 | def log(self, key, value, step): 161 | assert key.startswith("train") or key.startswith("eval") 162 | if type(value) == torch.Tensor: 163 | value = value.item() 164 | self._try_sw_log(key, value, step) 165 | self._try_wandb_log(key, value, step) 166 | mg = self._train_mg if key.startswith("train") else self._eval_mg 167 | mg.log(key, value) 168 | 169 | def log_metrics(self, metrics, step, ty): 170 | for key, value in metrics.items(): 171 | self.log(f"{ty}/{key}", value, step) 172 | 173 | def dump(self, step, ty=None): 174 | if ty is None or ty == "eval": 175 | self._eval_mg.dump(step, "eval") 176 | if ty is None or ty == "train": 177 | self._train_mg.dump(step, "train") 178 | if self._use_wandb and len(self._wandb_logs): 179 | wandb.log(self._wandb_logs, step=step) 180 | self._wandb_logs = dict() 181 | 182 | def log_and_dump_ctx(self, step, ty): 183 | return LogAndDumpCtx(self, step, ty) 184 | 185 | 186 | class LogAndDumpCtx: 187 | def __init__(self, logger, step, ty): 188 | self._logger = logger 189 | self._step = step 190 | self._ty = ty 191 | 192 | def __enter__(self): 193 | return self 194 | 195 | def __call__(self, key, value): 196 | self._logger.log(f"{self._ty}/{key}", value, self._step) 197 | 198 | def __exit__(self, *args): 199 | self._logger.dump(self._step, self._ty) 200 | -------------------------------------------------------------------------------- /media/cqn_gif1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/younggyoseo/CQN/1932a05d87fa368639c4a4158bd3d5182e46c4c9/media/cqn_gif1.gif -------------------------------------------------------------------------------- /media/cqn_gif2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/younggyoseo/CQN/1932a05d87fa368639c4a4158bd3d5182e46c4c9/media/cqn_gif2.gif -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import IterableDataset 10 | 11 | 12 | def episode_len(episode): 13 | # subtract -1 because the dummy first transition 14 | return next(iter(episode.values())).shape[0] - 1 15 | 16 | 17 | def save_episode(episode, fn): 18 | with io.BytesIO() as bs: 19 | np.savez_compressed(bs, **episode) 20 | bs.seek(0) 21 | with fn.open("wb") as f: 22 | f.write(bs.read()) 23 | 24 | 25 | def load_episode(fn): 26 | with fn.open("rb") as f: 27 | episode = np.load(f) 28 | episode = {k: episode[k] for k in episode.keys()} 29 | return episode 30 | 31 | 32 | class ReplayBufferStorage: 33 | def __init__(self, data_specs, replay_dir, use_relabeling, is_demo_buffer=False): 34 | self._data_specs = data_specs 35 | self._replay_dir = replay_dir 36 | self._use_relabeling = use_relabeling 37 | self._is_demo_buffer = is_demo_buffer 38 | replay_dir.mkdir(exist_ok=True) 39 | self._current_episode = defaultdict(list) 40 | self._preload() 41 | 42 | def __len__(self): 43 | return self._num_transitions 44 | 45 | def add(self, time_step): 46 | for spec in self._data_specs: 47 | value = time_step[spec.name] 48 | # Remove frame stacking 49 | if spec.name == "low_dim_obs": 50 | low_dim = 8 # hard-coded 51 | value = value[..., -low_dim:] 52 | elif spec.name == "rgb_obs": 53 | rgb_dim = 3 # hard-coded 54 | value = value[:, -rgb_dim:] 55 | if np.isscalar(value): 56 | value = np.full(spec.shape, value, spec.dtype) 57 | assert spec.shape == value.shape and spec.dtype == value.dtype, ( 58 | spec.name, 59 | spec.shape, 60 | value.shape, 61 | spec.dtype, 62 | value.dtype, 63 | ) 64 | self._current_episode[spec.name].append(value) 65 | if time_step.last(): 66 | episode = dict() 67 | for spec in self._data_specs: 68 | value = self._current_episode[spec.name] 69 | episode[spec.name] = np.array(value, spec.dtype) 70 | self._current_episode = defaultdict(list) 71 | if self._use_relabeling: 72 | episode = self._relabel_episode(episode) 73 | if self._is_demo_buffer: 74 | # If this is demo replay buffer, save only when it's successful 75 | if self._check_if_successful(episode): 76 | self._store_episode(episode) 77 | else: 78 | self._store_episode(episode) 79 | 80 | def _relabel_episode(self, episode): 81 | if self._check_if_successful(episode): 82 | episode["demo"] = np.ones_like(episode["demo"]) 83 | return episode 84 | 85 | def _check_if_successful(self, episode): 86 | reward = episode["reward"] 87 | return np.isclose(reward[-1], 1.0) 88 | 89 | def _preload(self): 90 | self._num_episodes = 0 91 | self._num_transitions = 0 92 | for fn in self._replay_dir.glob("*.npz"): 93 | _, _, eps_len = fn.stem.split("_") 94 | self._num_episodes += 1 95 | self._num_transitions += int(eps_len) 96 | 97 | def _store_episode(self, episode): 98 | eps_idx = self._num_episodes 99 | eps_len = episode_len(episode) 100 | self._num_episodes += 1 101 | self._num_transitions += eps_len 102 | ts = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") 103 | eps_fn = f"{ts}_{eps_idx}_{eps_len}.npz" 104 | save_episode(episode, self._replay_dir / eps_fn) 105 | 106 | 107 | class ReplayBuffer(IterableDataset): 108 | def __init__( 109 | self, 110 | replay_dir, 111 | max_size, 112 | num_workers, 113 | nstep, 114 | discount, 115 | do_always_bootstrap, 116 | frame_stack, 117 | fetch_every, 118 | save_snapshot, 119 | ): 120 | self._replay_dir = replay_dir 121 | self._size = 0 122 | self._max_size = max_size 123 | self._num_workers = max(1, num_workers) 124 | self._episode_fns = [] 125 | self._episodes = dict() 126 | self._nstep = nstep 127 | self._discount = discount 128 | self._do_always_bootstrap = do_always_bootstrap 129 | self._frame_stack = frame_stack 130 | self._fetch_every = fetch_every 131 | self._samples_since_last_fetch = fetch_every 132 | self._save_snapshot = save_snapshot 133 | 134 | def _sample_episode(self): 135 | eps_fn = random.choice(self._episode_fns) 136 | return self._episodes[eps_fn] 137 | 138 | def _store_episode(self, eps_fn): 139 | try: 140 | episode = load_episode(eps_fn) 141 | except: 142 | return False 143 | eps_len = episode_len(episode) 144 | while eps_len + self._size > self._max_size: 145 | early_eps_fn = self._episode_fns.pop(0) 146 | early_eps = self._episodes.pop(early_eps_fn) 147 | self._size -= episode_len(early_eps) 148 | early_eps_fn.unlink(missing_ok=True) 149 | self._episode_fns.append(eps_fn) 150 | self._episode_fns.sort() 151 | self._episodes[eps_fn] = episode 152 | self._size += eps_len 153 | 154 | if not self._save_snapshot: 155 | eps_fn.unlink(missing_ok=True) 156 | return True 157 | 158 | def _try_fetch(self): 159 | if self._samples_since_last_fetch < self._fetch_every: 160 | return 161 | self._samples_since_last_fetch = 0 162 | try: 163 | worker_id = torch.utils.data.get_worker_info().id 164 | except: 165 | worker_id = 0 166 | eps_fns = sorted(self._replay_dir.glob("*.npz"), reverse=True) 167 | fetched_size = 0 168 | for eps_fn in eps_fns: 169 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split("_")[1:]] 170 | if eps_idx % self._num_workers != worker_id: 171 | continue 172 | if eps_fn in self._episodes.keys(): 173 | break 174 | if fetched_size + eps_len > self._max_size: 175 | break 176 | fetched_size += eps_len 177 | if not self._store_episode(eps_fn): 178 | break 179 | 180 | def _sample(self): 181 | try: 182 | self._try_fetch() 183 | except: 184 | traceback.print_exc() 185 | self._samples_since_last_fetch += 1 186 | episode = self._sample_episode() 187 | # add +1 for the first dummy transition 188 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1 189 | next_idx = idx + self._nstep - 1 190 | 191 | obs_idxs = list( 192 | map( 193 | lambda x: np.clip(x, 0, None), 194 | range((idx - 1) - self._frame_stack + 1, (idx - 1) + 1), 195 | ) 196 | ) 197 | obs_next_idxs = list( 198 | map( 199 | lambda x: np.clip(x, 0, None), 200 | range(next_idx - self._frame_stack + 1, next_idx + 1), 201 | ) 202 | ) 203 | 204 | # rgb_obs stacking -- channel-wise concat 205 | rgb_obs = np.concatenate(episode["rgb_obs"][obs_idxs], 1) 206 | next_rgb_obs = np.concatenate(episode["rgb_obs"][obs_next_idxs], 1) 207 | # low_dim_obs stacking -- last-dim-wise concat 208 | low_dim_obs = np.concatenate(episode["low_dim_obs"][obs_idxs], -1) 209 | next_low_dim_obs = np.concatenate(episode["low_dim_obs"][obs_next_idxs], -1) 210 | 211 | action = episode["action"][idx] 212 | reward = np.zeros_like(episode["reward"][idx]) 213 | discount = np.ones_like(episode["discount"][idx]) 214 | for i in range(self._nstep): 215 | step_reward = episode["reward"][idx + i] 216 | reward += discount * step_reward 217 | if self._do_always_bootstrap: 218 | _discount = 1.0 219 | else: 220 | _discount = episode["discount"][idx + i] 221 | discount *= _discount * self._discount 222 | demo = episode["demo"][idx] 223 | return ( 224 | rgb_obs, 225 | low_dim_obs, 226 | action, 227 | reward, 228 | discount, 229 | next_rgb_obs, 230 | next_low_dim_obs, 231 | demo, 232 | ) 233 | 234 | def __iter__(self): 235 | while True: 236 | yield self._sample() 237 | 238 | 239 | def _worker_init_fn(worker_id): 240 | seed = np.random.get_state()[1][0] + worker_id 241 | np.random.seed(seed) 242 | random.seed(seed) 243 | 244 | 245 | def make_replay_loader( 246 | replay_dir, 247 | max_size, 248 | batch_size, 249 | num_workers, 250 | save_snapshot, 251 | nstep, 252 | discount, 253 | do_always_bootstrap, 254 | frame_stack, 255 | ): 256 | max_size_per_worker = max_size // max(1, num_workers) 257 | 258 | iterable = ReplayBuffer( 259 | replay_dir, 260 | max_size_per_worker, 261 | num_workers, 262 | nstep, 263 | discount, 264 | do_always_bootstrap, 265 | frame_stack, 266 | fetch_every=100, 267 | save_snapshot=save_snapshot, 268 | ) 269 | 270 | loader = torch.utils.data.DataLoader( 271 | iterable, 272 | batch_size=batch_size, 273 | num_workers=num_workers, 274 | pin_memory=True, 275 | worker_init_fn=_worker_init_fn, 276 | ) 277 | return loader 278 | -------------------------------------------------------------------------------- /replay_buffer_dmc.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import IterableDataset 11 | 12 | 13 | def episode_len(episode): 14 | # subtract -1 because the dummy first transition 15 | return next(iter(episode.values())).shape[0] - 1 16 | 17 | 18 | def save_episode(episode, fn): 19 | with io.BytesIO() as bs: 20 | np.savez_compressed(bs, **episode) 21 | bs.seek(0) 22 | with fn.open("wb") as f: 23 | f.write(bs.read()) 24 | 25 | 26 | def load_episode(fn): 27 | with fn.open("rb") as f: 28 | episode = np.load(f) 29 | episode = {k: episode[k] for k in episode.keys()} 30 | return episode 31 | 32 | 33 | class ReplayBufferStorage: 34 | def __init__(self, data_specs, replay_dir): 35 | self._data_specs = data_specs 36 | self._replay_dir = replay_dir 37 | replay_dir.mkdir(exist_ok=True) 38 | self._current_episode = defaultdict(list) 39 | self._preload() 40 | 41 | def __len__(self): 42 | return self._num_transitions 43 | 44 | def add(self, time_step): 45 | for spec in self._data_specs: 46 | value = time_step[spec.name] 47 | if np.isscalar(value): 48 | value = np.full(spec.shape, value, spec.dtype) 49 | assert spec.shape == value.shape and spec.dtype == value.dtype 50 | self._current_episode[spec.name].append(value) 51 | if time_step.last(): 52 | episode = dict() 53 | for spec in self._data_specs: 54 | value = self._current_episode[spec.name] 55 | episode[spec.name] = np.array(value, spec.dtype) 56 | self._current_episode = defaultdict(list) 57 | self._store_episode(episode) 58 | 59 | def _preload(self): 60 | self._num_episodes = 0 61 | self._num_transitions = 0 62 | for fn in self._replay_dir.glob("*.npz"): 63 | _, _, eps_len = fn.stem.split("_") 64 | self._num_episodes += 1 65 | self._num_transitions += int(eps_len) 66 | 67 | def _store_episode(self, episode): 68 | eps_idx = self._num_episodes 69 | eps_len = episode_len(episode) 70 | self._num_episodes += 1 71 | self._num_transitions += eps_len 72 | ts = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") 73 | eps_fn = f"{ts}_{eps_idx}_{eps_len}.npz" 74 | save_episode(episode, self._replay_dir / eps_fn) 75 | 76 | 77 | class ReplayBuffer(IterableDataset): 78 | def __init__( 79 | self, 80 | replay_dir, 81 | max_size, 82 | num_workers, 83 | nstep, 84 | discount, 85 | fetch_every, 86 | save_snapshot, 87 | ): 88 | self._replay_dir = replay_dir 89 | self._size = 0 90 | self._max_size = max_size 91 | self._num_workers = max(1, num_workers) 92 | self._episode_fns = [] 93 | self._episodes = dict() 94 | self._nstep = nstep 95 | self._discount = discount 96 | self._fetch_every = fetch_every 97 | self._samples_since_last_fetch = fetch_every 98 | self._save_snapshot = save_snapshot 99 | 100 | def _sample_episode(self): 101 | eps_fn = random.choice(self._episode_fns) 102 | return self._episodes[eps_fn] 103 | 104 | def _store_episode(self, eps_fn): 105 | try: 106 | episode = load_episode(eps_fn) 107 | except: 108 | return False 109 | eps_len = episode_len(episode) 110 | while eps_len + self._size > self._max_size: 111 | early_eps_fn = self._episode_fns.pop(0) 112 | early_eps = self._episodes.pop(early_eps_fn) 113 | self._size -= episode_len(early_eps) 114 | early_eps_fn.unlink(missing_ok=True) 115 | self._episode_fns.append(eps_fn) 116 | self._episode_fns.sort() 117 | self._episodes[eps_fn] = episode 118 | self._size += eps_len 119 | 120 | if not self._save_snapshot: 121 | eps_fn.unlink(missing_ok=True) 122 | return True 123 | 124 | def _try_fetch(self): 125 | if self._samples_since_last_fetch < self._fetch_every: 126 | return 127 | self._samples_since_last_fetch = 0 128 | try: 129 | worker_id = torch.utils.data.get_worker_info().id 130 | except: 131 | worker_id = 0 132 | eps_fns = sorted(self._replay_dir.glob("*.npz"), reverse=True) 133 | fetched_size = 0 134 | for eps_fn in eps_fns: 135 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split("_")[1:]] 136 | if eps_idx % self._num_workers != worker_id: 137 | continue 138 | if eps_fn in self._episodes.keys(): 139 | break 140 | if fetched_size + eps_len > self._max_size: 141 | break 142 | fetched_size += eps_len 143 | if not self._store_episode(eps_fn): 144 | break 145 | 146 | def _sample(self): 147 | try: 148 | self._try_fetch() 149 | except: 150 | traceback.print_exc() 151 | self._samples_since_last_fetch += 1 152 | episode = self._sample_episode() 153 | # add +1 for the first dummy transition 154 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1 155 | obs = episode["observation"][idx - 1] 156 | action = episode["action"][idx] 157 | next_obs = episode["observation"][idx + self._nstep - 1] 158 | reward = np.zeros_like(episode["reward"][idx]) 159 | discount = np.ones_like(episode["discount"][idx]) 160 | for i in range(self._nstep): 161 | step_reward = episode["reward"][idx + i] 162 | reward += discount * step_reward 163 | discount *= episode["discount"][idx + i] * self._discount 164 | return (obs, action, reward, discount, next_obs) 165 | 166 | def __iter__(self): 167 | while True: 168 | yield self._sample() 169 | 170 | 171 | def _worker_init_fn(worker_id): 172 | seed = np.random.get_state()[1][0] + worker_id 173 | np.random.seed(seed) 174 | random.seed(seed) 175 | 176 | 177 | def make_replay_loader( 178 | replay_dir, max_size, batch_size, num_workers, save_snapshot, nstep, discount 179 | ): 180 | max_size_per_worker = max_size // max(1, num_workers) 181 | 182 | iterable = ReplayBuffer( 183 | replay_dir, 184 | max_size_per_worker, 185 | num_workers, 186 | nstep, 187 | discount, 188 | fetch_every=1000, 189 | save_snapshot=save_snapshot, 190 | ) 191 | 192 | loader = torch.utils.data.DataLoader( 193 | iterable, 194 | batch_size=batch_size, 195 | num_workers=num_workers, 196 | pin_memory=True, 197 | worker_init_fn=_worker_init_fn, 198 | ) 199 | return loader 200 | -------------------------------------------------------------------------------- /rlbench_env.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from gymnasium import spaces 4 | import numpy as np 5 | from typing import Union, Dict, Any, NamedTuple 6 | from collections import deque 7 | 8 | from pyrep.const import RenderMode 9 | from pyrep.objects.vision_sensor import VisionSensor 10 | from pyrep.objects.dummy import Dummy 11 | 12 | from rlbench.action_modes.action_mode import MoveArmThenGripper 13 | from rlbench.action_modes.gripper_action_modes import Discrete 14 | from rlbench.environment import Environment 15 | from rlbench.observation_config import ObservationConfig 16 | from rlbench.utils import name_to_task_class 17 | from rlbench.action_modes.arm_action_modes import ( 18 | JointPosition, 19 | ) 20 | 21 | from dm_env import StepType, specs 22 | 23 | 24 | class TimeStep(NamedTuple): 25 | step_type: Any 26 | reward: Any 27 | discount: Any 28 | rgb_obs: Any 29 | low_dim_obs: Any 30 | demo: Any 31 | 32 | def first(self): 33 | return self.step_type == StepType.FIRST 34 | 35 | def mid(self): 36 | return self.step_type == StepType.MID 37 | 38 | def last(self): 39 | return self.step_type == StepType.LAST 40 | 41 | def __getitem__(self, attr): 42 | if isinstance(attr, str): 43 | return getattr(self, attr) 44 | else: 45 | return tuple.__getitem__(self, attr) 46 | 47 | 48 | class ExtendedTimeStep(NamedTuple): 49 | step_type: Any 50 | reward: Any 51 | discount: Any 52 | rgb_obs: Any 53 | low_dim_obs: Any 54 | action: Any 55 | demo: Any 56 | 57 | def first(self): 58 | return self.step_type == StepType.FIRST 59 | 60 | def mid(self): 61 | return self.step_type == StepType.MID 62 | 63 | def last(self): 64 | return self.step_type == StepType.LAST 65 | 66 | def __getitem__(self, attr): 67 | if isinstance(attr, str): 68 | return getattr(self, attr) 69 | else: 70 | return tuple.__getitem__(self, attr) 71 | 72 | 73 | class ExtendedTimeStepWrapper: 74 | def __init__(self, env): 75 | self._env = env 76 | 77 | def reset(self): 78 | time_step = self._env.reset() 79 | return self._augment_time_step(time_step) 80 | 81 | def step(self, action): 82 | time_step = self._env.step(action) 83 | return self._augment_time_step(time_step, action) 84 | 85 | def _augment_time_step(self, time_step, action=None): 86 | if action is None: 87 | action_spec = self.action_spec() 88 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 89 | return ExtendedTimeStep( 90 | rgb_obs=time_step.rgb_obs, 91 | low_dim_obs=time_step.low_dim_obs, 92 | step_type=time_step.step_type, 93 | action=action, 94 | reward=time_step.reward or 0.0, 95 | discount=time_step.discount or 1.0, 96 | demo=time_step.demo or 0.0, 97 | ) 98 | 99 | def low_dim_observation_spec(self): 100 | return self._env.low_dim_observation_spec() 101 | 102 | def rgb_observation_spec(self): 103 | return self._env.rgb_observation_spec() 104 | 105 | def low_dim_raw_observation_spec(self): 106 | return self._env.low_dim_raw_observation_spec() 107 | 108 | def rgb_raw_observation_spec(self): 109 | return self._env.rgb_raw_observation_spec() 110 | 111 | def action_spec(self): 112 | return self._env.action_spec() 113 | 114 | def __getattr__(self, name): 115 | return getattr(self._env, name) 116 | 117 | 118 | class RLBench: 119 | def __init__( 120 | self, 121 | task_name: str, 122 | episode_length: int = 200, 123 | frame_stack: int = 1, 124 | dataset_root: str = "", 125 | arm_max_velocity: float = 1.0, 126 | arm_max_acceleration: float = 4.0, 127 | camera_shape: tuple[int] = (84, 84), 128 | camera_keys: tuple[str] = ("front", "wrist", "left_shoulder", "right_shoulder"), 129 | state_keys: tuple[str] = ("joint_positions", "gripper_open"), 130 | renderer: str = "opengl3", 131 | render_mode: Union[None, str] = "rgb_array", 132 | ): 133 | self._task_name = task_name 134 | self._episode_length = episode_length 135 | self._frame_stack = frame_stack 136 | self._dataset_root = dataset_root 137 | self._arm_max_velocity = arm_max_velocity 138 | self._arm_max_acceleration = arm_max_acceleration 139 | self._camera_shape = camera_shape 140 | self._camera_keys = camera_keys 141 | self._state_keys = state_keys 142 | self._renderer = renderer 143 | self._render_mode = render_mode 144 | 145 | self._launch() 146 | self._add_gym_camera() 147 | self._initialize_frame_stack() 148 | self._construct_action_and_observation_spaces() 149 | 150 | def low_dim_observation_spec(self) -> spaces.Box: 151 | shape = self.low_dim_observation_space.shape 152 | spec = specs.Array(shape, np.float32, "low_dim_obs") 153 | return spec 154 | 155 | def low_dim_raw_observation_spec(self) -> spaces.Box: 156 | shape = self.low_dim_raw_observation_space.shape 157 | spec = specs.Array(shape, np.float32, "low_dim_obs") 158 | return spec 159 | 160 | def rgb_observation_spec(self) -> spaces.Box: 161 | shape = self.rgb_observation_space.shape 162 | spec = specs.Array(shape, np.uint8, "rgb_obs") 163 | return spec 164 | 165 | def rgb_raw_observation_spec(self) -> spaces.Box: 166 | shape = self.rgb_raw_observation_space.shape 167 | spec = specs.Array(shape, np.uint8, "rgb_obs") 168 | return spec 169 | 170 | def action_spec(self) -> spaces.Box: 171 | shape = self.action_space.shape 172 | spec = specs.Array(shape, np.float32, "action") 173 | return spec 174 | 175 | def step(self, action): 176 | action = self._convert_action_to_raw(action) 177 | rlb_obs, reward, terminated = self.task.step(action) 178 | obs = self._extract_obs(rlb_obs) 179 | self._step_counter += 1 180 | 181 | # Timelimit 182 | if self._step_counter >= self._episode_length: 183 | truncated = True 184 | else: 185 | truncated = False 186 | 187 | # Handle bootstrap 188 | if terminated or truncated: 189 | step_type = StepType.LAST 190 | else: 191 | step_type = StepType.MID 192 | discount = float(1 - terminated) 193 | 194 | return TimeStep( 195 | rgb_obs=obs["rgb_obs"], 196 | low_dim_obs=obs["low_dim_obs"], 197 | step_type=step_type, 198 | reward=reward, 199 | discount=discount, 200 | demo=0.0, 201 | ) 202 | 203 | def reset(self, **kwargs): 204 | # Clear deques used for frame stacking 205 | self._low_dim_obses.clear() 206 | for frames in self._frames.values(): 207 | frames.clear() 208 | 209 | _, rlb_obs = self.task.reset(**kwargs) 210 | obs = self._extract_obs(rlb_obs) 211 | self._step_counter = 0 212 | 213 | return TimeStep( 214 | rgb_obs=obs["rgb_obs"], 215 | low_dim_obs=obs["low_dim_obs"], 216 | step_type=StepType.FIRST, 217 | reward=0.0, 218 | discount=1.0, 219 | demo=0.0, 220 | ) 221 | 222 | def render(self, mode="rgb_array") -> Union[None, np.ndarray]: 223 | if mode != self._render_mode: 224 | raise ValueError( 225 | "The render mode must match the render mode selected in the " 226 | 'constructor. \nI.e. if you want "human" render mode, then ' 227 | "create the env by calling: " 228 | 'gym.make("reach_target-state-v0", render_mode="human").\n' 229 | "You passed in mode %s, but expected %s." % (mode, self._render_mode) 230 | ) 231 | if mode == "rgb_array": 232 | frame = self._gym_cam.capture_rgb() 233 | frame = np.clip((frame * 255.0).astype(np.uint8), 0, 255) 234 | return frame 235 | 236 | def get_demos(self, num_demos): 237 | """ 238 | 1. Collect or fetch demonstrations 239 | 2. Compute action stats from demonstrations, override self._action_stats 240 | 3. Rescale actions in demonstrations to [-1, 1] space 241 | """ 242 | live_demos = not self._dataset_root 243 | if live_demos: 244 | logging.warning("Live demo collection.. Takes a while..") 245 | raw_demos = self.task.get_demos(num_demos, live_demos) 246 | demos = [] 247 | for raw_demo in raw_demos: 248 | demo = self.convert_demo_to_timesteps(raw_demo) 249 | if demo is not None: 250 | demos.append(demo) 251 | else: 252 | print("Skipping demo for large delta action") 253 | # override action stats with demonstration-based stats 254 | self._action_stats = self.extract_action_stats(demos) 255 | # rescale actions with action stats 256 | demos = [self.rescale_demo_actions(demo) for demo in demos] 257 | return demos 258 | 259 | def extract_action_stats(self, demos: list[list[ExtendedTimeStep]]): 260 | actions = [] 261 | for demo in demos: 262 | for timestep in demo: 263 | actions.append(timestep.action) 264 | actions = np.stack(actions) 265 | 266 | # Gripper one-hot action's stats are hard-coded 267 | action_max = np.hstack([np.max(actions, 0)[:-1], 1]) 268 | action_min = np.hstack([np.min(actions, 0)[:-1], 0]) 269 | action_stats = { 270 | "max": action_max, 271 | "min": action_min, 272 | } 273 | return action_stats 274 | 275 | def extract_delta_joint_action(self, obs, next_obs): 276 | action = np.concatenate( 277 | [ 278 | ( 279 | next_obs.misc["joint_position_action"][:-1] - obs.joint_positions 280 | if "joint_position_action" in next_obs.misc 281 | else next_obs.joint_positions - obs.joint_positions 282 | ), 283 | [1.0 if next_obs.gripper_open == 1 else 0.0], 284 | ] 285 | ).astype(np.float32) 286 | return action 287 | 288 | def convert_demo_to_timesteps(self, demo): 289 | timesteps = [] 290 | 291 | # Clear deques used for frame stacking 292 | self._low_dim_obses.clear() 293 | for frames in self._frames.values(): 294 | frames.clear() 295 | 296 | for i in range(len(demo)): 297 | rlb_obs = demo[i] 298 | 299 | obs = self._extract_obs(rlb_obs) 300 | reward, discount = 0.0, 1.0 301 | if i == 0: 302 | # zero action for the first timestep 303 | action_spec = self.action_spec() 304 | action = np.zeros(action_spec.shape, dtype=action_spec.dtype) 305 | step_type = StepType.FIRST 306 | else: 307 | prev_rlb_obs = demo[i - 1] 308 | action = self.extract_delta_joint_action(prev_rlb_obs, rlb_obs) 309 | if np.any(action[:-1] > 0.2) or np.any(action[:-1] < -0.2): 310 | return None 311 | if i == len(demo) - 1: 312 | step_type = StepType.LAST 313 | reward = 1.0 314 | discount = 0.0 315 | else: 316 | step_type = StepType.MID 317 | 318 | timestep = ExtendedTimeStep( 319 | rgb_obs=obs["rgb_obs"], 320 | low_dim_obs=obs["low_dim_obs"], 321 | step_type=step_type, 322 | action=action, 323 | reward=reward, 324 | discount=discount, 325 | demo=1.0, 326 | ) 327 | timesteps.append(timestep) 328 | 329 | return timesteps 330 | 331 | def rescale_demo_actions( 332 | self, demo: list[ExtendedTimeStep] 333 | ) -> list[ExtendedTimeStep]: 334 | new_timesteps = [] 335 | for timestep in demo: 336 | action = self._convert_action_from_raw(timestep.action) 337 | new_timesteps.append(timestep._replace(action=action)) 338 | return new_timesteps 339 | 340 | def close(self) -> None: 341 | self._env.shutdown() 342 | 343 | def _launch(self): 344 | task_class = name_to_task_class(self._task_name) 345 | 346 | # Setup observation configs 347 | obs_config = ObservationConfig() 348 | obs_config.set_all_high_dim(False) 349 | obs_config.set_all_low_dim(False) 350 | assert ( 351 | "joint_positions" in self._state_keys 352 | ), "joint position is required as this code assumes joint control" 353 | for state_key in self._state_keys: 354 | setattr(obs_config, state_key, True) 355 | for camera_key in self._camera_keys: 356 | camera_config = getattr(obs_config, f"{camera_key}_camera") 357 | camera_config.rgb = True 358 | camera_config.image_size = self._camera_shape 359 | camera_config.render_mode = ( 360 | RenderMode.OPENGL3 if self._renderer == "opengl3" else RenderMode.OPENGL 361 | ) 362 | setattr(obs_config, f"{camera_key}_camera", camera_config) 363 | 364 | # Setup action mode 365 | action_mode = MoveArmThenGripper( 366 | arm_action_mode=JointPosition(False), gripper_action_mode=Discrete() 367 | ) 368 | 369 | # Launch environment and setup spaces 370 | self._env = Environment( 371 | action_mode, 372 | arm_max_velocity=self._arm_max_velocity, 373 | arm_max_acceleration=self._arm_max_acceleration, 374 | obs_config=obs_config, 375 | dataset_root=self._dataset_root, 376 | headless=True, 377 | ) 378 | self._env.launch() 379 | self.task = self._env.get_task(task_class) 380 | 381 | # Episode length counter 382 | self._step_counter = 0 383 | 384 | def _add_gym_camera(self): 385 | if self._render_mode is not None: 386 | # Add the camera to the scene 387 | cam_placeholder = Dummy("cam_cinematic_placeholder") 388 | self._gym_cam = VisionSensor.create([320, 192]) 389 | self._gym_cam.set_pose(cam_placeholder.get_pose()) 390 | if self._render_mode == "human": 391 | self._gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED) 392 | else: 393 | self._gym_cam.set_render_mode(RenderMode.OPENGL3) 394 | 395 | def _initialize_frame_stack(self): 396 | # Create deques for frame stacking 397 | self._low_dim_obses = deque([], maxlen=self._frame_stack) 398 | self._frames = { 399 | camera_key: deque([], maxlen=self._frame_stack) 400 | for camera_key in self._camera_keys 401 | } 402 | 403 | def _construct_action_and_observation_spaces(self): 404 | # Setup action/observation spaces 405 | self.action_space = spaces.Box(low=-1.0, high=1.0, shape=self._env.action_shape) 406 | self.low_dim_observation_space = spaces.Box( 407 | low=-np.inf, high=np.inf, shape=(8 * self._frame_stack,), dtype=np.float32 408 | ) # hard-coded: joint: 7, gripper_open: 1 409 | self.low_dim_raw_observation_space = spaces.Box( 410 | low=-np.inf, high=np.inf, shape=(8,), dtype=np.float32 411 | ) # without frame stacking 412 | self.rgb_observation_space = spaces.Box( 413 | low=0, 414 | high=255, 415 | shape=(len(self._camera_keys), 3 * self._frame_stack, *self._camera_shape), 416 | dtype=np.uint8, 417 | ) 418 | self.rgb_raw_observation_space = spaces.Box( 419 | low=0, 420 | high=255, 421 | shape=(len(self._camera_keys), 3, *self._camera_shape), 422 | dtype=np.uint8, 423 | ) # without frame stacking 424 | 425 | # Set default action stats, which will be overridden by demonstration action stats 426 | # Required for a case we don't use demonstrations 427 | action_min = ( 428 | -np.ones(self.action_space.shape, dtype=self.action_space.dtype) * 0.2 429 | ) 430 | action_max = ( 431 | np.ones(self.action_space.shape, dtype=self.action_space.dtype) * 0.2 432 | ) 433 | action_min[-1] = 0 434 | action_max[-1] = 1 435 | self._action_stats = {"min": action_min, "max": action_max} 436 | 437 | def _convert_action_to_raw(self, action): 438 | """Convert [-1, 1] action to raw joint space using action stats""" 439 | assert (max(action) <= 1) and (min(action) >= -1) 440 | action_min, action_max = self._action_stats["min"], self._action_stats["max"] 441 | _action_min = action_min - np.fabs(action_min) * 0.2 442 | _action_max = action_max + np.fabs(action_max) * 0.2 443 | new_action = (action + 1) / 2.0 # to [0, 1] 444 | new_action = new_action * (_action_max - _action_min) + _action_min # original 445 | return new_action.astype(action.dtype, copy=False) 446 | 447 | def _convert_action_from_raw(self, action): 448 | """Convert raw action in joint space to [-1, 1] using action stats""" 449 | action_min, action_max = self._action_stats["min"], self._action_stats["max"] 450 | _action_min = action_min - np.fabs(action_min) * 0.2 451 | _action_max = action_max + np.fabs(action_max) * 0.2 452 | 453 | new_action = (action - _action_min) / (_action_max - _action_min) # to [0, 1] 454 | new_action = new_action * 2 - 1 # to [-1, 1] 455 | return new_action.astype(action.dtype, copy=False) 456 | 457 | def _extract_obs(self, obs) -> Dict[str, np.ndarray]: 458 | obs = vars(obs) 459 | out = dict() 460 | 461 | # Get low-dimensional state with stacking 462 | low_dim_obs = np.hstack( 463 | [obs[key] for key in self._state_keys], dtype=np.float32 464 | ) 465 | if len(self._low_dim_obses) == 0: 466 | for _ in range(self._frame_stack): 467 | self._low_dim_obses.append(low_dim_obs) 468 | else: 469 | self._low_dim_obses.append(low_dim_obs) 470 | out["low_dim_obs"] = np.concatenate(list(self._low_dim_obses), axis=0) 471 | 472 | # Get rgb observations with stacking 473 | for camera_key in self._camera_keys: 474 | pixels = obs[f"{camera_key}_rgb"].transpose(2, 0, 1).copy() 475 | if len(self._frames[camera_key]) == 0: 476 | for _ in range(self._frame_stack): 477 | self._frames[camera_key].append(pixels) 478 | else: 479 | self._frames[camera_key].append(pixels) 480 | out["rgb_obs"] = np.stack( 481 | [ 482 | np.concatenate(list(self._frames[camera_key]), axis=0) 483 | for camera_key in self._camera_keys 484 | ], 485 | 0, 486 | ) 487 | return out 488 | 489 | def __del__( 490 | self, 491 | ) -> None: 492 | self.close() 493 | 494 | 495 | def make( 496 | task_name, 497 | episode_length, 498 | frame_stack, 499 | dataset_root, 500 | arm_max_velocity, 501 | arm_max_acceleration, 502 | camera_shape, 503 | camera_keys, 504 | state_keys, 505 | renderer, 506 | ): 507 | env = RLBench( 508 | task_name, 509 | episode_length=episode_length, 510 | frame_stack=frame_stack, 511 | dataset_root=dataset_root, 512 | arm_max_velocity=arm_max_velocity, 513 | arm_max_acceleration=arm_max_acceleration, 514 | camera_shape=camera_shape, 515 | camera_keys=camera_keys, 516 | state_keys=state_keys, 517 | renderer=renderer, 518 | ) 519 | env = ExtendedTimeStepWrapper(env) 520 | return env 521 | -------------------------------------------------------------------------------- /train_dmc.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore", category=DeprecationWarning) 4 | 5 | import os 6 | 7 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 8 | os.environ["MUJOCO_GL"] = "egl" 9 | 10 | from pathlib import Path 11 | 12 | import hydra 13 | import numpy as np 14 | import torch 15 | from dm_env import specs 16 | 17 | import dmc 18 | import utils 19 | from logger import Logger 20 | from replay_buffer_dmc import ReplayBufferStorage, make_replay_loader 21 | from video import TrainVideoRecorder, VideoRecorder 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | 26 | def make_agent(obs_spec, action_spec, use_logger, cfg): 27 | cfg.obs_shape = obs_spec.shape 28 | cfg.action_shape = action_spec.shape 29 | cfg.use_logger = use_logger 30 | return hydra.utils.instantiate(cfg) 31 | 32 | 33 | class Workspace: 34 | def __init__(self, cfg): 35 | self.work_dir = Path.cwd() 36 | print(f"workspace: {self.work_dir}") 37 | 38 | self.cfg = cfg 39 | utils.set_seed_everywhere(cfg.seed) 40 | self.device = torch.device(cfg.device) 41 | self.setup() 42 | 43 | self.agent = make_agent( 44 | self.train_env.observation_spec(), 45 | self.train_env.action_spec(), 46 | self.cfg.use_tb or self.cfg.use_wandb, 47 | self.cfg.agent, 48 | ) 49 | self.timer = utils.Timer() 50 | self.logger = Logger( 51 | self.work_dir, self.cfg.use_tb, self.cfg.use_wandb, self.cfg 52 | ) 53 | self._global_step = 0 54 | self._global_episode = 0 55 | 56 | def setup(self): 57 | # create envs 58 | self.train_env = dmc.make( 59 | self.cfg.task_name, 60 | self.cfg.frame_stack, 61 | self.cfg.action_repeat, 62 | self.cfg.seed, 63 | ) 64 | self.eval_env = dmc.make( 65 | self.cfg.task_name, 66 | self.cfg.frame_stack, 67 | self.cfg.action_repeat, 68 | self.cfg.seed, 69 | ) 70 | # create replay buffer 71 | data_specs = ( 72 | self.train_env.observation_spec(), 73 | self.train_env.action_spec(), 74 | specs.Array((1,), np.float32, "reward"), 75 | specs.Array((1,), np.float32, "discount"), 76 | ) 77 | 78 | self.replay_storage = ReplayBufferStorage(data_specs, self.work_dir / "buffer") 79 | 80 | self.replay_loader = make_replay_loader( 81 | self.work_dir / "buffer", 82 | self.cfg.replay_buffer_size, 83 | self.cfg.batch_size, 84 | self.cfg.replay_buffer_num_workers, 85 | self.cfg.save_snapshot, 86 | self.cfg.nstep, 87 | self.cfg.discount, 88 | ) 89 | self._replay_iter = None 90 | 91 | self.video_recorder = VideoRecorder( 92 | self.work_dir if self.cfg.save_video else None 93 | ) 94 | self.train_video_recorder = TrainVideoRecorder( 95 | self.work_dir if self.cfg.save_train_video else None 96 | ) 97 | 98 | @property 99 | def global_step(self): 100 | return self._global_step 101 | 102 | @property 103 | def global_episode(self): 104 | return self._global_episode 105 | 106 | @property 107 | def global_frame(self): 108 | return self.global_step * self.cfg.action_repeat 109 | 110 | @property 111 | def replay_iter(self): 112 | if self._replay_iter is None: 113 | self._replay_iter = iter(self.replay_loader) 114 | return self._replay_iter 115 | 116 | def eval(self): 117 | step, episode, total_reward = 0, 0, 0 118 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 119 | 120 | while eval_until_episode(episode): 121 | time_step = self.eval_env.reset() 122 | self.video_recorder.init(self.eval_env, enabled=(episode == 0)) 123 | while not time_step.last(): 124 | with torch.no_grad(), utils.eval_mode(self.agent): 125 | action = self.agent.act( 126 | time_step.observation, self.global_step, eval_mode=True 127 | ) 128 | time_step = self.eval_env.step(action) 129 | self.video_recorder.record(self.eval_env) 130 | total_reward += time_step.reward 131 | step += 1 132 | 133 | episode += 1 134 | self.video_recorder.save(f"{self.global_frame}.mp4") 135 | 136 | with self.logger.log_and_dump_ctx(self.global_frame, ty="eval") as log: 137 | log("episode_reward", total_reward / episode) 138 | log("episode_length", step * self.cfg.action_repeat / episode) 139 | log("episode", self.global_episode) 140 | log("step", self.global_step) 141 | 142 | def train(self): 143 | # predicates 144 | train_until_step = utils.Until( 145 | self.cfg.num_train_frames, self.cfg.action_repeat 146 | ) 147 | seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat) 148 | eval_every_step = utils.Every( 149 | self.cfg.eval_every_frames, self.cfg.action_repeat 150 | ) 151 | 152 | episode_step, episode_reward = 0, 0 153 | time_step = self.train_env.reset() 154 | self.replay_storage.add(time_step) 155 | self.train_video_recorder.init(time_step.observation) 156 | metrics = None 157 | while train_until_step(self.global_step): 158 | if time_step.last(): 159 | self._global_episode += 1 160 | self.train_video_recorder.save(f"{self.global_frame}.mp4") 161 | # wait until all the metrics schema is populated 162 | if metrics is not None: 163 | # log stats 164 | elapsed_time, total_time = self.timer.reset() 165 | episode_frame = episode_step * self.cfg.action_repeat 166 | with self.logger.log_and_dump_ctx( 167 | self.global_frame, ty="train" 168 | ) as log: 169 | log("fps", episode_frame / elapsed_time) 170 | log("total_time", total_time) 171 | log("episode_reward", episode_reward) 172 | log("episode_length", episode_frame) 173 | log("episode", self.global_episode) 174 | log("buffer_size", len(self.replay_storage)) 175 | log("step", self.global_step) 176 | 177 | # reset env 178 | time_step = self.train_env.reset() 179 | self.replay_storage.add(time_step) 180 | self.train_video_recorder.init(time_step.observation) 181 | # try to save snapshot 182 | if self.cfg.save_snapshot: 183 | self.save_snapshot() 184 | episode_step = 0 185 | episode_reward = 0 186 | 187 | # try to evaluate 188 | if eval_every_step(self.global_step): 189 | self.logger.log( 190 | "eval_total_time", self.timer.total_time(), self.global_frame 191 | ) 192 | self.eval() 193 | 194 | # sample action 195 | with torch.no_grad(), utils.eval_mode(self.agent): 196 | action = self.agent.act( 197 | time_step.observation, self.global_step, eval_mode=False 198 | ) 199 | 200 | # try to update the agent 201 | if not seed_until_step(self.global_step): 202 | metrics = self.agent.update(self.replay_iter, self.global_step) 203 | self.logger.log_metrics(metrics, self.global_frame, ty="train") 204 | 205 | # take env step 206 | time_step = self.train_env.step(action) 207 | episode_reward += time_step.reward 208 | self.replay_storage.add(time_step) 209 | self.train_video_recorder.record(time_step.observation) 210 | episode_step += 1 211 | self._global_step += 1 212 | 213 | def save_snapshot(self): 214 | snapshot = self.work_dir / "snapshot.pt" 215 | keys_to_save = ["agent", "timer", "_global_step", "_global_episode"] 216 | payload = {k: self.__dict__[k] for k in keys_to_save} 217 | with snapshot.open("wb") as f: 218 | torch.save(payload, f) 219 | 220 | def load_snapshot(self): 221 | snapshot = self.work_dir / "snapshot.pt" 222 | with snapshot.open("rb") as f: 223 | payload = torch.load(f) 224 | for k, v in payload.items(): 225 | self.__dict__[k] = v 226 | 227 | 228 | @hydra.main(config_path="cfgs", config_name="config_dmc") 229 | def main(cfg): 230 | from train_dmc import Workspace as W 231 | 232 | root_dir = Path.cwd() 233 | workspace = W(cfg) 234 | snapshot = root_dir / "snapshot.pt" 235 | if snapshot.exists(): 236 | print(f"resuming: {snapshot}") 237 | workspace.load_snapshot() 238 | workspace.train() 239 | 240 | 241 | if __name__ == "__main__": 242 | main() 243 | -------------------------------------------------------------------------------- /train_rlbench.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore", category=DeprecationWarning) 5 | 6 | import os 7 | 8 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 9 | os.environ["MUJOCO_GL"] = "egl" 10 | 11 | from pathlib import Path 12 | 13 | import hydra 14 | import numpy as np 15 | import torch 16 | from dm_env import specs 17 | 18 | import rlbench_env 19 | import utils 20 | from logger import Logger 21 | from replay_buffer import ReplayBufferStorage, make_replay_loader 22 | from video import TrainVideoRecorder, VideoRecorder 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | 27 | def make_agent(rgb_obs_spec, low_dim_obs_spec, action_spec, use_logger, cfg): 28 | cfg.rgb_obs_shape = rgb_obs_spec.shape 29 | cfg.low_dim_obs_shape = low_dim_obs_spec.shape 30 | cfg.action_shape = action_spec.shape 31 | cfg.use_logger = use_logger 32 | return hydra.utils.instantiate(cfg) 33 | 34 | 35 | class Workspace: 36 | def __init__(self, cfg): 37 | self.work_dir = Path.cwd() 38 | print(f"workspace: {self.work_dir}") 39 | 40 | self.cfg = cfg 41 | utils.set_seed_everywhere(cfg.seed) 42 | self.device = torch.device(cfg.device) 43 | self.setup() 44 | 45 | self.agent = make_agent( 46 | self.train_env.rgb_observation_spec(), 47 | self.train_env.low_dim_observation_spec(), 48 | self.train_env.action_spec(), 49 | self.cfg.use_tb or self.cfg.use_wandb, 50 | self.cfg.agent, 51 | ) 52 | self.timer = utils.Timer() 53 | self.logger = Logger( 54 | self.work_dir, self.cfg.use_tb, self.cfg.use_wandb, self.cfg 55 | ) 56 | self._global_step = 0 57 | self._global_episode = 0 58 | 59 | def setup(self): 60 | # create envs 61 | self.train_env = rlbench_env.make( 62 | self.cfg.task_name, 63 | self.cfg.episode_length, 64 | self.cfg.frame_stack, 65 | self.cfg.dataset_root, 66 | self.cfg.arm_max_velocity, 67 | self.cfg.arm_max_acceleration, 68 | self.cfg.camera_shape, 69 | self.cfg.camera_keys, 70 | self.cfg.state_keys, 71 | self.cfg.renderer, 72 | ) 73 | # create replay buffer 74 | data_specs = ( 75 | self.train_env.rgb_raw_observation_spec(), 76 | self.train_env.low_dim_raw_observation_spec(), 77 | self.train_env.action_spec(), 78 | specs.Array((1,), np.float32, "reward"), 79 | specs.Array((1,), np.float32, "discount"), 80 | specs.Array((1,), np.float32, "demo"), 81 | ) 82 | 83 | self.replay_storage = ReplayBufferStorage( 84 | data_specs, self.work_dir / "buffer", self.cfg.use_relabeling 85 | ) 86 | self.demo_replay_storage = ReplayBufferStorage( 87 | data_specs, 88 | self.work_dir / "demo_buffer", 89 | self.cfg.use_relabeling, 90 | is_demo_buffer=True, 91 | ) 92 | 93 | self.replay_loader = make_replay_loader( 94 | self.work_dir / "buffer", 95 | self.cfg.replay_buffer_size, 96 | self.cfg.batch_size, 97 | self.cfg.replay_buffer_num_workers, 98 | self.cfg.save_snapshot, 99 | self.cfg.nstep, 100 | self.cfg.discount, 101 | self.cfg.do_always_bootstrap, 102 | self.cfg.frame_stack, 103 | ) 104 | self.demo_replay_loader = make_replay_loader( 105 | self.work_dir / "demo_buffer", 106 | self.cfg.replay_buffer_size, 107 | self.cfg.batch_size, 108 | self.cfg.replay_buffer_num_workers, 109 | self.cfg.save_snapshot, 110 | self.cfg.nstep, 111 | self.cfg.discount, 112 | self.cfg.do_always_bootstrap, 113 | self.cfg.frame_stack, 114 | ) 115 | self._replay_iter = None 116 | 117 | self.video_recorder = VideoRecorder( 118 | self.work_dir if self.cfg.save_video else None 119 | ) 120 | self.train_video_recorder = TrainVideoRecorder( 121 | self.work_dir if self.cfg.save_train_video else None 122 | ) 123 | 124 | @property 125 | def global_step(self): 126 | return self._global_step 127 | 128 | @property 129 | def global_episode(self): 130 | return self._global_episode 131 | 132 | @property 133 | def global_frame(self): 134 | return self.global_step * self.cfg.action_repeat 135 | 136 | @property 137 | def replay_iter(self): 138 | if self._replay_iter is None: 139 | replay_iter = iter(self.replay_loader) 140 | demo_replay_iter = iter(self.demo_replay_loader) 141 | self._replay_iter = utils.DemoMergedIterator(replay_iter, demo_replay_iter) 142 | return self._replay_iter 143 | 144 | def eval(self): 145 | """We use train env for evaluation, because it's convenient""" 146 | step, episode, total_reward = 0, 0, 0 147 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 148 | 149 | while eval_until_episode(episode): 150 | time_step = self.train_env.reset() 151 | self.video_recorder.init(self.train_env, enabled=(episode == 0)) 152 | while not time_step.last(): 153 | with torch.no_grad(), utils.eval_mode(self.agent): 154 | action = self.agent.act( 155 | time_step.rgb_obs, 156 | time_step.low_dim_obs, 157 | self.global_step, 158 | eval_mode=True, 159 | ) 160 | time_step = self.train_env.step(action) 161 | self.video_recorder.record(self.train_env) 162 | total_reward += time_step.reward 163 | step += 1 164 | 165 | episode += 1 166 | self.video_recorder.save(f"{self.global_frame}.mp4") 167 | 168 | with self.logger.log_and_dump_ctx(self.global_frame, ty="eval") as log: 169 | log("episode_reward", total_reward / episode) 170 | log("episode_length", step * self.cfg.action_repeat / episode) 171 | log("episode", self.global_episode) 172 | log("step", self.global_step) 173 | 174 | def train(self): 175 | # predicates 176 | train_until_step = utils.Until( 177 | self.cfg.num_train_frames, self.cfg.action_repeat 178 | ) 179 | seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat) 180 | eval_every_step = utils.Every( 181 | self.cfg.eval_every_frames, self.cfg.action_repeat 182 | ) 183 | 184 | do_eval = False 185 | 186 | episode_step, episode_reward = 0, 0 187 | time_step = self.train_env.reset() 188 | self.replay_storage.add(time_step) 189 | self.demo_replay_storage.add(time_step) 190 | self.train_video_recorder.init(time_step.rgb_obs[0]) 191 | metrics = None 192 | while train_until_step(self.global_step): 193 | if time_step.last(): 194 | self._global_episode += 1 195 | self.train_video_recorder.save(f"{self.global_frame}.mp4") 196 | # wait until all the metrics schema is populated 197 | if metrics is not None: 198 | # log stats 199 | elapsed_time, total_time = self.timer.reset() 200 | episode_frame = episode_step * self.cfg.action_repeat 201 | with self.logger.log_and_dump_ctx( 202 | self.global_frame, ty="train" 203 | ) as log: 204 | log("fps", episode_frame / elapsed_time) 205 | log("total_time", total_time) 206 | log("episode_reward", episode_reward) 207 | log("episode_length", episode_frame) 208 | log("episode", self.global_episode) 209 | log("buffer_size", len(self.replay_storage)) 210 | log("demo_buffer_size", len(self.demo_replay_storage)) 211 | log("step", self.global_step) 212 | 213 | # do evaluation before resetting the environment 214 | if do_eval: 215 | self.logger.log( 216 | "eval_total_time", self.timer.total_time(), self.global_frame 217 | ) 218 | self.eval() 219 | do_eval = False 220 | 221 | # reset env 222 | time_step = self.train_env.reset() 223 | self.replay_storage.add(time_step) 224 | self.demo_replay_storage.add(time_step) 225 | self.train_video_recorder.init(time_step.rgb_obs[0]) 226 | # try to save snapshot 227 | if self.cfg.save_snapshot: 228 | self.save_snapshot() 229 | episode_step = 0 230 | episode_reward = 0 231 | 232 | # set a flag to initate evaluation when the current episode terminates 233 | if self.global_step >= self.cfg.eval_every_frames and eval_every_step( 234 | self.global_step 235 | ): 236 | do_eval = True 237 | 238 | # sample action 239 | with torch.no_grad(), utils.eval_mode(self.agent): 240 | action = self.agent.act( 241 | time_step.rgb_obs, 242 | time_step.low_dim_obs, 243 | self.global_step, 244 | eval_mode=False, 245 | ) 246 | 247 | # try to update the agent 248 | if not seed_until_step(self.global_step): 249 | for _ in range(self.cfg.num_update_steps): 250 | metrics = self.agent.update(self.replay_iter, self.global_step) 251 | self.logger.log_metrics(metrics, self.global_frame, ty="train") 252 | 253 | # take env step 254 | time_step = self.train_env.step(action) 255 | episode_reward += time_step.reward 256 | self.replay_storage.add(time_step) 257 | self.demo_replay_storage.add(time_step) 258 | self.train_video_recorder.record(time_step.rgb_obs[0]) 259 | episode_step += 1 260 | self._global_step += 1 261 | 262 | def load_rlbench_demos(self): 263 | if self.cfg.num_demos > 0: 264 | demos = self.train_env.get_demos(self.cfg.num_demos) 265 | for demo in demos: 266 | for time_step in demo: 267 | self.replay_storage.add(time_step) 268 | self.demo_replay_storage.add(time_step) 269 | else: 270 | logging.warning("Not using demonstrations") 271 | 272 | def save_snapshot(self): 273 | snapshot = self.work_dir / "snapshot.pt" 274 | keys_to_save = ["agent", "timer", "_global_step", "_global_episode"] 275 | payload = {k: self.__dict__[k] for k in keys_to_save} 276 | with snapshot.open("wb") as f: 277 | torch.save(payload, f) 278 | 279 | def load_snapshot(self): 280 | snapshot = self.work_dir / "snapshot.pt" 281 | with snapshot.open("rb") as f: 282 | payload = torch.load(f) 283 | for k, v in payload.items(): 284 | self.__dict__[k] = v 285 | 286 | 287 | @hydra.main(config_path="cfgs", config_name="config_rlbench") 288 | def main(cfg): 289 | from train_rlbench import Workspace as W 290 | 291 | root_dir = Path.cwd() 292 | workspace = W(cfg) 293 | snapshot = root_dir / "snapshot.pt" 294 | if snapshot.exists(): 295 | print(f"resuming: {snapshot}") 296 | workspace.load_snapshot() 297 | workspace.load_rlbench_demos() 298 | workspace.train() 299 | 300 | 301 | if __name__ == "__main__": 302 | main() 303 | -------------------------------------------------------------------------------- /train_rlbench_drqv2plus.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore", category=DeprecationWarning) 5 | 6 | import os 7 | 8 | os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" 9 | os.environ["MUJOCO_GL"] = "egl" 10 | 11 | from pathlib import Path 12 | 13 | import hydra 14 | import numpy as np 15 | import torch 16 | from dm_env import specs 17 | 18 | import rlbench_env 19 | import utils 20 | from logger import Logger 21 | from replay_buffer import ReplayBufferStorage, make_replay_loader 22 | from video import TrainVideoRecorder, VideoRecorder 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | 27 | def make_agent(rgb_obs_spec, low_dim_obs_spec, action_spec, use_logger, cfg): 28 | cfg.rgb_obs_shape = rgb_obs_spec.shape 29 | cfg.low_dim_obs_shape = low_dim_obs_spec.shape 30 | cfg.action_shape = action_spec.shape 31 | cfg.use_logger = use_logger 32 | return hydra.utils.instantiate(cfg) 33 | 34 | 35 | class Workspace: 36 | def __init__(self, cfg): 37 | self.work_dir = Path.cwd() 38 | print(f"workspace: {self.work_dir}") 39 | 40 | self.cfg = cfg 41 | utils.set_seed_everywhere(cfg.seed) 42 | self.device = torch.device(cfg.device) 43 | self.setup() 44 | 45 | self.agent = make_agent( 46 | self.train_env.rgb_observation_spec(), 47 | self.train_env.low_dim_observation_spec(), 48 | self.train_env.action_spec(), 49 | self.cfg.use_tb or self.cfg.use_wandb, 50 | self.cfg.agent, 51 | ) 52 | self.timer = utils.Timer() 53 | self.logger = Logger( 54 | self.work_dir, self.cfg.use_tb, self.cfg.use_wandb, self.cfg 55 | ) 56 | self._global_step = 0 57 | self._global_episode = 0 58 | 59 | def setup(self): 60 | # create envs 61 | self.train_env = rlbench_env.make( 62 | self.cfg.task_name, 63 | self.cfg.episode_length, 64 | self.cfg.frame_stack, 65 | self.cfg.dataset_root, 66 | self.cfg.arm_max_velocity, 67 | self.cfg.arm_max_acceleration, 68 | self.cfg.camera_shape, 69 | self.cfg.camera_keys, 70 | self.cfg.state_keys, 71 | self.cfg.renderer, 72 | ) 73 | # create replay buffer 74 | data_specs = ( 75 | self.train_env.rgb_raw_observation_spec(), 76 | self.train_env.low_dim_raw_observation_spec(), 77 | self.train_env.action_spec(), 78 | specs.Array((1,), np.float32, "reward"), 79 | specs.Array((1,), np.float32, "discount"), 80 | specs.Array((1,), np.float32, "demo"), 81 | ) 82 | 83 | self.replay_storage = ReplayBufferStorage( 84 | data_specs, self.work_dir / "buffer", self.cfg.use_relabeling 85 | ) 86 | self.demo_replay_storage = ReplayBufferStorage( 87 | data_specs, 88 | self.work_dir / "demo_buffer", 89 | self.cfg.use_relabeling, 90 | is_demo_buffer=True, 91 | ) 92 | 93 | self.replay_loader = make_replay_loader( 94 | self.work_dir / "buffer", 95 | self.cfg.replay_buffer_size, 96 | self.cfg.batch_size, 97 | self.cfg.replay_buffer_num_workers, 98 | self.cfg.save_snapshot, 99 | self.cfg.nstep, 100 | self.cfg.discount, 101 | self.cfg.do_always_bootstrap, 102 | self.cfg.frame_stack, 103 | ) 104 | self.demo_replay_loader = make_replay_loader( 105 | self.work_dir / "demo_buffer", 106 | self.cfg.replay_buffer_size, 107 | self.cfg.batch_size, 108 | self.cfg.replay_buffer_num_workers, 109 | self.cfg.save_snapshot, 110 | self.cfg.nstep, 111 | self.cfg.discount, 112 | self.cfg.do_always_bootstrap, 113 | self.cfg.frame_stack, 114 | ) 115 | self._replay_iter = None 116 | 117 | self.video_recorder = VideoRecorder( 118 | self.work_dir if self.cfg.save_video else None 119 | ) 120 | self.train_video_recorder = TrainVideoRecorder( 121 | self.work_dir if self.cfg.save_train_video else None 122 | ) 123 | 124 | @property 125 | def global_step(self): 126 | return self._global_step 127 | 128 | @property 129 | def global_episode(self): 130 | return self._global_episode 131 | 132 | @property 133 | def global_frame(self): 134 | return self.global_step * self.cfg.action_repeat 135 | 136 | @property 137 | def replay_iter(self): 138 | if self._replay_iter is None: 139 | replay_iter = iter(self.replay_loader) 140 | demo_replay_iter = iter(self.demo_replay_loader) 141 | self._replay_iter = utils.DemoMergedIterator(replay_iter, demo_replay_iter) 142 | return self._replay_iter 143 | 144 | def eval(self): 145 | """We use train env for evaluation, because it's convenient""" 146 | step, episode, total_reward = 0, 0, 0 147 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 148 | 149 | while eval_until_episode(episode): 150 | time_step = self.train_env.reset() 151 | self.video_recorder.init(self.train_env, enabled=(episode == 0)) 152 | while not time_step.last(): 153 | with torch.no_grad(), utils.eval_mode(self.agent): 154 | action = self.agent.act( 155 | time_step.rgb_obs, 156 | time_step.low_dim_obs, 157 | self.global_step, 158 | eval_mode=True, 159 | ) 160 | time_step = self.train_env.step(action) 161 | self.video_recorder.record(self.train_env) 162 | total_reward += time_step.reward 163 | step += 1 164 | 165 | episode += 1 166 | self.video_recorder.save(f"{self.global_frame}.mp4") 167 | 168 | with self.logger.log_and_dump_ctx(self.global_frame, ty="eval") as log: 169 | log("episode_reward", total_reward / episode) 170 | log("episode_length", step * self.cfg.action_repeat / episode) 171 | log("episode", self.global_episode) 172 | log("step", self.global_step) 173 | 174 | def train(self): 175 | # predicates 176 | train_until_step = utils.Until( 177 | self.cfg.num_train_frames, self.cfg.action_repeat 178 | ) 179 | seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat) 180 | eval_every_step = utils.Every( 181 | self.cfg.eval_every_frames, self.cfg.action_repeat 182 | ) 183 | 184 | do_eval = False 185 | 186 | episode_step, episode_reward = 0, 0 187 | time_step = self.train_env.reset() 188 | self.replay_storage.add(time_step) 189 | self.demo_replay_storage.add(time_step) 190 | self.train_video_recorder.init(time_step.rgb_obs[0]) 191 | metrics = None 192 | while train_until_step(self.global_step): 193 | if time_step.last(): 194 | self._global_episode += 1 195 | self.train_video_recorder.save(f"{self.global_frame}.mp4") 196 | # wait until all the metrics schema is populated 197 | if metrics is not None: 198 | # log stats 199 | elapsed_time, total_time = self.timer.reset() 200 | episode_frame = episode_step * self.cfg.action_repeat 201 | with self.logger.log_and_dump_ctx( 202 | self.global_frame, ty="train" 203 | ) as log: 204 | log("fps", episode_frame / elapsed_time) 205 | log("total_time", total_time) 206 | log("episode_reward", episode_reward) 207 | log("episode_length", episode_frame) 208 | log("episode", self.global_episode) 209 | log("buffer_size", len(self.replay_storage)) 210 | log("demo_buffer_size", len(self.demo_replay_storage)) 211 | log("step", self.global_step) 212 | 213 | # do evaluation before resetting the environment 214 | if do_eval: 215 | self.logger.log( 216 | "eval_total_time", self.timer.total_time(), self.global_frame 217 | ) 218 | self.eval() 219 | do_eval = False 220 | 221 | # reset env 222 | time_step = self.train_env.reset() 223 | self.replay_storage.add(time_step) 224 | self.demo_replay_storage.add(time_step) 225 | self.train_video_recorder.init(time_step.rgb_obs[0]) 226 | # try to save snapshot 227 | if self.cfg.save_snapshot: 228 | self.save_snapshot() 229 | episode_step = 0 230 | episode_reward = 0 231 | 232 | # set a flag to initate evaluation when the current episode terminates 233 | if self.global_step >= self.cfg.eval_every_frames and eval_every_step( 234 | self.global_step 235 | ): 236 | do_eval = True 237 | 238 | # sample action 239 | with torch.no_grad(), utils.eval_mode(self.agent): 240 | action = self.agent.act( 241 | time_step.rgb_obs, 242 | time_step.low_dim_obs, 243 | self.global_step, 244 | eval_mode=False, 245 | ) 246 | 247 | # try to update the agent 248 | if not seed_until_step(self.global_step): 249 | for _ in range(self.cfg.num_update_steps): 250 | metrics = self.agent.update(self.replay_iter, self.global_step) 251 | self.logger.log_metrics(metrics, self.global_frame, ty="train") 252 | 253 | # take env step 254 | time_step = self.train_env.step(action) 255 | episode_reward += time_step.reward 256 | self.replay_storage.add(time_step) 257 | self.demo_replay_storage.add(time_step) 258 | self.train_video_recorder.record(time_step.rgb_obs[0]) 259 | episode_step += 1 260 | self._global_step += 1 261 | 262 | def load_rlbench_demos(self): 263 | if self.cfg.num_demos > 0: 264 | demos = self.train_env.get_demos(self.cfg.num_demos) 265 | for demo in demos: 266 | for time_step in demo: 267 | self.replay_storage.add(time_step) 268 | self.demo_replay_storage.add(time_step) 269 | else: 270 | logging.warning("Not using demonstrations") 271 | 272 | def save_snapshot(self): 273 | snapshot = self.work_dir / "snapshot.pt" 274 | keys_to_save = ["agent", "timer", "_global_step", "_global_episode"] 275 | payload = {k: self.__dict__[k] for k in keys_to_save} 276 | with snapshot.open("wb") as f: 277 | torch.save(payload, f) 278 | 279 | def load_snapshot(self): 280 | snapshot = self.work_dir / "snapshot.pt" 281 | with snapshot.open("rb") as f: 282 | payload = torch.load(f) 283 | for k, v in payload.items(): 284 | self.__dict__[k] = v 285 | 286 | 287 | @hydra.main(config_path="cfgs", config_name="config_rlbench_drqv2plus") 288 | def main(cfg): 289 | from train_rlbench_drqv2plus import Workspace as W 290 | 291 | root_dir = Path.cwd() 292 | workspace = W(cfg) 293 | snapshot = root_dir / "snapshot.pt" 294 | if snapshot.exists(): 295 | print(f"resuming: {snapshot}") 296 | workspace.load_snapshot() 297 | workspace.load_rlbench_demos() 298 | workspace.train() 299 | 300 | 301 | if __name__ == "__main__": 302 | main() 303 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch import distributions as pyd 9 | from torch.distributions.utils import _standard_normal 10 | 11 | 12 | class eval_mode: 13 | def __init__(self, *models): 14 | self.models = models 15 | 16 | def __enter__(self): 17 | self.prev_states = [] 18 | for model in self.models: 19 | self.prev_states.append(model.training) 20 | model.train(False) 21 | 22 | def __exit__(self, *args): 23 | for model, state in zip(self.models, self.prev_states): 24 | model.train(state) 25 | return False 26 | 27 | 28 | def set_seed_everywhere(seed): 29 | torch.manual_seed(seed) 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed_all(seed) 32 | np.random.seed(seed) 33 | random.seed(seed) 34 | 35 | 36 | def soft_update_params(net, target_net, tau): 37 | for param, target_param in zip(net.parameters(), target_net.parameters()): 38 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 39 | 40 | 41 | def to_torch(xs, device): 42 | return tuple(torch.as_tensor(x, device=device) for x in xs) 43 | 44 | 45 | def weight_init(m): 46 | if isinstance(m, nn.Linear): 47 | nn.init.orthogonal_(m.weight.data) 48 | if hasattr(m.bias, "data"): 49 | m.bias.data.fill_(0.0) 50 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 51 | gain = nn.init.calculate_gain("relu") 52 | nn.init.orthogonal_(m.weight.data, gain) 53 | if hasattr(m.bias, "data"): 54 | m.bias.data.fill_(0.0) 55 | elif isinstance(m, nn.LayerNorm): 56 | m.weight.data.fill_(1.0) 57 | if hasattr(m.bias, "data"): 58 | m.bias.data.fill_(0.0) 59 | 60 | 61 | class Until: 62 | def __init__(self, until, action_repeat=1): 63 | self._until = until 64 | self._action_repeat = action_repeat 65 | 66 | def __call__(self, step): 67 | if self._until is None: 68 | return True 69 | until = self._until // self._action_repeat 70 | return step < until 71 | 72 | 73 | class Every: 74 | def __init__(self, every, action_repeat=1): 75 | self._every = every 76 | self._action_repeat = action_repeat 77 | 78 | def __call__(self, step): 79 | if self._every is None: 80 | return False 81 | every = self._every // self._action_repeat 82 | if step % every == 0: 83 | return True 84 | return False 85 | 86 | 87 | class Timer: 88 | def __init__(self): 89 | self._start_time = time.time() 90 | self._last_time = time.time() 91 | 92 | def reset(self): 93 | elapsed_time = time.time() - self._last_time 94 | self._last_time = time.time() 95 | total_time = time.time() - self._start_time 96 | return elapsed_time, total_time 97 | 98 | def total_time(self): 99 | return time.time() - self._start_time 100 | 101 | 102 | class TruncatedNormal(pyd.Normal): 103 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 104 | super().__init__(loc, scale, validate_args=False) 105 | self.low = low 106 | self.high = high 107 | self.eps = eps 108 | 109 | def _clamp(self, x): 110 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 111 | x = x - x.detach() + clamped_x.detach() 112 | return x 113 | 114 | def sample(self, clip=None, sample_shape=torch.Size()): 115 | shape = self._extended_shape(sample_shape) 116 | eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 117 | eps *= self.scale 118 | if clip is not None: 119 | eps = torch.clamp(eps, -clip, clip) 120 | x = self.loc + eps 121 | return self._clamp(x) 122 | 123 | 124 | def schedule(schdl, step): 125 | try: 126 | return float(schdl) 127 | except ValueError: 128 | match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) 129 | if match: 130 | init, final, duration = [float(g) for g in match.groups()] 131 | mix = np.clip(step / duration, 0.0, 1.0) 132 | return (1.0 - mix) * init + mix * final 133 | match = re.match(r"step_linear\((.+),(.+),(.+),(.+),(.+)\)", schdl) 134 | if match: 135 | init, final1, duration1, final2, duration2 = [ 136 | float(g) for g in match.groups() 137 | ] 138 | if step <= duration1: 139 | mix = np.clip(step / duration1, 0.0, 1.0) 140 | return (1.0 - mix) * init + mix * final1 141 | else: 142 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 143 | return (1.0 - mix) * final1 + mix * final2 144 | raise NotImplementedError(schdl) 145 | 146 | 147 | class DemoMergedIterator: 148 | def __init__(self, replay_iter, demo_replay_iter): 149 | self.replay_iter = replay_iter 150 | self.demo_replay_iter = demo_replay_iter 151 | 152 | def __iter__(self): 153 | return self 154 | 155 | def __next__(self): 156 | items1 = next(self.replay_iter) 157 | items2 = next(self.demo_replay_iter) 158 | 159 | new_items = [] 160 | for i in range(len(items1)): 161 | new_items.append(np.concatenate([items1[i], items2[i]], 0)) 162 | return tuple(new_items) 163 | 164 | 165 | """ 166 | For distributional critic: https://arxiv.org/pdf/1707.06887.pdf 167 | """ 168 | 169 | 170 | def signed_hyperbolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: 171 | """Signed hyperbolic transform, inverse of signed_parabolic.""" 172 | return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x 173 | 174 | 175 | def signed_parabolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: 176 | """Signed parabolic transform, inverse of signed_hyperbolic.""" 177 | z = torch.sqrt(1 + 4 * eps * (eps + 1 + torch.abs(x))) / 2 / eps - 1 / 2 / eps 178 | return torch.sign(x) * (torch.square(z) - 1) 179 | 180 | 181 | def from_categorical( 182 | distribution, limit=20, offset=0.0, logits=True, transformation=True 183 | ): 184 | distribution = distribution.float().squeeze(-1) # Avoid any fp16 shenanigans 185 | if logits: 186 | distribution = torch.softmax(distribution, -1) 187 | num_atoms = distribution.shape[-1] 188 | shift = limit * 2 / (num_atoms - 1) 189 | weights = ( 190 | torch.linspace( 191 | -(num_atoms // 2), num_atoms // 2, num_atoms, device=distribution.device 192 | ) 193 | .float() 194 | .unsqueeze(-1) 195 | ) 196 | if transformation: 197 | out = signed_parabolic((distribution @ weights) * shift) - offset 198 | else: 199 | out = (distribution @ weights) * shift - offset 200 | return out 201 | 202 | 203 | def to_categorical(value, limit=20, offset=0.0, num_atoms=251, transformation=True): 204 | value = value.float() + offset # Avoid any fp16 shenanigans 205 | shift = limit * 2 / (num_atoms - 1) 206 | if transformation: 207 | value = signed_hyperbolic(value) / shift 208 | else: 209 | value = value / shift 210 | value = value.clamp(-(num_atoms // 2), num_atoms // 2) 211 | distribution = torch.zeros(value.shape[0], num_atoms, 1, device=value.device) 212 | lower = value.floor().long() + num_atoms // 2 213 | upper = value.ceil().long() + num_atoms // 2 214 | upper_weight = value % 1 215 | lower_weight = 1 - upper_weight 216 | distribution.scatter_add_(-2, lower.unsqueeze(-1), lower_weight.unsqueeze(-1)) 217 | distribution.scatter_add_(-2, upper.unsqueeze(-1), upper_weight.unsqueeze(-1)) 218 | return distribution 219 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | 4 | 5 | class VideoRecorder: 6 | def __init__(self, root_dir, render_size=256, fps=20): 7 | if root_dir is not None: 8 | self.save_dir = root_dir / "eval_video" 9 | self.save_dir.mkdir(exist_ok=True) 10 | else: 11 | self.save_dir = None 12 | 13 | self.render_size = render_size 14 | self.fps = fps 15 | self.frames = [] 16 | 17 | def init(self, env, enabled=True): 18 | self.frames = [] 19 | self.enabled = self.save_dir is not None and enabled 20 | self.record(env) 21 | 22 | def record(self, env): 23 | if self.enabled: 24 | if hasattr(env, "physics"): 25 | frame = env.physics.render( 26 | height=self.render_size, width=self.render_size, camera_id=0 27 | ) 28 | else: 29 | frame = env.render() 30 | self.frames.append(frame) 31 | 32 | def save(self, file_name): 33 | if self.enabled: 34 | path = self.save_dir / file_name 35 | imageio.mimsave(str(path), self.frames, fps=self.fps) 36 | 37 | 38 | class TrainVideoRecorder: 39 | def __init__(self, root_dir, render_size=256, fps=20): 40 | if root_dir is not None: 41 | self.save_dir = root_dir / "train_video" 42 | self.save_dir.mkdir(exist_ok=True) 43 | else: 44 | self.save_dir = None 45 | 46 | self.render_size = render_size 47 | self.fps = fps 48 | self.frames = [] 49 | 50 | def init(self, obs, enabled=True): 51 | self.frames = [] 52 | self.enabled = self.save_dir is not None and enabled 53 | self.record(obs) 54 | 55 | def record(self, obs): 56 | if self.enabled: 57 | frame = cv2.resize( 58 | obs[-3:].transpose(1, 2, 0), 59 | dsize=(self.render_size, self.render_size), 60 | interpolation=cv2.INTER_CUBIC, 61 | ) 62 | self.frames.append(frame) 63 | 64 | def save(self, file_name): 65 | if self.enabled: 66 | path = self.save_dir / file_name 67 | imageio.mimsave(str(path), self.frames, fps=self.fps) 68 | --------------------------------------------------------------------------------