├── sorting ├── BUILD ├── sorting_train.py ├── sorting_eval.py └── sorting_model.py ├── BUILD ├── CONTRIBUTING.md ├── optimizer.py ├── sinkhorn_ops_test.py ├── LICENSE └── sinkhorn_ops.py /sorting/BUILD: -------------------------------------------------------------------------------- 1 | py_binary( 2 | name = "sorting_train", 3 | srcs = ["sorting_train.py"], 4 | deps = [ 5 | ":sorting_model", 6 | "//:sinkhorn_ops", 7 | ], 8 | ) 9 | 10 | py_binary( 11 | name = "sorting_eval", 12 | srcs = ["sorting_eval.py"], 13 | deps = [ 14 | ":sorting_model", 15 | "//:sinkhorn_ops", 16 | ], 17 | ) 18 | 19 | py_library( 20 | name = "sorting_model", 21 | srcs = ["sorting_model.py"], 22 | deps = [ 23 | "//:optimizer", 24 | "//:sinkhorn_ops", 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | py_library( 2 | name = "sinkhorn_ops", 3 | srcs = [ 4 | "sinkhorn_ops.py", 5 | ], 6 | visibility = ["//visibility:public"], 7 | deps = [], 8 | ) 9 | 10 | py_library( 11 | name = "optimizer", 12 | srcs = [ 13 | "optimizer.py", 14 | ], 15 | visibility = ["//visibility:public"], 16 | deps = [], 17 | ) 18 | 19 | py_test( 20 | name = "sinkhorn_ops_test", 21 | srcs = ["sinkhorn_ops_test.py"], 22 | visibility = ["//visibility:public"], 23 | deps = [ 24 | ":sinkhorn_ops", 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Library with optimization definitions and functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def set_optimizer(optimizer, lr, opt_eps=1.0, opt_momentum=0.9, rms_decay=0.9, 26 | adam_beta1=0.9, adam_beta2=0.999): 27 | """Sets optimizer optimizer op. 28 | 29 | Args: 30 | optimizer: A string (sgd, momentum, adagrad, adam, rmsprop). 31 | lr: learning rate, a float. 32 | opt_eps: Optimizer epsilon (for ADAM and RMSprop). 33 | opt_momentum: Optimizer momentum. Common for Momentum and RMSProp. 34 | rms_decay: RMSProp decay parameter. 35 | adam_beta1: beta_1 parameter for ADAM. 36 | adam_beta2: beta_2 parameter for ADAM. 37 | Returns: 38 | opt, the optimizer op. 39 | 40 | """ 41 | 42 | if optimizer == "sgd": 43 | opt = tf.train.GradientDescentOptimizer(lr) 44 | elif optimizer == "momentum": 45 | opt = tf.train.MomentumOptimizer(lr, opt_momentum) 46 | elif optimizer == "adagrad": 47 | opt = tf.train.AdagradOptimizer(lr) 48 | elif optimizer == "adam": 49 | opt = tf.train.AdamOptimizer(lr, beta1=adam_beta1, beta2=adam_beta2, 50 | epsilon=opt_eps) 51 | elif optimizer == "rmsprop": 52 | opt = tf.train.RMSPropOptimizer(lr, rms_decay, opt_momentum, opt_eps) 53 | return opt 54 | -------------------------------------------------------------------------------- /sorting/sorting_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Trains a model that sorts numbers, keeping loss summaries in tensorboard. 16 | 17 | The flag hparam has to be passed as a string of comma separated statements of 18 | the form hparam=value, where the hparam's are any of the listed in the 19 | dictionary DEFAULT_HPARAMS. 20 | See the README.md file for further compilation and running instructions. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | import sorting_model 30 | 31 | flags = tf.app.flags 32 | gfile = tf.gfile 33 | FLAGS = flags.FLAGS 34 | 35 | flags.DEFINE_string('hparams', '', 'Hyperparameters') 36 | flags.DEFINE_integer('num_iters', 500, 'Number of iterations') 37 | flags.DEFINE_integer( 38 | 'save_summaries_secs', 30, 39 | 'The frequency with which summaries are saved, in seconds.') 40 | flags.DEFINE_integer( 41 | 'save_interval_secs', 30, 42 | 'The frequency with which the model is saved, in seconds.') 43 | flags.DEFINE_string('exp_log_dir', '/tmp/sorting/', 44 | 'Directory where to write event logs.') 45 | flags.DEFINE_integer('max_to_keep', 1, 'Maximum number of checkpoints to keep') 46 | 47 | DEFAULT_HPARAMS = tf.contrib.training.HParams(n_numbers=50, 48 | lr=0.1, 49 | temperature=1.0, 50 | batch_size=10, 51 | prob_inc=1.0, 52 | samples_per_num=5, 53 | n_iter_sinkhorn=10, 54 | n_units=32, 55 | noise_factor=1.0, 56 | optimizer='adam', 57 | keep_prob=1.) 58 | 59 | 60 | def main(_): 61 | 62 | hparams = DEFAULT_HPARAMS 63 | hparams.parse(FLAGS.hparams) 64 | 65 | if not gfile.Exists(FLAGS.exp_log_dir): 66 | gfile.MakeDirs(FLAGS.exp_log_dir) 67 | tf.reset_default_graph() 68 | g = tf.Graph() 69 | model = sorting_model.SortingModel(g, hparams) 70 | with g.as_default(): 71 | model.set_input() 72 | model.build_network() 73 | model.build_l2s_loss() 74 | model.build_optimizer() 75 | model.add_summaries_train() 76 | 77 | with tf.Session(): 78 | tf.contrib.slim.learning.train( 79 | train_op=model.train_op, 80 | logdir=FLAGS.exp_log_dir, 81 | global_step=model.global_step, 82 | saver=tf.train.Saver(max_to_keep=FLAGS.max_to_keep), 83 | number_of_steps=FLAGS.num_iters, 84 | save_summaries_secs=FLAGS.save_summaries_secs, 85 | save_interval_secs=FLAGS.save_interval_secs) 86 | 87 | 88 | if __name__ == '__main__': 89 | tf.app.run(main) 90 | -------------------------------------------------------------------------------- /sorting/sorting_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evaluates a sorting model, adding scalar summaries to tensorboard. 16 | 17 | It also outputs evaluation metrics as numpy scalars. 18 | The flag hparam has to be passed as a string of comma separated statements of 19 | the form hparam=value, where the hparam's are any of the listed in the 20 | dictionary DEFAULT_HPARAMS. 21 | See the README.md file for further compilation and running instructions. 22 | 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | 30 | import os 31 | import time 32 | 33 | import tensorflow as tf 34 | 35 | import sorting_model 36 | 37 | gfile = tf.gfile 38 | flags = tf.app.flags 39 | flags.DEFINE_string("hparams", "", "Hyperparameters") 40 | flags.DEFINE_string("batch_transform_type", 41 | None, "Options: None, WeakMix, StrongMix") 42 | flags.DEFINE_boolean("evaluate_all", 43 | True, "If false evaluate only the last checkpoint") 44 | flags.DEFINE_string("exp_log_dir", "/tmp/sorting/", 45 | "Directory where to write event logs.") 46 | flags.DEFINE_boolean("eval_once", True, 47 | "If true, only evaluate the model once") 48 | flags.DEFINE_integer("sec_sleep", 30, 49 | "If no new checkpoint, sleep for some seconds") 50 | flags.DEFINE_integer("secs_run_for", 10000, 51 | "Maximum time (seconds) the program will be active") 52 | 53 | FLAGS = tf.app.flags.FLAGS 54 | DEFAULT_HPARAMS = tf.contrib.training.HParams(n_numbers=50, 55 | lr=0.1, 56 | temperature=1.0, 57 | batch_size=50, 58 | prob_inc=1.0, 59 | samples_per_num=5, 60 | n_iter_sinkhorn=10, 61 | n_units=32, 62 | noise_factor=0.0, 63 | optimizer="adam", 64 | keep_prob=1.) 65 | 66 | 67 | def log(s): 68 | tf.logging.info(s) 69 | print(s) 70 | 71 | 72 | def wait_for_new_checkpoint(saver, sess, logdir, global_step, 73 | last_step_evaluated, sleep_secs): 74 | while True: 75 | if restore_checkpoint_if_exists(saver, sess, logdir): 76 | step = sess.run(global_step) 77 | if step <= last_step_evaluated: 78 | log("Found old checkpoint, sleeping %ds" % sleep_secs) 79 | time.sleep(sleep_secs) 80 | else: 81 | return step 82 | else: 83 | log("Checkpoint not found in %s," 84 | "sleeping for %ds" % (logdir, sleep_secs)) 85 | time.sleep(sleep_secs) 86 | 87 | 88 | def restore_checkpoint_if_exists(saver, sess, logdir): 89 | ckpt = tf.train.get_checkpoint_state(logdir) 90 | if ckpt: 91 | log("Restoring checking point from %s" % ckpt.model_checkpoint_path) 92 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 93 | full_checkpoint_path = os.path.join(logdir, ckpt_name) 94 | saver.restore(sess, full_checkpoint_path) 95 | return True 96 | return False 97 | 98 | 99 | # Copyright 2017 Google Inc. 100 | # 101 | # Licensed under the Apache License, Version 2.0 (the "License"); 102 | # you may not use this file except in compliance with the License. 103 | # You may obtain a copy of the License at 104 | # 105 | # https://www.apache.org/licenses/LICENSE-2.0 106 | # 107 | # Unless required by applicable law or agreed to in writing, software 108 | # distributed under the License is distributed on an "AS IS" BASIS, 109 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 110 | # See the License for the specific language governing permissions and 111 | # limitations under the License. 112 | 113 | def main(_): 114 | time_start = time.time() 115 | hparams = DEFAULT_HPARAMS 116 | hparams.parse(FLAGS.hparams) 117 | 118 | if not gfile.Exists(FLAGS.exp_log_dir): 119 | gfile.MakeDirs(FLAGS.exp_log_dir) 120 | 121 | g = tf.Graph() 122 | model = sorting_model.SortingModel(g, hparams) 123 | 124 | with model.graph.as_default(): 125 | model.set_input() 126 | model.build_network() 127 | model.build_hard_losses() 128 | model.add_summaries_eval() 129 | summaries_eval = tf.summary.merge_all() 130 | saver = tf.train.Saver() 131 | writer = tf.summary.FileWriter(FLAGS.exp_log_dir, model.graph) 132 | last_step_evaluated = -1 133 | with tf.Session() as session: 134 | while time.time() - time_start < FLAGS.secs_run_for: 135 | wait_for_new_checkpoint( 136 | saver, session, FLAGS.exp_log_dir, 137 | model.global_step, last_step_evaluated, FLAGS.sec_sleep) 138 | (summaries, step, eval_measures) = session.run([ 139 | summaries_eval, model.global_step, model.get_eval_measures()]) 140 | (l1_diff, l2sh_diff, kendall_tau, 141 | prop_wrong, prop_any_wrong) = eval_measures 142 | log("Frequency of mistakes was %s" % prop_wrong) 143 | log("Frequency of series with at least an error was %s" % 144 | prop_any_wrong) 145 | log("Kendall's tau was %s" % kendall_tau) 146 | log(("Mean L2 squared difference between true and inferred series " 147 | " was %s") % l2sh_diff) 148 | log("Mean L1 difference between true and inferred series was %s" 149 | % l1_diff) 150 | writer.add_summary(summaries, global_step=step) 151 | last_step_evaluated = step 152 | if FLAGS.eval_once is True: 153 | break 154 | 155 | if __name__ == "__main__": 156 | tf.app.run() 157 | -------------------------------------------------------------------------------- /sinkhorn_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sinkhorn_ops library.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | #import google3 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from tensorflow.python.platform import test 25 | import sinkhorn_ops 26 | 27 | 28 | class SinkhornTest(test.TestCase): 29 | 30 | def setUp(self): 31 | self.rng = np.random.RandomState(0) 32 | tf.set_random_seed(1) 33 | 34 | def test_approximately_stochastic(self): 35 | with self.test_session(use_gpu=True): 36 | for dims in [2, 5, 10]: 37 | for batch_size in [1, 2, 10]: 38 | log_alpha = self.rng.randn(batch_size, dims, dims) 39 | result = sinkhorn_ops.sinkhorn(log_alpha) 40 | 41 | self.assertAllClose(np.sum(result.eval(), 1), 42 | np.tile([1.0], (batch_size, dims)), 43 | atol=1e-3) 44 | self.assertAllClose(np.sum(result.eval(), 2), 45 | np.tile([1.0], (batch_size, dims)), 46 | atol=1e-3) 47 | 48 | def test_equivalence_gumbel_sinkhorn_and_sinkhorn(self): 49 | """Tests the equivalence between sinkhorn and gumbel_sinhorn in a case. 50 | 51 | When noise_factor = 0.0 the output of gumbel_sinkhorn should be the same 52 | as the output of sinkhorn, 'modulo' possible many repetitions of the same 53 | matrix given by gumbel_sinkhorn. 54 | """ 55 | with self.test_session(use_gpu=True): 56 | batch_size = 10 57 | dims = 5 58 | n_samples = 20 59 | temp = 1.0 60 | noise_factor = 0.0 61 | log_alpha = self.rng.randn(batch_size, dims, dims) 62 | result_sinkhorn = sinkhorn_ops.sinkhorn(log_alpha) 63 | result_sinkhorn_reshaped = tf.reshape( 64 | result_sinkhorn, [batch_size, 1, dims, dims]) 65 | result_sinkhorn_tiled = tf.tile( 66 | result_sinkhorn_reshaped, [1, n_samples, 1, 1]) 67 | result_gumbel_sinkhorn, _ = sinkhorn_ops.gumbel_sinkhorn( 68 | log_alpha, temp, n_samples, noise_factor) 69 | 70 | self.assertAllEqual(result_gumbel_sinkhorn.eval(), 71 | result_sinkhorn_tiled.eval()) 72 | 73 | def test_gumbel_sinkhorn_high_temperature(self): 74 | """At very high temperatures, the resulting matrix approaches the uniform. 75 | """ 76 | n_samples = 1 77 | temp = 100000.0 78 | 79 | with self.test_session(use_gpu=True): 80 | 81 | for dims in [2, 5, 10]: 82 | for batch_size in [1, 2, 10]: 83 | for noise_factor in [1.0, 5.0]: 84 | log_alpha = tf.cast(self.rng.randn(batch_size, dims, dims), 85 | dtype=tf.float32) 86 | 87 | result_gumbel_sinkhorn, _ = sinkhorn_ops.gumbel_sinkhorn( 88 | log_alpha, temp, n_samples, noise_factor, squeeze=True) 89 | uniform = np.ones((batch_size, dims, dims)) / dims 90 | self.assertAllClose(uniform, result_gumbel_sinkhorn.eval(), 91 | atol=1e-3) 92 | 93 | def test_matching(self): 94 | """The solution of the matching for the identity matrix is range(N). 95 | """ 96 | with self.test_session(use_gpu=True): 97 | dims = 10 98 | identity = np.eye(dims) 99 | result_matching = sinkhorn_ops.matching(identity) 100 | self.assertAllEqual(result_matching.eval(), 101 | np.reshape(range(dims), [1, dims])) 102 | 103 | def test_perm_inverse(self): 104 | """The product of a permutation and its inverse is the identity.""" 105 | 106 | with self.test_session(use_gpu=True): 107 | dims = 10 108 | permutation = np.reshape(self.rng.permutation(dims), [1, -1]) 109 | permutation_matrix = sinkhorn_ops.listperm2matperm(permutation) 110 | inverse = sinkhorn_ops.invert_listperm(permutation) 111 | inverse_matrix = sinkhorn_ops.listperm2matperm(inverse) 112 | prod = tf.matmul(permutation_matrix, inverse_matrix) 113 | 114 | self.assertAllEqual(prod.eval(), 115 | np.reshape(np.eye(dims), [1, dims, dims])) 116 | 117 | def test_listperm2matperm(self): 118 | """The matrix form of the permutation range(N) is the identity.""" 119 | 120 | with self.test_session(use_gpu=True): 121 | dims = 10 122 | permutation_list = np.reshape(np.arange(dims), [1, -1]) 123 | permutation_matrix = sinkhorn_ops.listperm2matperm(permutation_list) 124 | self.assertAllEqual(permutation_matrix.eval(), 125 | np.reshape(np.eye(dims), [1, dims, dims])) 126 | 127 | def test_matperm2listperm(self): 128 | """The list form of the matrix permutation identity is range(N).""" 129 | 130 | with self.test_session(use_gpu=True): 131 | dims = 10 132 | permutation_matrix = np.eye(dims) 133 | permutation_list = sinkhorn_ops.matperm2listperm(permutation_matrix) 134 | self.assertAllEqual(permutation_list.eval(), 135 | np.reshape(np.arange(dims), [1, dims])) 136 | 137 | def test_sample_uniform_and_order(self): 138 | """Ordered numbers form indeed an increasing sequence.""" 139 | n_lists = 1 140 | n_numbers = 10 141 | prob_inc = 1.0 142 | with self.test_session(use_gpu=True): 143 | ordered, _, _ = sinkhorn_ops.sample_uniform_and_order(n_lists, 144 | n_numbers, 145 | prob_inc) 146 | self.assertTrue(np.min(np.diff(ordered.eval())) > 0) 147 | 148 | def test_sample_permutations(self): 149 | """What is being sampled are indeed permutations of range(N).""" 150 | n_permutations = 10 151 | n_objects = 5 152 | 153 | with self.test_session(use_gpu=True): 154 | permutations = sinkhorn_ops.sample_permutations(n_permutations, n_objects) 155 | tiled_range = np.tile(np.reshape( 156 | np.arange(n_objects), [1, n_objects]), [n_permutations, 1]) 157 | self.assertAllEqual(np.sort(permutations.eval()), tiled_range) 158 | 159 | if __name__ == '__main__': 160 | tf.test.main() 161 | -------------------------------------------------------------------------------- /sorting/sorting_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Model class for sorting numbers. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | import optimizer 26 | import sinkhorn_ops 27 | 28 | 29 | class SortingModel(object): 30 | """Constructs the graph of tensors to learn to sort numbers.""" 31 | 32 | def __init__(self, graph, hparams): 33 | self.graph = graph 34 | self.hparams = hparams 35 | self.batch_size = hparams.batch_size 36 | self.n_numbers = hparams.n_numbers 37 | self.samples_per_num = hparams.samples_per_num 38 | self.n_iter_sinkhorn = hparams.n_iter_sinkhorn 39 | self.noise_factor = hparams.noise_factor 40 | self.prob_inc = hparams.prob_inc 41 | self.optimizer = hparams.optimizer 42 | self.n_units = hparams.n_units 43 | 44 | def set_input(self): 45 | with self.graph.as_default(): 46 | (self._ordered, self._random, 47 | self._hard_perms) = sinkhorn_ops.sample_uniform_and_order( 48 | self.batch_size, self.n_numbers, self.prob_inc) 49 | # tiled variables, to compare to many permutations 50 | self._ordered_tiled = tf.tile(self._ordered, [self.samples_per_num, 1]) 51 | self._random_tiled = tf.tile(self._random, [self.samples_per_num, 1]) 52 | 53 | def build_network(self): 54 | """The most important part, where the neural network is built.""" 55 | 56 | def _create_log_alpha(self): 57 | """Creates the variable log_alpha, through NN processing of input.""" 58 | 59 | with tf.variable_scope("model_params"): 60 | # each number is processed with the same network, so data is reshaped 61 | # so that numbers occupy the 'batch' position. 62 | random_flattened = tf.reshape(self._random, [-1, 1]) 63 | # net: output of the first neural network that connects numbers to a 64 | # 'latent' representation. 65 | net = dropout(fc(random_flattened, self.n_units), self.keep_prob) 66 | # now those latent representation is connected to rows of the matrix 67 | # log_alpha. 68 | processed = dropout( 69 | fc(net, self.n_numbers, activation_fn=None), self.keep_prob) 70 | 71 | # the matrix log_alpha is created by concatenation of the rows 72 | # corresponding to different numbers. 73 | return tf.reshape(processed, [-1, self.n_numbers, self.n_numbers]) 74 | 75 | with self.graph.as_default(): 76 | 77 | self.keep_prob = tf.constant(self.hparams.keep_prob, dtype=tf.float32) 78 | self.temperature = tf.constant(self.hparams.temperature, dtype=tf.float32) 79 | self._global_step = tf.Variable(0, trainable=False) 80 | fc = tf.contrib.layers.fully_connected 81 | dropout = tf.contrib.layers.dropout 82 | 83 | self._log_alpha = _create_log_alpha(self) 84 | # Now, we sample using gumbel_sinkhorn from the 85 | # constructed matrix log_alpha. 86 | (self._soft_perms_inf, 87 | self._log_alpha_w_noise) = sinkhorn_ops.gumbel_sinkhorn( 88 | self._log_alpha, self.temperature, self.samples_per_num, 89 | self.noise_factor, self.n_iter_sinkhorn, squeeze=False) 90 | 91 | def build_initializer(self): 92 | with self.graph.as_default(): 93 | tf.initialize_all_variables() 94 | 95 | def build_l2s_loss(self): 96 | """Builds loss tensor with soft permutations, for training.""" 97 | with self.graph.as_default(): 98 | inv_soft_perms = tf.transpose(self._soft_perms_inf, [0, 1, 3, 2]) 99 | inv_soft_perms_flat = tf.reshape( 100 | tf.transpose(inv_soft_perms, [1, 0, 2, 3]), 101 | [-1, self.n_numbers, self.n_numbers]) 102 | ordered_tiled = tf.reshape(self._ordered_tiled, [-1, self.n_numbers, 1]) 103 | random_tiled = tf.reshape(self._random_tiled, [-1, self.n_numbers, 1]) 104 | # squared l2 loss 105 | self._l2s_diff = tf.reduce_mean( 106 | tf.square( 107 | ordered_tiled - tf.matmul(inv_soft_perms_flat, random_tiled))) 108 | 109 | def build_hard_losses(self): 110 | """Losses based on hard reconstruction. Only for evaluation. 111 | 112 | Doubly stochastic matrices are rounded with 113 | the matching function. 114 | """ 115 | 116 | log_alpha_w_noise_flat = tf.reshape(tf.transpose(self._log_alpha_w_noise, 117 | [1, 0, 2, 3]), 118 | [-1, self.n_numbers, self.n_numbers]) 119 | 120 | hard_perms_inf = sinkhorn_ops.matching(log_alpha_w_noise_flat) 121 | inverse_hard_perms_inf = sinkhorn_ops.invert_listperm(hard_perms_inf) 122 | hard_perms_tiled = tf.tile(self._hard_perms, 123 | [self.samples_per_num, 1]) 124 | 125 | # The 3D output of permute_batch_split must be squeezed 126 | self._ordered_inf_tiled = tf.reshape( 127 | sinkhorn_ops.permute_batch_split( 128 | self._random_tiled, inverse_hard_perms_inf), 129 | [-1, self.n_numbers]) 130 | 131 | self._l1_diff = tf.reduce_mean( 132 | tf.abs(self._ordered_tiled - self._ordered_inf_tiled)) 133 | self._l2sh_diff = tf.reduce_mean( 134 | tf.square(self._ordered_tiled - self._ordered_inf_tiled)) 135 | diff_perms = tf.cast( 136 | tf.abs(hard_perms_tiled - inverse_hard_perms_inf), tf.float32) 137 | self._prop_wrong = -tf.reduce_mean(tf.sign(-diff_perms)) 138 | self._prop_any_wrong = -tf.reduce_mean( 139 | tf.sign(-tf.reduce_sum(diff_perms, 1))) 140 | self._kendall_tau = tf.reduce_mean( 141 | sinkhorn_ops.kendall_tau(hard_perms_tiled, inverse_hard_perms_inf)) 142 | 143 | def build_optimizer(self): 144 | with self.graph.as_default(): 145 | opt = optimizer.set_optimizer(self.hparams.optimizer, 146 | self.hparams.lr, opt_eps=1e-8) 147 | self._train_op = tf.contrib.training.create_train_op( 148 | self._l2s_diff, opt, global_step=self._global_step) 149 | 150 | def build_train_ops(self): 151 | with self.graph.as_default(): 152 | self._vars = tf.trainable_variables() 153 | self._train_op = tf.train.AdamOptimizer( 154 | self.hparams.lr).minimize(self._l2s_diff, 155 | var_list=self._vars, 156 | global_step=self._global_step) 157 | 158 | def add_summaries_train(self): 159 | """Adds necessary summaries which will be computed during training.""" 160 | with tf.name_scope("Training"): 161 | with self.graph.as_default(): 162 | tf.summary.scalar("Total_l2_squared_loss", self._l2s_diff) 163 | 164 | def add_summaries_eval(self): 165 | """Adds necessary summaries which will be computed during evaluation.""" 166 | with tf.name_scope("Evaluation"): 167 | with self.graph.as_default(): 168 | tf.summary.scalar("L1_diff", self._l1_diff) 169 | tf.summary.scalar("L2_squared_diff", self._l2sh_diff) 170 | tf.summary.scalar("Proportion_wrong", self._prop_wrong) 171 | tf.summary.scalar("Proportion_where_any_is_wrong", 172 | self._prop_any_wrong) 173 | tf.summary.scalar("Kendall's_tau", 174 | self._kendall_tau) 175 | 176 | def get_eval_measures(self): 177 | """Getter method for evaluation measures.""" 178 | return (self._l1_diff, self._l2sh_diff, self._kendall_tau, 179 | self._prop_wrong, self._prop_any_wrong) 180 | 181 | @property 182 | def train_op(self): 183 | return self._train_op 184 | 185 | @property 186 | def ordered_inf(self): 187 | return tf.transpose( 188 | tf.reshape(self._ordered_inf_tiled, 189 | [self.samples_per_num, -1, self.n_numbers]), [1, 0, 2]) 190 | 191 | @property 192 | def ordered(self): 193 | return tf.transpose( 194 | tf.reshape(self._ordered_tiled, 195 | [self.samples_per_num, -1, self.n_numbers]), [1, 0, 2]) 196 | 197 | @property 198 | def random(self): 199 | return tf.transpose( 200 | tf.reshape(self._random_tiled, 201 | [self.samples_per_num, -1, self.n_numbers]), [1, 0, 2]) 202 | 203 | @property 204 | def global_step(self): 205 | return self._global_step 206 | 207 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /sinkhorn_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A tensorflow lib of ops with permutations, and sinkhorn balancing. 16 | 17 | A tensorflow library of operations and sampling with permutations 18 | and their approximation with doubly-stochastic matrices, through Sinkhorn 19 | balancing 20 | 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import numpy as np 28 | from scipy.optimize import linear_sum_assignment 29 | from scipy.stats import kendalltau 30 | import tensorflow as tf 31 | 32 | 33 | def sample_gumbel(shape, eps=1e-20): 34 | """Samples arbitrary-shaped standard gumbel variables. 35 | 36 | Args: 37 | shape: list of integers 38 | eps: float, for numerical stability 39 | Returns: 40 | A sample of standard Gumbel random variables 41 | """ 42 | 43 | u = tf.random_uniform(shape, minval=0, maxval=1, dtype=tf.float32) 44 | return -tf.log(-tf.log(u + eps) + eps) 45 | 46 | 47 | def matching(matrix_batch): 48 | """Solves a matching problem for a batch of matrices. 49 | 50 | This is a wrapper for the scipy.optimize.linear_sum_assignment function. It 51 | solves the optimization problem max_P sum_i,j M_i,j P_i,j with P a 52 | permutation matrix. Notice the negative sign; the reason, the original 53 | function solves a minimization problem 54 | 55 | Args: 56 | matrix_batch: A 3D tensor (a batch of matrices) with 57 | shape = [batch_size, N, N]. If 2D, the input is reshaped to 3D with 58 | batch_size = 1. 59 | 60 | Returns: 61 | listperms, a 2D integer tensor of permutations with shape [batch_size, N] 62 | so that listperms[n, :] is the permutation of range(N) that solves the 63 | problem max_P sum_i,j M_i,j P_i,j with M = matrix_batch[n, :, :]. 64 | """ 65 | 66 | def hungarian(x): 67 | if x.ndim == 2: 68 | x = np.reshape(x, [1, x.shape[0], x.shape[1]]) 69 | sol = np.zeros((x.shape[0], x.shape[1]), dtype=np.int32) 70 | for i in range(x.shape[0]): 71 | sol[i, :] = linear_sum_assignment(-x[i, :])[1].astype(np.int32) 72 | return sol 73 | 74 | listperms = tf.py_func(hungarian, [matrix_batch], tf.int32) 75 | return listperms 76 | 77 | 78 | def kendall_tau(batch_perm1, batch_perm2): 79 | """Wraps scipy.stats kendalltau function. 80 | 81 | Args: 82 | batch_perm1: A 2D tensor (a batch of matrices) with 83 | shape = [batch_size, N] 84 | batch_perm2: same as batch_perm1 85 | 86 | Returns: 87 | A list of Kendall distances between each of the elements of the batch. 88 | """ 89 | 90 | def kendalltau_batch(x, y): 91 | 92 | if x.ndim == 1: 93 | x = np.reshape(x, [1, x.shape[0]]) 94 | if y.ndim == 1: 95 | y = np.reshape(y, [1, y.shape[0]]) 96 | kendall = np.zeros((x.shape[0], 1), dtype=np.float32) 97 | for i in range(x.shape[0]): 98 | kendall[i, :] = kendalltau(x[i, :], y[i, :])[0] 99 | return kendall 100 | 101 | listkendall = tf.py_func(kendalltau_batch, [batch_perm1, batch_perm2], 102 | tf.float32) 103 | return listkendall 104 | 105 | 106 | def sinkhorn(log_alpha, n_iters=20): 107 | """Performs incomplete Sinkhorn normalization to log_alpha. 108 | 109 | By a theorem by Sinkhorn and Knopp [1], a sufficiently well-behaved matrix 110 | with positive entries can be turned into a doubly-stochastic matrix 111 | (i.e. its rows and columns add up to one) via the succesive row and column 112 | normalization. 113 | -To ensure positivity, the effective input to sinkhorn has to be 114 | exp(log_alpha) (elementwise). 115 | -However, for stability, sinkhorn works in the log-space. It is only at 116 | return time that entries are exponentiated. 117 | 118 | [1] Sinkhorn, Richard and Knopp, Paul. 119 | Concerning nonnegative matrices and doubly stochastic 120 | matrices. Pacific Journal of Mathematics, 1967 121 | 122 | Args: 123 | log_alpha: 2D tensor (a matrix of shape [N, N]) 124 | or 3D tensor (a batch of matrices of shape = [batch_size, N, N]) 125 | n_iters: number of sinkhorn iterations (in practice, as little as 20 126 | iterations are needed to achieve decent convergence for N~100) 127 | 128 | Returns: 129 | A 3D tensor of close-to-doubly-stochastic matrices (2D tensors are 130 | converted to 3D tensors with batch_size equals to 1) 131 | """ 132 | n = tf.shape(log_alpha)[1] 133 | log_alpha = tf.reshape(log_alpha, [-1, n, n]) 134 | 135 | for _ in range(n_iters): 136 | log_alpha -= tf.reshape(tf.reduce_logsumexp(log_alpha, axis=2), [-1, n, 1]) 137 | log_alpha -= tf.reshape(tf.reduce_logsumexp(log_alpha, axis=1), [-1, 1, n]) 138 | return tf.exp(log_alpha) 139 | 140 | 141 | def gumbel_sinkhorn(log_alpha, 142 | temp=1.0, n_samples=1, noise_factor=1.0, n_iters=20, 143 | squeeze=True): 144 | """Random doubly-stochastic matrices via gumbel noise. 145 | 146 | In the zero-temperature limit sinkhorn(log_alpha/temp) approaches 147 | a permutation matrix. Therefore, for low temperatures this method can be 148 | seen as an approximate sampling of permutation matrices, where the 149 | distribution is parameterized by the matrix log_alpha 150 | 151 | The deterministic case (noise_factor=0) is also interesting: it can be 152 | shown that lim t->0 sinkhorn(log_alpha/t) = M, where M is a 153 | permutation matrix, the solution of the 154 | matching problem M=arg max_M sum_i,j log_alpha_i,j M_i,j. 155 | Therefore, the deterministic limit case of gumbel_sinkhorn can be seen 156 | as approximate solving of a matching problem, otherwise solved via the 157 | Hungarian algorithm. 158 | 159 | Warning: the convergence holds true in the limit case n_iters = infty. 160 | Unfortunately, in practice n_iter is finite which can lead to numerical 161 | instabilities, mostly if temp is very low. Those manifest as 162 | pseudo-convergence or some row-columns to fractional entries (e.g. 163 | a row having two entries with 0.5, instead of a single 1.0) 164 | To minimize those effects, try increasing n_iter for decreased temp. 165 | On the other hand, too-low temperature usually lead to high-variance in 166 | gradients, so better not choose too low temperatures. 167 | 168 | Args: 169 | log_alpha: 2D tensor (a matrix of shape [N, N]) 170 | or 3D tensor (a batch of matrices of shape = [batch_size, N, N]) 171 | temp: temperature parameter, a float. 172 | n_samples: number of samples 173 | noise_factor: scaling factor for the gumbel samples. Mostly to explore 174 | different degrees of randomness (and the absence of randomness, with 175 | noise_factor=0) 176 | n_iters: number of sinkhorn iterations. Should be chosen carefully, in 177 | inverse corresponde with temp to avoid numerical stabilities. 178 | squeeze: a boolean, if True and there is a single sample, the output will 179 | remain being a 3D tensor. 180 | 181 | Returns: 182 | sink: a 4D tensor of [batch_size, n_samples, N, N] i.e. 183 | batch_size *n_samples doubly-stochastic matrices. If n_samples = 1 and 184 | squeeze = True then the output is 3D. 185 | log_alpha_w_noise: a 4D tensor of [batch_size, n_samples, N, N] of 186 | noisy samples of log_alpha, divided by the temperature parameter. If 187 | n_samples = 1 then the output is 3D. 188 | """ 189 | n = tf.shape(log_alpha)[1] 190 | log_alpha = tf.reshape(log_alpha, [-1, n, n]) 191 | batch_size = tf.shape(log_alpha)[0] 192 | log_alpha_w_noise = tf.tile(log_alpha, [n_samples, 1, 1]) 193 | if noise_factor == 0: 194 | noise = 0.0 195 | else: 196 | noise = sample_gumbel([n_samples*batch_size, n, n])*noise_factor 197 | log_alpha_w_noise += noise 198 | log_alpha_w_noise /= temp 199 | sink = sinkhorn(log_alpha_w_noise, n_iters) 200 | if n_samples > 1 or squeeze is False: 201 | sink = tf.reshape(sink, [n_samples, batch_size, n, n]) 202 | sink = tf.transpose(sink, [1, 0, 2, 3]) 203 | log_alpha_w_noise = tf.reshape( 204 | log_alpha_w_noise, [n_samples, batch_size, n, n]) 205 | log_alpha_w_noise = tf.transpose(log_alpha_w_noise, [1, 0, 2, 3]) 206 | return sink, log_alpha_w_noise 207 | 208 | 209 | def sample_uniform_and_order(n_lists, n_numbers, prob_inc): 210 | """Samples uniform random numbers, and sort them. 211 | 212 | Returns a 2-D tensor of n_lists lists of n_numbers sorted numbers in the [0,1] 213 | interval, each of them having n_numbers elements. 214 | Lists are increasing with probability prob_inc. 215 | It does so by first sampling uniform random numbers, and then sorting them. 216 | Therefore, sorted numbers follow the distribution of the order statistics of 217 | a uniform distribution. 218 | It also returns the random numbers and the lists of permutations p such 219 | p(sorted) = random. 220 | Notice that if one ones to build sorted numbers in different intervals, one 221 | might just want to re-scaled this canonical form. 222 | 223 | Args: 224 | n_lists: An int,the number of lists to be sorted. 225 | n_numbers: An int, the number of elements in the permutation. 226 | prob_inc: A float, the probability that an list of numbers will be sorted in 227 | increasing order. 228 | 229 | Returns: 230 | ordered: a 2-D float tensor with shape = [n_list, n_numbers] of sorted lists 231 | of numbers. 232 | random: a 2-D float tensor with shape = [n_list, n_numbers] of uniform random 233 | numbers. 234 | permutations: a 2-D int tensor with shape = [n_list, n_numbers], row i 235 | satisfies ordered[i, permutations[i]) = random[i,:]. 236 | 237 | """ 238 | 239 | bern = tf.contrib.distributions.Bernoulli( 240 | probs=np.ones((n_lists, 1)) * prob_inc).sample() 241 | sign = -1*tf.cast(tf.multiply(bern, 2) -1, dtype=tf.float32) 242 | random = tf.random_uniform(shape=[n_lists, n_numbers], dtype=tf.float32) 243 | random_with_sign = tf.multiply(random, sign) 244 | ordered, permutations = tf.nn.top_k(random_with_sign, k=n_numbers) 245 | ordered = tf.multiply(ordered, sign) 246 | return ordered, random, permutations 247 | 248 | 249 | def sample_permutations(n_permutations, n_objects): 250 | """Samples a batch permutations from the uniform distribution. 251 | 252 | Returns a sample of n_permutations permutations of n_objects indices. 253 | Permutations are assumed to be represented as lists of integers 254 | (see 'listperm2matperm' and 'matperm2listperm' for conversion to alternative 255 | matricial representation). It does so by sampling from a continuous 256 | distribution and then ranking the elements. By symmetry, the resulting 257 | distribution over permutations must be uniform. 258 | 259 | Args: 260 | n_permutations: An int, the number of permutations to sample. 261 | n_objects: An int, the number of elements in the permutation. 262 | the embedding sources. 263 | 264 | Returns: 265 | A 2D integer tensor with shape [n_permutations, n_objects], where each 266 | row is a permutation of range(n_objects) 267 | 268 | """ 269 | 270 | random_pre_perm = tf.random_normal(shape=[n_permutations, n_objects]) 271 | _, permutations = tf.nn.top_k(random_pre_perm, k=n_objects) 272 | return permutations 273 | 274 | 275 | def permute_batch_split(batch_split, permutations): 276 | """Scrambles a batch of objects according to permutations. 277 | 278 | It takes a 3D tensor [batch_size, n_objects, object_size] 279 | and permutes items in axis=1 according to the 2D integer tensor 280 | permutations, (with shape [batch_size, n_objects]) a list of permutations 281 | expressed as lists. For many dimensional-objects (e.g. images), objects have 282 | to be flattened so they will respect the 3D format, i.e. tf.reshape( 283 | batch_split, [batch_size, n_objects, -1]) 284 | 285 | Args: 286 | batch_split: 3D tensor with shape = [batch_size, n_objects, object_size] of 287 | splitted objects 288 | permutations: a 2D integer tensor with shape = [batch_size, n_objects] of 289 | permutations, so that permutations[n] is a permutation of range(n_objects) 290 | 291 | Returns: 292 | A 3D tensor perm_batch_split with the same shape as batch_split, 293 | so that perm_batch_split[n, j,:] = batch_split[n, perm[n,j],:] 294 | 295 | """ 296 | 297 | batch_split_size = tf.shape(batch_split, out_type=permutations.dtype)[0] 298 | n_objects = tf.shape(batch_split)[1] 299 | 300 | ind_permutations = tf.reshape(permutations, [-1, 1]) 301 | 302 | ind_batch = tf.reshape(tf.tile(tf.reshape(tf.range(batch_split_size), 303 | [-1, 1]), 304 | [1, n_objects]), 305 | [-1, 1]) 306 | 307 | ind_batch_and_permutation = tf.concat((ind_batch, ind_permutations), axis=1) 308 | 309 | batch_split = tf.reshape(tf.gather_nd(batch_split, ind_batch_and_permutation), 310 | [batch_split_size, n_objects, -1]) 311 | 312 | return batch_split 313 | 314 | 315 | def listperm2matperm(listperm): 316 | """Converts a batch of permutations to its matricial form. 317 | 318 | Args: 319 | listperm: 2D tensor of permutations of shape [batch_size, n_objects] so that 320 | listperm[n] is a permutation of range(n_objects). 321 | 322 | Returns: 323 | a 3D tensor of permutations matperm of 324 | shape = [batch_size, n_objects, n_objects] so that matperm[n, :, :] is a 325 | permutation of the identity matrix, with matperm[n, i, listperm[n,i]] = 1 326 | """ 327 | n_objects = tf.shape(listperm)[1] 328 | return tf.one_hot(listperm, n_objects) 329 | 330 | 331 | def matperm2listperm(matperm, dtype=tf.int32): 332 | """Converts a batch of permutations to its enumeration (list) form. 333 | 334 | Args: 335 | matperm: a 3D tensor of permutations of 336 | shape = [batch_size, n_objects, n_objects] so that matperm[n, :, :] is a 337 | permutation of the identity matrix. If the input is 2D, it is reshaped 338 | to 3D with batch_size = 1. 339 | dtype: output_type (tf.int32, tf.int64) 340 | 341 | Returns: 342 | A 2D tensor of permutations listperm, where listperm[n,i] 343 | is the index of the only non-zero entry in matperm[n, i, :] 344 | """ 345 | matperm = tf.reshape(matperm, [-1, 346 | tf.shape(matperm)[1], tf.shape(matperm)[1]]) 347 | batch_size = tf.shape(matperm)[0] 348 | n_objects = tf.shape(matperm)[1] 349 | 350 | return tf.reshape(tf.argmax(matperm, axis=2, output_type=dtype), 351 | [batch_size, n_objects]) 352 | 353 | 354 | def invert_listperm(listperm): 355 | """Inverts a batch of permutations. 356 | 357 | Args: 358 | listperm: a 2D integer tensor of permutations listperm of 359 | shape = [batch_size, n_objects] so that listperm[n] is a permutation of 360 | range(n_objects) 361 | Returns: 362 | A 2D tensor of permutations listperm, where listperm[n,i] 363 | is the index of the only non-zero entry in matperm[n, i, :] 364 | """ 365 | return matperm2listperm(tf.transpose(listperm2matperm(listperm), [0, 2, 1])) 366 | --------------------------------------------------------------------------------