├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── fast_soft_sort
├── __init__.py
├── jax_ops.py
├── numpy_ops.py
├── pytorch_ops.py
├── tf_ops.py
└── third_party
│ ├── LICENSE
│ ├── __init__.py
│ └── isotonic.py
├── setup.py
└── tests
├── isotonic_test.py
├── jax_ops_test.py
├── numpy_ops_test.py
├── pytorch_ops_test.py
└── tf_ops_test.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Version 2.0, January 2004
2 | http://www.apache.org/licenses/
3 |
4 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
5 |
6 | 1. Definitions.
7 |
8 | "License" shall mean the terms and conditions for use, reproduction,
9 | and distribution as defined by Sections 1 through 9 of this document.
10 |
11 | "Licensor" shall mean the copyright owner or entity authorized by
12 | the copyright owner that is granting the License.
13 |
14 | "Legal Entity" shall mean the union of the acting entity and all
15 | other entities that control, are controlled by, or are under common
16 | control with that entity. For the purposes of this definition,
17 | "control" means (i) the power, direct or indirect, to cause the
18 | direction or management of such entity, whether by contract or
19 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
20 | outstanding shares, or (iii) beneficial ownership of such entity.
21 |
22 | "You" (or "Your") shall mean an individual or Legal Entity
23 | exercising permissions granted by this License.
24 |
25 | "Source" form shall mean the preferred form for making modifications,
26 | including but not limited to software source code, documentation
27 | source, and configuration files.
28 |
29 | "Object" form shall mean any form resulting from mechanical
30 | transformation or translation of a Source form, including but
31 | not limited to compiled object code, generated documentation,
32 | and conversions to other media types.
33 |
34 | "Work" shall mean the work of authorship, whether in Source or
35 | Object form, made available under the License, as indicated by a
36 | copyright notice that is included in or attached to the work
37 | (an example is provided in the Appendix below).
38 |
39 | "Derivative Works" shall mean any work, whether in Source or Object
40 | form, that is based on (or derived from) the Work and for which the
41 | editorial revisions, annotations, elaborations, or other modifications
42 | represent, as a whole, an original work of authorship. For the purposes
43 | of this License, Derivative Works shall not include works that remain
44 | separable from, or merely link (or bind by name) to the interfaces of,
45 | the Work and Derivative Works thereof.
46 |
47 | "Contribution" shall mean any work of authorship, including
48 | the original version of the Work and any modifications or additions
49 | to that Work or Derivative Works thereof, that is intentionally
50 | submitted to Licensor for inclusion in the Work by the copyright owner
51 | or by an individual or Legal Entity authorized to submit on behalf of
52 | the copyright owner. For the purposes of this definition, "submitted"
53 | means any form of electronic, verbal, or written communication sent
54 | to the Licensor or its representatives, including but not limited to
55 | communication on electronic mailing lists, source code control systems,
56 | and issue tracking systems that are managed by, or on behalf of, the
57 | Licensor for the purpose of discussing and improving the Work, but
58 | excluding communication that is conspicuously marked or otherwise
59 | designated in writing by the copyright owner as "Not a Contribution."
60 |
61 | "Contributor" shall mean Licensor and any individual or Legal Entity
62 | on behalf of whom a Contribution has been received by Licensor and
63 | subsequently incorporated within the Work.
64 |
65 | 2. Grant of Copyright License. Subject to the terms and conditions of
66 | this License, each Contributor hereby grants to You a perpetual,
67 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
68 | copyright license to reproduce, prepare Derivative Works of,
69 | publicly display, publicly perform, sublicense, and distribute the
70 | Work and such Derivative Works in Source or Object form.
71 |
72 | 3. Grant of Patent License. Subject to the terms and conditions of
73 | this License, each Contributor hereby grants to You a perpetual,
74 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
75 | (except as stated in this section) patent license to make, have made,
76 | use, offer to sell, sell, import, and otherwise transfer the Work,
77 | where such license applies only to those patent claims licensable
78 | by such Contributor that are necessarily infringed by their
79 | Contribution(s) alone or by combination of their Contribution(s)
80 | with the Work to which such Contribution(s) was submitted. If You
81 | institute patent litigation against any entity (including a
82 | cross-claim or counterclaim in a lawsuit) alleging that the Work
83 | or a Contribution incorporated within the Work constitutes direct
84 | or contributory patent infringement, then any patent licenses
85 | granted to You under this License for that Work shall terminate
86 | as of the date such litigation is filed.
87 |
88 | 4. Redistribution. You may reproduce and distribute copies of the
89 | Work or Derivative Works thereof in any medium, with or without
90 | modifications, and in Source or Object form, provided that You
91 | meet the following conditions:
92 |
93 | (a) You must give any other recipients of the Work or
94 | Derivative Works a copy of this License; and
95 |
96 | (b) You must cause any modified files to carry prominent notices
97 | stating that You changed the files; and
98 |
99 | (c) You must retain, in the Source form of any Derivative Works
100 | that You distribute, all copyright, patent, trademark, and
101 | attribution notices from the Source form of the Work,
102 | excluding those notices that do not pertain to any part of
103 | the Derivative Works; and
104 |
105 | (d) If the Work includes a "NOTICE" text file as part of its
106 | distribution, then any Derivative Works that You distribute must
107 | include a readable copy of the attribution notices contained
108 | within such NOTICE file, excluding those notices that do not
109 | pertain to any part of the Derivative Works, in at least one
110 | of the following places: within a NOTICE text file distributed
111 | as part of the Derivative Works; within the Source form or
112 | documentation, if provided along with the Derivative Works; or,
113 | within a display generated by the Derivative Works, if and
114 | wherever such third-party notices normally appear. The contents
115 | of the NOTICE file are for informational purposes only and
116 | do not modify the License. You may add Your own attribution
117 | notices within Derivative Works that You distribute, alongside
118 | or as an addendum to the NOTICE text from the Work, provided
119 | that such additional attribution notices cannot be construed
120 | as modifying the License.
121 |
122 | You may add Your own copyright statement to Your modifications and
123 | may provide additional or different license terms and conditions
124 | for use, reproduction, or distribution of Your modifications, or
125 | for any such Derivative Works as a whole, provided Your use,
126 | reproduction, and distribution of the Work otherwise complies with
127 | the conditions stated in this License.
128 |
129 | 5. Submission of Contributions. Unless You explicitly state otherwise,
130 | any Contribution intentionally submitted for inclusion in the Work
131 | by You to the Licensor shall be under the terms and conditions of
132 | this License, without any additional terms or conditions.
133 | Notwithstanding the above, nothing herein shall supersede or modify
134 | the terms of any separate license agreement you may have executed
135 | with Licensor regarding such Contributions.
136 |
137 | 6. Trademarks. This License does not grant permission to use the trade
138 | names, trademarks, service marks, or product names of the Licensor,
139 | except as required for reasonable and customary use in describing the
140 | origin of the Work and reproducing the content of the NOTICE file.
141 |
142 | 7. Disclaimer of Warranty. Unless required by applicable law or
143 | agreed to in writing, Licensor provides the Work (and each
144 | Contributor provides its Contributions) on an "AS IS" BASIS,
145 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
146 | implied, including, without limitation, any warranties or conditions
147 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
148 | PARTICULAR PURPOSE. You are solely responsible for determining the
149 | appropriateness of using or redistributing the Work and assume any
150 | risks associated with Your exercise of permissions under this License.
151 |
152 | 8. Limitation of Liability. In no event and under no legal theory,
153 | whether in tort (including negligence), contract, or otherwise,
154 | unless required by applicable law (such as deliberate and grossly
155 | negligent acts) or agreed to in writing, shall any Contributor be
156 | liable to You for damages, including any direct, indirect, special,
157 | incidental, or consequential damages of any character arising as a
158 | result of this License or out of the use or inability to use the
159 | Work (including but not limited to damages for loss of goodwill,
160 | work stoppage, computer failure or malfunction, or any and all
161 | other commercial damages or losses), even if such Contributor
162 | has been advised of the possibility of such damages.
163 |
164 | 9. Accepting Warranty or Additional Liability. While redistributing
165 | the Work or Derivative Works thereof, You may choose to offer,
166 | and charge a fee for, acceptance of support, warranty, indemnity,
167 | or other liability obligations and/or rights consistent with this
168 | License. However, in accepting such obligations, You may act only
169 | on Your own behalf and on Your sole responsibility, not on behalf
170 | of any other Contributor, and only if You agree to indemnify,
171 | defend, and hold each Contributor harmless for any liability
172 | incurred by, or claims asserted against, such Contributor by reason
173 | of your accepting any such warranty or additional liability.
174 |
175 | END OF TERMS AND CONDITIONS
176 |
177 | APPENDIX: How to apply the Apache License to your work.
178 |
179 | To apply the Apache License to your work, attach the following
180 | boilerplate notice, with the fields enclosed by brackets "[]"
181 | replaced with your own identifying information. (Don't include
182 | the brackets!) The text should be enclosed in the appropriate
183 | comment syntax for the file format. We also recommend that a
184 | file or class name and description of purpose be included on the
185 | same "printed page" as the copyright notice for easier
186 | identification within third-party archives.
187 |
188 | Copyright [yyyy] [name of copyright owner]
189 |
190 | Licensed under the Apache License, Version 2.0 (the "License");
191 | you may not use this file except in compliance with the License.
192 | You may obtain a copy of the License at
193 |
194 | http://www.apache.org/licenses/LICENSE-2.0
195 |
196 | Unless required by applicable law or agreed to in writing, software
197 | distributed under the License is distributed on an "AS IS" BASIS,
198 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
199 | See the License for the specific language governing permissions and
200 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Fast Differentiable Sorting and Ranking
3 | =======================================
4 |
5 | Differentiable sorting and ranking operations in O(n log n).
6 |
7 | Dependencies
8 | ------------
9 |
10 | * NumPy
11 | * SciPy
12 | * Numba
13 | * Tensorflow (optional)
14 | * PyTorch (optional)
15 |
16 | TensorFlow Example
17 | -------------------
18 |
19 | ```python
20 | >>> import tensorflow as tf
21 | >>> from fast_soft_sort.tf_ops import soft_rank, soft_sort
22 | >>> values = tf.convert_to_tensor([[5., 1., 2.], [2., 1., 5.]], dtype=tf.float64)
23 | >>> soft_sort(values, regularization_strength=1.0)
24 |
25 | >>> soft_sort(values, regularization_strength=0.1)
26 |
27 | >>> soft_rank(values, regularization_strength=2.0)
28 |
29 | >>> soft_rank(values, regularization_strength=1.0)
30 |
31 | ```
32 |
33 | JAX Example
34 | -----------
35 |
36 | ```python
37 | >>> import jax.numpy as jnp
38 | >>> from fast_soft_sort.jax_ops import soft_rank, soft_sort
39 | >>> values = jnp.array([[5., 1., 2.], [2., 1., 5.]], dtype=jnp.float64)
40 | >>> soft_sort(values, regularization_strength=1.0)
41 | [[1.66666667 2.66666667 3.66666667]
42 | [1.66666667 2.66666667 3.66666667]]
43 | >>> soft_sort(values, regularization_strength=0.1)
44 | [[1. 2. 5.]
45 | [1. 2. 5.]]
46 | >>> soft_rank(values, regularization_strength=2.0)
47 | [[3. 1.25 1.75]
48 | [1.75 1.25 3. ]]
49 | >>> soft_rank(values, regularization_strength=1.0)
50 | [[3. 1. 2.]
51 | [2. 1. 3.]]
52 | ```
53 |
54 | PyTorch Example
55 | ---------------
56 |
57 | ```python
58 | >>> import torch
59 | >>> from pytorch_ops import soft_rank, soft_sort
60 | >>> values = fast_soft_sort.torch.tensor([[5., 1., 2.], [2., 1., 5.]], dtype=torch.float64)
61 | >>> soft_sort(values, regularization_strength=1.0)
62 | tensor([[1.6667, 2.6667, 3.6667]
63 | [1.6667, 2.6667, 3.6667]], dtype=torch.float64)
64 | >>> soft_sort(values, regularization_strength=0.1)
65 | tensor([[1., 2., 5.]
66 | [1., 2., 5.]], dtype=torch.float64)
67 | >>> soft_rank(values, regularization_strength=2.0)
68 | tensor([[3.0000, 1.2500, 1.7500],
69 | [1.7500, 1.2500, 3.0000]], dtype=torch.float64)
70 | >>> soft_rank(values, regularization_strength=1.0)
71 | tensor([[3., 1., 2.]
72 | [2., 1., 3.]], dtype=torch.float64)
73 | ```
74 |
75 |
76 | Install
77 | --------
78 |
79 | Run `python setup.py install` or copy the `fast_soft_sort/` folder to your
80 | project.
81 |
82 |
83 | Reference
84 | ------------
85 |
86 | > Fast Differentiable Sorting and Ranking
87 | > Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
88 | > In proceedings of ICML 2020
89 | > [arXiv:2002.08871](https://arxiv.org/abs/2002.08871)
90 |
--------------------------------------------------------------------------------
/fast_soft_sort/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/fast-soft-sort/6a52ce79869ab16e1e0f39149a84f50f8ad648c5/fast_soft_sort/__init__.py
--------------------------------------------------------------------------------
/fast_soft_sort/jax_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """JAX operators for soft sorting and ranking.
16 |
17 | Fast Differentiable Sorting and Ranking
18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
19 | https://arxiv.org/abs/2002.08871
20 | """
21 |
22 | from . import numpy_ops
23 | import jax
24 | import numpy as np
25 | import jax.numpy as jnp
26 | from jax import tree_util
27 |
28 |
29 | def _wrap_numpy_op(cls, **kwargs):
30 | """Converts NumPy operator to a JAX one."""
31 |
32 | def _func_fwd(values):
33 | """Converts values to numpy array, applies function and returns array."""
34 | dtype = values.dtype
35 | values = np.array(values)
36 | obj = cls(values, **kwargs)
37 | result = obj.compute()
38 | return jnp.array(result, dtype=dtype), tree_util.Partial(obj.vjp)
39 |
40 | def _func_bwd(vjp, g):
41 | g = np.array(g)
42 | result = jnp.array(vjp(g), dtype=g.dtype)
43 | return (result,)
44 |
45 | @jax.custom_vjp
46 | def _func(values):
47 | return _func_fwd(values)[0]
48 |
49 | _func.defvjp(_func_fwd, _func_bwd)
50 |
51 | return _func
52 |
53 |
54 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0,
55 | regularization="l2"):
56 | r"""Soft rank the given values (array) along the second axis.
57 |
58 | The regularization strength determines how close are the returned values
59 | to the actual ranks.
60 |
61 | Args:
62 | values: A 2d-array holding the numbers to be ranked.
63 | direction: Either 'ASCENDING' or 'DESCENDING'.
64 | regularization_strength: The regularization strength to be used. The smaller
65 | this number, the closer the values to the true ranks.
66 | regularization: Which regularization method to use. It
67 | must be set to one of ("l2", "kl", "log_kl").
68 | Returns:
69 | A 2d-array, soft-ranked along the second axis.
70 | """
71 | if len(values.shape) != 2:
72 | raise ValueError("'values' should be a 2d-array "
73 | "but got %r." % values.shape)
74 |
75 | func = _wrap_numpy_op(numpy_ops.SoftRank,
76 | regularization_strength=regularization_strength,
77 | direction=direction,
78 | regularization=regularization)
79 |
80 | return jnp.vstack([func(val) for val in values])
81 |
82 |
83 | def soft_sort(values, direction="ASCENDING",
84 | regularization_strength=1.0, regularization="l2"):
85 | r"""Soft sort the given values (array) along the second axis.
86 |
87 | The regularization strength determines how close are the returned values
88 | to the actual sorted values.
89 |
90 | Args:
91 | values: A 2d-array holding the numbers to be sorted.
92 | direction: Either 'ASCENDING' or 'DESCENDING'.
93 | regularization_strength: The regularization strength to be used. The smaller
94 | this number, the closer the values to the true sorted values.
95 | regularization: Which regularization method to use. It
96 | must be set to one of ("l2", "log_kl").
97 | Returns:
98 | A 2d-array, soft-sorted along the second axis.
99 | """
100 | if len(values.shape) != 2:
101 | raise ValueError("'values' should be a 2d-array "
102 | "but got %s." % str(values.shape))
103 |
104 | func = _wrap_numpy_op(numpy_ops.SoftSort,
105 | regularization_strength=regularization_strength,
106 | direction=direction,
107 | regularization=regularization)
108 |
109 | return jnp.vstack([func(val) for val in values])
110 |
--------------------------------------------------------------------------------
/fast_soft_sort/numpy_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Numpy operators for soft sorting and ranking.
16 |
17 | Fast Differentiable Sorting and Ranking
18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
19 | https://arxiv.org/abs/2002.08871
20 |
21 | This implementation follows the notation of the paper whenever possible.
22 | """
23 |
24 | from .third_party import isotonic
25 | import numpy as np
26 | from scipy import special
27 |
28 |
29 | def isotonic_l2(input_s, input_w=None):
30 | """Solves an isotonic regression problem using PAV.
31 |
32 | Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - (s-w)||^2.
33 |
34 | Args:
35 | input_s: input to isotonic regression, a 1d-array.
36 | input_w: input to isotonic regression, a 1d-array.
37 | Returns:
38 | solution to the optimization problem.
39 | """
40 | if input_w is None:
41 | input_w = np.arange(len(input_s))[::-1] + 1
42 | input_w = input_w.astype(input_s.dtype)
43 | solution = np.zeros_like(input_s)
44 | isotonic.isotonic_l2(input_s - input_w, solution)
45 | return solution
46 |
47 |
48 | def isotonic_kl(input_s, input_w=None):
49 | """Solves isotonic optimization with KL divergence using PAV.
50 |
51 | Formally, it solves argmin_{v_1 >= ... >= v_n} + .
52 |
53 | Args:
54 | input_s: input to isotonic optimization, a 1d-array.
55 | input_w: input to isotonic optimization, a 1d-array.
56 | Returns:
57 | solution to the optimization problem (same dtype as input_s).
58 | """
59 | if input_w is None:
60 | input_w = np.arange(len(input_s))[::-1] + 1
61 | input_w = input_w.astype(input_s.dtype)
62 | solution = np.zeros(len(input_s)).astype(input_s.dtype)
63 | isotonic.isotonic_kl(input_s, input_w.astype(input_s.dtype), solution)
64 | return solution
65 |
66 |
67 | def _partition(solution, eps=1e-9):
68 | """Returns partition corresponding to solution."""
69 | # pylint: disable=g-explicit-length-test
70 | if len(solution) == 0:
71 | return []
72 |
73 | sizes = [1]
74 |
75 | for i in range(1, len(solution)):
76 | if abs(solution[i] - solution[i - 1]) > eps:
77 | sizes.append(0)
78 | sizes[-1] += 1
79 |
80 | return sizes
81 |
82 |
83 | def _check_regularization(regularization):
84 | if regularization not in ("l2", "kl"):
85 | raise ValueError("'regularization' should be either 'l2' or 'kl' "
86 | "but got %s." % str(regularization))
87 |
88 |
89 | class _Differentiable(object):
90 | """Base class for differentiable operators."""
91 |
92 | def jacobian(self):
93 | """Computes Jacobian."""
94 | identity = np.eye(self.size)
95 | return np.array([self.jvp(identity[i]) for i in range(len(identity))]).T
96 |
97 | @property
98 | def size(self):
99 | raise NotImplementedError
100 |
101 | def compute(self):
102 | """Computes the desired quantity."""
103 | raise NotImplementedError
104 |
105 | def jvp(self, vector):
106 | """Computes Jacobian vector product."""
107 | raise NotImplementedError
108 |
109 | def vjp(self, vector):
110 | """Computes vector Jacobian product."""
111 | raise NotImplementedError
112 |
113 |
114 | class Isotonic(_Differentiable):
115 | """Isotonic optimization."""
116 |
117 | def __init__(self, input_s, input_w, regularization="l2"):
118 | self.input_s = input_s
119 | self.input_w = input_w
120 | _check_regularization(regularization)
121 | self.regularization = regularization
122 | self.solution_ = None
123 |
124 | @property
125 | def size(self):
126 | return len(self.input_s)
127 |
128 | def compute(self):
129 |
130 | if self.regularization == "l2":
131 | self.solution_ = isotonic_l2(self.input_s, self.input_w)
132 | else:
133 | self.solution_ = isotonic_kl(self.input_s, self.input_w)
134 | return self.solution_
135 |
136 | def _check_computed(self):
137 | if self.solution_ is None:
138 | raise RuntimeError("Need to run compute() first.")
139 |
140 | def jvp(self, vector):
141 | self._check_computed()
142 | start = 0
143 | return_value = np.zeros_like(self.solution_)
144 | for size in _partition(self.solution_):
145 | end = start + size
146 | if self.regularization == "l2":
147 | val = np.mean(vector[start:end])
148 | else:
149 | val = np.dot(special.softmax(self.input_s[start:end]),
150 | vector[start:end])
151 | return_value[start:end] = val
152 | start = end
153 | return return_value
154 |
155 | def vjp(self, vector):
156 | start = 0
157 | return_value = np.zeros_like(self.solution_)
158 | for size in _partition(self.solution_):
159 | end = start + size
160 | if self.regularization == "l2":
161 | val = 1. / size
162 | else:
163 | val = special.softmax(self.input_s[start:end])
164 | return_value[start:end] = val * np.sum(vector[start:end])
165 | start = end
166 | return return_value
167 |
168 |
169 | def _inv_permutation(permutation):
170 | """Returns inverse permutation of 'permutation'."""
171 | inv_permutation = np.zeros(len(permutation), dtype=int)
172 | inv_permutation[permutation] = np.arange(len(permutation))
173 | return inv_permutation
174 |
175 |
176 | class Projection(_Differentiable):
177 | """Computes projection onto the permutahedron P(w)."""
178 |
179 | def __init__(self, input_theta, input_w=None, regularization="l2"):
180 | if input_w is None:
181 | input_w = np.arange(len(input_theta))[::-1] + 1
182 | self.input_theta = np.asarray(input_theta)
183 | self.input_w = np.asarray(input_w)
184 | _check_regularization(regularization)
185 | self.regularization = regularization
186 | self.isotonic = None
187 |
188 | def _check_computed(self):
189 | if self.isotonic_ is None:
190 | raise ValueError("Need to run compute() first.")
191 |
192 | @property
193 | def size(self):
194 | return len(self.input_theta)
195 |
196 | def compute(self):
197 | self.permutation = np.argsort(self.input_theta)[::-1]
198 | input_s = self.input_theta[self.permutation]
199 |
200 | self.isotonic_ = Isotonic(input_s, self.input_w, self.regularization)
201 | dual_sol = self.isotonic_.compute()
202 | primal_sol = input_s - dual_sol
203 |
204 | self.inv_permutation = _inv_permutation(self.permutation)
205 | return primal_sol[self.inv_permutation]
206 |
207 | def jvp(self, vector):
208 | self._check_computed()
209 | ret = vector.copy()
210 | ret -= self.isotonic_.jvp(vector[self.permutation])[self.inv_permutation]
211 | return ret
212 |
213 | def vjp(self, vector):
214 | self._check_computed()
215 | ret = vector.copy()
216 | ret -= self.isotonic_.vjp(vector[self.permutation])[self.inv_permutation]
217 | return ret
218 |
219 |
220 | def _check_direction(direction):
221 | if direction not in ("ASCENDING", "DESCENDING"):
222 | raise ValueError("direction should be either 'ASCENDING' or 'DESCENDING'")
223 |
224 |
225 | class SoftRank(_Differentiable):
226 | """Soft ranking."""
227 |
228 | def __init__(self, values, direction="ASCENDING",
229 | regularization_strength=1.0, regularization="l2"):
230 | self.values = np.asarray(values)
231 | self.input_w = np.arange(len(values))[::-1] + 1
232 | _check_direction(direction)
233 | sign = 1 if direction == "ASCENDING" else -1
234 | self.scale = sign / regularization_strength
235 | _check_regularization(regularization)
236 | self.regularization = regularization
237 | self.projection_ = None
238 |
239 | @property
240 | def size(self):
241 | return len(self.values)
242 |
243 | def _check_computed(self):
244 | if self.projection_ is None:
245 | raise ValueError("Need to run compute() first.")
246 |
247 | def compute(self):
248 | if self.regularization == "kl":
249 | self.projection_ = Projection(
250 | self.values * self.scale,
251 | np.log(self.input_w),
252 | regularization=self.regularization)
253 | self.factor = np.exp(self.projection_.compute())
254 | return self.factor
255 | else:
256 | self.projection_ = Projection(
257 | self.values * self.scale, self.input_w,
258 | regularization=self.regularization)
259 | self.factor = 1.0
260 | return self.projection_.compute()
261 |
262 | def jvp(self, vector):
263 | self._check_computed()
264 | return self.factor * self.projection_.jvp(vector) * self.scale
265 |
266 | def vjp(self, vector):
267 | self._check_computed()
268 | return self.projection_.vjp(self.factor * vector) * self.scale
269 |
270 |
271 | class SoftSort(_Differentiable):
272 | """Soft sorting."""
273 |
274 | def __init__(self, values, direction="ASCENDING",
275 | regularization_strength=1.0, regularization="l2"):
276 | self.values = np.asarray(values)
277 | _check_direction(direction)
278 | self.sign = 1 if direction == "DESCENDING" else -1
279 | self.regularization_strength = regularization_strength
280 | _check_regularization(regularization)
281 | self.regularization = regularization
282 | self.isotonic_ = None
283 |
284 | @property
285 | def size(self):
286 | return len(self.values)
287 |
288 | def _check_computed(self):
289 | if self.isotonic_ is None:
290 | raise ValueError("Need to run compute() first.")
291 |
292 | def compute(self):
293 | size = len(self.values)
294 | input_w = np.arange(1, size + 1)[::-1] / self.regularization_strength
295 | values = self.sign * self.values
296 | self.permutation_ = np.argsort(values)[::-1]
297 | s = values[self.permutation_]
298 |
299 | self.isotonic_ = Isotonic(input_w, s, regularization=self.regularization)
300 | res = self.isotonic_.compute()
301 |
302 | # We set s as the first argument as we want the derivatives w.r.t. s.
303 | self.isotonic_.s = s
304 | return self.sign * (input_w - res)
305 |
306 | def jvp(self, vector):
307 | self._check_computed()
308 | return self.isotonic_.jvp(vector[self.permutation_])
309 |
310 | def vjp(self, vector):
311 | self._check_computed()
312 | inv_permutation = _inv_permutation(self.permutation_)
313 | return self.isotonic_.vjp(vector)[inv_permutation]
314 |
315 |
316 | class Sort(_Differentiable):
317 | """Hard sorting."""
318 |
319 | def __init__(self, values, direction="ASCENDING"):
320 | _check_direction(direction)
321 | self.values = np.asarray(values)
322 | self.sign = 1 if direction == "DESCENDING" else -1
323 | self.permutation_ = None
324 |
325 | @property
326 | def size(self):
327 | return len(self.values)
328 |
329 | def _check_computed(self):
330 | if self.permutation_ is None:
331 | raise ValueError("Need to run compute() first.")
332 |
333 | def compute(self):
334 | self.permutation_ = np.argsort(self.sign * self.values)[::-1]
335 | return self.values[self.permutation_]
336 |
337 | def jvp(self, vector):
338 | self._check_computed()
339 | return vector[self.permutation_]
340 |
341 | def vjp(self, vector):
342 | self._check_computed()
343 | inv_permutation = _inv_permutation(self.permutation_)
344 | return vector[inv_permutation]
345 |
346 |
347 | # Small utility functions for the case when we just want the forward
348 | # computation.
349 |
350 |
351 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0,
352 | regularization="l2"):
353 | r"""Soft rank the given values.
354 |
355 | The regularization strength determines how close are the returned values
356 | to the actual ranks.
357 |
358 | Args:
359 | values: A 1d-array holding the numbers to be ranked.
360 | direction: Either 'ASCENDING' or 'DESCENDING'.
361 | regularization_strength: The regularization strength to be used. The smaller
362 | this number, the closer the values to the true ranks.
363 | regularization: Which regularization method to use. It
364 | must be set to one of ("l2", "kl", "log_kl").
365 | Returns:
366 | A 1d-array, soft-ranked.
367 | """
368 | return SoftRank(values, regularization_strength=regularization_strength,
369 | direction=direction, regularization=regularization).compute()
370 |
371 |
372 | def soft_sort(values, direction="ASCENDING", regularization_strength=1.0,
373 | regularization="l2"):
374 | r"""Soft sort the given values.
375 |
376 | Args:
377 | values: A 1d-array holding the numbers to be sorted.
378 | direction: Either 'ASCENDING' or 'DESCENDING'.
379 | regularization_strength: The regularization strength to be used. The smaller
380 | this number, the closer the values to the true sorted values.
381 | regularization: Which regularization method to use. It
382 | must be set to one of ("l2", "log_kl").
383 | Returns:
384 | A 1d-array, soft-sorted.
385 | """
386 | return SoftSort(values, regularization_strength=regularization_strength,
387 | direction=direction, regularization=regularization).compute()
388 |
389 |
390 | def sort(values, direction="ASCENDING"):
391 | r"""Sort the given values.
392 |
393 | Args:
394 | values: A 1d-array holding the numbers to be sorted.
395 | direction: Either 'ASCENDING' or 'DESCENDING'.
396 | Returns:
397 | A 1d-array, sorted.
398 | """
399 | return Sort(values, direction=direction).compute()
400 |
401 |
402 | def rank(values, direction="ASCENDING"):
403 | r"""Rank the given values.
404 |
405 | Args:
406 | values: A 1d-array holding the numbers to be ranked.
407 | direction: Either 'ASCENDING' or 'DESCENDING'.
408 | Returns:
409 | A 1d-array, ranked.
410 | """
411 | permutation = np.argsort(values)
412 | if direction == "DESCENDING":
413 | permutation = permutation[::-1]
414 | return _inv_permutation(permutation) + 1 # We use 1-based indexing.
415 |
--------------------------------------------------------------------------------
/fast_soft_sort/pytorch_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """PyTorch operators for soft sorting and ranking.
16 |
17 | Fast Differentiable Sorting and Ranking
18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
19 | https://arxiv.org/abs/2002.08871
20 | """
21 |
22 | from . import numpy_ops
23 | import torch
24 |
25 |
26 | def wrap_class(cls, **kwargs):
27 | """Wraps the given NumpyOp in a torch Function."""
28 |
29 | class NumpyOpWrapper(torch.autograd.Function):
30 | """A torch Function wrapping a NumpyOp."""
31 |
32 | @staticmethod
33 | def forward(ctx, values):
34 | obj = cls(values.detach().numpy(), **kwargs)
35 | ctx.numpy_obj = obj
36 | return torch.from_numpy(obj.compute())
37 |
38 | @staticmethod
39 | def backward(ctx, grad_output):
40 | return torch.from_numpy(ctx.numpy_obj.vjp(grad_output.numpy()))
41 |
42 | return NumpyOpWrapper
43 |
44 |
45 | def map_tensor(map_fn, tensor):
46 | return torch.stack([map_fn(tensor_i) for tensor_i in torch.unbind(tensor)])
47 |
48 |
49 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0,
50 | regularization="l2"):
51 | r"""Soft rank the given values (tensor) along the second axis.
52 |
53 | The regularization strength determines how close are the returned values
54 | to the actual ranks.
55 |
56 | Args:
57 | values: A 2d-tensor holding the numbers to be ranked.
58 | direction: Either 'ASCENDING' or 'DESCENDING'.
59 | regularization_strength: The regularization strength to be used. The smaller
60 | this number, the closer the values to the true ranks.
61 | regularization: Which regularization method to use. It
62 | must be set to one of ("l2", "kl", "log_kl").
63 | Returns:
64 | A 2d-tensor, soft-ranked along the second axis.
65 | """
66 | if len(values.shape) != 2:
67 | raise ValueError("'values' should be a 2d-tensor "
68 | "but got %r." % values.shape)
69 |
70 | wrapped_fn = wrap_class(numpy_ops.SoftRank,
71 | regularization_strength=regularization_strength,
72 | direction=direction,
73 | regularization=regularization)
74 | return map_tensor(wrapped_fn.apply, values)
75 |
76 |
77 | def soft_sort(values, direction="ASCENDING",
78 | regularization_strength=1.0, regularization="l2"):
79 | r"""Soft sort the given values (tensor) along the second axis.
80 |
81 | The regularization strength determines how close are the returned values
82 | to the actual sorted values.
83 |
84 | Args:
85 | values: A 2d-tensor holding the numbers to be sorted.
86 | direction: Either 'ASCENDING' or 'DESCENDING'.
87 | regularization_strength: The regularization strength to be used. The smaller
88 | this number, the closer the values to the true sorted values.
89 | regularization: Which regularization method to use. It
90 | must be set to one of ("l2", "log_kl").
91 | Returns:
92 | A 2d-tensor, soft-sorted along the second axis.
93 | """
94 | if len(values.shape) != 2:
95 | raise ValueError("'values' should be a 2d-tensor "
96 | "but got %s." % str(values.shape))
97 |
98 | wrapped_fn = wrap_class(numpy_ops.SoftSort,
99 | regularization_strength=regularization_strength,
100 | direction=direction,
101 | regularization=regularization)
102 |
103 | return map_tensor(wrapped_fn.apply, values)
104 |
--------------------------------------------------------------------------------
/fast_soft_sort/tf_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tensorflow operators for soft sorting and ranking.
16 |
17 | Fast Differentiable Sorting and Ranking
18 | Mathieu Blondel, Olivier Teboul, Quentin Berthet, Josip Djolonga
19 | https://arxiv.org/abs/2002.08871
20 | """
21 |
22 | from . import numpy_ops
23 | import tensorflow.compat.v2 as tf
24 |
25 |
26 | def _wrap_numpy_op(cls, regularization_strength, direction, regularization):
27 | """Converts NumPy operator to a TF one."""
28 |
29 | @tf.custom_gradient
30 | def _func(values):
31 | """Converts values to numpy array, applies function and returns tensor."""
32 | dtype = values.dtype
33 |
34 | try:
35 | values = values.numpy()
36 | except AttributeError:
37 | pass
38 |
39 | obj = cls(values, regularization_strength=regularization_strength,
40 | direction=direction, regularization=regularization)
41 | result = obj.compute()
42 |
43 | def grad(v):
44 | v = v.numpy()
45 | return tf.convert_to_tensor(obj.vjp(v), dtype=dtype)
46 |
47 | return tf.convert_to_tensor(result, dtype=dtype), grad
48 |
49 | return _func
50 |
51 |
52 | def soft_rank(values, direction="ASCENDING", regularization_strength=1.0,
53 | regularization="l2"):
54 | r"""Soft rank the given values (tensor) along the second axis.
55 |
56 | The regularization strength determines how close are the returned values
57 | to the actual ranks.
58 |
59 | Args:
60 | values: A 2d-tensor holding the numbers to be ranked.
61 | direction: Either 'ASCENDING' or 'DESCENDING'.
62 | regularization_strength: The regularization strength to be used. The smaller
63 | this number, the closer the values to the true ranks.
64 | regularization: Which regularization method to use. It
65 | must be set to one of ("l2", "kl", "log_kl").
66 | Returns:
67 | A 2d-tensor, soft-ranked along the second axis.
68 | """
69 | if len(values.shape) != 2:
70 | raise ValueError("'values' should be a 2d-tensor "
71 | "but got %r." % values.shape)
72 |
73 | assert tf.executing_eagerly()
74 |
75 | func = _wrap_numpy_op(numpy_ops.SoftRank, regularization_strength, direction,
76 | regularization)
77 |
78 | return tf.map_fn(func, values)
79 |
80 |
81 | def soft_sort(values, direction="ASCENDING",
82 | regularization_strength=1.0, regularization="l2"):
83 | r"""Soft sort the given values (tensor) along the second axis.
84 |
85 | The regularization strength determines how close are the returned values
86 | to the actual sorted values.
87 |
88 | Args:
89 | values: A 2d-tensor holding the numbers to be sorted.
90 | direction: Either 'ASCENDING' or 'DESCENDING'.
91 | regularization_strength: The regularization strength to be used. The smaller
92 | this number, the closer the values to the true sorted values.
93 | regularization: Which regularization method to use. It
94 | must be set to one of ("l2", "log_kl").
95 | Returns:
96 | A 2d-tensor, soft-sorted along the second axis.
97 | """
98 | if len(values.shape) != 2:
99 | raise ValueError("'values' should be a 2d-tensor "
100 | "but got %s." % str(values.shape))
101 |
102 | assert tf.executing_eagerly()
103 |
104 | func = _wrap_numpy_op(numpy_ops.SoftSort, regularization_strength, direction,
105 | regularization)
106 |
107 | return tf.map_fn(func, values)
108 |
--------------------------------------------------------------------------------
/fast_soft_sort/third_party/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2007–2020 The scikit-learn developers.
2 | Copyright (c) 2020 Google LLC.
3 | All rights reserved.
4 |
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | a. Redistributions of source code must retain the above copyright notice,
10 | this list of conditions and the following disclaimer.
11 | b. Redistributions in binary form must reproduce the above copyright
12 | notice, this list of conditions and the following disclaimer in the
13 | documentation and/or other materials provided with the distribution.
14 | c. Neither the name of the Scikit-learn Developers nor the names of
15 | its contributors may be used to endorse or promote products
16 | derived from this software without specific prior written
17 | permission.
18 |
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 | ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
24 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
30 | DAMAGE.
--------------------------------------------------------------------------------
/fast_soft_sort/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/fast-soft-sort/6a52ce79869ab16e1e0f39149a84f50f8ad648c5/fast_soft_sort/third_party/__init__.py
--------------------------------------------------------------------------------
/fast_soft_sort/third_party/isotonic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2007-2020 The scikit-learn developers.
2 | # Copyright 2020 Google LLC.
3 | # All rights reserved.
4 | #
5 | # Redistribution and use in source and binary forms, with or without
6 | # modification, are permitted provided that the following conditions are met:
7 | #
8 | # a. Redistributions of source code must retain the above copyright notice,
9 | # this list of conditions and the following disclaimer.
10 | # b. Redistributions in binary form must reproduce the above copyright
11 | # notice, this list of conditions and the following disclaimer in the
12 | # documentation and/or other materials provided with the distribution.
13 | # c. Neither the name of the Scikit-learn Developers nor the names of
14 | # its contributors may be used to endorse or promote products
15 | # derived from this software without specific prior written
16 | # permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
22 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
26 | # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
27 | # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
28 | # DAMAGE.
29 |
30 | """Isotonic optimization routines in Numba."""
31 |
32 | import warnings
33 | import numpy as np
34 |
35 | # pylint: disable=g-import-not-at-top
36 | try:
37 | from numba import njit
38 | except ImportError:
39 | warnings.warn("Numba could not be imported. Code will run much more slowly."
40 | " To install, please run 'pip install numba'.")
41 |
42 | # If Numba is not available, we define a dummy 'njit' function.
43 | def njit(func):
44 | return func
45 |
46 |
47 | # Copied from scikit-learn with the following modifications:
48 | # - use decreasing constraints by default,
49 | # - do not return solution in place, rather save in array `sol`,
50 | # - avoid some needless multiplications.
51 |
52 |
53 | @njit
54 | def isotonic_l2(y, sol):
55 | """Solves an isotonic regression problem using PAV.
56 |
57 | Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - y||^2.
58 |
59 | Args:
60 | y: input to isotonic regression, a 1d-array.
61 | sol: where to write the solution, an array of the same size as y.
62 | """
63 | n = y.shape[0]
64 | target = np.arange(n)
65 | c = np.ones(n)
66 | sums = np.zeros(n)
67 |
68 | # target describes a list of blocks. At any time, if [i..j] (inclusive) is
69 | # an active block, then target[i] := j and target[j] := i.
70 |
71 | for i in range(n):
72 | sol[i] = y[i]
73 | sums[i] = y[i]
74 |
75 | i = 0
76 | while i < n:
77 | k = target[i] + 1
78 | if k == n:
79 | break
80 | if sol[i] > sol[k]:
81 | i = k
82 | continue
83 | sum_y = sums[i]
84 | sum_c = c[i]
85 | while True:
86 | # We are within an increasing subsequence.
87 | prev_y = sol[k]
88 | sum_y += sums[k]
89 | sum_c += c[k]
90 | k = target[k] + 1
91 | if k == n or prev_y > sol[k]:
92 | # Non-singleton increasing subsequence is finished,
93 | # update first entry.
94 | sol[i] = sum_y / sum_c
95 | sums[i] = sum_y
96 | c[i] = sum_c
97 | target[i] = k - 1
98 | target[k - 1] = i
99 | if i > 0:
100 | # Backtrack if we can. This makes the algorithm
101 | # single-pass and ensures O(n) complexity.
102 | i = target[i - 1]
103 | # Otherwise, restart from the same point.
104 | break
105 |
106 | # Reconstruct the solution.
107 | i = 0
108 | while i < n:
109 | k = target[i] + 1
110 | sol[i + 1 : k] = sol[i]
111 | i = k
112 |
113 |
114 | @njit
115 | def _log_add_exp(x, y):
116 | """Numerically stable log-add-exp."""
117 | larger = max(x, y)
118 | smaller = min(x, y)
119 | return larger + np.log1p(np.exp(smaller - larger))
120 |
121 |
122 | # Modified implementation for the KL geometry case.
123 | @njit
124 | def isotonic_kl(y, w, sol):
125 | """Solves isotonic optimization with KL divergence using PAV.
126 |
127 | Formally, it solves argmin_{v_1 >= ... >= v_n} + .
128 |
129 | Args:
130 | y: input to isotonic optimization, a 1d-array.
131 | w: input to isotonic optimization, a 1d-array.
132 | sol: where to write the solution, an array of the same size as y.
133 | """
134 | n = y.shape[0]
135 | target = np.arange(n)
136 | lse_y_ = np.zeros(n)
137 | lse_w_ = np.zeros(n)
138 |
139 | # target describes a list of blocks. At any time, if [i..j] (inclusive) is
140 | # an active block, then target[i] := j and target[j] := i.
141 |
142 | for i in range(n):
143 | sol[i] = y[i] - w[i]
144 | lse_y_[i] = y[i]
145 | lse_w_[i] = w[i]
146 |
147 | i = 0
148 | while i < n:
149 | k = target[i] + 1
150 | if k == n:
151 | break
152 | if sol[i] > sol[k]:
153 | i = k
154 | continue
155 | lse_y = lse_y_[i]
156 | lse_w = lse_w_[i]
157 | while True:
158 | # We are within an increasing subsequence.
159 | prev_y = sol[k]
160 | lse_y = _log_add_exp(lse_y, lse_y_[k])
161 | lse_w = _log_add_exp(lse_w, lse_w_[k])
162 | k = target[k] + 1
163 | if k == n or prev_y > sol[k]:
164 | # Non-singleton increasing subsequence is finished,
165 | # update first entry.
166 | sol[i] = lse_y - lse_w
167 | lse_y_[i] = lse_y
168 | lse_w_[i] = lse_w
169 | target[i] = k - 1
170 | target[k - 1] = i
171 | if i > 0:
172 | # Backtrack if we can. This makes the algorithm
173 | # single-pass and ensures O(n) complexity.
174 | i = target[i - 1]
175 | # Otherwise, restart from the same point.
176 | break
177 |
178 | # Reconstruct the solution.
179 | i = 0
180 | while i < n:
181 | k = target[i] + 1
182 | sol[i + 1 : k] = sol[i]
183 | i = k
184 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Install fast_soft_sort."""
16 |
17 | from setuptools import find_packages
18 | from setuptools import setup
19 |
20 | setup(
21 | name='fast_soft_sort',
22 | version='0.1',
23 | description=(
24 | 'Differentiable sorting and ranking in O(n log n).'),
25 | author='Google LLC',
26 | author_email='no-reply@google.com',
27 | url='https://github.com/google-research/fast-soft-sort',
28 | license='BSD',
29 | packages=find_packages(),
30 | package_data={},
31 | install_requires=[
32 | 'numba',
33 | 'numpy',
34 | 'scipy>=1.2.0',
35 | ],
36 | extras_require={
37 | 'tf': ['tensorflow>=1.12'],
38 | },
39 | classifiers=[
40 | 'Intended Audience :: Science/Research',
41 | 'License :: OSI Approved :: BSD License',
42 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
43 | ],
44 | keywords='machine learning sorting ranking',
45 | )
46 |
--------------------------------------------------------------------------------
/tests/isotonic_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for isotonic.py."""
16 |
17 | import unittest
18 | from absl.testing import parameterized
19 | from fast_soft_sort.third_party import isotonic
20 | import numpy as np
21 | from sklearn.isotonic import isotonic_regression
22 |
23 |
24 | class IsotonicTest(parameterized.TestCase):
25 |
26 | def test_l2_agrees_with_sklearn(self):
27 | rng = np.random.RandomState(0)
28 | y = rng.randn(10) * rng.randint(1, 5)
29 | sol = np.zeros_like(y)
30 | isotonic.isotonic_l2(y, sol)
31 | sol_skl = isotonic_regression(y, increasing=False)
32 | np.testing.assert_array_almost_equal(sol, sol_skl)
33 |
34 |
35 | if __name__ == "__main__":
36 | unittest.main()
37 |
--------------------------------------------------------------------------------
/tests/jax_ops_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for jax_ops.py."""
16 |
17 | import functools
18 | import itertools
19 | import unittest
20 |
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 |
24 | import numpy as np
25 | import jax.numpy as jnp
26 | import jax
27 |
28 | from jax.config import config
29 | config.update("jax_enable_x64", True)
30 |
31 | from fast_soft_sort import jax_ops
32 |
33 | GAMMAS = (0.1, 1, 10.0)
34 | DIRECTIONS = ("ASCENDING", "DESCENDING")
35 | REGULARIZERS = ("l2", )
36 |
37 |
38 | class JaxOpsTest(parameterized.TestCase):
39 |
40 | def _test(self, func, regularization_strength, direction, regularization):
41 |
42 | def loss_func(values):
43 | soft_values = func(values,
44 | regularization_strength=regularization_strength,
45 | direction=direction,
46 | regularization=regularization)
47 | return jnp.sum(soft_values ** 2)
48 |
49 | rng = np.random.RandomState(0)
50 | values = jnp.array(rng.randn(5, 10))
51 | mat = jnp.array(rng.randn(5, 10))
52 | unitmat = mat / np.sqrt(np.vdot(mat, mat))
53 | eps = 1e-5
54 | numerical = (loss_func(values + 0.5 * eps * unitmat) -
55 | loss_func(values - 0.5 * eps * unitmat)) / eps
56 | autodiff = jnp.vdot(jax.grad(loss_func)(values), unitmat)
57 | np.testing.assert_almost_equal(numerical, autodiff)
58 |
59 |
60 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
61 | def test_soft_rank(self, regularization_strength, direction, regularization):
62 | self._test(jax_ops.soft_rank,
63 | regularization_strength, direction, regularization)
64 |
65 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
66 | def test_soft_sort(self, regularization_strength, direction, regularization):
67 | self._test(jax_ops.soft_sort,
68 | regularization_strength, direction, regularization)
69 |
70 |
71 | if __name__ == "__main__":
72 | absltest.main()
73 |
--------------------------------------------------------------------------------
/tests/numpy_ops_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for numpy_ops.py."""
16 |
17 | import itertools
18 | import unittest
19 | from absl.testing import parameterized
20 | from fast_soft_sort import numpy_ops
21 | import numpy as np
22 |
23 |
24 | def _num_jacobian(theta, f, eps=1e-9):
25 | n_classes = len(theta)
26 | ret = np.zeros((n_classes, n_classes))
27 |
28 | for i in range(n_classes):
29 | theta_ = theta.copy()
30 | theta_[i] += eps
31 | val = f(theta_)
32 | theta_[i] -= 2 * eps
33 | val2 = f(theta_)
34 | ret[i] = (val - val2) / (2 * eps)
35 |
36 | return ret.T
37 |
38 |
39 | GAMMAS = (0.1, 1.0, 10.0)
40 | DIRECTIONS = ("ASCENDING", "DESCENDING")
41 | REGULARIZERS = ("l2", "kl")
42 | CLASSES = (numpy_ops.Isotonic, numpy_ops.Projection)
43 |
44 |
45 | class IsotonicProjectionTest(parameterized.TestCase):
46 |
47 | @parameterized.parameters(itertools.product(CLASSES, REGULARIZERS))
48 | def test_jvp_and_vjp_against_numerical_jacobian(self, cls, regularization):
49 | rng = np.random.RandomState(0)
50 | theta = rng.randn(5)
51 | w = np.arange(5)[::-1]
52 | v = rng.randn(5)
53 |
54 | f = lambda x: cls(x, w, regularization=regularization).compute()
55 | J = _num_jacobian(theta, f)
56 |
57 | obj = cls(theta, w, regularization=regularization)
58 | obj.compute()
59 |
60 | out = obj.jvp(v)
61 | np.testing.assert_array_almost_equal(J.dot(v), out)
62 |
63 | out = obj.vjp(v)
64 | np.testing.assert_array_almost_equal(v.dot(J), out)
65 |
66 |
67 | class SoftRankTest(parameterized.TestCase):
68 |
69 | @parameterized.parameters(itertools.product(DIRECTIONS, REGULARIZERS))
70 | def test_soft_rank_converges_to_hard(self, direction, regularization):
71 | rng = np.random.RandomState(0)
72 | theta = rng.randn(5)
73 | soft_rank = numpy_ops.SoftRank(theta, regularization_strength=1e-3,
74 | direction=direction,
75 | regularization=regularization)
76 | out = numpy_ops.rank(theta, direction=direction)
77 | out2 = soft_rank.compute()
78 | np.testing.assert_array_almost_equal(out, out2)
79 |
80 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
81 | def test_soft_rank_jvp_and_vjp_against_numerical_jacobian(self,
82 | regularization_strength,
83 | direction,
84 | regularization):
85 | rng = np.random.RandomState(0)
86 | theta = rng.randn(5)
87 | v = rng.randn(5)
88 |
89 | f = lambda x: numpy_ops.SoftRank(
90 | x, regularization_strength=regularization_strength, direction=direction,
91 | regularization=regularization).compute()
92 | J = _num_jacobian(theta, f)
93 |
94 | soft_rank = numpy_ops.SoftRank(
95 | theta, regularization_strength=regularization_strength,
96 | direction=direction, regularization=regularization)
97 | soft_rank.compute()
98 |
99 | out = soft_rank.jvp(v)
100 | np.testing.assert_array_almost_equal(J.dot(v), out, 1e-6)
101 |
102 | out = soft_rank.vjp(v)
103 | np.testing.assert_array_almost_equal(v.dot(J), out, 1e-6)
104 |
105 | out = soft_rank.jacobian()
106 | np.testing.assert_array_almost_equal(J, out, 1e-6)
107 |
108 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
109 | def test_soft_rank_works_with_lists(self, regularization_strength, direction,
110 | regularization):
111 | rng = np.random.RandomState(0)
112 | theta = rng.randn(5)
113 | ranks1 = numpy_ops.SoftRank(theta,
114 | regularization_strength=regularization_strength,
115 | direction=direction,
116 | regularization=regularization).compute()
117 | ranks2 = numpy_ops.SoftRank(list(theta),
118 | regularization_strength=regularization_strength,
119 | direction=direction,
120 | regularization=regularization).compute()
121 | np.testing.assert_array_almost_equal(ranks1, ranks2)
122 |
123 |
124 | class SoftSortTest(parameterized.TestCase):
125 |
126 | @parameterized.parameters(itertools.product(DIRECTIONS, REGULARIZERS))
127 | def test_soft_sort_converges_to_hard(self, direction, regularization):
128 | rng = np.random.RandomState(0)
129 | theta = rng.randn(5)
130 | soft_sort = numpy_ops.SoftSort(
131 | theta, regularization_strength=1e-3, direction=direction,
132 | regularization=regularization)
133 | sort = numpy_ops.Sort(theta, direction=direction)
134 | out = sort.compute()
135 | out2 = soft_sort.compute()
136 | np.testing.assert_array_almost_equal(out, out2)
137 |
138 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
139 | def test_soft_sort_jvp(self, regularization_strength, direction,
140 | regularization):
141 | rng = np.random.RandomState(0)
142 | theta = rng.randn(5)
143 | v = rng.randn(5)
144 |
145 | f = lambda x: numpy_ops.SoftSort(
146 | x, regularization_strength=regularization_strength,
147 | direction=direction, regularization=regularization).compute()
148 | J = _num_jacobian(theta, f)
149 |
150 | soft_sort = numpy_ops.SoftSort(
151 | theta, regularization_strength=regularization_strength,
152 | direction=direction, regularization=regularization)
153 | soft_sort.compute()
154 |
155 | out = soft_sort.jvp(v)
156 | np.testing.assert_array_almost_equal(J.dot(v), out, 1e-6)
157 |
158 | out = soft_sort.vjp(v)
159 | np.testing.assert_array_almost_equal(v.dot(J), out, 1e-6)
160 |
161 | out = soft_sort.jacobian()
162 | np.testing.assert_array_almost_equal(J, out, 1e-6)
163 |
164 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS))
165 | def test_soft_sort_works_with_lists(self, regularization_strength, direction,
166 | regularization):
167 | rng = np.random.RandomState(0)
168 | theta = rng.randn(5)
169 | sort1 = numpy_ops.SoftSort(theta,
170 | regularization_strength=regularization_strength,
171 | direction=direction,
172 | regularization=regularization).compute()
173 | sort2 = numpy_ops.SoftSort(list(theta),
174 | regularization_strength=regularization_strength,
175 | direction=direction,
176 | regularization=regularization).compute()
177 | np.testing.assert_array_almost_equal(sort1, sort2)
178 |
179 |
180 | class SortTest(parameterized.TestCase):
181 |
182 | @parameterized.parameters(itertools.product(DIRECTIONS))
183 | def test_sort_jvp(self, direction):
184 | rng = np.random.RandomState(0)
185 | theta = rng.randn(5)
186 | v = rng.randn(5)
187 |
188 | f = lambda x: numpy_ops.Sort(x, direction=direction).compute()
189 | J = _num_jacobian(theta, f)
190 |
191 | sort = numpy_ops.Sort(theta, direction=direction)
192 | sort.compute()
193 |
194 | out = sort.jvp(v)
195 | np.testing.assert_array_almost_equal(J.dot(v), out)
196 |
197 | out = sort.vjp(v)
198 | np.testing.assert_array_almost_equal(v.dot(J), out)
199 |
200 | @parameterized.parameters(itertools.product(DIRECTIONS))
201 | def test_sort_works_with_lists(self, direction):
202 | rng = np.random.RandomState(0)
203 | theta = rng.randn(5)
204 | sort_numpy = numpy_ops.Sort(theta, direction=direction).compute()
205 | sort_list = numpy_ops.Sort(list(theta), direction=direction).compute()
206 | np.testing.assert_array_almost_equal(sort_numpy, sort_list)
207 |
208 |
209 | if __name__ == "__main__":
210 | unittest.main()
211 |
--------------------------------------------------------------------------------
/tests/pytorch_ops_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for pytorch_ops.py."""
16 |
17 | import functools
18 | import itertools
19 | import unittest
20 |
21 | from absl.testing import parameterized
22 | from fast_soft_sort import pytorch_ops
23 | import torch
24 |
25 |
26 | GAMMAS = (0.1, 1, 10.0)
27 | DIRECTIONS = ("ASCENDING", "DESCENDING")
28 | REGULARIZERS = ("l2",) # The kl case is unstable for gradcheck.
29 | DTYPES = (torch.float64,)
30 |
31 |
32 | class PyTorchOpsTest(parameterized.TestCase):
33 |
34 | def _test(self, func, regularization_strength, direction, regularization,
35 | dtype, atol=1e-3, rtol=1e-3, eps=1e-5):
36 | x = torch.randn(5, 10, dtype=dtype, requires_grad=True)
37 |
38 | func = functools.partial(
39 | func, regularization_strength=regularization_strength,
40 | direction=direction, regularization=regularization)
41 |
42 | torch.autograd.gradcheck(func, [x], eps=eps, atol=atol, rtol=rtol)
43 |
44 | def _compute_loss(x):
45 | y = func(x, regularization_strength=regularization_strength,
46 | direction=direction, regularization=regularization)
47 | return torch.sum(y**2)
48 |
49 | torch.autograd.gradcheck(_compute_loss, x, eps=eps, atol=atol, rtol=rtol)
50 |
51 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS,
52 | DTYPES))
53 | def test_rank_gradient(self, regularization_strength, direction,
54 | regularization, dtype):
55 | self._test(pytorch_ops.soft_rank, regularization_strength, direction,
56 | regularization, dtype)
57 |
58 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS,
59 | DTYPES))
60 | def test_sort_gradient(self, regularization_strength, direction,
61 | regularization, dtype):
62 | self._test(pytorch_ops.soft_sort, regularization_strength, direction,
63 | regularization, dtype)
64 |
65 |
66 | if __name__ == "__main__":
67 | unittest.main()
68 |
--------------------------------------------------------------------------------
/tests/tf_ops_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for tf_ops.py."""
16 |
17 | import functools
18 | import itertools
19 | from absl.testing import parameterized
20 | from fast_soft_sort import tf_ops
21 | import tensorflow.compat.v2 as tf
22 |
23 |
24 | GAMMAS = (0.1, 1.0, 10.0)
25 | DIRECTIONS = ("ASCENDING", "DESCENDING")
26 | REGULARIZERS = ("l2", "kl")
27 | DTYPES = (tf.float64,)
28 |
29 |
30 | class TfOpsTest(parameterized.TestCase, tf.test.TestCase):
31 |
32 | def _test(self, func, regularization_strength, direction, regularization,
33 | dtype):
34 |
35 | precision = 1e-6
36 | delta = 1e-4
37 |
38 | x = tf.random.normal((5, 10), dtype=dtype)
39 |
40 | func = functools.partial(
41 | func, regularization_strength=regularization_strength,
42 | direction=direction, regularization=regularization)
43 |
44 | grad_theoretical, grad_numerical = tf.test.compute_gradient(
45 | func, [x], delta=delta)
46 |
47 | self.assertAllClose(grad_theoretical[0], grad_numerical[0], precision)
48 |
49 | def _compute_loss(x):
50 | y = func(x, regularization_strength=regularization_strength,
51 | direction=direction, regularization=regularization)
52 | return tf.reduce_mean(y**2)
53 |
54 | grad_theoretical, grad_numerical = tf.test.compute_gradient(
55 | _compute_loss, [x], delta=delta)
56 |
57 | self.assertAllClose(grad_theoretical[0], grad_numerical[0], precision)
58 |
59 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS,
60 | DTYPES))
61 | def test_rank_gradient(self, regularization_strength, direction,
62 | regularization, dtype):
63 | self._test(tf_ops.soft_rank, regularization_strength, direction,
64 | regularization, dtype)
65 |
66 | @parameterized.parameters(itertools.product(GAMMAS, DIRECTIONS, REGULARIZERS,
67 | DTYPES))
68 | def test_sort_gradient(self, regularization_strength, direction,
69 | regularization, dtype):
70 | if regularization == "l2" or regularization_strength < 10:
71 | # We skip regularization_strength >= 10 when regularization = "kl",
72 | # due to numerical instability.
73 | self._test(tf_ops.soft_sort, regularization_strength, direction,
74 | regularization, dtype)
75 |
76 |
77 | if __name__ == "__main__":
78 | tf.enable_v2_behavior()
79 | tf.test.main()
80 |
--------------------------------------------------------------------------------