├── atari_modules ├── __init__.py ├── her.py ├── models.py ├── replay_buffer.py ├── dqn_agent.py ├── her_dqn_agent.py ├── wrappers.py └── fb_agent.py ├── continuous_world_modules ├── __init__.py ├── test_env.py ├── featurizer.py ├── geometry.py ├── env.py └── dqn_agent.py ├── discrete_action_robots_modules ├── __init__.py ├── normalizer.py ├── replay_buffer.py ├── models.py ├── robots.py └── dqn_agent.py ├── .gitignore ├── grid_modules ├── __init__.py ├── gridworld │ ├── actions.py │ ├── __init__.py │ ├── txt_utilities.py │ ├── env.py │ ├── helper_utilities.py │ └── builder_tools.py ├── exceptions.py ├── utils.py ├── mdp_utils.py ├── her.py ├── common.py ├── plotting.py ├── replay_buffer.py ├── dqn_agent.py └── fb_agent.py ├── discrete_robots_main.py ├── atari_main.py ├── continuous_main.py ├── grid_main.py ├── README.md └── arguments.py /atari_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /continuous_world_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /discrete_action_robots_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | saved_models 4 | -------------------------------------------------------------------------------- /grid_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # from grid_modules.common import MDP -------------------------------------------------------------------------------- /grid_modules/gridworld/actions.py: -------------------------------------------------------------------------------- 1 | LEFT = 0 2 | RIGHT = 1 3 | UP = 2 4 | DOWN = 3 5 | STAY = 4 -------------------------------------------------------------------------------- /grid_modules/exceptions.py: -------------------------------------------------------------------------------- 1 | class EpisodeDoneError(TimeoutError): 2 | """An error for when the episode is over.""" 3 | pass 4 | 5 | 6 | class InvalidActionError(ValueError): 7 | """An error for when an invalid action is taken""" 8 | pass -------------------------------------------------------------------------------- /grid_modules/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # 1D utilities. 4 | 5 | 6 | def convert_int_rep_to_onehot(state, vector_size): 7 | s = np.zeros(vector_size) 8 | s[state] = 1 9 | return s 10 | 11 | 12 | def convert_onehot_to_int(state): 13 | if type(state) is not np.ndarray: 14 | state = np.array(state) 15 | return state.argmax().item() 16 | 17 | -------------------------------------------------------------------------------- /grid_modules/mdp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def value_iteration(R, P, gamma, atol=0.0001, max_iteration=1000): 6 | q = torch.zeros_like(R) 7 | for i in range(max_iteration): 8 | q_old = q 9 | v = torch.max(q, dim=1)[0] 10 | q = R + gamma * torch.einsum('ijk, k->ij', P, v) 11 | if torch.allclose(q, q_old, atol=atol): 12 | break 13 | return q 14 | 15 | 16 | def extract_policy(q, policy_type='boltzmann', temp=1, eps=0): 17 | action_space = q.shape[-1] 18 | if policy_type == 'boltzmann': 19 | policy = F.softmax(q / temp, dim=-1) 20 | elif policy_type == 'greedy': 21 | max_idx = torch.argmax(q, 1, keepdim=True) 22 | policy = torch.zeros_like(q).fill_(eps / action_space) 23 | policy.scatter_(1, max_idx, 1 - eps + eps / action_space) 24 | else: 25 | raise NotImplementedError() 26 | return policy 27 | 28 | 29 | def compute_successor_reps(P, pi, gamma): 30 | state_space, action_space = P.shape[:2] 31 | P_pi = torch.einsum('sax, xu -> saxu', P, pi) # S x A x S x A 32 | P_pi = P_pi.transpose(0, 1).transpose(2, 3).reshape(state_space * action_space, 33 | state_space * action_space) 34 | Id = torch.eye(*P_pi.size(), out=torch.empty_like(P_pi)) 35 | sr_pi = torch.inverse(Id - gamma * P_pi) 36 | return sr_pi -------------------------------------------------------------------------------- /discrete_robots_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from arguments import get_args 3 | from discrete_action_robots_modules.dqn_agent import DQNAgent 4 | from discrete_action_robots_modules.fb_agent import FBAgent 5 | from discrete_action_robots_modules.robots import FetchReach 6 | import random 7 | import torch 8 | 9 | 10 | def get_env_params(env): 11 | obs = env.reset() 12 | # close the environment 13 | params = {'obs': obs['observation'].shape[0], 14 | 'goal': obs['desired_goal'].shape[0], 15 | 'action': env.num_actions, 16 | } 17 | params['max_timesteps'] = env._max_episode_steps 18 | return params 19 | 20 | 21 | def launch(args): 22 | env = FetchReach() 23 | # import pdb 24 | # pdb.set_trace() 25 | # set random seeds for reproduce 26 | env.seed(args.seed) 27 | random.seed(args.seed) 28 | np.random.seed(args.seed) 29 | torch.manual_seed(args.seed) 30 | if args.cuda: 31 | torch.cuda.manual_seed(args.seed) 32 | # get the environment parameters 33 | env_params = get_env_params(env) 34 | # create the agent to interact with the environment 35 | if args.agent == 'DQN': 36 | dqn_trainer = DQNAgent(args, env, env_params) 37 | dqn_trainer.learn() 38 | elif args.agent == 'FB': 39 | fb_trainer = FBAgent(args, env, env_params) 40 | fb_trainer.learn() 41 | else: 42 | raise NotImplementedError() 43 | 44 | 45 | if __name__ == '__main__': 46 | # get the params 47 | args = get_args() 48 | launch(args) 49 | -------------------------------------------------------------------------------- /atari_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from arguments import get_args 3 | from atari_modules.dqn_agent import dqn_agent 4 | from atari_modules.fb_agent import FBAgent 5 | from atari_modules.her_dqn_agent import HerDQNAgent 6 | from atari_modules.wrappers import make_goalPacman 7 | import random 8 | import torch 9 | 10 | 11 | 12 | def get_env_params(env): 13 | params = {'obs': env.observation_space['observation'].shape, 14 | 'goal': 2, 15 | 'action': env.action_space.n, 16 | } 17 | params['max_timesteps'] = 50 18 | return params 19 | 20 | 21 | def launch(args): 22 | 23 | env = make_goalPacman() 24 | # set random seeds for reproduce 25 | env.seed(args.seed) 26 | random.seed(args.seed) 27 | np.random.seed(args.seed) 28 | torch.manual_seed(args.seed) 29 | if args.cuda: 30 | torch.cuda.manual_seed(args.seed) 31 | # get the environment parameters 32 | env_params = get_env_params(env) 33 | # create the agent to interact with the environment 34 | if args.agent == 'DQN': 35 | dqn_trainer = dqn_agent(args, env, env_params) 36 | dqn_trainer.learn() 37 | elif args.agent == 'FB': 38 | fb_trainer = FBAgent(args, env, env_params) 39 | fb_trainer.learn() 40 | elif args.agent == 'HerDQN': 41 | her_agent = HerDQNAgent(args, env, env_params) 42 | her_agent.learn() 43 | else: 44 | raise NotImplementedError() 45 | 46 | 47 | if __name__ == '__main__': 48 | # get the params 49 | args = get_args() 50 | launch(args) 51 | -------------------------------------------------------------------------------- /continuous_world_modules/test_env.py: -------------------------------------------------------------------------------- 1 | from continuous_world_modules.geometry import Point 2 | from continuous_world_modules.env import ContinuousWorld, visualize_environment 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | AGENT = 'random' 7 | env = ContinuousWorld(1, wall_pairs=[ 8 | (Point(0.25, 0.0), Point(0.25, 0.4)), 9 | (Point(0.75, 1), Point(0.75, 0.6))], 10 | movement_noise=0.01, 11 | threshold_distance=0.05) 12 | 13 | fig = plt.figure(figsize=(5, 5)) 14 | ax = fig.add_subplot(111) 15 | 16 | if AGENT == 'random': 17 | current_state = env.reset() 18 | # env.set_initial_position(Point(0.2, 0.1)) 19 | # env.set_goal(Point(0.9, 0.9)) 20 | # env.set_initial_position(Point(2, 1)) 21 | # current_state = env.current_position 22 | # env.set_agent_position(Point(3, 4)) 23 | visualize_environment(env, ax) 24 | for _ in range(100): 25 | x_pos, y_pos = current_state[0], current_state[1] 26 | action = np.random.choice(5) 27 | state, reward, done, info = env.step(action) 28 | # print(state) 29 | perturbed_action = state - current_state 30 | # action = ACTIONS.STAY 31 | # x_pos, y_pos = env.agent_position 32 | ax.quiver(x_pos, y_pos, perturbed_action[0], perturbed_action[1], color='#1ABC9C', alpha=1.0, 33 | angles='xy', scale_units='xy', scale=1, 34 | headwidth=5, linewidths=1, 35 | headlength=4) 36 | 37 | current_state = state 38 | 39 | plt.show() -------------------------------------------------------------------------------- /grid_modules/gridworld/__init__.py: -------------------------------------------------------------------------------- 1 | FOUR_ROOM_TXT = """########### 2 | # # # 3 | # # 4 | # # # 5 | # # # 6 | ## #### ### 7 | # # # 8 | # # # 9 | # # 10 | #s # # 11 | ###########""".split('\n') 12 | 13 | 14 | BIG_FOUR_ROOM_TXT = """##################### 15 | # # # 16 | # # # 17 | # # 18 | # # # 19 | # # # 20 | # # # 21 | # # # 22 | # # # 23 | # # # 24 | #### ######### ###### 25 | # # # 26 | # # # 27 | # # # 28 | # # 29 | # # # 30 | # s # # 31 | # # # 32 | # # # 33 | # # # 34 | #####################""".split('\n') 35 | 36 | ONE_ROOM_TXT = """##################### 37 | # # 38 | # # 39 | # # 40 | # # 41 | # # 42 | # # 43 | # # 44 | # # 45 | # # 46 | # # 47 | # # 48 | # # 49 | #s # 50 | # # 51 | # # 52 | # # 53 | # # 54 | # # 55 | # # 56 | #####################""".split('\n') 57 | 58 | STAIR_TXT = """########### 59 | # # 60 | # ###### 61 | # # 62 | ###### # 63 | # # 64 | # ###### 65 | # # 66 | ###### # 67 | #s # 68 | ###########""".split('\n') -------------------------------------------------------------------------------- /continuous_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from arguments import get_args 4 | from continuous_world_modules.dqn_agent import DQNAgent 5 | from continuous_world_modules.env import ContinuousWorld 6 | from continuous_world_modules.geometry import Point 7 | from continuous_world_modules.fb_agent import FBAgent 8 | 9 | import random 10 | import torch 11 | 12 | 13 | def get_env_params(env): 14 | params = {'obs': 441, 15 | 'goal': 441, 16 | 'action': 5, 17 | } 18 | params['max_timesteps'] = 30 19 | return params 20 | 21 | 22 | def launch(args): 23 | 24 | # set random seeds for reproduce 25 | # env.seed(args.seed) 26 | random.seed(args.seed) 27 | np.random.seed(args.seed) 28 | torch.manual_seed(args.seed) 29 | if args.cuda: 30 | torch.cuda.manual_seed(args.seed) 31 | 32 | env = ContinuousWorld(1, wall_pairs=[ 33 | (Point(0.25, 0.0), Point(0.25, 0.4)), 34 | (Point(0.75, 1.0), Point(0.75, 0.6))], 35 | movement_noise=0.01, 36 | threshold_distance=0.05, 37 | seed=args.seed) 38 | 39 | # get the environment parameters 40 | env_params = get_env_params(env) 41 | # create the agent to interact with the environment 42 | if args.agent == 'DQN': 43 | dqn_trainer = DQNAgent(args, env, env_params) 44 | dqn_trainer.learn() 45 | elif args.agent == 'FB': 46 | fb_trainer = FBAgent(args, env, env_params) 47 | fb_trainer.learn() 48 | else: 49 | raise NotImplementedError() 50 | 51 | 52 | if __name__ == '__main__': 53 | # get the params 54 | args = get_args() 55 | launch(args) 56 | -------------------------------------------------------------------------------- /grid_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from arguments import get_args 4 | from grid_modules.dqn_agent import DQNAgent 5 | from grid_modules.fb_agent import FBAgent 6 | from grid_modules.gridworld.txt_utilities import get_char_matrix, build_gridworld_from_char_matrix 7 | from grid_modules.gridworld import FOUR_ROOM_TXT, BIG_FOUR_ROOM_TXT 8 | import random 9 | import torch 10 | 11 | 12 | def build_grid(gamma=0.99, seed=123, p_success=1.0): 13 | char_matrix = get_char_matrix(FOUR_ROOM_TXT) 14 | return build_gridworld_from_char_matrix(char_matrix, p_success=p_success, seed=seed, gamma=gamma) 15 | 16 | 17 | def get_env_params(env): 18 | params = {'obs': env.state_space, 19 | 'goal': env.state_space, 20 | 'action': env.action_space, 21 | } 22 | params['max_timesteps'] = 50 23 | return params 24 | 25 | 26 | def launch(args): 27 | 28 | env = build_grid(gamma=args.gamma, seed=args.seed) 29 | # set random seeds for reproduce 30 | # env.seed(args.seed) 31 | random.seed(args.seed) 32 | np.random.seed(args.seed) 33 | torch.manual_seed(args.seed) 34 | if args.cuda: 35 | torch.cuda.manual_seed(args.seed) 36 | # get the environment parameters 37 | env_params = get_env_params(env) 38 | # create the agent to interact with the environment 39 | if args.agent == 'DQN': 40 | dqn_trainer = DQNAgent(args, env, env_params) 41 | dqn_trainer.learn() 42 | elif args.agent == 'FB': 43 | fb_trainer = FBAgent(args, env, env_params) 44 | fb_trainer.learn() 45 | else: 46 | raise NotImplementedError() 47 | 48 | 49 | if __name__ == '__main__': 50 | # get the params 51 | args = get_args() 52 | launch(args) 53 | -------------------------------------------------------------------------------- /grid_modules/her.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class her_sampler: 5 | def __init__(self, replay_strategy, replay_k, reward_func=None): 6 | self.replay_strategy = replay_strategy 7 | self.replay_k = replay_k 8 | if self.replay_strategy == 'future': 9 | self.future_p = 1 - (1. / (1 + replay_k)) 10 | elif self.replay_strategy == 'none': 11 | self.future_p = 0 12 | else: 13 | raise NotImplementedError() 14 | self.reward_func = reward_func 15 | 16 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 17 | T = episode_batch['action'].shape[1] 18 | rollout_batch_size = episode_batch['action'].shape[0] 19 | batch_size = batch_size_in_transitions 20 | # select which rollouts and which timesteps to be used 21 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 22 | t_samples = np.random.randint(T, size=batch_size) 23 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()} 24 | # her idx 25 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p) 26 | # import pdb 27 | # pdb.set_trace() 28 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples) 29 | future_offset = future_offset.astype(int) 30 | future_t = (t_samples + 1 + future_offset)[her_indexes] 31 | # replace go with achieved goal 32 | future_obs = episode_batch['obs'][episode_idxs[her_indexes], future_t] 33 | transitions['g'][her_indexes] = future_obs 34 | # to get the params to re-compute reward 35 | transitions['reward'] = np.expand_dims(self.reward_func(transitions['obs_next'], transitions['g']), 1) 36 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()} 37 | 38 | return transitions 39 | -------------------------------------------------------------------------------- /continuous_world_modules/featurizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class RadialBasisFunction2D(object): 6 | def __init__(self, size, dim, sigma, cuda=False): 7 | FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 8 | self.size = size 9 | self.dim = dim 10 | self.sigma = sigma 11 | 12 | xlist = np.linspace(0, size, dim) 13 | ylist = np.linspace(0, size, dim) 14 | XX, YY = np.meshgrid(xlist, ylist) 15 | 16 | self.XX = XX 17 | self.YY = YY 18 | 19 | self.x_mu = FloatTensor(XX.flatten()) 20 | self.y_mu = FloatTensor(YY.flatten()) 21 | 22 | def transform(self, A): 23 | distance = (A[:, 0].reshape(-1, 1) - self.x_mu[None])**2 + (A[:, 1].reshape(-1, 1) - self.y_mu[None])**2 24 | # X_mu = np.broadcast_to(self.x_mu, (A.shape[0], self.dim**2)) 25 | # Y_mu = np.broadcast_to(self.y_mu, (A.shape[0], self.dim**2)) 26 | # X = np.broadcast_to(A[:,0].reshape(-1,1), (A.shape[0], self.dim**2)) 27 | # Y = np.broadcast_to(A[:,1].reshape(-1,1), (A.shape[0], self.dim**2)) 28 | # 29 | # distance = ((X - X_mu)**2+((Y - Y_mu)**2)) 30 | weights = torch.exp(-distance/(2 * (self.sigma**2))) 31 | return weights / weights.sum(axis=1, keepdims=True) 32 | 33 | def inverse_transform(self, A): 34 | index = np.argmax(A, axis=1) 35 | result = [] 36 | for idx in index: 37 | i, j = self._1d_index_to_2d_index(idx) 38 | result.append([self.XX[i][j], self.YY[i][j]]) 39 | 40 | return np.array(result) 41 | 42 | def _1d_index_to_2d_index(self, index): 43 | i = index // self.dim 44 | j = index % self.dim 45 | 46 | return i, j 47 | 48 | 49 | if __name__ == '__main__': 50 | featurizer = RadialBasisFunction2D(10, 11, 1) 51 | A = torch.FloatTensor(np.array([[5, 5], [3, 6]])) 52 | print(featurizer.inverse_transform(featurizer.transform(A))) 53 | 54 | 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning One Representation to Optimize All Rewards 2 | This repo contains code for the paper 3 | 4 | [Learning One Representation to Optimize All Rewards. 5 | Ahmed Touati, Yann Ollivier. NeurIPS 2021](https://arxiv.org/pdf/2103.07945.pdf) 6 | 7 | ## Install Requirements 8 | 9 | ```bash 10 | pip install 'gym[atari]' 11 | pip install torch 12 | pip install opencv-python 13 | # Baselines for Atari preprocessing 14 | # Tensorflow is a dependency, but you don't need to install the GPU version 15 | conda install tensorflow 16 | pip install git+git://github.com/openai/baselines 17 | # AtariARI (Atari Annotated RAM Interface) 18 | pip install git+git://github.com/mila-iqia/atari-representation-learning.git 19 | ``` 20 | 21 | ## Instruction to run the code 22 | If you want to use GPU, just add the flag `--cuda`. 23 | 1. train **discrete maze**: 24 | ```bash 25 | python grid_main.py \ 26 | --agent FB \ 27 | --n-cycles 25 \ 28 | --n-test-rollouts 10 \ 29 | --num-rollouts-per-cycle 4 \ 30 | --update-eps 1 \ 31 | --soft-update \ 32 | --temp 200 \ 33 | --seed 0 \ 34 | --gamma 0.99 \ 35 | --lr 0.0005 \ 36 | --polyak 0.95 \ 37 | --embed-dim 100 \ 38 | --w-sampling cauchy_ball \ 39 | --n-epochs 200 \ 40 | ``` 41 | 2. train **continuous maze**: 42 | ```bash 43 | python continuous_main.py \ 44 | --agent FB \ 45 | --n-cycles 25 \ 46 | --n-test-rollouts 10 \ 47 | --num-rollouts-per-cycle 4 \ 48 | --update-eps 1 \ 49 | --soft-update \ 50 | --temp 200 \ 51 | --seed 0 \ 52 | --gamma 0.99 \ 53 | --lr 0.0005 \ 54 | --polyak 0.95 \ 55 | --embed-dim 100 \ 56 | --w-sampling cauchy_ball \ 57 | --n-epochs 200 \ 58 | ``` 59 | 3. train **atari**: 60 | ```bash 61 | python atari_main.py \ 62 | --agent FB \ 63 | --n-cycles 25 \ 64 | --n-test-rollouts 10 \ 65 | --num-rollouts-per-cycle 2 \ 66 | --update-eps 0.2 \ 67 | --soft-update \ 68 | --temp 200 \ 69 | --seed 0 \ 70 | --gamma 0.9 \ 71 | --lr 0.0005 \ 72 | --polyak 0.95 \ 73 | --embed-dim 100 \ 74 | --w-sampling cauchy_ball \ 75 | --n-epochs 200 \ 76 | ``` 77 | -------------------------------------------------------------------------------- /atari_modules/her.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class her_sampler: 5 | def __init__(self, replay_strategy, replay_k, reward_func=None): 6 | self.replay_strategy = replay_strategy 7 | self.replay_k = replay_k 8 | if self.replay_strategy == 'future': 9 | self.future_p = 1 - (1. / (1 + replay_k)) 10 | elif self.replay_strategy == 'none': 11 | self.future_p = 0 12 | else: 13 | raise NotImplementedError() 14 | self.reward_func = reward_func 15 | 16 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 17 | T = episode_batch['actions'].shape[1] 18 | rollout_batch_size = episode_batch['actions'].shape[0] 19 | batch_size = batch_size_in_transitions 20 | # select which rollouts and which timesteps to be used 21 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 22 | t_samples = np.random.randint(T, size=batch_size) 23 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()} 24 | # her idx 25 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p) 26 | # import pdb 27 | # pdb.set_trace() 28 | # future_offset = np.random.uniform(size=batch_size) * (T - t_samples) 29 | # future_offset = future_offset.astype(int) 30 | future_offset = T - t_samples 31 | for i in range(batch_size): 32 | ends = np.where(episode_batch['done'][episode_idxs[i]][t_samples[i]:])[0] 33 | if len(ends) > 0: 34 | future_offset[i] = ends[0] 35 | future_offset = (np.random.uniform(size=batch_size) * future_offset).astype(int) 36 | future_t = (t_samples + 1 + future_offset)[her_indexes] 37 | # replace go with achieved goal 38 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] 39 | transitions['g'][her_indexes] = future_ag 40 | # to get the params to re-compute reward 41 | transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], None), 1) 42 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()} 43 | 44 | return transitions 45 | -------------------------------------------------------------------------------- /discrete_action_robots_modules/normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class normalizer: 4 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf): 5 | self.size = size 6 | self.eps = eps 7 | self.default_clip_range = default_clip_range 8 | # some local information 9 | self.local_sum = np.zeros(self.size, np.float32) 10 | self.local_sumsq = np.zeros(self.size, np.float32) 11 | self.local_count = np.zeros(1, np.float32) 12 | # get the total sum sumsq and sum count 13 | self.total_sum = np.zeros(self.size, np.float32) 14 | self.total_sumsq = np.zeros(self.size, np.float32) 15 | self.total_count = np.ones(1, np.float32) 16 | # get the mean and std 17 | self.mean = np.zeros(self.size, np.float32) 18 | self.std = np.ones(self.size, np.float32) 19 | 20 | # update the parameters of the normalizer 21 | def update(self, v): 22 | v = v.reshape(-1, self.size) 23 | # do the computing 24 | self.local_sum += v.sum(axis=0) 25 | self.local_sumsq += (np.square(v)).sum(axis=0) 26 | self.local_count[0] += v.shape[0] 27 | 28 | def recompute_stats(self): 29 | local_count = self.local_count.copy() 30 | local_sum = self.local_sum.copy() 31 | local_sumsq = self.local_sumsq.copy() 32 | # reset 33 | self.local_count[...] = 0 34 | self.local_sum[...] = 0 35 | self.local_sumsq[...] = 0 36 | # synrc the stats 37 | sync_sum, sync_sumsq, sync_count = local_sum, local_sumsq, local_count 38 | # update the total stuff 39 | self.total_sum += sync_sum 40 | self.total_sumsq += sync_sumsq 41 | self.total_count += sync_count 42 | # calculate the new mean and std 43 | self.mean = self.total_sum / self.total_count 44 | self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square( 45 | self.total_sum / self.total_count))) 46 | 47 | # normalize the observation 48 | def normalize(self, v, clip_range=None): 49 | if clip_range is None: 50 | clip_range = self.default_clip_range 51 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range) 52 | # return np.clip(v, -clip_range, clip_range) 53 | -------------------------------------------------------------------------------- /atari_modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # define the critic network 7 | class critic(nn.Module): 8 | def __init__(self, env_params): 9 | super(critic, self).__init__() 10 | self.conv1 = nn.Conv2d(env_params['obs'][-1], 32, 8, stride=4, padding=1) 11 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 12 | self.conv3 = nn.Conv2d(64, 64, 3) 13 | self.fc1 = nn.Linear(3136 + env_params['goal'], 512) 14 | self.fc2 = nn.Linear(512, env_params['action']) 15 | 16 | def forward(self, obs, goal): 17 | x = F.relu(self.conv1(obs)) 18 | x = F.relu(self.conv2(x)) 19 | x = F.relu(self.conv3(x)) 20 | x = x.reshape(-1, 3136) 21 | x = torch.cat([x, goal], dim=1) 22 | x = F.relu(self.fc1(x)) 23 | x = self.fc2(x) 24 | return x 25 | 26 | 27 | class BackwardMap(nn.Module): 28 | def __init__(self, env_params, embed_dim): 29 | super(BackwardMap, self).__init__() 30 | self.fc1 = nn.Linear(env_params['goal'], 256) 31 | self.fc2 = nn.Linear(256, 256) 32 | self.fc3 = nn.Linear(256, 256) 33 | self.backward_out = nn.Linear(256, embed_dim) 34 | 35 | def forward(self, x): 36 | x = F.relu(self.fc1(x)) 37 | x = F.relu(self.fc2(x)) 38 | x = F.relu(self.fc3(x)) 39 | backward_value = self.backward_out(x) 40 | return backward_value 41 | 42 | 43 | class ForwardMap(nn.Module): 44 | def __init__(self, env_params, embed_dim): 45 | super(ForwardMap, self).__init__() 46 | self.embed_dim = embed_dim 47 | self.num_actions = env_params['action'] 48 | self.conv1 = nn.Conv2d(env_params['obs'][-1], 32, 8, stride=4, padding=1) 49 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 50 | self.conv3 = nn.Conv2d(64, 64, 3) 51 | self.fc1 = nn.Linear(3136 + embed_dim, 512) 52 | self.forward_out = nn.Linear(512, embed_dim * env_params['action']) 53 | 54 | def forward(self, obs, w): 55 | w = w / torch.sqrt(1 + torch.norm(w, dim=-1, keepdim=True) ** 2 / self.embed_dim) 56 | x = F.relu(self.conv1(obs)) 57 | x = F.relu(self.conv2(x)) 58 | x = F.relu(self.conv3(x)) 59 | x = x.reshape(-1, 3136) 60 | x = torch.cat([x, w], dim=1) 61 | x = F.relu(self.fc1(x)) 62 | forward_value = self.forward_out(x) 63 | return forward_value.reshape(-1, self.embed_dim, self.num_actions) 64 | -------------------------------------------------------------------------------- /discrete_action_robots_modules/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import numpy as np 3 | 4 | """ 5 | the replay buffer here is basically from the openai baselines code 6 | 7 | """ 8 | 9 | 10 | class replay_buffer: 11 | def __init__(self, env_params, buffer_size, sample_func): 12 | self.env_params = env_params 13 | self.T = env_params['max_timesteps'] 14 | self.size = buffer_size // self.T 15 | # memory management 16 | self.current_size = 0 17 | self.n_transitions_stored = 0 18 | self.sample_func = sample_func 19 | # create the buffer to store info 20 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 21 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]), 22 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 23 | 'actions': np.empty([self.size, self.T]), 24 | } 25 | # thread lock 26 | self.lock = threading.Lock() 27 | 28 | # store the episode 29 | def store_episode(self, episode_batch): 30 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch 31 | batch_size = mb_obs.shape[0] 32 | with self.lock: 33 | idxs = self._get_storage_idx(inc=batch_size) 34 | # store the informations 35 | self.buffers['obs'][idxs] = mb_obs 36 | self.buffers['ag'][idxs] = mb_ag 37 | self.buffers['g'][idxs] = mb_g 38 | self.buffers['actions'][idxs] = mb_actions 39 | self.n_transitions_stored += self.T * batch_size 40 | 41 | # sample the data from the replay buffer 42 | def sample(self, batch_size): 43 | temp_buffers = {} 44 | with self.lock: 45 | for key in self.buffers.keys(): 46 | temp_buffers[key] = self.buffers[key][:self.current_size] 47 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :] 48 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :] 49 | # sample transitions 50 | transitions = self.sample_func(temp_buffers, batch_size) 51 | return transitions 52 | 53 | def _get_storage_idx(self, inc=None): 54 | inc = inc or 1 55 | if self.current_size + inc <= self.size: 56 | idx = np.arange(self.current_size, self.current_size + inc) 57 | elif self.current_size < self.size: 58 | overflow = inc - (self.size - self.current_size) 59 | idx_a = np.arange(self.current_size, self.size) 60 | idx_b = np.random.randint(0, self.current_size, overflow) 61 | idx = np.concatenate([idx_a, idx_b]) 62 | else: 63 | idx = np.random.randint(0, self.size, inc) 64 | self.current_size = min(self.size, self.current_size + inc) 65 | if inc == 1: 66 | idx = idx[0] 67 | return idx 68 | -------------------------------------------------------------------------------- /grid_modules/gridworld/txt_utilities.py: -------------------------------------------------------------------------------- 1 | """Utilities to help load gridworlds from a text file with random goal. 2 | """ 3 | from grid_modules.gridworld.helper_utilities import flatten_state 4 | from grid_modules.gridworld.builder_tools import TransitionMatrixBuilder, create_reward_matrix 5 | from grid_modules.gridworld.env import GridWorldMDP 6 | 7 | 8 | def get_char_matrix(raw_file): 9 | """ 10 | :param raw_file: Either a python file object (open) 11 | or a list of strings containing the lines. 12 | """ 13 | return [[c for c in line.strip('\n')] for line in raw_file] 14 | 15 | 16 | def build_gridworld_from_char_matrix(char_matrix, p_success=1, seed=2017, 17 | gamma=1, skip_checks=False, 18 | transition_matrix_builder_cls=TransitionMatrixBuilder): 19 | """ 20 | A parser to build a gridworld from a text file. 21 | Each grid has ONE start and goal location. 22 | A reward of +1 is positioned at the goal location. 23 | :param char_matrix: Matrix of characters. 24 | :param p_success: Probability that the action is successful. 25 | :param seed: The seed for the GridWorldMDP object. 26 | :param skip_checks: Skips assertion checks. 27 | :transition_matrix_builder_cls: The transition matrix builder to use. 28 | :return: 29 | """ 30 | grid_size = len(char_matrix[0]) 31 | 32 | if not skip_checks: 33 | assert(len(char_matrix) == grid_size), 'Mismatch in the columns.' 34 | for row in char_matrix: 35 | assert(len(row) == grid_size), 'Mismatch in the rows.' 36 | # ... 37 | wall_locs = [] 38 | start_loc = None 39 | goal_loc = None 40 | for r in range(grid_size): 41 | for c in range(grid_size): 42 | char = char_matrix[r][c] 43 | if char == '#': 44 | wall_locs.append((r, c)) 45 | elif char == 's': 46 | assert start_loc is None, 'Start loc was overwritten!' 47 | start_loc = (r, c) 48 | elif char == 'g': 49 | assert goal_loc is None, 'Goal loc was overwritten!' 50 | goal_loc = (r, c) 51 | elif char != ' ': 52 | raise ValueError('Unknown character {} in grid.'.format(char)) 53 | # Attempt to make the desired gridworld. 54 | if goal_loc: 55 | reward_spec = {(goal_loc[0], goal_loc[1]): +1} 56 | else: 57 | reward_spec = {} 58 | 59 | tmb = transition_matrix_builder_cls(grid_size, has_terminal_state=False) 60 | tmb.add_grid(terminal_states=[], p_success=p_success) 61 | for (r, c) in wall_locs: 62 | tmb.add_wall_at((r, c)) 63 | P = tmb.P 64 | R = create_reward_matrix(P.shape[0], grid_size, reward_spec, action_space=5) 65 | p0 = flatten_state(start_loc, grid_size, R.shape[0]) 66 | 67 | gw = GridWorldMDP(P, R, gamma, p0, terminal_states=[], 68 | size=grid_size, wall_locs=wall_locs, 69 | goal_loc=goal_loc, seed=seed) 70 | return gw 71 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | """ 4 | Here are the param for the training 5 | 6 | """ 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | # the environment setting 12 | parser.add_argument('--agent', type=str, default='FB', help='[DQN, Z, FB]') 13 | parser.add_argument('--num-gpi', type=int, default=10, help='number of behaviors for GPI') 14 | parser.add_argument('--w-sampling', type=str, default='cauchy_ball', help='[uniform_ball, goal_oriented]') 15 | parser.add_argument('--embed-dim', type=int, default=500, help='embedding dimension') 16 | parser.add_argument('--reg-coef', type=float, default=1.0, help='backward regularization coefficient') 17 | parser.add_argument('--n-epochs', type=int, default=200, help='the number of epochs to train the agent') 18 | parser.add_argument('--n-cycles', type=int, default=5, help='the times to collect samples per epoch') 19 | parser.add_argument('--n-batches', type=int, default=40, help='the times to update the network') 20 | parser.add_argument('--save-interval', type=int, default=5, help='the interval that save the trajectory') 21 | parser.add_argument('--seed', type=int, default=123, help='random seed') 22 | parser.add_argument('--num-workers', type=int, default=1, help='the number of cpus to collect samples') 23 | parser.add_argument('--replay-strategy', type=str, default='none', help='the HER strategy') 24 | parser.add_argument('--clip-return', type=float, default=50, help='if clip the returns') 25 | parser.add_argument('--save-dir', type=str, default='saved_models', help='the path to save the models') 26 | parser.add_argument('--noise-eps', type=float, default=0.2, help='noise eps') 27 | parser.add_argument('--random-eps', type=float, default=0.3, help='random eps') 28 | parser.add_argument('--buffer-size', type=int, default=int(1e6), help='the size of the buffer') 29 | parser.add_argument('--replay-k', type=int, default=4, help='ratio to be replace') 30 | parser.add_argument('--clip-obs', type=float, default=200, help='the clip ratio') 31 | parser.add_argument('--batch-size', type=int, default=128, help='the sample batch size') 32 | parser.add_argument('--test-batch-size', type=int, default=1000, help='the sample test batch size') 33 | parser.add_argument('--gamma', type=float, default=0.9, help='the discount factor') 34 | parser.add_argument('--action-l2', type=float, default=1, help='l2 reg') 35 | parser.add_argument('--lr', type=float, default=0.001, help='the learning rate') 36 | parser.add_argument('--polyak', type=float, default=0.95, help='the average coefficient') 37 | parser.add_argument('--n-test-rollouts', type=int, default=10, help='the number of tests') 38 | parser.add_argument('--clip-range', type=float, default=5, help='the clip range') 39 | parser.add_argument('--demo-length', type=int, default=20, help='the demo length') 40 | parser.add_argument('--cuda', action='store_true', help='if use gpu do the acceleration') 41 | parser.add_argument('--soft-update', action='store_true', help='if use soft bellman backup') 42 | parser.add_argument('--num-rollouts-per-cycle', type=int, default=2, help='the rollouts per mpi') 43 | parser.add_argument('--temp', type=float, default=200, help='Boltzmann temperature') 44 | parser.add_argument('--update-eps', type=float, default=0.2, help='exploration epsilon') 45 | 46 | args = parser.parse_args() 47 | 48 | return args 49 | -------------------------------------------------------------------------------- /discrete_action_robots_modules/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | """ 6 | the input x in both networks should be [o, g], where o is the observation and g is the goal. 7 | 8 | """ 9 | 10 | 11 | class critic(nn.Module): 12 | def __init__(self, env_params): 13 | super(critic, self).__init__() 14 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256) 15 | self.fc2 = nn.Linear(256, 256) 16 | self.fc3 = nn.Linear(256, 256) 17 | self.q_out = nn.Linear(256, env_params['action']) 18 | 19 | def forward(self, obs, g): 20 | x = torch.cat([obs, g], dim=1) 21 | x = F.relu(self.fc1(x)) 22 | x = F.relu(self.fc2(x)) 23 | x = F.relu(self.fc3(x)) 24 | q_value = self.q_out(x) 25 | return q_value 26 | 27 | 28 | class VMap(nn.Module): 29 | def __init__(self, env_params, embed_dim): 30 | super(VMap, self).__init__() 31 | self.embed_dim = embed_dim 32 | self.fc1 = nn.Linear(env_params['obs'] + embed_dim + env_params['goal'], 256) 33 | self.fc2 = nn.Linear(256, 256) 34 | self.fc3 = nn.Linear(256, 256) 35 | self.v_out = nn.Linear(256, 1) 36 | 37 | def forward(self, obs, w, g): 38 | w = w / torch.sqrt(1 + torch.norm(w, dim=-1, keepdim=True) ** 2 / self.embed_dim) 39 | x = torch.cat([obs, w, g], dim=1) 40 | x = F.relu(self.fc1(x)) 41 | x = F.relu(self.fc2(x)) 42 | x = F.relu(self.fc3(x)) 43 | v_value = self.v_out(x) 44 | return v_value 45 | 46 | 47 | class ZMap(nn.Module): 48 | def __init__(self, env_params): 49 | super(ZMap, self).__init__() 50 | self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['goal'], 256) 51 | self.fc2 = nn.Linear(256, 256) 52 | self.fc3 = nn.Linear(256, 256) 53 | self.z_out = nn.Linear(256, env_params['action']) 54 | 55 | def forward(self, obs, g, g_other): 56 | assert g.shape[-1] == g_other.shape[-1] 57 | x = torch.cat([obs, g, g_other], dim=1) 58 | x = F.relu(self.fc1(x)) 59 | x = F.relu(self.fc2(x)) 60 | x = F.relu(self.fc3(x)) 61 | z_value = self.z_out(x) 62 | return z_value 63 | 64 | 65 | class BackwardMap(nn.Module): 66 | def __init__(self, env_params, embed_dim): 67 | super(BackwardMap, self).__init__() 68 | self.fc1 = nn.Linear(env_params['goal'], 256) 69 | self.fc2 = nn.Linear(256, 256) 70 | self.fc3 = nn.Linear(256, 256) 71 | self.backward_out = nn.Linear(256, embed_dim) 72 | 73 | def forward(self, x): 74 | x = F.relu(self.fc1(x)) 75 | x = F.relu(self.fc2(x)) 76 | x = F.relu(self.fc3(x)) 77 | backward_value = self.backward_out(x) 78 | return backward_value 79 | 80 | 81 | class ForwardMap(nn.Module): 82 | def __init__(self, env_params, embed_dim): 83 | super(ForwardMap, self).__init__() 84 | self.embed_dim = embed_dim 85 | self.num_actions = env_params['action'] 86 | self.fc1 = nn.Linear(env_params['obs'] + embed_dim, 256) 87 | self.fc2 = nn.Linear(256, 256) 88 | self.fc3 = nn.Linear(256, 256) 89 | self.forward_out = nn.Linear(256, embed_dim * env_params['action']) 90 | 91 | def forward(self, obs, w): 92 | w = w / torch.sqrt(1 + torch.norm(w, dim=-1, keepdim=True) ** 2 / self.embed_dim) 93 | x = torch.cat([obs, w], dim=1) 94 | x = F.relu(self.fc1(x)) 95 | x = F.relu(self.fc2(x)) 96 | x = F.relu(self.fc3(x)) 97 | forward_value = self.forward_out(x) 98 | 99 | return forward_value.reshape(-1, self.embed_dim, self.num_actions) -------------------------------------------------------------------------------- /grid_modules/gridworld/env.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple grid world environment 3 | """ 4 | 5 | import numpy as np 6 | from grid_modules.common import MDP 7 | from .helper_utilities import flatten_state, unflatten_state, get_reachable_id,\ 8 | from_xy_to_id, from_id_to_xy 9 | 10 | 11 | class GridWorldMDP(MDP): 12 | def __init__(self, P, R, gamma, p0, terminal_states, size, wall_locs, goal_loc=None, 13 | seed=1337, skip_check=False, convert_terminal_states_to_ints=False): 14 | """ 15 | (!) if terminal_states is not empty then there will be an absorbing state. So 16 | the actual number of states will be size x size + 1 17 | if there is a terminal state, it should be the last one. 18 | :param P: Transition matrix |S| x |A| x |S| 19 | :param R: Transition matrix |S| x |A| 20 | :param gamma: discount factor 21 | :param p0: initial starting distribution 22 | :param terminal_states: Must be a list of (x,y) tuples. use skip_terminal_state_conversion if giving ints 23 | :param size: the size of the grid world (i.e there are size x size (+ 1)= |S| states) 24 | :param seed: 25 | :param skip_check: 26 | """ 27 | if not convert_terminal_states_to_ints: 28 | terminal_states = list(map(lambda tupl: int(size * tupl[0] + tupl[1]), terminal_states)) 29 | self.size = size 30 | self.wall_locs = wall_locs 31 | if goal_loc is not None: 32 | self.goal_loc = goal_loc 33 | self.goal_id = from_xy_to_id(goal_loc, self.size) 34 | self.goal = flatten_state(self.goal_loc, self.size, self.state_space) 35 | self.human_state = (None, None) 36 | self.has_absorbing_state = len(terminal_states) > 0 37 | super().__init__(P, R, gamma, p0, terminal_states, seed=seed, skip_check=skip_check) 38 | self.reachable_states = get_reachable_id(self.state_space, self.size, self.wall_locs) 39 | # set initial state distribution to uniform 40 | p0 = np.zeros(self.state_space) 41 | p0[self.reachable_states] = 1 42 | p0 /= p0.sum() 43 | self.p0 = p0 44 | 45 | def set_goal(self, goal_id): 46 | self.goal_id = goal_id 47 | self.goal_loc = from_id_to_xy(goal_id, self.size) 48 | self.goal = flatten_state(self.goal_loc, self.size, self.state_space) 49 | R = np.zeros((self.state_space, self.action_space)) # S x A 50 | R[goal_id, :] = 1 51 | self.R = R 52 | 53 | def reset(self): 54 | super().reset() 55 | self.goal_id = self.sample_goal() 56 | self.goal_loc = from_id_to_xy(self.goal_id, self.size) 57 | self.goal = flatten_state(self.goal_loc, self.size, self.state_space) 58 | R = np.zeros((self.state_space, self.action_space)) # S x A 59 | R[self.goal_id, :] = 1 60 | self.R = R 61 | self.human_state = self.unflatten_state(self.current_state) 62 | return self.current_state 63 | 64 | def sample_goal(self): 65 | goal = np.random.choice(self.reachable_states, size=1, replace=False) 66 | return goal 67 | 68 | def flatten_state(self, state): 69 | """Flatten state (x,y) into a one hot vector""" 70 | return flatten_state(state, self.size, self.state_space) 71 | 72 | def unflatten_state(self, onehot): 73 | """Unflatten a one hot vector into a (x,y) pair""" 74 | return unflatten_state(onehot, self.size, self.has_absorbing_state) 75 | 76 | def step(self, action): 77 | state, reward, done, info = super().step(action) 78 | self.human_state = self.unflatten_state(self.current_state) 79 | return state, reward, done, info 80 | 81 | def set_current_state_to(self, tuple_state): 82 | return super().set_current_state_to(self.flatten_state(tuple_state).argmax()) -------------------------------------------------------------------------------- /grid_modules/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from grid_modules import utils 3 | from grid_modules.exceptions import InvalidActionError, EpisodeDoneError 4 | 5 | 6 | class Env(object): 7 | """ 8 | Abstract Environment wrapper. 9 | """ 10 | def __init__(self, seed): 11 | """ 12 | :param seed: A seed for the random number generator. 13 | """ 14 | self.set_seed(seed) 15 | 16 | def set_seed(self, seed): 17 | self.rng = np.random.RandomState(seed) 18 | 19 | 20 | class MDP(Env): 21 | def __init__(self, P, R, gamma, p0, terminal_states, seed=1337, skip_check=False): 22 | """ 23 | A simple MDP simulator. 24 | :param P: The transition matrix of size |S|x|A|x|S| 25 | :param R: The reward criterion |S|x|A| 26 | :param gamma: the discount factor. 27 | :param p0: the distribution over starting states |S| (must sum to 1.) 28 | :param terminal_states: A list of integers which indicate terminal states, used to end episodes. 29 | Note that in the transition matrix these 30 | should be absorbing states to ensure calculations are correct. 31 | :param seed: the random seed for simulations. 32 | """ 33 | super().__init__(seed) 34 | if not skip_check: assert np.allclose(P.sum(axis=2), 1), 'Transition matrix does not seem to be a stochastic matrix ' \ 35 | '(i.e. the sum over states for each action doesn not equal 1' 36 | self.P = P 37 | self.R = R 38 | self.state_space = P.shape[0] 39 | self.action_space = R.shape[1] 40 | if not skip_check: assert self.state_space == P.shape[2], '3rd Dimension of Transition Matrix is not of size |S|' 41 | if not skip_check: assert self.action_space == P.shape[1], '2nd Dimension of Transition Matrix is not of size |A|' 42 | if not skip_check: assert self.state_space == R.shape[0], '1st Dimesnion of Reward Matrix is not of size |S|' 43 | self.gamma = gamma 44 | if not skip_check: assert self.state_space == p0.shape[0], 'Distribution over initial states is not over |S|' 45 | self.p0 = p0 46 | self.terminal_states = terminal_states 47 | self.current_state = None 48 | # self.reset() 49 | 50 | def reset(self): 51 | integer_representation = np.random.choice(np.arange(self.state_space), p=self.p0) 52 | self.current_state = utils.convert_int_rep_to_onehot(integer_representation, self.state_space) 53 | self.done = False 54 | return self.current_state 55 | 56 | def set_current_state_to(self, state): 57 | self.current_state = utils.convert_int_rep_to_onehot(state, self.state_space) 58 | self.done = False 59 | return self.current_state 60 | 61 | def step(self, action): 62 | """ 63 | :param action: An integer representing the action taken. 64 | :return: 65 | """ 66 | if self.done: 67 | raise EpisodeDoneError('The episode has terminated. Use .reset() to restart the episode.') 68 | if action >= self.action_space or not isinstance(action, int): 69 | raise InvalidActionError('Invalid action {}. It must be an integer between 0 and {}'.format(action, self.action_space-1)) 70 | 71 | # we end from this episode onwards. 72 | # this check is done after entering terminal state 73 | # because we can only give the reward after leaving 74 | # a terminal state. 75 | if self.current_state.argmax() in self.terminal_states: 76 | self.done = True 77 | 78 | # get the vector representing the next state probabilities: 79 | current_state_idx = utils.convert_onehot_to_int(self.current_state) 80 | next_state_probs = self.P[current_state_idx, action] 81 | 82 | # sample the next state 83 | sampled_next_state = self.rng.choice(np.arange(self.state_space), p=next_state_probs) 84 | # observe the reward 85 | reward = self.R[current_state_idx, action] 86 | 87 | self.current_state = utils.convert_int_rep_to_onehot(sampled_next_state, self.state_space) 88 | 89 | return self.current_state, reward, self.done, {'gamma':self.gamma} -------------------------------------------------------------------------------- /continuous_world_modules/geometry.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import numpy as np 3 | from typing import Optional, List, Tuple, Any, Dict, Union, Callable 4 | 5 | #@title Point Class 6 | 7 | 8 | @dataclasses.dataclass(order=True, frozen=True) 9 | class Point: 10 | """A class representing a point in 2D space. 11 | 12 | Comes with some convenience functions. 13 | """ 14 | x: float 15 | y: float 16 | 17 | def sum(self): 18 | return self.x + self.y 19 | 20 | def l2norm(self): 21 | """Computes the L2 norm of the point.""" 22 | return np.sqrt(self.x * self.x + self.y * self.y) 23 | 24 | def __add__(self, other: 'Point'): 25 | return Point(self.x + other.x, self.y + other.y) 26 | 27 | def __sub__(self, other: 'Point'): 28 | return Point(self.x - other.x, self.y - other.y) 29 | 30 | def normal_sample_around(self, scale: float): 31 | """Samples a point around the current point based on some noise.""" 32 | new_coords = np.random.normal(dataclasses.astuple(self), scale) 33 | new_coords = new_coords.astype(np.float32) 34 | return Point(*new_coords) 35 | 36 | def is_close_to(self, other: 'Point', diff: float = 1e-4): 37 | """Determines if one point is close to another.""" 38 | point_diff = self - other 39 | if abs(point_diff.x) <= diff and abs(point_diff.y) <= diff: 40 | return True 41 | else: 42 | return False 43 | 44 | # # Intersection code. 45 | # See Sedgewick, Robert, and Kevin Wayne. Algorithms. , 2011. 46 | # Chapter 6.1 on Geometric Primitives 47 | # https://algs4.cs.princeton.edu/91primitives/ 48 | 49 | 50 | def on_segment(a: Point, b: Point, c: Point): 51 | x1, x2, x3 = a.x, b.x, c.x 52 | 53 | y1, y2, y3 = a.y, b.y, c.y 54 | 55 | if x1 == x2: 56 | on_and_between = (x3 == x2) and (y1 <= y3 <= y2) 57 | else: 58 | slope = (y2 - y1) / (x2 - x1) 59 | 60 | pt3_on = (y3 - y1) == slope * (x3 - x1) 61 | 62 | pt3_between = (min(x1, x2) <= x3 <= max(x1, x2)) and (min(y1, y2) <= y3 <= max(y1, y2)) 63 | on_and_between = pt3_on and pt3_between 64 | 65 | return on_and_between 66 | 67 | 68 | def _check_counter_clockwise(a: Point, b: Point, c: Point): 69 | """Checks if 3 points are counter clockwise to each other.""" 70 | slope_AB_numerator = (b.y - a.y) 71 | slope_AB_denominator = (b.x - a.x) 72 | slope_AC_numerator = (c.y - a.y) 73 | slope_AC_denominator = (c.x - a.x) 74 | return (slope_AC_numerator * slope_AB_denominator >= \ 75 | slope_AB_numerator * slope_AC_denominator) 76 | 77 | 78 | def intersect(segment_1: Tuple[Point, Point], segment_2: Tuple[Point, Point]): 79 | """Checks if two line segments intersect.""" 80 | a, b = segment_1 81 | c, d = segment_2 82 | 83 | if on_segment(a, b, c) or on_segment(a, b, d) or on_segment(c, d, a) or on_segment(c, d, b): 84 | return True 85 | 86 | # Checking if there is an intersection is equivalent to: 87 | # Exactly one counter clockwise path to D (from A or B) via C. 88 | AC_ccw_CD = _check_counter_clockwise(a, c, d) 89 | BC_ccw_CD = _check_counter_clockwise(b, c, d) 90 | toD_via_C = AC_ccw_CD != BC_ccw_CD 91 | 92 | # AND 93 | # Exactly one counterclockwise path from A (to C or D) via B. 94 | AB_ccw_BC = _check_counter_clockwise(a, b, c) 95 | AB_ccw_BD = _check_counter_clockwise(a, b, d) 96 | 97 | fromA_via_B = AB_ccw_BC != AB_ccw_BD 98 | 99 | return toD_via_C and fromA_via_B 100 | 101 | 102 | # Test the points. 103 | z1 = Point(0.4, 0.1) 104 | assert z1.is_close_to(z1) 105 | assert z1.is_close_to(Point(0.5, 0.0), 1.0) 106 | assert not z1.is_close_to(Point(5.0, 0.0), 1.0) 107 | z2 = Point(0.1, 0.1) 108 | z3 = z1 - z2 109 | assert isinstance(z3, Point) 110 | assert z3.is_close_to(Point(0.3, 0.0)) 111 | assert isinstance(z3.normal_sample_around(0.1), Point) 112 | 113 | 114 | # Some simple tests to ensure everything is working. 115 | assert not intersect((Point(1, 0), Point(1, 1)), (Point(0,0), Point(0, 1))), \ 116 | 'Parallel lines detected as intersecting.' 117 | assert not intersect((Point(0, 0), Point(1, 0)), (Point(0,1), Point(1, 1))), \ 118 | 'Parallel lines detected as intersecting.' 119 | assert intersect((Point(3, 5), Point(1, 1)), (Point(2, 2), Point(0, 1))), \ 120 | 'Lines that intersect not detected.' 121 | assert not intersect((Point(0, 0), Point(2, 2)), (Point(3, 3), Point(5, 1))), \ 122 | 'Lines that do not intersect detected as intersecting' 123 | assert intersect((Point(0, .5), Point(0, -.5)), (Point(.5, 0), Point(-.5, 0.))), \ 124 | 'Lines that intersect not detected.' -------------------------------------------------------------------------------- /grid_modules/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from grid_modules.gridworld import actions as ACTIONS 3 | from grid_modules.gridworld.helper_utilities import from_id_to_xy 4 | 5 | asbestos = (127 / 255, 140 / 255, 141 / 255, 0.8) 6 | carrot = (235 / 255, 137 / 255, 33 / 255, 0.8) 7 | emerald = (80 / 255, 200 / 255, 120 / 255, 0.8) 8 | red = (255 / 255, 0 / 255, 0 / 255, 0.8) 9 | 10 | marker_style = dict(linestyle=':', color=carrot, markersize=15) 11 | 12 | DEFAULT_ARROW_COLOR = '#1ABC9C' 13 | 14 | 15 | def plot_environment( 16 | mdp, ax, wall_locs=None, goal_ids=None, initial_state=None, door_ids=None, plot_grid=False, 17 | grid_kwargs=None, 18 | wall_color=asbestos # (127 / 255, 140 /255, 141 / 255 , 0.8), # R, G, B, alpha 19 | ): 20 | """Function to plot emdp environment 21 | 22 | Args: 23 | mdp: The MDP to use. 24 | ax: The axes to plot this on. 25 | wall_locs: Locations of the walls for plotting them in a different color. 26 | plot_grid: Boolean indicating if the overlay grid should be plotted. 27 | grid_kwargs: Grid keyword argrument specification. 28 | wall_color: RGB color of the walls. 29 | 30 | Returns: 31 | ax: The axes of the final plot. 32 | imshow_ax: The final plot. 33 | """ 34 | grid_kwargs = grid_kwargs or {} 35 | 36 | # Plot states with white background. 37 | state_background = np.ones((mdp.size, mdp.size)) 38 | 39 | # Walls appear in a different color. 40 | wall_img = np.ones((mdp.size, mdp.size, 4)) 41 | if wall_locs is not None: 42 | for state in wall_locs: 43 | y_coord = state[0] 44 | x_coord = state[1] 45 | wall_img[y_coord, x_coord, :] = np.array(wall_color) 46 | 47 | # Render the heatmap and overlay the walls. 48 | imshow_ax = ax.imshow(state_background, interpolation=None) 49 | imshow_ax = ax.imshow(wall_img, interpolation=None) 50 | 51 | # add initial state 52 | if initial_state is None: 53 | initial_state = mdp.reset() 54 | y_coord, x_coord = mdp.unflatten_state(initial_state) 55 | ax.plot(x_coord, y_coord, marker='H', **marker_style) 56 | 57 | # add door state 58 | if door_ids is not None: 59 | for door_id in door_ids: 60 | y_coord, x_coord = from_id_to_xy(door_id, size=mdp.size) 61 | ax.plot(x_coord, y_coord, marker='s', color=red, markersize=10) 62 | 63 | # add goal state 64 | if goal_ids is not None: 65 | for goal_id in goal_ids: 66 | y_coord, x_coord = from_id_to_xy(goal_id, size=mdp.size) 67 | ax.plot(x_coord, y_coord, marker='*', **marker_style) 68 | 69 | ax.grid(False) 70 | 71 | # Switch on flag if you want to plot grid 72 | if plot_grid: 73 | for i in range(mdp.size + 1): 74 | ax.plot( 75 | np.arange(mdp.size + 1) - 0.5, 76 | np.ones(mdp.size + 1) * i - 0.5, 77 | **grid_kwargs) 78 | for i in range(mdp.size + 1): 79 | ax.plot( 80 | np.ones(mdp.size + 1) * i - 0.5, 81 | np.arange(mdp.size + 1) - 0.5, 82 | **grid_kwargs) 83 | ax.set_xlabel('x') 84 | ax.set_ylabel('y') 85 | 86 | return ax, imshow_ax 87 | 88 | 89 | def get_current_state_integer(state_): 90 | return np.argmax(state_, axis=0) 91 | 92 | 93 | def _is_absorbing(state_int, mdp_size): 94 | """Checks if the state_int is an absorbing state""" 95 | return state_int == mdp_size * mdp_size 96 | 97 | 98 | def _checking_P(P): 99 | """Checks if the P matrix is valid.""" 100 | assert np.all(P <= 1.0) and np.all(P >= 0.0) 101 | assert not np.allclose(P, 1.0) 102 | assert not np.allclose(P, 0.0) 103 | 104 | 105 | def plot_action(ax, y_pos, x_pos, a, headwidth=5, linewidths=1, scale=1.9, headlength=4): 106 | left_arrow = (-0.6, 0) 107 | right_arrow = (0.6, 0) 108 | up_arrow = (0, -0.6) 109 | down_arrow = (0, 0.6) 110 | if a == ACTIONS.LEFT: # Left 111 | ax.quiver( 112 | x_pos, y_pos, *left_arrow, color=DEFAULT_ARROW_COLOR, alpha=1.0, 113 | angles='xy', scale_units='xy', scale=scale, 114 | headwidth=headwidth, linewidths=linewidths, 115 | headlength=headlength) #L 116 | if a == ACTIONS.RIGHT: #Right 117 | ax.quiver( 118 | x_pos, y_pos, *right_arrow, color=DEFAULT_ARROW_COLOR, alpha=1.0, 119 | angles='xy', scale_units='xy', scale=scale, 120 | headwidth=headwidth, linewidths=linewidths, 121 | headlength=headlength) #R 122 | if a == ACTIONS.UP: #Up 123 | ax.quiver( 124 | x_pos, y_pos, *up_arrow, color=DEFAULT_ARROW_COLOR, alpha=1.0, 125 | angles='xy', scale_units='xy', scale=scale, 126 | headwidth=headwidth, linewidths=linewidths, 127 | headlength=headlength) #U 128 | if a == ACTIONS.DOWN: #Down 129 | ax.quiver( 130 | x_pos, y_pos, *down_arrow, color=DEFAULT_ARROW_COLOR, alpha=1.0, 131 | angles='xy', scale_units='xy', scale=scale, 132 | headwidth=headwidth, linewidths=linewidths, 133 | headlength=headlength) #D -------------------------------------------------------------------------------- /grid_modules/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | """ 5 | Slightly different the replay buffer here is basically from the openai baselines code 6 | in order to include done for atari games and discrete actions 7 | 8 | """ 9 | 10 | 11 | class ReplayBuffer(object): 12 | """taken from https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py""" 13 | 14 | def __init__(self, size): 15 | """Create Replay buffer. 16 | Parameters 17 | ---------- 18 | size: int 19 | Max number of transitions to store in the buffer. When the buffer 20 | overflows the old memories are dropped. 21 | """ 22 | self._storage = [] 23 | self._maxsize = size 24 | self._next_idx = 0 25 | 26 | def __len__(self): 27 | return len(self._storage) 28 | 29 | def add(self, obs, g, action, reward, obs_next, done): 30 | data = (obs, g, action, reward, obs_next, done) 31 | 32 | if self._next_idx >= len(self._storage): 33 | self._storage.append(data) 34 | else: 35 | self._storage[self._next_idx] = data 36 | self._next_idx = (self._next_idx + 1) % self._maxsize 37 | 38 | def _encode_sample(self, idxes): 39 | obses, gs, actions, rewards, obses_next, dones = [], [], [], [], [], [] 40 | for i in idxes: 41 | data = self._storage[i] 42 | obs, g, action, reward, obs_next, done = data 43 | obses.append(np.array(obs, copy=False)) 44 | # gs.append(g.copy()) 45 | gs.append(g) 46 | actions.append(np.array(action, copy=False)) 47 | rewards.append(reward) 48 | obses_next.append(np.array(obs_next, copy=False)) 49 | dones.append(done) 50 | transitions = {'obs': np.array(obses), 51 | 'g': np.array(gs), 52 | 'action': np.array(actions), 53 | 'reward': np.array(rewards), 54 | 'obs_next': np.array(obses_next), 55 | 'done': np.array(dones)} 56 | return transitions 57 | 58 | def sample(self, batch_size): 59 | """Sample a batch of experiences. 60 | Parameters 61 | ---------- 62 | batch_size: int 63 | How many transitions to sample. 64 | Returns 65 | ------- 66 | obs_batch: np.array 67 | batch of observations 68 | act_batch: np.array 69 | batch of actions executed given obs_batch 70 | rew_batch: np.array 71 | rewards received as results of executing act_batch 72 | next_obs_batch: np.array 73 | next set of observations seen after executing act_batch 74 | done_mask: np.array 75 | done_mask[i] = 1 if executing act_batch[i] resulted in 76 | the end of an episode and 0 otherwise. 77 | """ 78 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 79 | return self._encode_sample(idxes) 80 | 81 | 82 | class her_replay_buffer: 83 | def __init__(self, env_params, buffer_size, sample_func): 84 | self.env_params = env_params 85 | self.T = env_params['max_timesteps'] 86 | self.size = buffer_size // self.T 87 | # memory management 88 | self.current_size = 0 89 | self.n_transitions_stored = 0 90 | self.sample_func = sample_func 91 | # create the buffer to store info 92 | self.buffers = {'obs': np.empty([self.size, self.T + 1, 2]), 93 | 'g': np.empty([self.size, self.T, 2]), 94 | 'action': np.empty([self.size, self.T]) 95 | } 96 | 97 | # store the episode 98 | def store_episode(self, episode_batch): 99 | mb_obs, mb_g, mb_actions = episode_batch 100 | batch_size = mb_obs.shape[0] 101 | idxs = self._get_storage_idx(inc=batch_size) 102 | # store the informations 103 | self.buffers['obs'][idxs] = mb_obs 104 | self.buffers['g'][idxs] = mb_g 105 | self.buffers['action'][idxs] = mb_actions 106 | self.n_transitions_stored += self.T * batch_size 107 | 108 | # sample the data from the replay buffer 109 | def sample(self, batch_size): 110 | temp_buffers = {} 111 | for key in self.buffers.keys(): 112 | temp_buffers[key] = self.buffers[key][:self.current_size] 113 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :] 114 | # sample transitions 115 | transitions = self.sample_func(temp_buffers, batch_size) 116 | return transitions 117 | 118 | def _get_storage_idx(self, inc=None): 119 | inc = inc or 1 120 | if self.current_size + inc <= self.size: 121 | idx = np.arange(self.current_size, self.current_size + inc) 122 | elif self.current_size < self.size: 123 | overflow = inc - (self.size - self.current_size) 124 | idx_a = np.arange(self.current_size, self.size) 125 | idx_b = np.random.randint(0, self.current_size, overflow) 126 | idx = np.concatenate([idx_a, idx_b]) 127 | else: 128 | idx = np.random.randint(0, self.size, inc) 129 | self.current_size = min(self.size, self.current_size + inc) 130 | if inc == 1: 131 | idx = idx[0] 132 | return idx -------------------------------------------------------------------------------- /atari_modules/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import numpy as np 3 | import random 4 | 5 | """ 6 | Slightly different the replay buffer here is basically from the openai baselines code 7 | in order to include done for atari games and discrete actions 8 | 9 | """ 10 | 11 | 12 | class ReplayBuffer(object): 13 | """taken from https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py""" 14 | 15 | def __init__(self, size): 16 | """Create Replay buffer. 17 | Parameters 18 | ---------- 19 | size: int 20 | Max number of transitions to store in the buffer. When the buffer 21 | overflows the old memories are dropped. 22 | """ 23 | self._storage = [] 24 | self._maxsize = size 25 | self._next_idx = 0 26 | 27 | def __len__(self): 28 | return len(self._storage) 29 | 30 | def add(self, obs, ag, g, action, reward, obs_next, done): 31 | data = (obs, ag, g, action, reward, obs_next, done) 32 | 33 | if self._next_idx >= len(self._storage): 34 | self._storage.append(data) 35 | else: 36 | self._storage[self._next_idx] = data 37 | self._next_idx = (self._next_idx + 1) % self._maxsize 38 | 39 | def _encode_sample(self, idxes): 40 | obses, ags, gs, actions, rewards, obses_next, dones = [], [], [], [], [], [], [] 41 | for i in idxes: 42 | data = self._storage[i] 43 | obs, ag, g, action, reward, obs_next, done = data 44 | obses.append(np.array(obs, copy=False)) 45 | ags.append(ag.copy()) 46 | gs.append(g.copy()) 47 | actions.append(np.array(action, copy=False)) 48 | rewards.append(reward) 49 | obses_next.append(np.array(obs_next, copy=False)) 50 | dones.append(done) 51 | transitions = {'obs': np.array(obses), 52 | 'ag': np.array(ags), 53 | 'g': np.array(gs), 54 | 'action': np.array(actions), 55 | 'reward': np.array(rewards), 56 | 'obs_next': np.array(obses_next), 57 | 'done': np.array(dones)} 58 | return transitions 59 | 60 | def sample(self, batch_size): 61 | """Sample a batch of experiences. 62 | Parameters 63 | ---------- 64 | batch_size: int 65 | How many transitions to sample. 66 | Returns 67 | ------- 68 | obs_batch: np.array 69 | batch of observations 70 | act_batch: np.array 71 | batch of actions executed given obs_batch 72 | rew_batch: np.array 73 | rewards received as results of executing act_batch 74 | next_obs_batch: np.array 75 | next set of observations seen after executing act_batch 76 | done_mask: np.array 77 | done_mask[i] = 1 if executing act_batch[i] resulted in 78 | the end of an episode and 0 otherwise. 79 | """ 80 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 81 | return self._encode_sample(idxes) 82 | 83 | 84 | class her_replay_buffer: 85 | def __init__(self, env_params, buffer_size, sample_func): 86 | self.env_params = env_params 87 | self.T = env_params['max_timesteps'] 88 | self.size = buffer_size // self.T 89 | # memory management 90 | self.current_size = 0 91 | self.n_transitions_stored = 0 92 | self.sample_func = sample_func 93 | # create the buffer to store info 94 | self.buffers = {'obs': np.empty([self.size, self.T + 1, *self.env_params['obs']]), 95 | 'ag': np.empty([self.size, self.T + 1, self.env_params['goal']]), 96 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 97 | 'actions': np.empty([self.size, self.T]), 98 | 'done': np.empty([self.size, self.T]) 99 | } 100 | # thread lock 101 | self.lock = threading.Lock() 102 | 103 | # store the episode 104 | def store_episode(self, episode_batch): 105 | mb_obs, mb_ag, mb_g, mb_actions, dones = episode_batch 106 | batch_size = mb_obs.shape[0] 107 | with self.lock: 108 | idxs = self._get_storage_idx(inc=batch_size) 109 | # store the informations 110 | self.buffers['obs'][idxs] = mb_obs 111 | self.buffers['ag'][idxs] = mb_ag 112 | self.buffers['g'][idxs] = mb_g 113 | self.buffers['actions'][idxs] = mb_actions 114 | self.buffers['done'][idxs] = dones 115 | self.n_transitions_stored += self.T * batch_size 116 | 117 | # sample the data from the replay buffer 118 | def sample(self, batch_size): 119 | temp_buffers = {} 120 | with self.lock: 121 | for key in self.buffers.keys(): 122 | temp_buffers[key] = self.buffers[key][:self.current_size] 123 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :] 124 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :] 125 | # sample transitions 126 | transitions = self.sample_func(temp_buffers, batch_size) 127 | return transitions 128 | 129 | def _get_storage_idx(self, inc=None): 130 | inc = inc or 1 131 | if self.current_size + inc <= self.size: 132 | idx = np.arange(self.current_size, self.current_size + inc) 133 | elif self.current_size < self.size: 134 | overflow = inc - (self.size - self.current_size) 135 | idx_a = np.arange(self.current_size, self.size) 136 | idx_b = np.random.randint(0, self.current_size, overflow) 137 | idx = np.concatenate([idx_a, idx_b]) 138 | else: 139 | idx = np.random.randint(0, self.size, inc) 140 | self.current_size = min(self.size, self.current_size + inc) 141 | if inc == 1: 142 | idx = idx[0] 143 | return idx 144 | -------------------------------------------------------------------------------- /grid_modules/gridworld/helper_utilities.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from grid_modules.gridworld.actions import LEFT, RIGHT, UP, DOWN, STAY 3 | from grid_modules.exceptions import InvalidActionError 4 | n_actions = 5 5 | 6 | 7 | def flatten_state(state, size, state_space): 8 | """Flatten state (x,y) into a one hot vector""" 9 | idx = size * state[0] + state[1] 10 | one_hot = np.zeros(state_space) 11 | one_hot[idx] = 1 12 | return one_hot 13 | 14 | 15 | def unflatten_state(onehot, size, has_absorbing_state=False): 16 | """Unflatten a one hot vector into a (x,y) pair""" 17 | if has_absorbing_state: 18 | onehot = onehot[:-1] 19 | onehot = onehot.reshape(size, size) 20 | x = onehot.argmax(0).max() 21 | y = onehot.argmax(1).max() 22 | return (x, y) 23 | 24 | 25 | def from_id_to_xy(state, size): 26 | return state // size, state % size 27 | 28 | 29 | def from_xy_to_id(state_loc, size): 30 | return state_loc[0] * size + state_loc[1] 31 | 32 | 33 | def get_reachable_id(state_space, size, wall_locs): 34 | return [s for s in range(state_space) if from_id_to_xy(s, size) not in wall_locs] 35 | 36 | 37 | def get_state_after_executing_action(action, state, grid_size): 38 | """ 39 | Gets the state after executing an action 40 | :param action: 41 | :param state: 42 | :param grid_size: 43 | :return: 44 | """ 45 | if check_can_take_action(action, state, grid_size): 46 | if action == LEFT: 47 | return state-1 48 | elif action == RIGHT: 49 | return state+1 50 | elif action == UP: 51 | return state - grid_size 52 | elif action == DOWN: 53 | return state + grid_size 54 | elif action == STAY: 55 | return state 56 | else: 57 | # cant execute action, stay in the same place. 58 | return state 59 | 60 | 61 | def check_can_take_action(action, state, grid_size): 62 | """ 63 | checks if you can take an action in a state. 64 | :param action: 65 | :param state: 66 | :param grid_size: 67 | :return: 68 | """ 69 | LAST_ROW = list(range(grid_size*(grid_size-1), grid_size*grid_size)) 70 | FIRST_ROW = list(range(0, grid_size)) 71 | LEFT_EDGE = list(range(0, grid_size*grid_size, grid_size)) 72 | RIGHT_EDGE = list(range(grid_size-1, grid_size*grid_size, grid_size)) 73 | 74 | if action == DOWN: 75 | if state in LAST_ROW: 76 | return False 77 | elif action == RIGHT: 78 | if state in RIGHT_EDGE: 79 | return False 80 | elif action == UP: 81 | if state in FIRST_ROW: 82 | return False 83 | elif action == LEFT: 84 | if state in LEFT_EDGE: 85 | return False 86 | elif action == STAY: 87 | return True 88 | else: 89 | raise InvalidActionError('Cannot take action {} in a grid world of size {}x{}'.format(action, grid_size, grid_size)) 90 | 91 | return True 92 | 93 | 94 | def get_possible_actions(state, grid_size): 95 | """ 96 | Gets all possible actions at a given state. 97 | :param state: 98 | :param grid_size: 99 | :return: 100 | """ 101 | LAST_ROW = list(range(grid_size*(grid_size-1), grid_size*grid_size)) 102 | FIRST_ROW = list(range(0, grid_size)) 103 | LEFT_EDGE = list(range(0, grid_size*grid_size, grid_size)) 104 | RIGHT_EDGE = list(range(grid_size-1, grid_size*grid_size, grid_size)) 105 | 106 | available_actions = [LEFT, RIGHT, UP, DOWN, STAY] 107 | if state in LAST_ROW: 108 | available_actions.remove(DOWN) 109 | if state in FIRST_ROW: 110 | available_actions.remove(UP) 111 | if state in RIGHT_EDGE: 112 | available_actions.remove(RIGHT) 113 | if state in LEFT_EDGE: 114 | available_actions.remove(LEFT) 115 | return available_actions 116 | 117 | 118 | # def flatten_state(state, n_states, grid_size): 119 | # """Flatten state (x,y) into a one hot vector""" 120 | # idx = 121 | # one_hot = np.zeros(n_states) 122 | # one_hot[idx] = 1 123 | # return one_hot 124 | 125 | def build_simple_grid(size=5, terminal_states=[], p_success=1): 126 | """ 127 | Builds a simple grid where an agent can move LEFT, RIGHT, UP or DOWN 128 | and actions success with probability p_success. 129 | A terminal state is added if len(terminal_states) > 0 and will return matrix of 130 | size (|S|+1)x|A|x(|S|+1) 131 | Moving into walls does nothing. 132 | :param size: size of the grid world 133 | :param terminal_state: the location of terminal states: a list of (x, y) tuples 134 | :param p_success: the probabilty that an action will be successful. 135 | :return: 136 | """ 137 | p_fail = 1 - p_success 138 | 139 | n_states = size*size 140 | grid_states = n_states # the number of entries of the state vector 141 | # corresponding to the grid itself. 142 | if len(terminal_states) > 0: n_states += 1 # add an entry to state vector for terminal state 143 | terminal_states = list(map(lambda tupl: int(size * tupl[0] + tupl[1]), terminal_states)) 144 | 145 | # this helper function creates the state transition list for 146 | # taking an action in a state 147 | def create_state_list_for_action(state_idx, action): 148 | transition_probs = np.zeros(n_states) 149 | if state_idx in terminal_states: 150 | # no matter what action you take you should go to the absorbing state 151 | transition_probs[-1] = 1 152 | elif state_idx == n_states-1 and len(terminal_states) > 0: 153 | # absorbing state, you should just transition back here whatever action you take. 154 | transition_probs[-1] = 1 155 | 156 | elif action in [LEFT, RIGHT, UP, DOWN, STAY]: 157 | # valid action, now see if we can actually execute this action 158 | # in this state: 159 | # TODO: distinguish between capability of slipping and taking wrong action vs failing to execute action. 160 | if check_can_take_action(action, state_idx, size): 161 | # yes we can 162 | possible_actions = get_possible_actions(state_idx, size) 163 | if action in possible_actions: 164 | transition_probs[get_state_after_executing_action(action, state_idx, size)] = p_success 165 | possible_actions.remove(action) 166 | for other_action in possible_actions: 167 | transition_probs[get_state_after_executing_action(other_action, state_idx, size)] = p_fail/len(possible_actions) 168 | 169 | else: 170 | possible_actions = get_possible_actions(state_idx, size) 171 | 172 | for other_action in possible_actions: 173 | transition_probs[get_state_after_executing_action(other_action, state_idx, size)] = p_fail/len(possible_actions) 174 | transition_probs[state_idx] += p_success # cant take action, stay in same place 175 | else: 176 | raise InvalidActionError('Invalid action {} in the 2D gridworld'.format(action)) 177 | return transition_probs 178 | 179 | P = np.zeros((n_states, n_actions, n_states)) 180 | for s in range(n_states): 181 | for a in range(n_actions): 182 | P[s, a, :] = create_state_list_for_action(s, a) 183 | # 184 | # T = {s: {a: create_state_list_for_action(s, a) for a in range(n_actions)} for s in range(n_states)} 185 | # T[0][LEFT][0], T[0][RIGHT][0], T[0][DOWN][0], T[0][UP][0] = 1, 1, 1, 1 186 | # T[15][LEFT][15], T[15][RIGHT][15], T[15][DOWN][15], T[15][UP][15] = 1, 1, 1, 1 187 | return P 188 | 189 | 190 | def add_walls(): 191 | pass -------------------------------------------------------------------------------- /grid_modules/gridworld/builder_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities to help build more complex grid worlds. 3 | """ 4 | import numpy as np 5 | from grid_modules.gridworld.helper_utilities import (build_simple_grid, 6 | get_possible_actions, 7 | check_can_take_action, 8 | get_state_after_executing_action, 9 | flatten_state, 10 | unflatten_state) 11 | from grid_modules.gridworld.env import GridWorldMDP 12 | 13 | 14 | class TransitionMatrixBuilder(object): 15 | """ 16 | Builder object to build a transition matrix for a grid world 17 | """ 18 | def __init__(self, grid_size, action_space=5, has_terminal_state=False): 19 | self.has_terminal_state = has_terminal_state 20 | self.grid_size = grid_size 21 | self.action_space = action_space 22 | self.state_space = grid_size * grid_size + int(has_terminal_state) 23 | self._P = np.zeros((self.state_space, self.action_space, self.state_space)) 24 | self.grid_added = False 25 | self.P_modified = False 26 | 27 | def add_grid(self, terminal_states=[], p_success=1): 28 | """ 29 | Adds a grid so that you cant walk off the edges of the grid 30 | :param terminal_states: 31 | :param p_success: 32 | :return: 33 | """ 34 | if self.has_terminal_state and len(terminal_states) == 0: 35 | raise ValueError('has_terminal_states is true, but no terminal states supplied.') 36 | 37 | if self.grid_added: 38 | raise ValueError('Grid has already been added') 39 | 40 | if self.P_modified: 41 | raise ValueError('transition matrix has already been modified. ' 42 | 'Adding a grid now can lead to weird behaviour') 43 | 44 | self._P = build_simple_grid(size=self.grid_size, p_success=p_success, terminal_states=terminal_states) 45 | self.grid_added = True 46 | self.P_modified = True 47 | 48 | def add_wall_at(self, tuple_location): 49 | """ 50 | Add a blockade at this position 51 | :param tuple_location: (x,y) location of the wall 52 | :return: 53 | """ 54 | target_state = flatten_state(tuple_location, self.grid_size, self.state_space) 55 | target_state = target_state.argmax() 56 | # find all the ways to go to "target_state" 57 | # from_states contains states that can lead you to target_state by executing from_action 58 | from_states, from_actions = np.where(self._P[:, :, target_state] != 0) 59 | 60 | # get the transition probability distributions that go from s--> t via some action 61 | transition_probs_from = self._P[from_states, from_actions, :] 62 | # TODO: optimize this loop 63 | for i, from_state in enumerate(from_states): # enumerate over states 64 | tmp = transition_probs_from[i, target_state] # get the prob of transitioning 65 | transition_probs_from[i, target_state] = 0 # set it to zero 66 | transition_probs_from[i, from_state] += tmp # add the transition prob to staying in the same place 67 | 68 | self._P[from_states, from_actions, :] = transition_probs_from 69 | 70 | # Get the probability of going to any state for all actions from target_state. 71 | transition_probs_from_wall = self._P[target_state, :, :] 72 | for i, probs_from_action in enumerate(transition_probs_from_wall): 73 | # Reset the probabilities. 74 | transition_probs_from_wall[i, :] = 0.0 75 | # Set the probability of going to the target state to be 1.0 76 | transition_probs_from_wall[i, target_state] = 1.0 77 | # Now set the probs of going to any state from target state as above (i.e only targets). 78 | self._P[target_state, :, :] = transition_probs_from_wall 79 | 80 | # renormalize and update transition matrix. 81 | normalization = self._P.sum(2) 82 | # normalization[normalization == 0] = 1 83 | normalization = 1/normalization 84 | self._P = (self._P * np.repeat(normalization, self._P.shape[0]).reshape(*self._P.shape)) 85 | 86 | assert np.allclose(self._P.sum(2), 1), 'Normalization did not occur correctly: {}'.format(self._P.sum(2)) 87 | assert np.allclose(self._P[target_state, :, target_state], 1.0), 'All actions from wall should lead to wall!' 88 | self._P_modified = True 89 | 90 | @property 91 | def P(self, nocopy=False): 92 | """ 93 | Returns a new array with the transition matrix built so far. 94 | :param nocopy: 95 | :return: 96 | """ 97 | if nocopy: 98 | return self._P 99 | else: 100 | return self._P.copy() 101 | 102 | def add_wall_between(self, start, end): 103 | """ 104 | Adds a wall between the starting and ending location 105 | :param start: tuple (x,y) representing the starting position of the wall 106 | :param end: tuple (x,y) representing the ending position of the wall 107 | :return: 108 | """ 109 | if not(start[0] == end[0] or start[1] == end[1]): 110 | raise ValueError('Walls can only be drawn in straight lines. ' 111 | 'Therefore, at least one of the x or y between ' 112 | 'the states should match.') 113 | 114 | if start[0] == end[0]: 115 | direction = 1 116 | else: 117 | direction = 0 118 | 119 | constant_idx = start[int(not direction)] 120 | start_idx = start[direction] 121 | end_idx = end[direction] 122 | 123 | if end_idx < start_idx: 124 | # flip start and end directions 125 | # to ensure we can still draw walls 126 | start_idx, end_idx = end_idx, start_idx 127 | 128 | for i in range(start_idx, end_idx+1): 129 | my_location = [None, None] 130 | my_location[direction] = i 131 | my_location[int(not direction)] = constant_idx 132 | print(my_location) 133 | self.add_wall_at(tuple(my_location)) 134 | 135 | 136 | def create_reward_matrix(state_space, size, reward_spec, action_space=5): 137 | """ 138 | Abstraction to create reward matrices. 139 | :param state_space: Size of the state space 140 | :param size: Size of the gird world (width) 141 | :param reward_spec: The reward specification 142 | :param action_space: The size of the action space 143 | :return: 144 | """ 145 | R = np.zeros((state_space, action_space)) 146 | for (reward_location, reward_value) in reward_spec.items(): 147 | reward_location = flatten_state(reward_location, size, state_space).argmax() 148 | R[reward_location, :] = reward_value 149 | 150 | return R 151 | 152 | 153 | def sample_goal(mdp): 154 | # Sample goal 155 | n = 0 156 | max_tries = 1e5 157 | found = False 158 | start_loc = unflatten_state(mdp.p0, mdp.size) 159 | while n < max_tries and not found: 160 | goal_loc = (np.random.randint(mdp.size), np.random.randint(mdp.size)) 161 | if goal_loc not in mdp.wall_locs and goal_loc != start_loc: 162 | # Reachable state found! 163 | found = True 164 | n += 1 165 | if not found: 166 | raise ValueError('Failed to sample a goal state.') 167 | reward_spec = {(goal_loc[0], goal_loc[1]): 1} 168 | R = create_reward_matrix(mdp.state_space, mdp.size, reward_spec, action_space=mdp.action_space) 169 | mdp.R = R 170 | mdp.goal_loc = goal_loc 171 | 172 | 173 | def sample_reachable_states(mdp, nb_goals=5, dist=np.random.randn): 174 | n = 0 175 | max_tries = 1e5 176 | reward_spec = dict() 177 | for _ in range(nb_goals): 178 | found = False 179 | while n < max_tries and not found: 180 | state_loc = (np.random.randint(mdp.size), np.random.randint(mdp.size)) 181 | if state_loc not in mdp.wall_locs: 182 | # Reachable state found! 183 | reward_spec[state_loc] = dist() 184 | found = True 185 | n += 1 186 | if not found: 187 | raise ValueError('Failed to sample a reachable state.') 188 | 189 | R = create_reward_matrix(mdp.state_space, mdp.size, reward_spec, action_space=mdp.action_space) 190 | return R 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /continuous_world_modules/env.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple, Any, Dict, Union, Callable 2 | import numpy as np 3 | import dataclasses 4 | import operator 5 | import random 6 | 7 | from continuous_world_modules.geometry import Point, intersect, on_segment 8 | 9 | # d = {'LEFT': 0, 'RIGHT': 1, 'UP': 2, 'DOWN': 3, 'STAY': 4} 10 | # ACTIONS = SimpleNamespace(**d) 11 | CARDINAL_ACTIONS = [Point(-0.1, 0), Point(0.1, 0.), Point(0., 0.1), Point(0., -0.1), Point(0., 0.)] 12 | ACTIONS_STR = ['LEFT', 'RIGHT', 'UP', 'DOWN', 'STAY'] 13 | 14 | carrot = (235 / 255, 137 / 255, 33 / 255, 0.8) 15 | marker_style = dict(linestyle=':', color=carrot, markersize=15) 16 | 17 | 18 | class ContinuousWorld(object): 19 | r"""The ContinuousWorld Environment. 20 | 21 | An agent can be anywhere in the grid. The agent provides Forces to move. When 22 | the agent provides a force, it is applied and the final position is jittered. 23 | 24 | Walls can be specified in this environment. Detection works by checking if the 25 | agents action forces it to go in a direction which collides with a wall. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | size: float, 31 | wall_pairs: Optional[List[Tuple[Point, Point]]] = None, 32 | movement_noise: float = 0.01, 33 | seed: int = 1, 34 | reset_noise: Optional[float] = None, 35 | verbose_reset: bool = False, 36 | threshold_distance: float = 0.5 37 | ): 38 | """Initializes the Continuous World Environment. 39 | 40 | Args: 41 | size: The size of the world. 42 | wall_pairs: A list of tuple of points representing the start and end 43 | positions of the wall. 44 | movement_noise: The noise around each position after movement. 45 | seed: The seed for the random number generator. 46 | max_episode_length: The maximum length of the episode before resetting. 47 | max_action_force: If using random_step() this will be the maximum random 48 | force applied in the x and y direction. 49 | verbose_reset: Prints out every time the global starting position is 50 | reset. 51 | """ 52 | self.actions = CARDINAL_ACTIONS 53 | self.actions_str = ACTIONS_STR 54 | self._size = size 55 | self._wall_pairs = wall_pairs or [] 56 | self._verbose_reset = verbose_reset 57 | self.threshold_distance = threshold_distance 58 | 59 | self._noise = movement_noise 60 | self._reset_noise = reset_noise or movement_noise 61 | self._rng = np.random.RandomState(seed) 62 | random.seed(seed) 63 | 64 | # self.update_start_position() 65 | # self.set_agent_position(self._start_position) 66 | 67 | def _past_edge(self, x: float) -> Tuple[bool, float]: 68 | """Checks if coordinate is beyond the edges.""" 69 | if x >= self._size: 70 | return True, self._size 71 | elif x <= 0.0: 72 | return True, 0.0 73 | else: 74 | return False, x 75 | 76 | def _wrap_coordinate(self, point: Point) -> Point: 77 | """Wraps coordinates that are beyond edges.""" 78 | wrapped_coordinates = map(self._past_edge, dataclasses.astuple(point)) 79 | return Point(*map(operator.itemgetter(1), wrapped_coordinates)) 80 | 81 | # def set_agent_position(self, new_position: Point): 82 | # self._current_position = self._wrap_coordinate(new_position) 83 | 84 | def set_goal(self, new_position: Point): 85 | self._goal = self._wrap_coordinate(new_position) 86 | 87 | def set_initial_position(self, new_position: Point): 88 | self._current_position = self._wrap_coordinate(new_position) 89 | 90 | # def update_start_position(self): 91 | # self._start_position = Point(*np.random.uniform(0, self._size, 2)) 92 | 93 | def reset(self): 94 | """Reset the current position of the agent and move the global mu.""" 95 | self._current_position = self.sample_point() 96 | self._goal = self.sample_point() 97 | return self.current_position 98 | 99 | def sample_point(self): 100 | on_wall = True 101 | while on_wall: 102 | p = Point(*np.random.uniform(0, self._size, 2)) 103 | on_wall = self._check_on_wall(p) 104 | return p 105 | 106 | @property 107 | def goal(self): 108 | return np.array(dataclasses.astuple(self._goal)) 109 | # def agent_position(self): 110 | # return dataclasses.astuple(self._current_position) 111 | 112 | @property 113 | def current_position(self): 114 | return np.array(dataclasses.astuple(self._current_position)) 115 | 116 | # @property 117 | # def start_position(self): 118 | # return dataclasses.astuple(self._start_position) 119 | 120 | @property 121 | def size(self): 122 | return self._size 123 | 124 | @property 125 | def walls(self): 126 | return self._wall_pairs 127 | 128 | def _check_goes_through_wall(self, start: Point, end: Point): 129 | if not self._wall_pairs: 130 | return False 131 | 132 | for pair in self._wall_pairs: 133 | if intersect((start, end), pair): 134 | return True 135 | return False 136 | 137 | def _check_on_wall(self, p: Point): 138 | if not self._wall_pairs: 139 | return False 140 | for pair in self._wall_pairs: 141 | if on_segment(pair[0], pair[1], p): 142 | return True 143 | 144 | def step(self, id_action) -> Tuple[Tuple[float, float], Optional[float], bool, Dict[str, Any]]: 145 | """Does a step in the environment using the action. 146 | 147 | Args: 148 | action: action's index to be executed. 149 | 150 | Returns: 151 | Agent position: A tuple of two floats. 152 | The reward. 153 | An indicator if the episode terminated. 154 | A dictionary containing any information about the step. 155 | """ 156 | perturbed_action = self.actions[id_action].normal_sample_around(self._noise) 157 | proposed_position = self._wrap_coordinate(self._current_position + perturbed_action) 158 | goes_through_wall = self._check_goes_through_wall(self._current_position, proposed_position) 159 | 160 | if not goes_through_wall: 161 | self._current_position = proposed_position 162 | done = False 163 | reward = 1.0 if self._current_position.is_close_to(self._goal, diff=self.threshold_distance) else 0.0 164 | return self.current_position, reward, done, {'goes_through_wall': goes_through_wall, 'proposed_position': proposed_position} 165 | 166 | 167 | def visualize_environment( 168 | world, 169 | ax, 170 | scaling=1.0, 171 | agent_color='r', 172 | agent_size=0.2, 173 | start_color='g', 174 | draw_initial_position=True, 175 | draw_goal=True, 176 | write_text=False): 177 | """Visualize the continuous grid world. 178 | 179 | The agent will be drawn as a circle. The start and target 180 | locations will be drawn by a cross. Walls will be drawn in 181 | black. 182 | 183 | Args: 184 | world: The continuous gridworld to visualize. 185 | ax: The matplotlib axes to draw the gridworld. 186 | scaling: Scale the plot by this factor. 187 | agent_color: Color of the agent. 188 | agent_size: Size of the agent in the world. 189 | start_color: Color of the start marker. 190 | draw_agent: Boolean that controls drawing agent. 191 | draw_start_mu: Boolean that controls drawing starting position. 192 | draw_target_mu: Boolean that controls drawing ending position. 193 | draw_walls: Boolean that controls drawing walls. 194 | write_text: Boolean to write text for each component being drawn. 195 | """ 196 | carrot = (235 / 255, 137 / 255, 33 / 255, 0.8) 197 | marker_style = dict(linestyle=':', color=carrot, markersize=15) 198 | 199 | scaled_size = scaling * world.size 200 | 201 | # Draw the outer walls. 202 | ax.hlines(0, 0, scaled_size, color='k') 203 | ax.hlines(scaled_size, 0, scaled_size, color='k') 204 | ax.vlines(scaled_size, 0, scaled_size, color='k') 205 | ax.vlines(0, 0, scaled_size, color='k') 206 | 207 | for wall_pair in world.walls: 208 | ax.plot( 209 | [p.x * scaling for p in wall_pair], 210 | [p.y * scaling for p in wall_pair], 211 | color='k') 212 | 213 | if draw_initial_position: 214 | # Draw the position of the start dist. 215 | x, y = [p * scaling for p in world.current_position] 216 | ax.plot(x, y, marker='o', **marker_style) 217 | if write_text: 218 | ax.text(x, y, 'starting position.') 219 | 220 | if draw_goal: 221 | # Draw the target position. 222 | x, y = [p * scaling for p in world.goal] 223 | ax.plot(x, y, marker='*', **marker_style) 224 | if write_text: 225 | ax.text(x, y, 'target position.') 226 | 227 | # if draw_agent: 228 | # # Draw the position of the agent as a circle. 229 | # x, y = [scaling * p for p in world.current_position] 230 | # ax.plot(x, y, marker='H', **marker_style) 231 | # # agent_circle = plt.Circle((x, y), agent_size, color=agent_color) 232 | # # ax.add_artist(agent_circle) 233 | # if write_text: 234 | # ax.text(x, y, 'current position.') 235 | 236 | ax.set_xlabel('x') 237 | ax.set_ylabel('y') 238 | ax.axis('off') 239 | ax.grid(False) 240 | return ax -------------------------------------------------------------------------------- /discrete_action_robots_modules/robots.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import gym 4 | from gym import spaces 5 | 6 | """ Modification from https://github.com/paulorauber/hpg/blob/master/hpg/environments/robotics.py""" 7 | 8 | 9 | def goal_distance(goal_a, goal_b): 10 | assert goal_a.shape == goal_b.shape 11 | return np.linalg.norm(goal_a - goal_b, axis=-1) 12 | 13 | 14 | def generate_itoa_dict( 15 | bucket_values=[-0.33, 0, 0.33], valid_movement_direction=[1, 1, 1, 1]): 16 | """ 17 | Set cartesian product to generate action combination 18 | spaces for the fetch environments 19 | valid_movement_direction: To set 20 | """ 21 | action_space_extended = [bucket_values if m == 1 else [0] 22 | for m in valid_movement_direction] 23 | return list(itertools.product(*action_space_extended)) 24 | 25 | 26 | class FetchReach: 27 | def __init__(self, 28 | action_mode="cart", action_buckets=[-1, 0, 1], 29 | action_stepsize=1.0, 30 | reward_type="sparse"): 31 | """ 32 | Parameters: 33 | action_mode {"cart","cartmixed","cartprod","impulse","impulsemixed"} 34 | action_stepsize: Step size of the action to perform. 35 | Int for cart and impulse 36 | List for cartmixed and impulsemixed 37 | action_buckets: List of buckets used when mode is cartprod 38 | reward_mode = {"sparse","dense"} 39 | 40 | Reward Mode: 41 | `sparse` rewards are like the standard HPG rewards. 42 | `dense` rewards (from the paper/gym) give -(distance to goal) at every timestep. 43 | 44 | Modes: 45 | `cart` is for manhattan style movement where an action moves the arm in one direction 46 | for every action. 47 | 48 | `impulse` treats the action dimensions as velocity and adds/decreases 49 | the velocity by action_stepsize depending on the direction picked. 50 | Adds current direction 51 | velocity to state 52 | 53 | 54 | `impulsemixed` and `cartmixed` does the above with multiple magnitudes of action_stepsize. 55 | Needs the action_stepsize as a list. 56 | 57 | `cartprod` takes combinations of actions as input 58 | """ 59 | 60 | try: 61 | self.env = gym.make("FetchReach-v1") 62 | except Exception as e: 63 | print( 64 | "You do not have the latest version of gym (gym-0.10.5). Falling back to v0 with movable table") 65 | self.env = gym.make("FetchReach-v0") 66 | 67 | self.action_directions = np.array( 68 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]) 69 | self.valid_action_directions = np.float32( 70 | np.any(self.action_directions, axis=0)) 71 | 72 | self.distance_threshold = self.env.distance_threshold 73 | # self.distance_threshold = 0.06 74 | self.action_mode = action_mode 75 | self.num_actions = self.generate_action_map(action_buckets, action_stepsize) 76 | 77 | obs_dim = 10 + 4 * (action_mode == "impulse" or action_mode == "impulsemixed") 78 | self.observation_space = spaces.Dict(dict( 79 | desired_goal=spaces.Box(-np.inf, np.inf, shape=(3, ), dtype='float32'), 80 | achieved_goal=spaces.Box(-np.inf, np.inf, shape=(3, ), dtype='float32'), 81 | observation=spaces.Box(-np.inf, np.inf, shape=(obs_dim, ), dtype='float32'), 82 | )) 83 | self._max_episode_steps = self.env._max_episode_steps 84 | self.reward_type = reward_type 85 | 86 | def generate_action_map(self, action_buckets, action_stepsize=1.): 87 | 88 | action_directions = self.action_directions 89 | if self.action_mode == "cart" or self.action_mode == "impulse": 90 | assert isinstance(action_stepsize, float) 91 | self.action_space = np.vstack( 92 | (action_directions * action_stepsize, -action_directions * action_stepsize)) 93 | 94 | elif self.action_mode == "cartmixed" or self.action_mode == "impulsemixed": 95 | assert isinstance(action_stepsize, list) 96 | action_space_list = [] 97 | for ai in action_stepsize: 98 | action_space_list += [action_directions * ai, 99 | -action_directions * ai] 100 | self.action_space = np.vstack(action_space_list) 101 | 102 | elif self.action_mode == "cartprod": 103 | self.action_space = generate_itoa_dict( 104 | action_buckets, self.valid_action_directions) 105 | 106 | return len(self.action_space) 107 | 108 | def seed(self, seed): 109 | self.env.seed(seed) 110 | 111 | def action_map(self, action): 112 | # If the modes are direct, just map the action as an index 113 | # else, accumulate them 114 | 115 | if self.action_mode in ["cartprod", "cart", "cartmixed"]: 116 | return self.action_space[action] 117 | else: 118 | self.action_vel += self.action_space[action] 119 | self.action_vel = np.clip(self.action_vel, -1, 1) 120 | return self.action_vel 121 | 122 | def reset(self): 123 | self.action_vel = np.zeros(4) # Initialize/reset 124 | obs = self.env.reset() 125 | if self.action_mode == "impulse" or self.action_mode == "impulsemixed": 126 | obs["observation"] = np.hstack((obs["observation"], self.action_vel)) 127 | return obs 128 | 129 | def step(self, a): 130 | 131 | action_vec = self.action_map(a) 132 | obs, reward, done, info = self.env.step(action_vec) 133 | if self.action_mode == "impulse" or self.action_mode == "impulsemixed": 134 | obs["observation"] = np.hstack( 135 | (obs["observation"], np.clip(self.action_vel, -1, 1))) 136 | 137 | done = False 138 | info = { 139 | 'is_success': self._is_success(obs['achieved_goal'], self.env.goal), 140 | } 141 | reward = self.compute_reward(obs['achieved_goal'], self.env.goal, info) 142 | return obs, reward, done, info 143 | 144 | def _is_success(self, achieved_goal, desired_goal): 145 | d = goal_distance(achieved_goal, desired_goal) 146 | return (d <= self.distance_threshold).astype(np.float32) 147 | 148 | def compute_reward(self, achieved_goal, goal, info): 149 | # Compute distance between goal and the achieved goal. 150 | d = goal_distance(achieved_goal, goal) 151 | if self.reward_type == 'sparse': 152 | return -(d > self.distance_threshold).astype(np.float32) 153 | else: 154 | return -d 155 | 156 | def __del__(self): 157 | self.env.close() 158 | 159 | 160 | class FetchPush(FetchReach): 161 | def __init__(self, 162 | action_mode="impulsemixed", action_buckets=[-1, 0, 1], 163 | action_stepsize=[0.1, 1.0], 164 | reward_type="sparse"): 165 | 166 | try: 167 | self.env = gym.make("FetchPush-v1") 168 | except Exception as e: 169 | print( 170 | "You do not have the latest version of gym (gym-0.10.5). Falling back to v0 with movable table") 171 | self.env = gym.make("FetchPush-v0") 172 | 173 | self.action_directions = np.array([[1, 0, 0, 0], [0, 1, 0, 0]]) 174 | self.valid_action_directions = np.float32( 175 | np.any(self.action_directions, axis=0)) 176 | 177 | self.goal = self.env.goal 178 | self.distance_threshold = self.env.distance_threshold 179 | self.action_mode = action_mode 180 | self.num_actions = self.generate_action_map(action_buckets, action_stepsize) 181 | 182 | obs_dim = 25 + 4 * (action_mode == "impulse" or action_mode == "impulsemixed") 183 | self.observation_space = spaces.Dict(dict( 184 | desired_goal=spaces.Box(-np.inf, np.inf, shape=(3, ), dtype='float32'), 185 | achieved_goal=spaces.Box(-np.inf, np.inf, shape=(3, ), dtype='float32'), 186 | observation=spaces.Box(-np.inf, np.inf, shape=(obs_dim, ), dtype='float32'), 187 | )) 188 | self._max_episode_steps = self.env._max_episode_steps 189 | self.reward_type = reward_type 190 | self.is_train = False 191 | 192 | 193 | class FetchSlide(FetchReach): 194 | def __init__(self, 195 | action_mode="cart", action_buckets=[-1, 0, 1], 196 | action_stepsize=1.0, 197 | reward_type="sparse"): 198 | 199 | try: 200 | self.env = gym.make("FetchSlide-v1") 201 | except Exception as e: 202 | print( 203 | "You do not have the latest version of gym (gym-0.10.5). Falling back to v0 with movable table") 204 | self.env = gym.make("FetchSlide-v0") 205 | 206 | self.action_directions = np.array([[1, 0, 0, 0], [0, 1, 0, 0]]) 207 | self.valid_action_directions = np.float32( 208 | np.any(self.action_directions, axis=0)) 209 | 210 | self.goal = self.env.goal 211 | self.distance_threshold = self.env.distance_threshold 212 | self.action_mode = action_mode 213 | self.num_actions = self.generate_action_map(action_buckets, action_stepsize) 214 | obs_dim = 25 + 4 * (action_mode == "impulse" or action_mode == "impulsemixed") 215 | self.observation_space = spaces.Dict(dict( 216 | desired_goal=spaces.Box(-np.inf, np.inf, shape=(3, ), dtype='float32'), 217 | achieved_goal=spaces.Box(-np.inf, np.inf, shape=(3, ), dtype='float32'), 218 | observation=spaces.Box(-np.inf, np.inf, shape=(obs_dim, ), dtype='float32'), 219 | )) 220 | self._max_episode_steps = self.env._max_episode_steps 221 | self.reward_type = reward_type 222 | 223 | 224 | if __name__=='__main__': 225 | env = FetchReach() 226 | obs = env.reset() 227 | -------------------------------------------------------------------------------- /continuous_world_modules/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | import pickle 7 | import csv 8 | from grid_modules.replay_buffer import ReplayBuffer, her_replay_buffer 9 | from grid_modules.her import her_sampler 10 | from discrete_action_robots_modules.models import critic 11 | from continuous_world_modules.featurizer import RadialBasisFunction2D 12 | 13 | """ 14 | DQN agent 15 | 16 | """ 17 | 18 | 19 | def goal_distance(goal_a, goal_b): 20 | assert goal_a.shape == goal_b.shape 21 | return np.max(np.abs(goal_a - goal_b), axis=-1) 22 | 23 | 24 | class DQNAgent: 25 | def __init__(self, args, env, env_params): 26 | self.args = args 27 | self.env = env 28 | self.env_params = env_params 29 | self.featuriser = RadialBasisFunction2D(1, 21, 0.05, cuda=args.cuda) 30 | # create the network 31 | self.critic_network = critic(env_params) 32 | # build up the target network 33 | self.critic_target_network = critic(env_params) 34 | # load the weights into the target networks 35 | self.critic_target_network.load_state_dict(self.critic_network.state_dict()) 36 | # if use gpu 37 | if self.args.cuda: 38 | self.critic_network.cuda() 39 | self.critic_target_network.cuda() 40 | # create the optimizer 41 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr) 42 | # her sampler 43 | compute_reward = lambda g_1, g_2: (goal_distance(g_1, g_2) <= env.threshold_distance).astype(np.float32) 44 | self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, compute_reward) 45 | # create the replay buffer 46 | self.buffer = her_replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions) 47 | # create the replay buffer 48 | # self.buffer = ReplayBuffer(self.args.buffer_size) 49 | 50 | if args.save_dir is not None: 51 | if not os.path.exists(self.args.save_dir): 52 | os.mkdir(self.args.save_dir) 53 | 54 | print(' ' * 26 + 'Options') 55 | for k, v in vars(self.args).items(): 56 | print(' ' * 26 + k + ': ' + str(v)) 57 | 58 | with open(self.args.save_dir + "/arguments.pkl", 'wb') as f: 59 | pickle.dump(self.args, f) 60 | 61 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "wt") as monitor_file: 62 | monitor = csv.writer(monitor_file) 63 | monitor.writerow(['epoch', 'eval', 'dist']) 64 | 65 | def learn(self): 66 | """ 67 | train the network 68 | 69 | """ 70 | # start to collect samples 71 | for epoch in range(self.args.n_epochs): 72 | for _ in range(self.args.n_cycles): 73 | mb_obs, mb_g, mb_actions = [], [], [] 74 | for _ in range(self.args.num_rollouts_per_cycle): 75 | # reset the rollouts 76 | ep_obs, ep_g, ep_actions = [], [], [] 77 | # reset the environment 78 | obs = self.env.reset() 79 | g = self.env.goal 80 | # start to collect samples 81 | for t in range(self.env_params['max_timesteps']): 82 | with torch.no_grad(): 83 | obs_tensor = self._preproc_o(obs) 84 | g_tensor = self._preproc_g(g) 85 | action = self.act_e_greedy(obs_tensor, g_tensor, update_eps=0.2) 86 | # feed the actions into the environment 87 | obs_new, reward, done, info = self.env.step(action) 88 | # append rollouts 89 | ep_obs.append(obs.copy()) 90 | ep_g.append(g.copy()) 91 | ep_actions.append(action) 92 | obs = obs_new 93 | ep_obs.append(obs.copy()) 94 | mb_obs.append(ep_obs) 95 | mb_g.append(ep_g) 96 | mb_actions.append(ep_actions) 97 | # convert them into arrays 98 | mb_obs = np.array(mb_obs) 99 | mb_g = np.array(mb_g) 100 | mb_actions = np.array(mb_actions) 101 | # store the episodes 102 | self.buffer.store_episode([mb_obs, mb_g, mb_actions]) 103 | 104 | for _ in range(self.args.n_batches): 105 | # train the network 106 | self._update_network() 107 | # soft update 108 | self._soft_update_target_network(self.critic_target_network, self.critic_network) 109 | # start to do the evaluation 110 | success_rate, dist = self._eval_agent() 111 | 112 | print('[{}] epoch is: {}, eval: {:.3f}, dist: {:.3f}'.format(datetime.now(), epoch, success_rate, dist)) 113 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "a") as monitor_file: 114 | monitor = csv.writer(monitor_file) 115 | monitor.writerow([epoch, success_rate, dist]) 116 | torch.save(self.critic_network.state_dict(), 117 | os.path.join(self.args.save_dir, 'model.pt')) 118 | 119 | # pre_process the inputs 120 | def _preproc_o(self, obs): 121 | obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) 122 | if self.args.cuda: 123 | obs_tensor = obs_tensor.cuda() 124 | obs_tensor = self.featuriser.transform(obs_tensor) 125 | return obs_tensor 126 | 127 | def _preproc_g(self, g): 128 | g_tensor = torch.tensor(g, dtype=torch.float32).unsqueeze(0) 129 | if self.args.cuda: 130 | g_tensor = g_tensor.cuda() 131 | g_tensor = self.featuriser.transform(g_tensor) 132 | return g_tensor 133 | 134 | # Acts based on single state (no batch) 135 | def act(self, obs, g, target_network=False): 136 | if target_network: 137 | q = self.critic_target_network(obs, g) 138 | else: 139 | q = self.critic_network(obs, g) 140 | return q.max(1)[1] 141 | 142 | # Acts with an epsilon-greedy policy 143 | def act_e_greedy(self, obs, g, update_eps=0.2): 144 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, g).item() 145 | 146 | # soft update 147 | def _soft_update_target_network(self, target, source): 148 | for target_param, param in zip(target.parameters(), source.parameters()): 149 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 150 | 151 | def _hard_update_target_network(self, target, source): 152 | for target_param, param in zip(target.parameters(), source.parameters()): 153 | target_param.data.copy_(param.data) 154 | 155 | # update the network 156 | def _update_network(self): 157 | # sample the episodes 158 | transitions = self.buffer.sample(self.args.batch_size) 159 | 160 | # transfer them into the tensor 161 | obs_tensor = torch.tensor(transitions['obs'], dtype=torch.float32) 162 | g_tensor = torch.tensor(transitions['g'], dtype=torch.float32) 163 | obs_next_tensor = torch.tensor(transitions['obs_next'], dtype=torch.float32) 164 | actions_tensor = torch.tensor(transitions['action'], dtype=torch.long) 165 | r_tensor = torch.tensor(transitions['reward'], dtype=torch.float32) 166 | if self.args.cuda: 167 | obs_tensor = obs_tensor.cuda() 168 | g_tensor = g_tensor.cuda() 169 | obs_next_tensor = obs_next_tensor.cuda() 170 | actions_tensor = actions_tensor.cuda() 171 | r_tensor = r_tensor.cuda() 172 | 173 | obs_tensor = self.featuriser.transform(obs_tensor) 174 | g_tensor = self.featuriser.transform(g_tensor) 175 | obs_next_tensor = self.featuriser.transform(obs_next_tensor) 176 | # calculate the target Q value function 177 | with torch.no_grad(): 178 | q_next_value = self.critic_target_network(obs_next_tensor, g_tensor).max(1)[0].reshape(-1, 1) 179 | q_next_value = q_next_value.detach() 180 | target_q_value = r_tensor + self.args.gamma * q_next_value 181 | target_q_value = target_q_value.detach() 182 | # clip the q value 183 | clip_return = 1 / (1 - self.args.gamma) 184 | target_q_value = torch.clamp(target_q_value, 0, clip_return) 185 | # the q loss 186 | real_q_value = self.critic_network(obs_tensor, g_tensor).gather(1, actions_tensor.reshape(-1, 1)) 187 | critic_loss = (target_q_value - real_q_value).pow(2).mean() 188 | # update the critic_network 189 | self.critic_optim.zero_grad() 190 | critic_loss.backward() 191 | self.critic_optim.step() 192 | 193 | # do the evaluation 194 | def _eval_agent(self): 195 | total_success_rate = [] 196 | total_dist = [] 197 | for _ in range(self.args.n_test_rollouts): 198 | obs = self.env.reset() 199 | g = self.env.goal 200 | # self.env.set_initial_position(Point(0.2, 0.1)) 201 | # self.env.set_goal(Point(0.9, 0.9)) 202 | # obs = self.env.current_position 203 | # g = self.env.goal 204 | for _ in range(self.env_params['max_timesteps']): 205 | with torch.no_grad(): 206 | obs_norm_tensor = self._preproc_o(obs) 207 | g_norm_tensor = self._preproc_g(g) 208 | action = self.act_e_greedy(obs_norm_tensor, g_norm_tensor, update_eps=0.02) 209 | obs, reward, _, info = self.env.step(action) 210 | if reward > 0: 211 | break 212 | 213 | total_success_rate.append(reward) 214 | dist = goal_distance(obs, g) 215 | total_dist.append(dist) 216 | 217 | total_success_rate = np.array(total_success_rate) 218 | total_success_rate = np.mean(total_success_rate) 219 | 220 | total_dist = np.array(total_dist) 221 | total_dist = np.mean(total_dist) 222 | 223 | return total_success_rate, total_dist -------------------------------------------------------------------------------- /grid_modules/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | import pickle 7 | import csv 8 | from grid_modules.replay_buffer import ReplayBuffer, her_replay_buffer 9 | from grid_modules.her import her_sampler 10 | from grid_modules.mdp_utils import extract_policy, value_iteration, compute_successor_reps 11 | from discrete_action_robots_modules.models import critic 12 | 13 | 14 | class DQNAgent: 15 | def __init__(self, args, env, env_params): 16 | self.args = args 17 | self.env = env 18 | self.env_params = env_params 19 | # create the network 20 | self.critic_network = critic(env_params) 21 | # build up the target network 22 | self.critic_target_network = critic(env_params) 23 | # load the weights into the target networks 24 | self.critic_target_network.load_state_dict(self.critic_network.state_dict()) 25 | # if use gpu 26 | if self.args.cuda: 27 | self.critic_network.cuda() 28 | self.critic_target_network.cuda() 29 | # create the optimizer 30 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr) 31 | # her sampler 32 | compute_reward = lambda g_1, g_2: (g_1.argmax(-1) == g_2.argmax(-1)).astype(np.float32) 33 | self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, compute_reward) 34 | # create the replay buffer 35 | self.buffer = her_replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions) 36 | # create the replay buffer 37 | # self.buffer = ReplayBuffer(self.args.buffer_size) 38 | 39 | if args.save_dir is not None: 40 | if not os.path.exists(self.args.save_dir): 41 | os.mkdir(self.args.save_dir) 42 | 43 | print(' ' * 26 + 'Options') 44 | for k, v in vars(self.args).items(): 45 | print(' ' * 26 + k + ': ' + str(v)) 46 | 47 | with open(self.args.save_dir + "/arguments.pkl", 'wb') as f: 48 | pickle.dump(self.args, f) 49 | 50 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "wt") as monitor_file: 51 | monitor = csv.writer(monitor_file) 52 | monitor.writerow(['epoch', 'eval']) 53 | 54 | def learn(self): 55 | """ 56 | train the network 57 | 58 | """ 59 | best_perf = 0 60 | # start to collect samples 61 | for epoch in range(self.args.n_epochs): 62 | for _ in range(self.args.n_cycles): 63 | mb_obs, mb_g, mb_actions = [], [], [] 64 | for _ in range(self.args.num_rollouts_per_cycle): 65 | # reset the rollouts 66 | ep_obs, ep_g, ep_actions = [], [], [] 67 | # reset the environment 68 | obs = self.env.reset() 69 | g = self.env.goal 70 | # start to collect samples 71 | for t in range(self.env_params['max_timesteps']): 72 | with torch.no_grad(): 73 | obs_tensor = self._preproc_o(obs) 74 | g_tensor = self._preproc_g(g) 75 | action = self.act_e_greedy(obs_tensor, g_tensor, update_eps=0.2) 76 | # feed the actions into the environment 77 | obs_new, reward, done, info = self.env.step(action) 78 | # append rollouts 79 | ep_obs.append(obs.copy()) 80 | ep_g.append(g.copy()) 81 | ep_actions.append(action) 82 | obs = obs_new 83 | ep_obs.append(obs.copy()) 84 | mb_obs.append(ep_obs) 85 | mb_g.append(ep_g) 86 | mb_actions.append(ep_actions) 87 | # convert them into arrays 88 | mb_obs = np.array(mb_obs) 89 | mb_g = np.array(mb_g) 90 | mb_actions = np.array(mb_actions) 91 | # store the episodes 92 | self.buffer.store_episode([mb_obs, mb_g, mb_actions]) 93 | 94 | for _ in range(self.args.n_batches): 95 | # train the network 96 | self._update_network() 97 | # soft update 98 | self._soft_update_target_network(self.critic_target_network, self.critic_network) 99 | # start to do the evaluation 100 | perf = self._eval_agent() 101 | 102 | print('[{}] epoch is: {}, eval: {:.3f}'.format(datetime.now(), epoch, perf)) 103 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "a") as monitor_file: 104 | monitor = csv.writer(monitor_file) 105 | monitor.writerow([epoch, perf]) 106 | torch.save(self.critic_network.state_dict(), 107 | os.path.join(self.args.save_dir, 'model.pt')) 108 | if perf > best_perf: 109 | torch.save(self.critic_network.state_dict(), 110 | os.path.join(self.args.save_dir, 'best_model.pt')) 111 | 112 | # pre_process the inputs 113 | def _preproc_o(self, obs): 114 | obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) 115 | if self.args.cuda: 116 | obs_tensor = obs_tensor.cuda() 117 | return obs_tensor 118 | 119 | def _preproc_g(self, g): 120 | g_tensor = torch.tensor(g, dtype=torch.float32).unsqueeze(0) 121 | if self.args.cuda: 122 | g_tensor = g_tensor.cuda() 123 | return g_tensor 124 | 125 | def get_policy(self, g, obs=None, policy_type='boltzmann', temp=1, eps=0.01, target_network=False): 126 | if obs is None: 127 | obs = torch.eye(self.env.state_space) # S x S 128 | g = g.repeat(self.env.state_space, 1) 129 | if self.args.cuda: 130 | obs = obs.cuda() # S x S 131 | if target_network: 132 | q = self.critic_target_network(obs, g) 133 | else: 134 | q = self.critic_network(obs, g) 135 | return extract_policy(q, policy_type=policy_type, temp=temp, eps=eps) 136 | 137 | # Acts based on single state (no batch) 138 | def act(self, obs, g, target_network=False): 139 | if target_network: 140 | q = self.critic_target_network(obs, g) 141 | else: 142 | q = self.critic_network(obs, g) 143 | return q.max(1)[1] 144 | 145 | # Acts with an epsilon-greedy policy 146 | def act_e_greedy(self, obs, g, update_eps=0.2): 147 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, g).item() 148 | 149 | # soft update 150 | def _soft_update_target_network(self, target, source): 151 | for target_param, param in zip(target.parameters(), source.parameters()): 152 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 153 | 154 | def _hard_update_target_network(self, target, source): 155 | for target_param, param in zip(target.parameters(), source.parameters()): 156 | target_param.data.copy_(param.data) 157 | 158 | # update the network 159 | def _update_network(self): 160 | # sample the episodes 161 | transitions = self.buffer.sample(self.args.batch_size) 162 | 163 | # transfer them into the tensor 164 | obs_tensor = torch.tensor(transitions['obs'], dtype=torch.float32) 165 | g_tensor = torch.tensor(transitions['g'], dtype=torch.float32) 166 | obs_next_tensor = torch.tensor(transitions['obs_next'], dtype=torch.float32) 167 | actions_tensor = torch.tensor(transitions['action'], dtype=torch.long) 168 | r_tensor = torch.tensor(transitions['reward'], dtype=torch.float32) 169 | if self.args.cuda: 170 | obs_tensor = obs_tensor.cuda() 171 | g_tensor = g_tensor.cuda() 172 | obs_next_tensor = obs_next_tensor.cuda() 173 | actions_tensor = actions_tensor.cuda() 174 | r_tensor = r_tensor.cuda() 175 | 176 | # calculate the target Q value function 177 | with torch.no_grad(): 178 | q_next_value = self.critic_target_network(obs_next_tensor, g_tensor).max(1)[0].reshape(-1, 1) 179 | q_next_value = q_next_value.detach() 180 | target_q_value = r_tensor + self.args.gamma * q_next_value 181 | target_q_value = target_q_value.detach() 182 | # clip the q value 183 | clip_return = 1 / (1 - self.args.gamma) 184 | target_q_value = torch.clamp(target_q_value, 0, clip_return) 185 | # the q loss 186 | real_q_value = self.critic_network(obs_tensor, g_tensor).gather(1, actions_tensor.reshape(-1, 1)) 187 | critic_loss = (target_q_value - real_q_value).pow(2).mean() 188 | # update the critic_network 189 | self.critic_optim.zero_grad() 190 | critic_loss.backward() 191 | self.critic_optim.step() 192 | 193 | # do the evaluation 194 | def _eval_agent(self): 195 | total_perf = [] 196 | for _ in range(self.args.n_test_rollouts): 197 | init_obs = self.env.reset() 198 | g = self.env.goal 199 | R = torch.tensor(self.env.R, dtype=torch.float32) 200 | P = torch.tensor(self.env.P, dtype=torch.float32) 201 | if self.args.cuda: 202 | R = R.cuda() 203 | P = P.cuda() 204 | opt_q = value_iteration(R, P, self.args.gamma, atol=1e-8, max_iteration=5000) 205 | opt_perf = opt_q[self.env.reachable_states].max(1)[0].mean() 206 | 207 | g_tensor = self._preproc_g(g) 208 | pi = self.get_policy(g_tensor, policy_type='boltzmann', temp=0.1) 209 | sr_pi = compute_successor_reps(P, pi, self.args.gamma) 210 | q_pi = torch.matmul(sr_pi, R.t().reshape(self.env.state_space * self.env.action_space)) 211 | q_pi = q_pi.reshape(self.env.action_space, self.env.state_space).t() 212 | 213 | # score = torch.dot(q_pi[init_obs.argmax()], pi[init_obs.argmax()]) 214 | score = torch.einsum('sa, sa -> s', q_pi, pi)[self.env.reachable_states].mean() 215 | score /= opt_perf 216 | total_perf.append(score.item()) 217 | 218 | total_perf = np.array(total_perf) 219 | return np.mean(total_perf) -------------------------------------------------------------------------------- /atari_modules/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | from atari_modules.replay_buffer import ReplayBuffer 7 | from atari_modules.models import critic 8 | import pickle 9 | import csv 10 | from atari_modules.wrappers import goal_distance 11 | 12 | 13 | class dqn_agent: 14 | def __init__(self, args, env, env_params): 15 | self.args = args 16 | self.env = env 17 | self.env_params = env_params 18 | # create the network 19 | self.critic_network = critic(env_params) 20 | # build up the target network 21 | self.critic_target_network = critic(env_params) 22 | # load the weights into the target networks 23 | self.critic_target_network.load_state_dict(self.critic_network.state_dict()) 24 | # if use gpu 25 | if self.args.cuda: 26 | self.critic_network.cuda() 27 | self.critic_target_network.cuda() 28 | # create the optimizer 29 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr) 30 | # her sampler 31 | # self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, self.env.compute_reward) 32 | # create the replay buffer 33 | # self.buffer = her_replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions) 34 | self.buffer = ReplayBuffer(self.args.buffer_size) 35 | # create the dict for store the model 36 | if self.args.save_dir is not None: 37 | if not os.path.exists(self.args.save_dir): 38 | os.mkdir(self.args.save_dir) 39 | 40 | print(' ' * 26 + 'Options') 41 | for k, v in vars(self.args).items(): 42 | print(' ' * 26 + k + ': ' + str(v)) 43 | 44 | with open(self.args.save_dir + "/arguments.pkl", 'wb') as f: 45 | pickle.dump(self.args, f) 46 | 47 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "wt") as monitor_file: 48 | monitor = csv.writer(monitor_file) 49 | monitor.writerow(['epoch', 'eval', 'avg dist']) 50 | 51 | def learn(self): 52 | """ 53 | train the network 54 | 55 | """ 56 | # start to collect samples 57 | for epoch in range(self.args.n_epochs): 58 | for _ in range(self.args.n_cycles): 59 | # mb_obs, mb_ag, mb_g, mb_actions, mb_dones = [], [], [], [], [] 60 | for _ in range(self.args.num_rollouts_per_cycle): 61 | # reset the rollouts 62 | # ep_obs, ep_ag, ep_g, ep_actions, ep_dones = [], [], [], [], [] 63 | # reset the environment 64 | observation = self.env.reset() 65 | obs = observation['observation'] 66 | ag = observation['achieved_goal'] 67 | g = observation['desired_goal'] 68 | # start to collect samples 69 | for t in range(self.env_params['max_timesteps']): 70 | with torch.no_grad(): 71 | obs_tensor = self._preproc_o(obs) 72 | g_tensor = self._preproc_g(g) 73 | action = self.act_e_greedy(obs_tensor, g_tensor, update_eps=0.2) 74 | # feed the actions into the environment 75 | observation_new, reward, done, info = self.env.step(action) 76 | obs_new = observation_new['observation'] 77 | ag_new = observation_new['achieved_goal'] 78 | # add transition to replay buffer 79 | self.buffer.add(obs, ag, g, action, reward, obs_new, done) 80 | # append rollouts 81 | # ep_obs.append(np.array(obs, dtype=np.uint8)) 82 | # ep_ag.append(ag.copy()) 83 | # ep_g.append(g.copy()) 84 | # ep_actions.append(action) 85 | # ep_dones.append(float(done)) 86 | # re-assign the observation 87 | if done: 88 | observation = self.env.reset() 89 | obs = observation['observation'] 90 | ag = observation['achieved_goal'] 91 | g = observation['desired_goal'] 92 | else: 93 | obs = obs_new 94 | ag = ag_new 95 | # ep_obs.append(np.array(obs, dtype=np.uint8)) 96 | # ep_ag.append(ag.copy()) 97 | # mb_obs.append(ep_obs) 98 | # mb_ag.append(ep_ag) 99 | # mb_g.append(ep_g) 100 | # mb_actions.append(ep_actions) 101 | # mb_dones.append(ep_dones) 102 | # convert them into arrays 103 | # mb_obs = np.array(mb_obs, dtype=np.uint8) 104 | # mb_ag = np.array(mb_ag) 105 | # mb_g = np.array(mb_g) 106 | # mb_actions = np.array(mb_actions) 107 | # mb_dones = np.array(mb_dones) 108 | # store the episodes 109 | # self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_dones]) 110 | for _ in range(self.args.n_batches): 111 | # train the network 112 | self._update_network() 113 | # soft update 114 | self._soft_update_target_network(self.critic_target_network, self.critic_network) 115 | # start to do the evaluation 116 | success_rate, avg_dist = self._eval_agent() 117 | print('[{}] epoch is: {}, eval success rate is: {:.3f}, avg dist: {:.3f}'.format(datetime.now(), epoch, success_rate, avg_dist)) 118 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "a") as monitor_file: 119 | monitor = csv.writer(monitor_file) 120 | monitor.writerow([epoch, success_rate, avg_dist]) 121 | torch.save([self.critic_network.state_dict()], 122 | os.path.join(self.args.save_dir, 'model.pt')) 123 | # print('n_transitions_stored: {}'.format(self.buffer.n_transitions_stored)) 124 | # print('current replay size: {}, percentage: {}'.format(self.buffer.current_size, self.buffer.current_size / self.buffer.size * 100)) 125 | # torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std, 126 | # self.actor_network.state_dict()], \ 127 | # self.model_path + '/model.pt') 128 | 129 | # pre_process the inputs 130 | def _preproc_o(self, obs): 131 | obs = np.transpose(np.array(obs)[None] / 255., [0, 3, 1, 2]) 132 | obs_tensor = torch.tensor(obs, dtype=torch.float32) 133 | if self.args.cuda: 134 | obs_tensor = obs_tensor.cuda() 135 | return obs_tensor 136 | 137 | def _preproc_g(self, g): 138 | g_tensor = torch.tensor(g[None] / 170, dtype=torch.float32) 139 | if self.args.cuda: 140 | g_tensor = g_tensor.cuda() 141 | return g_tensor 142 | 143 | # Acts based on single state (no batch) 144 | def act(self, obs, g): 145 | return self.critic_network(obs, g).data.max(1)[1] 146 | 147 | # Acts with an epsilon-greedy policy 148 | def act_e_greedy(self, obs, g, update_eps=0.2): 149 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, g).item() 150 | 151 | # soft update 152 | def _soft_update_target_network(self, target, source): 153 | for target_param, param in zip(target.parameters(), source.parameters()): 154 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 155 | 156 | # update the network 157 | def _update_network(self): 158 | # sample the episodes 159 | transitions = self.buffer.sample(self.args.batch_size) 160 | # pre-process the observation and goal 161 | obs_tensor = torch.tensor(np.transpose(transitions['obs'], [0, 3, 1, 2]) / 255, dtype=torch.float32) 162 | obs_next_tensor = torch.tensor(np.transpose(transitions['obs_next'], [0, 3, 1, 2]) / 255, dtype=torch.float32) 163 | g_tensor = torch.tensor(transitions['g'] / 170, dtype=torch.float32) 164 | dones_tensor = torch.tensor(transitions['done'], dtype=torch.float32).reshape(-1, 1) 165 | actions_tensor = torch.tensor(transitions['action'], dtype=torch.long) 166 | r_tensor = torch.tensor(transitions['reward'], dtype=torch.float32) 167 | if self.args.cuda: 168 | obs_tensor = obs_tensor.cuda() 169 | obs_next_tensor = obs_next_tensor.cuda() 170 | g_tensor = g_tensor.cuda() 171 | dones_tensor = dones_tensor.cuda() 172 | actions_tensor = actions_tensor.cuda() 173 | r_tensor = r_tensor.cuda() 174 | # calculate the target Q value function 175 | with torch.no_grad(): 176 | q_next_value = self.critic_target_network(obs_next_tensor, g_tensor).max(1)[0].reshape(-1, 1) 177 | q_next_value = q_next_value.detach() 178 | target_q_value = r_tensor + (1 - dones_tensor) * self.args.gamma * q_next_value 179 | target_q_value = target_q_value.detach() 180 | # clip the q value 181 | clip_return = 1 / (1 - self.args.gamma) 182 | target_q_value = torch.clamp(target_q_value, 0, clip_return) 183 | # the q loss 184 | real_q_value = self.critic_network(obs_tensor, g_tensor).gather(1, actions_tensor.reshape(-1, 1)) 185 | critic_loss = (target_q_value - real_q_value).pow(2).mean() 186 | # update the critic_network 187 | self.critic_optim.zero_grad() 188 | critic_loss.backward() 189 | self.critic_optim.step() 190 | 191 | # do the evaluation 192 | def _eval_agent(self): 193 | total_success_rate = [] 194 | total_dist = [] 195 | for _ in range(self.args.n_test_rollouts): 196 | observation = self.env.reset() 197 | obs = observation['observation'] 198 | g = observation['desired_goal'] 199 | for _ in range(self.env_params['max_timesteps']): 200 | with torch.no_grad(): 201 | # import pdb 202 | # pdb.set_trace() 203 | obs_tensor = self._preproc_o(obs) 204 | g_tensor = self._preproc_g(g) 205 | action = self.act_e_greedy(obs_tensor, g_tensor, update_eps=0.01) 206 | observation_new, _, done, info = self.env.step(action) 207 | obs = observation_new['observation'] 208 | g = observation_new['desired_goal'] 209 | dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal']) 210 | if info['is_success'] > 0 or done: 211 | break 212 | 213 | total_success_rate.append(info['is_success']) 214 | total_dist.append(dist) 215 | total_success_rate = np.array(total_success_rate) 216 | success_rate = np.mean(total_success_rate) 217 | total_dist = np.array(total_dist) 218 | dist = np.mean(total_dist) 219 | return success_rate, dist 220 | -------------------------------------------------------------------------------- /atari_modules/her_dqn_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | from atari_modules.replay_buffer import her_replay_buffer 7 | from atari_modules.models import critic 8 | from atari_modules.her import her_sampler 9 | import pickle 10 | import csv 11 | 12 | from atari_modules.wrappers import goal_distance 13 | 14 | 15 | class HerDQNAgent: 16 | def __init__(self, args, env, env_params): 17 | self.args = args 18 | self.env = env 19 | self.env_params = env_params 20 | # create the network 21 | self.critic_network = critic(env_params) 22 | # build up the target network 23 | self.critic_target_network = critic(env_params) 24 | # load the weights into the target networks 25 | self.critic_target_network.load_state_dict(self.critic_network.state_dict()) 26 | # if use gpu 27 | if self.args.cuda: 28 | self.critic_network.cuda() 29 | self.critic_target_network.cuda() 30 | # create the optimizer 31 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr) 32 | # her sampler 33 | self.her_module = her_sampler('future', self.args.replay_k, self.env.compute_reward) 34 | # create the replay buffer 35 | self.buffer = her_replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions) 36 | # self.buffer = ReplayBuffer(self.args.buffer_size) 37 | # create the dict for store the model 38 | if self.args.save_dir is not None: 39 | if not os.path.exists(self.args.save_dir): 40 | os.mkdir(self.args.save_dir) 41 | 42 | print(' ' * 26 + 'Options') 43 | for k, v in vars(self.args).items(): 44 | print(' ' * 26 + k + ': ' + str(v)) 45 | 46 | with open(self.args.save_dir + "/arguments.pkl", 'wb') as f: 47 | pickle.dump(self.args, f) 48 | 49 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "wt") as monitor_file: 50 | monitor = csv.writer(monitor_file) 51 | monitor.writerow(['epoch', 'eval', 'avg dist']) 52 | 53 | def learn(self): 54 | """ 55 | train the network 56 | 57 | """ 58 | # start to collect samples 59 | for epoch in range(self.args.n_epochs): 60 | for _ in range(self.args.n_cycles): 61 | mb_obs, mb_ag, mb_g, mb_actions, mb_dones = [], [], [], [], [] 62 | for _ in range(self.args.num_rollouts_per_cycle): 63 | # reset the rollouts 64 | ep_obs, ep_ag, ep_g, ep_actions, ep_dones = [], [], [], [], [] 65 | # reset the environment 66 | observation = self.env.reset() 67 | obs = observation['observation'] 68 | ag = observation['achieved_goal'] 69 | g = observation['desired_goal'] 70 | # start to collect samples 71 | for t in range(self.env_params['max_timesteps']): 72 | with torch.no_grad(): 73 | obs_tensor = self._preproc_o(obs) 74 | g_tensor = self._preproc_g(g) 75 | action = self.act_e_greedy(obs_tensor, g_tensor, update_eps=0.2) 76 | # feed the actions into the environment 77 | observation_new, reward, done, info = self.env.step(action) 78 | obs_new = observation_new['observation'] 79 | ag_new = observation_new['achieved_goal'] 80 | # add transition to replay buffer 81 | # self.buffer.add(obs, ag, g, action, reward, obs_new, done) 82 | # append rollouts 83 | ep_obs.append(np.array(obs, dtype=np.uint8)) 84 | ep_ag.append(ag.copy()) 85 | ep_g.append(g.copy()) 86 | ep_actions.append(action) 87 | ep_dones.append(float(done)) 88 | # re-assign the observation 89 | if done: 90 | observation = self.env.reset() 91 | obs = observation['observation'] 92 | ag = observation['achieved_goal'] 93 | g = observation['desired_goal'] 94 | else: 95 | obs = obs_new 96 | ag = ag_new 97 | ep_obs.append(np.array(obs, dtype=np.uint8)) 98 | ep_ag.append(ag.copy()) 99 | mb_obs.append(ep_obs) 100 | mb_ag.append(ep_ag) 101 | mb_g.append(ep_g) 102 | mb_actions.append(ep_actions) 103 | mb_dones.append(ep_dones) 104 | # convert them into arrays 105 | mb_obs = np.array(mb_obs, dtype=np.uint8) 106 | mb_ag = np.array(mb_ag) 107 | mb_g = np.array(mb_g) 108 | mb_actions = np.array(mb_actions) 109 | mb_dones = np.array(mb_dones) 110 | # store the episodes 111 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions, mb_dones]) 112 | for _ in range(self.args.n_batches): 113 | # train the network 114 | self._update_network() 115 | # soft update 116 | self._soft_update_target_network(self.critic_target_network, self.critic_network) 117 | # start to do the evaluation 118 | success_rate, avg_dist = self._eval_agent() 119 | print('[{}] epoch is: {}, eval success rate is: {:.3f}, avg dist: {:.3f}'.format(datetime.now(), epoch, success_rate, avg_dist)) 120 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "a") as monitor_file: 121 | monitor = csv.writer(monitor_file) 122 | monitor.writerow([epoch, success_rate, avg_dist]) 123 | torch.save([self.critic_network.state_dict()], 124 | os.path.join(self.args.save_dir, 'model.pt')) 125 | # print('n_transitions_stored: {}'.format(self.buffer.n_transitions_stored)) 126 | # print('current replay size: {}, percentage: {}'.format(self.buffer.current_size, self.buffer.current_size / self.buffer.size * 100)) 127 | # torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std, 128 | # self.actor_network.state_dict()], \ 129 | # self.model_path + '/model.pt') 130 | 131 | # pre_process the inputs 132 | def _preproc_o(self, obs): 133 | obs = np.transpose(np.array(obs)[None] / 255., [0, 3, 1, 2]) 134 | obs_tensor = torch.tensor(obs, dtype=torch.float32) 135 | if self.args.cuda: 136 | obs_tensor = obs_tensor.cuda() 137 | return obs_tensor 138 | 139 | def _preproc_g(self, g): 140 | g_tensor = torch.tensor(g[None] / 170, dtype=torch.float32) 141 | if self.args.cuda: 142 | g_tensor = g_tensor.cuda() 143 | return g_tensor 144 | 145 | 146 | # Acts based on single state (no batch) 147 | def act(self, obs, g): 148 | return self.critic_network(obs, g).data.max(1)[1] 149 | 150 | # Acts with an epsilon-greedy policy 151 | def act_e_greedy(self, obs, g, update_eps=0.2): 152 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, g).item() 153 | 154 | # soft update 155 | def _soft_update_target_network(self, target, source): 156 | for target_param, param in zip(target.parameters(), source.parameters()): 157 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 158 | 159 | # update the network 160 | def _update_network(self): 161 | # sample the episodes 162 | transitions = self.buffer.sample(self.args.batch_size) 163 | # pre-process the observation and goal 164 | obs_tensor = torch.tensor(np.transpose(transitions['obs'], [0, 3, 1, 2]) / 255, dtype=torch.float32) 165 | obs_next_tensor = torch.tensor(np.transpose(transitions['obs_next'], [0, 3, 1, 2]) / 255, dtype=torch.float32) 166 | g_tensor = torch.tensor(transitions['g'] / 170, dtype=torch.float32) 167 | # import pdb 168 | # pdb.set_trace() 169 | dones_tensor = torch.tensor(transitions['done'], dtype=torch.float32).reshape(-1, 1) 170 | actions_tensor = torch.tensor(transitions['actions'], dtype=torch.long) 171 | r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) 172 | if self.args.cuda: 173 | obs_tensor = obs_tensor.cuda() 174 | obs_next_tensor = obs_next_tensor.cuda() 175 | g_tensor = g_tensor.cuda() 176 | dones_tensor = dones_tensor.cuda() 177 | actions_tensor = actions_tensor.cuda() 178 | r_tensor = r_tensor.cuda() 179 | # calculate the target Q value function 180 | with torch.no_grad(): 181 | # do the normalization 182 | q_next_value = self.critic_target_network(obs_next_tensor, g_tensor).max(1)[0].reshape(-1, 1) 183 | q_next_value = q_next_value.detach() 184 | target_q_value = r_tensor + (1 - dones_tensor) * self.args.gamma * q_next_value 185 | target_q_value = target_q_value.detach() 186 | # clip the q value 187 | clip_return = 1 / (1 - self.args.gamma) 188 | target_q_value = torch.clamp(target_q_value, 0, clip_return) 189 | # the q loss 190 | real_q_value = self.critic_network(obs_tensor, g_tensor).gather(1, actions_tensor.reshape(-1, 1)) 191 | critic_loss = (target_q_value - real_q_value).pow(2).mean() 192 | # update the critic_network 193 | self.critic_optim.zero_grad() 194 | critic_loss.backward() 195 | self.critic_optim.step() 196 | 197 | # do the evaluation 198 | def _eval_agent(self): 199 | total_success_rate = [] 200 | total_dist = [] 201 | for _ in range(self.args.n_test_rollouts): 202 | observation = self.env.reset() 203 | obs = observation['observation'] 204 | g = observation['desired_goal'] 205 | for _ in range(self.env_params['max_timesteps']): 206 | with torch.no_grad(): 207 | # import pdb 208 | # pdb.set_trace() 209 | obs_tensor = self._preproc_o(obs) 210 | g_tensor = self._preproc_g(g) 211 | action = self.act_e_greedy(obs_tensor, g_tensor, update_eps=0.01) 212 | observation_new, _, done, info = self.env.step(action) 213 | obs = observation_new['observation'] 214 | g = observation_new['desired_goal'] 215 | dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal']) 216 | if info['is_success'] > 0 or done: 217 | break 218 | 219 | total_success_rate.append(info['is_success']) 220 | total_dist.append(dist) 221 | total_success_rate = np.array(total_success_rate) 222 | success_rate = np.mean(total_success_rate) 223 | total_dist = np.array(total_dist) 224 | dist = np.mean(total_dist) 225 | return success_rate, dist -------------------------------------------------------------------------------- /atari_modules/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | os.environ.setdefault('PATH', '') 4 | from collections import deque 5 | import gym 6 | from gym import spaces 7 | import cv2 8 | cv2.ocl.setUseOpenCL(False) 9 | from atariari.benchmark.wrapper import AtariARIWrapper, ram2label 10 | from baselines.common.atari_wrappers import MaxAndSkipEnv, WarpFrame, LazyFrames 11 | 12 | 13 | class NoopResetEnv(gym.Wrapper): 14 | def __init__(self, env, noops=30): 15 | """Sample initial states by taking random number of no-ops on reset. 16 | No-op is assumed to be action 0. 17 | modification from baselines 18 | """ 19 | gym.Wrapper.__init__(self, env) 20 | self.noops = noops 21 | self.noop_action = 0 22 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 23 | 24 | def reset(self, **kwargs): 25 | """ Do no-op action for a number of steps in [1, noop_max].""" 26 | self.env.reset(**kwargs) 27 | obs = None 28 | for _ in range(self.noops): 29 | obs, _, done, _ = self.env.step(self.noop_action) 30 | if done: 31 | obs = self.env.reset(**kwargs) 32 | return obs 33 | 34 | 35 | class FrameStack(gym.Wrapper): 36 | def __init__(self, env, k): 37 | """Stack k last frames. 38 | Returns lazy array, which is much more memory efficient. 39 | -------- 40 | modification from baselines.common.atari_wrappers (step function) 41 | """ 42 | gym.Wrapper.__init__(self, env) 43 | self.k = k 44 | self.frames = deque([], maxlen=k) 45 | shp = env.observation_space.shape 46 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) 47 | 48 | def reset(self): 49 | ob = self.env.reset() 50 | for _ in range(self.k): 51 | self.frames.append(ob) 52 | return self._get_ob() 53 | 54 | def step(self, action): 55 | total_reward = 0.0 56 | done = None 57 | for i in range(3): 58 | obs, reward, done, info = self.env.step(action) 59 | total_reward += reward 60 | self.frames.append(obs) 61 | if done: 62 | break 63 | return self._get_ob(), total_reward, done, info 64 | 65 | def _get_ob(self): 66 | assert len(self.frames) == self.k 67 | return LazyFrames(list(self.frames)) 68 | 69 | 70 | def make_atari(env_id): 71 | env = gym.make(env_id) 72 | assert 'NoFrameskip' in env.spec.id 73 | env = NoopResetEnv(env, noops=240) 74 | env = MaxAndSkipEnv(env, skip=4) 75 | return env 76 | 77 | 78 | # def goal_distance(goal_a, goal_b): 79 | # assert goal_a.shape == goal_b.shape 80 | # return np.linalg.norm(goal_a - goal_b, axis=-1) 81 | 82 | def goal_distance(goal_a, goal_b): 83 | assert goal_a.shape == goal_b.shape 84 | return np.max(np.abs(goal_a - goal_b), axis=-1) 85 | 86 | 87 | class CroppedFrame(gym.ObservationWrapper): 88 | def __init__(self, env): 89 | gym.ObservationWrapper.__init__(self, env) 90 | self.observation_space = gym.spaces.Box(low=0, high=255, shape=(170, 160, 3), dtype=np.uint8) 91 | 92 | def observation(self, observation): 93 | # careful! This undoes the memory optimization, use 94 | # with smaller replay buffers only. 95 | return observation[: 170, :, :] 96 | 97 | 98 | class LifeLossEnv(gym.Wrapper): 99 | def __init__(self, env): 100 | """Make a life loss an end of the episode. 101 | """ 102 | gym.Wrapper.__init__(self, env) 103 | self.lives = 0 104 | 105 | def step(self, action): 106 | obs, reward, done, info = self.env.step(action) 107 | # check current lives, make loss of life terminal, 108 | # then update lives to handle bonus lives 109 | lives = self.env.unwrapped.ale.lives() 110 | if lives < self.lives and lives > 0: 111 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 112 | # so it's important to keep lives > 0, so that we only reset once 113 | # the environment advertises done. 114 | done = True 115 | self.lives = lives 116 | return obs, reward, done, info 117 | 118 | 119 | class GoalMsPacman(gym.Wrapper): 120 | def __init__(self, env, distance_threshold=6, reward_type='sparse'): 121 | self.distance_threshold = distance_threshold 122 | self.reward_type = reward_type 123 | self.env = env 124 | # Maintain list of reachable goals 125 | # is_valid_idx = np.array([ 126 | # [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1], 127 | # [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1], 128 | # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 129 | # [0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0], 130 | # [1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], 131 | # [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], 132 | # [0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0], 133 | # [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], 134 | # [1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1], 135 | # [0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], 136 | # [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1], 137 | # [1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1], 138 | # [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1], 139 | # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=np.bool) 140 | # valid_idx = np.transpose(np.nonzero(is_valid_idx)) 141 | # 142 | # # columns and rows correspond to x and y respectively 143 | # self.all_goals = [] # Translate goal space to pixel space range 144 | # prev_row = 0 145 | # for row_i, row in enumerate(range(14, 171, 12)): 146 | # prev_col = 0 147 | # for col_i, col in enumerate(range(14, 160, 8)): 148 | # if np.sum(np.all(valid_idx == np.array([row_i, col_i]), axis=1)) > 0: 149 | # self.all_goals.append(np.array([int((prev_col + col)/2), int((prev_row + row)/2)])) 150 | # prev_col = col 151 | # prev_row = row 152 | # self.total_num_goals = len(self.all_goals) 153 | 154 | all_goals = [[11, 9], [18, 9], [27, 9], [34, 9], [50, 9], [58, 9], [66, 9], [76, 9], [86, 9], [95, 9], [104, 9], [112, 9], [127, 9], [135, 9], [143, 9], [150, 9], 155 | [36, 21], [50, 21], [112, 21], [128, 21], [152, 21], 156 | [11, 34], [18, 34], [35, 34], [43, 34], [51, 34], [59, 34], [68, 34], [76, 34], [87, 34], [95, 34], [103, 34], [111, 34], [119, 34], [127, 34], [135, 34], [143, 34], [151, 34], 157 | [20, 46], [34, 46], [60, 46], [103, 46], [129, 46], [144, 46], 158 | [11, 57], [19, 57], [35, 57], [43, 57], [51, 57], [59, 57], [67, 57], [75, 57], [87, 57], [95, 57], [103, 57], [111, 57], [118, 57], [127, 57], [143, 57], [151, 57], 159 | [20, 69], [59, 69], [104, 69], [144, 69], 160 | [20, 82], [27, 82], [35, 82], [43, 82], [43, 82], [51, 82], [59, 82], [103, 82], [111, 82], [128, 82], [135, 82], [142, 82], 161 | [20, 93], [58, 93], [103, 93], [143, 93], 162 | [12, 105], [19, 105], [35, 105], [43, 105], [51, 105], [110, 105], [118, 105], [126, 105], [143, 105], [151, 105], 163 | [19, 117], [35, 117], [50, 117], [50, 117], [67, 117], [94, 117], [110, 117], [127, 117], [143, 117], 164 | [12, 129], [19, 129], [27, 129], [35, 129], [35, 129], [51, 129], [67, 129], [75, 129], [86, 129], [95, 129], [111, 129], [126, 129], [135, 129], [143, 129], [151, 129], 165 | [12, 141], [35, 141], [50, 141], [67, 141], [95, 141], [102, 141], [112, 141], [127, 141], [151, 141], 166 | [12, 153], [35, 153], [67, 153], [95, 153], [127, 153], [151, 153], 167 | [12, 165], [19, 165], [27, 165], [35, 165], [43, 165], [50, 165], [59, 165], [68, 165], [76, 165], [87, 165], [87, 165], [96, 165], [103, 165], 168 | [112, 165], [119, 165], [127, 165], [136, 165], [142, 165], [151, 165]] 169 | self.all_goals = np.array(all_goals) 170 | self.all_goals += np.array([[0, -2]]) 171 | self.total_num_goals = len(self.all_goals) 172 | 173 | self.action_space = spaces.Discrete(5) 174 | obs = self.reset() 175 | self.observation_space = dict( 176 | desired_goal=spaces.Box(-np.inf, np.inf, shape=obs['achieved_goal'].shape, dtype='float32'), 177 | achieved_goal=spaces.Box(-np.inf, np.inf, shape=obs['achieved_goal'].shape, dtype='float32'), 178 | observation=self.env.observation_space, 179 | ) 180 | 181 | # sanity check 182 | # obs = env.reset() 183 | # for g in all_goals: 184 | # obs[g[1]-2: g[1]+2, g[0]-2:g[0]+2, :] = 255 185 | 186 | def _get_pos(self): 187 | ram = self.env.unwrapped.ale.getRAM() 188 | label_dict = ram2label(self.env.spec.id, ram) 189 | return np.array([label_dict['player_x'], label_dict['player_y']]) + np.array([-8, 6]) 190 | 191 | def reset(self): 192 | lazy_obs = self.env.reset() 193 | achieved_goal = self._get_pos() 194 | self.goal = self._sample_goal() 195 | return { 196 | 'observation': lazy_obs, 197 | 'achieved_goal': achieved_goal.copy(), 198 | 'desired_goal': self.goal.copy() 199 | } 200 | 201 | def step(self, action): 202 | lazy_obs, reward, done, info = self.env.step(action) 203 | achieved_goal = self._get_pos() 204 | obs = { 205 | 'observation': lazy_obs, 206 | 'achieved_goal': achieved_goal.copy(), 207 | 'desired_goal': self.goal.copy() 208 | } 209 | info['is_success'] = self._is_success(obs['achieved_goal'], self.goal) 210 | reward = self.compute_reward(obs['achieved_goal'], self.goal, None) 211 | return obs, reward, done, info 212 | 213 | def _sample_goal(self): 214 | id = np.random.randint(self.total_num_goals) 215 | return self.all_goals[id] 216 | 217 | def set_goal(self, g): 218 | self.goal = g 219 | 220 | def _is_success(self, achieved_goal, desired_goal): 221 | d = goal_distance(achieved_goal, desired_goal) 222 | return (d <= self.distance_threshold).astype(np.float32) 223 | 224 | def compute_reward(self, achieved_goal, goal, info): 225 | # Compute distance between goal and the achieved goal. 226 | d = goal_distance(achieved_goal, goal) 227 | if self.reward_type == 'sparse': 228 | return (d <= self.distance_threshold).astype(np.float32) 229 | else: 230 | return -d 231 | 232 | 233 | def make_goalPacman(): 234 | env = make_atari('MsPacmanNoFrameskip-v4') 235 | env = LifeLossEnv(env) 236 | env = CroppedFrame(env) 237 | env = WarpFrame(env) 238 | env = FrameStack(env, 4) 239 | env = GoalMsPacman(env) 240 | return env 241 | 242 | 243 | if __name__=='__main__': 244 | import matplotlib as mpl 245 | # mpl.use('TkAgg') 246 | import matplotlib.pyplot as plt 247 | env = make_atari('MsPacmanNoFrameskip-v4') 248 | env = LifeLossEnv(env) 249 | env = CroppedFrame(env) 250 | env = WarpFrame(env) 251 | env = FrameStack(env, 4) 252 | env = GoalMsPacman(env) 253 | obs = env.reset() 254 | 255 | env.set_goal(np.array([40, 160])) 256 | for i in range(100): 257 | plt.imsave('plots/step_{}.jpg'.format(i), env.unwrapped._get_obs()) 258 | if i < 30: 259 | obs, reward, done, info = env.step(3) 260 | else: 261 | obs, reward, done, info = env.step(4) 262 | if done: 263 | print('Death') 264 | if info['is_success'] > 0: 265 | print('Success !!') 266 | print(i) 267 | break 268 | 269 | # raw_obs = env.unwrapped._get_obs() 270 | # plt.imshow(raw_obs) 271 | # plt.show() 272 | 273 | -------------------------------------------------------------------------------- /discrete_action_robots_modules/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | from discrete_action_robots_modules.replay_buffer import replay_buffer 7 | from discrete_action_robots_modules.models import critic 8 | from discrete_action_robots_modules.normalizer import normalizer 9 | from her_modules.her import her_sampler 10 | from discrete_action_robots_modules.robots import goal_distance 11 | import csv 12 | import pickle 13 | 14 | 15 | class DQNAgent: 16 | def __init__(self, args, env, env_params): 17 | self.args = args 18 | self.env = env 19 | self.env_params = env_params 20 | # create the network 21 | self.critic_network = critic(env_params) 22 | # build up the target network 23 | self.critic_target_network = critic(env_params) 24 | # load the weights into the target networks 25 | self.critic_target_network.load_state_dict(self.critic_network.state_dict()) 26 | # if use gpu 27 | if self.args.cuda: 28 | self.critic_network.cuda() 29 | self.critic_target_network.cuda() 30 | # create the optimizer 31 | self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic) 32 | # her sampler 33 | self.her_module = her_sampler(self.args.replay_strategy, self.args.replay_k, self.env.compute_reward) 34 | # create the replay buffer 35 | self.buffer = replay_buffer(self.env_params, self.args.buffer_size, self.her_module.sample_her_transitions) 36 | # create the normalizer 37 | self.o_norm = normalizer(size=env_params['obs'], default_clip_range=self.args.clip_range) 38 | self.g_norm = normalizer(size=env_params['goal'], default_clip_range=self.args.clip_range) 39 | # create the dict for store the model 40 | if args.save_dir is not None: 41 | if not os.path.exists(self.args.save_dir): 42 | os.mkdir(self.args.save_dir) 43 | 44 | print(' ' * 26 + 'Options') 45 | for k, v in vars(self.args).items(): 46 | print(' ' * 26 + k + ': ' + str(v)) 47 | 48 | with open(self.args.save_dir + "/arguments.pkl", 'wb') as f: 49 | pickle.dump(self.args, f) 50 | 51 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "wt") as monitor_file: 52 | monitor = csv.writer(monitor_file) 53 | monitor.writerow(['epoch', 'eval', 'dist']) 54 | 55 | def learn(self): 56 | """ 57 | train the network 58 | 59 | """ 60 | # start to collect samples 61 | for epoch in range(self.args.n_epochs): 62 | for _ in range(self.args.n_cycles): 63 | mb_obs, mb_ag, mb_g, mb_actions = [], [], [], [] 64 | for _ in range(self.args.num_rollouts_per_cycle): 65 | # reset the rollouts 66 | ep_obs, ep_ag, ep_g, ep_actions = [], [], [], [] 67 | # reset the environment 68 | observation = self.env.reset() 69 | obs = observation['observation'] 70 | ag = observation['achieved_goal'] 71 | g = observation['desired_goal'] 72 | # start to collect samples 73 | for t in range(self.env_params['max_timesteps']): 74 | with torch.no_grad(): 75 | obs_norm_tensor = self._preproc_o(obs) 76 | g_norm_tensor = self._preproc_g(g) 77 | action = self.act_e_greedy(obs_norm_tensor, g_norm_tensor, update_eps=0.2) 78 | # feed the actions into the environment 79 | observation_new, _, _, info = self.env.step(action) 80 | obs_new = observation_new['observation'] 81 | ag_new = observation_new['achieved_goal'] 82 | # append rollouts 83 | ep_obs.append(obs.copy()) 84 | ep_ag.append(ag.copy()) 85 | ep_g.append(g.copy()) 86 | ep_actions.append(action) 87 | # re-assign the observation 88 | obs = obs_new 89 | ag = ag_new 90 | ep_obs.append(obs.copy()) 91 | ep_ag.append(ag.copy()) 92 | mb_obs.append(ep_obs) 93 | mb_ag.append(ep_ag) 94 | mb_g.append(ep_g) 95 | mb_actions.append(ep_actions) 96 | # convert them into arrays 97 | mb_obs = np.array(mb_obs) 98 | mb_ag = np.array(mb_ag) 99 | mb_g = np.array(mb_g) 100 | mb_actions = np.array(mb_actions) 101 | # store the episodes 102 | self.buffer.store_episode([mb_obs, mb_ag, mb_g, mb_actions]) 103 | self._update_normalizer([mb_obs, mb_ag, mb_g, mb_actions]) 104 | for _ in range(self.args.n_batches): 105 | # train the network 106 | self._update_network() 107 | # soft update 108 | self._soft_update_target_network(self.critic_target_network, self.critic_network) 109 | # start to do the evaluation 110 | # import pdb 111 | # pdb.set_trace() 112 | success_rate, dist = self._eval_agent() 113 | print('[{}] epoch is: {}, train success rate is: {:.3f},' 114 | ' dist: {:.3f}'.format(datetime.now(), epoch, success_rate, dist)) 115 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "a") as monitor_file: 116 | monitor = csv.writer(monitor_file) 117 | monitor.writerow([epoch, success_rate, dist]) 118 | # torch.save([self.o_norm.mean, self.o_norm.std, self.g_norm.mean, self.g_norm.std, 119 | # self.critic_network.state_dict()], \ 120 | # self.model_path + '/model.pt') 121 | 122 | # pre_process the inputs 123 | def _preproc_o(self, obs): 124 | obs_norm = self.o_norm.normalize(obs) 125 | obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32).unsqueeze(0) 126 | if self.args.cuda: 127 | obs_norm_tensor = obs_norm_tensor.cuda() 128 | return obs_norm_tensor 129 | 130 | def _preproc_g(self, g): 131 | g_norm = self.g_norm.normalize(g) 132 | g_norm_tensor = torch.tensor(g_norm, dtype=torch.float32).unsqueeze(0) 133 | if self.args.cuda: 134 | g_norm_tensor = g_norm_tensor.cuda() 135 | return g_norm_tensor 136 | 137 | # Acts based on single state (no batch) 138 | def act(self, obs, g): 139 | return self.critic_network(obs, g).data.max(1)[1][0] 140 | 141 | # Acts with an epsilon-greedy policy 142 | def act_e_greedy(self, obs, g, update_eps=0.2): 143 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, g) 144 | 145 | # update the normalizer 146 | def _update_normalizer(self, episode_batch): 147 | mb_obs, mb_ag, mb_g, mb_actions = episode_batch 148 | mb_obs_next = mb_obs[:, 1:, :] 149 | mb_ag_next = mb_ag[:, 1:, :] 150 | # get the number of normalization transitions 151 | num_transitions = mb_actions.shape[1] 152 | # create the new buffer to store them 153 | buffer_temp = {'obs': mb_obs, 154 | 'ag': mb_ag, 155 | 'g': mb_g, 156 | 'actions': mb_actions, 157 | 'obs_next': mb_obs_next, 158 | 'ag_next': mb_ag_next, 159 | } 160 | transitions = self.her_module.sample_her_transitions(buffer_temp, num_transitions) 161 | obs, g = transitions['obs'], transitions['g'] 162 | # pre process the obs and g 163 | transitions['obs'], transitions['g'] = self._preproc_og(obs, g) 164 | # update 165 | self.o_norm.update(transitions['obs']) 166 | self.g_norm.update(transitions['g']) 167 | # recompute the stats 168 | self.o_norm.recompute_stats() 169 | self.g_norm.recompute_stats() 170 | 171 | def _preproc_og(self, o, g): 172 | o = np.clip(o, -self.args.clip_obs, self.args.clip_obs) 173 | g = np.clip(g, -self.args.clip_obs, self.args.clip_obs) 174 | return o, g 175 | 176 | # soft update 177 | def _soft_update_target_network(self, target, source): 178 | for target_param, param in zip(target.parameters(), source.parameters()): 179 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 180 | 181 | # update the network 182 | def _update_network(self): 183 | # sample the episodes 184 | transitions = self.buffer.sample(self.args.batch_size) 185 | # pre-process the observation and goal 186 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g'] 187 | transitions['obs'], transitions['g'] = self._preproc_og(o, g) 188 | transitions['obs_next'], _ = self._preproc_og(o_next, g) 189 | # start to do the update 190 | obs_norm = self.o_norm.normalize(transitions['obs']) 191 | g_norm = self.g_norm.normalize(transitions['g']) 192 | obs_next_norm = self.o_norm.normalize(transitions['obs_next']) 193 | # transfer them into the tensor 194 | obs_norm_tensor = torch.tensor(obs_norm, dtype=torch.float32) 195 | g_norm_tensor = torch.tensor(g_norm, dtype=torch.float32) 196 | obs_next_norm_tensor = torch.tensor(obs_next_norm, dtype=torch.float32) 197 | actions_tensor = torch.tensor(transitions['actions'], dtype=torch.long) 198 | r_tensor = torch.tensor(transitions['r'], dtype=torch.float32) 199 | if self.args.cuda: 200 | obs_norm_tensor = obs_norm_tensor.cuda() 201 | g_norm_tensor = g_norm_tensor.cuda() 202 | obs_next_norm_tensor = obs_next_norm_tensor.cuda() 203 | actions_tensor = actions_tensor.cuda() 204 | r_tensor = r_tensor.cuda() 205 | # calculate the target Q value function 206 | with torch.no_grad(): 207 | q_next_value = self.critic_target_network(obs_next_norm_tensor, g_norm_tensor).max(1)[0].reshape(-1, 1) 208 | q_next_value = q_next_value.detach() 209 | target_q_value = r_tensor + self.args.gamma * q_next_value 210 | target_q_value = target_q_value.detach() 211 | # clip the q value 212 | clip_return = 1 / (1 - self.args.gamma) 213 | target_q_value = torch.clamp(target_q_value, -clip_return, 0) 214 | # the q loss 215 | real_q_value = self.critic_network(obs_norm_tensor, g_norm_tensor).gather(1, actions_tensor.reshape(-1, 1)) 216 | critic_loss = (target_q_value - real_q_value).pow(2).mean() 217 | # update the critic_network 218 | self.critic_optim.zero_grad() 219 | critic_loss.backward() 220 | self.critic_optim.step() 221 | 222 | # do the evaluation 223 | def _eval_agent(self): 224 | total_success_rate = [] 225 | total_dist = [] 226 | for _ in range(self.args.n_test_rollouts): 227 | per_success_rate = [] 228 | per_dist = [] 229 | observation = self.env.reset() 230 | obs = observation['observation'] 231 | g = observation['desired_goal'] 232 | # for _ in range(self.env_params['max_timesteps']): 233 | for _ in range(25): 234 | with torch.no_grad(): 235 | obs_norm_tensor = self._preproc_o(obs) 236 | g_norm_tensor = self._preproc_g(g) 237 | action = self.act(obs_norm_tensor, g_norm_tensor) 238 | observation_new, _, _, info = self.env.step(action) 239 | obs = observation_new['observation'] 240 | g = observation_new['desired_goal'] 241 | dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal']) 242 | # per_dist.append(dist) 243 | # per_success_rate.append(info['is_success']) 244 | per_dist = dist 245 | per_success_rate = info['is_success'] 246 | if info['is_success'] > 0: 247 | break 248 | total_success_rate.append(per_success_rate) 249 | total_dist.append(per_dist) 250 | total_success_rate = np.array(total_success_rate) 251 | avg_success_rate = np.mean(total_success_rate) 252 | total_dist = np.array(total_dist) 253 | avg_dist = np.mean(total_dist) 254 | return avg_success_rate, avg_dist 255 | -------------------------------------------------------------------------------- /atari_modules/fb_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | import pickle 7 | import csv 8 | from atari_modules.replay_buffer import ReplayBuffer 9 | from atari_modules.models import ForwardMap, BackwardMap 10 | from atari_modules.wrappers import goal_distance 11 | from grid_modules.mdp_utils import extract_policy 12 | from torch.distributions.cauchy import Cauchy 13 | 14 | 15 | def compute_entropy(policy): 16 | p_log_p = torch.log(policy) * policy 17 | return -p_log_p.sum(-1) 18 | 19 | 20 | def nanmean(v, *args, inplace=False, **kwargs): 21 | if not inplace: 22 | v = v.clone() 23 | is_nan = torch.isnan(v) 24 | v[is_nan] = 1e-10 25 | return v.mean(*args, **kwargs) 26 | 27 | 28 | class FBAgent: 29 | def __init__(self, args, env, env_params): 30 | self.args = args 31 | self.env = env 32 | self.env_params = env_params 33 | self.cauchy = Cauchy(torch.tensor([0.0]), torch.tensor([0.5])) 34 | # create the network 35 | self.forward_network = ForwardMap(env_params, args.embed_dim) 36 | self.backward_network = BackwardMap(env_params, args.embed_dim) 37 | # build up the target network 38 | self.forward_target_network = ForwardMap(env_params, args.embed_dim) 39 | self.backward_target_network = BackwardMap(env_params, args.embed_dim) 40 | # load the weights into the target networks 41 | self.forward_target_network.load_state_dict(self.forward_network.state_dict()) 42 | self.backward_target_network.load_state_dict(self.backward_network.state_dict()) 43 | # if use gpu 44 | if self.args.cuda: 45 | self.forward_network.cuda() 46 | self.backward_network.cuda() 47 | self.forward_target_network.cuda() 48 | self.backward_target_network.cuda() 49 | # create the optimizer 50 | f_params = [param for param in self.forward_network.parameters()] 51 | b_params = [param for param in self.backward_network.parameters()] 52 | self.fb_optim = torch.optim.Adam(f_params + b_params, lr=self.args.lr) 53 | # self.backward_optim = torch.optim.Adam(self.backward_network.parameters(), lr=self.args.lr_backward) 54 | # create the replay buffer 55 | self.buffer = ReplayBuffer(self.args.buffer_size) 56 | # create the dict for store the model 57 | if self.args.save_dir is not None: 58 | if not os.path.exists(self.args.save_dir): 59 | os.mkdir(self.args.save_dir) 60 | 61 | # print(' ' * 26 + 'Options') 62 | # for k, v in vars(self.args).items(): 63 | # print(' ' * 26 + k + ': ' + str(v)) 64 | 65 | with open(self.args.save_dir + "/arguments.pkl", 'wb') as f: 66 | pickle.dump(self.args, f) 67 | 68 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "wt") as monitor_file: 69 | monitor = csv.writer(monitor_file) 70 | monitor.writerow(['epoch', 'eval', 'avg dist', 'eval (GPI)', 'avg dist (GPI)']) 71 | 72 | def learn(self): 73 | """ 74 | train the network 75 | 76 | """ 77 | # start to collect samples 78 | # print('MPI SIZE: ', MPI.COMM_WORLD.Get_size()) 79 | for epoch in range(self.args.n_epochs): 80 | for _ in range(self.args.n_cycles): 81 | for _ in range(self.args.num_rollouts_per_cycle): 82 | # reset the rollouts 83 | # reset the environment 84 | observation = self.env.reset() 85 | obs = observation['observation'] 86 | ag = observation['achieved_goal'] 87 | g = observation['desired_goal'] 88 | if self.args.w_sampling == 'goal_oriented': 89 | g_tensor = self._preproc_g(g) 90 | with torch.no_grad(): 91 | w = self.backward_network(g_tensor) 92 | elif self.args.w_sampling == 'uniform_ball': 93 | w = self.sample_uniform_ball(1) 94 | elif self.args.w_sampling == 'cauchy_ball': 95 | w = self.sample_cauchy_ball(1) 96 | # start to collect samples 97 | for t in range(self.env_params['max_timesteps']): 98 | with torch.no_grad(): 99 | obs_tensor = self._preproc_o(obs) 100 | action = self.act_e_greedy(obs_tensor, w, update_eps=0.2) 101 | # feed the actions into the environment 102 | observation_new, reward, done, info = self.env.step(action) 103 | obs_new = observation_new['observation'] 104 | ag_new = observation_new['achieved_goal'] 105 | # add transition 106 | self.buffer.add(obs, ag, g, action, reward, obs_new, done) 107 | if done: 108 | observation = self.env.reset() 109 | obs = observation['observation'] 110 | ag = observation['achieved_goal'] 111 | g = observation['desired_goal'] 112 | else: 113 | obs = obs_new 114 | ag = ag_new 115 | for _ in range(self.args.n_batches): 116 | # train the network 117 | self._update_network() 118 | # soft update 119 | self._soft_update_target_network(self.forward_target_network, self.forward_network) 120 | self._soft_update_target_network(self.backward_target_network, self.backward_network) 121 | # start to do the evaluation 122 | success_rate, avg_dist = self._eval_agent() 123 | success_rate_gpi, avg_dist_gpi = self._eval_gpi_agent(num_gpi=self.args.num_gpi) 124 | print('[{}] epoch is: {}, eval: {:.3f}, avg_dist : {:.3f}, ' 125 | 'eval (GPI): {:.3f}, avg_dist (GPI): {:.3f}'.format(datetime.now(), epoch, success_rate, avg_dist, 126 | success_rate_gpi, avg_dist_gpi)) 127 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "a") as monitor_file: 128 | monitor = csv.writer(monitor_file) 129 | monitor.writerow([epoch, success_rate, avg_dist, success_rate_gpi, avg_dist_gpi]) 130 | torch.save([self.forward_network.state_dict(), self.backward_network.state_dict()], 131 | os.path.join(self.args.save_dir, 'model.pt')) 132 | 133 | def sample_uniform_ball(self, n, eps=1e-10): 134 | gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1) 135 | gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps 136 | uniform_rdv = torch.FloatTensor(n, 1).uniform_() 137 | w = np.sqrt(self.args.embed_dim) * gaussian_rdv * uniform_rdv 138 | if self.args.cuda: 139 | w = w.cuda() 140 | return w 141 | 142 | def sample_cauchy_ball(self, n, eps=1e-10): 143 | gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1) 144 | gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps 145 | cauchy_rdv = self.cauchy.sample((n, )) 146 | w = np.sqrt(self.args.embed_dim) * gaussian_rdv * cauchy_rdv 147 | if self.args.cuda: 148 | w = w.cuda() 149 | return w 150 | 151 | # pre_process the inputs 152 | def _preproc_o(self, obs): 153 | obs = np.transpose(np.array(obs)[None] / 255., [0, 3, 1, 2]) 154 | obs_tensor = torch.tensor(obs, dtype=torch.float32) 155 | if self.args.cuda: 156 | obs_tensor = obs_tensor.cuda() 157 | return obs_tensor 158 | 159 | def _preproc_g(self, g): 160 | g_tensor = torch.tensor(g[None] / 170, dtype=torch.float32) 161 | if self.args.cuda: 162 | g_tensor = g_tensor.cuda() 163 | return g_tensor 164 | 165 | def act_gpi(self, obs, w_train, w_eval): 166 | # import pdb 167 | # pdb.set_trace() 168 | num_gpi = w_train.shape[0] 169 | obs_repeat = obs.repeat(num_gpi, 1, 1, 1) 170 | w_eval_repeat = w_eval.repeat(num_gpi, 1) 171 | f = self.forward_network(obs_repeat, w_train) 172 | z = torch.einsum('sda, sd -> sa', f, w_eval_repeat).max(0)[0] 173 | return z.max(0)[1] 174 | 175 | # Acts based on single state (no batch) 176 | def act(self, obs, w, target_network=False): 177 | if target_network: 178 | f = self.forward_target_network(obs, w) 179 | else: 180 | f = self.forward_network(obs, w) 181 | z = torch.einsum('sda, sd -> sa', f, w) 182 | return z.max(1)[1] 183 | 184 | # Acts with an epsilon-greedy policy 185 | def act_e_greedy(self, obs, w, update_eps=0.2): 186 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, w).item() 187 | 188 | def act_gpi_e_greedy(self, obs, w_train, w, update_eps=0.2): 189 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act_gpi(obs, w_train, w).item() 190 | 191 | def get_policy(self, w, obs=None, policy_type='boltzmann', temp=1, eps=0.01, target_network=False): 192 | if target_network: 193 | f = self.forward_target_network(obs, w) 194 | else: 195 | f = self.forward_network(obs, w) 196 | z = torch.einsum('sda, sd -> sa', f, w) 197 | return extract_policy(z, policy_type=policy_type, temp=temp, eps=eps) 198 | 199 | # soft update 200 | def _soft_update_target_network(self, target, source): 201 | for target_param, param in zip(target.parameters(), source.parameters()): 202 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 203 | 204 | # update the network 205 | def _update_network(self): 206 | # sample transitions 207 | transitions = self.buffer.sample(self.args.batch_size) 208 | transitions_other = self.buffer.sample(self.args.batch_size) 209 | # pre-process the observation and goal 210 | obs_tensor = torch.tensor(np.transpose(transitions['obs'], [0, 3, 1, 2]) / 255, dtype=torch.float32) 211 | obs_next_tensor = torch.tensor(np.transpose(transitions['obs_next'], [0, 3, 1, 2]) / 255, dtype=torch.float32) 212 | g_tensor = torch.tensor(transitions['g'] / 170, dtype=torch.float32) 213 | ag_tensor = torch.tensor(transitions['ag'] / 170, dtype=torch.float32) 214 | ag_other_tensor = torch.tensor(transitions_other['ag'] / 170, dtype=torch.float32) 215 | dones_tensor = torch.tensor(transitions['done'], dtype=torch.float32).reshape(-1, 1) 216 | actions_tensor = torch.tensor(transitions['action'], dtype=torch.long) 217 | if self.args.cuda: 218 | obs_tensor = obs_tensor.cuda() 219 | obs_next_tensor = obs_next_tensor.cuda() 220 | g_tensor = g_tensor.cuda() 221 | ag_tensor = ag_tensor.cuda() 222 | ag_other_tensor = ag_other_tensor.cuda() 223 | dones_tensor = dones_tensor.cuda() 224 | actions_tensor = actions_tensor.cuda() 225 | 226 | if self.args.w_sampling == 'goal_oriented': 227 | with torch.no_grad(): 228 | w = self.backward_network(g_tensor) 229 | w = w.detach() 230 | elif self.args.w_sampling == 'uniform_ball': 231 | w = self.sample_uniform_ball(self.args.batch_size) 232 | elif self.args.w_sampling == 'cauchy_ball': 233 | w = self.sample_cauchy_ball(self.args.batch_size) 234 | 235 | # calculate the target Q value function 236 | with torch.no_grad(): 237 | if self.args.soft_update: 238 | pi = self.get_policy(w, obs=obs_next_tensor, policy_type='boltzmann', temp=self.args.temp, 239 | target_network=True) 240 | # entropy = nanmean(compute_entropy(pi)) 241 | f_next = torch.einsum('sda, sa -> sd', self.forward_target_network(obs_next_tensor, w), pi) 242 | else: 243 | actions_next_tensor = self.act(obs_next_tensor, w, target_network=True) 244 | next_idxs = actions_next_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None] 245 | f_next = self.forward_target_network(obs_next_tensor, w).gather(-1, next_idxs).squeeze() # batch x dim 246 | 247 | b_next = self.backward_target_network(ag_other_tensor) # batch x dim 248 | z_next = torch.einsum('sd, td -> st', f_next, b_next) # batch x batch 249 | z_next = z_next.detach() 250 | 251 | # the forward loss 252 | idxs = actions_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None] 253 | f = self.forward_network(obs_tensor, w).gather(-1, idxs).squeeze() 254 | b = self.backward_network(ag_tensor) 255 | b_other = self.backward_network(ag_other_tensor) 256 | z_diag = torch.einsum('sd, sd -> s', f, b) # batch 257 | z = torch.einsum('sd, td -> st', f, b_other) # batch x batch 258 | fb_loss = 0.5 * (z - (1 - dones_tensor) * self.args.gamma * z_next).pow(2).mean() - z_diag.mean() 259 | # compute orthonormality's regularisation loss 260 | b_b_other = torch.einsum('sd, xd -> sx', b, b_other) # batch x batch 261 | b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach()) # batch x batch 262 | b_b_detach = torch.einsum('sd, sd -> s', b, b.detach()) # batch 263 | reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean() 264 | fb_loss += self.args.reg_coef * reg_loss 265 | 266 | # update the forward_network 267 | self.fb_optim.zero_grad() 268 | fb_loss.backward() 269 | self.fb_optim.step() 270 | 271 | # the backward loss 272 | # f = self.forward_network(obs_norm_tensor, actions_tensor, w) 273 | # b = self.backward_network(ag_norm_tensor) 274 | # b_other = self.backward_network(g_other_norm_tensor) 275 | # z_diag = torch.einsum('sd, sd -> s', f, b) # batch 276 | # z = torch.einsum('sd, td -> st', f, b_other) # batch x batch 277 | # b_loss = 0.5 * (z - self.args.gamma * z_next).pow(2).mean() - z_diag.mean() 278 | # compute orthonormality's regularisation loss 279 | # b_b_other = torch.einsum('sd, xd -> sx', b, b_other) # batch x batch 280 | # b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach()) # batch x batch 281 | # b_b_detach = torch.einsum('sd, sd -> s', b, b.detach()) # batch 282 | # reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean() 283 | # b_loss += self.args.reg_coef * reg_loss 284 | # 285 | # # update the backward_network 286 | # self.backward_optim.zero_grad() 287 | # b_loss.backward() 288 | # sync_grads(self.backward_network) 289 | # self.backward_optim.step() 290 | 291 | # print('f_loss: {}, b_loss: {}'.format(f_loss.item(), b_loss.item())) 292 | 293 | # do the evaluation 294 | def _eval_agent(self): 295 | total_success_rate = [] 296 | total_dist = [] 297 | for _ in range(self.args.n_test_rollouts): 298 | observation = self.env.reset() 299 | obs = observation['observation'] 300 | g = observation['desired_goal'] 301 | 302 | for _ in range(self.env_params['max_timesteps']): 303 | with torch.no_grad(): 304 | g_tensor = self._preproc_g(g) 305 | w = self.backward_network(g_tensor) 306 | obs_tensor = self._preproc_o(obs) 307 | action = self.act_e_greedy(obs_tensor, w, update_eps=0.02) 308 | observation_new, _, done, info = self.env.step(action) 309 | obs = observation_new['observation'] 310 | g = observation_new['desired_goal'] 311 | dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal']) 312 | if info['is_success'] > 0 or done: 313 | break 314 | total_success_rate.append(info['is_success']) 315 | total_dist.append(dist) 316 | total_success_rate = np.array(total_success_rate) 317 | success_rate = np.mean(total_success_rate) 318 | total_dist = np.array(total_dist) 319 | dist = np.mean(total_dist) 320 | return success_rate, dist 321 | 322 | def _eval_gpi_agent(self, num_gpi=20): 323 | total_success_rate = [] 324 | total_dist = [] 325 | for _ in range(self.args.n_test_rollouts): 326 | observation = self.env.reset() 327 | obs = observation['observation'] 328 | g = observation['desired_goal'] 329 | if self.args.w_sampling == 'goal_oriented': 330 | transitions = self.buffer.sample(self.args.batch_size) 331 | g_train = transitions['g'] 332 | g_train_tensor = torch.tensor(g_train / 170, dtype=torch.float32) 333 | if self.args.cuda: 334 | g_train_tensor = g_train_tensor.cuda() 335 | w_train = self.backward_network(g_train_tensor) 336 | elif self.args.w_sampling == 'uniform_ball': 337 | w_train = self.sample_uniform_ball(num_gpi) 338 | elif self.args.w_sampling == 'cauchy_ball': 339 | w_train = self.sample_cauchy_ball(num_gpi) 340 | 341 | for _ in range(self.env_params['max_timesteps']): 342 | with torch.no_grad(): 343 | g_tensor = self._preproc_g(g) 344 | w = self.backward_network(g_tensor) 345 | obs_tensor = self._preproc_o(obs) 346 | action = self.act_gpi_e_greedy(obs_tensor, w_train, w, update_eps=0.02) 347 | observation_new, _, done, info = self.env.step(action) 348 | obs = observation_new['observation'] 349 | g = observation_new['desired_goal'] 350 | dist = goal_distance(observation_new['achieved_goal'], observation_new['desired_goal']) 351 | if info['is_success'] > 0 or done: 352 | break 353 | 354 | total_success_rate.append(info['is_success']) 355 | total_dist.append(dist) 356 | total_success_rate = np.array(total_success_rate) 357 | success_rate = np.mean(total_success_rate) 358 | total_dist = np.array(total_dist) 359 | dist = np.mean(total_dist) 360 | return success_rate, dist 361 | -------------------------------------------------------------------------------- /grid_modules/fb_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from datetime import datetime 4 | import numpy as np 5 | import random 6 | import pickle 7 | import csv 8 | from grid_modules.replay_buffer import ReplayBuffer 9 | from grid_modules.mdp_utils import extract_policy, value_iteration, compute_successor_reps 10 | from discrete_action_robots_modules.models import ForwardMap, BackwardMap 11 | # from grid_modules.models import ForwardMap, BackwardMap 12 | 13 | from torch.distributions.cauchy import Cauchy 14 | 15 | """ 16 | FB agent 17 | 18 | """ 19 | 20 | def compute_entropy(policy): 21 | p_log_p = torch.log(policy) * policy 22 | return -p_log_p.sum(-1) 23 | 24 | 25 | def nanmean(v, *args, inplace=False, **kwargs): 26 | if not inplace: 27 | v = v.clone() 28 | is_nan = torch.isnan(v) 29 | v[is_nan] = 1e-10 30 | return v.mean(*args, **kwargs) 31 | 32 | 33 | class FBAgent: 34 | def __init__(self, args, env, env_params): 35 | self.args = args 36 | self.env = env 37 | self.env_params = env_params 38 | self.cauchy = Cauchy(torch.tensor([0.0]), torch.tensor([0.5])) 39 | # create the network 40 | self.forward_network = ForwardMap(env_params, args.embed_dim) 41 | self.backward_network = BackwardMap(env_params, args.embed_dim) 42 | # build up the target network 43 | self.forward_target_network = ForwardMap(env_params, args.embed_dim) 44 | self.backward_target_network = BackwardMap(env_params, args.embed_dim) 45 | # load the weights into the target networks 46 | self.forward_target_network.load_state_dict(self.forward_network.state_dict()) 47 | self.backward_target_network.load_state_dict(self.backward_network.state_dict()) 48 | # if use gpu 49 | if self.args.cuda: 50 | self.forward_network.cuda() 51 | self.backward_network.cuda() 52 | self.forward_target_network.cuda() 53 | self.backward_target_network.cuda() 54 | # create the optimizer 55 | f_params = [param for param in self.forward_network.parameters()] 56 | b_params = [param for param in self.backward_network.parameters()] 57 | self.f_optim = torch.optim.Adam(f_params, lr=self.args.lr) 58 | self.b_optim = torch.optim.Adam(b_params, lr=self.args.lr) 59 | self.fb_optim = torch.optim.Adam(f_params + b_params, lr=self.args.lr) 60 | # self.backward_optim = torch.optim.Adam(self.backward_network.parameters(), lr=self.args.lr_backward) 61 | # her sampler 62 | 63 | # create the replay buffer 64 | self.buffer = ReplayBuffer(self.args.buffer_size) 65 | 66 | if args.save_dir is not None: 67 | if not os.path.exists(self.args.save_dir): 68 | os.mkdir(self.args.save_dir) 69 | 70 | print(' ' * 26 + 'Options') 71 | for k, v in vars(self.args).items(): 72 | print(' ' * 26 + k + ': ' + str(v)) 73 | 74 | with open(self.args.save_dir + "/arguments.pkl", 'wb') as f: 75 | pickle.dump(self.args, f) 76 | 77 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "wt") as monitor_file: 78 | monitor = csv.writer(monitor_file) 79 | monitor.writerow(['epoch', 'eval', 'eval (GPI)', 'loss', 'entropy']) 80 | 81 | def learn(self): 82 | """ 83 | train the network 84 | 85 | """ 86 | best_perf = 0 87 | # start to collect samples 88 | for epoch in range(self.args.n_epochs): 89 | for _ in range(self.args.n_cycles): 90 | for _ in range(self.args.num_rollouts_per_cycle): 91 | # reset the rollouts 92 | # reset the environment 93 | obs = self.env.reset() 94 | g = self.env.goal 95 | if self.args.w_sampling == 'goal_oriented': 96 | g_tensor = self._preproc_g(g) 97 | with torch.no_grad(): 98 | w = self.backward_network(g_tensor) 99 | elif self.args.w_sampling == 'uniform_ball': 100 | w = self.sample_uniform_ball(1) 101 | elif self.args.w_sampling == 'cauchy_ball': 102 | w = self.sample_cauchy_ball(1) 103 | # start to collect samples 104 | for t in range(self.env_params['max_timesteps']): 105 | with torch.no_grad(): 106 | obs_tensor = self._preproc_o(obs) 107 | action = self.act_e_greedy(obs_tensor, w, update_eps=self.args.update_eps) 108 | # feed the actions into the environment 109 | obs_new, reward, done, info = self.env.step(action) 110 | # add transition 111 | self.buffer.add(obs, g, action, reward, obs_new, done) 112 | if done: 113 | obs = self.env.reset() 114 | g = self.env.goal 115 | else: 116 | obs = obs_new 117 | for _ in range(self.args.n_batches): 118 | # train the network 119 | fb_loss, entropy = self._update_network() 120 | # soft update 121 | self._soft_update_target_network(self.forward_target_network, self.forward_network) 122 | self._soft_update_target_network(self.backward_target_network, self.backward_network) 123 | # self._hard_update_target_network(self.forward_target_network, self.forward_network) 124 | # self._hard_update_target_network(self.backward_target_network, self.backward_network) 125 | # start to do the evaluation 126 | perf, gpi_perf = self._eval_agent(num_gpi=self.args.num_gpi) 127 | 128 | print('[{}] epoch is: {}, eval: {:.3f}, ' 129 | 'eval (GPI): {:.3f}, loss: {:.3f}, entropy: {:.3f}'.format(datetime.now(), epoch, perf, gpi_perf, fb_loss, entropy)) 130 | with open('{}/score_monitor.csv'.format(self.args.save_dir), "a") as monitor_file: 131 | monitor = csv.writer(monitor_file) 132 | monitor.writerow([epoch, perf, gpi_perf, fb_loss, entropy]) 133 | torch.save([self.forward_network.state_dict(), self.backward_network.state_dict()], 134 | os.path.join(self.args.save_dir, 'model.pt')) 135 | if perf > best_perf: 136 | torch.save([self.forward_network.state_dict(), self.backward_network.state_dict()], 137 | os.path.join(self.args.save_dir, 'best_model.pt')) 138 | 139 | 140 | def sample_uniform_ball(self, n, eps=1e-10): 141 | gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1) 142 | gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps 143 | uniform_rdv = torch.FloatTensor(n, 1).uniform_() 144 | w = np.sqrt(self.args.embed_dim) * gaussian_rdv * uniform_rdv 145 | # w = gaussian_rdv * uniform_rdv 146 | # w = w.repeat(n, 1) 147 | if self.args.cuda: 148 | w = w.cuda() 149 | return w 150 | 151 | def sample_cauchy_ball(self, n, eps=1e-10): 152 | gaussian_rdv = torch.FloatTensor(n, self.args.embed_dim).normal_(mean=0, std=1) 153 | gaussian_rdv /= torch.norm(gaussian_rdv, dim=-1, keepdim=True) + eps 154 | cauchy_rdv = self.cauchy.sample((n, )) 155 | w = np.sqrt(self.args.embed_dim) * gaussian_rdv * cauchy_rdv 156 | # w = gaussian_rdv * uniform_rdv 157 | # w = w.repeat(n, 1) 158 | if self.args.cuda: 159 | w = w.cuda() 160 | return w 161 | 162 | # pre_process the inputs 163 | def _preproc_o(self, obs): 164 | obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) 165 | if self.args.cuda: 166 | obs_tensor = obs_tensor.cuda() 167 | return obs_tensor 168 | 169 | def _preproc_g(self, g): 170 | g_tensor = torch.tensor(g, dtype=torch.float32).unsqueeze(0) 171 | if self.args.cuda: 172 | g_tensor = g_tensor.cuda() 173 | return g_tensor 174 | 175 | def get_policy(self, w, obs=None, policy_type='boltzmann', temp=1, eps=0.01, target_network=False): 176 | if obs is None: 177 | obs = torch.eye(self.env.state_space) # S x S 178 | w = w.repeat(self.env.state_space, 1) 179 | if self.args.cuda: 180 | obs = obs.cuda() # S x S 181 | if target_network: 182 | f = self.forward_target_network(obs, w) 183 | else: 184 | f = self.forward_network(obs, w) 185 | z = torch.einsum('sda, sd -> sa', f, w) 186 | return extract_policy(z, policy_type=policy_type, temp=temp, eps=eps) 187 | 188 | def get_gpi_policy(self, w_train, w_eval, obs=None, policy_type='boltzmann', temp=0.1, eps=0.01): 189 | if obs is None: 190 | obs = torch.eye(self.env.state_space) # S x S 191 | if self.args.cuda: 192 | obs = obs.cuda() # S x S 193 | num_gpi = w_train.shape[0] 194 | obs_repeat = obs.repeat(1, num_gpi).reshape(num_gpi * self.env.state_space, -1) 195 | w_eval_repeat = w_eval.repeat(num_gpi * self.env.state_space, 1) 196 | w_train_repeat = w_train.repeat(self.env.state_space, 1) 197 | f = self.forward_network(obs_repeat, w_train_repeat) 198 | z = torch.einsum('sda, sd -> sa', f, w_eval_repeat).reshape(self.env.state_space, 199 | num_gpi, 200 | self.env.action_space) 201 | z = z.max(1)[0] 202 | return extract_policy(z, policy_type=policy_type, temp=temp, eps=eps) 203 | 204 | def act_gpi(self, obs, w_train, w_eval): 205 | # import pdb 206 | # pdb.set_trace() 207 | num_gpi = w_train.shape[0] 208 | obs_repeat = obs.repeat(num_gpi, 1) 209 | w_eval_repeat = w_eval.repeat(num_gpi, 1) 210 | f = self.forward_network(obs_repeat, w_train) 211 | z = torch.einsum('sda, sd -> sa', f, w_eval_repeat).max(0)[0] 212 | return z.max(0)[1] 213 | 214 | # Acts based on single state (no batch) 215 | def act(self, obs, w, target_network=False): 216 | if target_network: 217 | f = self.forward_target_network(obs, w) 218 | else: 219 | f = self.forward_network(obs, w) 220 | z = torch.einsum('sda, sd -> sa', f, w) 221 | # import pdb 222 | # pdb.set_trace() 223 | y = z.max(1)[1] 224 | return y 225 | 226 | # Acts with an epsilon-greedy policy 227 | def act_e_greedy(self, obs, g, update_eps=0.2): 228 | return random.randrange(self.env_params['action']) if random.random() < update_eps else self.act(obs, g).item() 229 | 230 | # soft update 231 | def _soft_update_target_network(self, target, source): 232 | for target_param, param in zip(target.parameters(), source.parameters()): 233 | target_param.data.copy_((1 - self.args.polyak) * param.data + self.args.polyak * target_param.data) 234 | 235 | def _hard_update_target_network(self, target, source): 236 | for target_param, param in zip(target.parameters(), source.parameters()): 237 | target_param.data.copy_(param.data) 238 | 239 | # update the network 240 | def _update_network(self): 241 | # sample the episodes 242 | transitions = self.buffer.sample(self.args.batch_size) 243 | other_transitions = self.buffer.sample(self.args.batch_size) 244 | 245 | # transfer them into the tensor 246 | obs_tensor = torch.tensor(transitions['obs'], dtype=torch.float32) 247 | g_tensor = torch.tensor(transitions['g'], dtype=torch.float32) 248 | obs_next_tensor = torch.tensor(transitions['obs_next'], dtype=torch.float32) 249 | actions_tensor = torch.tensor(transitions['action'], dtype=torch.long) 250 | obs_other_tensor = torch.tensor(other_transitions['obs'], dtype=torch.float32) 251 | actions_other_tensor = torch.tensor(other_transitions['action'], dtype=torch.long) 252 | if self.args.cuda: 253 | obs_tensor = obs_tensor.cuda() 254 | g_tensor = g_tensor.cuda() 255 | obs_next_tensor = obs_next_tensor.cuda() 256 | actions_tensor = actions_tensor.cuda() 257 | obs_other_tensor = obs_other_tensor.cuda() 258 | actions_other_tensor = actions_other_tensor.cuda() 259 | 260 | if self.args.w_sampling == 'goal_oriented': 261 | with torch.no_grad(): 262 | w = self.backward_network(g_tensor) 263 | w = w.detach() 264 | elif self.args.w_sampling == 'uniform_ball': 265 | w = self.sample_uniform_ball(self.args.batch_size) 266 | elif self.args.w_sampling == 'cauchy_ball': 267 | w = self.sample_cauchy_ball(self.args.batch_size) 268 | 269 | # calculate the target Q value function 270 | with torch.no_grad(): 271 | # import pdb 272 | # pdb.set_trace() 273 | # actions_next_tensor = self.act(obs_next_tensor, w, target_network=True) 274 | # next_idxs = actions_next_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None] 275 | # f_next = self.forward_target_network(obs_next_tensor, w).gather(-1, next_idxs).squeeze() # batch x dim 276 | pi = self.get_policy(w, obs=obs_next_tensor, policy_type='boltzmann', temp=self.args.temp, target_network=True) 277 | entropy = nanmean(compute_entropy(pi)) 278 | f_next = torch.einsum('sda, sa -> sd', self.forward_target_network(obs_next_tensor, w), pi) 279 | b_next = self.backward_target_network(obs_other_tensor) # batch x dim 280 | # idxs_other = actions_other_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None] 281 | # b_next = self.backward_target_network(obs_other_tensor).gather(-1, idxs_other).squeeze() # batch x dim 282 | z_next = torch.einsum('sd, td -> st', f_next, b_next) # batch x batch 283 | z_next = z_next.detach() 284 | 285 | # the forward loss 286 | idxs = actions_tensor[:, None].repeat(1, self.args.embed_dim)[:, :, None] 287 | f = self.forward_network(obs_tensor, w).gather(-1, idxs).squeeze() 288 | b = self.backward_network(obs_tensor) 289 | b_other = self.backward_network(obs_other_tensor) 290 | # b = self.backward_network(obs_tensor).gather(-1, idxs).squeeze() 291 | # b_other = self.backward_network(obs_other_tensor).gather(-1, idxs_other).squeeze() 292 | z_diag = torch.einsum('sd, sd -> s', f, b) # batch 293 | z = torch.einsum('sd, td -> st', f, b_other) # batch x batch 294 | fb_loss = 0.5 * (z - self.args.gamma * z_next).pow(2).mean() - z_diag.mean() 295 | # compute orthonormality's regularisation loss 296 | b_b_other = torch.einsum('sd, xd -> sx', b, b_other) # batch x batch 297 | b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach()) # batch x batch 298 | b_b_detach = torch.einsum('sd, sd -> s', b, b.detach()) # batch 299 | reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean() 300 | fb_loss += self.args.reg_coef * reg_loss 301 | 302 | # update the forward_network 303 | self.fb_optim.zero_grad() 304 | fb_loss.backward() 305 | # clip_grad_norm_(self.forward_network.parameters(), 5) 306 | self.fb_optim.step() 307 | 308 | return fb_loss.item(), entropy.item() 309 | 310 | # the backward loss 311 | # f = self.forward_network(obs_tensor, w).gather(-1, idxs).squeeze() 312 | # f = f.detach() 313 | # b = self.backward_network(obs_tensor) 314 | # b_other = self.backward_network(obs_other_tensor) 315 | # z_diag = torch.einsum('sd, sd -> s', f, b) # batch 316 | # z = torch.einsum('sd, td -> st', f, b_other) # batch x batch 317 | # b_loss = 0.5 * (z - self.args.gamma * z_next).pow(2).mean() - z_diag.mean() 318 | # # compute orthonormality's regularisation loss 319 | # b_b_other = torch.einsum('sd, xd -> sx', b, b_other) # batch x batch 320 | # b_b_other_detach = torch.einsum('sd, xd -> sx', b, b_other.detach()) # batch x batch 321 | # b_b_detach = torch.einsum('sd, sd -> s', b, b.detach()) # batch 322 | # reg_loss = (b_b_detach * b_b_other.detach()).mean() - b_b_other_detach.mean() 323 | # b_loss += self.args.reg_coef * reg_loss 324 | # 325 | # # update the backward_network 326 | # self.b_optim.zero_grad() 327 | # b_loss.backward() 328 | # clip_grad_norm_(self.backward_network.parameters(), 5) 329 | # self.b_optim.step() 330 | 331 | # do the evaluation 332 | def _eval_agent(self, num_gpi=20): 333 | total_perf = [] 334 | total_gpi_perf = [] 335 | for _ in range(self.args.n_test_rollouts): 336 | init_obs = self.env.reset() 337 | g = self.env.goal 338 | R = torch.tensor(self.env.R, dtype=torch.float32) 339 | P = torch.tensor(self.env.P, dtype=torch.float32) 340 | if self.args.cuda: 341 | R = R.cuda() 342 | P = P.cuda() 343 | opt_q = value_iteration(R, P, self.args.gamma, atol=1e-8, max_iteration=5000) 344 | opt_perf = opt_q[self.env.reachable_states].max(1)[0].mean() 345 | 346 | g_tensor = self._preproc_g(g) 347 | w = self.backward_network(g_tensor) 348 | pi = self.get_policy(w, policy_type='boltzmann', temp=1) 349 | sr_pi = compute_successor_reps(P, pi, self.args.gamma) 350 | q_pi = torch.matmul(sr_pi, R.t().reshape(self.env.state_space * self.env.action_space)) 351 | q_pi = q_pi.reshape(self.env.action_space, self.env.state_space).t() 352 | 353 | # score = torch.dot(q_pi[init_obs.argmax()], pi[init_obs.argmax()]) 354 | score = torch.einsum('sa, sa -> s', q_pi, pi)[self.env.reachable_states].mean() 355 | score /= opt_perf 356 | total_perf.append(score.item()) 357 | 358 | # with GPI 359 | 360 | if self.args.w_sampling == 'goal_oriented': 361 | transitions = self.buffer.sample(num_gpi) 362 | g_train = transitions['g'] 363 | g_train_tensor = torch.tensor(g_train, dtype=torch.float32) 364 | if self.args.cuda: 365 | g_train_tensor = g_train_tensor.cuda() 366 | w_train = self.backward_network(g_train_tensor) 367 | elif self.args.w_sampling == 'uniform_ball': 368 | w_train = self.sample_uniform_ball(num_gpi) 369 | elif self.args.w_sampling == 'cauchy_ball': 370 | w_train = w + self.sample_cauchy_ball(num_gpi) / np.sqrt(self.args.embed_dim) 371 | 372 | gpi_pi = self.get_gpi_policy(w_train, w, policy_type='boltzmann', temp=1) 373 | sr_gpi_pi = compute_successor_reps(P, gpi_pi, self.args.gamma) 374 | q_gpi_pi = torch.matmul(sr_gpi_pi, R.t().reshape(self.env.state_space * self.env.action_space)) 375 | q_gpi_pi = q_gpi_pi.reshape(self.env.action_space, self.env.state_space).t() 376 | 377 | # gpi_score = torch.dot(q_gpi_pi[init_obs.argmax()], gpi_pi[init_obs.argmax()]) 378 | gpi_score = torch.einsum('sa, sa -> s', q_gpi_pi, gpi_pi)[self.env.reachable_states].mean() 379 | gpi_score /= opt_perf 380 | total_gpi_perf.append(gpi_score.item()) 381 | 382 | total_perf = np.array(total_perf) 383 | total_gpi_perf = np.array(total_gpi_perf) 384 | return np.mean(total_perf), np.mean(total_gpi_perf) --------------------------------------------------------------------------------