├── .gitignore ├── LICENCE ├── README.md ├── __init__.py ├── arguments.py ├── configs ├── eval-localization.conf ├── eval-tracking.conf └── train.conf ├── data ├── rgb_model_trained.chk.data-00000-of-00001 ├── rgb_model_trained.chk.index └── rgb_model_trained.chk.meta ├── display_data.py ├── evaluate.py ├── figs ├── architecture.png └── trajectory.jpg ├── pfnet.py ├── preprocess.py ├── requirements.txt ├── train.py ├── transformer ├── README.md ├── __init__.py ├── cluttered_mnist.py ├── data │ └── README.md ├── example.py ├── spatial_transformer.py └── tf_utils.py └── utils ├── __init__.py ├── network_layers.py └── tfrecordfeatures.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.o 3 | .idea/ 4 | 5 | # Windows image file caches 6 | Thumbs.db 7 | ehthumbs.db 8 | 9 | # Folder config file 10 | Desktop.ini 11 | 12 | # Recycle Bin used on file shares 13 | $RECYCLE.BIN/ 14 | 15 | # Windows Installer files 16 | *.cab 17 | *.msi 18 | *.msm 19 | *.msp 20 | 21 | # Windows shortcuts 22 | *.lnk 23 | 24 | # ========================= 25 | # Operating System Files 26 | # ========================= 27 | 28 | # OSX 29 | # ========================= 30 | 31 | .DS_Store 32 | .AppleDouble 33 | .LSOverride 34 | 35 | # Thumbnails 36 | ._* 37 | 38 | # Files that might appear in the root of a volume 39 | .DocumentRevisions-V100 40 | .fseventsd 41 | .Spotlight-V100 42 | .TemporaryItems 43 | .Trashes 44 | .VolumeIcon.icns 45 | 46 | # Directories potentially created on remote AFP share 47 | .AppleDB 48 | .AppleDesktop 49 | Network Trash Folder 50 | Temporary Items 51 | .apdisk 52 | 53 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Peter Karkus, AdaCompNUS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Particle Filter Networks 2 | 3 | Tensorflow implementation of Particle Filter Networks (PF-net) 4 | 5 | Peter Karkus, David Hsu, and Wee Sun Lee: Particle filter networks with application to visual localization. 6 | Conference on Robot Learning (CoRL), 2018. https://arxiv.org/abs/1805.08975 7 | 8 | ### PF-net architecture 9 |
10 | PF-net encodes both a learned probabilistic system model and the particle filter algorithm in a single neural network 11 | 12 | ### Localization example 13 |
14 | Example for successful global localization 15 | 16 | 17 | ### Requirements 18 | 19 | Python 2.7, Tensorflow 1.5.0 20 | 21 | Additional packages can be installed with 22 | ``` 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ### Dataset 27 | 28 | Datasets for localization experiments are available at 29 | https://drive.google.com/open?id=1hSDRg7tEf2O1D_NIL8OmVdHLN2KqYFF7 30 | 31 | The folder contains data for training, validation and testing, in 32 | tfrecords format. Download to the ./data/ folder. 33 | 34 | A simple script is available to visualize the data: 35 | ``` 36 | python display_data.py ./data/valid.tfrecords 37 | ``` 38 | 39 | 40 | ### Training 41 | 42 | The ```./configs/``` folder contains default configuration files for training and evaluation. 43 | Training requires the datesets to be downloaded into the ./data/ folder. 44 | Note that training requires significant resources, and it may take several hours or days depending on the hardware. 45 | 46 | PF-net can be trained with the default configuration using the following command: 47 | ``` 48 | python train.py -c ./configs/train.conf --obsmode rgb 49 | ``` 50 | 51 | By default, logs will be saved in a new folder under ./logs/. 52 | For help on the input arguments run 53 | ``` 54 | python train.py -h 55 | ``` 56 | 57 | 58 | 59 | ### Evaluation 60 | 61 | A pre-trained model for RGB input is available under ./data/rgb_model_trained.chk. For evaluating a trained model run 62 | ``` 63 | python evaluate.py -c ./configs/eval-tracking.conf --load ./data/rgb_model_trained.chk 64 | ``` 65 | for the tracking task, and 66 | ``` 67 | python evaluate.py -c ./configs/eval-localization.conf --load ./data/rgb_model_trained.chk 68 | ``` 69 | for semi-global localization. 70 | The input arguments are the same as for train.py. 71 | Results will be somewhat different from the ones in 72 | the paper, because the initial uncertainty on the test trajectories 73 | uses a different random seed. 74 | 75 | ### Contact 76 | 77 | Peter Karkus -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import numpy as np 3 | 4 | 5 | def parse_args(args=None): 6 | """ 7 | Parse command line arguments 8 | :param args: command line arguments or None (default) 9 | :return: dictionary of parameters 10 | """ 11 | 12 | p = configargparse.ArgParser(default_config_files=[]) 13 | 14 | p.add('-c', '--config', required=True, is_config_file=True, 15 | help='Config file. use ./config/train.conf for training') 16 | 17 | p.add('--trainfiles', nargs='*', help='Data file(s) for training (tfrecord).') 18 | p.add('--testfiles', nargs='*', help='Data file(s) for validation or evaluation (tfrecord).') 19 | 20 | # input configuration 21 | p.add('--obsmode', type=str, default='rgb', 22 | help='Observation input type. Possible values: rgb / depth / rgb-depth / vrf.') 23 | p.add('--mapmode', type=str, default='wall', 24 | help='Map input type with different (semantic) channels. ' + 25 | 'Possible values: wall / wall-door / wall-roomtype / wall-door-roomtype') 26 | p.add('--map_pixel_in_meters', type=float, default=0.02, 27 | help='The width (and height) of a pixel of the map in meters. Defaults to 0.02 for House3D data.') 28 | 29 | p.add('--init_particles_distr', type=str, default='tracking', 30 | help='Distribution of initial particles. Possible values: tracking / one-room / two-rooms / all-rooms') 31 | p.add('--init_particles_std', nargs='*', default=["0.3", "0.523599"], # tracking setting, 30cm, 30deg 32 | help='Standard deviations for generated initial particles. Only applies to the tracking setting.' + 33 | 'Expects two float values: translation std (meters), rotation std (radians)') 34 | p.add('--trajlen', type=int, default=24, 35 | help='Length of trajectories. Assumes lower or equal to the trajectory length in the input data.') 36 | 37 | # PF-net configuration 38 | p.add('--num_particles', type=int, default=30, help='Number of particles in PF-net.') 39 | p.add('--resample', type=str, default='false', 40 | help='Resample particles in PF-net. Possible values: true / false.') 41 | p.add('--alpha_resample_ratio', type=float, default=1.0, 42 | help='Trade-off parameter for soft-resampling in PF-net. Only effective if resample == true. ' 43 | 'Assumes values 0.0 < alpha <= 1.0. Alpha equal to 1.0 corresponds to hard-resampling.') 44 | p.add('--transition_std', nargs='*', default=["0.0", "0.0"], 45 | help='Standard deviations for transition model. Expects two float values: ' + 46 | 'translation std (meters), rotatation std (radians). Defaults to zeros.') 47 | 48 | # training configuration 49 | p.add('--batchsize', type=int, default=24, help='Minibatch size for training. Must be 1 for evaluation.') 50 | p.add('--bptt_steps', type=int, default=4, 51 | help='Number of backpropagation steps for training with backpropagation through time (BPTT). ' 52 | 'Assumed to be an integer divisor of the trajectory length (--trajlen).') 53 | p.add('--learningrate', type=float, default=0.0025, help='Initial learning rate for training.') 54 | p.add('--l2scale', type=float, default=4e-6, help='Scaling term for the L2 regularization loss.') 55 | p.add('--epochs', metavar='epochs', type=int, default=1, help='Number of epochs for training.') 56 | p.add('--decaystep', type=int, default=4, help='Decay the learning rate after every N epochs.') 57 | p.add('--decayrate', type=float, help='Rate of decaying the learning rate.') 58 | 59 | p.add('--load', type=str, default="", help='Load a previously trained model from a checkpoint file.') 60 | p.add('--logpath', type=str, default='', 61 | help='Specify path for logs. Makes a new directory under ./log/ if empty (default).') 62 | p.add('--seed', type=int, help='Fix the random seed of numpy and tensorflow if set to larger than zero.') 63 | p.add('--validseed', type=int, 64 | help='Fix the random seed for validation if set to larger than zero. ' + 65 | 'Useful to evaluate with a fixed set of initial particles, which reduces the validation error variance.') 66 | p.add('--gpu', type=int, default=0, help='Select a gpu on a multi-gpu machine. Defaults to zero.') 67 | 68 | params = p.parse_args(args=args) 69 | 70 | # fix numpy seed if needed 71 | if params.seed is not None and params.seed >= 0: 72 | np.random.seed(params.seed) 73 | 74 | # convert multi-input fileds to numpy arrays 75 | params.transition_std = np.array(params.transition_std, np.float32) 76 | params.init_particles_std = np.array(params.init_particles_std, np.float32) 77 | 78 | # convert boolean fields 79 | if params.resample not in ['false', 'true']: 80 | print ("The value of resample must be either 'false' or 'true'") 81 | raise ValueError 82 | params.resample = (params.resample == 'true') 83 | 84 | return params 85 | -------------------------------------------------------------------------------- /configs/eval-localization.conf: -------------------------------------------------------------------------------- 1 | # config file for evaluating semi-global localization, with initial belief uniform over one room 2 | 3 | trainfile = ./data/train.tfrecords 4 | testfile = ./data/test.tfrecords 5 | logpath = ./log/ 6 | 7 | init_particles_distr = one-room 8 | 9 | trajlen = 100 10 | num_particles = 1000 11 | resample = true 12 | transition_std = [0.04, 0.0872665] # 4cm, 5deg 13 | 14 | seed = 98 15 | validseed = 100 16 | 17 | batchsize = 1 18 | epochs = 1 19 | -------------------------------------------------------------------------------- /configs/eval-tracking.conf: -------------------------------------------------------------------------------- 1 | # config file for evaluating tracking where initial belief is close to the true state 2 | 3 | trainfile = ./data/train.tfrecords 4 | testfile = ./data/test.tfrecords 5 | logpath = ./log/ 6 | 7 | init_particles_distr = tracking 8 | init_particles_std = [0.3, 0.523599] 9 | trajlen = 24 10 | num_particles = 300 11 | resample = false 12 | transition_std = [0, 0] 13 | 14 | seed = 98 15 | validseed = 100 16 | 17 | batchsize = 1 18 | epochs = 1 19 | -------------------------------------------------------------------------------- /configs/train.conf: -------------------------------------------------------------------------------- 1 | # config file for the allobj dataset 2 | 3 | trainfile = ./data/train.tfrecords 4 | testfile = ./data/valid.tfrecords 5 | logpath = ./log/ 6 | 7 | init_particles_distr = tracking 8 | init_particles_std = [0.3, 0.523599] #30cm, 30degrees 9 | trajlen = 24 10 | num_particles = 30 11 | resample = false 12 | transition_std = [0, 0] 13 | 14 | validseed = 1 15 | 16 | batchsize = 24 17 | bptt_steps = 4 18 | learningrate = 0.0001 19 | l2scale = 4e-6 20 | 21 | epochs = 12 22 | decaystep = 4 23 | decayrate = 0.5 24 | 25 | -------------------------------------------------------------------------------- /data/rgb_model_trained.chk.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/data/rgb_model_trained.chk.data-00000-of-00001 -------------------------------------------------------------------------------- /data/rgb_model_trained.chk.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/data/rgb_model_trained.chk.index -------------------------------------------------------------------------------- /data/rgb_model_trained.chk.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/data/rgb_model_trained.chk.meta -------------------------------------------------------------------------------- /display_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import matplotlib.pyplot as plt 6 | import sys 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | # Fix Python 2.x. 11 | try: input = raw_input 12 | except NameError: pass 13 | 14 | from utils.tfrecordfeatures import * 15 | from preprocess import decode_image, raw_images_to_array 16 | 17 | try: 18 | import ipdb as pdb 19 | except Exception: 20 | import pdb 21 | 22 | 23 | def display_data(file): 24 | gen = tf.python_io.tf_record_iterator(file) 25 | for data_i, string_record in enumerate(gen): 26 | result = tf.train.Example.FromString(string_record) 27 | features = result.features.feature 28 | 29 | # maps are np.uint8 arrays. each has a different size. 30 | 31 | # wall map: 0 for free space, 255 for walls 32 | map_wall = decode_image(features['map_wall'].bytes_list.value[0]) 33 | 34 | # door map: 0 for free space, 255 for doors 35 | map_door = decode_image(features['map_door'].bytes_list.value[0]) 36 | 37 | # roomtype map: binary encoding of 8 possible room categories 38 | # one state may belong to multiple room categories 39 | map_roomtype = decode_image(features['map_roomtype'].bytes_list.value[0]) 40 | 41 | # roomid map: pixels correspond to unique room ids. 42 | # for overlapping rooms the higher ids overwrite lower ids 43 | map_roomid = decode_image(features['map_roomid'].bytes_list.value[0]) 44 | 45 | # true states 46 | # (x, y, theta). x,y: pixel coordinates; theta: radians 47 | # coordinates index the map as a numpy array: map[x, y] 48 | true_states = features['states'].bytes_list.value[0] 49 | true_states = np.frombuffer(true_states, np.float32).reshape((-1, 3)) 50 | 51 | # odometry 52 | # each entry is true_states[i+1]-true_states[i]. 53 | # last row is always [0,0,0] 54 | odometry = features['odometry'].bytes_list.value[0] 55 | odometry = np.frombuffer(odometry, np.float32).reshape((-1, 3)) 56 | 57 | # observations are enceded as a list of png images 58 | rgb = raw_images_to_array(list(features['rgb'].bytes_list.value)) 59 | depth = raw_images_to_array(list(features['depth'].bytes_list.value)) 60 | 61 | print ("True states (first three)") 62 | print (true_states[:3]) 63 | 64 | print ("Odometry (first three)") 65 | print (odometry[:3]) 66 | 67 | print("Plot map and first observation") 68 | 69 | # note: when printed as an image, map should be transposed 70 | plt.figure() 71 | plt.imshow(map_wall.transpose()) 72 | 73 | plt.figure() 74 | plt.imshow(rgb[0]) 75 | 76 | plt.show() 77 | 78 | if input("proceed?") != 'y': 79 | break 80 | 81 | 82 | if __name__ == '__main__': 83 | if len(sys.argv) < 2: 84 | print ("Usage: display_data.py xxx.tfrecords") 85 | exit() 86 | 87 | display_data(sys.argv[1]) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os, tqdm 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | import pfnet 10 | from arguments import parse_args 11 | from preprocess import get_dataflow 12 | 13 | try: 14 | import ipdb as pdb 15 | except Exception: 16 | import pdb 17 | 18 | 19 | def run_evaluation(params): 20 | """ Run evaluation with the parsed arguments """ 21 | 22 | # overwrite for evaluation 23 | params.batchsize = 1 24 | params.bptt_steps = params.trajlen 25 | 26 | with tf.Graph().as_default(): 27 | if params.seed is not None: 28 | tf.set_random_seed(params.seed) 29 | 30 | # test data and network 31 | with tf.variable_scope(tf.get_variable_scope(), reuse=False): 32 | test_data, num_test_samples = get_dataflow(params.testfiles, params, is_training=False) 33 | test_brain = pfnet.PFNet(inputs=test_data[1:], labels=test_data[0], params=params, is_training=False) 34 | 35 | # Add the variable initializer Op. 36 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 37 | 38 | # Create a saver for writing training checkpoints. 39 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=3) 40 | 41 | # Create a session for running Ops on the Graph. 42 | os.environ["CUDA_VISIBLE_DEVICES"] = "%d"%int(params.gpu) 43 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 44 | sess_config.gpu_options.allow_growth = True 45 | 46 | # training session 47 | with tf.Session(config=sess_config) as sess: 48 | 49 | sess.run(init_op) 50 | 51 | # load model from checkpoint file 52 | if params.load: 53 | print("Loading model from " + params.load) 54 | saver.restore(sess, params.load) 55 | 56 | coord = tf.train.Coordinator() 57 | threads = tf.train.start_queue_runners(coord=coord) 58 | 59 | mse_list = [] # mean squared error 60 | success_list = [] # true for successful localization 61 | 62 | try: 63 | for step_i in tqdm.tqdm(range(num_test_samples)): 64 | all_distance2, _ = sess.run([test_brain.all_distance2_op, test_brain.update_state_op]) 65 | 66 | # we have squared differences along the trajectory 67 | mse = np.mean(all_distance2[0]) 68 | mse_list.append(mse) 69 | 70 | # localization is successfull if the rmse error is below 1m for the last 25% of the trajectory 71 | successful = np.all(all_distance2[0][-params.trajlen//4:] < 1.0 ** 2) # below 1 meter 72 | success_list.append(successful) 73 | 74 | except KeyboardInterrupt: 75 | pass 76 | 77 | except tf.errors.OutOfRangeError: 78 | print("data exhausted") 79 | 80 | finally: 81 | coord.request_stop() 82 | coord.join(threads) 83 | 84 | # report results 85 | mean_rmse = np.mean(np.sqrt(mse_list)) 86 | total_rmse = np.sqrt(np.mean(mse_list)) 87 | print ("Mean RMSE (average RMSE per trajectory) = %fcm"%(mean_rmse * 100)) 88 | print("Overall RMSE (reported value) = %fcm" % (total_rmse * 100)) 89 | print("Success rate = %f%%" % (np.mean(np.array(success_list, 'i')) * 100)) 90 | 91 | 92 | if __name__ == '__main__': 93 | params = parse_args() 94 | 95 | run_evaluation(params) -------------------------------------------------------------------------------- /figs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/figs/architecture.png -------------------------------------------------------------------------------- /figs/trajectory.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/figs/trajectory.jpg -------------------------------------------------------------------------------- /pfnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | from transformer.spatial_transformer import transformer 8 | from utils.network_layers import conv2_layer, locallyconn2_layer, dense_layer 9 | 10 | 11 | class PFCell(tf.nn.rnn_cell.RNNCell): 12 | """ 13 | PF-net for localization implemented with the RNN interface. 14 | Implements the particle set update, the observation and transition models. 15 | Cell inputs: observation, odometry 16 | Cell states: particle_states, particle_weights 17 | Cell outputs: particle_states, particle_weights 18 | """ 19 | 20 | def __init__(self, global_maps, params, batch_size, num_particles): 21 | """ 22 | :param global_maps: tensorflow op (batch, None, None, ch), global maps input. Since the map is fixed 23 | through the trajectory it can be input to the cell here, instead of part of the cell input. 24 | :param params: parsed arguments 25 | :param batch_size: int, minibatch size 26 | :param num_particles: number of particles 27 | """ 28 | super(PFCell, self).__init__() 29 | self.global_maps = global_maps 30 | self.params = params 31 | self.batch_size = batch_size 32 | self.num_particles = num_particles 33 | 34 | self.states_shape = (batch_size, num_particles, 3) 35 | self.weights_shape = (batch_size, num_particles, ) 36 | 37 | @property 38 | def state_size(self): 39 | return (tf.TensorShape(self.states_shape[1:]), tf.TensorShape(self.weights_shape[1:])) 40 | 41 | @property 42 | def output_size(self): 43 | return (tf.TensorShape(self.states_shape[1:]), tf.TensorShape(self.weights_shape[1:])) 44 | 45 | def __call__(self, inputs, state, scope=None): 46 | """ 47 | Implements a particle update. 48 | :param inputs: observation (batch, 56, 56, ch), odometry (batch, 3). 49 | observation is the sensor reading at time t, odometry is the relative motion from time t to time t+1 50 | :param state: particle states (batch, K, 3), particle weights (batch, K). 51 | weights are assumed to be in log space and they can be unnormalized 52 | :param scope: not used, only kept for the interface. Ops will be created in the current scope. 53 | :return: outputs, state 54 | outputs: particle states and weights after the observation update, but before the transition update 55 | state: updated particle states and weights after both observation and transition updates 56 | """ 57 | with tf.variable_scope(tf.get_variable_scope()): 58 | particle_states, particle_weights = state 59 | observation, odometry = inputs 60 | 61 | # observation update 62 | lik = self.observation_model(self.global_maps, particle_states, observation) 63 | particle_weights += lik # unnormalized 64 | 65 | # resample 66 | if self.params.resample: 67 | particle_states, particle_weights = self.resample( 68 | particle_states, particle_weights, alpha=self.params.alpha_resample_ratio) 69 | 70 | # construct output before motion update 71 | outputs = particle_states, particle_weights 72 | 73 | # motion update. this will only affect the particle state input at the next step 74 | particle_states = self.transition_model(particle_states, odometry) 75 | 76 | # construct new state 77 | state = particle_states, particle_weights 78 | 79 | return outputs, state 80 | 81 | def transition_model(self, particle_states, odometry): 82 | """ 83 | Implements a stochastic transition model for localization. 84 | :param particle_states: tf op (batch, K, 3), particle states before the update. 85 | :param odometry: tf op (batch, 3), odometry reading, relative motion in the robot coordinate frame 86 | :return: particle_states updated with the odometry and optionally transition noise 87 | """ 88 | translation_std = self.params.transition_std[0] / self.params.map_pixel_in_meters # in pixels 89 | rotation_std = self.params.transition_std[1] # in radians 90 | 91 | with tf.name_scope('transition'): 92 | part_x, part_y, part_th = tf.unstack(particle_states, axis=-1, num=3) 93 | 94 | odometry = tf.expand_dims(odometry, axis=1) 95 | odom_x, odom_y, odom_th = tf.unstack(odometry, axis=-1, num=3) 96 | 97 | noise_th = tf.random_normal(part_th.get_shape(), mean=0.0, stddev=1.0) * rotation_std 98 | 99 | # add orientation noise before translation 100 | part_th += noise_th 101 | 102 | cos_th = tf.cos(part_th) 103 | sin_th = tf.sin(part_th) 104 | delta_x = cos_th * odom_x - sin_th * odom_y 105 | delta_y = sin_th * odom_x + cos_th * odom_y 106 | delta_th = odom_th 107 | 108 | delta_x += tf.random_normal(delta_x.get_shape(), mean=0.0, stddev=1.0) * translation_std 109 | delta_y += tf.random_normal(delta_y.get_shape(), mean=0.0, stddev=1.0) * translation_std 110 | 111 | return tf.stack([part_x+delta_x, part_y+delta_y, part_th+delta_th], axis=-1) 112 | 113 | def observation_model(self, global_maps, particle_states, observation): 114 | """ 115 | Implements a discriminative observation model for localization. 116 | The model transforms the single global map to local maps for each particle, where a local map is a local 117 | view from the state defined by the particle. 118 | :param global_maps: tf op (batch, None, None, ch), global maps input. 119 | Assumes a scaling 0..2 where 0 is occupied, 2 is free space. 120 | :param particle_states: tf op (batch, K, 3), particle states before the update 121 | :param observation: tf op (batch, 56, 56, ch), image observation from a rgb, depth, or rgbd camera. 122 | :return: tf op (batch, K) particle likelihoods in the log space, unnormalized 123 | """ 124 | 125 | # transform global maps to local maps 126 | local_maps = self.transform_maps(global_maps, particle_states, (28, 28)) 127 | 128 | # rescale from 0..2 to -1..1. This is not actually necessary. 129 | local_maps = -(local_maps - 1) 130 | # flatten batch and particle dimensions 131 | local_maps = tf.reshape(local_maps, 132 | [self.batch_size * self.num_particles] + local_maps.shape.as_list()[2:]) 133 | 134 | # get features from the map 135 | map_features = self.map_features(local_maps) 136 | 137 | # get features from the observation 138 | obs_features = self.observation_features(observation) 139 | 140 | # tile observation features 141 | obs_features = tf.tile(tf.expand_dims(obs_features, axis=1), [1, self.num_particles, 1, 1, 1]) 142 | obs_features = tf.reshape(obs_features, 143 | [self.batch_size * self.num_particles] + obs_features.shape.as_list()[2:]) 144 | 145 | # sanity check 146 | assert obs_features.shape.as_list()[:-1] == map_features.shape.as_list()[:-1] 147 | 148 | # merge features and process further 149 | joint_features = tf.concat([map_features, obs_features], axis=-1) 150 | joint_features = self.joint_matrix_features(joint_features) 151 | 152 | # reshape to a vector and process further 153 | joint_features = tf.reshape(joint_features, (self.batch_size * self.num_particles, -1)) 154 | lik = self.joint_vector_features(joint_features) 155 | lik = tf.reshape(lik, [self.batch_size, self.num_particles]) 156 | 157 | return lik 158 | 159 | @staticmethod 160 | def resample(particle_states, particle_weights, alpha): 161 | """ 162 | Implements (soft)-resampling of particles. 163 | :param particle_states: tf op (batch, K, 3), particle states 164 | :param particle_weights: tf op (batch, K), unnormalized particle weights in log space 165 | :param alpha: float, trade-off parameter for soft-resampling. alpha == 1 corresponds to standard, 166 | hard-resampling. alpha == 0 corresponds to sampling particles uniformly, ignoring their weights. 167 | :return: particle_states, particle_weights 168 | """ 169 | with tf.name_scope('resample'): 170 | assert 0.0 < alpha <= 1.0 171 | batch_size, num_particles = particle_states.get_shape().as_list()[:2] 172 | 173 | # normalize 174 | particle_weights = particle_weights - tf.reduce_logsumexp(particle_weights, axis=-1, keep_dims=True) 175 | 176 | uniform_weights = tf.constant(-np.log(num_particles), shape=(batch_size, num_particles), dtype=tf.float32) 177 | 178 | # build sampling distribution, q(s), and update particle weights 179 | if alpha < 1.0: 180 | # soft resampling 181 | q_weights = tf.stack([particle_weights + np.log(alpha), uniform_weights + np.log(1.0-alpha)], axis=-1) 182 | q_weights = tf.reduce_logsumexp(q_weights, axis=-1, keep_dims=False) 183 | q_weights = q_weights - tf.reduce_logsumexp(q_weights, axis=-1, keep_dims=True) # normalized 184 | 185 | particle_weights = particle_weights - q_weights # this is unnormalized 186 | else: 187 | # hard resampling. this will produce zero gradients 188 | q_weights = particle_weights 189 | particle_weights = uniform_weights 190 | 191 | # sample particle indices according to q(s) 192 | indices = tf.cast(tf.multinomial(q_weights, num_particles), tf.int32) # shape: (batch_size, num_particles) 193 | 194 | # index into particles 195 | helper = tf.range(0, batch_size*num_particles, delta=num_particles, dtype=tf.int32) # (batch, ) 196 | indices = indices + tf.expand_dims(helper, axis=1) 197 | 198 | particle_states = tf.reshape(particle_states, (batch_size * num_particles, 3)) 199 | particle_states = tf.gather(particle_states, indices=indices, axis=0) # (batch_size, num_particles, 3) 200 | 201 | particle_weights = tf.reshape(particle_weights, (batch_size * num_particles, )) 202 | particle_weights = tf.gather(particle_weights, indices=indices, axis=0) # (batch_size, num_particles,) 203 | 204 | return particle_states, particle_weights 205 | 206 | @staticmethod 207 | def transform_maps(global_maps, particle_states, local_map_size): 208 | """ 209 | Implements global to local map transformation 210 | :param global_maps: tf op (batch, None, None, ch) global map input 211 | :param particle_states: tf op (batch, K, 3) particle states that define local views for the transformation 212 | :param local_map_size: tuple, (height, widght), size of the output local maps 213 | :return: tf op (batch, K, local_map_size[0], local_map_size[1], ch). local maps, each shows a 214 | different transformation of the global map corresponding to the particle states 215 | """ 216 | batch_size, num_particles = particle_states.get_shape().as_list()[:2] 217 | total_samples = batch_size * num_particles 218 | flat_states = tf.reshape(particle_states, [total_samples, 3]) 219 | 220 | # define some helper variables 221 | input_shape = tf.shape(global_maps) 222 | global_height = tf.cast(input_shape[1], tf.float32) 223 | global_width = tf.cast(input_shape[2], tf.float32) 224 | height_inverse = 1.0 / global_height 225 | width_inverse = 1.0 / global_width 226 | # at tf1.6 matmul still does not support broadcasting, so we need full vectors 227 | zero = tf.constant(0, dtype=tf.float32, shape=(total_samples, )) 228 | one = tf.constant(1, dtype=tf.float32, shape=(total_samples, )) 229 | 230 | # the global map will be down-scaled by some factor 231 | window_scaler = 8.0 232 | 233 | # normalize orientations and precompute cos and sin functions 234 | theta = -flat_states[:, 2] - 0.5 * np.pi 235 | costheta = tf.cos(theta) 236 | sintheta = tf.sin(theta) 237 | 238 | # construct an affine transformation matrix step-by-step. 239 | # 1, translate the global map s.t. the center is at the particle state 240 | translate_x = (flat_states[:, 0] * width_inverse * 2.0) - 1.0 241 | translate_y = (flat_states[:, 1] * height_inverse * 2.0) - 1.0 242 | 243 | transm1 = tf.stack((one, zero, translate_x, zero, one, translate_y, zero, zero, one), axis=1) 244 | transm1 = tf.reshape(transm1, (total_samples, 3, 3)) 245 | 246 | # 2, rotate map s.t. the orientation matches that of the particles 247 | rotm = tf.stack((costheta, sintheta, zero, -sintheta, costheta, zero, zero, zero, one), axis=1) 248 | rotm = tf.reshape(rotm, (total_samples, 3, 3)) 249 | 250 | # 3, scale down the map 251 | scale_x = tf.fill((total_samples, ), float(local_map_size[1] * window_scaler) * width_inverse) 252 | scale_y = tf.fill((total_samples, ), float(local_map_size[0] * window_scaler) * height_inverse) 253 | 254 | scalem = tf.stack((scale_x, zero, zero, zero, scale_y, zero, zero, zero, one), axis=1) 255 | scalem = tf.reshape(scalem, (total_samples, 3, 3)) 256 | 257 | # 4, translate the local map s.t. the particle defines the bottom mid-point instead of the center 258 | translate_y2 = tf.constant(-1.0, dtype=tf.float32, shape=(total_samples, )) 259 | 260 | transm2 = tf.stack((one, zero, zero, zero, one, translate_y2, zero, zero, one), axis=1) 261 | transm2 = tf.reshape(transm2, (total_samples, 3, 3)) 262 | 263 | # chain the transformation matrices into a single one: translate + rotate + scale + translate 264 | transform_m = tf.matmul(tf.matmul(tf.matmul(transm1, rotm), scalem), transm2) 265 | 266 | # reshape to the format expected by the spatial transform network 267 | transform_m = tf.reshape(transform_m[:, :2], (batch_size, num_particles, 6)) 268 | 269 | # do the image transformation using the spatial transform network 270 | # iterate over particle to avoid tiling large global maps 271 | output_list = [] 272 | for i in range(num_particles): 273 | output_list.append(transformer(global_maps, transform_m[:,i], local_map_size)) 274 | 275 | local_maps = tf.stack(output_list, axis=1) 276 | 277 | # set shape information that is lost in the spatial transform network 278 | local_maps = tf.reshape(local_maps, (batch_size, num_particles, local_map_size[0], local_map_size[1], 279 | global_maps.shape.as_list()[-1])) 280 | 281 | return local_maps 282 | 283 | @staticmethod 284 | def map_features(local_maps): 285 | assert local_maps.get_shape().as_list()[1:3] == [28, 28] 286 | data_format = 'channels_last' 287 | 288 | with tf.variable_scope("map"): 289 | x = local_maps 290 | layer_i = 1 291 | convs = [ 292 | conv2_layer( 293 | 24, (3, 3), activation=None, padding='same', data_format=data_format, 294 | use_bias=True, layer_i=layer_i)(x), 295 | conv2_layer( 296 | 16, (5, 5), activation=None, padding='same', data_format=data_format, 297 | use_bias=True, layer_i=layer_i)(x), 298 | conv2_layer( 299 | 8, (7, 7), activation=None, padding='same', data_format=data_format, 300 | use_bias=True, layer_i=layer_i)(x), 301 | conv2_layer( 302 | 8, (7, 7), activation=None, padding='same', data_format=data_format, 303 | dilation_rate=(2, 2), use_bias=True, layer_i=layer_i)(x), 304 | conv2_layer( 305 | 8, (7, 7), activation=None, padding='same', data_format=data_format, 306 | dilation_rate=(3, 3), use_bias=True, layer_i=layer_i)(x), 307 | ] 308 | x = tf.concat(convs, axis=-1) 309 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu) 310 | assert x.get_shape().as_list()[1:4] == [28, 28, 64] 311 | # (28x28x64) 312 | 313 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="same") 314 | 315 | layer_i+=1 316 | convs = [ 317 | conv2_layer( 318 | 4, (3, 3), activation=None, padding='same', data_format=data_format, 319 | use_bias=True, layer_i=layer_i)(x), 320 | conv2_layer( 321 | 4, (5, 5), activation=None, padding='same', data_format=data_format, 322 | use_bias=True, layer_i=layer_i)(x), 323 | ] 324 | x = tf.concat(convs, axis=-1) 325 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu) 326 | 327 | return x # (14x14x8) 328 | 329 | @staticmethod 330 | def observation_features(observation): 331 | data_format = 'channels_last' 332 | with tf.variable_scope("observation"): 333 | x = observation 334 | layer_i = 1 335 | convs = [ 336 | conv2_layer( 337 | 128, (3, 3), activation=None, padding='same', data_format=data_format, 338 | use_bias=True, layer_i=layer_i)(x), 339 | conv2_layer( 340 | 128, (5, 5), activation=None, padding='same', data_format=data_format, 341 | use_bias=True, layer_i=layer_i)(x), 342 | conv2_layer( 343 | 64, (5, 5), activation=None, padding='same', data_format=data_format, 344 | dilation_rate=(2, 2), use_bias=True, layer_i=layer_i)(x), 345 | conv2_layer( 346 | 64, (5, 5), activation=None, padding='same', data_format=data_format, 347 | dilation_rate=(4, 4), use_bias=True, layer_i=layer_i)(x), 348 | ] 349 | x = tf.concat(convs, axis=-1) 350 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="same") 351 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu) 352 | 353 | assert x.get_shape().as_list()[1:4] == [28, 28, 384] 354 | 355 | layer_i += 1 356 | x = conv2_layer( 357 | 16, (3, 3), activation=None, padding='same', data_format=data_format, 358 | use_bias=True, layer_i=layer_i)(x) 359 | 360 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="same") 361 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu) 362 | assert x.get_shape().as_list()[1:4] == [14, 14, 16] 363 | 364 | return x # (14,14,16) 365 | 366 | @staticmethod 367 | def joint_matrix_features(joint_matrix): 368 | assert joint_matrix.get_shape().as_list()[1:4] == [14, 14, 24] 369 | data_format = 'channels_last' 370 | 371 | with tf.variable_scope("joint"): 372 | x = joint_matrix 373 | layer_i = 1 374 | 375 | # pad manually to match different kernel sizes 376 | x_pad1 = tf.pad(x, paddings=tf.constant([[0, 0], [1, 1,], [1, 1], [0, 0]])) 377 | convs = [ 378 | locallyconn2_layer( 379 | 8, (3, 3), activation='relu', padding='valid', data_format=data_format, 380 | use_bias=True, layer_i=layer_i)(x), 381 | locallyconn2_layer( 382 | 8, (5, 5), activation='relu', padding='valid', data_format=data_format, 383 | use_bias=True, layer_i=layer_i)(x_pad1), 384 | ] 385 | x = tf.concat(convs, axis=-1) 386 | 387 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="valid") 388 | assert x.get_shape().as_list()[1:4] == [5, 5, 16] 389 | 390 | return x # (5, 5, 16) 391 | 392 | @staticmethod 393 | def joint_vector_features(joint_vector): 394 | with tf.variable_scope("joint"): 395 | x = joint_vector 396 | x = dense_layer(1, activation=None, use_bias=True, name='fc1')(x) 397 | return x 398 | 399 | 400 | class PFNet(object): 401 | """ Implements PF-net. Unrolls the PF-net RNN cell and defines losses and training ops.""" 402 | def __init__(self, inputs, labels, params, is_training): 403 | """ 404 | Calling this will create all tf ops for PF-net. 405 | :param inputs: list of tf ops, the inputs to PF-net. Assumed to have the following elements: 406 | global_maps, init_particle_states, observations, odometries, is_first_step 407 | :param labels: tf op, labels for training. Assumed to be the true states along the trajectory. 408 | :param params: parsed arguments 409 | :param is_training: bool, true for training. 410 | """ 411 | self.params = params 412 | 413 | # define ops to be accessed conveniently from outside 414 | self.outputs = [] 415 | self.hidden_states = [] 416 | 417 | self.train_loss_op = None 418 | self.valid_loss_op = None 419 | self.all_distance2_op = None 420 | 421 | self.global_step_op = None 422 | self.learning_rate_op = None 423 | self.train_op = None 424 | self.update_state_op = tf.constant(0) 425 | 426 | # build the network. this will generate the ops defined above 427 | self.build(inputs, labels, is_training) 428 | 429 | def build(self, inputs, labels, is_training): 430 | """ 431 | Unroll the PF-net RNN cell and create loss ops and optionally, training ops 432 | """ 433 | self.outputs = self.build_rnn(*inputs) 434 | 435 | self.build_loss_op(self.outputs[0], self.outputs[1], true_states=labels) 436 | 437 | if is_training: 438 | self.build_train_op() 439 | 440 | def save_state(self, sess): 441 | """ 442 | Returns a list, the hidden state of PF-net, i.e. the particle states and particle weights. 443 | The output can be used with load_state to restore the current hidden state. 444 | """ 445 | return sess.run(self.hidden_states) 446 | 447 | def load_state(self, sess, saved_state): 448 | """ 449 | Overwrite the hidden state of PF-net to that of saved_state. 450 | """ 451 | return sess.run(self.hidden_states, 452 | feed_dict={self.hidden_states[i]: saved_state[i] for i in range(len(self.hidden_states))}) 453 | 454 | def build_loss_op(self, particle_states, particle_weights, true_states): 455 | """ 456 | Create tf ops for various losses. This should be called only once with is_training=True. 457 | """ 458 | assert particle_weights.get_shape().ndims == 3 459 | 460 | lin_weights = tf.nn.softmax(particle_weights, dim=-1) 461 | 462 | true_coords = true_states[:, :, :2] 463 | mean_coords = tf.reduce_sum(tf.multiply(particle_states[:,:,:,:2], lin_weights[:,:,:,None]), axis=2) 464 | coord_diffs = mean_coords - true_coords 465 | 466 | # convert from pixel coordinates to meters 467 | coord_diffs *= self.params.map_pixel_in_meters 468 | 469 | # coordinate loss component: (x-x')^2 + (y-y')^2 470 | loss_coords = tf.reduce_sum(tf.square(coord_diffs), axis=2) 471 | 472 | true_orients = true_states[:, :, 2] 473 | orient_diffs = particle_states[:, :, :, 2] - true_orients[:,:,None] 474 | # normalize between -pi..+pi 475 | orient_diffs = tf.mod(orient_diffs + np.pi, 2*np.pi) - np.pi 476 | # orintation loss component: (sum_k[(theta_k-theta')*weight_k] )^2 477 | loss_orient = tf.square(tf.reduce_sum(orient_diffs * lin_weights, axis=2)) 478 | 479 | # combine translational and orientation losses 480 | loss_combined = loss_coords + 0.36 * loss_orient 481 | loss_pred = tf.reduce_mean(loss_combined, name='prediction_loss') 482 | 483 | # add L2 regularization loss 484 | loss_reg = tf.multiply(tf.losses.get_regularization_loss(), self.params.l2scale, name='l2') 485 | loss_total = tf.add_n([loss_pred, loss_reg], name="training_loss") 486 | 487 | self.all_distance2_op = loss_coords 488 | self.valid_loss_op = loss_pred 489 | self.train_loss_op = loss_total 490 | 491 | return loss_total 492 | 493 | def build_train_op(self): 494 | """ Create optimizer and train op. This should be called only once. """ 495 | 496 | # make sure this is only called once 497 | assert self.train_op is None and self.global_step_op is None and self.learning_rate_op is None 498 | 499 | # global step and learning rate 500 | with tf.device("/cpu:0"): 501 | self.global_step_op = tf.get_variable( 502 | initializer=tf.constant_initializer(0.0), shape=(), trainable=False, name='global_step',) 503 | self.learning_rate_op = tf.train.exponential_decay( 504 | self.params.learningrate, self.global_step_op, decay_steps=1, decay_rate=self.params.decayrate, 505 | staircase=True, name="learning_rate") 506 | 507 | # create gradient descent optimizer with the given learning rate. 508 | optimizer = tf.train.RMSPropOptimizer(self.learning_rate_op, decay=0.9) 509 | 510 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 511 | self.train_op = optimizer.minimize(self.train_loss_op, global_step=None, var_list=tf.trainable_variables()) 512 | 513 | return self.train_op 514 | 515 | def build_rnn(self, global_maps, init_particle_states, observations, odometries, is_first_step): 516 | """ 517 | Unroll the PF-net RNN cell through time. Input arguments are the inputs to PF-net. The time dependent 518 | fields are expected to be broken into fixed-length segments defined by params.bptt_steps 519 | """ 520 | batch_size, trajlen = observations.shape.as_list()[:2] 521 | num_particles = init_particle_states.shape.as_list()[1] 522 | global_map_ch = global_maps.shape.as_list()[-1] 523 | 524 | init_particle_weights = tf.constant(np.log(1.0/float(num_particles)), 525 | shape=(batch_size, num_particles), dtype=tf.float32) 526 | 527 | # create hidden state variable 528 | assert len(self.hidden_states) == 0 # no hidden state should be set before 529 | self.hidden_states = [ 530 | tf.get_variable("particle_states", shape=init_particle_states.get_shape(), 531 | dtype=init_particle_states.dtype, initializer=tf.constant_initializer(0), trainable=False), 532 | tf.get_variable("particle_weights", shape=init_particle_weights.get_shape(), 533 | dtype=init_particle_weights.dtype, initializer=tf.constant_initializer(0), trainable=False), 534 | ] 535 | 536 | # choose state for the current trajectory segment 537 | state = tf.cond(is_first_step, 538 | true_fn=lambda: (init_particle_states, init_particle_weights), 539 | false_fn=lambda: tuple(self.hidden_states)) 540 | 541 | with tf.variable_scope("rnn"): 542 | # hack to create variables on GPU 543 | dummy_cell_func = PFCell( 544 | global_maps=tf.zeros((1, 1, 1, global_map_ch), dtype=global_maps.dtype), 545 | params=self.params, batch_size=1, num_particles=1) 546 | 547 | dummy_cell_func( 548 | (tf.zeros([1]+observations.get_shape().as_list()[2:], dtype=observations.dtype), # observation 549 | tf.zeros([1, 3], dtype=odometries.dtype)), # odometry 550 | (tf.zeros([1, 1, 3], dtype=init_particle_states.dtype), # particle_states 551 | tf.zeros([1, 1], dtype=init_particle_weights.dtype))) # particle_weights 552 | 553 | # variables are now created. set reuse 554 | tf.get_variable_scope().reuse_variables() 555 | 556 | # unroll real steps using the variables already created 557 | cell_func = PFCell(global_maps=global_maps, params=self.params, batch_size=batch_size, 558 | num_particles=num_particles) 559 | 560 | outputs, state = tf.nn.dynamic_rnn(cell=cell_func, 561 | inputs=(observations, odometries), 562 | initial_state=state, 563 | swap_memory=True, 564 | time_major=False, 565 | parallel_iterations=1, 566 | scope=tf.get_variable_scope()) 567 | 568 | particle_states, particle_weights = outputs 569 | 570 | # define an op to update the hidden state, i.e. the particle states and particle weights. 571 | # this should be evaluated after every input 572 | with tf.control_dependencies([particle_states, particle_weights]): 573 | self.update_state_op = tf.group( 574 | *(self.hidden_states[i].assign(state[i]) for i in range(len(self.hidden_states)))) 575 | 576 | return particle_states, particle_weights 577 | 578 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import tensorflow as tf 7 | import numpy as np 8 | import cv2 9 | 10 | from tensorpack import dataflow 11 | from tensorpack.dataflow.base import RNGDataFlow, ProxyDataFlow 12 | 13 | try: 14 | import ipdb as pdb 15 | except Exception: 16 | import pdb 17 | 18 | 19 | def decode_image(img_str, resize=None): 20 | """ 21 | Decode image from tfrecord data 22 | :param img_str: image encoded as a png in a string 23 | :param resize: tuple width two elements that defines the new size of the image. optional 24 | :return: image as a numpy array 25 | """ 26 | nparr = np.fromstring(img_str, np.uint8) 27 | img_str = cv2.imdecode(nparr, -1) 28 | if resize is not None: 29 | img_str = cv2.resize(img_str, resize) 30 | return img_str 31 | 32 | 33 | def raw_images_to_array(images): 34 | """ 35 | Decode and normalize multiple images from tfrecord data 36 | :param images: list of images encoded as a png in a string 37 | :return: a numpy array of size (N, 56, 56, channels), normalized for training 38 | """ 39 | image_list = [] 40 | for image_str in images: 41 | image = decode_image(image_str, (56, 56)) 42 | image = scale_observation(np.atleast_3d(image.astype(np.float32))) 43 | image_list.append(image) 44 | 45 | return np.stack(image_list, axis=0) 46 | 47 | 48 | def scale_observation(x): 49 | """ 50 | Normalizes observation input, either an rgb image or a depth image 51 | :param x: observation input as numpy array, either an rgb image or a depth image 52 | :return: numpy array, a normalized observation 53 | """ 54 | if x.ndim == 2 or x.shape[2] == 1: # depth 55 | return x * (2.0 / 100.0) - 1.0 56 | else: # rgb 57 | return x * (2.0/255.0) - 1.0 58 | 59 | 60 | def bounding_box(img): 61 | """ 62 | Bounding box of non-zeros in an array (inclusive). Used with 2D maps 63 | :param img: numpy array 64 | :return: inclusive bounding box indices: top_row, bottom_row, leftmost_column, rightmost_column 65 | """ 66 | # helper function to 67 | rows = np.any(img, axis=1) 68 | cols = np.any(img, axis=0) 69 | rmin, rmax = np.where(rows)[0][[0, -1]] 70 | cmin, cmax = np.where(cols)[0][[0, -1]] 71 | 72 | return rmin, rmax, cmin, cmax 73 | 74 | 75 | class BatchDataWithPad(dataflow.BatchData): 76 | """ 77 | Stacks datapoints into batches. Selected elements can be padded to the same size in each batch. 78 | """ 79 | 80 | def __init__(self, ds, batch_size, remainder=False, use_list=False, padded_indices=()): 81 | """ 82 | :param ds: input dataflow. Same as BatchData 83 | :param batch_size: mini batch size. Same as BatchData 84 | :param remainder: if data is not enough to form a full batch, it makes a smaller batch when true. 85 | Same as BatchData. 86 | :param use_list: if True, components will contain a list of datapoints instead of creating a new numpy array. 87 | Same as BatchData. 88 | :param padded_indices: list of filed indices for which all elements will be padded with zeros to mach 89 | the largest in the batch. Each batch may produce a different size datapoint. 90 | """ 91 | super(BatchDataWithPad, self).__init__(ds, batch_size, remainder, use_list) 92 | self.padded_indices = padded_indices 93 | 94 | def get_data(self): 95 | """ 96 | Yields: Batched data by stacking each component on an extra 0th dimension. 97 | """ 98 | holder = [] 99 | for data in self.ds.get_data(): 100 | holder.append(data) 101 | if len(holder) == self.batch_size: 102 | yield BatchDataWithPad._aggregate_batch(holder, self.use_list, self.padded_indices) 103 | del holder[:] 104 | if self.remainder and len(holder) > 0: 105 | yield BatchDataWithPad._aggregate_batch(holder, self.use_list, self.padded_indices) 106 | 107 | @staticmethod 108 | def _aggregate_batch(data_holder, use_list=False, padded_indices=()): 109 | """ 110 | Re-implement the parent function with the option to pad selected fields to the largest in the batch. 111 | """ 112 | assert not use_list # cannot match shape if they must be treated as lists 113 | size = len(data_holder[0]) 114 | result = [] 115 | for k in range(size): 116 | dt = data_holder[0][k] 117 | if type(dt) in [int, bool]: 118 | tp = 'int32' 119 | elif type(dt) == float: 120 | tp = 'float32' 121 | else: 122 | try: 123 | tp = dt.dtype 124 | except AttributeError: 125 | raise TypeError("Unsupported type to batch: {}".format(type(dt))) 126 | try: 127 | if k in padded_indices: 128 | # pad this field 129 | shapes = np.array([x[k].shape for x in data_holder], 'i') # assumes ndim are the same for all 130 | assert shapes.shape[1] == 3 # only supports 3D arrays for now, e.g. images (height, width, ch) 131 | matching_shape = shapes.max(axis=0).tolist() 132 | new_data = np.zeros([shapes.shape[0]] + matching_shape, dtype=tp) 133 | for i in range(len(data_holder)): 134 | shape = data_holder[i][k].shape 135 | new_data[i, :shape[0], :shape[1], :shape[2]] = data_holder[i][k] 136 | result.append(new_data) 137 | else: 138 | # no need to pad this field, simply create batch 139 | result.append(np.asarray([x[k] for x in data_holder], dtype=tp)) 140 | except Exception as e: 141 | # exception handling. same as in parent class 142 | pdb.set_trace() 143 | dataflow.logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") 144 | if isinstance(dt, np.ndarray): 145 | s = dataflow.pprint.pformat([x[k].shape for x in data_holder]) 146 | dataflow.logger.error("Shape of all arrays to be batched: " + s) 147 | try: 148 | # open an ipython shell if possible 149 | import IPython as IP; IP.embed() # noqa 150 | except ImportError: 151 | pass 152 | return result 153 | 154 | 155 | class BreakForBPTT(ProxyDataFlow): 156 | """ 157 | Breaks long trajectories into multiple smaller segments for training with BPTT. 158 | Adds an extra field for indicating the first segment of a trajectory. 159 | """ 160 | def __init__(self, ds, timed_indices, trajlen, bptt_steps): 161 | """ 162 | :param ds: input dataflow 163 | :param timed_indices: field indices for which the second dimension corresponds to timestep along the trajectory 164 | :param trajlen: full length of trajectories 165 | :param bptt_steps: segment length, number of backprop steps for BPTT. Must be an integer divisor of trajlen 166 | """ 167 | super(BreakForBPTT, self).__init__(ds) 168 | self.timed_indiced = timed_indices 169 | self.bptt_steps = bptt_steps 170 | 171 | assert trajlen % bptt_steps == 0 172 | self.num_segments = trajlen // bptt_steps 173 | 174 | def size(self): 175 | return self.ds.size() * self.num_segments 176 | 177 | def get_data(self): 178 | """ 179 | Yields multiple datapoints per input datapoints corresponding segments of the trajectory. 180 | Adds an extra field for indicating the first segment of a trajectory. 181 | """ 182 | 183 | for data in self.ds.get_data(): 184 | for split_i in range(self.num_segments): 185 | new_data = [] 186 | for i in range(len(data)): 187 | if i in self.timed_indiced: 188 | new_data.append(data[i][:, split_i*self.bptt_steps:(split_i+1)*self.bptt_steps]) 189 | else: 190 | new_data.append(data[i]) 191 | 192 | new_data.append((split_i == 0)) 193 | 194 | yield new_data 195 | 196 | 197 | class House3DTrajData(RNGDataFlow): 198 | """ 199 | Process tfrecords data of House3D trajectories. Produces a dataflow with the following fields: 200 | true state, global map, initial particles, observations, odometries 201 | """ 202 | 203 | def __init__(self, files, mapmode, obsmode, trajlen, num_particles, init_particles_distr, init_particles_cov, 204 | seed=None): 205 | """ 206 | :param files: list of data file names. assumed to be tfrecords files 207 | :param mapmode: string, map type. Possible values: wall / wall-door / wall-roomtype / wall-door-roomtype 208 | :param obsmode: string, observation type. Possible values: rgb / depth / rgb-depth. Vrf is not yet supported 209 | :param trajlen: int, length of trajectories 210 | :param num_particles: int, number of particles 211 | :param init_particles_distr: string, type of initial particle distribution. 212 | Possible values: tracking / one-room. Does not support two-rooms and all-rooms yet. 213 | :param init_particles_cov: numpy array of shape (3,3), coveriance matrix for the initial particles. Ignored 214 | when init_particles_distr != 'tracking'. 215 | :param seed: int or None. Random seed will be fixed if not None. 216 | """ 217 | self.files = files 218 | self.mapmode = mapmode 219 | self.obsmode = obsmode 220 | self.trajlen = trajlen 221 | self.num_particles = num_particles 222 | self.init_particles_distr = init_particles_distr 223 | self.init_particles_cov = init_particles_cov 224 | self.seed = seed 225 | 226 | # count total number of entries 227 | count = 0 228 | for f in self.files: 229 | if not os.path.isfile(f): 230 | raise ValueError('Failed to find file: ' + f) 231 | record_iterator = tf.python_io.tf_record_iterator(f) 232 | for _ in record_iterator: 233 | count += 1 234 | self.count = count 235 | 236 | def size(self): 237 | return self.count 238 | 239 | def reset_state(self): 240 | """ Reset state. Fix numpy random seed if needed.""" 241 | super(House3DTrajData, self).reset_state() 242 | if self.seed is not None: 243 | np.random.seed(1) 244 | else: 245 | np.random.seed(self.rng.randint(0, 99999999)) 246 | 247 | def get_data(self): 248 | """ 249 | Yields datapoints, all numpy arrays, with the following fields. 250 | 251 | true states: (trajlen, 3). Second dimension corresponds to x, y, theta coordinates. 252 | 253 | global map: (n, m, ch). shape is different for each map. number of channels depend on the mapmode setting 254 | 255 | initial particles: (num_particles, 3) 256 | 257 | observations: (trajlen, 56, 56, ch) number of channels depend on the obsmode setting 258 | 259 | odometries: (trajlen, 3) relative motion in the robot coordinate frame 260 | """ 261 | for file in self.files: 262 | gen = tf.python_io.tf_record_iterator(file) 263 | for data_i, string_record in enumerate(gen): 264 | result = tf.train.Example.FromString(string_record) 265 | features = result.features.feature 266 | 267 | # process maps 268 | map_wall = self.process_wall_map(features['map_wall'].bytes_list.value[0]) 269 | global_map_list = [map_wall] 270 | if 'door' in self.mapmode: 271 | map_door = self.process_door_map(features['map_door'].bytes_list.value[0]) 272 | global_map_list.append(map_door) 273 | if 'roomtype' in self.mapmode: 274 | map_roomtype = self.process_roomtype_map(features['map_roomtype'].bytes_list.value[0]) 275 | global_map_list.append(map_roomtype) 276 | if self.init_particles_distr == 'tracking': 277 | map_roomid = None 278 | else: 279 | map_roomid = self.process_roomid_map(features['map_roomid'].bytes_list.value[0]) 280 | 281 | # input global map is a concatentation of semantic channels 282 | global_map = np.concatenate(global_map_list, axis=-1) 283 | 284 | # rescale to 0..2 range. this way zero padding will produce the equivalent of obstacles 285 | global_map = global_map.astype(np.float32) * (2.0 / 255.0) 286 | 287 | # process true states 288 | true_states = features['states'].bytes_list.value[0] 289 | true_states = np.frombuffer(true_states, np.float32).reshape((-1, 3)) 290 | 291 | # trajectory may be longer than what we use for training 292 | data_trajlen = true_states.shape[0] 293 | assert data_trajlen >= self.trajlen 294 | true_states = true_states[:self.trajlen] 295 | 296 | # process odometry 297 | odometry = features['odometry'].bytes_list.value[0] 298 | odometry = np.frombuffer(odometry, np.float32).reshape((-1, 3)) 299 | 300 | # process observations 301 | assert self.obsmode in ['rgb', 'depth', 'rgb-depth'] #TODO support for lidar 302 | if 'rgb' in self.obsmode: 303 | rgb = raw_images_to_array(list(features['rgb'].bytes_list.value)[:self.trajlen]) 304 | observation = rgb 305 | if 'depth' in self.obsmode: 306 | depth = raw_images_to_array(list(features['depth'].bytes_list.value)[:self.trajlen]) 307 | observation = depth 308 | if self.obsmode == 'rgb-depth': 309 | observation = np.concatenate((rgb, depth), axis=-1) 310 | 311 | # generate particle states 312 | init_particles = self.random_particles(true_states[0], self.init_particles_distr, 313 | self.init_particles_cov, self.num_particles, 314 | roomidmap=map_roomid, 315 | seed=self.get_sample_seed(self.seed, data_i), ) 316 | 317 | yield (true_states, global_map, init_particles, observation, odometry) 318 | 319 | def process_wall_map(self, wallmap_feature): 320 | floormap = np.atleast_3d(decode_image(wallmap_feature)) 321 | # transpose and invert 322 | floormap = 255 - np.transpose(floormap, axes=[1, 0, 2]) 323 | return floormap 324 | 325 | def process_door_map(self, doormap_feature): 326 | return self.process_wall_map(doormap_feature) 327 | 328 | def process_roomtype_map(self, roomtypemap_feature): 329 | binary_map = np.fromstring(roomtypemap_feature, np.uint8) 330 | binary_map = cv2.imdecode(binary_map, 2) # 16-bit image 331 | assert binary_map.dtype == np.uint16 and binary_map.ndim == 2 332 | # binary encoding from bit 0 .. 9 333 | 334 | room_map = np.zeros((binary_map.shape[0], binary_map.shape[1], 9), dtype=np.uint8) 335 | for i in range(9): 336 | room_map[:,:,i] = np.array((np.bitwise_and(binary_map, (1 << i)) > 0), dtype=np.uint8) 337 | room_map *= 255 338 | 339 | # transpose and invert 340 | room_map = np.transpose(room_map, axes=[1, 0, 2]) 341 | return room_map 342 | 343 | def process_roomid_map(self, roomidmap_feature): 344 | # this is not transposed, unlike other maps 345 | roomidmap = np.atleast_3d(decode_image(roomidmap_feature)) 346 | return roomidmap 347 | 348 | @staticmethod 349 | def random_particles(state, distr, particles_cov, num_particles, roomidmap, seed=None): 350 | """ 351 | Generate a random set of particles 352 | :param state: true state, numpy array of x,y,theta coordinates 353 | :param distr: string, type of distribution. Possible values: tracking / one-room. 354 | For 'tracking' the distribution is a Gaussian centered near the true state. 355 | For 'one-room' the distribution is uniform over states in the room defined by the true state. 356 | :param particles_cov: numpy array of shape (3,3), defines the covariance matrix if distr == 'tracking' 357 | :param num_particles: number of particles 358 | :param roomidmap: numpy array, map of room ids. Values define a unique room id for each pixel of the map. 359 | :param seed: int or None. If not None, the random seed will be fixed for generating the particle. 360 | The random state is restored to its original value. 361 | :return: numpy array of particles (num_particles, 3) 362 | """ 363 | assert distr in ["tracking", "one-room"] #TODO add support for two-room and all-room 364 | 365 | particles = np.zeros((num_particles, 3), np.float32) 366 | 367 | if distr == "tracking": 368 | # fix seed 369 | if seed is not None: 370 | random_state = np.random.get_state() 371 | np.random.seed(seed) 372 | 373 | # sample offset from the Gaussian 374 | center = np.random.multivariate_normal(mean=state, cov=particles_cov) 375 | 376 | # restore random seed 377 | if seed is not None: 378 | np.random.set_state(random_state) 379 | 380 | # sample particles from the Gaussian, centered around the offset 381 | particles = np.random.multivariate_normal(mean=center, cov=particles_cov, size=num_particles) 382 | 383 | elif distr == "one-room": 384 | # mask the room the initial state is in 385 | masked_map = (roomidmap == roomidmap[int(np.rint(state[0])), int(np.rint(state[1]))]) 386 | 387 | # get bounding box for more efficient sampling 388 | rmin, rmax, cmin, cmax = bounding_box(masked_map) 389 | 390 | # rejection sampling inside bounding box 391 | sample_i = 0 392 | while sample_i < num_particles: 393 | particle = np.random.uniform(low=(rmin, cmin, 0.0), high=(rmax, cmax, 2.0*np.pi), size=(3, ),) 394 | # reject if mask is zero 395 | if not masked_map[int(np.rint(particle[0])), int(np.rint(particle[1]))]: 396 | continue 397 | particles[sample_i] = particle 398 | sample_i += 1 399 | else: 400 | raise ValueError 401 | 402 | return particles 403 | 404 | @staticmethod 405 | def get_sample_seed(seed, data_i): 406 | """ 407 | Defines a random seed for each datapoint in a deterministic manner. 408 | :param seed: int or None, defining a random seed 409 | :param data_i: int, the index of the current data point 410 | :return: None if seed is None, otherwise an int, a fixed function of both seed and data_i inputs. 411 | """ 412 | return (None if (seed is None or seed == 0) else ((data_i + 1) * 113 + seed)) 413 | 414 | 415 | def get_dataflow(files, params, is_training): 416 | """ 417 | Build a tensorflow Dataset from appropriate tfrecords files. 418 | :param files: list a file paths corresponding to appropriate tfrecords data 419 | :param params: parsed arguments 420 | :param is_training: bool, true for training. 421 | :return: (nextdata, num_samples). 422 | nextdata: list of tensorflow ops that produce the next input with the following elements: 423 | true_states, global_map, init_particles, observations, odometries, is_first_step. 424 | See House3DTrajData.get_data for definitions. 425 | num_samples: number of samples that make an epoch 426 | """ 427 | 428 | mapmode = params.mapmode 429 | obsmode = params.obsmode 430 | batchsize = params.batchsize 431 | num_particles = params.num_particles 432 | trajlen = params.trajlen 433 | bptt_steps = params.bptt_steps 434 | 435 | # build initial covariance matrix of particles, in pixels and radians 436 | particle_std = params.init_particles_std.copy() 437 | particle_std[0] = particle_std[0] / params.map_pixel_in_meters # convert meters to pixels 438 | particle_std2 = np.square(particle_std) # variance 439 | init_particles_cov = np.diag(particle_std2[(0, 0, 1),]) 440 | 441 | df = House3DTrajData(files, mapmode, obsmode, trajlen, num_particles, 442 | params.init_particles_distr, init_particles_cov, 443 | seed=(params.seed if params.seed is not None and params.seed > 0 444 | else (params.validseed if not is_training else None))) 445 | # data: true_states, global_map, init_particles, observation, odometry 446 | 447 | # make it a multiple of batchsize 448 | df = dataflow.FixedSizeData(df, size=(df.size() // batchsize) * batchsize, keep_state=False) 449 | 450 | # shuffle 451 | if is_training: 452 | df = dataflow.LocallyShuffleData(df, 100 * batchsize) 453 | 454 | # repeat data for the number of epochs 455 | df = dataflow.RepeatedData(df, params.epochs) 456 | 457 | # batch 458 | df = BatchDataWithPad(df, batchsize, padded_indices=(1,)) 459 | 460 | # break trajectory into multiple segments for BPPT training. Augment df with is_first_step indicator 461 | df = BreakForBPTT(df, timed_indices=(0, 3, 4), trajlen=trajlen, bptt_steps=bptt_steps) 462 | # data: true_states, global_map, init_particles, observation, odometry, is_first_step 463 | 464 | num_samples = df.size() // params.epochs 465 | 466 | df.reset_state() 467 | 468 | # # test dataflow 469 | # df = dataflow.TestDataSpeed(dataflow.PrintData(df), 100) 470 | # df.start() 471 | 472 | obs_ch = {'rgb': 3, 'depth': 1, 'rgb-depth': 4} 473 | map_ch = {'wall': 1, 'wall-door': 2, 'wall-roomtype': 10, 'wall-door-roomtype': 11} 474 | types = [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.bool] 475 | sizes = [(batchsize, bptt_steps, 3), 476 | (batchsize, None, None, map_ch[mapmode]), 477 | (batchsize, num_particles, 3), 478 | (batchsize, bptt_steps, 56, 56, obs_ch[obsmode]), 479 | (batchsize, bptt_steps, 3), 480 | (), ] 481 | 482 | # turn it into a tf dataset 483 | def tuplegen(): 484 | for dp in df.get_data(): 485 | yield tuple(dp) 486 | 487 | dataset = tf.data.Dataset.from_generator(tuplegen, tuple(types), tuple(sizes)) 488 | iterator = dataset.make_one_shot_iterator() 489 | nextdata = iterator.get_next() 490 | 491 | return nextdata, num_samples 492 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse==0.12.0 2 | numpy==1.14.2 3 | opencv_python==3.3.0.10 4 | matplotlib==2.1.0 5 | tensorpack==0.8.5 6 | tqdm==4.19.8 7 | six==1.11.0 8 | ipdb==0.10.3 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os, tqdm 6 | import tensorflow as tf 7 | import numpy as np 8 | from datetime import datetime 9 | 10 | import pfnet 11 | from arguments import parse_args 12 | from preprocess import get_dataflow 13 | 14 | try: 15 | import ipdb as pdb 16 | except Exception: 17 | import pdb 18 | 19 | 20 | def validation(sess, brain, num_samples, params): 21 | """ 22 | Run validation 23 | :param sess: tensorflow session 24 | :param brain: network object that provides loss and update ops, and functions to save and restore the hidden state. 25 | :param num_samples: int, number of samples in the validation set. 26 | :param params: parsed arguments 27 | :return: validation loss, averaged over the validation set 28 | """ 29 | 30 | fix_seed = (params.validseed is not None and params.validseed >= 0) 31 | if fix_seed: 32 | np_random_state = np.random.get_state() 33 | np.random.seed(params.validseed) 34 | tf.set_random_seed(params.validseed) 35 | 36 | saved_state = brain.save_state(sess) 37 | 38 | total_loss = 0.0 39 | try: 40 | for eval_i in tqdm.tqdm(range(num_samples), desc="Validation"): 41 | loss, _ = sess.run([brain.valid_loss_op, brain.update_state_op]) 42 | total_loss += loss 43 | 44 | print ("Validation loss = %f"%(total_loss/num_samples)) 45 | 46 | except tf.errors.OutOfRangeError: 47 | print ("No more samples for evaluation. This should not happen") 48 | raise 49 | 50 | brain.load_state(sess, saved_state) 51 | 52 | # restore seed 53 | if fix_seed: 54 | np.random.set_state(np_random_state) 55 | tf.set_random_seed(np.random.randint(999999)) # cannot save tf seed, so generate random one from numpy 56 | 57 | return total_loss 58 | 59 | 60 | def run_training(params): 61 | """ Run training with the parsed arguments """ 62 | 63 | with tf.Graph().as_default(): 64 | if params.seed is not None: 65 | tf.set_random_seed(params.seed) 66 | 67 | # training data and network 68 | with tf.variable_scope(tf.get_variable_scope(), reuse=False): 69 | train_data, num_train_samples = get_dataflow(params.trainfiles, params, is_training=True) 70 | train_brain = pfnet.PFNet(inputs=train_data[1:], labels=train_data[0], params=params, is_training=True) 71 | 72 | # test data and network 73 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 74 | test_data, num_test_samples = get_dataflow(params.testfiles, params, is_training=False) 75 | test_brain = pfnet.PFNet(inputs=test_data[1:], labels=test_data[0], params=params, is_training=False) 76 | 77 | # Add the variable initializer Op. 78 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 79 | 80 | # Create a saver for writing training checkpoints. 81 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=3) 82 | 83 | # Create a session for running Ops on the Graph. 84 | os.environ["CUDA_VISIBLE_DEVICES"] = "%d"%int(params.gpu) 85 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 86 | sess_config.gpu_options.allow_growth = True 87 | 88 | # training session 89 | with tf.Session(config=sess_config) as sess: 90 | sess.run(init_op) 91 | 92 | # load model from checkpoint file 93 | if params.load: 94 | print("Loading model from " + params.load) 95 | saver.restore(sess, params.load) 96 | 97 | coord = tf.train.Coordinator() 98 | threads = tf.train.start_queue_runners(coord=coord) 99 | 100 | try: 101 | decay_step = 0 102 | 103 | # repeat for a fixed number of epochs 104 | for epoch_i in range(params.epochs): 105 | epoch_loss = 0.0 106 | periodic_loss = 0.0 107 | 108 | # run training over all samples in an epoch 109 | for step_i in tqdm.tqdm(range(num_train_samples)): 110 | _, loss, _ = sess.run([train_brain.train_op, train_brain.train_loss_op, 111 | train_brain.update_state_op]) 112 | periodic_loss += loss 113 | epoch_loss += loss 114 | 115 | # print accumulated loss after every few hundred steps 116 | if step_i > 0 and (step_i % 500) == 0: 117 | tqdm.tqdm.write("Epoch %d, step %d. Training loss = %f" % (epoch_i + 1, step_i, periodic_loss / 500.0)) 118 | periodic_loss = 0.0 119 | 120 | # print the avarage loss over the epoch 121 | tqdm.tqdm.write("Epoch %d done. Average training loss = %f" % (epoch_i + 1, epoch_loss / num_train_samples)) 122 | 123 | # save model, validate and decrease learning rate after each epoch 124 | saver.save(sess, os.path.join(params.logpath, 'model.chk'), global_step=epoch_i + 1) 125 | 126 | # run validation 127 | validation(sess, test_brain, num_samples=num_test_samples, params=params) 128 | 129 | # decay learning rate 130 | if epoch_i + 1 % params.decaystep == 0: 131 | decay_step += 1 132 | current_global_step = sess.run(tf.assign(train_brain.global_step_op, decay_step)) 133 | current_learning_rate = sess.run(train_brain.learning_rate_op) 134 | tqdm.tqdm.write("Decreased learning rate to %f." % (current_learning_rate)) 135 | 136 | except KeyboardInterrupt: 137 | pass 138 | 139 | except tf.errors.OutOfRangeError: 140 | print("data exhausted") 141 | 142 | finally: 143 | saver.save(sess, os.path.join(params.logpath, 'final.chk')) # dont pass global step 144 | coord.request_stop() 145 | 146 | coord.join(threads) 147 | 148 | print ("Training done. Model is saved to %s"%(params.logpath)) 149 | 150 | 151 | if __name__ == '__main__': 152 | params = parse_args() 153 | 154 | params.logpath = os.path.join(params.logpath, "log-" + datetime.now().strftime('%m%d-%H-%M-%S')) 155 | os.mkdir(params.logpath) 156 | 157 | run_training(params) -------------------------------------------------------------------------------- /transformer/README.md: -------------------------------------------------------------------------------- 1 | # Spatial Transformer Network 2 | 3 | The Spatial Transformer Network [1] allows the spatial manipulation of data within the network. 4 | 5 |
6 |

