├── .gitignore ├── img └── vin.png ├── data └── gridworld_8.mat ├── utils.py ├── README.md ├── data.py ├── model.py └── train_main.py /.gitignore: -------------------------------------------------------------------------------- 1 | #ignore 2 | *.pyc 3 | -------------------------------------------------------------------------------- /img/vin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onlytailei/Value-Iteration-Networks-PyTorch/HEAD/img/vin.png -------------------------------------------------------------------------------- /data/gridworld_8.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onlytailei/Value-Iteration-Networks-PyTorch/HEAD/data/gridworld_8.mat -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | # helper methods to print nice table (taken from CGT code) 5 | def fmt_item(x, l): 6 | if isinstance(x, np.ndarray): 7 | assert x.ndim==0 8 | x = x.item() 9 | if isinstance(x, float): rep = "%g"%x 10 | else: rep = str(x) 11 | return " "*(l - len(rep)) + rep 12 | 13 | def fmt_row(width, row): 14 | out = " | ".join(fmt_item(x, width) for x in row) 15 | return out 16 | 17 | def flipkernel(kern): 18 | return kern[(slice(None, None, -1),) * 2 + (slice(None), slice(None))] 19 | 20 | def theano_to_tf(tensor): 21 | # NCHW -> NHWC 22 | return tf.transpose(tensor, [0, 2, 3, 1]) 23 | 24 | def tf_to_theano(tensor): 25 | # NHWC -> NCHW 26 | return tf.transpose(tensor, [0, 3, 1, 2]) 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Value Iteration Networks](https://arxiv.org/abs/1602.02867) in PyTorch 2 | 3 | > Tamar, A., Wu, Y., Thomas, G., Levine, S., and Abbeel, P. _Value Iteration Networks_. Neural Information Processing Systems (NIPS) 2016 4 | 5 | This repository contains an implementation of Value Iteration Networks (VIN) in PyTorch based on the original [Theano implementation](https://github.com/avivt/VIN) by the authors and the [TensoFlow implementation](https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks) by [Abhishek Kumar](https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks). 6 | 7 | VIN won the Best Paper Award at NIPS 2016. 8 | 9 | ![Value Iteration Network and Module](img/vin.png) 10 | 11 | ## Dependencies 12 | * Python 2.7 13 | * PyTorch 14 | * SciPy >= 0.18.1 (to load the data) 15 | 16 | ## Datasets 17 | 18 | - The datasets is from the [author's repository](https://github.com/avivt/VIN/tree/master/data). This repository contains the 8x8 GridWorld dataset for convenience and its small size. 19 | - utils.py and data.py are from [Abhishek Kumar's repository](https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks) 20 | 21 | ## Training 22 | 23 | ``` 24 | python train_main.py 25 | ``` 26 | 27 | Several arguments can be set in train_main.py like learning rate. Please check train_main.py for details. 28 | 29 | ``` 30 | python train_main.py --lr 0.001 31 | ``` 32 | 33 | ## References 34 | 35 | * [Value Iteration Networks on arXiv](https://arxiv.org/abs/1602.02867) 36 | * [Aviv Tamar's (author) implementation in Theano](https://github.com/avivt/VIN) 37 | * [Abhishek Kumar's implementation in TensorFlow](https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks). 38 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | 4 | def process_gridworld_data(input, imsize): 5 | # run training from input matlab data file, and save test data prediction in output file 6 | # load data from Matlab file, including 7 | # im_data: flattened images 8 | # state_data: concatenated one-hot vectors for each state variable 9 | # state_xy_data: state variable (x,y position) 10 | # label_data: one-hot vector for action (state difference) 11 | im_size=[imsize, imsize] 12 | matlab_data = sio.loadmat(input) 13 | im_data = matlab_data["batch_im_data"] 14 | im_data = (im_data - 1)/255 # obstacles = 1, free zone = 0 15 | value_data = matlab_data["batch_value_data"] 16 | state1_data = matlab_data["state_x_data"] 17 | state2_data = matlab_data["state_y_data"] 18 | label_data = matlab_data["batch_label_data"] 19 | ydata = label_data.astype('int8') 20 | Xim_data = im_data.astype('float32') 21 | Xim_data = Xim_data.reshape(-1, 1, im_size[0], im_size[1]) 22 | Xval_data = value_data.astype('float32') 23 | Xval_data = Xval_data.reshape(-1, 1, im_size[0], im_size[1]) 24 | Xdata = np.append(Xim_data, Xval_data, axis=1) 25 | # Need to transpose because Theano is NCHW, while TensorFlow is NHWC 26 | Xdata = np.transpose(Xdata, (0, 2, 3, 1)) 27 | S1data = state1_data.astype('int8') 28 | S2data = state2_data.astype('int8') 29 | 30 | all_training_samples = int(6/7.0*Xdata.shape[0]) 31 | training_samples = all_training_samples 32 | Xtrain = Xdata[0:training_samples] 33 | S1train = S1data[0:training_samples] 34 | S2train = S2data[0:training_samples] 35 | ytrain = ydata[0:training_samples] 36 | 37 | Xtest = Xdata[all_training_samples:] 38 | S1test = S1data[all_training_samples:] 39 | S2test = S2data[all_training_samples:] 40 | ytest = ydata[all_training_samples:] 41 | ytest = ytest.flatten() 42 | 43 | sortinds = np.random.permutation(training_samples) 44 | Xtrain = Xtrain[sortinds] 45 | S1train = S1train[sortinds] 46 | S2train = S2train[sortinds] 47 | ytrain = ytrain[sortinds] 48 | ytrain = ytrain.flatten() 49 | return Xtrain, S1train, S2train, ytrain, Xtest, S1test, S2test, ytest 50 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | ''' 4 | Author:Tai Lei 5 | Date:Thu 09 Mar 2017 05:38:33 PM WAT 6 | Info: VIN module 7 | ''' 8 | from __future__ import print_function 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.parameter import Parameter 13 | import torch.autograd as autograd 14 | import torch.nn.functional as F 15 | # use torch functional layer here 16 | 17 | def agLT(x): 18 | return autograd.Variable(torch.LongTensor([x])) 19 | 20 | class VIN_Block(nn.Module): 21 | def __init__(self, arg): 22 | super(VIN_Block, self).__init__() 23 | self.k = arg.k 24 | self.ch_i = arg.ch_i 25 | self.ch_h = arg.ch_h 26 | self.ch_q = arg.ch_q 27 | self.state_batch_size = arg.statebatchsize 28 | 29 | self.bias = Parameter(torch.zeros(self.ch_h).random_(0,1)*0.01) 30 | self.register_parameter('bias', self.bias) 31 | self.w0 = Parameter(torch.zeros(self.ch_h,self.ch_i,3,3).random_(0,1)*0.01) 32 | self.register_parameter('w0', self.w0) 33 | self.w1 = Parameter(torch.zeros(1,self.ch_h,1,1).random_(0,1)*0.01) 34 | self.register_parameter('w1', self.w1) 35 | self.w = Parameter(torch.zeros(self.ch_q,1,3,3).random_(0,1)*0.01) 36 | self.register_parameter('w', self.w) 37 | self.w_fb = Parameter(torch.zeros(self.ch_q,1,3,3).random_(0,1)*0.01) 38 | self.register_parameter('w_fb', self.w_fb) 39 | self.w_o = Parameter(torch.zeros(8, self.ch_q).random_(0,1)*0.01) 40 | self.register_parameter('w_o', self.w_o) 41 | self.softmax = nn.Softmax() 42 | 43 | def forward(self, X, S1, S2): 44 | X = autograd.Variable(X) 45 | h = F.conv2d(X.float(), self.w0, bias = self.bias, padding = 1) 46 | r = F.conv2d(h, self.w1) 47 | q = F.conv2d(r, self.w, padding = 1) 48 | v,_ = torch.max(q, 1) 49 | 50 | for i in range(0, self.k-1): 51 | 52 | rv = torch.cat((r,v),1) 53 | wwfb = torch.cat((self.w, self.w_fb),1) 54 | 55 | q = F.conv2d(rv, wwfb, padding = 1) 56 | v,_ = torch.max(q,1) 57 | 58 | q = F.conv2d(torch.cat((r,v),1),torch.cat((self.w,self.w_fb),1), padding = 1) 59 | 60 | bs = q.data.numpy().shape[0] 61 | len_ = self.state_batch_size*bs 62 | rprn = np.array([[item]*self.state_batch_size for item in np.arange(bs)],dtype=np.int64).reshape(len_) 63 | ins1 = S1.reshape(len_).astype(np.int64) 64 | ins2 = S2.reshape(len_).astype(np.int64) 65 | 66 | q_ = torch.transpose(q, 0,2) 67 | q__ = torch.transpose(q_,1 ,3) 68 | 69 | # TODO need to be optimize cause there is no gather_nd in pytorch 70 | abs_q = torch.index_select( 71 | torch.index_select( 72 | torch.index_select(q__,0,agLT(ins1[0])), 73 | 1,agLT(ins2[0])), 74 | 2,agLT(rprn[0])) 75 | for item in np.arange(1,len_): 76 | abs_q_ = torch.index_select( 77 | torch.index_select( 78 | torch.index_select(q__,0, 79 | agLT(ins1[item])), 80 | 1,agLT(ins2[item])), 81 | 2,agLT(rprn[item])) 82 | abs_q = torch.cat((abs_q,abs_q_),0) 83 | 84 | final_q = torch.squeeze(abs_q) 85 | output = F.linear(final_q, self.w_o) 86 | prediction = self.softmax(output) 87 | return output,prediction 88 | 89 | if __name__ == "__main__": 90 | obj = VIN_Block() 91 | print (obj.parameters) 92 | -------------------------------------------------------------------------------- /train_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | ''' 5 | Author:Tai Lei 6 | Date:Thu 09 Mar 2017 04:37:17 PM WAT 7 | Info: Implement VIN through pytorch 8 | ''' 9 | 10 | from __future__ import print_function 11 | import time 12 | import numpy as np 13 | import torch 14 | import argparse 15 | from model import VIN_Block 16 | from data import * 17 | import torch.optim as optim 18 | import torch.autograd as autograd 19 | import torch.nn as nn 20 | from utils import * 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--input_path", type = str, 24 | default = 'data/gridworld_8.mat', 25 | help = "path of the training data") 26 | parser.add_argument("--imsize", type = int, default = 8, 27 | help = "size of the input image") 28 | parser.add_argument('--lr', type = float, default = 0.001, 29 | help = 'Learning rate for RMSProp') 30 | parser.add_argument('--epochs', type = int, default = 30, 31 | help = 'Maximum epochs to train for') 32 | parser.add_argument('--k', type = int, default = 10, 33 | help = 'Number of value iterations') 34 | parser.add_argument('--ch_i', type = int, default = 2, 35 | help = 'Channels in input layer') 36 | parser.add_argument('--ch_h', type = int, default=150, 37 | help = 'Channels in initial hidden layer') 38 | parser.add_argument('--ch_q', type = int, default = 10, 39 | help = 'Channels in q layer (~actions)') 40 | parser.add_argument('--batchsize', type = int, default = 12, 41 | help = 'Batch size') 42 | parser.add_argument('--statebatchsize', type = int, default=10, 43 | help='Number of state inputs for each sample (real number, technically is k+1)') 44 | parser.add_argument('--untied_weights', type = bool, default=False, 45 | help = 'Untie weights of VI network') 46 | parser.add_argument('--display_step', type = int,default=1, 47 | help='Print summary output every n epochs') 48 | parser.add_argument('--log', type = bool, default = False, 49 | help = 'Enable for tensorboard summary') 50 | parser.add_argument('--logdir', type = str, 51 | default = '/tmp/vintf/', 52 | help = 'Directory to store tensorboard summary') 53 | 54 | args = parser.parse_args() 55 | 56 | model = VIN_Block(args) 57 | optimizer = optim.RMSprop(model.parameters(), args.lr) 58 | criterion = nn.CrossEntropyLoss() 59 | dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor 60 | model.type(dtype) 61 | 62 | Xtrain, S1train, S2train, ytrain, Xtest, S1test, S2test, ytest = process_gridworld_data( 63 | input=args.input_path, 64 | imsize=args.imsize) 65 | batch_size = args.batchsize 66 | 67 | print(fmt_row(10, ["Epoch", "Train Cost", "Train Err", "Epoch Time"])) 68 | for epoch in range(args.epochs): 69 | tstart = time.time() 70 | avg_err, avg_cost = 0.0, 0.0 71 | num_batches = int(Xtrain.shape[0]/batch_size) 72 | for i in range(0, Xtrain.shape[0], batch_size): 73 | j = i+batch_size 74 | if j <= Xtrain.shape[0]: 75 | X = torch.from_numpy( 76 | np.transpose(Xtrain[i:j].astype(float),[0,3,1,2])) 77 | S1 = S1train[i:j] 78 | S2 = S2train[i:j] 79 | y_origin = ytrain[i * args.statebatchsize:j * 80 | args.statebatchsize].astype(np.int64) 81 | y = torch.from_numpy(y_origin) 82 | 83 | output,prediction = model(X, S1, S2) 84 | loss = criterion(output,autograd.Variable(y)) 85 | optimizer.zero_grad() 86 | loss.backward() 87 | optimizer.step() 88 | 89 | cp = np.argmax(prediction.data.numpy(),1) 90 | err = np.mean(cp!=y_origin) 91 | avg_cost+=loss.data.numpy()[0] 92 | avg_err+=err 93 | 94 | if epoch % args.display_step == 0: 95 | elapsed = time.time() - tstart 96 | print(fmt_row(10, [epoch, avg_cost/num_batches, avg_err/num_batches, elapsed])) 97 | 98 | #test 99 | Xtest_ = torch.from_numpy(np.transpose(Xtest.astype(float),[0,3,1,2])) 100 | ytest_ = torch.from_numpy(ytest.astype(np.int64)) 101 | output_test,prediction_test = model(Xtest_, S1test, S2test) 102 | cp_test = np.argmax(prediction_test.data.numpy(),1) 103 | acc = np.mean(cp_test!=ytest) 104 | print("Accuracy: {}%".format(100 * (1 - acc))) 105 | --------------------------------------------------------------------------------