├── requirements.txt ├── changelog.md ├── cartpole_demo.gif ├── TML_Presesntation.pdf ├── TOML_Final_Report.pdf ├── reward_vs_steps_k1.png ├── reward_vs_steps_k10.png ├── normalized_env.py ├── random_process.py ├── model.py ├── evaluator.py ├── util.py ├── README.md ├── action_space.py ├── ddpg.py ├── ContinuousCartPole.py ├── wolp.py └── memory.py /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /changelog.md: -------------------------------------------------------------------------------- 1 | # Changes 2 | 1. Added Presentation and report to the repo. 3 | 2. Updated README 4 | -------------------------------------------------------------------------------- /cartpole_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhil3456/Deep-Reinforcement-Learning-in-Large-Discrete-Action-Spaces/HEAD/cartpole_demo.gif -------------------------------------------------------------------------------- /TML_Presesntation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhil3456/Deep-Reinforcement-Learning-in-Large-Discrete-Action-Spaces/HEAD/TML_Presesntation.pdf -------------------------------------------------------------------------------- /TOML_Final_Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhil3456/Deep-Reinforcement-Learning-in-Large-Discrete-Action-Spaces/HEAD/TOML_Final_Report.pdf -------------------------------------------------------------------------------- /reward_vs_steps_k1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhil3456/Deep-Reinforcement-Learning-in-Large-Discrete-Action-Spaces/HEAD/reward_vs_steps_k1.png -------------------------------------------------------------------------------- /reward_vs_steps_k10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhil3456/Deep-Reinforcement-Learning-in-Large-Discrete-Action-Spaces/HEAD/reward_vs_steps_k10.png -------------------------------------------------------------------------------- /normalized_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | # https://github.com/openai/gym/blob/master/gym/core.py 4 | class NormalizedEnv(gym.ActionWrapper): 5 | """ Wrap action """ 6 | 7 | def action(self, action): 8 | act_k = (self.action_space.high - self.action_space.low)/ 2. 9 | act_b = (self.action_space.high + self.action_space.low)/ 2. 10 | return act_k * action + act_b 11 | 12 | def reverse_action(self, action): 13 | act_k_inv = 2./(self.action_space.high - self.action_space.low) 14 | act_b = (self.action_space.high + self.action_space.low)/ 2. 15 | return act_k_inv * (action - act_b) 16 | -------------------------------------------------------------------------------- /random_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/random.py 4 | 5 | class RandomProcess(object): 6 | def reset_states(self): 7 | pass 8 | 9 | class AnnealedGaussianProcess(RandomProcess): 10 | def __init__(self, mu, sigma, sigma_min, n_steps_annealing): 11 | self.mu = mu 12 | self.sigma = sigma 13 | self.n_steps = 0 14 | 15 | if sigma_min is not None: 16 | self.m = -float(sigma - sigma_min) / float(n_steps_annealing) 17 | self.c = sigma 18 | self.sigma_min = sigma_min 19 | else: 20 | self.m = 0. 21 | self.c = sigma 22 | self.sigma_min = sigma 23 | 24 | @property 25 | def current_sigma(self): 26 | sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c) 27 | return sigma 28 | 29 | 30 | # Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab 31 | class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess): 32 | def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000): 33 | super(OrnsteinUhlenbeckProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing) 34 | self.theta = theta 35 | self.mu = mu 36 | self.dt = dt 37 | self.x0 = x0 38 | self.size = size 39 | self.reset_states() 40 | 41 | def sample(self): 42 | x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size) 43 | self.x_prev = x 44 | self.n_steps += 1 45 | return x 46 | 47 | def reset_states(self): 48 | self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size) 49 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | #from ipdb import set_trace as debug 9 | 10 | def fanin_init(size, fanin=None): 11 | fanin = fanin or size[0] 12 | v = 1. / np.sqrt(fanin) 13 | return torch.Tensor(size).uniform_(-v, v) 14 | 15 | class Actor(nn.Module): 16 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300, init_w=3e-3): 17 | super(Actor, self).__init__() 18 | self.fc1 = nn.Linear(nb_states, hidden1) 19 | self.fc2 = nn.Linear(hidden1, hidden2) 20 | self.fc3 = nn.Linear(hidden2, nb_actions) 21 | self.relu = nn.ReLU() 22 | self.tanh = nn.Tanh() 23 | self.init_weights(init_w) 24 | 25 | def init_weights(self, init_w): 26 | self.fc1.weight.data = fanin_init(self.fc1.weight.data.size()) 27 | self.fc2.weight.data = fanin_init(self.fc2.weight.data.size()) 28 | self.fc3.weight.data.uniform_(-init_w, init_w) 29 | 30 | def forward(self, x): 31 | out = self.fc1(x) 32 | out = self.relu(out) 33 | out = self.fc2(out) 34 | out = self.relu(out) 35 | out = self.fc3(out) 36 | out = self.tanh(out) 37 | return out 38 | 39 | class Critic(nn.Module): 40 | def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300, init_w=3e-3): 41 | super(Critic, self).__init__() 42 | self.fc1 = nn.Linear(nb_states, hidden1) 43 | self.fc2 = nn.Linear(hidden1+nb_actions, hidden2) 44 | self.fc3 = nn.Linear(hidden2, 1) 45 | self.relu = nn.ReLU() 46 | self.init_weights(init_w) 47 | 48 | def init_weights(self, init_w): 49 | self.fc1.weight.data = fanin_init(self.fc1.weight.data.size()) 50 | self.fc2.weight.data = fanin_init(self.fc2.weight.data.size()) 51 | self.fc3.weight.data.uniform_(-init_w, init_w) 52 | 53 | def forward(self, xs): 54 | x, a = xs 55 | out = self.fc1(x) 56 | out = self.relu(out) 57 | # debug() 58 | out = self.fc2(torch.cat([out,a],1)) 59 | out = self.relu(out) 60 | out = self.fc3(out) 61 | return out 62 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from scipy.io import savemat 5 | 6 | from util import * 7 | 8 | class Evaluator(object): 9 | 10 | def __init__(self, num_episodes, interval, save_path='', max_episode_length=None): 11 | self.num_episodes = num_episodes 12 | self.max_episode_length = max_episode_length 13 | self.interval = interval 14 | self.save_path = save_path 15 | self.results = np.array([]).reshape(num_episodes,0) 16 | 17 | def __call__(self, env, policy, debug=False, visualize=False, save=True): 18 | 19 | self.is_training = False 20 | observation = None 21 | result = [] 22 | 23 | for episode in range(self.num_episodes): 24 | 25 | # reset at the start of episode 26 | observation = env.reset() 27 | episode_steps = 0 28 | episode_reward = 0. 29 | 30 | assert observation is not None 31 | 32 | # start episode 33 | done = False 34 | while not done: 35 | # basic operation, action ,reward, blablabla ... 36 | action = policy(observation) 37 | 38 | observation, reward, done, info = env.step(action) 39 | if self.max_episode_length and episode_steps >= self.max_episode_length -1: 40 | done = True 41 | 42 | if visualize: 43 | env.render(mode='human') 44 | 45 | # update 46 | episode_reward += reward 47 | episode_steps += 1 48 | 49 | if debug: prYellow('[Evaluate] #Episode{}: episode_reward:{}'.format(episode,episode_reward)) 50 | result.append(episode_reward) 51 | 52 | result = np.array(result).reshape(-1,1) 53 | self.results = np.hstack([self.results, result]) 54 | 55 | if save: 56 | self.save_results('{}/validate_reward'.format(self.save_path)) 57 | return np.mean(result) 58 | 59 | def save_results(self, fn): 60 | 61 | y = np.mean(self.results, axis=0) 62 | error=np.std(self.results, axis=0) 63 | 64 | x = range(0,self.results.shape[1]*self.interval,self.interval) 65 | fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 66 | plt.xlabel('Timestep') 67 | plt.ylabel('Average Reward') 68 | ax.errorbar(x, y, yerr=error, fmt='-o') 69 | plt.savefig(fn+'.png') 70 | savemat(fn+'.mat', {'reward':self.results}) 71 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | USE_CUDA = torch.cuda.is_available() 7 | FLOAT = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 8 | 9 | def prRed(prt): print("\033[91m {}\033[00m" .format(prt)) 10 | def prGreen(prt): print("\033[92m {}\033[00m" .format(prt)) 11 | def prYellow(prt): print("\033[93m {}\033[00m" .format(prt)) 12 | def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt)) 13 | def prPurple(prt): print("\033[95m {}\033[00m" .format(prt)) 14 | def prCyan(prt): print("\033[96m {}\033[00m" .format(prt)) 15 | def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt)) 16 | def prBlack(prt): print("\033[98m {}\033[00m" .format(prt)) 17 | 18 | def to_numpy(var): 19 | return var.cpu().data.numpy() if USE_CUDA else var.data.numpy() 20 | 21 | def to_tensor(ndarray, volatile=False, requires_grad=False, dtype=FLOAT): 22 | return Variable( 23 | torch.from_numpy(ndarray), volatile=volatile, requires_grad=requires_grad 24 | ).type(dtype) 25 | 26 | def soft_update(target, source, tau): 27 | for target_param, param in zip(target.parameters(), source.parameters()): 28 | target_param.data.copy_( 29 | target_param.data * (1.0 - tau) + param.data * tau 30 | ) 31 | 32 | def hard_update(target, source): 33 | for target_param, param in zip(target.parameters(), source.parameters()): 34 | target_param.data.copy_(param.data) 35 | 36 | def get_output_folder(parent_dir, env_name): 37 | """Return save folder. 38 | Assumes folders in the parent_dir have suffix -run{run 39 | number}. Finds the highest run number and sets the output folder 40 | to that number + 1. This is just convenient so that if you run the 41 | same script multiple times tensorboard can plot all of the results 42 | on the same plots with different names. 43 | Parameters 44 | ---------- 45 | parent_dir: str 46 | Path of the directory containing all experiment runs. 47 | Returns 48 | ------- 49 | parent_dir/run_dir 50 | Path to this run's save directory. 51 | """ 52 | os.makedirs(parent_dir, exist_ok=True) 53 | experiment_id = 0 54 | for folder_name in os.listdir(parent_dir): 55 | if not os.path.isdir(os.path.join(parent_dir, folder_name)): 56 | continue 57 | try: 58 | folder_name = int(folder_name.split('-run')[-1]) 59 | if folder_name > experiment_id: 60 | experiment_id = folder_name 61 | except: 62 | pass 63 | experiment_id += 1 64 | 65 | parent_dir = os.path.join(parent_dir, env_name) 66 | parent_dir = parent_dir + '-run{}'.format(experiment_id) 67 | os.makedirs(parent_dir, exist_ok=True) 68 | return parent_dir 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Reinforcement Learning in Large Discrete Action Spaces 2 | 3 | This is a PyTorch implementation of the [paper](https://arxiv.org/abs/1512.07679) "Deep Reinforcement Learning in Large Discrete Action Spaces" (Gabriel Dulac-Arnold, Richard Evans, Hado van Hasselt, Peter Sunehag, Timothy Lillicrap, Jonathan Hunt, Timothy Mann, Theophane Weber, Thomas Degris, Ben Coppin). 4 | 5 | ## Installation 6 | 7 | To install the relevant libraries, run the following command: 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Demonstration of Model 14 | 15 | Demonstration video: 16 | 17 | ![cartpole_demo](./cartpole_demo.gif) 18 | 19 | Results for `k = 1` (K is the number of nearest neighbours) | Results for `k = 10` (K is the number of nearest neighbours) 20 | :-------------------------:|:-------------------------: 21 | ![reward_k_1.png](./reward_vs_steps_k1.png) | ![reward_k_1.png](./reward_vs_steps_k10.png) 22 | 23 | 24 | ## Train the agent 25 | 26 | To train the agent, simply run the `main.ipynb` file provided in the repository. The parameters can be updated by changing the values in the `Arguments` class. 27 | 28 | ## Test the agent 29 | After training the agent using the above code, run the following code to test it on the cartpole environment. 30 | 31 | ```python 32 | import gym 33 | from gym import wrappers 34 | env_to_wrap = ContinuousCartPoleEnv() 35 | env = wrappers.Monitor(env_to_wrap, './demo', force = True) 36 | env.reset() 37 | for i_episode in range(1): 38 | observation = env.reset() 39 | ep_reward = 0 40 | for t in range(500): 41 | env.render() 42 | action = agent.select_action(observation) 43 | observation, reward, done, info = env.step(action) 44 | ep_reward += reward 45 | if done: 46 | print("Episode finished after {} timesteps".format(t+1)) 47 | break 48 | print(ep_reward) 49 | env_to_wrap.close() 50 | env.close() 51 | ``` 52 | 53 | ## Acknowledgements 54 | 55 | - Our DDPG code is based on the excellent implementation provided by [ghliu/pytorch-ddpg](https://github.com/ghliu/pytorch-ddpg). 56 | 57 | - The WOLPERTINGER agent code and `action_space.py` code is based on the excellent implementation of the paper provided by [jimkon/Deep-Reinforcement-Learning-in-Large-Discrete-Action-Spaces](https://github.com/jimkon/Deep-Reinforcement-Learning-in-Large-Discrete-Action-Spaces) 58 | 59 | ## Reference 60 | If you are interested in the work and want to cite it, please acknowledge the following paper: 61 | 62 | ``` 63 | @article{DBLP:journals/corr/Dulac-ArnoldESC15, 64 | author = {Gabriel Dulac{-}Arnold and 65 | Richard Evans and 66 | Peter Sunehag and 67 | Ben Coppin}, 68 | title = {Reinforcement Learning in Large Discrete Action Spaces}, 69 | journal = {CoRR}, 70 | volume = {abs/1512.07679}, 71 | year = {2015}, 72 | url = {http://arxiv.org/abs/1512.07679}, 73 | archivePrefix = {arXiv}, 74 | eprint = {1512.07679}, 75 | timestamp = {Mon, 13 Aug 2018 16:46:25 +0200}, 76 | biburl = {https://dblp.org/rec/bib/journals/corr/Dulac-ArnoldESC15}, 77 | bibsource = {dblp computer science bibliography, https://dblp.org} 78 | } 79 | ``` 80 | 81 | ## Collaborators 82 | 83 | 1. [Shashank Srikanth](https://github.com/talsperre) 84 | 2. [Nikhil Bansal](https://github.com/nikhil3456) 85 | -------------------------------------------------------------------------------- /action_space.py: -------------------------------------------------------------------------------- 1 | import pyflann 2 | from gym.spaces import Box 3 | import numpy as np 4 | import itertools 5 | 6 | class Space: 7 | 8 | def __init__(self, low, high, points): 9 | 10 | self._low = np.array(low) 11 | self._high = np.array(high) 12 | self._range = self._high - self._low 13 | self._dimensions = len(low) 14 | self.__space = init_uniform_space([0] * self._dimensions, 15 | [1] * self._dimensions, 16 | points) 17 | # print("self.__space: {}, self.__space.shape: {}, self.__space.dtype: {}".format(self.__space, self.__space.shape, self.__space.dtype)) 18 | self._flann = pyflann.FLANN() 19 | self.rebuild_flann() 20 | 21 | def rebuild_flann(self): 22 | self._index = self._flann.build_index(self.__space, algorithm='kdtree') 23 | # print("Index type: {}".format(type(self._index))) 24 | 25 | def search_point(self, point, k): 26 | p_in = self.import_point(point).reshape(1, -1).astype('float64') 27 | # print("p_in: {}, p_in.shape: {}, p_in.dtype: {}".format(p_in, p_in.shape, p_in.dtype)) 28 | search_res, _ = self._flann.nn_index(p_in, k) 29 | knns = self.__space[search_res] 30 | p_out = [] 31 | for p in knns: 32 | p_out.append(self.export_point(p)) 33 | 34 | if k == 1: 35 | p_out = [p_out] 36 | return np.array(p_out) 37 | 38 | def import_point(self, point): 39 | return (point - self._low) / self._range 40 | 41 | def export_point(self, point): 42 | return self._low + point * self._range 43 | 44 | def get_space(self): 45 | return self.__space 46 | 47 | def shape(self): 48 | return self.__space.shape 49 | 50 | def get_number_of_actions(self): 51 | return self.shape()[0] 52 | 53 | def plot_space(self, additional_points=None): 54 | 55 | dims = self._dimensions 56 | 57 | if dims > 3: 58 | print( 59 | 'Cannot plot a {}-dimensional space. Max 3 dimensions'.format(dims)) 60 | return 61 | 62 | space = self.get_space() 63 | if additional_points is not None: 64 | for i in additional_points: 65 | space = np.append(space, additional_points, axis=0) 66 | 67 | if dims == 1: 68 | for x in space: 69 | plt.plot([x], [0], 'o') 70 | 71 | plt.show() 72 | elif dims == 2: 73 | for x, y in space: 74 | plt.plot([x], [y], 'o') 75 | 76 | plt.show() 77 | else: 78 | plot_3d_points(space) 79 | 80 | 81 | class Discrete_space(Space): 82 | """ 83 | Discrete action space with n actions (the integers in the range [0, n)) 84 | 0, 1, 2, ..., n-2, n-1 85 | """ 86 | 87 | def __init__(self, n): # n: the number of the discrete actions 88 | super().__init__([0], [n - 1], n) 89 | 90 | def export_point(self, point): 91 | return super().export_point(point).astype(int) 92 | 93 | 94 | def init_uniform_space(low, high, points): 95 | dims = len(low) 96 | points_in_each_axis = round(points**(1 / dims)) 97 | 98 | axis = [] 99 | for i in range(dims): 100 | axis.append(list(np.linspace(low[i], high[i], points_in_each_axis))) 101 | 102 | space = [] 103 | for _ in itertools.product(*axis): 104 | space.append(list(_)) 105 | 106 | return np.array(space) -------------------------------------------------------------------------------- /ddpg.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim import Adam 7 | 8 | from model import (Actor, Critic) 9 | from memory import SequentialMemory 10 | from random_process import OrnsteinUhlenbeckProcess 11 | from util import * 12 | 13 | # from ipdb import set_trace as debug 14 | 15 | criterion = nn.MSELoss() 16 | 17 | class DDPG(object): 18 | def __init__(self, nb_states, nb_actions, args): 19 | 20 | if args.seed > 0: 21 | self.seed(args.seed) 22 | 23 | self.nb_states = nb_states 24 | self.nb_actions= nb_actions 25 | 26 | # Create Actor and Critic Network 27 | net_cfg = { 28 | 'hidden1':args.hidden1, 29 | 'hidden2':args.hidden2, 30 | 'init_w':args.init_w 31 | } 32 | self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg) 33 | self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg) 34 | self.actor_optim = Adam(self.actor.parameters(), lr=args.prate) 35 | 36 | self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg) 37 | self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg) 38 | self.critic_optim = Adam(self.critic.parameters(), lr=args.rate) 39 | 40 | hard_update(self.actor_target, self.actor) # Make sure target is with the same weight 41 | hard_update(self.critic_target, self.critic) 42 | 43 | #Create replay buffer 44 | self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length) 45 | self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=args.ou_theta, mu=args.ou_mu, sigma=args.ou_sigma) 46 | 47 | # Hyper-parameters 48 | self.batch_size = args.bsize 49 | self.tau = args.tau 50 | self.discount = args.discount 51 | self.depsilon = 1.0 / args.epsilon 52 | 53 | # 54 | self.epsilon = 1.0 55 | self.s_t = None # Most recent state 56 | self.a_t = None # Most recent action 57 | self.is_training = True 58 | 59 | # 60 | if USE_CUDA: self.cuda() 61 | 62 | def update_policy(self): 63 | # Sample batch 64 | state_batch, action_batch, reward_batch, \ 65 | next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size) 66 | 67 | # Prepare for the target q batch 68 | next_q_values = self.critic_target([ 69 | to_tensor(next_state_batch, volatile=True), 70 | self.actor_target(to_tensor(next_state_batch, volatile=True)), 71 | ]) 72 | next_q_values.volatile=False 73 | 74 | target_q_batch = to_tensor(reward_batch) + \ 75 | self.discount*to_tensor(terminal_batch.astype(np.float))*next_q_values 76 | 77 | # Critic update 78 | self.critic.zero_grad() 79 | 80 | q_batch = self.critic([ to_tensor(state_batch), to_tensor(action_batch) ]) 81 | 82 | value_loss = criterion(q_batch, target_q_batch) 83 | value_loss.backward() 84 | self.critic_optim.step() 85 | 86 | # Actor update 87 | self.actor.zero_grad() 88 | 89 | policy_loss = -self.critic([ 90 | to_tensor(state_batch), 91 | self.actor(to_tensor(state_batch)) 92 | ]) 93 | 94 | policy_loss = policy_loss.mean() 95 | policy_loss.backward() 96 | self.actor_optim.step() 97 | 98 | # Target update 99 | soft_update(self.actor_target, self.actor, self.tau) 100 | soft_update(self.critic_target, self.critic, self.tau) 101 | 102 | def eval(self): 103 | self.actor.eval() 104 | self.actor_target.eval() 105 | self.critic.eval() 106 | self.critic_target.eval() 107 | 108 | def cuda(self): 109 | self.actor.cuda() 110 | self.actor_target.cuda() 111 | self.critic.cuda() 112 | self.critic_target.cuda() 113 | 114 | def observe(self, r_t, s_t1, done): 115 | if self.is_training: 116 | self.memory.append(self.s_t, self.a_t, r_t, done) 117 | self.s_t = s_t1 118 | 119 | def random_action(self): 120 | action = np.random.uniform(-1.,1.,self.nb_actions) 121 | self.a_t = action 122 | return action 123 | 124 | def select_action(self, s_t, decay_epsilon=True): 125 | action = to_numpy( 126 | self.actor(to_tensor(np.array([s_t]))) 127 | ).squeeze(0) 128 | action += self.is_training*max(self.epsilon, 0)*self.random_process.sample() 129 | action = np.clip(action, -1., 1.) 130 | 131 | if decay_epsilon: 132 | self.epsilon -= self.depsilon 133 | 134 | self.a_t = action 135 | return action 136 | 137 | def reset(self, obs): 138 | self.s_t = obs 139 | self.random_process.reset_states() 140 | 141 | def load_weights(self, output): 142 | if output is None: return 143 | 144 | self.actor.load_state_dict( 145 | torch.load('{}/actor.pkl'.format(output)) 146 | ) 147 | 148 | self.critic.load_state_dict( 149 | torch.load('{}/critic.pkl'.format(output)) 150 | ) 151 | 152 | 153 | def save_model(self,output): 154 | torch.save( 155 | self.actor.state_dict(), 156 | '{}/actor.pkl'.format(output) 157 | ) 158 | torch.save( 159 | self.critic.state_dict(), 160 | '{}/critic.pkl'.format(output) 161 | ) 162 | 163 | def seed(self,s): 164 | torch.manual_seed(s) 165 | if USE_CUDA: 166 | torch.cuda.manual_seed(s) 167 | -------------------------------------------------------------------------------- /ContinuousCartPole.py: -------------------------------------------------------------------------------- 1 | ## code from https://gist.github.com/iandanforth/e3ffb67cf3623153e968f2afdfb01dc8#file-continuous_cartpole-py 2 | """ 3 | Classic cart-pole system implemented by Rich Sutton et al. 4 | Copied from http://incompleteideas.net/sutton/book/code/pole.c 5 | permalink: https://perma.cc/C9ZM-652R 6 | 7 | Continuous version by Ian Danforth 8 | """ 9 | 10 | import math 11 | import gym 12 | from gym import spaces, logger 13 | from gym.utils import seeding 14 | import numpy as np 15 | 16 | 17 | class ContinuousCartPoleEnv(gym.Env): 18 | metadata = { 19 | 'render.modes': ['human', 'rgb_array'], 20 | 'video.frames_per_second': 50 21 | } 22 | 23 | def __init__(self): 24 | self.gravity = 9.8 25 | self.masscart = 1.0 26 | self.masspole = 0.1 27 | self.total_mass = (self.masspole + self.masscart) 28 | self.length = 0.5 # actually half the pole's length 29 | self.polemass_length = (self.masspole * self.length) 30 | self.force_mag = 30.0 31 | self.tau = 0.02 # seconds between state updates 32 | self.min_action = -1.0 33 | self.max_action = 1.0 34 | 35 | # Angle at which to fail the episode 36 | self.theta_threshold_radians = 12 * 2 * math.pi / 360 37 | self.x_threshold = 2.4 38 | 39 | # Angle limit set to 2 * theta_threshold_radians so failing observation 40 | # is still within bounds 41 | high = np.array([ 42 | self.x_threshold * 2, 43 | np.finfo(np.float32).max, 44 | self.theta_threshold_radians * 2, 45 | np.finfo(np.float32).max]) 46 | 47 | self.action_space = spaces.Box( 48 | low=self.min_action, 49 | high=self.max_action, 50 | shape=(1,) 51 | ) 52 | self.observation_space = spaces.Box(-high, high) 53 | 54 | self.seed() 55 | self.viewer = None 56 | self.state = None 57 | 58 | self.steps_beyond_done = None 59 | 60 | self.steps = 0 61 | self._max_episode_steps = 500 62 | 63 | def seed(self, seed=None): 64 | self.np_random, seed = seeding.np_random(seed) 65 | return [seed] 66 | 67 | def stepPhysics(self, force): 68 | x, x_dot, theta, theta_dot = self.state 69 | costheta = math.cos(theta) 70 | sintheta = math.sin(theta) 71 | temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta) / self.total_mass 72 | thetaacc = (self.gravity * sintheta - costheta * temp) / \ 73 | (self.length * (4.0/3.0 - self.masspole * costheta * costheta / self.total_mass)) 74 | xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass 75 | x = x + self.tau * x_dot 76 | x_dot = x_dot + self.tau * xacc 77 | theta = theta + self.tau * theta_dot 78 | theta_dot = theta_dot + self.tau * thetaacc 79 | return (x, x_dot, theta, theta_dot) 80 | 81 | def step(self, action): 82 | # print(action) 83 | assert self.action_space.contains(action), \ 84 | "%r (%s) invalid" % (action, type(action)) 85 | # Cast action to float to strip np trappings 86 | force = self.force_mag * float(action) 87 | self.state = self.stepPhysics(force) 88 | x, x_dot, theta, theta_dot = self.state 89 | 90 | # self.steps += 1 91 | done = x < -self.x_threshold \ 92 | or x > self.x_threshold \ 93 | or theta < -self.theta_threshold_radians \ 94 | or theta > self.theta_threshold_radians \ 95 | # or self.steps >= self._max_episode_steps 96 | done = bool(done) 97 | 98 | if not done: 99 | reward = 1.0 100 | elif self.steps_beyond_done is None: 101 | # Pole just fell! 102 | self.steps_beyond_done = 0 103 | reward = 1.0 104 | else: 105 | if self.steps_beyond_done == 0: 106 | logger.warn(""" 107 | You are calling 'step()' even though this environment has already returned 108 | done = True. You should always call 'reset()' once you receive 'done = True' 109 | Any further steps are undefined behavior. 110 | """) 111 | self.steps_beyond_done += 1 112 | reward = 0.0 113 | 114 | 115 | return np.array(self.state), reward, done, {} 116 | 117 | def reset(self): 118 | self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) 119 | self.steps_beyond_done = None 120 | return np.array(self.state) 121 | 122 | def render(self, mode='human'): 123 | screen_width = 600 124 | screen_height = 400 125 | 126 | world_width = self.x_threshold * 2 127 | scale = screen_width /world_width 128 | carty = 100 # TOP OF CART 129 | polewidth = 10.0 130 | polelen = scale * 1.0 131 | cartwidth = 50.0 132 | cartheight = 30.0 133 | 134 | if self.viewer is None: 135 | from gym.envs.classic_control import rendering 136 | self.viewer = rendering.Viewer(screen_width, screen_height) 137 | l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 138 | axleoffset = cartheight / 4.0 139 | cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 140 | self.carttrans = rendering.Transform() 141 | cart.add_attr(self.carttrans) 142 | self.viewer.add_geom(cart) 143 | l, r, t, b = -polewidth / 2, polewidth / 2, polelen-polewidth / 2, -polewidth / 2 144 | pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 145 | pole.set_color(.8, .6, .4) 146 | self.poletrans = rendering.Transform(translation=(0, axleoffset)) 147 | pole.add_attr(self.poletrans) 148 | pole.add_attr(self.carttrans) 149 | self.viewer.add_geom(pole) 150 | self.axle = rendering.make_circle(polewidth / 2) 151 | self.axle.add_attr(self.poletrans) 152 | self.axle.add_attr(self.carttrans) 153 | self.axle.set_color(.5, .5, .8) 154 | self.viewer.add_geom(self.axle) 155 | self.track = rendering.Line((0, carty), (screen_width, carty)) 156 | self.track.set_color(0, 0, 0) 157 | self.viewer.add_geom(self.track) 158 | 159 | if self.state is None: 160 | return None 161 | 162 | x = self.state 163 | cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART 164 | self.carttrans.set_translation(cartx, carty) 165 | self.poletrans.set_rotation(-x[2]) 166 | 167 | return self.viewer.render(return_rgb_array=(mode == 'rgb_array')) 168 | 169 | def close(self): 170 | if self.viewer: 171 | self.viewer.close() -------------------------------------------------------------------------------- /wolp.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim import Adam 7 | 8 | from model import (Actor, Critic) 9 | from memory import SequentialMemory 10 | from random_process import OrnsteinUhlenbeckProcess 11 | from util import * 12 | import action_space 13 | 14 | # from ipdb import set_trace as debug 15 | 16 | criterion = nn.MSELoss() 17 | 18 | class WOLPAgent(object): 19 | def __init__(self, nb_states, nb_actions, args): 20 | 21 | if args.seed > 0: 22 | self.seed(args.seed) 23 | 24 | self.nb_states = nb_states 25 | self.nb_actions= nb_actions 26 | 27 | # Create Actor and Critic Network 28 | net_cfg = { 29 | 'hidden1':args.hidden1, 30 | 'hidden2':args.hidden2, 31 | 'init_w':args.init_w 32 | } 33 | 34 | ################################## Our Code Start ################################################ 35 | self.low = args.low 36 | self.high = args.high 37 | self.action_space = action_space.Space(self.low, self.high, args.max_actions) 38 | self.k_nearest_neighbors = max(1, int(args.max_actions * args.k_ratio)) 39 | ################################## Our Code End ################################################ 40 | 41 | self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg) 42 | self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg) 43 | self.actor_optim = Adam(self.actor.parameters(), lr=args.prate) 44 | 45 | self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg) 46 | self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg) 47 | self.critic_optim = Adam(self.critic.parameters(), lr=args.rate) 48 | 49 | hard_update(self.actor_target, self.actor) # Make sure target is with the same weight 50 | hard_update(self.critic_target, self.critic) 51 | 52 | #Create replay buffer 53 | self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length) 54 | self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=args.ou_theta, mu=args.ou_mu, sigma=args.ou_sigma) 55 | 56 | # Hyper-parameters 57 | self.batch_size = args.bsize 58 | self.tau = args.tau 59 | self.discount = args.discount 60 | self.depsilon = 1.0 / args.epsilon 61 | 62 | # 63 | self.epsilon = 1.0 64 | self.s_t = None # Most recent state 65 | self.a_t = None # Most recent action 66 | self.is_training = True 67 | 68 | # 69 | if USE_CUDA: self.cuda() 70 | 71 | def get_action_space(self): 72 | return self.action_space 73 | 74 | def update_policy(self): 75 | # Sample batch 76 | state_batch, action_batch, reward_batch, \ 77 | next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size) 78 | 79 | # Prepare for the target q batch 80 | next_q_values = self.critic_target([ 81 | to_tensor(next_state_batch, volatile=True), 82 | self.actor_target(to_tensor(next_state_batch, volatile=True)), 83 | ]) 84 | next_q_values.volatile=False 85 | 86 | target_q_batch = to_tensor(reward_batch) + \ 87 | self.discount*to_tensor(terminal_batch.astype(np.float))*next_q_values 88 | 89 | # Critic update 90 | self.critic.zero_grad() 91 | 92 | q_batch = self.critic([ to_tensor(state_batch), to_tensor(action_batch) ]) 93 | 94 | value_loss = criterion(q_batch, target_q_batch) 95 | value_loss.backward() 96 | self.critic_optim.step() 97 | 98 | # Actor update 99 | self.actor.zero_grad() 100 | 101 | policy_loss = -self.critic([ 102 | to_tensor(state_batch), 103 | self.actor(to_tensor(state_batch)) 104 | ]) 105 | 106 | policy_loss = policy_loss.mean() 107 | policy_loss.backward() 108 | self.actor_optim.step() 109 | 110 | # Target update 111 | soft_update(self.actor_target, self.actor, self.tau) 112 | soft_update(self.critic_target, self.critic, self.tau) 113 | 114 | def eval(self): 115 | self.actor.eval() 116 | self.actor_target.eval() 117 | self.critic.eval() 118 | self.critic_target.eval() 119 | 120 | def cuda(self): 121 | self.actor.cuda() 122 | self.actor_target.cuda() 123 | self.critic.cuda() 124 | self.critic_target.cuda() 125 | 126 | def observe(self, r_t, s_t1, done): 127 | if self.is_training: 128 | self.memory.append(self.s_t, self.a_t, r_t, done) 129 | self.s_t = s_t1 130 | 131 | def random_action(self): 132 | action = np.random.uniform(-1.,1.,self.nb_actions) 133 | self.a_t = action 134 | return action 135 | 136 | def select_action(self, s_t, decay_epsilon=True): 137 | proto_action = self.ddpg_select_action(s_t, decay_epsilon=decay_epsilon) 138 | # print("Proto action: {}, proto action.shape: {}".format(proto_action, proto_action.shape)) 139 | # print(proto_action) 140 | 141 | actions = self.action_space.search_point(proto_action, self.k_nearest_neighbors)[0] 142 | # print("len(actions): {}".format(len(actions))) 143 | states = np.tile(s_t, [len(actions), 1]) 144 | 145 | a = [to_tensor(states), to_tensor(actions)] 146 | # print("states: {}, actions: {}".format(a[0].size(), a[1].size())) 147 | actions_evaluation = self.critic([to_tensor(states), to_tensor(actions)]) 148 | # print("actions_evaluation: {}, actions_evaluation.size(): {}".format(actions_evaluation, actions_evaluation.size())) 149 | actions_evaluation_np = actions_evaluation.detach().numpy() 150 | max_index = np.argmax(actions_evaluation_np) 151 | 152 | self.a_t = actions[max_index] 153 | return self.a_t 154 | 155 | def ddpg_select_action(self, s_t, decay_epsilon=True): 156 | action = to_numpy( 157 | self.actor(to_tensor(np.array([s_t]))) 158 | ).squeeze(0) 159 | action += self.is_training*max(self.epsilon, 0)*self.random_process.sample() 160 | action = np.clip(action, -1., 1.) 161 | 162 | if decay_epsilon: 163 | self.epsilon -= self.depsilon 164 | 165 | return action 166 | 167 | def reset(self, obs): 168 | self.s_t = obs 169 | self.random_process.reset_states() 170 | 171 | def load_weights(self, output): 172 | if output is None: return 173 | 174 | self.actor.load_state_dict( 175 | torch.load('{}/actor.pkl'.format(output)) 176 | ) 177 | 178 | self.critic.load_state_dict( 179 | torch.load('{}/critic.pkl'.format(output)) 180 | ) 181 | 182 | 183 | def save_model(self,output): 184 | torch.save( 185 | self.actor.state_dict(), 186 | '{}/actor.pkl'.format(output) 187 | ) 188 | torch.save( 189 | self.critic.state_dict(), 190 | '{}/critic.pkl'.format(output) 191 | ) 192 | 193 | def seed(self,s): 194 | torch.manual_seed(s) 195 | if USE_CUDA: 196 | torch.cuda.manual_seed(s) 197 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import deque, namedtuple 3 | import warnings 4 | import random 5 | 6 | import numpy as np 7 | 8 | # [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py 9 | 10 | # This is to be understood as a transition: Given `state0`, performing `action` 11 | # yields `reward` and results in `state1`, which might be `terminal`. 12 | Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1') 13 | 14 | 15 | def sample_batch_indexes(low, high, size): 16 | if high - low >= size: 17 | # We have enough data. Draw without replacement, that is each index is unique in the 18 | # batch. We cannot use `np.random.choice` here because it is horribly inefficient as 19 | # the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion. 20 | # `random.sample` does the same thing (drawing without replacement) and is way faster. 21 | try: 22 | r = xrange(low, high) 23 | except NameError: 24 | r = range(low, high) 25 | batch_idxs = random.sample(r, size) 26 | else: 27 | # Not enough data. Help ourselves with sampling from the range, but the same index 28 | # can occur multiple times. This is not good and should be avoided by picking a 29 | # large enough warm-up phase. 30 | warnings.warn('Not enough entries to sample without replacement. Consider increasing your warm-up phase to avoid oversampling!') 31 | batch_idxs = np.random.random_integers(low, high - 1, size=size) 32 | assert len(batch_idxs) == size 33 | return batch_idxs 34 | 35 | 36 | class RingBuffer(object): 37 | def __init__(self, maxlen): 38 | self.maxlen = maxlen 39 | self.start = 0 40 | self.length = 0 41 | self.data = [None for _ in range(maxlen)] 42 | 43 | def __len__(self): 44 | return self.length 45 | 46 | def __getitem__(self, idx): 47 | if idx < 0 or idx >= self.length: 48 | raise KeyError() 49 | return self.data[(self.start + idx) % self.maxlen] 50 | 51 | def append(self, v): 52 | if self.length < self.maxlen: 53 | # We have space, simply increase the length. 54 | self.length += 1 55 | elif self.length == self.maxlen: 56 | # No space, "remove" the first item. 57 | self.start = (self.start + 1) % self.maxlen 58 | else: 59 | # This should never happen. 60 | raise RuntimeError() 61 | self.data[(self.start + self.length - 1) % self.maxlen] = v 62 | 63 | 64 | def zeroed_observation(observation): 65 | if hasattr(observation, 'shape'): 66 | return np.zeros(observation.shape) 67 | elif hasattr(observation, '__iter__'): 68 | out = [] 69 | for x in observation: 70 | out.append(zeroed_observation(x)) 71 | return out 72 | else: 73 | return 0. 74 | 75 | 76 | class Memory(object): 77 | def __init__(self, window_length, ignore_episode_boundaries=False): 78 | self.window_length = window_length 79 | self.ignore_episode_boundaries = ignore_episode_boundaries 80 | 81 | self.recent_observations = deque(maxlen=window_length) 82 | self.recent_terminals = deque(maxlen=window_length) 83 | 84 | def sample(self, batch_size, batch_idxs=None): 85 | raise NotImplementedError() 86 | 87 | def append(self, observation, action, reward, terminal, training=True): 88 | self.recent_observations.append(observation) 89 | self.recent_terminals.append(terminal) 90 | 91 | def get_recent_state(self, current_observation): 92 | # This code is slightly complicated by the fact that subsequent observations might be 93 | # from different episodes. We ensure that an experience never spans multiple episodes. 94 | # This is probably not that important in practice but it seems cleaner. 95 | state = [current_observation] 96 | idx = len(self.recent_observations) - 1 97 | for offset in range(0, self.window_length - 1): 98 | current_idx = idx - offset 99 | current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False 100 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 101 | # The previously handled observation was terminal, don't add the current one. 102 | # Otherwise we would leak into a different episode. 103 | break 104 | state.insert(0, self.recent_observations[current_idx]) 105 | while len(state) < self.window_length: 106 | state.insert(0, zeroed_observation(state[0])) 107 | return state 108 | 109 | def get_config(self): 110 | config = { 111 | 'window_length': self.window_length, 112 | 'ignore_episode_boundaries': self.ignore_episode_boundaries, 113 | } 114 | return config 115 | 116 | class SequentialMemory(Memory): 117 | def __init__(self, limit, **kwargs): 118 | super(SequentialMemory, self).__init__(**kwargs) 119 | 120 | self.limit = limit 121 | 122 | # Do not use deque to implement the memory. This data structure may seem convenient but 123 | # it is way too slow on random access. Instead, we use our own ring buffer implementation. 124 | self.actions = RingBuffer(limit) 125 | self.rewards = RingBuffer(limit) 126 | self.terminals = RingBuffer(limit) 127 | self.observations = RingBuffer(limit) 128 | 129 | def sample(self, batch_size, batch_idxs=None): 130 | if batch_idxs is None: 131 | # Draw random indexes such that we have at least a single entry before each 132 | # index. 133 | batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size) 134 | batch_idxs = np.array(batch_idxs) + 1 135 | assert np.min(batch_idxs) >= 1 136 | assert np.max(batch_idxs) < self.nb_entries 137 | assert len(batch_idxs) == batch_size 138 | 139 | # Create experiences 140 | experiences = [] 141 | for idx in batch_idxs: 142 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 143 | while terminal0: 144 | # Skip this transition because the environment was reset here. Select a new, random 145 | # transition and use this instead. This may cause the batch to contain the same 146 | # transition twice. 147 | idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] 148 | terminal0 = self.terminals[idx - 2] if idx >= 2 else False 149 | assert 1 <= idx < self.nb_entries 150 | 151 | # This code is slightly complicated by the fact that subsequent observations might be 152 | # from different episodes. We ensure that an experience never spans multiple episodes. 153 | # This is probably not that important in practice but it seems cleaner. 154 | state0 = [self.observations[idx - 1]] 155 | for offset in range(0, self.window_length - 1): 156 | current_idx = idx - 2 - offset 157 | current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False 158 | if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): 159 | # The previously handled observation was terminal, don't add the current one. 160 | # Otherwise we would leak into a different episode. 161 | break 162 | state0.insert(0, self.observations[current_idx]) 163 | while len(state0) < self.window_length: 164 | state0.insert(0, zeroed_observation(state0[0])) 165 | action = self.actions[idx - 1] 166 | reward = self.rewards[idx - 1] 167 | terminal1 = self.terminals[idx - 1] 168 | 169 | # Okay, now we need to create the follow-up state. This is state0 shifted on timestep 170 | # to the right. Again, we need to be careful to not include an observation from the next 171 | # episode if the last state is terminal. 172 | state1 = [np.copy(x) for x in state0[1:]] 173 | state1.append(self.observations[idx]) 174 | 175 | assert len(state0) == self.window_length 176 | assert len(state1) == len(state0) 177 | experiences.append(Experience(state0=state0, action=action, reward=reward, 178 | state1=state1, terminal1=terminal1)) 179 | assert len(experiences) == batch_size 180 | return experiences 181 | 182 | def sample_and_split(self, batch_size, batch_idxs=None): 183 | experiences = self.sample(batch_size, batch_idxs) 184 | 185 | state0_batch = [] 186 | reward_batch = [] 187 | action_batch = [] 188 | terminal1_batch = [] 189 | state1_batch = [] 190 | for e in experiences: 191 | state0_batch.append(e.state0) 192 | state1_batch.append(e.state1) 193 | reward_batch.append(e.reward) 194 | action_batch.append(e.action) 195 | terminal1_batch.append(0. if e.terminal1 else 1.) 196 | 197 | # Prepare and validate parameters. 198 | state0_batch = np.array(state0_batch).reshape(batch_size,-1) 199 | state1_batch = np.array(state1_batch).reshape(batch_size,-1) 200 | terminal1_batch = np.array(terminal1_batch).reshape(batch_size,-1) 201 | reward_batch = np.array(reward_batch).reshape(batch_size,-1) 202 | action_batch = np.array(action_batch).reshape(batch_size,-1) 203 | 204 | return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch 205 | 206 | 207 | def append(self, observation, action, reward, terminal, training=True): 208 | super(SequentialMemory, self).append(observation, action, reward, terminal, training=training) 209 | 210 | # This needs to be understood as follows: in `observation`, take `action`, obtain `reward` 211 | # and weather the next state is `terminal` or not. 212 | if training: 213 | self.observations.append(observation) 214 | self.actions.append(action) 215 | self.rewards.append(reward) 216 | self.terminals.append(terminal) 217 | 218 | @property 219 | def nb_entries(self): 220 | return len(self.observations) 221 | 222 | def get_config(self): 223 | config = super(SequentialMemory, self).get_config() 224 | config['limit'] = self.limit 225 | return config 226 | 227 | 228 | class EpisodeParameterMemory(Memory): 229 | def __init__(self, limit, **kwargs): 230 | super(EpisodeParameterMemory, self).__init__(**kwargs) 231 | self.limit = limit 232 | 233 | self.params = RingBuffer(limit) 234 | self.intermediate_rewards = [] 235 | self.total_rewards = RingBuffer(limit) 236 | 237 | def sample(self, batch_size, batch_idxs=None): 238 | if batch_idxs is None: 239 | batch_idxs = sample_batch_indexes(0, self.nb_entries, size=batch_size) 240 | assert len(batch_idxs) == batch_size 241 | 242 | batch_params = [] 243 | batch_total_rewards = [] 244 | for idx in batch_idxs: 245 | batch_params.append(self.params[idx]) 246 | batch_total_rewards.append(self.total_rewards[idx]) 247 | return batch_params, batch_total_rewards 248 | 249 | def append(self, observation, action, reward, terminal, training=True): 250 | super(EpisodeParameterMemory, self).append(observation, action, reward, terminal, training=training) 251 | if training: 252 | self.intermediate_rewards.append(reward) 253 | 254 | def finalize_episode(self, params): 255 | total_reward = sum(self.intermediate_rewards) 256 | self.total_rewards.append(total_reward) 257 | self.params.append(params) 258 | self.intermediate_rewards = [] 259 | 260 | @property 261 | def nb_entries(self): 262 | return len(self.total_rewards) 263 | 264 | def get_config(self): 265 | config = super(SequentialMemory, self).get_config() 266 | config['limit'] = self.limit 267 | return config 268 | --------------------------------------------------------------------------------