├── .gitignore ├── README.md ├── experience_replay.py ├── fourrooms.py ├── logger.py ├── main.py ├── models ├── option_critic_seed=0_1k └── option_critic_seed=0_2k ├── option_critic.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | .DS_store 3 | !.vscode/settings.json 4 | !.vscode/tasks.json 5 | !.vscode/launch.json 6 | !.vscode/extensions.json 7 | *.code-workspace 8 | settings.json 9 | .pylintrc 10 | *.pyc 11 | runs/* 12 | analysis/* 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Option Critic 2 | This repository is a PyTorch implementation of the paper "The Option-Critic Architecture" by Pierre-Luc Bacon, Jean Harb and Doina Precup [arXiv](https://arxiv.org/abs/1609.05140). It is mostly a rewriting of the original Theano code found [here](https://github.com/jeanharb/option_critic) into PyTorch, focused on readability and ease of understanding the logic begind the Option-Critic. 3 | 4 | 5 | ## CartPole 6 | Currently, the dense architecture can learn CartPole-v0 with a learning rate of 0.005, this has however only been tested with two options. (I dont see any reason to use more than two in the cart pole environment.) the current runs directory holds the training results for this env with 0.005 and 0.006 learning rates. Run it as follows: 7 | 8 | ``` 9 | python main.py 10 | ``` 11 | 12 | ## Atari Environments 13 | I suspect it will take a grid search over learning rate to work on Pong and such. Just supply the right `--env` argument and the model should switch to convolutions if the environment is Atari compatible. 14 | 15 | ## Four Room experiment 16 | There are plenty of resources to find a numpy version of the four rooms experiment, this one is a little bit different; represent the state as a one-hot encoded vector, and learn to solve this grid world using a deep net. Run this experiment as follows 17 | 18 | ``` 19 | python main.py --switch-goal True --env fourrooms 20 | ``` 21 | 22 | ## Requirements 23 | 24 | ``` 25 | pytorch>=1.12.1 26 | tensorboard>=2.0.2 27 | gym>=0.15.3 28 | ``` 29 | -------------------------------------------------------------------------------- /experience_replay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from collections import deque 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, capacity, seed=42): 7 | self.rng = random.SystemRandom(seed) 8 | self.buffer = deque(maxlen=capacity) 9 | 10 | def push(self, obs, option, reward, next_obs, done): 11 | self.buffer.append((obs, option, reward, next_obs, done)) 12 | 13 | def sample(self, batch_size): 14 | obs, option, reward, next_obs, done = zip(*self.rng.sample(self.buffer, batch_size)) 15 | return np.stack(obs), option, reward, np.stack(next_obs), done 16 | 17 | def __len__(self): 18 | return len(self.buffer) 19 | -------------------------------------------------------------------------------- /fourrooms.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import gym 4 | from gym import spaces 5 | from gym.utils import seeding 6 | import numpy as np 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class Fourrooms(gym.Env): 11 | metadata = { 12 | 'render.modes': ['human', 'rgb_array'], 13 | 'video.frames_per_second' : 50 14 | } 15 | 16 | def __init__(self): 17 | 18 | layout = """\ 19 | wwwwwwwwwwwww 20 | w w w 21 | w w w 22 | w w 23 | w w w 24 | w w w 25 | ww wwww w 26 | w www www 27 | w w w 28 | w w w 29 | w w 30 | w w w 31 | wwwwwwwwwwwww 32 | """ 33 | self.occupancy = np.array([list(map(lambda c: 1 if c=='w' else 0, line)) for line in layout.splitlines()]) 34 | 35 | # From any state the agent can perform one of four actions, up, down, left or right 36 | self.action_space = spaces.Discrete(4) 37 | self.observation_space = spaces.Box(low=0., high=1., shape=(np.sum(self.occupancy == 0),)) 38 | 39 | self.directions = [np.array((-1,0)), np.array((1,0)), np.array((0,-1)), np.array((0,1))] 40 | self.rng = np.random.RandomState(1234) 41 | 42 | self.tostate = {} 43 | statenum = 0 44 | for i in range(13): 45 | for j in range(13): 46 | if self.occupancy[i, j] == 0: 47 | self.tostate[(i,j)] = statenum 48 | statenum += 1 49 | self.tocell = {v:k for k,v in self.tostate.items()} 50 | 51 | self.goal = 62 # East doorway 52 | self.init_states = list(range(self.observation_space.shape[0])) 53 | self.init_states.remove(self.goal) 54 | self.ep_steps = 0 55 | 56 | def seed(self, seed=None): 57 | return self._seed(seed) 58 | 59 | def _seed(self, seed=None): 60 | self.np_random, seed = seeding.np_random(seed) 61 | return [seed] 62 | 63 | def empty_around(self, cell): 64 | avail = [] 65 | for action in range(self.action_space.n): 66 | nextcell = tuple(cell + self.directions[action]) 67 | if not self.occupancy[nextcell]: 68 | avail.append(nextcell) 69 | return avail 70 | 71 | def reset(self): 72 | state = self.rng.choice(self.init_states) 73 | self.currentcell = self.tocell[state] 74 | self.ep_steps = 0 75 | return self.get_state(state) 76 | 77 | def switch_goal(self): 78 | prev_goal = self.goal 79 | self.goal = self.rng.choice(self.init_states) 80 | self.init_states.append(prev_goal) 81 | self.init_states.remove(self.goal) 82 | assert prev_goal in self.init_states 83 | assert self.goal not in self.init_states 84 | 85 | def get_state(self, state): 86 | s = np.zeros(self.observation_space.shape[0]) 87 | s[state] = 1 88 | return s 89 | 90 | def render(self, show_goal=True): 91 | current_grid = np.array(self.occupancy) 92 | current_grid[self.currentcell[0], self.currentcell[1]] = -1 93 | if show_goal: 94 | goal_cell = self.tocell[self.goal] 95 | current_grid[goal_cell[0], goal_cell[1]] = -1 96 | return current_grid 97 | 98 | def step(self, action): 99 | """ 100 | The agent can perform one of four actions, 101 | up, down, left or right, which have a stochastic effect. With probability 2/3, the actions 102 | cause the agent to move one cell in the corresponding direction, and with probability 1/3, 103 | the agent moves instead in one of the other three directions, each with 1/9 probability. In 104 | either case, if the movement would take the agent into a wall then the agent remains in the 105 | same cell. 106 | We consider a case in which rewards are zero on all state transitions. 107 | """ 108 | self.ep_steps += 1 109 | 110 | nextcell = tuple(self.currentcell + self.directions[action]) 111 | if not self.occupancy[nextcell]: 112 | if self.rng.uniform() < 1/3.: 113 | empty_cells = self.empty_around(self.currentcell) 114 | self.currentcell = empty_cells[self.rng.randint(len(empty_cells))] 115 | else: 116 | self.currentcell = nextcell 117 | 118 | state = self.tostate[self.currentcell] 119 | done = state == self.goal 120 | reward = float(done) 121 | 122 | if not done and self.ep_steps >= 1000: 123 | done = True ; reward = 0.0 124 | 125 | return self.get_state(state), reward, done, None 126 | 127 | if __name__=="__main__": 128 | env = Fourrooms() 129 | env.seed(3) 130 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import numpy as np 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | class Logger(): 8 | def __init__(self, logdir, run_name): 9 | self.log_name = logdir + '/' + run_name 10 | self.tf_writer = None 11 | self.start_time = time.time() 12 | self.n_eps = 0 13 | 14 | if not os.path.exists(self.log_name): 15 | os.makedirs(self.log_name) 16 | 17 | self.writer = SummaryWriter(self.log_name) 18 | 19 | logging.basicConfig( 20 | level=logging.DEBUG, 21 | format='%(asctime)s %(message)s', 22 | handlers=[ 23 | logging.StreamHandler(), 24 | logging.FileHandler(self.log_name + '/logger.log'), 25 | ], 26 | datefmt='%Y/%m/%d %I:%M:%S %p' 27 | ) 28 | 29 | def log_episode(self, steps, reward, option_lengths, ep_steps, epsilon): 30 | self.n_eps += 1 31 | logging.info(f"> ep {self.n_eps} done. total_steps={steps} | reward={reward} | episode_steps={ep_steps} "\ 32 | f"| hours={(time.time()-self.start_time) / 60 / 60:.3f} | epsilon={epsilon:.3f}") 33 | self.writer.add_scalar(tag="episodic_rewards", scalar_value=reward, global_step=self.n_eps) 34 | self.writer.add_scalar(tag='episode_lengths', scalar_value=ep_steps, global_step=self.n_eps) 35 | 36 | # Keep track of options statistics 37 | for option, lens in option_lengths.items(): 38 | # Need better statistics for this one, point average is terrible in this case 39 | self.writer.add_scalar(tag=f"option_{option}_avg_length", scalar_value=np.mean(lens) if len(lens)>0 else 0, global_step=self.n_eps) 40 | self.writer.add_scalar(tag=f"option_{option}_active", scalar_value=sum(lens)/ep_steps, global_step=self.n_eps) 41 | def log_data(self, step, actor_loss, critic_loss, entropy, epsilon): 42 | if actor_loss: 43 | self.writer.add_scalar(tag="actor_loss", scalar_value=actor_loss.item(), global_step=step) 44 | if critic_loss: 45 | self.writer.add_scalar(tag="critic_loss", scalar_value=critic_loss.item(), global_step=step) 46 | self.writer.add_scalar(tag="policy_entropy", scalar_value=entropy, global_step=step) 47 | self.writer.add_scalar(tag="epsilon",scalar_value=epsilon, global_step=step) 48 | 49 | if __name__=="__main__": 50 | logger = Logger(logdir='runs/', run_name='test_model-test_env') 51 | steps = 200 ; reward = 5 ; option_lengths = {opt: np.random.randint(0,5,size=(5)) for opt in range(5)} ; ep_steps = 50 52 | logger.log_episode(steps, reward, option_lengths, ep_steps) 53 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import torch 4 | from copy import deepcopy 5 | 6 | from option_critic import OptionCriticFeatures, OptionCriticConv 7 | from option_critic import critic_loss as critic_loss_fn 8 | from option_critic import actor_loss as actor_loss_fn 9 | 10 | from experience_replay import ReplayBuffer 11 | from utils import make_env, to_tensor 12 | from logger import Logger 13 | 14 | import time 15 | 16 | parser = argparse.ArgumentParser(description="Option Critic PyTorch") 17 | parser.add_argument('--env', default='CartPole-v0', help='ROM to run') 18 | parser.add_argument('--optimal-eps', type=float, default=0.05, help='Epsilon when playing optimally') 19 | parser.add_argument('--frame-skip', default=4, type=int, help='Every how many frames to process') 20 | parser.add_argument('--learning-rate',type=float, default=.0005, help='Learning rate') 21 | parser.add_argument('--gamma', type=float, default=.99, help='Discount rate') 22 | parser.add_argument('--epsilon-start', type=float, default=1.0, help=('Starting value for epsilon.')) 23 | parser.add_argument('--epsilon-min', type=float, default=.1, help='Minimum epsilon.') 24 | parser.add_argument('--epsilon-decay', type=float, default=20000, help=('Number of steps to minimum epsilon.')) 25 | parser.add_argument('--max-history', type=int, default=10000, help=('Maximum number of steps stored in replay')) 26 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size.') 27 | parser.add_argument('--freeze-interval', type=int, default=200, help=('Interval between target freezes.')) 28 | parser.add_argument('--update-frequency', type=int, default=4, help=('Number of actions before each SGD update.')) 29 | parser.add_argument('--termination-reg', type=float, default=0.01, help=('Regularization to decrease termination prob.')) 30 | parser.add_argument('--entropy-reg', type=float, default=0.01, help=('Regularization to increase policy entropy.')) 31 | parser.add_argument('--num-options', type=int, default=2, help=('Number of options to create.')) 32 | parser.add_argument('--temp', type=float, default=1, help='Action distribution softmax tempurature param.') 33 | 34 | parser.add_argument('--max_steps_ep', type=int, default=18000, help='number of maximum steps per episode.') 35 | parser.add_argument('--max_steps_total', type=int, default=int(4e6), help='number of maximum steps to take.') # bout 4 million 36 | parser.add_argument('--cuda', type=bool, default=True, help='Enable CUDA training (recommended if possible).') 37 | parser.add_argument('--seed', type=int, default=0, help='Random seed for numpy, torch, random.') 38 | parser.add_argument('--logdir', type=str, default='runs', help='Directory for logging statistics') 39 | parser.add_argument('--exp', type=str, default=None, help='optional experiment name') 40 | parser.add_argument('--switch-goal', type=bool, default=False, help='switch goal after 2k eps') 41 | 42 | def run(args): 43 | env, is_atari = make_env(args.env) 44 | option_critic = OptionCriticConv if is_atari else OptionCriticFeatures 45 | device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu') 46 | 47 | option_critic = option_critic( 48 | in_features=env.observation_space.shape[0], 49 | num_actions=env.action_space.n, 50 | num_options=args.num_options, 51 | temperature=args.temp, 52 | eps_start=args.epsilon_start, 53 | eps_min=args.epsilon_min, 54 | eps_decay=args.epsilon_decay, 55 | eps_test=args.optimal_eps, 56 | device=device 57 | ) 58 | # Create a prime network for more stable Q values 59 | option_critic_prime = deepcopy(option_critic) 60 | 61 | optim = torch.optim.RMSprop(option_critic.parameters(), lr=args.learning_rate) 62 | 63 | np.random.seed(args.seed) 64 | torch.manual_seed(args.seed) 65 | env.seed(args.seed) 66 | 67 | buffer = ReplayBuffer(capacity=args.max_history, seed=args.seed) 68 | logger = Logger(logdir=args.logdir, run_name=f"{OptionCriticFeatures.__name__}-{args.env}-{args.exp}-{time.ctime()}") 69 | 70 | steps = 0 ; 71 | if args.switch_goal: print(f"Current goal {env.goal}") 72 | while steps < args.max_steps_total: 73 | 74 | rewards = 0 ; option_lengths = {opt:[] for opt in range(args.num_options)} 75 | 76 | obs = env.reset() 77 | state = option_critic.get_state(to_tensor(obs)) 78 | greedy_option = option_critic.greedy_option(state) 79 | current_option = 0 80 | 81 | # Goal switching experiment: run for 1k episodes in fourrooms, switch goals and run for another 82 | # 2k episodes. In option-critic, if the options have some meaning, only the policy-over-options 83 | # should be finedtuned (this is what we would hope). 84 | if args.switch_goal and logger.n_eps == 1000: 85 | torch.save({'model_params': option_critic.state_dict(), 86 | 'goal_state': env.goal}, 87 | f'models/option_critic_seed={args.seed}_1k') 88 | env.switch_goal() 89 | print(f"New goal {env.goal}") 90 | 91 | if args.switch_goal and logger.n_eps > 2000: 92 | torch.save({'model_params': option_critic.state_dict(), 93 | 'goal_state': env.goal}, 94 | f'models/option_critic_seed={args.seed}_2k') 95 | break 96 | 97 | done = False ; ep_steps = 0 ; option_termination = True ; curr_op_len = 0 98 | while not done and ep_steps < args.max_steps_ep: 99 | epsilon = option_critic.epsilon 100 | 101 | if option_termination: 102 | option_lengths[current_option].append(curr_op_len) 103 | current_option = np.random.choice(args.num_options) if np.random.rand() < epsilon else greedy_option 104 | curr_op_len = 0 105 | 106 | action, logp, entropy = option_critic.get_action(state, current_option) 107 | 108 | next_obs, reward, done, _ = env.step(action) 109 | buffer.push(obs, current_option, reward, next_obs, done) 110 | rewards += reward 111 | 112 | actor_loss, critic_loss = None, None 113 | if len(buffer) > args.batch_size: 114 | actor_loss = actor_loss_fn(obs, current_option, logp, entropy, \ 115 | reward, done, next_obs, option_critic, option_critic_prime, args) 116 | loss = actor_loss 117 | 118 | if steps % args.update_frequency == 0: 119 | data_batch = buffer.sample(args.batch_size) 120 | critic_loss = critic_loss_fn(option_critic, option_critic_prime, data_batch, args) 121 | loss += critic_loss 122 | 123 | optim.zero_grad() 124 | loss.backward() 125 | optim.step() 126 | 127 | if steps % args.freeze_interval == 0: 128 | option_critic_prime.load_state_dict(option_critic.state_dict()) 129 | 130 | state = option_critic.get_state(to_tensor(next_obs)) 131 | option_termination, greedy_option = option_critic.predict_option_termination(state, current_option) 132 | 133 | # update global steps etc 134 | steps += 1 135 | ep_steps += 1 136 | curr_op_len += 1 137 | obs = next_obs 138 | 139 | logger.log_data(steps, actor_loss, critic_loss, entropy.item(), epsilon) 140 | 141 | logger.log_episode(steps, rewards, option_lengths, ep_steps, epsilon) 142 | 143 | if __name__=="__main__": 144 | args = parser.parse_args() 145 | run(args) 146 | -------------------------------------------------------------------------------- /models/option_critic_seed=0_1k: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lweitkamp/option-critic-pytorch/fab40f7aae0ff45cf5945b7de79d5ae5446d31a0/models/option_critic_seed=0_1k -------------------------------------------------------------------------------- /models/option_critic_seed=0_2k: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lweitkamp/option-critic-pytorch/fab40f7aae0ff45cf5945b7de79d5ae5446d31a0/models/option_critic_seed=0_2k -------------------------------------------------------------------------------- /option_critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Categorical, Bernoulli 4 | 5 | from math import exp 6 | import numpy as np 7 | 8 | from utils import to_tensor 9 | 10 | 11 | class OptionCriticConv(nn.Module): 12 | def __init__(self, 13 | in_features, 14 | num_actions, 15 | num_options, 16 | temperature=1.0, 17 | eps_start=1.0, 18 | eps_min=0.1, 19 | eps_decay=int(1e6), 20 | eps_test=0.05, 21 | device='cpu', 22 | testing=False): 23 | 24 | super(OptionCriticConv, self).__init__() 25 | 26 | self.in_channels = in_features 27 | self.num_actions = num_actions 28 | self.num_options = num_options 29 | self.magic_number = 7 * 7 * 64 30 | self.device = device 31 | self.testing = testing 32 | 33 | self.temperature = temperature 34 | self.eps_min = eps_min 35 | self.eps_start = eps_start 36 | self.eps_decay = eps_decay 37 | self.eps_test = eps_test 38 | self.num_steps = 0 39 | 40 | self.features = nn.Sequential( 41 | nn.Conv2d(self.in_channels, 32, kernel_size=8, stride=4), 42 | nn.ReLU(), 43 | nn.Conv2d(32, 64, kernel_size=4, stride=2), 44 | nn.ReLU(), 45 | nn.Conv2d(64, 64, kernel_size=3, stride=1), 46 | nn.ReLU(), 47 | nn.modules.Flatten(), 48 | nn.Linear(self.magic_number, 512), 49 | nn.ReLU() 50 | ) 51 | 52 | self.Q = nn.Linear(512, num_options) # Policy-Over-Options 53 | self.terminations = nn.Linear(512, num_options) # Option-Termination 54 | self.options_W = nn.Parameter(torch.zeros(num_options, 512, num_actions)) 55 | self.options_b = nn.Parameter(torch.zeros(num_options, num_actions)) 56 | 57 | self.to(device) 58 | self.train(not testing) 59 | 60 | def get_state(self, obs): 61 | if obs.ndim < 4: 62 | obs = obs.unsqueeze(0) 63 | obs = obs.to(self.device) 64 | state = self.features(obs) 65 | return state 66 | 67 | def get_Q(self, state): 68 | return self.Q(state) 69 | 70 | def predict_option_termination(self, state, current_option): 71 | termination = self.terminations(state)[:, current_option].sigmoid() 72 | option_termination = Bernoulli(termination).sample() 73 | 74 | Q = self.get_Q(state) 75 | next_option = Q.argmax(dim=-1) 76 | return bool(option_termination.item()), next_option.item() 77 | 78 | def get_terminations(self, state): 79 | return self.terminations(state).sigmoid() 80 | 81 | def get_action(self, state, option): 82 | logits = state.data @ self.options_W[option] + self.options_b[option] 83 | action_dist = (logits / self.temperature).softmax(dim=-1) 84 | action_dist = Categorical(action_dist) 85 | 86 | action = action_dist.sample() 87 | logp = action_dist.log_prob(action) 88 | entropy = action_dist.entropy() 89 | 90 | return action.item(), logp, entropy 91 | 92 | def greedy_option(self, state): 93 | Q = self.get_Q(state) 94 | return Q.argmax(dim=-1).item() 95 | 96 | @property 97 | def epsilon(self): 98 | if not self.testing: 99 | eps = self.eps_min + (self.eps_start - self.eps_min) * exp(-self.num_steps / self.eps_decay) 100 | self.num_steps += 1 101 | else: 102 | eps = self.eps_test 103 | return eps 104 | 105 | 106 | class OptionCriticFeatures(nn.Module): 107 | def __init__(self, 108 | in_features, 109 | num_actions, 110 | num_options, 111 | temperature=1.0, 112 | eps_start=1.0, 113 | eps_min=0.1, 114 | eps_decay=int(1e6), 115 | eps_test=0.05, 116 | device='cpu', 117 | testing=False): 118 | 119 | super(OptionCriticFeatures, self).__init__() 120 | 121 | self.in_features = in_features 122 | self.num_actions = num_actions 123 | self.num_options = num_options 124 | self.device = device 125 | self.testing = testing 126 | 127 | self.temperature = temperature 128 | self.eps_min = eps_min 129 | self.eps_start = eps_start 130 | self.eps_decay = eps_decay 131 | self.eps_test = eps_test 132 | self.num_steps = 0 133 | 134 | self.features = nn.Sequential( 135 | nn.Linear(in_features, 32), 136 | nn.ReLU(), 137 | nn.Linear(32, 64), 138 | nn.ReLU() 139 | ) 140 | 141 | self.Q = nn.Linear(64, num_options) # Policy-Over-Options 142 | self.terminations = nn.Linear(64, num_options) # Option-Termination 143 | self.options_W = nn.Parameter(torch.zeros(num_options, 64, num_actions)) 144 | self.options_b = nn.Parameter(torch.zeros(num_options, num_actions)) 145 | 146 | self.to(device) 147 | self.train(not testing) 148 | 149 | def get_state(self, obs): 150 | if obs.ndim < 4: 151 | obs = obs.unsqueeze(0) 152 | obs = obs.to(self.device) 153 | state = self.features(obs) 154 | return state 155 | 156 | def get_Q(self, state): 157 | return self.Q(state) 158 | 159 | def predict_option_termination(self, state, current_option): 160 | termination = self.terminations(state)[:, current_option].sigmoid() 161 | option_termination = Bernoulli(termination).sample() 162 | Q = self.get_Q(state) 163 | next_option = Q.argmax(dim=-1) 164 | return bool(option_termination.item()), next_option.item() 165 | 166 | def get_terminations(self, state): 167 | return self.terminations(state).sigmoid() 168 | 169 | def get_action(self, state, option): 170 | logits = state.data @ self.options_W[option] + self.options_b[option] 171 | action_dist = (logits / self.temperature).softmax(dim=-1) 172 | action_dist = Categorical(action_dist) 173 | 174 | action = action_dist.sample() 175 | logp = action_dist.log_prob(action) 176 | entropy = action_dist.entropy() 177 | 178 | return action.item(), logp, entropy 179 | 180 | def greedy_option(self, state): 181 | Q = self.get_Q(state) 182 | return Q.argmax(dim=-1).item() 183 | 184 | @property 185 | def epsilon(self): 186 | if not self.testing: 187 | eps = self.eps_min + (self.eps_start - self.eps_min) * exp(-self.num_steps / self.eps_decay) 188 | self.num_steps += 1 189 | else: 190 | eps = self.eps_test 191 | return eps 192 | 193 | 194 | def critic_loss(model, model_prime, data_batch, args): 195 | obs, options, rewards, next_obs, dones = data_batch 196 | batch_idx = torch.arange(len(options)).long() 197 | options = torch.LongTensor(options).to(model.device) 198 | rewards = torch.FloatTensor(rewards).to(model.device) 199 | masks = 1 - torch.FloatTensor(dones).to(model.device) 200 | 201 | # The loss is the TD loss of Q and the update target, so we need to calculate Q 202 | states = model.get_state(to_tensor(obs)).squeeze(0) 203 | Q = model.get_Q(states) 204 | 205 | # the update target contains Q_next, but for stable learning we use prime network for this 206 | next_states_prime = model_prime.get_state(to_tensor(next_obs)).squeeze(0) 207 | next_Q_prime = model_prime.get_Q(next_states_prime) # detach? 208 | 209 | # Additionally, we need the beta probabilities of the next state 210 | next_states = model.get_state(to_tensor(next_obs)).squeeze(0) 211 | next_termination_probs = model.get_terminations(next_states).detach() 212 | next_options_term_prob = next_termination_probs[batch_idx, options] 213 | 214 | # Now we can calculate the update target gt 215 | gt = rewards + masks * args.gamma * \ 216 | ((1 - next_options_term_prob) * next_Q_prime[batch_idx, options] + next_options_term_prob * next_Q_prime.max(dim=-1)[0]) 217 | 218 | # to update Q we want to use the actual network, not the prime 219 | td_err = (Q[batch_idx, options] - gt.detach()).pow(2).mul(0.5).mean() 220 | return td_err 221 | 222 | def actor_loss(obs, option, logp, entropy, reward, done, next_obs, model, model_prime, args): 223 | state = model.get_state(to_tensor(obs)) 224 | next_state = model.get_state(to_tensor(next_obs)) 225 | next_state_prime = model_prime.get_state(to_tensor(next_obs)) 226 | 227 | option_term_prob = model.get_terminations(state)[:, option] 228 | next_option_term_prob = model.get_terminations(next_state)[:, option].detach() 229 | 230 | Q = model.get_Q(state).detach().squeeze() 231 | next_Q_prime = model_prime.get_Q(next_state_prime).detach().squeeze() 232 | 233 | # Target update gt 234 | gt = reward + (1 - done) * args.gamma * \ 235 | ((1 - next_option_term_prob) * next_Q_prime[option] + next_option_term_prob * next_Q_prime.max(dim=-1)[0]) 236 | 237 | # The termination loss 238 | termination_loss = option_term_prob * (Q[option].detach() - Q.max(dim=-1)[0].detach() + args.termination_reg) * (1 - done) 239 | 240 | # actor-critic policy gradient with entropy regularization 241 | policy_loss = -logp * (gt.detach() - Q[option]) - args.entropy_reg * entropy 242 | actor_loss = termination_loss + policy_loss 243 | return actor_loss 244 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | 5 | from gym.wrappers import AtariPreprocessing, TransformReward 6 | from gym.wrappers import FrameStack as FrameStack_ 7 | 8 | from fourrooms import Fourrooms 9 | 10 | 11 | class LazyFrames(object): 12 | def __init__(self, frames): 13 | self._frames = frames 14 | 15 | def __array__(self, dtype=None): 16 | out = np.concatenate(self._frames, axis=0) 17 | if dtype is not None: 18 | out = out.astype(dtype) 19 | return out 20 | 21 | def __len__(self): 22 | return len(self.__array__()) 23 | 24 | def __getitem__(self, i): 25 | return self.__array__()[i] 26 | 27 | 28 | class FrameStack(FrameStack_): 29 | def __init__(self, env, k): 30 | FrameStack_.__init__(self, env, k) 31 | 32 | def _get_ob(self): 33 | assert len(self.frames) == self.k 34 | return LazyFrames(list(self.frames)) 35 | 36 | def make_env(env_name): 37 | 38 | if env_name == 'fourrooms': 39 | return Fourrooms(), False 40 | 41 | env = gym.make(env_name) 42 | is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv) 43 | if is_atari: 44 | env = AtariPreprocessing(env, grayscale_obs=True, scale_obs=True, terminal_on_life_loss=True) 45 | env = TransformReward(env, lambda r: np.clip(r, -1, 1)) 46 | env = FrameStack(env, 4) 47 | return env, is_atari 48 | 49 | def to_tensor(obs): 50 | obs = np.asarray(obs) 51 | obs = torch.from_numpy(obs).float() 52 | return obs 53 | --------------------------------------------------------------------------------