├── plot_data ├── fetch_push_final.mp4 ├── FetchPush-v1_DDPG.npy ├── FetchReach-v1_DDPG.npy ├── FetchSlide-v1_DDPG.npy ├── FetchPush-v1_DDPG_HER.npy ├── FetchPush-v1_DDPG_PER.npy ├── FetchReach-v1_DDPG_HER.npy ├── FetchReach-v1_DDPG_PER.npy ├── FetchSlide-v1_DDPG_HER.npy ├── FetchSlide-v1_DDPG_PER.npy ├── FetchPickAndPlace-v1_DDPG.npy ├── FetchPush-v1_DDPG_HER_PER.npy ├── FetchSlide-v1_DDPG_HER_PER.npy ├── FetchPickAndPlace-v1_DDPG_HER.npy ├── FetchPickAndPlace-v1_DDPG_PER.npy ├── FetchPush-v1_DDPG_HER_PER_10.npy ├── FetchReach-v1_DDPG_HER_PER_10.npy ├── FetchReach-v1_DDPG_HER_PER_2.npy ├── FetchReach-v1_DDPG_HER_PER_4.npy ├── FetchReach-v1_DDPG_HER_PER_6.npy ├── FetchReach-v1_DDPG_HER_PER_8.npy ├── FetchSlide-v1_DDPG_HER_PER_10.npy └── FetchPickAndPlace-v1_DDPG_HER_PER.npy ├── plots ├── all_plots.png ├── all_plots_fp.png ├── all_plots_fr.png ├── all_plots_fs.png ├── alpha_plots_fp.png └── alpha_plots_fs.png ├── models ├── fetch_push_model.pt ├── fetch_reach_model.pt ├── fetch_slide_model.pt └── fetch_picknplace_model.pt ├── train.py ├── models.py ├── README.md ├── her.py ├── demo.py ├── utils.py ├── arguments.py ├── viz.py ├── normalizer.py ├── replay_buffer.py └── ddpg_agent.py /plot_data/fetch_push_final.mp4: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /plots/all_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots.png -------------------------------------------------------------------------------- /plots/all_plots_fp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots_fp.png -------------------------------------------------------------------------------- /plots/all_plots_fr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots_fr.png -------------------------------------------------------------------------------- /plots/all_plots_fs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/all_plots_fs.png -------------------------------------------------------------------------------- /plots/alpha_plots_fp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/alpha_plots_fp.png -------------------------------------------------------------------------------- /plots/alpha_plots_fs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plots/alpha_plots_fs.png -------------------------------------------------------------------------------- /models/fetch_push_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_push_model.pt -------------------------------------------------------------------------------- /models/fetch_reach_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_reach_model.pt -------------------------------------------------------------------------------- /models/fetch_slide_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_slide_model.pt -------------------------------------------------------------------------------- /plot_data/FetchPush-v1_DDPG.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG.npy -------------------------------------------------------------------------------- /models/fetch_picknplace_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/models/fetch_picknplace_model.pt -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG.npy -------------------------------------------------------------------------------- /plot_data/FetchSlide-v1_DDPG.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG.npy -------------------------------------------------------------------------------- /plot_data/FetchPush-v1_DDPG_HER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_HER.npy -------------------------------------------------------------------------------- /plot_data/FetchPush-v1_DDPG_PER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_PER.npy -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG_HER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER.npy -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG_PER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_PER.npy -------------------------------------------------------------------------------- /plot_data/FetchSlide-v1_DDPG_HER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_HER.npy -------------------------------------------------------------------------------- /plot_data/FetchSlide-v1_DDPG_PER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_PER.npy -------------------------------------------------------------------------------- /plot_data/FetchPickAndPlace-v1_DDPG.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG.npy -------------------------------------------------------------------------------- /plot_data/FetchPush-v1_DDPG_HER_PER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_HER_PER.npy -------------------------------------------------------------------------------- /plot_data/FetchSlide-v1_DDPG_HER_PER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_HER_PER.npy -------------------------------------------------------------------------------- /plot_data/FetchPickAndPlace-v1_DDPG_HER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG_HER.npy -------------------------------------------------------------------------------- /plot_data/FetchPickAndPlace-v1_DDPG_PER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG_PER.npy -------------------------------------------------------------------------------- /plot_data/FetchPush-v1_DDPG_HER_PER_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPush-v1_DDPG_HER_PER_10.npy -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG_HER_PER_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_10.npy -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG_HER_PER_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_2.npy -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG_HER_PER_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_4.npy -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG_HER_PER_6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_6.npy -------------------------------------------------------------------------------- /plot_data/FetchReach-v1_DDPG_HER_PER_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchReach-v1_DDPG_HER_PER_8.npy -------------------------------------------------------------------------------- /plot_data/FetchSlide-v1_DDPG_HER_PER_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchSlide-v1_DDPG_HER_PER_10.npy -------------------------------------------------------------------------------- /plot_data/FetchPickAndPlace-v1_DDPG_HER_PER.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sush1996/DDPG_Fetch/HEAD/plot_data/FetchPickAndPlace-v1_DDPG_HER_PER.npy -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import os, sys 4 | from arguments import get_args 5 | from mpi4py import MPI 6 | from subprocess import CalledProcessError 7 | from ddpg_agent import ddpg_agent 8 | 9 | """ 10 | train the agent, the MPI part code is copy from openai baselines(https://github.com/openai/baselines/blob/master/baselines/her) 11 | 12 | """ 13 | def get_env_params(env): 14 | 15 | obs = env.reset() 16 | 17 | # close the environment 18 | params = {'obs': obs['observation'].shape[0], 19 | 'goal': obs['desired_goal'].shape[0], 20 | 'action': env.action_space.shape[0], 21 | 'action_max': env.action_space.high[0], 22 | } 23 | params['max_timesteps'] = env._max_episode_steps 24 | 25 | return params 26 | 27 | def launch(args): 28 | # create the ddpg_agent 29 | env = gym.make(args.env_name) 30 | env.reward_type = 'dense' 31 | # get the environment parameters 32 | env_params = get_env_params(env) 33 | 34 | # create the ddpg agent to interact with the environment 35 | ddpg_trainer = ddpg_agent(args, env, env_params) 36 | ddpg_trainer.learn() 37 | 38 | if __name__ == '__main__': 39 | 40 | # take the configuration for the HER 41 | os.environ['OMP_NUM_THREADS'] = '1' 42 | os.environ['MKL_NUM_THREADS'] = '1' 43 | os.environ['IN_MPI'] = '1' 44 | 45 | # get the params 46 | args = get_args() 47 | launch(args) 48 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | the input x in both networks should be [o, g], where o is the observation and g is the goal. 7 | 8 | """ 9 | 10 | # define the actor network 11 | class actor(nn.Module): 12 | def __init__(self, env_params): 13 | super(actor, self).__init__() 14 | 15 | self.max_action = env_params['action_max'] 16 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256) 17 | self.fc2 = nn.Linear(256, 256) 18 | self.fc3 = nn.Linear(256, 256) 19 | self.action_out = nn.Linear(256, env_params['action']) 20 | 21 | def forward(self, x): 22 | 23 | x = F.relu(self.fc1(x)) 24 | x = F.relu(self.fc2(x)) 25 | x = F.relu(self.fc3(x)) 26 | actions = self.max_action * torch.tanh(self.action_out(x)) 27 | 28 | return actions 29 | 30 | class critic(nn.Module): 31 | def __init__(self, env_params): 32 | super(critic, self).__init__() 33 | 34 | self.max_action = env_params['action_max'] 35 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256) 36 | self.fc2 = nn.Linear(256, 256) 37 | self.fc3 = nn.Linear(256, 256) 38 | self.q_out = nn.Linear(256, 1) 39 | 40 | def forward(self, x, actions): 41 | x = torch.cat([x, actions / self.max_action], dim=1) 42 | x = F.relu(self.fc1(x)) 43 | x = F.relu(self.fc2(x)) 44 | x = F.relu(self.fc3(x)) 45 | q_value = self.q_out(x) 46 | 47 | return q_value 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDPG_Fetch 2 | Exploring the performance of Prioritized Experience Replay (PER) with the DDPG+HER scheme on the Fetch Robotics Environemnt 3 | 4 | Plots for Mean Success Rates for different Fetch Environments 5 | 6 |

