├── .gitignore
├── README.md
└── Upside Down Reinforcement Learning.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | __pycache__/
3 | runs/
4 | checkpoints/
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Upside-Down-Reinforcement-Learning
2 | Implementation of Schmidhuber's Upside Down Reinforcement Learning paper
3 |
4 | Link to paper with theory: https://arxiv.org/pdf/1912.02875.pdf
5 |
6 | Link to paper with implementation details and results: https://arxiv.org/pdf/1912.02877.pdf
7 |
8 | Use as you wish. Tweet(@mfharoon)/email(hshams@hotmail.co.uk) me any interesting results you find and sets of hyperparameters that work for particular environments. I will share here. Thanks!
9 |
10 | ### Working Hyper-Parameters
11 |
12 | #### CartPole
13 | replay_size = 600
14 | last_few = 50
15 | batch_size = 64
16 | n_warm_up_episodes = 50
17 | n_episodes_per_iter = 50
18 | n_updates_per_iter = 100
19 | command_scale = 0.02
20 | lr = 0.001
21 |
22 |
--------------------------------------------------------------------------------
/Upside Down Reinforcement Learning.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import time
3 | import torch
4 | import numpy as np
5 | from copy import deepcopy
6 | import torch.nn.functional as F
7 |
8 | env = gym.make('CartPole-v1')
9 |
10 | def random_policy(obs, command):
11 | return np.random.randint(env.action_space.n)
12 |
13 | #Visualise agent function
14 | def visualise_agent(policy, command, n=5):
15 | try:
16 | for trial_i in range(n):
17 | current_command = deepcopy(command)
18 | observation = env.reset()
19 | done=False
20 | t=0
21 | episode_return=0
22 | while not done:
23 | env.render()
24 | action = policy(torch.tensor([observation]).double(), torch.tensor([command]).double())
25 | observation, reward, done, info = env.step(action)
26 | episode_return+=reward
27 | current_command[0]-= reward
28 | current_command[1] = max(1, current_command[1]-1)
29 | t+=1
30 | env.render()
31 | time.sleep(1.5)
32 | print("Episode {} finished after {} timesteps. Return = {}".format(trial_i, t, episode_return))
33 | env.close()
34 | except KeyboardInterrupt:
35 | env.close()
36 |
37 | #Behaviour function - Neural Network
38 | class FCNN_AGENT(torch.nn.Module):
39 | def __init__(self, command_scale):
40 | super().__init__()
41 | hidden_size=64
42 | self.command_scale=command_scale
43 | self.observation_embedding = torch.nn.Sequential(
44 | torch.nn.Linear(np.prod(env.observation_space.shape), hidden_size),
45 | torch.nn.Tanh()
46 | )
47 | self.command_embedding = torch.nn.Sequential(
48 | torch.nn.Linear(2, hidden_size),
49 | torch.nn.Sigmoid()
50 | )
51 | self.to_output = torch.nn.Sequential(
52 | torch.nn.Linear(hidden_size, hidden_size),
53 | torch.nn.ReLU(),
54 | torch.nn.Linear(hidden_size, env.action_space.n)
55 | )
56 |
57 | def forward(self, observation, command):
58 | obs_emebdding = self.observation_embedding(observation)
59 | cmd_embedding = self.command_embedding(command*self.command_scale)
60 | embedding = torch.mul(obs_emebdding, cmd_embedding)
61 | action_prob_logits = self.to_output(embedding)
62 | return action_prob_logits
63 |
64 | def create_optimizer(self, lr):
65 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
66 |
67 | #Fill the replay buffer with more experience
68 | def collect_experience(policy, replay_buffer, replay_size, last_few, n_episodes=100, log_to_tensorboard=True):
69 | global i_episode
70 | init_replay_buffer = deepcopy(replay_buffer)
71 | try:
72 | for _ in range(n_episodes):
73 | command = sample_command(init_replay_buffer, last_few)
74 | if log_to_tensorboard: writer.add_scalar('Command desired reward/Episode', command[0], i_episode) # write loss to a graph
75 | if log_to_tensorboard: writer.add_scalar('Command horizon/Episode', command[1], i_episode) # write loss to a graph
76 | observation = env.reset()
77 | episode_mem = {'observation':[],
78 | 'action':[],
79 | 'reward':[],}
80 | done=False
81 | while not done:
82 | action = policy(torch.tensor([observation]).double(), torch.tensor([command]).double())
83 | new_observation, reward, done, info = env.step(action)
84 |
85 | episode_mem['observation'].append(observation)
86 | episode_mem['action'].append(action)
87 | episode_mem['reward'].append(reward)
88 |
89 | observation=new_observation
90 | command[0]-= reward
91 | command[1] = max(1, command[1]-1)
92 | episode_mem['return']=sum(episode_mem['reward'])
93 | episode_mem['episode_len']=len(episode_mem['observation'])
94 | replay_buffer.append(episode_mem)
95 | i_episode+=1
96 | if log_to_tensorboard: writer.add_scalar('Return/Episode', sum(episode_mem['reward']), i_episode) # write loss to a graph
97 | print("Episode {} finished after {} timesteps. Return = {}".format(i_episode, len(episode_mem['observation']), sum(episode_mem['reward'])))
98 | env.close()
99 | except KeyboardInterrupt:
100 | env.close()
101 | replay_buffer = sorted(replay_buffer, key=lambda x:x['return'])[-replay_size:]
102 | return replay_buffer
103 |
104 | #Sample exploratory command
105 | def sample_command(replay_buffer, last_few):
106 | if len(replay_buffer)==0:
107 | return [1, 1]
108 | else:
109 | command_samples = replay_buffer[-last_few:]
110 | lengths = [mem['episode_len'] for mem in command_samples]
111 | returns = [mem['return'] for mem in command_samples]
112 | mean_return, std_return = np.mean(returns), np.std(returns)
113 | command_horizon = np.mean(lengths)
114 | desired_reward = np.random.uniform(mean_return, mean_return+std_return)
115 | return [desired_reward, command_horizon]
116 |
117 | #Improve behviour function by training on replay buffer
118 | def train_net(policy_net, replay_buffer, n_updates=100, batch_size=64, log_to_tensorboard=True):
119 | global i_updates
120 | all_costs = []
121 | for i in range(n_updates):
122 | batch_observations = np.zeros((batch_size, np.prod(env.observation_space.shape)))
123 | batch_commands = np.zeros((batch_size, 2))
124 | batch_label = np.zeros((batch_size))
125 | for b in range(batch_size):
126 | sample_episode = np.random.randint(0, len(replay_buffer))
127 | sample_t1 = np.random.randint(0, len(replay_buffer[sample_episode]['observation']))
128 | sample_t2 = len(replay_buffer[sample_episode]['observation'])
129 | ##sample_t2 = np.random.randint(sample_t1+1, len(replay_buffer[sample_episode]['observation'])+1)
130 | sample_horizon = sample_t2-sample_t1
131 | sample_mem = replay_buffer[sample_episode]['observation'][sample_t1]
132 | sample_desired_reward = sum(replay_buffer[sample_episode]['reward'][sample_t1:sample_t2])
133 | network_input = np.append(sample_mem, [sample_desired_reward, sample_horizon])
134 | label = replay_buffer[sample_episode]['action'][sample_t1]
135 | batch_observations[b] = sample_mem
136 | batch_commands[b] = [sample_desired_reward, sample_horizon]
137 | batch_label[b] = label
138 | batch_observations = torch.tensor(batch_observations).double()
139 | batch_commands = torch.tensor(batch_commands).double()
140 | batch_label = torch.tensor(batch_label).long()
141 | pred = policy_net(batch_observations, batch_commands)
142 | cost = F.cross_entropy(pred, batch_label)
143 | if log_to_tensorboard: writer.add_scalar('Cost/NN update', cost.item() , i_updates) # write loss to a graph
144 | all_costs.append(cost.item())
145 | cost.backward()
146 | policy_net.optimizer.step()
147 | policy_net.optimizer.zero_grad()
148 | i_updates+=1
149 | return np.mean(all_costs)
150 |
151 | #Return a greedy policy from a given network
152 | def create_greedy_policy(policy_network):
153 | def policy(obs, command):
154 | action_logits = policy_network(obs, command)
155 | action_probs = F.softmax(action_logits, dim=-1)
156 | action = np.argmax(action_probs.detach().numpy())
157 | return action
158 | return policy
159 |
160 | #Return a stochastic policy from a given network
161 | def create_stochastic_policy(policy_network):
162 | def policy(obs, command):
163 | action_logits = policy_network(obs, command)
164 | action_probs = F.softmax(action_logits, dim=-1)
165 | action = torch.distributions.Categorical(action_probs).sample().item()
166 | return action
167 | return policy
168 |
169 | #Initialize vars
170 | i_episode=0 #number of episodes trained so far
171 | i_updates=0 #number of parameter updates to the neural network so far
172 | replay_buffer = []
173 | log_to_tensorboard = False
174 |
175 | ## HYPERPARAMS
176 | replay_size = 700
177 | last_few = 50
178 | batch_size = 256
179 | n_warm_up_episodes = 50
180 | n_episodes_per_iter = 25
181 | n_updates_per_iter = 15
182 | command_scale = 0.02
183 | lr = 0.001
184 |
185 | # Initialize behaviour function
186 | agent = FCNN_AGENT(command_scale).double()
187 | agent.create_optimizer(lr)
188 |
189 | stochastic_policy = create_stochastic_policy(agent)
190 | greedy_policy = create_greedy_policy(agent)
191 |
192 | # SET UP TRAINING VISUALISATION
193 | if log_to_tensorboard: from torch.utils.tensorboard import SummaryWriter
194 | if log_to_tensorboard: writer = SummaryWriter() # we will use this to show our models performance on a graph using tensorboard
195 |
196 | #Collect warm up episodes
197 | replay_buffer = collect_experience(random_policy, replay_buffer, replay_size, last_few, n_warm_up_episodes, log_to_tensorboard)
198 | train_net(agent, replay_buffer, n_updates_per_iter, batch_size, log_to_tensorboard)
199 |
200 | #Collect experience and train behaviour function for given number of iterations
201 | n_iters = 1000
202 | for i in range(n_iters):
203 | replay_buffer = collect_experience(stochastic_policy, replay_buffer, replay_size, last_few, n_episodes_per_iter, log_to_tensorboard)
204 | train_net(agent, replay_buffer, n_updates_per_iter, batch_size, log_to_tensorboard)
205 |
206 | #Visualise final trained agent with greedy policy
207 | visualise_agent(greedy_policy, command=[150, 400], n=5)
208 |
209 | #Visualise final trained agent with stochastic policy
210 | visualise_agent(stochastic_policy, command=[150, 400], n=5)
211 |
--------------------------------------------------------------------------------