├── plot_data
├── fetch_push_final.mp4
├── FetchPush-v1_DDPG.npy
├── FetchReach-v1_DDPG.npy
├── FetchSlide-v1_DDPG.npy
├── FetchPush-v1_DDPG_HER.npy
├── FetchPush-v1_DDPG_PER.npy
├── FetchReach-v1_DDPG_HER.npy
├── FetchReach-v1_DDPG_PER.npy
├── FetchSlide-v1_DDPG_HER.npy
├── FetchSlide-v1_DDPG_PER.npy
├── FetchPickAndPlace-v1_DDPG.npy
├── FetchPush-v1_DDPG_HER_PER.npy
├── FetchSlide-v1_DDPG_HER_PER.npy
├── FetchPickAndPlace-v1_DDPG_HER.npy
├── FetchPickAndPlace-v1_DDPG_PER.npy
├── FetchPush-v1_DDPG_HER_PER_10.npy
├── FetchReach-v1_DDPG_HER_PER_10.npy
├── FetchReach-v1_DDPG_HER_PER_2.npy
├── FetchReach-v1_DDPG_HER_PER_4.npy
├── FetchReach-v1_DDPG_HER_PER_6.npy
├── FetchReach-v1_DDPG_HER_PER_8.npy
├── FetchSlide-v1_DDPG_HER_PER_10.npy
└── FetchPickAndPlace-v1_DDPG_HER_PER.npy
├── plots
├── all_plots.png
├── all_plots_fp.png
├── all_plots_fr.png
├── all_plots_fs.png
├── alpha_plots_fp.png
└── alpha_plots_fs.png
├── models
├── fetch_push_model.pt
├── fetch_reach_model.pt
├── fetch_slide_model.pt
└── fetch_picknplace_model.pt
├── train.py
├── models.py
├── README.md
├── her.py
├── demo.py
├── utils.py
├── arguments.py
├── viz.py
├── normalizer.py
├── replay_buffer.py
└── ddpg_agent.py
/plot_data/fetch_push_final.mp4:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/plots/all_plots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots.png
--------------------------------------------------------------------------------
/plots/all_plots_fp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots_fp.png
--------------------------------------------------------------------------------
/plots/all_plots_fr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots_fr.png
--------------------------------------------------------------------------------
/plots/all_plots_fs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots_fs.png
--------------------------------------------------------------------------------
/plots/alpha_plots_fp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/alpha_plots_fp.png
--------------------------------------------------------------------------------
/plots/alpha_plots_fs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/alpha_plots_fs.png
--------------------------------------------------------------------------------
/models/fetch_push_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_push_model.pt
--------------------------------------------------------------------------------
/models/fetch_reach_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_reach_model.pt
--------------------------------------------------------------------------------
/models/fetch_slide_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_slide_model.pt
--------------------------------------------------------------------------------
/plot_data/FetchPush-v1_DDPG.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG.npy
--------------------------------------------------------------------------------
/models/fetch_picknplace_model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_picknplace_model.pt
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG.npy
--------------------------------------------------------------------------------
/plot_data/FetchSlide-v1_DDPG.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG.npy
--------------------------------------------------------------------------------
/plot_data/FetchPush-v1_DDPG_HER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_HER.npy
--------------------------------------------------------------------------------
/plot_data/FetchPush-v1_DDPG_PER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_PER.npy
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG_HER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER.npy
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG_PER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_PER.npy
--------------------------------------------------------------------------------
/plot_data/FetchSlide-v1_DDPG_HER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_HER.npy
--------------------------------------------------------------------------------
/plot_data/FetchSlide-v1_DDPG_PER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_PER.npy
--------------------------------------------------------------------------------
/plot_data/FetchPickAndPlace-v1_DDPG.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG.npy
--------------------------------------------------------------------------------
/plot_data/FetchPush-v1_DDPG_HER_PER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_HER_PER.npy
--------------------------------------------------------------------------------
/plot_data/FetchSlide-v1_DDPG_HER_PER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_HER_PER.npy
--------------------------------------------------------------------------------
/plot_data/FetchPickAndPlace-v1_DDPG_HER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG_HER.npy
--------------------------------------------------------------------------------
/plot_data/FetchPickAndPlace-v1_DDPG_PER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG_PER.npy
--------------------------------------------------------------------------------
/plot_data/FetchPush-v1_DDPG_HER_PER_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_HER_PER_10.npy
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG_HER_PER_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_10.npy
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG_HER_PER_2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_2.npy
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG_HER_PER_4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_4.npy
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG_HER_PER_6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_6.npy
--------------------------------------------------------------------------------
/plot_data/FetchReach-v1_DDPG_HER_PER_8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_8.npy
--------------------------------------------------------------------------------
/plot_data/FetchSlide-v1_DDPG_HER_PER_10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_HER_PER_10.npy
--------------------------------------------------------------------------------
/plot_data/FetchPickAndPlace-v1_DDPG_HER_PER.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG_HER_PER.npy
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 | import os, sys
4 | from arguments import get_args
5 | from mpi4py import MPI
6 | from subprocess import CalledProcessError
7 | from ddpg_agent import ddpg_agent
8 |
9 | """
10 | train the agent, the MPI part code is copy from openai baselines(https://github.com/openai/baselines/blob/master/baselines/her)
11 |
12 | """
13 | def get_env_params(env):
14 |
15 | obs = env.reset()
16 |
17 | # close the environment
18 | params = {'obs': obs['observation'].shape[0],
19 | 'goal': obs['desired_goal'].shape[0],
20 | 'action': env.action_space.shape[0],
21 | 'action_max': env.action_space.high[0],
22 | }
23 | params['max_timesteps'] = env._max_episode_steps
24 |
25 | return params
26 |
27 | def launch(args):
28 | # create the ddpg_agent
29 | env = gym.make(args.env_name)
30 | env.reward_type = 'dense'
31 | # get the environment parameters
32 | env_params = get_env_params(env)
33 |
34 | # create the ddpg agent to interact with the environment
35 | ddpg_trainer = ddpg_agent(args, env, env_params)
36 | ddpg_trainer.learn()
37 |
38 | if __name__ == '__main__':
39 |
40 | # take the configuration for the HER
41 | os.environ['OMP_NUM_THREADS'] = '1'
42 | os.environ['MKL_NUM_THREADS'] = '1'
43 | os.environ['IN_MPI'] = '1'
44 |
45 | # get the params
46 | args = get_args()
47 | launch(args)
48 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | """
6 | the input x in both networks should be [o, g], where o is the observation and g is the goal.
7 |
8 | """
9 |
10 | # define the actor network
11 | class actor(nn.Module):
12 | def __init__(self, env_params):
13 | super(actor, self).__init__()
14 |
15 | self.max_action = env_params['action_max']
16 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256)
17 | self.fc2 = nn.Linear(256, 256)
18 | self.fc3 = nn.Linear(256, 256)
19 | self.action_out = nn.Linear(256, env_params['action'])
20 |
21 | def forward(self, x):
22 |
23 | x = F.relu(self.fc1(x))
24 | x = F.relu(self.fc2(x))
25 | x = F.relu(self.fc3(x))
26 | actions = self.max_action * torch.tanh(self.action_out(x))
27 |
28 | return actions
29 |
30 | class critic(nn.Module):
31 | def __init__(self, env_params):
32 | super(critic, self).__init__()
33 |
34 | self.max_action = env_params['action_max']
35 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
36 | self.fc2 = nn.Linear(256, 256)
37 | self.fc3 = nn.Linear(256, 256)
38 | self.q_out = nn.Linear(256, 1)
39 |
40 | def forward(self, x, actions):
41 | x = torch.cat([x, actions / self.max_action], dim=1)
42 | x = F.relu(self.fc1(x))
43 | x = F.relu(self.fc2(x))
44 | x = F.relu(self.fc3(x))
45 | q_value = self.q_out(x)
46 |
47 | return q_value
48 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DDPG_Fetch
2 | Exploring the performance of Prioritized Experience Replay (PER) with the DDPG+HER scheme on the Fetch Robotics Environemnt
3 |
4 | Plots for Mean Success Rates for different Fetch Environments
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | Performance Plots when varying the alpha parameter on PER
18 |
19 |
20 |
21 |
22 |
23 | * Correction: The plot on the right is for FetchSlide but has been mistakenly labelled as FetchPush
24 |
25 |
26 |
27 |
28 | Addition of PER along with finetuning the alpha parameter boosts its performance.
29 |
30 | The inclusion of the PER algo within the DDPG-HER framework can be done in many ways, it could give greater performance boosts if combined well.
31 | (The integration of PER in this code isn't perfect, just something I tried out over a weekend)
32 |
33 | Use the command below to start training. (Avoid using sudo, if you get an ```EXPORT LIBRARY.. .bashrc``` error)
34 |
35 | ```
36 | mpirun -np 19 python3 train.py
37 | ```
38 |
--------------------------------------------------------------------------------
/her.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class her_sampler:
4 | def __init__(self, replay_strategy, replay_k, reward_func=None, her=False, per=False):
5 | self.replay_strategy = replay_strategy
6 | self.replay_k = replay_k
7 | self.her = her
8 | self.per = per
9 |
10 | if self.replay_strategy == 'future':
11 | self.future_p = 1 - (1. / (1 + replay_k))
12 | else:
13 | self.future_p = 0
14 |
15 | self.reward_func = reward_func
16 |
17 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
18 |
19 | T = episode_batch['actions'].shape[1]
20 | rollout_batch_size = episode_batch['actions'].shape[0]
21 | batch_size = batch_size_in_transitions
22 |
23 | # select which rollouts and which timesteps to be used
24 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
25 |
26 | if self.per == False:
27 | t_samples = np.random.randint(T, size=batch_size)
28 | else:
29 | p = episode_batch['p']
30 | sum_p = sum(p[0])
31 | p_norm = [p_elem/sum_p for p_elem in p[0]]
32 | t_samples = np.random.choice(T, size=batch_size, p=p_norm)
33 |
34 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
35 |
36 | if self.her==True:
37 |
38 | # her idx
39 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p)
40 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
41 | future_offset = future_offset.astype(int)
42 | future_t = (t_samples + 1 + future_offset)[her_indexes]
43 |
44 | # replace go with achieved goal
45 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
46 | transitions['g'][her_indexes] = future_ag
47 |
48 | transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1)
49 |
50 | # to get the params to re-compute reward
51 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}
52 |
53 |
54 | return transitions
55 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import actor
3 | from arguments import get_args
4 | import gym
5 | import numpy as np
6 |
7 | # process the inputs
8 | def process_inputs(o, g, o_mean, o_std, g_mean, g_std, args):
9 | o_clip = np.clip(o, -args.clip_obs, args.clip_obs)
10 | g_clip = np.clip(g, -args.clip_obs, args.clip_obs)
11 | o_norm = np.clip((o_clip - o_mean) / (o_std), -args.clip_range, args.clip_range)
12 | g_norm = np.clip((g_clip - g_mean) / (g_std), -args.clip_range, args.clip_range)
13 | inputs = np.concatenate([o_norm, g_norm])
14 | inputs = torch.tensor(inputs, dtype=torch.float32)
15 |
16 | return inputs
17 |
18 | if __name__ == '__main__':
19 | args = get_args()
20 |
21 | # load the model param
22 | model_path = args.save_dir + args.env_name + '/model.pt'
23 | o_mean, o_std, g_mean, g_std, model = torch.load(model_path, map_location=lambda storage, loc: storage)
24 |
25 | # create the environment
26 | env = gym.make(args.env_name)
27 | # get the env param
28 | observation = env.reset()
29 |
30 | # get the environment params
31 | env_params = {'obs': observation['observation'].shape[0],
32 | 'goal': observation['desired_goal'].shape[0],
33 | 'action': env.action_space.shape[0],
34 | 'action_max': env.action_space.high[0],
35 | }
36 |
37 | # create the actor network
38 | actor_network = actor(env_params)
39 | actor_network.load_state_dict(model)
40 | actor_network.eval()
41 |
42 | for i in range(args.demo_length):
43 | observation = env.reset()
44 |
45 | # start to do the demo
46 | obs = observation['observation']
47 | g = observation['desired_goal']
48 |
49 | for t in range(env._max_episode_steps):
50 | env.render()
51 | inputs = process_inputs(obs, g, o_mean, o_std, g_mean, g_std, args)
52 |
53 | with torch.no_grad():
54 | pi = actor_network(inputs)
55 | action = pi.detach().numpy().squeeze()
56 |
57 | # put actions into the environment
58 | observation_new, reward, _, info = env.step(action)
59 | obs = observation_new['observation']
60 |
61 | print('the episode is: {}, is success: {}'.format(i, info['is_success']))
62 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from mpi4py import MPI
2 | import numpy as np
3 | import torch
4 |
5 | # sync_networks across the different cores
6 | def sync_networks(network):
7 | """
8 | netowrk is the network you want to sync
9 |
10 | """
11 | comm = MPI.COMM_WORLD
12 | flat_params, params_shape = _get_flat_params(network)
13 | comm.Bcast(flat_params, root=0)
14 | # set the flat params back to the network
15 | _set_flat_params(network, params_shape, flat_params)
16 |
17 | # get the flat params from the network
18 | def _get_flat_params(network):
19 | param_shape = {}
20 | flat_params = None
21 | for key_name, value in network.named_parameters():
22 | param_shape[key_name] = value.detach().numpy().shape
23 | if flat_params is None:
24 | flat_params = value.detach().numpy().flatten()
25 | else:
26 | flat_params = np.append(flat_params, value.detach().numpy().flatten())
27 | return flat_params, param_shape
28 |
29 | # set the params from the network
30 | def _set_flat_params(network, params_shape, params):
31 | pointer = 0
32 | for key_name, values in network.named_parameters():
33 |
34 | # get the length of the parameters
35 | len_param = np.prod(params_shape[key_name])
36 | copy_params = params[pointer:pointer + len_param].reshape(params_shape[key_name])
37 | copy_params = torch.tensor(copy_params)
38 |
39 | # copy the params
40 | values.data.copy_(copy_params.data)
41 |
42 | # update the pointer
43 | pointer += len_param
44 |
45 | # sync the networks
46 | def sync_grads(network):
47 | flat_grads, grads_shape = _get_flat_grads(network)
48 | comm = MPI.COMM_WORLD
49 | global_grads = np.zeros_like(flat_grads)
50 | comm.Allreduce(flat_grads, global_grads, op=MPI.SUM)
51 | _set_flat_grads(network, grads_shape, global_grads)
52 |
53 | def _set_flat_grads(network, grads_shape, flat_grads):
54 | pointer = 0
55 |
56 | for key_name, value in network.named_parameters():
57 | len_grads = np.prod(grads_shape[key_name])
58 | copy_grads = flat_grads[pointer:pointer + len_grads].reshape(grads_shape[key_name])
59 | copy_grads = torch.tensor(copy_grads)
60 |
61 | # copy the grads
62 | value.grad.data.copy_(copy_grads.data)
63 | pointer += len_grads
64 |
65 | def _get_flat_grads(network):
66 | grads_shape = {}
67 | flat_grads = None
68 |
69 | for key_name, value in network.named_parameters():
70 | grads_shape[key_name] = value.grad.data.cpu().numpy().shape
71 |
72 | if flat_grads is None:
73 | flat_grads = value.grad.data.cpu().numpy().flatten()
74 | else:
75 | flat_grads = np.append(flat_grads, value.grad.data.cpu().numpy().flatten())
76 |
77 | return flat_grads, grads_shape
78 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | """
4 | Here are the param for the training
5 |
6 | """
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | # the environment setting
11 | parser.add_argument('--env-name', type=str, default='FetchReach-v1', help='the environment name')
12 | parser.add_argument('--n-epochs', type=int, default=50, help='the number of epochs to train the agent')
13 | parser.add_argument('--n-cycles', type=int, default=1, help='the times to collect samples per epoch')
14 | parser.add_argument('--n-batches', type=int, default=40, help='the times to update the network')
15 | parser.add_argument('--save-interval', type=int, default=5, help='the interval that save the trajectory')
16 | parser.add_argument('--seed', type=int, default=123, help='random seed')
17 | parser.add_argument('--num-workers', type=int, default=1, help='the number of cpus to collect samples')
18 | parser.add_argument('--replay-strategy', type=str, default='future', help='the HER strategy')
19 | parser.add_argument('--clip-return', type=float, default=50, help='if clip the returns')
20 | parser.add_argument('--save-dir', type=str, default='saved_models/', help='the path to save the models')
21 | parser.add_argument('--noise-eps', type=float, default=0.2, help='noise eps')
22 | parser.add_argument('--random-eps', type=float, default=0.3, help='random eps')
23 | parser.add_argument('--buffer-size', type=int, default=int(1e6), help='the size of the buffer')
24 | parser.add_argument('--replay-k', type=int, default=4, help='ratio to be replace')
25 | parser.add_argument('--clip-obs', type=float, default=200, help='the clip ratio')
26 | parser.add_argument('--batch-size', type=int, default=256, help='the sample batch size')
27 | parser.add_argument('--gamma', type=float, default=0.98, help='the discount factor')
28 | parser.add_argument('--action-l2', type=float, default=1, help='l2 reg')
29 | parser.add_argument('--lr-actor', type=float, default=0.001, help='the learning rate of the actor')
30 | parser.add_argument('--lr-critic', type=float, default=0.001, help='the learning rate of the critic')
31 | parser.add_argument('--polyak', type=float, default=0.95, help='the average coefficient')
32 | parser.add_argument('--n-test-rollouts', type=int, default=10, help='the number of tests')
33 | parser.add_argument('--clip-range', type=float, default=5, help='the clip range')
34 | parser.add_argument('--demo-length', type=int, default=10, help='the demo length')
35 | parser.add_argument('--cuda', action='store_true', help='if use gpu do the acceleration')
36 | parser.add_argument('--num-rollouts-per-mpi', type=int, default=1, help='the rollouts per mpi')
37 | parser.add_argument('--her', type=bool, default=False, help='is HER True or False')
38 | parser.add_argument('--per', type=bool, default=False, help='is PER True or False')
39 |
40 |
41 | args = parser.parse_args()
42 |
43 | return args
44 |
--------------------------------------------------------------------------------
/viz.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import seaborn as sns
3 | import matplotlib.pyplot as plt
4 | sns.set()
5 |
6 | #FetchReach Plots
7 |
8 | env = 'FetchSlide-v1'
9 |
10 | '''
11 | ddpg = np.load('{}_DDPG.npy'.format(env))
12 | ddpg_per = np.load('{}_DDPG_PER.npy'.format(env))
13 | ddpg_her = np.load('{}_DDPG_HER.npy'.format(env))
14 | '''
15 |
16 |
17 | ddpg_her_per_0 = np.load('{}_DDPG_HER.npy'.format(env))
18 | ddpg_her_per_5 = np.load('{}_DDPG_HER_PER.npy'.format(env))
19 | ddpg_her_per_10 = np.load('{}_DDPG_HER_PER_10.npy'.format(env))
20 | #ddpg_her_per = np.load('{}_DDPG_HER_PER.npy'.format(env))
21 |
22 | '''
23 | ep_n_0 = np.where(ddpg_her_per_0 == 1.0)[0][0]#[i for i in ddpg_her_per_0 if i == 1.0][0]
24 | ep_n_2 = np.where(ddpg_her_per_2 == 1.0)[0][0]#[i for i in ddpg_her_per_2 if i == 1.0][0]
25 | ep_n_4 = np.where(ddpg_her_per_4 == 1.0)[0][0]#[i for i in ddpg_her_per_4 if i == 1.0][0]
26 | ep_n_6 = np.where(ddpg_her_per_6 == 1.0)[0][0]#[i for i in ddpg_her_per_6 if i == 1.0][0]
27 | ep_n_8 = np.where(ddpg_her_per_8 == 1.0)[0][0]#[i for i in ddpg_her_per_8 if i == 1.0][0]
28 | ep_n_10 = np.where(ddpg_her_per_10 == 1.0)[0][0]#[i for i in ddpg_her_per_10 if i == 1.0][0]
29 |
30 |
31 | plt.bar(np.arange(len([0.0,0.2,0.4,0.6,0.8,1.0])), [ep_n_0, ep_n_2, ep_n_4, ep_n_6, ep_n_8, ep_n_10])
32 | plt.show()
33 | '''
34 | n = range(50)#range(len(ddpg))
35 | plt.plot(n, ddpg_her_per_0, label='alpha = 0.0')
36 | #plt.plot(n, ddpg_her_per_2, label='alpha = 0.2')
37 | plt.plot(n, ddpg_her_per_5, label='alpha = 0.5')
38 | '''
39 | plt.plot(n, ddpg_her_per_6, label='alpha = 0.6')
40 | plt.plot(n, ddpg_her_per_8, label='alpha = 0.8')
41 | '''
42 | plt.plot(n, ddpg_her_per_10, label='alpha = 1.0')
43 | plt.xlabel('Epoch')
44 | plt.ylabel('Mean Success Rate')
45 | plt.title('Mean Success Rate variation with alpha for DDPG+HER+PER on FetchPush-v1')
46 | plt.legend()
47 | plt.savefig('alpha_plots_fr')
48 | plt.show()
49 |
50 |
51 | '''
52 | #Plot 1 : All Plots
53 | plt.plot(n, ddpg, label='DDPG')
54 | plt.plot(n, ddpg_per, label='DDPG + PER')
55 | plt.plot(n, ddpg_her, label='DDPG + HER')
56 | plt.plot(n, ddpg_her_per, label='DDPG + HER + PER')
57 | plt.xlabel('Epoch')
58 | plt.ylabel('Mean Success Rate')
59 | plt.title('Mean Success Rate Plots for {}'.format(env))
60 | plt.legend()
61 | plt.savefig("{}_plots/all_plots".format(env))
62 | plt.show()
63 |
64 | #Plot 2 : DDPG and PER variants
65 | plt.plot(n, ddpg, label='DDPG')
66 | plt.plot(n, ddpg_per, label='DDPG + PER')
67 | plt.xlabel('Epoch')
68 | plt.ylabel('Mean Success Rate')
69 | plt.title('Mean Success Rate Plots for {}'.format(env))
70 | plt.legend()
71 | plt.savefig("{}_plots/ddpg_per_plots".format(env))
72 | plt.show()
73 |
74 | #Plot 3: HER variants
75 | plt.plot(n, ddpg_her, label='DDPG + HER')
76 | plt.plot(n, ddpg_her_per, label='DDPG + HER + PER')
77 | plt.xlabel('Epoch')
78 | plt.ylabel('Mean Success Rate')
79 | plt.title('Mean Success Rate Plots for {}'.format(env))
80 | plt.legend()
81 | plt.savefig("{}_plots/ddpg_her_per_plots".format(env))
82 | plt.show()
83 | '''
--------------------------------------------------------------------------------
/normalizer.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import numpy as np
3 | from mpi4py import MPI
4 |
5 | class normalizer:
6 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf):
7 |
8 | self.size = size
9 | self.eps = eps
10 | self.default_clip_range = default_clip_range
11 |
12 | # some local information
13 | self.local_sum = np.zeros(self.size, np.float32)
14 | self.local_sumsq = np.zeros(self.size, np.float32)
15 | self.local_count = np.zeros(1, np.float32)
16 |
17 | # get the total sum sumsq and sum count
18 | self.total_sum = np.zeros(self.size, np.float32)
19 | self.total_sumsq = np.zeros(self.size, np.float32)
20 | self.total_count = np.ones(1, np.float32)
21 |
22 | # get the mean and std
23 | self.mean = np.zeros(self.size, np.float32)
24 | self.std = np.ones(self.size, np.float32)
25 |
26 | # thread locker
27 | self.lock = threading.Lock()
28 |
29 | # update the parameters of the normalizer
30 | def update(self, v):
31 | v = v.reshape(-1, self.size)
32 |
33 | # do the computing
34 | with self.lock:
35 | self.local_sum += v.sum(axis=0)
36 | self.local_sumsq += (np.square(v)).sum(axis=0)
37 | self.local_count[0] += v.shape[0]
38 |
39 | # sync the parameters across the cpus
40 | def sync(self, local_sum, local_sumsq, local_count):
41 | local_sum[...] = self._mpi_average(local_sum)
42 | local_sumsq[...] = self._mpi_average(local_sumsq)
43 | local_count[...] = self._mpi_average(local_count)
44 |
45 | return local_sum, local_sumsq, local_count
46 |
47 | def recompute_stats(self):
48 | with self.lock:
49 | local_count = self.local_count.copy()
50 | local_sum = self.local_sum.copy()
51 | local_sumsq = self.local_sumsq.copy()
52 | # reset
53 | self.local_count[...] = 0
54 | self.local_sum[...] = 0
55 | self.local_sumsq[...] = 0
56 |
57 | # synrc the stats
58 | sync_sum, sync_sumsq, sync_count = self.sync(local_sum, local_sumsq, local_count)
59 |
60 | # update the total stuff
61 | self.total_sum += sync_sum
62 | self.total_sumsq += sync_sumsq
63 | self.total_count += sync_count
64 |
65 | # calculate the new mean and std
66 | self.mean = self.total_sum / self.total_count
67 | self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square(self.total_sum / self.total_count)))
68 |
69 | # average across the cpu's data
70 | def _mpi_average(self, x):
71 | buf = np.zeros_like(x)
72 | MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM)
73 | buf /= MPI.COMM_WORLD.Get_size()
74 |
75 | return buf
76 |
77 | # normalize the observation
78 | def normalize(self, v, clip_range=None):
79 | if clip_range is None:
80 | clip_range = self.default_clip_range
81 |
82 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range)
83 |
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import numpy as np
3 |
4 | """
5 | the replay buffer here is basically from the openai baselines code
6 |
7 | """
8 | class replay_buffer:
9 | def __init__(self, env_params, buffer_size, sample_func, her, per):
10 | self.env_params = env_params
11 | self.T = env_params['max_timesteps']
12 | self.size = buffer_size // self.T
13 |
14 | # memory management
15 | self.current_size = 0
16 | self.n_transitions_stored = 0
17 | self.sample_func = sample_func
18 | self.per = per
19 | self.her = her
20 |
21 | # create the buffer to store info
22 | if self.her == False:
23 | if self.per == True:
24 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]),
25 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]),
26 | 'g': np.empty([self.size, self.T, self.env_params['goal']]),
27 | 'actions': np.empty([self.size, self.T, self.env_params['action']]),
28 | 'r': np.empty([self.size, self.T]),
29 | 'p': np.empty([self.size, self.T])
30 | }
31 | else:
32 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]),
33 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]),
34 | 'g': np.empty([self.size, self.T, self.env_params['goal']]),
35 | 'actions': np.empty([self.size, self.T, self.env_params['action']]),
36 | 'r': np.empty([self.size, self.T])
37 | }
38 |
39 | elif self.her == True:
40 | if self.per == True:
41 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]),
42 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]),
43 | 'g': np.empty([self.size, self.T, self.env_params['goal']]),
44 | 'actions': np.empty([self.size, self.T, self.env_params['action']]),
45 | 'p': np.empty([self.size, self.T])
46 | }
47 | else:
48 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]),
49 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]),
50 | 'g': np.empty([self.size, self.T, self.env_params['goal']]),
51 | 'actions': np.empty([self.size, self.T, self.env_params['action']]),
52 | }
53 |
54 | # thread lock
55 | self.lock = threading.Lock()
56 |
57 | # store the episode
58 | def store_episode(self, episode_batch):
59 |
60 | if self.her == False:
61 | if self.per == True:
62 | mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p = episode_batch
63 | else:
64 | mb_obs, mb_ag, mb_g, mb_actions, mb_r = episode_batch
65 | else:
66 | if self.per == True:
67 | mb_obs, mb_ag, mb_g, mb_actions, mb_p = episode_batch
68 | else:
69 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch
70 |
71 | batch_size = mb_obs.shape[0]
72 |
73 | with self.lock:
74 | idxs = self._get_storage_idx(inc=batch_size)
75 |
76 | # store the informations
77 | self.buffers['obs'][idxs] = mb_obs
78 | self.buffers['ag'][idxs] = mb_ag
79 | self.buffers['g'][idxs] = mb_g
80 | self.buffers['actions'][idxs] = mb_actions
81 |
82 | if self.her == False:
83 | self.buffers['r'][idxs] = mb_r
84 |
85 | if self.per == True:
86 | self.buffers['p'][idxs] = mb_p
87 |
88 | elif self.her == True:
89 | if self.per == True:
90 | self.buffers['p'][idxs] = mb_p
91 |
92 |
93 | self.n_transitions_stored += self.T * batch_size
94 |
95 | # sample the data from the replay buffer
96 | def sample(self, batch_size):
97 |
98 | temp_buffers = {}
99 |
100 | with self.lock:
101 | for key in self.buffers.keys():
102 | temp_buffers[key] = self.buffers[key][:self.current_size].copy()
103 |
104 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :]
105 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :]
106 |
107 | # sample transitions
108 | transitions = self.sample_func(temp_buffers, batch_size)
109 |
110 | return transitions
111 |
112 | def _get_storage_idx(self, inc=None):
113 |
114 | inc = inc or 1
115 |
116 | if self.current_size+inc <= self.size:
117 | idx = np.arange(self.current_size, self.current_size+inc)
118 |
119 | elif self.current_size < self.size:
120 | overflow = inc - (self.size - self.current_size)
121 | idx_a = np.arange(self.current_size, self.size)
122 | idx_b = np.random.randint(0, self.current_size, overflow)
123 | idx = np.concatenate([idx_a, idx_b])
124 | else:
125 | idx = np.random.randint(0, self.size, inc)
126 |
127 | self.current_size = min(self.size, self.current_size+inc)
128 |
129 | if inc == 1:
130 | idx = idx[0]
131 |
132 | return idx
--------------------------------------------------------------------------------
/ddpg_agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from datetime import datetime
4 | import numpy as np
5 | from mpi4py import MPI
6 | from models import actor, critic
7 | from utils import sync_networks, sync_grads
8 | from replay_buffer import replay_buffer
9 | from normalizer import normalizer
10 | from her import her_sampler
11 | import matplotlib.pyplot as plt
12 |
13 | """
14 | ddpg with HER (MPI-version)
15 |
16 | """
17 | class ddpg_agent:
18 | def __init__(self, args, env, env_params):
19 | self.args = args
20 | self.env = env
21 | self.env_params = env_params
22 |
23 | # create the network
24 | self.actor_network = actor(env_params)
25 | self.critic_network = critic(env_params)
26 |
27 | # sync the networks across the cpus
28 | sync_networks(self.actor_network)
29 | sync_networks(self.critic_network)
30 |
31 | # build up the target network
32 | self.actor_target_network = actor(env_params)
33 | self.critic_target_network = critic(env_params)
34 |
35 | # load the weights into the target networks
36 | self.actor_target_network.load_state_dict(self.actor_network.state_dict())
37 | self.critic_target_network.load_state_dict(self.critic_network.state_dict())
38 |
39 | # if use gpu
40 | if self.args.cuda:
41 | self.actor_network.cuda()
42 | self.critic_network.cuda()
43 | self.actor_target_network.cuda()
44 | self.critic_target_network.cuda()
45 |
46 | # create the optimizer
47 | self.actor_optim = torch.optim.Adam(self.actor_network.parameters(), lr=self.args.lr_actor)
48 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic)
49 |
50 | # her sampler
51 | self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, self.env.compute_reward, self.args.her, self.args.per)
52 |
53 | # create the replay buffer
54 | self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions, self.args.her, self.args.per)
55 |
56 | # create the normalizer
57 | self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range)
58 | self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range)
59 |
60 | # create the dict for store the model
61 | if MPI.COMM_WORLD.Get_rank() == 0:
62 | if not os.path.exists(self.args.save_dir):
63 | os.mkdir(self.args.save_dir)
64 |
65 | # path to save the model
66 | self.model_path = os.path.join(self.args.save_dir, self.args.env_name)
67 | if not os.path.exists(self.model_path):
68 | os.mkdir(self.model_path)
69 |
70 | def learn(self):
71 | """
72 | train the network
73 | """
74 |
75 | to_plot, sum_priority, reward_plot, i = [], 0, [], 0
76 | alpha = 0.5
77 | epsilon = 0.1
78 |
79 | # start to collect samples
80 |
81 | for epoch in range(self.args.n_epochs):
82 | self.r = 0
83 | for _ in range(self.args.n_cycles):
84 | mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p = [], [], [], [], [], []
85 | for _ in range(self.args.num_rollouts_per_mpi):
86 |
87 | # reset the rollouts
88 | ep_obs, ep_ag, ep_g, ep_actions, ep_r, ep_p = [], [], [], [], [], []
89 |
90 | # reset the environment
91 | observation = self.env.reset()
92 | obs = observation['observation']
93 | ag = observation['achieved_goal']
94 | g = observation['desired_goal']
95 |
96 | # start to collect samples
97 | for t in range(self.env_params['max_timesteps']):
98 | with torch.no_grad():
99 | input_tensor = self._preproc_inputs(obs, g)
100 | pi = self.actor_network(input_tensor)
101 | action = self._select_actions(pi)
102 |
103 | # feed the actions into the environment
104 | #observation_new, _, _, info = self.env.step(action)
105 | observation_new, r, _, info = self.env.step(action)
106 |
107 | obs_new = observation_new['observation']
108 | ag_new = observation_new['achieved_goal']
109 |
110 | if self.args.per == True:
111 | with torch.no_grad():
112 |
113 | #obs, ag = self._preproc_og(obs, ag)
114 | obs_norm, ag_norm = self._preproc_og(obs, g)
115 | #obs_norm = self.o_norm.normalize(obs)
116 | #ag_norm = self.g_norm.normalize(ag)
117 |
118 | inputs_norm = list(obs_norm) + list(ag_norm)
119 |
120 | inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
121 | action_tensor = torch.tensor(action, dtype=torch.float32)
122 |
123 | q_curr = self.critic_network(inputs_norm_tensor.view(1,-1), action_tensor.view(1,-1))
124 | q_curr_value = q_curr.detach()
125 |
126 | obs_next_norm, ag_next_norm = self._preproc_og(obs_new, g)
127 |
128 | #obs_next_norm = self.o_norm.normalize(obs_new)
129 | #ag_next_norm = self.g_norm.normalize(ag_new)
130 |
131 | inputs_next_norm = list(obs_next_norm) + list(ag_next_norm)
132 | inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32)
133 |
134 | pi_next = self.actor_network(inputs_norm_tensor)
135 | action_next = self._select_actions(pi_next)
136 | action_next_tensor = torch.tensor(action_next, dtype=torch.float32)
137 |
138 | q_next_target = self.critic_target_network(inputs_next_norm_tensor.view(1,-1), action_next_tensor.view(1,-1))
139 | q_next_value = q_next_target.detach()
140 |
141 | td_error = np.abs(q_curr_value - q_next_value)
142 |
143 | priority = (td_error + epsilon) ** alpha
144 | sum_priority += priority
145 | p_priority = priority / sum_priority
146 |
147 | ep_p.append(p_priority)
148 |
149 | # append rollouts
150 | ep_obs.append(obs.copy())
151 | ep_ag.append(ag.copy())
152 | ep_g.append(g.copy())
153 | ep_actions.append(action.copy())
154 |
155 | #if self.args.her == False:
156 | ep_r.append(r.copy())
157 |
158 | # re-assign the observation
159 | obs = obs_new
160 | ag = ag_new
161 |
162 | reward_plot.append(sum(ep_r))
163 |
164 | '''
165 | if MPI.COMM_WORLD.Get_rank() == 0:
166 | print("epissode", i, "reward", sum(ep_r))
167 | '''
168 |
169 | i = i + 1
170 |
171 | ep_obs.append(obs.copy())
172 | ep_ag.append(ag.copy())
173 |
174 | mb_obs.append(ep_obs)
175 | mb_ag.append(ep_ag)
176 | mb_g.append(ep_g)
177 | mb_actions.append(ep_actions)
178 |
179 | if self.args.her == False:
180 | mb_r.append(ep_r)
181 |
182 | if self.args.per == True:
183 | mb_p.append(ep_p)
184 |
185 | # convert them into arrays
186 | mb_obs = np.array(mb_obs)
187 | mb_ag = np.array(mb_ag)
188 | mb_g = np.array(mb_g)
189 | mb_actions = np.array(mb_actions)
190 |
191 | if self.args.her == False:
192 |
193 | if self.args.per == True:
194 | mb_r = np.array(mb_r)
195 | mb_p = np.array(mb_p)
196 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p])
197 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p])
198 |
199 | else:
200 | mb_r = np.array(mb_r)
201 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_r])
202 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions, mb_r])
203 |
204 |
205 | elif self.args.her == True:
206 | if self.args.per == True:
207 | mb_p = np.array(mb_p)
208 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_p])
209 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions, mb_p])
210 | else:
211 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions])
212 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions])
213 |
214 | for _ in range(self.args.n_batches):
215 | # train the network
216 | self._update_network()
217 |
218 | # soft update
219 | self._soft_update_target_network(self.actor_target_network, self.actor_network)
220 | self._soft_update_target_network(self.critic_target_network, self.critic_network)
221 |
222 | # start to do the evaluation
223 | success_rate = self._eval_agent()
224 |
225 | if MPI.COMM_WORLD.Get_rank() == 0:
226 | print('[{}] epoch is: {}, eval success rate is: {:.3f}'.format(datetime.now(), epoch, success_rate))
227 | to_plot.append(success_rate)
228 |
229 |
230 | torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std, self.actor_network.state_dict()], \
231 | self.model_path + '/model.pt')
232 |
233 | if MPI.COMM_WORLD.Get_rank() == 0:
234 | plt.plot(range(50*1*1), reward_plot)
235 | plt.plot(range(self.args.n_epochs), to_plot)
236 |
237 | if self.args.env_name == 'FetchReach-v1':
238 | plt.xlabel('Episode')
239 | else:
240 | plt.xlabel('Epoch')
241 |
242 | plt.ylabel('Mean Success Rate')
243 |
244 | if self.args.per == True and self.args.her == True:
245 | plt.title("{} using DDPG + HER + PER".format(self.args.env_name))
246 | plt.savefig("{}_DDPG_HER_PER".format(self.args.env_name))
247 | np.save("{}_DDPG_HER_PER".format(self.args.env_name), to_plot)
248 |
249 | elif self.args.per == True and self.args.her == False:
250 | plt.title("{} using DDPG + PER".format(self.args.env_name))
251 | plt.savefig("{}_DDPG_PER".format(self.args.env_name))
252 | np.save("{}_DDPG_PER".format(self.args.env_name), to_plot)
253 |
254 | elif self.args.per == False and self.args.her == True:
255 | plt.title("{} using DDPG + HER".format(self.args.env_name))
256 | plt.savefig("{}_DDPG_HER".format(self.args.env_name))
257 | np.save("{}_DDPG_HER".format(self.args.env_name), to_plot)
258 |
259 | elif self.args.per == False and self.args.her == False:
260 | plt.title("{} using DDPG".format(self.args.env_name))
261 | plt.savefig("{}_DDPG".format(self.args.env_name))
262 | np.save("{}_DDPG".format(self.args.env_name), to_plot)
263 |
264 | plt.show()
265 |
266 |
267 |
268 | # pre_process the inputs
269 | def _preproc_inputs(self, obs, g):
270 | obs_norm = self.o_norm.normalize(obs)
271 | g_norm = self.g_norm.normalize(g)
272 |
273 | # concatenate the stuffs
274 | inputs = np.concatenate([obs_norm, g_norm])
275 | inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0)
276 | if self.args.cuda:
277 | inputs = inputs.cuda()
278 |
279 | return inputs
280 |
281 | # this function will choose action for the agent and do the exploration
282 | def _select_actions(self, pi):
283 | action = pi.cpu().numpy().squeeze()
284 |
285 | # add the gaussian
286 | action += self.args.noise_eps * self.env_params['action_max'] * np.random.randn(*action.shape)
287 | action = np.clip(action, -self.env_params['action_max'], self.env_params['action_max'])
288 |
289 | # random actions...
290 | random_actions = np.random.uniform(low=-self.env_params['action_max'], high=self.env_params['action_max'], \
291 | size=self.env_params['action'])
292 |
293 | # choose if use the random actions
294 | action += np.random.binomial(1, self.args.random_eps, 1)[0] * (random_actions - action)
295 |
296 | return action
297 |
298 | # update the normalizer
299 | def _update_normalizer(self, episode_batch):
300 | if self.args.her == False:
301 | if self.args.per == True:
302 | mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p = episode_batch
303 | else:
304 | mb_obs, mb_ag, mb_g, mb_actions, mb_r = episode_batch
305 |
306 | elif self.args.her == True:
307 | if self.args.per == True:
308 | mb_obs, mb_ag, mb_g, mb_actions, mb_p = episode_batch
309 | else:
310 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch
311 |
312 | mb_obs_next = mb_obs[:, 1:, :]
313 | mb_ag_next = mb_ag[:, 1:, :]
314 |
315 | # get the number of normalization transitions
316 | num_transitions = mb_actions.shape[1]
317 |
318 | # create the new buffer to store them
319 |
320 | if self.args.her == False:
321 | if self.args.per == True:
322 | buffer_temp = {'obs': mb_obs,
323 | 'ag': mb_ag,
324 | 'g': mb_g,
325 | 'actions': mb_actions,
326 | 'obs_next': mb_obs_next,
327 | 'ag_next': mb_ag_next,
328 | 'r' : mb_r,
329 | 'p' : mb_p
330 | }
331 | else:
332 | buffer_temp = {'obs': mb_obs,
333 | 'ag': mb_ag,
334 | 'g': mb_g,
335 | 'actions': mb_actions,
336 | 'obs_next': mb_obs_next,
337 | 'ag_next': mb_ag_next,
338 | 'r' : mb_r
339 | }
340 |
341 |
342 | elif self.args.her == True:
343 | if self.args.per == True:
344 | buffer_temp = {'obs': mb_obs,
345 | 'ag': mb_ag,
346 | 'g': mb_g,
347 | 'actions': mb_actions,
348 | 'obs_next': mb_obs_next,
349 | 'ag_next': mb_ag_next,
350 | 'p' : mb_p
351 | }
352 | else:
353 | buffer_temp = {'obs': mb_obs,
354 | 'ag': mb_ag,
355 | 'g': mb_g,
356 | 'actions': mb_actions,
357 | 'obs_next': mb_obs_next,
358 | 'ag_next': mb_ag_next,
359 | }
360 |
361 | transitions = self.her_module.sample_her_transitions(buffer_temp, num_transitions)
362 | obs, g = transitions['obs'], transitions['g']
363 |
364 | # pre process the obs and g
365 | transitions['obs'], transitions['g'] = self._preproc_og(obs, g)
366 |
367 | # update
368 | self.o_norm.update(transitions['obs'])
369 | self.g_norm.update(transitions['g'])
370 |
371 | # recompute the stats
372 | self.o_norm.recompute_stats()
373 | self.g_norm.recompute_stats()
374 |
375 | def _preproc_og(self, o, g):
376 |
377 | o = np.clip(o, -self.args.clip_obs, self.args.clip_obs)
378 | g = np.clip(g, -self.args.clip_obs, self.args.clip_obs)
379 |
380 | return o, g
381 |
382 | # soft update
383 | def _soft_update_target_network(self, target, source):
384 | for target_param, param in zip(target.parameters(), source.parameters()):
385 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data)
386 |
387 | # update the network
388 | def _update_network(self):
389 |
390 | # sample the episodes
391 | transitions = self.buffer.sample(self.args.batch_size)
392 |
393 | # pre-process the observation and goal
394 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g']
395 | transitions['obs'], transitions['g'] = self._preproc_og(o, g)
396 | transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g)
397 |
398 | # start tof do the update
399 | obs_norm = self.o_norm.normalize(transitions['obs'])
400 | g_norm = self.g_norm.normalize(transitions['g'])
401 | inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
402 | obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
403 | g_next_norm = self.g_norm.normalize(transitions['g_next'])
404 | inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
405 |
406 | # transfer them into the tensor
407 | inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32)
408 | inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32)
409 | actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32)
410 | r_tensor = torch.tensor(transitions['r'], dtype=torch.float32)
411 |
412 | if self.args.cuda:
413 | inputs_norm_tensor = inputs_norm_tensor.cuda()
414 | inputs_next_norm_tensor = inputs_next_norm_tensor.cuda()
415 | actions_tensor = actions_tensor.cuda()
416 | r_tensor = r_tensor.cuda()
417 |
418 | # calculate the target Q value function
419 | with torch.no_grad():
420 | # do the normalization
421 | # concatenate the stuffs
422 | actions_next = self.actor_target_network(inputs_next_norm_tensor)
423 | q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next)
424 | q_next_value = q_next_value.detach()
425 | target_q_value = r_tensor + self.args.gamma * q_next_value
426 | target_q_value = target_q_value.detach()
427 |
428 | # clip the q value
429 | clip_return = 1 / (1 - self.args.gamma)
430 | target_q_value = torch.clamp(target_q_value, -clip_return, 0)
431 |
432 | # the q loss
433 | real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor)
434 | critic_loss = (target_q_value - real_q_value).pow(2).mean()
435 |
436 | # the actor loss
437 | actions_real = self.actor_network(inputs_norm_tensor)
438 | actor_loss = -self.critic_network(inputs_norm_tensor, actions_real).mean()
439 | actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean()
440 |
441 | # start to update the network
442 | self.actor_optim.zero_grad()
443 | actor_loss.backward()
444 | sync_grads(self.actor_network)
445 | self.actor_optim.step()
446 |
447 | # update the critic_network
448 | self.critic_optim.zero_grad()
449 | critic_loss.backward()
450 | sync_grads(self.critic_network)
451 | self.critic_optim.step()
452 |
453 | # do the evaluation
454 | def _eval_agent(self):
455 | total_success_rate = []
456 | for _ in range(self.args.n_test_rollouts):
457 | per_success_rate = []
458 | observation = self.env.reset()
459 | obs = observation['observation']
460 | g = observation['desired_goal']
461 |
462 | for _ in range(self.env_params['max_timesteps']):
463 | with torch.no_grad():
464 | input_tensor = self._preproc_inputs(obs, g)
465 | pi = self.actor_network(input_tensor)
466 |
467 | # convert the actions
468 | actions = pi.detach().cpu().numpy().squeeze()
469 |
470 | observation_new, _, _, info = self.env.step(actions)
471 | obs = observation_new['observation']
472 | g = observation_new['desired_goal']
473 | per_success_rate.append(info['is_success'])
474 |
475 | total_success_rate.append(per_success_rate)
476 |
477 | total_success_rate = np.array(total_success_rate)
478 | local_success_rate = np.mean(total_success_rate[:, -1])
479 | global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM)
480 |
481 | return global_success_rate / MPI.COMM_WORLD.Get_size()
482 |
483 |
--------------------------------------------------------------------------------