├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE.txt ├── README.md ├── c ├── .gitignore ├── Makefile ├── d.enc ├── flip.c ├── flip.h ├── macros.c ├── main.c ├── readio.c ├── readio.h ├── sample.c ├── sample.h └── sstructs.h ├── check.sh ├── clean.sh ├── examples ├── __init__.py ├── maxerror.py └── sampling.py ├── pythenv.sh ├── setup.py ├── src ├── .gitignore ├── __init__.py ├── construct.py ├── divergences.py ├── flip.py ├── matrix.py ├── opt.py ├── orderm2 ├── packing.py ├── phi ├── sample.py ├── tree.py ├── utils.py └── writeio.py └── tests ├── __init__.py ├── test_divergences.py ├── test_matrix.py ├── test_opt.py ├── test_packing.py ├── test_sample.py ├── test_tree.py ├── test_utils.py └── utils.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.6] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Test src 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install '.[tests]' 30 | ./check.sh 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /*.egg-info 2 | /src/version.py 3 | /VERSION 4 | 5 | __pycache__/ 6 | 7 | build/ 8 | dist/ 9 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimal Approximate Sampling From Discrete Probability Distributions 2 | 3 | This repository contains a prototype implementation of the optimal 4 | sampling algorithms from: 5 | 6 | > Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka. 7 | [Optimal Approximate Sampling From Discrete Probability 8 | Distributions](https://doi.org/10.1145/3371104). 9 | _Proc. ACM Program. Lang._ 4, POPL, Article 36 (January 2020), 33 pages. 10 | 11 | ## Installing 12 | 13 | The Python 3 library can be installed via pip: 14 | 15 | pip install optas 16 | 17 | The C code for the main sampler is in the `c/` directory and the 18 | Python 3 libraries are in the `src/` directory. 19 | 20 | Only Python 3 is required to build and use the software from source. 21 | 22 | $ git clone https://github.com/probcomp/optimal-approximate-sampling 23 | $ cd optimal-approximate-sampling 24 | $ python setup.py install 25 | 26 | To build the C sampler 27 | 28 | $ cd c && make all 29 | 30 | ## Usage 31 | 32 | Please refer to the examples in the [examples](./examples) directory. 33 | Given a fixed target distribution and error measure: 34 | 35 | 1. [./examples/sampling.py](./examples/sampling.py) shows an example of how 36 | to find an optimal distribution and sample from it, given a 37 | user-specified precision. 38 | 39 | 2. [./examples/maxerror.py](./examples/maxerror.py) shows an example of how 40 | to find an optimal distribution that uses the least possible precision 41 | and obtains an error that is less than a user-specified maximum 42 | allowable error. 43 | 44 | These examples can be run directly as follows: 45 | 46 | $ ./pythenv.sh python examples/sampling.py 47 | $ ./pythenv.sh python examples/maxerror.py 48 | 49 | ## Tests 50 | 51 | To test the Python library and run a crash test in C (requires 52 | [pytest](https://docs.pytest.org/en/latest/) and 53 | [scipy](https://scipy.org/)): 54 | 55 | $ ./check.sh 56 | 57 | ## Experiments 58 | 59 | The code for experiments in the POPL publication is available in a tarball 60 | on the ACM Digital Library. Please refer to the online supplementary 61 | material at https://doi.org/10.1145/3371104. 62 | 63 | ## Citing 64 | 65 | Please use the following BibTeX to cite this work. 66 | 67 | @article{saad2020sampling, 68 | title = {Optimal approximate sampling from discrete probability distributions}, 69 | author = {Saad, Feras A. and Freer, Cameron E. and Rinard, Martin C. and Mansinghka, Vikash K.}, 70 | journal = {Proc. ACM Program. Lang.}, 71 | volume = 4, 72 | number = {POPL}, 73 | month = jan, 74 | year = 2020, 75 | pages = {36:1--36:31}, 76 | numpages = 31, 77 | publisher = {ACM}, 78 | doi = {10.1145/3371104}, 79 | abstract = {This paper addresses a fundamental problem in random variate generation: given access to a random source that emits a stream of independent fair bits, what is the most accurate and entropy-efficient algorithm for sampling from a discrete probability distribution $(p_1, \dots, p_n)$, where the output distribution $(\hat{p}_1, \dots, \hat{p}_n)$ of the sampling algorithm can be specified with a given level of bit precision? We present a theoretical framework for formulating this problem and provide new techniques for finding sampling algorithms that are optimal both statistically (in the sense of sampling accuracy) and information-theoretically (in the sense of entropy consumption). We leverage these results to build a system that, for a broad family of measures of statistical accuracy, delivers a sampling algorithm whose expected entropy usage is minimal among those that induce the same distribution (i.e., is ``entropy-optimal'') and whose output distribution $(\hat{p}_1, \dots, \hat{p}_n)$ is a closest approximation to the target distribution $(p_1, \dots, p_n)$ among all entropy-optimal sampling algorithms that operate within the specified precision budget. This optimal approximate sampler is also a closer approximation than any (possibly entropy-suboptimal) sampler that consumes a bounded amount of entropy with the specified precision, a class which includes floating-point implementations of inversion sampling and related methods found in many standard software libraries. We evaluate the accuracy, entropy consumption, precision requirements, and wall-clock runtime of our optimal approximate sampling algorithms on a broad set of probability distributions, demonstrating the ways that they are superior to existing approximate samplers and establishing that they often consume significantly fewer resources than are needed by exact samplers.}, 80 | } 81 | 82 | ## Related Repositories 83 | 84 | For a near-optimal exact dice rolling algorithm see 85 | https://github.com/probcomp/fast-loaded-dice-roller. 86 | -------------------------------------------------------------------------------- /c/.gitignore: -------------------------------------------------------------------------------- 1 | *.gch 2 | *.gcda 3 | *.gcna 4 | *.gcno 5 | *.o 6 | *.out 7 | mainc 8 | mainc.opt 9 | maincpp 10 | maincpp.opt 11 | preprocessc 12 | preprocessc.opt 13 | preprocesscpp 14 | preprocesscpp.opt 15 | -------------------------------------------------------------------------------- /c/Makefile: -------------------------------------------------------------------------------- 1 | all: mainc mainc.opt test 2 | 3 | SRC_C = flip.c main.c macros.c readio.c sample.c 4 | 5 | mainc: $(SRC_C) 6 | gcc -pg -o mainc $^ 7 | 8 | mainc.opt: $(SRC_C) 9 | gcc -O3 -Wno-unused-result -o mainc.opt $^ 10 | 11 | .PHONY: clean 12 | clean: 13 | rm -rf *.out *.gcno *.gcnh *.gcda *.gch mainc main.opt 14 | 15 | # Test 16 | .PHONY: test 17 | test: mainc.opt 18 | ./mainc.opt 1 1000000 ky.enc ./d.enc 19 | -------------------------------------------------------------------------------- /c/d.enc: -------------------------------------------------------------------------------- 1 | 4 16 2 | 91 2 87 4 83 6 79 8 75 10 71 12 67 14 63 16 59 18 55 20 51 22 47 24 43 26 39 28 35 30 34 32 33 -4 -2 -4 37 38 -3 -1 41 42 -3 -2 45 46 -2 -1 49 50 -4 -1 53 54 -4 -3 57 58 -3 -2 61 62 -2 -1 65 66 -4 -1 69 70 -4 -3 73 74 -3 -2 77 78 -2 -1 81 82 -4 -1 85 86 -4 -3 89 90 -3 -2 3 | -------------------------------------------------------------------------------- /c/flip.c: -------------------------------------------------------------------------------- 1 | // Flipping a coin. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #include 5 | 6 | #include "flip.h" 7 | 8 | unsigned long NUM_RNG_CALLS = 0; 9 | 10 | // RAND_MAX is 2**31-1, so bits are 0,...30 11 | static int k = 31; 12 | static int flip_word = 0; 13 | static int flip_pos = 0; 14 | 15 | int flip(void){ 16 | if (flip_pos == 0) { 17 | NUM_RNG_CALLS++; 18 | flip_word = rand(); 19 | flip_pos = k; 20 | } 21 | --flip_pos; 22 | return (flip_word >> flip_pos) & 1; 23 | } 24 | 25 | int randint(int k) { 26 | int n = 0; 27 | 28 | for (int i = 0; i < k; i++) { 29 | int b = flip(); 30 | n <<= 1; 31 | n += b; 32 | } 33 | 34 | return n; 35 | } 36 | -------------------------------------------------------------------------------- /c/flip.h: -------------------------------------------------------------------------------- 1 | // Flipping a coin. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #ifndef FLIP_H 5 | #define FLIP_H 6 | 7 | extern unsigned long NUM_RNG_CALLS; 8 | 9 | int flip(void); 10 | int randint(int k); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /c/macros.c: -------------------------------------------------------------------------------- 1 | // Macro for reading, sampling, and timing. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #define READ_SAMPLE_TIME(key, \ 5 | var_sampler, \ 6 | struct_name, \ 7 | func_read, \ 8 | func_sample, \ 9 | func_free, \ 10 | var_path, \ 11 | var_steps, \ 12 | var_t, \ 13 | var_x) \ 14 | if(strcmp(var_sampler, key) == 0) { \ 15 | struct struct_name s = func_read(var_path); \ 16 | var_t = clock(); \ 17 | for (int i = 0; i < var_steps; i++) { \ 18 | var_x += func_sample(&s); \ 19 | } \ 20 | var_t = clock() - var_t; \ 21 | func_free(s); \ 22 | } 23 | -------------------------------------------------------------------------------- /c/main.c: -------------------------------------------------------------------------------- 1 | // Main harness for C samplers. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "flip.h" 12 | #include "readio.h" 13 | #include "sample.h" 14 | #include "sstructs.h" 15 | 16 | #include "macros.c" 17 | 18 | int main(int argc, char **argv) { 19 | // Read command line arguments. 20 | if (argc != 5) { 21 | printf("usage: ./mainc seed steps sampler path\n"); 22 | exit(0); 23 | } 24 | int seed = atoi(argv[1]); 25 | int steps = atoi(argv[2]); 26 | char *sampler = argv[3]; 27 | char *path = argv[4]; 28 | 29 | printf("%d %d %s %s\n", seed, steps, sampler, path); 30 | srand(seed); 31 | 32 | int x = 0; 33 | clock_t t; 34 | READ_SAMPLE_TIME("ky.enc", 35 | sampler, 36 | sample_ky_encoding_s, 37 | read_sample_ky_encoding, 38 | sample_ky_encoding, 39 | free_sample_ky_encoding_s, 40 | path, steps, t, x) 41 | else READ_SAMPLE_TIME("ky.mat", 42 | sampler, 43 | sample_ky_matrix_s, 44 | read_sample_ky_matrix, 45 | sample_ky_matrix, 46 | free_sample_ky_matrix_s, 47 | path, steps, t, x) 48 | else READ_SAMPLE_TIME("ky.matc", 49 | sampler, 50 | sample_ky_matrix_cached_s, 51 | read_sample_ky_matrix_cached, 52 | sample_ky_matrix_cached, 53 | free_sample_ky_matrix_cached_s, 54 | path, steps, t, x) 55 | else { 56 | printf("Unknown sampler: %s\n", sampler); 57 | exit(1); 58 | } 59 | 60 | double e = ((double)t) / CLOCKS_PER_SEC; 61 | printf("%s %1.5f %ld\n", sampler, e, NUM_RNG_CALLS); 62 | 63 | return 0; 64 | } 65 | -------------------------------------------------------------------------------- /c/readio.c: -------------------------------------------------------------------------------- 1 | // Loading sampling data structures from disk. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #include 5 | 6 | #include "readio.h" 7 | #include "sstructs.h" 8 | 9 | // Load matrix from file. 10 | struct matrix_s load_matrix(FILE *fp) { 11 | 12 | struct matrix_s mat; 13 | fscanf(fp, "%d %d", &(mat.nrows), &(mat.ncols)); 14 | 15 | mat.P = (int **) calloc(mat.nrows, sizeof(int **)); 16 | for(int r = 0; r < mat.nrows; ++r) { 17 | mat.P[r] = (int *) calloc(mat.ncols, sizeof(int)); 18 | for (int c = 0; c < mat.ncols; ++c){ 19 | fscanf(fp, "%d", &(mat.P[r][c])); 20 | } 21 | } 22 | 23 | return mat; 24 | } 25 | 26 | void free_matrix_s (struct matrix_s x) { 27 | for (int i = 0; i < x.nrows; i++) { 28 | free(x.P[i]); 29 | } 30 | free(x.P); 31 | } 32 | 33 | // Load matrix from file. 34 | struct array_s load_array(FILE *fp) { 35 | 36 | struct array_s arr; 37 | fscanf(fp, "%d", &(arr.length)); 38 | 39 | arr.a = (int *) calloc(arr.length, sizeof(int)); 40 | for (int i = 0; i < arr.length; i++) { 41 | fscanf(fp, "%d", &arr.a[i]); 42 | } 43 | 44 | return arr; 45 | } 46 | 47 | void free_array_s (struct array_s x) { 48 | free(x.a); 49 | } 50 | 51 | // Load sample_ky_encoding data structure from file path. 52 | struct sample_ky_encoding_s read_sample_ky_encoding(char *fname) { 53 | FILE *fp = fopen(fname, "r"); 54 | 55 | struct sample_ky_encoding_s x; 56 | fscanf(fp, "%d %d", &(x.n), &(x.k)); 57 | x.encoding = load_array(fp); 58 | 59 | fclose(fp); 60 | return x; 61 | } 62 | 63 | void free_sample_ky_encoding_s (struct sample_ky_encoding_s x) { 64 | free_array_s(x.encoding); 65 | } 66 | 67 | // Load sample_ky_matrix data structure from file path. 68 | struct sample_ky_matrix_s read_sample_ky_matrix(char *fname) { 69 | FILE *fp = fopen(fname, "r"); 70 | 71 | struct sample_ky_matrix_s x; 72 | fscanf(fp, "%d %d", &(x.k), &(x.l)); 73 | x.P = load_matrix(fp); 74 | 75 | fclose(fp); 76 | return x; 77 | } 78 | 79 | void free_sample_ky_matrix_s (struct sample_ky_matrix_s x) { 80 | free_matrix_s(x.P); 81 | } 82 | 83 | // Load sample_ky_matrix_cached data structure from file path. 84 | struct sample_ky_matrix_cached_s read_sample_ky_matrix_cached(char *fname) { 85 | FILE *fp = fopen(fname, "r"); 86 | 87 | struct sample_ky_matrix_cached_s x; 88 | fscanf(fp, "%d %d", &(x.k), &(x.l)); 89 | x.h = load_array(fp); 90 | x.T = load_matrix(fp); 91 | 92 | fclose(fp); 93 | return x; 94 | } 95 | 96 | void free_sample_ky_matrix_cached_s (struct sample_ky_matrix_cached_s x) { 97 | free_array_s(x.h); 98 | free_matrix_s(x.T); 99 | } 100 | -------------------------------------------------------------------------------- /c/readio.h: -------------------------------------------------------------------------------- 1 | // Loading sampling data structures from disk. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #ifndef READIO_H 5 | #define READIO_H 6 | 7 | #include 8 | #include "sstructs.h" 9 | 10 | struct matrix_s load_matrix(FILE *fp); 11 | struct array_s load_array(FILE *fp); 12 | struct sample_ky_encoding_s read_sample_ky_encoding(char *fname); 13 | struct sample_ky_matrix_s read_sample_ky_matrix(char *fname); 14 | struct sample_ky_matrix_cached_s read_sample_ky_matrix_cached(char *fname); 15 | 16 | void free_matrix_s(struct matrix_s x); 17 | void free_array_s(struct array_s x); 18 | void free_sample_ky_encoding_s(struct sample_ky_encoding_s x); 19 | void free_sample_ky_matrix_s(struct sample_ky_matrix_s x); 20 | void free_sample_ky_matrix_cached_s(struct sample_ky_matrix_cached_s x); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /c/sample.c: -------------------------------------------------------------------------------- 1 | // Sample from distribution data structures. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #include 5 | #include 6 | 7 | #include "flip.h" 8 | #include "sample.h" 9 | #include "sstructs.h" 10 | 11 | int sample_ky_encoding(struct sample_ky_encoding_s *x) { 12 | 13 | if (x->encoding.length == 1) { 14 | return 1; 15 | } 16 | 17 | int *enc = x->encoding.a; 18 | int c = 0; 19 | while (true) { 20 | int b = flip(); 21 | c = enc[c+b]; 22 | if (enc[c] < 0) { 23 | return -enc[c]; 24 | } 25 | } 26 | } 27 | 28 | int sample_ky_matrix(struct sample_ky_matrix_s *x) { 29 | if (x->P.nrows == 1) { 30 | return 1; 31 | } 32 | 33 | int **P = x->P.P; 34 | int c = 0; 35 | int d = 0; 36 | 37 | while (true) { 38 | int b = flip(); 39 | d = 2 * d + (1-b); 40 | for (int r = 0; r < x->P.nrows; r++) { 41 | d = d - P[r][c]; 42 | if (d == - 1) { 43 | return r + 1; 44 | } 45 | } 46 | if (c == x->k - 1) { 47 | c = x->l; 48 | } else { 49 | c = c + 1; 50 | } 51 | } 52 | } 53 | 54 | int sample_ky_matrix_cached(struct sample_ky_matrix_cached_s *x) { 55 | if (x->T.nrows == 1) { 56 | return 1; 57 | } 58 | 59 | int **T = x->T.P; 60 | int *h = x->h.a; 61 | 62 | int c = 0; 63 | int d = 0; 64 | 65 | while (true) { 66 | int b = flip(); 67 | d = 2 * d + (1-b); 68 | if (d < h[c]) { 69 | return T[d][c] + 1; 70 | } 71 | d = d - h[c]; 72 | if (c == x->k - 1) { 73 | c = x->l; 74 | } else { 75 | c = c + 1; 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /c/sample.h: -------------------------------------------------------------------------------- 1 | // Loading sampling data structures from disk. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #ifndef SAMPLE_H 5 | #define SAMPLE_H 6 | 7 | #include "sstructs.h" 8 | 9 | int sample_ky_encoding(struct sample_ky_encoding_s *x); 10 | int sample_ky_matrix(struct sample_ky_matrix_s *x); 11 | int sample_ky_matrix_cached(struct sample_ky_matrix_cached_s *x); 12 | #endif 13 | -------------------------------------------------------------------------------- /c/sstructs.h: -------------------------------------------------------------------------------- 1 | // Defining structs for samplers.. 2 | // ** @author: fsaad@mit.edu 3 | 4 | #ifndef SSTRUCTS_H 5 | #define SSTRUCTS_H 6 | 7 | #include 8 | #include 9 | 10 | // matrix 11 | struct matrix_s { 12 | int nrows; 13 | int ncols; 14 | int **P; 15 | }; 16 | 17 | // array 18 | struct array_s { 19 | int length; 20 | int *a; 21 | }; 22 | 23 | 24 | // sample_ky_encoding 25 | struct sample_ky_encoding_s { 26 | int n; 27 | int k; 28 | struct array_s encoding; 29 | }; 30 | 31 | // sample_ky_matrix 32 | struct sample_ky_matrix_s { 33 | int k; 34 | int l; 35 | struct matrix_s P; 36 | }; 37 | 38 | // sample_ky_matrix_cached 39 | struct sample_ky_matrix_cached_s { 40 | int k; 41 | int l; 42 | struct array_s h; 43 | struct matrix_s T; 44 | }; 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -Ceu 4 | 5 | : ${PYTHON:=python} 6 | 7 | root=`cd -- "$(dirname -- "$0")" && pwd` 8 | 9 | ( 10 | set -Ceu 11 | cd -- "${root}" 12 | rm -rf build 13 | 14 | # Run Python tests. 15 | "$PYTHON" setup.py build 16 | if [ $# -eq 0 ]; then 17 | # By default run all tests. 18 | # Any test which uses this flag should end with __ci_() which 19 | # activates integration testing code path. If --integration is 20 | # not specified then a __ci_() test will either run as a crash test 21 | # or not run at all. (Use git grep '__ci_' to find these tests.) 22 | ./pythenv.sh "$PYTHON" -m pytest --pyargs optas 23 | # Run C crash test. 24 | cd c/ && make test 25 | elif [ ${1} = 'crash' ]; then 26 | ./pythenv.sh "$PYTHON" -m pytest -k 'not __ci_' --pyargs optas 27 | elif [ ${1} = 'release' ]; then 28 | # Make a release 29 | rm -rf dist 30 | "$PYTHON" setup.py sdist bdist_wheel 31 | twine upload --repository pypi dist/* 32 | else 33 | # If args are specified delegate control to user. 34 | ./pythenv.sh "$PYTHON" -m pytest "$@" 35 | fi 36 | ) 37 | -------------------------------------------------------------------------------- /clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf \ 4 | *.egg-info \ 5 | *.so \ 6 | .eggs/ \ 7 | build/ \ 8 | dist/ \ 9 | tests/__pycache__ \ 10 | VERSION 11 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | -------------------------------------------------------------------------------- /examples/maxerror.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Example of finding optimal distribution given a maximum allowed error.""" 4 | 5 | from math import ceil 6 | from math import log2 7 | 8 | from optas.divergences import KERNELS 9 | from optas.divergences import compute_divergence_kernel 10 | from optas.opt import get_optimal_probabilities 11 | from optas.tests.utils import allclose 12 | from optas.utils import argmin 13 | from optas.utils import get_Zkl 14 | 15 | 16 | def find_optimal_approximation(p_target, kernel, maxerror, dyadic): 17 | """Return optimal approximation 18 | 19 | Inputs: 20 | - p_target : list of target probabilities 21 | - kernel : name of f-divergence to use, see KERNELS from divergences.py 22 | - maxerror : maximum permitted approximation error 23 | - dyadic : True if sum of weights must be a power of two. 24 | 25 | Returns: 26 | - p_approx : the optimal approximation 27 | - error : the achieved error 28 | - Z : the sum of weights 29 | """ 30 | assert allclose(sum(p_target), 1) 31 | 32 | # These divergence measures require the approximate distribution 33 | # to have full support over the domain of the target distribution 34 | # in order to achieve a finite error. 35 | strict = kernel in ['kl', 'nchi2'] 36 | n = len(p_target) 37 | 38 | # Initial error and precision k. 39 | error = float('inf') 40 | k = 1 if not strict else ceil(log2(n)) 41 | 42 | # Keep doubling precision 43 | while maxerror < error: 44 | # Possible sum of weights for given precision k. 45 | Z_list = [pow(2, k)] if dyadic else [get_Zkl(k, l) for l in range(1, k+1)] 46 | if strict: 47 | Z_list = [Z for Z in Z_list if len(p_target) <= Z] 48 | 49 | # List of approximate distributions, one for each Z. 50 | p_approx_list = [ 51 | get_optimal_probabilities(Z, p_target, KERNELS[kernel]) 52 | for Z in Z_list 53 | ] 54 | 55 | # List of errors, one of for each approximate distribution 56 | error_list = [ 57 | compute_divergence_kernel(p_target, p_approx, KERNELS[kernel]) 58 | for p_approx in p_approx_list 59 | ] 60 | 61 | # Record the lowest error. 62 | i = argmin(error_list) 63 | p_approx = p_approx_list[i] 64 | error = error_list[i] 65 | Z = Z_list[i] 66 | k += 1 67 | 68 | return (p_approx, error, Z) 69 | 70 | (p_approx, error, Z) = find_optimal_approximation( 71 | p_target=[.07, .91, .02], 72 | kernel='nchi2', 73 | maxerror=2**-10, 74 | dyadic=False) 75 | 76 | print('optimal approximate distribution', p_approx) 77 | print('achieved error', error) 78 | -------------------------------------------------------------------------------- /examples/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Example of finding optimal distribution given a fixed precision.""" 4 | 5 | from tempfile import NamedTemporaryFile 6 | from fractions import Fraction 7 | from collections import Counter 8 | 9 | from optas.divergences import KERNELS 10 | from optas.opt import get_optimal_probabilities 11 | from optas.construct import construct_sample_ky_encoding 12 | from optas.sample import sample_ky_encoding 13 | 14 | from optas.writeio import write_sample_ky_encoding 15 | 16 | # Target probability distribution. 17 | p_target = [Fraction(1, 10), Fraction(3, 10), Fraction(4, 10), Fraction(2, 10)] 18 | 19 | # Obtain optimal probabilities (Algorithm 3). 20 | precision = 32 21 | kernel = 'hellinger' 22 | p_approx = get_optimal_probabilities(2**precision, p_target, KERNELS[kernel]) 23 | 24 | # Construct the sampler (Section 5). 25 | enc, n, k = construct_sample_ky_encoding(p_approx) 26 | 27 | # Run the sampler. 28 | num_samples = 50000 29 | samples = [sample_ky_encoding(enc) for _i in range(num_samples)] 30 | counts = Counter(samples) 31 | 32 | f_expect = [float(p) for p in p_target] 33 | f_actual = [counts[i]/num_samples for i in sorted(counts.keys())] 34 | 35 | print('generated %d samples' % (num_samples,)) 36 | print('average frequencies: %s' % (f_expect,)) 37 | print('sampled frequencies: %s' % (f_actual,)) 38 | 39 | # Write sampler to disk (for the C command-line interface). 40 | with NamedTemporaryFile(delete=False) as f: 41 | write_sample_ky_encoding(enc, n, k, f.name) 42 | print('\nsampler written to: %s' % (f.name,)) 43 | print('to generate %d samples in C, run this command from c/ directory:' 44 | % (num_samples,)) 45 | print('$ ./mainc.opt 1 %d ky.enc %s' % (num_samples, f.name)) 46 | -------------------------------------------------------------------------------- /pythenv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -Ceu 4 | 5 | : ${PYTHON:=python} 6 | root=`cd -- "$(dirname -- "$0")" && pwd` 7 | platform=$("${PYTHON}" -c 'import distutils.util as u; print(u.get_platform())') 8 | version=$("${PYTHON}" -c 'import sys; print(sys.version[0:3])') 9 | 10 | # The lib directory varies depending on 11 | # 12 | # (a) whether there are extension modules (here, no); and 13 | # (b) whether some Debian maintainer decided to patch the local Python 14 | # to behave as though there were. 15 | # 16 | # But there's no obvious way to just ask distutils what the name will 17 | # be. There's no harm in naming a pathname that doesn't exist, other 18 | # than a handful of microseconds of runtime, so we'll add both. 19 | libdir="${root}/build/lib" 20 | plat_libdir="${libdir}.${platform}-${version}" 21 | export PYTHONPATH="${libdir}:${plat_libdir}${PYTHONPATH:+:${PYTHONPATH}}" 22 | 23 | bindir="${root}/build/scripts-${version}" 24 | export PATH="${bindir}${PATH:+:${PATH}}" 25 | 26 | exec "$@" 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 MIT Probabilistic Computing Project. 2 | # See LICENSE.txt 3 | 4 | import os 5 | import re 6 | from setuptools import setup 7 | 8 | # Determine the version (hardcoded). 9 | dirname = os.path.dirname(os.path.realpath(__file__)) 10 | vre = re.compile('__version__ = \'(.*?)\'') 11 | m = open(os.path.join(dirname, 'src', '__init__.py')).read() 12 | __version__ = vre.findall(m)[0] 13 | 14 | setup( 15 | name='optas', 16 | version=__version__, 17 | description='Optimal Approximate Sampling from Discrete Probability Distributions', 18 | long_description=open('README.md').read(), 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/probcomp/optimal-approximate-sampling', 21 | maintainer='Feras A. Saad', 22 | maintainer_email='fsaad@mit.edu', 23 | license='Apache-2.0', 24 | classifiers=[ 25 | 'Development Status :: 2 - Pre-Alpha', 26 | 'Intended Audience :: Science/Research', 27 | 'License :: OSI Approved :: Apache Software License', 28 | 'Programming Language :: Python :: 3', 29 | ], 30 | packages=[ 31 | 'optas', 32 | 'optas.tests', 33 | 'optas.examples', 34 | ], 35 | package_dir={ 36 | 'optas': 'src', 37 | 'optas.tests': 'tests', 38 | 'optas.examples': 'examples', 39 | }, 40 | package_data={ 41 | 'optas': [ 42 | 'orderm2', 43 | 'phi', 44 | '../c/*.c', 45 | ], 46 | }, 47 | extras_require={ 48 | 'tests': ['pytest', 'scipy'] 49 | } 50 | ) 51 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.pickle 3 | *.png 4 | 5 | sampling.tar.gz 6 | 7 | dists.*/ 8 | profile.* 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | __version__ = '1.0.3' 5 | -------------------------------------------------------------------------------- /src/construct.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | from .matrix import make_ddg_matrix 5 | from .matrix import make_hamming_matrix 6 | from .matrix import make_hamming_vector 7 | 8 | from .utils import get_Zkl 9 | from .utils import get_binary_expansion_length 10 | from .utils import get_common_denominator 11 | from .utils import get_common_numerators 12 | 13 | from .packing import pack_tree 14 | from .tree import make_ddg_tree 15 | 16 | def construct_sample_ky_encoding(p_target): 17 | P, k, l = construct_sample_ky_matrix(p_target) 18 | root = make_ddg_tree(P, k, l) 19 | enc = {} 20 | pack_tree(enc, root, 0) 21 | n = len(P) 22 | encoding = [enc[i] for i in range(len(enc))] 23 | return encoding, n, k 24 | 25 | def construct_sample_ky_matrix(p_target): 26 | Z = get_common_denominator(p_target) 27 | k, l = get_binary_expansion_length(Z) 28 | Zkl = get_Zkl(k, l) 29 | Ms = get_common_numerators(Zkl, p_target) 30 | P, kp, lp = make_ddg_matrix(Ms, k, l) 31 | return P, kp, lp 32 | 33 | def construct_sample_ky_matrix_cached(p_target): 34 | P, k, l = construct_sample_ky_matrix(p_target) 35 | h = make_hamming_vector(P) 36 | T = make_hamming_matrix(P) 37 | return k, l, h, T 38 | -------------------------------------------------------------------------------- /src/divergences.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | from math import log2 5 | from math import sqrt 6 | 7 | from collections import OrderedDict 8 | 9 | try: 10 | import mpmath 11 | mpf = mpmath.mpf 12 | mplog2 = lambda x: mpmath.log(x, b=2) 13 | mpsqrt = mpmath.sqrt 14 | except ImportError: 15 | mpf = float 16 | mplog = log2 17 | mpsqrt = sqrt 18 | 19 | # ============================================================================ 20 | # Various f-divergences from the following sources 21 | # https://arxiv.org/pdf/math/0505238.pdf 22 | # https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7552457 23 | # https://arxiv.org/pdf/1309.3029.pdf 24 | 25 | LABELS = OrderedDict([ 26 | ('tv' , 'Total Variation'), 27 | ('hellinger' , 'Hellinger Divergence'), 28 | ('pchi2' , 'Pearson Chi-Square'), 29 | ('nchi2' , 'Neyman Chi-Square'), 30 | ('td' , 'Triangular Discrimination'), 31 | ('kl' , 'Relative Entropy'), 32 | ('reverse_kl' , 'Reverse Relative Entropy'), 33 | ('js' , 'Jensen-Shannon'), 34 | ('jf' , 'Jeffrey (Symmetric KL)'), 35 | ('mt' , 'Matern'), 36 | ('alpha' , 'Alpha Divergence'), 37 | ('x2' , 'Quadratic'), 38 | ]) 39 | 40 | 41 | # Kernels expressed in direct form. 42 | 43 | def kernel_tv(a, b): 44 | return 1/2 * abs(a-b) 45 | 46 | def kernel_hellinger(a, b): 47 | return (sqrt(a) - sqrt(b))**2 48 | 49 | def kernel_pchi2(a, b): 50 | if a == 0: 51 | return float('inf') 52 | if b == 0: 53 | return a 54 | ax = mpf(float(a)) 55 | bx = mpf(float(b)) 56 | r = (ax - bx)**2 / ax 57 | return float(r) 58 | 59 | def kernel_nchi2(a, b): 60 | return kernel_pchi2(b, a) 61 | 62 | def kernel_td(a, b): 63 | if a + b == 0: 64 | return 0 65 | ax = mpf(float(a)) 66 | bx = mpf(float(b)) 67 | return float((ax - bx)**2 / (ax + bx)) 68 | 69 | def kernel_kl(a, b): 70 | if a == 0: 71 | return 0 72 | if b == 0: 73 | return float('inf') 74 | return a * (log2(a) - log2(b)) 75 | 76 | def kernel_reverse_kl(a, b): 77 | return kernel_kl(b, a) 78 | 79 | def kernel_js(a, b): 80 | m = (a + b)/2 81 | return kernel_kl(a, m) + kernel_kl(b, m) 82 | 83 | def kernel_jf(a, b): 84 | raise NotImplementedError() 85 | 86 | def kernel_mt(a, b): 87 | raise NotImplementedError() 88 | 89 | def kernel_alpha(a, b): 90 | raise NotImplementedError() 91 | 92 | def kernel_x2(a, b): 93 | raise NotImplementedError() 94 | 95 | KERNELS = OrderedDict([ 96 | ('tv' , kernel_tv), 97 | ('hellinger' , kernel_hellinger), 98 | ('pchi2' , kernel_pchi2), 99 | ('nchi2' , kernel_nchi2), 100 | ('td' , kernel_td), 101 | ('kl' , kernel_kl), 102 | ('reverse_kl' , kernel_reverse_kl), 103 | ('js' , kernel_js), 104 | ('jf' , kernel_x2), 105 | ('alpha' , kernel_alpha), 106 | ('x2' , kernel_x2), 107 | ]) 108 | 109 | # Kernels expressed in generator form. 110 | 111 | g_tv = lambda t: .5 * abs(t-1) 112 | g_hellinger = lambda t: (sqrt(t)-1)**2 113 | g_pchi2 = lambda t: (t-1)**2 114 | g_nchi2 = lambda t: (1-t)**2/t if t > 0 else float('inf') 115 | g_td = lambda t: (t-1)**2/(t+1) 116 | g_kl = lambda t: -log2(t) if t > 0 else float('inf') 117 | g_reverse_kl = lambda t: t*log2(t) if t > 0 else 0 118 | g_js = lambda t: g_reverse_kl(t) - (1+t)*log2((1+t)/2) 119 | g_jf = lambda t: g_kl(t) + g_reverse_kl(t) 120 | g_mt = lambda t: (t-1)**2 * (t < 1) 121 | g_alpha = lambda t: 4/(1-.3**2) * (1 - t**((1+.3)/2)) 122 | g_x2 = lambda t: t**2 - 1 123 | 124 | GENERATORS = OrderedDict([ 125 | ('tv' , g_tv), 126 | ('hellinger' , g_hellinger), 127 | ('pchi2' , g_pchi2), 128 | ('nchi2' , g_nchi2), 129 | ('td' , g_td), 130 | ('kl' , g_kl), 131 | ('reverse_kl' , g_reverse_kl), 132 | ('js' , g_js), 133 | ('jf' , g_jf), 134 | ('mt' , g_mt), 135 | ('alpha' , g_alpha), 136 | ('x2' , g_x2), 137 | ]) 138 | 139 | # Kernels expressed in generator form using mpmath for stability. 140 | 141 | def make_stable(f): 142 | def f_stable(p, q): 143 | px = mpf(float(p)) 144 | qx = mpf(float(q)) 145 | t = qx/px 146 | return float(f(t)) 147 | return f_stable 148 | 149 | sg_tv = g_tv 150 | sg_hellinger = lambda t: (mpsqrt(t)-1)**2 151 | sg_pchi2 = g_pchi2 152 | sg_nchi2 = g_nchi2 153 | sg_td = g_td 154 | sg_kl = lambda t: 0 if t == 0 else -mplog2(t) 155 | sg_reverse_kl = lambda t: 0 if t == 0 else t*mplog2(t) 156 | sg_js = lambda t: sg_reverse_kl(t) - (1+t)*mplog2((1+t)/2) 157 | sg_jf = lambda t: sg_kl(t) + sg_reverse_kl(t) 158 | sg_mt = g_mt 159 | sg_alpha = g_alpha 160 | sg_x2 = g_x2 161 | 162 | GENERATORS_STABLE = OrderedDict([ 163 | ('tv' , make_stable(sg_tv)), 164 | ('hellinger' , make_stable(sg_hellinger)), 165 | ('pchi2' , make_stable(sg_pchi2)), 166 | ('nchi2' , make_stable(sg_nchi2)), 167 | ('td' , make_stable(sg_td)), 168 | ('kl' , make_stable(sg_kl)), 169 | ('reverse_kl' , make_stable(sg_reverse_kl)), 170 | ('js' , make_stable(sg_js)), 171 | ('jf' , make_stable(sg_jf)), 172 | ('mt' , make_stable(sg_mt)), 173 | ('alpha' , make_stable(sg_alpha)), 174 | ('x2' , make_stable(sg_x2)), 175 | ]) 176 | 177 | def compute_divergence_kernel(p, q, kernel): 178 | # assert allclose(float(sum(p)), 1) 179 | # assert allclose(float(sum(q)), 1) 180 | # 181 | # TODO: Handle f-divergences with no direct kernel. 182 | # ['alpha', 'patho', 'x2', 'jf']: 183 | terms = [kernel(a, b) for (a, b) in zip(p, q) if a > 0] 184 | return sum(terms) 185 | 186 | def compute_divergence_generator(p, q, g): 187 | # assert allclose(float(sum(p)), 1) 188 | # assert allclose(float(sum(q)), 1) 189 | ratios = [b/a if a > 0 else float('inf') for a, b in zip(p, q)] 190 | terms = [a*g(t) for (a, t) in zip(p, ratios) if a > 0] 191 | return sum(terms) 192 | -------------------------------------------------------------------------------- /src/flip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | """Algorithm for generating random bits lazily, adapted from 5 | 6 | Optimal Discrete Uniform Generation from Coin Flips, and Applications 7 | Jérémie Lumbroso, April 9, 2013 8 | https://arxiv.org/abs/1304.1916 9 | """ 10 | 11 | import random 12 | 13 | k = 32 14 | word = 0 15 | pos = 0 16 | 17 | def flip(): 18 | global pos 19 | global word 20 | if pos == 0: 21 | word = random.getrandbits(k) 22 | pos = k 23 | pos -= 1 24 | return (word & (1 << pos)) >> pos 25 | -------------------------------------------------------------------------------- /src/matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | from .utils import frac_to_bits 5 | from .utils import get_Zkl 6 | from .utils import reduce_fractions 7 | 8 | # Algorithm 3. 9 | 10 | def make_matrix(Ms, k, l): 11 | assert sum(Ms) == get_Zkl(k, l) 12 | return [frac_to_bits(M, k, l) for M in Ms] 13 | 14 | def make_ddg_matrix(Ms, k, l): 15 | Ms_prime, kp, lp = reduce_fractions(Ms, k, l) 16 | P = make_matrix(Ms_prime, kp, lp) if (kp, lp) != (1, 0) else [[1]] 17 | return P, kp, lp 18 | 19 | def make_hamming_vector(P): 20 | N, k = len(P), len(P[0]) 21 | return [sum(P[r][c] for r in range(N)) for c in range(k)] 22 | 23 | def make_hamming_matrix(P): 24 | N, k = len(P), len(P[0]) 25 | T = [[-1 for c in range(k)] for r in range(N)] 26 | for c in range(k): 27 | d = 0 28 | for r in range(N): 29 | if P[r][c] == 1: 30 | T[d][c] = r 31 | d += 1 32 | return T 33 | -------------------------------------------------------------------------------- /src/opt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | from .utils import argmin 5 | from .utils import argmin2 6 | from .utils import normalize_vector 7 | 8 | def get_delta_error(Z, p, M, delta, kernel): 9 | assert delta in [-1, 1] 10 | if delta == -1 and M == 0: 11 | return float('inf') 12 | if delta == 1 and M == Z: 13 | return float('inf') 14 | v1 = kernel(p, (M+delta)/Z) 15 | v0 = kernel(p, M/Z) 16 | return v1 - v0 17 | 18 | def get_initial_Ms(Z, p_target, kernel): 19 | Ms = [0] * len(p_target) 20 | for i, p in enumerate(p_target): 21 | Ms[i] = int(Z*p) 22 | if get_delta_error(Z, p, Ms[i], +1, kernel) < 0: 23 | Ms[i] += 1 24 | return tuple(Ms) 25 | 26 | def find_optimal_indexes(errs_dec, errs_inc): 27 | # Find the indexes of the lowest cost decrements and increments. 28 | j_min_dec0, j_min_dec1 = argmin2(errs_dec) 29 | j_min_inc0, j_min_inc1 = argmin2(errs_inc) 30 | # Ensure optimal indexes are distinct (optimally). 31 | if j_min_dec0 != j_min_inc0: 32 | j_min_dec = j_min_dec0 33 | j_min_inc = j_min_inc0 34 | else: 35 | cost0 = errs_dec[j_min_dec0] + errs_inc[j_min_inc1] 36 | cost1 = errs_dec[j_min_dec1] + errs_inc[j_min_inc0] 37 | if cost0 <= cost1: 38 | j_min_dec = j_min_dec0 39 | j_min_inc = j_min_inc1 40 | else: 41 | j_min_dec = j_min_dec1 42 | j_min_inc = j_min_inc0 43 | assert j_min_inc != j_min_dec 44 | return j_min_dec, j_min_inc 45 | 46 | def prune_initial_Ms(Z, p_target, Ms, kernel): 47 | Ms = list(Ms) 48 | # Compute cost of decrements and increments. 49 | errs_dec = [ 50 | get_delta_error(Z, p, M, -1, kernel) 51 | for M, p in zip(Ms, p_target) 52 | ] 53 | errs_inc = [ 54 | get_delta_error(Z, p, M, +1, kernel) 55 | for M, p in zip(Ms, p_target) 56 | ] 57 | # Find optimal indexes. 58 | j_min_dec, j_min_inc = find_optimal_indexes(errs_dec, errs_inc) 59 | # Begin the loop. 60 | MAXITER = len(p_target) + 1 61 | iters = 0 62 | while errs_dec[j_min_dec] + errs_inc[j_min_inc] < 0: 63 | # Apply the optimal move. 64 | Ms[j_min_dec] -= 1 65 | Ms[j_min_inc] += 1 66 | # Update the costs. 67 | errs_dec[j_min_dec] = get_delta_error( 68 | Z, p_target[j_min_dec], Ms[j_min_dec], -1, kernel) 69 | errs_inc[j_min_inc] = get_delta_error( 70 | Z, p_target[j_min_inc], Ms[j_min_inc], +1, kernel) 71 | # Update the optimal indexes. 72 | j_min_dec, j_min_inc = find_optimal_indexes(errs_dec, errs_inc) 73 | # Update the iteration counter. 74 | iters += 1 75 | # Fail if exceeded theoretical number of iterations. 76 | # Will fire in cases of severe numerical instability. 77 | if iters > MAXITER: 78 | assert False, 'Fatal error: pruning exceeding MAXITER.' 79 | return tuple(Ms) 80 | 81 | def fix_shortfall(Z, p_target, Ms, kernel): 82 | Ms = list(Ms) 83 | shortfall = sum(Ms) - Z 84 | delta = 1 if shortfall < 0 else -1 85 | errs_delta = [ 86 | get_delta_error(Z, p, M, delta, kernel) 87 | for M, p in zip(Ms, p_target) 88 | ] 89 | while shortfall != 0: 90 | j_min = argmin(errs_delta) 91 | Ms[j_min] += delta 92 | errs_delta[j_min] = get_delta_error( 93 | Z, p_target[j_min], Ms[j_min], delta, kernel) 94 | shortfall += delta 95 | assert sum(Ms) == Z 96 | return tuple(Ms) 97 | 98 | def optimize_unorm_strict(Z, p_target, kernel): 99 | """Run the optimization algorithm (requires p_target > 0 element-wise).""" 100 | # STEP 1: Initial guess. 101 | Ms_initial = get_initial_Ms(Z, p_target, kernel) 102 | # STEP 2: Pruning. 103 | Ms_prune = prune_initial_Ms(Z, p_target, Ms_initial, kernel) 104 | # STEP 3: Making up shortfall 105 | Ms_opt = fix_shortfall(Z, p_target, Ms_prune, kernel) 106 | # Return the result. 107 | return Ms_opt 108 | 109 | def optimize_unorm(Z, p_target, kernel): 110 | """Run the optimization algorithm.""" 111 | # STEP 1: Filter out zeros. 112 | p_nonzero_idx_vals = [(i, p) for (i, p) in enumerate(p_target) if p > 0] 113 | p_nonzero_idx = [i for i, _p in p_nonzero_idx_vals] 114 | p_nonzero_vals = [p for _i, p in p_nonzero_idx_vals] 115 | # STEP 2: Get the solution on the non-zero elements. 116 | Ms_opt_trunc = optimize_unorm_strict(Z, p_nonzero_vals, kernel) 117 | # STEP 3: Pad the solution 118 | Ms_opt = [0] * len(p_target) 119 | for j, idx in enumerate(p_nonzero_idx): 120 | Ms_opt[idx] = Ms_opt_trunc[j] 121 | # Return the result. 122 | return tuple(map(int, Ms_opt)) 123 | 124 | def get_optimal_probabilities_strict(Z, p_target, kernel): 125 | """Return optimal Z-type approximation of p_target under f-divergence.""" 126 | Ms = optimize_unorm_strict(Z, p_target, kernel) 127 | return normalize_vector(Z, Ms) 128 | 129 | def get_optimal_probabilities(Z, p_target, kernel): 130 | """Return optimal Z-type approximation of p_target under f-divergence.""" 131 | Ms = optimize_unorm(Z, p_target, kernel) 132 | return normalize_vector(Z, Ms) 133 | -------------------------------------------------------------------------------- /src/orderm2: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | d=$(dirname $(realpath "${0}")) 3 | echo "scale = 0; orderm(2, ${1})" | bc -q -l ${d}/phi 4 | -------------------------------------------------------------------------------- /src/packing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | def pack_tree(enc, node, offset): 5 | assert node.loc is None 6 | node.loc = offset 7 | # Encode a leaf. 8 | if node.label is not None: 9 | enc[offset] = -node.label 10 | return offset + 1 11 | # Encode left child. 12 | if node.left.loc is not None: 13 | enc[offset] = node.left.loc 14 | w = offset + 2 15 | else: 16 | enc[offset] = offset + 2 17 | w = pack_tree(enc, node.left, offset+2) 18 | # Encode right child. 19 | if node.right.loc is not None: 20 | enc[offset + 1] = node.right.loc 21 | else: 22 | enc[offset + 1] = w 23 | w = pack_tree(enc, node.right, w) 24 | # Return the next offset. 25 | return w 26 | -------------------------------------------------------------------------------- /src/phi: -------------------------------------------------------------------------------- 1 | /* gnubc program: phi 2 | 3 | This file was written by Keith Matthews and obtained from: 4 | http://www.numbertheory.org/gnubc/bc_programs.html 5 | 6 | omega(n) returns the number of distinct prime factors of n 7 | phi(n) returns the value of Euler's function 8 | tau(n) returns the number of divisors of n 9 | sigma(n) returns the sum of the divisors of n 10 | mu(n) returns the value of the Mobius Function 11 | lprimroot(p) returns the least primitive root mod p 12 | orderp(a,p) returns the order of a mod p 13 | orderpn(a,p,n) returns the order of a mod p^n, p a prime 14 | orderm(a,m) returns the order of a mod m. 15 | sigmak(k,n) returns the sum of the kth power of the divisors of n 16 | */ 17 | 18 | pglobal[0]=2 19 | pglobal[1]=3 20 | pglobal[2]=5 21 | pglobal[3]=7 22 | pglobal[4]=11 23 | pglobal[5]=13 24 | pglobal[6]=17 25 | pglobal[7]=19 26 | pglobal[8]=23 27 | pglobal[9]=29 28 | pglobal[10]=31 29 | pglobal[11]=37 30 | pglobal[12]=41 31 | pglobal[13]=43 32 | pglobal[14]=47 33 | pglobal[15]=53 34 | pglobal[16]=59 35 | pglobal[17]=61 36 | pglobal[18]=67 37 | pglobal[19]=71 38 | pglobal[20]=73 39 | pglobal[21]=79 40 | pglobal[22]=83 41 | pglobal[23]=89 42 | pglobal[24]=97 43 | pglobal[25]=101 44 | pglobal[26]=103 45 | pglobal[27]=107 46 | pglobal[28]=109 47 | pglobal[29]=113 48 | pglobal[30]=127 49 | pglobal[31]=131 50 | pglobal[32]=137 51 | pglobal[33]=139 52 | pglobal[34]=149 53 | pglobal[35]=151 54 | pglobal[36]=157 55 | pglobal[37]=163 56 | pglobal[38]=167 57 | pglobal[39]=173 58 | pglobal[40]=179 59 | pglobal[41]=181 60 | pglobal[42]=191 61 | pglobal[43]=193 62 | pglobal[44]=197 63 | pglobal[45]=199 64 | pglobal[46]=211 65 | pglobal[47]=223 66 | pglobal[48]=227 67 | pglobal[49]=229 68 | pglobal[50]=233 69 | pglobal[51]=239 70 | pglobal[52]=241 71 | pglobal[53]=251 72 | pglobal[54]=257 73 | pglobal[55]=263 74 | pglobal[56]=269 75 | pglobal[57]=271 76 | pglobal[58]=277 77 | pglobal[59]=281 78 | pglobal[60]=283 79 | pglobal[61]=293 80 | pglobal[62]=307 81 | pglobal[63]=311 82 | pglobal[64]=313 83 | pglobal[65]=317 84 | pglobal[66]=331 85 | pglobal[67]=337 86 | pglobal[68]=347 87 | pglobal[69]=349 88 | pglobal[70]=353 89 | pglobal[71]=359 90 | pglobal[72]=367 91 | pglobal[73]=373 92 | pglobal[74]=379 93 | pglobal[75]=383 94 | pglobal[76]=389 95 | pglobal[77]=397 96 | pglobal[78]=401 97 | pglobal[79]=409 98 | pglobal[80]=419 99 | pglobal[81]=421 100 | pglobal[82]=431 101 | pglobal[83]=433 102 | pglobal[84]=439 103 | pglobal[85]=443 104 | pglobal[86]=449 105 | pglobal[87]=457 106 | pglobal[88]=461 107 | pglobal[89]=463 108 | pglobal[90]=467 109 | pglobal[91]=479 110 | pglobal[92]=487 111 | pglobal[93]=491 112 | pglobal[94]=499 113 | pglobal[95]=503 114 | pglobal[96]=509 115 | pglobal[97]=521 116 | pglobal[98]=523 117 | pglobal[99]=541 118 | pglobal[100]=547 119 | pglobal[101]=557 120 | pglobal[102]=563 121 | pglobal[103]=569 122 | pglobal[104]=571 123 | pglobal[105]=577 124 | pglobal[106]=587 125 | pglobal[107]=593 126 | pglobal[108]=599 127 | pglobal[109]=601 128 | pglobal[110]=607 129 | pglobal[111]=613 130 | pglobal[112]=617 131 | pglobal[113]=619 132 | pglobal[114]=631 133 | pglobal[115]=641 134 | pglobal[116]=643 135 | pglobal[117]=647 136 | pglobal[118]=653 137 | pglobal[119]=659 138 | pglobal[120]=661 139 | pglobal[121]=673 140 | pglobal[122]=677 141 | pglobal[123]=683 142 | pglobal[124]=691 143 | pglobal[125]=701 144 | pglobal[126]=709 145 | pglobal[127]=719 146 | pglobal[128]=727 147 | pglobal[129]=733 148 | pglobal[130]=739 149 | pglobal[131]=743 150 | pglobal[132]=751 151 | pglobal[133]=757 152 | pglobal[134]=761 153 | pglobal[135]=769 154 | pglobal[136]=773 155 | pglobal[137]=787 156 | pglobal[138]=797 157 | pglobal[139]=809 158 | pglobal[140]=811 159 | pglobal[141]=821 160 | pglobal[142]=823 161 | pglobal[143]=827 162 | pglobal[144]=829 163 | pglobal[145]=839 164 | pglobal[146]=853 165 | pglobal[147]=857 166 | pglobal[148]=859 167 | pglobal[149]=863 168 | pglobal[150]=877 169 | pglobal[151]=881 170 | pglobal[152]=883 171 | pglobal[153]=887 172 | pglobal[154]=907 173 | pglobal[155]=911 174 | pglobal[156]=919 175 | pglobal[157]=929 176 | pglobal[158]=937 177 | pglobal[159]=941 178 | pglobal[160]=947 179 | pglobal[161]=953 180 | pglobal[162]=967 181 | pglobal[163]=971 182 | pglobal[164]=977 183 | pglobal[165]=983 184 | pglobal[166]=991 185 | pglobal[167]=997 186 | 187 | 188 | /* mod(a,b)=the least non-negative remainder when an integer 189 | a is divided by a positive integer b */ 190 | 191 | define mod(a,b){ 192 | auto c 193 | c=a%b 194 | if(a>=0) return(c) 195 | if(c==0) return(0) 196 | return(c+b) 197 | } 198 | 199 | 200 | /* lcm(a,b) for any integers a and b */ 201 | 202 | define lcm(a,b){ 203 | auto g 204 | g=gcd(a,b) 205 | if(g==0) return(0) 206 | return(abs(a*b)/g) 207 | } 208 | 209 | 210 | /* 211 | * the bth power of a, where a is an integer, b a positive integer. 212 | * 213 | */ 214 | define exp(a,b){ 215 | auto x,y,z 216 | x=a 217 | y=b 218 | z=1 219 | while(y>0){ 220 | while(y%2==0){ 221 | y=y/2 222 | x=x*x 223 | } 224 | y=y-1 225 | z=z*x 226 | } 227 | return(z) 228 | } 229 | 230 | 231 | /* 232 | * a^b (mod c), a,b,c integers, a,b>=0,c>=1 233 | */ 234 | 235 | define mpower(a,b,c){ 236 | auto x,y,z 237 | x=a 238 | y=b 239 | z=1 240 | while(y>0){ 241 | while(y%2==0){ 242 | y=y/2 243 | x=(x*x)%c 244 | } 245 | y=y-1 246 | z=(z*x)%c 247 | } 248 | return(z) 249 | } 250 | 251 | 252 | /* absolute value of an integer n */ 253 | 254 | define abs(n){ 255 | if(n>=0) return(n) 256 | return(-n) 257 | } 258 | 259 | /*NOTE: in bc we have */ 260 | 261 | /* gcd(m,n) for any integers m and n */ 262 | /* Euclid's division algorithm is used. */ 263 | /* We use gcd(m,n)=gcd(m,|n|) */ 264 | 265 | define gcd(m,n){ 266 | auto a,b,c 267 | a=abs(m) /* a=r[0] */ 268 | if(n==0) return(a) 269 | b=abs(n) /* b=r[1] */ 270 | c=a%b /* c=r[2]=r[0] mod(r[1]) */ 271 | while(c>0){ 272 | a=b 273 | b=c 274 | c=a%b /* c=r[j]=r[j-2] mod(r[j-1]) */ 275 | } 276 | return(b) 277 | } 278 | 279 | /* min(x,y) */ 280 | 281 | define min(x,y){ 282 | if(y1)f=1 309 | } 310 | if(k>=r)f=1 311 | } 312 | r=2*r;print "r=",r,"\n" 313 | if(g==n || r == 16384){ 314 | g=1 315 | r=1 316 | y=0 317 | a=a+1 318 | print "increasing a\n" 319 | if(a==3){ 320 | print "Brent-Pollard failed to find a factor\n" 321 | return(0) 322 | } 323 | } 324 | } 325 | "-- 326 | " 327 | "FINISHED ";g 328 | "is a proper factor of ";n 329 | "-- 330 | " 331 | return(g) 332 | } 333 | 334 | define pollard(n){ 335 | auto i,p,t,b 336 | b=t=2 337 | p=1 338 | for(i=2;i<=10^4;i++){ 339 | if(i%10==0){"i=";i} 340 | t=mpower(t,i,n)/* now t=b ^(i!) */ 341 | p=gcd(n,t-1) 342 | if(p>1){ 343 | if(p 1 and odd, gcd(n,b)=1, or more generally, n does not divide b. 389 | * If miller(n,b)=1, then n passes Miller's test for base b and n is 390 | * either a prime or a strong pseudoprime to base b. 391 | * if miller(n,b)=0, then n is composite. 392 | */ 393 | 394 | define miller(n,b){ 395 | auto a,s 396 | s=(n-1)/2 397 | a=s 398 | while(a%2==0)a=a/2 399 | b=mpower(b,a,n) 400 | if(b==1)return(1) 401 | while(a<=s){ 402 | if(b==n-1)return(1) 403 | b=mod(b*b,n) 404 | a=2*a 405 | } 406 | return(0) 407 | } 408 | 409 | /* 410 | * n > 1 is distinct from pglobal[0],...,pglobal[4]. 411 | * if q(n)=1, then n passes Miller's test for bases pglobal[0],...,pglobal[4] 412 | * and is likely to be prime. 413 | * if q(n)=0, then n is composite. 414 | */ 415 | 416 | define q(n){ 417 | auto i 418 | for(i=0;i<=4;i++){ 419 | if(miller(n,pglobal[i])==0){ 420 | return(0) 421 | } 422 | } 423 | return(1) 424 | } 425 | 426 | /* 427 | * n>1 is not divisible by pglobal[0],...,pglobal[167]. 428 | * v(n) returns a factor of n which is < 1,000,000 (and hence prime) 429 | * or which passes Miller's test for bases pglobal[0],...,pglobal[4] and is 430 | * therefore likely to be prime. 431 | */ 432 | 433 | define v(n){ 434 | auto f,x,y,b 435 | b=1000 436 | if(n 1,000,000, passes Miller's test 459 | * and is not divisible by pglobal[0],...,pglobal[167]. It is likely to be 460 | * prime. 461 | * primefactors(n) returns the number lglobal-t of q-prime factors of n. 462 | * rglobal[] and lglobal are global variables. 463 | * The prime and q-prime factors of n are qglobal[i],...,qglobal[jglobal-1] 464 | * with exponents 465 | * kglobal[i],...,kglobal[jglobal-1] where i is the value of the global 466 | * variable jglobal before primefactors(n) is called. 467 | */ 468 | 469 | define primefactors(n){ 470 | auto b,p,x,k,t 471 | b=1000 472 | x=babydivide(n) 473 | t=lglobal 474 | while(x!=1){ 475 | k=0 476 | p=v(x) 477 | while(x%p==0){ 478 | k=k+1 479 | x=x/p 480 | } 481 | if(p>b*b){ 482 | rglobal[lglobal]=p 483 | lglobal=lglobal+1 484 | } 485 | qglobal[jglobal]=p 486 | kglobal[jglobal]=k 487 | jglobal=jglobal+1 488 | } 489 | return(lglobal-t) 490 | } 491 | 492 | /* 493 | * Selfridge's test for primality - see "Prime Numbers and Computer 494 | * Methods for Factorization" by H. Riesel, Theorem 4.4, p.106. 495 | * input: n (q-prime) 496 | * first finds the prime and q-prime factors of n-1. If no q-prime factor 497 | * is present and 1 is returned, then n is prime. However if at least one 498 | * q-prime is present and 1 is returned, then n retains "q-prime" status. 499 | * If 0 is returned, then either n or one of the q-prime factors of n-1 is 500 | * composite. 501 | */ 502 | 503 | define selfridge(n){ 504 | auto i,x,s,t,u 505 | i=jglobal 506 | u=primefactors(n-1) 507 | cglobal=u+cglobal 508 | /* cglobal,jglobal,lglobal are global variables. */ 509 | /* primefactors(n-1) returns jglobal-i primes and q-primes 510 | qglobal[i],...,qglobal[jglobal-1] */ 511 | /* and q-primes rglobal[l-u],...,rglobal[lglobal-1], where u>=0. */ 512 | while(i<=jglobal-1){ 513 | for(x=2;x0){ 553 | t=selfridge(rglobal[i]) 554 | if(t==0){ 555 | return(0) 556 | } 557 | i=i+1 558 | cglobal=cglobal-1 559 | } 560 | return(s) 561 | } 562 | 563 | /* 564 | * phi(n) returns the value of Euler's function. 565 | */ 566 | define phi(n){ 567 | auto i,u,d,m,t 568 | u = 1 569 | d = omega(n) 570 | t = d-1 571 | print "factorization: " 572 | for (i=0;i1){ 576 | print "^",kglobal[i] 577 | } 578 | if(i=2) return(0)} 615 | return((-1)^(d)) 616 | } 617 | 618 | /* This finds the least primitive root mod p. */ 619 | define lprimroot(p){ 620 | auto d,i,r,q,f,m,u,w,pminus1 621 | 622 | if(p%2==0 && p>2){print p," is even\n";return(0)} 623 | w=omega(p) 624 | if(w>1 || (w==1 && kglobal[0]>1)){print p," is not a prime\n";return(0)} 625 | f = 1 626 | pminus1=p-1 627 | q = pminus1/2 628 | d=omega(pminus1) 629 | for (i = 1; i < p; i++){ 630 | /* "i=";i;*/ 631 | r = mpower(i,q,p) 632 | if (r-p == -1){ 633 | for (u=0;u2){d=orderp(a,p)} 675 | if(p==2){ 676 | if((a-1)%4==0)d=1 677 | if((a+1)%4==0)d=2 678 | } 679 | q=p 680 | e=0 681 | h=0 682 | while(e==0){ 683 | q=q*p 684 | e=mpower(a,d,q)-1 685 | h=h+1 686 | } 687 | if(n<=h)return(d) 688 | return(exp(p,n-h)*d) 689 | } 690 | 691 | /* 692 | * orderm(a,m) returns the order of a mod m. This is the lcm of the orders of 693 | * a modulo the prime powers exactly dividing m. 694 | */ 695 | define orderm(a,n){ 696 | auto x[],y[],i,s,o 697 | 698 | if(gcd(a,n)!=1){"a is not relatively prime to n: returning "; return(0)} 699 | if(a==1)return(1) 700 | if (a== -1){ 701 | if(n==2)return(1) 702 | return(2) 703 | } 704 | s=omega(n) 705 | for(i=0;i2){print p," is even\n";return(0)} 791 | w=omega(p) 792 | if(w>1 || (w==1 && kglobal[0]>1)){print p," is not a prime\n";return(0)} 793 | f = 1 794 | pminus1=p-1 795 | q = pminus1/2 796 | d=omega(pminus1) 797 | for (i = 1; i < p; i++){ 798 | /* "i=";i;*/ 799 | minusi=-i 800 | r = mpower(minusi,q,p) 801 | if (r-p == -1){ 802 | for (u=0;un){ 817 | temp=m 818 | m=n 819 | n=temp 820 | } 821 | /* now m<=n */ 822 | if(m<3){ 823 | print "m<3\n" 824 | return 825 | } 826 | 827 | for(i=m;i<=n;i++){ 828 | h=lucasnonverbose(i) 829 | if(h){ 830 | g=lprimroot(i) 831 | print "g=",g,",p=",i,"\n" 832 | } 833 | } 834 | } 835 | 836 | /* This lists the least negative primitive root mod p for all primes in the range 837 | 3 <= m <= p <= n. */ 838 | define lprimrootnegmn(m,n){ 839 | auto i,g,h,temp 840 | if(m>n){ 841 | temp=m 842 | m=n 843 | n=temp 844 | } 845 | /* now m<=n */ 846 | if(m<2){ 847 | print "m<2\n" 848 | return 849 | } 850 | 851 | for(i=m;i<=n;i++){ 852 | h=lucasnonverbose(i) 853 | if(h){ 854 | g=lprimrootneg(i) 855 | print "g=",g,",p=",i,"\n" 856 | } 857 | } 858 | } 859 | 860 | /* This prints the order of i and -i mod p for 1 <= i < p. */ 861 | define orderpiminusi(p){ 862 | auto i 863 | for(i=1;i> 1 84 | w = 1 85 | while (Mp % 2) == 0: 86 | w += 1 87 | Mp = Mp >> 1 88 | if Mp == 1: 89 | k = w 90 | l = k 91 | else: 92 | kp = orderm2(Mp) 93 | k = kp + w 94 | l = w 95 | return (k, l) 96 | 97 | def encode_binary(x, width): 98 | """Convert integer x to binary with at least width digits.""" 99 | assert isinstance(x, int) 100 | xb = bin(x)[2:] 101 | if width == 0: 102 | assert x == 0 103 | return '' 104 | else: 105 | assert len(xb) <= width 106 | pad = width - len(xb) 107 | return '0' * pad + xb 108 | 109 | 110 | def frac_to_bits(M, k, l): 111 | # Returns binary expansion of M / Zkl 112 | assert 0 <= M < get_Zkl(k, l) or (k == 1 and l == 0) 113 | if l == k: 114 | x = M 115 | y = 0 116 | elif l == 0: 117 | x = 0 118 | y = M 119 | else: 120 | Zb = pow(2, k-l) - 1 121 | x = M//Zb 122 | y = M - Zb * x 123 | a = encode_binary(x, l) 124 | s = encode_binary(y, k-l) 125 | b = a + s 126 | return [int(i) for i in b] 127 | 128 | def reduce_fractions(Ms, k, l): 129 | """Simplify (M/Zkl | M in Ms) to lowest terms.""" 130 | Zkl = get_Zkl(k, l) 131 | assert sum(Ms) == get_Zkl(k, l) 132 | if any(M==Zkl for M in Ms): 133 | Ms_prime = [M//Zkl for M in Ms] 134 | k_prime = 1 135 | l_prime = 0 136 | return (Ms_prime, k_prime, l_prime) 137 | if l == 0: 138 | return (Ms, k, l) 139 | if all(M%2 == 0 for M in Ms): 140 | Ms_prime = [M//2 for M in Ms] 141 | return reduce_fractions(Ms_prime, k-1, l-1) 142 | if all(M == Ms[0] for M in Ms): 143 | remainder = Zkl / Ms[0] 144 | base = log2(remainder) 145 | assert remainder == int(remainder) 146 | assert base == int(base) 147 | k_prime = int(base) 148 | l_prime = k_prime 149 | Ms_prime = [1] * len(Ms) 150 | return Ms_prime, k_prime, l_prime 151 | return Ms, k, l 152 | -------------------------------------------------------------------------------- /src/writeio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | def write_array(array, f): 5 | n = len(array) 6 | f.write('%d ' % (n,)) 7 | f.write(' '.join(map(str, array))) 8 | f.write('\n') 9 | 10 | def write_matrix(matrix, f): 11 | nrow = len(matrix) 12 | ncol = len(matrix[0]) 13 | f.write('%d %d\n' % (nrow, ncol)) 14 | for row in matrix: 15 | f.write(' '.join(map(str, row))) 16 | f.write('\n') 17 | 18 | def write_sample_ky_encoding(enc, n, k, fname): 19 | with open(fname, 'w') as f: 20 | f.write('%d %d\n' % (n, k)) 21 | write_array(enc, f) 22 | 23 | def write_sample_ky_matrix(P, k, l, fname): 24 | with open(fname, 'w') as f: 25 | f.write('%d %d\n' % (k, l)) 26 | write_matrix(P, f) 27 | 28 | def write_sample_ky_matrix_cached(k, l, h, T, fname): 29 | with open(fname, 'w') as f: 30 | f.write('%d %d\n' % (k, l)) 31 | write_array(h, f) 32 | write_matrix(T, f) 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt. 3 | -------------------------------------------------------------------------------- /tests/test_divergences.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | import pytest 5 | 6 | from optas.divergences import GENERATORS 7 | from optas.divergences import KERNELS 8 | from optas.divergences import LABELS 9 | from optas.divergences import compute_divergence_generator 10 | from optas.divergences import compute_divergence_kernel 11 | 12 | from optas.tests.utils import allclose 13 | from optas.tests.utils import get_random_dist 14 | 15 | def disabled_test_f_divergences_graphical(): 16 | import matplotlib.pyplot as plt 17 | import numpy 18 | fig, axes = plt.subplots(nrows=3, ncols=4) 19 | for g, ax in zip(GENERATORS, numpy.ravel(axes)): 20 | func = GENERATORS[g] 21 | label = LABELS[g] 22 | ts = numpy.linspace(1e-1, 2, 100) 23 | ys = [func(t) for t in ts] 24 | ax.plot(ts, ys) 25 | ax.set_title(label) 26 | fig.set_size_inches((18, 10)) 27 | fig.set_tight_layout(True) 28 | plt.show() 29 | 30 | 31 | @pytest.mark.parametrize('x', range(20)) 32 | def test_kernel_generator_agree(x): 33 | p = get_random_dist(20) 34 | q = get_random_dist(20) 35 | for k in KERNELS: 36 | try: 37 | div_kernel = compute_divergence_kernel(p, q, KERNELS[k]) 38 | div_generator = compute_divergence_generator(p, q, GENERATORS[k]) 39 | assert allclose(float(div_kernel), float(div_generator)) 40 | except NotImplementedError: 41 | continue 42 | -------------------------------------------------------------------------------- /tests/test_matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | from optas.matrix import make_ddg_matrix 5 | from optas.matrix import make_hamming_matrix 6 | from optas.matrix import make_hamming_vector 7 | from optas.matrix import make_matrix 8 | 9 | from optas.utils import frac_to_bits 10 | 11 | def test_make_matrix(): 12 | Ms, k, l = [6, 6, 6, 6], 5, 3 13 | P = make_matrix(Ms, k, l) 14 | assert P[0] == P[1] == P[2] == P[3] == frac_to_bits(Ms[0], k, l) 15 | 16 | Ms, k, l = [6, 6, 6, 6], 5, 3 17 | P, kp, lp = make_ddg_matrix(Ms, k, l) 18 | assert kp == 2 19 | assert lp == 2 20 | assert P[0] == P[1] == P[2] == P[3] == [0, 1] 21 | 22 | def test_make_hamming_vector_matrix(): 23 | P = [ 24 | [1, 0, 0, 1], 25 | [0, 1, 1, 1], 26 | [1, 0, 0, 1], 27 | [0, 0, 0, 1], 28 | ] 29 | h = make_hamming_vector(P) 30 | assert h == [2, 1, 1, 4] 31 | T = make_hamming_matrix(P) 32 | assert T == [ 33 | [ 0, 1, 1, 0], 34 | [ 2, -1, -1, 1], 35 | [-1, -1, -1, 2], 36 | [-1, -1, -1, 3], 37 | ] 38 | 39 | P = [ 40 | [0, 1, 0, 0], # 4 41 | [0, 0, 0, 1], # 1 42 | [1, 0, 1, 0], # 10 43 | [0, 0, 0, 1], # 1 44 | ] 45 | h = make_hamming_vector(P) 46 | assert h == [1, 1, 1, 2] 47 | T = make_hamming_matrix(P) 48 | assert T == [ 49 | [ 2, 0, 2, 1], 50 | [-1, -1, -1, 3], 51 | [-1, -1, -1, -1], 52 | [-1, -1, -1, -1], 53 | ] 54 | -------------------------------------------------------------------------------- /tests/test_opt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | import itertools 5 | 6 | import pytest 7 | 8 | from optas.divergences import KERNELS 9 | from optas.divergences import compute_divergence_kernel 10 | from optas.opt import get_optimal_probabilities 11 | 12 | from optas.tests.utils import get_random_dist 13 | from optas.tests.utils import get_random_dist_zeros 14 | from optas.tests.utils import allclose 15 | 16 | from optas.utils import argmin 17 | from optas.utils import normalize_vector 18 | 19 | def get_enumeration_tuples(Z, n): 20 | """Get all length-n tuples of nonnegative integers which sum to Z.""" 21 | sequences = itertools.product(*[range(Z+1) for _i in range(n)]) 22 | return filter(lambda s: sum(s)==Z, sequences) 23 | 24 | def get_enumeration_opt(Z, n, p_target, allocations, kernel): 25 | """Run the enumeration algorithm (requires p_target > 0 element-wise).""" 26 | assert n == len(p_target) 27 | Ms_list = [normalize_vector(Z, Ms) for Ms in allocations] 28 | divs = [compute_divergence_kernel(p_target, Ms, kernel) for Ms in Ms_list] 29 | i_opt = argmin(divs) 30 | Ms_opt = Ms_list[i_opt] 31 | return Ms_opt 32 | 33 | def check_solutions_match(Z, n, p_target, assignmemts, kernel): 34 | assert sum(p_target) == 1 35 | try: 36 | M_enum = get_enumeration_opt(Z, n, p_target, assignmemts, kernel) 37 | M_opt = get_optimal_probabilities(Z, p_target, kernel) 38 | except NotImplementedError: 39 | return True 40 | assert sum(M_enum) == 1 41 | assert sum(M_opt) == 1 42 | e_enum = compute_divergence_kernel(p_target, M_enum, kernel) 43 | e_opt = compute_divergence_kernel(p_target, M_opt, kernel) 44 | assert M_enum == M_opt or allclose(e_enum, e_opt) 45 | 46 | @pytest.mark.parametrize('n', [2, 3, 4]) 47 | @pytest.mark.parametrize('k', [2, 3, 4]) 48 | def test_get_optimal_probabilities(n, k): 49 | Z = 2**(k) 50 | assignmemts = list(get_enumeration_tuples(Z, n)) 51 | for kern in KERNELS: 52 | kernel = KERNELS[kern] 53 | p_target = get_random_dist(n) 54 | check_solutions_match(Z, n, p_target, assignmemts, kernel) 55 | 56 | @pytest.mark.parametrize('n', [4, 5,]) 57 | @pytest.mark.parametrize('k', [2, 3, 4]) 58 | def test_get_optimal_probabilities__ci_(n, k): 59 | Z = 2**(k) 60 | assignmemts = list(get_enumeration_tuples(Z, n)) 61 | for kern in KERNELS: 62 | kernel = KERNELS[kern] 63 | p_target = get_random_dist(n) 64 | check_solutions_match(Z, n, p_target, assignmemts, kernel) 65 | 66 | def test_opt_zeros(): 67 | Z = 100 68 | p_target = get_random_dist_zeros(50) 69 | idx_zero = [i for (i, p) in enumerate(p_target) if p==0] 70 | assert idx_zero 71 | for kern in KERNELS: 72 | try: 73 | M_opt = get_optimal_probabilities(Z, p_target, KERNELS[kern]) 74 | assert [M_opt[i] == 0 for i in idx_zero] 75 | except NotImplementedError: 76 | continue 77 | 78 | def test_opt_insufficient_precision(): 79 | Z = 16 80 | p_target = get_random_dist_zeros(100) 81 | idx_zero = [i for (i, p) in enumerate(p_target) if p==0] 82 | assert idx_zero 83 | for kern in KERNELS: 84 | try: 85 | M_opt = get_optimal_probabilities(Z, p_target, KERNELS[kern]) 86 | assert [M_opt[i] == 0 for i in idx_zero] 87 | except NotImplementedError: 88 | continue 89 | except Exception: 90 | # TODO: Handle this case more gracefully! 91 | # All sorts of errors arise when the list of errors 92 | # all all inf, since the solution values become negative. 93 | assert kern in ['nchi2', 'kl', 'jf'] 94 | -------------------------------------------------------------------------------- /tests/test_packing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | from optas.packing import pack_tree 5 | from optas.tree import make_ddg_tree 6 | 7 | def test_one_back_edge(): 8 | k, l = 4, 0 9 | P = [ 10 | [0, 0, 1, 1], # 3 11 | [1, 1, 0, 0], # 12 12 | ] 13 | 14 | root = make_ddg_tree(P, k, l) 15 | encoding = {} 16 | pack_tree(encoding, root, 0) 17 | 18 | back_edges = [b for a, b in encoding.items() if 0 <= b < a] 19 | assert back_edges == [0] 20 | 21 | leaves_three = sum(1 for b in encoding.values() if b == -1) 22 | assert leaves_three == 2 23 | 24 | leaves_twelve = sum(1 for b in encoding.values() if b == -2) 25 | assert leaves_twelve == 2 26 | -------------------------------------------------------------------------------- /tests/test_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | import random 5 | 6 | from collections import Counter 7 | 8 | import pytest 9 | 10 | from optas.matrix import make_ddg_matrix 11 | from optas.matrix import make_hamming_matrix 12 | from optas.matrix import make_hamming_vector 13 | from optas.packing import pack_tree 14 | from optas.tree import make_ddg_tree 15 | 16 | from optas.sample import sample_ky_encoding 17 | from optas.sample import sample_ky_matrix 18 | from optas.sample import sample_ky_matrix_cached 19 | 20 | import optas.flip 21 | 22 | from optas.tests.utils import get_bitstrings 23 | from optas.tests.utils import get_chisquare_pval 24 | 25 | @pytest.mark.parametrize('seed', [10, 20, 100123]) 26 | def test_deterministic(seed): 27 | random.seed(seed) 28 | Ms, k, l = [0, 31], 5, 0 29 | P, kp, lp = make_ddg_matrix(Ms, k, l) 30 | root = make_ddg_tree(P, kp, lp) 31 | encoding = {} 32 | pack_tree(encoding, root, 0) 33 | 34 | N_sample = 10000 35 | samples_mat = [sample_ky_matrix(P, kp, lp) for _i in range(N_sample)] 36 | samples_enc = [sample_ky_encoding(encoding) for _i in range(N_sample)] 37 | assert Counter(samples_mat)[1] == N_sample 38 | assert Counter(samples_enc)[1] == N_sample 39 | 40 | @pytest.mark.parametrize('seed', [10, 20, 100123]) 41 | def test_nondetermistic(seed): 42 | random.seed(seed) 43 | Ms, k, l = [3, 12], 4, 0 44 | P, kp, lp = make_ddg_matrix(Ms, k, l) 45 | root = make_ddg_tree(P, kp, lp) 46 | encoding = {} 47 | pack_tree(encoding, root, 0) 48 | 49 | N_sample = 10000 50 | samples_mat = [sample_ky_matrix(P, kp, lp) for _i in range(N_sample)] 51 | samples_enc = [sample_ky_encoding(encoding) for _i in range(N_sample)] 52 | 53 | pval_mat = get_chisquare_pval([3/15, 12/15], samples_mat) 54 | assert 0.05 < pval_mat 55 | 56 | pval_enc = get_chisquare_pval([3/15, 12/15], samples_enc) 57 | assert 0.05 < pval_enc 58 | 59 | def test_sample_ky_matrix_cached(): 60 | Ms, k, l = [3, 2, 1, 7, 2, 1], 4, 4 61 | P, kp, lp = make_ddg_matrix(Ms, k, l) 62 | h = make_hamming_vector(P) 63 | T = make_hamming_matrix(P) 64 | 65 | samples = [] 66 | optas.flip.k = 4 67 | for i in range(2**4): 68 | optas.flip.word = i 69 | optas.flip.pos = optas.flip.k 70 | result0 = sample_ky_matrix(P, kp, lp) 71 | 72 | optas.flip.word = i 73 | optas.flip.pos = optas.flip.k 74 | result1 = sample_ky_matrix_cached(kp, lp, h, T) 75 | 76 | assert result0 == result1 77 | samples.append(result0) 78 | 79 | counter = Counter(samples) 80 | assert counter[1] == 3 81 | assert counter[2] == 2 82 | assert counter[3] == 1 83 | assert counter[4] == 7 84 | assert counter[5] == 2 85 | assert counter[6] == 1 86 | -------------------------------------------------------------------------------- /tests/test_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | from optas.matrix import make_ddg_matrix 5 | from optas.matrix import make_matrix 6 | from optas.tree import make_ddg_tree 7 | from optas.tree import make_leaf_table 8 | from optas.tree import make_tree 9 | 10 | def test_probs_dyadic(): 11 | P_desired = [ 12 | [0, 0, 1], # 1 13 | [0, 0, 1], # 1 14 | [0, 1, 1], # 3 15 | [0, 0, 1], # 1 16 | [0, 1, 0], # 2 17 | ] 18 | Ms, k, l = [1, 1, 3, 1, 2], 3, 3 19 | P = make_matrix(Ms, k, l) 20 | L = make_leaf_table(P) 21 | root = make_tree(0, k, l, [], L) 22 | 23 | assert P == P_desired 24 | 25 | assert root.right.label is None 26 | assert root.right.right.label == 3 27 | assert root.right.left.label == 5 28 | 29 | assert root.left.label is None 30 | assert root.left.left.label is None 31 | assert root.left.right.label is None 32 | 33 | assert root.left.right.right.label == 1 34 | assert root.left.right.left.label == 2 35 | assert root.left.left.right.label == 3 36 | assert root.left.left.left.label == 4 37 | 38 | def test_probs_nondyadic_basic(): 39 | P_desired = [ 40 | [0, 0, 1, 1], # 3 41 | [1, 1, 0, 0], # 12 42 | ] 43 | Ms, k, l = [3, 12], 4, 0 44 | P = make_matrix(Ms, k, l) 45 | L = make_leaf_table(P) 46 | root = make_tree(0, k, l, [], L) 47 | 48 | assert P == P_desired 49 | 50 | assert root.right.label == 2 51 | assert root.right.left is None 52 | 53 | assert root.left.label is None 54 | assert root.left.right.label == 2 55 | 56 | assert root.left.left.label is None 57 | assert root.left.left.right.label == 1 58 | 59 | assert root.left.left.left.label is None 60 | assert root.left.left.left.right.label == 1 61 | assert root.left.left.left.left == root 62 | 63 | def test_probs_nondyadic_two_back_edge(): 64 | P_desired = [ 65 | [0, 1, 0, 1], # 5/14 66 | [0, 1, 0, 1], # 5/14 67 | [0, 1, 0, 0], # 4/14 68 | ] 69 | Ms, k, l = [5, 5, 4], 4, 1 70 | 71 | P = make_matrix(Ms, k, l) 72 | L = make_leaf_table(P) 73 | root = make_tree(0, k, l, [], L) 74 | 75 | assert P == P_desired 76 | 77 | assert len(L) == 5 78 | assert L[6] == 1 79 | assert L[5] == 2 80 | assert L[4] == 3 81 | assert L[18] == 1 82 | assert L[17] == 2 83 | 84 | assert root.right.right.label == 1 85 | assert root.right.right.right is None 86 | assert root.right.right.left is None 87 | assert root.right.left.label == 2 88 | assert root.right.left.right is None 89 | assert root.right.left.left is None 90 | 91 | assert root.left.right.label == 3 92 | assert root.left.right.left is None 93 | assert root.left.right.right is None 94 | 95 | assert root.left.left.right.right.label == 1 96 | assert root.left.left.right.left.label == 2 97 | 98 | assert root.left.left.left.left == root.left 99 | assert root.left.left.left.right == root.right 100 | 101 | def test_probs_nondyadic_three_back_edges(): 102 | P_desired = [ 103 | [0, 1, 0, 0, 1], # 8/28 104 | [0, 0, 1, 0, 1], # 5/28 105 | [0, 0, 1, 0, 1], # 5/28 106 | [0, 0, 1, 0, 1], # 5/28 107 | [0, 0, 1, 0, 1], # 5/28 108 | ] 109 | Ms, k, l = [8, 5, 5, 5, 5], 5, 2 110 | 111 | P = make_matrix(Ms, k, l) 112 | L = make_leaf_table(P) 113 | root = make_tree(0, k, l, [], L) 114 | 115 | assert P == P_desired 116 | 117 | assert len(L) == 10 118 | assert L[6] == 1 119 | 120 | assert L[12] == 2 121 | assert L[11] == 3 122 | assert L[10] == 4 123 | assert L[9] == 5 124 | 125 | assert L[38] == 1 126 | assert L[37] == 2 127 | assert L[36] == 3 128 | assert L[35] == 4 129 | assert L[34] == 5 130 | 131 | assert root.right.right.label == 1 132 | assert root.right.left.right.label == 2 133 | assert root.right.left.left.label == 3 134 | 135 | assert root.left.right.right.label == 4 136 | assert root.left.right.left.label == 5 137 | 138 | assert root.left.left.right.right.right.label == 1 139 | assert root.left.left.right.right.left.label == 2 140 | assert root.left.left.right.left.right.label == 3 141 | assert root.left.left.right.left.left.label == 4 142 | 143 | assert root.left.left.left.right.right.label == 5 144 | assert root.left.left.left.right.left.label is None 145 | assert root.left.left.left.right.left == root.right.left 146 | 147 | assert root.left.left.left.left.right == root.left.right 148 | assert root.left.left.left.left.left == root.left.left 149 | 150 | def test_reduction_to_single_node(): 151 | # An end-to-end test with automatic 152 | # reduction and so forth. 153 | P_desired = [ 154 | [0, 0, 0, 0], # 0 155 | [1, 1, 1, 1], # 15 156 | ] 157 | Ms, k, l = [0, 15], 4, 0 158 | P, kp, lp = make_ddg_matrix(Ms, k, l) 159 | root = make_ddg_tree(P, kp, lp) 160 | assert kp == 1 161 | assert lp == 0 162 | assert root.label == 1 163 | assert root.right is None 164 | assert root.left is None 165 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | import pytest 5 | 6 | from optas.utils import encode_binary 7 | from optas.utils import frac_to_bits 8 | from optas.utils import get_Zb 9 | from optas.utils import get_Zkl 10 | from optas.utils import get_binary_expansion_length 11 | from optas.utils import get_k_bit_prefixes 12 | from optas.utils import reduce_fractions 13 | 14 | def bits_to_int(bits): 15 | # Helper function for converting bit string to integer. 16 | sbits = ''.join(map(str, bits)) 17 | return int(sbits, 2) 18 | 19 | def bits_to_frac(bits, k, l): 20 | # Helper function for inverting frac_to_bits. 21 | Zkl = get_Zkl(k, l) 22 | Zb = get_Zb(k, l) 23 | prefix = bits[:l] 24 | suffix = bits[l:] 25 | int_prefix = bits_to_int(prefix) if prefix else 0 26 | int_suffix = bits_to_int(suffix) if suffix else 0 27 | numerator = Zb * int_prefix + int_suffix 28 | denominator = Zkl 29 | return (numerator, denominator) 30 | 31 | def frac_to_bits_dyadic(M, k): 32 | # Helper function binary expansion of M / 2**k 33 | bits = [0]*k 34 | for j in range(k): 35 | mask = 1 << ((k-1) - j) 36 | bits[j] = int((M & mask) > 0) 37 | return bits 38 | 39 | def test_orderm2(): 40 | assert get_binary_expansion_length(2) == (1,1) 41 | # 2^1 = 1 42 | assert get_binary_expansion_length(3) == (2,0) 43 | # 2^2-2^0 = 3 = 1*3 44 | assert get_binary_expansion_length(4) == (2,2) 45 | # 2^2 = 2 46 | assert get_binary_expansion_length(5) == (4,0) 47 | # 2^4-2^0 = 15 = 3*5 48 | assert get_binary_expansion_length(6) == (3,1) 49 | # 2^3-2^1 = 6 = 1*6 50 | assert get_binary_expansion_length(7) == (3,0) 51 | # 2^3-2^0 = 7 = 1*7 52 | assert get_binary_expansion_length(8) == (3,3) 53 | # 2^3 = 8 = 1*7 54 | assert get_binary_expansion_length(9) == (6,0) 55 | # 2^6-2^0 = 63 = 7*9 56 | assert get_binary_expansion_length(10) == (5,1) 57 | # 2^5-2^1 = 30 = 3*10 58 | assert get_binary_expansion_length(11) == (10,0) 59 | # 2^10 - 2^0 = 1023 = 93*11 60 | assert get_binary_expansion_length(12) == (4,2) 61 | # 2^4-2^2 = 12 = 1*12 62 | assert get_binary_expansion_length(13) == (12,0) 63 | # 2^12 - 2^0 = 4095 = 315*13 64 | assert get_binary_expansion_length(14) == (4,1) 65 | # 2^4-2^1 = 14 = 1*14 66 | assert get_binary_expansion_length(15) == (4,0) 67 | # 2^4-2^0 = 15 = 1*15 68 | assert get_binary_expansion_length(16) == (4,4) 69 | # 2^4 = 16 = 1*16 70 | 71 | def test_frac_to_bits_to_frac(): 72 | for k in range(1, 12): 73 | ls = get_k_bit_prefixes(k) 74 | for l in ls: 75 | Zkl = get_Zkl(k, l) 76 | for M in range(Zkl + 1*(k==1 and l==0)): 77 | bits = frac_to_bits(M, k, l) 78 | frac = bits_to_frac(bits, k, l) 79 | assert frac[0] == M 80 | 81 | def test_reduce_fractions_unit(): 82 | for k in [2, 5, 8, 10]: 83 | Ms = [2**k-1, 0, 0, 0] 84 | (k, l) = (k, 0) 85 | Mp, kp, lp = reduce_fractions(Ms, k, l) 86 | assert Mp == [1, 0, 0, 0] 87 | assert kp == 1 88 | assert lp == 0 89 | 90 | def test_reduce_fraction_dyadic_simplify(): 91 | Ms, k, l = [2, 2], 2, 2 92 | Mp, kp, lp = reduce_fractions(Ms, k, l) 93 | assert Mp == [1, 1] 94 | assert kp == 1 95 | assert lp == 1 96 | 97 | Ms, k, l = [4, 8, 4], 4, 4 98 | Mp, kp, lp = reduce_fractions(Ms, k, l) 99 | assert Mp == [1, 2, 1] 100 | assert kp == 2 101 | assert lp == 2 102 | 103 | Ms, k, l = [8, 16, 2, 4, 2], 5, 5 104 | Mp, kp, lp = reduce_fractions(Ms, k, l) 105 | assert Mp == [4, 8, 1, 2, 1] 106 | assert kp == 4 107 | assert lp == 4 108 | 109 | Ms, k, l = [2, 22, 2, 4, 2], 5, 5 110 | Mp, kp, lp = reduce_fractions(Ms, k, l) 111 | assert Mp == [1, 11, 1, 2, 1] 112 | assert kp == 4 113 | assert lp == 4 114 | 115 | def test_reduce_fraction_dyadic_nosimplify(): 116 | Ms, k, l = [3, 1], 2, 2 117 | Mp, kp, lp = reduce_fractions(Ms, k, l) 118 | assert (Mp, kp, lp) == (Ms, k, l) 119 | 120 | Ms, k, l = [5, 7, 4], 4, 4 121 | Mp, kp, lp = reduce_fractions(Ms, k, l) 122 | assert (Mp, kp, lp) == (Ms, k, l) 123 | 124 | Ms, k, l = [8, 16, 2, 5, 1], 5, 5 125 | Mp, kp, lp = reduce_fractions(Ms, k, l) 126 | assert (Mp, kp, lp) == (Ms, k, l) 127 | 128 | Ms, k, l = [2, 22, 2, 5, 1], 5, 5 129 | Mp, kp, lp = reduce_fractions(Ms, k, l) 130 | 131 | def test_reduce_fractions_uniform(): 132 | Ms, k, l = [4, 4, 4], 4, 2 133 | Mp, kp, lp = reduce_fractions(Ms, k, l) 134 | assert Mp == [1, 1, 1] 135 | assert kp == 2 136 | assert lp == 0 137 | 138 | Ms, k, l = [6, 6, 6, 6], 5, 3 139 | Mp, kp, lp = reduce_fractions(Ms, k, l) 140 | assert Mp == [1, 1, 1, 1] 141 | assert kp == 2 142 | assert lp == 2 143 | 144 | def test_encode_binary(): 145 | with pytest.raises(AssertionError): 146 | encode_binary(3, 0) 147 | with pytest.raises(AssertionError): 148 | encode_binary(3, 1) 149 | assert encode_binary(3, 2) == '11' 150 | assert encode_binary(3, 3) == '011' 151 | assert encode_binary(3, 4) == '0011' 152 | assert encode_binary(3, 5) == '00011' 153 | 154 | assert encode_binary(0, 0) == '' 155 | assert encode_binary(0, 1) == '0' 156 | assert encode_binary(0, 2) == '00' 157 | 158 | assert encode_binary(255, 10) == '0011111111' 159 | assert encode_binary(108, 10) == '0001101100' 160 | 161 | def test_frac_to_bits_dyadic(): 162 | for (x, k) in [(10, 5), (1, 11), (18, 10), (123, 9)]: 163 | a = frac_to_bits(x, k, k) 164 | b = frac_to_bits_dyadic(x, k) 165 | assert a == b 166 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 MIT Probabilistic Computing Project. 2 | # Released under Apache 2.0; refer to LICENSE.txt 3 | 4 | import random 5 | 6 | from collections import Counter 7 | from fractions import Fraction 8 | from itertools import product 9 | from math import isinf 10 | 11 | from scipy.stats import chisquare 12 | 13 | def get_chisquare_pval(p_target, samples): 14 | N = len(samples) 15 | f_expected = [int(N*p) for p in p_target] 16 | counts = Counter(samples) 17 | keys = sorted(set(samples)) 18 | f_actual = [counts[k] for k in keys] 19 | return chisquare(f_expected, f_actual)[1] 20 | 21 | def get_bitstrings(k): 22 | """Return all length-k binary strings.""" 23 | tuples = product(*[(0,1) for _i in range(k)]) 24 | strings = [''.join(map(str, t)) for t in tuples] 25 | return strings 26 | 27 | def get_random_dist(n): 28 | numerators = [random.randint(1, n**2) for i in range(n)] 29 | Z = sum(numerators) 30 | return [Fraction(a, Z) for a in numerators] 31 | 32 | def get_random_dist_zeros(n): 33 | numerators = [random.randint(0, n**2) for i in range(n)] 34 | n_zero = random.randint(1, n-1) 35 | numerators[:n_zero] = [0]*n_zero 36 | random.shuffle(numerators) 37 | Z = sum(numerators) 38 | return [Fraction(a, Z) for a in numerators] 39 | 40 | def allclose(a, b, rtol=1e-5, atol=1e-8): 41 | if isinf(a) and isinf(b): 42 | return True 43 | return abs(a - b) <= (atol + rtol * abs(b)) 44 | --------------------------------------------------------------------------------