├── .gitignore
├── LICENCE
├── README.md
├── __init__.py
├── arguments.py
├── configs
├── eval-localization.conf
├── eval-tracking.conf
└── train.conf
├── data
├── rgb_model_trained.chk.data-00000-of-00001
├── rgb_model_trained.chk.index
└── rgb_model_trained.chk.meta
├── display_data.py
├── evaluate.py
├── figs
├── architecture.png
└── trajectory.jpg
├── pfnet.py
├── preprocess.py
├── requirements.txt
├── train.py
├── transformer
├── README.md
├── __init__.py
├── cluttered_mnist.py
├── data
│ └── README.md
├── example.py
├── spatial_transformer.py
└── tf_utils.py
└── utils
├── __init__.py
├── network_layers.py
└── tfrecordfeatures.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.o
3 | .idea/
4 |
5 | # Windows image file caches
6 | Thumbs.db
7 | ehthumbs.db
8 |
9 | # Folder config file
10 | Desktop.ini
11 |
12 | # Recycle Bin used on file shares
13 | $RECYCLE.BIN/
14 |
15 | # Windows Installer files
16 | *.cab
17 | *.msi
18 | *.msm
19 | *.msp
20 |
21 | # Windows shortcuts
22 | *.lnk
23 |
24 | # =========================
25 | # Operating System Files
26 | # =========================
27 |
28 | # OSX
29 | # =========================
30 |
31 | .DS_Store
32 | .AppleDouble
33 | .LSOverride
34 |
35 | # Thumbnails
36 | ._*
37 |
38 | # Files that might appear in the root of a volume
39 | .DocumentRevisions-V100
40 | .fseventsd
41 | .Spotlight-V100
42 | .TemporaryItems
43 | .Trashes
44 | .VolumeIcon.icns
45 |
46 | # Directories potentially created on remote AFP share
47 | .AppleDB
48 | .AppleDesktop
49 | Network Trash Folder
50 | Temporary Items
51 | .apdisk
52 |
53 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Peter Karkus, AdaCompNUS
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Particle Filter Networks
2 |
3 | Tensorflow implementation of Particle Filter Networks (PF-net)
4 |
5 | Peter Karkus, David Hsu, and Wee Sun Lee: Particle filter networks with application to visual localization.
6 | Conference on Robot Learning (CoRL), 2018. https://arxiv.org/abs/1805.08975
7 |
8 | ### PF-net architecture
9 | 
10 | PF-net encodes both a learned probabilistic system model and the particle filter algorithm in a single neural network
11 |
12 | ### Localization example
13 | 
14 | Example for successful global localization
15 |
16 |
17 | ### Requirements
18 |
19 | Python 2.7, Tensorflow 1.5.0
20 |
21 | Additional packages can be installed with
22 | ```
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ### Dataset
27 |
28 | Datasets for localization experiments are available at
29 | https://drive.google.com/open?id=1hSDRg7tEf2O1D_NIL8OmVdHLN2KqYFF7
30 |
31 | The folder contains data for training, validation and testing, in
32 | tfrecords format. Download to the ./data/ folder.
33 |
34 | A simple script is available to visualize the data:
35 | ```
36 | python display_data.py ./data/valid.tfrecords
37 | ```
38 |
39 |
40 | ### Training
41 |
42 | The ```./configs/``` folder contains default configuration files for training and evaluation.
43 | Training requires the datesets to be downloaded into the ./data/ folder.
44 | Note that training requires significant resources, and it may take several hours or days depending on the hardware.
45 |
46 | PF-net can be trained with the default configuration using the following command:
47 | ```
48 | python train.py -c ./configs/train.conf --obsmode rgb
49 | ```
50 |
51 | By default, logs will be saved in a new folder under ./logs/.
52 | For help on the input arguments run
53 | ```
54 | python train.py -h
55 | ```
56 |
57 |
58 |
59 | ### Evaluation
60 |
61 | A pre-trained model for RGB input is available under ./data/rgb_model_trained.chk. For evaluating a trained model run
62 | ```
63 | python evaluate.py -c ./configs/eval-tracking.conf --load ./data/rgb_model_trained.chk
64 | ```
65 | for the tracking task, and
66 | ```
67 | python evaluate.py -c ./configs/eval-localization.conf --load ./data/rgb_model_trained.chk
68 | ```
69 | for semi-global localization.
70 | The input arguments are the same as for train.py.
71 | Results will be somewhat different from the ones in
72 | the paper, because the initial uncertainty on the test trajectories
73 | uses a different random seed.
74 |
75 | ### Contact
76 |
77 | Peter Karkus
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | import numpy as np
3 |
4 |
5 | def parse_args(args=None):
6 | """
7 | Parse command line arguments
8 | :param args: command line arguments or None (default)
9 | :return: dictionary of parameters
10 | """
11 |
12 | p = configargparse.ArgParser(default_config_files=[])
13 |
14 | p.add('-c', '--config', required=True, is_config_file=True,
15 | help='Config file. use ./config/train.conf for training')
16 |
17 | p.add('--trainfiles', nargs='*', help='Data file(s) for training (tfrecord).')
18 | p.add('--testfiles', nargs='*', help='Data file(s) for validation or evaluation (tfrecord).')
19 |
20 | # input configuration
21 | p.add('--obsmode', type=str, default='rgb',
22 | help='Observation input type. Possible values: rgb / depth / rgb-depth / vrf.')
23 | p.add('--mapmode', type=str, default='wall',
24 | help='Map input type with different (semantic) channels. ' +
25 | 'Possible values: wall / wall-door / wall-roomtype / wall-door-roomtype')
26 | p.add('--map_pixel_in_meters', type=float, default=0.02,
27 | help='The width (and height) of a pixel of the map in meters. Defaults to 0.02 for House3D data.')
28 |
29 | p.add('--init_particles_distr', type=str, default='tracking',
30 | help='Distribution of initial particles. Possible values: tracking / one-room / two-rooms / all-rooms')
31 | p.add('--init_particles_std', nargs='*', default=["0.3", "0.523599"], # tracking setting, 30cm, 30deg
32 | help='Standard deviations for generated initial particles. Only applies to the tracking setting.' +
33 | 'Expects two float values: translation std (meters), rotation std (radians)')
34 | p.add('--trajlen', type=int, default=24,
35 | help='Length of trajectories. Assumes lower or equal to the trajectory length in the input data.')
36 |
37 | # PF-net configuration
38 | p.add('--num_particles', type=int, default=30, help='Number of particles in PF-net.')
39 | p.add('--resample', type=str, default='false',
40 | help='Resample particles in PF-net. Possible values: true / false.')
41 | p.add('--alpha_resample_ratio', type=float, default=1.0,
42 | help='Trade-off parameter for soft-resampling in PF-net. Only effective if resample == true. '
43 | 'Assumes values 0.0 < alpha <= 1.0. Alpha equal to 1.0 corresponds to hard-resampling.')
44 | p.add('--transition_std', nargs='*', default=["0.0", "0.0"],
45 | help='Standard deviations for transition model. Expects two float values: ' +
46 | 'translation std (meters), rotatation std (radians). Defaults to zeros.')
47 |
48 | # training configuration
49 | p.add('--batchsize', type=int, default=24, help='Minibatch size for training. Must be 1 for evaluation.')
50 | p.add('--bptt_steps', type=int, default=4,
51 | help='Number of backpropagation steps for training with backpropagation through time (BPTT). '
52 | 'Assumed to be an integer divisor of the trajectory length (--trajlen).')
53 | p.add('--learningrate', type=float, default=0.0025, help='Initial learning rate for training.')
54 | p.add('--l2scale', type=float, default=4e-6, help='Scaling term for the L2 regularization loss.')
55 | p.add('--epochs', metavar='epochs', type=int, default=1, help='Number of epochs for training.')
56 | p.add('--decaystep', type=int, default=4, help='Decay the learning rate after every N epochs.')
57 | p.add('--decayrate', type=float, help='Rate of decaying the learning rate.')
58 |
59 | p.add('--load', type=str, default="", help='Load a previously trained model from a checkpoint file.')
60 | p.add('--logpath', type=str, default='',
61 | help='Specify path for logs. Makes a new directory under ./log/ if empty (default).')
62 | p.add('--seed', type=int, help='Fix the random seed of numpy and tensorflow if set to larger than zero.')
63 | p.add('--validseed', type=int,
64 | help='Fix the random seed for validation if set to larger than zero. ' +
65 | 'Useful to evaluate with a fixed set of initial particles, which reduces the validation error variance.')
66 | p.add('--gpu', type=int, default=0, help='Select a gpu on a multi-gpu machine. Defaults to zero.')
67 |
68 | params = p.parse_args(args=args)
69 |
70 | # fix numpy seed if needed
71 | if params.seed is not None and params.seed >= 0:
72 | np.random.seed(params.seed)
73 |
74 | # convert multi-input fileds to numpy arrays
75 | params.transition_std = np.array(params.transition_std, np.float32)
76 | params.init_particles_std = np.array(params.init_particles_std, np.float32)
77 |
78 | # convert boolean fields
79 | if params.resample not in ['false', 'true']:
80 | print ("The value of resample must be either 'false' or 'true'")
81 | raise ValueError
82 | params.resample = (params.resample == 'true')
83 |
84 | return params
85 |
--------------------------------------------------------------------------------
/configs/eval-localization.conf:
--------------------------------------------------------------------------------
1 | # config file for evaluating semi-global localization, with initial belief uniform over one room
2 |
3 | trainfile = ./data/train.tfrecords
4 | testfile = ./data/test.tfrecords
5 | logpath = ./log/
6 |
7 | init_particles_distr = one-room
8 |
9 | trajlen = 100
10 | num_particles = 1000
11 | resample = true
12 | transition_std = [0.04, 0.0872665] # 4cm, 5deg
13 |
14 | seed = 98
15 | validseed = 100
16 |
17 | batchsize = 1
18 | epochs = 1
19 |
--------------------------------------------------------------------------------
/configs/eval-tracking.conf:
--------------------------------------------------------------------------------
1 | # config file for evaluating tracking where initial belief is close to the true state
2 |
3 | trainfile = ./data/train.tfrecords
4 | testfile = ./data/test.tfrecords
5 | logpath = ./log/
6 |
7 | init_particles_distr = tracking
8 | init_particles_std = [0.3, 0.523599]
9 | trajlen = 24
10 | num_particles = 300
11 | resample = false
12 | transition_std = [0, 0]
13 |
14 | seed = 98
15 | validseed = 100
16 |
17 | batchsize = 1
18 | epochs = 1
19 |
--------------------------------------------------------------------------------
/configs/train.conf:
--------------------------------------------------------------------------------
1 | # config file for the allobj dataset
2 |
3 | trainfile = ./data/train.tfrecords
4 | testfile = ./data/valid.tfrecords
5 | logpath = ./log/
6 |
7 | init_particles_distr = tracking
8 | init_particles_std = [0.3, 0.523599] #30cm, 30degrees
9 | trajlen = 24
10 | num_particles = 30
11 | resample = false
12 | transition_std = [0, 0]
13 |
14 | validseed = 1
15 |
16 | batchsize = 24
17 | bptt_steps = 4
18 | learningrate = 0.0001
19 | l2scale = 4e-6
20 |
21 | epochs = 12
22 | decaystep = 4
23 | decayrate = 0.5
24 |
25 |
--------------------------------------------------------------------------------
/data/rgb_model_trained.chk.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/data/rgb_model_trained.chk.data-00000-of-00001
--------------------------------------------------------------------------------
/data/rgb_model_trained.chk.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/data/rgb_model_trained.chk.index
--------------------------------------------------------------------------------
/data/rgb_model_trained.chk.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/data/rgb_model_trained.chk.meta
--------------------------------------------------------------------------------
/display_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import matplotlib.pyplot as plt
6 | import sys
7 | import tensorflow as tf
8 | import numpy as np
9 |
10 | # Fix Python 2.x.
11 | try: input = raw_input
12 | except NameError: pass
13 |
14 | from utils.tfrecordfeatures import *
15 | from preprocess import decode_image, raw_images_to_array
16 |
17 | try:
18 | import ipdb as pdb
19 | except Exception:
20 | import pdb
21 |
22 |
23 | def display_data(file):
24 | gen = tf.python_io.tf_record_iterator(file)
25 | for data_i, string_record in enumerate(gen):
26 | result = tf.train.Example.FromString(string_record)
27 | features = result.features.feature
28 |
29 | # maps are np.uint8 arrays. each has a different size.
30 |
31 | # wall map: 0 for free space, 255 for walls
32 | map_wall = decode_image(features['map_wall'].bytes_list.value[0])
33 |
34 | # door map: 0 for free space, 255 for doors
35 | map_door = decode_image(features['map_door'].bytes_list.value[0])
36 |
37 | # roomtype map: binary encoding of 8 possible room categories
38 | # one state may belong to multiple room categories
39 | map_roomtype = decode_image(features['map_roomtype'].bytes_list.value[0])
40 |
41 | # roomid map: pixels correspond to unique room ids.
42 | # for overlapping rooms the higher ids overwrite lower ids
43 | map_roomid = decode_image(features['map_roomid'].bytes_list.value[0])
44 |
45 | # true states
46 | # (x, y, theta). x,y: pixel coordinates; theta: radians
47 | # coordinates index the map as a numpy array: map[x, y]
48 | true_states = features['states'].bytes_list.value[0]
49 | true_states = np.frombuffer(true_states, np.float32).reshape((-1, 3))
50 |
51 | # odometry
52 | # each entry is true_states[i+1]-true_states[i].
53 | # last row is always [0,0,0]
54 | odometry = features['odometry'].bytes_list.value[0]
55 | odometry = np.frombuffer(odometry, np.float32).reshape((-1, 3))
56 |
57 | # observations are enceded as a list of png images
58 | rgb = raw_images_to_array(list(features['rgb'].bytes_list.value))
59 | depth = raw_images_to_array(list(features['depth'].bytes_list.value))
60 |
61 | print ("True states (first three)")
62 | print (true_states[:3])
63 |
64 | print ("Odometry (first three)")
65 | print (odometry[:3])
66 |
67 | print("Plot map and first observation")
68 |
69 | # note: when printed as an image, map should be transposed
70 | plt.figure()
71 | plt.imshow(map_wall.transpose())
72 |
73 | plt.figure()
74 | plt.imshow(rgb[0])
75 |
76 | plt.show()
77 |
78 | if input("proceed?") != 'y':
79 | break
80 |
81 |
82 | if __name__ == '__main__':
83 | if len(sys.argv) < 2:
84 | print ("Usage: display_data.py xxx.tfrecords")
85 | exit()
86 |
87 | display_data(sys.argv[1])
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os, tqdm
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 | import pfnet
10 | from arguments import parse_args
11 | from preprocess import get_dataflow
12 |
13 | try:
14 | import ipdb as pdb
15 | except Exception:
16 | import pdb
17 |
18 |
19 | def run_evaluation(params):
20 | """ Run evaluation with the parsed arguments """
21 |
22 | # overwrite for evaluation
23 | params.batchsize = 1
24 | params.bptt_steps = params.trajlen
25 |
26 | with tf.Graph().as_default():
27 | if params.seed is not None:
28 | tf.set_random_seed(params.seed)
29 |
30 | # test data and network
31 | with tf.variable_scope(tf.get_variable_scope(), reuse=False):
32 | test_data, num_test_samples = get_dataflow(params.testfiles, params, is_training=False)
33 | test_brain = pfnet.PFNet(inputs=test_data[1:], labels=test_data[0], params=params, is_training=False)
34 |
35 | # Add the variable initializer Op.
36 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
37 |
38 | # Create a saver for writing training checkpoints.
39 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=3)
40 |
41 | # Create a session for running Ops on the Graph.
42 | os.environ["CUDA_VISIBLE_DEVICES"] = "%d"%int(params.gpu)
43 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
44 | sess_config.gpu_options.allow_growth = True
45 |
46 | # training session
47 | with tf.Session(config=sess_config) as sess:
48 |
49 | sess.run(init_op)
50 |
51 | # load model from checkpoint file
52 | if params.load:
53 | print("Loading model from " + params.load)
54 | saver.restore(sess, params.load)
55 |
56 | coord = tf.train.Coordinator()
57 | threads = tf.train.start_queue_runners(coord=coord)
58 |
59 | mse_list = [] # mean squared error
60 | success_list = [] # true for successful localization
61 |
62 | try:
63 | for step_i in tqdm.tqdm(range(num_test_samples)):
64 | all_distance2, _ = sess.run([test_brain.all_distance2_op, test_brain.update_state_op])
65 |
66 | # we have squared differences along the trajectory
67 | mse = np.mean(all_distance2[0])
68 | mse_list.append(mse)
69 |
70 | # localization is successfull if the rmse error is below 1m for the last 25% of the trajectory
71 | successful = np.all(all_distance2[0][-params.trajlen//4:] < 1.0 ** 2) # below 1 meter
72 | success_list.append(successful)
73 |
74 | except KeyboardInterrupt:
75 | pass
76 |
77 | except tf.errors.OutOfRangeError:
78 | print("data exhausted")
79 |
80 | finally:
81 | coord.request_stop()
82 | coord.join(threads)
83 |
84 | # report results
85 | mean_rmse = np.mean(np.sqrt(mse_list))
86 | total_rmse = np.sqrt(np.mean(mse_list))
87 | print ("Mean RMSE (average RMSE per trajectory) = %fcm"%(mean_rmse * 100))
88 | print("Overall RMSE (reported value) = %fcm" % (total_rmse * 100))
89 | print("Success rate = %f%%" % (np.mean(np.array(success_list, 'i')) * 100))
90 |
91 |
92 | if __name__ == '__main__':
93 | params = parse_args()
94 |
95 | run_evaluation(params)
--------------------------------------------------------------------------------
/figs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/figs/architecture.png
--------------------------------------------------------------------------------
/figs/trajectory.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/figs/trajectory.jpg
--------------------------------------------------------------------------------
/pfnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 | from transformer.spatial_transformer import transformer
8 | from utils.network_layers import conv2_layer, locallyconn2_layer, dense_layer
9 |
10 |
11 | class PFCell(tf.nn.rnn_cell.RNNCell):
12 | """
13 | PF-net for localization implemented with the RNN interface.
14 | Implements the particle set update, the observation and transition models.
15 | Cell inputs: observation, odometry
16 | Cell states: particle_states, particle_weights
17 | Cell outputs: particle_states, particle_weights
18 | """
19 |
20 | def __init__(self, global_maps, params, batch_size, num_particles):
21 | """
22 | :param global_maps: tensorflow op (batch, None, None, ch), global maps input. Since the map is fixed
23 | through the trajectory it can be input to the cell here, instead of part of the cell input.
24 | :param params: parsed arguments
25 | :param batch_size: int, minibatch size
26 | :param num_particles: number of particles
27 | """
28 | super(PFCell, self).__init__()
29 | self.global_maps = global_maps
30 | self.params = params
31 | self.batch_size = batch_size
32 | self.num_particles = num_particles
33 |
34 | self.states_shape = (batch_size, num_particles, 3)
35 | self.weights_shape = (batch_size, num_particles, )
36 |
37 | @property
38 | def state_size(self):
39 | return (tf.TensorShape(self.states_shape[1:]), tf.TensorShape(self.weights_shape[1:]))
40 |
41 | @property
42 | def output_size(self):
43 | return (tf.TensorShape(self.states_shape[1:]), tf.TensorShape(self.weights_shape[1:]))
44 |
45 | def __call__(self, inputs, state, scope=None):
46 | """
47 | Implements a particle update.
48 | :param inputs: observation (batch, 56, 56, ch), odometry (batch, 3).
49 | observation is the sensor reading at time t, odometry is the relative motion from time t to time t+1
50 | :param state: particle states (batch, K, 3), particle weights (batch, K).
51 | weights are assumed to be in log space and they can be unnormalized
52 | :param scope: not used, only kept for the interface. Ops will be created in the current scope.
53 | :return: outputs, state
54 | outputs: particle states and weights after the observation update, but before the transition update
55 | state: updated particle states and weights after both observation and transition updates
56 | """
57 | with tf.variable_scope(tf.get_variable_scope()):
58 | particle_states, particle_weights = state
59 | observation, odometry = inputs
60 |
61 | # observation update
62 | lik = self.observation_model(self.global_maps, particle_states, observation)
63 | particle_weights += lik # unnormalized
64 |
65 | # resample
66 | if self.params.resample:
67 | particle_states, particle_weights = self.resample(
68 | particle_states, particle_weights, alpha=self.params.alpha_resample_ratio)
69 |
70 | # construct output before motion update
71 | outputs = particle_states, particle_weights
72 |
73 | # motion update. this will only affect the particle state input at the next step
74 | particle_states = self.transition_model(particle_states, odometry)
75 |
76 | # construct new state
77 | state = particle_states, particle_weights
78 |
79 | return outputs, state
80 |
81 | def transition_model(self, particle_states, odometry):
82 | """
83 | Implements a stochastic transition model for localization.
84 | :param particle_states: tf op (batch, K, 3), particle states before the update.
85 | :param odometry: tf op (batch, 3), odometry reading, relative motion in the robot coordinate frame
86 | :return: particle_states updated with the odometry and optionally transition noise
87 | """
88 | translation_std = self.params.transition_std[0] / self.params.map_pixel_in_meters # in pixels
89 | rotation_std = self.params.transition_std[1] # in radians
90 |
91 | with tf.name_scope('transition'):
92 | part_x, part_y, part_th = tf.unstack(particle_states, axis=-1, num=3)
93 |
94 | odometry = tf.expand_dims(odometry, axis=1)
95 | odom_x, odom_y, odom_th = tf.unstack(odometry, axis=-1, num=3)
96 |
97 | noise_th = tf.random_normal(part_th.get_shape(), mean=0.0, stddev=1.0) * rotation_std
98 |
99 | # add orientation noise before translation
100 | part_th += noise_th
101 |
102 | cos_th = tf.cos(part_th)
103 | sin_th = tf.sin(part_th)
104 | delta_x = cos_th * odom_x - sin_th * odom_y
105 | delta_y = sin_th * odom_x + cos_th * odom_y
106 | delta_th = odom_th
107 |
108 | delta_x += tf.random_normal(delta_x.get_shape(), mean=0.0, stddev=1.0) * translation_std
109 | delta_y += tf.random_normal(delta_y.get_shape(), mean=0.0, stddev=1.0) * translation_std
110 |
111 | return tf.stack([part_x+delta_x, part_y+delta_y, part_th+delta_th], axis=-1)
112 |
113 | def observation_model(self, global_maps, particle_states, observation):
114 | """
115 | Implements a discriminative observation model for localization.
116 | The model transforms the single global map to local maps for each particle, where a local map is a local
117 | view from the state defined by the particle.
118 | :param global_maps: tf op (batch, None, None, ch), global maps input.
119 | Assumes a scaling 0..2 where 0 is occupied, 2 is free space.
120 | :param particle_states: tf op (batch, K, 3), particle states before the update
121 | :param observation: tf op (batch, 56, 56, ch), image observation from a rgb, depth, or rgbd camera.
122 | :return: tf op (batch, K) particle likelihoods in the log space, unnormalized
123 | """
124 |
125 | # transform global maps to local maps
126 | local_maps = self.transform_maps(global_maps, particle_states, (28, 28))
127 |
128 | # rescale from 0..2 to -1..1. This is not actually necessary.
129 | local_maps = -(local_maps - 1)
130 | # flatten batch and particle dimensions
131 | local_maps = tf.reshape(local_maps,
132 | [self.batch_size * self.num_particles] + local_maps.shape.as_list()[2:])
133 |
134 | # get features from the map
135 | map_features = self.map_features(local_maps)
136 |
137 | # get features from the observation
138 | obs_features = self.observation_features(observation)
139 |
140 | # tile observation features
141 | obs_features = tf.tile(tf.expand_dims(obs_features, axis=1), [1, self.num_particles, 1, 1, 1])
142 | obs_features = tf.reshape(obs_features,
143 | [self.batch_size * self.num_particles] + obs_features.shape.as_list()[2:])
144 |
145 | # sanity check
146 | assert obs_features.shape.as_list()[:-1] == map_features.shape.as_list()[:-1]
147 |
148 | # merge features and process further
149 | joint_features = tf.concat([map_features, obs_features], axis=-1)
150 | joint_features = self.joint_matrix_features(joint_features)
151 |
152 | # reshape to a vector and process further
153 | joint_features = tf.reshape(joint_features, (self.batch_size * self.num_particles, -1))
154 | lik = self.joint_vector_features(joint_features)
155 | lik = tf.reshape(lik, [self.batch_size, self.num_particles])
156 |
157 | return lik
158 |
159 | @staticmethod
160 | def resample(particle_states, particle_weights, alpha):
161 | """
162 | Implements (soft)-resampling of particles.
163 | :param particle_states: tf op (batch, K, 3), particle states
164 | :param particle_weights: tf op (batch, K), unnormalized particle weights in log space
165 | :param alpha: float, trade-off parameter for soft-resampling. alpha == 1 corresponds to standard,
166 | hard-resampling. alpha == 0 corresponds to sampling particles uniformly, ignoring their weights.
167 | :return: particle_states, particle_weights
168 | """
169 | with tf.name_scope('resample'):
170 | assert 0.0 < alpha <= 1.0
171 | batch_size, num_particles = particle_states.get_shape().as_list()[:2]
172 |
173 | # normalize
174 | particle_weights = particle_weights - tf.reduce_logsumexp(particle_weights, axis=-1, keep_dims=True)
175 |
176 | uniform_weights = tf.constant(-np.log(num_particles), shape=(batch_size, num_particles), dtype=tf.float32)
177 |
178 | # build sampling distribution, q(s), and update particle weights
179 | if alpha < 1.0:
180 | # soft resampling
181 | q_weights = tf.stack([particle_weights + np.log(alpha), uniform_weights + np.log(1.0-alpha)], axis=-1)
182 | q_weights = tf.reduce_logsumexp(q_weights, axis=-1, keep_dims=False)
183 | q_weights = q_weights - tf.reduce_logsumexp(q_weights, axis=-1, keep_dims=True) # normalized
184 |
185 | particle_weights = particle_weights - q_weights # this is unnormalized
186 | else:
187 | # hard resampling. this will produce zero gradients
188 | q_weights = particle_weights
189 | particle_weights = uniform_weights
190 |
191 | # sample particle indices according to q(s)
192 | indices = tf.cast(tf.multinomial(q_weights, num_particles), tf.int32) # shape: (batch_size, num_particles)
193 |
194 | # index into particles
195 | helper = tf.range(0, batch_size*num_particles, delta=num_particles, dtype=tf.int32) # (batch, )
196 | indices = indices + tf.expand_dims(helper, axis=1)
197 |
198 | particle_states = tf.reshape(particle_states, (batch_size * num_particles, 3))
199 | particle_states = tf.gather(particle_states, indices=indices, axis=0) # (batch_size, num_particles, 3)
200 |
201 | particle_weights = tf.reshape(particle_weights, (batch_size * num_particles, ))
202 | particle_weights = tf.gather(particle_weights, indices=indices, axis=0) # (batch_size, num_particles,)
203 |
204 | return particle_states, particle_weights
205 |
206 | @staticmethod
207 | def transform_maps(global_maps, particle_states, local_map_size):
208 | """
209 | Implements global to local map transformation
210 | :param global_maps: tf op (batch, None, None, ch) global map input
211 | :param particle_states: tf op (batch, K, 3) particle states that define local views for the transformation
212 | :param local_map_size: tuple, (height, widght), size of the output local maps
213 | :return: tf op (batch, K, local_map_size[0], local_map_size[1], ch). local maps, each shows a
214 | different transformation of the global map corresponding to the particle states
215 | """
216 | batch_size, num_particles = particle_states.get_shape().as_list()[:2]
217 | total_samples = batch_size * num_particles
218 | flat_states = tf.reshape(particle_states, [total_samples, 3])
219 |
220 | # define some helper variables
221 | input_shape = tf.shape(global_maps)
222 | global_height = tf.cast(input_shape[1], tf.float32)
223 | global_width = tf.cast(input_shape[2], tf.float32)
224 | height_inverse = 1.0 / global_height
225 | width_inverse = 1.0 / global_width
226 | # at tf1.6 matmul still does not support broadcasting, so we need full vectors
227 | zero = tf.constant(0, dtype=tf.float32, shape=(total_samples, ))
228 | one = tf.constant(1, dtype=tf.float32, shape=(total_samples, ))
229 |
230 | # the global map will be down-scaled by some factor
231 | window_scaler = 8.0
232 |
233 | # normalize orientations and precompute cos and sin functions
234 | theta = -flat_states[:, 2] - 0.5 * np.pi
235 | costheta = tf.cos(theta)
236 | sintheta = tf.sin(theta)
237 |
238 | # construct an affine transformation matrix step-by-step.
239 | # 1, translate the global map s.t. the center is at the particle state
240 | translate_x = (flat_states[:, 0] * width_inverse * 2.0) - 1.0
241 | translate_y = (flat_states[:, 1] * height_inverse * 2.0) - 1.0
242 |
243 | transm1 = tf.stack((one, zero, translate_x, zero, one, translate_y, zero, zero, one), axis=1)
244 | transm1 = tf.reshape(transm1, (total_samples, 3, 3))
245 |
246 | # 2, rotate map s.t. the orientation matches that of the particles
247 | rotm = tf.stack((costheta, sintheta, zero, -sintheta, costheta, zero, zero, zero, one), axis=1)
248 | rotm = tf.reshape(rotm, (total_samples, 3, 3))
249 |
250 | # 3, scale down the map
251 | scale_x = tf.fill((total_samples, ), float(local_map_size[1] * window_scaler) * width_inverse)
252 | scale_y = tf.fill((total_samples, ), float(local_map_size[0] * window_scaler) * height_inverse)
253 |
254 | scalem = tf.stack((scale_x, zero, zero, zero, scale_y, zero, zero, zero, one), axis=1)
255 | scalem = tf.reshape(scalem, (total_samples, 3, 3))
256 |
257 | # 4, translate the local map s.t. the particle defines the bottom mid-point instead of the center
258 | translate_y2 = tf.constant(-1.0, dtype=tf.float32, shape=(total_samples, ))
259 |
260 | transm2 = tf.stack((one, zero, zero, zero, one, translate_y2, zero, zero, one), axis=1)
261 | transm2 = tf.reshape(transm2, (total_samples, 3, 3))
262 |
263 | # chain the transformation matrices into a single one: translate + rotate + scale + translate
264 | transform_m = tf.matmul(tf.matmul(tf.matmul(transm1, rotm), scalem), transm2)
265 |
266 | # reshape to the format expected by the spatial transform network
267 | transform_m = tf.reshape(transform_m[:, :2], (batch_size, num_particles, 6))
268 |
269 | # do the image transformation using the spatial transform network
270 | # iterate over particle to avoid tiling large global maps
271 | output_list = []
272 | for i in range(num_particles):
273 | output_list.append(transformer(global_maps, transform_m[:,i], local_map_size))
274 |
275 | local_maps = tf.stack(output_list, axis=1)
276 |
277 | # set shape information that is lost in the spatial transform network
278 | local_maps = tf.reshape(local_maps, (batch_size, num_particles, local_map_size[0], local_map_size[1],
279 | global_maps.shape.as_list()[-1]))
280 |
281 | return local_maps
282 |
283 | @staticmethod
284 | def map_features(local_maps):
285 | assert local_maps.get_shape().as_list()[1:3] == [28, 28]
286 | data_format = 'channels_last'
287 |
288 | with tf.variable_scope("map"):
289 | x = local_maps
290 | layer_i = 1
291 | convs = [
292 | conv2_layer(
293 | 24, (3, 3), activation=None, padding='same', data_format=data_format,
294 | use_bias=True, layer_i=layer_i)(x),
295 | conv2_layer(
296 | 16, (5, 5), activation=None, padding='same', data_format=data_format,
297 | use_bias=True, layer_i=layer_i)(x),
298 | conv2_layer(
299 | 8, (7, 7), activation=None, padding='same', data_format=data_format,
300 | use_bias=True, layer_i=layer_i)(x),
301 | conv2_layer(
302 | 8, (7, 7), activation=None, padding='same', data_format=data_format,
303 | dilation_rate=(2, 2), use_bias=True, layer_i=layer_i)(x),
304 | conv2_layer(
305 | 8, (7, 7), activation=None, padding='same', data_format=data_format,
306 | dilation_rate=(3, 3), use_bias=True, layer_i=layer_i)(x),
307 | ]
308 | x = tf.concat(convs, axis=-1)
309 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu)
310 | assert x.get_shape().as_list()[1:4] == [28, 28, 64]
311 | # (28x28x64)
312 |
313 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="same")
314 |
315 | layer_i+=1
316 | convs = [
317 | conv2_layer(
318 | 4, (3, 3), activation=None, padding='same', data_format=data_format,
319 | use_bias=True, layer_i=layer_i)(x),
320 | conv2_layer(
321 | 4, (5, 5), activation=None, padding='same', data_format=data_format,
322 | use_bias=True, layer_i=layer_i)(x),
323 | ]
324 | x = tf.concat(convs, axis=-1)
325 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu)
326 |
327 | return x # (14x14x8)
328 |
329 | @staticmethod
330 | def observation_features(observation):
331 | data_format = 'channels_last'
332 | with tf.variable_scope("observation"):
333 | x = observation
334 | layer_i = 1
335 | convs = [
336 | conv2_layer(
337 | 128, (3, 3), activation=None, padding='same', data_format=data_format,
338 | use_bias=True, layer_i=layer_i)(x),
339 | conv2_layer(
340 | 128, (5, 5), activation=None, padding='same', data_format=data_format,
341 | use_bias=True, layer_i=layer_i)(x),
342 | conv2_layer(
343 | 64, (5, 5), activation=None, padding='same', data_format=data_format,
344 | dilation_rate=(2, 2), use_bias=True, layer_i=layer_i)(x),
345 | conv2_layer(
346 | 64, (5, 5), activation=None, padding='same', data_format=data_format,
347 | dilation_rate=(4, 4), use_bias=True, layer_i=layer_i)(x),
348 | ]
349 | x = tf.concat(convs, axis=-1)
350 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="same")
351 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu)
352 |
353 | assert x.get_shape().as_list()[1:4] == [28, 28, 384]
354 |
355 | layer_i += 1
356 | x = conv2_layer(
357 | 16, (3, 3), activation=None, padding='same', data_format=data_format,
358 | use_bias=True, layer_i=layer_i)(x)
359 |
360 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="same")
361 | x = tf.contrib.layers.layer_norm(x, activation_fn=tf.nn.relu)
362 | assert x.get_shape().as_list()[1:4] == [14, 14, 16]
363 |
364 | return x # (14,14,16)
365 |
366 | @staticmethod
367 | def joint_matrix_features(joint_matrix):
368 | assert joint_matrix.get_shape().as_list()[1:4] == [14, 14, 24]
369 | data_format = 'channels_last'
370 |
371 | with tf.variable_scope("joint"):
372 | x = joint_matrix
373 | layer_i = 1
374 |
375 | # pad manually to match different kernel sizes
376 | x_pad1 = tf.pad(x, paddings=tf.constant([[0, 0], [1, 1,], [1, 1], [0, 0]]))
377 | convs = [
378 | locallyconn2_layer(
379 | 8, (3, 3), activation='relu', padding='valid', data_format=data_format,
380 | use_bias=True, layer_i=layer_i)(x),
381 | locallyconn2_layer(
382 | 8, (5, 5), activation='relu', padding='valid', data_format=data_format,
383 | use_bias=True, layer_i=layer_i)(x_pad1),
384 | ]
385 | x = tf.concat(convs, axis=-1)
386 |
387 | x = tf.layers.max_pooling2d(x, pool_size=(3, 3), strides=(2, 2), padding="valid")
388 | assert x.get_shape().as_list()[1:4] == [5, 5, 16]
389 |
390 | return x # (5, 5, 16)
391 |
392 | @staticmethod
393 | def joint_vector_features(joint_vector):
394 | with tf.variable_scope("joint"):
395 | x = joint_vector
396 | x = dense_layer(1, activation=None, use_bias=True, name='fc1')(x)
397 | return x
398 |
399 |
400 | class PFNet(object):
401 | """ Implements PF-net. Unrolls the PF-net RNN cell and defines losses and training ops."""
402 | def __init__(self, inputs, labels, params, is_training):
403 | """
404 | Calling this will create all tf ops for PF-net.
405 | :param inputs: list of tf ops, the inputs to PF-net. Assumed to have the following elements:
406 | global_maps, init_particle_states, observations, odometries, is_first_step
407 | :param labels: tf op, labels for training. Assumed to be the true states along the trajectory.
408 | :param params: parsed arguments
409 | :param is_training: bool, true for training.
410 | """
411 | self.params = params
412 |
413 | # define ops to be accessed conveniently from outside
414 | self.outputs = []
415 | self.hidden_states = []
416 |
417 | self.train_loss_op = None
418 | self.valid_loss_op = None
419 | self.all_distance2_op = None
420 |
421 | self.global_step_op = None
422 | self.learning_rate_op = None
423 | self.train_op = None
424 | self.update_state_op = tf.constant(0)
425 |
426 | # build the network. this will generate the ops defined above
427 | self.build(inputs, labels, is_training)
428 |
429 | def build(self, inputs, labels, is_training):
430 | """
431 | Unroll the PF-net RNN cell and create loss ops and optionally, training ops
432 | """
433 | self.outputs = self.build_rnn(*inputs)
434 |
435 | self.build_loss_op(self.outputs[0], self.outputs[1], true_states=labels)
436 |
437 | if is_training:
438 | self.build_train_op()
439 |
440 | def save_state(self, sess):
441 | """
442 | Returns a list, the hidden state of PF-net, i.e. the particle states and particle weights.
443 | The output can be used with load_state to restore the current hidden state.
444 | """
445 | return sess.run(self.hidden_states)
446 |
447 | def load_state(self, sess, saved_state):
448 | """
449 | Overwrite the hidden state of PF-net to that of saved_state.
450 | """
451 | return sess.run(self.hidden_states,
452 | feed_dict={self.hidden_states[i]: saved_state[i] for i in range(len(self.hidden_states))})
453 |
454 | def build_loss_op(self, particle_states, particle_weights, true_states):
455 | """
456 | Create tf ops for various losses. This should be called only once with is_training=True.
457 | """
458 | assert particle_weights.get_shape().ndims == 3
459 |
460 | lin_weights = tf.nn.softmax(particle_weights, dim=-1)
461 |
462 | true_coords = true_states[:, :, :2]
463 | mean_coords = tf.reduce_sum(tf.multiply(particle_states[:,:,:,:2], lin_weights[:,:,:,None]), axis=2)
464 | coord_diffs = mean_coords - true_coords
465 |
466 | # convert from pixel coordinates to meters
467 | coord_diffs *= self.params.map_pixel_in_meters
468 |
469 | # coordinate loss component: (x-x')^2 + (y-y')^2
470 | loss_coords = tf.reduce_sum(tf.square(coord_diffs), axis=2)
471 |
472 | true_orients = true_states[:, :, 2]
473 | orient_diffs = particle_states[:, :, :, 2] - true_orients[:,:,None]
474 | # normalize between -pi..+pi
475 | orient_diffs = tf.mod(orient_diffs + np.pi, 2*np.pi) - np.pi
476 | # orintation loss component: (sum_k[(theta_k-theta')*weight_k] )^2
477 | loss_orient = tf.square(tf.reduce_sum(orient_diffs * lin_weights, axis=2))
478 |
479 | # combine translational and orientation losses
480 | loss_combined = loss_coords + 0.36 * loss_orient
481 | loss_pred = tf.reduce_mean(loss_combined, name='prediction_loss')
482 |
483 | # add L2 regularization loss
484 | loss_reg = tf.multiply(tf.losses.get_regularization_loss(), self.params.l2scale, name='l2')
485 | loss_total = tf.add_n([loss_pred, loss_reg], name="training_loss")
486 |
487 | self.all_distance2_op = loss_coords
488 | self.valid_loss_op = loss_pred
489 | self.train_loss_op = loss_total
490 |
491 | return loss_total
492 |
493 | def build_train_op(self):
494 | """ Create optimizer and train op. This should be called only once. """
495 |
496 | # make sure this is only called once
497 | assert self.train_op is None and self.global_step_op is None and self.learning_rate_op is None
498 |
499 | # global step and learning rate
500 | with tf.device("/cpu:0"):
501 | self.global_step_op = tf.get_variable(
502 | initializer=tf.constant_initializer(0.0), shape=(), trainable=False, name='global_step',)
503 | self.learning_rate_op = tf.train.exponential_decay(
504 | self.params.learningrate, self.global_step_op, decay_steps=1, decay_rate=self.params.decayrate,
505 | staircase=True, name="learning_rate")
506 |
507 | # create gradient descent optimizer with the given learning rate.
508 | optimizer = tf.train.RMSPropOptimizer(self.learning_rate_op, decay=0.9)
509 |
510 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
511 | self.train_op = optimizer.minimize(self.train_loss_op, global_step=None, var_list=tf.trainable_variables())
512 |
513 | return self.train_op
514 |
515 | def build_rnn(self, global_maps, init_particle_states, observations, odometries, is_first_step):
516 | """
517 | Unroll the PF-net RNN cell through time. Input arguments are the inputs to PF-net. The time dependent
518 | fields are expected to be broken into fixed-length segments defined by params.bptt_steps
519 | """
520 | batch_size, trajlen = observations.shape.as_list()[:2]
521 | num_particles = init_particle_states.shape.as_list()[1]
522 | global_map_ch = global_maps.shape.as_list()[-1]
523 |
524 | init_particle_weights = tf.constant(np.log(1.0/float(num_particles)),
525 | shape=(batch_size, num_particles), dtype=tf.float32)
526 |
527 | # create hidden state variable
528 | assert len(self.hidden_states) == 0 # no hidden state should be set before
529 | self.hidden_states = [
530 | tf.get_variable("particle_states", shape=init_particle_states.get_shape(),
531 | dtype=init_particle_states.dtype, initializer=tf.constant_initializer(0), trainable=False),
532 | tf.get_variable("particle_weights", shape=init_particle_weights.get_shape(),
533 | dtype=init_particle_weights.dtype, initializer=tf.constant_initializer(0), trainable=False),
534 | ]
535 |
536 | # choose state for the current trajectory segment
537 | state = tf.cond(is_first_step,
538 | true_fn=lambda: (init_particle_states, init_particle_weights),
539 | false_fn=lambda: tuple(self.hidden_states))
540 |
541 | with tf.variable_scope("rnn"):
542 | # hack to create variables on GPU
543 | dummy_cell_func = PFCell(
544 | global_maps=tf.zeros((1, 1, 1, global_map_ch), dtype=global_maps.dtype),
545 | params=self.params, batch_size=1, num_particles=1)
546 |
547 | dummy_cell_func(
548 | (tf.zeros([1]+observations.get_shape().as_list()[2:], dtype=observations.dtype), # observation
549 | tf.zeros([1, 3], dtype=odometries.dtype)), # odometry
550 | (tf.zeros([1, 1, 3], dtype=init_particle_states.dtype), # particle_states
551 | tf.zeros([1, 1], dtype=init_particle_weights.dtype))) # particle_weights
552 |
553 | # variables are now created. set reuse
554 | tf.get_variable_scope().reuse_variables()
555 |
556 | # unroll real steps using the variables already created
557 | cell_func = PFCell(global_maps=global_maps, params=self.params, batch_size=batch_size,
558 | num_particles=num_particles)
559 |
560 | outputs, state = tf.nn.dynamic_rnn(cell=cell_func,
561 | inputs=(observations, odometries),
562 | initial_state=state,
563 | swap_memory=True,
564 | time_major=False,
565 | parallel_iterations=1,
566 | scope=tf.get_variable_scope())
567 |
568 | particle_states, particle_weights = outputs
569 |
570 | # define an op to update the hidden state, i.e. the particle states and particle weights.
571 | # this should be evaluated after every input
572 | with tf.control_dependencies([particle_states, particle_weights]):
573 | self.update_state_op = tf.group(
574 | *(self.hidden_states[i].assign(state[i]) for i in range(len(self.hidden_states))))
575 |
576 | return particle_states, particle_weights
577 |
578 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import tensorflow as tf
7 | import numpy as np
8 | import cv2
9 |
10 | from tensorpack import dataflow
11 | from tensorpack.dataflow.base import RNGDataFlow, ProxyDataFlow
12 |
13 | try:
14 | import ipdb as pdb
15 | except Exception:
16 | import pdb
17 |
18 |
19 | def decode_image(img_str, resize=None):
20 | """
21 | Decode image from tfrecord data
22 | :param img_str: image encoded as a png in a string
23 | :param resize: tuple width two elements that defines the new size of the image. optional
24 | :return: image as a numpy array
25 | """
26 | nparr = np.fromstring(img_str, np.uint8)
27 | img_str = cv2.imdecode(nparr, -1)
28 | if resize is not None:
29 | img_str = cv2.resize(img_str, resize)
30 | return img_str
31 |
32 |
33 | def raw_images_to_array(images):
34 | """
35 | Decode and normalize multiple images from tfrecord data
36 | :param images: list of images encoded as a png in a string
37 | :return: a numpy array of size (N, 56, 56, channels), normalized for training
38 | """
39 | image_list = []
40 | for image_str in images:
41 | image = decode_image(image_str, (56, 56))
42 | image = scale_observation(np.atleast_3d(image.astype(np.float32)))
43 | image_list.append(image)
44 |
45 | return np.stack(image_list, axis=0)
46 |
47 |
48 | def scale_observation(x):
49 | """
50 | Normalizes observation input, either an rgb image or a depth image
51 | :param x: observation input as numpy array, either an rgb image or a depth image
52 | :return: numpy array, a normalized observation
53 | """
54 | if x.ndim == 2 or x.shape[2] == 1: # depth
55 | return x * (2.0 / 100.0) - 1.0
56 | else: # rgb
57 | return x * (2.0/255.0) - 1.0
58 |
59 |
60 | def bounding_box(img):
61 | """
62 | Bounding box of non-zeros in an array (inclusive). Used with 2D maps
63 | :param img: numpy array
64 | :return: inclusive bounding box indices: top_row, bottom_row, leftmost_column, rightmost_column
65 | """
66 | # helper function to
67 | rows = np.any(img, axis=1)
68 | cols = np.any(img, axis=0)
69 | rmin, rmax = np.where(rows)[0][[0, -1]]
70 | cmin, cmax = np.where(cols)[0][[0, -1]]
71 |
72 | return rmin, rmax, cmin, cmax
73 |
74 |
75 | class BatchDataWithPad(dataflow.BatchData):
76 | """
77 | Stacks datapoints into batches. Selected elements can be padded to the same size in each batch.
78 | """
79 |
80 | def __init__(self, ds, batch_size, remainder=False, use_list=False, padded_indices=()):
81 | """
82 | :param ds: input dataflow. Same as BatchData
83 | :param batch_size: mini batch size. Same as BatchData
84 | :param remainder: if data is not enough to form a full batch, it makes a smaller batch when true.
85 | Same as BatchData.
86 | :param use_list: if True, components will contain a list of datapoints instead of creating a new numpy array.
87 | Same as BatchData.
88 | :param padded_indices: list of filed indices for which all elements will be padded with zeros to mach
89 | the largest in the batch. Each batch may produce a different size datapoint.
90 | """
91 | super(BatchDataWithPad, self).__init__(ds, batch_size, remainder, use_list)
92 | self.padded_indices = padded_indices
93 |
94 | def get_data(self):
95 | """
96 | Yields: Batched data by stacking each component on an extra 0th dimension.
97 | """
98 | holder = []
99 | for data in self.ds.get_data():
100 | holder.append(data)
101 | if len(holder) == self.batch_size:
102 | yield BatchDataWithPad._aggregate_batch(holder, self.use_list, self.padded_indices)
103 | del holder[:]
104 | if self.remainder and len(holder) > 0:
105 | yield BatchDataWithPad._aggregate_batch(holder, self.use_list, self.padded_indices)
106 |
107 | @staticmethod
108 | def _aggregate_batch(data_holder, use_list=False, padded_indices=()):
109 | """
110 | Re-implement the parent function with the option to pad selected fields to the largest in the batch.
111 | """
112 | assert not use_list # cannot match shape if they must be treated as lists
113 | size = len(data_holder[0])
114 | result = []
115 | for k in range(size):
116 | dt = data_holder[0][k]
117 | if type(dt) in [int, bool]:
118 | tp = 'int32'
119 | elif type(dt) == float:
120 | tp = 'float32'
121 | else:
122 | try:
123 | tp = dt.dtype
124 | except AttributeError:
125 | raise TypeError("Unsupported type to batch: {}".format(type(dt)))
126 | try:
127 | if k in padded_indices:
128 | # pad this field
129 | shapes = np.array([x[k].shape for x in data_holder], 'i') # assumes ndim are the same for all
130 | assert shapes.shape[1] == 3 # only supports 3D arrays for now, e.g. images (height, width, ch)
131 | matching_shape = shapes.max(axis=0).tolist()
132 | new_data = np.zeros([shapes.shape[0]] + matching_shape, dtype=tp)
133 | for i in range(len(data_holder)):
134 | shape = data_holder[i][k].shape
135 | new_data[i, :shape[0], :shape[1], :shape[2]] = data_holder[i][k]
136 | result.append(new_data)
137 | else:
138 | # no need to pad this field, simply create batch
139 | result.append(np.asarray([x[k] for x in data_holder], dtype=tp))
140 | except Exception as e:
141 | # exception handling. same as in parent class
142 | pdb.set_trace()
143 | dataflow.logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
144 | if isinstance(dt, np.ndarray):
145 | s = dataflow.pprint.pformat([x[k].shape for x in data_holder])
146 | dataflow.logger.error("Shape of all arrays to be batched: " + s)
147 | try:
148 | # open an ipython shell if possible
149 | import IPython as IP; IP.embed() # noqa
150 | except ImportError:
151 | pass
152 | return result
153 |
154 |
155 | class BreakForBPTT(ProxyDataFlow):
156 | """
157 | Breaks long trajectories into multiple smaller segments for training with BPTT.
158 | Adds an extra field for indicating the first segment of a trajectory.
159 | """
160 | def __init__(self, ds, timed_indices, trajlen, bptt_steps):
161 | """
162 | :param ds: input dataflow
163 | :param timed_indices: field indices for which the second dimension corresponds to timestep along the trajectory
164 | :param trajlen: full length of trajectories
165 | :param bptt_steps: segment length, number of backprop steps for BPTT. Must be an integer divisor of trajlen
166 | """
167 | super(BreakForBPTT, self).__init__(ds)
168 | self.timed_indiced = timed_indices
169 | self.bptt_steps = bptt_steps
170 |
171 | assert trajlen % bptt_steps == 0
172 | self.num_segments = trajlen // bptt_steps
173 |
174 | def size(self):
175 | return self.ds.size() * self.num_segments
176 |
177 | def get_data(self):
178 | """
179 | Yields multiple datapoints per input datapoints corresponding segments of the trajectory.
180 | Adds an extra field for indicating the first segment of a trajectory.
181 | """
182 |
183 | for data in self.ds.get_data():
184 | for split_i in range(self.num_segments):
185 | new_data = []
186 | for i in range(len(data)):
187 | if i in self.timed_indiced:
188 | new_data.append(data[i][:, split_i*self.bptt_steps:(split_i+1)*self.bptt_steps])
189 | else:
190 | new_data.append(data[i])
191 |
192 | new_data.append((split_i == 0))
193 |
194 | yield new_data
195 |
196 |
197 | class House3DTrajData(RNGDataFlow):
198 | """
199 | Process tfrecords data of House3D trajectories. Produces a dataflow with the following fields:
200 | true state, global map, initial particles, observations, odometries
201 | """
202 |
203 | def __init__(self, files, mapmode, obsmode, trajlen, num_particles, init_particles_distr, init_particles_cov,
204 | seed=None):
205 | """
206 | :param files: list of data file names. assumed to be tfrecords files
207 | :param mapmode: string, map type. Possible values: wall / wall-door / wall-roomtype / wall-door-roomtype
208 | :param obsmode: string, observation type. Possible values: rgb / depth / rgb-depth. Vrf is not yet supported
209 | :param trajlen: int, length of trajectories
210 | :param num_particles: int, number of particles
211 | :param init_particles_distr: string, type of initial particle distribution.
212 | Possible values: tracking / one-room. Does not support two-rooms and all-rooms yet.
213 | :param init_particles_cov: numpy array of shape (3,3), coveriance matrix for the initial particles. Ignored
214 | when init_particles_distr != 'tracking'.
215 | :param seed: int or None. Random seed will be fixed if not None.
216 | """
217 | self.files = files
218 | self.mapmode = mapmode
219 | self.obsmode = obsmode
220 | self.trajlen = trajlen
221 | self.num_particles = num_particles
222 | self.init_particles_distr = init_particles_distr
223 | self.init_particles_cov = init_particles_cov
224 | self.seed = seed
225 |
226 | # count total number of entries
227 | count = 0
228 | for f in self.files:
229 | if not os.path.isfile(f):
230 | raise ValueError('Failed to find file: ' + f)
231 | record_iterator = tf.python_io.tf_record_iterator(f)
232 | for _ in record_iterator:
233 | count += 1
234 | self.count = count
235 |
236 | def size(self):
237 | return self.count
238 |
239 | def reset_state(self):
240 | """ Reset state. Fix numpy random seed if needed."""
241 | super(House3DTrajData, self).reset_state()
242 | if self.seed is not None:
243 | np.random.seed(1)
244 | else:
245 | np.random.seed(self.rng.randint(0, 99999999))
246 |
247 | def get_data(self):
248 | """
249 | Yields datapoints, all numpy arrays, with the following fields.
250 |
251 | true states: (trajlen, 3). Second dimension corresponds to x, y, theta coordinates.
252 |
253 | global map: (n, m, ch). shape is different for each map. number of channels depend on the mapmode setting
254 |
255 | initial particles: (num_particles, 3)
256 |
257 | observations: (trajlen, 56, 56, ch) number of channels depend on the obsmode setting
258 |
259 | odometries: (trajlen, 3) relative motion in the robot coordinate frame
260 | """
261 | for file in self.files:
262 | gen = tf.python_io.tf_record_iterator(file)
263 | for data_i, string_record in enumerate(gen):
264 | result = tf.train.Example.FromString(string_record)
265 | features = result.features.feature
266 |
267 | # process maps
268 | map_wall = self.process_wall_map(features['map_wall'].bytes_list.value[0])
269 | global_map_list = [map_wall]
270 | if 'door' in self.mapmode:
271 | map_door = self.process_door_map(features['map_door'].bytes_list.value[0])
272 | global_map_list.append(map_door)
273 | if 'roomtype' in self.mapmode:
274 | map_roomtype = self.process_roomtype_map(features['map_roomtype'].bytes_list.value[0])
275 | global_map_list.append(map_roomtype)
276 | if self.init_particles_distr == 'tracking':
277 | map_roomid = None
278 | else:
279 | map_roomid = self.process_roomid_map(features['map_roomid'].bytes_list.value[0])
280 |
281 | # input global map is a concatentation of semantic channels
282 | global_map = np.concatenate(global_map_list, axis=-1)
283 |
284 | # rescale to 0..2 range. this way zero padding will produce the equivalent of obstacles
285 | global_map = global_map.astype(np.float32) * (2.0 / 255.0)
286 |
287 | # process true states
288 | true_states = features['states'].bytes_list.value[0]
289 | true_states = np.frombuffer(true_states, np.float32).reshape((-1, 3))
290 |
291 | # trajectory may be longer than what we use for training
292 | data_trajlen = true_states.shape[0]
293 | assert data_trajlen >= self.trajlen
294 | true_states = true_states[:self.trajlen]
295 |
296 | # process odometry
297 | odometry = features['odometry'].bytes_list.value[0]
298 | odometry = np.frombuffer(odometry, np.float32).reshape((-1, 3))
299 |
300 | # process observations
301 | assert self.obsmode in ['rgb', 'depth', 'rgb-depth'] #TODO support for lidar
302 | if 'rgb' in self.obsmode:
303 | rgb = raw_images_to_array(list(features['rgb'].bytes_list.value)[:self.trajlen])
304 | observation = rgb
305 | if 'depth' in self.obsmode:
306 | depth = raw_images_to_array(list(features['depth'].bytes_list.value)[:self.trajlen])
307 | observation = depth
308 | if self.obsmode == 'rgb-depth':
309 | observation = np.concatenate((rgb, depth), axis=-1)
310 |
311 | # generate particle states
312 | init_particles = self.random_particles(true_states[0], self.init_particles_distr,
313 | self.init_particles_cov, self.num_particles,
314 | roomidmap=map_roomid,
315 | seed=self.get_sample_seed(self.seed, data_i), )
316 |
317 | yield (true_states, global_map, init_particles, observation, odometry)
318 |
319 | def process_wall_map(self, wallmap_feature):
320 | floormap = np.atleast_3d(decode_image(wallmap_feature))
321 | # transpose and invert
322 | floormap = 255 - np.transpose(floormap, axes=[1, 0, 2])
323 | return floormap
324 |
325 | def process_door_map(self, doormap_feature):
326 | return self.process_wall_map(doormap_feature)
327 |
328 | def process_roomtype_map(self, roomtypemap_feature):
329 | binary_map = np.fromstring(roomtypemap_feature, np.uint8)
330 | binary_map = cv2.imdecode(binary_map, 2) # 16-bit image
331 | assert binary_map.dtype == np.uint16 and binary_map.ndim == 2
332 | # binary encoding from bit 0 .. 9
333 |
334 | room_map = np.zeros((binary_map.shape[0], binary_map.shape[1], 9), dtype=np.uint8)
335 | for i in range(9):
336 | room_map[:,:,i] = np.array((np.bitwise_and(binary_map, (1 << i)) > 0), dtype=np.uint8)
337 | room_map *= 255
338 |
339 | # transpose and invert
340 | room_map = np.transpose(room_map, axes=[1, 0, 2])
341 | return room_map
342 |
343 | def process_roomid_map(self, roomidmap_feature):
344 | # this is not transposed, unlike other maps
345 | roomidmap = np.atleast_3d(decode_image(roomidmap_feature))
346 | return roomidmap
347 |
348 | @staticmethod
349 | def random_particles(state, distr, particles_cov, num_particles, roomidmap, seed=None):
350 | """
351 | Generate a random set of particles
352 | :param state: true state, numpy array of x,y,theta coordinates
353 | :param distr: string, type of distribution. Possible values: tracking / one-room.
354 | For 'tracking' the distribution is a Gaussian centered near the true state.
355 | For 'one-room' the distribution is uniform over states in the room defined by the true state.
356 | :param particles_cov: numpy array of shape (3,3), defines the covariance matrix if distr == 'tracking'
357 | :param num_particles: number of particles
358 | :param roomidmap: numpy array, map of room ids. Values define a unique room id for each pixel of the map.
359 | :param seed: int or None. If not None, the random seed will be fixed for generating the particle.
360 | The random state is restored to its original value.
361 | :return: numpy array of particles (num_particles, 3)
362 | """
363 | assert distr in ["tracking", "one-room"] #TODO add support for two-room and all-room
364 |
365 | particles = np.zeros((num_particles, 3), np.float32)
366 |
367 | if distr == "tracking":
368 | # fix seed
369 | if seed is not None:
370 | random_state = np.random.get_state()
371 | np.random.seed(seed)
372 |
373 | # sample offset from the Gaussian
374 | center = np.random.multivariate_normal(mean=state, cov=particles_cov)
375 |
376 | # restore random seed
377 | if seed is not None:
378 | np.random.set_state(random_state)
379 |
380 | # sample particles from the Gaussian, centered around the offset
381 | particles = np.random.multivariate_normal(mean=center, cov=particles_cov, size=num_particles)
382 |
383 | elif distr == "one-room":
384 | # mask the room the initial state is in
385 | masked_map = (roomidmap == roomidmap[int(np.rint(state[0])), int(np.rint(state[1]))])
386 |
387 | # get bounding box for more efficient sampling
388 | rmin, rmax, cmin, cmax = bounding_box(masked_map)
389 |
390 | # rejection sampling inside bounding box
391 | sample_i = 0
392 | while sample_i < num_particles:
393 | particle = np.random.uniform(low=(rmin, cmin, 0.0), high=(rmax, cmax, 2.0*np.pi), size=(3, ),)
394 | # reject if mask is zero
395 | if not masked_map[int(np.rint(particle[0])), int(np.rint(particle[1]))]:
396 | continue
397 | particles[sample_i] = particle
398 | sample_i += 1
399 | else:
400 | raise ValueError
401 |
402 | return particles
403 |
404 | @staticmethod
405 | def get_sample_seed(seed, data_i):
406 | """
407 | Defines a random seed for each datapoint in a deterministic manner.
408 | :param seed: int or None, defining a random seed
409 | :param data_i: int, the index of the current data point
410 | :return: None if seed is None, otherwise an int, a fixed function of both seed and data_i inputs.
411 | """
412 | return (None if (seed is None or seed == 0) else ((data_i + 1) * 113 + seed))
413 |
414 |
415 | def get_dataflow(files, params, is_training):
416 | """
417 | Build a tensorflow Dataset from appropriate tfrecords files.
418 | :param files: list a file paths corresponding to appropriate tfrecords data
419 | :param params: parsed arguments
420 | :param is_training: bool, true for training.
421 | :return: (nextdata, num_samples).
422 | nextdata: list of tensorflow ops that produce the next input with the following elements:
423 | true_states, global_map, init_particles, observations, odometries, is_first_step.
424 | See House3DTrajData.get_data for definitions.
425 | num_samples: number of samples that make an epoch
426 | """
427 |
428 | mapmode = params.mapmode
429 | obsmode = params.obsmode
430 | batchsize = params.batchsize
431 | num_particles = params.num_particles
432 | trajlen = params.trajlen
433 | bptt_steps = params.bptt_steps
434 |
435 | # build initial covariance matrix of particles, in pixels and radians
436 | particle_std = params.init_particles_std.copy()
437 | particle_std[0] = particle_std[0] / params.map_pixel_in_meters # convert meters to pixels
438 | particle_std2 = np.square(particle_std) # variance
439 | init_particles_cov = np.diag(particle_std2[(0, 0, 1),])
440 |
441 | df = House3DTrajData(files, mapmode, obsmode, trajlen, num_particles,
442 | params.init_particles_distr, init_particles_cov,
443 | seed=(params.seed if params.seed is not None and params.seed > 0
444 | else (params.validseed if not is_training else None)))
445 | # data: true_states, global_map, init_particles, observation, odometry
446 |
447 | # make it a multiple of batchsize
448 | df = dataflow.FixedSizeData(df, size=(df.size() // batchsize) * batchsize, keep_state=False)
449 |
450 | # shuffle
451 | if is_training:
452 | df = dataflow.LocallyShuffleData(df, 100 * batchsize)
453 |
454 | # repeat data for the number of epochs
455 | df = dataflow.RepeatedData(df, params.epochs)
456 |
457 | # batch
458 | df = BatchDataWithPad(df, batchsize, padded_indices=(1,))
459 |
460 | # break trajectory into multiple segments for BPPT training. Augment df with is_first_step indicator
461 | df = BreakForBPTT(df, timed_indices=(0, 3, 4), trajlen=trajlen, bptt_steps=bptt_steps)
462 | # data: true_states, global_map, init_particles, observation, odometry, is_first_step
463 |
464 | num_samples = df.size() // params.epochs
465 |
466 | df.reset_state()
467 |
468 | # # test dataflow
469 | # df = dataflow.TestDataSpeed(dataflow.PrintData(df), 100)
470 | # df.start()
471 |
472 | obs_ch = {'rgb': 3, 'depth': 1, 'rgb-depth': 4}
473 | map_ch = {'wall': 1, 'wall-door': 2, 'wall-roomtype': 10, 'wall-door-roomtype': 11}
474 | types = [tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.bool]
475 | sizes = [(batchsize, bptt_steps, 3),
476 | (batchsize, None, None, map_ch[mapmode]),
477 | (batchsize, num_particles, 3),
478 | (batchsize, bptt_steps, 56, 56, obs_ch[obsmode]),
479 | (batchsize, bptt_steps, 3),
480 | (), ]
481 |
482 | # turn it into a tf dataset
483 | def tuplegen():
484 | for dp in df.get_data():
485 | yield tuple(dp)
486 |
487 | dataset = tf.data.Dataset.from_generator(tuplegen, tuple(types), tuple(sizes))
488 | iterator = dataset.make_one_shot_iterator()
489 | nextdata = iterator.get_next()
490 |
491 | return nextdata, num_samples
492 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ConfigArgParse==0.12.0
2 | numpy==1.14.2
3 | opencv_python==3.3.0.10
4 | matplotlib==2.1.0
5 | tensorpack==0.8.5
6 | tqdm==4.19.8
7 | six==1.11.0
8 | ipdb==0.10.3
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os, tqdm
6 | import tensorflow as tf
7 | import numpy as np
8 | from datetime import datetime
9 |
10 | import pfnet
11 | from arguments import parse_args
12 | from preprocess import get_dataflow
13 |
14 | try:
15 | import ipdb as pdb
16 | except Exception:
17 | import pdb
18 |
19 |
20 | def validation(sess, brain, num_samples, params):
21 | """
22 | Run validation
23 | :param sess: tensorflow session
24 | :param brain: network object that provides loss and update ops, and functions to save and restore the hidden state.
25 | :param num_samples: int, number of samples in the validation set.
26 | :param params: parsed arguments
27 | :return: validation loss, averaged over the validation set
28 | """
29 |
30 | fix_seed = (params.validseed is not None and params.validseed >= 0)
31 | if fix_seed:
32 | np_random_state = np.random.get_state()
33 | np.random.seed(params.validseed)
34 | tf.set_random_seed(params.validseed)
35 |
36 | saved_state = brain.save_state(sess)
37 |
38 | total_loss = 0.0
39 | try:
40 | for eval_i in tqdm.tqdm(range(num_samples), desc="Validation"):
41 | loss, _ = sess.run([brain.valid_loss_op, brain.update_state_op])
42 | total_loss += loss
43 |
44 | print ("Validation loss = %f"%(total_loss/num_samples))
45 |
46 | except tf.errors.OutOfRangeError:
47 | print ("No more samples for evaluation. This should not happen")
48 | raise
49 |
50 | brain.load_state(sess, saved_state)
51 |
52 | # restore seed
53 | if fix_seed:
54 | np.random.set_state(np_random_state)
55 | tf.set_random_seed(np.random.randint(999999)) # cannot save tf seed, so generate random one from numpy
56 |
57 | return total_loss
58 |
59 |
60 | def run_training(params):
61 | """ Run training with the parsed arguments """
62 |
63 | with tf.Graph().as_default():
64 | if params.seed is not None:
65 | tf.set_random_seed(params.seed)
66 |
67 | # training data and network
68 | with tf.variable_scope(tf.get_variable_scope(), reuse=False):
69 | train_data, num_train_samples = get_dataflow(params.trainfiles, params, is_training=True)
70 | train_brain = pfnet.PFNet(inputs=train_data[1:], labels=train_data[0], params=params, is_training=True)
71 |
72 | # test data and network
73 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
74 | test_data, num_test_samples = get_dataflow(params.testfiles, params, is_training=False)
75 | test_brain = pfnet.PFNet(inputs=test_data[1:], labels=test_data[0], params=params, is_training=False)
76 |
77 | # Add the variable initializer Op.
78 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
79 |
80 | # Create a saver for writing training checkpoints.
81 | saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=3)
82 |
83 | # Create a session for running Ops on the Graph.
84 | os.environ["CUDA_VISIBLE_DEVICES"] = "%d"%int(params.gpu)
85 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
86 | sess_config.gpu_options.allow_growth = True
87 |
88 | # training session
89 | with tf.Session(config=sess_config) as sess:
90 | sess.run(init_op)
91 |
92 | # load model from checkpoint file
93 | if params.load:
94 | print("Loading model from " + params.load)
95 | saver.restore(sess, params.load)
96 |
97 | coord = tf.train.Coordinator()
98 | threads = tf.train.start_queue_runners(coord=coord)
99 |
100 | try:
101 | decay_step = 0
102 |
103 | # repeat for a fixed number of epochs
104 | for epoch_i in range(params.epochs):
105 | epoch_loss = 0.0
106 | periodic_loss = 0.0
107 |
108 | # run training over all samples in an epoch
109 | for step_i in tqdm.tqdm(range(num_train_samples)):
110 | _, loss, _ = sess.run([train_brain.train_op, train_brain.train_loss_op,
111 | train_brain.update_state_op])
112 | periodic_loss += loss
113 | epoch_loss += loss
114 |
115 | # print accumulated loss after every few hundred steps
116 | if step_i > 0 and (step_i % 500) == 0:
117 | tqdm.tqdm.write("Epoch %d, step %d. Training loss = %f" % (epoch_i + 1, step_i, periodic_loss / 500.0))
118 | periodic_loss = 0.0
119 |
120 | # print the avarage loss over the epoch
121 | tqdm.tqdm.write("Epoch %d done. Average training loss = %f" % (epoch_i + 1, epoch_loss / num_train_samples))
122 |
123 | # save model, validate and decrease learning rate after each epoch
124 | saver.save(sess, os.path.join(params.logpath, 'model.chk'), global_step=epoch_i + 1)
125 |
126 | # run validation
127 | validation(sess, test_brain, num_samples=num_test_samples, params=params)
128 |
129 | # decay learning rate
130 | if epoch_i + 1 % params.decaystep == 0:
131 | decay_step += 1
132 | current_global_step = sess.run(tf.assign(train_brain.global_step_op, decay_step))
133 | current_learning_rate = sess.run(train_brain.learning_rate_op)
134 | tqdm.tqdm.write("Decreased learning rate to %f." % (current_learning_rate))
135 |
136 | except KeyboardInterrupt:
137 | pass
138 |
139 | except tf.errors.OutOfRangeError:
140 | print("data exhausted")
141 |
142 | finally:
143 | saver.save(sess, os.path.join(params.logpath, 'final.chk')) # dont pass global step
144 | coord.request_stop()
145 |
146 | coord.join(threads)
147 |
148 | print ("Training done. Model is saved to %s"%(params.logpath))
149 |
150 |
151 | if __name__ == '__main__':
152 | params = parse_args()
153 |
154 | params.logpath = os.path.join(params.logpath, "log-" + datetime.now().strftime('%m%d-%H-%M-%S'))
155 | os.mkdir(params.logpath)
156 |
157 | run_training(params)
--------------------------------------------------------------------------------
/transformer/README.md:
--------------------------------------------------------------------------------
1 | # Spatial Transformer Network
2 |
3 | The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
4 |
5 |
6 |

7 |
8 |
9 | ### API
10 |
11 | A Spatial Transformer Network implemented in Tensorflow 1.0 and based on [2].
12 |
13 | #### How to use
14 |
15 |
16 |

17 |
18 |
19 | ```python
20 | transformer(U, theta, out_size)
21 | ```
22 |
23 | #### Parameters
24 |
25 | U : float
26 | The output of a convolutional net should have the
27 | shape [num_batch, height, width, num_channels].
28 | theta: float
29 | The output of the
30 | localisation network should be [num_batch, 6].
31 | out_size: tuple of two ints
32 | The size of the output of the network
33 |
34 |
35 | #### Notes
36 | To initialize the network to the identity transform init ``theta`` to :
37 |
38 | ```python
39 | identity = np.array([[1., 0., 0.],
40 | [0., 1., 0.]])
41 | identity = identity.flatten()
42 | theta = tf.Variable(initial_value=identity)
43 | ```
44 |
45 | #### Experiments
46 |
47 |
48 |

49 |
50 |
51 | We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
52 |
53 | All experiments were run in Tensorflow 0.7.
54 |
55 | ### References
56 |
57 | [1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
58 |
59 | [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
60 |
--------------------------------------------------------------------------------
/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/transformer/__init__.py
--------------------------------------------------------------------------------
/transformer/cluttered_mnist.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 | import tensorflow as tf
16 | from spatial_transformer import transformer
17 | import numpy as np
18 | from tf_utils import weight_variable, bias_variable, dense_to_one_hot
19 |
20 | # %% Load data
21 | mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
22 |
23 | X_train = mnist_cluttered['X_train']
24 | y_train = mnist_cluttered['y_train']
25 | X_valid = mnist_cluttered['X_valid']
26 | y_valid = mnist_cluttered['y_valid']
27 | X_test = mnist_cluttered['X_test']
28 | y_test = mnist_cluttered['y_test']
29 |
30 | # % turn from dense to one hot representation
31 | Y_train = dense_to_one_hot(y_train, n_classes=10)
32 | Y_valid = dense_to_one_hot(y_valid, n_classes=10)
33 | Y_test = dense_to_one_hot(y_test, n_classes=10)
34 |
35 | # %% Graph representation of our network
36 |
37 | # %% Placeholders for 40x40 resolution
38 | x = tf.placeholder(tf.float32, [None, 1600])
39 | y = tf.placeholder(tf.float32, [None, 10])
40 |
41 | # %% Since x is currently [batch, height*width], we need to reshape to a
42 | # 4-D tensor to use it in a convolutional graph. If one component of
43 | # `shape` is the special value -1, the size of that dimension is
44 | # computed so that the total size remains constant. Since we haven't
45 | # defined the batch dimension's shape yet, we use -1 to denote this
46 | # dimension should not change size.
47 | x_tensor = tf.reshape(x, [-1, 40, 40, 1])
48 |
49 | # %% We'll setup the two-layer localisation network to figure out the
50 | # %% parameters for an affine transformation of the input
51 | # %% Create variables for fully connected layer
52 | W_fc_loc1 = weight_variable([1600, 20])
53 | b_fc_loc1 = bias_variable([20])
54 |
55 | W_fc_loc2 = weight_variable([20, 6])
56 | # Use identity transformation as starting point
57 | initial = np.array([[1., 0, 0], [0, 1., 0]])
58 | initial = initial.astype('float32')
59 | initial = initial.flatten()
60 | b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
61 |
62 | # %% Define the two layer localisation network
63 | h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
64 | # %% We can add dropout for regularizing and to reduce overfitting like so:
65 | keep_prob = tf.placeholder(tf.float32)
66 | h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
67 | # %% Second layer
68 | h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
69 |
70 | # %% We'll create a spatial transformer module to identify discriminative
71 | # %% patches
72 | out_size = (40, 40)
73 | h_trans = transformer(x_tensor, h_fc_loc2, out_size)
74 |
75 | # %% We'll setup the first convolutional layer
76 | # Weight matrix is [height x width x input_channels x output_channels]
77 | filter_size = 3
78 | n_filters_1 = 16
79 | W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
80 |
81 | # %% Bias is [output_channels]
82 | b_conv1 = bias_variable([n_filters_1])
83 |
84 | # %% Now we can build a graph which does the first layer of convolution:
85 | # we define our stride as batch x height x width x channels
86 | # instead of pooling, we use strides of 2 and more layers
87 | # with smaller filters.
88 |
89 | h_conv1 = tf.nn.relu(
90 | tf.nn.conv2d(input=h_trans,
91 | filter=W_conv1,
92 | strides=[1, 2, 2, 1],
93 | padding='SAME') +
94 | b_conv1)
95 |
96 | # %% And just like the first layer, add additional layers to create
97 | # a deep net
98 | n_filters_2 = 16
99 | W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
100 | b_conv2 = bias_variable([n_filters_2])
101 | h_conv2 = tf.nn.relu(
102 | tf.nn.conv2d(input=h_conv1,
103 | filter=W_conv2,
104 | strides=[1, 2, 2, 1],
105 | padding='SAME') +
106 | b_conv2)
107 |
108 | # %% We'll now reshape so we can connect to a fully-connected layer:
109 | h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
110 |
111 | # %% Create a fully-connected layer:
112 | n_fc = 1024
113 | W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
114 | b_fc1 = bias_variable([n_fc])
115 | h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
116 |
117 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
118 |
119 | # %% And finally our softmax layer:
120 | W_fc2 = weight_variable([n_fc, 10])
121 | b_fc2 = bias_variable([10])
122 | y_logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
123 |
124 | # %% Define loss/eval/training functions
125 | cross_entropy = tf.reduce_mean(
126 | tf.nn.softmax_cross_entropy_with_logits(logits=y_logits, labels=y))
127 | opt = tf.train.AdamOptimizer()
128 | optimizer = opt.minimize(cross_entropy)
129 | grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
130 |
131 | # %% Monitor accuracy
132 | correct_prediction = tf.equal(tf.argmax(y_logits, 1), tf.argmax(y, 1))
133 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
134 |
135 | # %% We now create a new session to actually perform the initialization the
136 | # variables:
137 | sess = tf.Session()
138 | sess.run(tf.global_variables_initializer())
139 |
140 |
141 | # %% We'll now train in minibatches and report accuracy, loss:
142 | iter_per_epoch = 100
143 | n_epochs = 500
144 | train_size = 10000
145 |
146 | indices = np.linspace(0, 10000 - 1, iter_per_epoch)
147 | indices = indices.astype('int')
148 |
149 | for epoch_i in range(n_epochs):
150 | for iter_i in range(iter_per_epoch - 1):
151 | batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
152 | batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
153 |
154 | if iter_i % 10 == 0:
155 | loss = sess.run(cross_entropy,
156 | feed_dict={
157 | x: batch_xs,
158 | y: batch_ys,
159 | keep_prob: 1.0
160 | })
161 | print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
162 |
163 | sess.run(optimizer, feed_dict={
164 | x: batch_xs, y: batch_ys, keep_prob: 0.8})
165 |
166 | print('Accuracy (%d): ' % epoch_i + str(sess.run(accuracy,
167 | feed_dict={
168 | x: X_valid,
169 | y: Y_valid,
170 | keep_prob: 1.0
171 | })))
172 | # theta = sess.run(h_fc_loc2, feed_dict={
173 | # x: batch_xs, keep_prob: 1.0})
174 | # print(theta[0])
175 |
--------------------------------------------------------------------------------
/transformer/data/README.md:
--------------------------------------------------------------------------------
1 | ### How to get the data
2 |
3 | #### Cluttered MNIST
4 |
5 | The cluttered MNIST dataset can be found here [1] or can be generated via [2].
6 |
7 | Settings used for `cluttered_mnist.py` :
8 |
9 | ```python
10 |
11 | ORG_SHP = [28, 28]
12 | OUT_SHP = [40, 40]
13 | NUM_DISTORTIONS = 8
14 | dist_size = (5, 5)
15 |
16 | ```
17 |
18 | [1] https://github.com/daviddao/spatial-transformer-tensorflow
19 |
20 | [2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py
--------------------------------------------------------------------------------
/transformer/example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | from scipy import ndimage
16 | import tensorflow as tf
17 | from spatial_transformer import transformer
18 | import numpy as np
19 | import matplotlib.pyplot as plt
20 |
21 | # %% Create a batch of three images (1600 x 1200)
22 | # %% Image retrieved from:
23 | # %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
24 | im = ndimage.imread('cat.jpg')
25 | im = im / 255.
26 | im = im.reshape(1, 1200, 1600, 3)
27 | im = im.astype('float32')
28 |
29 | # %% Let the output size of the transformer be half the image size.
30 | out_size = (600, 800)
31 |
32 | # %% Simulate batch
33 | batch = np.append(im, im, axis=0)
34 | batch = np.append(batch, im, axis=0)
35 | num_batch = 3
36 |
37 | x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
38 | x = tf.cast(batch, 'float32')
39 |
40 | # %% Create localisation network and convolutional layer
41 | with tf.variable_scope('spatial_transformer_0'):
42 |
43 | # %% Create a fully-connected layer with 6 output nodes
44 | n_fc = 6
45 | W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
46 |
47 | # %% Zoom into the image
48 | initial = np.array([[0.5, 0, 0], [0, 0.5, 0]])
49 | initial = initial.astype('float32')
50 | initial = initial.flatten()
51 |
52 | b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
53 | h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1
54 | h_trans = transformer(x, h_fc1, out_size)
55 |
56 | # %% Run session
57 | sess = tf.Session()
58 | sess.run(tf.global_variables_initializer())
59 | y = sess.run(h_trans, feed_dict={x: batch})
60 |
61 | # plt.imshow(y[0])
62 |
--------------------------------------------------------------------------------
/transformer/spatial_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from six.moves import xrange
17 | import tensorflow as tf
18 |
19 |
20 | def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
21 | """Spatial Transformer Layer
22 |
23 | Implements a spatial transformer layer as described in [1]_.
24 | Based on [2]_ and edited by David Dao for Tensorflow.
25 |
26 | Parameters
27 | ----------
28 | U : float
29 | The output of a convolutional net should have the
30 | shape [num_batch, height, width, num_channels].
31 | theta: float
32 | The output of the
33 | localisation network should be [num_batch, 6].
34 | out_size: tuple of two ints
35 | The size of the output of the network (height, width)
36 |
37 | References
38 | ----------
39 | .. [1] Spatial Transformer Networks
40 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
41 | Submitted on 5 Jun 2015
42 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
43 |
44 | Notes
45 | -----
46 | To initialize the network to the identity transform init
47 | ``theta`` to :
48 | identity = np.array([[1., 0., 0.],
49 | [0., 1., 0.]])
50 | identity = identity.flatten()
51 | theta = tf.Variable(initial_value=identity)
52 |
53 | """
54 |
55 | def _repeat(x, n_repeats):
56 | with tf.variable_scope('_repeat'):
57 | rep = tf.transpose(
58 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0])
59 | rep = tf.cast(rep, 'int32')
60 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
61 | return tf.reshape(x, [-1])
62 |
63 | def _interpolate(im, x, y, out_size):
64 | with tf.variable_scope('_interpolate'):
65 | # constants
66 | num_batch = tf.shape(im)[0]
67 | height = tf.shape(im)[1]
68 | width = tf.shape(im)[2]
69 | channels = tf.shape(im)[3]
70 |
71 | x = tf.cast(x, 'float32')
72 | y = tf.cast(y, 'float32')
73 | height_f = tf.cast(height, 'float32')
74 | width_f = tf.cast(width, 'float32')
75 | out_height = out_size[0]
76 | out_width = out_size[1]
77 | zero = tf.zeros([], dtype='int32')
78 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
79 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
80 |
81 | # scale indices from [-1, 1] to [0, width/height]
82 | x = (x + 1.0)*(width_f) / 2.0
83 | y = (y + 1.0)*(height_f) / 2.0
84 |
85 | # do sampling
86 | x0 = tf.cast(tf.floor(x), 'int32')
87 | x1 = x0 + 1
88 | y0 = tf.cast(tf.floor(y), 'int32')
89 | y1 = y0 + 1
90 |
91 | x0 = tf.clip_by_value(x0, zero, max_x)
92 | x1 = tf.clip_by_value(x1, zero, max_x)
93 | y0 = tf.clip_by_value(y0, zero, max_y)
94 | y1 = tf.clip_by_value(y1, zero, max_y)
95 | dim2 = width
96 | dim1 = width*height
97 | base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
98 | base_y0 = base + y0*dim2
99 | base_y1 = base + y1*dim2
100 | idx_a = base_y0 + x0
101 | idx_b = base_y1 + x0
102 | idx_c = base_y0 + x1
103 | idx_d = base_y1 + x1
104 |
105 | # use indices to lookup pixels in the flat image and restore
106 | # channels dim
107 | im_flat = tf.reshape(im, tf.stack([-1, channels]))
108 | im_flat = tf.cast(im_flat, 'float32')
109 | Ia = tf.gather(im_flat, idx_a)
110 | Ib = tf.gather(im_flat, idx_b)
111 | Ic = tf.gather(im_flat, idx_c)
112 | Id = tf.gather(im_flat, idx_d)
113 |
114 | # and finally calculate interpolated values
115 | x0_f = tf.cast(x0, 'float32')
116 | x1_f = tf.cast(x1, 'float32')
117 | y0_f = tf.cast(y0, 'float32')
118 | y1_f = tf.cast(y1, 'float32')
119 | wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
120 | wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
121 | wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
122 | wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
123 | output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
124 | return output
125 |
126 | def _meshgrid(height, width):
127 | with tf.variable_scope('_meshgrid'):
128 | # This should be equivalent to:
129 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
130 | # np.linspace(-1, 1, height))
131 | # ones = np.ones(np.prod(x_t.shape))
132 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
133 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
134 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
135 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
136 | tf.ones(shape=tf.stack([1, width])))
137 |
138 | x_t_flat = tf.reshape(x_t, (1, -1))
139 | y_t_flat = tf.reshape(y_t, (1, -1))
140 |
141 | ones = tf.ones_like(x_t_flat)
142 | grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat, ones])
143 | return grid
144 |
145 | def _transform(theta, input_dim, out_size):
146 | with tf.variable_scope('_transform'):
147 | num_batch = tf.shape(input_dim)[0]
148 | height = tf.shape(input_dim)[1]
149 | width = tf.shape(input_dim)[2]
150 | num_channels = tf.shape(input_dim)[3]
151 | theta = tf.reshape(theta, (-1, 2, 3))
152 | theta = tf.cast(theta, 'float32')
153 |
154 | # grid of (x_t, y_t, 1), eq (1) in ref [1]
155 | height_f = tf.cast(height, 'float32')
156 | width_f = tf.cast(width, 'float32')
157 | out_height = out_size[0]
158 | out_width = out_size[1]
159 | grid = _meshgrid(out_height, out_width)
160 | grid = tf.expand_dims(grid, 0)
161 | grid = tf.reshape(grid, [-1])
162 | grid = tf.tile(grid, tf.stack([num_batch]))
163 | grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
164 |
165 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
166 | T_g = tf.matmul(theta, grid)
167 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
168 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
169 | x_s_flat = tf.reshape(x_s, [-1])
170 | y_s_flat = tf.reshape(y_s, [-1])
171 |
172 | input_transformed = _interpolate(
173 | input_dim, x_s_flat, y_s_flat,
174 | out_size)
175 |
176 | output = tf.reshape(
177 | input_transformed, tf.stack([num_batch, out_height, out_width, num_channels]))
178 | return output
179 |
180 | with tf.variable_scope(name):
181 | output = _transform(theta, U, out_size)
182 | return output
183 |
184 |
185 | def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
186 | """Batch Spatial Transformer Layer
187 |
188 | Parameters
189 | ----------
190 |
191 | U : float
192 | tensor of inputs [num_batch,height,width,num_channels]
193 | thetas : float
194 | a set of transformations for each input [num_batch,num_transforms,6]
195 | out_size : int
196 | the size of the output [out_height,out_width]
197 |
198 | Returns: float
199 | Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
200 | """
201 | with tf.variable_scope(name):
202 | num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
203 | indices = [[i]*num_transforms for i in xrange(num_batch)]
204 | input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
205 | return transformer(input_repeated, thetas, out_size)
206 |
--------------------------------------------------------------------------------
/transformer/tf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/
17 | import tensorflow as tf
18 | import numpy as np
19 |
20 | def conv2d(x, n_filters,
21 | k_h=5, k_w=5,
22 | stride_h=2, stride_w=2,
23 | stddev=0.02,
24 | activation=lambda x: x,
25 | bias=True,
26 | padding='SAME',
27 | name="Conv2D"):
28 | """2D Convolution with options for kernel size, stride, and init deviation.
29 | Parameters
30 | ----------
31 | x : Tensor
32 | Input tensor to convolve.
33 | n_filters : int
34 | Number of filters to apply.
35 | k_h : int, optional
36 | Kernel height.
37 | k_w : int, optional
38 | Kernel width.
39 | stride_h : int, optional
40 | Stride in rows.
41 | stride_w : int, optional
42 | Stride in cols.
43 | stddev : float, optional
44 | Initialization's standard deviation.
45 | activation : arguments, optional
46 | Function which applies a nonlinearity
47 | padding : str, optional
48 | 'SAME' or 'VALID'
49 | name : str, optional
50 | Variable scope to use.
51 | Returns
52 | -------
53 | x : Tensor
54 | Convolved input.
55 | """
56 | with tf.variable_scope(name):
57 | w = tf.get_variable(
58 | 'w', [k_h, k_w, x.get_shape()[-1], n_filters],
59 | initializer=tf.truncated_normal_initializer(stddev=stddev))
60 | conv = tf.nn.conv2d(
61 | x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
62 | if bias:
63 | b = tf.get_variable(
64 | 'b', [n_filters],
65 | initializer=tf.truncated_normal_initializer(stddev=stddev))
66 | conv = conv + b
67 | return conv
68 |
69 | def linear(x, n_units, scope=None, stddev=0.02,
70 | activation=lambda x: x):
71 | """Fully-connected network.
72 | Parameters
73 | ----------
74 | x : Tensor
75 | Input tensor to the network.
76 | n_units : int
77 | Number of units to connect to.
78 | scope : str, optional
79 | Variable scope to use.
80 | stddev : float, optional
81 | Initialization's standard deviation.
82 | activation : arguments, optional
83 | Function which applies a nonlinearity
84 | Returns
85 | -------
86 | x : Tensor
87 | Fully-connected output.
88 | """
89 | shape = x.get_shape().as_list()
90 |
91 | with tf.variable_scope(scope or "Linear"):
92 | matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
93 | tf.random_normal_initializer(stddev=stddev))
94 | return activation(tf.matmul(x, matrix))
95 |
96 | # %%
97 | def weight_variable(shape):
98 | '''Helper function to create a weight variable initialized with
99 | a normal distribution
100 | Parameters
101 | ----------
102 | shape : list
103 | Size of weight variable
104 | '''
105 | #initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
106 | initial = tf.zeros(shape)
107 | return tf.Variable(initial)
108 |
109 | # %%
110 | def bias_variable(shape):
111 | '''Helper function to create a bias variable initialized with
112 | a constant value.
113 | Parameters
114 | ----------
115 | shape : list
116 | Size of weight variable
117 | '''
118 | initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
119 | return tf.Variable(initial)
120 |
121 | # %%
122 | def dense_to_one_hot(labels, n_classes=2):
123 | """Convert class labels from scalars to one-hot vectors."""
124 | labels = np.array(labels)
125 | n_labels = labels.shape[0]
126 | index_offset = np.arange(n_labels) * n_classes
127 | labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32)
128 | labels_one_hot.flat[index_offset + labels.ravel()] = 1
129 | return labels_one_hot
130 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdaCompNUS/pfnet/861b398c58574cc3e415896f3dd278a76cb2b383/utils/__init__.py
--------------------------------------------------------------------------------
/utils/network_layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | _L2_SCALE = 1.0 # scale when adding to loss instead of a global scaler
5 |
6 |
7 | # Helper functions for constructing layers
8 | def dense_layer(units, activation=None, use_bias=False, name=None):
9 |
10 | fn = lambda x: tf.layers.dense(
11 | x, units, activation=convert_activation_string(activation), use_bias=use_bias, name=name,
12 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=_L2_SCALE))
13 | return fn
14 |
15 |
16 | def conv2_layer(filters, kernel_size, activation=None, padding='same', strides=(1,1), dilation_rate=(1,1),
17 | data_format='channels_last', use_bias=False, name=None, layer_i=0, name2=None):
18 | if name is None:
19 | name = "l%d_conv%d" % (layer_i, np.max(kernel_size))
20 | if np.max(dilation_rate) > 1:
21 | name += "_d%d" % np.max(dilation_rate)
22 | if name2 is not None:
23 | name += "_"+name2
24 | fn = lambda x: tf.layers.conv2d(
25 | x, filters, kernel_size, activation=convert_activation_string(activation),
26 | padding=padding, strides=strides, dilation_rate = dilation_rate, data_format=data_format,
27 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=_L2_SCALE),
28 | kernel_initializer=tf.variance_scaling_initializer(),
29 | use_bias=use_bias, name=name)
30 | return fn
31 |
32 |
33 | def locallyconn2_layer(filters, kernel_size, activation=None, padding='same', strides=(1,1), dilation_rate=(1,1),
34 | data_format='channels_last', use_bias=False, name=None, layer_i=0, name2=None):
35 | assert dilation_rate == (1, 1) # keras layer doesnt have this input. maybe different name?
36 | if name is None:
37 | name = "l%d_conv%d"%(layer_i, np.max(kernel_size))
38 | if np.max(dilation_rate) > 1:
39 | name += "_d%d"%np.max(dilation_rate)
40 | if name2 is not None:
41 | name += "_"+name2
42 | fn = tf.keras.layers.LocallyConnected2D(
43 | filters, kernel_size, activation=convert_activation_string(activation),
44 | padding=padding, strides=strides, data_format=data_format,
45 | kernel_regularizer=tf.contrib.layers.l2_regularizer(scale=_L2_SCALE),
46 | kernel_initializer=tf.variance_scaling_initializer(),
47 | use_bias=use_bias, name=name)
48 | return fn
49 |
50 |
51 | def convert_activation_string(activation):
52 | if isinstance(activation, str):
53 | if activation == 'relu':
54 | activation = tf.nn.relu
55 | elif activation == 'tanh':
56 | activation = tf.nn.tanh
57 | else:
58 | assert False
59 | return activation
60 |
--------------------------------------------------------------------------------
/utils/tfrecordfeatures.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def tf_bytes_feature(value):
5 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
6 |
7 |
8 | def tf_bytelist_feature(values):
9 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
10 |
11 |
12 | def tf_int64_feature(value):
13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
--------------------------------------------------------------------------------