7 |
8 | 9 | ### API 10 | 11 | A Spatial Transformer Network implemented in Tensorflow 1.0 and based on [2]. 12 | 13 | #### How to use 14 | 15 |
16 |

17 |
18 | 19 | ```python 20 | transformer(U, theta, out_size) 21 | ``` 22 | 23 | #### Parameters 24 | 25 | U : float 26 | The output of a convolutional net should have the 27 | shape [num_batch, height, width, num_channels]. 28 | theta: float 29 | The output of the 30 | localisation network should be [num_batch, 6]. 31 | out_size: tuple of two ints 32 | The size of the output of the network 33 | 34 | 35 | #### Notes 36 | To initialize the network to the identity transform init ``theta`` to : 37 | 38 | ```python 39 | identity = np.array([[1., 0., 0.], 40 | [0., 1., 0.]]) 41 | identity = identity.flatten() 42 | theta = tf.Variable(initial_value=identity) 43 | ``` 44 | 45 | #### Experiments 46 | 47 |
48 |

49 |
50 | 51 | We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN. 52 | 53 | All experiments were run in Tensorflow 0.7. 54 | 55 | ### References 56 | 57 | [1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015) 58 | 59 | [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 60 | -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/transformer/__init__.py -------------------------------------------------------------------------------- /transformer/cluttered_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | import tensorflow as tf 16 | from spatial_transformer import transformer 17 | import numpy as np 18 | from tf_utils import weight_variable, bias_variable, dense_to_one_hot 19 | 20 | # %% Load data 21 | mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz') 22 | 23 | X_train = mnist_cluttered['X_train'] 24 | y_train = mnist_cluttered['y_train'] 25 | X_valid = mnist_cluttered['X_valid'] 26 | y_valid = mnist_cluttered['y_valid'] 27 | X_test = mnist_cluttered['X_test'] 28 | y_test = mnist_cluttered['y_test'] 29 | 30 | # % turn from dense to one hot representation 31 | Y_train = dense_to_one_hot(y_train, n_classes=10) 32 | Y_valid = dense_to_one_hot(y_valid, n_classes=10) 33 | Y_test = dense_to_one_hot(y_test, n_classes=10) 34 | 35 | # %% Graph representation of our network 36 | 37 | # %% Placeholders for 40x40 resolution 38 | x = tf.placeholder(tf.float32, [None, 1600]) 39 | y = tf.placeholder(tf.float32, [None, 10]) 40 | 41 | # %% Since x is currently [batch, height*width], we need to reshape to a 42 | # 4-D tensor to use it in a convolutional graph. If one component of 43 | # `shape` is the special value -1, the size of that dimension is 44 | # computed so that the total size remains constant. Since we haven't 45 | # defined the batch dimension's shape yet, we use -1 to denote this 46 | # dimension should not change size. 47 | x_tensor = tf.reshape(x, [-1, 40, 40, 1]) 48 | 49 | # %% We'll setup the two-layer localisation network to figure out the 50 | # %% parameters for an affine transformation of the input 51 | # %% Create variables for fully connected layer 52 | W_fc_loc1 = weight_variable([1600, 20]) 53 | b_fc_loc1 = bias_variable([20]) 54 | 55 | W_fc_loc2 = weight_variable([20, 6]) 56 | # Use identity transformation as starting point 57 | initial = np.array([[1., 0, 0], [0, 1., 0]]) 58 | initial = initial.astype('float32') 59 | initial = initial.flatten() 60 | b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2') 61 | 62 | # %% Define the two layer localisation network 63 | h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1) 64 | # %% We can add dropout for regularizing and to reduce overfitting like so: 65 | keep_prob = tf.placeholder(tf.float32) 66 | h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob) 67 | # %% Second layer 68 | h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2) 69 | 70 | # %% We'll create a spatial transformer module to identify discriminative 71 | # %% patches 72 | out_size = (40, 40) 73 | h_trans = transformer(x_tensor, h_fc_loc2, out_size) 74 | 75 | # %% We'll setup the first convolutional layer 76 | # Weight matrix is [height x width x input_channels x output_channels] 77 | filter_size = 3 78 | n_filters_1 = 16 79 | W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1]) 80 | 81 | # %% Bias is [output_channels] 82 | b_conv1 = bias_variable([n_filters_1]) 83 | 84 | # %% Now we can build a graph which does the first layer of convolution: 85 | # we define our stride as batch x height x width x channels 86 | # instead of pooling, we use strides of 2 and more layers 87 | # with smaller filters. 88 | 89 | h_conv1 = tf.nn.relu( 90 | tf.nn.conv2d(input=h_trans, 91 | filter=W_conv1, 92 | strides=[1, 2, 2, 1], 93 | padding='SAME') + 94 | b_conv1) 95 | 96 | # %% And just like the first layer, add additional layers to create 97 | # a deep net 98 | n_filters_2 = 16 99 | W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2]) 100 | b_conv2 = bias_variable([n_filters_2]) 101 | h_conv2 = tf.nn.relu( 102 | tf.nn.conv2d(input=h_conv1, 103 | filter=W_conv2, 104 | strides=[1, 2, 2, 1], 105 | padding='SAME') + 106 | b_conv2) 107 | 108 | # %% We'll now reshape so we can connect to a fully-connected layer: 109 | h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2]) 110 | 111 | # %% Create a fully-connected layer: 112 | n_fc = 1024 113 | W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc]) 114 | b_fc1 = bias_variable([n_fc]) 115 | h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1) 116 | 117 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 118 | 119 | # %% And finally our softmax layer: 120 | W_fc2 = weight_variable([n_fc, 10]) 121 | b_fc2 = bias_variable([10]) 122 | y_logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 123 | 124 | # %% Define loss/eval/training functions 125 | cross_entropy = tf.reduce_mean( 126 | tf.nn.softmax_cross_entropy_with_logits(logits=y_logits, labels=y)) 127 | opt = tf.train.AdamOptimizer() 128 | optimizer = opt.minimize(cross_entropy) 129 | grads = opt.compute_gradients(cross_entropy, [b_fc_loc2]) 130 | 131 | # %% Monitor accuracy 132 | correct_prediction = tf.equal(tf.argmax(y_logits, 1), tf.argmax(y, 1)) 133 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) 134 | 135 | # %% We now create a new session to actually perform the initialization the 136 | # variables: 137 | sess = tf.Session() 138 | sess.run(tf.global_variables_initializer()) 139 | 140 | 141 | # %% We'll now train in minibatches and report accuracy, loss: 142 | iter_per_epoch = 100 143 | n_epochs = 500 144 | train_size = 10000 145 | 146 | indices = np.linspace(0, 10000 - 1, iter_per_epoch) 147 | indices = indices.astype('int') 148 | 149 | for epoch_i in range(n_epochs): 150 | for iter_i in range(iter_per_epoch - 1): 151 | batch_xs = X_train[indices[iter_i]:indices[iter_i+1]] 152 | batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]] 153 | 154 | if iter_i % 10 == 0: 155 | loss = sess.run(cross_entropy, 156 | feed_dict={ 157 | x: batch_xs, 158 | y: batch_ys, 159 | keep_prob: 1.0 160 | }) 161 | print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss)) 162 | 163 | sess.run(optimizer, feed_dict={ 164 | x: batch_xs, y: batch_ys, keep_prob: 0.8}) 165 | 166 | print('Accuracy (%d): ' % epoch_i + str(sess.run(accuracy, 167 | feed_dict={ 168 | x: X_valid, 169 | y: Y_valid, 170 | keep_prob: 1.0 171 | }))) 172 | # theta = sess.run(h_fc_loc2, feed_dict={ 173 | # x: batch_xs, keep_prob: 1.0}) 174 | # print(theta[0]) 175 | -------------------------------------------------------------------------------- /transformer/data/README.md: -------------------------------------------------------------------------------- 1 | ### How to get the data 2 | 3 | #### Cluttered MNIST 4 | 5 | The cluttered MNIST dataset can be found here [1] or can be generated via [2]. 6 | 7 | Settings used for `cluttered_mnist.py` : 8 | 9 | ```python 10 | 11 | ORG_SHP = [28, 28] 12 | OUT_SHP = [40, 40] 13 | NUM_DISTORTIONS = 8 14 | dist_size = (5, 5) 15 | 16 | ``` 17 | 18 | [1] https://github.com/daviddao/spatial-transformer-tensorflow 19 | 20 | [2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py -------------------------------------------------------------------------------- /transformer/example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from scipy import ndimage 16 | import tensorflow as tf 17 | from spatial_transformer import transformer 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | 21 | # %% Create a batch of three images (1600 x 1200) 22 | # %% Image retrieved from: 23 | # %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg 24 | im = ndimage.imread('cat.jpg') 25 | im = im / 255. 26 | im = im.reshape(1, 1200, 1600, 3) 27 | im = im.astype('float32') 28 | 29 | # %% Let the output size of the transformer be half the image size. 30 | out_size = (600, 800) 31 | 32 | # %% Simulate batch 33 | batch = np.append(im, im, axis=0) 34 | batch = np.append(batch, im, axis=0) 35 | num_batch = 3 36 | 37 | x = tf.placeholder(tf.float32, [None, 1200, 1600, 3]) 38 | x = tf.cast(batch, 'float32') 39 | 40 | # %% Create localisation network and convolutional layer 41 | with tf.variable_scope('spatial_transformer_0'): 42 | 43 | # %% Create a fully-connected layer with 6 output nodes 44 | n_fc = 6 45 | W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1') 46 | 47 | # %% Zoom into the image 48 | initial = np.array([[0.5, 0, 0], [0, 0.5, 0]]) 49 | initial = initial.astype('float32') 50 | initial = initial.flatten() 51 | 52 | b_fc1 = tf.Variable(initial_value=initial, name='b_fc1') 53 | h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1 54 | h_trans = transformer(x, h_fc1, out_size) 55 | 56 | # %% Run session 57 | sess = tf.Session() 58 | sess.run(tf.global_variables_initializer()) 59 | y = sess.run(h_trans, feed_dict={x: batch}) 60 | 61 | # plt.imshow(y[0]) 62 | -------------------------------------------------------------------------------- /transformer/spatial_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from six.moves import xrange 17 | import tensorflow as tf 18 | 19 | 20 | def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs): 21 | """Spatial Transformer Layer 22 | 23 | Implements a spatial transformer layer as described in [1]_. 24 | Based on [2]_ and edited by David Dao for Tensorflow. 25 | 26 | Parameters 27 | ---------- 28 | U : float 29 | The output of a convolutional net should have the 30 | shape [num_batch, height, width, num_channels]. 31 | theta: float 32 | The output of the 33 | localisation network should be [num_batch, 6]. 34 | out_size: tuple of two ints 35 | The size of the output of the network (height, width) 36 | 37 | References 38 | ---------- 39 | .. [1] Spatial Transformer Networks 40 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 41 | Submitted on 5 Jun 2015 42 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 43 | 44 | Notes 45 | ----- 46 | To initialize the network to the identity transform init 47 | ``theta`` to : 48 | identity = np.array([[1., 0., 0.], 49 | [0., 1., 0.]]) 50 | identity = identity.flatten() 51 | theta = tf.Variable(initial_value=identity) 52 | 53 | """ 54 | 55 | def _repeat(x, n_repeats): 56 | with tf.variable_scope('_repeat'): 57 | rep = tf.transpose( 58 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 59 | rep = tf.cast(rep, 'int32') 60 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 61 | return tf.reshape(x, [-1]) 62 | 63 | def _interpolate(im, x, y, out_size): 64 | with tf.variable_scope('_interpolate'): 65 | # constants 66 | num_batch = tf.shape(im)[0] 67 | height = tf.shape(im)[1] 68 | width = tf.shape(im)[2] 69 | channels = tf.shape(im)[3] 70 | 71 | x = tf.cast(x, 'float32') 72 | y = tf.cast(y, 'float32') 73 | height_f = tf.cast(height, 'float32') 74 | width_f = tf.cast(width, 'float32') 75 | out_height = out_size[0] 76 | out_width = out_size[1] 77 | zero = tf.zeros([], dtype='int32') 78 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') 79 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') 80 | 81 | # scale indices from [-1, 1] to [0, width/height] 82 | x = (x + 1.0)*(width_f) / 2.0 83 | y = (y + 1.0)*(height_f) / 2.0 84 | 85 | # do sampling 86 | x0 = tf.cast(tf.floor(x), 'int32') 87 | x1 = x0 + 1 88 | y0 = tf.cast(tf.floor(y), 'int32') 89 | y1 = y0 + 1 90 | 91 | x0 = tf.clip_by_value(x0, zero, max_x) 92 | x1 = tf.clip_by_value(x1, zero, max_x) 93 | y0 = tf.clip_by_value(y0, zero, max_y) 94 | y1 = tf.clip_by_value(y1, zero, max_y) 95 | dim2 = width 96 | dim1 = width*height 97 | base = _repeat(tf.range(num_batch)*dim1, out_height*out_width) 98 | base_y0 = base + y0*dim2 99 | base_y1 = base + y1*dim2 100 | idx_a = base_y0 + x0 101 | idx_b = base_y1 + x0 102 | idx_c = base_y0 + x1 103 | idx_d = base_y1 + x1 104 | 105 | # use indices to lookup pixels in the flat image and restore 106 | # channels dim 107 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 108 | im_flat = tf.cast(im_flat, 'float32') 109 | Ia = tf.gather(im_flat, idx_a) 110 | Ib = tf.gather(im_flat, idx_b) 111 | Ic = tf.gather(im_flat, idx_c) 112 | Id = tf.gather(im_flat, idx_d) 113 | 114 | # and finally calculate interpolated values 115 | x0_f = tf.cast(x0, 'float32') 116 | x1_f = tf.cast(x1, 'float32') 117 | y0_f = tf.cast(y0, 'float32') 118 | y1_f = tf.cast(y1, 'float32') 119 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1) 120 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1) 121 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1) 122 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1) 123 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) 124 | return output 125 | 126 | def _meshgrid(height, width): 127 | with tf.variable_scope('_meshgrid'): 128 | # This should be equivalent to: 129 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 130 | # np.linspace(-1, 1, height)) 131 | # ones = np.ones(np.prod(x_t.shape)) 132 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 133 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])), 134 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0])) 135 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), 136 | tf.ones(shape=tf.stack([1, width]))) 137 | 138 | x_t_flat = tf.reshape(x_t, (1, -1)) 139 | y_t_flat = tf.reshape(y_t, (1, -1)) 140 | 141 | ones = tf.ones_like(x_t_flat) 142 | grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat, ones]) 143 | return grid 144 | 145 | def _transform(theta, input_dim, out_size): 146 | with tf.variable_scope('_transform'): 147 | num_batch = tf.shape(input_dim)[0] 148 | height = tf.shape(input_dim)[1] 149 | width = tf.shape(input_dim)[2] 150 | num_channels = tf.shape(input_dim)[3] 151 | theta = tf.reshape(theta, (-1, 2, 3)) 152 | theta = tf.cast(theta, 'float32') 153 | 154 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 155 | height_f = tf.cast(height, 'float32') 156 | width_f = tf.cast(width, 'float32') 157 | out_height = out_size[0] 158 | out_width = out_size[1] 159 | grid = _meshgrid(out_height, out_width) 160 | grid = tf.expand_dims(grid, 0) 161 | grid = tf.reshape(grid, [-1]) 162 | grid = tf.tile(grid, tf.stack([num_batch])) 163 | grid = tf.reshape(grid, tf.stack([num_batch, 3, -1])) 164 | 165 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 166 | T_g = tf.matmul(theta, grid) 167 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 168 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 169 | x_s_flat = tf.reshape(x_s, [-1]) 170 | y_s_flat = tf.reshape(y_s, [-1]) 171 | 172 | input_transformed = _interpolate( 173 | input_dim, x_s_flat, y_s_flat, 174 | out_size) 175 | 176 | output = tf.reshape( 177 | input_transformed, tf.stack([num_batch, out_height, out_width, num_channels])) 178 | return output 179 | 180 | with tf.variable_scope(name): 181 | output = _transform(theta, U, out_size) 182 | return output 183 | 184 | 185 | def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'): 186 | """Batch Spatial Transformer Layer 187 | 188 | Parameters 189 | ---------- 190 | 191 | U : float 192 | tensor of inputs [num_batch,height,width,num_channels] 193 | thetas : float 194 | a set of transformations for each input [num_batch,num_transforms,6] 195 | out_size : int 196 | the size of the output [out_height,out_width] 197 | 198 | Returns: float 199 | Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels] 200 | """ 201 | with tf.variable_scope(name): 202 | num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2]) 203 | indices = [[i]*num_transforms for i in xrange(num_batch)] 204 | input_repeated = tf.gather(U, tf.reshape(indices, [-1])) 205 | return transformer(input_repeated, thetas, out_size) 206 | -------------------------------------------------------------------------------- /transformer/tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/ 17 | import tensorflow as tf 18 | import numpy as np 19 | 20 | def conv2d(x, n_filters, 21 | k_h=5, k_w=5, 22 | stride_h=2, stride_w=2, 23 | stddev=0.02, 24 | activation=lambda x: x, 25 | bias=True, 26 | padding='SAME', 27 | name="Conv2D"): 28 | """2D Convolution with options for kernel size, stride, and init deviation. 29 | Parameters 30 | ---------- 31 | x : Tensor 32 | Input tensor to convolve. 33 | n_filters : int 34 | Number of filters to apply. 35 | k_h : int, optional 36 | Kernel height. 37 | k_w : int, optional 38 | Kernel width. 39 | stride_h : int, optional 40 | Stride in rows. 41 | stride_w : int, optional 42 | Stride in cols. 43 | stddev : float, optional 44 | Initialization's standard deviation. 45 | activation : arguments, optional 46 | Function which applies a nonlinearity 47 | padding : str, optional 48 | 'SAME' or 'VALID' 49 | name : str, optional 50 | Variable scope to use. 51 | Returns 52 | ------- 53 | x : Tensor 54 | Convolved input. 55 | """ 56 | with tf.variable_scope(name): 57 | w = tf.get_variable( 58 | 'w', [k_h, k_w, x.get_shape()[-1], n_filters], 59 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 60 | conv = tf.nn.conv2d( 61 | x, w, strides=[1, stride_h, stride_w, 1], padding=padding) 62 | if bias: 63 | b = tf.get_variable( 64 | 'b', [n_filters], 65 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 66 | conv = conv + b 67 | return conv 68 | 69 | def linear(x, n_units, scope=None, stddev=0.02, 70 | activation=lambda x: x): 71 | """Fully-connected network. 72 | Parameters 73 | ---------- 74 | x : Tensor 75 | Input tensor to the network. 76 | n_units : int 77 | Number of units to connect to. 78 | scope : str, optional 79 | Variable scope to use. 80 | stddev : float, optional 81 | Initialization's standard deviation. 82 | activation : arguments, optional 83 | Function which applies a nonlinearity 84 | Returns 85 | ------- 86 | x : Tensor 87 | Fully-connected output. 88 | """ 89 | shape = x.get_shape().as_list() 90 | 91 | with tf.variable_scope(scope or "Linear"): 92 | matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32, 93 | tf.random_normal_initializer(stddev=stddev)) 94 | return activation(tf.matmul(x, matrix)) 95 | 96 | # %% 97 | def weight_variable(shape): 98 | '''Helper function to create a weight variable initialized with 99 | a normal distribution 100 | Parameters 101 | ---------- 102 | shape : list 103 | Size of weight variable 104 | ''' 105 | #initial = tf.random_normal(shape, mean=0.0, stddev=0.01) 106 | initial = tf.zeros(shape) 107 | return tf.Variable(initial) 108 | 109 | # %% 110 | def bias_variable(shape): 111 | '''Helper function to create a bias variable initialized with 112 | a constant value. 113 | Parameters 114 | ---------- 115 | shape : list 116 | Size of weight variable 117 | ''' 118 | initial = tf.random_normal(shape, mean=0.0, stddev=0.01) 119 | return tf.Variable(initial) 120 | 121 | # %% 122 | def dense_to_one_hot(labels, n_classes=2): 123 | """Convert class labels from scalars to one-hot vectors.""" 124 | labels = np.array(labels) 125 | n_labels = labels.shape[0] 126 | index_offset = np.arange(n_labels) * n_classes 127 | labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32) 128 | labels_one_hot.flat[index_offset + labels.ravel()] = 1 129 | return labels_one_hot 130 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/utils/__init__.py -------------------------------------------------------------------------------- /utils/network_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | _L2_SCALE = 1.0 # scale when adding to loss instead of a global scaler 5 | 6 | 7 | # Helper functions for constructing layers 8 | def dense_layer(units, activation=None, use_bias=False, name=None): 9 | 10 | fn = lambda x: tf.layers.dense( 11 | x, units, activation=convert_activation_string(activation), use_bias=use_bias, name=name, 12 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=_L2_SCALE)) 13 | return fn 14 | 15 | 16 | def conv2_layer(filters, kernel_size, activation=None, padding='same', strides=(1,1), dilation_rate=(1,1), 17 | data_format='channels_last', use_bias=False, name=None, layer_i=0, name2=None): 18 | if name is None: 19 | name = "l%d_conv%d" % (layer_i, np.max(kernel_size)) 20 | if np.max(dilation_rate) > 1: 21 | name += "_d%d" % np.max(dilation_rate) 22 | if name2 is not None: 23 | name += "_"+name2 24 | fn = lambda x: tf.layers.conv2d( 25 | x, filters, kernel_size, activation=convert_activation_string(activation), 26 | padding=padding, strides=strides, dilation_rate = dilation_rate, data_format=data_format, 27 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=_L2_SCALE), 28 | kernel_initializer=tf.variance_scaling_initializer(), 29 | use_bias=use_bias, name=name) 30 | return fn 31 | 32 | 33 | def locallyconn2_layer(filters, kernel_size, activation=None, padding='same', strides=(1,1), dilation_rate=(1,1), 34 | data_format='channels_last', use_bias=False, name=None, layer_i=0, name2=None): 35 | assert dilation_rate == (1, 1) # keras layer doesnt have this input. maybe different name? 36 | if name is None: 37 | name = "l%d_conv%d"%(layer_i, np.max(kernel_size)) 38 | if np.max(dilation_rate) > 1: 39 | name += "_d%d"%np.max(dilation_rate) 40 | if name2 is not None: 41 | name += "_"+name2 42 | fn = tf.keras.layers.LocallyConnected2D( 43 | filters, kernel_size, activation=convert_activation_string(activation), 44 | padding=padding, strides=strides, data_format=data_format, 45 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=_L2_SCALE), 46 | kernel_initializer=tf.variance_scaling_initializer(), 47 | use_bias=use_bias, name=name) 48 | return fn 49 | 50 | 51 | def convert_activation_string(activation): 52 | if isinstance(activation, str): 53 | if activation == 'relu': 54 | activation = tf.nn.relu 55 | elif activation == 'tanh': 56 | activation = tf.nn.tanh 57 | else: 58 | assert False 59 | return activation 60 | -------------------------------------------------------------------------------- /utils/tfrecordfeatures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def tf_bytes_feature(value): 5 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 6 | 7 | 8 | def tf_bytelist_feature(values): 9 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) 10 | 11 | 12 | def tf_int64_feature(value): 13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) --------------------------------------------------------------------------------