├── agent
├── env.png
├── bc.yaml
├── crr.yaml
├── td3.yaml
├── td3_bc.yaml
├── ddpg.yaml
├── cql.yaml
├── cql_cdsz.yaml
├── cql_cds.yaml
├── pbrl.yaml
├── bc.py
├── td3_bc.py
├── td3.py
├── crr.py
├── ddpg.py
├── pbrl.py
├── cql_cdsz.py
├── cql_cds.py
└── cql.py
├── custom_dmc_tasks
├── common
│ ├── visual.xml
│ ├── skybox.xml
│ └── materials.xml
├── test.py
├── __init__.py
├── point_mass_maze_reach_bottom_left.xml
├── point_mass_maze_reach_top_left.xml
├── point_mass_maze_reach_top_right.xml
├── point_mass_maze_reach_bottom_right.xml
├── hopper.xml
├── walker.xml
├── test.xml
├── cheetah.xml
├── cheetah.py
├── point_mass_maze.py
├── hopper.py
├── jaco.py
└── walker.py
├── conda_env.yml
├── LICENSE
├── task.json
├── config_cds.yaml
├── config_single.yaml
├── config.yaml
├── collect_data.yaml
├── video.py
├── train_offline_single.py
├── README.md
├── train_offline_share.py
├── replay_buffer.py
├── train_offline_cds.py
├── logger.py
├── replay_buffer_collect.py
└── collect_data.py
/agent/env.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Baichenjia/UTDS/HEAD/agent/env.png
--------------------------------------------------------------------------------
/custom_dmc_tasks/common/visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/common/skybox.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/agent/bc.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.bc.BCAgent
3 | name: bc
4 | obs_shape: ??? # to be specified later
5 | action_shape: ??? # to be specified later
6 | device: ${device}
7 | lr: 1e-4
8 | use_tb: ${use_tb}
9 | hidden_dim: 1024
10 | batch_size: 1024 # 256 for pixels
11 | has_next_action: False
12 | stddev_schedule: 0.2
--------------------------------------------------------------------------------
/agent/crr.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.crr.CRRAgent
3 | name: crr
4 | obs_shape: ??? # to be specified later
5 | action_shape: ??? # to be specified later
6 | device: ${device}
7 | lr: 1e-4
8 | critic_target_tau: 0.01
9 | use_tb: ${use_tb}
10 | hidden_dim: 1024
11 | stddev_schedule: 0.2
12 | stddev_clip: 0.3
13 | nstep: 1
14 | batch_size: 1024 # 256 for pixels
15 | num_value_samples: 10
16 | weight_func: indicator
17 | has_next_action: False
--------------------------------------------------------------------------------
/agent/td3.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.td3.TD3Agent
3 | name: td3
4 | #obs_type: ??? # to be specified later
5 | obs_shape: ??? # to be specified later
6 | action_shape: ??? # to be specified later
7 | device: ${device}
8 | lr: 1e-4
9 | critic_target_tau: 0.01
10 | use_tb: ${use_tb}
11 | hidden_dim: 1024
12 | stddev_schedule: 0.2
13 | stddev_clip: 0.3
14 | nstep: 1
15 | batch_size: 1024 # 256 for pixels
16 | has_next_action: False
17 | num_expl_steps: ??? # to be specified later
18 |
--------------------------------------------------------------------------------
/agent/td3_bc.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.td3_bc.TD3BCAgent
3 | name: td3_bc
4 | obs_shape: ??? # to be specified later
5 | action_shape: ??? # to be specified later
6 | device: ${device}
7 | alpha: 2.5
8 | lr: 1e-4
9 | critic_target_tau: 0.005
10 | actor_target_tau: 0.005
11 | policy_freq: 2
12 | use_tb: ${use_tb}
13 | hidden_dim: 1024
14 | #stddev_schedule: 0.2
15 | #stddev_clip: 0.3
16 | policy_noise: 0.2
17 | noise_clip: 0.5
18 | #nstep: 1
19 | batch_size: 1024
20 | #has_next_action: False
21 | num_expl_steps: ??? # to be specified later
22 |
--------------------------------------------------------------------------------
/agent/ddpg.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.ddpg.DDPGAgent
3 | name: ddpg
4 | #reward_free: ${reward_free}
5 | obs_type: ??? # to be specified later
6 | obs_shape: ??? # to be specified later
7 | action_shape: ??? # to be specified later
8 | device: ${device}
9 | lr: 1e-4
10 | critic_target_tau: 0.01
11 | update_every_steps: 2
12 | use_tb: ${use_tb}
13 | use_wandb: ${use_wandb}
14 | num_expl_steps: ??? # to be specified later
15 | hidden_dim: 1024
16 | feature_dim: 50
17 | stddev_schedule: 0.2
18 | stddev_clip: 0.3
19 | nstep: 3
20 | batch_size: 1024 # 256 for pixels
21 | init_critic: true
22 | update_encoder: ${update_encoder}
23 |
--------------------------------------------------------------------------------
/agent/cql.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.cql.CQLAgent
3 | name: cql
4 | obs_shape: ??? # to be specified later
5 | action_shape: ??? # to be specified later
6 | device: ${device}
7 | #lr: 1e-4
8 | actor_lr: 1e-4
9 | critic_lr: 3e-4
10 | critic_target_tau: 0.01
11 |
12 | n_samples: 3
13 | use_critic_lagrange: False
14 | alpha: 50 # used if use_critic_lagrange is False
15 | target_cql_penalty: 5.0 # used if use_critic_lagrange is True
16 |
17 | use_tb: True
18 | hidden_dim: 256
19 | #stddev_schedule: 0.2
20 | #stddev_clip: 0.3
21 | nstep: 1
22 | batch_size: 1024
23 | has_next_action: False
24 |
25 | num_expl_steps: ??? # to be specified later
--------------------------------------------------------------------------------
/agent/cql_cdsz.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.cql_cdsz.CQLCDSZeroAgent
3 | name: cql_cdsz
4 | obs_shape: ??? # to be specified later
5 | action_shape: ??? # to be specified later
6 | device: ${device}
7 | #lr: 1e-4
8 | actor_lr: 1e-4
9 | critic_lr: 3e-4
10 | critic_target_tau: 0.01
11 |
12 | n_samples: 3
13 | use_critic_lagrange: False
14 | alpha: 50 # used if use_critic_lagrange is False
15 | target_cql_penalty: 5.0 # used if use_critic_lagrange is True
16 |
17 | use_tb: True
18 | hidden_dim: 256
19 | #stddev_schedule: 0.2
20 | #stddev_clip: 0.3
21 | nstep: 1
22 | batch_size: 1024
23 | has_next_action: False
24 |
25 | num_expl_steps: ??? # to be specified later
--------------------------------------------------------------------------------
/agent/cql_cds.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.cql_cds.CQLCDSAgent
3 | name: cql_cds
4 | obs_shape: ??? # to be specified later
5 | action_shape: ??? # to be specified later
6 | device: ${device}
7 | #lr: 1e-4
8 | actor_lr: 1e-4
9 | critic_lr: 3e-4
10 | critic_target_tau: 0.01
11 |
12 | n_samples: 3
13 | use_critic_lagrange: False
14 | alpha: 50 # used if use_critic_lagrange is False
15 | target_cql_penalty: 5.0 # used if use_critic_lagrange is True
16 |
17 | use_tb: True
18 | hidden_dim: 256 # 1024
19 | #stddev_schedule: 0.2
20 | #stddev_clip: 0.3
21 | nstep: 1
22 | batch_size: 1024 # 1024
23 | has_next_action: False
24 |
25 | num_expl_steps: ??? # to be specified later
--------------------------------------------------------------------------------
/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: utds
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.8
6 | - pip=21.1.3
7 | - numpy=1.19.2
8 | - absl-py=0.13.0
9 | - pyparsing=2.4.7
10 | - jupyterlab=3.0.14
11 | - scikit-image=0.18.1
12 | - nvidia::cudatoolkit=11.1
13 | - pytorch::pytorch
14 | - pytorch::torchvision
15 | - pytorch::torchaudio
16 | - pip:
17 | - termcolor==1.1.0
18 | - git+git://github.com/deepmind/dm_control.git
19 | - tb-nightly
20 | - imageio==2.9.0
21 | - imageio-ffmpeg==0.4.4
22 | - hydra-core==1.1.0
23 | - hydra-submitit-launcher==1.1.5
24 | - pandas==1.3.0
25 | - ipdb==0.13.9
26 | - yapf==0.31.0
27 | - sklearn==0.0
28 | - matplotlib==3.4.2
29 | - opencv-python==4.5.3.56
30 | - moviepy==1.0.3
31 |
--------------------------------------------------------------------------------
/agent/pbrl.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.pbrl.PBRLAgent
3 | name: pbrl
4 | obs_shape: ??? # to be specified later
5 | action_shape: ??? # to be specified later
6 | device: ${device}
7 | #alpha: 2.5
8 | lr: 1e-4
9 | critic_target_tau: 0.005
10 | actor_target_tau: 0.005
11 | policy_freq: 2
12 | use_tb: True
13 | hidden_dim: 1024
14 | #stddev_schedule: 0.2
15 | #stddev_clip: 0.3
16 | policy_noise: 0.2
17 | noise_clip: 0.5
18 | #nstep: 1
19 | batch_size: 1024
20 | #has_next_action: False
21 | num_expl_steps: ??? # to be specified later
22 |
23 | # PBRL
24 | num_random: 3
25 | ucb_ratio_in: 0.001
26 | ensemble: 5
27 | ood_noise: 0.01 # action noise for sampling
28 |
29 | ucb_ratio_ood_init: 3.0 # 3.0
30 | ucb_ratio_ood_min: 0.1 # 0.1
31 | ood_decay_factor: 0.99995 # 0.99995 ucb ratio decay factor.
32 |
33 | share_ratio: 1.5 # penalty ratio for shared dataset
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/task.json:
--------------------------------------------------------------------------------
1 | {
2 | "walker":
3 | {
4 | "stand": {"expert": 986.8, "medium": 458.85},
5 | "walk": {"expert": 980.0, "medium": 681.4},
6 | "run": {"expert": 842.6, "medium": 411.0},
7 | "filp": {"expert": 839.5, "medium": 583.5}
8 | },
9 | "hopper":
10 | {
11 | "stand": 1000,
12 | "hop": 1000,
13 | "hop_backward": 1000,
14 | "flip": 1000,
15 | "flip_backward": 1000
16 | },
17 | "cheetah":
18 | {
19 | "run": 1000,
20 | "run_backward": 1000,
21 | "flip": 1000
22 | },
23 | "quadruped":
24 | {
25 | "jump": 1000,
26 | "run": 1000,
27 | "roll": 1000,
28 | "roll_fast": 1000,
29 | "stand": 1000,
30 | "walk": 1000,
31 | "escape": 1000,
32 | "fetch": 1000
33 | },
34 | "jaco":
35 | {
36 | "reach_top_left": 1000,
37 | "reach_top_right": 1000,
38 | "reach_bottom_left": 1000,
39 | "reach_bottom_right": 1000
40 | },
41 | "point_mass_maze":
42 | {
43 | "reach_top_left": 1000,
44 | "reach_top_right": 1000,
45 | "reach_bottom_left": 1000,
46 | "reach_bottom_right": 1000
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/common/materials.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/config_cds.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - agent: cql_cds # cql_cds, cql_cdsz
3 | - override hydra/launcher: submitit_local
4 |
5 | # unsupervised exploration
6 | # expl_agent: td3
7 | # task settings
8 | task: walker_walk # main task to train (relable other datasets to this task)
9 | share_task: [walker_walk, walker_run] # task for data sharing
10 | data_type: [medium, medium-replay] # dataset for data sharing (corresponding each share_task)
11 |
12 | discount: 0.99
13 | # train settings
14 | num_grad_steps: 1000000
15 | log_every_steps: 1000
16 | # eval
17 | eval_every_steps: 10000
18 | num_eval_episodes: 10
19 | # dataset
20 | replay_buffer_dir: ../../collect # make sure to update this if you change hydra run dir
21 | replay_buffer_size: 10000000 # max: 10M
22 | replay_buffer_num_workers: 4
23 | batch_size: ${agent.batch_size}
24 | # misc
25 | seed: 1
26 | device: cuda
27 | save_video: False
28 | use_tb: False
29 |
30 | # used for train_offline_single
31 | data_main: expert
32 |
33 | wandb: False
34 | hydra:
35 | run:
36 | dir: ./result_cds/${task}-Share_${share_task[0]}_${share_task[1]}-${data_type[0]}-${agent.name}-${now:%m-%d-%H-%M-%S}
37 |
--------------------------------------------------------------------------------
/config_single.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - agent: pbrl # td3_bc, cql, pbrl
3 | - override hydra/launcher: submitit_local
4 |
5 | # unsupervised exploration
6 | # expl_agent: td3
7 | # task settings
8 | task: walker_walk # main task to train (relable other datasets to this task)
9 | share_task: [walker_walk, walker_run] # task for data sharing
10 | data_type: [medium, medium-replay] # dataset for data sharing (corresponding each share_task)
11 |
12 | discount: 0.99
13 | # train settings
14 | num_grad_steps: 1000000
15 | log_every_steps: 1000
16 | # eval
17 | eval_every_steps: 10000
18 | num_eval_episodes: 10
19 | # dataset
20 | replay_buffer_dir: ../../collect # make sure to update this if you change hydra run dir
21 | replay_buffer_size: 10000000 # max: 10M
22 | replay_buffer_num_workers: 4
23 | batch_size: ${agent.batch_size}
24 | # misc
25 | seed: 1
26 | device: cuda
27 | save_video: False
28 | use_tb: False
29 |
30 | # used for train_offline_single
31 | data_main: expert
32 |
33 | wandb: False
34 | hydra:
35 | run:
36 | # dir: ./result_cql/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S}
37 | dir: ./result_pbrl/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S}
38 | # dir: ./output/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S}
39 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/test.py:
--------------------------------------------------------------------------------
1 | from dm_control import manipulation, suite
2 | from dm_control.suite.wrappers import action_scale, pixels
3 | from dm_env import StepType, specs
4 | import numpy as np
5 |
6 | env = suite.load("walker", "walk")
7 | env.reset()
8 | action = np.random.random(6)
9 | time_step = env.step(action)
10 | for k, v in time_step.observation.items():
11 | print("observation:", k, v.shape)
12 | print("original reward:", time_step.reward)
13 |
14 | state = env.physics.get_state()
15 | print("physics:", state.shape)
16 | print("-----")
17 | ##########
18 |
19 | new_env = suite.load("walker", "walk")
20 | reward_spec = new_env.reward_spec()
21 | new_env.reset()
22 | with new_env.physics.reset_context():
23 | new_env.physics.set_state(state)
24 | new_reward = new_env.task.get_reward(new_env.physics) # 输入是 env 当前的状态 physics, 通过 env.task.get_reward 函数输出奖励
25 | new_reward = np.full(reward_spec.shape, new_reward, reward_spec.dtype)
26 | print("new reward:", new_reward)
27 |
28 | # states = episode['physics']
29 | # for i in range(states.shape[0]):
30 | # with env.physics.reset_context():
31 | # env.physics.set_state(states[i])
32 | # reward = env.task.get_reward(env.physics) # 输入是 env 当前的状态 physics, 通过 env.task.get_reward 函数输出奖励
33 | # reward = np.full(reward_spec.shape, reward, reward_spec.dtype) # 改变shape和dtype
34 | # rewards.append(reward)
35 | # episode['reward'] = np.array(rewards, dtype=reward_spec.dtype)
36 | # return episode
37 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - agent: pbrl # td3_bc, cql, pbrl
3 | - override hydra/launcher: submitit_local
4 |
5 | # unsupervised exploration
6 | # expl_agent: td3
7 | # task settings
8 | task: walker_walk # main task to train (relable other datasets to this task)
9 | share_task: [walker_walk, walker_run] # task for data sharing
10 | data_type: [medium, medium-replay] # dataset for data sharing (corresponding each share_task)
11 |
12 | discount: 0.99
13 | # train settings
14 | num_grad_steps: 1000000
15 | log_every_steps: 1000
16 | # eval
17 | eval_every_steps: 10000
18 | num_eval_episodes: 10
19 | # dataset
20 | replay_buffer_dir: ../../collect # make sure to update this if you change hydra run dir
21 | replay_buffer_size: 10000000 # max: 10M
22 | replay_buffer_num_workers: 4
23 | batch_size: ${agent.batch_size}
24 | # misc
25 | seed: 1
26 | device: cuda
27 | save_video: False
28 | use_tb: False
29 |
30 | # used for train_offline_single
31 | data_main: expert
32 |
33 | wandb: False
34 | hydra:
35 | run:
36 | # dir: ./result_td3bc_share/${task}-Share_${share_task[0]}_${share_task[1]}-Data_${data_type[0]}_${data_type[1]}-${agent.name}-${now:%m-%d-%H-%M-%S}
37 | # dir: ./result_cql/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S}
38 | # dir: ./result_pbrl/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S}
39 | dir: ./result_pbrl_share/${task}-Share_${share_task[0]}_${share_task[1]}-${data_type[0]}-${agent.name}-${now:%m-%d-%H-%M-%S}
40 | # dir: ./output/${task}-${data_main}-${agent.name}-${now:%m-%d-%H-%M-%S}
41 |
42 |
--------------------------------------------------------------------------------
/collect_data.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - agent: td3
3 | - override hydra/launcher: submitit_local
4 |
5 | # mode
6 | # reward_free: false
7 | # task settings
8 | task: walker_run
9 | #obs_type: states # [states, pixels]
10 | #frame_stack: 3 # only works if obs_type=pixels
11 | action_repeat: 1 # set to 2 for pixels
12 | discount: 0.99
13 | # train settings
14 | num_train_frames: 2000010 # 2M steps to converge
15 | num_seed_frames: 4000
16 | # eval
17 | eval_every_frames: 10000
18 | num_eval_episodes: 10
19 | # pretrained
20 | #snapshot_ts: 100000
21 | #snapshot_base_dir: ./pretrained_models
22 | # replay buffer
23 | replay_buffer_size: 1000000
24 | replay_buffer_num_workers: 4
25 | batch_size: ${agent.batch_size}
26 | nstep: ${agent.nstep}
27 | update_encoder: false # can be either true or false depending if we want to fine-tune encoder
28 | # misc
29 | seed: 1
30 | device: cuda
31 | save_video: true
32 | save_train_video: false
33 | use_tb: false
34 | use_wandb: false
35 | # experiment
36 | experiment: exp
37 |
38 |
39 | hydra:
40 | run:
41 | dir: ./collect/${task}-${agent.name}-${now:%m-%d-%H-%-M%S}
42 | # dir: ./collect/${task}-${agent.name}-medium-replay # for collect_data_fixed, medium-replay
43 | # dir: ./collect/${task}-${agent.name}-expert # for collect_data_fixed, expert
44 | # dir: ./video/${task}-${agent.name}-video
45 | sweep:
46 | dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}
47 | subdir: ${hydra.job.num}
48 | launcher:
49 | timeout_min: 4300
50 | cpus_per_task: 10
51 | gpus_per_node: 1
52 | tasks_per_node: 1
53 | mem_gb: 160
54 | nodes: 1
55 | submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M}_${agent.name}_${experiment}/.slurm
56 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from custom_dmc_tasks import cheetah
2 | from custom_dmc_tasks import walker
3 | from custom_dmc_tasks import hopper
4 | from custom_dmc_tasks import quadruped
5 | from custom_dmc_tasks import jaco
6 | from custom_dmc_tasks import point_mass_maze
7 |
8 |
9 | def make(domain, task,
10 | task_kwargs=None,
11 | environment_kwargs=None,
12 | visualize_reward=False):
13 |
14 | if domain == 'cheetah':
15 | return cheetah.make(task,
16 | task_kwargs=task_kwargs,
17 | environment_kwargs=environment_kwargs,
18 | visualize_reward=visualize_reward)
19 | elif domain == 'walker':
20 | return walker.make(task,
21 | task_kwargs=task_kwargs,
22 | environment_kwargs=environment_kwargs,
23 | visualize_reward=visualize_reward)
24 | elif domain == 'point_mass_maze':
25 | return point_mass_maze.make(task,
26 | task_kwargs=task_kwargs,
27 | environment_kwargs=environment_kwargs,
28 | visualize_reward=visualize_reward)
29 | elif domain == 'hopper':
30 | return hopper.make(task,
31 | task_kwargs=task_kwargs,
32 | environment_kwargs=environment_kwargs,
33 | visualize_reward=visualize_reward)
34 | elif domain == 'quadruped':
35 | return quadruped.make(task,
36 | task_kwargs=task_kwargs,
37 | environment_kwargs=environment_kwargs,
38 | visualize_reward=visualize_reward)
39 | else:
40 | raise f'{task} not found'
41 |
42 | assert None
43 |
44 |
45 | def make_jaco(task, obs_type, seed):
46 | return jaco.make(task, obs_type, seed)
--------------------------------------------------------------------------------
/custom_dmc_tasks/point_mass_maze_reach_bottom_left.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/point_mass_maze_reach_top_left.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/point_mass_maze_reach_top_right.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/point_mass_maze_reach_bottom_right.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import imageio
3 | import numpy as np
4 | import wandb
5 |
6 |
7 | class VideoRecorder:
8 | def __init__(self,
9 | root_dir,
10 | render_size=256,
11 | fps=20,
12 | camera_id=0):
13 | if root_dir is not None:
14 | self.save_dir = root_dir / 'eval_video'
15 | self.save_dir.mkdir(exist_ok=True)
16 | else:
17 | self.save_dir = None
18 |
19 | self.render_size = render_size
20 | self.fps = fps
21 | self.frames = []
22 | self.camera_id = camera_id
23 |
24 | def init(self, env, enabled=True):
25 | self.frames = []
26 | self.enabled = self.save_dir is not None and enabled
27 | self.record(env)
28 |
29 | def record(self, env):
30 | if self.enabled:
31 | if hasattr(env, 'physics'):
32 | frame = env.physics.render(height=self.render_size,
33 | width=self.render_size,
34 | camera_id=self.camera_id)
35 | else:
36 | frame = env.render()
37 | self.frames.append(frame)
38 |
39 | def log_to_wandb(self):
40 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
41 | fps, skip = 6, 8
42 | wandb.log({
43 | 'eval/video':
44 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
45 | })
46 |
47 | def save(self, file_name):
48 | if self.enabled:
49 | path = self.save_dir / file_name
50 | imageio.mimsave(str(path), self.frames, fps=self.fps)
51 |
52 |
53 | class TrainVideoRecorder:
54 | def __init__(self,
55 | root_dir,
56 | render_size=256,
57 | fps=20,
58 | camera_id=0,
59 | use_wandb=False):
60 | if root_dir is not None:
61 | self.save_dir = root_dir / 'train_video'
62 | self.save_dir.mkdir(exist_ok=True)
63 | else:
64 | self.save_dir = None
65 |
66 | self.render_size = render_size
67 | self.fps = fps
68 | self.frames = []
69 | self.camera_id = camera_id
70 | self.use_wandb = use_wandb
71 |
72 | def init(self, obs, enabled=True):
73 | self.frames = []
74 | self.enabled = self.save_dir is not None and enabled
75 | self.record(obs)
76 |
77 | def record(self, obs):
78 | if self.enabled:
79 | frame = cv2.resize(obs[-3:].transpose(1, 2, 0),
80 | dsize=(self.render_size, self.render_size),
81 | interpolation=cv2.INTER_CUBIC)
82 | self.frames.append(frame)
83 |
84 | def log_to_wandb(self):
85 | frames = np.transpose(np.array(self.frames), (0, 3, 1, 2))
86 | fps, skip = 6, 8
87 | wandb.log({
88 | 'train/video':
89 | wandb.Video(frames[::skip, :, ::2, ::2], fps=fps, format="gif")
90 | })
91 |
92 | def save(self, file_name):
93 | if self.enabled:
94 | if self.use_wandb:
95 | self.log_to_wandb()
96 | path = self.save_dir / file_name
97 | imageio.mimsave(str(path), self.frames, fps=self.fps)
98 |
--------------------------------------------------------------------------------
/agent/bc.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 | import functools
8 |
9 | import utils
10 | from dm_control.utils import rewards
11 |
12 |
13 | class Actor(nn.Module):
14 | def __init__(self, obs_dim, action_dim, hidden_dim):
15 | super().__init__()
16 |
17 | self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim),
18 | nn.LayerNorm(hidden_dim), nn.Tanh(),
19 | nn.Linear(hidden_dim, hidden_dim),
20 | nn.ReLU(inplace=True),
21 | nn.Linear(hidden_dim, action_dim))
22 |
23 | self.apply(utils.weight_init)
24 |
25 | def forward(self, obs, std):
26 | mu = self.policy(obs)
27 | mu = torch.tanh(mu)
28 | std = torch.ones_like(mu) * std
29 |
30 | dist = utils.TruncatedNormal(mu, std)
31 | return dist
32 |
33 |
34 | class BCAgent:
35 | def __init__(self,
36 | name,
37 | obs_shape,
38 | action_shape,
39 | device,
40 | lr,
41 | hidden_dim,
42 | batch_size,
43 | stddev_schedule,
44 | use_tb,
45 | has_next_action=False):
46 | self.action_dim = action_shape[0]
47 | self.hidden_dim = hidden_dim
48 | self.lr = lr
49 | self.device = device
50 | self.stddev_schedule = stddev_schedule
51 | self.use_tb = use_tb
52 |
53 | # models
54 | self.actor = Actor(obs_shape[0], action_shape[0],
55 | hidden_dim).to(device)
56 |
57 | # optimizers
58 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
59 |
60 | self.train()
61 |
62 | def train(self, training=True):
63 | self.training = training
64 | self.actor.train(training)
65 |
66 | def act(self, obs, step, eval_mode):
67 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
68 | stddev = utils.schedule(self.stddev_schedule, step)
69 | policy = self.actor(obs, stddev)
70 | if eval_mode:
71 | action = policy.mean
72 | else:
73 | action = policy.sample(clip=None)
74 | if step < self.num_expl_steps:
75 | action.uniform_(-1.0, 1.0)
76 | return action.cpu().numpy()[0]
77 |
78 | def update_actor(self, obs, action, step):
79 | metrics = dict()
80 |
81 | stddev = utils.schedule(self.stddev_schedule, step)
82 | policy = self.actor(obs, stddev)
83 |
84 | log_prob = policy.log_prob(action).sum(-1, keepdim=True)
85 | actor_loss = (-log_prob).mean()
86 |
87 | self.actor_opt.zero_grad(set_to_none=True)
88 | actor_loss.backward()
89 | self.actor_opt.step()
90 |
91 | if self.use_tb:
92 | metrics['actor_loss'] = actor_loss.item()
93 | metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item()
94 |
95 | return metrics
96 |
97 | def update(self, replay_iter, step):
98 | metrics = dict()
99 |
100 | batch = next(replay_iter)
101 | obs, action, reward, discount, next_obs = utils.to_torch(
102 | batch, self.device)
103 |
104 | if self.use_tb:
105 | metrics['batch_reward'] = reward.mean().item()
106 |
107 | # update actor
108 | metrics.update(self.update_actor(obs, action, step))
109 |
110 | return metrics
111 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/test.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/cheetah.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/train_offline_single.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 | import os
5 | import random
6 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
7 | os.environ['MUJOCO_GL'] = 'egl'
8 | import json
9 | from pathlib import Path
10 | import hydra
11 | import numpy as np
12 | import torch
13 | from dm_env import specs
14 | import dmc
15 | import utils
16 | from logger import Logger
17 | from replay_buffer import make_replay_loader
18 | from video import VideoRecorder
19 | import wandb
20 | from omegaconf import OmegaConf
21 |
22 |
23 | torch.backends.cudnn.benchmark = True
24 |
25 | with open("task.json", "r") as f:
26 | task_dict = json.load(f)
27 |
28 |
29 | def get_domain(task):
30 | if task.startswith('point_mass_maze'):
31 | return 'point_mass_maze'
32 | return task.split('_', 1)[0]
33 |
34 |
35 | def get_data_seed(seed, num_data_seeds):
36 | return (seed - 1) % num_data_seeds + 1
37 |
38 |
39 | def eval(global_step, agent, env, logger, num_eval_episodes, video_recorder):
40 | step, episode, total_reward = 0, 0, 0
41 | eval_until_episode = utils.Until(num_eval_episodes)
42 | while eval_until_episode(episode):
43 | time_step = env.reset()
44 | video_recorder.init(env, enabled=(episode == 0))
45 | while not time_step.last():
46 | with torch.no_grad(), utils.eval_mode(agent):
47 | action = agent.act(time_step.observation, step=global_step, eval_mode=True)
48 | time_step = env.step(action)
49 | video_recorder.record(env)
50 | total_reward += time_step.reward
51 | step += 1
52 |
53 | episode += 1
54 | video_recorder.save(f'{global_step}.mp4')
55 |
56 | with logger.log_and_dump_ctx(global_step, ty='eval') as log:
57 | log('episode_reward', total_reward / episode)
58 | log('episode_length', step / episode)
59 | log('step', global_step)
60 | return {"eval_episode_reward": total_reward / episode}
61 |
62 |
63 | @hydra.main(config_path='.', config_name='config_single')
64 | def main(cfg):
65 | work_dir = Path.cwd()
66 | print(f'workspace: {work_dir}')
67 |
68 | # random seeds
69 | cfg.seed = random.randint(0, 100000)
70 |
71 | utils.set_seed_everywhere(cfg.seed)
72 | device = torch.device(cfg.device)
73 |
74 | # create logger
75 | logger = Logger(work_dir, use_tb=cfg.use_tb)
76 |
77 | # create envs
78 | env = dmc.make(cfg.task, seed=cfg.seed)
79 |
80 | # create agent
81 | agent = hydra.utils.instantiate(cfg.agent,
82 | obs_shape=env.observation_spec().shape, action_shape=env.action_spec().shape,
83 | num_expl_steps=0)
84 |
85 | # create replay buffer
86 | replay_dir_list = []
87 | datasets_dir = work_dir / cfg.replay_buffer_dir
88 | replay_dir = datasets_dir.resolve() / Path(cfg.task+"-td3-"+str(cfg.data_main)) / 'data'
89 | print(f'replay dir: {replay_dir}')
90 | replay_dir_list.append(replay_dir)
91 |
92 | # 构建 replay buffer (single task)
93 | replay_loader = make_replay_loader(env, replay_dir_list, cfg.replay_buffer_size,
94 | cfg.batch_size, cfg.replay_buffer_num_workers, cfg.discount,
95 | main_task=cfg.task, task_list=[cfg.task])
96 | replay_iter = iter(replay_loader) # OfflineReplayBuffer.sample function
97 | print("load data done.")
98 |
99 | # create video recorders
100 | video_recorder = VideoRecorder(work_dir if cfg.save_video else None)
101 |
102 | timer = utils.Timer()
103 | global_step = 0
104 |
105 | train_until_step = utils.Until(cfg.num_grad_steps)
106 | eval_every_step = utils.Every(cfg.eval_every_steps)
107 | log_every_step = utils.Every(cfg.log_every_steps)
108 |
109 | if cfg.wandb:
110 | wandb_dir = f"./wandb/{cfg.task}_{cfg.agent.name}_{cfg.data_main}_{cfg.seed}"
111 | if not os.path.exists(wandb_dir):
112 | os.makedirs(wandb_dir)
113 | wandb.init(project="UTDS", entity='', config=cfg, group=f'{cfg.task}_{cfg.agent.name}_{cfg.data_main}',
114 | name=f'{cfg.task}_{cfg.agent.name}_{cfg.data_main}', dir=wandb_dir)
115 | wandb.config.update(vars(cfg))
116 |
117 | while train_until_step(global_step):
118 | # try to evaluate
119 | if eval_every_step(global_step):
120 | logger.log('eval_total_time', timer.total_time(), global_step)
121 | eval_metrics = eval(global_step, agent, env, logger, cfg.num_eval_episodes, video_recorder)
122 | if cfg.wandb:
123 | wandb.log(eval_metrics)
124 |
125 | # train the agent
126 | metrics = agent.update(replay_iter, global_step, cfg.num_grad_steps)
127 | if cfg.wandb:
128 | wandb.log(metrics)
129 |
130 | # log
131 | logger.log_metrics(metrics, global_step, ty='train')
132 | if log_every_step(global_step):
133 | elapsed_time, total_time = timer.reset()
134 | with logger.log_and_dump_ctx(global_step, ty='train') as log:
135 | log('fps', cfg.log_every_steps / elapsed_time)
136 | log('total_time', total_time)
137 | log('step', global_step)
138 |
139 | global_step += 1
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pessimistic Value Iteration for Multi-Task Data Sharing
2 |
3 | This repo contains a PyTorch implementation and the datasets for our paper titled "Pessimistic Value Iteration for Multi-Task Data Sharing in Offline Reinforcement Learning" published at *Artificial Intelligence* Journal. This is the paper [Link](https://www.sciencedirect.com/science/article/pii/S0004370223001947).
4 |
5 | ## Datasets
6 |
7 | We collect a Multi-Task Offline [Dataset](http://alchemist.wang-research.northwestern.edu/dataset/) based on DeepMind Control Suite (DMC).
8 | - Download the [Dataset](http://alchemist.wang-research.northwestern.edu/dataset/) to `./collect` before you start training.
9 | - The users can collect new datasets based on `collect_daeta.py`. The supported tasks
10 | include standard tasks from DMC and custom tasks from `./custom_dmc_tasks/`
11 |
12 | Our dataset contains 3 domains with 4 tasks per domain, resulting in 12 tasks in total.
13 |
14 | | Domain | Available task names |
15 | |---|---|
16 | | Walker | `walker_stand`, `walker_walk`, `walker_run`, `walker_flip` |
17 | | Quadruped | `quadruped_jump`, `quadruped_roll_fast` | `quadruped_walk` | `quadruped_run` |
18 | | Jaco Arm | `jaco_reach_top_left`, `jaco_reach_top_right`, `jaco_reach_bottom_left`, `jaco_reach_bottom_right` |
19 |
20 | 
21 |
22 |
23 | For each task, we run `TD3` to collect five types of datasets, including:
24 | - `random` data generated by a random agent.
25 | - `medium` data generated by a medium-level TD3 agent.
26 | - `medium-replay` data that collects all experiences in training a medium-level TD3 agent.
27 | - `medium-expert` data that collects all experiences in training an expert-level TD3 agent.
28 | - `expert` data generated by an expert-level TD3 agent.
29 |
30 |
31 | ## Prerequisites
32 |
33 | Install [MuJoCo](http://www.mujoco.org/):
34 |
35 | * Download MuJoCo binaries [here](https://mujoco.org/download).
36 | * Unzip the downloaded archive into `~/.mujoco/`.
37 | * Append the MuJoCo subdirectory bin path into the env variable `LD_LIBRARY_PATH`.
38 |
39 | Install the following libraries:
40 | ```sh
41 | sudo apt update
42 | sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 unzip
43 | ```
44 |
45 | Install dependencies:
46 | ```sh
47 | conda env create -f conda_env.yml
48 | conda activate utds
49 | ```
50 |
51 | ## Algorithms
52 |
53 | We provide several algorithms to train the *single-agent* and *multi-task data-sharing agent*.
54 |
55 | - For single-agent training, we provide the following algorithms.
56 |
57 | | Algorithm | Name | Paper |
58 | |---|---|---|
59 | | Behavior Cloning | `bc` | [paper](https://arxiv.org/abs/1805.01954)|
60 | | CQL | `cql` | [paper](https://proceedings.neurips.cc/paper/2020/file/0d2b2061826a5df3221116a5085a6052-Paper.pdf)|
61 | | TD3-BC | `td3_bc` | [paper](https://proceedings.neurips.cc/paper/2021/hash/a8166da05c5a094f7dc03724b41886e5-Abstract.html)|
62 | | CRR | `crr` | [paper](http://proceedings.neurips.cc/paper/2020/file/588cb956d6bbe67078f29f8de420a13d-Paper.pdf)|
63 | | PBRL | `ddpg` | [paper](https://openreview.net/forum?id=Y4cs1Z3HnqL)|
64 |
65 | - For multi-task data sharing, we support the following algorithms.
66 |
67 |
68 | | Algorithm | Name | Paper |
69 | |---|---|---|
70 | | Direct Sharing | `cql` | [paper](https://proceedings.neurips.cc/paper/2020/file/0d2b2061826a5df3221116a5085a6052-Paper.pdf)|
71 | | CDS | `cql_cds` | [paper](https://proceedings.neurips.cc/paper/2021/hash/5fd2c06f558321eff612bbbe455f6fbd-Abstract.html)|
72 | | Unlabeled-CDS | `cql_cdsz` | [paper](https://proceedings.mlr.press/v162/yu22c/yu22c.pdf)|
73 | | UTDS | `pbrl` |our paper|
74 |
75 |
76 | ## Training
77 |
78 | #### Train CDS
79 |
80 | Train the CDS agent in `quadruped_jump (random)` task with data sharing from `quadruped_roll_fast (replay)` dataset, run
81 | ```
82 | python train_offline_cds.py task=quadruped_jump "+share_task=[quadruped_jump, quadruped_roll_fast]" "+data_type=[random, replay]"
83 | ```
84 |
85 | #### Train UTDS
86 |
87 | Train the CDS agent in `quadruped_jump (random)` task with data sharing from `quadruped_roll_fast (replay)` dataset, run
88 | ```
89 | python train_offline_share.py task=quadruped_jump "+share_task=[quadruped_jump, quadruped_roll_fast]" "+data_type=[random, replay]"
90 | ```
91 |
92 | We support wandb by setting `wandb: True` in `config*.yaml` file.
93 |
94 | ## Citation
95 | ```
96 | @article{UTDS2023,
97 | title = {Pessimistic Value Iteration for Multi-Task Data Sharing in Offline Reinforcement Learning},
98 | journal = {Artificial Intelligence},
99 | author = {Chenjia Bai and Lingxiao Wang and Jianye Hao and Zhuoran Yang and Bin Zhao and Zhen Wang and Xuelong Li},
100 | pages = {104048},
101 | year = {2023},
102 | issn = {0004-3702},
103 | doi = {https://doi.org/10.1016/j.artint.2023.104048},
104 | url = {https://www.sciencedirect.com/science/article/pii/S0004370223001947},
105 | }
106 | ```
107 |
108 | ## License
109 | MIT license
110 |
--------------------------------------------------------------------------------
/train_offline_share.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 |
5 | import os
6 | import random
7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
8 | os.environ['MUJOCO_GL'] = 'egl'
9 | import json
10 | from pathlib import Path
11 | import hydra
12 | import numpy as np
13 | import torch
14 | from dm_env import specs
15 | import dmc
16 | import utils
17 | from logger import Logger
18 | from replay_buffer import make_replay_loader
19 | from video import VideoRecorder
20 | import wandb
21 |
22 | torch.backends.cudnn.benchmark = True
23 |
24 | with open("task.json", "r") as f:
25 | task_dict = json.load(f)
26 |
27 |
28 | def get_domain(task):
29 | if task.startswith('point_mass_maze'):
30 | return 'point_mass_maze'
31 | return task.split('_', 1)[0]
32 |
33 |
34 | def get_data_seed(seed, num_data_seeds):
35 | return (seed - 1) % num_data_seeds + 1
36 |
37 |
38 | def eval(global_step, agent, env, logger, num_eval_episodes, video_recorder):
39 | step, episode, total_reward = 0, 0, 0
40 | eval_until_episode = utils.Until(num_eval_episodes)
41 | while eval_until_episode(episode):
42 | time_step = env.reset()
43 | video_recorder.init(env, enabled=(episode == 0))
44 | while not time_step.last():
45 | with torch.no_grad(), utils.eval_mode(agent):
46 | action = agent.act(time_step.observation, step=global_step, eval_mode=True)
47 | time_step = env.step(action)
48 | video_recorder.record(env)
49 | total_reward += time_step.reward
50 | step += 1
51 |
52 | episode += 1
53 | video_recorder.save(f'{global_step}.mp4')
54 |
55 | with logger.log_and_dump_ctx(global_step, ty='eval') as log:
56 | log('episode_reward', total_reward / episode)
57 | log('episode_length', step / episode)
58 | log('step', global_step)
59 |
60 |
61 | @hydra.main(config_path='.', config_name='config')
62 | def main(cfg):
63 | work_dir = Path.cwd()
64 | print(f'workspace: {work_dir}')
65 |
66 | # random seeds
67 | cfg.seed = random.randint(0, 100000)
68 |
69 | utils.set_seed_everywhere(cfg.seed)
70 | device = torch.device(cfg.device)
71 |
72 | # create logger
73 | logger = Logger(work_dir, use_tb=cfg.use_tb)
74 |
75 | # create envs
76 | env = dmc.make(cfg.task, seed=cfg.seed)
77 |
78 | # create agent
79 | agent = hydra.utils.instantiate(cfg.agent,
80 | obs_shape=env.observation_spec().shape, action_shape=env.action_spec().shape,
81 | num_expl_steps=0)
82 |
83 | replay_dir_list = []
84 |
85 | for task_id in range(len(cfg.share_task)):
86 | task = cfg.share_task[task_id] # dataset task
87 | data_type = cfg.data_type[task_id] # dataset type [random, medium, medium-replay, expert, replay]
88 | datasets_dir = work_dir / cfg.replay_buffer_dir
89 | replay_dir = datasets_dir.resolve() / Path(task+"-td3-"+str(data_type)) / 'data'
90 | print(f'replay dir: {replay_dir}')
91 | replay_dir_list.append(replay_dir)
92 |
93 | # construct the replay buffer. env is the main task, we use it to relabel the reward of other tasks
94 | replay_loader = make_replay_loader(env, replay_dir_list, cfg.replay_buffer_size,
95 | cfg.batch_size, cfg.replay_buffer_num_workers, cfg.discount,
96 | main_task=cfg.task, task_list=cfg.share_task)
97 | replay_iter = iter(replay_loader) # OfflineReplayBuffer sample function
98 | print("load data done.")
99 |
100 | # for i in replay_iter:
101 | # print(i)
102 | # break
103 |
104 | # create video recorders
105 | video_recorder = VideoRecorder(work_dir if cfg.save_video else None)
106 |
107 | timer = utils.Timer()
108 | global_step = 0
109 |
110 | train_until_step = utils.Until(cfg.num_grad_steps)
111 | eval_every_step = utils.Every(cfg.eval_every_steps)
112 | log_every_step = utils.Every(cfg.log_every_steps)
113 |
114 | if cfg.wandb:
115 | path_str = f'{cfg.agent.name}_{cfg.share_task[0]}_{cfg.share_task[1]}_{cfg.data_type[0]}_{cfg.data_type[1]}'
116 | wandb_dir = f"./wandb/{path_str}_{cfg.seed}"
117 | if not os.path.exists(wandb_dir):
118 | os.makedirs(wandb_dir)
119 | wandb.init(project="UTDS", entity='', config=cfg, name=f'{path_str}_1', dir=wandb_dir)
120 | wandb.config.update(vars(cfg))
121 |
122 | while train_until_step(global_step):
123 | # try to evaluate
124 | if eval_every_step(global_step):
125 | logger.log('eval_total_time', timer.total_time(), global_step)
126 | eval(global_step, agent, env, logger, cfg.num_eval_episodes, video_recorder)
127 |
128 | # train the agent
129 | metrics = agent.update(replay_iter, global_step, cfg.num_grad_steps)
130 |
131 | # log
132 | logger.log_metrics(metrics, global_step, ty='train')
133 | if log_every_step(global_step):
134 | elapsed_time, total_time = timer.reset()
135 | with logger.log_and_dump_ctx(global_step, ty='train') as log:
136 | log('fps', cfg.log_every_steps / elapsed_time)
137 | log('total_time', total_time)
138 | log('step', global_step)
139 |
140 | global_step += 1
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import io
3 | import random
4 | import traceback
5 | import copy
6 | from collections import defaultdict
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils.data import IterableDataset
12 |
13 |
14 | def episode_len(episode):
15 | # subtract -1 because the dummy first transition
16 | return next(iter(episode.values())).shape[0] - 1
17 |
18 |
19 | def save_episode(episode, fn):
20 | with io.BytesIO() as bs:
21 | np.savez_compressed(bs, **episode)
22 | bs.seek(0)
23 | with fn.open('wb') as f:
24 | f.write(bs.read())
25 |
26 |
27 | def load_episode(fn):
28 | with fn.open('rb') as f:
29 | episode = np.load(f)
30 | episode = {k: episode[k] for k in episode.keys()}
31 | return episode
32 |
33 |
34 | def relable_episode(env, episode): # relabel the reward function
35 | rewards = []
36 | reward_spec = env.reward_spec()
37 | states = episode['physics']
38 | for i in range(states.shape[0]):
39 | with env.physics.reset_context():
40 | env.physics.set_state(states[i])
41 | # Input: the current physics of env, then use env.task.get_reward to calculate the reward
42 | reward = env.task.get_reward(env.physics)
43 | reward = np.full(reward_spec.shape, reward, reward_spec.dtype) # 改变shape和dtype
44 | rewards.append(reward)
45 | original_reward = np.mean(episode['reward'])
46 | episode['reward'] = np.array(rewards, dtype=reward_spec.dtype)
47 | # print("Reward difference after relabeling:", original_reward - episode['reward'].mean())
48 | return episode
49 |
50 |
51 | class OfflineReplayBuffer(IterableDataset):
52 | # 用于 offline training 的 dataset
53 | def __init__(self, env, replay_dir_list, max_size, num_workers, discount, main_task, task_list):
54 | self._env = env
55 | self._replay_dir_list = replay_dir_list
56 | self._size = 0
57 | self._max_size = max_size
58 | self._num_workers = max(1, num_workers)
59 | self._episode_fns = []
60 | self._episodes = dict() # save as episode
61 | self._discount = discount
62 | self._loaded = False
63 | self._main_task = main_task
64 | self._task_list = task_list
65 |
66 | def _load(self, relable=True):
67 | print("load data", self._replay_dir_list, self._task_list)
68 | for i in range(len(self._replay_dir_list)): # loop
69 | _replay_dir = self._replay_dir_list[i]
70 | _task_share = self._task_list[i]
71 | assert _task_share in str(_replay_dir)
72 | try:
73 | worker_id = torch.utils.data.get_worker_info().id
74 | except:
75 | worker_id = 0
76 | print(f'Loading data from {_replay_dir} and Relabel...', "worker_id:", worker_id) # each worker will run this function
77 | print(f"Need relabeling: {relable and _task_share != self._main_task}")
78 | eps_fns = sorted(_replay_dir.glob('*.npz'))
79 | for eps_fn in eps_fns:
80 | if self._size > self._max_size:
81 | break
82 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
83 | if eps_idx % self._num_workers != worker_id: # read the npz file of the worker
84 | continue
85 | # load a npz file to represent an episodic sample. The keys include 'observation', 'action', 'reward', 'discount', 'physics'
86 | episode = load_episode(eps_fn)
87 |
88 | if relable and _task_share != self._main_task:
89 | # print(f"relabel {_replay_dir} for {self._main_task} task")
90 | episode = self._relable_reward(episode) # relabel
91 | data_flag = _task_share+str(eps_fn)+str(_task_share == self._main_task)
92 | self._episode_fns.append(data_flag)
93 | self._episodes[data_flag] = episode
94 | self._size += episode_len(episode)
95 | # if worker_id == 0:
96 | # print("data_flag:", data_flag)
97 | print("load done. Num of episodes", len(self._episode_fns)*self._num_workers)
98 |
99 | def _sample_episode(self):
100 | if not self._loaded:
101 | self._load()
102 | self._loaded = True
103 | eps_fn = random.choice(self._episode_fns)
104 | return self._episodes[eps_fn], eps_fn.endswith("True") # whether is the main buffer
105 |
106 | def _relable_reward(self, episode):
107 | return relable_episode(self._env, episode)
108 |
109 | def _sample(self):
110 | episode, eps_flag = self._sample_episode() # return the signal
111 | # add +1 for the first dummy transition
112 | idx = np.random.randint(0, episode_len(episode)) + 1
113 | obs = episode['observation'][idx - 1]
114 | action = episode['action'][idx]
115 | next_obs = episode['observation'][idx]
116 | reward = episode['reward'][idx]
117 | discount = episode['discount'][idx] * self._discount
118 |
119 | return (obs, action, reward, discount, next_obs, bool(eps_flag))
120 |
121 | def __iter__(self):
122 | while True:
123 | yield self._sample()
124 |
125 |
126 | def _worker_init_fn(worker_id):
127 | seed = np.random.get_state()[1][0] + worker_id
128 | np.random.seed(seed)
129 | random.seed(seed)
130 |
131 |
132 | def make_replay_loader(env, replay_dir_list, max_size, batch_size, num_workers, discount, main_task, task_list):
133 | max_size_per_worker = max_size // max(1, num_workers)
134 |
135 | iterable = OfflineReplayBuffer(env, replay_dir_list, max_size_per_worker,
136 | num_workers, discount, main_task, task_list) # task 表示主任务
137 |
138 | loader = torch.utils.data.DataLoader(iterable,
139 | batch_size=batch_size,
140 | num_workers=num_workers,
141 | pin_memory=True,
142 | worker_init_fn=_worker_init_fn)
143 | return loader
144 |
145 |
--------------------------------------------------------------------------------
/train_offline_cds.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 |
5 | import os
6 | import random
7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
8 | os.environ['MUJOCO_GL'] = 'egl'
9 | import json
10 | from pathlib import Path
11 | import hydra
12 | import numpy as np
13 | import torch
14 | from dm_env import specs
15 | import dmc
16 | import utils
17 | from logger import Logger
18 | from replay_buffer import make_replay_loader
19 | from video import VideoRecorder
20 | import wandb
21 |
22 | torch.backends.cudnn.benchmark = True
23 |
24 | with open("task.json", "r") as f:
25 | task_dict = json.load(f)
26 |
27 |
28 | def get_domain(task):
29 | if task.startswith('point_mass_maze'):
30 | return 'point_mass_maze'
31 | return task.split('_', 1)[0]
32 |
33 |
34 | def get_data_seed(seed, num_data_seeds):
35 | return (seed - 1) % num_data_seeds + 1
36 |
37 |
38 | def eval(global_step, agent, env, logger, num_eval_episodes, video_recorder):
39 | step, episode, total_reward = 0, 0, 0
40 | eval_until_episode = utils.Until(num_eval_episodes)
41 | while eval_until_episode(episode):
42 | time_step = env.reset()
43 | video_recorder.init(env, enabled=(episode == 0))
44 | while not time_step.last():
45 | with torch.no_grad(), utils.eval_mode(agent):
46 | action = agent.act(time_step.observation, step=global_step, eval_mode=True)
47 | time_step = env.step(action)
48 | video_recorder.record(env)
49 | total_reward += time_step.reward
50 | step += 1
51 |
52 | episode += 1
53 | video_recorder.save(f'{global_step}.mp4')
54 |
55 | with logger.log_and_dump_ctx(global_step, ty='eval') as log:
56 | log('episode_reward', total_reward / episode)
57 | log('episode_length', step / episode)
58 | log('step', global_step)
59 |
60 |
61 | @hydra.main(config_path='.', config_name='config_cds')
62 | def main(cfg):
63 | work_dir = Path.cwd()
64 | print(f'workspace: {work_dir}')
65 |
66 | # random seeds
67 | cfg.seed = random.randint(0, 100000)
68 |
69 | utils.set_seed_everywhere(cfg.seed)
70 | device = torch.device(cfg.device)
71 |
72 | # create logger
73 | logger = Logger(work_dir, use_tb=cfg.use_tb)
74 |
75 | # create envs
76 | env = dmc.make(cfg.task, seed=cfg.seed)
77 |
78 | # create agent
79 | agent = hydra.utils.instantiate(cfg.agent, obs_shape=env.observation_spec().shape,
80 | action_shape=env.action_spec().shape, num_expl_steps=0)
81 |
82 | replay_dir_list_main = []
83 | replay_dir_list_share = []
84 |
85 | share_tasks = []
86 | for task_id in range(len(cfg.share_task)):
87 | task = cfg.share_task[task_id] # dataset task
88 | data_type = cfg.data_type[task_id] # dataset type [random, medium, medium-replay, expert, replay]
89 | datasets_dir = work_dir / cfg.replay_buffer_dir # 存储数据的目录
90 | replay_dir = datasets_dir.resolve() / Path(task+"-td3-"+str(data_type)) / 'data'
91 | print(f'replay dir: {replay_dir}')
92 | if task == cfg.task:
93 | replay_dir_list_main.append(replay_dir)
94 | else:
95 | replay_dir_list_share.append(replay_dir)
96 | share_tasks.append(task)
97 |
98 | print("CDS. load main dataset..", cfg.task)
99 | replay_loader_main = make_replay_loader(env, replay_dir_list_main, cfg.replay_buffer_size,
100 | cfg.batch_size // 2, cfg.replay_buffer_num_workers, cfg.discount, # batch size (half)
101 | main_task=cfg.task, task_list=[cfg.task])
102 | replay_iter_main = iter(replay_loader_main) # run OfflineReplayBuffer.sample function
103 |
104 | print("CDS. load share dataset..", share_tasks)
105 | replay_loader_share = make_replay_loader(env, replay_dir_list_share, cfg.replay_buffer_size,
106 | cfg.batch_size // 2 * 10, cfg.replay_buffer_num_workers, cfg.discount, # batch size是10倍,后取top10
107 | main_task=cfg.task, task_list=share_tasks)
108 | replay_iter_share = iter(replay_loader_share) # run OfflineReplayBuffer.sample function
109 | print("load data done.")
110 |
111 | # for i in replay_iter_share:
112 | # print(i)
113 | # break
114 |
115 | # create video recorders
116 | video_recorder = VideoRecorder(work_dir if cfg.save_video else None)
117 |
118 | timer = utils.Timer()
119 | global_step = 0
120 |
121 | train_until_step = utils.Until(cfg.num_grad_steps)
122 | eval_every_step = utils.Every(cfg.eval_every_steps)
123 | log_every_step = utils.Every(cfg.log_every_steps)
124 |
125 | if cfg.wandb:
126 | path_str = f'{cfg.agent.name}_{cfg.share_task[0]}_{cfg.share_task[1]}_{cfg.data_type[0]}_{cfg.data_type[1]}'
127 | wandb_dir = f"./wandb/{path_str}_{cfg.seed}"
128 | if not os.path.exists(wandb_dir):
129 | os.makedirs(wandb_dir)
130 | wandb.init(project="UTDS", entity='', config=cfg, name=f'{path_str}_1', dir=wandb_dir)
131 | wandb.config.update(vars(cfg))
132 |
133 | while train_until_step(global_step):
134 | # try to evaluate
135 | if eval_every_step(global_step):
136 | logger.log('eval_total_time', timer.total_time(), global_step)
137 | eval(global_step, agent, env, logger, cfg.num_eval_episodes, video_recorder)
138 |
139 | # train the agent
140 | metrics = agent.update(replay_iter_main, replay_iter_share, global_step, cfg.num_grad_steps)
141 |
142 | # log
143 | logger.log_metrics(metrics, global_step, ty='train')
144 | if log_every_step(global_step):
145 | elapsed_time, total_time = timer.reset()
146 | with logger.log_and_dump_ctx(global_step, ty='train') as log:
147 | log('fps', cfg.log_every_steps / elapsed_time)
148 | log('total_time', total_time)
149 | log('step', global_step)
150 |
151 | global_step += 1
152 |
153 |
154 | if __name__ == '__main__':
155 | main()
156 |
--------------------------------------------------------------------------------
/agent/td3_bc.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 | import copy
8 | import utils
9 | from dm_control.utils import rewards
10 |
11 |
12 | class Actor(nn.Module):
13 | def __init__(self, state_dim, action_dim, max_action=1):
14 | super(Actor, self).__init__()
15 |
16 | self.l1 = nn.Linear(state_dim, 256)
17 | self.l2 = nn.Linear(256, 256)
18 | self.l3 = nn.Linear(256, action_dim)
19 |
20 | self.max_action = max_action
21 |
22 | def forward(self, state):
23 | a = F.relu(self.l1(state))
24 | a = F.relu(self.l2(a))
25 | return self.max_action * torch.tanh(self.l3(a))
26 |
27 |
28 | class Critic(nn.Module):
29 | def __init__(self, state_dim, action_dim):
30 | super(Critic, self).__init__()
31 |
32 | # Q1 architecture
33 | self.l1 = nn.Linear(state_dim + action_dim, 256)
34 | self.l2 = nn.Linear(256, 256)
35 | self.l3 = nn.Linear(256, 1)
36 |
37 | # Q2 architecture
38 | self.l4 = nn.Linear(state_dim + action_dim, 256)
39 | self.l5 = nn.Linear(256, 256)
40 | self.l6 = nn.Linear(256, 1)
41 |
42 | def forward(self, state, action):
43 | sa = torch.cat([state, action], 1)
44 |
45 | q1 = F.relu(self.l1(sa))
46 | q1 = F.relu(self.l2(q1))
47 | q1 = self.l3(q1)
48 |
49 | q2 = F.relu(self.l4(sa))
50 | q2 = F.relu(self.l5(q2))
51 | q2 = self.l6(q2)
52 | return q1, q2
53 |
54 | def Q1(self, state, action):
55 | sa = torch.cat([state, action], 1)
56 |
57 | q1 = F.relu(self.l1(sa))
58 | q1 = F.relu(self.l2(q1))
59 | q1 = self.l3(q1)
60 | return q1
61 |
62 |
63 | class TD3BCAgent:
64 | def __init__(self,
65 | name,
66 | obs_shape,
67 | action_shape,
68 | device,
69 | lr,
70 | hidden_dim,
71 | critic_target_tau,
72 | actor_target_tau,
73 | policy_freq,
74 | policy_noise,
75 | noise_clip,
76 | use_tb,
77 | alpha,
78 | batch_size,
79 | num_expl_steps):
80 | self.policy_noise = policy_noise
81 | self.policy_freq = policy_freq
82 | self.noise_clip = noise_clip
83 | self.num_expl_steps = num_expl_steps
84 | self.action_dim = action_shape[0]
85 | self.hidden_dim = hidden_dim
86 | self.lr = lr
87 | self.device = device
88 | self.critic_target_tau = critic_target_tau
89 | self.actor_target_tau = actor_target_tau
90 | self.use_tb = use_tb
91 | # self.stddev_schedule = stddev_schedule
92 | # self.stddev_clip = stddev_clip
93 | self.alpha = alpha
94 | self.max_action = 1.0
95 |
96 | # models
97 | self.actor = Actor(obs_shape[0], action_shape[0]).to(device)
98 | self.actor_target = copy.deepcopy(self.actor)
99 |
100 | self.critic = Critic(obs_shape[0], action_shape[0]).to(device)
101 | self.critic_target = copy.deepcopy(self.critic)
102 | self.critic_target.load_state_dict(self.critic.state_dict())
103 |
104 | # optimizers
105 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
106 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
107 |
108 | self.train()
109 | self.critic_target.train()
110 |
111 | def train(self, training=True):
112 | self.training = training
113 | self.actor.train(training)
114 | self.critic.train(training)
115 |
116 | def act(self, obs, step, eval_mode):
117 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
118 | action = self.actor(obs)
119 | if step < self.num_expl_steps:
120 | action.uniform_(-1.0, 1.0)
121 | return action.cpu().numpy()[0]
122 |
123 | def update_critic(self, obs, action, reward, discount, next_obs):
124 | metrics = dict()
125 |
126 | with torch.no_grad():
127 | # Select action according to policy and add clipped noise
128 | noise = (
129 | torch.randn_like(action) * self.policy_noise
130 | ).clamp(-self.noise_clip, self.noise_clip)
131 |
132 | next_action = (
133 | self.actor_target(next_obs) + noise
134 | ).clamp(-self.max_action, self.max_action)
135 |
136 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
137 | target_V = torch.min(target_Q1, target_Q2)
138 | target_Q = reward + (discount * target_V)
139 |
140 | Q1, Q2 = self.critic(obs, action)
141 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
142 |
143 | if self.use_tb:
144 | metrics['critic_target_q'] = target_Q.mean().item()
145 | metrics['critic_q1'] = Q1.mean().item()
146 | metrics['critic_q2'] = Q2.mean().item()
147 | metrics['critic_loss'] = critic_loss.item()
148 |
149 | # optimize critic
150 | self.critic_opt.zero_grad(set_to_none=True)
151 | critic_loss.backward()
152 | self.critic_opt.step()
153 | return metrics
154 |
155 | def update_actor(self, obs, action):
156 | metrics = dict()
157 |
158 | # Compute actor loss
159 | pi = self.actor(obs)
160 | Q = self.critic.Q1(obs, pi)
161 | lmbda = self.alpha / Q.abs().mean().detach()
162 |
163 | actor_loss = -lmbda * Q.mean() + F.mse_loss(pi, action)
164 |
165 | # optimize actor
166 | self.actor_opt.zero_grad(set_to_none=True)
167 | actor_loss.backward()
168 | self.actor_opt.step()
169 |
170 | if self.use_tb:
171 | metrics['actor_loss'] = actor_loss.item()
172 |
173 | return metrics
174 |
175 | def update(self, replay_iter, step, total_step=None):
176 | metrics = dict()
177 |
178 | batch = next(replay_iter)
179 | # obs.shape=(1024,obs_dim), action.shape=(1024,1), reward.shape=(1024,1), discount.shape=(1024,1)
180 | obs, action, reward, discount, next_obs, _ = utils.to_torch(
181 | batch, self.device)
182 |
183 | if self.use_tb:
184 | metrics['batch_reward'] = reward.mean().item()
185 |
186 | # update critic
187 | metrics.update(
188 | self.update_critic(obs, action, reward, discount, next_obs))
189 |
190 | # update actor
191 | if step % self.policy_freq == 0:
192 | metrics.update(self.update_actor(obs, action))
193 |
194 | # update critic target
195 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
196 | utils.soft_update_params(self.actor, self.actor_target, self.actor_target_tau)
197 |
198 | return metrics
199 |
--------------------------------------------------------------------------------
/agent/td3.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 |
8 | import utils
9 | from dm_control.utils import rewards
10 |
11 |
12 | class Actor(nn.Module):
13 | def __init__(self, obs_dim, action_dim, hidden_dim):
14 | super().__init__()
15 |
16 | self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim),
17 | nn.LayerNorm(hidden_dim), nn.Tanh(),
18 | nn.Linear(hidden_dim, hidden_dim),
19 | nn.ReLU(inplace=True),
20 | nn.Linear(hidden_dim, action_dim))
21 |
22 | self.apply(utils.weight_init)
23 |
24 | def forward(self, obs, std):
25 | mu = self.policy(obs)
26 | mu = torch.tanh(mu)
27 | std = torch.ones_like(mu) * std
28 |
29 | dist = utils.TruncatedNormal(mu, std)
30 | return dist
31 |
32 |
33 | class Critic(nn.Module):
34 | def __init__(self, obs_dim, action_dim, hidden_dim):
35 | super().__init__()
36 |
37 | self.q1_net = nn.Sequential(
38 | nn.Linear(obs_dim + action_dim, hidden_dim),
39 | nn.LayerNorm(hidden_dim), nn.Tanh(),
40 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
41 | nn.Linear(hidden_dim, 1))
42 |
43 | self.q2_net = nn.Sequential(
44 | nn.Linear(obs_dim + action_dim, hidden_dim),
45 | nn.LayerNorm(hidden_dim), nn.Tanh(),
46 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
47 | nn.Linear(hidden_dim, 1))
48 |
49 | self.apply(utils.weight_init)
50 |
51 | def forward(self, obs, action):
52 | obs_action = torch.cat([obs, action], dim=-1)
53 | q1 = self.q1_net(obs_action)
54 | q2 = self.q2_net(obs_action)
55 |
56 | return q1, q2
57 |
58 |
59 | class TD3Agent:
60 | def __init__(self,
61 | name,
62 | obs_shape,
63 | action_shape,
64 | device,
65 | lr,
66 | hidden_dim,
67 | critic_target_tau,
68 | stddev_schedule,
69 | num_expl_steps, # steps to before training
70 | nstep,
71 | batch_size,
72 | stddev_clip,
73 | use_tb,
74 | # obs_type,
75 | has_next_action=False):
76 | # self.obs_type = obs_type
77 | self.num_expl_steps = num_expl_steps
78 | self.action_dim = action_shape[0]
79 | self.hidden_dim = hidden_dim
80 | self.lr = lr
81 | self.device = device
82 | self.critic_target_tau = critic_target_tau
83 | self.use_tb = use_tb
84 | self.stddev_schedule = stddev_schedule
85 | self.stddev_clip = stddev_clip
86 |
87 | # models
88 | self.actor = Actor(obs_shape[0], action_shape[0],
89 | hidden_dim).to(device)
90 |
91 | self.critic = Critic(obs_shape[0], action_shape[0],
92 | hidden_dim).to(device)
93 | self.critic_target = Critic(obs_shape[0], action_shape[0],
94 | hidden_dim).to(device)
95 | self.critic_target.load_state_dict(self.critic.state_dict())
96 |
97 | # optimizers
98 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
99 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
100 |
101 | self.train()
102 | self.critic_target.train()
103 |
104 | def train(self, training=True):
105 | self.training = training
106 | self.actor.train(training)
107 | self.critic.train(training)
108 |
109 | def get_meta_specs(self):
110 | return tuple()
111 |
112 | def init_meta(self):
113 | return OrderedDict()
114 |
115 | def update_meta(self, meta, global_step, time_step, finetune=False):
116 | return meta
117 |
118 | def act(self, obs, step, eval_mode, meta=None):
119 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
120 | stddev = utils.schedule(self.stddev_schedule, step)
121 | policy = self.actor(obs, stddev)
122 | if eval_mode:
123 | action = policy.mean
124 | else:
125 | action = policy.sample(clip=None)
126 | if step < self.num_expl_steps:
127 | action.uniform_(-1.0, 1.0)
128 | return action.cpu().numpy()[0]
129 |
130 | def update_critic(self, obs, action, reward, discount, next_obs, step):
131 | metrics = dict()
132 |
133 | with torch.no_grad():
134 | stddev = utils.schedule(self.stddev_schedule, step)
135 | dist = self.actor(next_obs, stddev)
136 | next_action = dist.sample(clip=self.stddev_clip)
137 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
138 | target_V = torch.min(target_Q1, target_Q2)
139 | target_Q = reward + (discount * target_V)
140 |
141 | Q1, Q2 = self.critic(obs, action)
142 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
143 |
144 | if self.use_tb:
145 | metrics['critic_target_q'] = target_Q.mean().item()
146 | metrics['critic_q1'] = Q1.mean().item()
147 | metrics['critic_q2'] = Q2.mean().item()
148 | metrics['critic_loss'] = critic_loss.item()
149 |
150 | # optimize critic
151 | self.critic_opt.zero_grad(set_to_none=True)
152 | critic_loss.backward()
153 | self.critic_opt.step()
154 | return metrics
155 |
156 | def update_actor(self, obs, action, step):
157 | metrics = dict()
158 |
159 | stddev = utils.schedule(self.stddev_schedule, step)
160 | policy = self.actor(obs, stddev)
161 |
162 | Q1, Q2 = self.critic(obs, policy.sample(clip=self.stddev_clip))
163 | Q = torch.min(Q1, Q2)
164 |
165 | actor_loss = -Q.mean()
166 |
167 | # optimize actor
168 | self.actor_opt.zero_grad(set_to_none=True)
169 | actor_loss.backward()
170 | self.actor_opt.step()
171 |
172 | if self.use_tb:
173 | metrics['actor_loss'] = actor_loss.item()
174 | metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item()
175 |
176 | return metrics
177 |
178 | def update(self, replay_iter, step):
179 | metrics = dict()
180 |
181 | batch = next(replay_iter)
182 | obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device)
183 |
184 | if self.use_tb:
185 | metrics['batch_reward'] = reward.mean().item()
186 |
187 | # update critic
188 | metrics.update(
189 | self.update_critic(obs, action, reward, discount, next_obs, step))
190 |
191 | # update actor
192 | metrics.update(self.update_actor(obs, action, step))
193 |
194 | # update critic target
195 | utils.soft_update_params(self.critic, self.critic_target,
196 | self.critic_target_tau)
197 |
198 | return metrics
199 |
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | from termcolor import colored
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 |
12 | TRAIN_FORMAT = [('step', 'S', 'int'), ('episode', 'E', 'int'), ('batch_reward', 'BR', 'float'),
13 | ('episode_length', 'L', 'int'), ('episode_reward', 'R', 'float'),
14 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')]
15 |
16 |
17 | EVAL_FORMAT = [('step', 'S', 'int'), ('episode_length', 'L', 'int'),
18 | ('episode_reward', 'R', 'float'),
19 | ('dataset_reward', 'DR', 'float'), ('total_time', 'T', 'time')]
20 |
21 |
22 | class AverageMeter(object):
23 | def __init__(self):
24 | self._sum = 0
25 | self._count = 0
26 |
27 | def update(self, value, n=1):
28 | self._sum += value
29 | self._count += n
30 |
31 | def value(self):
32 | return self._sum / max(1, self._count)
33 |
34 |
35 | class MetersGroup(object):
36 | def __init__(self, csv_file_name, formating):
37 | self._csv_file_name = csv_file_name
38 | self._formating = formating
39 | self._meters = defaultdict(AverageMeter)
40 | self._csv_file = None
41 | self._csv_writer = None
42 |
43 | def log(self, key, value, n=1):
44 | self._meters[key].update(value, n)
45 |
46 | def _prime_meters(self):
47 | data = dict()
48 | for key, meter in self._meters.items():
49 | if key.startswith('train'):
50 | key = key[len('train') + 1:]
51 | else:
52 | key = key[len('eval') + 1:]
53 | key = key.replace('/', '_')
54 | data[key] = meter.value()
55 | return data
56 |
57 | def _remove_old_entries(self, data):
58 | rows = []
59 | with self._csv_file_name.open('r') as f:
60 | reader = csv.DictReader(f)
61 | for row in reader:
62 | if float(row['episode']) >= data['episode']:
63 | break
64 | rows.append(row)
65 | with self._csv_file_name.open('w') as f:
66 | writer = csv.DictWriter(f,
67 | fieldnames=sorted(data.keys()),
68 | restval=0.0)
69 | writer.writeheader()
70 | for row in rows:
71 | writer.writerow(row)
72 |
73 | def _dump_to_csv(self, data):
74 | if self._csv_writer is None:
75 | should_write_header = True
76 | if self._csv_file_name.exists():
77 | self._remove_old_entries(data)
78 | should_write_header = False
79 |
80 | self._csv_file = self._csv_file_name.open('a')
81 | self._csv_writer = csv.DictWriter(self._csv_file,
82 | fieldnames=sorted(data.keys()),
83 | restval=0.0)
84 | if should_write_header:
85 | self._csv_writer.writeheader()
86 |
87 | self._csv_writer.writerow(data)
88 | self._csv_file.flush()
89 |
90 | def _format(self, key, value, ty):
91 | if ty == 'int':
92 | value = int(value)
93 | return f'{key}: {value}'
94 | elif ty == 'float':
95 | return f'{key}: {value:.04f}'
96 | elif ty == 'time':
97 | value = str(datetime.timedelta(seconds=int(value)))
98 | return f'{key}: {value}'
99 | else:
100 | raise f'invalid format type: {ty}'
101 |
102 | def _dump_to_console(self, data, prefix):
103 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
104 | pieces = [f'| {prefix: <14}']
105 | for key, disp_key, ty in self._formating:
106 | value = data.get(key, 0)
107 | pieces.append(self._format(disp_key, value, ty))
108 | print(' | '.join(pieces))
109 |
110 | def dump(self, step, prefix):
111 | if len(self._meters) == 0:
112 | return
113 | data = self._prime_meters()
114 | data['frame'] = step
115 | self._dump_to_csv(data)
116 | self._dump_to_console(data, prefix)
117 | self._meters.clear()
118 |
119 |
120 | class Logger(object):
121 | def __init__(self, log_dir, use_tb, offline=False):
122 | self._log_dir = log_dir
123 | self._train_mg = MetersGroup(log_dir / 'train.csv', formating=TRAIN_FORMAT)
124 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', formating=EVAL_FORMAT)
125 | if use_tb:
126 | self._sw = SummaryWriter(str(log_dir / 'tb'))
127 | else:
128 | self._sw = None
129 |
130 | def _try_sw_log(self, key, value, step):
131 | if self._sw is not None:
132 | self._sw.add_scalar(key, value, step)
133 |
134 | def log(self, key, value, step):
135 | assert key.startswith('train') or key.startswith('eval')
136 | if type(value) == torch.Tensor:
137 | value = value.item()
138 | self._try_sw_log(key, value, step)
139 | mg = self._train_mg if key.startswith('train') else self._eval_mg
140 | mg.log(key, value)
141 |
142 | def log_metrics(self, metrics, step, ty):
143 | for key, value in metrics.items():
144 | self.log(f'{ty}/{key}', value, step)
145 |
146 | def dump(self, step, ty=None):
147 | if ty is None or ty == 'eval':
148 | self._eval_mg.dump(step, 'eval')
149 | if ty is None or ty == 'train':
150 | self._train_mg.dump(step, 'train')
151 |
152 | def log_and_dump_ctx(self, step, ty):
153 | return LogAndDumpCtx(self, step, ty)
154 |
155 |
156 | class LogAndDumpCtx:
157 | def __init__(self, logger, step, ty):
158 | self._logger = logger
159 | self._step = step
160 | self._ty = ty
161 |
162 | def __enter__(self):
163 | return self
164 |
165 | def __call__(self, key, value):
166 | self._logger.log(f'{self._ty}/{key}', value, self._step)
167 |
168 | def __exit__(self, *args):
169 | self._logger.dump(self._step, self._ty)
170 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/cheetah.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Cheetah Domain."""
16 |
17 | import collections
18 | import os
19 |
20 | from dm_control import mujoco
21 | from dm_control.rl import control
22 | from dm_control.suite import base
23 | from dm_control.suite import common
24 | from dm_control.utils import containers
25 | from dm_control.utils import rewards
26 | from dm_control.utils import io as resources
27 |
28 | # How long the simulation will run, in seconds.
29 | _DEFAULT_TIME_LIMIT = 10
30 |
31 | # Running speed above which reward is 1.
32 | _RUN_SPEED = 10
33 | _SPIN_SPEED = 5
34 |
35 | SUITE = containers.TaggedTasks()
36 |
37 |
38 | def make(task,
39 | task_kwargs=None,
40 | environment_kwargs=None,
41 | visualize_reward=False):
42 | task_kwargs = task_kwargs or {}
43 | if environment_kwargs is not None:
44 | task_kwargs = task_kwargs.copy()
45 | task_kwargs['environment_kwargs'] = environment_kwargs
46 | env = SUITE[task](**task_kwargs)
47 | env.task.visualize_reward = visualize_reward
48 | return env
49 |
50 |
51 | def get_model_and_assets():
52 | """Returns a tuple containing the model XML string and a dict of assets."""
53 | root_dir = os.path.dirname(os.path.dirname(__file__))
54 | xml = resources.GetResource(
55 | os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml'))
56 | return xml, common.ASSETS
57 |
58 |
59 | @SUITE.add('benchmarking')
60 | def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
61 | """Returns the run task."""
62 | physics = Physics.from_xml_string(*get_model_and_assets())
63 | task = Cheetah(forward=True, flip=False, random=random)
64 | environment_kwargs = environment_kwargs or {}
65 | return control.Environment(physics, task, time_limit=time_limit,
66 | **environment_kwargs)
67 |
68 |
69 | @SUITE.add('benchmarking')
70 | def run_backward(time_limit=_DEFAULT_TIME_LIMIT,
71 | random=None,
72 | environment_kwargs=None):
73 | """Returns the run task."""
74 | physics = Physics.from_xml_string(*get_model_and_assets())
75 | task = Cheetah(forward=False, flip=False, random=random)
76 | environment_kwargs = environment_kwargs or {}
77 | return control.Environment(physics,
78 | task,
79 | time_limit=time_limit,
80 | **environment_kwargs)
81 |
82 |
83 | @SUITE.add('benchmarking')
84 | def flip(time_limit=_DEFAULT_TIME_LIMIT,
85 | random=None,
86 | environment_kwargs=None):
87 | """Returns the run task."""
88 | physics = Physics.from_xml_string(*get_model_and_assets())
89 | task = Cheetah(forward=True, flip=True, random=random)
90 | environment_kwargs = environment_kwargs or {}
91 | return control.Environment(physics,
92 | task,
93 | time_limit=time_limit,
94 | **environment_kwargs)
95 |
96 |
97 | @SUITE.add('benchmarking')
98 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT,
99 | random=None,
100 | environment_kwargs=None):
101 | """Returns the run task."""
102 | physics = Physics.from_xml_string(*get_model_and_assets())
103 | task = Cheetah(forward=False, flip=True, random=random)
104 | environment_kwargs = environment_kwargs or {}
105 | return control.Environment(physics,
106 | task,
107 | time_limit=time_limit,
108 | **environment_kwargs)
109 |
110 |
111 | class Physics(mujoco.Physics):
112 | """Physics simulation with additional features for the Cheetah domain."""
113 | def speed(self):
114 | """Returns the horizontal speed of the Cheetah."""
115 | return self.named.data.sensordata['torso_subtreelinvel'][0]
116 |
117 | def angmomentum(self): # angmomentum 是角动量,用于鼓励旋转
118 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
119 | return self.named.data.subtree_angmom['torso'][1]
120 |
121 |
122 | class Cheetah(base.Task):
123 | """A `Task` to train a running Cheetah."""
124 | def __init__(self, forward=True, flip=False, random=None):
125 | self._forward = 1 if forward else -1
126 | self._flip = flip
127 | super(Cheetah, self).__init__(random=random)
128 |
129 | def initialize_episode(self, physics):
130 | """Sets the state of the environment at the start of each episode."""
131 | # The indexing below assumes that all joints have a single DOF.
132 | assert physics.model.nq == physics.model.njnt
133 | is_limited = physics.model.jnt_limited == 1
134 | lower, upper = physics.model.jnt_range[is_limited].T
135 | physics.data.qpos[is_limited] = self.random.uniform(lower, upper)
136 |
137 | # Stabilize the model before the actual simulation.
138 | for _ in range(200):
139 | physics.step()
140 |
141 | physics.data.time = 0
142 | self._timeout_progress = 0
143 | super().initialize_episode(physics)
144 |
145 | def get_observation(self, physics):
146 | """Returns an observation of the state, ignoring horizontal position."""
147 | obs = collections.OrderedDict()
148 | # Ignores horizontal position to maintain translational invariance.
149 | obs['position'] = physics.data.qpos[1:].copy()
150 | obs['velocity'] = physics.velocity()
151 | return obs
152 |
153 | def get_reward(self, physics):
154 | # TODO: 这里增加了选择. 原始的奖励为 physics.speed()
155 | """Returns a reward to the agent."""
156 | if self._flip:
157 | reward = rewards.tolerance(self._forward * physics.angmomentum(),
158 | bounds=(_SPIN_SPEED, float('inf')),
159 | margin=_SPIN_SPEED,
160 | value_at_margin=0,
161 | sigmoid='linear')
162 |
163 | else:
164 | reward = rewards.tolerance(self._forward * physics.speed(),
165 | bounds=(_RUN_SPEED, float('inf')),
166 | margin=_RUN_SPEED,
167 | value_at_margin=0,
168 | sigmoid='linear')
169 | return reward
170 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/point_mass_maze.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Point-mass domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 |
23 | from dm_control import mujoco
24 | from dm_control.rl import control
25 | from dm_control.suite import base
26 | from dm_control.suite import common
27 | from dm_control.suite.utils import randomizers
28 | from dm_control.utils import containers
29 | from dm_control.utils import rewards
30 | from dm_control.utils import io as resources
31 | from dm_env import specs
32 | import numpy as np
33 | import os
34 |
35 | _DEFAULT_TIME_LIMIT = 20
36 | SUITE = containers.TaggedTasks()
37 |
38 |
39 | TASKS = [('reach_top_left', np.array([-0.15, 0.15, 0.01])),
40 | ('reach_top_right', np.array([0.15, 0.15, 0.01])),
41 | ('reach_bottom_left', np.array([-0.15, -0.15, 0.01])),
42 | ('reach_bottom_right', np.array([0.15, -0.15, 0.01]))]
43 |
44 |
45 | def make(task,
46 | task_kwargs=None,
47 | environment_kwargs=None,
48 | visualize_reward=False):
49 | task_kwargs = task_kwargs or {}
50 | if environment_kwargs is not None:
51 | task_kwargs = task_kwargs.copy()
52 | task_kwargs['environment_kwargs'] = environment_kwargs
53 | env = SUITE[task](**task_kwargs)
54 | env.task.visualize_reward = visualize_reward
55 | return env
56 |
57 |
58 | def get_model_and_assets(task):
59 | """Returns a tuple containing the model XML string and a dict of assets."""
60 | root_dir = os.path.dirname(os.path.dirname(__file__))
61 | xml = resources.GetResource(
62 | os.path.join(root_dir, 'custom_dmc_tasks', f'point_mass_maze_{task}.xml'))
63 | return xml, common.ASSETS
64 |
65 |
66 | @SUITE.add('benchmarking')
67 | def reach_top_left(time_limit=_DEFAULT_TIME_LIMIT,
68 | random=None,
69 | environment_kwargs=None):
70 | """Returns the Run task."""
71 | physics = Physics.from_xml_string(*get_model_and_assets('reach_top_left'))
72 | task = MultiTaskPointMassMaze(target_id=0, random=random)
73 | environment_kwargs = environment_kwargs or {}
74 | return control.Environment(physics,
75 | task,
76 | time_limit=time_limit,
77 | **environment_kwargs)
78 |
79 |
80 | @SUITE.add('benchmarking')
81 | def reach_top_right(time_limit=_DEFAULT_TIME_LIMIT,
82 | random=None,
83 | environment_kwargs=None):
84 | """Returns the Run task."""
85 | physics = Physics.from_xml_string(*get_model_and_assets('reach_top_right'))
86 | task = MultiTaskPointMassMaze(target_id=1, random=random)
87 | environment_kwargs = environment_kwargs or {}
88 | return control.Environment(physics,
89 | task,
90 | time_limit=time_limit,
91 | **environment_kwargs)
92 |
93 |
94 | @SUITE.add('benchmarking')
95 | def reach_bottom_left(time_limit=_DEFAULT_TIME_LIMIT,
96 | random=None,
97 | environment_kwargs=None):
98 | """Returns the Run task."""
99 | physics = Physics.from_xml_string(*get_model_and_assets('reach_bottom_left'))
100 | task = MultiTaskPointMassMaze(target_id=2, random=random)
101 | environment_kwargs = environment_kwargs or {}
102 | return control.Environment(physics,
103 | task,
104 | time_limit=time_limit,
105 | **environment_kwargs)
106 |
107 |
108 | @SUITE.add('benchmarking')
109 | def reach_bottom_right(time_limit=_DEFAULT_TIME_LIMIT,
110 | random=None,
111 | environment_kwargs=None):
112 | """Returns the Run task."""
113 | physics = Physics.from_xml_string(*get_model_and_assets('reach_bottom_right'))
114 | task = MultiTaskPointMassMaze(target_id=3, random=random)
115 | environment_kwargs = environment_kwargs or {}
116 | return control.Environment(physics,
117 | task,
118 | time_limit=time_limit,
119 | **environment_kwargs)
120 |
121 |
122 | class Physics(mujoco.Physics):
123 | """physics for the point_mass domain."""
124 |
125 | def mass_to_target_dist(self, target):
126 | """Returns the distance from mass to the target."""
127 | d = target - self.named.data.geom_xpos['pointmass']
128 | return np.linalg.norm(d)
129 |
130 |
131 | class MultiTaskPointMassMaze(base.Task):
132 | """A point_mass `Task` to reach target with smooth reward."""
133 | def __init__(self, target_id, random=None):
134 | """Initialize an instance of `PointMassMaze`.
135 |
136 | Args:
137 | randomize_gains: A `bool`, whether to randomize the actuator gains.
138 | random: Optional, either a `numpy.random.RandomState` instance, an
139 | integer seed for creating a new `RandomState`, or None to select a seed
140 | automatically (default).
141 | """
142 | self._target = TASKS[target_id][1]
143 | super().__init__(random=random)
144 |
145 | def initialize_episode(self, physics):
146 | """Sets the state of the environment at the start of each episode.
147 |
148 | If _randomize_gains is True, the relationship between the controls and
149 | the joints is randomized, so that each control actuates a random linear
150 | combination of joints.
151 |
152 | Args:
153 | physics: An instance of `mujoco.Physics`.
154 | """
155 | randomizers.randomize_limited_and_rotational_joints(
156 | physics, self.random)
157 | physics.data.qpos[0] = np.random.uniform(-0.29, -0.15)
158 | physics.data.qpos[1] = np.random.uniform(0.15, 0.29)
159 | #import ipdb; ipdb.set_trace()
160 | physics.named.data.geom_xpos['target'][:] = self._target
161 |
162 | super().initialize_episode(physics)
163 |
164 | def get_observation(self, physics):
165 | """Returns an observation of the state."""
166 | obs = collections.OrderedDict()
167 | obs['position'] = physics.position()
168 | obs['velocity'] = physics.velocity()
169 | return obs
170 |
171 | def get_reward_spec(self):
172 | return specs.Array(shape=(1,), dtype=np.float32, name='reward')
173 |
174 | def get_reward(self, physics):
175 | """Returns a reward to the agent."""
176 | target_size = .015
177 | control_reward = rewards.tolerance(physics.control(), margin=1,
178 | value_at_margin=0,
179 | sigmoid='quadratic').mean()
180 | small_control = (control_reward + 4) / 5
181 | near_target = rewards.tolerance(physics.mass_to_target_dist(self._target),
182 | bounds=(0, target_size), margin=target_size)
183 | reward = near_target * small_control
184 | return reward
185 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/hopper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Hopper domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | import numpy as np
33 |
34 | SUITE = containers.TaggedTasks()
35 |
36 | _CONTROL_TIMESTEP = .02 # (Seconds)
37 |
38 | # Default duration of an episode, in seconds.
39 | _DEFAULT_TIME_LIMIT = 20
40 |
41 | # Minimal height of torso over foot above which stand reward is 1.
42 | _STAND_HEIGHT = 0.6
43 |
44 | # Hopping speed above which hop reward is 1.
45 | _HOP_SPEED = 2
46 | _SPIN_SPEED = 5
47 |
48 |
49 | def make(task,
50 | task_kwargs=None,
51 | environment_kwargs=None,
52 | visualize_reward=False):
53 | task_kwargs = task_kwargs or {}
54 | if environment_kwargs is not None:
55 | task_kwargs = task_kwargs.copy()
56 | task_kwargs['environment_kwargs'] = environment_kwargs
57 | env = SUITE[task](**task_kwargs)
58 | env.task.visualize_reward = visualize_reward
59 | return env
60 |
61 | def get_model_and_assets():
62 | """Returns a tuple containing the model XML string and a dict of assets."""
63 | root_dir = os.path.dirname(os.path.dirname(__file__))
64 | xml = resources.GetResource(
65 | os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml'))
66 | return xml, common.ASSETS
67 |
68 |
69 |
70 | @SUITE.add('benchmarking')
71 | def hop_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
72 | """Returns a Hopper that strives to hop forward."""
73 | physics = Physics.from_xml_string(*get_model_and_assets())
74 | task = Hopper(hopping=True, forward=False, flip=False, random=random)
75 | environment_kwargs = environment_kwargs or {}
76 | return control.Environment(physics,
77 | task,
78 | time_limit=time_limit,
79 | control_timestep=_CONTROL_TIMESTEP,
80 | **environment_kwargs)
81 |
82 | @SUITE.add('benchmarking')
83 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
84 | """Returns a Hopper that strives to hop forward."""
85 | physics = Physics.from_xml_string(*get_model_and_assets())
86 | task = Hopper(hopping=True, forward=True, flip=True, random=random)
87 | environment_kwargs = environment_kwargs or {}
88 | return control.Environment(physics,
89 | task,
90 | time_limit=time_limit,
91 | control_timestep=_CONTROL_TIMESTEP,
92 | **environment_kwargs)
93 |
94 | @SUITE.add('benchmarking')
95 | def flip_backward(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
96 | """Returns a Hopper that strives to hop forward."""
97 | physics = Physics.from_xml_string(*get_model_and_assets())
98 | task = Hopper(hopping=True, forward=False, flip=True, random=random)
99 | environment_kwargs = environment_kwargs or {}
100 | return control.Environment(physics,
101 | task,
102 | time_limit=time_limit,
103 | control_timestep=_CONTROL_TIMESTEP,
104 | **environment_kwargs)
105 |
106 |
107 | class Physics(mujoco.Physics):
108 | """Physics simulation with additional features for the Hopper domain."""
109 | def height(self):
110 | """Returns height of torso with respect to foot."""
111 | return (self.named.data.xipos['torso', 'z'] -
112 | self.named.data.xipos['foot', 'z'])
113 |
114 | def speed(self):
115 | """Returns horizontal speed of the Hopper."""
116 | return self.named.data.sensordata['torso_subtreelinvel'][0]
117 |
118 | def touch(self):
119 | """Returns the signals from two foot touch sensors."""
120 | return np.log1p(self.named.data.sensordata[['touch_toe',
121 | 'touch_heel']])
122 |
123 | def angmomentum(self): # TODO: 角动量
124 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
125 | return self.named.data.subtree_angmom['torso'][1]
126 |
127 |
128 |
129 | class Hopper(base.Task):
130 | """A Hopper's `Task` to train a standing and a jumping Hopper."""
131 | def __init__(self, hopping, forward=True, flip=False, random=None):
132 | """Initialize an instance of `Hopper`.
133 |
134 | Args:
135 | hopping: Boolean, if True the task is to hop forwards, otherwise it is to
136 | balance upright.
137 | random: Optional, either a `numpy.random.RandomState` instance, an
138 | integer seed for creating a new `RandomState`, or None to select a seed
139 | automatically (default).
140 | """
141 | self._hopping = hopping
142 | self._forward = 1 if forward else -1
143 | self._flip = flip
144 | super(Hopper, self).__init__(random=random)
145 |
146 | def initialize_episode(self, physics):
147 | """Sets the state of the environment at the start of each episode."""
148 | randomizers.randomize_limited_and_rotational_joints(
149 | physics, self.random)
150 | self._timeout_progress = 0
151 | super(Hopper, self).initialize_episode(physics)
152 |
153 | def get_observation(self, physics):
154 | """Returns an observation of positions, velocities and touch sensors."""
155 | obs = collections.OrderedDict()
156 | # Ignores horizontal position to maintain translational invariance:
157 | obs['position'] = physics.data.qpos[1:].copy()
158 | obs['velocity'] = physics.velocity()
159 | obs['touch'] = physics.touch()
160 | return obs
161 |
162 | def get_reward(self, physics):
163 | """Returns a reward applicable to the performed task."""
164 | standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
165 | assert self._hopping
166 | if self._flip:
167 | hopping = rewards.tolerance(self._forward * physics.angmomentum(),
168 | bounds=(_SPIN_SPEED, float('inf')),
169 | margin=_SPIN_SPEED,
170 | value_at_margin=0,
171 | sigmoid='linear')
172 | else:
173 | hopping = rewards.tolerance(self._forward * physics.speed(),
174 | bounds=(_HOP_SPEED, float('inf')),
175 | margin=_HOP_SPEED / 2,
176 | value_at_margin=0.5,
177 | sigmoid='linear')
178 | return standing * hopping
--------------------------------------------------------------------------------
/replay_buffer_collect.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import io
3 | import random
4 | import traceback
5 | from collections import defaultdict
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import IterableDataset
11 |
12 |
13 | def episode_len(episode):
14 | # subtract -1 because the dummy first transition
15 | return next(iter(episode.values())).shape[0] - 1
16 |
17 |
18 | def save_episode(episode, fn):
19 | with io.BytesIO() as bs:
20 | np.savez_compressed(bs, **episode)
21 | bs.seek(0)
22 | with fn.open('wb') as f:
23 | f.write(bs.read())
24 |
25 |
26 | def load_episode(fn):
27 | with fn.open('rb') as f:
28 | episode = np.load(f)
29 | episode = {k: episode[k] for k in episode.keys()}
30 | return episode
31 |
32 |
33 | class ReplayBufferStorage:
34 | def __init__(self, data_specs, meta_specs, replay_dir, dataset_dir):
35 | self._data_specs = data_specs
36 | self._meta_specs = meta_specs
37 | self._replay_dir = replay_dir
38 | replay_dir.mkdir(exist_ok=True)
39 | self._current_episode = defaultdict(list)
40 | self._preload()
41 | # TODO
42 | self._dataset_episodic = {"observation": [], "action": [], "reward": [], "discount": [], "physics": []}
43 | self._dataset_dir = dataset_dir
44 | dataset_dir.mkdir(exist_ok=True)
45 |
46 | def __len__(self):
47 | return self._num_transitions
48 |
49 | def add(self, time_step, meta, physics):
50 | self._dataset_episodic['physics'].append(physics) # add physics
51 |
52 | for key, value in meta.items():
53 | self._current_episode[key].append(value)
54 | for spec in self._data_specs:
55 | value = time_step[spec.name]
56 | if np.isscalar(value):
57 | value = np.full(spec.shape, value, spec.dtype)
58 | assert spec.shape == value.shape and spec.dtype == value.dtype
59 | self._current_episode[spec.name].append(value)
60 | self._dataset_episodic[spec.name].append(value) # append value
61 |
62 | if time_step.last():
63 | episode = dict()
64 | for spec in self._data_specs:
65 | value = self._current_episode[spec.name]
66 | episode[spec.name] = np.array(value, spec.dtype)
67 | for spec in self._meta_specs:
68 | value = self._current_episode[spec.name]
69 | episode[spec.name] = np.array(value, spec.dtype)
70 | self._current_episode = defaultdict(list)
71 | self._store_episode(episode)
72 | self._store_dataset()
73 |
74 | def _preload(self):
75 | self._num_episodes = 0
76 | self._num_transitions = 0
77 | for fn in self._replay_dir.glob('*.npz'):
78 | _, _, eps_len = fn.stem.split('_')
79 | self._num_episodes += 1
80 | self._num_transitions += int(eps_len)
81 |
82 | def _store_episode(self, episode):
83 | eps_idx = self._num_episodes
84 | eps_len = episode_len(episode)
85 | self._num_episodes += 1
86 | self._num_transitions += eps_len
87 | ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
88 | eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
89 | save_episode(episode, self._replay_dir / eps_fn)
90 |
91 | def _store_dataset(self): # save the data
92 | # print("save data to", self._dataset_dir)
93 | dataset_episodic = dict()
94 | for key in self._dataset_episodic.keys():
95 | value = self._dataset_episodic[key]
96 | dataset_episodic[key] = np.array(value, np.float32)
97 | # print("save:", key, np.array(value, np.float32).shape)
98 | # ts = datetime.datetime.now().strftime('%m%dT%H%M%S')
99 | episode_str = f"{self._num_episodes-1:06d}"
100 | eps_fn = f'episode_{episode_str}_{episode_len(dataset_episodic)}.npz'
101 | np.savez_compressed(self._dataset_dir / eps_fn, **dataset_episodic)
102 |
103 | for k in self._dataset_episodic.keys(): # clear
104 | self._dataset_episodic[k] = []
105 |
106 |
107 | class ReplayBuffer(IterableDataset):
108 | def __init__(self, storage, max_size, num_workers, nstep, discount,
109 | fetch_every, save_snapshot):
110 | self._storage = storage
111 | self._size = 0
112 | self._max_size = max_size
113 | self._num_workers = max(1, num_workers)
114 | self._episode_fns = []
115 | self._episodes = dict()
116 | self._nstep = nstep
117 | self._discount = discount
118 | self._fetch_every = fetch_every
119 | self._samples_since_last_fetch = fetch_every
120 | self._save_snapshot = save_snapshot
121 |
122 | def _sample_episode(self):
123 | eps_fn = random.choice(self._episode_fns)
124 | return self._episodes[eps_fn]
125 |
126 | def _store_episode(self, eps_fn):
127 | # store new episodic samples to self._episode_fn
128 | try:
129 | episode = load_episode(eps_fn)
130 | except:
131 | return False
132 | eps_len = episode_len(episode) # 1000
133 | while eps_len + self._size > self._max_size:
134 | early_eps_fn = self._episode_fns.pop(0)
135 | early_eps = self._episodes.pop(early_eps_fn)
136 | self._size -= episode_len(early_eps)
137 | early_eps_fn.unlink(missing_ok=True)
138 | self._episode_fns.append(eps_fn) # num of episodes / num of worker
139 | self._episode_fns.sort()
140 | self._episodes[eps_fn] = episode
141 | self._size += eps_len
142 | # print("store ", eps_len, episode.keys(), ", total:", len(self._episode_fns))
143 | # print("self._storage._replay_dir.glob('*.npz')", self._storage._replay_dir.glob('*.npz'))
144 |
145 | if not self._save_snapshot:
146 | eps_fn.unlink(missing_ok=True)
147 | return True
148 |
149 | def _try_fetch(self):
150 | if self._samples_since_last_fetch < self._fetch_every:
151 | return
152 | self._samples_since_last_fetch = 0
153 | try:
154 | worker_id = torch.utils.data.get_worker_info().id
155 | except:
156 | worker_id = 0
157 | # print("\nworker id:", worker_id)
158 | eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True)
159 | # print("lens of eps_fns:", len(eps_fns))
160 | fetched_size = 0
161 | for eps_fn in eps_fns:
162 | eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
163 | if eps_idx % self._num_workers != worker_id:
164 | continue
165 | if eps_fn in self._episodes.keys():
166 | break
167 | if fetched_size + eps_len > self._max_size:
168 | break
169 | fetched_size += eps_len
170 | # print("eps_fns:", eps_fn, fetched_size, eps_len)
171 | if not self._store_episode(eps_fn):
172 | break
173 |
174 | def _sample(self):
175 | try:
176 | self._try_fetch()
177 | except:
178 | traceback.print_exc()
179 | self._samples_since_last_fetch += 1
180 | episode = self._sample_episode()
181 | # add +1 for the first dummy transition
182 | idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
183 | meta = []
184 | for spec in self._storage._meta_specs:
185 | meta.append(episode[spec.name][idx - 1])
186 | obs = episode['observation'][idx - 1]
187 | action = episode['action'][idx]
188 | next_obs = episode['observation'][idx + self._nstep - 1]
189 | reward = np.zeros_like(episode['reward'][idx])
190 | discount = np.ones_like(episode['discount'][idx])
191 | for i in range(self._nstep):
192 | step_reward = episode['reward'][idx + i]
193 | reward += discount * step_reward
194 | discount *= episode['discount'][idx + i] * self._discount
195 | return (obs, action, reward, discount, next_obs, *meta)
196 |
197 | def __iter__(self):
198 | while True:
199 | yield self._sample()
200 |
201 |
202 | def _worker_init_fn(worker_id):
203 | seed = np.random.get_state()[1][0] + worker_id
204 | np.random.seed(seed)
205 | random.seed(seed)
206 |
207 |
208 | def make_replay_loader(storage, max_size, batch_size, num_workers,
209 | save_snapshot, nstep, discount):
210 | max_size_per_worker = max_size // max(1, num_workers)
211 |
212 | iterable = ReplayBuffer(storage,
213 | max_size_per_worker,
214 | num_workers,
215 | nstep,
216 | discount,
217 | fetch_every=1000,
218 | save_snapshot=save_snapshot)
219 |
220 | loader = torch.utils.data.DataLoader(iterable,
221 | batch_size=batch_size,
222 | num_workers=num_workers,
223 | pin_memory=True,
224 | worker_init_fn=_worker_init_fn)
225 | return loader
226 |
--------------------------------------------------------------------------------
/agent/crr.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 | from einops import repeat, rearrange
8 |
9 | import utils
10 |
11 |
12 | class Actor(nn.Module):
13 | def __init__(self, obs_dim, action_dim, hidden_dim):
14 | super().__init__()
15 |
16 | self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim),
17 | nn.LayerNorm(hidden_dim), nn.Tanh(),
18 | nn.Linear(hidden_dim, hidden_dim),
19 | nn.ReLU(inplace=True),
20 | nn.Linear(hidden_dim, action_dim))
21 |
22 | self.apply(utils.weight_init)
23 |
24 | def forward(self, obs, std):
25 | mu = self.policy(obs)
26 | mu = torch.tanh(mu)
27 | std = torch.ones_like(mu) * std
28 |
29 | dist = utils.TruncatedNormal(mu, std)
30 | return dist
31 |
32 |
33 | class Critic(nn.Module):
34 | def __init__(self, obs_dim, action_dim, hidden_dim):
35 | super().__init__()
36 |
37 | self.q1_net = nn.Sequential(
38 | nn.Linear(obs_dim + action_dim, hidden_dim),
39 | nn.LayerNorm(hidden_dim), nn.Tanh(),
40 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
41 | nn.Linear(hidden_dim, 1))
42 |
43 | self.q2_net = nn.Sequential(
44 | nn.Linear(obs_dim + action_dim, hidden_dim),
45 | nn.LayerNorm(hidden_dim), nn.Tanh(),
46 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
47 | nn.Linear(hidden_dim, 1))
48 |
49 | self.apply(utils.weight_init)
50 |
51 | def forward(self, obs, action):
52 | obs_action = torch.cat([obs, action], dim=-1)
53 | q1 = self.q1_net(obs_action)
54 | q2 = self.q2_net(obs_action)
55 |
56 | return q1, q2
57 |
58 |
59 | class CRRAgent:
60 | def __init__(self,
61 | name,
62 | obs_shape,
63 | action_shape,
64 | device,
65 | lr,
66 | hidden_dim,
67 | critic_target_tau,
68 | num_value_samples,
69 | weight_func,
70 | stddev_schedule,
71 | nstep,
72 | batch_size,
73 | stddev_clip,
74 | use_tb,
75 | has_next_action=False):
76 | self.action_dim = action_shape[0]
77 | self.hidden_dim = hidden_dim
78 | self.lr = lr
79 | self.device = device
80 | self.critic_target_tau = critic_target_tau
81 | self.use_tb = use_tb
82 | self.stddev_schedule = stddev_schedule
83 | self.stddev_clip = stddev_clip
84 | self.num_value_samples = num_value_samples
85 | self.weight_func = weight_func
86 |
87 | # models
88 | self.actor = Actor(obs_shape[0], action_shape[0],
89 | hidden_dim).to(device)
90 |
91 | self.critic = Critic(obs_shape[0], action_shape[0],
92 | hidden_dim).to(device)
93 | self.critic_target = Critic(obs_shape[0], action_shape[0],
94 | hidden_dim).to(device)
95 | self.critic_target.load_state_dict(self.critic.state_dict())
96 |
97 | # optimizers
98 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
99 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
100 |
101 | self.train()
102 | self.critic_target.train()
103 |
104 | def train(self, training=True):
105 | self.training = training
106 | self.actor.train(training)
107 | self.critic.train(training)
108 |
109 | def act(self, obs, step, eval_mode):
110 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
111 | stddev = utils.schedule(self.stddev_schedule, step)
112 | policy = self.actor(obs, stddev)
113 | if eval_mode:
114 | action = policy.mean
115 | else:
116 | action = policy.sample(clip=None)
117 | if step < self.num_expl_steps:
118 | action.uniform_(-1.0, 1.0)
119 | return action.cpu().numpy()[0]
120 |
121 | def compute_value(self, obs, step):
122 | obses = repeat(obs, 'b x -> (b n) x', n=self.num_value_samples)
123 | stddev = utils.schedule(self.stddev_schedule, step)
124 | dists = self.actor(obses, stddev)
125 | actions = dists.sample(clip=self.stddev_clip)
126 | Q1, Q2 = self.critic(obses, actions)
127 | Q = torch.min(Q1, Q2)
128 | V = rearrange(Q, '(b n) x -> b n x',
129 | n=self.num_value_samples).mean(dim=1)
130 |
131 | return V
132 |
133 | def adv_transform(self, A):
134 | assert self.weight_func in ['identity', 'indicator', 'exp']
135 | if self.weight_func == 'identity':
136 | return A
137 | elif self.weight_func == 'indicator':
138 | return torch.sign(torch.relu(A))
139 | elif self.weight_func == 'exp':
140 | return torch.clamp(A.exp(), 0, 20.0)
141 | else:
142 | assert False, f'wrong weight function: {self.weight_func}'
143 |
144 | def update_critic(self, obs, action, reward, discount, next_obs, step):
145 | metrics = dict()
146 |
147 | with torch.no_grad():
148 | stddev = utils.schedule(self.stddev_schedule, step)
149 | dist = self.actor(next_obs, stddev)
150 | next_action = dist.sample(clip=self.stddev_clip)
151 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
152 | target_V = torch.min(target_Q1, target_Q2)
153 | target_Q = reward + (discount * target_V)
154 |
155 | Q1, Q2 = self.critic(obs, action)
156 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
157 |
158 | if self.use_tb:
159 | metrics['critic_target_q'] = target_Q.mean().item()
160 | metrics['critic_q1'] = Q1.mean().item()
161 | metrics['critic_q2'] = Q2.mean().item()
162 | metrics['critic_loss'] = critic_loss.item()
163 |
164 | # optimize critic
165 | self.critic_opt.zero_grad(set_to_none=True)
166 | critic_loss.backward()
167 | self.critic_opt.step()
168 | return metrics
169 |
170 | def update_actor(self, obs, action, step):
171 | metrics = dict()
172 |
173 | metrics = dict()
174 | with torch.no_grad():
175 | V = self.compute_value(obs, step)
176 | Q1, Q2 = self.critic(obs, action)
177 | Q = torch.min(Q1, Q2)
178 | A = Q - V
179 | w = self.adv_transform(A)
180 |
181 | stddev = utils.schedule(self.stddev_schedule, step)
182 | policy = self.actor(obs, stddev)
183 |
184 | log_prob = policy.log_prob(action).sum(-1, keepdim=True)
185 | actor_loss = -(log_prob * w).mean()
186 |
187 | # optimize actor
188 | self.actor_opt.zero_grad(set_to_none=True)
189 | actor_loss.backward()
190 | self.actor_opt.step()
191 |
192 | if self.use_tb:
193 | metrics['actor_loss'] = actor_loss.item()
194 | metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item()
195 |
196 | return metrics
197 |
198 | def update(self, replay_iter, step):
199 | metrics = dict()
200 |
201 | batch = next(replay_iter)
202 | obs, action, reward, discount, next_obs = utils.to_torch(
203 | batch, self.device)
204 |
205 | if self.use_tb:
206 | metrics['batch_reward'] = reward.mean().item()
207 |
208 | # update critic
209 | metrics.update(
210 | self.update_critic(obs, action, reward, discount, next_obs, step))
211 |
212 | # update actor
213 | metrics.update(self.update_actor(obs, action, step))
214 |
215 | # update critic target
216 | utils.soft_update_params(self.critic, self.critic_target,
217 | self.critic_target_tau)
218 |
219 | return metrics
220 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/jaco.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """A task where the goal is to move the hand close to a target prop or site."""
16 |
17 | import collections
18 |
19 | from dm_control import composer
20 | from dm_control.composer import initializers
21 | from dm_control.composer.observation import observable
22 | from dm_control.composer.variation import distributions
23 | from dm_control.entities import props
24 | from dm_control.manipulation.shared import arenas
25 | from dm_control.manipulation.shared import cameras
26 | from dm_control.manipulation.shared import constants
27 | from dm_control.manipulation.shared import observations
28 | from dm_control.manipulation.shared import registry
29 | from dm_control.manipulation.shared import robots
30 | from dm_control.manipulation.shared import tags
31 | from dm_control.manipulation.shared import workspaces
32 | from dm_control.utils import rewards
33 | from dm_env import specs
34 | import numpy as np
35 |
36 | _ReachWorkspace = collections.namedtuple(
37 | '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])
38 |
39 | # Ensures that the props are not touching the table before settling.
40 | _PROP_Z_OFFSET = 0.001
41 |
42 | _DUPLO_WORKSPACE = _ReachWorkspace(
43 | target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET),
44 | upper=(0.1, 0.1, _PROP_Z_OFFSET)),
45 | tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2),
46 | upper=(0.1, 0.1, 0.4)),
47 | arm_offset=robots.ARM_OFFSET)
48 |
49 | _SITE_WORKSPACE = _ReachWorkspace(
50 | target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
51 | upper=(0.2, 0.2, 0.4)),
52 | tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
53 | upper=(0.2, 0.2, 0.4)),
54 | arm_offset=robots.ARM_OFFSET)
55 |
56 | _TARGET_RADIUS = 0.05
57 | _TIME_LIMIT = 10.
58 |
59 | TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])),
60 | ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])),
61 | ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])),
62 | ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))]
63 |
64 |
65 | def make(task_id, obs_type, seed):
66 | obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
67 | task = _reach(task_id, obs_settings=obs_settings, use_site=True)
68 | return composer.Environment(task,
69 | time_limit=_TIME_LIMIT,
70 | random_state=seed)
71 |
72 |
73 | class MultiTaskReach(composer.Task):
74 | """Bring the hand close to a target prop or site."""
75 | def __init__(self, task_id, arena, arm, hand, prop, obs_settings,
76 | workspace, control_timestep):
77 | """Initializes a new `Reach` task.
78 |
79 | Args:
80 | arena: `composer.Entity` instance.
81 | arm: `robot_base.RobotArm` instance.
82 | hand: `robot_base.RobotHand` instance.
83 | prop: `composer.Entity` instance specifying the prop to reach to, or None
84 | in which case the target is a fixed site whose position is specified by
85 | the workspace.
86 | obs_settings: `observations.ObservationSettings` instance.
87 | workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
88 | control_timestep: Float specifying the control timestep in seconds.
89 | """
90 | self._arena = arena
91 | self._arm = arm
92 | self._hand = hand
93 | self._arm.attach(self._hand)
94 | self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
95 | self.control_timestep = control_timestep
96 | self._tcp_initializer = initializers.ToolCenterPointInitializer(
97 | self._hand,
98 | self._arm,
99 | position=distributions.Uniform(*workspace.tcp_bbox),
100 | quaternion=workspaces.DOWN_QUATERNION)
101 |
102 | # Add custom camera observable.
103 | self._task_observables = cameras.add_camera_observables(
104 | arena, obs_settings, cameras.FRONT_CLOSE)
105 |
106 | if task_id == 'reach_multitask':
107 | self._targets = [target for (_, target) in TASKS]
108 | else:
109 | self._targets = [
110 | target for (task, target) in TASKS if task == task_id
111 | ]
112 |
113 | # target_pos_distribution = distributions.Uniform(*TASKS[task_id])
114 | self._prop = prop
115 | if prop:
116 | # The prop itself is used to visualize the target location.
117 | self._make_target_site(parent_entity=prop, visible=False)
118 | self._target = self._arena.add_free_entity(prop)
119 | self._prop_placer = initializers.PropPlacer(
120 | props=[prop],
121 | position=target_pos_distribution,
122 | quaternion=workspaces.uniform_z_rotation,
123 | settle_physics=True)
124 | else:
125 | if len(self._targets) == 1:
126 | self._target = self._make_target_site(parent_entity=arena,
127 | visible=True)
128 |
129 | # obs = observable.MJCFFeature('pos', self._target)
130 | # obs.configure(**obs_settings.prop_pose._asdict())
131 | # self._task_observables['target_position'] = obs
132 |
133 | # Add sites for visualizing the prop and target bounding boxes.
134 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
135 | lower=workspace.tcp_bbox.lower,
136 | upper=workspace.tcp_bbox.upper,
137 | rgba=constants.GREEN,
138 | name='tcp_spawn_area')
139 | workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
140 | lower=workspace.target_bbox.lower,
141 | upper=workspace.target_bbox.upper,
142 | rgba=constants.BLUE,
143 | name='target_spawn_area')
144 |
145 | def _make_target_site(self, parent_entity, visible):
146 | return workspaces.add_target_site(
147 | body=parent_entity.mjcf_model.worldbody,
148 | radius=_TARGET_RADIUS,
149 | visible=visible,
150 | rgba=constants.RED,
151 | name='target_site')
152 |
153 | @property
154 | def root_entity(self):
155 | return self._arena
156 |
157 | @property
158 | def arm(self):
159 | return self._arm
160 |
161 | @property
162 | def hand(self):
163 | return self._hand
164 |
165 | def get_reward_spec(self):
166 | n = len(self._targets)
167 | return specs.Array(shape=(n,), dtype=np.float32, name='reward')
168 |
169 | @property
170 | def task_observables(self):
171 | return self._task_observables
172 |
173 | def get_reward(self, physics):
174 | hand_pos = physics.bind(self._hand.tool_center_point).xpos
175 | rews = []
176 | for target_pos in self._targets:
177 | distance = np.linalg.norm(hand_pos - target_pos)
178 | reward = rewards.tolerance(distance,
179 | bounds=(0, _TARGET_RADIUS),
180 | margin=_TARGET_RADIUS)
181 | rews.append(reward)
182 | rews = np.array(rews).astype(np.float32)
183 | if len(self._targets) == 1:
184 | return rews[0]
185 | return rews
186 |
187 | def initialize_episode(self, physics, random_state):
188 | self._hand.set_grasp(physics, close_factors=random_state.uniform())
189 | self._tcp_initializer(physics, random_state)
190 | if self._prop:
191 | self._prop_placer(physics, random_state)
192 | else:
193 | if len(self._targets) == 1:
194 | physics.bind(self._target).pos = self._targets[0]
195 |
196 |
197 | def _reach(task_id, obs_settings, use_site):
198 | """Configure and instantiate a `Reach` task.
199 |
200 | Args:
201 | obs_settings: An `observations.ObservationSettings` instance.
202 | use_site: Boolean, if True then the target will be a fixed site, otherwise
203 | it will be a moveable Duplo brick.
204 |
205 | Returns:
206 | An instance of `reach.Reach`.
207 | """
208 | arena = arenas.Standard()
209 | arm = robots.make_arm(obs_settings=obs_settings)
210 | hand = robots.make_hand(obs_settings=obs_settings)
211 | if use_site:
212 | workspace = _SITE_WORKSPACE
213 | prop = None
214 | else:
215 | workspace = _DUPLO_WORKSPACE
216 | prop = props.Duplo(observable_options=observations.make_options(
217 | obs_settings, observations.FREEPROP_OBSERVABLES))
218 | task = MultiTaskReach(task_id,
219 | arena=arena,
220 | arm=arm,
221 | hand=hand,
222 | prop=prop,
223 | obs_settings=obs_settings,
224 | workspace=workspace,
225 | control_timestep=constants.CONTROL_TIMESTEP)
226 | return task
227 |
--------------------------------------------------------------------------------
/collect_data.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import wandb
3 | warnings.filterwarnings('ignore', category=DeprecationWarning)
4 |
5 | import os
6 |
7 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
8 | os.environ['MUJOCO_GL'] = 'egl'
9 | import pickle
10 | from pathlib import Path
11 |
12 | import hydra
13 | import numpy as np
14 | import torch
15 | from dm_env import specs
16 |
17 | import dmc
18 | import utils
19 | from logger import Logger
20 | from replay_buffer_collect import ReplayBufferStorage, make_replay_loader
21 | from video import TrainVideoRecorder, VideoRecorder
22 |
23 | torch.backends.cudnn.benchmark = True
24 |
25 |
26 | def make_agent(obs_spec, action_spec, num_expl_steps, cfg):
27 | cfg.obs_shape = obs_spec.shape
28 | cfg.action_shape = action_spec.shape
29 | cfg.num_expl_steps = num_expl_steps
30 | return hydra.utils.instantiate(cfg)
31 |
32 |
33 | class Workspace:
34 | def __init__(self, cfg):
35 | self.work_dir = Path.cwd()
36 | print(f'workspace: {self.work_dir}')
37 |
38 | self.cfg = cfg
39 | utils.set_seed_everywhere(cfg.seed)
40 | self.device = torch.device(cfg.device)
41 |
42 | # create logger
43 | self.logger = Logger(self.work_dir, use_tb=cfg.use_tb)
44 | self.train_env = dmc.make(cfg.task, action_repeat=cfg.action_repeat, seed=cfg.seed)
45 | self.eval_env = dmc.make(cfg.task, action_repeat=cfg.action_repeat, seed=cfg.seed)
46 |
47 | # create agent
48 | self.agent = make_agent(self.train_env.observation_spec(),
49 | self.train_env.action_spec(),
50 | cfg.num_seed_frames // cfg.action_repeat,
51 | cfg.agent)
52 |
53 | # get meta specs
54 | meta_specs = self.agent.get_meta_specs()
55 | # create replay buffer
56 | data_specs = (self.train_env.observation_spec(), self.train_env.action_spec(),
57 | specs.Array((1,), np.float32, 'reward'),
58 | specs.Array((1,), np.float32, 'discount'))
59 |
60 | # create data storage
61 | self.replay_storage = ReplayBufferStorage(data_specs, meta_specs,
62 | replay_dir=self.work_dir / 'buffer', dataset_dir=self.work_dir / 'data')
63 |
64 | # create replay buffer
65 | self.replay_loader = make_replay_loader(self.replay_storage, cfg.replay_buffer_size,
66 | cfg.batch_size, cfg.replay_buffer_num_workers, False, cfg.nstep, cfg.discount)
67 | self._replay_iter = None
68 |
69 | # create video recorders
70 | self.video_recorder = VideoRecorder(self.work_dir if cfg.save_video else None)
71 | self.train_video_recorder = TrainVideoRecorder(self.work_dir if cfg.save_train_video else None)
72 |
73 | self.timer = utils.Timer()
74 | self._global_step = 0
75 | self._global_episode = 0
76 |
77 | # TODO: save agent
78 | self._agent_dir = self.work_dir / 'agent'
79 | self._agent_dir.mkdir(exist_ok=True)
80 | self.change_freq = True
81 |
82 | @property
83 | def global_step(self):
84 | return self._global_step
85 |
86 | @property
87 | def global_episode(self):
88 | return self._global_episode
89 |
90 | @property
91 | def global_frame(self):
92 | return self.global_step * self.cfg.action_repeat
93 |
94 | @property
95 | def replay_iter(self):
96 | if self._replay_iter is None:
97 | self._replay_iter = iter(self.replay_loader)
98 | return self._replay_iter
99 |
100 | def eval(self):
101 | step, episode, total_reward = 0, 0, 0
102 | eval_until_episode = utils.Until(self.cfg.num_eval_episodes)
103 | meta = self.agent.init_meta()
104 | while eval_until_episode(episode): # eval 10 episodes
105 | time_step = self.eval_env.reset()
106 | self.video_recorder.init(self.eval_env, enabled=(episode == 0))
107 | while not time_step.last():
108 | with torch.no_grad(), utils.eval_mode(self.agent):
109 | action = self.agent.act(time_step.observation, step=self.global_step, eval_mode=True, meta=meta)
110 | time_step = self.eval_env.step(action)
111 | self.video_recorder.record(self.eval_env)
112 | total_reward += time_step.reward
113 | step += 1
114 | episode += 1
115 | self.video_recorder.save(f'{self.global_frame}.mp4')
116 |
117 | with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
118 | log('episode_reward', total_reward / episode)
119 | log('episode_length', step * self.cfg.action_repeat / episode)
120 | log('episode', self.global_episode)
121 | log('step', self.global_step)
122 | if self.cfg.use_wandb:
123 | wandb.log({"eval_return": total_reward / episode})
124 |
125 | return total_reward / episode
126 |
127 | def train(self):
128 | train_until_step = utils.Until(self.cfg.num_train_frames, self.cfg.action_repeat)
129 | seed_until_step = utils.Until(self.cfg.num_seed_frames, self.cfg.action_repeat)
130 | eval_every_step = utils.Every(self.cfg.eval_every_frames, self.cfg.action_repeat)
131 | save_every_step = utils.Every(self.cfg.eval_every_frames, self.cfg.action_repeat) # TODO: save agent
132 |
133 | episode_step, episode_reward = 0, 0
134 | time_step = self.train_env.reset()
135 | meta = self.agent.init_meta()
136 | self.replay_storage.add(time_step, meta, physics=self.train_env.physics.get_state()) # 这里加入了physics信息
137 | self.train_video_recorder.init(time_step.observation)
138 | metrics = None
139 | eval_rew = 0
140 | while train_until_step(self.global_step):
141 | if time_step.last():
142 | self._global_episode += 1
143 | self.train_video_recorder.save(f'{self.global_frame}.mp4')
144 | # wait until all the metrics schema is populated
145 | if metrics is not None:
146 | # log stats
147 | elapsed_time, total_time = self.timer.reset()
148 | episode_frame = episode_step * self.cfg.action_repeat
149 | with self.logger.log_and_dump_ctx(self.global_frame, ty='train') as log:
150 | log('fps', episode_frame / elapsed_time)
151 | log('total_time', total_time)
152 | log('episode_reward', episode_reward)
153 | log('episode_length', episode_frame)
154 | log('episode', self.global_episode)
155 | log('buffer_size', len(self.replay_storage))
156 | log('step', self.global_step)
157 |
158 | # reset env
159 | time_step = self.train_env.reset()
160 | meta = self.agent.init_meta()
161 | self.replay_storage.add(time_step, meta, physics=self.train_env.physics.get_state())
162 | self.train_video_recorder.init(time_step.observation)
163 |
164 | episode_step = 0
165 | episode_reward = 0
166 |
167 | # try to evaluate
168 | if eval_every_step(self.global_step):
169 | self.logger.log('eval_total_time', self.timer.total_time(), self.global_frame)
170 | eval_rew = self.eval()
171 |
172 | # TODO: save policy
173 | if save_every_step(self.global_step):
174 | agent_stamp = self._agent_dir / f'agent-{int(self.global_step/1000)}K-{round(eval_rew, 2)}.pkl'
175 | with open(str(agent_stamp), 'wb') as f_agent:
176 | pickle.dump(self.agent, f_agent)
177 | print("Save agent to", agent_stamp)
178 | if self.global_step >= 200000 and self.change_freq:
179 | save_every_step.change_every(freq=10) # decrease the freq
180 | self.change_freq = False
181 |
182 | meta = self.agent.update_meta(meta, self.global_step, time_step)
183 | if hasattr(self.agent, "regress_meta"):
184 | repeat = self.cfg.action_repeat
185 | every = self.agent.update_task_every_step // repeat
186 | init_step = self.agent.num_init_steps
187 | if self.global_step > (init_step // repeat) and self.global_step % every == 0:
188 | meta = self.agent.regress_meta(self.replay_iter, self.global_step)
189 |
190 | # sample action
191 | with torch.no_grad(), utils.eval_mode(self.agent):
192 | action = self.agent.act(time_step.observation,
193 | meta=meta, step=self.global_step, eval_mode=False)
194 |
195 | # try to update the agent
196 | if not seed_until_step(self.global_step):
197 | metrics = self.agent.update(self.replay_iter, self.global_step)
198 | self.logger.log_metrics(metrics, self.global_frame, ty='train')
199 | if self.cfg.use_wandb:
200 | wandb.log(metrics)
201 |
202 | # take env step
203 | time_step = self.train_env.step(action)
204 | episode_reward += time_step.reward
205 | self.replay_storage.add(time_step, meta, physics=self.train_env.physics.get_state())
206 | self.train_video_recorder.record(time_step.observation)
207 | episode_step += 1
208 | self._global_step += 1
209 |
210 |
211 | @hydra.main(config_path='.', config_name='collect_data')
212 | def main(cfg):
213 | from collect_data import Workspace as W
214 | root_dir = Path.cwd()
215 | workspace = W(cfg)
216 | snapshot = root_dir / 'snapshot.pt'
217 | if snapshot.exists():
218 | print(f'resuming: {snapshot}')
219 | workspace.load_snapshot()
220 |
221 | if cfg.use_wandb:
222 | wandb_dir = f"./wandb/collect_{cfg.task}_{cfg.agent.name}_{cfg.seed}"
223 | if not os.path.exists(wandb_dir):
224 | os.makedirs(wandb_dir)
225 | wandb.init(project="UTDS", entity='', config=cfg, group=f'{cfg.task}_{cfg.agent.name}',
226 | name=f'{cfg.task}_{cfg.agent.name}', dir=wandb_dir)
227 | wandb.config.update(vars(cfg))
228 |
229 | workspace.train()
230 |
231 |
232 | if __name__ == '__main__':
233 | main()
234 |
--------------------------------------------------------------------------------
/agent/ddpg.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import hydra
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import utils
10 |
11 |
12 | class Encoder(nn.Module):
13 | def __init__(self, obs_shape):
14 | super().__init__()
15 |
16 | assert len(obs_shape) == 3
17 | self.repr_dim = 32 * 35 * 35
18 |
19 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
20 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
21 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
22 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
23 | nn.ReLU())
24 |
25 | self.apply(utils.weight_init)
26 |
27 | def forward(self, obs):
28 | obs = obs / 255.0 - 0.5
29 | h = self.convnet(obs)
30 | h = h.view(h.shape[0], -1)
31 | return h
32 |
33 |
34 | class Actor(nn.Module):
35 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim):
36 | super().__init__()
37 |
38 | feature_dim = feature_dim if obs_type == 'pixels' else hidden_dim
39 |
40 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim),
41 | nn.LayerNorm(feature_dim), nn.Tanh())
42 |
43 | policy_layers = []
44 | policy_layers += [
45 | nn.Linear(feature_dim, hidden_dim),
46 | nn.ReLU(inplace=True)
47 | ]
48 | # add additional hidden layer for pixels
49 | if obs_type == 'pixels':
50 | policy_layers += [
51 | nn.Linear(hidden_dim, hidden_dim),
52 | nn.ReLU(inplace=True)
53 | ]
54 | policy_layers += [nn.Linear(hidden_dim, action_dim)]
55 |
56 | self.policy = nn.Sequential(*policy_layers)
57 |
58 | self.apply(utils.weight_init)
59 |
60 | def forward(self, obs, std):
61 | h = self.trunk(obs)
62 |
63 | mu = self.policy(h)
64 | mu = torch.tanh(mu)
65 | std = torch.ones_like(mu) * std
66 |
67 | dist = utils.TruncatedNormal(mu, std)
68 | return dist
69 |
70 |
71 | class Critic(nn.Module):
72 | def __init__(self, obs_type, obs_dim, action_dim, feature_dim, hidden_dim):
73 | super().__init__()
74 |
75 | self.obs_type = obs_type
76 |
77 | if obs_type == 'pixels':
78 | # for pixels actions will be added after trunk
79 | self.trunk = nn.Sequential(nn.Linear(obs_dim, feature_dim),
80 | nn.LayerNorm(feature_dim), nn.Tanh())
81 | trunk_dim = feature_dim + action_dim
82 | else:
83 | # for states actions come in the beginning
84 | self.trunk = nn.Sequential(
85 | nn.Linear(obs_dim + action_dim, hidden_dim),
86 | nn.LayerNorm(hidden_dim), nn.Tanh())
87 | trunk_dim = hidden_dim
88 |
89 | def make_q():
90 | q_layers = []
91 | q_layers += [
92 | nn.Linear(trunk_dim, hidden_dim),
93 | nn.ReLU(inplace=True)
94 | ]
95 | if obs_type == 'pixels':
96 | q_layers += [
97 | nn.Linear(hidden_dim, hidden_dim),
98 | nn.ReLU(inplace=True)
99 | ]
100 | q_layers += [nn.Linear(hidden_dim, 1)]
101 | return nn.Sequential(*q_layers)
102 |
103 | self.Q1 = make_q()
104 | self.Q2 = make_q()
105 |
106 | self.apply(utils.weight_init)
107 |
108 | def forward(self, obs, action):
109 | inpt = obs if self.obs_type == 'pixels' else torch.cat([obs, action], dim=-1)
110 | h = self.trunk(inpt)
111 | h = torch.cat([h, action], dim=-1) if self.obs_type == 'pixels' else h
112 |
113 | q1 = self.Q1(h)
114 | q2 = self.Q2(h)
115 |
116 | return q1, q2
117 |
118 |
119 | class DDPGAgent:
120 | def __init__(self,
121 | name,
122 | # reward_free,
123 | obs_type,
124 | obs_shape,
125 | action_shape,
126 | device,
127 | lr,
128 | feature_dim,
129 | hidden_dim,
130 | critic_target_tau,
131 | num_expl_steps,
132 | update_every_steps,
133 | stddev_schedule,
134 | nstep,
135 | batch_size,
136 | stddev_clip,
137 | init_critic,
138 | use_tb,
139 | use_wandb,
140 | update_encoder,
141 | meta_dim=0):
142 | self.update_encoder = update_encoder
143 | # self.reward_free = reward_free
144 | self.obs_type = obs_type
145 | self.obs_shape = obs_shape
146 | self.action_dim = action_shape[0]
147 | self.hidden_dim = hidden_dim
148 | self.lr = lr
149 | self.device = device
150 | self.critic_target_tau = critic_target_tau
151 | self.update_every_steps = update_every_steps
152 | self.use_tb = use_tb
153 | self.use_wandb = use_wandb
154 | self.num_expl_steps = num_expl_steps
155 | self.stddev_schedule = stddev_schedule
156 | self.stddev_clip = stddev_clip
157 | self.init_critic = init_critic
158 | self.feature_dim = feature_dim
159 | self.solved_meta = None
160 |
161 | # models
162 | if obs_type == 'pixels':
163 | self.aug = utils.RandomShiftsAug(pad=4)
164 | self.encoder = Encoder(obs_shape).to(device)
165 | self.obs_dim = self.encoder.repr_dim + meta_dim
166 | else:
167 | self.aug = nn.Identity()
168 | self.encoder = nn.Identity()
169 | self.obs_dim = obs_shape[0] + meta_dim
170 |
171 | self.actor = Actor(obs_type, self.obs_dim, self.action_dim,
172 | feature_dim, hidden_dim).to(device)
173 |
174 | self.critic = Critic(obs_type, self.obs_dim, self.action_dim,
175 | feature_dim, hidden_dim).to(device)
176 | self.critic_target = Critic(obs_type, self.obs_dim, self.action_dim,
177 | feature_dim, hidden_dim).to(device)
178 | self.critic_target.load_state_dict(self.critic.state_dict())
179 |
180 | # optimizers
181 | if obs_type == 'pixels':
182 | self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
183 | else:
184 | self.encoder_opt = None
185 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
186 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
187 |
188 | self.train()
189 | self.critic_target.train()
190 |
191 | def train(self, training=True):
192 | self.training = training
193 | self.encoder.train(training)
194 | self.actor.train(training)
195 | self.critic.train(training)
196 |
197 | def init_from(self, other):
198 | # copy parameters over
199 | utils.hard_update_params(other.encoder, self.encoder)
200 | utils.hard_update_params(other.actor, self.actor)
201 | if self.init_critic:
202 | utils.hard_update_params(other.critic.trunk, self.critic.trunk)
203 |
204 | def get_meta_specs(self):
205 | return tuple()
206 |
207 | def init_meta(self):
208 | return OrderedDict()
209 |
210 | def update_meta(self, meta, global_step, time_step, finetune=False):
211 | return meta
212 |
213 | def act(self, obs, meta, step, eval_mode):
214 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
215 | h = self.encoder(obs)
216 | inputs = [h]
217 | for value in meta.values():
218 | value = torch.as_tensor(value, device=self.device).unsqueeze(0)
219 | inputs.append(value)
220 | inpt = torch.cat(inputs, dim=-1)
221 | #assert obs.shape[-1] == self.obs_shape[-1]
222 | stddev = utils.schedule(self.stddev_schedule, step)
223 | dist = self.actor(inpt, stddev)
224 | if eval_mode:
225 | action = dist.mean
226 | else:
227 | action = dist.sample(clip=None)
228 | if step < self.num_expl_steps:
229 | action.uniform_(-1.0, 1.0)
230 | return action.cpu().numpy()[0]
231 |
232 | def update_critic(self, obs, action, reward, discount, next_obs, step):
233 | metrics = dict()
234 |
235 | with torch.no_grad():
236 | stddev = utils.schedule(self.stddev_schedule, step)
237 | dist = self.actor(next_obs, stddev)
238 | next_action = dist.sample(clip=self.stddev_clip)
239 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
240 | target_V = torch.min(target_Q1, target_Q2)
241 | target_Q = reward + (discount * target_V)
242 |
243 | Q1, Q2 = self.critic(obs, action)
244 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
245 |
246 | if self.use_tb or self.use_wandb:
247 | metrics['critic_target_q'] = target_Q.mean().item()
248 | metrics['critic_q1'] = Q1.mean().item()
249 | metrics['critic_q2'] = Q2.mean().item()
250 | metrics['critic_loss'] = critic_loss.item()
251 |
252 | # optimize critic
253 | if self.encoder_opt is not None:
254 | self.encoder_opt.zero_grad(set_to_none=True)
255 | self.critic_opt.zero_grad(set_to_none=True)
256 | critic_loss.backward()
257 | self.critic_opt.step()
258 | if self.encoder_opt is not None:
259 | self.encoder_opt.step()
260 | return metrics
261 |
262 | def update_actor(self, obs, step):
263 | metrics = dict()
264 |
265 | stddev = utils.schedule(self.stddev_schedule, step)
266 | dist = self.actor(obs, stddev)
267 | action = dist.sample(clip=self.stddev_clip)
268 | log_prob = dist.log_prob(action).sum(-1, keepdim=True)
269 | Q1, Q2 = self.critic(obs, action)
270 | Q = torch.min(Q1, Q2)
271 |
272 | actor_loss = -Q.mean()
273 |
274 | # optimize actor
275 | self.actor_opt.zero_grad(set_to_none=True)
276 | actor_loss.backward()
277 | self.actor_opt.step()
278 |
279 | if self.use_tb or self.use_wandb:
280 | metrics['actor_loss'] = actor_loss.item()
281 | metrics['actor_logprob'] = log_prob.mean().item()
282 | metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
283 |
284 | return metrics
285 |
286 | def aug_and_encode(self, obs):
287 | obs = self.aug(obs)
288 | return self.encoder(obs)
289 |
290 | def update(self, replay_iter, step):
291 | metrics = dict()
292 | #import ipdb; ipdb.set_trace()
293 |
294 | if step % self.update_every_steps != 0:
295 | return metrics
296 |
297 | batch = next(replay_iter)
298 | obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device)
299 |
300 | # augment and encode
301 | obs = self.aug_and_encode(obs)
302 | with torch.no_grad():
303 | next_obs = self.aug_and_encode(next_obs)
304 |
305 | if self.use_tb or self.use_wandb:
306 | metrics['batch_reward'] = reward.mean().item()
307 |
308 | if not self.update_encoder:
309 | obs = obs.detach()
310 | next_obs = next_obs.detach()
311 |
312 | # update critic
313 | metrics.update(
314 | self.update_critic(obs, action, reward, discount, next_obs, step))
315 |
316 | # update actor
317 | metrics.update(self.update_actor(obs.detach(), step))
318 |
319 | # update critic target
320 | utils.soft_update_params(self.critic, self.critic_target,
321 | self.critic_target_tau)
322 |
323 | return metrics
324 |
--------------------------------------------------------------------------------
/custom_dmc_tasks/walker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Planar Walker Domain."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 |
24 | from dm_control import mujoco
25 | from dm_control.rl import control
26 | from dm_control.suite import base
27 | from dm_control.suite import common
28 | from dm_control.suite.utils import randomizers
29 | from dm_control.utils import containers
30 | from dm_control.utils import rewards
31 | from dm_control.utils import io as resources
32 | from dm_control import suite
33 | from dm_env import specs
34 |
35 | import numpy as np
36 |
37 | _DEFAULT_TIME_LIMIT = 25
38 | _CONTROL_TIMESTEP = .025
39 |
40 | # Minimal height of torso over foot above which stand reward is 1.
41 | _STAND_HEIGHT = 1.2
42 |
43 | # Horizontal speeds (meters/second) above which move reward is 1.
44 | _WALK_SPEED = 1
45 | _RUN_SPEED = 8
46 | _SPIN_SPEED = 5
47 |
48 | SUITE = containers.TaggedTasks()
49 |
50 |
51 | def make(task,
52 | task_kwargs=None,
53 | environment_kwargs=None,
54 | visualize_reward=False):
55 | task_kwargs = task_kwargs or {}
56 | if environment_kwargs is not None:
57 | task_kwargs = task_kwargs.copy()
58 | task_kwargs['environment_kwargs'] = environment_kwargs
59 | env = SUITE[task](**task_kwargs)
60 | env.task.visualize_reward = visualize_reward
61 | return env
62 |
63 |
64 | def get_model_and_assets():
65 | """Returns a tuple containing the model XML string and a dict of assets."""
66 | root_dir = os.path.dirname(os.path.dirname(__file__))
67 | xml = resources.GetResource(
68 | os.path.join(root_dir, 'custom_dmc_tasks', 'walker.xml'))
69 | return xml, common.ASSETS
70 |
71 |
72 | @SUITE.add('benchmarking')
73 | def flip(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
74 | """Returns the Run task."""
75 | physics = Physics.from_xml_string(*get_model_and_assets())
76 | task = PlanarWalker(move_speed=_RUN_SPEED, flip=True, random=random)
77 | environment_kwargs = environment_kwargs or {}
78 | return control.Environment(physics,
79 | task,
80 | time_limit=time_limit,
81 | control_timestep=_CONTROL_TIMESTEP,
82 | **environment_kwargs)
83 |
84 |
85 | @SUITE.add('benchmarking')
86 | def multitask(time_limit=_DEFAULT_TIME_LIMIT,
87 | random=None,
88 | environment_kwargs=None):
89 | """Returns the Run task."""
90 | physics = Physics.from_xml_string(*get_model_and_assets())
91 | task = MultiTaskPlanarWalker(random=random)
92 | environment_kwargs = environment_kwargs or {}
93 | return control.Environment(physics,
94 | task,
95 | time_limit=time_limit,
96 | control_timestep=_CONTROL_TIMESTEP,
97 | **environment_kwargs)
98 |
99 |
100 | class Physics(mujoco.Physics):
101 | """Physics simulation with additional features for the Walker domain."""
102 | def torso_upright(self):
103 | """Returns projection from z-axes of torso to the z-axes of world."""
104 | return self.named.data.xmat['torso', 'zz']
105 |
106 | def torso_height(self):
107 | """Returns the height of the torso."""
108 | return self.named.data.xpos['torso', 'z']
109 |
110 | def horizontal_velocity(self):
111 | """Returns the horizontal velocity of the center-of-mass."""
112 | return self.named.data.sensordata['torso_subtreelinvel'][0]
113 |
114 | def orientations(self):
115 | """Returns planar orientations of all bodies."""
116 | return self.named.data.xmat[1:, ['xx', 'xz']].ravel()
117 |
118 | def angmomentum(self):
119 | """Returns the angular momentum of torso of the Cheetah about Y axis."""
120 | return self.named.data.subtree_angmom['torso'][1]
121 |
122 |
123 | class PlanarWalker(base.Task):
124 | """A planar walker task."""
125 | def __init__(self, move_speed, flip=False, random=None):
126 | """Initializes an instance of `PlanarWalker`.
127 |
128 | Args:
129 | move_speed: A float. If this value is zero, reward is given simply for
130 | standing up. Otherwise this specifies a target horizontal velocity for
131 | the walking task.
132 | random: Optional, either a `numpy.random.RandomState` instance, an
133 | integer seed for creating a new `RandomState`, or None to select a seed
134 | automatically (default).
135 | """
136 | self._move_speed = move_speed
137 | self._flip = flip
138 | super().__init__(random=random)
139 |
140 | def initialize_episode(self, physics):
141 | """Sets the state of the environment at the start of each episode.
142 |
143 | In 'standing' mode, use initial orientation and small velocities.
144 | In 'random' mode, randomize joint angles and let fall to the floor.
145 |
146 | Args:
147 | physics: An instance of `Physics`.
148 |
149 | """
150 | randomizers.randomize_limited_and_rotational_joints(
151 | physics, self.random)
152 | super().initialize_episode(physics)
153 |
154 | def get_observation(self, physics):
155 | """Returns an observation of body orientations, height and velocites."""
156 | obs = collections.OrderedDict()
157 | obs['orientations'] = physics.orientations()
158 | obs['height'] = physics.torso_height()
159 | obs['velocity'] = physics.velocity()
160 | return obs
161 |
162 | def get_reward(self, physics):
163 | """Returns a reward to the agent."""
164 | standing = rewards.tolerance(physics.torso_height(),
165 | bounds=(_STAND_HEIGHT, float('inf')),
166 | margin=_STAND_HEIGHT / 2)
167 | upright = (1 + physics.torso_upright()) / 2
168 | stand_reward = (3 * standing + upright) / 4
169 |
170 | if self._flip:
171 | move_reward = rewards.tolerance(physics.angmomentum(),
172 | bounds=(_SPIN_SPEED, float('inf')),
173 | margin=_SPIN_SPEED,
174 | value_at_margin=0,
175 | sigmoid='linear')
176 | else:
177 | move_reward = rewards.tolerance(physics.horizontal_velocity(),
178 | bounds=(self._move_speed,
179 | float('inf')),
180 | margin=self._move_speed / 2,
181 | value_at_margin=0.5,
182 | sigmoid='linear')
183 |
184 | return stand_reward * (5 * move_reward + 1) / 6
185 |
186 |
187 | class MultiTaskPlanarWalker(base.Task):
188 | """A planar walker task."""
189 | def __init__(self, random=None):
190 | """Initializes an instance of `PlanarWalker`.
191 |
192 | Args:
193 | move_speed: A float. If this value is zero, reward is given simply for
194 | standing up. Otherwise this specifies a target horizontal velocity for
195 | the walking task.
196 | random: Optional, either a `numpy.random.RandomState` instance, an
197 | integer seed for creating a new `RandomState`, or None to select a seed
198 | automatically (default).
199 | """
200 | super().__init__(random=random)
201 |
202 | def initialize_episode(self, physics):
203 | """Sets the state of the environment at the start of each episode.
204 |
205 | In 'standing' mode, use initial orientation and small velocities.
206 | In 'random' mode, randomize joint angles and let fall to the floor.
207 |
208 | Args:
209 | physics: An instance of `Physics`.
210 |
211 | """
212 | randomizers.randomize_limited_and_rotational_joints(
213 | physics, self.random)
214 | super().initialize_episode(physics)
215 |
216 | def get_observation(self, physics):
217 | """Returns an observation of body orientations, height and velocites."""
218 | obs = collections.OrderedDict()
219 | obs['orientations'] = physics.orientations()
220 | obs['height'] = physics.torso_height()
221 | obs['velocity'] = physics.velocity()
222 | return obs
223 |
224 | def get_reward_spec(self):
225 | return specs.Array(shape=(4,), dtype=np.float32, name='reward')
226 |
227 | def get_reward(self, physics):
228 | """Returns a reward to the agent."""
229 |
230 | # compute stand reward
231 | standing = rewards.tolerance(physics.torso_height(),
232 | bounds=(_STAND_HEIGHT, float('inf')),
233 | margin=_STAND_HEIGHT / 2)
234 | upright = (1 + physics.torso_upright()) / 2
235 |
236 | stand_reward = (3 * standing + upright) / 4
237 |
238 | # compute walk reward
239 | walking = rewards.tolerance(physics.horizontal_velocity(),
240 | bounds=(_WALK_SPEED, float('inf')),
241 | margin=_WALK_SPEED / 2,
242 | value_at_margin=0.5,
243 | sigmoid='linear')
244 | walk_reward = stand_reward * (5 * walking + 1) / 6
245 |
246 | # compute run reward
247 | running = rewards.tolerance(physics.horizontal_velocity(),
248 | bounds=(_RUN_SPEED, float('inf')),
249 | margin=_RUN_SPEED / 2,
250 | value_at_margin=0.5,
251 | sigmoid='linear')
252 | run_reward = stand_reward * (5 * running + 1) / 6
253 |
254 | # compute flip reward
255 | flipping = rewards.tolerance(physics.angmomentum(),
256 | bounds=(_SPIN_SPEED, float('inf')),
257 | margin=_SPIN_SPEED,
258 | value_at_margin=0,
259 | sigmoid='linear')
260 |
261 | flip_reward = stand_reward * (5 * flipping + 1) / 6
262 |
263 | reward = np.array([stand_reward, walk_reward, run_reward, flip_reward])
264 | return reward.astype(np.float32)
265 |
--------------------------------------------------------------------------------
/agent/pbrl.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 | import copy
8 | import utils
9 | from dm_control.utils import rewards
10 |
11 |
12 | class Actor(nn.Module):
13 | def __init__(self, state_dim, action_dim, max_action=1):
14 | super(Actor, self).__init__()
15 |
16 | self.l1 = nn.Linear(state_dim, 256)
17 | self.l2 = nn.Linear(256, 256)
18 | self.l3 = nn.Linear(256, action_dim)
19 |
20 | self.max_action = max_action
21 |
22 | def forward(self, state):
23 | a = F.relu(self.l1(state))
24 | a = F.relu(self.l2(a))
25 | return self.max_action * torch.tanh(self.l3(a))
26 |
27 |
28 | class Critic(nn.Module):
29 | def __init__(self, state_dim, action_dim):
30 | super(Critic, self).__init__()
31 | self.l1 = nn.Linear(state_dim + action_dim, 256)
32 | self.l2 = nn.Linear(256, 256)
33 | self.l3 = nn.Linear(256, 1)
34 |
35 | def forward(self, state, action):
36 | sa = torch.cat([state, action], 1)
37 | q1 = F.relu(self.l1(sa))
38 | q1 = F.relu(self.l2(q1))
39 | q1 = self.l3(q1)
40 | return q1
41 |
42 |
43 | class PBRLAgent:
44 | def __init__(self,
45 | name,
46 | obs_shape,
47 | action_shape,
48 | device,
49 | lr,
50 | hidden_dim,
51 | critic_target_tau,
52 | actor_target_tau,
53 | policy_freq,
54 | policy_noise,
55 | noise_clip,
56 | use_tb,
57 | # alpha,
58 | batch_size,
59 | num_expl_steps,
60 | # PBRL parameters
61 | num_random,
62 | ucb_ratio_in,
63 | ucb_ratio_ood_init,
64 | ucb_ratio_ood_min,
65 | ood_decay_factor,
66 | ensemble,
67 | ood_noise,
68 | share_ratio,
69 | has_next_action=False):
70 | self.policy_noise = policy_noise
71 | self.policy_freq = policy_freq
72 | self.noise_clip = noise_clip
73 | self.num_expl_steps = num_expl_steps
74 | self.action_dim = action_shape[0]
75 | self.hidden_dim = hidden_dim
76 | self.lr = lr
77 | self.device = device
78 | self.critic_target_tau = critic_target_tau
79 | self.actor_target_tau = actor_target_tau
80 | self.use_tb = use_tb
81 | # self.stddev_schedule = stddev_schedule
82 | # self.stddev_clip = stddev_clip
83 | # self.alpha = alpha
84 | self.max_action = 1.0
85 | self.share_ratio = share_ratio
86 | self.share_ratio_now = None
87 |
88 | # PBRL parameters
89 | self.num_random = num_random # for ood action
90 | self.ucb_ratio_in = ucb_ratio_in
91 | self.ensemble = ensemble
92 | self.ood_noise = ood_noise
93 |
94 | # PBRL parameters: control ood ratio
95 | self.ucb_ratio_ood_init = ucb_ratio_ood_init
96 | self.ucb_ratio_ood_min = ucb_ratio_ood_min
97 | self.ood_decay_factor = ood_decay_factor
98 | self.ucb_ratio_ood = ucb_ratio_ood_init
99 | self.ucb_ratio_ood_linear_steps = None
100 |
101 | # models
102 | self.actor = Actor(obs_shape[0], action_shape[0]).to(device)
103 | self.actor_target = copy.deepcopy(self.actor)
104 |
105 | # initialize ensemble of critic
106 | self.critic, self.critic_target = [], []
107 | for _ in range(self.ensemble):
108 | single_critic = Critic(obs_shape[0], action_shape[0]).to(device)
109 | single_critic_target = copy.deepcopy(single_critic)
110 | single_critic_target.load_state_dict(single_critic.state_dict())
111 | self.critic.append(single_critic)
112 | self.critic_target.append(single_critic_target)
113 | print("Actor parameters:", utils.total_parameters(self.actor))
114 | print("Critic parameters: single", utils.total_parameters(self.critic[0]), ", total:", utils.total_parameters(self.critic))
115 |
116 | # optimizers
117 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
118 | self.critic_opt = []
119 | for i in range(self.ensemble): # each ensemble member has its optimizer
120 | self.critic_opt.append(torch.optim.Adam(self.critic[i].parameters(), lr=lr))
121 |
122 | self.train()
123 | for ct_single in self.critic_target:
124 | ct_single.train()
125 |
126 | def train(self, training=True):
127 | self.training = training
128 | self.actor.train(training)
129 | for c_single in self.critic:
130 | c_single.train(training)
131 |
132 | def act(self, obs, step, eval_mode):
133 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
134 | action = self.actor(obs)
135 | if step < self.num_expl_steps:
136 | action.uniform_(-1.0, 1.0)
137 | return action.cpu().numpy()[0]
138 |
139 | def ucb_func(self, obs, action, mean=False):
140 | action_shape = action.shape[0] # 1024*num_random
141 | obs_shape = obs.shape[0] # 1024
142 | assert int(action_shape / obs_shape) in [1, self.num_random]
143 | if int(action_shape / obs_shape) != 1:
144 | obs = obs.unsqueeze(1).repeat(1, self.num_random, 1).view(obs.shape[0] * self.num_random, obs.shape[1])
145 | # Bootstrapped uncertainty
146 | q_pred = []
147 | for i in range(self.ensemble):
148 | q_pred.append(self.critic[i](obs.cuda(), action.cuda()))
149 | ucb = torch.std(torch.hstack(q_pred), dim=1, keepdim=True) # (1024, ensemble) -> (1024, 1)
150 | assert ucb.size() == (action_shape, 1)
151 | if mean:
152 | q_pred = torch.mean(torch.hstack(q_pred), dim=1, keepdim=True)
153 | return ucb, q_pred
154 |
155 | def ucb_func_target(self, obs_next, act_next):
156 | action_shape = act_next.shape[0] # 2560
157 | obs_shape = obs_next.shape[0] # 256
158 | assert int(action_shape / obs_shape) in [1, self.num_random]
159 | if int(action_shape / obs_shape) != 1:
160 | obs_next = obs_next.unsqueeze(1).repeat(1, self.num_random, 1).view(obs_next.shape[0] * self.num_random, obs_next.shape[1]) # (2560, obs_dim)
161 | # Bootstrapped uncertainty
162 | target_q_pred = []
163 | for i in range(self.ensemble):
164 | target_q_pred.append(self.critic[i](obs_next.cuda(), act_next.cuda()))
165 | ucb_t = torch.std(torch.hstack(target_q_pred), dim=1, keepdim=True)
166 | assert ucb_t.size() == (action_shape, 1)
167 | return ucb_t, target_q_pred
168 |
169 | def update_critic(self, obs, action, reward, discount, next_obs, step, total_step, bool_flag):
170 | self.share_ratio_now = utils.decay_linear(t=step, init=self.share_ratio, minimum=1.0, total_steps=total_step // 2)
171 |
172 | metrics = dict()
173 | with torch.no_grad():
174 | # Select action according to policy and add clipped noise
175 | noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
176 | next_action = (self.actor_target(next_obs) + noise).clamp(-self.max_action, self.max_action)
177 |
178 | # ood sample 1
179 | sampled_current_actions = self.actor(obs).unsqueeze(1).repeat(1, self.num_random, 1).view(
180 | action.shape[0]*self.num_random, action.shape[1])
181 | noise_current = (torch.randn_like(sampled_current_actions) * self.ood_noise).clamp(-self.noise_clip, self.noise_clip)
182 | sampled_current_actions = (sampled_current_actions + noise_current).clamp(-self.max_action, self.max_action)
183 |
184 | # ood sample 2
185 | sampled_next_actions = self.actor(next_obs).unsqueeze(1).repeat(1, self.num_random, 1).view(
186 | action.shape[0]*self.num_random, action.shape[1])
187 | noise_next = (torch.randn_like(sampled_next_actions) * self.ood_noise).clamp(-self.noise_clip, self.noise_clip)
188 | sampled_next_actions = (sampled_next_actions + noise_next).clamp(-self.max_action, self.max_action)
189 |
190 | # random sample
191 | random_actions = torch.FloatTensor(action.shape[0]*self.num_random, action.shape[1]).uniform_(
192 | -self.max_action, self.max_action).to(self.device)
193 |
194 | # TODO: UCB and Q-values
195 | ucb_current, q_pred = self.ucb_func(obs, action) # (1024,1). lenth=ensemble, q_pred[0].shape=(1024,1)
196 | ucb_next, target_q_pred = self.ucb_func_target(next_obs, next_action) # (1024,1). lenth=ensemble, target_q_pred[0].shape=(1024,1)
197 |
198 | ucb_curr_actions_ood, qf_curr_actions_all_ood = self.ucb_func(obs, sampled_current_actions) # (1024*num_random, 1), length=ensemble, (1024*num_random, 1)
199 | ucb_next_actions_ood, qf_next_actions_all_ood = self.ucb_func(next_obs, sampled_next_actions) # 同上
200 | # ucb_rand_ood, qf_rand_actions_all_ood = self.ucb_func(obs, random_actions)
201 |
202 | for qf_index in np.arange(self.ensemble):
203 | ucb_ratio_in_flag = bool_flag * self.ucb_ratio_in + (1 - bool_flag) * self.ucb_ratio_in * self.share_ratio_now
204 | ucb_ratio_in_flag = np.expand_dims(ucb_ratio_in_flag, 1)
205 | q_target = reward + discount * (target_q_pred[qf_index] - torch.from_numpy(ucb_ratio_in_flag.astype(np.float32)).cuda() * ucb_next) # (1024, 1), (1024, 1), (1024, 1)
206 | # print("bool flag", bool_flag[:10], bool_flag[-10:])
207 | # print("ucb_ratio_in_flag", q_target.shape, ucb_ratio_in_flag.shape, ucb_next.shape, (torch.from_numpy(ucb_ratio_in_flag.astype(np.float32)).cuda() * ucb_next).shape, ucb_ratio_in_flag[:10])
208 |
209 | # q_target = reward + discount * (target_q_pred[qf_index] - self.ucb_ratio_in * ucb_next) # (1024, 1), (1024, 1), (1024, 1)
210 | q_target = q_target.detach()
211 | qf_loss_in = F.mse_loss(q_pred[qf_index], q_target)
212 |
213 | # TODO: ood loss
214 | cat_qf_ood = torch.cat([qf_curr_actions_all_ood[qf_index],
215 | qf_next_actions_all_ood[qf_index]], 0)
216 | # assert cat_qf_ood.size() == (1024*self.num_random*3, 1)
217 |
218 | ucb_ratio_ood_flag = bool_flag * self.ucb_ratio_ood + (1 - bool_flag) * self.ucb_ratio_ood * self.share_ratio_now
219 | ucb_ratio_ood_flag = np.expand_dims(ucb_ratio_ood_flag, 1).repeat(self.num_random, axis=1).reshape(-1, 1).astype(np.float32)
220 | # print("ucb_ratio_ood_flag 1", ucb_ratio_ood_flag.shape, ucb_curr_actions_ood.shape)
221 |
222 | cat_qf_ood_target = torch.cat([
223 | torch.maximum(qf_curr_actions_all_ood[qf_index] - torch.from_numpy(ucb_ratio_ood_flag).cuda() * ucb_curr_actions_ood, torch.zeros(1).cuda()),
224 | torch.maximum(qf_next_actions_all_ood[qf_index] - torch.from_numpy(ucb_ratio_ood_flag).cuda() * ucb_next_actions_ood, torch.zeros(1).cuda())], 0)
225 | # print("ucb_ratio_ood_flag 2", cat_qf_ood_target.shape, qf_curr_actions_all_ood[qf_index].shape)
226 | cat_qf_ood_target = cat_qf_ood_target.detach()
227 |
228 | # assert cat_qf_ood_target.size() == (1024*self.num_random*3, 1)
229 | qf_loss_ood = F.mse_loss(cat_qf_ood, cat_qf_ood_target)
230 | critic_loss = qf_loss_in + qf_loss_ood
231 |
232 | # Update the Q-functions
233 | self.critic_opt[qf_index].zero_grad()
234 | critic_loss.backward(retain_graph=True)
235 | self.critic_opt[qf_index].step()
236 |
237 | # change the ood ratio
238 | self.ucb_ratio_ood = max(self.ucb_ratio_ood_init * self.ood_decay_factor ** step, self.ucb_ratio_ood_min)
239 |
240 | if self.use_tb:
241 | metrics['critic_target_q'] = q_target.mean().item()
242 | metrics['critic_q1'] = q_pred[0].mean().item()
243 | # metrics['critic_q2'] = q_pred[1].mean().item()
244 | # ucb
245 | metrics['ucb_current'] = ucb_current.mean().item()
246 | metrics['ucb_next'] = ucb_next.mean().item()
247 | metrics['ucb_curr_actions_ood'] = ucb_curr_actions_ood.mean().item()
248 | metrics['ucb_next_actions_ood'] = ucb_next_actions_ood.mean().item()
249 | # loss
250 | metrics['critic_loss_in'] = qf_loss_in.item()
251 | metrics['critic_loss_ood'] = qf_loss_ood.item()
252 | metrics['ucb_ratio_ood'] = self.ucb_ratio_ood
253 | metrics['share_ratio_now'] = self.share_ratio_now
254 | return metrics
255 |
256 | def update_actor(self, obs, action):
257 | metrics = dict()
258 |
259 | # Compute actor loss
260 | pi = self.actor(obs)
261 |
262 | Qvalues = []
263 | for i in range(self.ensemble):
264 | Qvalues.append(self.critic[i](obs, pi)) # (1024, 1)
265 | Qvalues_min = torch.min(torch.hstack(Qvalues), dim=1, keepdim=True).values
266 | assert Qvalues_min.size() == (1024, 1)
267 |
268 | actor_loss = -1. * Qvalues_min.mean()
269 |
270 | # optimize actor
271 | self.actor_opt.zero_grad(set_to_none=True)
272 | actor_loss.backward()
273 | self.actor_opt.step()
274 |
275 | if self.use_tb:
276 | metrics['actor_loss'] = actor_loss.item()
277 |
278 | return metrics
279 |
280 | def update(self, replay_iter, step, total_step):
281 | metrics = dict()
282 |
283 | batch = next(replay_iter)
284 | obs, action, reward, discount, next_obs, bool_flag = utils.to_torch(
285 | batch, self.device)
286 | bool_flag = bool_flag.cpu().detach().numpy()
287 |
288 | if self.use_tb:
289 | metrics['batch_reward'] = reward.mean().item()
290 |
291 | # update critic
292 | metrics.update(
293 | self.update_critic(obs, action, reward, discount, next_obs, step, total_step, bool_flag))
294 |
295 | # update actor
296 | if step % self.policy_freq == 0:
297 | metrics.update(self.update_actor(obs, action))
298 |
299 | # update actor target
300 | utils.soft_update_params(self.actor, self.actor_target, self.actor_target_tau)
301 |
302 | # update critic target
303 | for i in range(self.ensemble):
304 | utils.soft_update_params(self.critic[i], self.critic_target[i], self.critic_target_tau)
305 |
306 | return metrics
307 |
--------------------------------------------------------------------------------
/agent/cql_cdsz.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 |
8 | import utils
9 | from dm_control.utils import rewards
10 |
11 |
12 | class Actor(nn.Module):
13 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3):
14 | super().__init__()
15 |
16 | self.policy = nn.Sequential(
17 | nn.Linear(obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
18 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
19 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
20 |
21 | self.fc_mu = nn.Linear(hidden_dim, action_dim)
22 | self.fc_logstd = nn.Linear(hidden_dim, action_dim)
23 |
24 | def forward(self, obs):
25 | f = self.policy(obs)
26 | mu = self.fc_mu(f)
27 | log_std = self.fc_logstd(f)
28 |
29 | mu = torch.clamp(mu, -5, 5)
30 |
31 | std = log_std.clamp(-5, 2).exp()
32 | dist = utils.SquashedNormal2(mu, std)
33 | return dist
34 |
35 |
36 | class Critic(nn.Module):
37 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3):
38 | super().__init__()
39 |
40 | self.q1_net = nn.Sequential(
41 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
42 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(),
43 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
44 | self.q1_last = nn.Linear(hidden_dim, 1)
45 |
46 | self.q2_net = nn.Sequential(
47 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
48 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(),
49 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
50 | self.q2_last = nn.Linear(hidden_dim, 1)
51 |
52 | def forward(self, obs, action):
53 | obs_action = torch.cat([obs, action], dim=-1)
54 | q1 = self.q1_net(obs_action)
55 | q1 = self.q1_last(q1)
56 |
57 | q2 = self.q2_net(obs_action)
58 | q2 = self.q2_last(q2)
59 |
60 | return q1, q2
61 |
62 |
63 | class CQLCDSZeroAgent:
64 | def __init__(self,
65 | name,
66 | obs_shape,
67 | action_shape,
68 | device,
69 | actor_lr,
70 | critic_lr,
71 | hidden_dim,
72 | critic_target_tau,
73 | nstep,
74 | batch_size,
75 | use_tb,
76 | alpha,
77 | n_samples,
78 | target_cql_penalty,
79 | use_critic_lagrange,
80 | num_expl_steps,
81 | has_next_action=False):
82 | self.num_expl_steps = num_expl_steps
83 | self.action_dim = action_shape[0]
84 | self.hidden_dim = hidden_dim
85 | self.actor_lr = actor_lr
86 | self.critic_lr = critic_lr
87 | self.device = device
88 | self.critic_target_tau = critic_target_tau
89 | self.use_tb = use_tb
90 | self.use_critic_lagrange = use_critic_lagrange
91 | self.target_cql_penalty = target_cql_penalty
92 |
93 | self.alpha = alpha
94 | self.n_samples = n_samples
95 |
96 | state_dim = obs_shape[0]
97 | action_dim = action_shape[0]
98 |
99 | # models
100 | self.actor = Actor(state_dim, action_dim, hidden_dim).to(device)
101 | self.critic = Critic(state_dim, action_dim, hidden_dim).to(device)
102 | self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(device)
103 | self.critic_target.load_state_dict(self.critic.state_dict())
104 |
105 | # lagrange multipliers
106 | self.target_entropy = -self.action_dim
107 | self.log_actor_alpha = torch.zeros(1, requires_grad=True, device=device)
108 | self.log_critic_alpha = torch.zeros(1, requires_grad=True, device=device)
109 |
110 | # optimizers
111 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
112 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
113 | self.actor_alpha_opt = torch.optim.Adam([self.log_actor_alpha], lr=actor_lr)
114 | self.critic_alpha_opt = torch.optim.Adam([self.log_critic_alpha], lr=actor_lr)
115 |
116 | self.train()
117 | # self.critic_target.train()
118 |
119 | def train(self, training=True):
120 | self.training = training
121 | self.actor.train(training)
122 | self.critic.train(training)
123 |
124 | def act(self, obs, step, eval_mode):
125 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
126 | policy = self.actor(obs)
127 | if eval_mode:
128 | action = policy.mean
129 | else:
130 | action = policy.sample()
131 | if step < self.num_expl_steps:
132 | action.uniform_(-1.0, 1.0)
133 | return action.cpu().numpy()[0]
134 |
135 | def _repeated_critic_apply(self, obs, actions):
136 | """
137 | obs is (batch_size, obs_dim)
138 | actions is (n_samples, batch_size, action_dim)
139 |
140 | output tensors are (n_samples, batch_size, 1)
141 | """
142 | batch_size = obs.shape[0]
143 | n_samples = actions.shape[0]
144 |
145 | reshaped_actions = actions.reshape((n_samples * batch_size, -1))
146 | repeated_obs = obs.unsqueeze(0).repeat((n_samples, 1, 1))
147 | repeated_obs = repeated_obs.reshape((n_samples * batch_size, -1))
148 |
149 | Q1, Q2 = self.critic(repeated_obs, reshaped_actions)
150 | Q1 = Q1.reshape((n_samples, batch_size, 1))
151 | Q2 = Q2.reshape((n_samples, batch_size, 1))
152 |
153 | return Q1, Q2
154 |
155 | def update_critic(self, obs, action, reward, discount, next_obs, step):
156 | metrics = dict()
157 |
158 | # Compute standard SAC loss
159 | with torch.no_grad():
160 | dist = self.actor(next_obs) # SquashedNormal分布
161 | sampled_next_action = dist.sample() # (1024, act_dim)
162 | # print("sampled_next_action:", sampled_next_action.shape)
163 | target_Q1, target_Q2 = self.critic_target(next_obs, sampled_next_action) # (1024,1), (1024,1)
164 | target_V = torch.min(target_Q1, target_Q2) # (1024,1)
165 | target_Q = reward + (discount * target_V) # (1024,1)
166 |
167 | Q1, Q2 = self.critic(obs, action)
168 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) # 标量
169 |
170 | # Add CQL penalty
171 | with torch.no_grad():
172 | random_actions = torch.FloatTensor(self.n_samples, Q1.shape[0],
173 | action.shape[-1]).uniform_(-1, 1).to(self.device) # (n_samples, 1024, act_dim)
174 | sampled_actions = self.actor(obs).sample(
175 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim)
176 | # print("sampled_actions:", sampled_actions.shape)
177 | next_sampled_actions = self.actor(next_obs).sample(
178 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim)
179 | # print("next_sampled_actions:", next_sampled_actions.shape)
180 |
181 | rand_Q1, rand_Q2 = self._repeated_critic_apply(obs, random_actions) # (n_samples, 1024, 1)
182 | sampled_Q1, sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1)
183 | obs, sampled_actions)
184 | next_sampled_Q1, next_sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1)
185 | obs, next_sampled_actions)
186 |
187 | # 默认情况1
188 | cat_Q1 = torch.cat([rand_Q1, sampled_Q1, next_sampled_Q1,
189 | Q1.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1)
190 | cat_Q2 = torch.cat([rand_Q2, sampled_Q2, next_sampled_Q2,
191 | Q2.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1)
192 | assert (not torch.isnan(cat_Q1).any()) and (not torch.isnan(cat_Q2).any())
193 |
194 | cql_logsumexp1 = torch.logsumexp(cat_Q1, dim=0).mean()
195 | cql_logsumexp2 = torch.logsumexp(cat_Q2, dim=0).mean()
196 | cql_logsumexp = cql_logsumexp1 + cql_logsumexp2
197 | cql_penalty = cql_logsumexp - (Q1 + Q2).mean() # 标量
198 |
199 | # Update lagrange multiplier
200 | if self.use_critic_lagrange:
201 | alpha = torch.clamp(self.log_critic_alpha.exp(), min=0.0, max=1000000.0)
202 | alpha_loss = -0.5 * alpha * (cql_penalty - self.target_cql_penalty)
203 |
204 | self.critic_alpha_opt.zero_grad()
205 | alpha_loss.backward(retain_graph=True)
206 | self.critic_alpha_opt.step()
207 | alpha = torch.clamp(self.log_critic_alpha.exp(),
208 | min=0.0,
209 | max=1000000.0).detach()
210 | else:
211 | alpha = self.alpha
212 |
213 | # Combine losses
214 | critic_loss = critic_loss + alpha * cql_penalty
215 |
216 | # optimize critic
217 | self.critic_opt.zero_grad(set_to_none=True)
218 | critic_loss.backward()
219 | # torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 50)
220 | critic_grad_norm = utils.grad_norm(self.critic.parameters())
221 | self.critic_opt.step()
222 |
223 | if self.use_tb:
224 | metrics['critic_target_q'] = target_Q.mean().item()
225 | metrics['critic_q1'] = Q1.mean().item()
226 | # metrics['critic_q2'] = Q2.mean().item()
227 | metrics['critic_loss'] = critic_loss.item()
228 | metrics['critic_cql'] = cql_penalty.item()
229 | metrics['critic_cql_logsum1'] = cql_logsumexp1.item()
230 | # metrics['critic_cql_logsum2'] = cql_logsumexp1.item()
231 | metrics['rand_Q1'] = rand_Q1.mean().item()
232 | # metrics['rand_Q2'] = rand_Q2.mean().item()
233 | metrics['sampled_Q1'] = sampled_Q1.mean().item()
234 | # metrics['sampled_Q2'] = sampled_Q2.mean().item()
235 | metrics['critic_grad_norm'] = critic_grad_norm
236 |
237 | return metrics
238 |
239 | def update_actor(self, obs, action, step):
240 | metrics = dict()
241 |
242 | policy = self.actor(obs)
243 | sampled_action = policy.rsample() # (1024, 6)
244 | log_pi = policy.log_prob(sampled_action) # (1024, 6)
245 |
246 | # update lagrange multiplier
247 | alpha_loss = -(self.log_actor_alpha * (log_pi + self.target_entropy).detach()).mean()
248 | self.actor_alpha_opt.zero_grad(set_to_none=True)
249 | alpha_loss.backward()
250 | self.actor_alpha_opt.step()
251 | alpha = self.log_actor_alpha.exp().detach()
252 |
253 | # optimize actor
254 | Q1, Q2 = self.critic(obs, sampled_action) # (1024, 1)
255 | Q = torch.min(Q1, Q2) # (1024, 1)
256 | actor_loss = (alpha * log_pi - Q).mean() #
257 | self.actor_opt.zero_grad(set_to_none=True)
258 | actor_loss.backward()
259 | actor_grad_norm = utils.grad_norm(self.actor.parameters())
260 | self.actor_opt.step()
261 |
262 | if self.use_tb:
263 | metrics['actor_loss'] = actor_loss.item()
264 | metrics['actor_ent'] = -log_pi.mean().item()
265 | metrics['actor_alpha'] = alpha.item()
266 | metrics['actor_alpha_loss'] = alpha_loss.item()
267 | metrics['actor_mean'] = policy.loc.mean().item()
268 | metrics['actor_std'] = policy.scale.mean().item()
269 | metrics['actor_action'] = sampled_action.mean().item()
270 | metrics['actor_atanh_action'] = utils.atanh(sampled_action).mean().item()
271 | metrics['actor_grad_norm'] = actor_grad_norm
272 |
273 | return metrics
274 |
275 | # TODO: conservative data sharing
276 | def conservative_data_share(self, batch_main, batch_share):
277 | obs_all, action_all, next_obs_all, reward_all, discount_all, obs_k_all = [], [], [], [], [], []
278 |
279 | # 1. sample examples from the main task
280 | # shape = (512, 24), (512, 6), (512, 1), (512, 1), (512, 24)
281 | obs_m, action_m, reward_m, discount_m, next_obs_m, _ = utils.to_torch(batch_main, self.device)
282 | obs_all.append(obs_m)
283 | action_all.append(action_m)
284 | next_obs_all.append(next_obs_m)
285 | reward_all.append(reward_m)
286 | discount_all.append(discount_m)
287 |
288 | # 2. sample examples from other tasks
289 | # sample 10 times samples, and select the top-0.1 samples. (batch_size_split*10, xx, xx)
290 | # shape = (5120, 24) (5120, 6) (5120, 1) (5120, 1) (5120, 24)
291 | obs, action, reward, discount, next_obs, _ = utils.to_torch(batch_share, self.device)
292 | # print("obs:", obs.shape, action.shape, reward.shape, discount.shape, next_obs.shape)
293 | # calculate the conservative Q value
294 | with torch.no_grad():
295 | conservative_q_value, _ = self.critic(obs, action) # (5120, 1)
296 | conservative_q_value = conservative_q_value.squeeze().detach().cpu().numpy()
297 | # choose the top 0.1 top_index.shape=(512,)
298 | top_index = np.argpartition(conservative_q_value, -obs_m.shape[0])[-obs_m.shape[0]:] # find the top 0.1 index. (batch_size_split,)
299 | # extract the samples
300 | obs_all.append(obs[top_index])
301 | action_all.append(action[top_index])
302 | next_obs_all.append(next_obs[top_index])
303 | discount_all.append(discount[top_index])
304 |
305 | # TODO: zero the reward (the only difference between cql-cds)
306 | reward_all.append(torch.zeros(obs_m.shape[0], 1).cuda())
307 |
308 | return torch.cat(obs_all, dim=0), torch.cat(action_all, dim=0), torch.cat(reward_all, dim=0), \
309 | torch.cat(discount_all, dim=0), torch.cat(next_obs_all, dim=0)
310 |
311 | def update(self, replay_iter_main, replay_iter_share, step, total_step):
312 | metrics = dict()
313 |
314 | batch_main = next(replay_iter_main)
315 | batch_share = next(replay_iter_share)
316 |
317 | # print("conservative data sharing...") # obs.shape=(1024, 24), action.shape=(1024, 6) reward.shape=(1024, 1)
318 | obs, action, reward, discount, next_obs = self.conservative_data_share(batch_main, batch_share)
319 |
320 | if self.use_tb:
321 | metrics['batch_reward'] = reward.mean().item()
322 |
323 | # update critic
324 | metrics.update(
325 | self.update_critic(obs, action, reward, discount, next_obs, step))
326 |
327 | # update actor
328 | metrics.update(self.update_actor(obs, action, step))
329 |
330 | # update critic target
331 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
332 |
333 | return metrics
334 |
--------------------------------------------------------------------------------
/agent/cql_cds.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 |
8 | import utils
9 | from dm_control.utils import rewards
10 |
11 |
12 | class Actor(nn.Module):
13 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3):
14 | super().__init__()
15 |
16 | self.policy = nn.Sequential(
17 | nn.Linear(obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
18 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), # 新增
19 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
20 |
21 | self.fc_mu = nn.Linear(hidden_dim, action_dim)
22 | self.fc_logstd = nn.Linear(hidden_dim, action_dim)
23 |
24 | def forward(self, obs):
25 | f = self.policy(obs)
26 | mu = self.fc_mu(f)
27 | log_std = self.fc_logstd(f)
28 |
29 | mu = torch.clamp(mu, -5, 5)
30 |
31 | std = log_std.clamp(-5, 2).exp()
32 | dist = utils.SquashedNormal2(mu, std)
33 | return dist
34 |
35 |
36 | class Critic(nn.Module):
37 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3):
38 | super().__init__()
39 |
40 | self.q1_net = nn.Sequential(
41 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
42 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(),
43 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
44 | self.q1_last = nn.Linear(hidden_dim, 1)
45 |
46 | self.q2_net = nn.Sequential(
47 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
48 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(),
49 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
50 | self.q2_last = nn.Linear(hidden_dim, 1)
51 |
52 | def forward(self, obs, action):
53 | obs_action = torch.cat([obs, action], dim=-1)
54 | q1 = self.q1_net(obs_action)
55 | q1 = self.q1_last(q1)
56 |
57 | q2 = self.q2_net(obs_action)
58 | q2 = self.q2_last(q2)
59 |
60 | return q1, q2
61 |
62 |
63 | class CQLCDSAgent:
64 | def __init__(self,
65 | name,
66 | obs_shape,
67 | action_shape,
68 | device,
69 | actor_lr,
70 | critic_lr,
71 | hidden_dim,
72 | critic_target_tau,
73 | nstep,
74 | batch_size,
75 | use_tb,
76 | alpha,
77 | n_samples,
78 | target_cql_penalty,
79 | use_critic_lagrange,
80 | num_expl_steps,
81 | has_next_action=False):
82 | self.num_expl_steps = num_expl_steps
83 | self.action_dim = action_shape[0]
84 | self.hidden_dim = hidden_dim
85 | self.actor_lr = actor_lr
86 | self.critic_lr = critic_lr
87 | self.device = device
88 | self.critic_target_tau = critic_target_tau
89 | self.use_tb = use_tb
90 | self.use_critic_lagrange = use_critic_lagrange
91 | self.target_cql_penalty = target_cql_penalty
92 |
93 | self.alpha = alpha
94 | self.n_samples = n_samples
95 |
96 | state_dim = obs_shape[0]
97 | action_dim = action_shape[0]
98 |
99 | # models
100 | self.actor = Actor(state_dim, action_dim, hidden_dim).to(device)
101 | self.critic = Critic(state_dim, action_dim, hidden_dim).to(device)
102 | self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(device)
103 | self.critic_target.load_state_dict(self.critic.state_dict())
104 |
105 | # lagrange multipliers
106 | self.target_entropy = -self.action_dim
107 | self.log_actor_alpha = torch.zeros(1, requires_grad=True, device=device)
108 | self.log_critic_alpha = torch.zeros(1, requires_grad=True, device=device)
109 |
110 | # optimizers
111 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
112 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
113 | self.actor_alpha_opt = torch.optim.Adam([self.log_actor_alpha], lr=actor_lr)
114 | self.critic_alpha_opt = torch.optim.Adam([self.log_critic_alpha], lr=actor_lr)
115 |
116 | self.train()
117 | # self.critic_target.train()
118 |
119 | def train(self, training=True):
120 | self.training = training
121 | self.actor.train(training)
122 | self.critic.train(training)
123 |
124 | def act(self, obs, step, eval_mode):
125 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
126 | policy = self.actor(obs)
127 | if eval_mode:
128 | action = policy.mean
129 | else:
130 | action = policy.sample()
131 | if step < self.num_expl_steps:
132 | action.uniform_(-1.0, 1.0)
133 | return action.cpu().numpy()[0]
134 |
135 | def _repeated_critic_apply(self, obs, actions):
136 | """
137 | obs is (batch_size, obs_dim)
138 | actions is (n_samples, batch_size, action_dim)
139 |
140 | output tensors are (n_samples, batch_size, 1)
141 | """
142 | batch_size = obs.shape[0]
143 | n_samples = actions.shape[0]
144 |
145 | reshaped_actions = actions.reshape((n_samples * batch_size, -1))
146 | repeated_obs = obs.unsqueeze(0).repeat((n_samples, 1, 1))
147 | repeated_obs = repeated_obs.reshape((n_samples * batch_size, -1))
148 |
149 | Q1, Q2 = self.critic(repeated_obs, reshaped_actions)
150 | Q1 = Q1.reshape((n_samples, batch_size, 1))
151 | Q2 = Q2.reshape((n_samples, batch_size, 1))
152 |
153 | return Q1, Q2
154 |
155 | def update_critic(self, obs, action, reward, discount, next_obs, step):
156 | metrics = dict()
157 |
158 | # Compute standard SAC loss
159 | with torch.no_grad():
160 | dist = self.actor(next_obs) # SquashedNormal分布
161 | sampled_next_action = dist.sample() # (1024, act_dim)
162 | # print("sampled_next_action:", sampled_next_action.shape)
163 | target_Q1, target_Q2 = self.critic_target(next_obs, sampled_next_action) # (1024,1), (1024,1)
164 | target_V = torch.min(target_Q1, target_Q2) # (1024,1)
165 | target_Q = reward + (discount * target_V) # (1024,1)
166 |
167 | Q1, Q2 = self.critic(obs, action)
168 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) # 标量
169 |
170 | # Add CQL penalty
171 | with torch.no_grad():
172 | random_actions = torch.FloatTensor(self.n_samples, Q1.shape[0],
173 | action.shape[-1]).uniform_(-1, 1).to(self.device) # (n_samples, 1024, act_dim)
174 | sampled_actions = self.actor(obs).sample(
175 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim)
176 | # print("sampled_actions:", sampled_actions.shape)
177 | next_sampled_actions = self.actor(next_obs).sample(
178 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim)
179 | # print("next_sampled_actions:", next_sampled_actions.shape)
180 |
181 | rand_Q1, rand_Q2 = self._repeated_critic_apply(obs, random_actions) # (n_samples, 1024, 1)
182 | sampled_Q1, sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1)
183 | obs, sampled_actions)
184 | next_sampled_Q1, next_sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1)
185 | obs, next_sampled_actions)
186 |
187 | # situation 1
188 | cat_Q1 = torch.cat([rand_Q1, sampled_Q1, next_sampled_Q1,
189 | Q1.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1)
190 | cat_Q2 = torch.cat([rand_Q2, sampled_Q2, next_sampled_Q2,
191 | Q2.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1)
192 |
193 | assert (not torch.isnan(cat_Q1).any()) and (not torch.isnan(cat_Q2).any())
194 |
195 | cql_logsumexp1 = torch.logsumexp(cat_Q1, dim=0).mean()
196 | cql_logsumexp2 = torch.logsumexp(cat_Q2, dim=0).mean()
197 | cql_logsumexp = cql_logsumexp1 + cql_logsumexp2
198 |
199 | cql_penalty = cql_logsumexp - (Q1 + Q2).mean() # 标量
200 |
201 | # Update lagrange multiplier
202 | if self.use_critic_lagrange:
203 | alpha = torch.clamp(self.log_critic_alpha.exp(), min=0.0, max=1000000.0)
204 | alpha_loss = -0.5 * alpha * (cql_penalty - self.target_cql_penalty)
205 |
206 | self.critic_alpha_opt.zero_grad()
207 | alpha_loss.backward(retain_graph=True)
208 | self.critic_alpha_opt.step()
209 | alpha = torch.clamp(self.log_critic_alpha.exp(),
210 | min=0.0,
211 | max=1000000.0).detach()
212 | else:
213 | alpha = self.alpha
214 |
215 | # Combine losses
216 | critic_loss = critic_loss + alpha * cql_penalty
217 |
218 | # optimize critic
219 | self.critic_opt.zero_grad(set_to_none=True)
220 | critic_loss.backward()
221 | critic_grad_norm = utils.grad_norm(self.critic.parameters())
222 | self.critic_opt.step()
223 |
224 | if self.use_tb:
225 | metrics['critic_target_q'] = target_Q.mean().item()
226 | metrics['critic_q1'] = Q1.mean().item()
227 | # metrics['critic_q2'] = Q2.mean().item()
228 | metrics['critic_loss'] = critic_loss.item()
229 | metrics['critic_cql'] = cql_penalty.item()
230 | metrics['critic_cql_logsum1'] = cql_logsumexp1.item()
231 | # metrics['critic_cql_logsum2'] = cql_logsumexp1.item()
232 | metrics['rand_Q1'] = rand_Q1.mean().item()
233 | # metrics['rand_Q2'] = rand_Q2.mean().item()
234 | metrics['sampled_Q1'] = sampled_Q1.mean().item()
235 | # metrics['sampled_Q2'] = sampled_Q2.mean().item()
236 | metrics['critic_grad_norm'] = critic_grad_norm
237 |
238 | return metrics
239 |
240 | def update_actor(self, obs, action, step):
241 | metrics = dict()
242 |
243 | policy = self.actor(obs)
244 | sampled_action = policy.rsample() # (1024, 6)
245 | # print("sampled_action:", sampled_action.shape)
246 | log_pi = policy.log_prob(sampled_action) # (1024, 6)
247 |
248 | # update lagrange multiplier
249 | alpha_loss = -(self.log_actor_alpha * (log_pi + self.target_entropy).detach()).mean()
250 | self.actor_alpha_opt.zero_grad(set_to_none=True)
251 | alpha_loss.backward()
252 | self.actor_alpha_opt.step()
253 | alpha = self.log_actor_alpha.exp().detach()
254 |
255 | # optimize actor
256 | Q1, Q2 = self.critic(obs, sampled_action) # (1024, 1)
257 | Q = torch.min(Q1, Q2) # (1024, 1)
258 | actor_loss = (alpha * log_pi - Q).mean() # 标量
259 | self.actor_opt.zero_grad(set_to_none=True)
260 | actor_loss.backward()
261 | # torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 5)
262 | actor_grad_norm = utils.grad_norm(self.actor.parameters())
263 | self.actor_opt.step()
264 |
265 | if self.use_tb:
266 | metrics['actor_loss'] = actor_loss.item()
267 | metrics['actor_ent'] = -log_pi.mean().item()
268 | metrics['actor_alpha'] = alpha.item()
269 | metrics['actor_alpha_loss'] = alpha_loss.item()
270 | metrics['actor_mean'] = policy.loc.mean().item()
271 | metrics['actor_std'] = policy.scale.mean().item()
272 | metrics['actor_action'] = sampled_action.mean().item()
273 | metrics['actor_atanh_action'] = utils.atanh(sampled_action).mean().item()
274 | metrics['actor_grad_norm'] = actor_grad_norm
275 |
276 | return metrics
277 |
278 | # TODO: conservative data sharing
279 | def conservative_data_share(self, batch_main, batch_share):
280 | obs_all, action_all, next_obs_all, reward_all, discount_all, obs_k_all = [], [], [], [], [], []
281 |
282 | # 1. sample examples from the main task
283 | # shape = (512, 24), (512, 6), (512, 1), (512, 1), (512, 24)
284 | obs_m, action_m, reward_m, discount_m, next_obs_m, _ = utils.to_torch(batch_main, self.device)
285 | # print("main samples:", obs_m.shape, action_m.shape, reward_m.shape, discount_m.shape, next_obs_m.shape, eps_flag_m.shape)
286 | obs_all.append(obs_m)
287 | action_all.append(action_m)
288 | next_obs_all.append(next_obs_m)
289 | reward_all.append(reward_m)
290 | discount_all.append(discount_m)
291 |
292 | # 2. sample examples from other tasks
293 | # sample 10 times samples, and select the top-0.1 samples. (batch_size_split*10, xx, xx)
294 | # shape = (5120, 24) (5120, 6) (5120, 1) (5120, 1) (5120, 24)
295 | obs, action, reward, discount, next_obs, _ = utils.to_torch(batch_share, self.device)
296 | # print("obs:", obs.shape, action.shape, reward.shape, discount.shape, next_obs.shape)
297 | # calculate the conservative Q value
298 | with torch.no_grad():
299 | conservative_q_value, _ = self.critic(obs, action) # (5120, 1)
300 | conservative_q_value = conservative_q_value.squeeze().detach().cpu().numpy()
301 | # choose the top 0.1 top_index.shape=(512,)
302 | top_index = np.argpartition(conservative_q_value, -obs_m.shape[0])[-obs_m.shape[0]:] # find the top 0.1 index. (batch_size_split,)
303 | # extract the samples
304 | obs_all.append(obs[top_index])
305 | action_all.append(action[top_index])
306 | next_obs_all.append(next_obs[top_index])
307 | reward_all.append(reward[top_index])
308 | discount_all.append(discount[top_index])
309 |
310 | return torch.cat(obs_all, dim=0), torch.cat(action_all, dim=0), torch.cat(reward_all, dim=0), \
311 | torch.cat(discount_all, dim=0), torch.cat(next_obs_all, dim=0)
312 |
313 | def update(self, replay_iter_main, replay_iter_share, step, total_step):
314 | metrics = dict()
315 |
316 | batch_main = next(replay_iter_main)
317 | batch_share = next(replay_iter_share)
318 |
319 | # print("conservative data sharing...") # obs.shape=(1024, 24), action.shape=(1024, 6) reward.shape=(1024, 1)
320 | obs, action, reward, discount, next_obs = self.conservative_data_share(batch_main, batch_share)
321 |
322 | if self.use_tb:
323 | metrics['batch_reward'] = reward.mean().item()
324 |
325 | # update critic
326 | metrics.update(
327 | self.update_critic(obs, action, reward, discount, next_obs, step))
328 |
329 | # update actor
330 | metrics.update(self.update_actor(obs, action, step))
331 |
332 | # update critic target
333 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
334 |
335 | return metrics
336 |
--------------------------------------------------------------------------------
/agent/cql.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 |
8 | import utils
9 | from dm_control.utils import rewards
10 |
11 |
12 | class Actor(nn.Module):
13 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3):
14 | super().__init__()
15 | self.policy = nn.Sequential(
16 | nn.Linear(obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
17 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), # 新增
18 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
19 |
20 | self.fc_mu = nn.Linear(hidden_dim, action_dim)
21 | self.fc_logstd = nn.Linear(hidden_dim, action_dim)
22 |
23 | def forward(self, obs):
24 | f = self.policy(obs)
25 | mu = self.fc_mu(f)
26 | log_std = self.fc_logstd(f)
27 | mu = torch.clamp(mu, -5, 5)
28 |
29 | std = log_std.clamp(-5, 2).exp()
30 | dist = utils.SquashedNormal2(mu, std)
31 | return dist
32 |
33 |
34 | class Critic(nn.Module):
35 | def __init__(self, obs_dim, action_dim, hidden_dim, init_w=1e-3):
36 | super().__init__()
37 |
38 | self.q1_net = nn.Sequential(
39 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
40 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(),
41 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
42 | self.q1_last = nn.Linear(hidden_dim, 1)
43 |
44 | self.q2_net = nn.Sequential(
45 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(),
46 | nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.LeakyReLU(),
47 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU())
48 | self.q2_last = nn.Linear(hidden_dim, 1)
49 |
50 | def forward(self, obs, action):
51 | obs_action = torch.cat([obs, action], dim=-1)
52 | q1 = self.q1_net(obs_action)
53 | q1 = self.q1_last(q1)
54 |
55 | q2 = self.q2_net(obs_action)
56 | q2 = self.q2_last(q2)
57 |
58 | return q1, q2
59 |
60 |
61 | class CQLAgent:
62 | def __init__(self,
63 | name,
64 | obs_shape,
65 | action_shape,
66 | device,
67 | actor_lr,
68 | critic_lr,
69 | hidden_dim,
70 | critic_target_tau,
71 | nstep,
72 | batch_size,
73 | use_tb,
74 | alpha,
75 | n_samples,
76 | target_cql_penalty,
77 | use_critic_lagrange,
78 | num_expl_steps,
79 | has_next_action=False):
80 | self.num_expl_steps = num_expl_steps
81 | self.action_dim = action_shape[0]
82 | self.hidden_dim = hidden_dim
83 | self.actor_lr = actor_lr
84 | self.critic_lr = critic_lr
85 | self.device = device
86 | self.critic_target_tau = critic_target_tau
87 | self.use_tb = use_tb
88 | self.use_critic_lagrange = use_critic_lagrange
89 | self.target_cql_penalty = target_cql_penalty
90 |
91 | self.alpha = alpha
92 | self.n_samples = n_samples
93 |
94 | state_dim = obs_shape[0]
95 | action_dim = action_shape[0]
96 |
97 | # models
98 | self.actor = Actor(state_dim, action_dim, hidden_dim).to(device)
99 | self.critic = Critic(state_dim, action_dim, hidden_dim).to(device)
100 | self.critic_target = Critic(state_dim, action_dim, hidden_dim).to(device)
101 | self.critic_target.load_state_dict(self.critic.state_dict())
102 |
103 | # lagrange multipliers
104 | self.target_entropy = -self.action_dim
105 | self.log_actor_alpha = torch.zeros(1, requires_grad=True, device=device)
106 | self.log_critic_alpha = torch.zeros(1, requires_grad=True, device=device)
107 |
108 | # optimizers
109 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
110 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
111 | self.actor_alpha_opt = torch.optim.Adam([self.log_actor_alpha], lr=actor_lr)
112 | self.critic_alpha_opt = torch.optim.Adam([self.log_critic_alpha], lr=actor_lr)
113 |
114 | self.train()
115 |
116 | def train(self, training=True):
117 | self.training = training
118 | self.actor.train(training)
119 | self.critic.train(training)
120 |
121 | def act(self, obs, step, eval_mode):
122 | obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
123 | policy = self.actor(obs)
124 | if eval_mode:
125 | action = policy.mean
126 | else:
127 | action = policy.sample()
128 | if step < self.num_expl_steps:
129 | action.uniform_(-1.0, 1.0)
130 | return action.cpu().numpy()[0]
131 |
132 | def _repeated_critic_apply(self, obs, actions):
133 | """
134 | obs is (batch_size, obs_dim)
135 | actions is (n_samples, batch_size, action_dim)
136 |
137 | output tensors are (n_samples, batch_size, 1)
138 | """
139 | batch_size = obs.shape[0]
140 | n_samples = actions.shape[0]
141 |
142 | reshaped_actions = actions.reshape((n_samples * batch_size, -1))
143 | repeated_obs = obs.unsqueeze(0).repeat((n_samples, 1, 1))
144 | repeated_obs = repeated_obs.reshape((n_samples * batch_size, -1))
145 |
146 | Q1, Q2 = self.critic(repeated_obs, reshaped_actions)
147 | Q1 = Q1.reshape((n_samples, batch_size, 1))
148 | Q2 = Q2.reshape((n_samples, batch_size, 1))
149 |
150 | return Q1, Q2
151 |
152 | def update_critic(self, obs, action, reward, discount, next_obs, step):
153 | metrics = dict()
154 |
155 | # Compute standard SAC loss
156 | with torch.no_grad():
157 | dist = self.actor(next_obs) # SquashedNormal distribution
158 | sampled_next_action = dist.sample() # (1024, act_dim)
159 | # print("sampled_next_action:", sampled_next_action.shape)
160 | target_Q1, target_Q2 = self.critic_target(next_obs, sampled_next_action) # (1024,1), (1024,1)
161 | target_V = torch.min(target_Q1, target_Q2) # (1024,1)
162 | target_Q = reward + (discount * target_V) # (1024,1)
163 |
164 | Q1, Q2 = self.critic(obs, action)
165 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) # scalar
166 |
167 | # Add CQL penalty
168 | with torch.no_grad():
169 | random_actions = torch.FloatTensor(self.n_samples, Q1.shape[0],
170 | action.shape[-1]).uniform_(-1, 1).to(self.device) # (n_samples, 1024, act_dim)
171 | sampled_actions = self.actor(obs).sample(
172 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim)
173 | # print("sampled_actions:", sampled_actions.shape)
174 | next_sampled_actions = self.actor(next_obs).sample(
175 | sample_shape=(self.n_samples,)) # (n_samples, 1024, act_dim)
176 | # print("next_sampled_actions:", next_sampled_actions.shape)
177 |
178 | rand_Q1, rand_Q2 = self._repeated_critic_apply(obs, random_actions) # (n_samples, 1024, 1)
179 | sampled_Q1, sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1)
180 | obs, sampled_actions)
181 | next_sampled_Q1, next_sampled_Q2 = self._repeated_critic_apply( # (n_samples, 1024, 1)
182 | obs, next_sampled_actions)
183 |
184 | # situation 1
185 | cat_Q1 = torch.cat([rand_Q1, sampled_Q1, next_sampled_Q1,
186 | Q1.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1)
187 | cat_Q2 = torch.cat([rand_Q2, sampled_Q2, next_sampled_Q2,
188 | Q2.unsqueeze(0)], dim=0) # (1+3*n_samples, 1024, 1)
189 |
190 |
191 | # situation 2
192 | # cat_Q1 = torch.cat([rand_Q1 - np.log(0.5 ** self.action_dim),
193 | # sampled_Q1 - self.actor(obs).log_prob(sampled_actions).unsqueeze(-1).detach(),
194 | # next_sampled_Q1 - self.actor(obs).log_prob(next_sampled_actions).unsqueeze(-1).detach()], dim=0) # (1+3*n_samples, 1024, 1)
195 | # cat_Q2 = torch.cat([rand_Q2 - np.log(0.5 ** self.action_dim),
196 | # sampled_Q2 - self.actor(obs).log_prob(sampled_actions).unsqueeze(-1).detach(),
197 | # next_sampled_Q2 - self.actor(obs).log_prob(next_sampled_actions).unsqueeze(-1).detach()], dim=0) # (1+3*n_samples, 1024, 1)
198 |
199 | # cat_Q1 = torch.cat([rand_Q1 - np.log(0.5 ** self.action_dim),
200 | # sampled_Q1 - self.actor(obs).log_prob(sampled_actions).sum(-1, keepdims=True).detach(),
201 | # next_sampled_Q1 - self.actor(obs).log_prob(next_sampled_actions).sum(-1, keepdims=True).detach()], dim=0) # (1+3*n_samples, 1024, 1)
202 | # cat_Q2 = torch.cat([rand_Q2 - np.log(0.5 ** self.action_dim),
203 | # sampled_Q2 - self.actor(obs).log_prob(sampled_actions).sum(-1, keepdims=True).detach(),
204 | # next_sampled_Q2 - self.actor(obs).log_prob(next_sampled_actions).sum(-1, keepdims=True).detach()], dim=0) # (1+3*n_samples, 1024, 1)
205 |
206 | # print("cat Q1:", cat_Q1.shape, cat_Q2.shape)
207 | # if torch.isnan(cat_Q1).any():
208 | # print("in cql 1:", torch.isnan(rand_Q1).any(), torch.isnan(sampled_Q1).any(), torch.isnan(rand_Q1).any(), torch.isnan(next_sampled_Q1).any())
209 | # print("in cql 2:", torch.isnan(self.actor(obs).log_prob(sampled_actions)).any())
210 | # print("in cql 3:", torch.isnan(self.actor(obs).log_prob(next_sampled_actions)).any())
211 |
212 | assert (not torch.isnan(cat_Q1).any()) and (not torch.isnan(cat_Q2).any())
213 |
214 | cql_logsumexp1 = torch.logsumexp(cat_Q1, dim=0).mean()
215 | cql_logsumexp2 = torch.logsumexp(cat_Q2, dim=0).mean()
216 | cql_logsumexp = cql_logsumexp1 + cql_logsumexp2
217 | # print("train 1:", rand_Q1.mean(), np.log(0.5 ** self.action_dim))
218 | # dist_test = self.actor(obs)
219 | # print("train 2:", dist_test.loc.shape, dist_test.scale.shape, torch.isnan(dist_test.loc).any(), torch.isnan(dist_test.scale).any())
220 | # print("train 3:", sampled_Q1.mean(), self.actor(obs).log_prob(sampled_actions).detach().mean())
221 | # print("train 4:", next_sampled_Q1.mean(), self.actor(obs).log_prob(next_sampled_actions).detach().mean())
222 |
223 | cql_penalty = cql_logsumexp - (Q1 + Q2).mean() # 标量
224 |
225 | # Update lagrange multiplier
226 | if self.use_critic_lagrange:
227 | alpha = torch.clamp(self.log_critic_alpha.exp(), min=0.0, max=1000000.0)
228 | alpha_loss = -0.5 * alpha * (cql_penalty - self.target_cql_penalty)
229 |
230 | self.critic_alpha_opt.zero_grad()
231 | alpha_loss.backward(retain_graph=True)
232 | self.critic_alpha_opt.step()
233 | alpha = torch.clamp(self.log_critic_alpha.exp(),
234 | min=0.0,
235 | max=1000000.0).detach()
236 | else:
237 | alpha = self.alpha
238 |
239 | # Combine losses
240 | critic_loss = critic_loss + alpha * cql_penalty
241 |
242 | # optimize critic
243 | self.critic_opt.zero_grad(set_to_none=True)
244 | critic_loss.backward()
245 | # torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 50)
246 | critic_grad_norm = utils.grad_norm(self.critic.parameters())
247 | self.critic_opt.step()
248 |
249 | if self.use_tb:
250 | metrics['critic_target_q'] = target_Q.mean().item()
251 | metrics['critic_q1'] = Q1.mean().item()
252 | # metrics['critic_q2'] = Q2.mean().item()
253 | metrics['critic_loss'] = critic_loss.item()
254 | metrics['critic_cql'] = cql_penalty.item()
255 | metrics['critic_cql_logsum1'] = cql_logsumexp1.item()
256 | # metrics['critic_cql_logsum2'] = cql_logsumexp1.item()
257 | metrics['rand_Q1'] = rand_Q1.mean().item()
258 | # metrics['rand_Q2'] = rand_Q2.mean().item()
259 | metrics['sampled_Q1'] = sampled_Q1.mean().item()
260 | # metrics['sampled_Q2'] = sampled_Q2.mean().item()
261 | metrics['critic_grad_norm'] = critic_grad_norm
262 |
263 | return metrics
264 |
265 | def update_actor(self, obs, action, step):
266 | metrics = dict()
267 |
268 | policy = self.actor(obs)
269 | sampled_action = policy.rsample() # (1024, 6)
270 | # print("sampled_action:", sampled_action.shape)
271 | log_pi = policy.log_prob(sampled_action) # (1024, 6)
272 |
273 | # update lagrange multiplier
274 | alpha_loss = -(self.log_actor_alpha * (log_pi + self.target_entropy).detach()).mean()
275 | self.actor_alpha_opt.zero_grad(set_to_none=True)
276 | alpha_loss.backward()
277 | self.actor_alpha_opt.step()
278 | alpha = self.log_actor_alpha.exp().detach()
279 |
280 | # optimize actor
281 | Q1, Q2 = self.critic(obs, sampled_action) # (1024, 1)
282 | Q = torch.min(Q1, Q2) # (1024, 1)
283 | actor_loss = (alpha * log_pi - Q).mean() # 标量
284 | self.actor_opt.zero_grad(set_to_none=True)
285 | actor_loss.backward()
286 | # torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 5)
287 | actor_grad_norm = utils.grad_norm(self.actor.parameters())
288 | self.actor_opt.step()
289 |
290 | if self.use_tb:
291 | metrics['actor_loss'] = actor_loss.item()
292 | metrics['actor_ent'] = -log_pi.mean().item()
293 | metrics['actor_alpha'] = alpha.item()
294 | metrics['actor_alpha_loss'] = alpha_loss.item()
295 | metrics['actor_mean'] = policy.loc.mean().item()
296 | metrics['actor_std'] = policy.scale.mean().item()
297 | metrics['actor_action'] = sampled_action.mean().item()
298 | metrics['actor_atanh_action'] = utils.atanh(sampled_action).mean().item()
299 | metrics['actor_grad_norm'] = actor_grad_norm
300 |
301 | return metrics
302 |
303 | def update(self, replay_iter, step, total_step):
304 | metrics = dict()
305 |
306 | batch = next(replay_iter)
307 | # obs.shape=(1024,obs_dim), action.shape=(1024,1), reward.shape=(1024,1), discount.shape=(1024,1)
308 | obs, action, reward, discount, next_obs, _ = utils.to_torch(
309 | batch, self.device)
310 |
311 | if self.use_tb:
312 | metrics['batch_reward'] = reward.mean().item()
313 |
314 | # update critic
315 | metrics.update(
316 | self.update_critic(obs, action, reward, discount, next_obs, step))
317 |
318 | # update actor
319 | metrics.update(self.update_actor(obs, action, step))
320 |
321 | # update critic target
322 | utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
323 |
324 | return metrics
325 |
--------------------------------------------------------------------------------