├── README.md ├── predictron.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Predictron 2 | 3 | This is an implementation of "The Predictron: End-To-End Learning and Planning" (http://arxiv.org/abs/1612.08810) in Chainer. 4 | 5 | # How to run 6 | 7 | Install Chainer and run `train.py`. 8 | -------------------------------------------------------------------------------- /predictron.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import chainer 4 | from chainer import functions as F 5 | from chainer import links as L 6 | 7 | 8 | class Sequence(chainer.ChainList): 9 | 10 | def __init__(self, *layers): 11 | self.layers = layers 12 | links = [layer for layer in layers if isinstance(layer, chainer.Link)] 13 | super().__init__(*links) 14 | 15 | def __call__(self, x, test): 16 | h = x 17 | for layer in self.layers: 18 | argnames = inspect.getargspec(layer)[0] 19 | if 'test' in argnames: 20 | h = layer(h, test=test) 21 | else: 22 | h = layer(h) 23 | return h 24 | 25 | 26 | class PredictronCore(chainer.Chain): 27 | 28 | def __init__(self, n_tasks, n_channels): 29 | super().__init__( 30 | state2hidden=Sequence( 31 | L.Convolution2D(n_channels, n_channels, ksize=3, pad=1), 32 | L.BatchNormalization(n_channels), 33 | F.relu, 34 | ), 35 | hidden2nextstate=Sequence( 36 | L.Convolution2D(n_channels, n_channels, ksize=3, pad=1), 37 | L.BatchNormalization(n_channels), 38 | F.relu, 39 | L.Convolution2D(n_channels, n_channels, ksize=3, pad=1), 40 | L.BatchNormalization(n_channels), 41 | F.relu, 42 | ), 43 | hidden2reward=Sequence( 44 | L.Linear(None, n_channels), 45 | L.BatchNormalization(n_channels), 46 | F.relu, 47 | L.Linear(n_channels, n_tasks), 48 | ), 49 | hidden2gamma=Sequence( 50 | L.Linear(None, n_channels), 51 | L.BatchNormalization(n_channels), 52 | F.relu, 53 | L.Linear(n_channels, n_tasks), 54 | F.sigmoid, 55 | ), 56 | hidden2lambda=Sequence( 57 | L.Linear(None, n_channels), 58 | L.BatchNormalization(n_channels), 59 | F.relu, 60 | L.Linear(n_channels, n_tasks), 61 | F.sigmoid, 62 | ), 63 | ) 64 | 65 | def __call__(self, x, test): 66 | hidden = self.state2hidden(x, test=test) 67 | # No skip 68 | nextstate = self.hidden2nextstate(hidden, test=test) 69 | reward = self.hidden2reward(hidden, test=test) 70 | gamma = self.hidden2gamma(hidden, test=test) 71 | # lambda doesn't backprop errors to states 72 | lmbda = self.hidden2lambda( 73 | chainer.Variable(hidden.data), test=test) 74 | return nextstate, reward, gamma, lmbda 75 | 76 | 77 | class Predictron(chainer.Chain): 78 | 79 | def __init__(self, n_tasks, n_channels, model_steps, 80 | use_reward_gamma=True, use_lambda=True, usage_weighting=True): 81 | self.model_steps = model_steps 82 | self.use_reward_gamma = use_reward_gamma 83 | self.use_lambda = use_lambda 84 | self.usage_weighting = usage_weighting 85 | super().__init__( 86 | obs2state=Sequence( 87 | L.Convolution2D(None, n_channels, ksize=3, pad=1), 88 | L.BatchNormalization(n_channels), 89 | F.relu, 90 | L.Convolution2D(n_channels, n_channels, ksize=3, pad=1), 91 | L.BatchNormalization(n_channels), 92 | F.relu, 93 | ), 94 | core=PredictronCore(n_tasks=n_tasks, n_channels=n_channels), 95 | state2value=Sequence( 96 | L.Linear(None, n_channels), 97 | L.BatchNormalization(n_channels), 98 | F.relu, 99 | L.Linear(n_channels, n_tasks), 100 | ), 101 | ) 102 | 103 | def unroll(self, x, test): 104 | # Compute g^k and lambda^k for k=0,...,K 105 | g_k = [] 106 | lambda_k = [] 107 | state = self.obs2state(x, test=test) 108 | g_k.append(self.state2value(state, test=test)) # g^0 = v^0 109 | reward_sum = 0 110 | gamma_prod = 1 111 | for k in range(self.model_steps): 112 | state, reward, gamma, lmbda = self.core(state, test=test) 113 | if not self.use_reward_gamma: 114 | reward = 0 115 | gamma = 1 116 | if not self.use_lambda: 117 | lmbda = 1 118 | lambda_k.append(lmbda) # lambda^k 119 | v = self.state2value(state, test=test) 120 | reward_sum += gamma_prod * reward 121 | gamma_prod *= gamma 122 | g_k.append(reward_sum + gamma_prod * v) # g^{k+1} 123 | lambda_k.append(0) # lambda^K = 0 124 | # Compute g^lambda 125 | lambda_prod = 1 126 | g_lambda = 0 127 | w_k = [] 128 | for k in range(self.model_steps + 1): 129 | w = (1 - lambda_k[k]) * lambda_prod 130 | w_k.append(w) 131 | lambda_prod *= lambda_k[k] 132 | # g^lambda doesn't backprop errors to g^k 133 | g_lambda += w * chainer.Variable(g_k[k].data) 134 | return g_k, g_lambda, w_k 135 | 136 | def supervised_loss(self, x, t): 137 | g_k, g_lambda, w_k = self.unroll(x, test=False) 138 | if self.usage_weighting: 139 | g_k_loss = sum(F.sum(w * (g - t) ** 2) / x.shape[0] 140 | for g, w in zip(g_k, w_k)) 141 | else: 142 | g_k_loss = sum(F.mean_squared_error(g, t) for g in g_k) / len(g_k) 143 | g_lambda_loss = F.mean_squared_error(g_lambda, t) 144 | return g_k_loss, g_lambda_loss 145 | 146 | def unsupervised_loss(self, x): 147 | g_k, g_lambda, w_k = self.unroll(x, test=False) 148 | # Only update g_k 149 | g_lambda.creator = None 150 | if self.usage_weighting: 151 | return sum(F.sum(w * (g - g_lambda) ** 2) / x.shape[0] 152 | for g, w in zip(g_k, w_k)) 153 | else: 154 | return sum(F.mean_squared_error(g, g_lambda) for g in g_k) / len(g_k) 155 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import chainer 4 | from chainer import optimizers 5 | import numpy as np 6 | 7 | from predictron import Predictron 8 | 9 | 10 | def generate_maze(size=20): 11 | # 1 = wall, 0 = empty 12 | # A maze will contain 30% of walls 13 | maze = (np.random.rand(size, size) < 0.3).astype(np.float32) 14 | maze[0, 0] = 0 15 | maze[-1, -1] = 0 16 | return maze 17 | 18 | 19 | def check_connectivity_to_right_bottom(maze): 20 | maze = np.pad(maze, 1, 'constant', constant_values=1) # Pad with walls 21 | visited = np.zeros_like(maze) 22 | dst = (maze.shape[0] - 2, maze.shape[1] - 2) 23 | 24 | def search(pos): 25 | if visited[pos]: 26 | return 27 | if maze[pos]: 28 | return 29 | visited[pos] = 1 30 | search((pos[0] + 1, pos[1])) 31 | search((pos[0] - 1, pos[1])) 32 | search((pos[0], pos[1] + 1)) 33 | search((pos[0], pos[1] - 1)) 34 | 35 | search(dst) 36 | return np.diagonal(visited[1:-1, 1:-1]) 37 | 38 | 39 | def phi(maze): 40 | # 3 binary channels: wall, empty, field (full of ones except padding) 41 | return np.asarray([maze, -maze, np.ones_like(maze)]) 42 | 43 | 44 | def generate_supervised_batch(maze_size=20, batch_size=100): 45 | xs = [] 46 | ts = [] 47 | for b in range(batch_size): 48 | maze = generate_maze(maze_size) 49 | connectivity = check_connectivity_to_right_bottom(maze) 50 | xs.append(phi(maze)) 51 | ts.append(connectivity) 52 | x = np.asarray(xs, dtype=np.float32) 53 | t = np.asarray(ts, dtype=np.float32) 54 | return x, t 55 | 56 | 57 | def generate_unsupervised_batch(maze_size=20, batch_size=100): 58 | xs = [] 59 | for b in range(batch_size): 60 | maze = generate_maze(maze_size) 61 | xs.append(phi(maze)) 62 | x = np.asarray(xs, dtype=np.float32) 63 | return x 64 | 65 | 66 | def main(): 67 | 68 | parser = argparse.ArgumentParser(description='Predictron on random mazes') 69 | parser.add_argument('--batchsize', '-b', type=int, default=100, 70 | help='Number of transitions in each mini-batch') 71 | parser.add_argument('--max-iter', type=int, default=10000, 72 | help='Number of iterations to run') 73 | parser.add_argument('--n-model-steps', type=int, default=16, 74 | help='Number of model steps') 75 | parser.add_argument('--n-channels', type=int, default=32, 76 | help='Number of channels for hidden units') 77 | parser.add_argument('--maze-size', type=int, default=20, 78 | help='Size of random mazes') 79 | parser.add_argument('--use-reward-gamma', type=bool, default=True, 80 | help='Use reward and gamma') 81 | parser.add_argument('--use-lambda', type=bool, default=True, 82 | help='Use lambda-network') 83 | parser.add_argument('--usage-weighting', type=bool, default=True, 84 | help='Enable usage weighting') 85 | parser.add_argument('--n-unsupervised-updates', type=int, default=0, 86 | help='Number of unsupervised upates per supervised' 87 | 'updates') 88 | parser.add_argument('--gpu', '-g', type=int, default=-1, 89 | help='GPU ID (negative value indicates CPU)') 90 | parser.add_argument('--out', '-o', default='result', 91 | help='Directory to output the result') 92 | args = parser.parse_args() 93 | 94 | # chainer.set_debug(True) 95 | model = Predictron(n_tasks=args.maze_size, n_channels=args.n_channels, 96 | model_steps=args.n_model_steps, 97 | use_reward_gamma=args.use_reward_gamma, 98 | use_lambda=args.use_lambda, 99 | usage_weighting=args.usage_weighting) 100 | if args.gpu >= 0: 101 | chainer.cuda.get_device(args.gpu).use() 102 | model.to_gpu(args.gpu) 103 | opt = optimizers.Adam() 104 | opt.setup(model) 105 | 106 | for i in range(args.max_iter): 107 | x, t = generate_supervised_batch( 108 | maze_size=args.maze_size, batch_size=args.batchsize) 109 | if args.gpu >= 0: 110 | x = chainer.cuda.to_gpu(x) 111 | t = chainer.cuda.to_gpu(t) 112 | model.cleargrads() 113 | g_k_loss, g_lambda_loss = model.supervised_loss(x, t) 114 | supervised_loss = g_k_loss + g_lambda_loss 115 | supervised_loss.backward() 116 | opt.update() 117 | for _ in range(args.n_unsupervised_updates): 118 | x = generate_unsupervised_batch( 119 | maze_size=args.maze_size, batch_size=args.batchsize) 120 | if args.gpu >= 0: 121 | x = chainer.cuda.to_gpu(x) 122 | model.cleargrads() 123 | unsupervised_loss = model.unsupervised_loss(x) 124 | unsupervised_loss.backward() 125 | opt.update() 126 | print(i, g_k_loss.data, g_lambda_loss.data, 127 | (g_lambda_loss.data ** 0.5)) 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | --------------------------------------------------------------------------------