├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── fast_soft_sort ├── __init__.py ├── jax_ops.py ├── numpy_ops.py ├── pytorch_ops.py ├── tf_ops.py └── third_party │ ├── LICENSE │ ├── __init__.py │ └── isotonic.py ├── setup.py └── tests ├── isotonic_test.py ├── jax_ops_test.py ├── numpy_ops_test.py ├── pytorch_ops_test.py └── tf_ops_test.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Version 2.0, January 2004 2 | http://www.apache.org/licenses/ 3 | 4 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 5 | 6 | 1. Definitions. 7 | 8 | "License" shall mean the terms and conditions for use, reproduction, 9 | and distribution as defined by Sections 1 through 9 of this document. 10 | 11 | "Licensor" shall mean the copyright owner or entity authorized by 12 | the copyright owner that is granting the License. 13 | 14 | "Legal Entity" shall mean the union of the acting entity and all 15 | other entities that control, are controlled by, or are under common 16 | control with that entity. For the purposes of this definition, 17 | "control" means (i) the power, direct or indirect, to cause the 18 | direction or management of such entity, whether by contract or 19 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 20 | outstanding shares, or (iii) beneficial ownership of such entity. 21 | 22 | "You" (or "Your") shall mean an individual or Legal Entity 23 | exercising permissions granted by this License. 24 | 25 | "Source" form shall mean the preferred form for making modifications, 26 | including but not limited to software source code, documentation 27 | source, and configuration files. 28 | 29 | "Object" form shall mean any form resulting from mechanical 30 | transformation or translation of a Source form, including but 31 | not limited to compiled object code, generated documentation, 32 | and conversions to other media types. 33 | 34 | "Work" shall mean the work of authorship, whether in Source or 35 | Object form, made available under the License, as indicated by a 36 | copyright notice that is included in or attached to the work 37 | (an example is provided in the Appendix below). 38 | 39 | "Derivative Works" shall mean any work, whether in Source or Object 40 | form, that is based on (or derived from) the Work and for which the 41 | editorial revisions, annotations, elaborations, or other modifications 42 | represent, as a whole, an original work of authorship. For the purposes 43 | of this License, Derivative Works shall not include works that remain 44 | separable from, or merely link (or bind by name) to the interfaces of, 45 | the Work and Derivative Works thereof. 46 | 47 | "Contribution" shall mean any work of authorship, including 48 | the original version of the Work and any modifications or additions 49 | to that Work or Derivative Works thereof, that is intentionally 50 | submitted to Licensor for inclusion in the Work by the copyright owner 51 | or by an individual or Legal Entity authorized to submit on behalf of 52 | the copyright owner. For the purposes of this definition, "submitted" 53 | means any form of electronic, verbal, or written communication sent 54 | to the Licensor or its representatives, including but not limited to 55 | communication on electronic mailing lists, source code control systems, 56 | and issue tracking systems that are managed by, or on behalf of, the 57 | Licensor for the purpose of discussing and improving the Work, but 58 | excluding communication that is conspicuously marked or otherwise 59 | designated in writing by the copyright owner as "Not a Contribution." 60 | 61 | "Contributor" shall mean Licensor and any individual or Legal Entity 62 | on behalf of whom a Contribution has been received by Licensor and 63 | subsequently incorporated within the Work. 64 | 65 | 2. Grant of Copyright License. Subject to the terms and conditions of 66 | this License, each Contributor hereby grants to You a perpetual, 67 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 68 | copyright license to reproduce, prepare Derivative Works of, 69 | publicly display, publicly perform, sublicense, and distribute the 70 | Work and such Derivative Works in Source or Object form. 71 | 72 | 3. Grant of Patent License. Subject to the terms and conditions of 73 | this License, each Contributor hereby grants to You a perpetual, 74 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 75 | (except as stated in this section) patent license to make, have made, 76 | use, offer to sell, sell, import, and otherwise transfer the Work, 77 | where such license applies only to those patent claims licensable 78 | by such Contributor that are necessarily infringed by their 79 | Contribution(s) alone or by combination of their Contribution(s) 80 | with the Work to which such Contribution(s) was submitted. If You 81 | institute patent litigation against any entity (including a 82 | cross-claim or counterclaim in a lawsuit) alleging that the Work 83 | or a Contribution incorporated within the Work constitutes direct 84 | or contributory patent infringement, then any patent licenses 85 | granted to You under this License for that Work shall terminate 86 | as of the date such litigation is filed. 87 | 88 | 4. Redistribution. You may reproduce and distribute copies of the 89 | Work or Derivative Works thereof in any medium, with or without 90 | modifications, and in Source or Object form, provided that You 91 | meet the following conditions: 92 | 93 | (a) You must give any other recipients of the Work or 94 | Derivative Works a copy of this License; and 95 | 96 | (b) You must cause any modified files to carry prominent notices 97 | stating that You changed the files; and 98 | 99 | (c) You must retain, in the Source form of any Derivative Works 100 | that You distribute, all copyright, patent, trademark, and 101 | attribution notices from the Source form of the Work, 102 | excluding those notices that do not pertain to any part of 103 | the Derivative Works; and 104 | 105 | (d) If the Work includes a "NOTICE" text file as part of its 106 | distribution, then any Derivative Works that You distribute must 107 | include a readable copy of the attribution notices contained 108 | within such NOTICE file, excluding those notices that do not 109 | pertain to any part of the Derivative Works, in at least one 110 | of the following places: within a NOTICE text file distributed 111 | as part of the Derivative Works; within the Source form or 112 | documentation, if provided along with the Derivative Works; or, 113 | within a display generated by the Derivative Works, if and 114 | wherever such third-party notices normally appear. The contents 115 | of the NOTICE file are for informational purposes only and 116 | do not modify the License. You may add Your own attribution 117 | notices within Derivative Works that You distribute, alongside 118 | or as an addendum to the NOTICE text from the Work, provided 119 | that such additional attribution notices cannot be construed 120 | as modifying the License. 121 | 122 | You may add Your own copyright statement to Your modifications and 123 | may provide additional or different license terms and conditions 124 | for use, reproduction, or distribution of Your modifications, or 125 | for any such Derivative Works as a whole, provided Your use, 126 | reproduction, and distribution of the Work otherwise complies with 127 | the conditions stated in this License. 128 | 129 | 5. Submission of Contributions. Unless You explicitly state otherwise, 130 | any Contribution intentionally submitted for inclusion in the Work 131 | by You to the Licensor shall be under the terms and conditions of 132 | this License, without any additional terms or conditions. 133 | Notwithstanding the above, nothing herein shall supersede or modify 134 | the terms of any separate license agreement you may have executed 135 | with Licensor regarding such Contributions. 136 | 137 | 6. Trademarks. This License does not grant permission to use the trade 138 | names, trademarks, service marks, or product names of the Licensor, 139 | except as required for reasonable and customary use in describing the 140 | origin of the Work and reproducing the content of the NOTICE file. 141 | 142 | 7. Disclaimer of Warranty. Unless required by applicable law or 143 | agreed to in writing, Licensor provides the Work (and each 144 | Contributor provides its Contributions) on an "AS IS" BASIS, 145 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 146 | implied, including, without limitation, any warranties or conditions 147 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 148 | PARTICULAR PURPOSE. You are solely responsible for determining the 149 | appropriateness of using or redistributing the Work and assume any 150 | risks associated with Your exercise of permissions under this License. 151 | 152 | 8. Limitation of Liability. In no event and under no legal theory, 153 | whether in tort (including negligence), contract, or otherwise, 154 | unless required by applicable law (such as deliberate and grossly 155 | negligent acts) or agreed to in writing, shall any Contributor be 156 | liable to You for damages, including any direct, indirect, special, 157 | incidental, or consequential damages of any character arising as a 158 | result of this License or out of the use or inability to use the 159 | Work (including but not limited to damages for loss of goodwill, 160 | work stoppage, computer failure or malfunction, or any and all 161 | other commercial damages or losses), even if such Contributor 162 | has been advised of the possibility of such damages. 163 | 164 | 9. Accepting Warranty or Additional Liability. While redistributing 165 | the Work or Derivative Works thereof, You may choose to offer, 166 | and charge a fee for, acceptance of support, warranty, indemnity, 167 | or other liability obligations and/or rights consistent with this 168 | License. However, in accepting such obligations, You may act only 169 | on Your own behalf and on Your sole responsibility, not on behalf 170 | of any other Contributor, and only if You agree to indemnify, 171 | defend, and hold each Contributor harmless for any liability 172 | incurred by, or claims asserted against, such Contributor by reason 173 | of your accepting any such warranty or additional liability. 174 | 175 | END OF TERMS AND CONDITIONS 176 | 177 | APPENDIX: How to apply the Apache License to your work. 178 | 179 | To apply the Apache License to your work, attach the following 180 | boilerplate notice, with the fields enclosed by brackets "[]" 181 | replaced with your own identifying information. (Don't include 182 | the brackets!) The text should be enclosed in the appropriate 183 | comment syntax for the file format. We also recommend that a 184 | file or class name and description of purpose be included on the 185 | same "printed page" as the copyright notice for easier 186 | identification within third-party archives. 187 | 188 | Copyright [yyyy] [name of copyright owner] 189 | 190 | Licensed under the Apache License, Version 2.0 (the "License"); 191 | you may not use this file except in compliance with the License. 192 | You may obtain a copy of the License at 193 | 194 | http://www.apache.org/licenses/LICENSE-2.0 195 | 196 | Unless required by applicable law or agreed to in writing, software 197 | distributed under the License is distributed on an "AS IS" BASIS, 198 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 199 | See the License for the specific language governing permissions and 200 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Fast Differentiable Sorting and Ranking 3 | ======================================= 4 | 5 | Differentiable sorting and ranking operations in O(n log n). 6 | 7 | Dependencies 8 | ------------ 9 | 10 | * NumPy 11 | * SciPy 12 | * Numba 13 | * Tensorflow (optional) 14 | * PyTorch (optional) 15 | 16 | TensorFlow Example 17 | ------------------- 18 | 19 | ```python 20 | >>> import tensorflow as tf 21 | >>> from fast_soft_sort.tf_ops import soft_rank, soft_sort 22 | >>> values = tf.convert_to_tensor([[5., 1., 2.], [2., 1., 5.]], dtype=tf.float64) 23 | >>> soft_sort(values, regularization_strength=1.0) 24 | 25 | >>> soft_sort(values, regularization_strength=0.1) 26 | 27 | >>> soft_rank(values, regularization_strength=2.0) 28 | 29 | >>> soft_rank(values, regularization_strength=1.0) 30 | 31 | ``` 32 | 33 | JAX Example 34 | ----------- 35 | 36 | ```python 37 | >>> import jax.numpy as jnp 38 | >>> from fast_soft_sort.jax_ops import soft_rank, soft_sort 39 | >>> values = jnp.array([[5., 1., 2.], [2., 1., 5.]], dtype=jnp.float64) 40 | >>> soft_sort(values, regularization_strength=1.0) 41 | [[1.66666667 2.66666667 3.66666667] 42 | [1.66666667 2.66666667 3.66666667]] 43 | >>> soft_sort(values, regularization_strength=0.1) 44 | [[1. 2. 5.] 45 | [1. 2. 5.]] 46 | >>> soft_rank(values, regularization_strength=2.0) 47 | [[3. 1.25 1.75] 48 | [1.75 1.25 3. ]] 49 | >>> soft_rank(values, regularization_strength=1.0) 50 | [[3. 1. 2.] 51 | [2. 1. 3.]] 52 | ``` 53 | 54 | PyTorch Example 55 | --------------- 56 | 57 | ```python 58 | >>> import torch 59 | >>> from pytorch_ops import soft_rank, soft_sort 60 | >>> values = fast_soft_sort.torch.tensor([[5., 1., 2.], [2., 1., 5.]], dtype=torch.float64) 61 | >>> soft_sort(values, regularization_strength=1.0) 62 | tensor([[1.6667, 2.6667, 3.6667] 63 | [1.6667, 2.6667, 3.6667]], dtype=torch.float64) 64 | >>> soft_sort(values, regularization_strength=0.1) 65 | tensor([[1., 2., 5.] 66 | [1., 2., 5.]], dtype=torch.float64) 67 | >>> soft_rank(values, regularization_strength=2.0) 68 | tensor([[3.0000, 1.2500, 1.7500], 69 | [1.7500, 1.2500, 3.0000]], dtype=torch.float64) 70 | >>> soft_rank(values, regularization_strength=1.0) 71 | tensor([[3., 1., 2.] 72 | [2., 1., 3.]], dtype=torch.float64) 73 | ``` 74 | 75 | 76 | Install 77 | -------- 78 | 79 | Run `python setup.py install` or copy the `fast_soft_sort/` folder to your 80 | project. 81 | 82 | 83 | Reference 84 | ------------ 85 | 86 | > Fast Differentiable Sorting and Ranking 87 | > Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga 88 | > In proceedings of ICML 2020 89 | > [arXiv:2002.08871](https://arxiv.org/abs/2002.08871) 90 | -------------------------------------------------------------------------------- /fast_soft_sort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/fast-soft-sort/6a52ce79869ab16e1e0f39149a84f50f8ad648c5/fast_soft_sort/__init__.py -------------------------------------------------------------------------------- /fast_soft_sort/jax_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """JAX operators for soft sorting and ranking. 16 | 17 | Fast Differentiable Sorting and Ranking 18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga 19 | https://arxiv.org/abs/2002.08871 20 | """ 21 | 22 | from . import numpy_ops 23 | import jax 24 | import numpy as np 25 | import jax.numpy as jnp 26 | from jax import tree_util 27 | 28 | 29 | def _wrap_numpy_op(cls, **kwargs): 30 | """Converts NumPy operator to a JAX one.""" 31 | 32 | def _func_fwd(values): 33 | """Converts values to numpy array, applies function and returns array.""" 34 | dtype = values.dtype 35 | values = np.array(values) 36 | obj = cls(values, **kwargs) 37 | result = obj.compute() 38 | return jnp.array(result, dtype=dtype), tree_util.Partial(obj.vjp) 39 | 40 | def _func_bwd(vjp, g): 41 | g = np.array(g) 42 | result = jnp.array(vjp(g), dtype=g.dtype) 43 | return (result,) 44 | 45 | @jax.custom_vjp 46 | def _func(values): 47 | return _func_fwd(values)[0] 48 | 49 | _func.defvjp(_func_fwd, _func_bwd) 50 | 51 | return _func 52 | 53 | 54 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0, 55 | regularization="l2"): 56 | r"""Soft rank the given values (array) along the second axis. 57 | 58 | The regularization strength determines how close are the returned values 59 | to the actual ranks. 60 | 61 | Args: 62 | values: A 2d-array holding the numbers to be ranked. 63 | direction: Either 'ASCENDING' or 'DESCENDING'. 64 | regularization_strength: The regularization strength to be used. The smaller 65 | this number, the closer the values to the true ranks. 66 | regularization: Which regularization method to use. It 67 | must be set to one of ("l2", "kl", "log_kl"). 68 | Returns: 69 | A 2d-array, soft-ranked along the second axis. 70 | """ 71 | if len(values.shape) != 2: 72 | raise ValueError("'values' should be a 2d-array " 73 | "but got %r." % values.shape) 74 | 75 | func = _wrap_numpy_op(numpy_ops.SoftRank, 76 | regularization_strength=regularization_strength, 77 | direction=direction, 78 | regularization=regularization) 79 | 80 | return jnp.vstack([func(val) for val in values]) 81 | 82 | 83 | def soft_sort(values, direction="ASCENDING", 84 | regularization_strength=1.0, regularization="l2"): 85 | r"""Soft sort the given values (array) along the second axis. 86 | 87 | The regularization strength determines how close are the returned values 88 | to the actual sorted values. 89 | 90 | Args: 91 | values: A 2d-array holding the numbers to be sorted. 92 | direction: Either 'ASCENDING' or 'DESCENDING'. 93 | regularization_strength: The regularization strength to be used. The smaller 94 | this number, the closer the values to the true sorted values. 95 | regularization: Which regularization method to use. It 96 | must be set to one of ("l2", "log_kl"). 97 | Returns: 98 | A 2d-array, soft-sorted along the second axis. 99 | """ 100 | if len(values.shape) != 2: 101 | raise ValueError("'values' should be a 2d-array " 102 | "but got %s." % str(values.shape)) 103 | 104 | func = _wrap_numpy_op(numpy_ops.SoftSort, 105 | regularization_strength=regularization_strength, 106 | direction=direction, 107 | regularization=regularization) 108 | 109 | return jnp.vstack([func(val) for val in values]) 110 | -------------------------------------------------------------------------------- /fast_soft_sort/numpy_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Numpy operators for soft sorting and ranking. 16 | 17 | Fast Differentiable Sorting and Ranking 18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga 19 | https://arxiv.org/abs/2002.08871 20 | 21 | This implementation follows the notation of the paper whenever possible. 22 | """ 23 | 24 | from .third_party import isotonic 25 | import numpy as np 26 | from scipy import special 27 | 28 | 29 | def isotonic_l2(input_s, input_w=None): 30 | """Solves an isotonic regression problem using PAV. 31 | 32 | Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - (s-w)||^2. 33 | 34 | Args: 35 | input_s: input to isotonic regression, a 1d-array. 36 | input_w: input to isotonic regression, a 1d-array. 37 | Returns: 38 | solution to the optimization problem. 39 | """ 40 | if input_w is None: 41 | input_w = np.arange(len(input_s))[::-1] + 1 42 | input_w = input_w.astype(input_s.dtype) 43 | solution = np.zeros_like(input_s) 44 | isotonic.isotonic_l2(input_s - input_w, solution) 45 | return solution 46 | 47 | 48 | def isotonic_kl(input_s, input_w=None): 49 | """Solves isotonic optimization with KL divergence using PAV. 50 | 51 | Formally, it solves argmin_{v_1 >= ... >= v_n} + . 52 | 53 | Args: 54 | input_s: input to isotonic optimization, a 1d-array. 55 | input_w: input to isotonic optimization, a 1d-array. 56 | Returns: 57 | solution to the optimization problem (same dtype as input_s). 58 | """ 59 | if input_w is None: 60 | input_w = np.arange(len(input_s))[::-1] + 1 61 | input_w = input_w.astype(input_s.dtype) 62 | solution = np.zeros(len(input_s)).astype(input_s.dtype) 63 | isotonic.isotonic_kl(input_s, input_w.astype(input_s.dtype), solution) 64 | return solution 65 | 66 | 67 | def _partition(solution, eps=1e-9): 68 | """Returns partition corresponding to solution.""" 69 | # pylint: disable=g-explicit-length-test 70 | if len(solution) == 0: 71 | return [] 72 | 73 | sizes = [1] 74 | 75 | for i in range(1, len(solution)): 76 | if abs(solution[i] - solution[i - 1]) > eps: 77 | sizes.append(0) 78 | sizes[-1] += 1 79 | 80 | return sizes 81 | 82 | 83 | def _check_regularization(regularization): 84 | if regularization not in ("l2", "kl"): 85 | raise ValueError("'regularization' should be either 'l2' or 'kl' " 86 | "but got %s." % str(regularization)) 87 | 88 | 89 | class _Differentiable(object): 90 | """Base class for differentiable operators.""" 91 | 92 | def jacobian(self): 93 | """Computes Jacobian.""" 94 | identity = np.eye(self.size) 95 | return np.array([self.jvp(identity[i]) for i in range(len(identity))]).T 96 | 97 | @property 98 | def size(self): 99 | raise NotImplementedError 100 | 101 | def compute(self): 102 | """Computes the desired quantity.""" 103 | raise NotImplementedError 104 | 105 | def jvp(self, vector): 106 | """Computes Jacobian vector product.""" 107 | raise NotImplementedError 108 | 109 | def vjp(self, vector): 110 | """Computes vector Jacobian product.""" 111 | raise NotImplementedError 112 | 113 | 114 | class Isotonic(_Differentiable): 115 | """Isotonic optimization.""" 116 | 117 | def __init__(self, input_s, input_w, regularization="l2"): 118 | self.input_s = input_s 119 | self.input_w = input_w 120 | _check_regularization(regularization) 121 | self.regularization = regularization 122 | self.solution_ = None 123 | 124 | @property 125 | def size(self): 126 | return len(self.input_s) 127 | 128 | def compute(self): 129 | 130 | if self.regularization == "l2": 131 | self.solution_ = isotonic_l2(self.input_s, self.input_w) 132 | else: 133 | self.solution_ = isotonic_kl(self.input_s, self.input_w) 134 | return self.solution_ 135 | 136 | def _check_computed(self): 137 | if self.solution_ is None: 138 | raise RuntimeError("Need to run compute() first.") 139 | 140 | def jvp(self, vector): 141 | self._check_computed() 142 | start = 0 143 | return_value = np.zeros_like(self.solution_) 144 | for size in _partition(self.solution_): 145 | end = start + size 146 | if self.regularization == "l2": 147 | val = np.mean(vector[start:end]) 148 | else: 149 | val = np.dot(special.softmax(self.input_s[start:end]), 150 | vector[start:end]) 151 | return_value[start:end] = val 152 | start = end 153 | return return_value 154 | 155 | def vjp(self, vector): 156 | start = 0 157 | return_value = np.zeros_like(self.solution_) 158 | for size in _partition(self.solution_): 159 | end = start + size 160 | if self.regularization == "l2": 161 | val = 1. / size 162 | else: 163 | val = special.softmax(self.input_s[start:end]) 164 | return_value[start:end] = val * np.sum(vector[start:end]) 165 | start = end 166 | return return_value 167 | 168 | 169 | def _inv_permutation(permutation): 170 | """Returns inverse permutation of 'permutation'.""" 171 | inv_permutation = np.zeros(len(permutation), dtype=int) 172 | inv_permutation[permutation] = np.arange(len(permutation)) 173 | return inv_permutation 174 | 175 | 176 | class Projection(_Differentiable): 177 | """Computes projection onto the permutahedron P(w).""" 178 | 179 | def __init__(self, input_theta, input_w=None, regularization="l2"): 180 | if input_w is None: 181 | input_w = np.arange(len(input_theta))[::-1] + 1 182 | self.input_theta = np.asarray(input_theta) 183 | self.input_w = np.asarray(input_w) 184 | _check_regularization(regularization) 185 | self.regularization = regularization 186 | self.isotonic = None 187 | 188 | def _check_computed(self): 189 | if self.isotonic_ is None: 190 | raise ValueError("Need to run compute() first.") 191 | 192 | @property 193 | def size(self): 194 | return len(self.input_theta) 195 | 196 | def compute(self): 197 | self.permutation = np.argsort(self.input_theta)[::-1] 198 | input_s = self.input_theta[self.permutation] 199 | 200 | self.isotonic_ = Isotonic(input_s, self.input_w, self.regularization) 201 | dual_sol = self.isotonic_.compute() 202 | primal_sol = input_s - dual_sol 203 | 204 | self.inv_permutation = _inv_permutation(self.permutation) 205 | return primal_sol[self.inv_permutation] 206 | 207 | def jvp(self, vector): 208 | self._check_computed() 209 | ret = vector.copy() 210 | ret -= self.isotonic_.jvp(vector[self.permutation])[self.inv_permutation] 211 | return ret 212 | 213 | def vjp(self, vector): 214 | self._check_computed() 215 | ret = vector.copy() 216 | ret -= self.isotonic_.vjp(vector[self.permutation])[self.inv_permutation] 217 | return ret 218 | 219 | 220 | def _check_direction(direction): 221 | if direction not in ("ASCENDING", "DESCENDING"): 222 | raise ValueError("direction should be either 'ASCENDING' or 'DESCENDING'") 223 | 224 | 225 | class SoftRank(_Differentiable): 226 | """Soft ranking.""" 227 | 228 | def __init__(self, values, direction="ASCENDING", 229 | regularization_strength=1.0, regularization="l2"): 230 | self.values = np.asarray(values) 231 | self.input_w = np.arange(len(values))[::-1] + 1 232 | _check_direction(direction) 233 | sign = 1 if direction == "ASCENDING" else -1 234 | self.scale = sign / regularization_strength 235 | _check_regularization(regularization) 236 | self.regularization = regularization 237 | self.projection_ = None 238 | 239 | @property 240 | def size(self): 241 | return len(self.values) 242 | 243 | def _check_computed(self): 244 | if self.projection_ is None: 245 | raise ValueError("Need to run compute() first.") 246 | 247 | def compute(self): 248 | if self.regularization == "kl": 249 | self.projection_ = Projection( 250 | self.values * self.scale, 251 | np.log(self.input_w), 252 | regularization=self.regularization) 253 | self.factor = np.exp(self.projection_.compute()) 254 | return self.factor 255 | else: 256 | self.projection_ = Projection( 257 | self.values * self.scale, self.input_w, 258 | regularization=self.regularization) 259 | self.factor = 1.0 260 | return self.projection_.compute() 261 | 262 | def jvp(self, vector): 263 | self._check_computed() 264 | return self.factor * self.projection_.jvp(vector) * self.scale 265 | 266 | def vjp(self, vector): 267 | self._check_computed() 268 | return self.projection_.vjp(self.factor * vector) * self.scale 269 | 270 | 271 | class SoftSort(_Differentiable): 272 | """Soft sorting.""" 273 | 274 | def __init__(self, values, direction="ASCENDING", 275 | regularization_strength=1.0, regularization="l2"): 276 | self.values = np.asarray(values) 277 | _check_direction(direction) 278 | self.sign = 1 if direction == "DESCENDING" else -1 279 | self.regularization_strength = regularization_strength 280 | _check_regularization(regularization) 281 | self.regularization = regularization 282 | self.isotonic_ = None 283 | 284 | @property 285 | def size(self): 286 | return len(self.values) 287 | 288 | def _check_computed(self): 289 | if self.isotonic_ is None: 290 | raise ValueError("Need to run compute() first.") 291 | 292 | def compute(self): 293 | size = len(self.values) 294 | input_w = np.arange(1, size + 1)[::-1] / self.regularization_strength 295 | values = self.sign * self.values 296 | self.permutation_ = np.argsort(values)[::-1] 297 | s = values[self.permutation_] 298 | 299 | self.isotonic_ = Isotonic(input_w, s, regularization=self.regularization) 300 | res = self.isotonic_.compute() 301 | 302 | # We set s as the first argument as we want the derivatives w.r.t. s. 303 | self.isotonic_.s = s 304 | return self.sign * (input_w - res) 305 | 306 | def jvp(self, vector): 307 | self._check_computed() 308 | return self.isotonic_.jvp(vector[self.permutation_]) 309 | 310 | def vjp(self, vector): 311 | self._check_computed() 312 | inv_permutation = _inv_permutation(self.permutation_) 313 | return self.isotonic_.vjp(vector)[inv_permutation] 314 | 315 | 316 | class Sort(_Differentiable): 317 | """Hard sorting.""" 318 | 319 | def __init__(self, values, direction="ASCENDING"): 320 | _check_direction(direction) 321 | self.values = np.asarray(values) 322 | self.sign = 1 if direction == "DESCENDING" else -1 323 | self.permutation_ = None 324 | 325 | @property 326 | def size(self): 327 | return len(self.values) 328 | 329 | def _check_computed(self): 330 | if self.permutation_ is None: 331 | raise ValueError("Need to run compute() first.") 332 | 333 | def compute(self): 334 | self.permutation_ = np.argsort(self.sign * self.values)[::-1] 335 | return self.values[self.permutation_] 336 | 337 | def jvp(self, vector): 338 | self._check_computed() 339 | return vector[self.permutation_] 340 | 341 | def vjp(self, vector): 342 | self._check_computed() 343 | inv_permutation = _inv_permutation(self.permutation_) 344 | return vector[inv_permutation] 345 | 346 | 347 | # Small utility functions for the case when we just want the forward 348 | # computation. 349 | 350 | 351 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0, 352 | regularization="l2"): 353 | r"""Soft rank the given values. 354 | 355 | The regularization strength determines how close are the returned values 356 | to the actual ranks. 357 | 358 | Args: 359 | values: A 1d-array holding the numbers to be ranked. 360 | direction: Either 'ASCENDING' or 'DESCENDING'. 361 | regularization_strength: The regularization strength to be used. The smaller 362 | this number, the closer the values to the true ranks. 363 | regularization: Which regularization method to use. It 364 | must be set to one of ("l2", "kl", "log_kl"). 365 | Returns: 366 | A 1d-array, soft-ranked. 367 | """ 368 | return SoftRank(values, regularization_strength=regularization_strength, 369 | direction=direction, regularization=regularization).compute() 370 | 371 | 372 | def soft_sort(values, direction="ASCENDING", regularization_strength=1.0, 373 | regularization="l2"): 374 | r"""Soft sort the given values. 375 | 376 | Args: 377 | values: A 1d-array holding the numbers to be sorted. 378 | direction: Either 'ASCENDING' or 'DESCENDING'. 379 | regularization_strength: The regularization strength to be used. The smaller 380 | this number, the closer the values to the true sorted values. 381 | regularization: Which regularization method to use. It 382 | must be set to one of ("l2", "log_kl"). 383 | Returns: 384 | A 1d-array, soft-sorted. 385 | """ 386 | return SoftSort(values, regularization_strength=regularization_strength, 387 | direction=direction, regularization=regularization).compute() 388 | 389 | 390 | def sort(values, direction="ASCENDING"): 391 | r"""Sort the given values. 392 | 393 | Args: 394 | values: A 1d-array holding the numbers to be sorted. 395 | direction: Either 'ASCENDING' or 'DESCENDING'. 396 | Returns: 397 | A 1d-array, sorted. 398 | """ 399 | return Sort(values, direction=direction).compute() 400 | 401 | 402 | def rank(values, direction="ASCENDING"): 403 | r"""Rank the given values. 404 | 405 | Args: 406 | values: A 1d-array holding the numbers to be ranked. 407 | direction: Either 'ASCENDING' or 'DESCENDING'. 408 | Returns: 409 | A 1d-array, ranked. 410 | """ 411 | permutation = np.argsort(values) 412 | if direction == "DESCENDING": 413 | permutation = permutation[::-1] 414 | return _inv_permutation(permutation) + 1 # We use 1-based indexing. 415 | -------------------------------------------------------------------------------- /fast_soft_sort/pytorch_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PyTorch operators for soft sorting and ranking. 16 | 17 | Fast Differentiable Sorting and Ranking 18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga 19 | https://arxiv.org/abs/2002.08871 20 | """ 21 | 22 | from . import numpy_ops 23 | import torch 24 | 25 | 26 | def wrap_class(cls, **kwargs): 27 | """Wraps the given NumpyOp in a torch Function.""" 28 | 29 | class NumpyOpWrapper(torch.autograd.Function): 30 | """A torch Function wrapping a NumpyOp.""" 31 | 32 | @staticmethod 33 | def forward(ctx, values): 34 | obj = cls(values.detach().numpy(), **kwargs) 35 | ctx.numpy_obj = obj 36 | return torch.from_numpy(obj.compute()) 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | return torch.from_numpy(ctx.numpy_obj.vjp(grad_output.numpy())) 41 | 42 | return NumpyOpWrapper 43 | 44 | 45 | def map_tensor(map_fn, tensor): 46 | return torch.stack([map_fn(tensor_i) for tensor_i in torch.unbind(tensor)]) 47 | 48 | 49 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0, 50 | regularization="l2"): 51 | r"""Soft rank the given values (tensor) along the second axis. 52 | 53 | The regularization strength determines how close are the returned values 54 | to the actual ranks. 55 | 56 | Args: 57 | values: A 2d-tensor holding the numbers to be ranked. 58 | direction: Either 'ASCENDING' or 'DESCENDING'. 59 | regularization_strength: The regularization strength to be used. The smaller 60 | this number, the closer the values to the true ranks. 61 | regularization: Which regularization method to use. It 62 | must be set to one of ("l2", "kl", "log_kl"). 63 | Returns: 64 | A 2d-tensor, soft-ranked along the second axis. 65 | """ 66 | if len(values.shape) != 2: 67 | raise ValueError("'values' should be a 2d-tensor " 68 | "but got %r." % values.shape) 69 | 70 | wrapped_fn = wrap_class(numpy_ops.SoftRank, 71 | regularization_strength=regularization_strength, 72 | direction=direction, 73 | regularization=regularization) 74 | return map_tensor(wrapped_fn.apply, values) 75 | 76 | 77 | def soft_sort(values, direction="ASCENDING", 78 | regularization_strength=1.0, regularization="l2"): 79 | r"""Soft sort the given values (tensor) along the second axis. 80 | 81 | The regularization strength determines how close are the returned values 82 | to the actual sorted values. 83 | 84 | Args: 85 | values: A 2d-tensor holding the numbers to be sorted. 86 | direction: Either 'ASCENDING' or 'DESCENDING'. 87 | regularization_strength: The regularization strength to be used. The smaller 88 | this number, the closer the values to the true sorted values. 89 | regularization: Which regularization method to use. It 90 | must be set to one of ("l2", "log_kl"). 91 | Returns: 92 | A 2d-tensor, soft-sorted along the second axis. 93 | """ 94 | if len(values.shape) != 2: 95 | raise ValueError("'values' should be a 2d-tensor " 96 | "but got %s." % str(values.shape)) 97 | 98 | wrapped_fn = wrap_class(numpy_ops.SoftSort, 99 | regularization_strength=regularization_strength, 100 | direction=direction, 101 | regularization=regularization) 102 | 103 | return map_tensor(wrapped_fn.apply, values) 104 | -------------------------------------------------------------------------------- /fast_soft_sort/tf_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tensorflow operators for soft sorting and ranking. 16 | 17 | Fast Differentiable Sorting and Ranking 18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga 19 | https://arxiv.org/abs/2002.08871 20 | """ 21 | 22 | from . import numpy_ops 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | def _wrap_numpy_op(cls, regularization_strength, direction, regularization): 27 | """Converts NumPy operator to a TF one.""" 28 | 29 | @tf.custom_gradient 30 | def _func(values): 31 | """Converts values to numpy array, applies function and returns tensor.""" 32 | dtype = values.dtype 33 | 34 | try: 35 | values = values.numpy() 36 | except AttributeError: 37 | pass 38 | 39 | obj = cls(values, regularization_strength=regularization_strength, 40 | direction=direction, regularization=regularization) 41 | result = obj.compute() 42 | 43 | def grad(v): 44 | v = v.numpy() 45 | return tf.convert_to_tensor(obj.vjp(v), dtype=dtype) 46 | 47 | return tf.convert_to_tensor(result, dtype=dtype), grad 48 | 49 | return _func 50 | 51 | 52 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0, 53 | regularization="l2"): 54 | r"""Soft rank the given values (tensor) along the second axis. 55 | 56 | The regularization strength determines how close are the returned values 57 | to the actual ranks. 58 | 59 | Args: 60 | values: A 2d-tensor holding the numbers to be ranked. 61 | direction: Either 'ASCENDING' or 'DESCENDING'. 62 | regularization_strength: The regularization strength to be used. The smaller 63 | this number, the closer the values to the true ranks. 64 | regularization: Which regularization method to use. It 65 | must be set to one of ("l2", "kl", "log_kl"). 66 | Returns: 67 | A 2d-tensor, soft-ranked along the second axis. 68 | """ 69 | if len(values.shape) != 2: 70 | raise ValueError("'values' should be a 2d-tensor " 71 | "but got %r." % values.shape) 72 | 73 | assert tf.executing_eagerly() 74 | 75 | func = _wrap_numpy_op(numpy_ops.SoftRank, regularization_strength, direction, 76 | regularization) 77 | 78 | return tf.map_fn(func, values) 79 | 80 | 81 | def soft_sort(values, direction="ASCENDING", 82 | regularization_strength=1.0, regularization="l2"): 83 | r"""Soft sort the given values (tensor) along the second axis. 84 | 85 | The regularization strength determines how close are the returned values 86 | to the actual sorted values. 87 | 88 | Args: 89 | values: A 2d-tensor holding the numbers to be sorted. 90 | direction: Either 'ASCENDING' or 'DESCENDING'. 91 | regularization_strength: The regularization strength to be used. The smaller 92 | this number, the closer the values to the true sorted values. 93 | regularization: Which regularization method to use. It 94 | must be set to one of ("l2", "log_kl"). 95 | Returns: 96 | A 2d-tensor, soft-sorted along the second axis. 97 | """ 98 | if len(values.shape) != 2: 99 | raise ValueError("'values' should be a 2d-tensor " 100 | "but got %s." % str(values.shape)) 101 | 102 | assert tf.executing_eagerly() 103 | 104 | func = _wrap_numpy_op(numpy_ops.SoftSort, regularization_strength, direction, 105 | regularization) 106 | 107 | return tf.map_fn(func, values) 108 | -------------------------------------------------------------------------------- /fast_soft_sort/third_party/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2007–2020 The scikit-learn developers. 2 | Copyright (c) 2020 Google LLC. 3 | All rights reserved. 4 | 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | a. Redistributions of source code must retain the above copyright notice, 10 | this list of conditions and the following disclaimer. 11 | b. Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | c. Neither the name of the Scikit-learn Developers nor the names of 15 | its contributors may be used to endorse or promote products 16 | derived from this software without specific prior written 17 | permission. 18 | 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 23 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 24 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 28 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 29 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 30 | DAMAGE. -------------------------------------------------------------------------------- /fast_soft_sort/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/fast-soft-sort/6a52ce79869ab16e1e0f39149a84f50f8ad648c5/fast_soft_sort/third_party/__init__.py -------------------------------------------------------------------------------- /fast_soft_sort/third_party/isotonic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2007-2020 The scikit-learn developers. 2 | # Copyright 2020 Google LLC. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # 8 | # a. Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # b. Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # c. Neither the name of the Scikit-learn Developers nor the names of 14 | # its contributors may be used to endorse or promote products 15 | # derived from this software without specific prior written 16 | # permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR 22 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 26 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 27 | # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 28 | # DAMAGE. 29 | 30 | """Isotonic optimization routines in Numba.""" 31 | 32 | import warnings 33 | import numpy as np 34 | 35 | # pylint: disable=g-import-not-at-top 36 | try: 37 | from numba import njit 38 | except ImportError: 39 | warnings.warn("Numba could not be imported. Code will run much more slowly." 40 | " To install, please run 'pip install numba'.") 41 | 42 | # If Numba is not available, we define a dummy 'njit' function. 43 | def njit(func): 44 | return func 45 | 46 | 47 | # Copied from scikit-learn with the following modifications: 48 | # - use decreasing constraints by default, 49 | # - do not return solution in place, rather save in array `sol`, 50 | # - avoid some needless multiplications. 51 | 52 | 53 | @njit 54 | def isotonic_l2(y, sol): 55 | """Solves an isotonic regression problem using PAV. 56 | 57 | Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - y||^2. 58 | 59 | Args: 60 | y: input to isotonic regression, a 1d-array. 61 | sol: where to write the solution, an array of the same size as y. 62 | """ 63 | n = y.shape[0] 64 | target = np.arange(n) 65 | c = np.ones(n) 66 | sums = np.zeros(n) 67 | 68 | # target describes a list of blocks. At any time, if [i..j] (inclusive) is 69 | # an active block, then target[i] := j and target[j] := i. 70 | 71 | for i in range(n): 72 | sol[i] = y[i] 73 | sums[i] = y[i] 74 | 75 | i = 0 76 | while i < n: 77 | k = target[i] + 1 78 | if k == n: 79 | break 80 | if sol[i] > sol[k]: 81 | i = k 82 | continue 83 | sum_y = sums[i] 84 | sum_c = c[i] 85 | while True: 86 | # We are within an increasing subsequence. 87 | prev_y = sol[k] 88 | sum_y += sums[k] 89 | sum_c += c[k] 90 | k = target[k] + 1 91 | if k == n or prev_y > sol[k]: 92 | # Non-singleton increasing subsequence is finished, 93 | # update first entry. 94 | sol[i] = sum_y / sum_c 95 | sums[i] = sum_y 96 | c[i] = sum_c 97 | target[i] = k - 1 98 | target[k - 1] = i 99 | if i > 0: 100 | # Backtrack if we can. This makes the algorithm 101 | # single-pass and ensures O(n) complexity. 102 | i = target[i - 1] 103 | # Otherwise, restart from the same point. 104 | break 105 | 106 | # Reconstruct the solution. 107 | i = 0 108 | while i < n: 109 | k = target[i] + 1 110 | sol[i + 1 : k] = sol[i] 111 | i = k 112 | 113 | 114 | @njit 115 | def _log_add_exp(x, y): 116 | """Numerically stable log-add-exp.""" 117 | larger = max(x, y) 118 | smaller = min(x, y) 119 | return larger + np.log1p(np.exp(smaller - larger)) 120 | 121 | 122 | # Modified implementation for the KL geometry case. 123 | @njit 124 | def isotonic_kl(y, w, sol): 125 | """Solves isotonic optimization with KL divergence using PAV. 126 | 127 | Formally, it solves argmin_{v_1 >= ... >= v_n} + . 128 | 129 | Args: 130 | y: input to isotonic optimization, a 1d-array. 131 | w: input to isotonic optimization, a 1d-array. 132 | sol: where to write the solution, an array of the same size as y. 133 | """ 134 | n = y.shape[0] 135 | target = np.arange(n) 136 | lse_y_ = np.zeros(n) 137 | lse_w_ = np.zeros(n) 138 | 139 | # target describes a list of blocks. At any time, if [i..j] (inclusive) is 140 | # an active block, then target[i] := j and target[j] := i. 141 | 142 | for i in range(n): 143 | sol[i] = y[i] - w[i] 144 | lse_y_[i] = y[i] 145 | lse_w_[i] = w[i] 146 | 147 | i = 0 148 | while i < n: 149 | k = target[i] + 1 150 | if k == n: 151 | break 152 | if sol[i] > sol[k]: 153 | i = k 154 | continue 155 | lse_y = lse_y_[i] 156 | lse_w = lse_w_[i] 157 | while True: 158 | # We are within an increasing subsequence. 159 | prev_y = sol[k] 160 | lse_y = _log_add_exp(lse_y, lse_y_[k]) 161 | lse_w = _log_add_exp(lse_w, lse_w_[k]) 162 | k = target[k] + 1 163 | if k == n or prev_y > sol[k]: 164 | # Non-singleton increasing subsequence is finished, 165 | # update first entry. 166 | sol[i] = lse_y - lse_w 167 | lse_y_[i] = lse_y 168 | lse_w_[i] = lse_w 169 | target[i] = k - 1 170 | target[k - 1] = i 171 | if i > 0: 172 | # Backtrack if we can. This makes the algorithm 173 | # single-pass and ensures O(n) complexity. 174 | i = target[i - 1] 175 | # Otherwise, restart from the same point. 176 | break 177 | 178 | # Reconstruct the solution. 179 | i = 0 180 | while i < n: 181 | k = target[i] + 1 182 | sol[i + 1 : k] = sol[i] 183 | i = k 184 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Install fast_soft_sort.""" 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | setup( 21 | name='fast_soft_sort', 22 | version='0.1', 23 | description=( 24 | 'Differentiable sorting and ranking in O(n log n).'), 25 | author='Google LLC', 26 | author_email='no-reply@google.com', 27 | url='https://github.com/google-research/fast-soft-sort', 28 | license='BSD', 29 | packages=find_packages(), 30 | package_data={}, 31 | install_requires=[ 32 | 'numba', 33 | 'numpy', 34 | 'scipy>=1.2.0', 35 | ], 36 | extras_require={ 37 | 'tf': ['tensorflow>=1.12'], 38 | }, 39 | classifiers=[ 40 | 'Intended Audience :: Science/Research', 41 | 'License :: OSI Approved :: BSD License', 42 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 43 | ], 44 | keywords='machine learning sorting ranking', 45 | ) 46 | -------------------------------------------------------------------------------- /tests/isotonic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for isotonic.py.""" 16 | 17 | import unittest 18 | from absl.testing import parameterized 19 | from fast_soft_sort.third_party import isotonic 20 | import numpy as np 21 | from sklearn.isotonic import isotonic_regression 22 | 23 | 24 | class IsotonicTest(parameterized.TestCase): 25 | 26 | def test_l2_agrees_with_sklearn(self): 27 | rng = np.random.RandomState(0) 28 | y = rng.randn(10) * rng.randint(1, 5) 29 | sol = np.zeros_like(y) 30 | isotonic.isotonic_l2(y, sol) 31 | sol_skl = isotonic_regression(y, increasing=False) 32 | np.testing.assert_array_almost_equal(sol, sol_skl) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /tests/jax_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for jax_ops.py.""" 16 | 17 | import functools 18 | import itertools 19 | import unittest 20 | 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | 24 | import numpy as np 25 | import jax.numpy as jnp 26 | import jax 27 | 28 | from jax.config import config 29 | config.update("jax_enable_x64", True) 30 | 31 | from fast_soft_sort import jax_ops 32 | 33 | GAMMAS = (0.1, 1, 10.0) 34 | DIRECTIONS = ("ASCENDING", "DESCENDING") 35 | REGULARIZERS = ("l2", ) 36 | 37 | 38 | class JaxOpsTest(parameterized.TestCase): 39 | 40 | def _test(self, func, regularization_strength, direction, regularization): 41 | 42 | def loss_func(values): 43 | soft_values = func(values, 44 | regularization_strength=regularization_strength, 45 | direction=direction, 46 | regularization=regularization) 47 | return jnp.sum(soft_values ** 2) 48 | 49 | rng = np.random.RandomState(0) 50 | values = jnp.array(rng.randn(5, 10)) 51 | mat = jnp.array(rng.randn(5, 10)) 52 | unitmat = mat / np.sqrt(np.vdot(mat, mat)) 53 | eps = 1e-5 54 | numerical = (loss_func(values + 0.5 * eps * unitmat) - 55 | loss_func(values - 0.5 * eps * unitmat)) / eps 56 | autodiff = jnp.vdot(jax.grad(loss_func)(values), unitmat) 57 | np.testing.assert_almost_equal(numerical, autodiff) 58 | 59 | 60 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS)) 61 | def test_soft_rank(self, regularization_strength, direction, regularization): 62 | self._test(jax_ops.soft_rank, 63 | regularization_strength, direction, regularization) 64 | 65 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS)) 66 | def test_soft_sort(self, regularization_strength, direction, regularization): 67 | self._test(jax_ops.soft_sort, 68 | regularization_strength, direction, regularization) 69 | 70 | 71 | if __name__ == "__main__": 72 | absltest.main() 73 | -------------------------------------------------------------------------------- /tests/numpy_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for numpy_ops.py.""" 16 | 17 | import itertools 18 | import unittest 19 | from absl.testing import parameterized 20 | from fast_soft_sort import numpy_ops 21 | import numpy as np 22 | 23 | 24 | def _num_jacobian(theta, f, eps=1e-9): 25 | n_classes = len(theta) 26 | ret = np.zeros((n_classes, n_classes)) 27 | 28 | for i in range(n_classes): 29 | theta_ = theta.copy() 30 | theta_[i] += eps 31 | val = f(theta_) 32 | theta_[i] -= 2 * eps 33 | val2 = f(theta_) 34 | ret[i] = (val - val2) / (2 * eps) 35 | 36 | return ret.T 37 | 38 | 39 | GAMMAS = (0.1, 1.0, 10.0) 40 | DIRECTIONS = ("ASCENDING", "DESCENDING") 41 | REGULARIZERS = ("l2", "kl") 42 | CLASSES = (numpy_ops.Isotonic, numpy_ops.Projection) 43 | 44 | 45 | class IsotonicProjectionTest(parameterized.TestCase): 46 | 47 | @parameterized.parameters(itertools.product(CLASSES, REGULARIZERS)) 48 | def test_jvp_and_vjp_against_numerical_jacobian(self, cls, regularization): 49 | rng = np.random.RandomState(0) 50 | theta = rng.randn(5) 51 | w = np.arange(5)[::-1] 52 | v = rng.randn(5) 53 | 54 | f = lambda x: cls(x, w, regularization=regularization).compute() 55 | J = _num_jacobian(theta, f) 56 | 57 | obj = cls(theta, w, regularization=regularization) 58 | obj.compute() 59 | 60 | out = obj.jvp(v) 61 | np.testing.assert_array_almost_equal(J.dot(v), out) 62 | 63 | out = obj.vjp(v) 64 | np.testing.assert_array_almost_equal(v.dot(J), out) 65 | 66 | 67 | class SoftRankTest(parameterized.TestCase): 68 | 69 | @parameterized.parameters(itertools.product(DIRECTIONS, REGULARIZERS)) 70 | def test_soft_rank_converges_to_hard(self, direction, regularization): 71 | rng = np.random.RandomState(0) 72 | theta = rng.randn(5) 73 | soft_rank = numpy_ops.SoftRank(theta, regularization_strength=1e-3, 74 | direction=direction, 75 | regularization=regularization) 76 | out = numpy_ops.rank(theta, direction=direction) 77 | out2 = soft_rank.compute() 78 | np.testing.assert_array_almost_equal(out, out2) 79 | 80 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS)) 81 | def test_soft_rank_jvp_and_vjp_against_numerical_jacobian(self, 82 | regularization_strength, 83 | direction, 84 | regularization): 85 | rng = np.random.RandomState(0) 86 | theta = rng.randn(5) 87 | v = rng.randn(5) 88 | 89 | f = lambda x: numpy_ops.SoftRank( 90 | x, regularization_strength=regularization_strength, direction=direction, 91 | regularization=regularization).compute() 92 | J = _num_jacobian(theta, f) 93 | 94 | soft_rank = numpy_ops.SoftRank( 95 | theta, regularization_strength=regularization_strength, 96 | direction=direction, regularization=regularization) 97 | soft_rank.compute() 98 | 99 | out = soft_rank.jvp(v) 100 | np.testing.assert_array_almost_equal(J.dot(v), out, 1e-6) 101 | 102 | out = soft_rank.vjp(v) 103 | np.testing.assert_array_almost_equal(v.dot(J), out, 1e-6) 104 | 105 | out = soft_rank.jacobian() 106 | np.testing.assert_array_almost_equal(J, out, 1e-6) 107 | 108 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS)) 109 | def test_soft_rank_works_with_lists(self, regularization_strength, direction, 110 | regularization): 111 | rng = np.random.RandomState(0) 112 | theta = rng.randn(5) 113 | ranks1 = numpy_ops.SoftRank(theta, 114 | regularization_strength=regularization_strength, 115 | direction=direction, 116 | regularization=regularization).compute() 117 | ranks2 = numpy_ops.SoftRank(list(theta), 118 | regularization_strength=regularization_strength, 119 | direction=direction, 120 | regularization=regularization).compute() 121 | np.testing.assert_array_almost_equal(ranks1, ranks2) 122 | 123 | 124 | class SoftSortTest(parameterized.TestCase): 125 | 126 | @parameterized.parameters(itertools.product(DIRECTIONS, REGULARIZERS)) 127 | def test_soft_sort_converges_to_hard(self, direction, regularization): 128 | rng = np.random.RandomState(0) 129 | theta = rng.randn(5) 130 | soft_sort = numpy_ops.SoftSort( 131 | theta, regularization_strength=1e-3, direction=direction, 132 | regularization=regularization) 133 | sort = numpy_ops.Sort(theta, direction=direction) 134 | out = sort.compute() 135 | out2 = soft_sort.compute() 136 | np.testing.assert_array_almost_equal(out, out2) 137 | 138 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS)) 139 | def test_soft_sort_jvp(self, regularization_strength, direction, 140 | regularization): 141 | rng = np.random.RandomState(0) 142 | theta = rng.randn(5) 143 | v = rng.randn(5) 144 | 145 | f = lambda x: numpy_ops.SoftSort( 146 | x, regularization_strength=regularization_strength, 147 | direction=direction, regularization=regularization).compute() 148 | J = _num_jacobian(theta, f) 149 | 150 | soft_sort = numpy_ops.SoftSort( 151 | theta, regularization_strength=regularization_strength, 152 | direction=direction, regularization=regularization) 153 | soft_sort.compute() 154 | 155 | out = soft_sort.jvp(v) 156 | np.testing.assert_array_almost_equal(J.dot(v), out, 1e-6) 157 | 158 | out = soft_sort.vjp(v) 159 | np.testing.assert_array_almost_equal(v.dot(J), out, 1e-6) 160 | 161 | out = soft_sort.jacobian() 162 | np.testing.assert_array_almost_equal(J, out, 1e-6) 163 | 164 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS)) 165 | def test_soft_sort_works_with_lists(self, regularization_strength, direction, 166 | regularization): 167 | rng = np.random.RandomState(0) 168 | theta = rng.randn(5) 169 | sort1 = numpy_ops.SoftSort(theta, 170 | regularization_strength=regularization_strength, 171 | direction=direction, 172 | regularization=regularization).compute() 173 | sort2 = numpy_ops.SoftSort(list(theta), 174 | regularization_strength=regularization_strength, 175 | direction=direction, 176 | regularization=regularization).compute() 177 | np.testing.assert_array_almost_equal(sort1, sort2) 178 | 179 | 180 | class SortTest(parameterized.TestCase): 181 | 182 | @parameterized.parameters(itertools.product(DIRECTIONS)) 183 | def test_sort_jvp(self, direction): 184 | rng = np.random.RandomState(0) 185 | theta = rng.randn(5) 186 | v = rng.randn(5) 187 | 188 | f = lambda x: numpy_ops.Sort(x, direction=direction).compute() 189 | J = _num_jacobian(theta, f) 190 | 191 | sort = numpy_ops.Sort(theta, direction=direction) 192 | sort.compute() 193 | 194 | out = sort.jvp(v) 195 | np.testing.assert_array_almost_equal(J.dot(v), out) 196 | 197 | out = sort.vjp(v) 198 | np.testing.assert_array_almost_equal(v.dot(J), out) 199 | 200 | @parameterized.parameters(itertools.product(DIRECTIONS)) 201 | def test_sort_works_with_lists(self, direction): 202 | rng = np.random.RandomState(0) 203 | theta = rng.randn(5) 204 | sort_numpy = numpy_ops.Sort(theta, direction=direction).compute() 205 | sort_list = numpy_ops.Sort(list(theta), direction=direction).compute() 206 | np.testing.assert_array_almost_equal(sort_numpy, sort_list) 207 | 208 | 209 | if __name__ == "__main__": 210 | unittest.main() 211 | -------------------------------------------------------------------------------- /tests/pytorch_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for pytorch_ops.py.""" 16 | 17 | import functools 18 | import itertools 19 | import unittest 20 | 21 | from absl.testing import parameterized 22 | from fast_soft_sort import pytorch_ops 23 | import torch 24 | 25 | 26 | GAMMAS = (0.1, 1, 10.0) 27 | DIRECTIONS = ("ASCENDING", "DESCENDING") 28 | REGULARIZERS = ("l2",) # The kl case is unstable for gradcheck. 29 | DTYPES = (torch.float64,) 30 | 31 | 32 | class PyTorchOpsTest(parameterized.TestCase): 33 | 34 | def _test(self, func, regularization_strength, direction, regularization, 35 | dtype, atol=1e-3, rtol=1e-3, eps=1e-5): 36 | x = torch.randn(5, 10, dtype=dtype, requires_grad=True) 37 | 38 | func = functools.partial( 39 | func, regularization_strength=regularization_strength, 40 | direction=direction, regularization=regularization) 41 | 42 | torch.autograd.gradcheck(func, [x], eps=eps, atol=atol, rtol=rtol) 43 | 44 | def _compute_loss(x): 45 | y = func(x, regularization_strength=regularization_strength, 46 | direction=direction, regularization=regularization) 47 | return torch.sum(y**2) 48 | 49 | torch.autograd.gradcheck(_compute_loss, x, eps=eps, atol=atol, rtol=rtol) 50 | 51 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS, 52 | DTYPES)) 53 | def test_rank_gradient(self, regularization_strength, direction, 54 | regularization, dtype): 55 | self._test(pytorch_ops.soft_rank, regularization_strength, direction, 56 | regularization, dtype) 57 | 58 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS, 59 | DTYPES)) 60 | def test_sort_gradient(self, regularization_strength, direction, 61 | regularization, dtype): 62 | self._test(pytorch_ops.soft_sort, regularization_strength, direction, 63 | regularization, dtype) 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /tests/tf_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for tf_ops.py.""" 16 | 17 | import functools 18 | import itertools 19 | from absl.testing import parameterized 20 | from fast_soft_sort import tf_ops 21 | import tensorflow.compat.v2 as tf 22 | 23 | 24 | GAMMAS = (0.1, 1.0, 10.0) 25 | DIRECTIONS = ("ASCENDING", "DESCENDING") 26 | REGULARIZERS = ("l2", "kl") 27 | DTYPES = (tf.float64,) 28 | 29 | 30 | class TfOpsTest(parameterized.TestCase, tf.test.TestCase): 31 | 32 | def _test(self, func, regularization_strength, direction, regularization, 33 | dtype): 34 | 35 | precision = 1e-6 36 | delta = 1e-4 37 | 38 | x = tf.random.normal((5, 10), dtype=dtype) 39 | 40 | func = functools.partial( 41 | func, regularization_strength=regularization_strength, 42 | direction=direction, regularization=regularization) 43 | 44 | grad_theoretical, grad_numerical = tf.test.compute_gradient( 45 | func, [x], delta=delta) 46 | 47 | self.assertAllClose(grad_theoretical[0], grad_numerical[0], precision) 48 | 49 | def _compute_loss(x): 50 | y = func(x, regularization_strength=regularization_strength, 51 | direction=direction, regularization=regularization) 52 | return tf.reduce_mean(y**2) 53 | 54 | grad_theoretical, grad_numerical = tf.test.compute_gradient( 55 | _compute_loss, [x], delta=delta) 56 | 57 | self.assertAllClose(grad_theoretical[0], grad_numerical[0], precision) 58 | 59 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS, 60 | DTYPES)) 61 | def test_rank_gradient(self, regularization_strength, direction, 62 | regularization, dtype): 63 | self._test(tf_ops.soft_rank, regularization_strength, direction, 64 | regularization, dtype) 65 | 66 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS, 67 | DTYPES)) 68 | def test_sort_gradient(self, regularization_strength, direction, 69 | regularization, dtype): 70 | if regularization == "l2" or regularization_strength < 10: 71 | # We skip regularization_strength >= 10 when regularization = "kl", 72 | # due to numerical instability. 73 | self._test(tf_ops.soft_sort, regularization_strength, direction, 74 | regularization, dtype) 75 | 76 | 77 | if __name__ == "__main__": 78 | tf.enable_v2_behavior() 79 | tf.test.main() 80 | --------------------------------------------------------------------------------