├── her_modules ├── __init__.py └── her.py ├── mpi_utils ├── __init__.py ├── mpi_utils.py └── normalizer.py ├── rl_modules ├── __init__.py ├── models.py ├── replay_buffer.py └── ddpg_agent.py ├── figures ├── pick.gif ├── push.gif ├── reach.gif ├── results.png └── slide.gif ├── LICENSE ├── train.py ├── .gitignore ├── demo.py ├── README.md └── arguments.py /her_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mpi_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/pick.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TianhongDai/hindsight-experience-replay/HEAD/figures/pick.gif -------------------------------------------------------------------------------- /figures/push.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TianhongDai/hindsight-experience-replay/HEAD/figures/push.gif -------------------------------------------------------------------------------- /figures/reach.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TianhongDai/hindsight-experience-replay/HEAD/figures/reach.gif -------------------------------------------------------------------------------- /figures/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TianhongDai/hindsight-experience-replay/HEAD/figures/results.png -------------------------------------------------------------------------------- /figures/slide.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TianhongDai/hindsight-experience-replay/HEAD/figures/slide.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tianhong Dai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rl_modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | the input x in both networks should be [o, g], where o is the observation and g is the goal. 7 | 8 | """ 9 | 10 | # define the actor network 11 | class actor(nn.Module): 12 | def __init__(self, env_params): 13 | super(actor, self).__init__() 14 | self.max_action = env_params['action_max'] 15 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256) 16 | self.fc2 = nn.Linear(256, 256) 17 | self.fc3 = nn.Linear(256, 256) 18 | self.action_out = nn.Linear(256, env_params['action']) 19 | 20 | def forward(self, x): 21 | x = F.relu(self.fc1(x)) 22 | x = F.relu(self.fc2(x)) 23 | x = F.relu(self.fc3(x)) 24 | actions = self.max_action * torch.tanh(self.action_out(x)) 25 | 26 | return actions 27 | 28 | class critic(nn.Module): 29 | def __init__(self, env_params): 30 | super(critic, self).__init__() 31 | self.max_action = env_params['action_max'] 32 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256) 33 | self.fc2 = nn.Linear(256, 256) 34 | self.fc3 = nn.Linear(256, 256) 35 | self.q_out = nn.Linear(256, 1) 36 | 37 | def forward(self, x, actions): 38 | x = torch.cat([x, actions / self.max_action], dim=1) 39 | x = F.relu(self.fc1(x)) 40 | x = F.relu(self.fc2(x)) 41 | x = F.relu(self.fc3(x)) 42 | q_value = self.q_out(x) 43 | 44 | return q_value 45 | -------------------------------------------------------------------------------- /mpi_utils/mpi_utils.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | import torch 4 | 5 | # sync_networks across the different cores 6 | def sync_networks(network): 7 | """ 8 | netowrk is the network you want to sync 9 | 10 | """ 11 | comm = MPI.COMM_WORLD 12 | flat_params = _get_flat_params_or_grads(network, mode='params') 13 | comm.Bcast(flat_params, root=0) 14 | # set the flat params back to the network 15 | _set_flat_params_or_grads(network, flat_params, mode='params') 16 | 17 | def sync_grads(network): 18 | flat_grads = _get_flat_params_or_grads(network, mode='grads') 19 | comm = MPI.COMM_WORLD 20 | global_grads = np.zeros_like(flat_grads) 21 | comm.Allreduce(flat_grads, global_grads, op=MPI.SUM) 22 | _set_flat_params_or_grads(network, global_grads, mode='grads') 23 | 24 | # get the flat grads or params 25 | def _get_flat_params_or_grads(network, mode='params'): 26 | """ 27 | include two kinds: grads and params 28 | 29 | """ 30 | attr = 'data' if mode == 'params' else 'grad' 31 | return np.concatenate([getattr(param, attr).cpu().numpy().flatten() for param in network.parameters()]) 32 | 33 | def _set_flat_params_or_grads(network, flat_params, mode='params'): 34 | """ 35 | include two kinds: grads and params 36 | 37 | """ 38 | attr = 'data' if mode == 'params' else 'grad' 39 | # the pointer 40 | pointer = 0 41 | for param in network.parameters(): 42 | getattr(param, attr).copy_(torch.tensor(flat_params[pointer:pointer + param.data.numel()]).view_as(param.data)) 43 | pointer += param.data.numel() 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import os, sys 4 | from arguments import get_args 5 | from mpi4py import MPI 6 | from rl_modules.ddpg_agent import ddpg_agent 7 | import random 8 | import torch 9 | 10 | """ 11 | train the agent, the MPI part code is copy from openai baselines(https://github.com/openai/baselines/blob/master/baselines/her) 12 | 13 | """ 14 | def get_env_params(env): 15 | obs = env.reset() 16 | # close the environment 17 | params = {'obs': obs['observation'].shape[0], 18 | 'goal': obs['desired_goal'].shape[0], 19 | 'action': env.action_space.shape[0], 20 | 'action_max': env.action_space.high[0], 21 | } 22 | params['max_timesteps'] = env._max_episode_steps 23 | return params 24 | 25 | def launch(args): 26 | # create the ddpg_agent 27 | env = gym.make(args.env_name) 28 | # set random seeds for reproduce 29 | env.seed(args.seed + MPI.COMM_WORLD.Get_rank()) 30 | random.seed(args.seed + MPI.COMM_WORLD.Get_rank()) 31 | np.random.seed(args.seed + MPI.COMM_WORLD.Get_rank()) 32 | torch.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank()) 33 | if args.cuda: 34 | torch.cuda.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank()) 35 | # get the environment parameters 36 | env_params = get_env_params(env) 37 | # create the ddpg agent to interact with the environment 38 | ddpg_trainer = ddpg_agent(args, env, env_params) 39 | ddpg_trainer.learn() 40 | 41 | if __name__ == '__main__': 42 | # take the configuration for the HER 43 | os.environ['OMP_NUM_THREADS'] = '1' 44 | os.environ['MKL_NUM_THREADS'] = '1' 45 | os.environ['IN_MPI'] = '1' 46 | # get the params 47 | args = get_args() 48 | launch(args) 49 | -------------------------------------------------------------------------------- /her_modules/her.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class her_sampler: 4 | def __init__(self, replay_strategy, replay_k, reward_func=None): 5 | self.replay_strategy = replay_strategy 6 | self.replay_k = replay_k 7 | if self.replay_strategy == 'future': 8 | self.future_p = 1 - (1. / (1 + replay_k)) 9 | else: 10 | self.future_p = 0 11 | self.reward_func = reward_func 12 | 13 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 14 | T = episode_batch['actions'].shape[1] 15 | rollout_batch_size = episode_batch['actions'].shape[0] 16 | batch_size = batch_size_in_transitions 17 | # select which rollouts and which timesteps to be used 18 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 19 | t_samples = np.random.randint(T, size=batch_size) 20 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()} 21 | # her idx 22 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p) 23 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples) 24 | future_offset = future_offset.astype(int) 25 | future_t = (t_samples + 1 + future_offset)[her_indexes] 26 | # replace go with achieved goal 27 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] 28 | transitions['g'][her_indexes] = future_ag 29 | # to get the params to re-compute reward 30 | transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1) 31 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()} 32 | 33 | return transitions 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # DS Store 107 | .DS_Store 108 | 109 | #saved_model 110 | *.pth 111 | 112 | *.pt 113 | 114 | *.log 115 | 116 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from rl_modules.models import actor 3 | from arguments import get_args 4 | import gym 5 | import numpy as np 6 | 7 | # process the inputs 8 | def process_inputs(o, g, o_mean, o_std, g_mean, g_std, args): 9 | o_clip = np.clip(o, -args.clip_obs, args.clip_obs) 10 | g_clip = np.clip(g, -args.clip_obs, args.clip_obs) 11 | o_norm = np.clip((o_clip - o_mean) / (o_std), -args.clip_range, args.clip_range) 12 | g_norm = np.clip((g_clip - g_mean) / (g_std), -args.clip_range, args.clip_range) 13 | inputs = np.concatenate([o_norm, g_norm]) 14 | inputs = torch.tensor(inputs, dtype=torch.float32) 15 | return inputs 16 | 17 | if __name__ == '__main__': 18 | args = get_args() 19 | # load the model param 20 | model_path = args.save_dir + args.env_name + '/model.pt' 21 | o_mean, o_std, g_mean, g_std, model = torch.load(model_path, map_location=lambda storage, loc: storage) 22 | # create the environment 23 | env = gym.make(args.env_name) 24 | # get the env param 25 | observation = env.reset() 26 | # get the environment params 27 | env_params = {'obs': observation['observation'].shape[0], 28 | 'goal': observation['desired_goal'].shape[0], 29 | 'action': env.action_space.shape[0], 30 | 'action_max': env.action_space.high[0], 31 | } 32 | # create the actor network 33 | actor_network = actor(env_params) 34 | actor_network.load_state_dict(model) 35 | actor_network.eval() 36 | for i in range(args.demo_length): 37 | observation = env.reset() 38 | # start to do the demo 39 | obs = observation['observation'] 40 | g = observation['desired_goal'] 41 | for t in range(env._max_episode_steps): 42 | env.render() 43 | inputs = process_inputs(obs, g, o_mean, o_std, g_mean, g_std, args) 44 | with torch.no_grad(): 45 | pi = actor_network(inputs) 46 | action = pi.detach().numpy().squeeze() 47 | # put actions into the environment 48 | observation_new, reward, _, info = env.step(action) 49 | obs = observation_new['observation'] 50 | print('the episode is: {}, is success: {}'.format(i, info['is_success'])) 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hindsight Experience Replay (HER) 2 | This is a pytorch implementation of [Hindsight Experience Replay](https://arxiv.org/abs/1707.01495). 3 | 4 | ## Acknowledgement: 5 | - [Openai Baselines](https://github.com/openai/baselines) 6 | 7 | ## Requirements 8 | - python=3.5.2 9 | - openai-gym=0.12.5 (mujoco200 is supported, but you need to use gym >= 0.12.5, it has a bug in the previous version.) 10 | - mujoco-py=1.50.1.56 (~~**Please use this version, if you use mujoco200, you may failed in the FetchSlide-v1**~~) 11 | - pytorch=1.0.0 (**If you use pytorch-0.4.1, you may have data type errors. I will fix it later.**) 12 | - mpi4py 13 | 14 | ## TODO List 15 | - [x] support GPU acceleration - although I have added GPU support, but I still not recommend if you don't have a powerful machine. 16 | - [x] add multi-env per MPI. 17 | - [x] add the plot and demo of the **FetchSlide-v1**. 18 | 19 | ## Instruction to run the code 20 | If you want to use GPU, just add the flag `--cuda` **(Not Recommended, Better Use CPU)**. 21 | 1. train the **FetchReach-v1**: 22 | ```bash 23 | mpirun -np 1 python -u train.py --env-name='FetchReach-v1' --n-cycles=10 2>&1 | tee reach.log 24 | ``` 25 | 2. train the **FetchPush-v1**: 26 | ```bash 27 | mpirun -np 8 python -u train.py --env-name='FetchPush-v1' 2>&1 | tee push.log 28 | ``` 29 | 3. train the **FetchPickAndPlace-v1**: 30 | ```bash 31 | mpirun -np 16 python -u train.py --env-name='FetchPickAndPlace-v1' 2>&1 | tee pick.log 32 | ``` 33 | 4. train the **FetchSlide-v1**: 34 | ```bash 35 | mpirun -np 8 python -u train.py --env-name='FetchSlide-v1' --n-epochs=200 2>&1 | tee slide.log 36 | ``` 37 | 38 | ### Play Demo 39 | ```bash 40 | python demo.py --env-name= 41 | ``` 42 | ### Download the Pre-trained Model 43 | Please download them from the [Google Driver](https://drive.google.com/open?id=1dNzIpIcL4x1im8dJcUyNO30m_lhzO9K4), then put the `saved_models` under the current folder. 44 | 45 | ## Results 46 | ### Training Performance 47 | It was plotted by using 5 different seeds, the solid line is the median value. 48 | ![Training_Curve](figures/results.png) 49 | ### Demo: 50 | **Tips**: when you watch the demo, you can press **TAB** to switch the camera in the mujoco. 51 | 52 | FetchPush-v1| FetchPickAndPlace-v1| FetchSlide-v1 53 | -----------------------|-----------------------|-----------------------| 54 | ![](figures/push.gif)| ![](figures/pick.gif)| ![](figures/slide.gif) 55 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | """ 4 | Here are the param for the training 5 | 6 | """ 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | # the environment setting 11 | parser.add_argument('--env-name', type=str, default='FetchReach-v1', help='the environment name') 12 | parser.add_argument('--n-epochs', type=int, default=50, help='the number of epochs to train the agent') 13 | parser.add_argument('--n-cycles', type=int, default=50, help='the times to collect samples per epoch') 14 | parser.add_argument('--n-batches', type=int, default=40, help='the times to update the network') 15 | parser.add_argument('--save-interval', type=int, default=5, help='the interval that save the trajectory') 16 | parser.add_argument('--seed', type=int, default=123, help='random seed') 17 | parser.add_argument('--num-workers', type=int, default=1, help='the number of cpus to collect samples') 18 | parser.add_argument('--replay-strategy', type=str, default='future', help='the HER strategy') 19 | parser.add_argument('--clip-return', type=float, default=50, help='if clip the returns') 20 | parser.add_argument('--save-dir', type=str, default='saved_models/', help='the path to save the models') 21 | parser.add_argument('--noise-eps', type=float, default=0.2, help='noise eps') 22 | parser.add_argument('--random-eps', type=float, default=0.3, help='random eps') 23 | parser.add_argument('--buffer-size', type=int, default=int(1e6), help='the size of the buffer') 24 | parser.add_argument('--replay-k', type=int, default=4, help='ratio to be replace') 25 | parser.add_argument('--clip-obs', type=float, default=200, help='the clip ratio') 26 | parser.add_argument('--batch-size', type=int, default=256, help='the sample batch size') 27 | parser.add_argument('--gamma', type=float, default=0.98, help='the discount factor') 28 | parser.add_argument('--action-l2', type=float, default=1, help='l2 reg') 29 | parser.add_argument('--lr-actor', type=float, default=0.001, help='the learning rate of the actor') 30 | parser.add_argument('--lr-critic', type=float, default=0.001, help='the learning rate of the critic') 31 | parser.add_argument('--polyak', type=float, default=0.95, help='the average coefficient') 32 | parser.add_argument('--n-test-rollouts', type=int, default=10, help='the number of tests') 33 | parser.add_argument('--clip-range', type=float, default=5, help='the clip range') 34 | parser.add_argument('--demo-length', type=int, default=20, help='the demo length') 35 | parser.add_argument('--cuda', action='store_true', help='if use gpu do the acceleration') 36 | parser.add_argument('--num-rollouts-per-mpi', type=int, default=2, help='the rollouts per mpi') 37 | 38 | args = parser.parse_args() 39 | 40 | return args 41 | -------------------------------------------------------------------------------- /rl_modules/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import numpy as np 3 | 4 | """ 5 | the replay buffer here is basically from the openai baselines code 6 | 7 | """ 8 | class replay_buffer: 9 | def __init__(self, env_params, buffer_size, sample_func): 10 | self.env_params = env_params 11 | self.T = env_params['max_timesteps'] 12 | self.size = buffer_size // self.T 13 | # memory management 14 | self.current_size = 0 15 | self.n_transitions_stored = 0 16 | self.sample_func = sample_func 17 | # create the buffer to store info 18 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 19 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]), 20 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 21 | 'actions': np.empty([self.size, self.T, self.env_params['action']]), 22 | } 23 | # thread lock 24 | self.lock = threading.Lock() 25 | 26 | # store the episode 27 | def store_episode(self, episode_batch): 28 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch 29 | batch_size = mb_obs.shape[0] 30 | with self.lock: 31 | idxs = self._get_storage_idx(inc=batch_size) 32 | # store the informations 33 | self.buffers['obs'][idxs] = mb_obs 34 | self.buffers['ag'][idxs] = mb_ag 35 | self.buffers['g'][idxs] = mb_g 36 | self.buffers['actions'][idxs] = mb_actions 37 | self.n_transitions_stored += self.T * batch_size 38 | 39 | # sample the data from the replay buffer 40 | def sample(self, batch_size): 41 | temp_buffers = {} 42 | with self.lock: 43 | for key in self.buffers.keys(): 44 | temp_buffers[key] = self.buffers[key][:self.current_size] 45 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :] 46 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :] 47 | # sample transitions 48 | transitions = self.sample_func(temp_buffers, batch_size) 49 | return transitions 50 | 51 | def _get_storage_idx(self, inc=None): 52 | inc = inc or 1 53 | if self.current_size+inc <= self.size: 54 | idx = np.arange(self.current_size, self.current_size+inc) 55 | elif self.current_size < self.size: 56 | overflow = inc - (self.size - self.current_size) 57 | idx_a = np.arange(self.current_size, self.size) 58 | idx_b = np.random.randint(0, self.current_size, overflow) 59 | idx = np.concatenate([idx_a, idx_b]) 60 | else: 61 | idx = np.random.randint(0, self.size, inc) 62 | self.current_size = min(self.size, self.current_size+inc) 63 | if inc == 1: 64 | idx = idx[0] 65 | return idx 66 | -------------------------------------------------------------------------------- /mpi_utils/normalizer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import numpy as np 3 | from mpi4py import MPI 4 | 5 | class normalizer: 6 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf): 7 | self.size = size 8 | self.eps = eps 9 | self.default_clip_range = default_clip_range 10 | # some local information 11 | self.local_sum = np.zeros(self.size, np.float32) 12 | self.local_sumsq = np.zeros(self.size, np.float32) 13 | self.local_count = np.zeros(1, np.float32) 14 | # get the total sum sumsq and sum count 15 | self.total_sum = np.zeros(self.size, np.float32) 16 | self.total_sumsq = np.zeros(self.size, np.float32) 17 | self.total_count = np.ones(1, np.float32) 18 | # get the mean and std 19 | self.mean = np.zeros(self.size, np.float32) 20 | self.std = np.ones(self.size, np.float32) 21 | # thread locker 22 | self.lock = threading.Lock() 23 | 24 | # update the parameters of the normalizer 25 | def update(self, v): 26 | v = v.reshape(-1, self.size) 27 | # do the computing 28 | with self.lock: 29 | self.local_sum += v.sum(axis=0) 30 | self.local_sumsq += (np.square(v)).sum(axis=0) 31 | self.local_count[0] += v.shape[0] 32 | 33 | # sync the parameters across the cpus 34 | def sync(self, local_sum, local_sumsq, local_count): 35 | local_sum[...] = self._mpi_average(local_sum) 36 | local_sumsq[...] = self._mpi_average(local_sumsq) 37 | local_count[...] = self._mpi_average(local_count) 38 | return local_sum, local_sumsq, local_count 39 | 40 | def recompute_stats(self): 41 | with self.lock: 42 | local_count = self.local_count.copy() 43 | local_sum = self.local_sum.copy() 44 | local_sumsq = self.local_sumsq.copy() 45 | # reset 46 | self.local_count[...] = 0 47 | self.local_sum[...] = 0 48 | self.local_sumsq[...] = 0 49 | # synrc the stats 50 | sync_sum, sync_sumsq, sync_count = self.sync(local_sum, local_sumsq, local_count) 51 | # update the total stuff 52 | self.total_sum += sync_sum 53 | self.total_sumsq += sync_sumsq 54 | self.total_count += sync_count 55 | # calculate the new mean and std 56 | self.mean = self.total_sum / self.total_count 57 | self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square(self.total_sum / self.total_count))) 58 | 59 | # average across the cpu's data 60 | def _mpi_average(self, x): 61 | buf = np.zeros_like(x) 62 | MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM) 63 | buf /= MPI.COMM_WORLD.Get_size() 64 | return buf 65 | 66 | # normalize the observation 67 | def normalize(self, v, clip_range=None): 68 | if clip_range is None: 69 | clip_range = self.default_clip_range 70 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range) 71 | -------------------------------------------------------------------------------- /rl_modules/ddpg_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | from mpi4py import MPI 6 | from mpi_utils.mpi_utils import sync_networks, sync_grads 7 | from rl_modules.replay_buffer import replay_buffer 8 | from rl_modules.models import actor, critic 9 | from mpi_utils.normalizer import normalizer 10 | from her_modules.her import her_sampler 11 | 12 | """ 13 | ddpg with HER (MPI-version) 14 | 15 | """ 16 | class ddpg_agent: 17 | def __init__(self, args, env, env_params): 18 | self.args = args 19 | self.env = env 20 | self.env_params = env_params 21 | # create the network 22 | self.actor_network = actor(env_params) 23 | self.critic_network = critic(env_params) 24 | # sync the networks across the cpus 25 | sync_networks(self.actor_network) 26 | sync_networks(self.critic_network) 27 | # build up the target network 28 | self.actor_target_network = actor(env_params) 29 | self.critic_target_network = critic(env_params) 30 | # load the weights into the target networks 31 | self.actor_target_network.load_state_dict(self.actor_network.state_dict()) 32 | self.critic_target_network.load_state_dict(self.critic_network.state_dict()) 33 | # if use gpu 34 | if self.args.cuda: 35 | self.actor_network.cuda() 36 | self.critic_network.cuda() 37 | self.actor_target_network.cuda() 38 | self.critic_target_network.cuda() 39 | # create the optimizer 40 | self.actor_optim = torch.optim.Adam(self.actor_network.parameters(), lr=self.args.lr_actor) 41 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic) 42 | # her sampler 43 | self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, self.env.compute_reward) 44 | # create the replay buffer 45 | self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions) 46 | # create the normalizer 47 | self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range) 48 | self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range) 49 | # create the dict for store the model 50 | if MPI.COMM_WORLD.Get_rank() == 0: 51 | if not os.path.exists(self.args.save_dir): 52 | os.mkdir(self.args.save_dir) 53 | # path to save the model 54 | self.model_path = os.path.join(self.args.save_dir, self.args.env_name) 55 | if not os.path.exists(self.model_path): 56 | os.mkdir(self.model_path) 57 | 58 | def learn(self): 59 | """ 60 | train the network 61 | 62 | """ 63 | # start to collect samples 64 | for epoch in range(self.args.n_epochs): 65 | for _ in range(self.args.n_cycles): 66 | mb_obs, mb_ag, mb_g, mb_actions = [], [], [], [] 67 | for _ in range(self.args.num_rollouts_per_mpi): 68 | # reset the rollouts 69 | ep_obs, ep_ag, ep_g, ep_actions = [], [], [], [] 70 | # reset the environment 71 | observation = self.env.reset() 72 | obs = observation['observation'] 73 | ag = observation['achieved_goal'] 74 | g = observation['desired_goal'] 75 | # start to collect samples 76 | for t in range(self.env_params['max_timesteps']): 77 | with torch.no_grad(): 78 | input_tensor = self._preproc_inputs(obs, g) 79 | pi = self.actor_network(input_tensor) 80 | action = self._select_actions(pi) 81 | # feed the actions into the environment 82 | observation_new, _, _, info = self.env.step(action) 83 | obs_new = observation_new['observation'] 84 | ag_new = observation_new['achieved_goal'] 85 | # append rollouts 86 | ep_obs.append(obs.copy()) 87 | ep_ag.append(ag.copy()) 88 | ep_g.append(g.copy()) 89 | ep_actions.append(action.copy()) 90 | # re-assign the observation 91 | obs = obs_new 92 | ag = ag_new 93 | ep_obs.append(obs.copy()) 94 | ep_ag.append(ag.copy()) 95 | mb_obs.append(ep_obs) 96 | mb_ag.append(ep_ag) 97 | mb_g.append(ep_g) 98 | mb_actions.append(ep_actions) 99 | # convert them into arrays 100 | mb_obs = np.array(mb_obs) 101 | mb_ag = np.array(mb_ag) 102 | mb_g = np.array(mb_g) 103 | mb_actions = np.array(mb_actions) 104 | # store the episodes 105 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions]) 106 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions]) 107 | for _ in range(self.args.n_batches): 108 | # train the network 109 | self._update_network() 110 | # soft update 111 | self._soft_update_target_network(self.actor_target_network, self.actor_network) 112 | self._soft_update_target_network(self.critic_target_network, self.critic_network) 113 | # start to do the evaluation 114 | success_rate = self._eval_agent() 115 | if MPI.COMM_WORLD.Get_rank() == 0: 116 | print('[{}] epoch is: {}, eval success rate is: {:.3f}'.format(datetime.now(), epoch, success_rate)) 117 | torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std, self.actor_network.state_dict()], \ 118 | self.model_path + '/model.pt') 119 | 120 | # pre_process the inputs 121 | def _preproc_inputs(self, obs, g): 122 | obs_norm = self.o_norm.normalize(obs) 123 | g_norm = self.g_norm.normalize(g) 124 | # concatenate the stuffs 125 | inputs = np.concatenate([obs_norm, g_norm]) 126 | inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0) 127 | if self.args.cuda: 128 | inputs = inputs.cuda() 129 | return inputs 130 | 131 | # this function will choose action for the agent and do the exploration 132 | def _select_actions(self, pi): 133 | action = pi.cpu().numpy().squeeze() 134 | # add the gaussian 135 | action += self.args.noise_eps * self.env_params['action_max'] * np.random.randn(*action.shape) 136 | action = np.clip(action, -self.env_params['action_max'], self.env_params['action_max']) 137 | # random actions... 138 | random_actions = np.random.uniform(low=-self.env_params['action_max'], high=self.env_params['action_max'], \ 139 | size=self.env_params['action']) 140 | # choose if use the random actions 141 | action += np.random.binomial(1, self.args.random_eps, 1)[0] * (random_actions - action) 142 | return action 143 | 144 | # update the normalizer 145 | def _update_normalizer(self, episode_batch): 146 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch 147 | mb_obs_next = mb_obs[:, 1:, :] 148 | mb_ag_next = mb_ag[:, 1:, :] 149 | # get the number of normalization transitions 150 | num_transitions = mb_actions.shape[1] 151 | # create the new buffer to store them 152 | buffer_temp = {'obs': mb_obs, 153 | 'ag': mb_ag, 154 | 'g': mb_g, 155 | 'actions': mb_actions, 156 | 'obs_next': mb_obs_next, 157 | 'ag_next': mb_ag_next, 158 | } 159 | transitions = self.her_module.sample_her_transitions(buffer_temp, num_transitions) 160 | obs, g = transitions['obs'], transitions['g'] 161 | # pre process the obs and g 162 | transitions['obs'], transitions['g'] = self._preproc_og(obs, g) 163 | # update 164 | self.o_norm.update(transitions['obs']) 165 | self.g_norm.update(transitions['g']) 166 | # recompute the stats 167 | self.o_norm.recompute_stats() 168 | self.g_norm.recompute_stats() 169 | 170 | def _preproc_og(self, o, g): 171 | o = np.clip(o, -self.args.clip_obs, self.args.clip_obs) 172 | g = np.clip(g, -self.args.clip_obs, self.args.clip_obs) 173 | return o, g 174 | 175 | # soft update 176 | def _soft_update_target_network(self, target, source): 177 | for target_param, param in zip(target.parameters(), source.parameters()): 178 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 179 | 180 | # update the network 181 | def _update_network(self): 182 | # sample the episodes 183 | transitions = self.buffer.sample(self.args.batch_size) 184 | # pre-process the observation and goal 185 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g'] 186 | transitions['obs'], transitions['g'] = self._preproc_og(o, g) 187 | transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g) 188 | # start to do the update 189 | obs_norm = self.o_norm.normalize(transitions['obs']) 190 | g_norm = self.g_norm.normalize(transitions['g']) 191 | inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) 192 | obs_next_norm = self.o_norm.normalize(transitions['obs_next']) 193 | g_next_norm = self.g_norm.normalize(transitions['g_next']) 194 | inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) 195 | # transfer them into the tensor 196 | inputs_norm_tensor = torch.tensor(inputs_norm, dtype=torch.float32) 197 | inputs_next_norm_tensor = torch.tensor(inputs_next_norm, dtype=torch.float32) 198 | actions_tensor = torch.tensor(transitions['actions'], dtype=torch.float32) 199 | r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) 200 | if self.args.cuda: 201 | inputs_norm_tensor = inputs_norm_tensor.cuda() 202 | inputs_next_norm_tensor = inputs_next_norm_tensor.cuda() 203 | actions_tensor = actions_tensor.cuda() 204 | r_tensor = r_tensor.cuda() 205 | # calculate the target Q value function 206 | with torch.no_grad(): 207 | # do the normalization 208 | # concatenate the stuffs 209 | actions_next = self.actor_target_network(inputs_next_norm_tensor) 210 | q_next_value = self.critic_target_network(inputs_next_norm_tensor, actions_next) 211 | q_next_value = q_next_value.detach() 212 | target_q_value = r_tensor + self.args.gamma * q_next_value 213 | target_q_value = target_q_value.detach() 214 | # clip the q value 215 | clip_return = 1 / (1 - self.args.gamma) 216 | target_q_value = torch.clamp(target_q_value, -clip_return, 0) 217 | # the q loss 218 | real_q_value = self.critic_network(inputs_norm_tensor, actions_tensor) 219 | critic_loss = (target_q_value - real_q_value).pow(2).mean() 220 | # the actor loss 221 | actions_real = self.actor_network(inputs_norm_tensor) 222 | actor_loss = -self.critic_network(inputs_norm_tensor, actions_real).mean() 223 | actor_loss += self.args.action_l2 * (actions_real / self.env_params['action_max']).pow(2).mean() 224 | # start to update the network 225 | self.actor_optim.zero_grad() 226 | actor_loss.backward() 227 | sync_grads(self.actor_network) 228 | self.actor_optim.step() 229 | # update the critic_network 230 | self.critic_optim.zero_grad() 231 | critic_loss.backward() 232 | sync_grads(self.critic_network) 233 | self.critic_optim.step() 234 | 235 | # do the evaluation 236 | def _eval_agent(self): 237 | total_success_rate = [] 238 | for _ in range(self.args.n_test_rollouts): 239 | per_success_rate = [] 240 | observation = self.env.reset() 241 | obs = observation['observation'] 242 | g = observation['desired_goal'] 243 | for _ in range(self.env_params['max_timesteps']): 244 | with torch.no_grad(): 245 | input_tensor = self._preproc_inputs(obs, g) 246 | pi = self.actor_network(input_tensor) 247 | # convert the actions 248 | actions = pi.detach().cpu().numpy().squeeze() 249 | observation_new, _, _, info = self.env.step(actions) 250 | obs = observation_new['observation'] 251 | g = observation_new['desired_goal'] 252 | per_success_rate.append(info['is_success']) 253 | total_success_rate.append(per_success_rate) 254 | total_success_rate = np.array(total_success_rate) 255 | local_success_rate = np.mean(total_success_rate[:, -1]) 256 | global_success_rate = MPI.COMM_WORLD.allreduce(local_success_rate, op=MPI.SUM) 257 | return global_success_rate / MPI.COMM_WORLD.Get_size() 258 | --------------------------------------------------------------------------------