├── agent ├── env.png ├── bc.yaml ├── crr.yaml ├── td3.yaml ├── td3_bc.yaml ├── ddpg.yaml ├── cql.yaml ├── cql_cdsz.yaml ├── cql_cds.yaml ├── pbrl.yaml ├── bc.py ├── td3_bc.py ├── td3.py ├── crr.py ├── ddpg.py ├── pbrl.py ├── cql_cdsz.py ├── cql_cds.py └── cql.py ├── custom_dmc_tasks ├── common │ ├── visual.xml │ ├── skybox.xml │ └── materials.xml ├── test.py ├── __init__.py ├── point_mass_maze_reach_bottom_left.xml ├── point_mass_maze_reach_top_left.xml ├── point_mass_maze_reach_top_right.xml ├── point_mass_maze_reach_bottom_right.xml ├── hopper.xml ├── walker.xml ├── test.xml ├── cheetah.xml ├── cheetah.py ├── point_mass_maze.py ├── hopper.py ├── jaco.py └── walker.py ├── conda_env.yml ├── LICENSE ├── task.json ├── config_cds.yaml ├── config_single.yaml ├── config.yaml ├── collect_data.yaml ├── video.py ├── train_offline_single.py ├── README.md ├── train_offline_share.py ├── replay_buffer.py ├── train_offline_cds.py ├── logger.py ├── replay_buffer_collect.py └── collect_data.py /agent/env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Baichenjia/UTDS/HEAD/agent/env.png -------------------------------------------------------------------------------- /custom_dmc_tasks/common/visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /custom_dmc_tasks/common/skybox.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /agent/bc.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.bc.BCAgent 3 | name: bc 4 | obs_shape: ??? # to be specified later 5 | action_shape: ??? # to be specified later 6 | device: ${device} 7 | lr: 1e-4 8 | use_tb: ${use_tb} 9 | hidden_dim: 1024 10 | batch_size: 1024 # 256 for pixels 11 | has_next_action: False 12 | stddev_schedule: 0.2 -------------------------------------------------------------------------------- /agent/crr.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.crr.CRRAgent 3 | name: crr 4 | obs_shape: ??? # to be specified later 5 | action_shape: ??? # to be specified later 6 | device: ${device} 7 | lr: 1e-4 8 | critic_target_tau: 0.01 9 | use_tb: ${use_tb} 10 | hidden_dim: 1024 11 | stddev_schedule: 0.2 12 | stddev_clip: 0.3 13 | nstep: 1 14 | batch_size: 1024 # 256 for pixels 15 | num_value_samples: 10 16 | weight_func: indicator 17 | has_next_action: False -------------------------------------------------------------------------------- /agent/td3.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.td3.TD3Agent 3 | name: td3 4 | #obs_type: ??? # to be specified later 5 | obs_shape: ??? # to be specified later 6 | action_shape: ??? # to be specified later 7 | device: ${device} 8 | lr: 1e-4 9 | critic_target_tau: 0.01 10 | use_tb: ${use_tb} 11 | hidden_dim: 1024 12 | stddev_schedule: 0.2 13 | stddev_clip: 0.3 14 | nstep: 1 15 | batch_size: 1024 # 256 for pixels 16 | has_next_action: False 17 | num_expl_steps: ??? # to be specified later 18 | -------------------------------------------------------------------------------- /agent/td3_bc.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.td3_bc.TD3BCAgent 3 | name: td3_bc 4 | obs_shape: ??? # to be specified later 5 | action_shape: ??? # to be specified later 6 | device: ${device} 7 | alpha: 2.5 8 | lr: 1e-4 9 | critic_target_tau: 0.005 10 | actor_target_tau: 0.005 11 | policy_freq: 2 12 | use_tb: ${use_tb} 13 | hidden_dim: 1024 14 | #stddev_schedule: 0.2 15 | #stddev_clip: 0.3 16 | policy_noise: 0.2 17 | noise_clip: 0.5 18 | #nstep: 1 19 | batch_size: 1024 20 | #has_next_action: False 21 | num_expl_steps: ??? # to be specified later 22 | -------------------------------------------------------------------------------- /agent/ddpg.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.ddpg.DDPGAgent 3 | name: ddpg 4 | #reward_free: ${reward_free} 5 | obs_type: ??? # to be specified later 6 | obs_shape: ??? # to be specified later 7 | action_shape: ??? # to be specified later 8 | device: ${device} 9 | lr: 1e-4 10 | critic_target_tau: 0.01 11 | update_every_steps: 2 12 | use_tb: ${use_tb} 13 | use_wandb: ${use_wandb} 14 | num_expl_steps: ??? # to be specified later 15 | hidden_dim: 1024 16 | feature_dim: 50 17 | stddev_schedule: 0.2 18 | stddev_clip: 0.3 19 | nstep: 3 20 | batch_size: 1024 # 256 for pixels 21 | init_critic: true 22 | update_encoder: ${update_encoder} 23 | -------------------------------------------------------------------------------- /agent/cql.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.cql.CQLAgent 3 | name: cql 4 | obs_shape: ??? # to be specified later 5 | action_shape: ??? # to be specified later 6 | device: ${device} 7 | #lr: 1e-4 8 | actor_lr: 1e-4 9 | critic_lr: 3e-4 10 | critic_target_tau: 0.01 11 | 12 | n_samples: 3 13 | use_critic_lagrange: False 14 | alpha: 50 # used if use_critic_lagrange is False 15 | target_cql_penalty: 5.0 # used if use_critic_lagrange is True 16 | 17 | use_tb: True 18 | hidden_dim: 256 19 | #stddev_schedule: 0.2 20 | #stddev_clip: 0.3 21 | nstep: 1 22 | batch_size: 1024 23 | has_next_action: False 24 | 25 | num_expl_steps: ??? # to be specified later -------------------------------------------------------------------------------- /agent/cql_cdsz.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.cql_cdsz.CQLCDSZeroAgent 3 | name: cql_cdsz 4 | obs_shape: ??? # to be specified later 5 | action_shape: ??? # to be specified later 6 | device: ${device} 7 | #lr: 1e-4 8 | actor_lr: 1e-4 9 | critic_lr: 3e-4 10 | critic_target_tau: 0.01 11 | 12 | n_samples: 3 13 | use_critic_lagrange: False 14 | alpha: 50 # used if use_critic_lagrange is False 15 | target_cql_penalty: 5.0 # used if use_critic_lagrange is True 16 | 17 | use_tb: True 18 | hidden_dim: 256 19 | #stddev_schedule: 0.2 20 | #stddev_clip: 0.3 21 | nstep: 1 22 | batch_size: 1024 23 | has_next_action: False 24 | 25 | num_expl_steps: ??? # to be specified later -------------------------------------------------------------------------------- /agent/cql_cds.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.cql_cds.CQLCDSAgent 3 | name: cql_cds 4 | obs_shape: ??? # to be specified later 5 | action_shape: ??? # to be specified later 6 | device: ${device} 7 | #lr: 1e-4 8 | actor_lr: 1e-4 9 | critic_lr: 3e-4 10 | critic_target_tau: 0.01 11 | 12 | n_samples: 3 13 | use_critic_lagrange: False 14 | alpha: 50 # used if use_critic_lagrange is False 15 | target_cql_penalty: 5.0 # used if use_critic_lagrange is True 16 | 17 | use_tb: True 18 | hidden_dim: 256 # 1024 19 | #stddev_schedule: 0.2 20 | #stddev_clip: 0.3 21 | nstep: 1 22 | batch_size: 1024 # 1024 23 | has_next_action: False 24 | 25 | num_expl_steps: ??? # to be specified later -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: utds 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.8 6 | - pip=21.1.3 7 | - numpy=1.19.2 8 | - absl-py=0.13.0 9 | - pyparsing=2.4.7 10 | - jupyterlab=3.0.14 11 | - scikit-image=0.18.1 12 | - nvidia::cudatoolkit=11.1 13 | - pytorch::pytorch 14 | - pytorch::torchvision 15 | - pytorch::torchaudio 16 | - pip: 17 | - termcolor==1.1.0 18 | - git+git://github.com/deepmind/dm_control.git 19 | - tb-nightly 20 | - imageio==2.9.0 21 | - imageio-ffmpeg==0.4.4 22 | - hydra-core==1.1.0 23 | - hydra-submitit-launcher==1.1.5 24 | - pandas==1.3.0 25 | - ipdb==0.13.9 26 | - yapf==0.31.0 27 | - sklearn==0.0 28 | - matplotlib==3.4.2 29 | - opencv-python==4.5.3.56 30 | - moviepy==1.0.3 31 | -------------------------------------------------------------------------------- /agent/pbrl.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.pbrl.PBRLAgent 3 | name: pbrl 4 | obs_shape: ??? # to be specified later 5 | action_shape: ??? # to be specified later 6 | device: ${device} 7 | #alpha: 2.5 8 | lr: 1e-4 9 | critic_target_tau: 0.005 10 | actor_target_tau: 0.005 11 | policy_freq: 2 12 | use_tb: True 13 | hidden_dim: 1024 14 | #stddev_schedule: 0.2 15 | #stddev_clip: 0.3 16 | policy_noise: 0.2 17 | noise_clip: 0.5 18 | #nstep: 1 19 | batch_size: 1024 20 | #has_next_action: False 21 | num_expl_steps: ??? # to be specified later 22 | 23 | # PBRL 24 | num_random: 3 25 | ucb_ratio_in: 0.001 26 | ensemble: 5 27 | ood_noise: 0.01 # action noise for sampling 28 | 29 | ucb_ratio_ood_init: 3.0 # 3.0 30 | ucb_ratio_ood_min: 0.1 # 0.1 31 | ood_decay_factor: 0.99995 # 0.99995 ucb ratio decay factor. 32 | 33 | share_ratio: 1.5 # penalty ratio for shared dataset 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /task.json: -------------------------------------------------------------------------------- 1 | { 2 | "walker": 3 | { 4 | "stand": {"expert": 986.8, "medium": 458.85}, 5 | "walk": {"expert": 980.0, "medium": 681.4}, 6 | "run": {"expert": 842.6, "medium": 411.0}, 7 | "filp": {"expert": 839.5, "medium": 583.5} 8 | }, 9 | "hopper": 10 | { 11 | "stand": 1000, 12 | "hop": 1000, 13 | "hop_backward": 1000, 14 | "flip": 1000, 15 | "flip_backward": 1000 16 | }, 17 | "cheetah": 18 | { 19 | "run": 1000, 20 | "run_backward": 1000, 21 | "flip": 1000 22 | }, 23 | "quadruped": 24 | { 25 | "jump": 1000, 26 | "run": 1000, 27 | "roll": 1000, 28 | "roll_fast": 1000, 29 | "stand": 1000, 30 | "walk": 1000, 31 | "escape": 1000, 32 | "fetch": 1000 33 | }, 34 | "jaco": 35 | { 36 | "reach_top_left": 1000, 37 | "reach_top_right": 1000, 38 | "reach_bottom_left": 1000, 39 | "reach_bottom_right": 1000 40 | }, 41 | "point_mass_maze": 42 | { 43 | "reach_top_left": 1000, 44 | "reach_top_right": 1000, 45 | "reach_bottom_left": 1000, 46 | "reach_bottom_right": 1000 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /custom_dmc_tasks/common/materials.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /config_cds.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: cql_cds # cql_cds, cql_cdsz 3 | - override hydra/launcher: submitit_local 4 | 5 | # unsupervised exploration 6 | # expl_agent: td3 7 | # task settings 8 | task: walker_walk # main task to train (relable other datasets to this task) 9 | share_task: [walker_walk, walker_run] # task for data sharing 10 | data_type: [medium, medium-replay] # dataset for data sharing (corresponding each share_task) 11 | 12 | discount: 0.99 13 | # train settings 14 | num_grad_steps: 1000000 15 | log_every_steps: 1000 16 | # eval 17 | eval_every_steps: 10000 18 | num_eval_episodes: 10 19 | # dataset 20 | replay_buffer_dir: ../../collect # make sure to update this if you change hydra run dir 21 | replay_buffer_size: 10000000 # max: 10M 22 | replay_buffer_num_workers: 4 23 | batch_size: ${agent.batch_size} 24 | # misc 25 | seed: 1 26 | device: cuda 27 | save_video: False 28 | use_tb: False 29 | 30 | # used for train_offline_single 31 | data_main: expert 32 | 33 | wandb: False 34 | hydra: 35 | run: 36 | dir: ./result_cds/${task}-Share_${share_task[0]}_${share_task[1]}-${data_type[0]}-${agent.name}-${now:%m-%d-%H-%M-%S} 37 | -------------------------------------------------------------------------------- /config_single.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: pbrl # td3_bc, cql, pbrl 3 | - override hydra/launcher: submitit_local 4 | 5 | # unsupervised exploration 6 | # expl_agent: td3 7 | # task settings 8 | task: walker_walk # main task to train (relable other datasets to this task) 9 | share_task: [walker_walk, walker_run] # task for data sharing 10 | data_type: [medium, medium-replay] # dataset for data sharing (corresponding each share_task) 11 | 12 | discount: 0.99 13 | # train settings 14 | num_grad_steps: 1000000 15 | log_every_steps: 1000 16 | # eval 17 | eval_every_steps: 10000 18 | num_eval_episodes: 10 19 | # dataset 20 | replay_buffer_dir: ../../collect # make sure to update this if you change hydra run dir 21 | replay_buffer_size: 10000000 # max: 10M 22 | replay_buffer_num_workers: 4 23 | batch_size: ${agent.batch_size} 24 | # misc 25 | seed: 1 26 | device: cuda 27 | save_video: False 28 | use_tb: False 29 | 30 | # used for train_offline_single 31 | data_main: expert 32 | 33 | wandb: False 34 | hydra: 35 | run: 36 | # dir: ./result_cql/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S} 37 | dir: ./result_pbrl/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S} 38 | # dir: ./output/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S} 39 | -------------------------------------------------------------------------------- /custom_dmc_tasks/test.py: -------------------------------------------------------------------------------- 1 | from dm_control import manipulation, suite 2 | from dm_control.suite.wrappers import action_scale, pixels 3 | from dm_env import StepType, specs 4 | import numpy as np 5 | 6 | env = suite.load("walker", "walk") 7 | env.reset() 8 | action = np.random.random(6) 9 | time_step = env.step(action) 10 | for k, v in time_step.observation.items(): 11 | print("observation:", k, v.shape) 12 | print("original reward:", time_step.reward) 13 | 14 | state = env.physics.get_state() 15 | print("physics:", state.shape) 16 | print("-----") 17 | ########## 18 | 19 | new_env = suite.load("walker", "walk") 20 | reward_spec = new_env.reward_spec() 21 | new_env.reset() 22 | with new_env.physics.reset_context(): 23 | new_env.physics.set_state(state) 24 | new_reward = new_env.task.get_reward(new_env.physics) # 输入是 env 当前的状态 physics, 通过 env.task.get_reward 函数输出奖励 25 | new_reward = np.full(reward_spec.shape, new_reward, reward_spec.dtype) 26 | print("new reward:", new_reward) 27 | 28 | # states = episode['physics'] 29 | # for i in range(states.shape[0]): 30 | # with env.physics.reset_context(): 31 | # env.physics.set_state(states[i]) 32 | # reward = env.task.get_reward(env.physics) # 输入是 env 当前的状态 physics, 通过 env.task.get_reward 函数输出奖励 33 | # reward = np.full(reward_spec.shape, reward, reward_spec.dtype) # 改变shape和dtype 34 | # rewards.append(reward) 35 | # episode['reward'] = np.array(rewards, dtype=reward_spec.dtype) 36 | # return episode 37 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: pbrl # td3_bc, cql, pbrl 3 | - override hydra/launcher: submitit_local 4 | 5 | # unsupervised exploration 6 | # expl_agent: td3 7 | # task settings 8 | task: walker_walk # main task to train (relable other datasets to this task) 9 | share_task: [walker_walk, walker_run] # task for data sharing 10 | data_type: [medium, medium-replay] # dataset for data sharing (corresponding each share_task) 11 | 12 | discount: 0.99 13 | # train settings 14 | num_grad_steps: 1000000 15 | log_every_steps: 1000 16 | # eval 17 | eval_every_steps: 10000 18 | num_eval_episodes: 10 19 | # dataset 20 | replay_buffer_dir: ../../collect # make sure to update this if you change hydra run dir 21 | replay_buffer_size: 10000000 # max: 10M 22 | replay_buffer_num_workers: 4 23 | batch_size: ${agent.batch_size} 24 | # misc 25 | seed: 1 26 | device: cuda 27 | save_video: False 28 | use_tb: False 29 | 30 | # used for train_offline_single 31 | data_main: expert 32 | 33 | wandb: False 34 | hydra: 35 | run: 36 | # dir: ./result_td3bc_share/${task}-Share_${share_task[0]}_${share_task[1]}-Data_${data_type[0]}_${data_type[1]}-${agent.name}-${now:%m-%d-%H-%M-%S} 37 | # dir: ./result_cql/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S} 38 | # dir: ./result_pbrl/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S} 39 | dir: ./result_pbrl_share/${task}-Share_${share_task[0]}_${share_task[1]}-${data_type[0]}-${agent.name}-${now:%m-%d-%H-%M-%S} 40 | # dir: ./output/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S} 41 | 42 | -------------------------------------------------------------------------------- /collect_data.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: td3 3 | - override hydra/launcher: submitit_local 4 | 5 | # mode 6 | # reward_free: false 7 | # task settings 8 | task: walker_run 9 | #obs_type: states # [states, pixels] 10 | #frame_stack: 3 # only works if obs_type=pixels 11 | action_repeat: 1 # set to 2 for pixels 12 | discount: 0.99 13 | # train settings 14 | num_train_frames: 2000010 # 2M steps to converge 15 | num_seed_frames: 4000 16 | # eval 17 | eval_every_frames: 10000 18 | num_eval_episodes: 10 19 | # pretrained 20 | #snapshot_ts: 100000 21 | #snapshot_base_dir: ./pretrained_models 22 | # replay buffer 23 | replay_buffer_size: 1000000 24 | replay_buffer_num_workers: 4 25 | batch_size: ${agent.batch_size} 26 | nstep: ${agent.nstep} 27 | update_encoder: false # can be either true or false depending if we want to fine-tune encoder 28 | # misc 29 | seed: 1 30 | device: cuda 31 | save_video: true 32 | save_train_video: false 33 | use_tb: false 34 | use_wandb: false 35 | # experiment 36 | experiment: exp 37 | 38 | 39 | hydra: 40 | run: 41 | dir: ./collect/${task}-${agent.name}-${now:%m-%d-%H-%-M%S} 42 | # dir: ./collect/${task}-${agent.name}-medium-replay # for collect_data_fixed, medium-replay 43 | # dir: ./collect/${task}-${agent.name}-expert # for collect_data_fixed, expert 44 | # dir: ./video/${task}-${agent.name}-video 45 | sweep: 46 | dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment} 47 | subdir: ${hydra.job.num} 48 | launcher: 49 | timeout_min: 4300 50 | cpus_per_task: 10 51 | gpus_per_node: 1 52 | tasks_per_node: 1 53 | mem_gb: 160 54 | nodes: 1 55 | submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm 56 | -------------------------------------------------------------------------------- /custom_dmc_tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from custom_dmc_tasks import cheetah 2 | from custom_dmc_tasks import walker 3 | from custom_dmc_tasks import hopper 4 | from custom_dmc_tasks import quadruped 5 | from custom_dmc_tasks import jaco 6 | from custom_dmc_tasks import point_mass_maze 7 | 8 | 9 | def make(domain, task, 10 | task_kwargs=None, 11 | environment_kwargs=None, 12 | visualize_reward=False): 13 | 14 | if domain == 'cheetah': 15 | return cheetah.make(task, 16 | task_kwargs=task_kwargs, 17 | environment_kwargs=environment_kwargs, 18 | visualize_reward=visualize_reward) 19 | elif domain == 'walker': 20 | return walker.make(task, 21 | task_kwargs=task_kwargs, 22 | environment_kwargs=environment_kwargs, 23 | visualize_reward=visualize_reward) 24 | elif domain == 'point_mass_maze': 25 | return point_mass_maze.make(task, 26 | task_kwargs=task_kwargs, 27 | environment_kwargs=environment_kwargs, 28 | visualize_reward=visualize_reward) 29 | elif domain == 'hopper': 30 | return hopper.make(task, 31 | task_kwargs=task_kwargs, 32 | environment_kwargs=environment_kwargs, 33 | visualize_reward=visualize_reward) 34 | elif domain == 'quadruped': 35 | return quadruped.make(task, 36 | task_kwargs=task_kwargs, 37 | environment_kwargs=environment_kwargs, 38 | visualize_reward=visualize_reward) 39 | else: 40 | raise f'{task} not found' 41 | 42 | assert None 43 | 44 | 45 | def make_jaco(task, obs_type, seed): 46 | return jaco.make(task, obs_type, seed) -------------------------------------------------------------------------------- /custom_dmc_tasks/point_mass_maze_reach_bottom_left.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /custom_dmc_tasks/point_mass_maze_reach_top_left.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /custom_dmc_tasks/point_mass_maze_reach_top_right.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /custom_dmc_tasks/point_mass_maze_reach_bottom_right.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /custom_dmc_tasks/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /custom_dmc_tasks/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imageio 3 | import numpy as np 4 | import wandb 5 | 6 | 7 | class VideoRecorder: 8 | def __init__(self, 9 | root_dir, 10 | render_size=256, 11 | fps=20, 12 | camera_id=0): 13 | if root_dir is not None: 14 | self.save_dir = root_dir / 'eval_video' 15 | self.save_dir.mkdir(exist_ok=True) 16 | else: 17 | self.save_dir = None 18 | 19 | self.render_size = render_size 20 | self.fps = fps 21 | self.frames = [] 22 | self.camera_id = camera_id 23 | 24 | def init(self, env, enabled=True): 25 | self.frames = [] 26 | self.enabled = self.save_dir is not None and enabled 27 | self.record(env) 28 | 29 | def record(self, env): 30 | if self.enabled: 31 | if hasattr(env, 'physics'): 32 | frame = env.physics.render(height=self.render_size, 33 | width=self.render_size, 34 | camera_id=self.camera_id) 35 | else: 36 | frame = env.render() 37 | self.frames.append(frame) 38 | 39 | def log_to_wandb(self): 40 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 41 | fps, skip = 6, 8 42 | wandb.log({ 43 | 'eval/video': 44 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 45 | }) 46 | 47 | def save(self, file_name): 48 | if self.enabled: 49 | path = self.save_dir / file_name 50 | imageio.mimsave(str(path), self.frames, fps=self.fps) 51 | 52 | 53 | class TrainVideoRecorder: 54 | def __init__(self, 55 | root_dir, 56 | render_size=256, 57 | fps=20, 58 | camera_id=0, 59 | use_wandb=False): 60 | if root_dir is not None: 61 | self.save_dir = root_dir / 'train_video' 62 | self.save_dir.mkdir(exist_ok=True) 63 | else: 64 | self.save_dir = None 65 | 66 | self.render_size = render_size 67 | self.fps = fps 68 | self.frames = [] 69 | self.camera_id = camera_id 70 | self.use_wandb = use_wandb 71 | 72 | def init(self, obs, enabled=True): 73 | self.frames = [] 74 | self.enabled = self.save_dir is not None and enabled 75 | self.record(obs) 76 | 77 | def record(self, obs): 78 | if self.enabled: 79 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0), 80 | dsize=(self.render_size, self.render_size), 81 | interpolation=cv2.INTER_CUBIC) 82 | self.frames.append(frame) 83 | 84 | def log_to_wandb(self): 85 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2)) 86 | fps, skip = 6, 8 87 | wandb.log({ 88 | 'train/video': 89 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif") 90 | }) 91 | 92 | def save(self, file_name): 93 | if self.enabled: 94 | if self.use_wandb: 95 | self.log_to_wandb() 96 | path = self.save_dir / file_name 97 | imageio.mimsave(str(path), self.frames, fps=self.fps) 98 | -------------------------------------------------------------------------------- /agent/bc.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | import functools 8 | 9 | import utils 10 | from dm_control.utils import rewards 11 | 12 | 13 | class Actor(nn.Module): 14 | def __init__(self, obs_dim, action_dim, hidden_dim): 15 | super().__init__() 16 | 17 | self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim), 18 | nn.LayerNorm(hidden_dim), nn.Tanh(), 19 | nn.Linear(hidden_dim, hidden_dim), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(hidden_dim, action_dim)) 22 | 23 | self.apply(utils.weight_init) 24 | 25 | def forward(self, obs, std): 26 | mu = self.policy(obs) 27 | mu = torch.tanh(mu) 28 | std = torch.ones_like(mu) * std 29 | 30 | dist = utils.TruncatedNormal(mu, std) 31 | return dist 32 | 33 | 34 | class BCAgent: 35 | def __init__(self, 36 | name, 37 | obs_shape, 38 | action_shape, 39 | device, 40 | lr, 41 | hidden_dim, 42 | batch_size, 43 | stddev_schedule, 44 | use_tb, 45 | has_next_action=False): 46 | self.action_dim = action_shape[0] 47 | self.hidden_dim = hidden_dim 48 | self.lr = lr 49 | self.device = device 50 | self.stddev_schedule = stddev_schedule 51 | self.use_tb = use_tb 52 | 53 | # models 54 | self.actor = Actor(obs_shape[0], action_shape[0], 55 | hidden_dim).to(device) 56 | 57 | # optimizers 58 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 59 | 60 | self.train() 61 | 62 | def train(self, training=True): 63 | self.training = training 64 | self.actor.train(training) 65 | 66 | def act(self, obs, step, eval_mode): 67 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 68 | stddev = utils.schedule(self.stddev_schedule, step) 69 | policy = self.actor(obs, stddev) 70 | if eval_mode: 71 | action = policy.mean 72 | else: 73 | action = policy.sample(clip=None) 74 | if step < self.num_expl_steps: 75 | action.uniform_(-1.0, 1.0) 76 | return action.cpu().numpy()[0] 77 | 78 | def update_actor(self, obs, action, step): 79 | metrics = dict() 80 | 81 | stddev = utils.schedule(self.stddev_schedule, step) 82 | policy = self.actor(obs, stddev) 83 | 84 | log_prob = policy.log_prob(action).sum(-1, keepdim=True) 85 | actor_loss = (-log_prob).mean() 86 | 87 | self.actor_opt.zero_grad(set_to_none=True) 88 | actor_loss.backward() 89 | self.actor_opt.step() 90 | 91 | if self.use_tb: 92 | metrics['actor_loss'] = actor_loss.item() 93 | metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item() 94 | 95 | return metrics 96 | 97 | def update(self, replay_iter, step): 98 | metrics = dict() 99 | 100 | batch = next(replay_iter) 101 | obs, action, reward, discount, next_obs = utils.to_torch( 102 | batch, self.device) 103 | 104 | if self.use_tb: 105 | metrics['batch_reward'] = reward.mean().item() 106 | 107 | # update actor 108 | metrics.update(self.update_actor(obs, action, step)) 109 | 110 | return metrics 111 | -------------------------------------------------------------------------------- /custom_dmc_tasks/test.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /custom_dmc_tasks/cheetah.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /train_offline_single.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | import os 5 | import random 6 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 7 | os.environ['MUJOCO_GL'] = 'egl' 8 | import json 9 | from pathlib import Path 10 | import hydra 11 | import numpy as np 12 | import torch 13 | from dm_env import specs 14 | import dmc 15 | import utils 16 | from logger import Logger 17 | from replay_buffer import make_replay_loader 18 | from video import VideoRecorder 19 | import wandb 20 | from omegaconf import OmegaConf 21 | 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | with open("task.json", "r") as f: 26 | task_dict = json.load(f) 27 | 28 | 29 | def get_domain(task): 30 | if task.startswith('point_mass_maze'): 31 | return 'point_mass_maze' 32 | return task.split('_', 1)[0] 33 | 34 | 35 | def get_data_seed(seed, num_data_seeds): 36 | return (seed - 1) % num_data_seeds + 1 37 | 38 | 39 | def eval(global_step, agent, env, logger, num_eval_episodes, video_recorder): 40 | step, episode, total_reward = 0, 0, 0 41 | eval_until_episode = utils.Until(num_eval_episodes) 42 | while eval_until_episode(episode): 43 | time_step = env.reset() 44 | video_recorder.init(env, enabled=(episode == 0)) 45 | while not time_step.last(): 46 | with torch.no_grad(), utils.eval_mode(agent): 47 | action = agent.act(time_step.observation, step=global_step, eval_mode=True) 48 | time_step = env.step(action) 49 | video_recorder.record(env) 50 | total_reward += time_step.reward 51 | step += 1 52 | 53 | episode += 1 54 | video_recorder.save(f'{global_step}.mp4') 55 | 56 | with logger.log_and_dump_ctx(global_step, ty='eval') as log: 57 | log('episode_reward', total_reward / episode) 58 | log('episode_length', step / episode) 59 | log('step', global_step) 60 | return {"eval_episode_reward": total_reward / episode} 61 | 62 | 63 | @hydra.main(config_path='.', config_name='config_single') 64 | def main(cfg): 65 | work_dir = Path.cwd() 66 | print(f'workspace: {work_dir}') 67 | 68 | # random seeds 69 | cfg.seed = random.randint(0, 100000) 70 | 71 | utils.set_seed_everywhere(cfg.seed) 72 | device = torch.device(cfg.device) 73 | 74 | # create logger 75 | logger = Logger(work_dir, use_tb=cfg.use_tb) 76 | 77 | # create envs 78 | env = dmc.make(cfg.task, seed=cfg.seed) 79 | 80 | # create agent 81 | agent = hydra.utils.instantiate(cfg.agent, 82 | obs_shape=env.observation_spec().shape, action_shape=env.action_spec().shape, 83 | num_expl_steps=0) 84 | 85 | # create replay buffer 86 | replay_dir_list = [] 87 | datasets_dir = work_dir / cfg.replay_buffer_dir 88 | replay_dir = datasets_dir.resolve() / Path(cfg.task+"-td3-"+str(cfg.data_main)) / 'data' 89 | print(f'replay dir: {replay_dir}') 90 | replay_dir_list.append(replay_dir) 91 | 92 | # 构建 replay buffer (single task) 93 | replay_loader = make_replay_loader(env, replay_dir_list, cfg.replay_buffer_size, 94 | cfg.batch_size, cfg.replay_buffer_num_workers, cfg.discount, 95 | main_task=cfg.task, task_list=[cfg.task]) 96 | replay_iter = iter(replay_loader) # OfflineReplayBuffer.sample function 97 | print("load data done.") 98 | 99 | # create video recorders 100 | video_recorder = VideoRecorder(work_dir if cfg.save_video else None) 101 | 102 | timer = utils.Timer() 103 | global_step = 0 104 | 105 | train_until_step = utils.Until(cfg.num_grad_steps) 106 | eval_every_step = utils.Every(cfg.eval_every_steps) 107 | log_every_step = utils.Every(cfg.log_every_steps) 108 | 109 | if cfg.wandb: 110 | wandb_dir = f"./wandb/{cfg.task}_{cfg.agent.name}_{cfg.data_main}_{cfg.seed}" 111 | if not os.path.exists(wandb_dir): 112 | os.makedirs(wandb_dir) 113 | wandb.init(project="UTDS", entity='', config=cfg, group=f'{cfg.task}_{cfg.agent.name}_{cfg.data_main}', 114 | name=f'{cfg.task}_{cfg.agent.name}_{cfg.data_main}', dir=wandb_dir) 115 | wandb.config.update(vars(cfg)) 116 | 117 | while train_until_step(global_step): 118 | # try to evaluate 119 | if eval_every_step(global_step): 120 | logger.log('eval_total_time', timer.total_time(), global_step) 121 | eval_metrics = eval(global_step, agent, env, logger, cfg.num_eval_episodes, video_recorder) 122 | if cfg.wandb: 123 | wandb.log(eval_metrics) 124 | 125 | # train the agent 126 | metrics = agent.update(replay_iter, global_step, cfg.num_grad_steps) 127 | if cfg.wandb: 128 | wandb.log(metrics) 129 | 130 | # log 131 | logger.log_metrics(metrics, global_step, ty='train') 132 | if log_every_step(global_step): 133 | elapsed_time, total_time = timer.reset() 134 | with logger.log_and_dump_ctx(global_step, ty='train') as log: 135 | log('fps', cfg.log_every_steps / elapsed_time) 136 | log('total_time', total_time) 137 | log('step', global_step) 138 | 139 | global_step += 1 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pessimistic Value Iteration for Multi-Task Data Sharing 2 | 3 | This repo contains a PyTorch implementation and the datasets for our paper titled "Pessimistic Value Iteration for Multi-Task Data Sharing in Offline Reinforcement Learning" published at *Artificial Intelligence* Journal. This is the paper [Link](https://www.sciencedirect.com/science/article/pii/S0004370223001947). 4 | 5 | ## Datasets 6 | 7 | We collect a Multi-Task Offline [Dataset](http://alchemist.wang-research.northwestern.edu/dataset/) based on DeepMind Control Suite (DMC). 8 | - Download the [Dataset](http://alchemist.wang-research.northwestern.edu/dataset/) to `./collect` before you start training. 9 | - The users can collect new datasets based on `collect_daeta.py`. The supported tasks 10 | include standard tasks from DMC and custom tasks from `./custom_dmc_tasks/` 11 | 12 | Our dataset contains 3 domains with 4 tasks per domain, resulting in 12 tasks in total. 13 | 14 | | Domain | Available task names | 15 | |---|---| 16 | | Walker | `walker_stand`, `walker_walk`, `walker_run`, `walker_flip` | 17 | | Quadruped | `quadruped_jump`, `quadruped_roll_fast` | `quadruped_walk` | `quadruped_run` | 18 | | Jaco Arm | `jaco_reach_top_left`, `jaco_reach_top_right`, `jaco_reach_bottom_left`, `jaco_reach_bottom_right` | 19 | 20 | ![alt tasks](agent/env.png "tasks") 21 | 22 | 23 | For each task, we run `TD3` to collect five types of datasets, including: 24 | - `random` data generated by a random agent. 25 | - `medium` data generated by a medium-level TD3 agent. 26 | - `medium-replay` data that collects all experiences in training a medium-level TD3 agent. 27 | - `medium-expert` data that collects all experiences in training an expert-level TD3 agent. 28 | - `expert` data generated by an expert-level TD3 agent. 29 | 30 | 31 | ## Prerequisites 32 | 33 | Install [MuJoCo](http://www.mujoco.org/): 34 | 35 | * Download MuJoCo binaries [here](https://mujoco.org/download). 36 | * Unzip the downloaded archive into `~/.mujoco/`. 37 | * Append the MuJoCo subdirectory bin path into the env variable `LD_LIBRARY_PATH`. 38 | 39 | Install the following libraries: 40 | ```sh 41 | sudo apt update 42 | sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 unzip 43 | ``` 44 | 45 | Install dependencies: 46 | ```sh 47 | conda env create -f conda_env.yml 48 | conda activate utds 49 | ``` 50 | 51 | ## Algorithms 52 | 53 | We provide several algorithms to train the *single-agent* and *multi-task data-sharing agent*. 54 | 55 | - For single-agent training, we provide the following algorithms. 56 | 57 | | Algorithm | Name | Paper | 58 | |---|---|---| 59 | | Behavior Cloning | `bc` | [paper](https://arxiv.org/abs/1805.01954)| 60 | | CQL | `cql` | [paper](https://proceedings.neurips.cc/paper/2020/file/0d2b2061826a5df3221116a5085a6052-Paper.pdf)| 61 | | TD3-BC | `td3_bc` | [paper](https://proceedings.neurips.cc/paper/2021/hash/a8166da05c5a094f7dc03724b41886e5-Abstract.html)| 62 | | CRR | `crr` | [paper](http://proceedings.neurips.cc/paper/2020/file/588cb956d6bbe67078f29f8de420a13d-Paper.pdf)| 63 | | PBRL | `ddpg` | [paper](https://openreview.net/forum?id=Y4cs1Z3HnqL)| 64 | 65 | - For multi-task data sharing, we support the following algorithms. 66 | 67 | 68 | | Algorithm | Name | Paper | 69 | |---|---|---| 70 | | Direct Sharing | `cql` | [paper](https://proceedings.neurips.cc/paper/2020/file/0d2b2061826a5df3221116a5085a6052-Paper.pdf)| 71 | | CDS | `cql_cds` | [paper](https://proceedings.neurips.cc/paper/2021/hash/5fd2c06f558321eff612bbbe455f6fbd-Abstract.html)| 72 | | Unlabeled-CDS | `cql_cdsz` | [paper](https://proceedings.mlr.press/v162/yu22c/yu22c.pdf)| 73 | | UTDS | `pbrl` |our paper| 74 | 75 | 76 | ## Training 77 | 78 | #### Train CDS 79 | 80 | Train the CDS agent in `quadruped_jump (random)` task with data sharing from `quadruped_roll_fast (replay)` dataset, run 81 | ``` 82 | python train_offline_cds.py task=quadruped_jump "+share_task=[quadruped_jump, quadruped_roll_fast]" "+data_type=[random, replay]" 83 | ``` 84 | 85 | #### Train UTDS 86 | 87 | Train the CDS agent in `quadruped_jump (random)` task with data sharing from `quadruped_roll_fast (replay)` dataset, run 88 | ``` 89 | python train_offline_share.py task=quadruped_jump "+share_task=[quadruped_jump, quadruped_roll_fast]" "+data_type=[random, replay]" 90 | ``` 91 | 92 | We support wandb by setting `wandb: True` in `config*.yaml` file. 93 | 94 | ## Citation 95 | ``` 96 | @article{UTDS2023, 97 | title = {Pessimistic Value Iteration for Multi-Task Data Sharing in Offline Reinforcement Learning}, 98 | journal = {Artificial Intelligence}, 99 | author = {Chenjia Bai and Lingxiao Wang and Jianye Hao and Zhuoran Yang and Bin Zhao and Zhen Wang and Xuelong Li}, 100 | pages = {104048}, 101 | year = {2023}, 102 | issn = {0004-3702}, 103 | doi = {https://doi.org/10.1016/j.artint.2023.104048}, 104 | url = {https://www.sciencedirect.com/science/article/pii/S0004370223001947}, 105 | } 106 | ``` 107 | 108 | ## License 109 | MIT license 110 | -------------------------------------------------------------------------------- /train_offline_share.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import os 6 | import random 7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 8 | os.environ['MUJOCO_GL'] = 'egl' 9 | import json 10 | from pathlib import Path 11 | import hydra 12 | import numpy as np 13 | import torch 14 | from dm_env import specs 15 | import dmc 16 | import utils 17 | from logger import Logger 18 | from replay_buffer import make_replay_loader 19 | from video import VideoRecorder 20 | import wandb 21 | 22 | torch.backends.cudnn.benchmark = True 23 | 24 | with open("task.json", "r") as f: 25 | task_dict = json.load(f) 26 | 27 | 28 | def get_domain(task): 29 | if task.startswith('point_mass_maze'): 30 | return 'point_mass_maze' 31 | return task.split('_', 1)[0] 32 | 33 | 34 | def get_data_seed(seed, num_data_seeds): 35 | return (seed - 1) % num_data_seeds + 1 36 | 37 | 38 | def eval(global_step, agent, env, logger, num_eval_episodes, video_recorder): 39 | step, episode, total_reward = 0, 0, 0 40 | eval_until_episode = utils.Until(num_eval_episodes) 41 | while eval_until_episode(episode): 42 | time_step = env.reset() 43 | video_recorder.init(env, enabled=(episode == 0)) 44 | while not time_step.last(): 45 | with torch.no_grad(), utils.eval_mode(agent): 46 | action = agent.act(time_step.observation, step=global_step, eval_mode=True) 47 | time_step = env.step(action) 48 | video_recorder.record(env) 49 | total_reward += time_step.reward 50 | step += 1 51 | 52 | episode += 1 53 | video_recorder.save(f'{global_step}.mp4') 54 | 55 | with logger.log_and_dump_ctx(global_step, ty='eval') as log: 56 | log('episode_reward', total_reward / episode) 57 | log('episode_length', step / episode) 58 | log('step', global_step) 59 | 60 | 61 | @hydra.main(config_path='.', config_name='config') 62 | def main(cfg): 63 | work_dir = Path.cwd() 64 | print(f'workspace: {work_dir}') 65 | 66 | # random seeds 67 | cfg.seed = random.randint(0, 100000) 68 | 69 | utils.set_seed_everywhere(cfg.seed) 70 | device = torch.device(cfg.device) 71 | 72 | # create logger 73 | logger = Logger(work_dir, use_tb=cfg.use_tb) 74 | 75 | # create envs 76 | env = dmc.make(cfg.task, seed=cfg.seed) 77 | 78 | # create agent 79 | agent = hydra.utils.instantiate(cfg.agent, 80 | obs_shape=env.observation_spec().shape, action_shape=env.action_spec().shape, 81 | num_expl_steps=0) 82 | 83 | replay_dir_list = [] 84 | 85 | for task_id in range(len(cfg.share_task)): 86 | task = cfg.share_task[task_id] # dataset task 87 | data_type = cfg.data_type[task_id] # dataset type [random, medium, medium-replay, expert, replay] 88 | datasets_dir = work_dir / cfg.replay_buffer_dir 89 | replay_dir = datasets_dir.resolve() / Path(task+"-td3-"+str(data_type)) / 'data' 90 | print(f'replay dir: {replay_dir}') 91 | replay_dir_list.append(replay_dir) 92 | 93 | # construct the replay buffer. env is the main task, we use it to relabel the reward of other tasks 94 | replay_loader = make_replay_loader(env, replay_dir_list, cfg.replay_buffer_size, 95 | cfg.batch_size, cfg.replay_buffer_num_workers, cfg.discount, 96 | main_task=cfg.task, task_list=cfg.share_task) 97 | replay_iter = iter(replay_loader) # OfflineReplayBuffer sample function 98 | print("load data done.") 99 | 100 | # for i in replay_iter: 101 | # print(i) 102 | # break 103 | 104 | # create video recorders 105 | video_recorder = VideoRecorder(work_dir if cfg.save_video else None) 106 | 107 | timer = utils.Timer() 108 | global_step = 0 109 | 110 | train_until_step = utils.Until(cfg.num_grad_steps) 111 | eval_every_step = utils.Every(cfg.eval_every_steps) 112 | log_every_step = utils.Every(cfg.log_every_steps) 113 | 114 | if cfg.wandb: 115 | path_str = f'{cfg.agent.name}_{cfg.share_task[0]}_{cfg.share_task[1]}_{cfg.data_type[0]}_{cfg.data_type[1]}' 116 | wandb_dir = f"./wandb/{path_str}_{cfg.seed}" 117 | if not os.path.exists(wandb_dir): 118 | os.makedirs(wandb_dir) 119 | wandb.init(project="UTDS", entity='', config=cfg, name=f'{path_str}_1', dir=wandb_dir) 120 | wandb.config.update(vars(cfg)) 121 | 122 | while train_until_step(global_step): 123 | # try to evaluate 124 | if eval_every_step(global_step): 125 | logger.log('eval_total_time', timer.total_time(), global_step) 126 | eval(global_step, agent, env, logger, cfg.num_eval_episodes, video_recorder) 127 | 128 | # train the agent 129 | metrics = agent.update(replay_iter, global_step, cfg.num_grad_steps) 130 | 131 | # log 132 | logger.log_metrics(metrics, global_step, ty='train') 133 | if log_every_step(global_step): 134 | elapsed_time, total_time = timer.reset() 135 | with logger.log_and_dump_ctx(global_step, ty='train') as log: 136 | log('fps', cfg.log_every_steps / elapsed_time) 137 | log('total_time', total_time) 138 | log('step', global_step) 139 | 140 | global_step += 1 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | import copy 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import IterableDataset 12 | 13 | 14 | def episode_len(episode): 15 | # subtract -1 because the dummy first transition 16 | return next(iter(episode.values())).shape[0] - 1 17 | 18 | 19 | def save_episode(episode, fn): 20 | with io.BytesIO() as bs: 21 | np.savez_compressed(bs, **episode) 22 | bs.seek(0) 23 | with fn.open('wb') as f: 24 | f.write(bs.read()) 25 | 26 | 27 | def load_episode(fn): 28 | with fn.open('rb') as f: 29 | episode = np.load(f) 30 | episode = {k: episode[k] for k in episode.keys()} 31 | return episode 32 | 33 | 34 | def relable_episode(env, episode): # relabel the reward function 35 | rewards = [] 36 | reward_spec = env.reward_spec() 37 | states = episode['physics'] 38 | for i in range(states.shape[0]): 39 | with env.physics.reset_context(): 40 | env.physics.set_state(states[i]) 41 | # Input: the current physics of env, then use env.task.get_reward to calculate the reward 42 | reward = env.task.get_reward(env.physics) 43 | reward = np.full(reward_spec.shape, reward, reward_spec.dtype) # 改变shape和dtype 44 | rewards.append(reward) 45 | original_reward = np.mean(episode['reward']) 46 | episode['reward'] = np.array(rewards, dtype=reward_spec.dtype) 47 | # print("Reward difference after relabeling:", original_reward - episode['reward'].mean()) 48 | return episode 49 | 50 | 51 | class OfflineReplayBuffer(IterableDataset): 52 | # 用于 offline training 的 dataset 53 | def __init__(self, env, replay_dir_list, max_size, num_workers, discount, main_task, task_list): 54 | self._env = env 55 | self._replay_dir_list = replay_dir_list 56 | self._size = 0 57 | self._max_size = max_size 58 | self._num_workers = max(1, num_workers) 59 | self._episode_fns = [] 60 | self._episodes = dict() # save as episode 61 | self._discount = discount 62 | self._loaded = False 63 | self._main_task = main_task 64 | self._task_list = task_list 65 | 66 | def _load(self, relable=True): 67 | print("load data", self._replay_dir_list, self._task_list) 68 | for i in range(len(self._replay_dir_list)): # loop 69 | _replay_dir = self._replay_dir_list[i] 70 | _task_share = self._task_list[i] 71 | assert _task_share in str(_replay_dir) 72 | try: 73 | worker_id = torch.utils.data.get_worker_info().id 74 | except: 75 | worker_id = 0 76 | print(f'Loading data from {_replay_dir} and Relabel...', "worker_id:", worker_id) # each worker will run this function 77 | print(f"Need relabeling: {relable and _task_share != self._main_task}") 78 | eps_fns = sorted(_replay_dir.glob('*.npz')) 79 | for eps_fn in eps_fns: 80 | if self._size > self._max_size: 81 | break 82 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 83 | if eps_idx % self._num_workers != worker_id: # read the npz file of the worker 84 | continue 85 | # load a npz file to represent an episodic sample. The keys include 'observation', 'action', 'reward', 'discount', 'physics' 86 | episode = load_episode(eps_fn) 87 | 88 | if relable and _task_share != self._main_task: 89 | # print(f"relabel {_replay_dir} for {self._main_task} task") 90 | episode = self._relable_reward(episode) # relabel 91 | data_flag = _task_share+str(eps_fn)+str(_task_share == self._main_task) 92 | self._episode_fns.append(data_flag) 93 | self._episodes[data_flag] = episode 94 | self._size += episode_len(episode) 95 | # if worker_id == 0: 96 | # print("data_flag:", data_flag) 97 | print("load done. Num of episodes", len(self._episode_fns)*self._num_workers) 98 | 99 | def _sample_episode(self): 100 | if not self._loaded: 101 | self._load() 102 | self._loaded = True 103 | eps_fn = random.choice(self._episode_fns) 104 | return self._episodes[eps_fn], eps_fn.endswith("True") # whether is the main buffer 105 | 106 | def _relable_reward(self, episode): 107 | return relable_episode(self._env, episode) 108 | 109 | def _sample(self): 110 | episode, eps_flag = self._sample_episode() # return the signal 111 | # add +1 for the first dummy transition 112 | idx = np.random.randint(0, episode_len(episode)) + 1 113 | obs = episode['observation'][idx - 1] 114 | action = episode['action'][idx] 115 | next_obs = episode['observation'][idx] 116 | reward = episode['reward'][idx] 117 | discount = episode['discount'][idx] * self._discount 118 | 119 | return (obs, action, reward, discount, next_obs, bool(eps_flag)) 120 | 121 | def __iter__(self): 122 | while True: 123 | yield self._sample() 124 | 125 | 126 | def _worker_init_fn(worker_id): 127 | seed = np.random.get_state()[1][0] + worker_id 128 | np.random.seed(seed) 129 | random.seed(seed) 130 | 131 | 132 | def make_replay_loader(env, replay_dir_list, max_size, batch_size, num_workers, discount, main_task, task_list): 133 | max_size_per_worker = max_size // max(1, num_workers) 134 | 135 | iterable = OfflineReplayBuffer(env, replay_dir_list, max_size_per_worker, 136 | num_workers, discount, main_task, task_list) # task 表示主任务 137 | 138 | loader = torch.utils.data.DataLoader(iterable, 139 | batch_size=batch_size, 140 | num_workers=num_workers, 141 | pin_memory=True, 142 | worker_init_fn=_worker_init_fn) 143 | return loader 144 | 145 | -------------------------------------------------------------------------------- /train_offline_cds.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import os 6 | import random 7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 8 | os.environ['MUJOCO_GL'] = 'egl' 9 | import json 10 | from pathlib import Path 11 | import hydra 12 | import numpy as np 13 | import torch 14 | from dm_env import specs 15 | import dmc 16 | import utils 17 | from logger import Logger 18 | from replay_buffer import make_replay_loader 19 | from video import VideoRecorder 20 | import wandb 21 | 22 | torch.backends.cudnn.benchmark = True 23 | 24 | with open("task.json", "r") as f: 25 | task_dict = json.load(f) 26 | 27 | 28 | def get_domain(task): 29 | if task.startswith('point_mass_maze'): 30 | return 'point_mass_maze' 31 | return task.split('_', 1)[0] 32 | 33 | 34 | def get_data_seed(seed, num_data_seeds): 35 | return (seed - 1) % num_data_seeds + 1 36 | 37 | 38 | def eval(global_step, agent, env, logger, num_eval_episodes, video_recorder): 39 | step, episode, total_reward = 0, 0, 0 40 | eval_until_episode = utils.Until(num_eval_episodes) 41 | while eval_until_episode(episode): 42 | time_step = env.reset() 43 | video_recorder.init(env, enabled=(episode == 0)) 44 | while not time_step.last(): 45 | with torch.no_grad(), utils.eval_mode(agent): 46 | action = agent.act(time_step.observation, step=global_step, eval_mode=True) 47 | time_step = env.step(action) 48 | video_recorder.record(env) 49 | total_reward += time_step.reward 50 | step += 1 51 | 52 | episode += 1 53 | video_recorder.save(f'{global_step}.mp4') 54 | 55 | with logger.log_and_dump_ctx(global_step, ty='eval') as log: 56 | log('episode_reward', total_reward / episode) 57 | log('episode_length', step / episode) 58 | log('step', global_step) 59 | 60 | 61 | @hydra.main(config_path='.', config_name='config_cds') 62 | def main(cfg): 63 | work_dir = Path.cwd() 64 | print(f'workspace: {work_dir}') 65 | 66 | # random seeds 67 | cfg.seed = random.randint(0, 100000) 68 | 69 | utils.set_seed_everywhere(cfg.seed) 70 | device = torch.device(cfg.device) 71 | 72 | # create logger 73 | logger = Logger(work_dir, use_tb=cfg.use_tb) 74 | 75 | # create envs 76 | env = dmc.make(cfg.task, seed=cfg.seed) 77 | 78 | # create agent 79 | agent = hydra.utils.instantiate(cfg.agent, obs_shape=env.observation_spec().shape, 80 | action_shape=env.action_spec().shape, num_expl_steps=0) 81 | 82 | replay_dir_list_main = [] 83 | replay_dir_list_share = [] 84 | 85 | share_tasks = [] 86 | for task_id in range(len(cfg.share_task)): 87 | task = cfg.share_task[task_id] # dataset task 88 | data_type = cfg.data_type[task_id] # dataset type [random, medium, medium-replay, expert, replay] 89 | datasets_dir = work_dir / cfg.replay_buffer_dir # 存储数据的目录 90 | replay_dir = datasets_dir.resolve() / Path(task+"-td3-"+str(data_type)) / 'data' 91 | print(f'replay dir: {replay_dir}') 92 | if task == cfg.task: 93 | replay_dir_list_main.append(replay_dir) 94 | else: 95 | replay_dir_list_share.append(replay_dir) 96 | share_tasks.append(task) 97 | 98 | print("CDS. load main dataset..", cfg.task) 99 | replay_loader_main = make_replay_loader(env, replay_dir_list_main, cfg.replay_buffer_size, 100 | cfg.batch_size // 2, cfg.replay_buffer_num_workers, cfg.discount, # batch size (half) 101 | main_task=cfg.task, task_list=[cfg.task]) 102 | replay_iter_main = iter(replay_loader_main) # run OfflineReplayBuffer.sample function 103 | 104 | print("CDS. load share dataset..", share_tasks) 105 | replay_loader_share = make_replay_loader(env, replay_dir_list_share, cfg.replay_buffer_size, 106 | cfg.batch_size // 2 * 10, cfg.replay_buffer_num_workers, cfg.discount, # batch size是10倍,后取top10 107 | main_task=cfg.task, task_list=share_tasks) 108 | replay_iter_share = iter(replay_loader_share) # run OfflineReplayBuffer.sample function 109 | print("load data done.") 110 | 111 | # for i in replay_iter_share: 112 | # print(i) 113 | # break 114 | 115 | # create video recorders 116 | video_recorder = VideoRecorder(work_dir if cfg.save_video else None) 117 | 118 | timer = utils.Timer() 119 | global_step = 0 120 | 121 | train_until_step = utils.Until(cfg.num_grad_steps) 122 | eval_every_step = utils.Every(cfg.eval_every_steps) 123 | log_every_step = utils.Every(cfg.log_every_steps) 124 | 125 | if cfg.wandb: 126 | path_str = f'{cfg.agent.name}_{cfg.share_task[0]}_{cfg.share_task[1]}_{cfg.data_type[0]}_{cfg.data_type[1]}' 127 | wandb_dir = f"./wandb/{path_str}_{cfg.seed}" 128 | if not os.path.exists(wandb_dir): 129 | os.makedirs(wandb_dir) 130 | wandb.init(project="UTDS", entity='', config=cfg, name=f'{path_str}_1', dir=wandb_dir) 131 | wandb.config.update(vars(cfg)) 132 | 133 | while train_until_step(global_step): 134 | # try to evaluate 135 | if eval_every_step(global_step): 136 | logger.log('eval_total_time', timer.total_time(), global_step) 137 | eval(global_step, agent, env, logger, cfg.num_eval_episodes, video_recorder) 138 | 139 | # train the agent 140 | metrics = agent.update(replay_iter_main, replay_iter_share, global_step, cfg.num_grad_steps) 141 | 142 | # log 143 | logger.log_metrics(metrics, global_step, ty='train') 144 | if log_every_step(global_step): 145 | elapsed_time, total_time = timer.reset() 146 | with logger.log_and_dump_ctx(global_step, ty='train') as log: 147 | log('fps', cfg.log_every_steps / elapsed_time) 148 | log('total_time', total_time) 149 | log('step', global_step) 150 | 151 | global_step += 1 152 | 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /agent/td3_bc.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | import copy 8 | import utils 9 | from dm_control.utils import rewards 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, state_dim, action_dim, max_action=1): 14 | super(Actor, self).__init__() 15 | 16 | self.l1 = nn.Linear(state_dim, 256) 17 | self.l2 = nn.Linear(256, 256) 18 | self.l3 = nn.Linear(256, action_dim) 19 | 20 | self.max_action = max_action 21 | 22 | def forward(self, state): 23 | a = F.relu(self.l1(state)) 24 | a = F.relu(self.l2(a)) 25 | return self.max_action * torch.tanh(self.l3(a)) 26 | 27 | 28 | class Critic(nn.Module): 29 | def __init__(self, state_dim, action_dim): 30 | super(Critic, self).__init__() 31 | 32 | # Q1 architecture 33 | self.l1 = nn.Linear(state_dim + action_dim, 256) 34 | self.l2 = nn.Linear(256, 256) 35 | self.l3 = nn.Linear(256, 1) 36 | 37 | # Q2 architecture 38 | self.l4 = nn.Linear(state_dim + action_dim, 256) 39 | self.l5 = nn.Linear(256, 256) 40 | self.l6 = nn.Linear(256, 1) 41 | 42 | def forward(self, state, action): 43 | sa = torch.cat([state, action], 1) 44 | 45 | q1 = F.relu(self.l1(sa)) 46 | q1 = F.relu(self.l2(q1)) 47 | q1 = self.l3(q1) 48 | 49 | q2 = F.relu(self.l4(sa)) 50 | q2 = F.relu(self.l5(q2)) 51 | q2 = self.l6(q2) 52 | return q1, q2 53 | 54 | def Q1(self, state, action): 55 | sa = torch.cat([state, action], 1) 56 | 57 | q1 = F.relu(self.l1(sa)) 58 | q1 = F.relu(self.l2(q1)) 59 | q1 = self.l3(q1) 60 | return q1 61 | 62 | 63 | class TD3BCAgent: 64 | def __init__(self, 65 | name, 66 | obs_shape, 67 | action_shape, 68 | device, 69 | lr, 70 | hidden_dim, 71 | critic_target_tau, 72 | actor_target_tau, 73 | policy_freq, 74 | policy_noise, 75 | noise_clip, 76 | use_tb, 77 | alpha, 78 | batch_size, 79 | num_expl_steps): 80 | self.policy_noise = policy_noise 81 | self.policy_freq = policy_freq 82 | self.noise_clip = noise_clip 83 | self.num_expl_steps = num_expl_steps 84 | self.action_dim = action_shape[0] 85 | self.hidden_dim = hidden_dim 86 | self.lr = lr 87 | self.device = device 88 | self.critic_target_tau = critic_target_tau 89 | self.actor_target_tau = actor_target_tau 90 | self.use_tb = use_tb 91 | # self.stddev_schedule = stddev_schedule 92 | # self.stddev_clip = stddev_clip 93 | self.alpha = alpha 94 | self.max_action = 1.0 95 | 96 | # models 97 | self.actor = Actor(obs_shape[0], action_shape[0]).to(device) 98 | self.actor_target = copy.deepcopy(self.actor) 99 | 100 | self.critic = Critic(obs_shape[0], action_shape[0]).to(device) 101 | self.critic_target = copy.deepcopy(self.critic) 102 | self.critic_target.load_state_dict(self.critic.state_dict()) 103 | 104 | # optimizers 105 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 106 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 107 | 108 | self.train() 109 | self.critic_target.train() 110 | 111 | def train(self, training=True): 112 | self.training = training 113 | self.actor.train(training) 114 | self.critic.train(training) 115 | 116 | def act(self, obs, step, eval_mode): 117 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 118 | action = self.actor(obs) 119 | if step < self.num_expl_steps: 120 | action.uniform_(-1.0, 1.0) 121 | return action.cpu().numpy()[0] 122 | 123 | def update_critic(self, obs, action, reward, discount, next_obs): 124 | metrics = dict() 125 | 126 | with torch.no_grad(): 127 | # Select action according to policy and add clipped noise 128 | noise = ( 129 | torch.randn_like(action) * self.policy_noise 130 | ).clamp(-self.noise_clip, self.noise_clip) 131 | 132 | next_action = ( 133 | self.actor_target(next_obs) + noise 134 | ).clamp(-self.max_action, self.max_action) 135 | 136 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 137 | target_V = torch.min(target_Q1, target_Q2) 138 | target_Q = reward + (discount * target_V) 139 | 140 | Q1, Q2 = self.critic(obs, action) 141 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 142 | 143 | if self.use_tb: 144 | metrics['critic_target_q'] = target_Q.mean().item() 145 | metrics['critic_q1'] = Q1.mean().item() 146 | metrics['critic_q2'] = Q2.mean().item() 147 | metrics['critic_loss'] = critic_loss.item() 148 | 149 | # optimize critic 150 | self.critic_opt.zero_grad(set_to_none=True) 151 | critic_loss.backward() 152 | self.critic_opt.step() 153 | return metrics 154 | 155 | def update_actor(self, obs, action): 156 | metrics = dict() 157 | 158 | # Compute actor loss 159 | pi = self.actor(obs) 160 | Q = self.critic.Q1(obs, pi) 161 | lmbda = self.alpha / Q.abs().mean().detach() 162 | 163 | actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action) 164 | 165 | # optimize actor 166 | self.actor_opt.zero_grad(set_to_none=True) 167 | actor_loss.backward() 168 | self.actor_opt.step() 169 | 170 | if self.use_tb: 171 | metrics['actor_loss'] = actor_loss.item() 172 | 173 | return metrics 174 | 175 | def update(self, replay_iter, step, total_step=None): 176 | metrics = dict() 177 | 178 | batch = next(replay_iter) 179 | # obs.shape=(1024,obs_dim), action.shape=(1024,1), reward.shape=(1024,1), discount.shape=(1024,1) 180 | obs, action, reward, discount, next_obs, _ = utils.to_torch( 181 | batch, self.device) 182 | 183 | if self.use_tb: 184 | metrics['batch_reward'] = reward.mean().item() 185 | 186 | # update critic 187 | metrics.update( 188 | self.update_critic(obs, action, reward, discount, next_obs)) 189 | 190 | # update actor 191 | if step % self.policy_freq == 0: 192 | metrics.update(self.update_actor(obs, action)) 193 | 194 | # update critic target 195 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau) 196 | utils.soft_update_params(self.actor, self.actor_target, self.actor_target_tau) 197 | 198 | return metrics 199 | -------------------------------------------------------------------------------- /agent/td3.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | import utils 9 | from dm_control.utils import rewards 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, obs_dim, action_dim, hidden_dim): 14 | super().__init__() 15 | 16 | self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim), 17 | nn.LayerNorm(hidden_dim), nn.Tanh(), 18 | nn.Linear(hidden_dim, hidden_dim), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(hidden_dim, action_dim)) 21 | 22 | self.apply(utils.weight_init) 23 | 24 | def forward(self, obs, std): 25 | mu = self.policy(obs) 26 | mu = torch.tanh(mu) 27 | std = torch.ones_like(mu) * std 28 | 29 | dist = utils.TruncatedNormal(mu, std) 30 | return dist 31 | 32 | 33 | class Critic(nn.Module): 34 | def __init__(self, obs_dim, action_dim, hidden_dim): 35 | super().__init__() 36 | 37 | self.q1_net = nn.Sequential( 38 | nn.Linear(obs_dim + action_dim, hidden_dim), 39 | nn.LayerNorm(hidden_dim), nn.Tanh(), 40 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 41 | nn.Linear(hidden_dim, 1)) 42 | 43 | self.q2_net = nn.Sequential( 44 | nn.Linear(obs_dim + action_dim, hidden_dim), 45 | nn.LayerNorm(hidden_dim), nn.Tanh(), 46 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 47 | nn.Linear(hidden_dim, 1)) 48 | 49 | self.apply(utils.weight_init) 50 | 51 | def forward(self, obs, action): 52 | obs_action = torch.cat([obs, action], dim=-1) 53 | q1 = self.q1_net(obs_action) 54 | q2 = self.q2_net(obs_action) 55 | 56 | return q1, q2 57 | 58 | 59 | class TD3Agent: 60 | def __init__(self, 61 | name, 62 | obs_shape, 63 | action_shape, 64 | device, 65 | lr, 66 | hidden_dim, 67 | critic_target_tau, 68 | stddev_schedule, 69 | num_expl_steps, # steps to before training 70 | nstep, 71 | batch_size, 72 | stddev_clip, 73 | use_tb, 74 | # obs_type, 75 | has_next_action=False): 76 | # self.obs_type = obs_type 77 | self.num_expl_steps = num_expl_steps 78 | self.action_dim = action_shape[0] 79 | self.hidden_dim = hidden_dim 80 | self.lr = lr 81 | self.device = device 82 | self.critic_target_tau = critic_target_tau 83 | self.use_tb = use_tb 84 | self.stddev_schedule = stddev_schedule 85 | self.stddev_clip = stddev_clip 86 | 87 | # models 88 | self.actor = Actor(obs_shape[0], action_shape[0], 89 | hidden_dim).to(device) 90 | 91 | self.critic = Critic(obs_shape[0], action_shape[0], 92 | hidden_dim).to(device) 93 | self.critic_target = Critic(obs_shape[0], action_shape[0], 94 | hidden_dim).to(device) 95 | self.critic_target.load_state_dict(self.critic.state_dict()) 96 | 97 | # optimizers 98 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 99 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 100 | 101 | self.train() 102 | self.critic_target.train() 103 | 104 | def train(self, training=True): 105 | self.training = training 106 | self.actor.train(training) 107 | self.critic.train(training) 108 | 109 | def get_meta_specs(self): 110 | return tuple() 111 | 112 | def init_meta(self): 113 | return OrderedDict() 114 | 115 | def update_meta(self, meta, global_step, time_step, finetune=False): 116 | return meta 117 | 118 | def act(self, obs, step, eval_mode, meta=None): 119 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 120 | stddev = utils.schedule(self.stddev_schedule, step) 121 | policy = self.actor(obs, stddev) 122 | if eval_mode: 123 | action = policy.mean 124 | else: 125 | action = policy.sample(clip=None) 126 | if step < self.num_expl_steps: 127 | action.uniform_(-1.0, 1.0) 128 | return action.cpu().numpy()[0] 129 | 130 | def update_critic(self, obs, action, reward, discount, next_obs, step): 131 | metrics = dict() 132 | 133 | with torch.no_grad(): 134 | stddev = utils.schedule(self.stddev_schedule, step) 135 | dist = self.actor(next_obs, stddev) 136 | next_action = dist.sample(clip=self.stddev_clip) 137 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 138 | target_V = torch.min(target_Q1, target_Q2) 139 | target_Q = reward + (discount * target_V) 140 | 141 | Q1, Q2 = self.critic(obs, action) 142 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 143 | 144 | if self.use_tb: 145 | metrics['critic_target_q'] = target_Q.mean().item() 146 | metrics['critic_q1'] = Q1.mean().item() 147 | metrics['critic_q2'] = Q2.mean().item() 148 | metrics['critic_loss'] = critic_loss.item() 149 | 150 | # optimize critic 151 | self.critic_opt.zero_grad(set_to_none=True) 152 | critic_loss.backward() 153 | self.critic_opt.step() 154 | return metrics 155 | 156 | def update_actor(self, obs, action, step): 157 | metrics = dict() 158 | 159 | stddev = utils.schedule(self.stddev_schedule, step) 160 | policy = self.actor(obs, stddev) 161 | 162 | Q1, Q2 = self.critic(obs, policy.sample(clip=self.stddev_clip)) 163 | Q = torch.min(Q1, Q2) 164 | 165 | actor_loss = -Q.mean() 166 | 167 | # optimize actor 168 | self.actor_opt.zero_grad(set_to_none=True) 169 | actor_loss.backward() 170 | self.actor_opt.step() 171 | 172 | if self.use_tb: 173 | metrics['actor_loss'] = actor_loss.item() 174 | metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item() 175 | 176 | return metrics 177 | 178 | def update(self, replay_iter, step): 179 | metrics = dict() 180 | 181 | batch = next(replay_iter) 182 | obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device) 183 | 184 | if self.use_tb: 185 | metrics['batch_reward'] = reward.mean().item() 186 | 187 | # update critic 188 | metrics.update( 189 | self.update_critic(obs, action, reward, discount, next_obs, step)) 190 | 191 | # update actor 192 | metrics.update(self.update_actor(obs, action, step)) 193 | 194 | # update critic target 195 | utils.soft_update_params(self.critic, self.critic_target, 196 | self.critic_target_tau) 197 | 198 | return metrics 199 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from termcolor import colored 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | TRAIN_FORMAT = [('step', 'S', 'int'), ('episode', 'E', 'int'), ('batch_reward', 'BR', 'float'), 13 | ('episode_length', 'L', 'int'), ('episode_reward', 'R', 'float'), 14 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')] 15 | 16 | 17 | EVAL_FORMAT = [('step', 'S', 'int'), ('episode_length', 'L', 'int'), 18 | ('episode_reward', 'R', 'float'), 19 | ('dataset_reward', 'DR', 'float'), ('total_time', 'T', 'time')] 20 | 21 | 22 | class AverageMeter(object): 23 | def __init__(self): 24 | self._sum = 0 25 | self._count = 0 26 | 27 | def update(self, value, n=1): 28 | self._sum += value 29 | self._count += n 30 | 31 | def value(self): 32 | return self._sum / max(1, self._count) 33 | 34 | 35 | class MetersGroup(object): 36 | def __init__(self, csv_file_name, formating): 37 | self._csv_file_name = csv_file_name 38 | self._formating = formating 39 | self._meters = defaultdict(AverageMeter) 40 | self._csv_file = None 41 | self._csv_writer = None 42 | 43 | def log(self, key, value, n=1): 44 | self._meters[key].update(value, n) 45 | 46 | def _prime_meters(self): 47 | data = dict() 48 | for key, meter in self._meters.items(): 49 | if key.startswith('train'): 50 | key = key[len('train') + 1:] 51 | else: 52 | key = key[len('eval') + 1:] 53 | key = key.replace('/', '_') 54 | data[key] = meter.value() 55 | return data 56 | 57 | def _remove_old_entries(self, data): 58 | rows = [] 59 | with self._csv_file_name.open('r') as f: 60 | reader = csv.DictReader(f) 61 | for row in reader: 62 | if float(row['episode']) >= data['episode']: 63 | break 64 | rows.append(row) 65 | with self._csv_file_name.open('w') as f: 66 | writer = csv.DictWriter(f, 67 | fieldnames=sorted(data.keys()), 68 | restval=0.0) 69 | writer.writeheader() 70 | for row in rows: 71 | writer.writerow(row) 72 | 73 | def _dump_to_csv(self, data): 74 | if self._csv_writer is None: 75 | should_write_header = True 76 | if self._csv_file_name.exists(): 77 | self._remove_old_entries(data) 78 | should_write_header = False 79 | 80 | self._csv_file = self._csv_file_name.open('a') 81 | self._csv_writer = csv.DictWriter(self._csv_file, 82 | fieldnames=sorted(data.keys()), 83 | restval=0.0) 84 | if should_write_header: 85 | self._csv_writer.writeheader() 86 | 87 | self._csv_writer.writerow(data) 88 | self._csv_file.flush() 89 | 90 | def _format(self, key, value, ty): 91 | if ty == 'int': 92 | value = int(value) 93 | return f'{key}: {value}' 94 | elif ty == 'float': 95 | return f'{key}: {value:.04f}' 96 | elif ty == 'time': 97 | value = str(datetime.timedelta(seconds=int(value))) 98 | return f'{key}: {value}' 99 | else: 100 | raise f'invalid format type: {ty}' 101 | 102 | def _dump_to_console(self, data, prefix): 103 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 104 | pieces = [f'| {prefix: <14}'] 105 | for key, disp_key, ty in self._formating: 106 | value = data.get(key, 0) 107 | pieces.append(self._format(disp_key, value, ty)) 108 | print(' | '.join(pieces)) 109 | 110 | def dump(self, step, prefix): 111 | if len(self._meters) == 0: 112 | return 113 | data = self._prime_meters() 114 | data['frame'] = step 115 | self._dump_to_csv(data) 116 | self._dump_to_console(data, prefix) 117 | self._meters.clear() 118 | 119 | 120 | class Logger(object): 121 | def __init__(self, log_dir, use_tb, offline=False): 122 | self._log_dir = log_dir 123 | self._train_mg = MetersGroup(log_dir / 'train.csv', formating=TRAIN_FORMAT) 124 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', formating=EVAL_FORMAT) 125 | if use_tb: 126 | self._sw = SummaryWriter(str(log_dir / 'tb')) 127 | else: 128 | self._sw = None 129 | 130 | def _try_sw_log(self, key, value, step): 131 | if self._sw is not None: 132 | self._sw.add_scalar(key, value, step) 133 | 134 | def log(self, key, value, step): 135 | assert key.startswith('train') or key.startswith('eval') 136 | if type(value) == torch.Tensor: 137 | value = value.item() 138 | self._try_sw_log(key, value, step) 139 | mg = self._train_mg if key.startswith('train') else self._eval_mg 140 | mg.log(key, value) 141 | 142 | def log_metrics(self, metrics, step, ty): 143 | for key, value in metrics.items(): 144 | self.log(f'{ty}/{key}', value, step) 145 | 146 | def dump(self, step, ty=None): 147 | if ty is None or ty == 'eval': 148 | self._eval_mg.dump(step, 'eval') 149 | if ty is None or ty == 'train': 150 | self._train_mg.dump(step, 'train') 151 | 152 | def log_and_dump_ctx(self, step, ty): 153 | return LogAndDumpCtx(self, step, ty) 154 | 155 | 156 | class LogAndDumpCtx: 157 | def __init__(self, logger, step, ty): 158 | self._logger = logger 159 | self._step = step 160 | self._ty = ty 161 | 162 | def __enter__(self): 163 | return self 164 | 165 | def __call__(self, key, value): 166 | self._logger.log(f'{self._ty}/{key}', value, self._step) 167 | 168 | def __exit__(self, *args): 169 | self._logger.dump(self._step, self._ty) 170 | -------------------------------------------------------------------------------- /custom_dmc_tasks/cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Cheetah Domain.""" 16 | 17 | import collections 18 | import os 19 | 20 | from dm_control import mujoco 21 | from dm_control.rl import control 22 | from dm_control.suite import base 23 | from dm_control.suite import common 24 | from dm_control.utils import containers 25 | from dm_control.utils import rewards 26 | from dm_control.utils import io as resources 27 | 28 | # How long the simulation will run, in seconds. 29 | _DEFAULT_TIME_LIMIT = 10 30 | 31 | # Running speed above which reward is 1. 32 | _RUN_SPEED = 10 33 | _SPIN_SPEED = 5 34 | 35 | SUITE = containers.TaggedTasks() 36 | 37 | 38 | def make(task, 39 | task_kwargs=None, 40 | environment_kwargs=None, 41 | visualize_reward=False): 42 | task_kwargs = task_kwargs or {} 43 | if environment_kwargs is not None: 44 | task_kwargs = task_kwargs.copy() 45 | task_kwargs['environment_kwargs'] = environment_kwargs 46 | env = SUITE[task](**task_kwargs) 47 | env.task.visualize_reward = visualize_reward 48 | return env 49 | 50 | 51 | def get_model_and_assets(): 52 | """Returns a tuple containing the model XML string and a dict of assets.""" 53 | root_dir = os.path.dirname(os.path.dirname(__file__)) 54 | xml = resources.GetResource( 55 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml')) 56 | return xml, common.ASSETS 57 | 58 | 59 | @SUITE.add('benchmarking') 60 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 61 | """Returns the run task.""" 62 | physics = Physics.from_xml_string(*get_model_and_assets()) 63 | task = Cheetah(forward=True, flip=False, random=random) 64 | environment_kwargs = environment_kwargs or {} 65 | return control.Environment(physics, task, time_limit=time_limit, 66 | **environment_kwargs) 67 | 68 | 69 | @SUITE.add('benchmarking') 70 | def run_backward(time_limit=_DEFAULT_TIME_LIMIT, 71 | random=None, 72 | environment_kwargs=None): 73 | """Returns the run task.""" 74 | physics = Physics.from_xml_string(*get_model_and_assets()) 75 | task = Cheetah(forward=False, flip=False, random=random) 76 | environment_kwargs = environment_kwargs or {} 77 | return control.Environment(physics, 78 | task, 79 | time_limit=time_limit, 80 | **environment_kwargs) 81 | 82 | 83 | @SUITE.add('benchmarking') 84 | def flip(time_limit=_DEFAULT_TIME_LIMIT, 85 | random=None, 86 | environment_kwargs=None): 87 | """Returns the run task.""" 88 | physics = Physics.from_xml_string(*get_model_and_assets()) 89 | task = Cheetah(forward=True, flip=True, random=random) 90 | environment_kwargs = environment_kwargs or {} 91 | return control.Environment(physics, 92 | task, 93 | time_limit=time_limit, 94 | **environment_kwargs) 95 | 96 | 97 | @SUITE.add('benchmarking') 98 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, 99 | random=None, 100 | environment_kwargs=None): 101 | """Returns the run task.""" 102 | physics = Physics.from_xml_string(*get_model_and_assets()) 103 | task = Cheetah(forward=False, flip=True, random=random) 104 | environment_kwargs = environment_kwargs or {} 105 | return control.Environment(physics, 106 | task, 107 | time_limit=time_limit, 108 | **environment_kwargs) 109 | 110 | 111 | class Physics(mujoco.Physics): 112 | """Physics simulation with additional features for the Cheetah domain.""" 113 | def speed(self): 114 | """Returns the horizontal speed of the Cheetah.""" 115 | return self.named.data.sensordata['torso_subtreelinvel'][0] 116 | 117 | def angmomentum(self): # angmomentum 是角动量,用于鼓励旋转 118 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 119 | return self.named.data.subtree_angmom['torso'][1] 120 | 121 | 122 | class Cheetah(base.Task): 123 | """A `Task` to train a running Cheetah.""" 124 | def __init__(self, forward=True, flip=False, random=None): 125 | self._forward = 1 if forward else -1 126 | self._flip = flip 127 | super(Cheetah, self).__init__(random=random) 128 | 129 | def initialize_episode(self, physics): 130 | """Sets the state of the environment at the start of each episode.""" 131 | # The indexing below assumes that all joints have a single DOF. 132 | assert physics.model.nq == physics.model.njnt 133 | is_limited = physics.model.jnt_limited == 1 134 | lower, upper = physics.model.jnt_range[is_limited].T 135 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper) 136 | 137 | # Stabilize the model before the actual simulation. 138 | for _ in range(200): 139 | physics.step() 140 | 141 | physics.data.time = 0 142 | self._timeout_progress = 0 143 | super().initialize_episode(physics) 144 | 145 | def get_observation(self, physics): 146 | """Returns an observation of the state, ignoring horizontal position.""" 147 | obs = collections.OrderedDict() 148 | # Ignores horizontal position to maintain translational invariance. 149 | obs['position'] = physics.data.qpos[1:].copy() 150 | obs['velocity'] = physics.velocity() 151 | return obs 152 | 153 | def get_reward(self, physics): 154 | # TODO: 这里增加了选择. 原始的奖励为 physics.speed() 155 | """Returns a reward to the agent.""" 156 | if self._flip: 157 | reward = rewards.tolerance(self._forward * physics.angmomentum(), 158 | bounds=(_SPIN_SPEED, float('inf')), 159 | margin=_SPIN_SPEED, 160 | value_at_margin=0, 161 | sigmoid='linear') 162 | 163 | else: 164 | reward = rewards.tolerance(self._forward * physics.speed(), 165 | bounds=(_RUN_SPEED, float('inf')), 166 | margin=_RUN_SPEED, 167 | value_at_margin=0, 168 | sigmoid='linear') 169 | return reward 170 | -------------------------------------------------------------------------------- /custom_dmc_tasks/point_mass_maze.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Point-mass domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | from dm_control import mujoco 24 | from dm_control.rl import control 25 | from dm_control.suite import base 26 | from dm_control.suite import common 27 | from dm_control.suite.utils import randomizers 28 | from dm_control.utils import containers 29 | from dm_control.utils import rewards 30 | from dm_control.utils import io as resources 31 | from dm_env import specs 32 | import numpy as np 33 | import os 34 | 35 | _DEFAULT_TIME_LIMIT = 20 36 | SUITE = containers.TaggedTasks() 37 | 38 | 39 | TASKS = [('reach_top_left', np.array([-0.15, 0.15, 0.01])), 40 | ('reach_top_right', np.array([0.15, 0.15, 0.01])), 41 | ('reach_bottom_left', np.array([-0.15, -0.15, 0.01])), 42 | ('reach_bottom_right', np.array([0.15, -0.15, 0.01]))] 43 | 44 | 45 | def make(task, 46 | task_kwargs=None, 47 | environment_kwargs=None, 48 | visualize_reward=False): 49 | task_kwargs = task_kwargs or {} 50 | if environment_kwargs is not None: 51 | task_kwargs = task_kwargs.copy() 52 | task_kwargs['environment_kwargs'] = environment_kwargs 53 | env = SUITE[task](**task_kwargs) 54 | env.task.visualize_reward = visualize_reward 55 | return env 56 | 57 | 58 | def get_model_and_assets(task): 59 | """Returns a tuple containing the model XML string and a dict of assets.""" 60 | root_dir = os.path.dirname(os.path.dirname(__file__)) 61 | xml = resources.GetResource( 62 | os.path.join(root_dir, 'custom_dmc_tasks', f'point_mass_maze_{task}.xml')) 63 | return xml, common.ASSETS 64 | 65 | 66 | @SUITE.add('benchmarking') 67 | def reach_top_left(time_limit=_DEFAULT_TIME_LIMIT, 68 | random=None, 69 | environment_kwargs=None): 70 | """Returns the Run task.""" 71 | physics = Physics.from_xml_string(*get_model_and_assets('reach_top_left')) 72 | task = MultiTaskPointMassMaze(target_id=0, random=random) 73 | environment_kwargs = environment_kwargs or {} 74 | return control.Environment(physics, 75 | task, 76 | time_limit=time_limit, 77 | **environment_kwargs) 78 | 79 | 80 | @SUITE.add('benchmarking') 81 | def reach_top_right(time_limit=_DEFAULT_TIME_LIMIT, 82 | random=None, 83 | environment_kwargs=None): 84 | """Returns the Run task.""" 85 | physics = Physics.from_xml_string(*get_model_and_assets('reach_top_right')) 86 | task = MultiTaskPointMassMaze(target_id=1, random=random) 87 | environment_kwargs = environment_kwargs or {} 88 | return control.Environment(physics, 89 | task, 90 | time_limit=time_limit, 91 | **environment_kwargs) 92 | 93 | 94 | @SUITE.add('benchmarking') 95 | def reach_bottom_left(time_limit=_DEFAULT_TIME_LIMIT, 96 | random=None, 97 | environment_kwargs=None): 98 | """Returns the Run task.""" 99 | physics = Physics.from_xml_string(*get_model_and_assets('reach_bottom_left')) 100 | task = MultiTaskPointMassMaze(target_id=2, random=random) 101 | environment_kwargs = environment_kwargs or {} 102 | return control.Environment(physics, 103 | task, 104 | time_limit=time_limit, 105 | **environment_kwargs) 106 | 107 | 108 | @SUITE.add('benchmarking') 109 | def reach_bottom_right(time_limit=_DEFAULT_TIME_LIMIT, 110 | random=None, 111 | environment_kwargs=None): 112 | """Returns the Run task.""" 113 | physics = Physics.from_xml_string(*get_model_and_assets('reach_bottom_right')) 114 | task = MultiTaskPointMassMaze(target_id=3, random=random) 115 | environment_kwargs = environment_kwargs or {} 116 | return control.Environment(physics, 117 | task, 118 | time_limit=time_limit, 119 | **environment_kwargs) 120 | 121 | 122 | class Physics(mujoco.Physics): 123 | """physics for the point_mass domain.""" 124 | 125 | def mass_to_target_dist(self, target): 126 | """Returns the distance from mass to the target.""" 127 | d = target - self.named.data.geom_xpos['pointmass'] 128 | return np.linalg.norm(d) 129 | 130 | 131 | class MultiTaskPointMassMaze(base.Task): 132 | """A point_mass `Task` to reach target with smooth reward.""" 133 | def __init__(self, target_id, random=None): 134 | """Initialize an instance of `PointMassMaze`. 135 | 136 | Args: 137 | randomize_gains: A `bool`, whether to randomize the actuator gains. 138 | random: Optional, either a `numpy.random.RandomState` instance, an 139 | integer seed for creating a new `RandomState`, or None to select a seed 140 | automatically (default). 141 | """ 142 | self._target = TASKS[target_id][1] 143 | super().__init__(random=random) 144 | 145 | def initialize_episode(self, physics): 146 | """Sets the state of the environment at the start of each episode. 147 | 148 | If _randomize_gains is True, the relationship between the controls and 149 | the joints is randomized, so that each control actuates a random linear 150 | combination of joints. 151 | 152 | Args: 153 | physics: An instance of `mujoco.Physics`. 154 | """ 155 | randomizers.randomize_limited_and_rotational_joints( 156 | physics, self.random) 157 | physics.data.qpos[0] = np.random.uniform(-0.29, -0.15) 158 | physics.data.qpos[1] = np.random.uniform(0.15, 0.29) 159 | #import ipdb; ipdb.set_trace() 160 | physics.named.data.geom_xpos['target'][:] = self._target 161 | 162 | super().initialize_episode(physics) 163 | 164 | def get_observation(self, physics): 165 | """Returns an observation of the state.""" 166 | obs = collections.OrderedDict() 167 | obs['position'] = physics.position() 168 | obs['velocity'] = physics.velocity() 169 | return obs 170 | 171 | def get_reward_spec(self): 172 | return specs.Array(shape=(1,), dtype=np.float32, name='reward') 173 | 174 | def get_reward(self, physics): 175 | """Returns a reward to the agent.""" 176 | target_size = .015 177 | control_reward = rewards.tolerance(physics.control(), margin=1, 178 | value_at_margin=0, 179 | sigmoid='quadratic').mean() 180 | small_control = (control_reward + 4) / 5 181 | near_target = rewards.tolerance(physics.mass_to_target_dist(self._target), 182 | bounds=(0, target_size), margin=target_size) 183 | reward = near_target * small_control 184 | return reward 185 | -------------------------------------------------------------------------------- /custom_dmc_tasks/hopper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Hopper domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | import numpy as np 33 | 34 | SUITE = containers.TaggedTasks() 35 | 36 | _CONTROL_TIMESTEP = .02 # (Seconds) 37 | 38 | # Default duration of an episode, in seconds. 39 | _DEFAULT_TIME_LIMIT = 20 40 | 41 | # Minimal height of torso over foot above which stand reward is 1. 42 | _STAND_HEIGHT = 0.6 43 | 44 | # Hopping speed above which hop reward is 1. 45 | _HOP_SPEED = 2 46 | _SPIN_SPEED = 5 47 | 48 | 49 | def make(task, 50 | task_kwargs=None, 51 | environment_kwargs=None, 52 | visualize_reward=False): 53 | task_kwargs = task_kwargs or {} 54 | if environment_kwargs is not None: 55 | task_kwargs = task_kwargs.copy() 56 | task_kwargs['environment_kwargs'] = environment_kwargs 57 | env = SUITE[task](**task_kwargs) 58 | env.task.visualize_reward = visualize_reward 59 | return env 60 | 61 | def get_model_and_assets(): 62 | """Returns a tuple containing the model XML string and a dict of assets.""" 63 | root_dir = os.path.dirname(os.path.dirname(__file__)) 64 | xml = resources.GetResource( 65 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml')) 66 | return xml, common.ASSETS 67 | 68 | 69 | 70 | @SUITE.add('benchmarking') 71 | def hop_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 72 | """Returns a Hopper that strives to hop forward.""" 73 | physics = Physics.from_xml_string(*get_model_and_assets()) 74 | task = Hopper(hopping=True, forward=False, flip=False, random=random) 75 | environment_kwargs = environment_kwargs or {} 76 | return control.Environment(physics, 77 | task, 78 | time_limit=time_limit, 79 | control_timestep=_CONTROL_TIMESTEP, 80 | **environment_kwargs) 81 | 82 | @SUITE.add('benchmarking') 83 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 84 | """Returns a Hopper that strives to hop forward.""" 85 | physics = Physics.from_xml_string(*get_model_and_assets()) 86 | task = Hopper(hopping=True, forward=True, flip=True, random=random) 87 | environment_kwargs = environment_kwargs or {} 88 | return control.Environment(physics, 89 | task, 90 | time_limit=time_limit, 91 | control_timestep=_CONTROL_TIMESTEP, 92 | **environment_kwargs) 93 | 94 | @SUITE.add('benchmarking') 95 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 96 | """Returns a Hopper that strives to hop forward.""" 97 | physics = Physics.from_xml_string(*get_model_and_assets()) 98 | task = Hopper(hopping=True, forward=False, flip=True, random=random) 99 | environment_kwargs = environment_kwargs or {} 100 | return control.Environment(physics, 101 | task, 102 | time_limit=time_limit, 103 | control_timestep=_CONTROL_TIMESTEP, 104 | **environment_kwargs) 105 | 106 | 107 | class Physics(mujoco.Physics): 108 | """Physics simulation with additional features for the Hopper domain.""" 109 | def height(self): 110 | """Returns height of torso with respect to foot.""" 111 | return (self.named.data.xipos['torso', 'z'] - 112 | self.named.data.xipos['foot', 'z']) 113 | 114 | def speed(self): 115 | """Returns horizontal speed of the Hopper.""" 116 | return self.named.data.sensordata['torso_subtreelinvel'][0] 117 | 118 | def touch(self): 119 | """Returns the signals from two foot touch sensors.""" 120 | return np.log1p(self.named.data.sensordata[['touch_toe', 121 | 'touch_heel']]) 122 | 123 | def angmomentum(self): # TODO: 角动量 124 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 125 | return self.named.data.subtree_angmom['torso'][1] 126 | 127 | 128 | 129 | class Hopper(base.Task): 130 | """A Hopper's `Task` to train a standing and a jumping Hopper.""" 131 | def __init__(self, hopping, forward=True, flip=False, random=None): 132 | """Initialize an instance of `Hopper`. 133 | 134 | Args: 135 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to 136 | balance upright. 137 | random: Optional, either a `numpy.random.RandomState` instance, an 138 | integer seed for creating a new `RandomState`, or None to select a seed 139 | automatically (default). 140 | """ 141 | self._hopping = hopping 142 | self._forward = 1 if forward else -1 143 | self._flip = flip 144 | super(Hopper, self).__init__(random=random) 145 | 146 | def initialize_episode(self, physics): 147 | """Sets the state of the environment at the start of each episode.""" 148 | randomizers.randomize_limited_and_rotational_joints( 149 | physics, self.random) 150 | self._timeout_progress = 0 151 | super(Hopper, self).initialize_episode(physics) 152 | 153 | def get_observation(self, physics): 154 | """Returns an observation of positions, velocities and touch sensors.""" 155 | obs = collections.OrderedDict() 156 | # Ignores horizontal position to maintain translational invariance: 157 | obs['position'] = physics.data.qpos[1:].copy() 158 | obs['velocity'] = physics.velocity() 159 | obs['touch'] = physics.touch() 160 | return obs 161 | 162 | def get_reward(self, physics): 163 | """Returns a reward applicable to the performed task.""" 164 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2)) 165 | assert self._hopping 166 | if self._flip: 167 | hopping = rewards.tolerance(self._forward * physics.angmomentum(), 168 | bounds=(_SPIN_SPEED, float('inf')), 169 | margin=_SPIN_SPEED, 170 | value_at_margin=0, 171 | sigmoid='linear') 172 | else: 173 | hopping = rewards.tolerance(self._forward * physics.speed(), 174 | bounds=(_HOP_SPEED, float('inf')), 175 | margin=_HOP_SPEED / 2, 176 | value_at_margin=0.5, 177 | sigmoid='linear') 178 | return standing * hopping -------------------------------------------------------------------------------- /replay_buffer_collect.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import io 3 | import random 4 | import traceback 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import IterableDataset 11 | 12 | 13 | def episode_len(episode): 14 | # subtract -1 because the dummy first transition 15 | return next(iter(episode.values())).shape[0] - 1 16 | 17 | 18 | def save_episode(episode, fn): 19 | with io.BytesIO() as bs: 20 | np.savez_compressed(bs, **episode) 21 | bs.seek(0) 22 | with fn.open('wb') as f: 23 | f.write(bs.read()) 24 | 25 | 26 | def load_episode(fn): 27 | with fn.open('rb') as f: 28 | episode = np.load(f) 29 | episode = {k: episode[k] for k in episode.keys()} 30 | return episode 31 | 32 | 33 | class ReplayBufferStorage: 34 | def __init__(self, data_specs, meta_specs, replay_dir, dataset_dir): 35 | self._data_specs = data_specs 36 | self._meta_specs = meta_specs 37 | self._replay_dir = replay_dir 38 | replay_dir.mkdir(exist_ok=True) 39 | self._current_episode = defaultdict(list) 40 | self._preload() 41 | # TODO 42 | self._dataset_episodic = {"observation": [], "action": [], "reward": [], "discount": [], "physics": []} 43 | self._dataset_dir = dataset_dir 44 | dataset_dir.mkdir(exist_ok=True) 45 | 46 | def __len__(self): 47 | return self._num_transitions 48 | 49 | def add(self, time_step, meta, physics): 50 | self._dataset_episodic['physics'].append(physics) # add physics 51 | 52 | for key, value in meta.items(): 53 | self._current_episode[key].append(value) 54 | for spec in self._data_specs: 55 | value = time_step[spec.name] 56 | if np.isscalar(value): 57 | value = np.full(spec.shape, value, spec.dtype) 58 | assert spec.shape == value.shape and spec.dtype == value.dtype 59 | self._current_episode[spec.name].append(value) 60 | self._dataset_episodic[spec.name].append(value) # append value 61 | 62 | if time_step.last(): 63 | episode = dict() 64 | for spec in self._data_specs: 65 | value = self._current_episode[spec.name] 66 | episode[spec.name] = np.array(value, spec.dtype) 67 | for spec in self._meta_specs: 68 | value = self._current_episode[spec.name] 69 | episode[spec.name] = np.array(value, spec.dtype) 70 | self._current_episode = defaultdict(list) 71 | self._store_episode(episode) 72 | self._store_dataset() 73 | 74 | def _preload(self): 75 | self._num_episodes = 0 76 | self._num_transitions = 0 77 | for fn in self._replay_dir.glob('*.npz'): 78 | _, _, eps_len = fn.stem.split('_') 79 | self._num_episodes += 1 80 | self._num_transitions += int(eps_len) 81 | 82 | def _store_episode(self, episode): 83 | eps_idx = self._num_episodes 84 | eps_len = episode_len(episode) 85 | self._num_episodes += 1 86 | self._num_transitions += eps_len 87 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 88 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz' 89 | save_episode(episode, self._replay_dir / eps_fn) 90 | 91 | def _store_dataset(self): # save the data 92 | # print("save data to", self._dataset_dir) 93 | dataset_episodic = dict() 94 | for key in self._dataset_episodic.keys(): 95 | value = self._dataset_episodic[key] 96 | dataset_episodic[key] = np.array(value, np.float32) 97 | # print("save:", key, np.array(value, np.float32).shape) 98 | # ts = datetime.datetime.now().strftime('%m%dT%H%M%S') 99 | episode_str = f"{self._num_episodes-1:06d}" 100 | eps_fn = f'episode_{episode_str}_{episode_len(dataset_episodic)}.npz' 101 | np.savez_compressed(self._dataset_dir / eps_fn, **dataset_episodic) 102 | 103 | for k in self._dataset_episodic.keys(): # clear 104 | self._dataset_episodic[k] = [] 105 | 106 | 107 | class ReplayBuffer(IterableDataset): 108 | def __init__(self, storage, max_size, num_workers, nstep, discount, 109 | fetch_every, save_snapshot): 110 | self._storage = storage 111 | self._size = 0 112 | self._max_size = max_size 113 | self._num_workers = max(1, num_workers) 114 | self._episode_fns = [] 115 | self._episodes = dict() 116 | self._nstep = nstep 117 | self._discount = discount 118 | self._fetch_every = fetch_every 119 | self._samples_since_last_fetch = fetch_every 120 | self._save_snapshot = save_snapshot 121 | 122 | def _sample_episode(self): 123 | eps_fn = random.choice(self._episode_fns) 124 | return self._episodes[eps_fn] 125 | 126 | def _store_episode(self, eps_fn): 127 | # store new episodic samples to self._episode_fn 128 | try: 129 | episode = load_episode(eps_fn) 130 | except: 131 | return False 132 | eps_len = episode_len(episode) # 1000 133 | while eps_len + self._size > self._max_size: 134 | early_eps_fn = self._episode_fns.pop(0) 135 | early_eps = self._episodes.pop(early_eps_fn) 136 | self._size -= episode_len(early_eps) 137 | early_eps_fn.unlink(missing_ok=True) 138 | self._episode_fns.append(eps_fn) # num of episodes / num of worker 139 | self._episode_fns.sort() 140 | self._episodes[eps_fn] = episode 141 | self._size += eps_len 142 | # print("store ", eps_len, episode.keys(), ", total:", len(self._episode_fns)) 143 | # print("self._storage._replay_dir.glob('*.npz')", self._storage._replay_dir.glob('*.npz')) 144 | 145 | if not self._save_snapshot: 146 | eps_fn.unlink(missing_ok=True) 147 | return True 148 | 149 | def _try_fetch(self): 150 | if self._samples_since_last_fetch < self._fetch_every: 151 | return 152 | self._samples_since_last_fetch = 0 153 | try: 154 | worker_id = torch.utils.data.get_worker_info().id 155 | except: 156 | worker_id = 0 157 | # print("\nworker id:", worker_id) 158 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True) 159 | # print("lens of eps_fns:", len(eps_fns)) 160 | fetched_size = 0 161 | for eps_fn in eps_fns: 162 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]] 163 | if eps_idx % self._num_workers != worker_id: 164 | continue 165 | if eps_fn in self._episodes.keys(): 166 | break 167 | if fetched_size + eps_len > self._max_size: 168 | break 169 | fetched_size += eps_len 170 | # print("eps_fns:", eps_fn, fetched_size, eps_len) 171 | if not self._store_episode(eps_fn): 172 | break 173 | 174 | def _sample(self): 175 | try: 176 | self._try_fetch() 177 | except: 178 | traceback.print_exc() 179 | self._samples_since_last_fetch += 1 180 | episode = self._sample_episode() 181 | # add +1 for the first dummy transition 182 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1 183 | meta = [] 184 | for spec in self._storage._meta_specs: 185 | meta.append(episode[spec.name][idx - 1]) 186 | obs = episode['observation'][idx - 1] 187 | action = episode['action'][idx] 188 | next_obs = episode['observation'][idx + self._nstep - 1] 189 | reward = np.zeros_like(episode['reward'][idx]) 190 | discount = np.ones_like(episode['discount'][idx]) 191 | for i in range(self._nstep): 192 | step_reward = episode['reward'][idx + i] 193 | reward += discount * step_reward 194 | discount *= episode['discount'][idx + i] * self._discount 195 | return (obs, action, reward, discount, next_obs, *meta) 196 | 197 | def __iter__(self): 198 | while True: 199 | yield self._sample() 200 | 201 | 202 | def _worker_init_fn(worker_id): 203 | seed = np.random.get_state()[1][0] + worker_id 204 | np.random.seed(seed) 205 | random.seed(seed) 206 | 207 | 208 | def make_replay_loader(storage, max_size, batch_size, num_workers, 209 | save_snapshot, nstep, discount): 210 | max_size_per_worker = max_size // max(1, num_workers) 211 | 212 | iterable = ReplayBuffer(storage, 213 | max_size_per_worker, 214 | num_workers, 215 | nstep, 216 | discount, 217 | fetch_every=1000, 218 | save_snapshot=save_snapshot) 219 | 220 | loader = torch.utils.data.DataLoader(iterable, 221 | batch_size=batch_size, 222 | num_workers=num_workers, 223 | pin_memory=True, 224 | worker_init_fn=_worker_init_fn) 225 | return loader 226 | -------------------------------------------------------------------------------- /agent/crr.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | from einops import repeat, rearrange 8 | 9 | import utils 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, obs_dim, action_dim, hidden_dim): 14 | super().__init__() 15 | 16 | self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim), 17 | nn.LayerNorm(hidden_dim), nn.Tanh(), 18 | nn.Linear(hidden_dim, hidden_dim), 19 | nn.ReLU(inplace=True), 20 | nn.Linear(hidden_dim, action_dim)) 21 | 22 | self.apply(utils.weight_init) 23 | 24 | def forward(self, obs, std): 25 | mu = self.policy(obs) 26 | mu = torch.tanh(mu) 27 | std = torch.ones_like(mu) * std 28 | 29 | dist = utils.TruncatedNormal(mu, std) 30 | return dist 31 | 32 | 33 | class Critic(nn.Module): 34 | def __init__(self, obs_dim, action_dim, hidden_dim): 35 | super().__init__() 36 | 37 | self.q1_net = nn.Sequential( 38 | nn.Linear(obs_dim + action_dim, hidden_dim), 39 | nn.LayerNorm(hidden_dim), nn.Tanh(), 40 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 41 | nn.Linear(hidden_dim, 1)) 42 | 43 | self.q2_net = nn.Sequential( 44 | nn.Linear(obs_dim + action_dim, hidden_dim), 45 | nn.LayerNorm(hidden_dim), nn.Tanh(), 46 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 47 | nn.Linear(hidden_dim, 1)) 48 | 49 | self.apply(utils.weight_init) 50 | 51 | def forward(self, obs, action): 52 | obs_action = torch.cat([obs, action], dim=-1) 53 | q1 = self.q1_net(obs_action) 54 | q2 = self.q2_net(obs_action) 55 | 56 | return q1, q2 57 | 58 | 59 | class CRRAgent: 60 | def __init__(self, 61 | name, 62 | obs_shape, 63 | action_shape, 64 | device, 65 | lr, 66 | hidden_dim, 67 | critic_target_tau, 68 | num_value_samples, 69 | weight_func, 70 | stddev_schedule, 71 | nstep, 72 | batch_size, 73 | stddev_clip, 74 | use_tb, 75 | has_next_action=False): 76 | self.action_dim = action_shape[0] 77 | self.hidden_dim = hidden_dim 78 | self.lr = lr 79 | self.device = device 80 | self.critic_target_tau = critic_target_tau 81 | self.use_tb = use_tb 82 | self.stddev_schedule = stddev_schedule 83 | self.stddev_clip = stddev_clip 84 | self.num_value_samples = num_value_samples 85 | self.weight_func = weight_func 86 | 87 | # models 88 | self.actor = Actor(obs_shape[0], action_shape[0], 89 | hidden_dim).to(device) 90 | 91 | self.critic = Critic(obs_shape[0], action_shape[0], 92 | hidden_dim).to(device) 93 | self.critic_target = Critic(obs_shape[0], action_shape[0], 94 | hidden_dim).to(device) 95 | self.critic_target.load_state_dict(self.critic.state_dict()) 96 | 97 | # optimizers 98 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 99 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 100 | 101 | self.train() 102 | self.critic_target.train() 103 | 104 | def train(self, training=True): 105 | self.training = training 106 | self.actor.train(training) 107 | self.critic.train(training) 108 | 109 | def act(self, obs, step, eval_mode): 110 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 111 | stddev = utils.schedule(self.stddev_schedule, step) 112 | policy = self.actor(obs, stddev) 113 | if eval_mode: 114 | action = policy.mean 115 | else: 116 | action = policy.sample(clip=None) 117 | if step < self.num_expl_steps: 118 | action.uniform_(-1.0, 1.0) 119 | return action.cpu().numpy()[0] 120 | 121 | def compute_value(self, obs, step): 122 | obses = repeat(obs, 'b x -> (b n) x', n=self.num_value_samples) 123 | stddev = utils.schedule(self.stddev_schedule, step) 124 | dists = self.actor(obses, stddev) 125 | actions = dists.sample(clip=self.stddev_clip) 126 | Q1, Q2 = self.critic(obses, actions) 127 | Q = torch.min(Q1, Q2) 128 | V = rearrange(Q, '(b n) x -> b n x', 129 | n=self.num_value_samples).mean(dim=1) 130 | 131 | return V 132 | 133 | def adv_transform(self, A): 134 | assert self.weight_func in ['identity', 'indicator', 'exp'] 135 | if self.weight_func == 'identity': 136 | return A 137 | elif self.weight_func == 'indicator': 138 | return torch.sign(torch.relu(A)) 139 | elif self.weight_func == 'exp': 140 | return torch.clamp(A.exp(), 0, 20.0) 141 | else: 142 | assert False, f'wrong weight function: {self.weight_func}' 143 | 144 | def update_critic(self, obs, action, reward, discount, next_obs, step): 145 | metrics = dict() 146 | 147 | with torch.no_grad(): 148 | stddev = utils.schedule(self.stddev_schedule, step) 149 | dist = self.actor(next_obs, stddev) 150 | next_action = dist.sample(clip=self.stddev_clip) 151 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 152 | target_V = torch.min(target_Q1, target_Q2) 153 | target_Q = reward + (discount * target_V) 154 | 155 | Q1, Q2 = self.critic(obs, action) 156 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 157 | 158 | if self.use_tb: 159 | metrics['critic_target_q'] = target_Q.mean().item() 160 | metrics['critic_q1'] = Q1.mean().item() 161 | metrics['critic_q2'] = Q2.mean().item() 162 | metrics['critic_loss'] = critic_loss.item() 163 | 164 | # optimize critic 165 | self.critic_opt.zero_grad(set_to_none=True) 166 | critic_loss.backward() 167 | self.critic_opt.step() 168 | return metrics 169 | 170 | def update_actor(self, obs, action, step): 171 | metrics = dict() 172 | 173 | metrics = dict() 174 | with torch.no_grad(): 175 | V = self.compute_value(obs, step) 176 | Q1, Q2 = self.critic(obs, action) 177 | Q = torch.min(Q1, Q2) 178 | A = Q - V 179 | w = self.adv_transform(A) 180 | 181 | stddev = utils.schedule(self.stddev_schedule, step) 182 | policy = self.actor(obs, stddev) 183 | 184 | log_prob = policy.log_prob(action).sum(-1, keepdim=True) 185 | actor_loss = -(log_prob * w).mean() 186 | 187 | # optimize actor 188 | self.actor_opt.zero_grad(set_to_none=True) 189 | actor_loss.backward() 190 | self.actor_opt.step() 191 | 192 | if self.use_tb: 193 | metrics['actor_loss'] = actor_loss.item() 194 | metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item() 195 | 196 | return metrics 197 | 198 | def update(self, replay_iter, step): 199 | metrics = dict() 200 | 201 | batch = next(replay_iter) 202 | obs, action, reward, discount, next_obs = utils.to_torch( 203 | batch, self.device) 204 | 205 | if self.use_tb: 206 | metrics['batch_reward'] = reward.mean().item() 207 | 208 | # update critic 209 | metrics.update( 210 | self.update_critic(obs, action, reward, discount, next_obs, step)) 211 | 212 | # update actor 213 | metrics.update(self.update_actor(obs, action, step)) 214 | 215 | # update critic target 216 | utils.soft_update_params(self.critic, self.critic_target, 217 | self.critic_target_tau) 218 | 219 | return metrics 220 | -------------------------------------------------------------------------------- /custom_dmc_tasks/jaco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """A task where the goal is to move the hand close to a target prop or site.""" 16 | 17 | import collections 18 | 19 | from dm_control import composer 20 | from dm_control.composer import initializers 21 | from dm_control.composer.observation import observable 22 | from dm_control.composer.variation import distributions 23 | from dm_control.entities import props 24 | from dm_control.manipulation.shared import arenas 25 | from dm_control.manipulation.shared import cameras 26 | from dm_control.manipulation.shared import constants 27 | from dm_control.manipulation.shared import observations 28 | from dm_control.manipulation.shared import registry 29 | from dm_control.manipulation.shared import robots 30 | from dm_control.manipulation.shared import tags 31 | from dm_control.manipulation.shared import workspaces 32 | from dm_control.utils import rewards 33 | from dm_env import specs 34 | import numpy as np 35 | 36 | _ReachWorkspace = collections.namedtuple( 37 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset']) 38 | 39 | # Ensures that the props are not touching the table before settling. 40 | _PROP_Z_OFFSET = 0.001 41 | 42 | _DUPLO_WORKSPACE = _ReachWorkspace( 43 | target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET), 44 | upper=(0.1, 0.1, _PROP_Z_OFFSET)), 45 | tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2), 46 | upper=(0.1, 0.1, 0.4)), 47 | arm_offset=robots.ARM_OFFSET) 48 | 49 | _SITE_WORKSPACE = _ReachWorkspace( 50 | target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02), 51 | upper=(0.2, 0.2, 0.4)), 52 | tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02), 53 | upper=(0.2, 0.2, 0.4)), 54 | arm_offset=robots.ARM_OFFSET) 55 | 56 | _TARGET_RADIUS = 0.05 57 | _TIME_LIMIT = 10. 58 | 59 | TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])), 60 | ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])), 61 | ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])), 62 | ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))] 63 | 64 | 65 | def make(task_id, obs_type, seed): 66 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES 67 | task = _reach(task_id, obs_settings=obs_settings, use_site=True) 68 | return composer.Environment(task, 69 | time_limit=_TIME_LIMIT, 70 | random_state=seed) 71 | 72 | 73 | class MultiTaskReach(composer.Task): 74 | """Bring the hand close to a target prop or site.""" 75 | def __init__(self, task_id, arena, arm, hand, prop, obs_settings, 76 | workspace, control_timestep): 77 | """Initializes a new `Reach` task. 78 | 79 | Args: 80 | arena: `composer.Entity` instance. 81 | arm: `robot_base.RobotArm` instance. 82 | hand: `robot_base.RobotHand` instance. 83 | prop: `composer.Entity` instance specifying the prop to reach to, or None 84 | in which case the target is a fixed site whose position is specified by 85 | the workspace. 86 | obs_settings: `observations.ObservationSettings` instance. 87 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP. 88 | control_timestep: Float specifying the control timestep in seconds. 89 | """ 90 | self._arena = arena 91 | self._arm = arm 92 | self._hand = hand 93 | self._arm.attach(self._hand) 94 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset) 95 | self.control_timestep = control_timestep 96 | self._tcp_initializer = initializers.ToolCenterPointInitializer( 97 | self._hand, 98 | self._arm, 99 | position=distributions.Uniform(*workspace.tcp_bbox), 100 | quaternion=workspaces.DOWN_QUATERNION) 101 | 102 | # Add custom camera observable. 103 | self._task_observables = cameras.add_camera_observables( 104 | arena, obs_settings, cameras.FRONT_CLOSE) 105 | 106 | if task_id == 'reach_multitask': 107 | self._targets = [target for (_, target) in TASKS] 108 | else: 109 | self._targets = [ 110 | target for (task, target) in TASKS if task == task_id 111 | ] 112 | 113 | # target_pos_distribution = distributions.Uniform(*TASKS[task_id]) 114 | self._prop = prop 115 | if prop: 116 | # The prop itself is used to visualize the target location. 117 | self._make_target_site(parent_entity=prop, visible=False) 118 | self._target = self._arena.add_free_entity(prop) 119 | self._prop_placer = initializers.PropPlacer( 120 | props=[prop], 121 | position=target_pos_distribution, 122 | quaternion=workspaces.uniform_z_rotation, 123 | settle_physics=True) 124 | else: 125 | if len(self._targets) == 1: 126 | self._target = self._make_target_site(parent_entity=arena, 127 | visible=True) 128 | 129 | # obs = observable.MJCFFeature('pos', self._target) 130 | # obs.configure(**obs_settings.prop_pose._asdict()) 131 | # self._task_observables['target_position'] = obs 132 | 133 | # Add sites for visualizing the prop and target bounding boxes. 134 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody, 135 | lower=workspace.tcp_bbox.lower, 136 | upper=workspace.tcp_bbox.upper, 137 | rgba=constants.GREEN, 138 | name='tcp_spawn_area') 139 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody, 140 | lower=workspace.target_bbox.lower, 141 | upper=workspace.target_bbox.upper, 142 | rgba=constants.BLUE, 143 | name='target_spawn_area') 144 | 145 | def _make_target_site(self, parent_entity, visible): 146 | return workspaces.add_target_site( 147 | body=parent_entity.mjcf_model.worldbody, 148 | radius=_TARGET_RADIUS, 149 | visible=visible, 150 | rgba=constants.RED, 151 | name='target_site') 152 | 153 | @property 154 | def root_entity(self): 155 | return self._arena 156 | 157 | @property 158 | def arm(self): 159 | return self._arm 160 | 161 | @property 162 | def hand(self): 163 | return self._hand 164 | 165 | def get_reward_spec(self): 166 | n = len(self._targets) 167 | return specs.Array(shape=(n,), dtype=np.float32, name='reward') 168 | 169 | @property 170 | def task_observables(self): 171 | return self._task_observables 172 | 173 | def get_reward(self, physics): 174 | hand_pos = physics.bind(self._hand.tool_center_point).xpos 175 | rews = [] 176 | for target_pos in self._targets: 177 | distance = np.linalg.norm(hand_pos - target_pos) 178 | reward = rewards.tolerance(distance, 179 | bounds=(0, _TARGET_RADIUS), 180 | margin=_TARGET_RADIUS) 181 | rews.append(reward) 182 | rews = np.array(rews).astype(np.float32) 183 | if len(self._targets) == 1: 184 | return rews[0] 185 | return rews 186 | 187 | def initialize_episode(self, physics, random_state): 188 | self._hand.set_grasp(physics, close_factors=random_state.uniform()) 189 | self._tcp_initializer(physics, random_state) 190 | if self._prop: 191 | self._prop_placer(physics, random_state) 192 | else: 193 | if len(self._targets) == 1: 194 | physics.bind(self._target).pos = self._targets[0] 195 | 196 | 197 | def _reach(task_id, obs_settings, use_site): 198 | """Configure and instantiate a `Reach` task. 199 | 200 | Args: 201 | obs_settings: An `observations.ObservationSettings` instance. 202 | use_site: Boolean, if True then the target will be a fixed site, otherwise 203 | it will be a moveable Duplo brick. 204 | 205 | Returns: 206 | An instance of `reach.Reach`. 207 | """ 208 | arena = arenas.Standard() 209 | arm = robots.make_arm(obs_settings=obs_settings) 210 | hand = robots.make_hand(obs_settings=obs_settings) 211 | if use_site: 212 | workspace = _SITE_WORKSPACE 213 | prop = None 214 | else: 215 | workspace = _DUPLO_WORKSPACE 216 | prop = props.Duplo(observable_options=observations.make_options( 217 | obs_settings, observations.FREEPROP_OBSERVABLES)) 218 | task = MultiTaskReach(task_id, 219 | arena=arena, 220 | arm=arm, 221 | hand=hand, 222 | prop=prop, 223 | obs_settings=obs_settings, 224 | workspace=workspace, 225 | control_timestep=constants.CONTROL_TIMESTEP) 226 | return task 227 | -------------------------------------------------------------------------------- /collect_data.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import wandb 3 | warnings.filterwarnings('ignore', category=DeprecationWarning) 4 | 5 | import os 6 | 7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 8 | os.environ['MUJOCO_GL'] = 'egl' 9 | import pickle 10 | from pathlib import Path 11 | 12 | import hydra 13 | import numpy as np 14 | import torch 15 | from dm_env import specs 16 | 17 | import dmc 18 | import utils 19 | from logger import Logger 20 | from replay_buffer_collect import ReplayBufferStorage, make_replay_loader 21 | from video import TrainVideoRecorder, VideoRecorder 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | 26 | def make_agent(obs_spec, action_spec, num_expl_steps, cfg): 27 | cfg.obs_shape = obs_spec.shape 28 | cfg.action_shape = action_spec.shape 29 | cfg.num_expl_steps = num_expl_steps 30 | return hydra.utils.instantiate(cfg) 31 | 32 | 33 | class Workspace: 34 | def __init__(self, cfg): 35 | self.work_dir = Path.cwd() 36 | print(f'workspace: {self.work_dir}') 37 | 38 | self.cfg = cfg 39 | utils.set_seed_everywhere(cfg.seed) 40 | self.device = torch.device(cfg.device) 41 | 42 | # create logger 43 | self.logger = Logger(self.work_dir, use_tb=cfg.use_tb) 44 | self.train_env = dmc.make(cfg.task, action_repeat=cfg.action_repeat, seed=cfg.seed) 45 | self.eval_env = dmc.make(cfg.task, action_repeat=cfg.action_repeat, seed=cfg.seed) 46 | 47 | # create agent 48 | self.agent = make_agent(self.train_env.observation_spec(), 49 | self.train_env.action_spec(), 50 | cfg.num_seed_frames // cfg.action_repeat, 51 | cfg.agent) 52 | 53 | # get meta specs 54 | meta_specs = self.agent.get_meta_specs() 55 | # create replay buffer 56 | data_specs = (self.train_env.observation_spec(), self.train_env.action_spec(), 57 | specs.Array((1,), np.float32, 'reward'), 58 | specs.Array((1,), np.float32, 'discount')) 59 | 60 | # create data storage 61 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs, 62 | replay_dir=self.work_dir / 'buffer', dataset_dir=self.work_dir / 'data') 63 | 64 | # create replay buffer 65 | self.replay_loader = make_replay_loader(self.replay_storage, cfg.replay_buffer_size, 66 | cfg.batch_size, cfg.replay_buffer_num_workers, False, cfg.nstep, cfg.discount) 67 | self._replay_iter = None 68 | 69 | # create video recorders 70 | self.video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None) 71 | self.train_video_recorder = TrainVideoRecorder(self.work_dir if cfg.save_train_video else None) 72 | 73 | self.timer = utils.Timer() 74 | self._global_step = 0 75 | self._global_episode = 0 76 | 77 | # TODO: save agent 78 | self._agent_dir = self.work_dir / 'agent' 79 | self._agent_dir.mkdir(exist_ok=True) 80 | self.change_freq = True 81 | 82 | @property 83 | def global_step(self): 84 | return self._global_step 85 | 86 | @property 87 | def global_episode(self): 88 | return self._global_episode 89 | 90 | @property 91 | def global_frame(self): 92 | return self.global_step * self.cfg.action_repeat 93 | 94 | @property 95 | def replay_iter(self): 96 | if self._replay_iter is None: 97 | self._replay_iter = iter(self.replay_loader) 98 | return self._replay_iter 99 | 100 | def eval(self): 101 | step, episode, total_reward = 0, 0, 0 102 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes) 103 | meta = self.agent.init_meta() 104 | while eval_until_episode(episode): # eval 10 episodes 105 | time_step = self.eval_env.reset() 106 | self.video_recorder.init(self.eval_env, enabled=(episode == 0)) 107 | while not time_step.last(): 108 | with torch.no_grad(), utils.eval_mode(self.agent): 109 | action = self.agent.act(time_step.observation, step=self.global_step, eval_mode=True, meta=meta) 110 | time_step = self.eval_env.step(action) 111 | self.video_recorder.record(self.eval_env) 112 | total_reward += time_step.reward 113 | step += 1 114 | episode += 1 115 | self.video_recorder.save(f'{self.global_frame}.mp4') 116 | 117 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log: 118 | log('episode_reward', total_reward / episode) 119 | log('episode_length', step * self.cfg.action_repeat / episode) 120 | log('episode', self.global_episode) 121 | log('step', self.global_step) 122 | if self.cfg.use_wandb: 123 | wandb.log({"eval_return": total_reward / episode}) 124 | 125 | return total_reward / episode 126 | 127 | def train(self): 128 | train_until_step = utils.Until(self.cfg.num_train_frames, self.cfg.action_repeat) 129 | seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat) 130 | eval_every_step = utils.Every(self.cfg.eval_every_frames, self.cfg.action_repeat) 131 | save_every_step = utils.Every(self.cfg.eval_every_frames, self.cfg.action_repeat) # TODO: save agent 132 | 133 | episode_step, episode_reward = 0, 0 134 | time_step = self.train_env.reset() 135 | meta = self.agent.init_meta() 136 | self.replay_storage.add(time_step, meta, physics=self.train_env.physics.get_state()) # 这里加入了physics信息 137 | self.train_video_recorder.init(time_step.observation) 138 | metrics = None 139 | eval_rew = 0 140 | while train_until_step(self.global_step): 141 | if time_step.last(): 142 | self._global_episode += 1 143 | self.train_video_recorder.save(f'{self.global_frame}.mp4') 144 | # wait until all the metrics schema is populated 145 | if metrics is not None: 146 | # log stats 147 | elapsed_time, total_time = self.timer.reset() 148 | episode_frame = episode_step * self.cfg.action_repeat 149 | with self.logger.log_and_dump_ctx(self.global_frame, ty='train') as log: 150 | log('fps', episode_frame / elapsed_time) 151 | log('total_time', total_time) 152 | log('episode_reward', episode_reward) 153 | log('episode_length', episode_frame) 154 | log('episode', self.global_episode) 155 | log('buffer_size', len(self.replay_storage)) 156 | log('step', self.global_step) 157 | 158 | # reset env 159 | time_step = self.train_env.reset() 160 | meta = self.agent.init_meta() 161 | self.replay_storage.add(time_step, meta, physics=self.train_env.physics.get_state()) 162 | self.train_video_recorder.init(time_step.observation) 163 | 164 | episode_step = 0 165 | episode_reward = 0 166 | 167 | # try to evaluate 168 | if eval_every_step(self.global_step): 169 | self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame) 170 | eval_rew = self.eval() 171 | 172 | # TODO: save policy 173 | if save_every_step(self.global_step): 174 | agent_stamp = self._agent_dir / f'agent-{int(self.global_step/1000)}K-{round(eval_rew, 2)}.pkl' 175 | with open(str(agent_stamp), 'wb') as f_agent: 176 | pickle.dump(self.agent, f_agent) 177 | print("Save agent to", agent_stamp) 178 | if self.global_step >= 200000 and self.change_freq: 179 | save_every_step.change_every(freq=10) # decrease the freq 180 | self.change_freq = False 181 | 182 | meta = self.agent.update_meta(meta, self.global_step, time_step) 183 | if hasattr(self.agent, "regress_meta"): 184 | repeat = self.cfg.action_repeat 185 | every = self.agent.update_task_every_step // repeat 186 | init_step = self.agent.num_init_steps 187 | if self.global_step > (init_step // repeat) and self.global_step % every == 0: 188 | meta = self.agent.regress_meta(self.replay_iter, self.global_step) 189 | 190 | # sample action 191 | with torch.no_grad(), utils.eval_mode(self.agent): 192 | action = self.agent.act(time_step.observation, 193 | meta=meta, step=self.global_step, eval_mode=False) 194 | 195 | # try to update the agent 196 | if not seed_until_step(self.global_step): 197 | metrics = self.agent.update(self.replay_iter, self.global_step) 198 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 199 | if self.cfg.use_wandb: 200 | wandb.log(metrics) 201 | 202 | # take env step 203 | time_step = self.train_env.step(action) 204 | episode_reward += time_step.reward 205 | self.replay_storage.add(time_step, meta, physics=self.train_env.physics.get_state()) 206 | self.train_video_recorder.record(time_step.observation) 207 | episode_step += 1 208 | self._global_step += 1 209 | 210 | 211 | @hydra.main(config_path='.', config_name='collect_data') 212 | def main(cfg): 213 | from collect_data import Workspace as W 214 | root_dir = Path.cwd() 215 | workspace = W(cfg) 216 | snapshot = root_dir / 'snapshot.pt' 217 | if snapshot.exists(): 218 | print(f'resuming: {snapshot}') 219 | workspace.load_snapshot() 220 | 221 | if cfg.use_wandb: 222 | wandb_dir = f"./wandb/collect_{cfg.task}_{cfg.agent.name}_{cfg.seed}" 223 | if not os.path.exists(wandb_dir): 224 | os.makedirs(wandb_dir) 225 | wandb.init(project="UTDS", entity='', config=cfg, group=f'{cfg.task}_{cfg.agent.name}', 226 | name=f'{cfg.task}_{cfg.agent.name}', dir=wandb_dir) 227 | wandb.config.update(vars(cfg)) 228 | 229 | workspace.train() 230 | 231 | 232 | if __name__ == '__main__': 233 | main() 234 | -------------------------------------------------------------------------------- /agent/ddpg.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import hydra 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | 11 | 12 | class Encoder(nn.Module): 13 | def __init__(self, obs_shape): 14 | super().__init__() 15 | 16 | assert len(obs_shape) == 3 17 | self.repr_dim = 32 * 35 * 35 18 | 19 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2), 20 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 21 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 22 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 23 | nn.ReLU()) 24 | 25 | self.apply(utils.weight_init) 26 | 27 | def forward(self, obs): 28 | obs = obs / 255.0 - 0.5 29 | h = self.convnet(obs) 30 | h = h.view(h.shape[0], -1) 31 | return h 32 | 33 | 34 | class Actor(nn.Module): 35 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim): 36 | super().__init__() 37 | 38 | feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim 39 | 40 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim), 41 | nn.LayerNorm(feature_dim), nn.Tanh()) 42 | 43 | policy_layers = [] 44 | policy_layers += [ 45 | nn.Linear(feature_dim, hidden_dim), 46 | nn.ReLU(inplace=True) 47 | ] 48 | # add additional hidden layer for pixels 49 | if obs_type == 'pixels': 50 | policy_layers += [ 51 | nn.Linear(hidden_dim, hidden_dim), 52 | nn.ReLU(inplace=True) 53 | ] 54 | policy_layers += [nn.Linear(hidden_dim, action_dim)] 55 | 56 | self.policy = nn.Sequential(*policy_layers) 57 | 58 | self.apply(utils.weight_init) 59 | 60 | def forward(self, obs, std): 61 | h = self.trunk(obs) 62 | 63 | mu = self.policy(h) 64 | mu = torch.tanh(mu) 65 | std = torch.ones_like(mu) * std 66 | 67 | dist = utils.TruncatedNormal(mu, std) 68 | return dist 69 | 70 | 71 | class Critic(nn.Module): 72 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim): 73 | super().__init__() 74 | 75 | self.obs_type = obs_type 76 | 77 | if obs_type == 'pixels': 78 | # for pixels actions will be added after trunk 79 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim), 80 | nn.LayerNorm(feature_dim), nn.Tanh()) 81 | trunk_dim = feature_dim + action_dim 82 | else: 83 | # for states actions come in the beginning 84 | self.trunk = nn.Sequential( 85 | nn.Linear(obs_dim + action_dim, hidden_dim), 86 | nn.LayerNorm(hidden_dim), nn.Tanh()) 87 | trunk_dim = hidden_dim 88 | 89 | def make_q(): 90 | q_layers = [] 91 | q_layers += [ 92 | nn.Linear(trunk_dim, hidden_dim), 93 | nn.ReLU(inplace=True) 94 | ] 95 | if obs_type == 'pixels': 96 | q_layers += [ 97 | nn.Linear(hidden_dim, hidden_dim), 98 | nn.ReLU(inplace=True) 99 | ] 100 | q_layers += [nn.Linear(hidden_dim, 1)] 101 | return nn.Sequential(*q_layers) 102 | 103 | self.Q1 = make_q() 104 | self.Q2 = make_q() 105 | 106 | self.apply(utils.weight_init) 107 | 108 | def forward(self, obs, action): 109 | inpt = obs if self.obs_type == 'pixels' else torch.cat([obs, action], dim=-1) 110 | h = self.trunk(inpt) 111 | h = torch.cat([h, action], dim=-1) if self.obs_type == 'pixels' else h 112 | 113 | q1 = self.Q1(h) 114 | q2 = self.Q2(h) 115 | 116 | return q1, q2 117 | 118 | 119 | class DDPGAgent: 120 | def __init__(self, 121 | name, 122 | # reward_free, 123 | obs_type, 124 | obs_shape, 125 | action_shape, 126 | device, 127 | lr, 128 | feature_dim, 129 | hidden_dim, 130 | critic_target_tau, 131 | num_expl_steps, 132 | update_every_steps, 133 | stddev_schedule, 134 | nstep, 135 | batch_size, 136 | stddev_clip, 137 | init_critic, 138 | use_tb, 139 | use_wandb, 140 | update_encoder, 141 | meta_dim=0): 142 | self.update_encoder = update_encoder 143 | # self.reward_free = reward_free 144 | self.obs_type = obs_type 145 | self.obs_shape = obs_shape 146 | self.action_dim = action_shape[0] 147 | self.hidden_dim = hidden_dim 148 | self.lr = lr 149 | self.device = device 150 | self.critic_target_tau = critic_target_tau 151 | self.update_every_steps = update_every_steps 152 | self.use_tb = use_tb 153 | self.use_wandb = use_wandb 154 | self.num_expl_steps = num_expl_steps 155 | self.stddev_schedule = stddev_schedule 156 | self.stddev_clip = stddev_clip 157 | self.init_critic = init_critic 158 | self.feature_dim = feature_dim 159 | self.solved_meta = None 160 | 161 | # models 162 | if obs_type == 'pixels': 163 | self.aug = utils.RandomShiftsAug(pad=4) 164 | self.encoder = Encoder(obs_shape).to(device) 165 | self.obs_dim = self.encoder.repr_dim + meta_dim 166 | else: 167 | self.aug = nn.Identity() 168 | self.encoder = nn.Identity() 169 | self.obs_dim = obs_shape[0] + meta_dim 170 | 171 | self.actor = Actor(obs_type, self.obs_dim, self.action_dim, 172 | feature_dim, hidden_dim).to(device) 173 | 174 | self.critic = Critic(obs_type, self.obs_dim, self.action_dim, 175 | feature_dim, hidden_dim).to(device) 176 | self.critic_target = Critic(obs_type, self.obs_dim, self.action_dim, 177 | feature_dim, hidden_dim).to(device) 178 | self.critic_target.load_state_dict(self.critic.state_dict()) 179 | 180 | # optimizers 181 | if obs_type == 'pixels': 182 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr) 183 | else: 184 | self.encoder_opt = None 185 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 186 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 187 | 188 | self.train() 189 | self.critic_target.train() 190 | 191 | def train(self, training=True): 192 | self.training = training 193 | self.encoder.train(training) 194 | self.actor.train(training) 195 | self.critic.train(training) 196 | 197 | def init_from(self, other): 198 | # copy parameters over 199 | utils.hard_update_params(other.encoder, self.encoder) 200 | utils.hard_update_params(other.actor, self.actor) 201 | if self.init_critic: 202 | utils.hard_update_params(other.critic.trunk, self.critic.trunk) 203 | 204 | def get_meta_specs(self): 205 | return tuple() 206 | 207 | def init_meta(self): 208 | return OrderedDict() 209 | 210 | def update_meta(self, meta, global_step, time_step, finetune=False): 211 | return meta 212 | 213 | def act(self, obs, meta, step, eval_mode): 214 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 215 | h = self.encoder(obs) 216 | inputs = [h] 217 | for value in meta.values(): 218 | value = torch.as_tensor(value, device=self.device).unsqueeze(0) 219 | inputs.append(value) 220 | inpt = torch.cat(inputs, dim=-1) 221 | #assert obs.shape[-1] == self.obs_shape[-1] 222 | stddev = utils.schedule(self.stddev_schedule, step) 223 | dist = self.actor(inpt, stddev) 224 | if eval_mode: 225 | action = dist.mean 226 | else: 227 | action = dist.sample(clip=None) 228 | if step < self.num_expl_steps: 229 | action.uniform_(-1.0, 1.0) 230 | return action.cpu().numpy()[0] 231 | 232 | def update_critic(self, obs, action, reward, discount, next_obs, step): 233 | metrics = dict() 234 | 235 | with torch.no_grad(): 236 | stddev = utils.schedule(self.stddev_schedule, step) 237 | dist = self.actor(next_obs, stddev) 238 | next_action = dist.sample(clip=self.stddev_clip) 239 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 240 | target_V = torch.min(target_Q1, target_Q2) 241 | target_Q = reward + (discount * target_V) 242 | 243 | Q1, Q2 = self.critic(obs, action) 244 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 245 | 246 | if self.use_tb or self.use_wandb: 247 | metrics['critic_target_q'] = target_Q.mean().item() 248 | metrics['critic_q1'] = Q1.mean().item() 249 | metrics['critic_q2'] = Q2.mean().item() 250 | metrics['critic_loss'] = critic_loss.item() 251 | 252 | # optimize critic 253 | if self.encoder_opt is not None: 254 | self.encoder_opt.zero_grad(set_to_none=True) 255 | self.critic_opt.zero_grad(set_to_none=True) 256 | critic_loss.backward() 257 | self.critic_opt.step() 258 | if self.encoder_opt is not None: 259 | self.encoder_opt.step() 260 | return metrics 261 | 262 | def update_actor(self, obs, step): 263 | metrics = dict() 264 | 265 | stddev = utils.schedule(self.stddev_schedule, step) 266 | dist = self.actor(obs, stddev) 267 | action = dist.sample(clip=self.stddev_clip) 268 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 269 | Q1, Q2 = self.critic(obs, action) 270 | Q = torch.min(Q1, Q2) 271 | 272 | actor_loss = -Q.mean() 273 | 274 | # optimize actor 275 | self.actor_opt.zero_grad(set_to_none=True) 276 | actor_loss.backward() 277 | self.actor_opt.step() 278 | 279 | if self.use_tb or self.use_wandb: 280 | metrics['actor_loss'] = actor_loss.item() 281 | metrics['actor_logprob'] = log_prob.mean().item() 282 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item() 283 | 284 | return metrics 285 | 286 | def aug_and_encode(self, obs): 287 | obs = self.aug(obs) 288 | return self.encoder(obs) 289 | 290 | def update(self, replay_iter, step): 291 | metrics = dict() 292 | #import ipdb; ipdb.set_trace() 293 | 294 | if step % self.update_every_steps != 0: 295 | return metrics 296 | 297 | batch = next(replay_iter) 298 | obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device) 299 | 300 | # augment and encode 301 | obs = self.aug_and_encode(obs) 302 | with torch.no_grad(): 303 | next_obs = self.aug_and_encode(next_obs) 304 | 305 | if self.use_tb or self.use_wandb: 306 | metrics['batch_reward'] = reward.mean().item() 307 | 308 | if not self.update_encoder: 309 | obs = obs.detach() 310 | next_obs = next_obs.detach() 311 | 312 | # update critic 313 | metrics.update( 314 | self.update_critic(obs, action, reward, discount, next_obs, step)) 315 | 316 | # update actor 317 | metrics.update(self.update_actor(obs.detach(), step)) 318 | 319 | # update critic target 320 | utils.soft_update_params(self.critic, self.critic_target, 321 | self.critic_target_tau) 322 | 323 | return metrics 324 | -------------------------------------------------------------------------------- /custom_dmc_tasks/walker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Planar Walker Domain.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | 24 | from dm_control import mujoco 25 | from dm_control.rl import control 26 | from dm_control.suite import base 27 | from dm_control.suite import common 28 | from dm_control.suite.utils import randomizers 29 | from dm_control.utils import containers 30 | from dm_control.utils import rewards 31 | from dm_control.utils import io as resources 32 | from dm_control import suite 33 | from dm_env import specs 34 | 35 | import numpy as np 36 | 37 | _DEFAULT_TIME_LIMIT = 25 38 | _CONTROL_TIMESTEP = .025 39 | 40 | # Minimal height of torso over foot above which stand reward is 1. 41 | _STAND_HEIGHT = 1.2 42 | 43 | # Horizontal speeds (meters/second) above which move reward is 1. 44 | _WALK_SPEED = 1 45 | _RUN_SPEED = 8 46 | _SPIN_SPEED = 5 47 | 48 | SUITE = containers.TaggedTasks() 49 | 50 | 51 | def make(task, 52 | task_kwargs=None, 53 | environment_kwargs=None, 54 | visualize_reward=False): 55 | task_kwargs = task_kwargs or {} 56 | if environment_kwargs is not None: 57 | task_kwargs = task_kwargs.copy() 58 | task_kwargs['environment_kwargs'] = environment_kwargs 59 | env = SUITE[task](**task_kwargs) 60 | env.task.visualize_reward = visualize_reward 61 | return env 62 | 63 | 64 | def get_model_and_assets(): 65 | """Returns a tuple containing the model XML string and a dict of assets.""" 66 | root_dir = os.path.dirname(os.path.dirname(__file__)) 67 | xml = resources.GetResource( 68 | os.path.join(root_dir, 'custom_dmc_tasks', 'walker.xml')) 69 | return xml, common.ASSETS 70 | 71 | 72 | @SUITE.add('benchmarking') 73 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): 74 | """Returns the Run task.""" 75 | physics = Physics.from_xml_string(*get_model_and_assets()) 76 | task = PlanarWalker(move_speed=_RUN_SPEED, flip=True, random=random) 77 | environment_kwargs = environment_kwargs or {} 78 | return control.Environment(physics, 79 | task, 80 | time_limit=time_limit, 81 | control_timestep=_CONTROL_TIMESTEP, 82 | **environment_kwargs) 83 | 84 | 85 | @SUITE.add('benchmarking') 86 | def multitask(time_limit=_DEFAULT_TIME_LIMIT, 87 | random=None, 88 | environment_kwargs=None): 89 | """Returns the Run task.""" 90 | physics = Physics.from_xml_string(*get_model_and_assets()) 91 | task = MultiTaskPlanarWalker(random=random) 92 | environment_kwargs = environment_kwargs or {} 93 | return control.Environment(physics, 94 | task, 95 | time_limit=time_limit, 96 | control_timestep=_CONTROL_TIMESTEP, 97 | **environment_kwargs) 98 | 99 | 100 | class Physics(mujoco.Physics): 101 | """Physics simulation with additional features for the Walker domain.""" 102 | def torso_upright(self): 103 | """Returns projection from z-axes of torso to the z-axes of world.""" 104 | return self.named.data.xmat['torso', 'zz'] 105 | 106 | def torso_height(self): 107 | """Returns the height of the torso.""" 108 | return self.named.data.xpos['torso', 'z'] 109 | 110 | def horizontal_velocity(self): 111 | """Returns the horizontal velocity of the center-of-mass.""" 112 | return self.named.data.sensordata['torso_subtreelinvel'][0] 113 | 114 | def orientations(self): 115 | """Returns planar orientations of all bodies.""" 116 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel() 117 | 118 | def angmomentum(self): 119 | """Returns the angular momentum of torso of the Cheetah about Y axis.""" 120 | return self.named.data.subtree_angmom['torso'][1] 121 | 122 | 123 | class PlanarWalker(base.Task): 124 | """A planar walker task.""" 125 | def __init__(self, move_speed, flip=False, random=None): 126 | """Initializes an instance of `PlanarWalker`. 127 | 128 | Args: 129 | move_speed: A float. If this value is zero, reward is given simply for 130 | standing up. Otherwise this specifies a target horizontal velocity for 131 | the walking task. 132 | random: Optional, either a `numpy.random.RandomState` instance, an 133 | integer seed for creating a new `RandomState`, or None to select a seed 134 | automatically (default). 135 | """ 136 | self._move_speed = move_speed 137 | self._flip = flip 138 | super().__init__(random=random) 139 | 140 | def initialize_episode(self, physics): 141 | """Sets the state of the environment at the start of each episode. 142 | 143 | In 'standing' mode, use initial orientation and small velocities. 144 | In 'random' mode, randomize joint angles and let fall to the floor. 145 | 146 | Args: 147 | physics: An instance of `Physics`. 148 | 149 | """ 150 | randomizers.randomize_limited_and_rotational_joints( 151 | physics, self.random) 152 | super().initialize_episode(physics) 153 | 154 | def get_observation(self, physics): 155 | """Returns an observation of body orientations, height and velocites.""" 156 | obs = collections.OrderedDict() 157 | obs['orientations'] = physics.orientations() 158 | obs['height'] = physics.torso_height() 159 | obs['velocity'] = physics.velocity() 160 | return obs 161 | 162 | def get_reward(self, physics): 163 | """Returns a reward to the agent.""" 164 | standing = rewards.tolerance(physics.torso_height(), 165 | bounds=(_STAND_HEIGHT, float('inf')), 166 | margin=_STAND_HEIGHT / 2) 167 | upright = (1 + physics.torso_upright()) / 2 168 | stand_reward = (3 * standing + upright) / 4 169 | 170 | if self._flip: 171 | move_reward = rewards.tolerance(physics.angmomentum(), 172 | bounds=(_SPIN_SPEED, float('inf')), 173 | margin=_SPIN_SPEED, 174 | value_at_margin=0, 175 | sigmoid='linear') 176 | else: 177 | move_reward = rewards.tolerance(physics.horizontal_velocity(), 178 | bounds=(self._move_speed, 179 | float('inf')), 180 | margin=self._move_speed / 2, 181 | value_at_margin=0.5, 182 | sigmoid='linear') 183 | 184 | return stand_reward * (5 * move_reward + 1) / 6 185 | 186 | 187 | class MultiTaskPlanarWalker(base.Task): 188 | """A planar walker task.""" 189 | def __init__(self, random=None): 190 | """Initializes an instance of `PlanarWalker`. 191 | 192 | Args: 193 | move_speed: A float. If this value is zero, reward is given simply for 194 | standing up. Otherwise this specifies a target horizontal velocity for 195 | the walking task. 196 | random: Optional, either a `numpy.random.RandomState` instance, an 197 | integer seed for creating a new `RandomState`, or None to select a seed 198 | automatically (default). 199 | """ 200 | super().__init__(random=random) 201 | 202 | def initialize_episode(self, physics): 203 | """Sets the state of the environment at the start of each episode. 204 | 205 | In 'standing' mode, use initial orientation and small velocities. 206 | In 'random' mode, randomize joint angles and let fall to the floor. 207 | 208 | Args: 209 | physics: An instance of `Physics`. 210 | 211 | """ 212 | randomizers.randomize_limited_and_rotational_joints( 213 | physics, self.random) 214 | super().initialize_episode(physics) 215 | 216 | def get_observation(self, physics): 217 | """Returns an observation of body orientations, height and velocites.""" 218 | obs = collections.OrderedDict() 219 | obs['orientations'] = physics.orientations() 220 | obs['height'] = physics.torso_height() 221 | obs['velocity'] = physics.velocity() 222 | return obs 223 | 224 | def get_reward_spec(self): 225 | return specs.Array(shape=(4,), dtype=np.float32, name='reward') 226 | 227 | def get_reward(self, physics): 228 | """Returns a reward to the agent.""" 229 | 230 | # compute stand reward 231 | standing = rewards.tolerance(physics.torso_height(), 232 | bounds=(_STAND_HEIGHT, float('inf')), 233 | margin=_STAND_HEIGHT / 2) 234 | upright = (1 + physics.torso_upright()) / 2 235 | 236 | stand_reward = (3 * standing + upright) / 4 237 | 238 | # compute walk reward 239 | walking = rewards.tolerance(physics.horizontal_velocity(), 240 | bounds=(_WALK_SPEED, float('inf')), 241 | margin=_WALK_SPEED / 2, 242 | value_at_margin=0.5, 243 | sigmoid='linear') 244 | walk_reward = stand_reward * (5 * walking + 1) / 6 245 | 246 | # compute run reward 247 | running = rewards.tolerance(physics.horizontal_velocity(), 248 | bounds=(_RUN_SPEED, float('inf')), 249 | margin=_RUN_SPEED / 2, 250 | value_at_margin=0.5, 251 | sigmoid='linear') 252 | run_reward = stand_reward * (5 * running + 1) / 6 253 | 254 | # compute flip reward 255 | flipping = rewards.tolerance(physics.angmomentum(), 256 | bounds=(_SPIN_SPEED, float('inf')), 257 | margin=_SPIN_SPEED, 258 | value_at_margin=0, 259 | sigmoid='linear') 260 | 261 | flip_reward = stand_reward * (5 * flipping + 1) / 6 262 | 263 | reward = np.array([stand_reward, walk_reward, run_reward, flip_reward]) 264 | return reward.astype(np.float32) 265 | -------------------------------------------------------------------------------- /agent/pbrl.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | import copy 8 | import utils 9 | from dm_control.utils import rewards 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, state_dim, action_dim, max_action=1): 14 | super(Actor, self).__init__() 15 | 16 | self.l1 = nn.Linear(state_dim, 256) 17 | self.l2 = nn.Linear(256, 256) 18 | self.l3 = nn.Linear(256, action_dim) 19 | 20 | self.max_action = max_action 21 | 22 | def forward(self, state): 23 | a = F.relu(self.l1(state)) 24 | a = F.relu(self.l2(a)) 25 | return self.max_action * torch.tanh(self.l3(a)) 26 | 27 | 28 | class Critic(nn.Module): 29 | def __init__(self, state_dim, action_dim): 30 | super(Critic, self).__init__() 31 | self.l1 = nn.Linear(state_dim + action_dim, 256) 32 | self.l2 = nn.Linear(256, 256) 33 | self.l3 = nn.Linear(256, 1) 34 | 35 | def forward(self, state, action): 36 | sa = torch.cat([state, action], 1) 37 | q1 = F.relu(self.l1(sa)) 38 | q1 = F.relu(self.l2(q1)) 39 | q1 = self.l3(q1) 40 | return q1 41 | 42 | 43 | class PBRLAgent: 44 | def __init__(self, 45 | name, 46 | obs_shape, 47 | action_shape, 48 | device, 49 | lr, 50 | hidden_dim, 51 | critic_target_tau, 52 | actor_target_tau, 53 | policy_freq, 54 | policy_noise, 55 | noise_clip, 56 | use_tb, 57 | # alpha, 58 | batch_size, 59 | num_expl_steps, 60 | # PBRL parameters 61 | num_random, 62 | ucb_ratio_in, 63 | ucb_ratio_ood_init, 64 | ucb_ratio_ood_min, 65 | ood_decay_factor, 66 | ensemble, 67 | ood_noise, 68 | share_ratio, 69 | has_next_action=False): 70 | self.policy_noise = policy_noise 71 | self.policy_freq = policy_freq 72 | self.noise_clip = noise_clip 73 | self.num_expl_steps = num_expl_steps 74 | self.action_dim = action_shape[0] 75 | self.hidden_dim = hidden_dim 76 | self.lr = lr 77 | self.device = device 78 | self.critic_target_tau = critic_target_tau 79 | self.actor_target_tau = actor_target_tau 80 | self.use_tb = use_tb 81 | # self.stddev_schedule = stddev_schedule 82 | # self.stddev_clip = stddev_clip 83 | # self.alpha = alpha 84 | self.max_action = 1.0 85 | self.share_ratio = share_ratio 86 | self.share_ratio_now = None 87 | 88 | # PBRL parameters 89 | self.num_random = num_random # for ood action 90 | self.ucb_ratio_in = ucb_ratio_in 91 | self.ensemble = ensemble 92 | self.ood_noise = ood_noise 93 | 94 | # PBRL parameters: control ood ratio 95 | self.ucb_ratio_ood_init = ucb_ratio_ood_init 96 | self.ucb_ratio_ood_min = ucb_ratio_ood_min 97 | self.ood_decay_factor = ood_decay_factor 98 | self.ucb_ratio_ood = ucb_ratio_ood_init 99 | self.ucb_ratio_ood_linear_steps = None 100 | 101 | # models 102 | self.actor = Actor(obs_shape[0], action_shape[0]).to(device) 103 | self.actor_target = copy.deepcopy(self.actor) 104 | 105 | # initialize ensemble of critic 106 | self.critic, self.critic_target = [], [] 107 | for _ in range(self.ensemble): 108 | single_critic = Critic(obs_shape[0], action_shape[0]).to(device) 109 | single_critic_target = copy.deepcopy(single_critic) 110 | single_critic_target.load_state_dict(single_critic.state_dict()) 111 | self.critic.append(single_critic) 112 | self.critic_target.append(single_critic_target) 113 | print("Actor parameters:", utils.total_parameters(self.actor)) 114 | print("Critic parameters: single", utils.total_parameters(self.critic[0]), ", total:", utils.total_parameters(self.critic)) 115 | 116 | # optimizers 117 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 118 | self.critic_opt = [] 119 | for i in range(self.ensemble): # each ensemble member has its optimizer 120 | self.critic_opt.append(torch.optim.Adam(self.critic[i].parameters(), lr=lr)) 121 | 122 | self.train() 123 | for ct_single in self.critic_target: 124 | ct_single.train() 125 | 126 | def train(self, training=True): 127 | self.training = training 128 | self.actor.train(training) 129 | for c_single in self.critic: 130 | c_single.train(training) 131 | 132 | def act(self, obs, step, eval_mode): 133 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 134 | action = self.actor(obs) 135 | if step < self.num_expl_steps: 136 | action.uniform_(-1.0, 1.0) 137 | return action.cpu().numpy()[0] 138 | 139 | def ucb_func(self, obs, action, mean=False): 140 | action_shape = action.shape[0] # 1024*num_random 141 | obs_shape = obs.shape[0] # 1024 142 | assert int(action_shape / obs_shape) in [1, self.num_random] 143 | if int(action_shape / obs_shape) != 1: 144 | obs = obs.unsqueeze(1).repeat(1, self.num_random, 1).view(obs.shape[0] * self.num_random, obs.shape[1]) 145 | # Bootstrapped uncertainty 146 | q_pred = [] 147 | for i in range(self.ensemble): 148 | q_pred.append(self.critic[i](obs.cuda(), action.cuda())) 149 | ucb = torch.std(torch.hstack(q_pred), dim=1, keepdim=True) # (1024, ensemble) -> (1024, 1) 150 | assert ucb.size() == (action_shape, 1) 151 | if mean: 152 | q_pred = torch.mean(torch.hstack(q_pred), dim=1, keepdim=True) 153 | return ucb, q_pred 154 | 155 | def ucb_func_target(self, obs_next, act_next): 156 | action_shape = act_next.shape[0] # 2560 157 | obs_shape = obs_next.shape[0] # 256 158 | assert int(action_shape / obs_shape) in [1, self.num_random] 159 | if int(action_shape / obs_shape) != 1: 160 | obs_next = obs_next.unsqueeze(1).repeat(1, self.num_random, 1).view(obs_next.shape[0] * self.num_random, obs_next.shape[1]) # (2560, obs_dim) 161 | # Bootstrapped uncertainty 162 | target_q_pred = [] 163 | for i in range(self.ensemble): 164 | target_q_pred.append(self.critic[i](obs_next.cuda(), act_next.cuda())) 165 | ucb_t = torch.std(torch.hstack(target_q_pred), dim=1, keepdim=True) 166 | assert ucb_t.size() == (action_shape, 1) 167 | return ucb_t, target_q_pred 168 | 169 | def update_critic(self, obs, action, reward, discount, next_obs, step, total_step, bool_flag): 170 | self.share_ratio_now = utils.decay_linear(t=step, init=self.share_ratio, minimum=1.0, total_steps=total_step // 2) 171 | 172 | metrics = dict() 173 | with torch.no_grad(): 174 | # Select action according to policy and add clipped noise 175 | noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip) 176 | next_action = (self.actor_target(next_obs) + noise).clamp(-self.max_action, self.max_action) 177 | 178 | # ood sample 1 179 | sampled_current_actions = self.actor(obs).unsqueeze(1).repeat(1, self.num_random, 1).view( 180 | action.shape[0]*self.num_random, action.shape[1]) 181 | noise_current = (torch.randn_like(sampled_current_actions) * self.ood_noise).clamp(-self.noise_clip, self.noise_clip) 182 | sampled_current_actions = (sampled_current_actions + noise_current).clamp(-self.max_action, self.max_action) 183 | 184 | # ood sample 2 185 | sampled_next_actions = self.actor(next_obs).unsqueeze(1).repeat(1, self.num_random, 1).view( 186 | action.shape[0]*self.num_random, action.shape[1]) 187 | noise_next = (torch.randn_like(sampled_next_actions) * self.ood_noise).clamp(-self.noise_clip, self.noise_clip) 188 | sampled_next_actions = (sampled_next_actions + noise_next).clamp(-self.max_action, self.max_action) 189 | 190 | # random sample 191 | random_actions = torch.FloatTensor(action.shape[0]*self.num_random, action.shape[1]).uniform_( 192 | -self.max_action, self.max_action).to(self.device) 193 | 194 | # TODO: UCB and Q-values 195 | ucb_current, q_pred = self.ucb_func(obs, action) # (1024,1). lenth=ensemble, q_pred[0].shape=(1024,1) 196 | ucb_next, target_q_pred = self.ucb_func_target(next_obs, next_action) # (1024,1). lenth=ensemble, target_q_pred[0].shape=(1024,1) 197 | 198 | ucb_curr_actions_ood, qf_curr_actions_all_ood = self.ucb_func(obs, sampled_current_actions) # (1024*num_random, 1), length=ensemble, (1024*num_random, 1) 199 | ucb_next_actions_ood, qf_next_actions_all_ood = self.ucb_func(next_obs, sampled_next_actions) # 同上 200 | # ucb_rand_ood, qf_rand_actions_all_ood = self.ucb_func(obs, random_actions) 201 | 202 | for qf_index in np.arange(self.ensemble): 203 | ucb_ratio_in_flag = bool_flag * self.ucb_ratio_in + (1 - bool_flag) * self.ucb_ratio_in * self.share_ratio_now 204 | ucb_ratio_in_flag = np.expand_dims(ucb_ratio_in_flag, 1) 205 | q_target = reward + discount * (target_q_pred[qf_index] - torch.from_numpy(ucb_ratio_in_flag.astype(np.float32)).cuda() * ucb_next) # (1024, 1), (1024, 1), (1024, 1) 206 | # print("bool flag", bool_flag[:10], bool_flag[-10:]) 207 | # print("ucb_ratio_in_flag", q_target.shape, ucb_ratio_in_flag.shape, ucb_next.shape, (torch.from_numpy(ucb_ratio_in_flag.astype(np.float32)).cuda() * ucb_next).shape, ucb_ratio_in_flag[:10]) 208 | 209 | # q_target = reward + discount * (target_q_pred[qf_index] - self.ucb_ratio_in * ucb_next) # (1024, 1), (1024, 1), (1024, 1) 210 | q_target = q_target.detach() 211 | qf_loss_in = F.mse_loss(q_pred[qf_index], q_target) 212 | 213 | # TODO: ood loss 214 | cat_qf_ood = torch.cat([qf_curr_actions_all_ood[qf_index], 215 | qf_next_actions_all_ood[qf_index]], 0) 216 | # assert cat_qf_ood.size() == (1024*self.num_random*3, 1) 217 | 218 | ucb_ratio_ood_flag = bool_flag * self.ucb_ratio_ood + (1 - bool_flag) * self.ucb_ratio_ood * self.share_ratio_now 219 | ucb_ratio_ood_flag = np.expand_dims(ucb_ratio_ood_flag, 1).repeat(self.num_random, axis=1).reshape(-1, 1).astype(np.float32) 220 | # print("ucb_ratio_ood_flag 1", ucb_ratio_ood_flag.shape, ucb_curr_actions_ood.shape) 221 | 222 | cat_qf_ood_target = torch.cat([ 223 | torch.maximum(qf_curr_actions_all_ood[qf_index] - torch.from_numpy(ucb_ratio_ood_flag).cuda() * ucb_curr_actions_ood, torch.zeros(1).cuda()), 224 | torch.maximum(qf_next_actions_all_ood[qf_index] - torch.from_numpy(ucb_ratio_ood_flag).cuda() * ucb_next_actions_ood, torch.zeros(1).cuda())], 0) 225 | # print("ucb_ratio_ood_flag 2", cat_qf_ood_target.shape, qf_curr_actions_all_ood[qf_index].shape) 226 | cat_qf_ood_target = cat_qf_ood_target.detach() 227 | 228 | # assert cat_qf_ood_target.size() == (1024*self.num_random*3, 1) 229 | qf_loss_ood = F.mse_loss(cat_qf_ood, cat_qf_ood_target) 230 | critic_loss = qf_loss_in + qf_loss_ood 231 | 232 | # Update the Q-functions 233 | self.critic_opt[qf_index].zero_grad() 234 | critic_loss.backward(retain_graph=True) 235 | self.critic_opt[qf_index].step() 236 | 237 | # change the ood ratio 238 | self.ucb_ratio_ood = max(self.ucb_ratio_ood_init * self.ood_decay_factor ** step, self.ucb_ratio_ood_min) 239 | 240 | if self.use_tb: 241 | metrics['critic_target_q'] = q_target.mean().item() 242 | metrics['critic_q1'] = q_pred[0].mean().item() 243 | # metrics['critic_q2'] = q_pred[1].mean().item() 244 | # ucb 245 | metrics['ucb_current'] = ucb_current.mean().item() 246 | metrics['ucb_next'] = ucb_next.mean().item() 247 | metrics['ucb_curr_actions_ood'] = ucb_curr_actions_ood.mean().item() 248 | metrics['ucb_next_actions_ood'] = ucb_next_actions_ood.mean().item() 249 | # loss 250 | metrics['critic_loss_in'] = qf_loss_in.item() 251 | metrics['critic_loss_ood'] = qf_loss_ood.item() 252 | metrics['ucb_ratio_ood'] = self.ucb_ratio_ood 253 | metrics['share_ratio_now'] = self.share_ratio_now 254 | return metrics 255 | 256 | def update_actor(self, obs, action): 257 | metrics = dict() 258 | 259 | # Compute actor loss 260 | pi = self.actor(obs) 261 | 262 | Qvalues = [] 263 | for i in range(self.ensemble): 264 | Qvalues.append(self.critic[i](obs, pi)) # (1024, 1) 265 | Qvalues_min = torch.min(torch.hstack(Qvalues), dim=1, keepdim=True).values 266 | assert Qvalues_min.size() == (1024, 1) 267 | 268 | actor_loss = -1. * Qvalues_min.mean() 269 | 270 | # optimize actor 271 | self.actor_opt.zero_grad(set_to_none=True) 272 | actor_loss.backward() 273 | self.actor_opt.step() 274 | 275 | if self.use_tb: 276 | metrics['actor_loss'] = actor_loss.item() 277 | 278 | return metrics 279 | 280 | def update(self, replay_iter, step, total_step): 281 | metrics = dict() 282 | 283 | batch = next(replay_iter) 284 | obs, action, reward, discount, next_obs, bool_flag = utils.to_torch( 285 | batch, self.device) 286 | bool_flag = bool_flag.cpu().detach().numpy() 287 | 288 | if self.use_tb: 289 | metrics['batch_reward'] = reward.mean().item() 290 | 291 | # update critic 292 | metrics.update( 293 | self.update_critic(obs, action, reward, discount, next_obs, step, total_step, bool_flag)) 294 | 295 | # update actor 296 | if step % self.policy_freq == 0: 297 | metrics.update(self.update_actor(obs, action)) 298 | 299 | # update actor target 300 | utils.soft_update_params(self.actor, self.actor_target, self.actor_target_tau) 301 | 302 | # update critic target 303 | for i in range(self.ensemble): 304 | utils.soft_update_params(self.critic[i], self.critic_target[i], self.critic_target_tau) 305 | 306 | return metrics 307 | -------------------------------------------------------------------------------- /agent/cql_cdsz.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | import utils 9 | from dm_control.utils import rewards 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3): 14 | super().__init__() 15 | 16 | self.policy = nn.Sequential( 17 | nn.Linear(obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 18 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 19 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 20 | 21 | self.fc_mu = nn.Linear(hidden_dim, action_dim) 22 | self.fc_logstd = nn.Linear(hidden_dim, action_dim) 23 | 24 | def forward(self, obs): 25 | f = self.policy(obs) 26 | mu = self.fc_mu(f) 27 | log_std = self.fc_logstd(f) 28 | 29 | mu = torch.clamp(mu, -5, 5) 30 | 31 | std = log_std.clamp(-5, 2).exp() 32 | dist = utils.SquashedNormal2(mu, std) 33 | return dist 34 | 35 | 36 | class Critic(nn.Module): 37 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3): 38 | super().__init__() 39 | 40 | self.q1_net = nn.Sequential( 41 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 42 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(), 43 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 44 | self.q1_last = nn.Linear(hidden_dim, 1) 45 | 46 | self.q2_net = nn.Sequential( 47 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 48 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(), 49 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 50 | self.q2_last = nn.Linear(hidden_dim, 1) 51 | 52 | def forward(self, obs, action): 53 | obs_action = torch.cat([obs, action], dim=-1) 54 | q1 = self.q1_net(obs_action) 55 | q1 = self.q1_last(q1) 56 | 57 | q2 = self.q2_net(obs_action) 58 | q2 = self.q2_last(q2) 59 | 60 | return q1, q2 61 | 62 | 63 | class CQLCDSZeroAgent: 64 | def __init__(self, 65 | name, 66 | obs_shape, 67 | action_shape, 68 | device, 69 | actor_lr, 70 | critic_lr, 71 | hidden_dim, 72 | critic_target_tau, 73 | nstep, 74 | batch_size, 75 | use_tb, 76 | alpha, 77 | n_samples, 78 | target_cql_penalty, 79 | use_critic_lagrange, 80 | num_expl_steps, 81 | has_next_action=False): 82 | self.num_expl_steps = num_expl_steps 83 | self.action_dim = action_shape[0] 84 | self.hidden_dim = hidden_dim 85 | self.actor_lr = actor_lr 86 | self.critic_lr = critic_lr 87 | self.device = device 88 | self.critic_target_tau = critic_target_tau 89 | self.use_tb = use_tb 90 | self.use_critic_lagrange = use_critic_lagrange 91 | self.target_cql_penalty = target_cql_penalty 92 | 93 | self.alpha = alpha 94 | self.n_samples = n_samples 95 | 96 | state_dim = obs_shape[0] 97 | action_dim = action_shape[0] 98 | 99 | # models 100 | self.actor = Actor(state_dim, action_dim, hidden_dim).to(device) 101 | self.critic = Critic(state_dim, action_dim, hidden_dim).to(device) 102 | self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(device) 103 | self.critic_target.load_state_dict(self.critic.state_dict()) 104 | 105 | # lagrange multipliers 106 | self.target_entropy = -self.action_dim 107 | self.log_actor_alpha = torch.zeros(1, requires_grad=True, device=device) 108 | self.log_critic_alpha = torch.zeros(1, requires_grad=True, device=device) 109 | 110 | # optimizers 111 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) 112 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) 113 | self.actor_alpha_opt = torch.optim.Adam([self.log_actor_alpha], lr=actor_lr) 114 | self.critic_alpha_opt = torch.optim.Adam([self.log_critic_alpha], lr=actor_lr) 115 | 116 | self.train() 117 | # self.critic_target.train() 118 | 119 | def train(self, training=True): 120 | self.training = training 121 | self.actor.train(training) 122 | self.critic.train(training) 123 | 124 | def act(self, obs, step, eval_mode): 125 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 126 | policy = self.actor(obs) 127 | if eval_mode: 128 | action = policy.mean 129 | else: 130 | action = policy.sample() 131 | if step < self.num_expl_steps: 132 | action.uniform_(-1.0, 1.0) 133 | return action.cpu().numpy()[0] 134 | 135 | def _repeated_critic_apply(self, obs, actions): 136 | """ 137 | obs is (batch_size, obs_dim) 138 | actions is (n_samples, batch_size, action_dim) 139 | 140 | output tensors are (n_samples, batch_size, 1) 141 | """ 142 | batch_size = obs.shape[0] 143 | n_samples = actions.shape[0] 144 | 145 | reshaped_actions = actions.reshape((n_samples * batch_size, -1)) 146 | repeated_obs = obs.unsqueeze(0).repeat((n_samples, 1, 1)) 147 | repeated_obs = repeated_obs.reshape((n_samples * batch_size, -1)) 148 | 149 | Q1, Q2 = self.critic(repeated_obs, reshaped_actions) 150 | Q1 = Q1.reshape((n_samples, batch_size, 1)) 151 | Q2 = Q2.reshape((n_samples, batch_size, 1)) 152 | 153 | return Q1, Q2 154 | 155 | def update_critic(self, obs, action, reward, discount, next_obs, step): 156 | metrics = dict() 157 | 158 | # Compute standard SAC loss 159 | with torch.no_grad(): 160 | dist = self.actor(next_obs) # SquashedNormal分布 161 | sampled_next_action = dist.sample() # (1024, act_dim) 162 | # print("sampled_next_action:", sampled_next_action.shape) 163 | target_Q1, target_Q2 = self.critic_target(next_obs, sampled_next_action) # (1024,1), (1024,1) 164 | target_V = torch.min(target_Q1, target_Q2) # (1024,1) 165 | target_Q = reward + (discount * target_V) # (1024,1) 166 | 167 | Q1, Q2 = self.critic(obs, action) 168 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) # 标量 169 | 170 | # Add CQL penalty 171 | with torch.no_grad(): 172 | random_actions = torch.FloatTensor(self.n_samples, Q1.shape[0], 173 | action.shape[-1]).uniform_(-1, 1).to(self.device) # (n_samples, 1024, act_dim) 174 | sampled_actions = self.actor(obs).sample( 175 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim) 176 | # print("sampled_actions:", sampled_actions.shape) 177 | next_sampled_actions = self.actor(next_obs).sample( 178 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim) 179 | # print("next_sampled_actions:", next_sampled_actions.shape) 180 | 181 | rand_Q1, rand_Q2 = self._repeated_critic_apply(obs, random_actions) # (n_samples, 1024, 1) 182 | sampled_Q1, sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1) 183 | obs, sampled_actions) 184 | next_sampled_Q1, next_sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1) 185 | obs, next_sampled_actions) 186 | 187 | # 默认情况1 188 | cat_Q1 = torch.cat([rand_Q1, sampled_Q1, next_sampled_Q1, 189 | Q1.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1) 190 | cat_Q2 = torch.cat([rand_Q2, sampled_Q2, next_sampled_Q2, 191 | Q2.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1) 192 | assert (not torch.isnan(cat_Q1).any()) and (not torch.isnan(cat_Q2).any()) 193 | 194 | cql_logsumexp1 = torch.logsumexp(cat_Q1, dim=0).mean() 195 | cql_logsumexp2 = torch.logsumexp(cat_Q2, dim=0).mean() 196 | cql_logsumexp = cql_logsumexp1 + cql_logsumexp2 197 | cql_penalty = cql_logsumexp - (Q1 + Q2).mean() # 标量 198 | 199 | # Update lagrange multiplier 200 | if self.use_critic_lagrange: 201 | alpha = torch.clamp(self.log_critic_alpha.exp(), min=0.0, max=1000000.0) 202 | alpha_loss = -0.5 * alpha * (cql_penalty - self.target_cql_penalty) 203 | 204 | self.critic_alpha_opt.zero_grad() 205 | alpha_loss.backward(retain_graph=True) 206 | self.critic_alpha_opt.step() 207 | alpha = torch.clamp(self.log_critic_alpha.exp(), 208 | min=0.0, 209 | max=1000000.0).detach() 210 | else: 211 | alpha = self.alpha 212 | 213 | # Combine losses 214 | critic_loss = critic_loss + alpha * cql_penalty 215 | 216 | # optimize critic 217 | self.critic_opt.zero_grad(set_to_none=True) 218 | critic_loss.backward() 219 | # torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 50) 220 | critic_grad_norm = utils.grad_norm(self.critic.parameters()) 221 | self.critic_opt.step() 222 | 223 | if self.use_tb: 224 | metrics['critic_target_q'] = target_Q.mean().item() 225 | metrics['critic_q1'] = Q1.mean().item() 226 | # metrics['critic_q2'] = Q2.mean().item() 227 | metrics['critic_loss'] = critic_loss.item() 228 | metrics['critic_cql'] = cql_penalty.item() 229 | metrics['critic_cql_logsum1'] = cql_logsumexp1.item() 230 | # metrics['critic_cql_logsum2'] = cql_logsumexp1.item() 231 | metrics['rand_Q1'] = rand_Q1.mean().item() 232 | # metrics['rand_Q2'] = rand_Q2.mean().item() 233 | metrics['sampled_Q1'] = sampled_Q1.mean().item() 234 | # metrics['sampled_Q2'] = sampled_Q2.mean().item() 235 | metrics['critic_grad_norm'] = critic_grad_norm 236 | 237 | return metrics 238 | 239 | def update_actor(self, obs, action, step): 240 | metrics = dict() 241 | 242 | policy = self.actor(obs) 243 | sampled_action = policy.rsample() # (1024, 6) 244 | log_pi = policy.log_prob(sampled_action) # (1024, 6) 245 | 246 | # update lagrange multiplier 247 | alpha_loss = -(self.log_actor_alpha * (log_pi + self.target_entropy).detach()).mean() 248 | self.actor_alpha_opt.zero_grad(set_to_none=True) 249 | alpha_loss.backward() 250 | self.actor_alpha_opt.step() 251 | alpha = self.log_actor_alpha.exp().detach() 252 | 253 | # optimize actor 254 | Q1, Q2 = self.critic(obs, sampled_action) # (1024, 1) 255 | Q = torch.min(Q1, Q2) # (1024, 1) 256 | actor_loss = (alpha * log_pi - Q).mean() # 257 | self.actor_opt.zero_grad(set_to_none=True) 258 | actor_loss.backward() 259 | actor_grad_norm = utils.grad_norm(self.actor.parameters()) 260 | self.actor_opt.step() 261 | 262 | if self.use_tb: 263 | metrics['actor_loss'] = actor_loss.item() 264 | metrics['actor_ent'] = -log_pi.mean().item() 265 | metrics['actor_alpha'] = alpha.item() 266 | metrics['actor_alpha_loss'] = alpha_loss.item() 267 | metrics['actor_mean'] = policy.loc.mean().item() 268 | metrics['actor_std'] = policy.scale.mean().item() 269 | metrics['actor_action'] = sampled_action.mean().item() 270 | metrics['actor_atanh_action'] = utils.atanh(sampled_action).mean().item() 271 | metrics['actor_grad_norm'] = actor_grad_norm 272 | 273 | return metrics 274 | 275 | # TODO: conservative data sharing 276 | def conservative_data_share(self, batch_main, batch_share): 277 | obs_all, action_all, next_obs_all, reward_all, discount_all, obs_k_all = [], [], [], [], [], [] 278 | 279 | # 1. sample examples from the main task 280 | # shape = (512, 24), (512, 6), (512, 1), (512, 1), (512, 24) 281 | obs_m, action_m, reward_m, discount_m, next_obs_m, _ = utils.to_torch(batch_main, self.device) 282 | obs_all.append(obs_m) 283 | action_all.append(action_m) 284 | next_obs_all.append(next_obs_m) 285 | reward_all.append(reward_m) 286 | discount_all.append(discount_m) 287 | 288 | # 2. sample examples from other tasks 289 | # sample 10 times samples, and select the top-0.1 samples. (batch_size_split*10, xx, xx) 290 | # shape = (5120, 24) (5120, 6) (5120, 1) (5120, 1) (5120, 24) 291 | obs, action, reward, discount, next_obs, _ = utils.to_torch(batch_share, self.device) 292 | # print("obs:", obs.shape, action.shape, reward.shape, discount.shape, next_obs.shape) 293 | # calculate the conservative Q value 294 | with torch.no_grad(): 295 | conservative_q_value, _ = self.critic(obs, action) # (5120, 1) 296 | conservative_q_value = conservative_q_value.squeeze().detach().cpu().numpy() 297 | # choose the top 0.1 top_index.shape=(512,) 298 | top_index = np.argpartition(conservative_q_value, -obs_m.shape[0])[-obs_m.shape[0]:] # find the top 0.1 index. (batch_size_split,) 299 | # extract the samples 300 | obs_all.append(obs[top_index]) 301 | action_all.append(action[top_index]) 302 | next_obs_all.append(next_obs[top_index]) 303 | discount_all.append(discount[top_index]) 304 | 305 | # TODO: zero the reward (the only difference between cql-cds) 306 | reward_all.append(torch.zeros(obs_m.shape[0], 1).cuda()) 307 | 308 | return torch.cat(obs_all, dim=0), torch.cat(action_all, dim=0), torch.cat(reward_all, dim=0), \ 309 | torch.cat(discount_all, dim=0), torch.cat(next_obs_all, dim=0) 310 | 311 | def update(self, replay_iter_main, replay_iter_share, step, total_step): 312 | metrics = dict() 313 | 314 | batch_main = next(replay_iter_main) 315 | batch_share = next(replay_iter_share) 316 | 317 | # print("conservative data sharing...") # obs.shape=(1024, 24), action.shape=(1024, 6) reward.shape=(1024, 1) 318 | obs, action, reward, discount, next_obs = self.conservative_data_share(batch_main, batch_share) 319 | 320 | if self.use_tb: 321 | metrics['batch_reward'] = reward.mean().item() 322 | 323 | # update critic 324 | metrics.update( 325 | self.update_critic(obs, action, reward, discount, next_obs, step)) 326 | 327 | # update actor 328 | metrics.update(self.update_actor(obs, action, step)) 329 | 330 | # update critic target 331 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau) 332 | 333 | return metrics 334 | -------------------------------------------------------------------------------- /agent/cql_cds.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | import utils 9 | from dm_control.utils import rewards 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3): 14 | super().__init__() 15 | 16 | self.policy = nn.Sequential( 17 | nn.Linear(obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 18 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), # 新增 19 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 20 | 21 | self.fc_mu = nn.Linear(hidden_dim, action_dim) 22 | self.fc_logstd = nn.Linear(hidden_dim, action_dim) 23 | 24 | def forward(self, obs): 25 | f = self.policy(obs) 26 | mu = self.fc_mu(f) 27 | log_std = self.fc_logstd(f) 28 | 29 | mu = torch.clamp(mu, -5, 5) 30 | 31 | std = log_std.clamp(-5, 2).exp() 32 | dist = utils.SquashedNormal2(mu, std) 33 | return dist 34 | 35 | 36 | class Critic(nn.Module): 37 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3): 38 | super().__init__() 39 | 40 | self.q1_net = nn.Sequential( 41 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 42 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(), 43 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 44 | self.q1_last = nn.Linear(hidden_dim, 1) 45 | 46 | self.q2_net = nn.Sequential( 47 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 48 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(), 49 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 50 | self.q2_last = nn.Linear(hidden_dim, 1) 51 | 52 | def forward(self, obs, action): 53 | obs_action = torch.cat([obs, action], dim=-1) 54 | q1 = self.q1_net(obs_action) 55 | q1 = self.q1_last(q1) 56 | 57 | q2 = self.q2_net(obs_action) 58 | q2 = self.q2_last(q2) 59 | 60 | return q1, q2 61 | 62 | 63 | class CQLCDSAgent: 64 | def __init__(self, 65 | name, 66 | obs_shape, 67 | action_shape, 68 | device, 69 | actor_lr, 70 | critic_lr, 71 | hidden_dim, 72 | critic_target_tau, 73 | nstep, 74 | batch_size, 75 | use_tb, 76 | alpha, 77 | n_samples, 78 | target_cql_penalty, 79 | use_critic_lagrange, 80 | num_expl_steps, 81 | has_next_action=False): 82 | self.num_expl_steps = num_expl_steps 83 | self.action_dim = action_shape[0] 84 | self.hidden_dim = hidden_dim 85 | self.actor_lr = actor_lr 86 | self.critic_lr = critic_lr 87 | self.device = device 88 | self.critic_target_tau = critic_target_tau 89 | self.use_tb = use_tb 90 | self.use_critic_lagrange = use_critic_lagrange 91 | self.target_cql_penalty = target_cql_penalty 92 | 93 | self.alpha = alpha 94 | self.n_samples = n_samples 95 | 96 | state_dim = obs_shape[0] 97 | action_dim = action_shape[0] 98 | 99 | # models 100 | self.actor = Actor(state_dim, action_dim, hidden_dim).to(device) 101 | self.critic = Critic(state_dim, action_dim, hidden_dim).to(device) 102 | self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(device) 103 | self.critic_target.load_state_dict(self.critic.state_dict()) 104 | 105 | # lagrange multipliers 106 | self.target_entropy = -self.action_dim 107 | self.log_actor_alpha = torch.zeros(1, requires_grad=True, device=device) 108 | self.log_critic_alpha = torch.zeros(1, requires_grad=True, device=device) 109 | 110 | # optimizers 111 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) 112 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) 113 | self.actor_alpha_opt = torch.optim.Adam([self.log_actor_alpha], lr=actor_lr) 114 | self.critic_alpha_opt = torch.optim.Adam([self.log_critic_alpha], lr=actor_lr) 115 | 116 | self.train() 117 | # self.critic_target.train() 118 | 119 | def train(self, training=True): 120 | self.training = training 121 | self.actor.train(training) 122 | self.critic.train(training) 123 | 124 | def act(self, obs, step, eval_mode): 125 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 126 | policy = self.actor(obs) 127 | if eval_mode: 128 | action = policy.mean 129 | else: 130 | action = policy.sample() 131 | if step < self.num_expl_steps: 132 | action.uniform_(-1.0, 1.0) 133 | return action.cpu().numpy()[0] 134 | 135 | def _repeated_critic_apply(self, obs, actions): 136 | """ 137 | obs is (batch_size, obs_dim) 138 | actions is (n_samples, batch_size, action_dim) 139 | 140 | output tensors are (n_samples, batch_size, 1) 141 | """ 142 | batch_size = obs.shape[0] 143 | n_samples = actions.shape[0] 144 | 145 | reshaped_actions = actions.reshape((n_samples * batch_size, -1)) 146 | repeated_obs = obs.unsqueeze(0).repeat((n_samples, 1, 1)) 147 | repeated_obs = repeated_obs.reshape((n_samples * batch_size, -1)) 148 | 149 | Q1, Q2 = self.critic(repeated_obs, reshaped_actions) 150 | Q1 = Q1.reshape((n_samples, batch_size, 1)) 151 | Q2 = Q2.reshape((n_samples, batch_size, 1)) 152 | 153 | return Q1, Q2 154 | 155 | def update_critic(self, obs, action, reward, discount, next_obs, step): 156 | metrics = dict() 157 | 158 | # Compute standard SAC loss 159 | with torch.no_grad(): 160 | dist = self.actor(next_obs) # SquashedNormal分布 161 | sampled_next_action = dist.sample() # (1024, act_dim) 162 | # print("sampled_next_action:", sampled_next_action.shape) 163 | target_Q1, target_Q2 = self.critic_target(next_obs, sampled_next_action) # (1024,1), (1024,1) 164 | target_V = torch.min(target_Q1, target_Q2) # (1024,1) 165 | target_Q = reward + (discount * target_V) # (1024,1) 166 | 167 | Q1, Q2 = self.critic(obs, action) 168 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) # 标量 169 | 170 | # Add CQL penalty 171 | with torch.no_grad(): 172 | random_actions = torch.FloatTensor(self.n_samples, Q1.shape[0], 173 | action.shape[-1]).uniform_(-1, 1).to(self.device) # (n_samples, 1024, act_dim) 174 | sampled_actions = self.actor(obs).sample( 175 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim) 176 | # print("sampled_actions:", sampled_actions.shape) 177 | next_sampled_actions = self.actor(next_obs).sample( 178 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim) 179 | # print("next_sampled_actions:", next_sampled_actions.shape) 180 | 181 | rand_Q1, rand_Q2 = self._repeated_critic_apply(obs, random_actions) # (n_samples, 1024, 1) 182 | sampled_Q1, sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1) 183 | obs, sampled_actions) 184 | next_sampled_Q1, next_sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1) 185 | obs, next_sampled_actions) 186 | 187 | # situation 1 188 | cat_Q1 = torch.cat([rand_Q1, sampled_Q1, next_sampled_Q1, 189 | Q1.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1) 190 | cat_Q2 = torch.cat([rand_Q2, sampled_Q2, next_sampled_Q2, 191 | Q2.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1) 192 | 193 | assert (not torch.isnan(cat_Q1).any()) and (not torch.isnan(cat_Q2).any()) 194 | 195 | cql_logsumexp1 = torch.logsumexp(cat_Q1, dim=0).mean() 196 | cql_logsumexp2 = torch.logsumexp(cat_Q2, dim=0).mean() 197 | cql_logsumexp = cql_logsumexp1 + cql_logsumexp2 198 | 199 | cql_penalty = cql_logsumexp - (Q1 + Q2).mean() # 标量 200 | 201 | # Update lagrange multiplier 202 | if self.use_critic_lagrange: 203 | alpha = torch.clamp(self.log_critic_alpha.exp(), min=0.0, max=1000000.0) 204 | alpha_loss = -0.5 * alpha * (cql_penalty - self.target_cql_penalty) 205 | 206 | self.critic_alpha_opt.zero_grad() 207 | alpha_loss.backward(retain_graph=True) 208 | self.critic_alpha_opt.step() 209 | alpha = torch.clamp(self.log_critic_alpha.exp(), 210 | min=0.0, 211 | max=1000000.0).detach() 212 | else: 213 | alpha = self.alpha 214 | 215 | # Combine losses 216 | critic_loss = critic_loss + alpha * cql_penalty 217 | 218 | # optimize critic 219 | self.critic_opt.zero_grad(set_to_none=True) 220 | critic_loss.backward() 221 | critic_grad_norm = utils.grad_norm(self.critic.parameters()) 222 | self.critic_opt.step() 223 | 224 | if self.use_tb: 225 | metrics['critic_target_q'] = target_Q.mean().item() 226 | metrics['critic_q1'] = Q1.mean().item() 227 | # metrics['critic_q2'] = Q2.mean().item() 228 | metrics['critic_loss'] = critic_loss.item() 229 | metrics['critic_cql'] = cql_penalty.item() 230 | metrics['critic_cql_logsum1'] = cql_logsumexp1.item() 231 | # metrics['critic_cql_logsum2'] = cql_logsumexp1.item() 232 | metrics['rand_Q1'] = rand_Q1.mean().item() 233 | # metrics['rand_Q2'] = rand_Q2.mean().item() 234 | metrics['sampled_Q1'] = sampled_Q1.mean().item() 235 | # metrics['sampled_Q2'] = sampled_Q2.mean().item() 236 | metrics['critic_grad_norm'] = critic_grad_norm 237 | 238 | return metrics 239 | 240 | def update_actor(self, obs, action, step): 241 | metrics = dict() 242 | 243 | policy = self.actor(obs) 244 | sampled_action = policy.rsample() # (1024, 6) 245 | # print("sampled_action:", sampled_action.shape) 246 | log_pi = policy.log_prob(sampled_action) # (1024, 6) 247 | 248 | # update lagrange multiplier 249 | alpha_loss = -(self.log_actor_alpha * (log_pi + self.target_entropy).detach()).mean() 250 | self.actor_alpha_opt.zero_grad(set_to_none=True) 251 | alpha_loss.backward() 252 | self.actor_alpha_opt.step() 253 | alpha = self.log_actor_alpha.exp().detach() 254 | 255 | # optimize actor 256 | Q1, Q2 = self.critic(obs, sampled_action) # (1024, 1) 257 | Q = torch.min(Q1, Q2) # (1024, 1) 258 | actor_loss = (alpha * log_pi - Q).mean() # 标量 259 | self.actor_opt.zero_grad(set_to_none=True) 260 | actor_loss.backward() 261 | # torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 5) 262 | actor_grad_norm = utils.grad_norm(self.actor.parameters()) 263 | self.actor_opt.step() 264 | 265 | if self.use_tb: 266 | metrics['actor_loss'] = actor_loss.item() 267 | metrics['actor_ent'] = -log_pi.mean().item() 268 | metrics['actor_alpha'] = alpha.item() 269 | metrics['actor_alpha_loss'] = alpha_loss.item() 270 | metrics['actor_mean'] = policy.loc.mean().item() 271 | metrics['actor_std'] = policy.scale.mean().item() 272 | metrics['actor_action'] = sampled_action.mean().item() 273 | metrics['actor_atanh_action'] = utils.atanh(sampled_action).mean().item() 274 | metrics['actor_grad_norm'] = actor_grad_norm 275 | 276 | return metrics 277 | 278 | # TODO: conservative data sharing 279 | def conservative_data_share(self, batch_main, batch_share): 280 | obs_all, action_all, next_obs_all, reward_all, discount_all, obs_k_all = [], [], [], [], [], [] 281 | 282 | # 1. sample examples from the main task 283 | # shape = (512, 24), (512, 6), (512, 1), (512, 1), (512, 24) 284 | obs_m, action_m, reward_m, discount_m, next_obs_m, _ = utils.to_torch(batch_main, self.device) 285 | # print("main samples:", obs_m.shape, action_m.shape, reward_m.shape, discount_m.shape, next_obs_m.shape, eps_flag_m.shape) 286 | obs_all.append(obs_m) 287 | action_all.append(action_m) 288 | next_obs_all.append(next_obs_m) 289 | reward_all.append(reward_m) 290 | discount_all.append(discount_m) 291 | 292 | # 2. sample examples from other tasks 293 | # sample 10 times samples, and select the top-0.1 samples. (batch_size_split*10, xx, xx) 294 | # shape = (5120, 24) (5120, 6) (5120, 1) (5120, 1) (5120, 24) 295 | obs, action, reward, discount, next_obs, _ = utils.to_torch(batch_share, self.device) 296 | # print("obs:", obs.shape, action.shape, reward.shape, discount.shape, next_obs.shape) 297 | # calculate the conservative Q value 298 | with torch.no_grad(): 299 | conservative_q_value, _ = self.critic(obs, action) # (5120, 1) 300 | conservative_q_value = conservative_q_value.squeeze().detach().cpu().numpy() 301 | # choose the top 0.1 top_index.shape=(512,) 302 | top_index = np.argpartition(conservative_q_value, -obs_m.shape[0])[-obs_m.shape[0]:] # find the top 0.1 index. (batch_size_split,) 303 | # extract the samples 304 | obs_all.append(obs[top_index]) 305 | action_all.append(action[top_index]) 306 | next_obs_all.append(next_obs[top_index]) 307 | reward_all.append(reward[top_index]) 308 | discount_all.append(discount[top_index]) 309 | 310 | return torch.cat(obs_all, dim=0), torch.cat(action_all, dim=0), torch.cat(reward_all, dim=0), \ 311 | torch.cat(discount_all, dim=0), torch.cat(next_obs_all, dim=0) 312 | 313 | def update(self, replay_iter_main, replay_iter_share, step, total_step): 314 | metrics = dict() 315 | 316 | batch_main = next(replay_iter_main) 317 | batch_share = next(replay_iter_share) 318 | 319 | # print("conservative data sharing...") # obs.shape=(1024, 24), action.shape=(1024, 6) reward.shape=(1024, 1) 320 | obs, action, reward, discount, next_obs = self.conservative_data_share(batch_main, batch_share) 321 | 322 | if self.use_tb: 323 | metrics['batch_reward'] = reward.mean().item() 324 | 325 | # update critic 326 | metrics.update( 327 | self.update_critic(obs, action, reward, discount, next_obs, step)) 328 | 329 | # update actor 330 | metrics.update(self.update_actor(obs, action, step)) 331 | 332 | # update critic target 333 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau) 334 | 335 | return metrics 336 | -------------------------------------------------------------------------------- /agent/cql.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | import utils 9 | from dm_control.utils import rewards 10 | 11 | 12 | class Actor(nn.Module): 13 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3): 14 | super().__init__() 15 | self.policy = nn.Sequential( 16 | nn.Linear(obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 17 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), # 新增 18 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 19 | 20 | self.fc_mu = nn.Linear(hidden_dim, action_dim) 21 | self.fc_logstd = nn.Linear(hidden_dim, action_dim) 22 | 23 | def forward(self, obs): 24 | f = self.policy(obs) 25 | mu = self.fc_mu(f) 26 | log_std = self.fc_logstd(f) 27 | mu = torch.clamp(mu, -5, 5) 28 | 29 | std = log_std.clamp(-5, 2).exp() 30 | dist = utils.SquashedNormal2(mu, std) 31 | return dist 32 | 33 | 34 | class Critic(nn.Module): 35 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3): 36 | super().__init__() 37 | 38 | self.q1_net = nn.Sequential( 39 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 40 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(), 41 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 42 | self.q1_last = nn.Linear(hidden_dim, 1) 43 | 44 | self.q2_net = nn.Sequential( 45 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), 46 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(), 47 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()) 48 | self.q2_last = nn.Linear(hidden_dim, 1) 49 | 50 | def forward(self, obs, action): 51 | obs_action = torch.cat([obs, action], dim=-1) 52 | q1 = self.q1_net(obs_action) 53 | q1 = self.q1_last(q1) 54 | 55 | q2 = self.q2_net(obs_action) 56 | q2 = self.q2_last(q2) 57 | 58 | return q1, q2 59 | 60 | 61 | class CQLAgent: 62 | def __init__(self, 63 | name, 64 | obs_shape, 65 | action_shape, 66 | device, 67 | actor_lr, 68 | critic_lr, 69 | hidden_dim, 70 | critic_target_tau, 71 | nstep, 72 | batch_size, 73 | use_tb, 74 | alpha, 75 | n_samples, 76 | target_cql_penalty, 77 | use_critic_lagrange, 78 | num_expl_steps, 79 | has_next_action=False): 80 | self.num_expl_steps = num_expl_steps 81 | self.action_dim = action_shape[0] 82 | self.hidden_dim = hidden_dim 83 | self.actor_lr = actor_lr 84 | self.critic_lr = critic_lr 85 | self.device = device 86 | self.critic_target_tau = critic_target_tau 87 | self.use_tb = use_tb 88 | self.use_critic_lagrange = use_critic_lagrange 89 | self.target_cql_penalty = target_cql_penalty 90 | 91 | self.alpha = alpha 92 | self.n_samples = n_samples 93 | 94 | state_dim = obs_shape[0] 95 | action_dim = action_shape[0] 96 | 97 | # models 98 | self.actor = Actor(state_dim, action_dim, hidden_dim).to(device) 99 | self.critic = Critic(state_dim, action_dim, hidden_dim).to(device) 100 | self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(device) 101 | self.critic_target.load_state_dict(self.critic.state_dict()) 102 | 103 | # lagrange multipliers 104 | self.target_entropy = -self.action_dim 105 | self.log_actor_alpha = torch.zeros(1, requires_grad=True, device=device) 106 | self.log_critic_alpha = torch.zeros(1, requires_grad=True, device=device) 107 | 108 | # optimizers 109 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) 110 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) 111 | self.actor_alpha_opt = torch.optim.Adam([self.log_actor_alpha], lr=actor_lr) 112 | self.critic_alpha_opt = torch.optim.Adam([self.log_critic_alpha], lr=actor_lr) 113 | 114 | self.train() 115 | 116 | def train(self, training=True): 117 | self.training = training 118 | self.actor.train(training) 119 | self.critic.train(training) 120 | 121 | def act(self, obs, step, eval_mode): 122 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) 123 | policy = self.actor(obs) 124 | if eval_mode: 125 | action = policy.mean 126 | else: 127 | action = policy.sample() 128 | if step < self.num_expl_steps: 129 | action.uniform_(-1.0, 1.0) 130 | return action.cpu().numpy()[0] 131 | 132 | def _repeated_critic_apply(self, obs, actions): 133 | """ 134 | obs is (batch_size, obs_dim) 135 | actions is (n_samples, batch_size, action_dim) 136 | 137 | output tensors are (n_samples, batch_size, 1) 138 | """ 139 | batch_size = obs.shape[0] 140 | n_samples = actions.shape[0] 141 | 142 | reshaped_actions = actions.reshape((n_samples * batch_size, -1)) 143 | repeated_obs = obs.unsqueeze(0).repeat((n_samples, 1, 1)) 144 | repeated_obs = repeated_obs.reshape((n_samples * batch_size, -1)) 145 | 146 | Q1, Q2 = self.critic(repeated_obs, reshaped_actions) 147 | Q1 = Q1.reshape((n_samples, batch_size, 1)) 148 | Q2 = Q2.reshape((n_samples, batch_size, 1)) 149 | 150 | return Q1, Q2 151 | 152 | def update_critic(self, obs, action, reward, discount, next_obs, step): 153 | metrics = dict() 154 | 155 | # Compute standard SAC loss 156 | with torch.no_grad(): 157 | dist = self.actor(next_obs) # SquashedNormal distribution 158 | sampled_next_action = dist.sample() # (1024, act_dim) 159 | # print("sampled_next_action:", sampled_next_action.shape) 160 | target_Q1, target_Q2 = self.critic_target(next_obs, sampled_next_action) # (1024,1), (1024,1) 161 | target_V = torch.min(target_Q1, target_Q2) # (1024,1) 162 | target_Q = reward + (discount * target_V) # (1024,1) 163 | 164 | Q1, Q2 = self.critic(obs, action) 165 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) # scalar 166 | 167 | # Add CQL penalty 168 | with torch.no_grad(): 169 | random_actions = torch.FloatTensor(self.n_samples, Q1.shape[0], 170 | action.shape[-1]).uniform_(-1, 1).to(self.device) # (n_samples, 1024, act_dim) 171 | sampled_actions = self.actor(obs).sample( 172 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim) 173 | # print("sampled_actions:", sampled_actions.shape) 174 | next_sampled_actions = self.actor(next_obs).sample( 175 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim) 176 | # print("next_sampled_actions:", next_sampled_actions.shape) 177 | 178 | rand_Q1, rand_Q2 = self._repeated_critic_apply(obs, random_actions) # (n_samples, 1024, 1) 179 | sampled_Q1, sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1) 180 | obs, sampled_actions) 181 | next_sampled_Q1, next_sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1) 182 | obs, next_sampled_actions) 183 | 184 | # situation 1 185 | cat_Q1 = torch.cat([rand_Q1, sampled_Q1, next_sampled_Q1, 186 | Q1.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1) 187 | cat_Q2 = torch.cat([rand_Q2, sampled_Q2, next_sampled_Q2, 188 | Q2.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1) 189 | 190 | 191 | # situation 2 192 | # cat_Q1 = torch.cat([rand_Q1 - np.log(0.5 ** self.action_dim), 193 | # sampled_Q1 - self.actor(obs).log_prob(sampled_actions).unsqueeze(-1).detach(), 194 | # next_sampled_Q1 - self.actor(obs).log_prob(next_sampled_actions).unsqueeze(-1).detach()], dim=0) # (1+3*n_samples, 1024, 1) 195 | # cat_Q2 = torch.cat([rand_Q2 - np.log(0.5 ** self.action_dim), 196 | # sampled_Q2 - self.actor(obs).log_prob(sampled_actions).unsqueeze(-1).detach(), 197 | # next_sampled_Q2 - self.actor(obs).log_prob(next_sampled_actions).unsqueeze(-1).detach()], dim=0) # (1+3*n_samples, 1024, 1) 198 | 199 | # cat_Q1 = torch.cat([rand_Q1 - np.log(0.5 ** self.action_dim), 200 | # sampled_Q1 - self.actor(obs).log_prob(sampled_actions).sum(-1, keepdims=True).detach(), 201 | # next_sampled_Q1 - self.actor(obs).log_prob(next_sampled_actions).sum(-1, keepdims=True).detach()], dim=0) # (1+3*n_samples, 1024, 1) 202 | # cat_Q2 = torch.cat([rand_Q2 - np.log(0.5 ** self.action_dim), 203 | # sampled_Q2 - self.actor(obs).log_prob(sampled_actions).sum(-1, keepdims=True).detach(), 204 | # next_sampled_Q2 - self.actor(obs).log_prob(next_sampled_actions).sum(-1, keepdims=True).detach()], dim=0) # (1+3*n_samples, 1024, 1) 205 | 206 | # print("cat Q1:", cat_Q1.shape, cat_Q2.shape) 207 | # if torch.isnan(cat_Q1).any(): 208 | # print("in cql 1:", torch.isnan(rand_Q1).any(), torch.isnan(sampled_Q1).any(), torch.isnan(rand_Q1).any(), torch.isnan(next_sampled_Q1).any()) 209 | # print("in cql 2:", torch.isnan(self.actor(obs).log_prob(sampled_actions)).any()) 210 | # print("in cql 3:", torch.isnan(self.actor(obs).log_prob(next_sampled_actions)).any()) 211 | 212 | assert (not torch.isnan(cat_Q1).any()) and (not torch.isnan(cat_Q2).any()) 213 | 214 | cql_logsumexp1 = torch.logsumexp(cat_Q1, dim=0).mean() 215 | cql_logsumexp2 = torch.logsumexp(cat_Q2, dim=0).mean() 216 | cql_logsumexp = cql_logsumexp1 + cql_logsumexp2 217 | # print("train 1:", rand_Q1.mean(), np.log(0.5 ** self.action_dim)) 218 | # dist_test = self.actor(obs) 219 | # print("train 2:", dist_test.loc.shape, dist_test.scale.shape, torch.isnan(dist_test.loc).any(), torch.isnan(dist_test.scale).any()) 220 | # print("train 3:", sampled_Q1.mean(), self.actor(obs).log_prob(sampled_actions).detach().mean()) 221 | # print("train 4:", next_sampled_Q1.mean(), self.actor(obs).log_prob(next_sampled_actions).detach().mean()) 222 | 223 | cql_penalty = cql_logsumexp - (Q1 + Q2).mean() # 标量 224 | 225 | # Update lagrange multiplier 226 | if self.use_critic_lagrange: 227 | alpha = torch.clamp(self.log_critic_alpha.exp(), min=0.0, max=1000000.0) 228 | alpha_loss = -0.5 * alpha * (cql_penalty - self.target_cql_penalty) 229 | 230 | self.critic_alpha_opt.zero_grad() 231 | alpha_loss.backward(retain_graph=True) 232 | self.critic_alpha_opt.step() 233 | alpha = torch.clamp(self.log_critic_alpha.exp(), 234 | min=0.0, 235 | max=1000000.0).detach() 236 | else: 237 | alpha = self.alpha 238 | 239 | # Combine losses 240 | critic_loss = critic_loss + alpha * cql_penalty 241 | 242 | # optimize critic 243 | self.critic_opt.zero_grad(set_to_none=True) 244 | critic_loss.backward() 245 | # torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 50) 246 | critic_grad_norm = utils.grad_norm(self.critic.parameters()) 247 | self.critic_opt.step() 248 | 249 | if self.use_tb: 250 | metrics['critic_target_q'] = target_Q.mean().item() 251 | metrics['critic_q1'] = Q1.mean().item() 252 | # metrics['critic_q2'] = Q2.mean().item() 253 | metrics['critic_loss'] = critic_loss.item() 254 | metrics['critic_cql'] = cql_penalty.item() 255 | metrics['critic_cql_logsum1'] = cql_logsumexp1.item() 256 | # metrics['critic_cql_logsum2'] = cql_logsumexp1.item() 257 | metrics['rand_Q1'] = rand_Q1.mean().item() 258 | # metrics['rand_Q2'] = rand_Q2.mean().item() 259 | metrics['sampled_Q1'] = sampled_Q1.mean().item() 260 | # metrics['sampled_Q2'] = sampled_Q2.mean().item() 261 | metrics['critic_grad_norm'] = critic_grad_norm 262 | 263 | return metrics 264 | 265 | def update_actor(self, obs, action, step): 266 | metrics = dict() 267 | 268 | policy = self.actor(obs) 269 | sampled_action = policy.rsample() # (1024, 6) 270 | # print("sampled_action:", sampled_action.shape) 271 | log_pi = policy.log_prob(sampled_action) # (1024, 6) 272 | 273 | # update lagrange multiplier 274 | alpha_loss = -(self.log_actor_alpha * (log_pi + self.target_entropy).detach()).mean() 275 | self.actor_alpha_opt.zero_grad(set_to_none=True) 276 | alpha_loss.backward() 277 | self.actor_alpha_opt.step() 278 | alpha = self.log_actor_alpha.exp().detach() 279 | 280 | # optimize actor 281 | Q1, Q2 = self.critic(obs, sampled_action) # (1024, 1) 282 | Q = torch.min(Q1, Q2) # (1024, 1) 283 | actor_loss = (alpha * log_pi - Q).mean() # 标量 284 | self.actor_opt.zero_grad(set_to_none=True) 285 | actor_loss.backward() 286 | # torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 5) 287 | actor_grad_norm = utils.grad_norm(self.actor.parameters()) 288 | self.actor_opt.step() 289 | 290 | if self.use_tb: 291 | metrics['actor_loss'] = actor_loss.item() 292 | metrics['actor_ent'] = -log_pi.mean().item() 293 | metrics['actor_alpha'] = alpha.item() 294 | metrics['actor_alpha_loss'] = alpha_loss.item() 295 | metrics['actor_mean'] = policy.loc.mean().item() 296 | metrics['actor_std'] = policy.scale.mean().item() 297 | metrics['actor_action'] = sampled_action.mean().item() 298 | metrics['actor_atanh_action'] = utils.atanh(sampled_action).mean().item() 299 | metrics['actor_grad_norm'] = actor_grad_norm 300 | 301 | return metrics 302 | 303 | def update(self, replay_iter, step, total_step): 304 | metrics = dict() 305 | 306 | batch = next(replay_iter) 307 | # obs.shape=(1024,obs_dim), action.shape=(1024,1), reward.shape=(1024,1), discount.shape=(1024,1) 308 | obs, action, reward, discount, next_obs, _ = utils.to_torch( 309 | batch, self.device) 310 | 311 | if self.use_tb: 312 | metrics['batch_reward'] = reward.mean().item() 313 | 314 | # update critic 315 | metrics.update( 316 | self.update_critic(obs, action, reward, discount, next_obs, step)) 317 | 318 | # update actor 319 | metrics.update(self.update_actor(obs, action, step)) 320 | 321 | # update critic target 322 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau) 323 | 324 | return metrics 325 | --------------------------------------------------------------------------------