7 | 8 | 9 | 10 |

11 | 12 |

13 | 14 | 15 |

16 | 17 | Performance Plots when varying the alpha parameter on PER 18 |

19 | 20 | 21 |

22 | 23 | * Correction: The plot on the right is for FetchSlide but has been mistakenly labelled as FetchPush 24 | 25 | 26 | 27 | 28 | Addition of PER along with finetuning the alpha parameter boosts its performance. 29 | 30 | The inclusion of the PER algo within the DDPG-HER framework can be done in many ways, it could give greater performance boosts if combined well. 31 | (The integration of PER in this code isn't perfect, just something I tried out over a weekend) 32 | 33 | Use the command below to start training. (Avoid using sudo, if you get an ```EXPORT LIBRARY.. .bashrc``` error) 34 | 35 | ``` 36 | mpirun -np 19 python3 train.py 37 | ``` 38 | -------------------------------------------------------------------------------- /her.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class her_sampler: 4 | def __init__(self, replay_strategy, replay_k, reward_func=None, her=False, per=False): 5 | self.replay_strategy = replay_strategy 6 | self.replay_k = replay_k 7 | self.her = her 8 | self.per = per 9 | 10 | if self.replay_strategy == 'future': 11 | self.future_p = 1 - (1. / (1 + replay_k)) 12 | else: 13 | self.future_p = 0 14 | 15 | self.reward_func = reward_func 16 | 17 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 18 | 19 | T = episode_batch['actions'].shape[1] 20 | rollout_batch_size = episode_batch['actions'].shape[0] 21 | batch_size = batch_size_in_transitions 22 | 23 | # select which rollouts and which timesteps to be used 24 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 25 | 26 | if self.per == False: 27 | t_samples = np.random.randint(T, size=batch_size) 28 | else: 29 | p = episode_batch['p'] 30 | sum_p = sum(p[0]) 31 | p_norm = [p_elem/sum_p for p_elem in p[0]] 32 | t_samples = np.random.choice(T, size=batch_size, p=p_norm) 33 | 34 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()} 35 | 36 | if self.her==True: 37 | 38 | # her idx 39 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p) 40 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples) 41 | future_offset = future_offset.astype(int) 42 | future_t = (t_samples + 1 + future_offset)[her_indexes] 43 | 44 | # replace go with achieved goal 45 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] 46 | transitions['g'][her_indexes] = future_ag 47 | 48 | transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1) 49 | 50 | # to get the params to re-compute reward 51 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()} 52 | 53 | 54 | return transitions 55 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import actor 3 | from arguments import get_args 4 | import gym 5 | import numpy as np 6 | 7 | # process the inputs 8 | def process_inputs(o, g, o_mean, o_std, g_mean, g_std, args): 9 | o_clip = np.clip(o, -args.clip_obs, args.clip_obs) 10 | g_clip = np.clip(g, -args.clip_obs, args.clip_obs) 11 | o_norm = np.clip((o_clip - o_mean) / (o_std), -args.clip_range, args.clip_range) 12 | g_norm = np.clip((g_clip - g_mean) / (g_std), -args.clip_range, args.clip_range) 13 | inputs = np.concatenate([o_norm, g_norm]) 14 | inputs = torch.tensor(inputs, dtype=torch.float32) 15 | 16 | return inputs 17 | 18 | if __name__ == '__main__': 19 | args = get_args() 20 | 21 | # load the model param 22 | model_path = args.save_dir + args.env_name + '/model.pt' 23 | o_mean, o_std, g_mean, g_std, model = torch.load(model_path, map_location=lambda storage, loc: storage) 24 | 25 | # create the environment 26 | env = gym.make(args.env_name) 27 | # get the env param 28 | observation = env.reset() 29 | 30 | # get the environment params 31 | env_params = {'obs': observation['observation'].shape[0], 32 | 'goal': observation['desired_goal'].shape[0], 33 | 'action': env.action_space.shape[0], 34 | 'action_max': env.action_space.high[0], 35 | } 36 | 37 | # create the actor network 38 | actor_network = actor(env_params) 39 | actor_network.load_state_dict(model) 40 | actor_network.eval() 41 | 42 | for i in range(args.demo_length): 43 | observation = env.reset() 44 | 45 | # start to do the demo 46 | obs = observation['observation'] 47 | g = observation['desired_goal'] 48 | 49 | for t in range(env._max_episode_steps): 50 | env.render() 51 | inputs = process_inputs(obs, g, o_mean, o_std, g_mean, g_std, args) 52 | 53 | with torch.no_grad(): 54 | pi = actor_network(inputs) 55 | action = pi.detach().numpy().squeeze() 56 | 57 | # put actions into the environment 58 | observation_new, reward, _, info = env.step(action) 59 | obs = observation_new['observation'] 60 | 61 | print('the episode is: {}, is success: {}'.format(i, info['is_success'])) 62 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | import torch 4 | 5 | # sync_networks across the different cores 6 | def sync_networks(network): 7 | """ 8 | netowrk is the network you want to sync 9 | 10 | """ 11 | comm = MPI.COMM_WORLD 12 | flat_params, params_shape = _get_flat_params(network) 13 | comm.Bcast(flat_params, root=0) 14 | # set the flat params back to the network 15 | _set_flat_params(network, params_shape, flat_params) 16 | 17 | # get the flat params from the network 18 | def _get_flat_params(network): 19 | param_shape = {} 20 | flat_params = None 21 | for key_name, value in network.named_parameters(): 22 | param_shape[key_name] = value.detach().numpy().shape 23 | if flat_params is None: 24 | flat_params = value.detach().numpy().flatten() 25 | else: 26 | flat_params = np.append(flat_params, value.detach().numpy().flatten()) 27 | return flat_params, param_shape 28 | 29 | # set the params from the network 30 | def _set_flat_params(network, params_shape, params): 31 | pointer = 0 32 | for key_name, values in network.named_parameters(): 33 | 34 | # get the length of the parameters 35 | len_param = np.prod(params_shape[key_name]) 36 | copy_params = params[pointer:pointer + len_param].reshape(params_shape[key_name]) 37 | copy_params = torch.tensor(copy_params) 38 | 39 | # copy the params 40 | values.data.copy_(copy_params.data) 41 | 42 | # update the pointer 43 | pointer += len_param 44 | 45 | # sync the networks 46 | def sync_grads(network): 47 | flat_grads, grads_shape = _get_flat_grads(network) 48 | comm = MPI.COMM_WORLD 49 | global_grads = np.zeros_like(flat_grads) 50 | comm.Allreduce(flat_grads, global_grads, op=MPI.SUM) 51 | _set_flat_grads(network, grads_shape, global_grads) 52 | 53 | def _set_flat_grads(network, grads_shape, flat_grads): 54 | pointer = 0 55 | 56 | for key_name, value in network.named_parameters(): 57 | len_grads = np.prod(grads_shape[key_name]) 58 | copy_grads = flat_grads[pointer:pointer + len_grads].reshape(grads_shape[key_name]) 59 | copy_grads = torch.tensor(copy_grads) 60 | 61 | # copy the grads 62 | value.grad.data.copy_(copy_grads.data) 63 | pointer += len_grads 64 | 65 | def _get_flat_grads(network): 66 | grads_shape = {} 67 | flat_grads = None 68 | 69 | for key_name, value in network.named_parameters(): 70 | grads_shape[key_name] = value.grad.data.cpu().numpy().shape 71 | 72 | if flat_grads is None: 73 | flat_grads = value.grad.data.cpu().numpy().flatten() 74 | else: 75 | flat_grads = np.append(flat_grads, value.grad.data.cpu().numpy().flatten()) 76 | 77 | return flat_grads, grads_shape 78 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | """ 4 | Here are the param for the training 5 | 6 | """ 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | # the environment setting 11 | parser.add_argument('--env-name', type=str, default='FetchReach-v1', help='the environment name') 12 | parser.add_argument('--n-epochs', type=int, default=50, help='the number of epochs to train the agent') 13 | parser.add_argument('--n-cycles', type=int, default=1, help='the times to collect samples per epoch') 14 | parser.add_argument('--n-batches', type=int, default=40, help='the times to update the network') 15 | parser.add_argument('--save-interval', type=int, default=5, help='the interval that save the trajectory') 16 | parser.add_argument('--seed', type=int, default=123, help='random seed') 17 | parser.add_argument('--num-workers', type=int, default=1, help='the number of cpus to collect samples') 18 | parser.add_argument('--replay-strategy', type=str, default='future', help='the HER strategy') 19 | parser.add_argument('--clip-return', type=float, default=50, help='if clip the returns') 20 | parser.add_argument('--save-dir', type=str, default='saved_models/', help='the path to save the models') 21 | parser.add_argument('--noise-eps', type=float, default=0.2, help='noise eps') 22 | parser.add_argument('--random-eps', type=float, default=0.3, help='random eps') 23 | parser.add_argument('--buffer-size', type=int, default=int(1e6), help='the size of the buffer') 24 | parser.add_argument('--replay-k', type=int, default=4, help='ratio to be replace') 25 | parser.add_argument('--clip-obs', type=float, default=200, help='the clip ratio') 26 | parser.add_argument('--batch-size', type=int, default=256, help='the sample batch size') 27 | parser.add_argument('--gamma', type=float, default=0.98, help='the discount factor') 28 | parser.add_argument('--action-l2', type=float, default=1, help='l2 reg') 29 | parser.add_argument('--lr-actor', type=float, default=0.001, help='the learning rate of the actor') 30 | parser.add_argument('--lr-critic', type=float, default=0.001, help='the learning rate of the critic') 31 | parser.add_argument('--polyak', type=float, default=0.95, help='the average coefficient') 32 | parser.add_argument('--n-test-rollouts', type=int, default=10, help='the number of tests') 33 | parser.add_argument('--clip-range', type=float, default=5, help='the clip range') 34 | parser.add_argument('--demo-length', type=int, default=10, help='the demo length') 35 | parser.add_argument('--cuda', action='store_true', help='if use gpu do the acceleration') 36 | parser.add_argument('--num-rollouts-per-mpi', type=int, default=1, help='the rollouts per mpi') 37 | parser.add_argument('--her', type=bool, default=False, help='is HER True or False') 38 | parser.add_argument('--per', type=bool, default=False, help='is PER True or False') 39 | 40 | 41 | args = parser.parse_args() 42 | 43 | return args 44 | -------------------------------------------------------------------------------- /viz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | sns.set() 5 | 6 | #FetchReach Plots 7 | 8 | env = 'FetchSlide-v1' 9 | 10 | ''' 11 | ddpg = np.load('{}_DDPG.npy'.format(env)) 12 | ddpg_per = np.load('{}_DDPG_PER.npy'.format(env)) 13 | ddpg_her = np.load('{}_DDPG_HER.npy'.format(env)) 14 | ''' 15 | 16 | 17 | ddpg_her_per_0 = np.load('{}_DDPG_HER.npy'.format(env)) 18 | ddpg_her_per_5 = np.load('{}_DDPG_HER_PER.npy'.format(env)) 19 | ddpg_her_per_10 = np.load('{}_DDPG_HER_PER_10.npy'.format(env)) 20 | #ddpg_her_per = np.load('{}_DDPG_HER_PER.npy'.format(env)) 21 | 22 | ''' 23 | ep_n_0 = np.where(ddpg_her_per_0 == 1.0)[0][0]#[i for i in ddpg_her_per_0 if i == 1.0][0] 24 | ep_n_2 = np.where(ddpg_her_per_2 == 1.0)[0][0]#[i for i in ddpg_her_per_2 if i == 1.0][0] 25 | ep_n_4 = np.where(ddpg_her_per_4 == 1.0)[0][0]#[i for i in ddpg_her_per_4 if i == 1.0][0] 26 | ep_n_6 = np.where(ddpg_her_per_6 == 1.0)[0][0]#[i for i in ddpg_her_per_6 if i == 1.0][0] 27 | ep_n_8 = np.where(ddpg_her_per_8 == 1.0)[0][0]#[i for i in ddpg_her_per_8 if i == 1.0][0] 28 | ep_n_10 = np.where(ddpg_her_per_10 == 1.0)[0][0]#[i for i in ddpg_her_per_10 if i == 1.0][0] 29 | 30 | 31 | plt.bar(np.arange(len([0.0,0.2,0.4,0.6,0.8,1.0])), [ep_n_0, ep_n_2, ep_n_4, ep_n_6, ep_n_8, ep_n_10]) 32 | plt.show() 33 | ''' 34 | n = range(50)#range(len(ddpg)) 35 | plt.plot(n, ddpg_her_per_0, label='alpha = 0.0') 36 | #plt.plot(n, ddpg_her_per_2, label='alpha = 0.2') 37 | plt.plot(n, ddpg_her_per_5, label='alpha = 0.5') 38 | ''' 39 | plt.plot(n, ddpg_her_per_6, label='alpha = 0.6') 40 | plt.plot(n, ddpg_her_per_8, label='alpha = 0.8') 41 | ''' 42 | plt.plot(n, ddpg_her_per_10, label='alpha = 1.0') 43 | plt.xlabel('Epoch') 44 | plt.ylabel('Mean Success Rate') 45 | plt.title('Mean Success Rate variation with alpha for DDPG+HER+PER on FetchPush-v1') 46 | plt.legend() 47 | plt.savefig('alpha_plots_fr') 48 | plt.show() 49 | 50 | 51 | ''' 52 | #Plot 1 : All Plots 53 | plt.plot(n, ddpg, label='DDPG') 54 | plt.plot(n, ddpg_per, label='DDPG + PER') 55 | plt.plot(n, ddpg_her, label='DDPG + HER') 56 | plt.plot(n, ddpg_her_per, label='DDPG + HER + PER') 57 | plt.xlabel('Epoch') 58 | plt.ylabel('Mean Success Rate') 59 | plt.title('Mean Success Rate Plots for {}'.format(env)) 60 | plt.legend() 61 | plt.savefig("{}_plots/all_plots".format(env)) 62 | plt.show() 63 | 64 | #Plot 2 : DDPG and PER variants 65 | plt.plot(n, ddpg, label='DDPG') 66 | plt.plot(n, ddpg_per, label='DDPG + PER') 67 | plt.xlabel('Epoch') 68 | plt.ylabel('Mean Success Rate') 69 | plt.title('Mean Success Rate Plots for {}'.format(env)) 70 | plt.legend() 71 | plt.savefig("{}_plots/ddpg_per_plots".format(env)) 72 | plt.show() 73 | 74 | #Plot 3: HER variants 75 | plt.plot(n, ddpg_her, label='DDPG + HER') 76 | plt.plot(n, ddpg_her_per, label='DDPG + HER + PER') 77 | plt.xlabel('Epoch') 78 | plt.ylabel('Mean Success Rate') 79 | plt.title('Mean Success Rate Plots for {}'.format(env)) 80 | plt.legend() 81 | plt.savefig("{}_plots/ddpg_her_per_plots".format(env)) 82 | plt.show() 83 | ''' -------------------------------------------------------------------------------- /normalizer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import numpy as np 3 | from mpi4py import MPI 4 | 5 | class normalizer: 6 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf): 7 | 8 | self.size = size 9 | self.eps = eps 10 | self.default_clip_range = default_clip_range 11 | 12 | # some local information 13 | self.local_sum = np.zeros(self.size, np.float32) 14 | self.local_sumsq = np.zeros(self.size, np.float32) 15 | self.local_count = np.zeros(1, np.float32) 16 | 17 | # get the total sum sumsq and sum count 18 | self.total_sum = np.zeros(self.size, np.float32) 19 | self.total_sumsq = np.zeros(self.size, np.float32) 20 | self.total_count = np.ones(1, np.float32) 21 | 22 | # get the mean and std 23 | self.mean = np.zeros(self.size, np.float32) 24 | self.std = np.ones(self.size, np.float32) 25 | 26 | # thread locker 27 | self.lock = threading.Lock() 28 | 29 | # update the parameters of the normalizer 30 | def update(self, v): 31 | v = v.reshape(-1, self.size) 32 | 33 | # do the computing 34 | with self.lock: 35 | self.local_sum += v.sum(axis=0) 36 | self.local_sumsq += (np.square(v)).sum(axis=0) 37 | self.local_count[0] += v.shape[0] 38 | 39 | # sync the parameters across the cpus 40 | def sync(self, local_sum, local_sumsq, local_count): 41 | local_sum[...] = self._mpi_average(local_sum) 42 | local_sumsq[...] = self._mpi_average(local_sumsq) 43 | local_count[...] = self._mpi_average(local_count) 44 | 45 | return local_sum, local_sumsq, local_count 46 | 47 | def recompute_stats(self): 48 | with self.lock: 49 | local_count = self.local_count.copy() 50 | local_sum = self.local_sum.copy() 51 | local_sumsq = self.local_sumsq.copy() 52 | # reset 53 | self.local_count[...] = 0 54 | self.local_sum[...] = 0 55 | self.local_sumsq[...] = 0 56 | 57 | # synrc the stats 58 | sync_sum, sync_sumsq, sync_count = self.sync(local_sum, local_sumsq, local_count) 59 | 60 | # update the total stuff 61 | self.total_sum += sync_sum 62 | self.total_sumsq += sync_sumsq 63 | self.total_count += sync_count 64 | 65 | # calculate the new mean and std 66 | self.mean = self.total_sum / self.total_count 67 | self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square(self.total_sum / self.total_count))) 68 | 69 | # average across the cpu's data 70 | def _mpi_average(self, x): 71 | buf = np.zeros_like(x) 72 | MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM) 73 | buf /= MPI.COMM_WORLD.Get_size() 74 | 75 | return buf 76 | 77 | # normalize the observation 78 | def normalize(self, v, clip_range=None): 79 | if clip_range is None: 80 | clip_range = self.default_clip_range 81 | 82 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range) 83 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import numpy as np 3 | 4 | """ 5 | the replay buffer here is basically from the openai baselines code 6 | 7 | """ 8 | class replay_buffer: 9 | def __init__(self, env_params, buffer_size, sample_func, her, per): 10 | self.env_params = env_params 11 | self.T = env_params['max_timesteps'] 12 | self.size = buffer_size // self.T 13 | 14 | # memory management 15 | self.current_size = 0 16 | self.n_transitions_stored = 0 17 | self.sample_func = sample_func 18 | self.per = per 19 | self.her = her 20 | 21 | # create the buffer to store info 22 | if self.her == False: 23 | if self.per == True: 24 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 25 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]), 26 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 27 | 'actions': np.empty([self.size, self.T, self.env_params['action']]), 28 | 'r': np.empty([self.size, self.T]), 29 | 'p': np.empty([self.size, self.T]) 30 | } 31 | else: 32 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 33 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]), 34 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 35 | 'actions': np.empty([self.size, self.T, self.env_params['action']]), 36 | 'r': np.empty([self.size, self.T]) 37 | } 38 | 39 | elif self.her == True: 40 | if self.per == True: 41 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 42 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]), 43 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 44 | 'actions': np.empty([self.size, self.T, self.env_params['action']]), 45 | 'p': np.empty([self.size, self.T]) 46 | } 47 | else: 48 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 49 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]), 50 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 51 | 'actions': np.empty([self.size, self.T, self.env_params['action']]), 52 | } 53 | 54 | # thread lock 55 | self.lock = threading.Lock() 56 | 57 | # store the episode 58 | def store_episode(self, episode_batch): 59 | 60 | if self.her == False: 61 | if self.per == True: 62 | mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p = episode_batch 63 | else: 64 | mb_obs, mb_ag, mb_g, mb_actions, mb_r = episode_batch 65 | else: 66 | if self.per == True: 67 | mb_obs, mb_ag, mb_g, mb_actions, mb_p = episode_batch 68 | else: 69 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch 70 | 71 | batch_size = mb_obs.shape[0] 72 | 73 | with self.lock: 74 | idxs = self._get_storage_idx(inc=batch_size) 75 | 76 | # store the informations 77 | self.buffers['obs'][idxs] = mb_obs 78 | self.buffers['ag'][idxs] = mb_ag 79 | self.buffers['g'][idxs] = mb_g 80 | self.buffers['actions'][idxs] = mb_actions 81 | 82 | if self.her == False: 83 | self.buffers['r'][idxs] = mb_r 84 | 85 | if self.per == True: 86 | self.buffers['p'][idxs] = mb_p 87 | 88 | elif self.her == True: 89 | if self.per == True: 90 | self.buffers['p'][idxs] = mb_p 91 | 92 | 93 | self.n_transitions_stored += self.T * batch_size 94 | 95 | # sample the data from the replay buffer 96 | def sample(self, batch_size): 97 | 98 | temp_buffers = {} 99 | 100 | with self.lock: 101 | for key in self.buffers.keys(): 102 | temp_buffers[key] = self.buffers[key][:self.current_size].copy() 103 | 104 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :] 105 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :] 106 | 107 | # sample transitions 108 | transitions = self.sample_func(temp_buffers, batch_size) 109 | 110 | return transitions 111 | 112 | def _get_storage_idx(self, inc=None): 113 | 114 | inc = inc or 1 115 | 116 | if self.current_size+inc <= self.size: 117 | idx = np.arange(self.current_size, self.current_size+inc) 118 | 119 | elif self.current_size < self.size: 120 | overflow = inc - (self.size - self.current_size) 121 | idx_a = np.arange(self.current_size, self.size) 122 | idx_b = np.random.randint(0, self.current_size, overflow) 123 | idx = np.concatenate([idx_a, idx_b]) 124 | else: 125 | idx = np.random.randint(0, self.size, inc) 126 | 127 | self.current_size = min(self.size, self.current_size+inc) 128 | 129 | if inc == 1: 130 | idx = idx[0] 131 | 132 | return idx -------------------------------------------------------------------------------- /ddpg_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | from mpi4py import MPI 6 | from models import actor, critic 7 | from utils import sync_networks, sync_grads 8 | from replay_buffer import replay_buffer 9 | from normalizer import normalizer 10 | from her import her_sampler 11 | import matplotlib.pyplot as plt 12 | 13 | """ 14 | ddpg with HER (MPI-version) 15 | 16 | """ 17 | class ddpg_agent: 18 | def __init__(self, args, env, env_params): 19 | self.args = args 20 | self.env = env 21 | self.env_params = env_params 22 | 23 | # create the network 24 | self.actor_network = actor(env_params) 25 | self.critic_network = critic(env_params) 26 | 27 | # sync the networks across the cpus 28 | sync_networks(self.actor_network) 29 | sync_networks(self.critic_network) 30 | 31 | # build up the target network 32 | self.actor_target_network = actor(env_params) 33 | self.critic_target_network = critic(env_params) 34 | 35 | # load the weights into the target networks 36 | self.actor_target_network.load_state_dict(self.actor_network.state_dict()) 37 | self.critic_target_network.load_state_dict(self.critic_network.state_dict()) 38 | 39 | # if use gpu 40 | if self.args.cuda: 41 | self.actor_network.cuda() 42 | self.critic_network.cuda() 43 | self.actor_target_network.cuda() 44 | self.critic_target_network.cuda() 45 | 46 | # create the optimizer 47 | self.actor_optim = torch.optim.Adam(self.actor_network.parameters(), lr=self.args.lr_actor) 48 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic) 49 | 50 | # her sampler 51 | self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, self.env.compute_reward, self.args.her, self.args.per) 52 | 53 | # create the replay buffer 54 | self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions, self.args.her, self.args.per) 55 | 56 | # create the normalizer 57 | self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range) 58 | self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range) 59 | 60 | # create the dict for store the model 61 | if MPI.COMM_WORLD.Get_rank() == 0: 62 | if not os.path.exists(self.args.save_dir): 63 | os.mkdir(self.args.save_dir) 64 | 65 | # path to save the model 66 | self.model_path = os.path.join(self.args.save_dir, self.args.env_name) 67 | if not os.path.exists(self.model_path): 68 | os.mkdir(self.model_path) 69 | 70 | def learn(self): 71 | """ 72 | train the network 73 | """ 74 | 75 | to_plot, sum_priority, reward_plot, i = [], 0, [], 0 76 | alpha = 0.5 77 | epsilon = 0.1 78 | 79 | # start to collect samples 80 | 81 | for epoch in range(self.args.n_epochs): 82 | self.r = 0 83 | for _ in range(self.args.n_cycles): 84 | mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p = [], [], [], [], [], [] 85 | for _ in range(self.args.num_rollouts_per_mpi): 86 | 87 | # reset the rollouts 88 | ep_obs, ep_ag, ep_g, ep_actions, ep_r, ep_p = [], [], [], [], [], [] 89 | 90 | # reset the environment 91 | observation = self.env.reset() 92 | obs = observation['observation'] 93 | ag = observation['achieved_goal'] 94 | g = observation['desired_goal'] 95 | 96 | # start to collect samples 97 | for t in range(self.env_params['max_timesteps']): 98 | with torch.no_grad(): 99 | input_tensor = self._preproc_inputs(obs, g) 100 | pi = self.actor_network(input_tensor) 101 | action = self._select_actions(pi) 102 | 103 | # feed the actions into the environment 104 | #observation_new, _, _, info = self.env.step(action) 105 | observation_new, r, _, info = self.env.step(action) 106 | 107 | obs_new = observation_new['observation'] 108 | ag_new = observation_new['achieved_goal'] 109 | 110 | if self.args.per == True: 111 | with torch.no_grad(): 112 | 113 | #obs, ag = self._preproc_og(obs, ag) 114 | obs_norm, ag_norm = self._preproc_og(obs, g) 115 | #obs_norm = self.o_norm.normalize(obs) 116 | #ag_norm = self.g_norm.normalize(ag) 117 | 118 | inputs_norm = list(obs_norm) + list(ag_norm) 119 | 120 | inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) 121 | action_tensor = torch.tensor(action, dtype=torch.float32) 122 | 123 | q_curr = self.critic_network(inputs_norm_tensor.view(1,-1), action_tensor.view(1,-1)) 124 | q_curr_value = q_curr.detach() 125 | 126 | obs_next_norm, ag_next_norm = self._preproc_og(obs_new, g) 127 | 128 | #obs_next_norm = self.o_norm.normalize(obs_new) 129 | #ag_next_norm = self.g_norm.normalize(ag_new) 130 | 131 | inputs_next_norm = list(obs_next_norm) + list(ag_next_norm) 132 | inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) 133 | 134 | pi_next = self.actor_network(inputs_norm_tensor) 135 | action_next = self._select_actions(pi_next) 136 | action_next_tensor = torch.tensor(action_next, dtype=torch.float32) 137 | 138 | q_next_target = self.critic_target_network(inputs_next_norm_tensor.view(1,-1), action_next_tensor.view(1,-1)) 139 | q_next_value = q_next_target.detach() 140 | 141 | td_error = np.abs(q_curr_value - q_next_value) 142 | 143 | priority = (td_error + epsilon) ** alpha 144 | sum_priority += priority 145 | p_priority = priority / sum_priority 146 | 147 | ep_p.append(p_priority) 148 | 149 | # append rollouts 150 | ep_obs.append(obs.copy()) 151 | ep_ag.append(ag.copy()) 152 | ep_g.append(g.copy()) 153 | ep_actions.append(action.copy()) 154 | 155 | #if self.args.her == False: 156 | ep_r.append(r.copy()) 157 | 158 | # re-assign the observation 159 | obs = obs_new 160 | ag = ag_new 161 | 162 | reward_plot.append(sum(ep_r)) 163 | 164 | ''' 165 | if MPI.COMM_WORLD.Get_rank() == 0: 166 | print("epissode", i, "reward", sum(ep_r)) 167 | ''' 168 | 169 | i = i + 1 170 | 171 | ep_obs.append(obs.copy()) 172 | ep_ag.append(ag.copy()) 173 | 174 | mb_obs.append(ep_obs) 175 | mb_ag.append(ep_ag) 176 | mb_g.append(ep_g) 177 | mb_actions.append(ep_actions) 178 | 179 | if self.args.her == False: 180 | mb_r.append(ep_r) 181 | 182 | if self.args.per == True: 183 | mb_p.append(ep_p) 184 | 185 | # convert them into arrays 186 | mb_obs = np.array(mb_obs) 187 | mb_ag = np.array(mb_ag) 188 | mb_g = np.array(mb_g) 189 | mb_actions = np.array(mb_actions) 190 | 191 | if self.args.her == False: 192 | 193 | if self.args.per == True: 194 | mb_r = np.array(mb_r) 195 | mb_p = np.array(mb_p) 196 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p]) 197 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p]) 198 | 199 | else: 200 | mb_r = np.array(mb_r) 201 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_r]) 202 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions, mb_r]) 203 | 204 | 205 | elif self.args.her == True: 206 | if self.args.per == True: 207 | mb_p = np.array(mb_p) 208 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_p]) 209 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions, mb_p]) 210 | else: 211 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions]) 212 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions]) 213 | 214 | for _ in range(self.args.n_batches): 215 | # train the network 216 | self._update_network() 217 | 218 | # soft update 219 | self._soft_update_target_network(self.actor_target_network, self.actor_network) 220 | self._soft_update_target_network(self.critic_target_network, self.critic_network) 221 | 222 | # start to do the evaluation 223 | success_rate = self._eval_agent() 224 | 225 | if MPI.COMM_WORLD.Get_rank() == 0: 226 | print('[{}] epoch is: {}, eval success rate is: {:.3f}'.format(datetime.now(), epoch, success_rate)) 227 | to_plot.append(success_rate) 228 | 229 | 230 | torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std, self.actor_network.state_dict()], \ 231 | self.model_path + '/model.pt') 232 | 233 | if MPI.COMM_WORLD.Get_rank() == 0: 234 | plt.plot(range(50*1*1), reward_plot) 235 | plt.plot(range(self.args.n_epochs), to_plot) 236 | 237 | if self.args.env_name == 'FetchReach-v1': 238 | plt.xlabel('Episode') 239 | else: 240 | plt.xlabel('Epoch') 241 | 242 | plt.ylabel('Mean Success Rate') 243 | 244 | if self.args.per == True and self.args.her == True: 245 | plt.title("{} using DDPG + HER + PER".format(self.args.env_name)) 246 | plt.savefig("{}_DDPG_HER_PER".format(self.args.env_name)) 247 | np.save("{}_DDPG_HER_PER".format(self.args.env_name), to_plot) 248 | 249 | elif self.args.per == True and self.args.her == False: 250 | plt.title("{} using DDPG + PER".format(self.args.env_name)) 251 | plt.savefig("{}_DDPG_PER".format(self.args.env_name)) 252 | np.save("{}_DDPG_PER".format(self.args.env_name), to_plot) 253 | 254 | elif self.args.per == False and self.args.her == True: 255 | plt.title("{} using DDPG + HER".format(self.args.env_name)) 256 | plt.savefig("{}_DDPG_HER".format(self.args.env_name)) 257 | np.save("{}_DDPG_HER".format(self.args.env_name), to_plot) 258 | 259 | elif self.args.per == False and self.args.her == False: 260 | plt.title("{} using DDPG".format(self.args.env_name)) 261 | plt.savefig("{}_DDPG".format(self.args.env_name)) 262 | np.save("{}_DDPG".format(self.args.env_name), to_plot) 263 | 264 | plt.show() 265 | 266 | 267 | 268 | # pre_process the inputs 269 | def _preproc_inputs(self, obs, g): 270 | obs_norm = self.o_norm.normalize(obs) 271 | g_norm = self.g_norm.normalize(g) 272 | 273 | # concatenate the stuffs 274 | inputs = np.concatenate([obs_norm, g_norm]) 275 | inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0) 276 | if self.args.cuda: 277 | inputs = inputs.cuda() 278 | 279 | return inputs 280 | 281 | # this function will choose action for the agent and do the exploration 282 | def _select_actions(self, pi): 283 | action = pi.cpu().numpy().squeeze() 284 | 285 | # add the gaussian 286 | action += self.args.noise_eps * self.env_params['action_max'] * np.random.randn(*action.shape) 287 | action = np.clip(action, -self.env_params['action_max'], self.env_params['action_max']) 288 | 289 | # random actions... 290 | random_actions = np.random.uniform(low=-self.env_params['action_max'], high=self.env_params['action_max'], \ 291 | size=self.env_params['action']) 292 | 293 | # choose if use the random actions 294 | action += np.random.binomial(1, self.args.random_eps, 1)[0] * (random_actions - action) 295 | 296 | return action 297 | 298 | # update the normalizer 299 | def _update_normalizer(self, episode_batch): 300 | if self.args.her == False: 301 | if self.args.per == True: 302 | mb_obs, mb_ag, mb_g, mb_actions, mb_r, mb_p = episode_batch 303 | else: 304 | mb_obs, mb_ag, mb_g, mb_actions, mb_r = episode_batch 305 | 306 | elif self.args.her == True: 307 | if self.args.per == True: 308 | mb_obs, mb_ag, mb_g, mb_actions, mb_p = episode_batch 309 | else: 310 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch 311 | 312 | mb_obs_next = mb_obs[:, 1:, :] 313 | mb_ag_next = mb_ag[:, 1:, :] 314 | 315 | # get the number of normalization transitions 316 | num_transitions = mb_actions.shape[1] 317 | 318 | # create the new buffer to store them 319 | 320 | if self.args.her == False: 321 | if self.args.per == True: 322 | buffer_temp = {'obs': mb_obs, 323 | 'ag': mb_ag, 324 | 'g': mb_g, 325 | 'actions': mb_actions, 326 | 'obs_next': mb_obs_next, 327 | 'ag_next': mb_ag_next, 328 | 'r' : mb_r, 329 | 'p' : mb_p 330 | } 331 | else: 332 | buffer_temp = {'obs': mb_obs, 333 | 'ag': mb_ag, 334 | 'g': mb_g, 335 | 'actions': mb_actions, 336 | 'obs_next': mb_obs_next, 337 | 'ag_next': mb_ag_next, 338 | 'r' : mb_r 339 | } 340 | 341 | 342 | elif self.args.her == True: 343 | if self.args.per == True: 344 | buffer_temp = {'obs': mb_obs, 345 | 'ag': mb_ag, 346 | 'g': mb_g, 347 | 'actions': mb_actions, 348 | 'obs_next': mb_obs_next, 349 | 'ag_next': mb_ag_next, 350 | 'p' : mb_p 351 | } 352 | else: 353 | buffer_temp = {'obs': mb_obs, 354 | 'ag': mb_ag, 355 | 'g': mb_g, 356 | 'actions': mb_actions, 357 | 'obs_next': mb_obs_next, 358 | 'ag_next': mb_ag_next, 359 | } 360 | 361 | transitions = self.her_module.sample_her_transitions(buffer_temp, num_transitions) 362 | obs, g = transitions['obs'], transitions['g'] 363 | 364 | # pre process the obs and g 365 | transitions['obs'], transitions['g'] = self._preproc_og(obs, g) 366 | 367 | # update 368 | self.o_norm.update(transitions['obs']) 369 | self.g_norm.update(transitions['g']) 370 | 371 | # recompute the stats 372 | self.o_norm.recompute_stats() 373 | self.g_norm.recompute_stats() 374 | 375 | def _preproc_og(self, o, g): 376 | 377 | o = np.clip(o, -self.args.clip_obs, self.args.clip_obs) 378 | g = np.clip(g, -self.args.clip_obs, self.args.clip_obs) 379 | 380 | return o, g 381 | 382 | # soft update 383 | def _soft_update_target_network(self, target, source): 384 | for target_param, param in zip(target.parameters(), source.parameters()): 385 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 386 | 387 | # update the network 388 | def _update_network(self): 389 | 390 | # sample the episodes 391 | transitions = self.buffer.sample(self.args.batch_size) 392 | 393 | # pre-process the observation and goal 394 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g'] 395 | transitions['obs'], transitions['g'] = self._preproc_og(o, g) 396 | transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g) 397 | 398 | # start tof do the update 399 | obs_norm = self.o_norm.normalize(transitions['obs']) 400 | g_norm = self.g_norm.normalize(transitions['g']) 401 | inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) 402 | obs_next_norm = self.o_norm.normalize(transitions['obs_next']) 403 | g_next_norm = self.g_norm.normalize(transitions['g_next']) 404 | inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) 405 | 406 | # transfer them into the tensor 407 | inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) 408 | inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) 409 | actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) 410 | r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) 411 | 412 | if self.args.cuda: 413 | inputs_norm_tensor = inputs_norm_tensor.cuda() 414 | inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() 415 | actions_tensor = actions_tensor.cuda() 416 | r_tensor = r_tensor.cuda() 417 | 418 | # calculate the target Q value function 419 | with torch.no_grad(): 420 | # do the normalization 421 | # concatenate the stuffs 422 | actions_next = self.actor_target_network(inputs_next_norm_tensor) 423 | q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next) 424 | q_next_value = q_next_value.detach() 425 | target_q_value = r_tensor + self.args.gamma * q_next_value 426 | target_q_value = target_q_value.detach() 427 | 428 | # clip the q value 429 | clip_return = 1 / (1 - self.args.gamma) 430 | target_q_value = torch.clamp(target_q_value, -clip_return, 0) 431 | 432 | # the q loss 433 | real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor) 434 | critic_loss = (target_q_value - real_q_value).pow(2).mean() 435 | 436 | # the actor loss 437 | actions_real = self.actor_network(inputs_norm_tensor) 438 | actor_loss = -self.critic_network(inputs_norm_tensor, actions_real).mean() 439 | actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() 440 | 441 | # start to update the network 442 | self.actor_optim.zero_grad() 443 | actor_loss.backward() 444 | sync_grads(self.actor_network) 445 | self.actor_optim.step() 446 | 447 | # update the critic_network 448 | self.critic_optim.zero_grad() 449 | critic_loss.backward() 450 | sync_grads(self.critic_network) 451 | self.critic_optim.step() 452 | 453 | # do the evaluation 454 | def _eval_agent(self): 455 | total_success_rate = [] 456 | for _ in range(self.args.n_test_rollouts): 457 | per_success_rate = [] 458 | observation = self.env.reset() 459 | obs = observation['observation'] 460 | g = observation['desired_goal'] 461 | 462 | for _ in range(self.env_params['max_timesteps']): 463 | with torch.no_grad(): 464 | input_tensor = self._preproc_inputs(obs, g) 465 | pi = self.actor_network(input_tensor) 466 | 467 | # convert the actions 468 | actions = pi.detach().cpu().numpy().squeeze() 469 | 470 | observation_new, _, _, info = self.env.step(actions) 471 | obs = observation_new['observation'] 472 | g = observation_new['desired_goal'] 473 | per_success_rate.append(info['is_success']) 474 | 475 | total_success_rate.append(per_success_rate) 476 | 477 | total_success_rate = np.array(total_success_rate) 478 | local_success_rate = np.mean(total_success_rate[:, -1]) 479 | global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM) 480 | 481 | return global_success_rate / MPI.COMM_WORLD.Get_size() 482 | 483 | --------------------------------------------------------------------------------