├── .gitignore ├── LICENSE ├── README.md ├── images ├── Pytorch_CPU_softsort.png ├── Pytorch_GPU_softsort.png ├── TensorFlow_CPU_softsort.png ├── TensorFlow_GPU_softsort.png ├── laplace_and_gaussian.py ├── laplace_and_gaussian_softsort.png ├── run_median_learning_curve.png ├── run_sort_learning_curve.png └── synthetic_experiment_learning_curves.png ├── pytorch ├── .gitignore ├── checkpoint │ └── .gitignore ├── dknn_layer.py ├── experiments │ └── .gitignore ├── logs │ └── .gitignore ├── models │ ├── easy_net.py │ └── preact_resnet.py ├── neuralsort.py ├── neuralsort_cpu_or_gpu.py ├── pl.py ├── run_baseline.py ├── run_dknn.py ├── run_dknn.sh ├── run_dknn_table_of_results.py ├── softsort.py ├── synthetic_experiment_learning_curves.py ├── synthetic_experiment_speed_comparison.py └── utils.py ├── requirements.txt ├── synthetic_experiment_learning_curves.sh ├── synthetic_experiment_learning_curves_plot.py ├── synthetic_experiment_speed_comparison.sh ├── synthetic_experiment_speed_comparison_plot.py └── tf ├── .gitignore ├── checkpoints └── .gitignore ├── logs └── .gitignore ├── mnist_input.py ├── multi_mnist_cnn.py ├── predictions └── .gitignore ├── run_median.py ├── run_median.sh ├── run_median_learning_curves.py ├── run_median_table_of_results.py ├── run_sort.py ├── run_sort.sh ├── run_sort_learning_curves.py ├── run_sort_table_of_results.py ├── sinkhorn.py ├── synthetic_experiment_learning_curves.py ├── synthetic_experiment_speed_comparison.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sebastián Prillo and Julián Eisenschlos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for "SoftSort: A Continuous Relaxation for the argsort Operator", ICML 2020. 2 | --- 3 | 4 | This repository is a fork of ```ermongroup/neuralsort``` implementing the SoftSort operator and reproducing all the results reported in the paper "SoftSort: A Continuous Relaxation for the argsort Operator". 5 | 6 | ## Requirements 7 | 8 | The codebase is implemented in Python 3.7. To install the necessary requirements, run the following commands: 9 | 10 | ``` 11 | pip3 install -r requirements.txt 12 | ``` 13 | 14 | ## Sorting Handwritten Numbers Experiment 15 | 16 | To reproduce the results in Table 1, just run: 17 | 18 | ``` 19 | cd tf 20 | bash run_sort.sh 21 | python3 run_sort_table_of_results.py 22 | ``` 23 | 24 | The first script (bash) will train all models. This takes a long time. You can inspect this script to see what parameters were used to train each model (which are the ones reported in the paper). The second script (python) will process the results from the models and print Table 1. 25 | 26 | To train a single model directly, you can use the `tf/run_sort.py` script, with the following arguments: 27 | 28 | ``` 29 | --M INT Minibatch size 30 | --n INT Number of elements to compare at a time 31 | --l INT Number of digits in each multi-mnist dataset element 32 | --tau FLOAT Temperature (either of sinkhorn or neuralsort relaxation) 33 | --method STRING One of 'deterministic_neuralsort', 'stochastic_neuralsort', 'deterministic_softsort', 'stochastic_softsort' 34 | --n_s INT Number of samples for stochastic methods 35 | --num_epochs INT Number of epochs to train 36 | --lr FLOAT Initial learning rate 37 | ``` 38 | 39 | ## Quantile Regression Experiment 40 | 41 | To reproduce the results in Table 2, just run: 42 | 43 | ``` 44 | cd tf 45 | bash run_median.sh 46 | python3 run_median_table_of_results.py 47 | ``` 48 | 49 | The first script (bash) will train all models. This takes a long time. You can inspect this script to see what parameters were used to train each model (which are the ones reported in the paper). The second script (python) will process the results from the models and print Table 2. 50 | 51 | To train a single model directly, you can use the `tf/run_median.py` script, with the following arguments: 52 | 53 | ``` 54 | --M INT Minibatch size 55 | --n INT Number of elements to compare at a time 56 | --l INT Number of digits in each multi-mnist dataset element 57 | --tau FLOAT Temperature (either of sinkhorn or neuralsort relaxation) 58 | --method STRING One of 'deterministic_neuralsort', 'stochastic_neuralsort', 'deterministic_softsort', 'stochastic_softsort' 59 | --n_s INT Number of samples for stochastic methods 60 | --num_epochs INT Number of epochs to train 61 | --lr FLOAT Initial learning rate 62 | ``` 63 | 64 | ## Differentiable kNN Experiment 65 | 66 | To reproduce the results in Table 3, run: 67 | 68 | ``` 69 | cd pytorch 70 | bash run_dknn.sh 71 | python3 run_dknn_table_of_results.py 72 | ``` 73 | 74 | The first script (bash) will train all the models. This takes about two days to sequentally test the different hyperparameter configurations. The seconds script iterates through logs and prints the best results. 75 | 76 | To train a single model directly, you can use the `pytorch/run_dknn.py` script, with the following arguments: 77 | 78 | ``` 79 | --simple Whether to use our softsort, or the baseline neuralsort 80 | --k INT Number of nearest neighbors 81 | --tau FLOAT Temperature of sorting operator 82 | --nloglr FLOAT Negative log10 of learning rate 83 | --method STRING One of 'deterministic', 'stochastic' 84 | --dataset STRING One of 'mnist', 'fashion-mnist', 'cifar10' 85 | --num_train_queries INT Number of queries to evaluate during training. 86 | --num_train_neighbors INT Number of neighbors to consider during training. 87 | --num_samples INT Number of samples for stochastic methods 88 | --num_epochs INT Number of epochs to train 89 | ``` 90 | 91 | ## Speed Comparison Experiment 92 | 93 | To reproduce the results in Figure 6, just run: 94 | 95 | ``` 96 | bash synthetic_experiment_speed_comparison.sh 97 | python3 synthetic_experiment_speed_comparison_plot.py 98 | ``` 99 | 100 | The first script (bash) will train all models. This takes some time. The second script (python) will process the results and print the graphs in Figure 6 under the ```images/``` directory. 101 | 102 | ## Learning Curves 103 | 104 | For the synthetic experiment learning curves, run: 105 | 106 | ``` 107 | bash synthetic_experiment_learning_curves.sh 108 | ``` 109 | 110 | Then, to generate the plot in Figure 8, run: 111 | 112 | ``` 113 | python3 synthetic_experiment_learning_curves_plot.py 114 | ``` 115 | 116 | To generate the run_sort and run_median learning curve plots (Figure 7), run: 117 | 118 | ``` 119 | cd tf 120 | python3 run_sort_learning_curves.py 121 | python3 run_median_learning_curves.py 122 | ``` 123 | -------------------------------------------------------------------------------- /images/Pytorch_CPU_softsort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/Pytorch_CPU_softsort.png -------------------------------------------------------------------------------- /images/Pytorch_GPU_softsort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/Pytorch_GPU_softsort.png -------------------------------------------------------------------------------- /images/TensorFlow_CPU_softsort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/TensorFlow_CPU_softsort.png -------------------------------------------------------------------------------- /images/TensorFlow_GPU_softsort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/TensorFlow_GPU_softsort.png -------------------------------------------------------------------------------- /images/laplace_and_gaussian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from scipy.stats import laplace, norm 4 | 5 | 6 | def plot( 7 | x_min, 8 | x_max, 9 | s_xs, 10 | s_ticks, 11 | distribution, 12 | legend, 13 | ylabel): 14 | fontsize = 16 15 | x = np.arange(start=x_min, stop=x_max, step=0.01) 16 | y = distribution.pdf(x) 17 | plt.plot(x, y) 18 | plt.ylabel(ylabel, fontsize=fontsize) 19 | plt.yticks([], []) 20 | plt.xticks(s_xs, s_ticks, fontsize=fontsize) 21 | plt.fill_between(x, y, facecolor='blue', alpha=0.2) 22 | for s_x in s_xs: 23 | plt.vlines(x=s_x, ymin=0, ymax=distribution.pdf(s_x)) 24 | plt.legend([legend], fontsize=fontsize) 25 | plt.tight_layout() 26 | 27 | 28 | plt.figure(figsize=(13, 5)) 29 | 30 | plt.subplot(1, 2, 1) 31 | plot(x_min=-4, 32 | x_max=4, 33 | s_xs=[-2, 0, 1, 2.5], 34 | s_ticks=[r'$s_4$', '$s_3$', '$s_2$', '$s_1$'], 35 | distribution=laplace, 36 | legend=r'$\propto\phi_{Laplace(s_3, \tau)}$', 37 | ylabel=r'$SoftSort^{|\cdot|}_\tau(s)[3, :]$') 38 | 39 | plt.subplot(1, 2, 2) 40 | plot(x_min=-4, 41 | x_max=4, 42 | s_xs=[-1.2, 0, 1.2, 2.4], 43 | s_ticks=[r'$s_4$', '$s_3$', '$s_2$', '$s_1$'], 44 | distribution=norm, 45 | legend=r'$\propto\phi_{\mathcal{N}(s_3, a\tau)}$', 46 | ylabel=r'$NeuralSort_\tau(s)[3, :]$') 47 | 48 | plt.savefig('laplace_and_gaussian_softsort') 49 | -------------------------------------------------------------------------------- /images/laplace_and_gaussian_softsort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/laplace_and_gaussian_softsort.png -------------------------------------------------------------------------------- /images/run_median_learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/run_median_learning_curve.png -------------------------------------------------------------------------------- /images/run_sort_learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/run_sort_learning_curve.png -------------------------------------------------------------------------------- /images/synthetic_experiment_learning_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sprillo/softsort/9590d8e8b0bbf4573aa02b7b19d9f34ed87c3f30/images/synthetic_experiment_learning_curves.png -------------------------------------------------------------------------------- /pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | /data* 2 | -------------------------------------------------------------------------------- /pytorch/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /pytorch/dknn_layer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | "Differentiable k-nearest neighbors" layer. 3 | 4 | Given a set of M queries and a set of N neighbors, 5 | returns an M x N matrix whose rows sum to k, indicating to what degree 6 | a certain neighbor is one of the k nearest neighbors to the query. 7 | At the limit of tau = 0, each entry is a binary value representing 8 | whether each neighbor is actually one of the k closest to each query. 9 | ''' 10 | 11 | import torch 12 | from pl import PL 13 | from neuralsort import NeuralSort 14 | from softsort import SoftSort 15 | 16 | 17 | class DKNN (torch.nn.Module): 18 | def __init__(self, k, tau=1.0, hard=False, method='deterministic', num_samples=-1, simple=False): 19 | super(DKNN, self).__init__() 20 | self.k = k 21 | self.soft_sort = SoftSort(tau=tau, hard=hard) if simple else NeuralSort(tau=tau, hard=hard) 22 | self.method = method 23 | self.num_samples = num_samples 24 | 25 | # query: M x p 26 | # neighbors: N x p 27 | # 28 | # returns: 29 | def forward(self, query, neighbors, tau=1.0): 30 | diffs = (query.unsqueeze(1) - neighbors.unsqueeze(0)) 31 | squared_diffs = diffs ** 2 32 | l2_norms = squared_diffs.sum(2) 33 | norms = l2_norms 34 | scores = -norms 35 | 36 | if self.method == 'deterministic': 37 | P_hat = self.soft_sort(scores) 38 | top_k = P_hat[:, :self.k, :].sum(1) 39 | return top_k 40 | if self.method == 'stochastic': 41 | pl_s = PL(scores, tau, hard=False) 42 | P_hat = pl_s.sample((self.num_samples,)) 43 | top_k = P_hat[:, :, :self.k, :].sum(2) 44 | return top_k 45 | -------------------------------------------------------------------------------- /pytorch/experiments/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /pytorch/logs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /pytorch/models/easy_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvNet(nn.Module): 7 | def __init__(self): 8 | super(ConvNet, self).__init__() 9 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 10 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 11 | self.fc1 = nn.Linear(4 * 4 * 50, 500) 12 | 13 | def forward(self, x): 14 | x = F.relu(self.conv1(x)) 15 | x = F.max_pool2d(x, 2, 2) 16 | x = F.relu(self.conv2(x)) 17 | x = F.max_pool2d(x, 2, 2) 18 | x = x.view(-1, 4 * 4 * 50) 19 | x = F.relu(self.fc1(x)) 20 | return x 21 | 22 | 23 | cn = ConvNet() 24 | cn(torch.zeros(1, 1, 28, 28)) 25 | -------------------------------------------------------------------------------- /pytorch/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code is due to Kuang Liu (https://github.com/kuangliu). 3 | 4 | Pre-activation ResNet in PyTorch. 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class PreActBlock(nn.Module): 15 | '''Pre-activation version of the BasicBlock.''' 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(PreActBlock, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(in_planes) 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | 25 | if stride != 1 or in_planes != self.expansion * planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(x)) 32 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 33 | out = self.conv1(out) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out += shortcut 36 | return out 37 | 38 | 39 | class PreActBottleneck(nn.Module): 40 | '''Pre-activation version of the original Bottleneck module.''' 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(PreActBottleneck, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 51 | 52 | if stride != 1 or in_planes != self.expansion * planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(x)) 59 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 60 | out = self.conv1(out) 61 | out = self.conv2(F.relu(self.bn2(out))) 62 | out = self.conv3(F.relu(self.bn3(out))) 63 | out += shortcut 64 | return out 65 | 66 | 67 | class PreActResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10, num_channels=3): 69 | super(PreActResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 77 | self.linear = nn.Linear(512 * block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1] * (num_blocks - 1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = self.conv1(x) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | out = out.view(out.size(0), -1) 95 | # removed this so we have a kNN space. 96 | # out = self.linear(out) 97 | return out 98 | 99 | 100 | def PreActResNet18(num_channels=3): 101 | return PreActResNet(PreActBlock, [2, 2, 2, 2], num_channels=num_channels) 102 | 103 | 104 | def PreActResNet34(): 105 | return PreActResNet(PreActBlock, [3, 4, 6, 3]) 106 | 107 | 108 | def PreActResNet50(): 109 | return PreActResNet(PreActBottleneck, [3, 4, 6, 3]) 110 | 111 | 112 | def PreActResNet101(): 113 | return PreActResNet(PreActBottleneck, [3, 4, 23, 3]) 114 | 115 | 116 | def PreActResNet152(): 117 | return PreActResNet(PreActBottleneck, [3, 8, 36, 3]) 118 | 119 | 120 | def test(): 121 | net = PreActResNet18() 122 | y = net((torch.randn(1, 3, 32, 32))) 123 | print(y.size()) 124 | 125 | 126 | # test() 127 | 128 | ''' 129 | 130 | MIT License 131 | 132 | Copyright (c) 2017 liukuang 133 | 134 | Permission is hereby granted, free of charge, to any person obtaining a copy 135 | of this software and associated documentation files (the "Software"), to deal 136 | in the Software without restriction, including without limitation the rights 137 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 138 | copies of the Software, and to permit persons to whom the Software is 139 | furnished to do so, subject to the following conditions: 140 | 141 | The above copyright notice and this permission notice shall be included in all 142 | copies or substantial portions of the Software. 143 | 144 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 145 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 146 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 147 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 148 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 149 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 150 | SOFTWARE. 151 | 152 | ''' 153 | -------------------------------------------------------------------------------- /pytorch/neuralsort.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | 7 | class NeuralSort(torch.nn.Module): 8 | def __init__(self, tau=1.0, hard=False): 9 | super(NeuralSort, self).__init__() 10 | self.hard = hard 11 | self.tau = tau 12 | 13 | def forward(self, scores: Tensor): 14 | """ 15 | scores: elements to be sorted. Typical shape: batch_size x n 16 | """ 17 | scores = scores.unsqueeze(-1) 18 | bsize = scores.size()[0] 19 | dim = scores.size()[1] 20 | one = torch.cuda.FloatTensor(dim, 1).fill_(1) 21 | 22 | A_scores = torch.abs(scores - scores.permute(0, 2, 1)) 23 | # B = torch.matmul(A_scores, torch.matmul( 24 | # one, torch.transpose(one, 0, 1))) # => NeuralSort O(n^3) BUG! 25 | B = torch.matmul(torch.matmul(A_scores, 26 | one), torch.transpose(one, 0, 1)) # => Bugfix 27 | scaling = (dim + 1 - 2 * (torch.arange(dim) + 1) 28 | ).type(torch.cuda.FloatTensor) 29 | C = torch.matmul(scores, scaling.unsqueeze(0)) 30 | 31 | P_max = (C - B).permute(0, 2, 1) 32 | sm = torch.nn.Softmax(-1) 33 | P_hat = sm(P_max / self.tau) 34 | 35 | if self.hard: 36 | P = torch.zeros_like(P_hat, device='cuda') 37 | b_idx = torch.arange(bsize).repeat([1, dim]).view(dim, bsize).transpose( 38 | dim0=1, dim1=0).flatten().type(torch.cuda.LongTensor) 39 | r_idx = torch.arange(dim).repeat( 40 | [bsize, 1]).flatten().type(torch.cuda.LongTensor) 41 | c_idx = torch.argmax(P_hat, dim=-1).flatten() # this is on cuda 42 | brc_idx = torch.stack((b_idx, r_idx, c_idx)) 43 | 44 | P[brc_idx[0], brc_idx[1], brc_idx[2]] = 1 45 | P_hat = (P - P_hat).detach() + P_hat 46 | return P_hat 47 | -------------------------------------------------------------------------------- /pytorch/neuralsort_cpu_or_gpu.py: -------------------------------------------------------------------------------- 1 | r''' 2 | This is the same as neuralsort.py, but instead of being hardcoded into GPU it 3 | allows using either GPU or CPU. (Compare 'forward' method to the on in neuralsort.py) 4 | ''' 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | 10 | class NeuralSort(torch.nn.Module): 11 | def __init__(self, tau=1.0, hard=False, device='cuda'): 12 | super(NeuralSort, self).__init__() 13 | self.hard = hard 14 | self.tau = tau 15 | self.device = device 16 | if device == 'cuda': 17 | self.torch = torch.cuda 18 | elif device == 'cpu': 19 | self.torch = torch 20 | else: 21 | raise ValueError('Unknown device: %s' % device) 22 | 23 | def forward(self, scores: Tensor): 24 | """ 25 | scores: elements to be sorted. Typical shape: batch_size x n 26 | """ 27 | scores = scores.unsqueeze(-1) 28 | bsize = scores.size()[0] 29 | dim = scores.size()[1] 30 | one = self.torch.FloatTensor(dim, 1).fill_(1) 31 | 32 | A_scores = torch.abs(scores - scores.permute(0, 2, 1)) 33 | # B = torch.matmul(A_scores, torch.matmul( 34 | # one, torch.transpose(one, 0, 1))) # => NeuralSort O(n^3) BUG! 35 | B = torch.matmul(torch.matmul(A_scores, 36 | one), torch.transpose(one, 0, 1)) # => Bugfix 37 | scaling = (dim + 1 - 2 * (torch.arange(dim) + 1) 38 | ).type(self.torch.FloatTensor) 39 | C = torch.matmul(scores, scaling.unsqueeze(0)) 40 | 41 | P_max = (C - B).permute(0, 2, 1) 42 | sm = torch.nn.Softmax(-1) 43 | P_hat = sm(P_max / self.tau) 44 | 45 | if self.hard: 46 | P = torch.zeros_like(P_hat, device=self.device) 47 | b_idx = torch.arange(bsize).repeat([1, dim]).view(dim, bsize).transpose( 48 | dim0=1, dim1=0).flatten().type(self.torch.LongTensor) 49 | r_idx = torch.arange(dim).repeat( 50 | [bsize, 1]).flatten().type(self.torch.LongTensor) 51 | c_idx = torch.argmax(P_hat, dim=-1).flatten() # this is on cuda 52 | brc_idx = torch.stack((b_idx, r_idx, c_idx)) 53 | 54 | P[brc_idx[0], brc_idx[1], brc_idx[2]] = 1 55 | P_hat = (P - P_hat).detach() + P_hat 56 | return P_hat 57 | -------------------------------------------------------------------------------- /pytorch/pl.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Relaxed Plackett-Luce distribution. 3 | ''' 4 | 5 | from numbers import Number 6 | 7 | import torch 8 | from torch.distributions.distribution import Distribution 9 | from torch.distributions import constraints 10 | 11 | # use GPU if available 12 | USE_CUDA = torch.cuda.is_available() 13 | FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 14 | LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor 15 | ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor 16 | 17 | 18 | class PL(Distribution): 19 | 20 | arg_constraints = {'scores': constraints.positive, 21 | 'tau': constraints.positive} 22 | has_rsample = True 23 | 24 | @property 25 | def mean(self): 26 | # mode of the PL distribution 27 | return self.relaxed_sort(self.scores) 28 | 29 | def __init__(self, scores, tau, hard=True, validate_args=None): 30 | """ 31 | scores. Shape: (batch_size x) n 32 | tau: temperature for the relaxation. Scalar. 33 | hard: use straight-through estimation if True 34 | """ 35 | self.scores = scores.unsqueeze(-1) 36 | self.tau = tau 37 | self.hard = hard 38 | self.n = self.scores.size()[1] 39 | 40 | if isinstance(scores, Number): 41 | batch_shape = torch.Size() 42 | else: 43 | batch_shape = self.scores.size() 44 | super(PL, self).__init__(batch_shape, validate_args=validate_args) 45 | 46 | if self._validate_args: 47 | if not torch.gt(self.scores, torch.zeros_like(self.scores)).all(): 48 | raise ValueError("PL is not defined when scores <= 0") 49 | 50 | def relaxed_sort(self, inp): 51 | """ 52 | inp: elements to be sorted. Typical shape: batch_size x n x 1 53 | """ 54 | bsize = inp.size()[0] 55 | dim = inp.size()[1] 56 | one = FloatTensor(dim, 1).fill_(1) 57 | 58 | A_inp = torch.abs(inp - inp.permute(0, 2, 1)) 59 | B = torch.matmul(A_inp, torch.matmul(one, torch.transpose(one, 0, 1))) 60 | scaling = (dim + 1 - 2 * (torch.arange(dim) + 1)).type(FloatTensor) 61 | C = torch.matmul(inp, scaling.unsqueeze(0)) 62 | 63 | P_max = (C - B).permute(0, 2, 1) 64 | sm = torch.nn.Softmax(-1) 65 | P_hat = sm(P_max / self.tau) 66 | 67 | if self.hard: 68 | P = torch.zeros_like(P_hat) 69 | b_idx = torch.arange(bsize).repeat([1, dim]).view( 70 | dim, bsize).transpose(dim0=1, dim1=0).flatten().type(LongTensor) 71 | r_idx = torch.arange(dim).repeat( 72 | [bsize, 1]).flatten().type(LongTensor) 73 | c_idx = torch.argmax(P_hat, dim=-1).flatten() # this is on cuda 74 | brc_idx = torch.stack((b_idx, r_idx, c_idx)) 75 | 76 | P[brc_idx[0], brc_idx[1], brc_idx[2]] = 1 77 | P_hat = (P - P_hat).detach() + P_hat 78 | return P_hat 79 | 80 | def rsample(self, sample_shape, log_score=True): 81 | """ 82 | sample_shape: number of samples from the PL distribution. Scalar. 83 | """ 84 | with torch.enable_grad(): # torch.distributions turns off autograd 85 | n_samples = sample_shape[0] 86 | 87 | def sample_gumbel(samples_shape, eps=1e-20): 88 | U = torch.zeros(samples_shape, device='cuda').uniform_() 89 | return -torch.log(-torch.log(U + eps) + eps) 90 | if not log_score: 91 | log_s_perturb = torch.log(self.scores.unsqueeze( 92 | 0)) + sample_gumbel([n_samples, 1, self.n, 1]) 93 | else: 94 | log_s_perturb = self.scores.unsqueeze( 95 | 0) + sample_gumbel([n_samples, 1, self.n, 1]) 96 | log_s_perturb = log_s_perturb.view(-1, self.n, 1) 97 | P_hat = self.relaxed_sort(log_s_perturb) 98 | P_hat = P_hat.view(n_samples, -1, self.n, self.n) 99 | 100 | return P_hat.squeeze() 101 | 102 | def log_prob(self, value): 103 | """ 104 | value: permutation matrix. shape: batch_size x n x n 105 | """ 106 | permuted_scores = torch.squeeze(torch.matmul(value, self.scores)) 107 | log_numerator = torch.sum(torch.log(permuted_scores), dim=-1) 108 | idx = LongTensor([i for i in range(self.n - 1, -1, -1)]) 109 | invert_permuted_scores = permuted_scores.index_select(-1, idx) 110 | denominators = torch.cumsum(invert_permuted_scores, dim=-1) 111 | log_denominator = torch.sum(torch.log(denominators), dim=-1) 112 | return (log_numerator - log_denominator) 113 | 114 | 115 | if __name__ == '__main__': 116 | 117 | scores = torch.Tensor([[100.8, 0.3, 11111.9]]).unsqueeze(-1) 118 | tau = 0.1 119 | 120 | # hard = True is necessary 121 | pl_dist = PL(scores, tau, hard=True) 122 | 123 | # check helper sorting function 124 | sorted_scores = pl_dist.relaxed_sort(scores) 125 | print(sorted_scores) 126 | 127 | # check if we get mode of distribution 128 | print(pl_dist.mean) 129 | 130 | # check log prob function 131 | good_pm = torch.Tensor([[[0., 0., 1.], 132 | [1., 0., 0.], 133 | [0., 1., 0.]]]) 134 | intermediate_pm = torch.Tensor([[[0., 0., 1.], 135 | [0., 1., 0.], 136 | [1., 0., 0.]]]) 137 | bad_pm = torch.Tensor([[[0., 1., 0.], 138 | [1., 0., 0.], 139 | [0., 0., 1.]]]) 140 | print(pl_dist.log_prob(good_pm), pl_dist.log_prob( 141 | intermediate_pm), pl_dist.log_prob(bad_pm)) 142 | print() 143 | 144 | # check sample 145 | scores_bimodal = torch.Tensor([[11111.92, 0.3, 11111.9]]).unsqueeze(-1) 146 | pl_dist_bimodal = PL(scores_bimodal, tau, hard=True) 147 | samples = pl_dist_bimodal.sample((5,)) 148 | print(samples) 149 | print() 150 | 151 | # code for kl(q, p) 152 | scores_prior = torch.Tensor([[0.3, 10.8, 1111.9]]).unsqueeze(-1) 153 | tau_prior = 0.1 154 | 155 | pl_dist_prior = PL(scores_prior, tau_prior, hard=True) 156 | print(pl_dist_prior.mean) 157 | print(pl_dist_prior.log_prob(good_pm), pl_dist_prior.log_prob( 158 | intermediate_pm), pl_dist_prior.log_prob(bad_pm)) 159 | 160 | # kl (q, p) 161 | empirical_kl = pl_dist.log_prob(good_pm) - pl_dist_prior.log_prob(good_pm) 162 | print(empirical_kl) 163 | -------------------------------------------------------------------------------- /pytorch/run_baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | 8 | from models.preact_resnet import PreActResNet18 9 | from models.easy_net import ConvNet 10 | from dataset import DataSplit 11 | 12 | torch.manual_seed(94305) 13 | torch.cuda.manual_seed(94305) 14 | np.random.seed(94305) 15 | random.seed(94305) 16 | 17 | parser = argparse.ArgumentParser( 18 | description="Differentiable k-nearest neighbors.") 19 | parser.add_argument("--k", type=int, metavar="k") 20 | parser.add_argument("--tau", type=float, metavar="tau") 21 | parser.add_argument("--nloglr", type=float, metavar="-log10(beta)") 22 | parser.add_argument("--method", type=str) 23 | parser.add_argument("-resume", action='store_true') 24 | parser.add_argument("--dataset", type=str) 25 | 26 | args = parser.parse_args() 27 | dataset = args.dataset 28 | split = DataSplit(dataset) 29 | print(args) 30 | 31 | k = args.k 32 | tau = args.tau 33 | NUM_TRAIN_QUERIES = 100 34 | NUM_TEST_QUERIES = 10 35 | NUM_TRAIN_NEIGHBORS = 100 36 | LEARNING_RATE = 10 ** -args.nloglr 37 | NUM_SAMPLES = 5 38 | resume = args.resume 39 | method = args.method 40 | 41 | NUM_EPOCHS = 150 if dataset == 'cifar10' else 50 42 | EMBEDDING_SIZE = 500 if dataset == 'mnist' else 512 43 | 44 | 45 | def experiment_id(dataset, k, tau, nloglr, method): 46 | return 'baseline-resnet-%s-%s-k%d-t%d-b%d' % (dataset, method, k, tau, nloglr) 47 | 48 | 49 | e_id = experiment_id(dataset, k, tau * 10, args.nloglr, method) 50 | 51 | 52 | gpu = torch.device('cuda') 53 | 54 | if dataset == 'mnist': 55 | h_phi = ConvNet().to(gpu) 56 | else: 57 | h_phi = PreActResNet18(num_channels=3 if dataset == 'cifar10' else 1).to(gpu) 58 | 59 | optimizer = torch.optim.SGD( 60 | h_phi.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4) 61 | 62 | linear_layer = torch.nn.Linear(EMBEDDING_SIZE, 10).to(device=gpu) 63 | ce_loss = torch.nn.CrossEntropyLoss() 64 | 65 | batched_train = split.get_train_loader(NUM_TRAIN_QUERIES) 66 | 67 | 68 | def train(epoch): 69 | h_phi.train() 70 | to_average = [] 71 | # train 72 | for x, y in batched_train: 73 | optimizer.zero_grad() 74 | x = x.to(device=gpu) 75 | y = y.to(device=gpu) 76 | logits = linear_layer(h_phi(x)) 77 | loss = ce_loss(logits, y) 78 | loss = loss.mean() 79 | loss.backward() 80 | optimizer.step() 81 | to_average.append((-loss).item()) 82 | print('train', sum(to_average) / len(to_average)) 83 | 84 | 85 | logfile = open('./logs/%s.log' % e_id, 'a' if resume else 'w') 86 | 87 | batched_val = split.get_valid_loader(NUM_TEST_QUERIES) 88 | batched_test = split.get_test_loader(NUM_TEST_QUERIES) 89 | 90 | best_acc = 0 91 | 92 | 93 | def test(epoch, val=False): 94 | h_phi.eval() 95 | global best_acc 96 | data = batched_val if val else batched_test 97 | 98 | accs = [] 99 | 100 | for x, y in data: 101 | x = x.to(device=gpu) 102 | y = y.to(device=gpu) 103 | logits = linear_layer(h_phi(x)) 104 | pred = logits.argmax(dim=-1) 105 | acc = (pred == y).float().mean() 106 | accs.append(acc.item()) 107 | avg_acc = sum(accs) / len(accs) 108 | print('val' if val else 'test', avg_acc) 109 | if avg_acc > best_acc and val: 110 | print('Saving...') 111 | state = { 112 | 'net': h_phi.state_dict(), 113 | 'acc': avg_acc, 114 | 'epoch': epoch, 115 | } 116 | if not os.path.isdir('checkpoint'): 117 | os.mkdir('checkpoint') 118 | torch.save(state, './checkpoint/ckpt-%s.t7' % e_id) 119 | best_acc = avg_acc 120 | 121 | 122 | for t in range(NUM_EPOCHS): 123 | print('Beginning epoch %d: ' % t, e_id) 124 | print('Beginning epoch %d: ' % t, e_id, file=logfile) 125 | logfile.flush() 126 | train(t) 127 | test(t, val=True) 128 | 129 | 130 | checkpoint = torch.load('./checkpoint/ckpt-%s.t7' % e_id) 131 | h_phi.load_state_dict(checkpoint['net']) 132 | test(-1, val=True) 133 | test(-1, val=False) 134 | logfile.close() 135 | -------------------------------------------------------------------------------- /pytorch/run_dknn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | 8 | from utils import one_hot 9 | from models.preact_resnet import PreActResNet18 10 | from models.easy_net import ConvNet 11 | from dataset import DataSplit 12 | from dknn_layer import DKNN 13 | 14 | torch.manual_seed(94305) 15 | torch.cuda.manual_seed(94305) 16 | np.random.seed(94305) 17 | random.seed(94305) 18 | 19 | parser = argparse.ArgumentParser( 20 | description="Differentiable k-nearest neighbors.") 21 | parser.add_argument("--k", type=int, metavar="k", required=True) 22 | parser.add_argument("--tau", type=float, metavar="tau", default=16.) 23 | parser.add_argument("--nloglr", type=float, metavar="-log10(beta)", default=3.) 24 | parser.add_argument("--method", type=str, default="deterministic") 25 | parser.add_argument("--simple", action='store_true') 26 | parser.add_argument("-resume", action='store_true') 27 | parser.add_argument("--dataset", type=str, required=True) 28 | parser.add_argument("--use_cross_entropy_loss", action='store_true') 29 | 30 | parser.add_argument("--num_train_queries", type=int, default=100) 31 | # no effect on training, but massive effect on memory usage 32 | parser.add_argument("--num_test_queries", type=int, default=10) 33 | parser.add_argument("--num_train_neighbors", type=int, default=100) 34 | parser.add_argument("--num_samples", type=int, default=5) 35 | parser.add_argument("--num_epochs", type=int, default=200) 36 | 37 | args = parser.parse_args() 38 | dataset = args.dataset 39 | split = DataSplit(dataset) 40 | print(args) 41 | 42 | k = args.k 43 | tau = args.tau 44 | NUM_TRAIN_QUERIES = args.num_train_queries 45 | NUM_TEST_QUERIES = args.num_test_queries 46 | NUM_TRAIN_NEIGHBORS = args.num_train_neighbors 47 | LEARNING_RATE = 10 ** -args.nloglr 48 | NUM_SAMPLES = args.num_samples 49 | resume = args.resume 50 | method = args.method 51 | simple = args.simple 52 | use_cross_entropy_loss = args.use_cross_entropy_loss 53 | NUM_EPOCHS = args.num_epochs 54 | EMBEDDING_SIZE = 500 if dataset == 'mnist' else 512 55 | 56 | 57 | def experiment_id(dataset, k, tau, nloglr, method, simple): 58 | return 'dknn-resnet-%s-%s-k%d-t%d-b%d-%s' % (dataset, method, k, tau * 100, nloglr, 59 | 'simple' if simple else 'neural') 60 | 61 | 62 | e_id = experiment_id(dataset, k, tau, args.nloglr, method, simple) 63 | 64 | 65 | dknn_layer = DKNN(k, tau, method=method, num_samples=NUM_SAMPLES, simple=simple) 66 | 67 | 68 | def dknn_loss(query, neighbors, query_label, neighbor_labels, method=method): 69 | # query: batch_size x p 70 | # neighbors: 10k x p 71 | # query_labels: batch_size x [10] one-hot 72 | # neighbor_labels: n x [10] one-hot 73 | if method == 'deterministic': 74 | top_k_ness = dknn_layer(query, neighbors) 75 | correct = (query_label.unsqueeze(1) * neighbor_labels.unsqueeze(0)).sum(-1) 76 | correct_in_top_k = (correct * top_k_ness).sum(-1) 77 | loss = -correct_in_top_k 78 | return loss 79 | elif method == 'stochastic': 80 | top_k_ness = dknn_layer(query, neighbors) 81 | correct = (query_label.unsqueeze(1) * neighbor_labels.unsqueeze(0)).sum(-1) 82 | correct_in_top_k = (correct.unsqueeze(0) * top_k_ness).sum(-1) 83 | loss = -correct_in_top_k 84 | return loss 85 | else: 86 | raise ValueError(method) 87 | 88 | 89 | gpu = torch.device('cuda') 90 | 91 | if dataset == 'mnist': 92 | h_phi = ConvNet().to(gpu) 93 | else: 94 | h_phi = PreActResNet18(num_channels=3 if dataset == 'cifar10' else 1).to(gpu) 95 | 96 | if resume: 97 | # Load checkpoint. 98 | print('==> Resuming from checkpoint..') 99 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 100 | checkpoint = torch.load('./checkpoint/ckpt-%s.t7' % e_id) 101 | h_phi.load_state_dict(checkpoint['net']) 102 | best_acc = checkpoint['acc'] 103 | start_epoch = checkpoint['epoch'] 104 | else: 105 | best_acc = 0 106 | start_epoch = 0 107 | 108 | 109 | optimizer = torch.optim.SGD( 110 | h_phi.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4) 111 | 112 | unit_test_linear_layer = torch.nn.Linear(EMBEDDING_SIZE, 10).to(device=gpu) 113 | unit_test_ce_loss = torch.nn.CrossEntropyLoss() 114 | 115 | ema_factor = .999 116 | ema_num = 0 117 | 118 | 119 | batched_query_train = split.get_train_loader(NUM_TRAIN_QUERIES, drop_last=True) 120 | batched_neighbor_train = split.get_train_loader(NUM_TRAIN_NEIGHBORS, drop_last=True) 121 | 122 | 123 | def loopy(dl): 124 | while True: 125 | for x in dl: 126 | yield x 127 | 128 | 129 | def train(epoch): 130 | h_phi.train() 131 | to_average = [] 132 | # train 133 | for query, candidates in zip(batched_query_train, loopy(batched_neighbor_train)): 134 | optimizer.zero_grad() 135 | cand_x, cand_y = candidates 136 | query_x, query_y = query 137 | 138 | cand_x = cand_x.to(device=gpu) 139 | cand_y = cand_y.to(device=gpu) 140 | query_x = query_x.to(device=gpu) 141 | query_y = query_y.to(device=gpu) 142 | 143 | neighbor_e = h_phi(cand_x).reshape(NUM_TRAIN_NEIGHBORS, EMBEDDING_SIZE) 144 | query_e = h_phi(query_x).reshape(NUM_TRAIN_QUERIES, EMBEDDING_SIZE) 145 | 146 | neighbor_y_oh = one_hot(cand_y).reshape(NUM_TRAIN_NEIGHBORS, 10) 147 | query_y_oh = one_hot(query_y).reshape(NUM_TRAIN_QUERIES, 10) 148 | 149 | losses = dknn_loss(query_e, neighbor_e, query_y_oh, neighbor_y_oh) 150 | if use_cross_entropy_loss: 151 | losses = losses.neg().log().neg() 152 | loss = losses.mean() 153 | loss.backward() 154 | optimizer.step() 155 | to_average.append((-loss).item() / k) 156 | 157 | print('Avg. train correctness of top k:', 158 | sum(to_average) / len(to_average)) 159 | print('Avg. train correctness of top k:', sum( 160 | to_average) / len(to_average), file=logfile) 161 | logfile.flush() 162 | 163 | 164 | def majority(lst): 165 | return max(set(lst), key=lst.count) 166 | 167 | 168 | def new_predict(query, neighbors, neighbor_labels): 169 | ''' 170 | query: p 171 | neighbors: n x p 172 | neighbor_labels: n (int) 173 | ''' 174 | diffs = (query.unsqueeze(1) - neighbors.unsqueeze(0)) 175 | squared_diffs = diffs ** 2 176 | norms = squared_diffs.sum(-1) 177 | indices = torch.argsort(norms, dim=-1) 178 | labels = neighbor_labels.take(indices[:, :k]) 179 | prediction = [majority(l.tolist()) for l in labels] 180 | return torch.Tensor(prediction).to(device=gpu).long() 181 | 182 | 183 | def acc(query, neighbors, query_label, neighbor_labels): 184 | prediction = new_predict(query, neighbors, neighbor_labels) 185 | return (prediction == query_label).float().cpu().numpy() 186 | 187 | 188 | logfile = open('./logs/%s.log' % e_id, 'a' if resume else 'w') 189 | 190 | batched_query_val = split.get_valid_loader(NUM_TEST_QUERIES) 191 | batched_query_test = split.get_test_loader(NUM_TEST_QUERIES) 192 | 193 | 194 | def test(epoch, val=False): 195 | h_phi.eval() 196 | global best_acc 197 | with torch.no_grad(): 198 | embeddings = [] 199 | labels = [] 200 | for neighbor_x, neighbor_y in batched_neighbor_train: 201 | neighbor_x = neighbor_x.to(device=gpu) 202 | neighbor_y = neighbor_y.to(device=gpu) 203 | embeddings.append(h_phi(neighbor_x)) 204 | labels.append(neighbor_y) 205 | neighbors_e = torch.stack(embeddings).reshape(-1, EMBEDDING_SIZE) 206 | labels = torch.stack(labels).reshape(-1) 207 | 208 | results = [] 209 | for queries in batched_query_val if val else batched_query_test: 210 | query_x, query_y = queries 211 | query_x = query_x.to(device=gpu) 212 | query_y = query_y.to(device=gpu) 213 | query_e = h_phi(query_x) 214 | results.append(acc(query_e, neighbors_e, query_y, labels)) 215 | total_acc = np.mean(np.array(results)) 216 | 217 | split = 'val' if val else 'test' 218 | print('Avg. %s acc:' % split, total_acc) 219 | print('Avg. %s acc:' % split, total_acc, file=logfile) 220 | if total_acc > best_acc and val: 221 | print('Saving...') 222 | state = { 223 | 'net': h_phi.state_dict(), 224 | 'acc': total_acc, 225 | 'epoch': epoch, 226 | } 227 | if not os.path.isdir('checkpoint'): 228 | os.mkdir('checkpoint') 229 | torch.save(state, './checkpoint/ckpt-%s.t7' % e_id) 230 | best_acc = total_acc 231 | 232 | 233 | for t in range(start_epoch, NUM_EPOCHS): 234 | print('Beginning epoch %d: ' % t, e_id) 235 | print('Beginning epoch %d: ' % t, e_id, file=logfile) 236 | logfile.flush() 237 | train(t) 238 | test(t, val=True) 239 | 240 | 241 | checkpoint = torch.load('./checkpoint/ckpt-%s.t7' % e_id) 242 | h_phi.load_state_dict(checkpoint['net']) 243 | test(-1, val=True) 244 | test(-1, val=False) 245 | logfile.close() 246 | -------------------------------------------------------------------------------- /pytorch/run_dknn.sh: -------------------------------------------------------------------------------- 1 | ROOT_DIR=run_dknn_results 2 | mkdir -p ${ROOT_DIR} 3 | 4 | NUM_EPOCHS=200 5 | for DATASET in mnist cifar10 fashion-mnist 6 | do 7 | for METHOD in stochastic deterministic 8 | do 9 | GRID_SEARCH_RESULTS_DIR="${DATASET}_${METHOD}" 10 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 11 | for TAU in 1 4 16 64 128 512 12 | do 13 | for k in 1 3 5 9 14 | do 15 | for LR in 3 4 5 16 | do 17 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/${DATASET}_${METHOD}_TAU_${TAU}_k_${k}_lr_${LR}.txt" 18 | python3 run_dknn.py --k=${k} --tau=${TAU} --nloglr=${LR} --method=${METHOD} --dataset=${DATASET} --num_epochs=${NUM_EPOCHS} --simple 2>&1 | tee ${OUTPUT_FILE} 19 | done 20 | done 21 | done 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /pytorch/run_dknn_table_of_results.py: -------------------------------------------------------------------------------- 1 | # Selects best runs according to validation accuracy 2 | import glob 3 | import os 4 | 5 | 6 | def get_last_accuracy(file): 7 | with open(file) as f: 8 | lines = f.readlines() 9 | if len(lines) < 2: 10 | return (0, 0) 11 | val_acc_text, test_acc_text = lines[-2:] 12 | if 'val acc: ' not in val_acc_text: 13 | return (0, 0) 14 | val_acc = float(val_acc_text.split('val acc: ')[1]) 15 | if 'test acc: ' not in test_acc_text: 16 | return (0, 0) 17 | test_acc = float(test_acc_text.split('test acc: ')[1]) 18 | return val_acc, test_acc, os.path.basename(file) 19 | 20 | 21 | def get_best_run(dir): 22 | return list(sorted( 23 | get_last_accuracy(file) for file in glob.glob(os.path.join(dir, '*')) 24 | ))[-1][1:] 25 | 26 | 27 | for dir in glob.glob(os.path.join('run_dknn_results', '*')): 28 | print(dir, *get_best_run(dir)) 29 | -------------------------------------------------------------------------------- /pytorch/softsort.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | class SoftSort(torch.nn.Module): 6 | def __init__(self, tau=1.0, hard=False, pow=1.0): 7 | super(SoftSort, self).__init__() 8 | self.hard = hard 9 | self.tau = tau 10 | self.pow = pow 11 | 12 | def forward(self, scores: Tensor): 13 | """ 14 | scores: elements to be sorted. Typical shape: batch_size x n 15 | """ 16 | scores = scores.unsqueeze(-1) 17 | sorted = scores.sort(descending=True, dim=1)[0] 18 | pairwise_diff = (scores.transpose(1, 2) - sorted).abs().pow(self.pow).neg() / self.tau 19 | P_hat = pairwise_diff.softmax(-1) 20 | 21 | if self.hard: 22 | P = torch.zeros_like(P_hat, device=P_hat.device) 23 | P.scatter_(-1, P_hat.topk(1, -1)[1], value=1) 24 | P_hat = (P - P_hat).detach() + P_hat 25 | return P_hat 26 | 27 | 28 | class SoftSort_p1(torch.nn.Module): 29 | def __init__(self, tau=1.0, hard=False): 30 | super(SoftSort_p1, self).__init__() 31 | self.hard = hard 32 | self.tau = tau 33 | 34 | def forward(self, scores: Tensor): 35 | """ 36 | scores: elements to be sorted. Typical shape: batch_size x n 37 | """ 38 | scores = scores.unsqueeze(-1) 39 | sorted = scores.sort(descending=True, dim=1)[0] 40 | pairwise_diff = (scores.transpose(1, 2) - sorted).abs().neg() / self.tau 41 | P_hat = pairwise_diff.softmax(-1) 42 | 43 | if self.hard: 44 | P = torch.zeros_like(P_hat, device=P_hat.device) 45 | P.scatter_(-1, P_hat.topk(1, -1)[1], value=1) 46 | P_hat = (P - P_hat).detach() + P_hat 47 | return P_hat 48 | 49 | 50 | class SoftSort_p2(torch.nn.Module): 51 | def __init__(self, tau=1.0, hard=False): 52 | super(SoftSort_p2, self).__init__() 53 | self.hard = hard 54 | self.tau = tau 55 | 56 | def forward(self, scores: Tensor): 57 | """ 58 | scores: elements to be sorted. Typical shape: batch_size x n 59 | """ 60 | scores = scores.unsqueeze(-1) 61 | sorted = scores.sort(descending=True, dim=1)[0] 62 | pairwise_diff = ((scores.transpose(1, 2) - sorted) ** 2).neg() / self.tau 63 | P_hat = pairwise_diff.softmax(-1) 64 | 65 | if self.hard: 66 | P = torch.zeros_like(P_hat, device=P_hat.device) 67 | P.scatter_(-1, P_hat.topk(1, -1)[1], value=1) 68 | P_hat = (P - P_hat).detach() + P_hat 69 | return P_hat 70 | -------------------------------------------------------------------------------- /pytorch/synthetic_experiment_learning_curves.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from scipy import stats 4 | 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | from neuralsort_cpu_or_gpu import NeuralSort 10 | from softsort import SoftSort 11 | 12 | parser = argparse.ArgumentParser(description="Benchmark speed of softsort vs" 13 | " neuralsort") 14 | 15 | parser.add_argument("--batch_size", type=int, default=20) 16 | parser.add_argument("--n", type=int, default=2000) 17 | parser.add_argument("--epochs", type=int, default=100) 18 | parser.add_argument("--device", type=str, default='cpu') 19 | parser.add_argument("--method", type=str, default='neuralsort') 20 | parser.add_argument("--tau", type=float, default=1.0) 21 | parser.add_argument("--pow", type=float, default=1.0) 22 | 23 | args = parser.parse_args() 24 | 25 | print("Benchmarking with:\n" 26 | "\tbatch_size = %d\n" 27 | "\tn = %d\n" 28 | "\tepochs = %d\n" 29 | "\tdevice = %s\n" 30 | "\tmethod = %s\n" 31 | "\ttau = %s\n" 32 | "\tpow = %f\n" % 33 | (args.batch_size, 34 | args.n, 35 | args.epochs, 36 | args.device, 37 | args.method, 38 | args.tau, 39 | args.pow)) 40 | 41 | np.random.seed(1) 42 | torch.manual_seed(1) 43 | 44 | sort_op = None 45 | if args.method == 'neuralsort': 46 | sort_op = NeuralSort(tau=args.tau, device=args.device) 47 | elif args.method == 'softsort': 48 | sort_op = SoftSort(tau=args.tau, pow=args.pow) 49 | else: 50 | raise ValueError('method %s not found' % args.method) 51 | 52 | scores = Variable(torch.rand(size=(args.batch_size, args.n), 53 | device=args.device) * 2.0 - 1.0, requires_grad=True) 54 | optimizer = torch.optim.SGD([scores], lr=10.0, momentum=0.5, weight_decay=0.01) 55 | 56 | 57 | def evaluate(scores): 58 | r''' 59 | Returns the mean spearman correlation over the batch. 60 | ''' 61 | scores_eval = scores.cpu().detach().numpy() 62 | rank_correlations = [] 63 | for i in range(args.batch_size): 64 | rank_correlation, _ = stats.spearmanr(scores_eval[i], range(args.n, 0, -1)) 65 | rank_correlations.append(rank_correlation) 66 | mean_rank_correlation = np.mean(rank_correlations) 67 | return mean_rank_correlation 68 | 69 | 70 | # train 71 | start_time = time.time() 72 | log = "" 73 | for epoch in range(args.epochs): 74 | optimizer.zero_grad() 75 | 76 | # Normalize scores before feeding them into the sorting op for increased stability. 77 | min_scores, _ = torch.min(scores, dim=1, keepdim=True) 78 | min_scores = min_scores.detach() 79 | max_scores, _ = torch.max(scores, dim=1, keepdim=True) 80 | max_scores = max_scores.detach() 81 | scores_normalized = (scores - min_scores) / (max_scores - min_scores) 82 | P_hat = sort_op(scores_normalized) 83 | 84 | loss = torch.mean(1.0 - torch.log(torch.diagonal(P_hat, dim1=1, dim2=2))) 85 | loss.backward(retain_graph=True) 86 | optimizer.step() 87 | spearmanr = evaluate(scores) 88 | log += "Epoch %d loss = %f spearmanr = %f\n" % (epoch, loss, spearmanr) 89 | spearmanr = evaluate(scores) 90 | if args.device == 'cuda': 91 | torch.cuda.synchronize() 92 | end_time = time.time() 93 | total_time = end_time - start_time 94 | 95 | log += "Epochs: %d\n" % args.epochs 96 | log += "Loss: %f\n" % loss 97 | log += "Spearmanr: %f\n" % spearmanr 98 | log += "Total time: %f\n" % total_time 99 | log += "Time per epoch: %f\n" % (total_time / args.epochs) 100 | 101 | print(log) 102 | -------------------------------------------------------------------------------- /pytorch/synthetic_experiment_speed_comparison.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from scipy import stats 4 | 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | from neuralsort_cpu_or_gpu import NeuralSort 10 | from softsort import SoftSort_p2 11 | 12 | parser = argparse.ArgumentParser(description="Benchmark speed of softsort vs" 13 | " neuralsort") 14 | 15 | parser.add_argument("--batch_size", type=int, default=20) 16 | parser.add_argument("--n", type=int, default=2000) 17 | parser.add_argument("--epochs", type=int, default=100) 18 | parser.add_argument("--device", type=str, default='cpu') 19 | parser.add_argument("--method", type=str, default='neuralsort') 20 | parser.add_argument("--burnin", type=int, default=100) 21 | 22 | args = parser.parse_args() 23 | 24 | print("Benchmarking with:\n" 25 | "\tbatch_size = %d\n" 26 | "\tn = %d\n" 27 | "\tepochs = %d\n" 28 | "\tdevice = %s\n" 29 | "\tmethod = %s\n" 30 | "\tburnin = %d" % 31 | (args.batch_size, 32 | args.n, 33 | args.epochs, 34 | args.device, 35 | args.method, 36 | args.burnin)) 37 | 38 | np.random.seed(1) 39 | torch.manual_seed(1) 40 | 41 | sort_op = None 42 | if args.method == 'neuralsort': 43 | sort_op = NeuralSort(tau=100.0, device=args.device) 44 | elif args.method == 'softsort': 45 | sort_op = SoftSort_p2(tau=0.1) 46 | else: 47 | raise ValueError('method %s not found' % args.method) 48 | 49 | scores = Variable(torch.rand(size=(args.batch_size, args.n), 50 | device=args.device) * 2.0 - 1.0, requires_grad=True) 51 | optimizer = torch.optim.SGD([scores], lr=10.0, momentum=0.5, weight_decay=0.01) 52 | 53 | 54 | def evaluate(scores): 55 | r''' 56 | Returns the mean spearman correlation over the batch. 57 | ''' 58 | scores_eval = scores.cpu().detach().numpy() 59 | rank_correlations = [] 60 | for i in range(args.batch_size): 61 | rank_correlation, _ = stats.spearmanr(scores_eval[i], range(args.n, 0, -1)) 62 | rank_correlations.append(rank_correlation) 63 | mean_rank_correlation = np.mean(rank_correlations) 64 | return mean_rank_correlation 65 | 66 | 67 | def training_step(): 68 | optimizer.zero_grad() 69 | 70 | # Normalize scores before feeding them into the sorting op for increased stability. 71 | min_scores, _ = torch.min(scores, dim=1, keepdim=True) 72 | min_scores = min_scores.detach() 73 | max_scores, _ = torch.max(scores, dim=1, keepdim=True) 74 | max_scores = max_scores.detach() 75 | scores_normalized = (scores - min_scores) / (max_scores - min_scores) 76 | P_hat = sort_op(scores_normalized) 77 | 78 | loss = torch.mean(1.0 - torch.log(torch.diagonal(P_hat, dim1=1, dim2=2))) 79 | loss.backward(retain_graph=True) 80 | optimizer.step() 81 | return loss 82 | 83 | 84 | # burn-in 85 | for epoch in range(args.burnin): 86 | training_step() 87 | if args.device == 'cuda': 88 | torch.cuda.synchronize() 89 | 90 | # train 91 | start_time = time.time() 92 | log = "" 93 | for epoch in range(args.epochs): 94 | loss = training_step() 95 | spearmanr = evaluate(scores) 96 | if args.device == 'cuda': 97 | torch.cuda.synchronize() 98 | end_time = time.time() 99 | total_time = end_time - start_time 100 | 101 | log += "Epochs: %d\n" % args.epochs 102 | log += "Loss: %f\n" % loss 103 | log += "Spearmanr: %f\n" % spearmanr 104 | log += "Total time: %f\n" % total_time 105 | log += "Time per epoch: %f\n" % (total_time / args.epochs) 106 | 107 | print(log) 108 | -------------------------------------------------------------------------------- /pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # labels is a 1-dimensional tensor 4 | 5 | 6 | def one_hot(labels, l=10): 7 | n = labels.shape[0] 8 | labels = labels.unsqueeze(-1) 9 | oh = torch.zeros(n, l, device='cuda').scatter_(1, labels, 1) 10 | return oh 11 | 12 | 13 | generate_nothing = iter(int, 1) 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.0.3 2 | numpy==1.15.4 3 | pillow==9.0.0 4 | scipy==1.1.0 5 | sklearn==0.0 6 | tensorflow==1.12.0 7 | tensorflow-gpu==1.14.0 8 | torch==1.0.1 9 | torchvision==0.2.1 10 | -------------------------------------------------------------------------------- /synthetic_experiment_learning_curves.sh: -------------------------------------------------------------------------------- 1 | ROOT_DIR=benchmark_results_learning_curve 2 | BATCH_SIZE=20 3 | N=4000 4 | EPOCHS=100 5 | DEVICE=cpu 6 | 7 | mkdir -p ${ROOT_DIR} 8 | 9 | FRAMEWORK=tf 10 | python3 ${FRAMEWORK}/synthetic_experiment_learning_curves.py \ 11 | --batch_size ${BATCH_SIZE} \ 12 | --n ${N} \ 13 | --epochs ${EPOCHS} \ 14 | --device ${DEVICE} \ 15 | --method softsort \ 16 | --tau 0.03 \ 17 | --pow 2.0 \ 18 | 2>&1 | tee ${ROOT_DIR}/benchmark_results_learning_curve_softsort_p2_${FRAMEWORK}_${N}_${BATCH_SIZE}.txt 19 | 20 | 21 | python3 ${FRAMEWORK}/synthetic_experiment_learning_curves.py \ 22 | --batch_size ${BATCH_SIZE} \ 23 | --n ${N} \ 24 | --epochs ${EPOCHS} \ 25 | --device ${DEVICE} \ 26 | --method softsort \ 27 | --tau 0.1 \ 28 | --pow 1.0 \ 29 | 2>&1 | tee ${ROOT_DIR}/benchmark_results_learning_curve_softsort_p1_${FRAMEWORK}_${N}_${BATCH_SIZE}.txt 30 | 31 | 32 | python3 ${FRAMEWORK}/synthetic_experiment_learning_curves.py \ 33 | --batch_size ${BATCH_SIZE} \ 34 | --n ${N} \ 35 | --epochs ${EPOCHS} \ 36 | --device ${DEVICE} \ 37 | --method neuralsort \ 38 | --tau 100.0 \ 39 | 2>&1 | tee ${ROOT_DIR}/benchmark_results_learning_curve_neuralsort_${FRAMEWORK}_${N}_${BATCH_SIZE}.txt 40 | -------------------------------------------------------------------------------- /synthetic_experiment_learning_curves_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | N = 4000 5 | BATCH_SIZE = 20 6 | ROOT_DIR = "benchmark_results_learning_curve" 7 | 8 | 9 | def get_spearmanr2s_from_file(filename): 10 | spearmanr2s = [] 11 | for line in open(filename): 12 | if line.startswith('Epoch '): 13 | spearmanr2s.append(float(line.split(' ')[-1])) 14 | return spearmanr2s 15 | 16 | 17 | plt.figure(figsize=(7, 5)) 18 | fontsize = 16 19 | spearmanr2s_neuralsort =\ 20 | get_spearmanr2s_from_file(f"{ROOT_DIR}/benchmark_results_learning_curve_neuralsort_tf_{N}_{BATCH_SIZE}.txt") 21 | plt.plot(spearmanr2s_neuralsort, label='NeuralSort', color='red', linestyle='--') 22 | spearmanr2s_softsort_p1 =\ 23 | get_spearmanr2s_from_file(f"{ROOT_DIR}/benchmark_results_learning_curve_softsort_p1_tf_{N}_{BATCH_SIZE}.txt") 24 | plt.plot(spearmanr2s_softsort_p1, label='SoftSort, p=1', color='blue', linestyle='-.') 25 | spearmanr2s_softsort_p2 =\ 26 | get_spearmanr2s_from_file(f"{ROOT_DIR}/benchmark_results_learning_curve_softsort_p2_tf_{N}_{BATCH_SIZE}.txt") 27 | plt.plot(spearmanr2s_softsort_p2, label='SoftSort, p=2', color='blue', linestyle='-') 28 | # plt.title('Learning Curves for Benchmark Task') 29 | plt.xlabel('Epoch', fontsize=fontsize) 30 | plt.ylabel('Spearman R2', fontsize=fontsize) 31 | plt.xticks(fontsize=fontsize) 32 | plt.yticks(fontsize=fontsize) 33 | plt.legend(fontsize=fontsize) 34 | plt.tight_layout() 35 | plt.savefig("images/synthetic_experiment_learning_curves.png") 36 | -------------------------------------------------------------------------------- /synthetic_experiment_speed_comparison.sh: -------------------------------------------------------------------------------- 1 | for FRAMEWORK in tf pytorch 2 | do 3 | ROOT_DIR=benchmark_results_${FRAMEWORK} 4 | mkdir -p ${ROOT_DIR} 5 | for ((N=100;N<=4000;N+=100)) 6 | do 7 | for METHOD in neuralsort softsort 8 | do 9 | RESULTS_DIR="N_${N}_${METHOD}" 10 | mkdir ${ROOT_DIR}/${RESULTS_DIR} 11 | for DEVICE in cpu cuda 12 | do 13 | OUTPUT_FILE="${ROOT_DIR}/${RESULTS_DIR}/N_${N}_${METHOD}_DEVICE_${DEVICE}.txt" 14 | python3 ${FRAMEWORK}/synthetic_experiment_speed_comparison.py --batch_size 20 --n ${N} --epochs 100 --device ${DEVICE} --method ${METHOD} --burnin 1 2>&1 | tee ${OUTPUT_FILE} 15 | done 16 | done 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /synthetic_experiment_speed_comparison_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class BenchmarkSoftsortBackwardsResultsParser: 7 | r''' 8 | Parses an individual results (i.e. log) file and stores the results. 9 | ''' 10 | def __init__(self): 11 | self.epochs = None 12 | self.loss = None 13 | self.spearmanr = None 14 | self.total_time = None 15 | self.time_per_epoch = None 16 | self.oom = False 17 | 18 | def parse(self, file_path, expected_length=None): 19 | r''' 20 | :param file_path: path to the results (i.e. log) file 21 | ''' 22 | with open(file_path) as file: 23 | for line in file: 24 | line_tokens = line.replace(',', '').replace('\n', '').split(' ') 25 | if line.startswith("Epochs"): 26 | assert self.epochs is None 27 | self.epochs = line_tokens[1] 28 | elif line.startswith("Loss"): 29 | assert self.loss is None 30 | self.loss = line_tokens[1] 31 | elif line.startswith("Spearmanr"): 32 | assert self.spearmanr is None 33 | self.spearmanr = line_tokens[1] 34 | elif line.startswith("Total time"): 35 | assert self.total_time is None 36 | self.total_time = line_tokens[2] 37 | elif line.startswith("Time per epoch"): 38 | assert self.time_per_epoch is None 39 | self.time_per_epoch = line_tokens[3] 40 | if line.startswith("RuntimeError:"): 41 | self.oom = True 42 | return 43 | if expected_length: 44 | assert(int(self.epochs) == expected_length) 45 | assert self.epochs is not None 46 | assert self.loss is not None 47 | assert self.spearmanr is not None 48 | assert self.total_time is not None 49 | assert self.time_per_epoch is not None 50 | 51 | def get_epochs(self): 52 | return self.epochs if not self.oom else '-' 53 | 54 | def get_loss(self): 55 | return self.loss if not self.oom else '-' 56 | 57 | def get_spearmanr(self): 58 | return self.spearmanr if not self.oom else '-' 59 | 60 | def get_total_time(self): 61 | return self.total_time if not self.oom else '-' 62 | 63 | def get_time_per_epoch(self): 64 | r''' 65 | Returns the time per epoch in ms 66 | ''' 67 | return ("%.5f" % (1000.0 * float(self.time_per_epoch))) if not self.oom else '-' 68 | 69 | 70 | num_epochs = 100 71 | frameworks = ['pytorch', 'pytorch', 'tf', 'tf'] 72 | devices = ['cpu', 'cuda', 'cpu', 'cuda'] 73 | ns_lists = \ 74 | [[str(i) for i in range(100, 4001, 100)]] * 4 75 | methods = ['neuralsort', 'softsort'] 76 | 77 | res = dict() 78 | 79 | for framework, device, ns in zip(frameworks, devices, ns_lists): 80 | for n in ns: 81 | for method in methods: 82 | filename = "./benchmark_results_%s/N_%s_%s/N_%s_%s_DEVICE_%s.txt" %\ 83 | (framework, n, method, n, method, device) 84 | print("Processing " + str(filename)) 85 | results_parser = BenchmarkSoftsortBackwardsResultsParser() 86 | results_parser.parse(filename, expected_length=int(num_epochs)) 87 | epochs = results_parser.get_epochs() 88 | loss = results_parser.get_loss() 89 | spearmanr = results_parser.get_spearmanr() 90 | total_time = results_parser.get_total_time() 91 | time_per_epoch = results_parser.get_time_per_epoch() 92 | res[(framework, device, n, method, 'epochs')] = epochs 93 | res[(framework, device, n, method, 'loss')] = loss 94 | res[(framework, device, n, method, 'spearmanr')] = spearmanr 95 | res[(framework, device, n, method, 'total_time')] = total_time 96 | res[(framework, device, n, method, 'time_per_epoch')] = time_per_epoch 97 | 98 | 99 | def get_times_for_device_framework_and_method(device, framework, method): 100 | times = [] 101 | for n in ns: 102 | time = res[(framework, device, n, method, 'time_per_epoch')] 103 | if time == '-': 104 | break 105 | times.append(time) 106 | times = np.array(times) 107 | return times 108 | 109 | 110 | ns = np.array([str(i) for i in range(100, 4001, 100)]) 111 | 112 | for device in ['cpu', 'cuda']: 113 | time_normalization = 1000 if device == 'cpu' else 1 114 | for framework in ['pytorch', 'tf']: 115 | times_neuralsort = get_times_for_device_framework_and_method( 116 | device=device, 117 | framework=framework, 118 | method='neuralsort') 119 | times_softsort = get_times_for_device_framework_and_method( 120 | device=device, 121 | framework=framework, 122 | method='softsort') 123 | fig1, ax1 = plt.subplots(figsize=(7, 5)) 124 | fontsize = 16 125 | ax1.plot(ns[:len(times_neuralsort)].astype('int'), times_neuralsort.astype('float') / time_normalization, 126 | color='red', linestyle='--') 127 | ax1.plot(ns[:len(times_softsort)].astype('int'), times_softsort.astype('float') / time_normalization, 128 | color='blue', linestyle='-') 129 | plt.xticks(rotation=70, fontsize=fontsize) 130 | ax1.set_xticks(ns.astype('int')) 131 | ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) 132 | plt.xlabel(r'$n$', fontsize=fontsize) 133 | plt.xticks(range(200, 4001, 200), fontsize=fontsize) 134 | plt.yticks(fontsize=fontsize) 135 | if device == 'cuda': 136 | plt.ylim(0, 150) 137 | plt.ylabel('time per epoch (ms)', fontsize=fontsize) 138 | else: 139 | plt.ylim(0, 30) 140 | plt.ylabel('time per epoch (s)', fontsize=fontsize) 141 | title = "" 142 | if framework == 'pytorch': 143 | title += 'Pytorch' 144 | elif framework == 'tf': 145 | title += 'TensorFlow' 146 | if device == 'cuda': 147 | title += ' GPU' 148 | elif device == 'cpu': 149 | title += ' CPU' 150 | # plt.title(title) # Title should go in the figure latex caption 151 | plt.legend(['NeuralSort', 'SoftSort'], fontsize=fontsize) 152 | plt.tight_layout() 153 | plt.savefig('images/' + title.replace(' ', '_') + '_softsort') 154 | -------------------------------------------------------------------------------- /tf/.gitignore: -------------------------------------------------------------------------------- 1 | MNIST_data/ 2 | -------------------------------------------------------------------------------- /tf/checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /tf/logs/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /tf/mnist_input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import tensorflow as tf 4 | from statistics import median 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | TRAIN_SET_SIZE = 55000 8 | VAL_SET_SIZE = 5000 9 | TEST_SET_SIZE = 10000 10 | 11 | 12 | def select_digit(split, d): 13 | return split.images[np.nonzero(split.labels[:, d])] 14 | 15 | 16 | def split_digits(split): 17 | return [select_digit(split, d) for d in range(10)] 18 | 19 | 20 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 21 | 22 | train_digits = split_digits(mnist.train) 23 | validation_digits = split_digits(mnist.validation) 24 | test_digits = split_digits(mnist.test) 25 | 26 | 27 | def get_multi_mnist_input(l, n, low, high, digset=train_digits): 28 | multi_mnist_sequences = [] 29 | values = [] 30 | for i in range(n): 31 | mnist_digits = [] 32 | num = random.randint(low, high) 33 | values.append(num) 34 | 35 | for i in range(l): 36 | digit = num % 10 37 | num //= 10 38 | ref = digset[digit] 39 | mnist_digits.insert(0, ref[np.random.randint(0, ref.shape[0])]) 40 | multi_mnist_sequence = np.concatenate(mnist_digits) 41 | multi_mnist_sequence = np.reshape(multi_mnist_sequence, (-1, 28)) 42 | multi_mnist_sequences.append(multi_mnist_sequence) 43 | multi_mnist_batch = np.stack(multi_mnist_sequences) 44 | vals = np.array(values) 45 | med = int(median(values)) 46 | arg_med = np.equal(vals, med).astype('float32') 47 | arg_med /= np.sum(arg_med) 48 | return multi_mnist_batch, med, arg_med, vals 49 | 50 | 51 | def get_iterator(l, n, window_size, digset, minibatch_size=None): 52 | low, high = 0, 10 ** l - 1 53 | 54 | def input_generator(): 55 | while True: 56 | window_begin = random.randint(low, high - window_size) 57 | ret = get_multi_mnist_input( 58 | l, n, window_begin, window_begin + window_size, digset) 59 | yield ret 60 | mm_data = tf.data.Dataset.from_generator( 61 | input_generator, 62 | (tf.float32, tf.float32, tf.float32, tf.float32), 63 | ((n, l * 28, 28), (), (n,), (n,)) 64 | ) 65 | if minibatch_size: 66 | mm_data = mm_data.batch(minibatch_size) 67 | mm_data = mm_data.prefetch(10) 68 | return mm_data.make_one_shot_iterator() 69 | 70 | 71 | def get_iterators(l, n, window_size, minibatch_size=None, val_repeat=None): 72 | return get_iterator(l, n, window_size, train_digits, minibatch_size=minibatch_size), \ 73 | get_iterator(l, n, window_size, validation_digits, minibatch_size=minibatch_size), \ 74 | get_iterator(l, n, window_size, test_digits, 75 | minibatch_size=minibatch_size) 76 | 77 | 78 | def test_iterators(): 79 | a, b, c = get_iterators(5, 10, 100) 80 | with tf.Session() as sess: 81 | for d in [a, b, c]: 82 | print(sess.run(d.get_next())) 83 | -------------------------------------------------------------------------------- /tf/multi_mnist_cnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def deepnn(l, x, final_dim=1): 5 | """deepnn builds the graph for a deep net for classifying digits. 6 | Args: 7 | x: an input tensor with the dimensions (N_examples, 784), where 784 is the 8 | number of pixels in a standard MNIST image. 9 | Returns: 10 | A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values 11 | equal to the logits of classifying the digit into one of 10 classes (the 12 | digits 0-9). keep_prob is a scalar placeholder for the probability of 13 | dropout. 14 | """ 15 | # Reshape to use within a convolutional neural net. 16 | # Last dimension is for "features" - there is only one here, since images are 17 | # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. 18 | 19 | with tf.name_scope('reshape'): 20 | x_image = tf.reshape(x, [-1, l * 28, 28, 1]) 21 | 22 | # First convolutional layer - maps one grayscale image to 32 feature maps. 23 | with tf.name_scope('conv1'): 24 | W_conv1 = weight_variable([5, 5, 1, 32]) 25 | b_conv1 = bias_variable([32]) 26 | h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 27 | 28 | # Pooling layer - downsamples by 2X. 29 | with tf.name_scope('pool1'): 30 | h_pool1 = max_pool_2x2(h_conv1) 31 | 32 | # Second convolutional layer -- maps 32 feature maps to 64. 33 | with tf.name_scope('conv2'): 34 | W_conv2 = weight_variable([5, 5, 32, 64]) 35 | b_conv2 = bias_variable([64]) 36 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 37 | 38 | # Second pooling layer. 39 | with tf.name_scope('pool2'): 40 | h_pool2 = max_pool_2x2(h_conv2) 41 | 42 | # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image 43 | # is down to 7x7x64 feature maps -- maps this to 64 features. 44 | with tf.name_scope('fc1'): 45 | W_fc1 = weight_variable([l * 7 * 7 * 64, 64]) 46 | b_fc1 = bias_variable([64]) 47 | 48 | h_pool2_flat = tf.reshape(h_pool2, [-1, l * 7 * 7 * 64]) 49 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 50 | 51 | with tf.name_scope('fc2'): 52 | W_fc2 = weight_variable([64, final_dim]) 53 | b_fc2 = bias_variable([final_dim]) 54 | 55 | h_fc1_flat = tf.reshape(h_fc1, [-1, 64]) 56 | h_fc2 = tf.matmul(h_fc1_flat, W_fc2) + b_fc2 57 | 58 | return h_fc2 59 | 60 | # Dropout - controls the complexity of the model, prevents co-adaptation of 61 | # features. 62 | with tf.name_scope('dropout'): 63 | keep_prob = tf.placeholder(tf.float32) 64 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 65 | 66 | return h_fc1_drop, keep_prob 67 | 68 | 69 | def conv2d(x, W): 70 | """conv2d returns a 2d convolution layer with full stride.""" 71 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 72 | 73 | 74 | def max_pool_2x2(x): 75 | """max_pool_2x2 downsamples a feature map by 2X.""" 76 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 77 | strides=[1, 2, 2, 1], padding='SAME') 78 | 79 | 80 | def weight_variable(shape): 81 | """weight_variable generates a weight variable of a given shape.""" 82 | initial = tf.truncated_normal(shape, stddev=0.1) 83 | with tf.name_scope("reg"): 84 | return tf.Variable(initial) 85 | 86 | 87 | def bias_variable(shape): 88 | """bias_variable generates a bias variable of a given shape.""" 89 | initial = tf.constant(0.1, shape=shape) 90 | return tf.Variable(initial) 91 | -------------------------------------------------------------------------------- /tf/predictions/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /tf/run_median.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import tensorflow as tf 4 | import numpy as np 5 | from scipy.stats import spearmanr 6 | from sklearn.metrics import r2_score 7 | import mnist_input 8 | import multi_mnist_cnn 9 | from sinkhorn import sinkhorn_operator 10 | 11 | import util 12 | import random 13 | 14 | os.environ['TF_CUDNN_DETERMINISTIC'] = 'true' 15 | tf.set_random_seed(94305) 16 | random.seed(94305) 17 | np.random.seed(94305) 18 | 19 | flags = tf.app.flags 20 | flags.DEFINE_integer('M', 1, 'batch size') 21 | flags.DEFINE_integer('n', 3, 'number of elements to compare at a time') 22 | flags.DEFINE_integer('l', 5, 'number of digits') 23 | flags.DEFINE_integer('repetition', 0, 'number of repetition') 24 | flags.DEFINE_float('pow', 1, 'softsort exponent for pairwise difference') 25 | flags.DEFINE_float('tau', 5, 'temperature (dependent meaning)') 26 | flags.DEFINE_string('method', 'deterministic_neuralsort', 27 | 'which method to use?') 28 | flags.DEFINE_integer('n_s', 5, 'number of samples') 29 | flags.DEFINE_integer('num_epochs', 200, 'number of epochs to train') 30 | flags.DEFINE_float('lr', 1e-4, 'initial learning rate') 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | n_s = FLAGS.n_s 35 | NUM_EPOCHS = FLAGS.num_epochs 36 | M = FLAGS.M 37 | n = FLAGS.n 38 | l = FLAGS.l 39 | repetition = FLAGS.repetition 40 | power = FLAGS.pow 41 | tau = FLAGS.tau 42 | method = FLAGS.method 43 | initial_rate = FLAGS.lr 44 | 45 | train_iterator, val_iterator, test_iterator = mnist_input.get_iterators( 46 | l, n, 10 ** l - 1, minibatch_size=M) 47 | 48 | false_tensor = tf.convert_to_tensor(False) 49 | evaluation = tf.placeholder_with_default(false_tensor, ()) 50 | temp = tf.cond(evaluation, 51 | false_fn=lambda: tf.convert_to_tensor(tau, dtype=tf.float32), 52 | true_fn=lambda: tf.convert_to_tensor(1e-10, dtype=tf.float32) 53 | ) 54 | 55 | experiment_id = 'median-%s-M%d-n%d-l%d-t%d-p%.2f' % (method, M, n, l, tau * 10, power) 56 | checkpoint_path = 'checkpoints/%s/' % experiment_id 57 | predictions_path = 'predictions/' 58 | 59 | handle = tf.placeholder(tf.string, ()) 60 | X_iterator = tf.data.Iterator.from_string_handle( 61 | handle, 62 | (tf.float32, tf.float32, tf.float32, tf.float32), 63 | ((M, n, l * 28, 28), (M,), (M, n), (M, n)) 64 | ) 65 | 66 | X, y, median_scores, true_scores = X_iterator.get_next() 67 | 68 | true_scores = tf.expand_dims(true_scores, 2) 69 | P_true = util.neuralsort(true_scores, 1e-10) 70 | n_prime = n 71 | 72 | 73 | def get_median_probs(P): 74 | median_strip = P[:, n // 2, :] 75 | median_total = tf.reduce_sum(median_strip, axis=1, keepdims=True) 76 | probs = median_strip / median_total 77 | # print(probs) 78 | return probs 79 | 80 | 81 | if method == 'vanilla': 82 | with tf.variable_scope("phi"): 83 | representations = multi_mnist_cnn.deepnn(l, X, 10) 84 | representations = tf.reshape(representations, [M, n * 10]) 85 | fc1 = tf.layers.dense(representations, 10, tf.nn.relu) 86 | fc2 = tf.layers.dense(fc1, 10, tf.nn.relu) 87 | fc3 = tf.layers.dense(fc2, 10, tf.nn.relu) 88 | y_hat = tf.layers.dense(fc3, 1) 89 | y_hat = tf.squeeze(y_hat) 90 | loss_phi = tf.reduce_sum(tf.squared_difference(y_hat, y)) 91 | loss_theta = loss_phi 92 | prob_median_eval = 0 93 | 94 | elif method == 'sinkhorn': 95 | with tf.variable_scope('phi'): 96 | representations = multi_mnist_cnn.deepnn(l, X, n) 97 | pre_sinkhorn = tf.reshape(representations, [M, n, n]) 98 | with tf.variable_scope('theta'): 99 | regression_candidates = multi_mnist_cnn.deepnn(l, X, 1) 100 | regression_candidates = tf.reshape( 101 | regression_candidates, [M, n]) 102 | 103 | P_hat = sinkhorn_operator(pre_sinkhorn, temp=temp) 104 | prob_median = get_median_probs(P_hat) 105 | 106 | point_estimates = tf.reduce_sum( 107 | prob_median * regression_candidates, axis=1) 108 | exp_loss = tf.squared_difference(y, point_estimates) 109 | 110 | loss_phi = tf.reduce_mean(exp_loss) 111 | loss_theta = loss_phi 112 | 113 | P_hat_eval = sinkhorn_operator(pre_sinkhorn, temp=1e-20) 114 | prob_median_eval = get_median_probs(P_hat_eval) 115 | 116 | elif method == 'gumbel_sinkhorn': 117 | with tf.variable_scope('phi'): 118 | representations = multi_mnist_cnn.deepnn(l, X, n) 119 | pre_sinkhorn_orig = tf.reshape(representations, [M, n, n]) 120 | pre_sinkhorn = tf.tile(pre_sinkhorn_orig, [ 121 | n_s, 1, 1]) 122 | pre_sinkhorn += util.sample_gumbel([n_s * M, n, n]) 123 | 124 | with tf.variable_scope('theta'): 125 | regression_candidates = multi_mnist_cnn.deepnn(l, X, 1) 126 | regression_candidates = tf.reshape( 127 | regression_candidates, [M, n]) 128 | 129 | P_hat = sinkhorn_operator(pre_sinkhorn, temp=temp) 130 | prob_median = get_median_probs(P_hat) 131 | prob_median = tf.reshape(prob_median, [n_s, M, n]) 132 | 133 | point_estimates = tf.reduce_sum( 134 | prob_median * regression_candidates, axis=2) 135 | exp_loss = tf.squared_difference(y, point_estimates) 136 | 137 | loss_phi = tf.reduce_mean(exp_loss) 138 | loss_theta = loss_phi 139 | 140 | P_hat_eval = sinkhorn_operator(pre_sinkhorn_orig, temp=1e-20) 141 | prob_median_eval = get_median_probs(P_hat_eval) 142 | 143 | elif method == 'deterministic_neuralsort': 144 | with tf.variable_scope('phi'): 145 | scores = multi_mnist_cnn.deepnn(l, X, 1) 146 | scores = tf.reshape(scores, [M, n, 1]) 147 | 148 | P_hat = util.neuralsort(scores, temp) 149 | P_hat_eval = util.neuralsort(scores, 1e-20) 150 | 151 | with tf.variable_scope('theta'): 152 | regression_candidates = multi_mnist_cnn.deepnn(l, X, 1) 153 | regression_candidates = tf.reshape( 154 | regression_candidates, [M, n]) 155 | 156 | losses = tf.squared_difference( 157 | regression_candidates, tf.expand_dims(y, 1)) 158 | prob_median = get_median_probs(P_hat) 159 | prob_median_eval = get_median_probs(P_hat_eval) 160 | 161 | point_estimates = tf.reduce_sum( 162 | prob_median * regression_candidates, axis=1) 163 | exp_loss = tf.squared_difference(y, point_estimates) 164 | 165 | point_estimates_eval = tf.reduce_sum( 166 | prob_median_eval * regression_candidates, axis=1) 167 | exp_loss_eval = tf.squared_difference(y, point_estimates) 168 | 169 | loss_phi = tf.reduce_mean(exp_loss) 170 | loss_theta = tf.reduce_mean(exp_loss_eval) 171 | 172 | elif method == 'deterministic_softsort': 173 | with tf.variable_scope('phi'): 174 | scores = multi_mnist_cnn.deepnn(l, X, 1) 175 | scores = tf.reshape(scores, [M, n, 1]) 176 | 177 | P_hat = util.softsort(scores, temp, power) 178 | P_hat_eval = util.softsort(scores, 1e-20, power) 179 | 180 | with tf.variable_scope('theta'): 181 | regression_candidates = multi_mnist_cnn.deepnn(l, X, 1) 182 | regression_candidates = tf.reshape( 183 | regression_candidates, [M, n]) 184 | 185 | losses = tf.squared_difference( 186 | regression_candidates, tf.expand_dims(y, 1)) 187 | prob_median = get_median_probs(P_hat) 188 | prob_median_eval = get_median_probs(P_hat_eval) 189 | 190 | point_estimates = tf.reduce_sum( 191 | prob_median * regression_candidates, axis=1) 192 | exp_loss = tf.squared_difference(y, point_estimates) 193 | 194 | point_estimates_eval = tf.reduce_sum( 195 | prob_median_eval * regression_candidates, axis=1) 196 | exp_loss_eval = tf.squared_difference(y, point_estimates) 197 | 198 | loss_phi = tf.reduce_mean(exp_loss) 199 | loss_theta = tf.reduce_mean(exp_loss_eval) 200 | 201 | elif method == 'stochastic_neuralsort': 202 | with tf.variable_scope('phi'): 203 | scores = multi_mnist_cnn.deepnn(l, X, 1) 204 | scores = tf.reshape(scores, [M, n, 1]) 205 | scores = tf.tile(scores, [n_s, 1, 1]) 206 | scores += util.sample_gumbel([M * n_s, n, 1]) 207 | 208 | P_hat = util.neuralsort(scores, temp) 209 | P_hat_eval = util.neuralsort(scores, 1e-20) 210 | 211 | with tf.variable_scope('theta'): 212 | regression_candidates = multi_mnist_cnn.deepnn(l, X, 1) 213 | regression_candidates = tf.reshape( 214 | regression_candidates, [M, n]) 215 | 216 | res_y = tf.expand_dims(y, 1) 217 | 218 | losses = tf.squared_difference(regression_candidates, res_y) 219 | 220 | prob_median = get_median_probs(P_hat) 221 | prob_median = tf.reshape(prob_median, [n_s, M, n]) 222 | prob_median_eval = get_median_probs(P_hat_eval) 223 | prob_median_eval = tf.reshape(prob_median_eval, [n_s, M, n]) 224 | 225 | exp_losses = tf.reduce_sum(prob_median * losses, axis=2) 226 | exp_losses_eval = tf.reduce_sum( 227 | prob_median_eval * losses, axis=2) 228 | 229 | point_estimates_eval = tf.reduce_mean(tf.reduce_sum(prob_median_eval * regression_candidates, axis=2), axis=0) 230 | 231 | loss_phi = tf.reduce_mean(exp_losses) 232 | loss_theta = tf.reduce_mean(exp_losses_eval) 233 | 234 | elif method == 'stochastic_softsort': 235 | with tf.variable_scope('phi'): 236 | scores = multi_mnist_cnn.deepnn(l, X, 1) 237 | scores = tf.reshape(scores, [M, n, 1]) 238 | scores = tf.tile(scores, [n_s, 1, 1]) 239 | scores += util.sample_gumbel([M * n_s, n, 1]) 240 | 241 | P_hat = util.softsort(scores, temp, power) 242 | P_hat_eval = util.softsort(scores, 1e-20, power) 243 | 244 | with tf.variable_scope('theta'): 245 | regression_candidates = multi_mnist_cnn.deepnn(l, X, 1) 246 | regression_candidates = tf.reshape( 247 | regression_candidates, [M, n]) 248 | 249 | res_y = tf.expand_dims(y, 1) 250 | 251 | losses = tf.squared_difference(regression_candidates, res_y) 252 | 253 | prob_median = get_median_probs(P_hat) 254 | prob_median = tf.reshape(prob_median, [n_s, M, n]) 255 | prob_median_eval = get_median_probs(P_hat_eval) 256 | prob_median_eval = tf.reshape(prob_median_eval, [n_s, M, n]) 257 | 258 | exp_losses = tf.reduce_sum(prob_median * losses, axis=2) 259 | exp_losses_eval = tf.reduce_sum( 260 | prob_median_eval * losses, axis=2) 261 | 262 | point_estimates_eval = tf.reduce_mean(tf.reduce_sum(prob_median_eval * regression_candidates, axis=2), axis=0) 263 | 264 | loss_phi = tf.reduce_mean(exp_losses) 265 | loss_theta = tf.reduce_mean(exp_losses_eval) 266 | else: 267 | raise ValueError("No such method.") 268 | 269 | num_losses = M * n_s if method == 'stochastic_neuralsort' \ 270 | or method == 'stochastic_softsort' \ 271 | or method == 'gumbel_sinkhorn' else M 272 | 273 | correctly_identified = tf.reduce_sum( 274 | prob_median_eval * median_scores) / num_losses 275 | 276 | phi = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='phi') 277 | theta = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='theta') 278 | 279 | train_phi = tf.train.AdamOptimizer( 280 | initial_rate).minimize(loss_phi, var_list=phi) 281 | 282 | if method != 'vanilla': 283 | train_theta = tf.train.AdamOptimizer(initial_rate).minimize( 284 | loss_phi, var_list=theta) 285 | train_step = tf.group(train_phi, train_theta) 286 | else: 287 | train_step = train_phi 288 | 289 | saver = tf.train.Saver() 290 | 291 | sess = tf.Session() 292 | logfile = open('./logs/%s.log' % experiment_id, 'w') 293 | 294 | 295 | def prnt(*args): 296 | print(*args) 297 | print(*args, file=logfile) 298 | 299 | 300 | sess.run(tf.global_variables_initializer()) 301 | train_sh, validate_sh, test_sh = sess.run([ 302 | train_iterator.string_handle(), 303 | val_iterator.string_handle(), 304 | test_iterator.string_handle() 305 | ]) 306 | 307 | TRAIN_PER_EPOCH = mnist_input.TRAIN_SET_SIZE // (l * M) 308 | VAL_PER_EPOCH = mnist_input.VAL_SET_SIZE // (l * M) 309 | TEST_PER_EPOCH = mnist_input.TEST_SET_SIZE // (l * M) 310 | best_val = float('inf') 311 | tiebreaker_val = -1 312 | 313 | 314 | def save_model(epoch): 315 | saver.save(sess, checkpoint_path + 'checkpoint', global_step=epoch) 316 | 317 | 318 | def load_model(): 319 | filename = tf.train.latest_checkpoint(checkpoint_path) 320 | if filename is None: 321 | raise Exception("No model found.") 322 | print("Loaded model %s." % filename) 323 | saver.restore(sess, filename) 324 | 325 | 326 | def train(epoch): 327 | loss_train = [] 328 | for _ in range(TRAIN_PER_EPOCH): 329 | _, l = sess.run([train_step, loss_phi], 330 | feed_dict={handle: train_sh}) 331 | loss_train.append(l) 332 | prnt('Average loss:', sum(loss_train) / len(loss_train)) 333 | 334 | 335 | def test(epoch, val=False): 336 | global best_val 337 | c_is = [] 338 | l_vs = [] 339 | y_evals = [] 340 | point_estimates_eval_evals = [] 341 | for _ in range(VAL_PER_EPOCH if val else TEST_PER_EPOCH): 342 | if method.startswith('deterministic'): 343 | c_i, l_v, y_eval, point_estimates_eval_eval =\ 344 | sess.run([correctly_identified, loss_phi, y, point_estimates_eval], feed_dict={ 345 | handle: validate_sh if val else test_sh, evaluation: True}) 346 | elif method.startswith('stochastic'): 347 | c_i, l_v, y_eval, point_estimates_eval_eval =\ 348 | sess.run([correctly_identified, loss_phi, res_y, point_estimates_eval], feed_dict={ 349 | handle: validate_sh if val else test_sh, evaluation: True}) 350 | else: 351 | raise ValueError('Cannot handle other methods because I need their prediction tensors and they are ' 352 | 'named differently.') 353 | c_is.append(c_i) 354 | l_vs.append(l_v) 355 | y_evals.append(y_eval.reshape(-1)) 356 | point_estimates_eval_evals.append(point_estimates_eval_eval.reshape(-1)) 357 | y_eval = np.concatenate(y_evals) 358 | point_estimates_eval_eval = np.concatenate(point_estimates_eval_evals) 359 | id_suffix = "_N_%s_%s_TAU_%s_LR_%s_E_%s_REP_%s.txt" % ( 360 | str(n), str(method), str(tau), str(initial_rate), str(NUM_EPOCHS), str(repetition)) 361 | if not val: 362 | np.savetxt(predictions_path + 'y_eval' + id_suffix, y_eval) 363 | np.savetxt(predictions_path + 'point_estimates_eval_eval' + id_suffix, point_estimates_eval_eval) 364 | 365 | c_i = sum(c_is) / len(c_is) 366 | l_v = sum(l_vs) / len(l_vs) 367 | r2 = r2_score(y_eval, point_estimates_eval_eval) 368 | spearman_r = spearmanr(y_eval, point_estimates_eval_eval).correlation 369 | 370 | if val: 371 | prnt("Validation set: correctly identified %f, mean squared error %f, R2 %f, spearmanr %f" % 372 | (c_i, l_v, r2, spearman_r)) 373 | if l_v < best_val: 374 | best_val = l_v 375 | prnt('Saving...') 376 | save_model(epoch) 377 | else: 378 | prnt("Test set: correctly identified %f, mean squared error %f, R2 %f, spearmanr %f" % 379 | (c_i, l_v, r2, spearman_r)) 380 | 381 | 382 | total_training_time = 0 383 | for epoch in range(1, NUM_EPOCHS + 1): 384 | prnt('Epoch', epoch, '(%s)' % experiment_id) 385 | start_time = time.time() 386 | train(epoch) 387 | end_time = time.time() 388 | total_training_time += (end_time - start_time) 389 | test(epoch, val=True) 390 | logfile.flush() 391 | load_model() 392 | test(epoch, val=False) 393 | training_time_per_epoch = total_training_time / NUM_EPOCHS 394 | print("total_training_time: %f" % total_training_time) 395 | print("training_time_per_epoch: %f" % training_time_per_epoch) 396 | 397 | sess.close() 398 | logfile.close() 399 | -------------------------------------------------------------------------------- /tf/run_median.sh: -------------------------------------------------------------------------------- 1 | ROOT_DIR=run_median_results 2 | mkdir -p ${ROOT_DIR} 3 | 4 | NUM_EPOCHS=100 5 | L=4 6 | M=5 7 | for N in 5 8 | do 9 | for METHOD in deterministic_neuralsort 10 | do 11 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 12 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 13 | for LR in 0.001 14 | do 15 | for TAU in 1024 16 | do 17 | for REPETITION in 0 18 | do 19 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 20 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 21 | done 22 | done 23 | done 24 | done 25 | done 26 | 27 | NUM_EPOCHS=100 28 | L=4 29 | M=5 30 | for N in 5 31 | do 32 | for METHOD in stochastic_neuralsort 33 | do 34 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 35 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 36 | for LR in 0.001 37 | do 38 | for TAU in 2048 39 | do 40 | for REPETITION in 0 41 | do 42 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 43 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 44 | done 45 | done 46 | done 47 | done 48 | done 49 | 50 | NUM_EPOCHS=100 51 | L=4 52 | M=5 53 | for N in 5 54 | do 55 | for METHOD in deterministic_softsort 56 | do 57 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 58 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 59 | for LR in 0.001 60 | do 61 | for TAU in 2048 62 | do 63 | for REPETITION in 0 64 | do 65 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 66 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 67 | done 68 | done 69 | done 70 | done 71 | done 72 | 73 | NUM_EPOCHS=100 74 | L=4 75 | M=5 76 | for N in 5 77 | do 78 | for METHOD in stochastic_softsort 79 | do 80 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 81 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 82 | for LR in 0.001 83 | do 84 | for TAU in 4096 85 | do 86 | for REPETITION in 0 87 | do 88 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 89 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 90 | done 91 | done 92 | done 93 | done 94 | done 95 | 96 | NUM_EPOCHS=100 97 | L=4 98 | M=5 99 | for N in 9 100 | do 101 | for METHOD in deterministic_neuralsort 102 | do 103 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 104 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 105 | for LR in 0.001 106 | do 107 | for TAU in 512 108 | do 109 | for REPETITION in 0 110 | do 111 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 112 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 113 | done 114 | done 115 | done 116 | done 117 | done 118 | 119 | NUM_EPOCHS=100 120 | L=4 121 | M=5 122 | for N in 9 123 | do 124 | for METHOD in stochastic_neuralsort 125 | do 126 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 127 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 128 | for LR in 0.001 129 | do 130 | for TAU in 512 131 | do 132 | for REPETITION in 0 133 | do 134 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 135 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 136 | done 137 | done 138 | done 139 | done 140 | done 141 | 142 | NUM_EPOCHS=100 143 | L=4 144 | M=5 145 | for N in 9 146 | do 147 | for METHOD in deterministic_softsort 148 | do 149 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 150 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 151 | for LR in 0.001 152 | do 153 | for TAU in 2048 154 | do 155 | for REPETITION in 0 156 | do 157 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 158 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 159 | done 160 | done 161 | done 162 | done 163 | done 164 | 165 | NUM_EPOCHS=100 166 | L=4 167 | M=5 168 | for N in 9 169 | do 170 | for METHOD in stochastic_softsort 171 | do 172 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 173 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 174 | for LR in 0.001 175 | do 176 | for TAU in 2048 177 | do 178 | for REPETITION in 0 179 | do 180 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 181 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 182 | done 183 | done 184 | done 185 | done 186 | done 187 | 188 | NUM_EPOCHS=100 189 | L=4 190 | M=5 191 | for N in 15 192 | do 193 | for METHOD in deterministic_neuralsort 194 | do 195 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 196 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 197 | for LR in 0.001 198 | do 199 | for TAU in 1024 200 | do 201 | for REPETITION in 0 202 | do 203 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 204 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 205 | done 206 | done 207 | done 208 | done 209 | done 210 | 211 | NUM_EPOCHS=100 212 | L=4 213 | M=5 214 | for N in 15 215 | do 216 | for METHOD in stochastic_neuralsort 217 | do 218 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 219 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 220 | for LR in 0.001 221 | do 222 | for TAU in 4096 223 | do 224 | for REPETITION in 0 225 | do 226 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 227 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 228 | done 229 | done 230 | done 231 | done 232 | done 233 | 234 | NUM_EPOCHS=100 235 | L=4 236 | M=5 237 | for N in 15 238 | do 239 | for METHOD in deterministic_softsort 240 | do 241 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 242 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 243 | for LR in 0.001 244 | do 245 | for TAU in 256 246 | do 247 | for REPETITION in 0 248 | do 249 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 250 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 251 | done 252 | done 253 | done 254 | done 255 | done 256 | 257 | NUM_EPOCHS=100 258 | L=4 259 | M=5 260 | for N in 15 261 | do 262 | for METHOD in stochastic_softsort 263 | do 264 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 265 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 266 | for LR in 0.001 267 | do 268 | for TAU in 2048 269 | do 270 | for REPETITION in 0 271 | do 272 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 273 | python3 run_median.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} --repetition=${REPETITION} 2>&1 | tee ${OUTPUT_FILE} 274 | done 275 | done 276 | done 277 | done 278 | done 279 | 280 | -------------------------------------------------------------------------------- /tf/run_median_learning_curves.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from run_median_table_of_results import RunMedianResultsParser 5 | 6 | 7 | def get_filename(n, method, tau, lr, num_epochs, repetition): 8 | filename = "./run_median_results/N_%s_%s/N_%s_%s_TAU_%s_LR_%s_E_%s_REP_%s.txt" %\ 9 | (n, method, n, method, tau, lr, num_epochs, repetition) 10 | return filename 11 | 12 | 13 | def get_learning_curves(n, method, tau, lr, num_epochs, repetitions): 14 | average_losses = np.zeros(shape=(num_epochs, len(repetitions))) 15 | val_set_correctly_identifieds = np.zeros(shape=(num_epochs, len(repetitions))) 16 | val_set_mean_squared_errors = np.zeros(shape=(num_epochs, len(repetitions))) 17 | val_set_r2s = np.zeros(shape=(num_epochs, len(repetitions))) 18 | val_set_spearmanrs = np.zeros(shape=(num_epochs, len(repetitions))) 19 | for r_id, repetition in enumerate(repetitions): 20 | filename = get_filename(n, method, tau, lr, num_epochs, repetition) 21 | parser = RunMedianResultsParser() 22 | parser.parse(filename, expected_length=num_epochs) 23 | for i in range(num_epochs): 24 | average_losses[i, r_id] = float(parser.average_loss[i]) 25 | val_set_correctly_identifieds[i, r_id] = float(parser.val_set_correctly_identified[i]) 26 | val_set_mean_squared_errors[i, r_id] = float(parser.val_set_mean_squared_error[i]) 27 | val_set_r2s[i, r_id] = float(parser.val_set_r2[i]) 28 | val_set_spearmanrs[i, r_id] = float(parser.val_set_spearmanr[i]) 29 | 30 | average_losses_mean = average_losses.mean(axis=1) 31 | average_losses_std = average_losses.std(axis=1) 32 | 33 | val_set_correctly_identifieds_mean = val_set_correctly_identifieds.mean(axis=1) 34 | val_set_correctly_identifieds_std = val_set_correctly_identifieds.std(axis=1) 35 | 36 | val_set_mean_squared_errors_mean = val_set_mean_squared_errors.mean(axis=1) 37 | val_set_mean_squared_errors_std = val_set_mean_squared_errors.std(axis=1) 38 | 39 | val_set_r2s_mean = val_set_r2s.mean(axis=1) 40 | val_set_r2s_std = val_set_r2s.std(axis=1) 41 | 42 | val_set_spearmanrs_mean = val_set_spearmanrs.mean(axis=1) 43 | val_set_spearmanrs_std = val_set_spearmanrs.std(axis=1) 44 | 45 | return average_losses_mean, average_losses_std, val_set_correctly_identifieds_mean,\ 46 | val_set_correctly_identifieds_std, val_set_mean_squared_errors_mean, val_set_mean_squared_errors_std,\ 47 | val_set_r2s_mean, val_set_r2s_std, val_set_spearmanrs_mean, val_set_spearmanrs_std 48 | 49 | 50 | ns = ['5', '9', '15'] 51 | lr = '0.001' 52 | num_epochs = 100 53 | repetitions = ['0'] 54 | 55 | 56 | def get_tau_for_n_and_method(n, method): 57 | ns = ['5', '9', '15'] 58 | methods = ['deterministic_neuralsort', 'stochastic_neuralsort', 'deterministic_softsort', 'stochastic_softsort'] 59 | taus = [ 60 | ['1024', '2048', '2048', '4096'], 61 | ['512', '512', '2048', '2048'], 62 | ['1024', '4096', '256', '2048']] 63 | 64 | for col, method2 in enumerate(methods): 65 | if method2 == method: 66 | for row, n2 in enumerate(ns): 67 | if n2 == n: 68 | return taus[row][col] 69 | 70 | 71 | for n in ns: 72 | if n != '15': 73 | continue 74 | for det_or_stoch in ["deterministic", "stochastic"]: 75 | if det_or_stoch != "deterministic": 76 | continue 77 | average_losses_mean = {} 78 | average_losses_std = {} 79 | 80 | val_set_correctly_identifieds_mean = {} 81 | val_set_correctly_identifieds_std = {} 82 | 83 | val_set_mean_squared_errors_mean = {} 84 | val_set_mean_squared_errors_std = {} 85 | 86 | val_set_r2s_mean = {} 87 | val_set_r2s_std = {} 88 | 89 | val_set_spearmanrs_mean = {} 90 | val_set_spearmanrs_std = {} 91 | 92 | for method in [det_or_stoch + "_softsort", det_or_stoch + "_neuralsort"]: 93 | tau = get_tau_for_n_and_method(n, method) 94 | average_losses_mean[method],\ 95 | average_losses_std[method],\ 96 | val_set_correctly_identifieds_mean[method],\ 97 | val_set_correctly_identifieds_std[method],\ 98 | val_set_mean_squared_errors_mean[method],\ 99 | val_set_mean_squared_errors_std[method],\ 100 | val_set_r2s_mean[method],\ 101 | val_set_r2s_std[method],\ 102 | val_set_spearmanrs_mean[method],\ 103 | val_set_spearmanrs_std[method] = get_learning_curves(n, method, tau, lr, num_epochs, repetitions) 104 | 105 | title_prefix = f"N = {n}\n" 106 | title_suffix = "\nTraining curve" 107 | 108 | plt.figure(figsize=(7, 5)) 109 | plt.plot(1e-6 * average_losses_mean[det_or_stoch + "_neuralsort"], label="NeuralSort", color='red', 110 | linestyle='--') 111 | plt.plot(1e-6 * average_losses_mean[det_or_stoch + "_softsort"], label="SoftSort", color='blue', 112 | linestyle='-') 113 | fontsize = 17 114 | plt.ylabel(r'Loss ($\times 10^{-6}$)', fontsize=fontsize) 115 | plt.xlabel('Epoch', fontsize=fontsize) 116 | plt.xticks(fontsize=fontsize) 117 | plt.yticks(fontsize=fontsize) 118 | plt.legend(fontsize=fontsize) 119 | plt.tight_layout() 120 | plt.savefig("../images/run_median_learning_curve.png") 121 | -------------------------------------------------------------------------------- /tf/run_median_table_of_results.py: -------------------------------------------------------------------------------- 1 | from statistics import mean 2 | 3 | 4 | class RunMedianResultsParser: 5 | r''' 6 | Parses an individual results (i.e. log) file and stores the results. 7 | ''' 8 | def __init__(self): 9 | self.average_loss = [] 10 | self.val_set_correctly_identified = [] 11 | self.val_set_mean_squared_error = [] 12 | self.val_set_r2 = [] 13 | self.val_set_spearmanr = [] 14 | self.test_set_correctly_identified = -1 15 | self.test_set_mean_squared_error = -1 16 | self.test_set_r2 = -1 17 | self.test_set_spearmanr = -1 18 | 19 | def parse(self, file_path, expected_length=None): 20 | r''' 21 | :param file_path: path to the results (i.e. log) file 22 | ''' 23 | with open(file_path) as file: 24 | for line in file: 25 | line_tokens = line.replace(',', '').replace('\n', '').split(' ') 26 | if line.startswith("Average loss"): 27 | self.average_loss.append(line_tokens[2][:8]) 28 | elif line.startswith("Validation set"): 29 | self.val_set_correctly_identified.append(line_tokens[4][:8]) 30 | self.val_set_mean_squared_error.append(line_tokens[8][:8]) 31 | self.val_set_r2.append(line_tokens[10][:8]) 32 | self.val_set_spearmanr.append(line_tokens[12][:8]) 33 | elif line.startswith("Test set"): 34 | # print(line_tokens) 35 | self.test_set_correctly_identified = line_tokens[4][:8] 36 | self.test_set_mean_squared_error = line_tokens[8][:8] 37 | self.test_set_r2 = line_tokens[10][:8] 38 | self.test_set_spearmanr = line_tokens[12][:8] 39 | # print("file_path = %s" % file_path) 40 | if expected_length: 41 | assert(len(self.val_set_correctly_identified) == expected_length) 42 | assert(len(self.val_set_mean_squared_error) == expected_length) 43 | assert(len(self.val_set_r2) == expected_length) 44 | assert(len(self.val_set_spearmanr) == expected_length) 45 | # Validate data 46 | for list_name in ["average_loss", "val_set_correctly_identified", "val_set_mean_squared_error", 47 | "val_set_r2", "val_set_spearmanr"]: 48 | for i, elem in enumerate(self.__dict__[list_name]): 49 | try: 50 | float(elem) 51 | except ValueError: 52 | print(f"path:\n{file_path}\n{i}: list {list_name} contains non-float: {elem}") 53 | raise ValueError 54 | 55 | def get_val_set_correctly_identified(self): 56 | return self.val_set_correctly_identified[-1] 57 | 58 | def get_val_set_mean_squared_error(self): 59 | return self.val_set_mean_squared_error[-1] 60 | 61 | def get_val_set_r2(self): 62 | return self.val_set_r2[-1] 63 | 64 | def get_val_set_spearmanr(self): 65 | return self.val_set_spearmanr[-1] 66 | 67 | def get_test_set_correctly_identified(self): 68 | return self.test_set_correctly_identified 69 | 70 | def get_test_set_mean_squared_error(self): 71 | return self.test_set_mean_squared_error 72 | 73 | def get_test_set_r2(self): 74 | return self.test_set_r2 75 | 76 | def get_test_set_spearmanr(self): 77 | return self.test_set_spearmanr 78 | 79 | 80 | num_epochs = '100' 81 | l = '4' 82 | m = '20' 83 | ns = ['5', '9', '15'] 84 | methods = ['deterministic_neuralsort', 'stochastic_neuralsort', 'deterministic_softsort', 'stochastic_softsort'] 85 | lr = '0.001' 86 | taus = [ 87 | ['1024', '2048', '2048', '4096'], 88 | ['512', '512', '2048', '2048'], 89 | ['1024', '4096', '256', '2048']] 90 | repetitions = ['0'] 91 | 92 | res = dict() 93 | 94 | for n, taus_for_each_method in zip(ns, taus): 95 | for method, tau in zip(methods, taus_for_each_method): 96 | val_set_correctly_identified = [] 97 | val_set_mean_squared_error = [] 98 | val_set_r2 = [] 99 | val_set_spearmanr = [] 100 | test_set_correctly_identified = [] 101 | test_set_mean_squared_error = [] 102 | test_set_r2 = [] 103 | test_set_spearmanr = [] 104 | for repetition in repetitions: 105 | filename = "./run_median_results/N_%s_%s/N_%s_%s_TAU_%s_LR_%s_E_%s_REP_%s.txt" %\ 106 | (n, method, n, method, tau, lr, num_epochs, repetition) 107 | # print("Processing " + str(filename)) 108 | results_parser = RunMedianResultsParser() 109 | results_parser.parse(filename, expected_length=int(num_epochs)) 110 | val_set_correctly_identified.append(float(results_parser.get_val_set_correctly_identified())) 111 | val_set_mean_squared_error.append(float(results_parser.get_val_set_mean_squared_error())) 112 | val_set_r2.append(float(results_parser.get_val_set_r2())) 113 | val_set_spearmanr.append(float(results_parser.get_val_set_spearmanr())) 114 | test_set_correctly_identified.append(float(results_parser.get_test_set_correctly_identified())) 115 | test_set_mean_squared_error.append(float(results_parser.get_test_set_mean_squared_error())) 116 | test_set_r2.append(float(results_parser.get_test_set_r2())) 117 | test_set_spearmanr.append(float(results_parser.get_test_set_spearmanr())) 118 | res[(n, tau, method, 'test_set_correctly_identified')] = mean(test_set_correctly_identified) 119 | res[(n, tau, method, 'test_set_mean_squared_error')] = mean(test_set_mean_squared_error) 120 | res[(n, tau, method, 'test_set_r2')] = mean(test_set_r2) 121 | res[(n, tau, method, 'test_set_spearmanr')] = mean(test_set_spearmanr) 122 | 123 | 124 | def pretty_print_table(table): 125 | r''' 126 | Pretty prints the given table (of size (1 + #methods) x (1 + #ns)) 127 | ''' 128 | res = "" 129 | nrow = len(table) 130 | ncol = len(table[0]) 131 | # Print header 132 | header = table[0] 133 | for c in range(ncol): 134 | if c == 0: 135 | res += "{:<31}".format('') 136 | else: 137 | res += "| n = " + "{:<10}".format(header[c]) 138 | res += "\n" 139 | for r in range(1, nrow): 140 | # Method name 141 | res += "{:<31}".format(table[r][0]) 142 | for c in range(1, ncol): 143 | res += "| " + table[r][c].replace('\\', '') + " " 144 | res += "\n" 145 | print(res) 146 | 147 | 148 | def pretty_print_table_latex(table): 149 | r''' 150 | Pretty prints the given table (of size (1 + #methods) x (1 + #ns)) 151 | ''' 152 | algorithm_names = { 153 | "deterministic_neuralsort": "Deterministic NeuralSort", 154 | "stochastic_neuralsort": "Stochastic NeuralSort", 155 | "deterministic_softsort": "Deterministic SoftSort", 156 | "stochastic_softsort": "Stochastic SoftSort" 157 | } 158 | res = "" \ 159 | "\\begin{tabular}{lccc}\n" \ 160 | "\\toprule\n" 161 | nrow = len(table) 162 | ncol = len(table[0]) 163 | # Print header 164 | header = table[0] 165 | for c in range(ncol): 166 | if c == 0: 167 | res += "Algorithm " 168 | else: 169 | res += "& $n = " + "{}$ ".format(header[c]) 170 | res += "\\\\\n" 171 | res += "\\midrule\n" 172 | for r in range(1, nrow): 173 | # Method name 174 | res += "{} ".format(algorithm_names[table[r][0]]) 175 | for c in range(1, ncol): 176 | res += "& $" + table[r][c] + "$ " 177 | res += "\\\\\n" 178 | res += "" \ 179 | "\\bottomrule\n" \ 180 | "\\end{tabular}\n" 181 | print(res) 182 | 183 | 184 | def print_table(latex=False): 185 | table = [] 186 | # Add table header 187 | header = ["algorithm"] + [n for n in ns] 188 | table.append(header) 189 | for i, method in enumerate(methods): 190 | row = [method] 191 | for j, n in enumerate(ns): 192 | tau = taus[j][i] 193 | test_set_mean_squared_error = res[(n, tau, method, 'test_set_mean_squared_error')] 194 | test_set_spearmanr = res[(n, tau, method, 'test_set_spearmanr')] 195 | table_entry = "%.2f\\ (%.2f)" % (test_set_mean_squared_error * 1e-4, test_set_spearmanr) 196 | row.append(table_entry) 197 | table.append(row) 198 | if latex: 199 | pretty_print_table_latex(table) 200 | else: 201 | pretty_print_table(table) 202 | 203 | 204 | print_table(latex=True) 205 | print_table(latex=False) 206 | -------------------------------------------------------------------------------- /tf/run_sort.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | import mnist_input 4 | import multi_mnist_cnn 5 | from sinkhorn import gumbel_sinkhorn, sinkhorn_operator 6 | 7 | import util 8 | import random 9 | 10 | tf.set_random_seed(94305) 11 | random.seed(94305) 12 | 13 | flags = tf.app.flags 14 | flags.DEFINE_integer('M', 1, 'batch size') 15 | flags.DEFINE_integer('n', 3, 'number of elements to compare at a time') 16 | flags.DEFINE_integer('l', 4, 'number of digits') 17 | flags.DEFINE_integer('tau', 5, 'temperature (dependent meaning)') 18 | flags.DEFINE_string('method', 'deterministic_neuralsort', 19 | 'which method to use?') 20 | flags.DEFINE_integer('n_s', 5, 'number of samples') 21 | flags.DEFINE_integer('num_epochs', 200, 'number of epochs to train') 22 | flags.DEFINE_float('lr', 1e-4, 'initial learning rate') 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | n_s = FLAGS.n_s 27 | NUM_EPOCHS = FLAGS.num_epochs 28 | M = FLAGS.M 29 | n = FLAGS.n 30 | l = FLAGS.l 31 | tau = FLAGS.tau 32 | method = FLAGS.method 33 | initial_rate = FLAGS.lr 34 | 35 | train_iterator, val_iterator, test_iterator = mnist_input.get_iterators( 36 | l, n, 10 ** l - 1, minibatch_size=M) 37 | 38 | false_tensor = tf.convert_to_tensor(False) 39 | evaluation = tf.placeholder_with_default(false_tensor, ()) 40 | temperature = tf.cond(evaluation, 41 | false_fn=lambda: tf.convert_to_tensor( 42 | tau, dtype=tf.float32), 43 | true_fn=lambda: tf.convert_to_tensor( 44 | 1e-10, dtype=tf.float32) # simulate hard sort 45 | ) 46 | 47 | experiment_id = 'sort-%s-M%d-n%d-l%d-t%d' % (method, M, n, l, tau * 10) 48 | checkpoint_path = 'checkpoints/%s/' % experiment_id 49 | 50 | handle = tf.placeholder(tf.string, ()) 51 | X_iterator = tf.data.Iterator.from_string_handle( 52 | handle, 53 | (tf.float32, tf.float32, tf.float32, tf.float32), 54 | ((M, n, l * 28, 28), (M,), (M, n), (M, n)) 55 | ) 56 | 57 | X, y, median_scores, true_scores = X_iterator.get_next() 58 | true_scores = tf.expand_dims(true_scores, 2) 59 | P_true = util.neuralsort(true_scores, 1e-10) 60 | 61 | if method == 'vanilla': 62 | representations = multi_mnist_cnn.deepnn(l, X, n) 63 | concat_reps = tf.reshape(representations, [M, n * n]) 64 | fc1 = tf.layers.dense(concat_reps, n * n) 65 | fc2 = tf.layers.dense(fc1, n * n) 66 | P_hat_raw = tf.layers.dense(fc2, n * n) 67 | P_hat_raw_square = tf.reshape(P_hat_raw, [M, n, n]) 68 | 69 | P_hat = tf.nn.softmax(P_hat_raw_square, dim=-1) # row-stochastic! 70 | 71 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 72 | labels=P_true, logits=P_hat_raw_square, dim=2) 73 | losses = tf.reduce_mean(losses, axis=-1) 74 | loss = tf.reduce_mean(losses) 75 | 76 | elif method == 'sinkhorn': 77 | representations = multi_mnist_cnn.deepnn(l, X, n) 78 | pre_sinkhorn = tf.reshape(representations, [M, n, n]) 79 | P_hat = sinkhorn_operator(pre_sinkhorn, temp=temperature) 80 | P_hat_logit = tf.log(P_hat) 81 | 82 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 83 | labels=P_true, logits=P_hat_logit, dim=2) 84 | losses = tf.reduce_mean(losses, axis=-1) 85 | loss = tf.reduce_mean(losses) 86 | 87 | elif method == 'gumbel_sinkhorn': 88 | representations = multi_mnist_cnn.deepnn(l, X, n) 89 | pre_sinkhorn = tf.reshape(representations, [M, n, n]) 90 | P_hat = sinkhorn_operator(pre_sinkhorn, temp=temperature) 91 | 92 | P_hat_sample, _ = gumbel_sinkhorn( 93 | pre_sinkhorn, temp=temperature, n_samples=n_s) 94 | P_hat_sample_logit = tf.log(P_hat_sample) 95 | 96 | P_true_sample = tf.expand_dims(P_true, 1) 97 | P_true_sample = tf.tile(P_true_sample, [1, n_s, 1, 1]) 98 | 99 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 100 | labels=P_true_sample, logits=P_hat_sample_logit, dim=3) 101 | losses = tf.reduce_mean(losses, axis=-1) 102 | losses = tf.reshape(losses, [-1]) 103 | loss = tf.reduce_mean(losses) 104 | 105 | elif method == 'deterministic_neuralsort': 106 | scores = multi_mnist_cnn.deepnn(l, X, 1) 107 | scores = tf.reshape(scores, [M, n, 1]) 108 | P_hat = util.neuralsort(scores, temperature) 109 | 110 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 111 | labels=P_true, logits=tf.log(P_hat + 1e-20), dim=2) 112 | losses = tf.reduce_mean(losses, axis=-1) 113 | loss = tf.reduce_mean(losses) 114 | 115 | elif method == 'deterministic_softsort': 116 | scores = multi_mnist_cnn.deepnn(l, X, 1) 117 | scores = tf.reshape(scores, [M, n, 1]) 118 | P_hat = util.softsort(scores, temperature) 119 | 120 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 121 | labels=P_true, logits=tf.log(P_hat + 1e-20), dim=2) 122 | losses = tf.reduce_mean(losses, axis=-1) 123 | loss = tf.reduce_mean(losses) 124 | 125 | elif method == 'stochastic_neuralsort': 126 | scores = multi_mnist_cnn.deepnn(l, X, 1) 127 | scores = tf.reshape(scores, [M, n, 1]) 128 | P_hat = util.neuralsort(scores, temperature) 129 | 130 | scores_sample = tf.tile(scores, [n_s, 1, 1]) 131 | scores_sample += util.sample_gumbel([M * n_s, n, 1]) 132 | P_hat_sample = util.neuralsort( 133 | scores_sample, temperature) 134 | 135 | P_true_sample = tf.tile(P_true, [n_s, 1, 1]) 136 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 137 | labels=P_true_sample, logits=tf.log(P_hat_sample + 1e-20), dim=2) 138 | losses = tf.reduce_mean(losses, axis=-1) 139 | loss = tf.reduce_mean(losses) 140 | 141 | elif method == 'stochastic_softsort': 142 | scores = multi_mnist_cnn.deepnn(l, X, 1) 143 | scores = tf.reshape(scores, [M, n, 1]) 144 | P_hat = util.softsort(scores, temperature) 145 | 146 | scores_sample = tf.tile(scores, [n_s, 1, 1]) 147 | scores_sample += util.sample_gumbel([M * n_s, n, 1]) 148 | P_hat_sample = util.softsort( 149 | scores_sample, temperature) 150 | 151 | P_true_sample = tf.tile(P_true, [n_s, 1, 1]) 152 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 153 | labels=P_true_sample, logits=tf.log(P_hat_sample + 1e-20), dim=2) 154 | losses = tf.reduce_mean(losses, axis=-1) 155 | loss = tf.reduce_mean(losses) 156 | else: 157 | raise ValueError("No such method.") 158 | 159 | 160 | def vec_gradient(l): # l is a scalar 161 | gradient = tf.gradients(l, tf.trainable_variables()) 162 | vec_grads = [tf.reshape(grad, [-1]) for grad in gradient] # flatten 163 | z = tf.concat(vec_grads, 0) # n_params 164 | return z 165 | 166 | 167 | prop_correct = util.prop_correct(P_true, P_hat) 168 | prop_any_correct = util.prop_any_correct(P_true, P_hat) 169 | 170 | opt = tf.train.AdamOptimizer(initial_rate) 171 | train_step = opt.minimize(loss) 172 | saver = tf.train.Saver() 173 | 174 | # MAIN BEGINS 175 | 176 | sess = tf.Session() 177 | logfile = open('./logs/%s.log' % experiment_id, 'w') 178 | 179 | 180 | def prnt(*args): 181 | print(*args) 182 | print(*args, file=logfile) 183 | 184 | 185 | sess.run(tf.global_variables_initializer()) 186 | train_sh, validate_sh, test_sh = sess.run([ 187 | train_iterator.string_handle(), 188 | val_iterator.string_handle(), 189 | test_iterator.string_handle() 190 | ]) 191 | 192 | 193 | TRAIN_PER_EPOCH = mnist_input.TRAIN_SET_SIZE // (l * M) 194 | VAL_PER_EPOCH = mnist_input.VAL_SET_SIZE // (l * M) 195 | TEST_PER_EPOCH = mnist_input.TEST_SET_SIZE // (l * M) 196 | best_correct_val = 0 197 | 198 | 199 | def save_model(epoch): 200 | saver.save(sess, checkpoint_path + 'checkpoint', global_step=epoch) 201 | 202 | 203 | def load_model(): 204 | filename = tf.train.latest_checkpoint(checkpoint_path) 205 | if filename is None: 206 | raise Exception("No model found.") 207 | prnt("Loaded model %s." % filename) 208 | saver.restore(sess, filename) 209 | 210 | 211 | def train(epoch): 212 | loss_train = [] 213 | for _ in range(TRAIN_PER_EPOCH): 214 | _, l = sess.run([train_step, loss], 215 | feed_dict={handle: train_sh}) 216 | loss_train.append(l) 217 | prnt('Average loss:', sum(loss_train) / len(loss_train)) 218 | 219 | 220 | def test(epoch, val=False): 221 | global best_correct_val 222 | p_cs = [] 223 | p_acs = [] 224 | for _ in range(VAL_PER_EPOCH if val else TEST_PER_EPOCH): 225 | p_c, p_ac = sess.run([prop_correct, prop_any_correct], feed_dict={ 226 | handle: validate_sh if val else test_sh, 227 | evaluation: True}) 228 | p_cs.append(p_c) 229 | p_acs.append(p_ac) 230 | 231 | p_c = sum(p_cs) / len(p_cs) 232 | p_ac = sum(p_acs) / len(p_acs) 233 | 234 | if val: 235 | prnt("Validation set: prop. all correct %f, prop. any correct %f" % 236 | (p_c, p_ac)) 237 | if p_c > best_correct_val: 238 | best_correct_val = p_c 239 | prnt('Saving...') 240 | save_model(epoch) 241 | else: 242 | prnt("Test set: prop. all correct %f, prop. any correct %f" % (p_c, p_ac)) 243 | 244 | 245 | total_training_time = 0 246 | for epoch in range(1, NUM_EPOCHS + 1): 247 | prnt('Epoch', epoch, '(%s)' % experiment_id) 248 | start_time = time.time() 249 | train(epoch) 250 | end_time = time.time() 251 | total_training_time += (end_time - start_time) 252 | test(epoch, val=True) 253 | logfile.flush() 254 | load_model() 255 | test(epoch, val=False) 256 | training_time_per_epoch = total_training_time / NUM_EPOCHS 257 | print("total_training_time: %f" % total_training_time) 258 | print("training_time_per_epoch: %f" % training_time_per_epoch) 259 | 260 | sess.close() 261 | logfile.close() 262 | -------------------------------------------------------------------------------- /tf/run_sort.sh: -------------------------------------------------------------------------------- 1 | ROOT_DIR=run_sort_results 2 | mkdir -p ${ROOT_DIR} 3 | 4 | NUM_EPOCHS=100 5 | L=4 6 | M=20 7 | for N in 3 5 7 8 | do 9 | for METHOD in deterministic_neuralsort deterministic_softsort stochastic_neuralsort stochastic_softsort 10 | do 11 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 12 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 13 | for LR in 0.005 14 | do 15 | for TAU in 1024 16 | do 17 | for REPETITION in 0 1 2 3 4 5 6 7 8 9 18 | do 19 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 20 | python3 run_sort.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} 2>&1 | tee ${OUTPUT_FILE} 21 | done 22 | done 23 | done 24 | done 25 | done 26 | 27 | NUM_EPOCHS=100 28 | L=4 29 | M=20 30 | for N in 9 15 31 | do 32 | for METHOD in deterministic_neuralsort deterministic_softsort stochastic_neuralsort stochastic_softsort 33 | do 34 | GRID_SEARCH_RESULTS_DIR="N_${N}_${METHOD}" 35 | mkdir ${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR} 36 | for LR in 0.005 37 | do 38 | for TAU in 128 39 | do 40 | for REPETITION in 0 1 2 3 4 5 6 7 8 9 41 | do 42 | OUTPUT_FILE="${ROOT_DIR}/${GRID_SEARCH_RESULTS_DIR}/N_${N}_${METHOD}_TAU_${TAU}_LR_${LR}_E_${NUM_EPOCHS}_REP_${REPETITION}.txt" 43 | python3 run_sort.py --num_epochs ${NUM_EPOCHS} --l=${L} --lr=${LR} --M=${M} --n=${N} --tau=${TAU} --method=${METHOD} 2>&1 | tee ${OUTPUT_FILE} 44 | done 45 | done 46 | done 47 | done 48 | done -------------------------------------------------------------------------------- /tf/run_sort_learning_curves.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from run_sort_table_of_results import RunSortResultsParser 5 | 6 | 7 | def get_filename(n, method, tau, lr, num_epochs, repetition): 8 | filename = "./run_sort_results/N_%s_%s/N_%s_%s_TAU_%s_LR_%s_E_%s_REP_%s.txt" %\ 9 | (n, method, n, method, tau, lr, num_epochs, repetition) 10 | return filename 11 | 12 | 13 | def get_learning_curves(n, method, tau, lr, num_epochs, repetitions): 14 | average_losses = np.zeros(shape=(num_epochs, len(repetitions))) 15 | val_set_prop_all_corrects = np.zeros(shape=(num_epochs, len(repetitions))) 16 | val_set_prop_any_corrects = np.zeros(shape=(num_epochs, len(repetitions))) 17 | for r_id, repetition in enumerate(repetitions): 18 | filename = get_filename(n, method, tau, lr, num_epochs, repetition) 19 | parser = RunSortResultsParser() 20 | parser.parse(filename, expected_length=num_epochs) 21 | for i in range(num_epochs): 22 | average_losses[i, r_id] = float(parser.average_loss[i]) 23 | val_set_prop_all_corrects[i, r_id] = float(parser.val_set_prop_all_correct[i]) 24 | val_set_prop_any_corrects[i, r_id] = float(parser.val_set_prop_any_correct[i]) 25 | 26 | average_losses_mean = average_losses.mean(axis=1) 27 | average_losses_std = average_losses.std(axis=1) 28 | 29 | val_set_prop_all_corrects_mean = val_set_prop_all_corrects.mean(axis=1) 30 | val_set_prop_all_corrects_std = val_set_prop_all_corrects.std(axis=1) 31 | 32 | val_set_prop_any_corrects_mean = val_set_prop_any_corrects.mean(axis=1) 33 | val_set_prop_any_corrects_std = val_set_prop_any_corrects.std(axis=1) 34 | 35 | return average_losses_mean, average_losses_std, val_set_prop_all_corrects_mean, val_set_prop_all_corrects_std,\ 36 | val_set_prop_any_corrects_mean, val_set_prop_any_corrects_std 37 | 38 | 39 | ns = ['3', '5', '7', '9', '15'] 40 | taus = ['1024', '1024', '1024', '128', '128'] 41 | lr = '0.005' 42 | num_epochs = 100 43 | repetitions = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 44 | 45 | 46 | for n, tau in zip(ns, taus): 47 | if n != '15': 48 | continue 49 | print(f"n = {n}") 50 | 51 | for det_or_stoch in ["deterministic", "stochastic"]: 52 | if det_or_stoch != "deterministic": 53 | continue 54 | average_losses_mean = {} 55 | average_losses_std = {} 56 | val_set_prop_all_corrects_mean = {} 57 | val_set_prop_all_corrects_std = {} 58 | val_set_prop_any_corrects_mean = {} 59 | val_set_prop_any_corrects_std = {} 60 | 61 | for method in [det_or_stoch + "_softsort", det_or_stoch + "_neuralsort"]: 62 | average_losses_mean[method],\ 63 | average_losses_std[method],\ 64 | val_set_prop_all_corrects_mean[method],\ 65 | val_set_prop_all_corrects_std[method],\ 66 | val_set_prop_any_corrects_mean[method],\ 67 | val_set_prop_any_corrects_std[method] = get_learning_curves(n, method, tau, lr, num_epochs, repetitions) 68 | 69 | title_prefix = f"N = {n}\n" 70 | title_suffix = "\nAverage training curve over 10 repetitions" 71 | 72 | plt.figure(figsize=(7, 5)) 73 | plt.plot(average_losses_mean[det_or_stoch + "_neuralsort"], label="NeuralSort", color='red', linestyle='--') 74 | plt.plot(average_losses_mean[det_or_stoch + "_softsort"], label="SoftSort", color='blue', linestyle='-') 75 | fontsize = 17 76 | plt.ylabel('Loss', fontsize=fontsize) 77 | plt.xlabel('Epoch', fontsize=fontsize) 78 | plt.xticks(fontsize=fontsize) 79 | plt.yticks(fontsize=fontsize) 80 | plt.legend(fontsize=fontsize) 81 | plt.tight_layout() 82 | plt.savefig("../images/run_sort_learning_curve.png") 83 | -------------------------------------------------------------------------------- /tf/run_sort_table_of_results.py: -------------------------------------------------------------------------------- 1 | from statistics import mean, median, stdev 2 | 3 | 4 | class RunSortResultsParser: 5 | r''' 6 | Parses an individual results (i.e. log) file and stores the results. 7 | ''' 8 | def __init__(self): 9 | self.average_loss = [] 10 | self.val_set_prop_all_correct = [] 11 | self.val_set_prop_any_correct = [] 12 | self.test_set_prop_all_correct = -1 13 | self.test_set_prop_any_correct = -1 14 | 15 | def parse(self, file_path, expected_length=None): 16 | r''' 17 | :param file_path: path to the results (i.e. log) file 18 | ''' 19 | with open(file_path) as file: 20 | for line in file: 21 | line_tokens = line.replace(',', ' ').replace('\n', '').split(' ') 22 | if line.startswith("Average loss"): 23 | self.average_loss.append(line_tokens[2][:8]) 24 | elif line.startswith("Validation set"): 25 | self.val_set_prop_all_correct.append(line_tokens[5][:8]) 26 | self.val_set_prop_any_correct.append(line_tokens[10][:8]) 27 | elif line.startswith("Test set"): 28 | self.test_set_prop_all_correct = line_tokens[5][:8] 29 | self.test_set_prop_any_correct = line_tokens[10][:8] 30 | # print("file_path = %s" % file_path) 31 | if expected_length: 32 | assert(len(self.val_set_prop_all_correct) == expected_length) 33 | assert(len(self.val_set_prop_any_correct) == expected_length) 34 | # Check that all parsed entries are floats 35 | for list_name in ["average_loss", "val_set_prop_all_correct", "val_set_prop_any_correct"]: 36 | for i, elem in enumerate(self.__dict__[list_name]): 37 | try: 38 | float(elem) 39 | except ValueError: 40 | print(f"path:\n{file_path}\n{i}: list {list_name} contains non-float: {elem}") 41 | raise ValueError 42 | 43 | def get_val_set_prop_all_correct(self): 44 | return self.val_set_prop_all_correct[-1] 45 | 46 | def get_val_set_prop_any_correct(self): 47 | return self.val_set_prop_any_correct[-1] 48 | 49 | def get_test_set_prop_all_correct(self): 50 | return self.test_set_prop_all_correct 51 | 52 | def get_test_set_prop_any_correct(self): 53 | return self.test_set_prop_any_correct 54 | 55 | 56 | num_epochs = '100' 57 | l = '4' 58 | m = '20' 59 | ns = ['3', '5', '7', '9', '15'] 60 | methods = ['deterministic_neuralsort', 'stochastic_neuralsort', 'deterministic_softsort', 'stochastic_softsort'] 61 | lr = '0.005' 62 | taus = ['1024', '1024', '1024', '128', '128'] 63 | repetitions = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 64 | 65 | res = dict() 66 | 67 | for n, tau in zip(ns, taus): 68 | for method in methods: 69 | val_set_all_correct = [] 70 | val_set_any_correct = [] 71 | test_set_all_correct = [] 72 | test_set_any_correct = [] 73 | for repetition in repetitions: 74 | filename = "./run_sort_results/N_%s_%s/N_%s_%s_TAU_%s_LR_%s_E_%s_REP_%s.txt" %\ 75 | (n, method, n, method, tau, lr, num_epochs, repetition) 76 | # print("Processing " + str(filename)) 77 | results_parser = RunSortResultsParser() 78 | results_parser.parse(filename, expected_length=int(num_epochs)) 79 | val_set_all_correct.append(float(results_parser.get_val_set_prop_all_correct())) 80 | val_set_any_correct.append(float(results_parser.get_val_set_prop_any_correct())) 81 | test_set_all_correct.append(float(results_parser.get_test_set_prop_all_correct())) 82 | test_set_any_correct.append(float(results_parser.get_test_set_prop_any_correct())) 83 | res[(n, tau, method, 'test_set_all_correct_median')] = median(test_set_all_correct) 84 | res[(n, tau, method, 'test_set_any_correct_median')] = median(test_set_any_correct) 85 | res[(n, tau, method, 'test_set_all_correct_mean')] = mean(test_set_all_correct) 86 | res[(n, tau, method, 'test_set_any_correct_mean')] = mean(test_set_any_correct) 87 | # Add SDs 88 | res[(n, tau, method, 'test_set_all_correct_sd')] = stdev(test_set_all_correct) 89 | res[(n, tau, method, 'test_set_any_correct_sd')] = stdev(test_set_any_correct) 90 | 91 | 92 | def pretty_print_table(table): 93 | r''' 94 | Pretty prints the given table (of size (1 + #methods) x (1 + #ns)) 95 | ''' 96 | res = "" 97 | nrow = len(table) 98 | ncol = len(table[0]) 99 | # Print header 100 | header = table[0] 101 | for c in range(ncol): 102 | if c == 0: 103 | res += "{:<31}".format('') 104 | else: 105 | res += "| n = " + "{:<11}".format(header[c]) 106 | res += "\n" 107 | for r in range(1, nrow): 108 | # Method name 109 | res += "{:<31}".format(table[r][0]) 110 | for c in range(1, ncol): 111 | res += "| " + table[r][c].replace('\\', '') + " " 112 | res += "\n" 113 | print(res) 114 | 115 | 116 | def pretty_print_table_latex(table): 117 | r''' 118 | Pretty prints the given table (of size (1 + #methods) x (1 + #ns)) 119 | ''' 120 | algorithm_names = { 121 | "deterministic_neuralsort": "Deterministic NeuralSort", 122 | "stochastic_neuralsort": "Stochastic NeuralSort", 123 | "deterministic_softsort": "Deterministic SoftSort", 124 | "stochastic_softsort": "Stochastic SoftSort" 125 | } 126 | res = "" \ 127 | "\\begin{tabular}{lccccc}\n" \ 128 | "\\toprule\n" 129 | nrow = len(table) 130 | ncol = len(table[0]) 131 | # Print header 132 | header = table[0] 133 | for c in range(ncol): 134 | if c == 0: 135 | res += "Algorithm " 136 | else: 137 | res += "& $n = " + "{}$ ".format(header[c]) 138 | res += "\\\\\n" 139 | res += "\\midrule\n" 140 | for r in range(1, nrow): 141 | # Method name 142 | res += "{} ".format(algorithm_names[table[r][0]]) 143 | for c in range(1, ncol): 144 | res += "& $" + table[r][c] + "$ " 145 | res += "\\\\\n" 146 | res += "" \ 147 | "\\bottomrule\n" \ 148 | "\\end{tabular}\n" 149 | print(res) 150 | 151 | 152 | def print_table_for_metric( 153 | mean_or_median, 154 | legacy=False, # Prints as in SoftSort paper v0 (which is same as NeuralSort and OT paper) 155 | show_test_set_all_correct=True, # Ignored if legacy=True 156 | test_set_all_any_correct=False, # Ignored if legacy=True 157 | latex=False): 158 | print('*' * 30 + (' %s over %d runs ' % (mean_or_median, len(repetitions))) + '*' * 30) 159 | table = [] 160 | # Add table header 161 | header = ["algorithm"] + [n for n in ns] 162 | table.append(header) 163 | for method in methods: 164 | row = [method] 165 | for n, tau in zip(ns, taus): 166 | test_set_all_correct = res[(n, tau, method, 'test_set_all_correct_' + mean_or_median)] 167 | test_set_all_correct_sd = res[(n, tau, method, 'test_set_all_correct_sd')] 168 | test_set_any_correct = res[(n, tau, method, 'test_set_any_correct_' + mean_or_median)] 169 | test_set_any_correct_sd = res[(n, tau, method, 'test_set_any_correct_sd')] 170 | if legacy: 171 | table_entry = "%.3f\\ (%.3f)" % (test_set_all_correct, test_set_any_correct) 172 | else: 173 | table_entry = "" 174 | if show_test_set_all_correct: 175 | table_entry += "%.3f\\ \\pm\\ %.3f" % (test_set_all_correct, test_set_all_correct_sd) 176 | if test_set_all_any_correct: 177 | table_entry += "%.3f\\ \\pm\\ %.3f" % (test_set_any_correct, test_set_any_correct_sd) 178 | row.append(table_entry) 179 | table.append(row) 180 | if latex: 181 | pretty_print_table_latex(table) 182 | else: 183 | pretty_print_table(table) 184 | 185 | 186 | print('*' * 90) 187 | print('*' * 30 + ' Printing Legacy Tables (no SDs) ' + '*' * 30) 188 | print('*' * 90) 189 | print_table_for_metric('mean', legacy=True, latex=True) 190 | print_table_for_metric('mean', legacy=True, latex=False) 191 | 192 | print('*' * 90) 193 | print('*' * 30 + ' Printing New Tables with SDs ' + '*' * 30) 194 | print('*' * 90) 195 | 196 | for show_metric in [(True, False), (False, True)]: 197 | print_table_for_metric( 198 | 'mean', 199 | legacy=False, 200 | show_test_set_all_correct=show_metric[0], 201 | test_set_all_any_correct=show_metric[1], 202 | latex=True) 203 | 204 | print_table_for_metric( 205 | 'mean', 206 | legacy=False, 207 | show_test_set_all_correct=show_metric[0], 208 | test_set_all_any_correct=show_metric[1], 209 | latex=False) 210 | -------------------------------------------------------------------------------- /tf/sinkhorn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Parts of the original code have been removed. 16 | 17 | import tensorflow as tf 18 | from util import sample_gumbel 19 | 20 | 21 | def sinkhorn_operator(log_alpha, n_iters=20, temp=1.0): 22 | """Performs incomplete Sinkhorn normalization to log_alpha. 23 | By a theorem by Sinkhorn and Knopp [1], a sufficiently well-behaved matrix 24 | with positive entries can be turned into a doubly-stochastic matrix 25 | (i.e. its rows and columns add up to one) via the succesive row and column 26 | normalization. 27 | -To ensure positivity, the effective input to sinkhorn has to be 28 | exp(log_alpha) (elementwise). 29 | -However, for stability, sinkhorn works in the log-space. It is only at 30 | return time that entries are exponentiated. 31 | [1] Sinkhorn, Richard and Knopp, Paul. 32 | Concerning nonnegative matrices and doubly stochastic 33 | matrices. Pacific Journal of Mathematics, 1967 34 | Args: 35 | log_alpha: 2D tensor (a matrix of shape [N, N]) 36 | or 3D tensor (a batch of matrices of shape = [batch_size, N, N]) 37 | n_iters: number of sinkhorn iterations (in practice, as little as 20 38 | iterations are needed to achieve decent convergence for N~100) 39 | Returns: 40 | A 3D tensor of close-to-doubly-stochastic matrices (2D tensors are 41 | converted to 3D tensors with batch_size equals to 1) 42 | """ 43 | 44 | n = tf.shape(log_alpha)[1] 45 | log_alpha = tf.reshape(log_alpha, [-1, n, n]) / temp 46 | 47 | for _ in range(n_iters): 48 | log_alpha -= tf.reshape(tf.reduce_logsumexp(log_alpha, axis=2), [-1, n, 1]) 49 | log_alpha -= tf.reshape(tf.reduce_logsumexp(log_alpha, axis=1), [-1, 1, n]) 50 | return tf.exp(log_alpha) 51 | 52 | 53 | def gumbel_sinkhorn(log_alpha, 54 | temp=1.0, n_samples=1, noise_factor=1.0, n_iters=20, 55 | squeeze=True): 56 | """Random doubly-stochastic matrices via gumbel noise. 57 | In the zero-temperature limit sinkhorn(log_alpha/temp) approaches 58 | a permutation matrix. Therefore, for low temperatures this method can be 59 | seen as an approximate sampling of permutation matrices, where the 60 | distribution is parameterized by the matrix log_alpha 61 | The deterministic case (noise_factor=0) is also interesting: it can be 62 | shown that lim t->0 sinkhorn(log_alpha/t) = M, where M is a 63 | permutation matrix, the solution of the 64 | matching problem M=arg max_M sum_i,j log_alpha_i,j M_i,j. 65 | Therefore, the deterministic limit case of gumbel_sinkhorn can be seen 66 | as approximate solving of a matching problem, otherwise solved via the 67 | Hungarian algorithm. 68 | Warning: the convergence holds true in the limit case n_iters = infty. 69 | Unfortunately, in practice n_iter is finite which can lead to numerical 70 | instabilities, mostly if temp is very low. Those manifest as 71 | pseudo-convergence or some row-columns to fractional entries (e.g. 72 | a row having two entries with 0.5, instead of a single 1.0) 73 | To minimize those effects, try increasing n_iter for decreased temp. 74 | On the other hand, too-low temperature usually lead to high-variance in 75 | gradients, so better not choose too low temperatures. 76 | Args: 77 | log_alpha: 2D tensor (a matrix of shape [N, N]) 78 | or 3D tensor (a batch of matrices of shape = [batch_size, N, N]) 79 | temp: temperature parameter, a float. 80 | n_samples: number of samples 81 | noise_factor: scaling factor for the gumbel samples. Mostly to explore 82 | different degrees of randomness (and the absence of randomness, with 83 | noise_factor=0) 84 | n_iters: number of sinkhorn iterations. Should be chosen carefully, in 85 | inverse corresponde with temp to avoid numerical stabilities. 86 | squeeze: a boolean, if True and there is a single sample, the output will 87 | remain being a 3D tensor. 88 | Returns: 89 | sink: a 4D tensor of [batch_size, n_samples, N, N] i.e. 90 | batch_size *n_samples doubly-stochastic matrices. If n_samples = 1 and 91 | squeeze = True then the output is 3D. 92 | log_alpha_w_noise: a 4D tensor of [batch_size, n_samples, N, N] of 93 | noisy samples of log_alpha, divided by the temperature parameter. If 94 | n_samples = 1 then the output is 3D. 95 | """ 96 | n = tf.shape(log_alpha)[1] 97 | log_alpha = tf.reshape(log_alpha, [-1, n, n]) 98 | batch_size = tf.shape(log_alpha)[0] 99 | log_alpha_w_noise = tf.tile(log_alpha, [n_samples, 1, 1]) 100 | if noise_factor == 0: 101 | noise = 0.0 102 | else: 103 | noise = sample_gumbel([n_samples * batch_size, n, n]) * noise_factor 104 | log_alpha_w_noise += noise 105 | log_alpha_w_noise /= temp 106 | sink = sinkhorn_operator(log_alpha_w_noise, n_iters) 107 | if n_samples > 1 or squeeze is False: 108 | sink = tf.reshape(sink, [n_samples, batch_size, n, n]) 109 | sink = tf.transpose(sink, [1, 0, 2, 3]) 110 | log_alpha_w_noise = tf.reshape( 111 | log_alpha_w_noise, [n_samples, batch_size, n, n]) 112 | log_alpha_w_noise = tf.transpose(log_alpha_w_noise, [1, 0, 2, 3]) 113 | return sink, log_alpha_w_noise 114 | -------------------------------------------------------------------------------- /tf/synthetic_experiment_learning_curves.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from scipy import stats 8 | 9 | import util 10 | 11 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 12 | 13 | parser = argparse.ArgumentParser(description="Benchmark speed of softsort vs" 14 | " neuralsort") 15 | 16 | parser.add_argument("--batch_size", type=int, default=20) 17 | parser.add_argument("--n", type=int, default=2000) 18 | parser.add_argument("--epochs", type=int, default=100) 19 | parser.add_argument("--device", type=str, default='cpu') 20 | parser.add_argument("--method", type=str, default='neuralsort') 21 | parser.add_argument("--tau", type=float, default=1.0) 22 | parser.add_argument("--pow", type=float, default=1.0) 23 | 24 | args = parser.parse_args() 25 | 26 | print("Benchmarking with:\n" 27 | "\tbatch_size = %d\n" 28 | "\tn = %d\n" 29 | "\tepochs = %d\n" 30 | "\tdevice = %s\n" 31 | "\tmethod = %s\n" 32 | "\ttau = %s\n" 33 | "\tpow = %f\n" % 34 | (args.batch_size, 35 | args.n, 36 | args.epochs, 37 | args.device, 38 | args.method, 39 | args.tau, 40 | args.pow)) 41 | 42 | sort_op = None 43 | if args.method == 'neuralsort': 44 | sort_op = util.neuralsort 45 | elif args.method == 'softsort': 46 | sort_op = util.softsort 47 | else: 48 | raise ValueError('method %s not found' % args.method) 49 | 50 | device_str = '/GPU:0' if args.device == 'cuda' else '/CPU:0' 51 | 52 | 53 | def evaluate(scores_eval): 54 | r''' 55 | Returns the mean spearman correlation over the batch. 56 | ''' 57 | rank_correlations = [] 58 | for i in range(args.batch_size): 59 | rank_correlation, _ = stats.spearmanr(scores_eval[i, :, 0], 60 | range(args.n, 0, -1)) 61 | rank_correlations.append(rank_correlation) 62 | mean_rank_correlation = np.mean(rank_correlations) 63 | return mean_rank_correlation 64 | 65 | 66 | log = "" 67 | with tf.Session() as sess: 68 | with tf.device(device_str): 69 | np.random.seed(1) 70 | tf.set_random_seed(1) 71 | # Define model 72 | scores = tf.get_variable( 73 | shape=[args.batch_size, args.n, 1], 74 | initializer=tf.random_uniform_initializer(-1.0, 1.0), 75 | name='scores') 76 | 77 | # Normalize scores before feeding them into the sorting op for increased stability. 78 | min_scores = tf.math.reduce_min(scores, axis=1, keepdims=True) 79 | min_scores = tf.stop_gradient(min_scores) 80 | max_scores = tf.math.reduce_max(scores, axis=1, keepdims=True) 81 | max_scores = tf.stop_gradient(max_scores) 82 | scores_normalized = (scores - min_scores) / (max_scores - min_scores) 83 | 84 | if args.method == 'softsort': 85 | P_hat = sort_op(scores_normalized, tau=args.tau, pow=args.pow) 86 | else: 87 | P_hat = sort_op(scores_normalized, tau=args.tau) 88 | 89 | wd = 5.0 90 | loss = (tf.reduce_mean(1.0 - tf.log(tf.matrix_diag_part(P_hat))) 91 | + wd * tf.reduce_mean(tf.multiply(scores, scores))) * args.batch_size 92 | optimizer = tf.train.MomentumOptimizer( 93 | learning_rate=10.0, 94 | momentum=0.5).\ 95 | minimize(loss, var_list=[scores]) 96 | # Train model 97 | tf.global_variables_initializer().run() 98 | # Train 99 | start_time = time.time() 100 | for epoch in range(args.epochs): 101 | _, loss_eval, scores_eval = sess.run([optimizer, loss, scores]) 102 | spearmanr = evaluate(scores_eval) 103 | log += "Epoch %d loss = %f spearmanr = %f\n" % (epoch, loss_eval, spearmanr) 104 | loss_eval, scores_eval = sess.run([loss, scores]) 105 | spearmanr = evaluate(scores_eval) 106 | end_time = time.time() 107 | total_time = end_time - start_time 108 | log += "Epochs: %d\n" % args.epochs 109 | log += "Loss: %f\n" % loss_eval 110 | log += "Spearmanr: %f\n" % spearmanr 111 | log += "Total time: %f\n" % total_time 112 | log += "Time per epoch: %f\n" % (total_time / args.epochs) 113 | 114 | print(log) 115 | -------------------------------------------------------------------------------- /tf/synthetic_experiment_speed_comparison.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from scipy import stats 8 | 9 | import util 10 | 11 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 12 | 13 | parser = argparse.ArgumentParser(description="Benchmark speed of softsort vs" 14 | " neuralsort") 15 | 16 | parser.add_argument("--batch_size", type=int, default=20) 17 | parser.add_argument("--n", type=int, default=2000) 18 | parser.add_argument("--epochs", type=int, default=100) 19 | parser.add_argument("--device", type=str, default='cpu') 20 | parser.add_argument("--method", type=str, default='neuralsort') 21 | parser.add_argument("--burnin", type=int, default=100) 22 | 23 | args = parser.parse_args() 24 | 25 | print("Benchmarking with:\n" 26 | "\tbatch_size = %d\n" 27 | "\tn = %d\n" 28 | "\tepochs = %d\n" 29 | "\tdevice = %s\n" 30 | "\tmethod = %s\n" 31 | "\tburnin = %d" % 32 | (args.batch_size, 33 | args.n, 34 | args.epochs, 35 | args.device, 36 | args.method, 37 | args.burnin)) 38 | 39 | sort_op = None 40 | if args.method == 'neuralsort': 41 | sort_op = util.neuralsort 42 | args.tau = 100.0 43 | elif args.method == 'softsort': 44 | sort_op = util.softsort_p2 45 | args.tau = 0.03 46 | else: 47 | raise ValueError('method %s not found' % args.method) 48 | 49 | device_str = '/GPU:0' if args.device == 'cuda' else '/CPU:0' 50 | 51 | 52 | def evaluate(scores_eval): 53 | r''' 54 | Returns the mean spearman correlation over the batch. 55 | ''' 56 | rank_correlations = [] 57 | for i in range(args.batch_size): 58 | rank_correlation, _ = stats.spearmanr(scores_eval[i, :, 0], 59 | range(args.n, 0, -1)) 60 | rank_correlations.append(rank_correlation) 61 | mean_rank_correlation = np.mean(rank_correlations) 62 | return mean_rank_correlation 63 | 64 | 65 | log = "" 66 | with tf.Session() as sess: 67 | with tf.device(device_str): 68 | np.random.seed(1) 69 | tf.set_random_seed(1) 70 | # Define model 71 | scores = tf.get_variable( 72 | shape=[args.batch_size, args.n, 1], 73 | initializer=tf.random_uniform_initializer(-1.0, 1.0), 74 | name='scores') 75 | 76 | # Normalize scores before feeding them into the sorting op for increased stability. 77 | min_scores = tf.math.reduce_min(scores, axis=1, keepdims=True) 78 | min_scores = tf.stop_gradient(min_scores) 79 | max_scores = tf.math.reduce_max(scores, axis=1, keepdims=True) 80 | max_scores = tf.stop_gradient(max_scores) 81 | scores_normalized = (scores - min_scores) / (max_scores - min_scores) 82 | 83 | P_hat = sort_op(scores_normalized, tau=args.tau) 84 | 85 | wd = 5.0 86 | loss = (tf.reduce_mean(1.0 - tf.log(tf.matrix_diag_part(P_hat))) 87 | + wd * tf.reduce_mean(tf.multiply(scores, scores))) * args.batch_size 88 | optimizer = tf.train.MomentumOptimizer( 89 | learning_rate=10.0, 90 | momentum=0.5).\ 91 | minimize(loss, var_list=[scores]) 92 | # Train model 93 | tf.global_variables_initializer().run() 94 | # Burn-in 95 | for _ in range(args.burnin): 96 | sess.run(optimizer) 97 | # Train 98 | start_time = time.time() 99 | for epoch in range(args.epochs): 100 | sess.run(optimizer) 101 | loss_eval, scores_eval = sess.run([loss, scores]) 102 | spearmanr = evaluate(scores_eval) 103 | end_time = time.time() 104 | total_time = end_time - start_time 105 | log += "Epochs: %d\n" % args.epochs 106 | log += "Loss: %f\n" % loss_eval 107 | log += "Spearmanr: %f\n" % spearmanr 108 | log += "Total time: %f\n" % total_time 109 | log += "Time per epoch: %f\n" % (total_time / args.epochs) 110 | 111 | print(log) 112 | -------------------------------------------------------------------------------- /tf/util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | # M: minibatch size 5 | # n: number of items in each sequence 6 | # s: scores 7 | 8 | np.set_printoptions(precision=4, suppress=True) 9 | eps = 1e-20 10 | 11 | 12 | def bl_matmul(A, B): 13 | return tf.einsum('mij,jk->mik', A, B) 14 | 15 | 16 | def br_matmul(A, B): 17 | return tf.einsum('ij,mjk->mik', A, B) 18 | 19 | 20 | # s: M x n x 1 21 | # neuralsort(s): M x n x n 22 | def neuralsort(s, tau=1): 23 | A_s = s - tf.transpose(s, perm=[0, 2, 1]) 24 | A_s = tf.abs(A_s) 25 | # As_ij = |s_i - s_j| 26 | 27 | n = tf.shape(s)[1] 28 | one = tf.ones((n, 1), dtype=tf.float32) 29 | 30 | # B = bl_matmul(A_s, one @ tf.transpose(one)) # => NeuralSort O(n^3) BUG! 31 | B = (A_s @ one) @ tf.transpose(one) # => Bugfix 32 | # B_:k = (A_s)(one) 33 | 34 | K = tf.range(n) + 1 35 | # K_k = k 36 | 37 | # C = bl_matmul( 38 | # s, tf.expand_dims(tf.cast(n + 1 - 2 * K, dtype=tf.float32), 0) 39 | # ) 40 | C = ( 41 | s @ tf.expand_dims(tf.cast(n + 1 - 2 * K, dtype=tf.float32), 0) 42 | ) 43 | # C_:k = (n + 1 - 2k)s 44 | 45 | P = tf.transpose(C - B, perm=[0, 2, 1]) 46 | # P_k: = (n + 1 - 2k)s - (A_s)(one) 47 | 48 | P = tf.nn.softmax(P / tau, -1) 49 | # P_k: = softmax( ((n + 1 - 2k)s - (A_s)(one)) / tau ) 50 | 51 | return P 52 | 53 | 54 | # s: M x n x 1 55 | # softsort(s): M x n x n 56 | def softsort(s, tau=1, pow=1): 57 | s_sorted = tf.sort(s, direction='DESCENDING', axis=1) 58 | pairwise_distances = -tf.pow(tf.abs(tf.transpose(s, perm=[0, 2, 1]) - s_sorted), pow) 59 | P_hat = tf.nn.softmax(pairwise_distances / tau, -1) 60 | return P_hat 61 | 62 | 63 | # s: M x n x 1 64 | # softsort_p1(s): M x n x n 65 | def softsort_p1(s, tau=1): 66 | s_sorted = tf.sort(s, direction='DESCENDING', axis=1) 67 | pairwise_distances = -tf.abs(tf.transpose(s, perm=[0, 2, 1]) - s_sorted) 68 | P_hat = tf.nn.softmax(pairwise_distances / tau, -1) 69 | return P_hat 70 | 71 | 72 | # s: M x n x 1 73 | # softsort_p2(s): M x n x n 74 | def softsort_p2(s, tau=1): 75 | s_sorted = tf.sort(s, direction='DESCENDING', axis=1) 76 | pairwise_distances = -tf.square(tf.transpose(s, perm=[0, 2, 1]) - s_sorted) 77 | P_hat = tf.nn.softmax(pairwise_distances / tau, -1) 78 | return P_hat 79 | 80 | 81 | # Pi: M x n x n row-stochastic 82 | def prop_any_correct(P1, P2): 83 | z1 = tf.argmax(P1, axis=-1) 84 | z2 = tf.argmax(P2, axis=-1) 85 | eq = tf.equal(z1, z2) 86 | eq = tf.cast(eq, dtype=tf.float32) 87 | correct = tf.reduce_mean(eq, axis=-1) 88 | return tf.reduce_mean(correct) 89 | 90 | 91 | # Pi: M x n x n row-stochastic 92 | def prop_correct(P1, P2): 93 | z1 = tf.argmax(P1, axis=-1) 94 | z2 = tf.argmax(P2, axis=-1) 95 | eq = tf.equal(z1, z2) 96 | correct = tf.reduce_all(eq, axis=-1) 97 | return tf.reduce_mean(tf.cast(correct, tf.float32)) 98 | 99 | 100 | def sample_gumbel(shape, eps=1e-20): 101 | U = tf.random_uniform(shape, minval=0, maxval=1) 102 | return -tf.log(-tf.log(U + eps) + eps) 103 | 104 | 105 | # s: M x n 106 | # P: M x n x n 107 | # returns: M 108 | def pl_log_density(log_s, P): 109 | log_s = tf.expand_dims(log_s, 2) # M x n x 1 110 | ordered_log_s = P @ log_s # M x n x 1 111 | ordered_log_s = tf.squeeze(ordered_log_s, squeeze_dims=[-1]) # M x n 112 | potentials = tf.exp(ordered_log_s) 113 | n = log_s.get_shape().as_list()[1] 114 | max_log_s = [ 115 | tf.reduce_max(ordered_log_s[:, k:], axis=1, keepdims=True) 116 | for k in range(n) 117 | ] # [M x 1] x n 118 | adj_log_s = [ 119 | ordered_log_s - max_log_s[k] 120 | for k in range(n) 121 | ] # [M x n] x n 122 | potentials = [ 123 | tf.exp(adj_log_s[k][:, k:]) 124 | for k in range(n) 125 | ] # [M x n] x n 126 | denominators = [ 127 | tf.reduce_sum(potentials[k], axis=1, keepdims=True) 128 | for k in range(n) 129 | ] # [M x 1] x n 130 | log_denominators = [ 131 | tf.squeeze(tf.log(denominators[k]) + max_log_s[k], squeeze_dims=[1]) 132 | for k in range(n) 133 | ] # [M] x n 134 | log_denominator = tf.add_n(log_denominators) # M 135 | log_potentials = ordered_log_s # M x n x 1 136 | log_potential = tf.reduce_sum(log_potentials, 1) # M 137 | log_likelihood = log_potential - log_denominator 138 | return log_likelihood 139 | --------------------------------------------------------------------------------