├── .gitignore ├── README.md └── Upside Down Reinforcement Learning.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | runs/ 4 | checkpoints/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Upside-Down-Reinforcement-Learning 2 | Implementation of Schmidhuber's Upside Down Reinforcement Learning paper 3 | 4 | Link to paper with theory: https://arxiv.org/pdf/1912.02875.pdf 5 | 6 | Link to paper with implementation details and results: https://arxiv.org/pdf/1912.02877.pdf 7 | 8 | Use as you wish. Tweet(@mfharoon)/email(hshams@hotmail.co.uk) me any interesting results you find and sets of hyperparameters that work for particular environments. I will share here. Thanks! 9 | 10 | ### Working Hyper-Parameters 11 | 12 | #### CartPole 13 | replay_size = 600
14 | last_few = 50
15 | batch_size = 64
16 | n_warm_up_episodes = 50
17 | n_episodes_per_iter = 50
18 | n_updates_per_iter = 100
19 | command_scale = 0.02
20 | lr = 0.001 21 | 22 | -------------------------------------------------------------------------------- /Upside Down Reinforcement Learning.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | import torch 4 | import numpy as np 5 | from copy import deepcopy 6 | import torch.nn.functional as F 7 | 8 | env = gym.make('CartPole-v1') 9 | 10 | def random_policy(obs, command): 11 | return np.random.randint(env.action_space.n) 12 | 13 | #Visualise agent function 14 | def visualise_agent(policy, command, n=5): 15 | try: 16 | for trial_i in range(n): 17 | current_command = deepcopy(command) 18 | observation = env.reset() 19 | done=False 20 | t=0 21 | episode_return=0 22 | while not done: 23 | env.render() 24 | action = policy(torch.tensor([observation]).double(), torch.tensor([command]).double()) 25 | observation, reward, done, info = env.step(action) 26 | episode_return+=reward 27 | current_command[0]-= reward 28 | current_command[1] = max(1, current_command[1]-1) 29 | t+=1 30 | env.render() 31 | time.sleep(1.5) 32 | print("Episode {} finished after {} timesteps. Return = {}".format(trial_i, t, episode_return)) 33 | env.close() 34 | except KeyboardInterrupt: 35 | env.close() 36 | 37 | #Behaviour function - Neural Network 38 | class FCNN_AGENT(torch.nn.Module): 39 | def __init__(self, command_scale): 40 | super().__init__() 41 | hidden_size=64 42 | self.command_scale=command_scale 43 | self.observation_embedding = torch.nn.Sequential( 44 | torch.nn.Linear(np.prod(env.observation_space.shape), hidden_size), 45 | torch.nn.Tanh() 46 | ) 47 | self.command_embedding = torch.nn.Sequential( 48 | torch.nn.Linear(2, hidden_size), 49 | torch.nn.Sigmoid() 50 | ) 51 | self.to_output = torch.nn.Sequential( 52 | torch.nn.Linear(hidden_size, hidden_size), 53 | torch.nn.ReLU(), 54 | torch.nn.Linear(hidden_size, env.action_space.n) 55 | ) 56 | 57 | def forward(self, observation, command): 58 | obs_emebdding = self.observation_embedding(observation) 59 | cmd_embedding = self.command_embedding(command*self.command_scale) 60 | embedding = torch.mul(obs_emebdding, cmd_embedding) 61 | action_prob_logits = self.to_output(embedding) 62 | return action_prob_logits 63 | 64 | def create_optimizer(self, lr): 65 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) 66 | 67 | #Fill the replay buffer with more experience 68 | def collect_experience(policy, replay_buffer, replay_size, last_few, n_episodes=100, log_to_tensorboard=True): 69 | global i_episode 70 | init_replay_buffer = deepcopy(replay_buffer) 71 | try: 72 | for _ in range(n_episodes): 73 | command = sample_command(init_replay_buffer, last_few) 74 | if log_to_tensorboard: writer.add_scalar('Command desired reward/Episode', command[0], i_episode) # write loss to a graph 75 | if log_to_tensorboard: writer.add_scalar('Command horizon/Episode', command[1], i_episode) # write loss to a graph 76 | observation = env.reset() 77 | episode_mem = {'observation':[], 78 | 'action':[], 79 | 'reward':[],} 80 | done=False 81 | while not done: 82 | action = policy(torch.tensor([observation]).double(), torch.tensor([command]).double()) 83 | new_observation, reward, done, info = env.step(action) 84 | 85 | episode_mem['observation'].append(observation) 86 | episode_mem['action'].append(action) 87 | episode_mem['reward'].append(reward) 88 | 89 | observation=new_observation 90 | command[0]-= reward 91 | command[1] = max(1, command[1]-1) 92 | episode_mem['return']=sum(episode_mem['reward']) 93 | episode_mem['episode_len']=len(episode_mem['observation']) 94 | replay_buffer.append(episode_mem) 95 | i_episode+=1 96 | if log_to_tensorboard: writer.add_scalar('Return/Episode', sum(episode_mem['reward']), i_episode) # write loss to a graph 97 | print("Episode {} finished after {} timesteps. Return = {}".format(i_episode, len(episode_mem['observation']), sum(episode_mem['reward']))) 98 | env.close() 99 | except KeyboardInterrupt: 100 | env.close() 101 | replay_buffer = sorted(replay_buffer, key=lambda x:x['return'])[-replay_size:] 102 | return replay_buffer 103 | 104 | #Sample exploratory command 105 | def sample_command(replay_buffer, last_few): 106 | if len(replay_buffer)==0: 107 | return [1, 1] 108 | else: 109 | command_samples = replay_buffer[-last_few:] 110 | lengths = [mem['episode_len'] for mem in command_samples] 111 | returns = [mem['return'] for mem in command_samples] 112 | mean_return, std_return = np.mean(returns), np.std(returns) 113 | command_horizon = np.mean(lengths) 114 | desired_reward = np.random.uniform(mean_return, mean_return+std_return) 115 | return [desired_reward, command_horizon] 116 | 117 | #Improve behviour function by training on replay buffer 118 | def train_net(policy_net, replay_buffer, n_updates=100, batch_size=64, log_to_tensorboard=True): 119 | global i_updates 120 | all_costs = [] 121 | for i in range(n_updates): 122 | batch_observations = np.zeros((batch_size, np.prod(env.observation_space.shape))) 123 | batch_commands = np.zeros((batch_size, 2)) 124 | batch_label = np.zeros((batch_size)) 125 | for b in range(batch_size): 126 | sample_episode = np.random.randint(0, len(replay_buffer)) 127 | sample_t1 = np.random.randint(0, len(replay_buffer[sample_episode]['observation'])) 128 | sample_t2 = len(replay_buffer[sample_episode]['observation']) 129 | ##sample_t2 = np.random.randint(sample_t1+1, len(replay_buffer[sample_episode]['observation'])+1) 130 | sample_horizon = sample_t2-sample_t1 131 | sample_mem = replay_buffer[sample_episode]['observation'][sample_t1] 132 | sample_desired_reward = sum(replay_buffer[sample_episode]['reward'][sample_t1:sample_t2]) 133 | network_input = np.append(sample_mem, [sample_desired_reward, sample_horizon]) 134 | label = replay_buffer[sample_episode]['action'][sample_t1] 135 | batch_observations[b] = sample_mem 136 | batch_commands[b] = [sample_desired_reward, sample_horizon] 137 | batch_label[b] = label 138 | batch_observations = torch.tensor(batch_observations).double() 139 | batch_commands = torch.tensor(batch_commands).double() 140 | batch_label = torch.tensor(batch_label).long() 141 | pred = policy_net(batch_observations, batch_commands) 142 | cost = F.cross_entropy(pred, batch_label) 143 | if log_to_tensorboard: writer.add_scalar('Cost/NN update', cost.item() , i_updates) # write loss to a graph 144 | all_costs.append(cost.item()) 145 | cost.backward() 146 | policy_net.optimizer.step() 147 | policy_net.optimizer.zero_grad() 148 | i_updates+=1 149 | return np.mean(all_costs) 150 | 151 | #Return a greedy policy from a given network 152 | def create_greedy_policy(policy_network): 153 | def policy(obs, command): 154 | action_logits = policy_network(obs, command) 155 | action_probs = F.softmax(action_logits, dim=-1) 156 | action = np.argmax(action_probs.detach().numpy()) 157 | return action 158 | return policy 159 | 160 | #Return a stochastic policy from a given network 161 | def create_stochastic_policy(policy_network): 162 | def policy(obs, command): 163 | action_logits = policy_network(obs, command) 164 | action_probs = F.softmax(action_logits, dim=-1) 165 | action = torch.distributions.Categorical(action_probs).sample().item() 166 | return action 167 | return policy 168 | 169 | #Initialize vars 170 | i_episode=0 #number of episodes trained so far 171 | i_updates=0 #number of parameter updates to the neural network so far 172 | replay_buffer = [] 173 | log_to_tensorboard = False 174 | 175 | ## HYPERPARAMS 176 | replay_size = 700 177 | last_few = 50 178 | batch_size = 256 179 | n_warm_up_episodes = 50 180 | n_episodes_per_iter = 25 181 | n_updates_per_iter = 15 182 | command_scale = 0.02 183 | lr = 0.001 184 | 185 | # Initialize behaviour function 186 | agent = FCNN_AGENT(command_scale).double() 187 | agent.create_optimizer(lr) 188 | 189 | stochastic_policy = create_stochastic_policy(agent) 190 | greedy_policy = create_greedy_policy(agent) 191 | 192 | # SET UP TRAINING VISUALISATION 193 | if log_to_tensorboard: from torch.utils.tensorboard import SummaryWriter 194 | if log_to_tensorboard: writer = SummaryWriter() # we will use this to show our models performance on a graph using tensorboard 195 | 196 | #Collect warm up episodes 197 | replay_buffer = collect_experience(random_policy, replay_buffer, replay_size, last_few, n_warm_up_episodes, log_to_tensorboard) 198 | train_net(agent, replay_buffer, n_updates_per_iter, batch_size, log_to_tensorboard) 199 | 200 | #Collect experience and train behaviour function for given number of iterations 201 | n_iters = 1000 202 | for i in range(n_iters): 203 | replay_buffer = collect_experience(stochastic_policy, replay_buffer, replay_size, last_few, n_episodes_per_iter, log_to_tensorboard) 204 | train_net(agent, replay_buffer, n_updates_per_iter, batch_size, log_to_tensorboard) 205 | 206 | #Visualise final trained agent with greedy policy 207 | visualise_agent(greedy_policy, command=[150, 400], n=5) 208 | 209 | #Visualise final trained agent with stochastic policy 210 | visualise_agent(stochastic_policy, command=[150, 400], n=5) 211 | --------------------------------------------------------------------------------