├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── mathematics_dataset ├── __init__.py ├── example.py ├── generate.py ├── generate_settings.py ├── generate_test.py ├── generate_to_file.py ├── modules │ ├── __init__.py │ ├── algebra.py │ ├── algebra_test.py │ ├── arithmetic.py │ ├── arithmetic_test.py │ ├── calculus.py │ ├── calculus_test.py │ ├── comparison.py │ ├── measurement.py │ ├── modules.py │ ├── numbers.py │ ├── polynomials.py │ ├── probability.py │ └── train_test_split.py ├── sample │ ├── __init__.py │ ├── arithmetic.py │ ├── arithmetic_test.py │ ├── linear_system.py │ ├── linear_system_test.py │ ├── number.py │ ├── number_test.py │ ├── ops.py │ ├── ops_test.py │ ├── polynomials.py │ └── polynomials_test.py └── util │ ├── __init__.py │ ├── combinatorics.py │ ├── combinatorics_test.py │ ├── composition.py │ ├── composition_test.py │ ├── display.py │ ├── display_test.py │ ├── probability.py │ └── probability_test.py └── setup.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mathematics Dataset 2 | 3 | This dataset code generates mathematical question and answer pairs, from a range 4 | of question types at roughly school-level difficulty. This is designed to test 5 | the mathematical learning and algebraic reasoning skills of learning models. 6 | 7 | Original paper: [Analysing Mathematical 8 | Reasoning Abilities of Neural Models](https://openreview.net/pdf?id=H1gR5iR5FX) 9 | (Saxton, Grefenstette, Hill, Kohli). 10 | 11 | ## Example questions 12 | 13 | ``` 14 | Question: Solve -42*r + 27*c = -1167 and 130*r + 4*c = 372 for r. 15 | Answer: 4 16 | 17 | Question: Calculate -841880142.544 + 411127. 18 | Answer: -841469015.544 19 | 20 | Question: Let x(g) = 9*g + 1. Let q(c) = 2*c + 1. Let f(i) = 3*i - 39. Let w(j) = q(x(j)). Calculate f(w(a)). 21 | Answer: 54*a - 30 22 | 23 | Question: Let e(l) = l - 6. Is 2 a factor of both e(9) and 2? 24 | Answer: False 25 | 26 | Question: Let u(n) = -n**3 - n**2. Let e(c) = -2*c**3 + c. Let l(j) = -118*e(j) + 54*u(j). What is the derivative of l(a)? 27 | Answer: 546*a**2 - 108*a - 118 28 | 29 | Question: Three letters picked without replacement from qqqkkklkqkkk. Give prob of sequence qql. 30 | Answer: 1/110 31 | ``` 32 | 33 | ## Pre-generated data 34 | 35 | [Pre-generated files](https://console.cloud.google.com/storage/browser/mathematics-dataset) 36 | 37 | ### Version 1.0 38 | 39 | This is the version released with the original paper. It contains 2 million 40 | (question, answer) pairs per module, with questions limited to 160 characters in 41 | length, and answers to 30 characters in length. Note the training data for each 42 | question type is split into "train-easy", "train-medium", and "train-hard". This 43 | allows training models via a curriculum. The data can also be mixed together 44 | uniformly from these training datasets to obtain the results reported in the 45 | paper. Categories: 46 | 47 | * **algebra** (linear equations, polynomial roots, sequences) 48 | * **arithmetic** (pairwise operations and mixed expressions, surds) 49 | * **calculus** (differentiation) 50 | * **comparison** (closest numbers, pairwise comparisons, sorting) 51 | * **measurement** (conversion, working with time) 52 | * **numbers** (base conversion, remainders, common divisors and multiples, 53 | primality, place value, rounding numbers) 54 | * **polynomials** (addition, simplification, composition, evaluating, expansion) 55 | * **probability** (sampling without replacement) 56 | 57 | ## Getting the source 58 | 59 | ### PyPI 60 | 61 | The easiest way to get the source is to use pip: 62 | 63 | ```shell 64 | $ pip install mathematics_dataset 65 | ``` 66 | 67 | ### From GitHub 68 | 69 | Alternately you can get the source by cloning the mathematics_dataset 70 | repository: 71 | 72 | ```shell 73 | $ git clone https://github.com/deepmind/mathematics_dataset 74 | $ pip install --upgrade mathematics_dataset/ 75 | ``` 76 | 77 | ## Generating examples 78 | 79 | Generated examples can be printed to stdout via the `generate` script. For 80 | example: 81 | 82 | ```shell 83 | python -m mathematics_dataset.generate --filter=linear_1d 84 | ``` 85 | 86 | will generate example (question, answer) pairs for solving linear equations in 87 | one variable. 88 | 89 | We've also included `generate_to_file.py` as an example of how to write the 90 | generated examples to text files. You can use this directly, or adapt it for 91 | your generation and training needs. 92 | 93 | ## Dataset Metadata 94 | The following table is necessary for this dataset to be indexed by search 95 | engines such as Google Dataset Search. 96 |
97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 150 | 151 | 152 | 153 | 171 | 172 | 173 | 174 | 175 | 176 |
propertyvalue
nameMathematics Dataset
url
sameAshttps://github.com/deepmind/mathematics_dataset
descriptionThis dataset consists of mathematical question and answer pairs, from a range 117 | of question types at roughly school-level difficulty. This is designed to test 118 | the mathematical learning and algebraic reasoning skills of learning models.\n 119 | \n 120 | ## Example questions\n 121 | \n 122 | ```\n 123 | Question: Solve -42*r + 27*c = -1167 and 130*r + 4*c = 372 for r.\n 124 | Answer: 4\n 125 | \n 126 | Question: Calculate -841880142.544 + 411127.\n 127 | Answer: -841469015.544\n 128 | \n 129 | Question: Let x(g) = 9*g + 1. Let q(c) = 2*c + 1. Let f(i) = 3*i - 39. Let w(j) = q(x(j)). Calculate f(w(a)).\n 130 | Answer: 54*a - 30\n 131 | ```\n 132 | \n 133 | It contains 2 million 134 | (question, answer) pairs per module, with questions limited to 160 characters in 135 | length, and answers to 30 characters in length. Note the training data for each 136 | question type is split into "train-easy", "train-medium", and "train-hard". This 137 | allows training models via a curriculum. The data can also be mixed together 138 | uniformly from these training datasets to obtain the results reported in the 139 | paper. Categories:\n 140 | \n 141 | * **algebra** (linear equations, polynomial roots, sequences)\n 142 | * **arithmetic** (pairwise operations and mixed expressions, surds)\n 143 | * **calculus** (differentiation)\n 144 | * **comparison** (closest numbers, pairwise comparisons, sorting)\n 145 | * **measurement** (conversion, working with time)\n 146 | * **numbers** (base conversion, remainders, common divisors and multiples,\n 147 | primality, place value, rounding numbers)\n 148 | * **polynomials** (addition, simplification, composition, evaluating, expansion)\n 149 | * **probability** (sampling without replacement)
provider 154 |
155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 |
propertyvalue
nameDeepMind
sameAshttps://en.wikipedia.org/wiki/DeepMind
169 |
170 |
citationhttps://identifiers.org/arxiv:1904.01557
177 |
178 | -------------------------------------------------------------------------------- /mathematics_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mathematics_dataset/example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Containers for "[example] problems" (i.e., question/answer) pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | from mathematics_dataset.util import composition 24 | 25 | 26 | def question(context, template, **kwargs): 27 | """Makes a question, using the given context and template. 28 | 29 | The format is similar to that for python's `format` function, for example: 30 | 31 | ``` 32 | question(context, 'What is {} plus {p} over {q}?', 2, p=3, q=4) 33 | ``` 34 | 35 | The main difference between this and the standard python formatting is that 36 | this understands `Entity`s in the arguments, and will do appropriate expansion 37 | of text and prefixing of their descriptions. 38 | 39 | Arguments: 40 | context: Instance of `composition.Context`, for extracting entities needed 41 | for describing the problem. 42 | template: A string, like "Calculate the value of {exp}.". 43 | **kwargs: A dictionary mapping arguments to values, e.g., 44 | `{'exp': sympy.Add(2, 3, evaluate=False)}`. 45 | 46 | Returns: 47 | String. 48 | """ 49 | assert isinstance(context, composition.Context) 50 | assert isinstance(template, str) 51 | prefix, kwargs = composition.expand_entities(context, **kwargs) 52 | if prefix: 53 | prefix += ' ' 54 | return prefix + template.format(**kwargs) 55 | 56 | 57 | Problem = collections.namedtuple('Problem', ('question', 'answer')) 58 | 59 | 60 | -------------------------------------------------------------------------------- /mathematics_dataset/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Prints to stdout different curriculum questions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import textwrap 23 | 24 | # Dependency imports 25 | from absl import app 26 | from absl import flags 27 | from absl import logging 28 | from mathematics_dataset import generate_settings 29 | from mathematics_dataset.modules import modules 30 | import six 31 | from six.moves import range 32 | 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | flags.DEFINE_string('filter', '', 'restrict to matching module names') 37 | flags.DEFINE_integer('per_train_module', 10, 'Num of examples per train module') 38 | flags.DEFINE_integer('per_test_module', 10, 'Num of examples per test module') 39 | flags.DEFINE_bool('show_dropped', False, 'Whether to print dropped questions') 40 | 41 | 42 | filtered_modules = collections.OrderedDict([]) 43 | counts = {} 44 | 45 | 46 | def _make_entropy_fn(level, num_levels): 47 | """This returns a function that returns a subrange of entropy. 48 | 49 | E.g., if level=1 (medium) and num_levels=3, then the returned function will 50 | map the range [x, x + y] to [x + y/3, x + 2y/3]. 51 | 52 | Args: 53 | level: Integer in range [0, num_levels - 1]. 54 | num_levels: Number of difficulty levels. 55 | 56 | Returns: 57 | Function to restrict entropy range. 58 | """ 59 | lower = level / num_levels 60 | upper = (level + 1) / num_levels 61 | def modify_entropy(range_): 62 | assert len(range_) == 2 63 | length = range_[1] - range_[0] 64 | return (range_[0] + lower * length, range_[0] + upper * length) 65 | return modify_entropy 66 | 67 | 68 | def _filter_and_flatten(modules_): 69 | """Returns flattened dict, filtered according to FLAGS.""" 70 | flat = collections.OrderedDict() 71 | 72 | def add(submodules, prefix=None): 73 | for key, module_or_function in six.iteritems(submodules): 74 | full_name = prefix + '__' + key if prefix is not None else key 75 | if isinstance(module_or_function, dict): 76 | add(module_or_function, full_name) 77 | else: 78 | if FLAGS.filter not in full_name: 79 | continue 80 | flat[full_name] = module_or_function 81 | 82 | add(modules_) 83 | 84 | # Make sure list of modules are in deterministic order. This is important when 85 | # generating across multiple machines. 86 | flat = collections.OrderedDict( 87 | [(key, flat[key]) for key in sorted(six.iterkeys(flat))]) 88 | 89 | return flat 90 | 91 | 92 | def init_modules(train_split=False): 93 | """Inits the dicts containing functions for generating modules.""" 94 | if filtered_modules: 95 | return # already initialized 96 | 97 | all_modules = collections.OrderedDict([]) 98 | if train_split: 99 | all_modules['train-easy'] = modules.train(_make_entropy_fn(0, 3)) 100 | all_modules['train-medium'] = modules.train(_make_entropy_fn(1, 3)) 101 | all_modules['train-hard'] = modules.train(_make_entropy_fn(2, 3)) 102 | else: 103 | all_modules['train'] = modules.train(_make_entropy_fn(0, 1)) 104 | 105 | all_modules['interpolate'] = modules.test() 106 | all_modules['extrapolate'] = modules.test_extra() 107 | 108 | counts['train'] = FLAGS.per_train_module 109 | counts['train-easy'] = FLAGS.per_train_module // 3 110 | counts['train-medium'] = FLAGS.per_train_module // 3 111 | counts['train-hard'] = FLAGS.per_train_module // 3 112 | counts['interpolate'] = FLAGS.per_test_module 113 | counts['extrapolate'] = FLAGS.per_test_module 114 | 115 | for regime_, modules_ in six.iteritems(all_modules): 116 | filtered_modules[regime_] = _filter_and_flatten(modules_) 117 | 118 | 119 | def sample_from_module(module): 120 | """Samples a problem, ignoring samples with overly long questions / answers. 121 | 122 | Args: 123 | module: Callable returning a `Problem`. 124 | 125 | Returns: 126 | Pair `(problem, num_dropped)`, where `problem` is an instance of `Problem` 127 | and `num_dropped` is an integer >= 0 indicating the number of samples that 128 | were dropped. 129 | """ 130 | num_dropped = 0 131 | while True: 132 | problem = module() 133 | question = str(problem.question) 134 | if len(question) > generate_settings.MAX_QUESTION_LENGTH: 135 | num_dropped += 1 136 | if FLAGS.show_dropped: 137 | logging.warning('Dropping question: %s', question) 138 | continue 139 | answer = str(problem.answer) 140 | if len(answer) > generate_settings.MAX_ANSWER_LENGTH: 141 | num_dropped += 1 142 | if FLAGS.show_dropped: 143 | logging.warning('Dropping question with answer: %s', answer) 144 | continue 145 | return problem, num_dropped 146 | 147 | 148 | def main(unused_argv): 149 | """Prints Q&As from modules according to FLAGS.filter.""" 150 | init_modules() 151 | 152 | text_wrapper = textwrap.TextWrapper( 153 | width=80, initial_indent=' ', subsequent_indent=' ') 154 | 155 | for regime, flat_modules in six.iteritems(filtered_modules): 156 | per_module = counts[regime] 157 | for module_name, module in six.iteritems(flat_modules): 158 | # These magic print constants make the header bold. 159 | print('\033[1m{}/{}\033[0m'.format(regime, module_name)) 160 | num_dropped = 0 161 | for _ in range(per_module): 162 | problem, extra_dropped = sample_from_module(module) 163 | num_dropped += extra_dropped 164 | text = text_wrapper.fill( 165 | '{} \033[92m{}\033[0m'.format(problem.question, problem.answer)) 166 | print(text) 167 | if num_dropped > 0: 168 | logging.warning('Dropped %d examples', num_dropped) 169 | 170 | 171 | if __name__ == '__main__': 172 | app.run(main) 173 | -------------------------------------------------------------------------------- /mathematics_dataset/generate_settings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Settings for generation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import string 23 | 24 | MAX_QUESTION_LENGTH = 160 25 | MAX_ANSWER_LENGTH = 30 26 | QUESTION_CHARS = ( 27 | ['', ' '] + list(string.ascii_letters + string.digits + string.punctuation)) 28 | EMPTY_INDEX = QUESTION_CHARS.index('') 29 | NUM_INDICES = len(QUESTION_CHARS) 30 | CHAR_TO_INDEX = {char: index for index, char in enumerate(QUESTION_CHARS)} 31 | INDEX_TO_CHAR = {index: char for index, char in enumerate(QUESTION_CHARS)} 32 | 33 | -------------------------------------------------------------------------------- /mathematics_dataset/generate_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.generate.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import absltest 23 | from absl.testing import parameterized 24 | from mathematics_dataset import generate 25 | import six 26 | from six.moves import range 27 | 28 | 29 | class GenerateTest(parameterized.TestCase): 30 | 31 | def testMakeEntropyFn(self): 32 | entropy_full = generate._make_entropy_fn(0, 1) 33 | self.assertEqual(entropy_full((2, 3)), (2, 3)) 34 | entropy_third = generate._make_entropy_fn(2, 3) 35 | self.assertEqual(entropy_third((3, 6)), (5, 6)) 36 | 37 | @parameterized.parameters('train', 'interpolate', 'extrapolate') 38 | def testGenerate(self, regime): 39 | generate.init_modules() 40 | for module in six.itervalues(generate.filtered_modules[regime]): 41 | for _ in range(3): 42 | question = module() 43 | str(question) 44 | 45 | 46 | if __name__ == '__main__': 47 | absltest.main() 48 | -------------------------------------------------------------------------------- /mathematics_dataset/generate_to_file.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example of how to write generated questions to text files. 16 | 17 | Given an output directory, this will create the following subdirectories: 18 | 19 | * train-easy 20 | * train-medium 21 | * train-hard 22 | * interpolate 23 | * extrapolate 24 | 25 | and populate each of these directories with a text file for each of the module, 26 | where the text file contains lines alternating between the question and the 27 | answer. 28 | 29 | Passing --train_split=False will create a single output directory 'train' for 30 | training data. 31 | """ 32 | 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import os 38 | 39 | # Dependency imports 40 | from absl import app 41 | from absl import flags 42 | from absl import logging 43 | from mathematics_dataset import generate 44 | import six 45 | from six.moves import range 46 | 47 | FLAGS = flags.FLAGS 48 | 49 | flags.DEFINE_string('output_dir', None, 'Where to write output text') 50 | flags.DEFINE_boolean('train_split', True, 51 | 'Whether to split training data by difficulty') 52 | flags.mark_flag_as_required('output_dir') 53 | 54 | 55 | def main(unused_argv): 56 | generate.init_modules(FLAGS.train_split) 57 | 58 | output_dir = os.path.expanduser(FLAGS.output_dir) 59 | if os.path.exists(output_dir): 60 | logging.fatal('output dir %s already exists', output_dir) 61 | logging.info('Writing to %s', output_dir) 62 | os.makedirs(output_dir) 63 | 64 | for regime, flat_modules in six.iteritems(generate.filtered_modules): 65 | regime_dir = os.path.join(output_dir, regime) 66 | os.mkdir(regime_dir) 67 | per_module = generate.counts[regime] 68 | for module_name, module in six.iteritems(flat_modules): 69 | path = os.path.join(regime_dir, module_name + '.txt') 70 | with open(path, 'w') as text_file: 71 | for _ in range(per_module): 72 | problem, _ = generate.sample_from_module(module) 73 | text_file.write(str(problem.question) + '\n') 74 | text_file.write(str(problem.answer) + '\n') 75 | logging.info('Written %s', path) 76 | 77 | 78 | if __name__ == '__main__': 79 | app.run(main) 80 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/algebra.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Algebra-related questions, e.g., "Solve 1 + x = 2.".""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import random 23 | 24 | # Dependency imports 25 | from mathematics_dataset import example 26 | from mathematics_dataset.sample import linear_system 27 | from mathematics_dataset.sample import number 28 | from mathematics_dataset.sample import ops 29 | from mathematics_dataset.sample import polynomials 30 | from mathematics_dataset.util import composition 31 | from mathematics_dataset.util import display 32 | import numpy as np 33 | from six.moves import range 34 | import sympy 35 | 36 | 37 | _ENTROPY_TRAIN = (3, 10) 38 | _ENTROPY_INTERPOLATE = (8, 8) 39 | _ENTROPY_EXTRAPOLATE = (12, 12) 40 | 41 | # In generating a polynomial with real roots (where the roots are generated 42 | # sequentially), this is the probability of taking a previous root, thus giving 43 | # at least one repeated root, rather than sampling a new number. The value is 44 | # somewhat arbitrary, but gives a "medium probability" of seeing a repeated root 45 | # for lowish degree polynomials. 46 | _POLY_PROBABILITY_REPEATED_ROOT = 0.2 47 | 48 | 49 | def _make_modules(entropy): 50 | """Returns modules given "difficulty" parameters.""" 51 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy) 52 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy) 53 | 54 | return { 55 | # Solving equations: 56 | 'polynomial_roots': functools.partial( 57 | polynomial_roots, None, sample_args_pure), 58 | 'polynomial_roots_composed': functools.partial( 59 | polynomial_roots, None, sample_args_composed), 60 | 'linear_1d': functools.partial( 61 | solve_linear_1d, None, sample_args_pure), 62 | 'linear_1d_composed': functools.partial( 63 | solve_linear_1d, None, sample_args_composed), 64 | 'linear_2d': functools.partial( 65 | solve_linear_2d, None, sample_args_pure), 66 | 'linear_2d_composed': functools.partial( 67 | solve_linear_2d, None, sample_args_composed), 68 | 69 | # Sequences: 70 | 'sequence_next_term': functools.partial(sequence_next_term, *entropy), 71 | 'sequence_nth_term': functools.partial(sequence_nth_term, *entropy), 72 | } 73 | 74 | 75 | def train(entropy_fn): 76 | """Returns dict of training modules.""" 77 | return _make_modules(entropy_fn(_ENTROPY_TRAIN)) 78 | 79 | 80 | def test(): 81 | """Returns dict of testing modules.""" 82 | return _make_modules(_ENTROPY_INTERPOLATE) 83 | 84 | 85 | def test_extra(): 86 | """Returns dict of extrapolation testing modules.""" 87 | sample_args_pure = composition.PreSampleArgs(1, 1, *_ENTROPY_EXTRAPOLATE) 88 | return { 89 | 'polynomial_roots_big': functools.partial( 90 | polynomial_roots, None, sample_args_pure), 91 | } 92 | 93 | 94 | def _sample_roots(entropy): 95 | """Generates `num_distinct + num_repeated` polynomial roots.""" 96 | num_roots = random.randint(2, 5) 97 | 98 | num_repeated = np.random.binomial( 99 | num_roots - 1, _POLY_PROBABILITY_REPEATED_ROOT) 100 | # Slight hack: don't allow all the roots to be repeated when the entropy is 101 | # high, as this can create very large coefficients. 102 | if entropy > 4: 103 | num_repeated = min(num_repeated, int(num_roots / 2)) 104 | 105 | num_distinct = num_roots - num_repeated 106 | 107 | entropies = entropy * np.random.dirichlet(np.ones(num_distinct)) 108 | 109 | roots = [] 110 | 111 | for root_entropy in entropies: 112 | # Generates a root with small probability of being rational. 113 | # (Otherwise when we multiply out the denominators, we get really large 114 | # coefficients in our polynomial.) 115 | if random.random() < 0.1: 116 | root = number.non_integer_rational(root_entropy, True) 117 | else: 118 | root = number.integer(root_entropy, True) 119 | roots.append(root) 120 | 121 | for _ in range(num_repeated): 122 | roots.append(random.choice(roots[:num_distinct])) 123 | 124 | return roots 125 | 126 | 127 | def _polynomial_coeffs_with_roots(roots, scale_entropy): 128 | """Returns a polynomial with the given roots. 129 | 130 | The polynomial is generated by expanding product_{root in roots} (x - root), 131 | and then (1) scaling by the coefficients so they are all integers with lcm 1, 132 | and then (2) further scaling the coefficients by a random integer or rational 133 | with `scale_entropy` digits. 134 | 135 | Args: 136 | roots: List of values. 137 | scale_entropy: Float; entropy of the random coefficient scaling. 138 | 139 | Returns: 140 | List of coefficients `coeffs`, such that `coeffs[i]` is the coefficient of 141 | variable ** i. 142 | """ 143 | variable = sympy.Symbol('x') # doesn't matter, only use coefficients 144 | polynomial = sympy.Poly(sympy.prod([variable - root for root in roots])) 145 | coeffs_reversed = polynomial.all_coeffs() 146 | assert len(coeffs_reversed) == len(roots) + 1 147 | coeffs = list(reversed(coeffs_reversed)) 148 | # Multiply terms to change rationals to integers, and then maybe reintroduce. 149 | lcm = sympy.lcm([sympy.denom(coeff) for coeff in coeffs]) 150 | if scale_entropy > 0: 151 | while True: 152 | scale = number.integer_or_rational(scale_entropy, signed=True) 153 | if scale != 0: 154 | break 155 | else: 156 | scale = 1 157 | return [coeff * scale * lcm for coeff in coeffs] 158 | 159 | 160 | def polynomial_roots(value, sample_args, context=None): 161 | """E.g., "Solve 2*x**2 - 18 = 0.".""" 162 | del value # not currently used 163 | # is_question = context is None 164 | if context is None: 165 | context = composition.Context() 166 | 167 | entropy, sample_args = sample_args.peel() 168 | scale_entropy = min(entropy / 2, 1) 169 | 170 | roots = _sample_roots(entropy - scale_entropy) 171 | solutions = sorted(list(sympy.FiniteSet(*roots))) 172 | coeffs = _polynomial_coeffs_with_roots(roots, scale_entropy) 173 | (polynomial_entity,) = context.sample( 174 | sample_args, [composition.Polynomial(coeffs)]) 175 | 176 | if random.choice([False, True]): 177 | # Ask for explicit roots. 178 | if len(solutions) == 1: 179 | answer = solutions[0] 180 | else: 181 | answer = display.NumberList(solutions) 182 | 183 | if polynomial_entity.has_expression(): 184 | equality = ops.Eq(polynomial_entity.expression, 0) 185 | variable = polynomial_entity.polynomial_variables[0] 186 | else: 187 | variable = sympy.Symbol(context.pop()) 188 | equality = ops.Eq(polynomial_entity.handle.apply(variable), 0) 189 | template = random.choice([ 190 | 'Let {equality}. What is {variable}?', 191 | 'Let {equality}. Calculate {variable}.', 192 | 'Suppose {equality}. What is {variable}?', 193 | 'Suppose {equality}. Calculate {variable}.', 194 | 'What is {variable} in {equality}?', 195 | 'Solve {equality} for {variable}.', 196 | 'Find {variable} such that {equality}.', 197 | 'Find {variable}, given that {equality}.', 198 | 'Determine {variable} so that {equality}.', 199 | 'Determine {variable}, given that {equality}.', 200 | 'Solve {equality}.' 201 | ]) 202 | return example.Problem( 203 | question=example.question( 204 | context, template, equality=equality, variable=variable), 205 | answer=answer) 206 | else: 207 | if polynomial_entity.has_expression(): 208 | expression = polynomial_entity.expression 209 | variable = polynomial_entity.polynomial_variables[0] 210 | else: 211 | variable = sympy.Symbol(context.pop()) 212 | expression = polynomial_entity.handle.apply(variable) 213 | factored = sympy.factor( 214 | polynomials.coefficients_to_polynomial(coeffs, variable)) 215 | template = random.choice([ 216 | 'Factor {expression}.', 217 | ]) 218 | return example.Problem( 219 | question=example.question(context, template, expression=expression), 220 | answer=factored) 221 | 222 | 223 | def _solve_linear_system(degree, value, sample_args, context=None): 224 | """Solve linear equations.""" 225 | is_question = context is None 226 | if context is None: 227 | context = composition.Context() 228 | 229 | entropy, sample_args = sample_args.peel() 230 | 231 | solutions = [] 232 | if value is not None: 233 | solutions.append(value) 234 | 235 | extra_solutions_needed = degree - len(solutions) 236 | if extra_solutions_needed > 0: 237 | entropies = (entropy / 4) * np.random.dirichlet( 238 | np.ones(extra_solutions_needed)) 239 | entropies = np.maximum(1, entropies) # min per-solution entropy 240 | entropy -= sum(entropies) 241 | solutions += [number.integer(solution_entropy, True) 242 | for solution_entropy in entropies] 243 | entropy = max(1, entropy) 244 | 245 | variables = [sympy.Symbol(context.pop()) for _ in range(degree)] 246 | 247 | solution_index = 0 248 | # If we're going to be creating a linear system with constants to replace by 249 | # handles from other modules, then we need a linear system with constants 250 | # occurring. Very occasionally this can fail to happen, e.g., "x = -x"; 251 | # normally this while loop will only see one iteration. 252 | while True: 253 | equations = linear_system.linear_system( 254 | variables=variables, solutions=solutions, entropy=entropy, 255 | non_trivial_in=solution_index) 256 | constants = ops.number_constants(equations) 257 | if sample_args.num_modules <= 1 or constants: 258 | break 259 | 260 | context.sample_by_replacing_constants(sample_args, equations) 261 | 262 | variable = variables[solution_index] 263 | answer = solutions[solution_index] 264 | 265 | equations = ', '.join([str(equation) for equation in equations]) 266 | 267 | if is_question: 268 | template = random.choice([ 269 | 'Solve {equations} for {variable}.', 270 | ]) 271 | return example.Problem( 272 | example.question( 273 | context, template, equations=equations, 274 | variable=variable), 275 | answer) 276 | else: 277 | return composition.Entity( 278 | context=context, 279 | value=answer, 280 | description='Suppose {equations}.', 281 | handle=variable, 282 | equations=equations) 283 | 284 | 285 | @composition.module(number.is_integer) 286 | def solve_linear_1d(*args, **kwargs): 287 | return _solve_linear_system(1, *args, **kwargs) 288 | 289 | 290 | @composition.module(number.is_integer) 291 | def solve_linear_2d(*args, **kwargs): 292 | return _solve_linear_system(2, *args, **kwargs) 293 | 294 | 295 | class _PolynomialSequence(object): 296 | """A sequence given by a polynomial.""" 297 | 298 | def __init__(self, variable, entropy, min_degree=1, max_degree=3): 299 | """Initializes a random polynomial sequence. 300 | 301 | Args: 302 | variable: Variable to use. 303 | entropy: Entropy for polynomial coefficients. 304 | min_degree: Minimum order of polynomial. 305 | max_degree: Maximum order of polynomial. 306 | """ 307 | self._degree = random.randint(min_degree, max_degree) 308 | self._variable = variable 309 | polynomial = polynomials.sample_with_small_evaluation( 310 | variable=self._variable, degree=self._degree, 311 | max_abs_input=self._degree + 2, entropy=entropy) 312 | self._sympy = polynomial.sympy() 313 | 314 | @property 315 | def min_num_terms(self): 316 | """Returns the minimum number of terms to identify the sequence. 317 | 318 | This assumes a human-like prior over types of sequences. 319 | 320 | Returns: 321 | Integer >= 1. 322 | """ 323 | return self._degree + 2 324 | 325 | @property 326 | def sympy(self): 327 | return self._sympy 328 | 329 | def term(self, n): 330 | """Returns the `n`th term of the sequence.""" 331 | return self._sympy.subs(self._variable, n) 332 | 333 | 334 | def sequence_next_term(min_entropy, max_entropy): 335 | """E.g., "What is the next term in the sequence 1, 2, 3?".""" 336 | entropy = random.uniform(min_entropy, max_entropy) 337 | context = composition.Context() 338 | variable = sympy.Symbol(context.pop()) 339 | 340 | sequence = _PolynomialSequence(variable, entropy) 341 | min_num_terms = sequence.min_num_terms 342 | num_terms = random.randint(min_num_terms, min_num_terms + 3) 343 | sequence_sample = [sequence.term(n + 1) for n in range(num_terms)] 344 | sequence_sample = display.NumberList(sequence_sample) 345 | 346 | template = random.choice([ 347 | 'What is next in {sequence}?', 348 | 'What comes next: {sequence}?', 349 | 'What is the next term in {sequence}?', 350 | ]) 351 | answer = sequence.term(num_terms + 1) 352 | 353 | return example.Problem( 354 | question=example.question(context, template, sequence=sequence_sample), 355 | answer=answer) 356 | 357 | 358 | def sequence_nth_term(min_entropy, max_entropy): 359 | """E.g., "What is the nth term in the sequence 1, 2, 3?".""" 360 | entropy = random.uniform(min_entropy, max_entropy) 361 | context = composition.Context() 362 | variable = sympy.Symbol(context.pop()) 363 | 364 | sequence = _PolynomialSequence(variable, entropy) 365 | min_num_terms = sequence.min_num_terms 366 | num_terms = random.randint(min_num_terms, min_num_terms + 3) 367 | sequence_sample = [sequence.term(n + 1) for n in range(num_terms)] 368 | sequence_sample = display.NumberList(sequence_sample) 369 | 370 | template = random.choice([ 371 | 'What is the {variable}\'th term of {sequence}?', 372 | ]) 373 | answer = sequence.sympy 374 | 375 | return example.Problem( 376 | question=example.question( 377 | context, template, variable=variable, sequence=sequence_sample), 378 | answer=answer) 379 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/algebra_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.modules.algebra.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | 23 | # Dependency imports 24 | from absl.testing import absltest 25 | from mathematics_dataset.modules import algebra 26 | from mathematics_dataset.sample import polynomials 27 | from six.moves import range 28 | import sympy 29 | 30 | 31 | class AlgebraTest(absltest.TestCase): 32 | 33 | def testPolynomialCoeffsWithRoots(self): 34 | coeffs = algebra._polynomial_coeffs_with_roots([1, 2], scale_entropy=0.0) 35 | self.assertEqual(coeffs, [2, -3, 1]) 36 | 37 | def testPolynomialRoots(self): 38 | variable = sympy.Symbol('x') 39 | for _ in range(10): 40 | roots = random.sample(list(range(-9, 10)), 3) 41 | coeffs = algebra._polynomial_coeffs_with_roots(roots, scale_entropy=10.0) 42 | polynomial = polynomials.coefficients_to_polynomial(coeffs, variable) 43 | calc_roots = sympy.polys.polytools.real_roots(polynomial) 44 | self.assertEqual(calc_roots, sorted(roots)) 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/arithmetic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.modules.arithmetic.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import absltest 23 | from mathematics_dataset.modules import arithmetic 24 | import sympy 25 | 26 | 27 | class ArithmeticTest(absltest.TestCase): 28 | 29 | def testSurdCoefficients(self): 30 | exp = sympy.sympify('1') 31 | self.assertEqual(arithmetic._surd_coefficients(exp), 32 | (1, 0)) 33 | 34 | exp = sympy.sympify('1/2') 35 | self.assertEqual(arithmetic._surd_coefficients(exp), 36 | (1/2, 0)) 37 | 38 | exp = sympy.sympify('sqrt(2)') 39 | self.assertEqual(arithmetic._surd_coefficients(exp), 40 | (0, 1)) 41 | 42 | exp = sympy.sympify('3*sqrt(2)') 43 | self.assertEqual(arithmetic._surd_coefficients(exp), 44 | (0, 3)) 45 | 46 | exp = sympy.sympify('3*sqrt(5)/2') 47 | self.assertEqual(arithmetic._surd_coefficients(exp), 48 | (0, 3/2)) 49 | 50 | exp = sympy.sympify('1 + 3 * sqrt(2)') 51 | self.assertEqual(arithmetic._surd_coefficients(exp), 52 | (1, 3)) 53 | 54 | exp = sympy.sympify('1/2 + 3 * sqrt(5) / 2') 55 | self.assertEqual(arithmetic._surd_coefficients(exp), 56 | (1/2, 3/2)) 57 | 58 | exp = sympy.sympify('sqrt(2)/(-1 + 2*sqrt(2))**2') 59 | self.assertEqual(arithmetic._surd_coefficients(exp), 60 | (8/49, 9/49)) 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/calculus.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Calculus related questions, e.g., "differentiate x**2".""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import math 23 | import random 24 | 25 | # Dependency imports 26 | from mathematics_dataset import example 27 | from mathematics_dataset.sample import polynomials 28 | from mathematics_dataset.util import composition 29 | from mathematics_dataset.util import display 30 | import numpy as np 31 | from six.moves import range 32 | import sympy 33 | 34 | 35 | _ENTROPY_TRAIN = (3, 10) 36 | _ENTROPY_INTERPOLATE = (8, 8) 37 | 38 | 39 | def _make_modules(entropy): 40 | """Returns modules given "difficulty" parameters.""" 41 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy) 42 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy) 43 | 44 | return { 45 | 'differentiate_composed': functools.partial( 46 | differentiate_univariate, None, sample_args_composed), 47 | 'differentiate': functools.partial(differentiate, None, sample_args_pure), 48 | } 49 | 50 | 51 | def train(entropy_fn): 52 | """Returns dict of training modules.""" 53 | return _make_modules(entropy_fn(_ENTROPY_TRAIN)) 54 | 55 | 56 | def test(): 57 | """Returns dict of testing modules.""" 58 | return _make_modules(_ENTROPY_INTERPOLATE) 59 | 60 | 61 | def test_extra(): 62 | """Returns dict of extrapolation testing modules.""" 63 | return { 64 | } 65 | 66 | 67 | def _generate_polynomial(num_variables, entropy, derivative_order, 68 | derivative_axis): 69 | """Returns polynomial.""" 70 | # Note: numpy randint has upper bound as ) not ], unlike python random.randint 71 | degrees = np.random.randint(1, 4, [num_variables]) 72 | degrees[derivative_axis] = np.random.randint(0, 4) # allow to be zero here. 73 | 74 | coefficients = polynomials.sample_coefficients(degrees, entropy) 75 | 76 | # We also generate coefficients that will disappear when differentiated. 77 | # Thus we don't account for the entropy used here. 78 | assert derivative_order > 0 79 | degrees[derivative_axis] = derivative_order - 1 80 | extra_coefficients = polynomials.sample_coefficients(degrees, entropy) 81 | 82 | return np.concatenate( 83 | [extra_coefficients, coefficients], axis=derivative_axis) 84 | 85 | 86 | def _template(module_count, derivative_order, num_variables): 87 | """Selects appropriate template.""" 88 | templates = [ 89 | 'Find the {nth} derivative of {eq} wrt {var}.', 90 | 'What is the {nth} derivative of {eq} wrt {var}?', 91 | ] 92 | if derivative_order == 1: 93 | templates += [ 94 | 'Differentiate {eq} with respect to {var}.', 95 | 'Differentiate {eq} wrt {var}.', 96 | 'What is the derivative of {eq} wrt {var}?', 97 | ] 98 | 99 | derivative_variable_is_unambiguous = num_variables == 1 and module_count == 1 100 | if derivative_variable_is_unambiguous: 101 | templates += [ 102 | 'Find the {nth} derivative of {eq}.', 103 | 'What is the {nth} derivative of {eq}?', 104 | ] 105 | if derivative_order == 1: 106 | templates += [ 107 | 'Differentiate {eq}.', 108 | 'What is the derivative of {eq}?', 109 | ] 110 | 111 | return random.choice(templates) 112 | 113 | 114 | def _sample_integrand(coefficients, derivative_order, derivative_axis, entropy): 115 | """Integrates `coefficients` and adds sampled "constant" terms.""" 116 | coefficients = np.asarray(coefficients) 117 | 118 | # Integrate (with zero for constant terms). 119 | integrand = coefficients 120 | for _ in range(derivative_order): 121 | integrand = polynomials.integrate(integrand, derivative_axis) 122 | 123 | # Add on sampled constant terms. 124 | constant_degrees = np.array(integrand.shape) - 1 125 | constant_degrees[derivative_axis] = derivative_order - 1 126 | extra_coeffs = polynomials.sample_coefficients(constant_degrees, entropy) 127 | pad_amount = coefficients.shape[derivative_axis] 128 | pad = [(0, pad_amount if i == derivative_axis else 0) 129 | for i in range(coefficients.ndim)] 130 | extra_coeffs = np.pad(extra_coeffs, pad, 'constant', constant_values=0) 131 | return integrand + extra_coeffs 132 | 133 | 134 | def _differentiate_polynomial(value, sample_args, context, num_variables): 135 | """Generates a question for differentiating a polynomial.""" 136 | is_question = context is None 137 | if context is None: 138 | context = composition.Context() 139 | 140 | if value is not None: 141 | num_variables = value.coefficients.ndim 142 | 143 | entropy, sample_args = sample_args.peel() 144 | max_derivative_order = 3 145 | derivative_order = random.randint(1, max_derivative_order) 146 | entropy = max(0, entropy - math.log10(max_derivative_order)) 147 | 148 | derivative_axis = random.randint(0, num_variables - 1) 149 | if value is None: 150 | coefficients = _generate_polynomial( 151 | num_variables, entropy, derivative_order, derivative_axis) 152 | else: 153 | coefficients = _sample_integrand( 154 | value.coefficients, derivative_order, derivative_axis, entropy) 155 | 156 | (entity,) = context.sample( 157 | sample_args, [composition.Polynomial(coefficients)]) 158 | 159 | value = coefficients 160 | for _ in range(derivative_order): 161 | value = polynomials.differentiate(value, axis=derivative_axis) 162 | nth = display.StringOrdinal(derivative_order) 163 | 164 | if entity.has_expression(): 165 | polynomial = entity.expression 166 | variables = entity.polynomial_variables 167 | else: 168 | variables = [sympy.Symbol(context.pop()) for _ in range(num_variables)] 169 | polynomial = entity.handle.apply(*variables) 170 | variable = variables[derivative_axis] 171 | 172 | if is_question: 173 | template = _template(context.module_count, derivative_order, len(variables)) 174 | answer = polynomials.coefficients_to_polynomial(value, variables).sympy() 175 | return example.Problem( 176 | question=example.question( 177 | context, template, eq=polynomial, var=variable, nth=nth), 178 | answer=answer) 179 | else: 180 | fn_symbol = context.pop() 181 | variables_string = ', '.join(str(variable) for variable in variables) 182 | assert len(variables) == 1 # since below we don't specify var we diff wrt 183 | return composition.Entity( 184 | context=context, 185 | value=composition.Polynomial(value), 186 | description='Let {fn}({variables}) be the {nth} derivative of {eq}.', 187 | handle=composition.FunctionHandle(fn_symbol), 188 | fn=fn_symbol, variables=variables_string, nth=nth, eq=polynomial) 189 | 190 | 191 | def differentiate_univariate(value, sample_args, context=None): 192 | return _differentiate_polynomial(value, sample_args, context, 1) 193 | 194 | 195 | @composition.module(composition.is_polynomial) 196 | def differentiate(value, sample_args, context=None): 197 | num_variables = random.randint(1, 4) 198 | return _differentiate_polynomial(value, sample_args, context, num_variables) 199 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/calculus_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.modules.calculus.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from mathematics_dataset.modules import calculus 23 | import tensorflow as tf 24 | 25 | 26 | class CalculusTest(tf.test.TestCase): 27 | 28 | def testSampleIntegrand(self): 29 | # y + 2*x + 3*x**2 30 | coefficients = [[0, 1], [2, 0], [3, 0]] 31 | derivative_order = 1 32 | derivative_axis = 0 33 | # const + x*y + x**2 + x**3 34 | expected = [[0, 1], [1, 0], [1, 0]] 35 | entropy = 4 36 | result = calculus._sample_integrand( 37 | coefficients, derivative_order, derivative_axis, entropy) 38 | result = result[1:, :] # ignore random constant terms 39 | self.assertAllEqual(result, expected) 40 | 41 | 42 | if __name__ == '__main__': 43 | tf.test.main() 44 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/comparison.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Comparisons, e.g. "is 2 > 3?".""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import random 23 | 24 | # Dependency imports 25 | from mathematics_dataset import example 26 | from mathematics_dataset.sample import number 27 | from mathematics_dataset.sample import ops 28 | from mathematics_dataset.util import composition 29 | from mathematics_dataset.util import display 30 | import numpy as np 31 | from six.moves import range 32 | import sympy 33 | 34 | 35 | _ENTROPY_TRAIN = (3, 10) 36 | _ENTROPY_INTERPOLATE = (8, 8) 37 | _ENTROPY_EXTRAPOLATE = (12, 12) 38 | 39 | _EXTRAPOLATION_EXTRA_COUNT = 2 40 | 41 | _PROB_EQUAL = 0.2 42 | 43 | 44 | def _make_modules(entropy): 45 | """Returns modules given "difficulty" parameters.""" 46 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy) 47 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy) 48 | 49 | return { 50 | 'pair': functools.partial(pair, sample_args_pure), 51 | 'pair_composed': functools.partial(pair, sample_args_composed), 52 | 'kth_biggest': functools.partial(kth_biggest, sample_args_pure), 53 | 'kth_biggest_composed': functools.partial( 54 | kth_biggest, sample_args_composed), 55 | 'closest': functools.partial(closest, sample_args_pure), 56 | 'closest_composed': functools.partial(closest, sample_args_composed), 57 | 'sort': functools.partial(sort, sample_args_pure), 58 | 'sort_composed': functools.partial(sort, sample_args_composed), 59 | } 60 | 61 | 62 | def train(entropy_fn): 63 | """Returns dict of training modules.""" 64 | return _make_modules(entropy_fn(_ENTROPY_TRAIN)) 65 | 66 | 67 | def test(): 68 | """Returns dict of testing modules.""" 69 | return _make_modules(_ENTROPY_INTERPOLATE) 70 | 71 | 72 | def test_extra(): 73 | """Returns dict of extrapolation testing modules.""" 74 | sample_args_pure = composition.PreSampleArgs(1, 1, *_ENTROPY_EXTRAPOLATE) 75 | 76 | def sort_count(): 77 | lower = _sort_count_range(_ENTROPY_TRAIN[1])[1] 78 | return random.randint(lower + 1, lower + _EXTRAPOLATION_EXTRA_COUNT) 79 | def closest_count(): 80 | lower = _closest_count_range(_ENTROPY_TRAIN[1])[1] 81 | return random.randint(lower + 1, lower + _EXTRAPOLATION_EXTRA_COUNT) 82 | def kth_biggest_more(): 83 | return kth_biggest(sample_args_pure, count=sort_count()) 84 | def sort_more(): 85 | return sort(sample_args_pure, count=sort_count()) 86 | def closest_more(): 87 | return closest(sample_args_pure, count=closest_count()) 88 | 89 | return { 90 | 'kth_biggest_more': kth_biggest_more, 91 | 'sort_more': sort_more, 92 | 'closest_more': closest_more, 93 | } 94 | 95 | 96 | def _make_comparison_question(context, left, right): 97 | """Makes a question for comparing two values.""" 98 | if random.choice([False, True]) and sympy.Ne(left.value, right.value): 99 | # Do question of form: "Which is bigger: a or b?". 100 | if random.choice([False, True]): 101 | answer = ( 102 | left.handle if sympy.Gt(left.value, right.value) else right.handle) 103 | template = random.choice([ 104 | 'Which is bigger: {left} or {right}?', 105 | 'Which is greater: {left} or {right}?', 106 | ]) 107 | else: 108 | answer = ( 109 | left.handle if sympy.Lt(left.value, right.value) else right.handle) 110 | template = random.choice([ 111 | 'Which is smaller: {left} or {right}?', 112 | ]) 113 | return example.Problem( 114 | question=example.question(context, template, left=left, right=right), 115 | answer=answer) 116 | 117 | comparisons = { 118 | '<': sympy.Lt, 119 | '<=': sympy.Le, 120 | '>': sympy.Gt, 121 | '>=': sympy.Ge, 122 | '=': sympy.Eq, 123 | '!=': sympy.Ne, 124 | } 125 | 126 | templates = { 127 | '<': [ 128 | 'Is {left} ' + ops.LT_SYMBOL + ' {right}?', 129 | 'Is {left} less than {right}?', 130 | 'Is {left} smaller than {right}?', 131 | ], 132 | '<=': [ 133 | 'Is {left} ' + ops.LE_SYMBOL + ' {right}?', 134 | 'Is {left} less than or equal to {right}?', 135 | 'Is {left} at most {right}?', 136 | 'Is {left} at most as big as {right}?', 137 | ], 138 | '>': [ 139 | 'Is {left} ' + ops.GT_SYMBOL + ' {right}?', 140 | 'Is {left} greater than {right}?', 141 | 'Is {left} bigger than {right}?', 142 | ], 143 | '>=': [ 144 | 'Is {left} ' + ops.GE_SYMBOL + ' {right}?', 145 | 'Is {left} greater than or equal to {right}?', 146 | 'Is {left} at least {right}?', 147 | 'Is {left} at least as big as {right}?', 148 | ], 149 | '=': [ 150 | 'Does {left} ' + ops.EQ_SYMBOL + ' {right}?', 151 | 'Are {left} and {right} equal?', 152 | 'Is {left} equal to {right}?', 153 | 'Do {left} and {right} have the same value?', 154 | ], 155 | '!=': [ 156 | 'Is {left} ' + ops.NE_SYMBOL + ' {right}?', 157 | 'Is {left} not equal to {right}?', 158 | 'Are {left} and {right} unequal?', 159 | 'Are {left} and {right} nonequal?', 160 | 'Are {left} and {right} non-equal?', 161 | 'Do {left} and {right} have different values?', 162 | ], 163 | } 164 | 165 | comparison = random.choice(list(comparisons.keys())) 166 | template = random.choice(templates[comparison]) 167 | question = example.question(context, template, left=left, right=right) 168 | answer = comparisons[comparison](left.value, right.value) 169 | 170 | return example.Problem(question=question, answer=answer) 171 | 172 | 173 | def integer_or_rational_or_decimal(entropy): 174 | if random.choice([False, True]): 175 | return number.integer_or_decimal(entropy, signed=True) 176 | else: 177 | return number.integer_or_rational(entropy, signed=True) 178 | 179 | 180 | def pair(sample_args, context=None): 181 | """Compares two numbers, e.g., "is 1/2 < 0.5?".""" 182 | if context is None: 183 | context = composition.Context() 184 | entropy, sample_args = sample_args.peel() 185 | 186 | def integers_close(): 187 | entropy_diff, entropy_left = entropy * np.random.dirichlet([1, 3]) 188 | left = number.integer(entropy_left, True) 189 | right = left + number.integer(entropy_diff, True) 190 | return left, right 191 | 192 | def rational_and_integer(): 193 | # Pick rational, and integer close to rational evaluation 194 | left = number.non_integer_rational(entropy, True) 195 | right = int(round(left)) + random.randint(-1, 1) 196 | return left, right 197 | 198 | def independent(): 199 | # Return an independent pair. 200 | entropy_left, entropy_right = entropy * np.random.dirichlet([1, 1]) 201 | left = integer_or_rational_or_decimal(entropy_left) 202 | right = integer_or_rational_or_decimal(entropy_right) 203 | return left, right 204 | 205 | generator = random.choice([integers_close, rational_and_integer, independent]) 206 | 207 | left, right = generator() 208 | 209 | # maybe swap for symmetry 210 | if random.choice([False, True]): 211 | left, right = right, left 212 | left, right = context.sample(sample_args, [left, right]) 213 | 214 | return _make_comparison_question(context, left, right) 215 | 216 | 217 | def _entities_to_list(entities): 218 | entity_dict = {} 219 | values_template = '' 220 | for i, entity in enumerate(entities): 221 | if i > 0: 222 | values_template += ', ' 223 | entity_name = 'entity_{}'.format(i) 224 | entity_dict[entity_name] = entity 225 | values_template += '{' + entity_name + '}' 226 | return entity_dict, values_template 227 | 228 | 229 | def _entities_to_choices(entities, answer): 230 | """Generate a multichoice question template.""" 231 | if len(entities) > 26: 232 | raise ValueError('Too many choices: {}'.format(len(entities))) 233 | 234 | entity_dict = {} 235 | choices_template = '' 236 | answer_choice = None 237 | for i, entity in enumerate(entities): 238 | choices_template += ' ' 239 | entity_name = 'entity_{}'.format(i) 240 | entity_dict[entity_name] = entity 241 | letter = chr(ord('a') + i) 242 | choices_template += '({letter}) {{{entity_name}}}'.format( 243 | letter=letter, entity_name=entity_name) 244 | if entity is answer: 245 | assert answer_choice is None 246 | answer_choice = letter 247 | 248 | assert answer_choice is not None 249 | return entity_dict, choices_template, answer_choice 250 | 251 | 252 | def _mark_choice_letters_used(count, context): 253 | """Marks the choice letters as used.""" 254 | for i in range(count): 255 | context.mark_used(chr(ord('a') + i)) 256 | 257 | 258 | def _kth_biggest_list_question(context, entities, adjective, answer): 259 | """Ask for the biggest (or smallest, or second biggest, etc) in a list.""" 260 | entity_dict, values_template = _entities_to_list(entities) 261 | 262 | question = example.question( 263 | context, 'What is the {adjective} value in ' + values_template + '?', 264 | adjective=adjective, **entity_dict) 265 | return example.Problem(question=question, answer=answer.handle) 266 | 267 | 268 | def _kth_biggest_multichoice_question(context, entities, adjective, answer): 269 | """Ask for the biggest (or smallest, or second biggest, etc) of choices.""" 270 | entity_dict, choices_template, answer_choice = _entities_to_choices( 271 | entities, answer) 272 | question = example.question( 273 | context, 'Which is the {adjective} value?' + choices_template, 274 | adjective=adjective, **entity_dict) 275 | return example.Problem(question=question, answer=answer_choice) 276 | 277 | 278 | def _entity_sort_key(entity): 279 | return sympy.default_sort_key(entity.value) 280 | 281 | 282 | def _sort_count_range(entropy): 283 | min_ = 3 284 | return min_, min_ + int(entropy/2) 285 | 286 | 287 | def _unique_values(entropy, only_integers=False, count=None): 288 | """Generates unique values.""" 289 | if count is None: 290 | count = random.randint(*_sort_count_range(entropy)) 291 | 292 | if only_integers: 293 | sampler = functools.partial(number.integer, signed=True) 294 | else: 295 | sampler = integer_or_rational_or_decimal 296 | 297 | for _ in range(1000): 298 | entropies = entropy * np.random.dirichlet(np.ones(count)) 299 | entropies = np.maximum(1, entropies) 300 | values = [sampler(ent) for ent in entropies] 301 | if len(sympy.FiniteSet(*values)) == len(values): 302 | return values 303 | raise ValueError('Could not generate {} unique values with entropy={}' 304 | .format(count, entropy)) 305 | 306 | 307 | def kth_biggest(sample_args, count=None): 308 | """Asks for the kth biggest value in a list.""" 309 | sample_args = sample_args() 310 | context = composition.Context() 311 | 312 | entropy, sample_args = sample_args.peel() 313 | values = _unique_values(entropy, count=count) 314 | count = len(values) 315 | 316 | display_multichoice = random.choice([False, True]) 317 | if display_multichoice: 318 | _mark_choice_letters_used(count, context) 319 | 320 | entities = context.sample(sample_args, values) 321 | sorted_entities = sorted(entities, key=_entity_sort_key) 322 | ordinal = random.randint(1, count) 323 | 324 | if random.choice([False, True]): 325 | # Do from biggest. 326 | answer = sorted_entities[-ordinal] 327 | adjective = 'biggest' 328 | else: 329 | # Do from smallest. 330 | answer = sorted_entities[ordinal - 1] 331 | adjective = 'smallest' 332 | 333 | if ordinal > 1: 334 | adjective = str(display.StringOrdinal(ordinal)) + ' ' + adjective 335 | 336 | if display_multichoice: 337 | return _kth_biggest_multichoice_question( 338 | context=context, entities=entities, adjective=adjective, answer=answer) 339 | else: 340 | return _kth_biggest_list_question( 341 | context=context, entities=entities, adjective=adjective, answer=answer) 342 | 343 | 344 | def _closest_in_list_question(context, entities, target, adjective, answer): 345 | """Ask for the closest to a given value in a list.""" 346 | entity_dict, values_template = _entities_to_list(entities) 347 | 348 | question = example.question( 349 | context, 350 | 'What is the {adjective} to {target} in ' + values_template + '?', 351 | adjective=adjective, target=target, **entity_dict) 352 | return example.Problem(question=question, answer=answer.handle) 353 | 354 | 355 | def _closest_multichoice_question(context, entities, target, adjective, answer): 356 | """Ask for the closest to a given value in a set of choices.""" 357 | entity_dict, choices_template, answer_choice = _entities_to_choices( 358 | entities, answer) 359 | 360 | question = example.question( 361 | context, 362 | 'Which is the {adjective} to {target}?' + choices_template, 363 | adjective=adjective, target=target, **entity_dict) 364 | return example.Problem(question=question, answer=answer_choice) 365 | 366 | 367 | def _closest_count_range(entropy): 368 | min_ = 3 369 | return min_, min_ + int(entropy/3) 370 | 371 | 372 | def closest(sample_args, count=None): 373 | """Ask for the closest to a given value in a list.""" 374 | sample_args = sample_args() 375 | context = composition.Context() 376 | 377 | entropy, sample_args = sample_args.peel() 378 | if count is None: 379 | count = random.randint(*_closest_count_range(entropy)) 380 | 381 | display_multichoice = random.choice([False, True]) 382 | if display_multichoice: 383 | _mark_choice_letters_used(count, context) 384 | 385 | entropy_target, entropy_list = entropy * np.random.dirichlet([1, count]) 386 | target = integer_or_rational_or_decimal(entropy_target) 387 | 388 | while True: 389 | value_entropies = entropy_list * np.random.dirichlet(np.ones(count)) 390 | value_entropies = np.maximum(1, value_entropies) 391 | values = [integer_or_rational_or_decimal(ent) for ent in value_entropies] 392 | differences = [abs(sympy.sympify(value) - target) for value in values] 393 | if len(sympy.FiniteSet(*differences)) == count: # all differences unique 394 | break 395 | 396 | target_and_entities = context.sample(sample_args, [target] + values) 397 | target = target_and_entities[0] 398 | entities = target_and_entities[1:] 399 | 400 | min_difference = min(differences) 401 | answer_index = differences.index(min_difference) 402 | answer = entities[answer_index] 403 | adjective = random.choice(['closest', 'nearest']) 404 | 405 | if display_multichoice: 406 | return _closest_multichoice_question( 407 | context=context, entities=entities, target=target, adjective=adjective, 408 | answer=answer) 409 | else: 410 | return _closest_in_list_question( 411 | context=context, entities=entities, target=target, adjective=adjective, 412 | answer=answer) 413 | 414 | 415 | def sort(sample_args, count=None): 416 | """Ask to sort numbers in increasing or decreasing order.""" 417 | sample_args = sample_args() 418 | context = composition.Context() 419 | 420 | entropy, sample_args = sample_args.peel() 421 | # Sometimes just integers, to allow for more terms in a short space. 422 | values = _unique_values( 423 | entropy, only_integers=random.choice([False, True]), count=count) 424 | 425 | entities = context.sample(sample_args, values) 426 | 427 | unsorted_dict, unsorted_template = _entities_to_list(entities) 428 | 429 | ascending = random.choice([False, True]) 430 | templates = [ 431 | 'Sort ' + unsorted_template + ' in {direction} order.', 432 | 'Put ' + unsorted_template + ' in {direction} order.', 433 | ] 434 | if ascending: 435 | templates.append('Sort ' + unsorted_template + '.') 436 | direction = random.choice(['ascending', 'increasing']) 437 | else: 438 | direction = random.choice(['descending', 'decreasing']) 439 | template = random.choice(templates) 440 | 441 | sorted_entities = sorted( 442 | entities, key=_entity_sort_key, reverse=(not ascending)) 443 | answer = '' 444 | for i, entity in enumerate(sorted_entities): 445 | if i > 0: 446 | answer += ', ' 447 | answer += str(entity.handle) 448 | 449 | return example.Problem( 450 | question=example.question( 451 | context, template, direction=direction, **unsorted_dict), 452 | answer=answer) 453 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/measurement.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Measurement questions, e.g., "How many hours are there in a day?".""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import functools 23 | import random 24 | 25 | # Dependency imports 26 | from mathematics_dataset import example 27 | from mathematics_dataset.modules import train_test_split 28 | from mathematics_dataset.sample import number 29 | from mathematics_dataset.util import composition 30 | from mathematics_dataset.util import display 31 | import six 32 | import sympy 33 | 34 | 35 | def _make_modules(is_train): 36 | """Returns modules, with split based on the boolean `is_train`.""" 37 | return { 38 | 'conversion': functools.partial( 39 | conversion, is_train=is_train, is_extrapolation=False), 40 | 'time': functools.partial(time, is_train=is_train), 41 | } 42 | 43 | 44 | def train(entropy_fn): 45 | """Returns dict of training modules.""" 46 | del entropy_fn # unused 47 | return _make_modules(is_train=True) 48 | 49 | 50 | def test(): 51 | """Returns dict of testing modules.""" 52 | return _make_modules(is_train=False) 53 | 54 | 55 | def test_extra(): 56 | """Returns dict of extrapolation testing modules.""" 57 | return { 58 | 'conversion': functools.partial( 59 | conversion, is_train=False, is_extrapolation=True), 60 | } 61 | 62 | 63 | Unit = collections.namedtuple('Unit', ('name', 'symbol')) 64 | 65 | 66 | MICRO_SYMBOL = 'u' 67 | 68 | 69 | LENGTH = { 70 | Unit('meter', 'm'): 1, 71 | Unit('kilometer', 'km'): 1000, 72 | Unit('centimeter', 'cm'): sympy.Rational(1, 100), 73 | Unit('millimeter', 'mm'): sympy.Rational(1, 1000), 74 | Unit('micrometer', 'um'): sympy.Rational(1, 1e6), 75 | Unit('nanometer', 'nm'): sympy.Rational(1, 1e9), 76 | } 77 | 78 | TIME = { 79 | Unit('second', 's'): 1, 80 | Unit('minute', None): 60, 81 | Unit('hour', None): 60*60, 82 | Unit('day', None): 24*60*60, 83 | Unit('week', None): 7*24*60*60, 84 | Unit('millisecond', 'ms'): sympy.Rational(1, 1e3), 85 | Unit('microsecond', MICRO_SYMBOL + 's'): sympy.Rational(1, 1e6), 86 | Unit('nanosecond', 'ns'): sympy.Rational(1, 1e9), 87 | } 88 | 89 | TIME_YEARLY = { 90 | Unit('year', None): 1, 91 | Unit('decade', None): 10, 92 | Unit('century', None): 100, 93 | Unit('millennium', None): 1000, 94 | Unit('month', None): sympy.Rational(1, 12), 95 | } 96 | 97 | MASS = { 98 | Unit('kilogram', 'kg'): 1, # Yes, the *kilo*gram is the SI base unit. 99 | Unit('tonne', 't'): 1000, 100 | Unit('gram', 'g'): sympy.Rational(1, 1e3), 101 | Unit('milligram', 'mg'): sympy.Rational(1, 1e6), 102 | Unit('microgram', MICRO_SYMBOL + 'g'): sympy.Rational(1, 1e9), 103 | Unit('nanogram', 'ng'): sympy.Rational(1, 1e12), 104 | } 105 | 106 | VOLUME = { 107 | Unit('litre', 'l'): 1, 108 | Unit('millilitre', 'ml'): sympy.Rational(1, 1000), 109 | } 110 | 111 | 112 | DIMENSIONS = [LENGTH, TIME, TIME_YEARLY, MASS, VOLUME] 113 | 114 | 115 | def pluralize(name): 116 | if name == 'century': 117 | return 'centuries' 118 | if name == 'millennium': 119 | return 'millennia' 120 | return name + 's' 121 | 122 | 123 | def _factor_non_decimal(value): 124 | """Extras x dividing value such that x is coprime to 2 and 5.""" 125 | result = 1 126 | factors = sympy.factorint(value) 127 | for factor, power in six.iteritems(factors): 128 | if factor not in [2, 5]: 129 | result *= factor ** power 130 | return result 131 | 132 | 133 | def _sample_conversion_decimal(dimension, is_extrapolation): 134 | """Samples to and from units and values.""" 135 | base_unit, target_unit = random.sample(list(dimension.keys()), 2) 136 | scale = sympy.Rational(dimension[base_unit]) / dimension[target_unit] 137 | scale_non_decimal = _factor_non_decimal(sympy.denom(scale)) 138 | entropy = 9 if is_extrapolation else 7 139 | base_value = number.non_integer_decimal(entropy, signed=False) 140 | base_value = display.Decimal(base_value.value * scale_non_decimal) 141 | target_value = display.Decimal(base_value.value * scale) 142 | return base_value, base_unit, target_value, target_unit 143 | 144 | 145 | def _conversion_decimal(context, is_train, is_extrapolation): 146 | """E.g., "How many grams are in 5kg?".""" 147 | dimension = random.choice(DIMENSIONS) 148 | while True: 149 | base_value, base_unit, target_value, target_unit = ( 150 | _sample_conversion_decimal(dimension, is_extrapolation)) 151 | if train_test_split.is_train(base_value) == is_train: 152 | break 153 | 154 | templates = [ 155 | 'How many {target_name} are there in {base_value} {base_name}?', 156 | 'What is {base_value} {base_name} in {target_name}?', 157 | 'Convert {base_value} {base_name} to {target_name}.', 158 | ] 159 | if base_unit.symbol is not None: 160 | templates += [ 161 | 'How many {target_name} are there in {base_value}{base_symbol}?', 162 | 'What is {base_value}{base_symbol} in {target_name}?', 163 | 'Convert {base_value}{base_symbol} to {target_name}.', 164 | ] 165 | template = random.choice(templates) 166 | 167 | base_name = pluralize(base_unit.name) 168 | target_name = pluralize(target_unit.name) 169 | 170 | question = example.question( 171 | context, 172 | template, 173 | base_name=base_name, 174 | base_symbol=base_unit.symbol, 175 | base_value=base_value, 176 | target_name=target_name) 177 | return example.Problem(question=question, answer=target_value) 178 | 179 | 180 | def _conversion_fraction(context, is_train): 181 | """E.g., "How many grams are in three quarters of a kg?".""" 182 | dimension = random.choice(DIMENSIONS) 183 | 184 | # Limit probability of giving zero answer. 185 | allow_zero = random.random() < 0.2 186 | 187 | # Repeat until we find a pair with an integral answer. (Avoids ambiguity with 188 | # decimals.) 189 | while True: 190 | base_unit, target_unit = random.sample(list(dimension.keys()), 2) 191 | base_value = number.non_integer_rational(2, signed=False) 192 | if train_test_split.is_train(base_value) != is_train: 193 | continue 194 | answer = (base_value * sympy.Rational(dimension[base_unit]) 195 | / sympy.Rational(dimension[target_unit])) 196 | if (abs(answer) <= 100000 197 | and sympy.denom(answer) == 1 198 | and (allow_zero or answer != 0)): 199 | break 200 | 201 | template = random.choice([ 202 | 'How many {target_name} are there in {base_value} of a {base_name}?', 203 | 'What is {base_value} of a {base_name} in {target_name}?', 204 | ]) 205 | 206 | if sympy.denom(base_value) > 20 or random.choice([False, True]): 207 | base_value_string = base_value # Will be represented as e.g., 2/3. 208 | else: 209 | base_value_string = display.StringNumber(base_value) # e.g., two thirds 210 | 211 | question = example.question( 212 | context, template, 213 | base_name=base_unit.name, 214 | base_value=base_value_string, 215 | target_name=pluralize(target_unit.name)) 216 | return example.Problem(question=question, answer=answer) 217 | 218 | 219 | def conversion(is_train, is_extrapolation): 220 | """Conversion question, in decimal or fraction.""" 221 | context = composition.Context() 222 | # TODO(b/124038528): implement extrapolation for fraction conversions too 223 | if is_extrapolation or random.choice([False, True]): 224 | return _conversion_decimal( 225 | context, is_train=is_train, is_extrapolation=is_extrapolation) 226 | else: 227 | return _conversion_fraction(context, is_train=is_train) 228 | 229 | 230 | def time(is_train): 231 | """Questions for calculating start, end, or time differences.""" 232 | context = composition.Context() 233 | start_minutes = random.randint(1, 24*60 - 1) 234 | while True: 235 | duration_minutes = random.randint(1, 12*60 - 1) 236 | if train_test_split.is_train(duration_minutes) == is_train: 237 | break 238 | end_minutes = start_minutes + duration_minutes 239 | 240 | def format_12hr(minutes): 241 | """Format minutes from midnight in 12 hr format.""" 242 | hours = (minutes // 60) % 24 243 | minutes %= 60 244 | am_pm = 'AM' if hours < 12 else 'PM' 245 | hours = (hours - 1) % 12 + 1 246 | return '{}:{:02} {}'.format(hours, minutes, am_pm) 247 | 248 | start = format_12hr(start_minutes) 249 | end = format_12hr(end_minutes) 250 | 251 | which_question = random.randint(0, 3) 252 | if which_question == 0: 253 | # Question: What is start = end - duration? 254 | template = random.choice([ 255 | 'What is {duration} minutes before {end}?', 256 | ]) 257 | return example.Problem( 258 | question=example.question( 259 | context, template, duration=duration_minutes, end=end), 260 | answer=start) 261 | elif which_question == 1: 262 | # Question: What is end = start + duration? 263 | template = random.choice([ 264 | 'What is {duration} minutes after {start}?', 265 | ]) 266 | return example.Problem( 267 | question=example.question( 268 | context, template, duration=duration_minutes, start=start), 269 | answer=end) 270 | else: 271 | # Question: What is duration = end - start? 272 | template = random.choice([ 273 | 'How many minutes are there between {start} and {end}?', 274 | ]) 275 | return example.Problem( 276 | question=example.question(context, template, start=start, end=end), 277 | answer=duration_minutes) 278 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """The various mathematics modules.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from mathematics_dataset.modules import algebra 22 | from mathematics_dataset.modules import arithmetic 23 | from mathematics_dataset.modules import calculus 24 | from mathematics_dataset.modules import comparison 25 | from mathematics_dataset.modules import measurement 26 | from mathematics_dataset.modules import numbers 27 | from mathematics_dataset.modules import polynomials 28 | from mathematics_dataset.modules import probability 29 | import six 30 | 31 | 32 | all_ = { 33 | 'algebra': algebra, 34 | 'arithmetic': arithmetic, 35 | 'calculus': calculus, 36 | 'comparison': comparison, 37 | 'measurement': measurement, 38 | 'numbers': numbers, 39 | 'polynomials': polynomials, 40 | 'probability': probability, 41 | } 42 | 43 | 44 | def train(entropy_fn): 45 | """Returns dict of training modules.""" 46 | return { 47 | name: module.train(entropy_fn) for name, module in six.iteritems(all_) 48 | } 49 | 50 | 51 | def test(): 52 | """Returns dict of testing modules.""" 53 | return {name: module.test() for name, module in six.iteritems(all_)} 54 | 55 | 56 | def test_extra(): 57 | """Returns dict of extrapolation testing modules.""" 58 | return {name: module.test_extra() for name, module in six.iteritems(all_)} 59 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/polynomials.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Polynomial manipulation (adding, composing, finding coefficients, etc).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import math 23 | import random 24 | 25 | # Dependency imports 26 | from mathematics_dataset import example 27 | from mathematics_dataset.sample import number 28 | from mathematics_dataset.sample import ops 29 | from mathematics_dataset.sample import polynomials 30 | from mathematics_dataset.util import composition 31 | import numpy as np 32 | from six.moves import range 33 | import sympy 34 | 35 | 36 | _ENTROPY_TRAIN = (3, 10) 37 | _ENTROPY_INTERPOLATE = (8, 8) 38 | 39 | 40 | def _make_modules(entropy): 41 | """Returns modules given "difficulty" parameters.""" 42 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy) 43 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy) 44 | sample_args_mixed = composition.PreSampleArgs(1, 4, *entropy) 45 | 46 | return { 47 | 'coefficient_named': 48 | functools.partial(coefficient_named, None, sample_args_pure), 49 | 'evaluate': 50 | functools.partial(evaluate, None, sample_args_pure), 51 | 'evaluate_composed': 52 | functools.partial(evaluate, None, sample_args_composed), 53 | # TODO(b/124038948): consider doing pure sample args for 'add'? 54 | 'add': 55 | functools.partial(add, None, sample_args_mixed), 56 | 'expand': 57 | functools.partial(expand, None, sample_args_pure), 58 | 'collect': 59 | functools.partial(collect, None, sample_args_pure), 60 | 'compose': 61 | functools.partial(compose, None, sample_args_mixed), 62 | 63 | # Rearranging powers: 64 | 'simplify_power': 65 | functools.partial(simplify_power, None, sample_args_pure), 66 | } 67 | 68 | 69 | def train(entropy_fn): 70 | """Returns dict of training modules.""" 71 | return _make_modules(entropy_fn(_ENTROPY_TRAIN)) 72 | 73 | 74 | def test(): 75 | """Returns dict of testing modules.""" 76 | return _make_modules(_ENTROPY_INTERPOLATE) 77 | 78 | 79 | def test_extra(): 80 | """Returns dict of extrapolation testing modules.""" 81 | return { 82 | } 83 | 84 | 85 | def coefficient_named(value, sample_args, context=None): 86 | """E.g., "Express x^2 + 2x in the form h * x^2 + k * x + t and give h.".""" 87 | del value # not used 88 | if context is None: 89 | context = composition.Context() 90 | variable = sympy.Symbol(context.pop()) 91 | 92 | entropy, sample_args = sample_args.peel() 93 | degree = random.randint(1, 4) 94 | if random.choice([False, True]): 95 | coefficients = polynomials.sample_coefficients( 96 | degree, entropy/2, min_non_zero=random.randint(degree - 1, degree)) 97 | expanded = polynomials.expand_coefficients(coefficients, entropy/2) 98 | expression = polynomials.coefficients_to_polynomial(expanded, variable) 99 | else: 100 | expression = polynomials.sample_with_brackets(variable, degree, entropy) 101 | coefficients = list(reversed(sympy.Poly(expression).all_coeffs())) 102 | 103 | named_coeffs = [sympy.Symbol(context.pop()) for _ in range(degree + 1)] 104 | canonical = polynomials.coefficients_to_polynomial(named_coeffs, variable) 105 | 106 | if random.random() < 0.2: # only small probability of non-zero power 107 | power = random.randint(0, degree) 108 | else: 109 | non_zero_powers = [i for i in range(degree + 1) if coefficients[i] != 0] 110 | power = random.choice(non_zero_powers) 111 | 112 | value = coefficients[power] 113 | named_coeff = named_coeffs[power] 114 | 115 | template = random.choice([ 116 | 'Express {expression} as {canonical} and give {target}.', 117 | 'Rearrange {expression} to {canonical} and give {target}.', 118 | 'Express {expression} in the form {canonical} and give {target}.', 119 | 'Rearrange {expression} to the form {canonical} and give {target}.', 120 | ]) 121 | return example.Problem( 122 | question=example.question( 123 | context, template, expression=expression, canonical=canonical, 124 | target=named_coeff), 125 | answer=value) 126 | 127 | 128 | _TEMPLATES = [ 129 | 'What is {composed}?', 130 | 'Calculate {composed}.', 131 | 'Give {composed}.', 132 | 'Determine {composed}.', 133 | ] 134 | 135 | 136 | @composition.module(number.is_integer) 137 | def evaluate(value, sample_args, context=None): 138 | """Entity for evaluating an integer-valued polynomial at a given point.""" 139 | is_question = context is None 140 | if context is None: 141 | context = composition.Context() 142 | 143 | entropy, sample_args = sample_args.peel() 144 | 145 | if value is None: 146 | entropy_value = random.uniform(1, 1 + entropy/3) 147 | entropy = max(0, entropy - entropy_value) 148 | value = number.integer(entropy_value, signed=True) 149 | 150 | entropy_input = random.uniform(1, 1 + entropy/3) 151 | entropy = max(0, entropy - entropy_input) 152 | input_ = number.integer(entropy_input, signed=True) 153 | 154 | degree = random.randint(1, 3) 155 | 156 | entropies = entropy * np.random.dirichlet(list(range(1, degree + 1))) 157 | # Calculate coefficients in reverse order. 158 | target = value 159 | coeffs_reversed = [] 160 | for i, coeff_entropy in enumerate(entropies): 161 | power = degree - i 162 | coeff = number.integer(coeff_entropy, signed=True) 163 | if input_ != 0: 164 | coeff += int(round(target / input_ ** power)) 165 | if coeff == 0 and i == 0: 166 | # Don't allow zero in leading coefficient. 167 | coeff += random.choice([-1, 1]) 168 | coeffs_reversed.append(coeff) 169 | target -= coeff * (input_ ** power) 170 | coeffs_reversed.append(target) 171 | 172 | coefficients = list(reversed(coeffs_reversed)) 173 | 174 | (polynomial_entity, input_) = context.sample( 175 | sample_args, [composition.Polynomial(coefficients), input_]) 176 | composed = polynomial_entity.handle.apply(input_.handle) 177 | 178 | if is_question: 179 | template = random.choice(_TEMPLATES) 180 | return example.Problem( 181 | question=example.question(context, template, composed=composed), 182 | answer=value) 183 | else: 184 | return composition.Entity( 185 | context=context, 186 | value=value, 187 | expression=composed, 188 | description='Let {self} be {composed}.', 189 | composed=composed) 190 | 191 | 192 | # TODO(b/124039290): merge with compose? both add and compose do similar things. 193 | @composition.module(composition.is_integer_polynomial) 194 | def add(value, sample_args, context=None): 195 | """E.g., "Let f(x)=2x+1, g(x)=3x+2. What is 5*f(x) - 7*g(x)?".""" 196 | is_question = context is None 197 | if context is None: 198 | context = composition.Context() 199 | 200 | entropy, sample_args = sample_args.peel() 201 | 202 | if value is None: 203 | max_degree = 3 204 | degree = random.randint(1, max_degree) 205 | entropy -= math.log10(max_degree) 206 | entropy_value = entropy / 2 207 | entropy -= entropy_value 208 | value = polynomials.sample_coefficients( 209 | degree, entropy=entropy_value, min_non_zero=random.randint(1, 3)) 210 | value = composition.Polynomial(value) 211 | 212 | c1, c2, coeffs1, coeffs2 = polynomials.coefficients_linear_split( 213 | value.coefficients, entropy) 214 | coeffs1 = polynomials.trim(coeffs1) 215 | coeffs2 = polynomials.trim(coeffs2) 216 | 217 | c1, c2, fn1, fn2 = context.sample( 218 | sample_args, 219 | [c1, c2, composition.Polynomial(coeffs1), composition.Polynomial(coeffs2)] 220 | ) 221 | 222 | var = sympy.var(context.pop()) 223 | 224 | expression = ( 225 | c1.handle * fn1.handle.apply(var) + c2.handle * fn2.handle.apply(var)) 226 | 227 | if is_question: 228 | answer = polynomials.coefficients_to_polynomial(value.coefficients, var) 229 | answer = answer.sympy() 230 | template = random.choice(_TEMPLATES) 231 | return example.Problem( 232 | question=example.question(context, template, composed=expression), 233 | answer=answer) 234 | else: 235 | intermediate_symbol = context.pop() 236 | intermediate = sympy.Function(intermediate_symbol)(var) 237 | return composition.Entity( 238 | context=context, 239 | value=value, 240 | description='Let {intermediate} = {composed}.', 241 | handle=composition.FunctionHandle(intermediate_symbol), 242 | intermediate=intermediate, 243 | composed=expression) 244 | 245 | 246 | def expand(value, sample_args, context=None): 247 | """E.g., "Expand (x**2 + 1)**2.".""" 248 | del value # not used 249 | if context is None: 250 | context = composition.Context() 251 | variable = sympy.Symbol(context.pop()) 252 | entropy, sample_args = sample_args.peel() 253 | 254 | min_order = 1 255 | max_order = 5 256 | order = random.randint(min_order, max_order) 257 | entropy -= math.log10(max_order - min_order + 1) 258 | expression_ = polynomials.sample_with_brackets(variable, order, entropy) 259 | expanded = sympy.expand(expression_) 260 | template = random.choice([ 261 | 'Expand {expression}.' 262 | ]) 263 | return example.Problem( 264 | question=example.question(context, template, expression=expression_), 265 | answer=expanded) 266 | 267 | 268 | @composition.module(composition.is_polynomial) 269 | def collect(value, sample_args, context=None): 270 | """Collect terms in an unsimplified polynomial.""" 271 | is_question = context is None 272 | if context is None: 273 | context = composition.Context() 274 | 275 | entropy, sample_args = sample_args.peel() 276 | if value is None: 277 | entropy_value, entropy = entropy * np.random.dirichlet([2, 3]) 278 | degrees = [random.randint(1, 3)] 279 | value = composition.Polynomial( 280 | polynomials.sample_coefficients(degrees, entropy_value)) 281 | 282 | assert isinstance(value, composition.Polynomial) 283 | coefficients = value.coefficients 284 | 285 | all_coefficients_are_integer = True 286 | for coeff in coefficients.flat: 287 | if not number.is_integer(coeff): 288 | all_coefficients_are_integer = False 289 | break 290 | 291 | if all_coefficients_are_integer: 292 | coefficients = polynomials.expand_coefficients(coefficients, entropy) 293 | else: 294 | # put back the unused entropy 295 | sample_args = composition.SampleArgs( 296 | sample_args.num_modules, sample_args.entropy + entropy) 297 | 298 | num_variables = coefficients.ndim 299 | variables = [sympy.Symbol(context.pop()) for _ in range(num_variables)] 300 | unsimplified = polynomials.coefficients_to_polynomial(coefficients, variables) 301 | simplified = unsimplified.sympy().expand() 302 | 303 | # Bit of a hack: handle the very rare case where no number constants appearing 304 | if not ops.number_constants(unsimplified): 305 | unsimplified = ops.Add(unsimplified, ops.Constant(0)) 306 | context.sample_by_replacing_constants(sample_args, unsimplified) 307 | 308 | if is_question: 309 | template = 'Collect the terms in {unsimplified}.' 310 | return example.Problem( 311 | question=example.question(context, template, unsimplified=unsimplified), 312 | answer=simplified) 313 | else: 314 | function_symbol = context.pop() 315 | function = sympy.Function(function_symbol)(*variables) 316 | return composition.Entity( 317 | context=context, 318 | value=value, 319 | handle=composition.FunctionHandle(function_symbol), 320 | expression=unsimplified, 321 | polynomial_variables=variables, 322 | description='Let {function} = {unsimplified}.', 323 | function=function, 324 | unsimplified=unsimplified) 325 | 326 | 327 | def compose(value, sample_args, context=None): 328 | """E.g., "Let f(x)=2x+1, let g(x)=3x+10. What is f(g(x))?".""" 329 | del value # unused 330 | if context is None: 331 | context = composition.Context() 332 | 333 | entropy, sample_args = sample_args.peel() 334 | entropy_f, entropy_g = entropy * np.random.dirichlet([1, 1]) 335 | 336 | coeffs_f = polynomials.sample_coefficients([random.randint(1, 2)], entropy_f) 337 | coeffs_g = polynomials.sample_coefficients([random.randint(1, 2)], entropy_g) 338 | 339 | entity_f, entity_g = context.sample( 340 | sample_args, 341 | [composition.Polynomial(coeffs_f), composition.Polynomial(coeffs_g)]) 342 | 343 | variable = sympy.var(context.pop()) 344 | 345 | poly_f = polynomials.coefficients_to_polynomial(coeffs_f, variable) 346 | poly_g = polynomials.coefficients_to_polynomial(coeffs_g, variable) 347 | 348 | poly_f_g = poly_f.sympy().subs(variable, poly_g.sympy()).expand() 349 | 350 | expression = composition.FunctionHandle(entity_f, entity_g).apply(variable) 351 | 352 | template = random.choice(_TEMPLATES) 353 | return example.Problem( 354 | question=example.question(context, template, composed=expression), 355 | answer=poly_f_g) 356 | 357 | 358 | def simplify_power(value, sample_args, context=None): 359 | """E.g., "Simplify ((x**2)**3/x**4)**2/x**3.".""" 360 | del value # unused 361 | if context is None: 362 | context = composition.Context() 363 | 364 | entropy, sample_args = sample_args.peel() 365 | 366 | variable = sympy.symbols(context.pop(), positive=True) 367 | unsimplified = polynomials.sample_messy_power(variable, entropy) 368 | answer = unsimplified.sympy() 369 | 370 | template = random.choice([ 371 | 'Simplify {unsimplified} assuming {variable} is positive.', 372 | ]) 373 | return example.Problem( 374 | example.question( 375 | context, template, unsimplified=unsimplified, variable=variable), 376 | answer) 377 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/probability.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Probability questions (sampling, independence, expectations, ...).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import functools 23 | import random 24 | import string 25 | 26 | # Dependency imports 27 | from mathematics_dataset import example 28 | from mathematics_dataset.modules import train_test_split 29 | from mathematics_dataset.util import combinatorics 30 | from mathematics_dataset.util import composition 31 | from mathematics_dataset.util import display 32 | from mathematics_dataset.util import probability 33 | import numpy as np 34 | from six.moves import range 35 | from six.moves import zip 36 | 37 | 38 | _LETTERS = string.ascii_lowercase 39 | 40 | _MAX_FRAC_TRIVIAL_PROB = 0.1 41 | 42 | # Maximum number of colours and objects in a bag. 43 | _MAX_DISTINCT_LETTERS = 6 44 | _MAX_TOTAL_LETTERS = 20 45 | _MAX_LETTER_REPEAT = 10 46 | 47 | _SWR_SAMPLE_COUNT = [2, 4] 48 | _SWR_SAMPLE_COUNT_EXTRAPOLATE = [5, 5] 49 | 50 | _GERUNDS = { 51 | 'pick': 'picking', 52 | } 53 | 54 | 55 | def _make_modules(is_train): 56 | """Returns modules, with split based on the boolean `is_train`.""" 57 | return { 58 | 'swr_p_sequence': functools.partial( 59 | swr_prob_sequence, is_train=is_train, sample_range=_SWR_SAMPLE_COUNT), 60 | 'swr_p_level_set': functools.partial( 61 | swr_prob_level_set, is_train=is_train, 62 | sample_range=_SWR_SAMPLE_COUNT), 63 | } 64 | 65 | 66 | def train(entropy_fn): 67 | """Returns dict of training modules.""" 68 | del entropy_fn # unused 69 | return _make_modules(is_train=True) 70 | 71 | 72 | def test(): 73 | """Returns dict of testing modules.""" 74 | return _make_modules(is_train=False) 75 | 76 | 77 | def test_extra(): 78 | """Returns dict of extrapolation testing modules.""" 79 | return { 80 | 'swr_p_sequence_more_samples': functools.partial( 81 | swr_prob_sequence, is_train=None, 82 | sample_range=_SWR_SAMPLE_COUNT_EXTRAPOLATE), 83 | 'swr_p_level_set_more_samples': functools.partial( 84 | swr_prob_level_set, is_train=None, 85 | sample_range=_SWR_SAMPLE_COUNT_EXTRAPOLATE), 86 | } 87 | 88 | 89 | def _sequence_event(values, length, verb): 90 | """Returns sequence (finite product) event. 91 | 92 | Args: 93 | values: List of values to sample from. 94 | length: Length of the sequence to generate. 95 | verb: Verb in infinitive form. 96 | 97 | Returns: 98 | Instance of `probability.FiniteProductEvent`, together with a text 99 | description. 100 | """ 101 | del verb # unused 102 | samples = [random.choice(values) for _ in range(length)] 103 | events = [probability.DiscreteEvent([sample]) for sample in samples] 104 | event = probability.FiniteProductEvent(events) 105 | sequence = ''.join(str(sample) for sample in samples) 106 | event_description = 'sequence {sequence}'.format(sequence=sequence) 107 | return event, event_description 108 | 109 | 110 | def _word_series(words, conjunction='and'): 111 | """Combines the words using commas and the final conjunction.""" 112 | len_words = len(words) 113 | if len_words == 0: 114 | return '' 115 | if len_words == 1: 116 | return words[0] 117 | return '{} {} {}'.format(', '.join(words[:-1]), conjunction, words[-1]) 118 | 119 | 120 | def _level_set_event(values, length, verb): 121 | """Generates `LevelSetEvent`; see _generate_sequence_event.""" 122 | counts = combinatorics.uniform_non_negative_integers_with_sum( 123 | len(values), length) 124 | counts_dict = dict(list(zip(values, counts))) 125 | event = probability.CountLevelSetEvent(counts_dict) 126 | 127 | shuffled_values = list(values) 128 | random.shuffle(shuffled_values) 129 | 130 | counts_and_values = [ 131 | '{} {}'.format(counts_dict[value], value) 132 | for value in shuffled_values 133 | if counts_dict[value] > 0 134 | ] 135 | counts_and_values = _word_series(counts_and_values) 136 | template = random.choice([ 137 | '{verbing} {counts_and_values}', 138 | ]) 139 | verbing = _GERUNDS[verb] 140 | event_description = template.format( 141 | counts_and_values=counts_and_values, verbing=verbing) 142 | return event, event_description 143 | 144 | 145 | LetterBag = collections.namedtuple( 146 | 'LetterBag', 147 | ('weights', 'random_variable', 'letters_distinct', 'bag_contents')) 148 | 149 | 150 | def _sample_letter_bag(is_train, min_total): 151 | """Samples a "container of letters" and returns info on it.""" 152 | while True: 153 | num_distinct_letters = random.randint(1, _MAX_DISTINCT_LETTERS) 154 | num_letters_total = random.randint( 155 | max(num_distinct_letters, min_total), 156 | min(_MAX_TOTAL_LETTERS, num_distinct_letters * _MAX_LETTER_REPEAT)) 157 | letter_counts = combinatorics.uniform_positive_integers_with_sum( 158 | num_distinct_letters, num_letters_total) 159 | 160 | # Test/train split. 161 | if (is_train is None 162 | or train_test_split.is_train(sorted(letter_counts)) == is_train): 163 | break 164 | 165 | letters_distinct = random.sample(_LETTERS, num_distinct_letters) 166 | weights = {i: 1 for i in range(num_letters_total)} 167 | 168 | letters_with_repetition = [] 169 | for letter, count in zip(letters_distinct, letter_counts): 170 | letters_with_repetition += [letter] * count 171 | random.shuffle(letters_with_repetition) 172 | 173 | random_variable = probability.DiscreteRandomVariable( 174 | {i: letter for i, letter in enumerate(letters_with_repetition)}) 175 | 176 | if random.choice([False, True]): 177 | bag_contents = ''.join(letters_with_repetition) 178 | else: 179 | letters_and_counts = [ 180 | '{}: {}'.format(letter, count) 181 | for letter, count in zip(letters_distinct, letter_counts)] 182 | bag_contents = '{' + ', '.join(letters_and_counts) + '}' 183 | 184 | return LetterBag( 185 | weights=weights, 186 | random_variable=random_variable, 187 | letters_distinct=letters_distinct, 188 | bag_contents=bag_contents) 189 | 190 | 191 | def _swr_space(is_train, sample_range): 192 | """Returns probability space for sampling without replacement.""" 193 | num_sampled = random.randint(*sample_range) 194 | sample = _sample_letter_bag(is_train=is_train, min_total=num_sampled) 195 | 196 | space = probability.SampleWithoutReplacementSpace(sample.weights, num_sampled) 197 | 198 | random_variable = probability.FiniteProductRandomVariable( 199 | [sample.random_variable] * num_sampled) 200 | 201 | random_variable.description = ( 202 | str(display.StringNumber(num_sampled)) 203 | + ' letters picked without replacement from ' 204 | + sample.bag_contents) 205 | 206 | return sample.letters_distinct, space, random_variable 207 | 208 | 209 | def _sample_without_replacement_probability_question( 210 | is_train, event_fn, sample_range): 211 | """Question for prob of some event when sampling without replacement.""" 212 | def too_big(event_in_space): 213 | if isinstance(event_in_space, probability.SequenceEvent): 214 | size = len(event_in_space.all_sequences()) 215 | else: 216 | assert isinstance(event_in_space, probability.FiniteProductEvent) 217 | size = np.prod([len(event.values) for event in event_in_space.events]) 218 | return size > int(2e5) 219 | 220 | allow_trivial_prob = random.random() < _MAX_FRAC_TRIVIAL_PROB 221 | 222 | while True: 223 | distinct_letters, space, random_variable = _swr_space( 224 | is_train, sample_range) 225 | 226 | event, event_description = event_fn( 227 | values=distinct_letters, length=space.n_samples, verb='pick') 228 | event_in_space = random_variable.inverse(event) 229 | if too_big(event_in_space): 230 | continue 231 | answer = space.probability(event_in_space) 232 | if answer not in [0, 1] or allow_trivial_prob: 233 | break 234 | 235 | context = composition.Context() 236 | 237 | template = random.choice([ 238 | '{random_variable_capitalize}. What is prob of {event}?', 239 | '{random_variable_capitalize}. Give prob of {event}.', 240 | 'What is prob of {event} when {random_variable}?', 241 | 'Calculate prob of {event} when {random_variable}.', 242 | ]) 243 | question = example.question( 244 | context, 245 | template, 246 | random_variable=random_variable.description, 247 | random_variable_capitalize=( 248 | str(random_variable.description).capitalize()), 249 | event=event_description) 250 | return example.Problem(question, answer) 251 | 252 | 253 | def swr_prob_sequence(is_train, sample_range): 254 | """Probability of given sequence when sampling without replacement.""" 255 | return _sample_without_replacement_probability_question( 256 | is_train=is_train, event_fn=_sequence_event, sample_range=sample_range) 257 | 258 | 259 | def swr_prob_level_set(is_train, sample_range): 260 | """Probability of given level set when sampling without replacement.""" 261 | return _sample_without_replacement_probability_question( 262 | is_train=is_train, event_fn=_level_set_event, sample_range=sample_range) 263 | -------------------------------------------------------------------------------- /mathematics_dataset/modules/train_test_split.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility for train/test split based on hash value.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import hashlib 22 | 23 | 24 | def is_train(value): 25 | """Returns whether `value` should be used in a training question.""" 26 | value_as_string = str(value).encode('utf-8') 27 | return int(hashlib.md5(value_as_string).hexdigest(), 16) % 2 == 0 28 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/arithmetic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sample arithmetic expressions with a given value.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import math 23 | import random 24 | 25 | # Dependency imports 26 | from mathematics_dataset.sample import number 27 | from mathematics_dataset.sample import ops 28 | from mathematics_dataset.util import combinatorics 29 | import numpy as np 30 | import six 31 | from six.moves import zip 32 | import sympy 33 | 34 | 35 | class _SampleArgs(collections.namedtuple('SampleArgs', ('count', 'entropy'))): 36 | """For sampling mathematical expressions.""" 37 | 38 | def peel(self, frac=1): 39 | """Peels one (or `frac`) of an op's entropy.""" 40 | entropy = frac * self.entropy / self.count 41 | new_sample_args = _SampleArgs(self.count, self.entropy - entropy) 42 | return entropy, new_sample_args 43 | 44 | def split(self, args): 45 | """Splits the entropy and op counts up.""" 46 | non_integer_count = sum(not arg.is_Integer for arg in args) 47 | assert non_integer_count <= self.count - 1 48 | count_split = combinatorics.uniform_non_negative_integers_with_sum( 49 | len(args), (self.count - 1) - non_integer_count) 50 | for i, arg in enumerate(args): 51 | if not arg.is_Integer: 52 | count_split[i] += 1 53 | if all(count == 0 for count in count_split): 54 | assert self.entropy == 0 55 | entropies = np.zeros(len(count_split)) 56 | else: 57 | entropies = ( 58 | np.random.dirichlet(np.maximum(1e-9, count_split)) * self.entropy) 59 | return [_SampleArgs(op_count, entropy) 60 | for op_count, entropy in zip(count_split, entropies)] 61 | 62 | 63 | def _add_sub_filter(value, sample_args): 64 | return sample_args.count >= 2 or value.is_Integer 65 | 66 | 67 | def _add_op(value, sample_args, rationals_allowed): 68 | """Returns sampled args for `ops.Add`.""" 69 | entropy, sample_args = sample_args.peel() 70 | if rationals_allowed and sample_args.count >= 3: 71 | x = number.integer_or_rational(entropy, True) 72 | else: 73 | x = number.integer(entropy, True) 74 | if random.choice([False, True]): 75 | op_args = [x, value - x] 76 | else: 77 | op_args = [value - x, x] 78 | return ops.Add, op_args, sample_args 79 | 80 | 81 | def _sub_op(value, sample_args, rationals_allowed): 82 | """Returns sampled args for `ops.Sub`.""" 83 | entropy, sample_args = sample_args.peel() 84 | if rationals_allowed and sample_args.count >= 3: 85 | x = number.integer_or_rational(entropy, True) 86 | else: 87 | x = number.integer(entropy, True) 88 | if random.choice([False, True]): 89 | op_args = [x, x - value] 90 | else: 91 | op_args = [value + x, x] 92 | return ops.Sub, op_args, sample_args 93 | 94 | 95 | def _entropy_of_factor_split(integer): 96 | """Returns entropy (log base 10) of decomposing: integer = a * b.""" 97 | assert integer.is_Integer 98 | if integer == 0: 99 | return 0 100 | # Gives dict of form {factor: multiplicity} 101 | factors = sympy.factorint(integer) 102 | return sum(math.log10(mult + 1) for mult in six.itervalues(factors)) 103 | 104 | 105 | def _split_factors(integer): 106 | """Randomly factors integer into product of two integers.""" 107 | assert integer.is_Integer 108 | if integer == 0: 109 | return [1, 0] 110 | # Gives dict of form {factor: multiplicity} 111 | factors = sympy.factorint(integer) 112 | left = sympy.Integer(1) 113 | right = sympy.Integer(1) 114 | for factor, mult in six.iteritems(factors): 115 | left_mult = random.randint(0, mult) 116 | right_mult = mult - left_mult 117 | left *= factor ** left_mult 118 | right *= factor ** right_mult 119 | return left, right 120 | 121 | 122 | def _mul_filter(value, sample_args): 123 | if sample_args.count >= 2: 124 | return True 125 | if not value.is_Integer: 126 | return False 127 | return sample_args.entropy <= _entropy_of_factor_split(value) 128 | 129 | 130 | def _mul_op(value, sample_args, rationals_allowed): 131 | """Returns sampled args for `ops.Mul`.""" 132 | if sample_args.count >= 3: 133 | _, op_args, sample_args = _div_op(value, sample_args, rationals_allowed) 134 | op_args = [op_args[0], sympy.Integer(1) / op_args[1]] 135 | elif sample_args.count == 1: 136 | entropy, sample_args = sample_args.peel() 137 | assert _entropy_of_factor_split(value) >= entropy 138 | op_args = _split_factors(value) 139 | else: 140 | assert sample_args.count == 2 141 | entropy, sample_args = sample_args.peel() 142 | numer = sympy.numer(value) 143 | denom = sympy.denom(value) 144 | p1, p2 = _split_factors(numer) 145 | entropy -= _entropy_of_factor_split(numer) 146 | mult = number.integer(entropy, signed=True, min_abs=1, coprime_to=p1) 147 | op_args = [p1 / (mult * denom), p2 * mult] 148 | 149 | if random.choice([False, True]): 150 | op_args = list(reversed(op_args)) 151 | 152 | return ops.Mul, op_args, sample_args 153 | 154 | 155 | def _div_filter(value, sample_args): 156 | del value # unused 157 | del sample_args # unused 158 | return True 159 | 160 | 161 | def _div_op(value, sample_args, rationals_allowed): 162 | """Returns sampled args for `ops.Div`.""" 163 | assert rationals_allowed # should be True if this function gets invoked 164 | entropy, sample_args = sample_args.peel() 165 | 166 | numer = sympy.numer(value) 167 | denom = sympy.denom(value) 168 | 169 | if sample_args.count == 1: 170 | mult = number.integer(entropy, signed=True, min_abs=1) 171 | op_args = [numer * mult, denom * mult] 172 | elif sample_args.count == 2: 173 | if numer == 0 or random.choice([False, True]): 174 | x = number.integer(entropy, signed=True, min_abs=1, coprime_to=denom) 175 | op_args = [sympy.Rational(x * numer, denom), x] 176 | else: 177 | x = number.integer(entropy, signed=True, min_abs=1, coprime_to=numer) 178 | op_args = [x, sympy.Rational(x * denom, numer)] 179 | else: 180 | assert sample_args.count >= 3 181 | p2, p1 = _split_factors(numer) 182 | q1, q2 = _split_factors(denom) 183 | entropy -= _entropy_of_factor_split(numer) + _entropy_of_factor_split(denom) 184 | entropy_r = random.uniform(0, entropy) 185 | entropy_s = entropy - entropy_r 186 | r = number.integer(entropy_r, signed=True, min_abs=1, coprime_to=q1*p2) 187 | s = number.integer(entropy_s, signed=False, min_abs=1, coprime_to=p1*q2) 188 | op_args = [sympy.Rational(r*p1, s*q1), sympy.Rational(r*q2, s*p2)] 189 | 190 | return ops.Div, op_args, sample_args 191 | 192 | 193 | def _arithmetic(value, sample_args, add_sub, mul_div): 194 | """Internal arithmetic thingy....""" 195 | assert sample_args.count >= 0 196 | if sample_args.count == 0: 197 | assert sample_args.entropy == 0 198 | return ops.Constant(value) 199 | 200 | allowed = [] 201 | if add_sub and _add_sub_filter(value, sample_args): 202 | allowed.append(_add_op) 203 | allowed.append(_sub_op) 204 | if mul_div and _mul_filter(value, sample_args): 205 | allowed.append(_mul_op) 206 | if mul_div and _div_filter(value, sample_args): 207 | allowed.append(_div_op) 208 | if not allowed: 209 | raise ValueError( 210 | 'No valid ops found, add_sub={} mul_div={} value={} sample_args={}' 211 | .format(add_sub, mul_div, value, sample_args)) 212 | choice = random.choice(allowed) 213 | 214 | op, args, sample_args = choice(value, sample_args, rationals_allowed=mul_div) 215 | sample_args = sample_args.split(args) 216 | child_expressions = [_arithmetic(arg, child_sample_arg, add_sub, mul_div) 217 | for arg, child_sample_arg in zip(args, sample_args)] 218 | 219 | return op(*child_expressions) 220 | 221 | 222 | def length_range_for_entropy(entropy): 223 | """Returns length range to sample from for given entropy.""" 224 | min_length = 3 225 | max_length = min_length + int(entropy / 2) 226 | return min_length, max_length 227 | 228 | 229 | def arithmetic(value, entropy, length=None, add_sub=True, mul_div=True): 230 | """Generates an arithmetic expression with a given value. 231 | 232 | Args: 233 | value: Target value (integer or rational). 234 | entropy: Amount of randomness to use in generating expression. 235 | length: Number of ops to use. If `None` then suitable length will be picked 236 | based on entropy by sampling within the range 237 | `length_range_for_entropy`. 238 | add_sub: Whether to include addition and subtraction operations. 239 | mul_div: Whether to include multiplication and division operations. 240 | 241 | Returns: 242 | Instance of `ops.Op` containing expression. 243 | """ 244 | assert isinstance(entropy, float) 245 | if length is None: 246 | min_length, max_length = length_range_for_entropy(entropy) 247 | length = random.randint(min_length, max_length) 248 | # Some entropy used up in sampling the length. 249 | entropy -= math.log10(max_length - min_length + 1) 250 | else: 251 | assert isinstance(length, int) 252 | 253 | # Entropy adjustment, because different binary trees (from sampling ops) can 254 | # lead to the same expression. This is the correct value when we use just 255 | # addition as the op, and is otherwise an an upper bound. 256 | entropy += combinatorics.log_number_binary_trees(length) / math.log(10) 257 | 258 | value = sympy.sympify(value) 259 | sample_args = _SampleArgs(length, entropy) 260 | return _arithmetic(value, sample_args, add_sub, mul_div) 261 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/arithmetic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.sample.arithmetic.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | 23 | # Dependency imports 24 | from absl.testing import absltest 25 | from absl.testing import parameterized 26 | from mathematics_dataset.sample import arithmetic 27 | from mathematics_dataset.sample import number 28 | from mathematics_dataset.sample import ops 29 | from six.moves import range 30 | import sympy 31 | 32 | 33 | class ArithmeticTest(parameterized.TestCase): 34 | 35 | def testArithmetic(self): 36 | for _ in range(1000): 37 | target = number.integer_or_rational(4, signed=True) 38 | entropy = 8.0 39 | expression = arithmetic.arithmetic(target, entropy) 40 | self.assertEqual(sympy.sympify(expression), target) 41 | 42 | def testArithmeticLength(self): 43 | """Tests that the generated arithmetic expressions have given length.""" 44 | for _ in range(1000): 45 | target = number.integer_or_rational(4, signed=True) 46 | entropy = 8.0 47 | length = random.randint(2, 10) 48 | expression = arithmetic.arithmetic(target, entropy, length) 49 | # Note: actual length is #ops = #numbers - 1. 50 | actual_length = len(ops.number_constants(expression)) - 1 51 | self.assertEqual(actual_length, length) 52 | 53 | 54 | if __name__ == '__main__': 55 | absltest.main() 56 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/linear_system.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generate linear systems with given set of solutions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | 23 | # Dependency imports 24 | from mathematics_dataset.sample import number 25 | from mathematics_dataset.sample import ops 26 | from mathematics_dataset.sample import polynomials 27 | import numpy as np 28 | from six.moves import range 29 | import sympy 30 | 31 | 32 | def _make_equals_zero_split(monomials): 33 | """Returns an `ops.Eq` containing sum of monomials split on left and right.""" 34 | left = [] 35 | right = [] 36 | for monomial in monomials: 37 | if random.choice([False, True]): 38 | left.append(monomial) 39 | else: 40 | right.append(ops.Neg(monomial)) 41 | if not left: 42 | left = [0] 43 | if not right: 44 | right = [0] 45 | left = ops.Add(*left) 46 | right = ops.Add(*right) 47 | return ops.Eq(left, right) 48 | 49 | 50 | def _is_trivial_in(matrix, variable): 51 | """Returns true if matrix_ij == 0 for some i and all j != variable.""" 52 | matrix = np.asarray(matrix) 53 | assert matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1] 54 | size = matrix.shape[0] 55 | if size == 1: 56 | return False 57 | for i in range(size): 58 | all_zero = True 59 | for j in range(size): 60 | if j != variable and matrix[i, j] != 0: 61 | all_zero = False 62 | break 63 | if all_zero: 64 | return True 65 | return False 66 | 67 | 68 | def _invertible_matrix(degree, entropy, non_trivial_in): 69 | """Generates random invertible matrix.""" 70 | matrix_entropies = entropy * np.random.dirichlet(np.ones(degree * degree)) 71 | matrix_entropies = np.reshape(matrix_entropies, [degree, degree]) 72 | matrix_entropies = np.maximum(1, matrix_entropies) 73 | 74 | while True: 75 | def gen(i, j): 76 | return number.integer(matrix_entropies[i, j], True) 77 | 78 | matrix = [[gen(i, j) for i in range(degree)] for j in range(degree)] # pylint: disable=g-complex-comprehension 79 | if non_trivial_in is not None and _is_trivial_in(matrix, non_trivial_in): 80 | continue 81 | if sympy.det(sympy.Matrix(matrix)) != 0: 82 | break 83 | 84 | matrix = np.asarray(matrix).astype(int) 85 | return matrix 86 | 87 | 88 | def linear_system(variables, solutions, entropy, non_trivial_in=None, 89 | length=None): 90 | """Returns a linear system (set of equalities) with the given solutions. 91 | 92 | Args: 93 | variables: List of variables. 94 | solutions: List of solutions, of the same length as `variables`. 95 | entropy: Float >= 0; the entropy used. 96 | non_trivial_in: Optional integer corresponding to a variable for which the 97 | solution shouldn't be "trivial". E.g., "solve a + b = 3, a = -2 for a" 98 | is disallowed if `variables[non_trivial_in] == 'a'`. 99 | length: Total number of terms appearing; if `None` then selected wisely. 100 | 101 | Returns: 102 | List of `ops.Eq`. 103 | """ 104 | degree = len(variables) 105 | assert degree == len(solutions) 106 | 107 | frac_entropy_matrix = random.uniform(1/3, 2/3) 108 | matrix = _invertible_matrix( 109 | degree, entropy * frac_entropy_matrix, non_trivial_in) 110 | solutions = np.asarray(solutions) 111 | constant = np.matmul(matrix, solutions.astype(int)) 112 | flattened = np.concatenate([np.reshape(matrix, [degree * degree]), constant]) 113 | is_zero = flattened == 0 114 | 115 | if length is None: 116 | min_length = np.count_nonzero(flattened) + 1 117 | max_length = max(min_length, 1 + int(degree * (1 + entropy / 2))) 118 | length = random.randint(min_length, max_length) 119 | 120 | counts = polynomials.expanded_coefficient_counts( 121 | length=length, is_zero=is_zero) 122 | 123 | entropies = (1 - frac_entropy_matrix) * entropy * np.random.dirichlet( 124 | np.maximum(1e-9, counts - 1)) 125 | 126 | terms = [] 127 | for i in range(len(flattened)): 128 | coeffs = polynomials.integers_with_sum( 129 | value=flattened[i], count=counts[i], entropy=entropies[i]) 130 | terms.append(coeffs) 131 | 132 | matrix = terms[:degree*degree] 133 | constant = terms[-degree:] 134 | equations = [] 135 | for row_index in range(degree): 136 | monomials = [] 137 | for col_index in range(degree): 138 | for term in matrix[row_index * degree + col_index]: 139 | monomials.append(polynomials.monomial(term, variables[col_index], 1)) 140 | for term in constant[row_index]: 141 | monomials.append(polynomials.monomial(-term, None, 0)) 142 | equations.append(_make_equals_zero_split(monomials)) 143 | 144 | return equations 145 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/linear_system_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.sample.linear_system.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | 23 | # Dependency imports 24 | from absl.testing import absltest 25 | from absl.testing import parameterized 26 | from mathematics_dataset.sample import linear_system 27 | from six.moves import range 28 | import sympy 29 | 30 | 31 | class ExpressionWithValueTest(parameterized.TestCase): 32 | 33 | def testIsTrivialIn(self): 34 | self.assertEqual(linear_system._is_trivial_in([[1]], 0), False) 35 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [3, 4]], 0), False) 36 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [3, 0]], 0), True) 37 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [3, 0]], 1), False) 38 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [0, 3]], 0), False) 39 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [0, 3]], 1), True) 40 | 41 | @parameterized.parameters([1, 2, 3]) 42 | def testLinearSystem(self, degree): 43 | for _ in range(100): # test a few times 44 | target = [random.randint(-100, 100) for _ in range(degree)] 45 | variables = [sympy.Symbol(chr(ord('a') + i)) for i in range(degree)] 46 | system = linear_system.linear_system( 47 | variables=variables, 48 | solutions=target, 49 | entropy=10.0) 50 | solved = sympy.solve(system, variables) 51 | solved = [solved[symbol] for symbol in variables] 52 | self.assertEqual(target, solved) 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/number.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generate random integers and rationals with minimum guarantees on entropy.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import random 23 | 24 | # Dependency imports 25 | from mathematics_dataset.util import display 26 | import numpy as np 27 | import six 28 | import sympy 29 | 30 | 31 | def _coprime_density(value): 32 | """Returns float > 0; asymptotic density of integers coprime to `value`.""" 33 | factors = sympy.factorint(value) 34 | density = 1.0 35 | for prime in six.iterkeys(factors): 36 | density *= 1 - 1 / prime 37 | return density 38 | 39 | 40 | def integer(entropy, signed, min_abs=0, coprime_to=1): 41 | """Returns an integer from a set of size ceil(10**entropy). 42 | 43 | If `signed` is True, then includes negative integers, otherwise includes just 44 | positive integers. 45 | 46 | Args: 47 | entropy: Float >= 0. 48 | signed: Boolean. Whether to also return negative numbers. 49 | min_abs: Integer >= 0. The minimum absolute value. 50 | coprime_to: Optional integer >= 1. The returned integer is guaranteed to be 51 | coprime to `coprime_to`, with entropy still accounted for. 52 | 53 | Returns: 54 | Integer. 55 | """ 56 | assert isinstance(min_abs, int) and not isinstance(min_abs, bool) 57 | coprime_to = abs(coprime_to) 58 | assert min_abs >= 0 59 | 60 | max_ = math.pow(10, entropy) 61 | max_ += min_abs 62 | if coprime_to >= 2: 63 | max_ = max_ / _coprime_density(coprime_to) + 1 64 | 65 | if signed: 66 | max_ = int(math.ceil(max_ / 2)) 67 | range_ = [-max_, max_] 68 | else: 69 | max_ = int(math.ceil(max_)) 70 | range_ = [min_abs, max_] 71 | 72 | while True: 73 | value = random.randint(*range_) 74 | if abs(value) >= min_abs and sympy.gcd(value, coprime_to) == 1: 75 | break 76 | 77 | return sympy.Integer(value) 78 | 79 | 80 | def non_integer_rational(entropy, signed): 81 | """Similar args to `integer`. Entropy split between denom and numer.""" 82 | numer_entropy = random.uniform(0, entropy) 83 | denom_entropy = entropy - numer_entropy 84 | numer = integer(numer_entropy, signed, min_abs=1) 85 | denom = integer(denom_entropy, False, min_abs=2, coprime_to=numer) 86 | return sympy.Rational(numer, denom) 87 | 88 | 89 | def integer_or_rational(entropy, signed, min_abs=0): 90 | """Returns a rational, with 50% probability of it being an integer.""" 91 | if random.choice([False, True]): 92 | return integer(entropy, signed, min_abs=min_abs) 93 | else: 94 | return non_integer_rational(entropy, signed) 95 | 96 | 97 | def non_integer_decimal(entropy, signed): 98 | """Returns a random decimal; integer divided by random power of ten. 99 | 100 | Guaranteed to be non-integer (i.e., numbers after the decimal point). 101 | 102 | Args: 103 | entropy: Float. 104 | signed: Boolean. Whether to also return negative numbers. 105 | 106 | Returns: 107 | Non-integer decimal. 108 | """ 109 | while True: 110 | base = integer(entropy, signed) 111 | shift = random.randint(1, int(math.ceil(entropy))) 112 | divisor = 10**shift 113 | if base % divisor != 0: 114 | return display.Decimal(sympy.Rational(base, divisor)) 115 | 116 | 117 | def integer_or_decimal(entropy, signed): 118 | """Returns integer or non-integer decimal; 50% probability of each.""" 119 | if random.choice([False, True]): 120 | # Represent it as a decimal so that arithmetic operations are supported: 121 | return display.Decimal(integer(entropy, signed)) 122 | else: 123 | return non_integer_decimal(entropy, signed) 124 | 125 | 126 | def entropy_of_value(value): 127 | """Returns "min entropy" that would give probability of getting this value.""" 128 | if isinstance(value, display.Decimal): 129 | return entropy_of_value(sympy.numer(value)) 130 | 131 | if is_non_integer_rational(value): 132 | numer = sympy.numer(value) 133 | denom = sympy.denom(value) 134 | return entropy_of_value(numer) + entropy_of_value(denom) 135 | elif not is_integer(value): 136 | raise ValueError('Unhandled value: {}'.format(value)) 137 | 138 | # Note: we sample integers in a range of size approx 10**entropy about zero, 139 | # so assume that `abs(value)` is about half of the upper range. 140 | return math.log10(5 * abs(value) + 1) 141 | 142 | 143 | def is_integer(value): 144 | return isinstance(value, (int, np.int64, np.int32, sympy.Integer)) 145 | 146 | 147 | def is_positive_integer(value): 148 | """Filter for: value is a strictly positive integer.""" 149 | return is_integer(value) and value > 0 150 | 151 | 152 | def is_integer_or_rational(value): 153 | return is_integer(value) or isinstance(value, sympy.Rational) 154 | 155 | 156 | def is_integer_or_decimal(value): 157 | return is_integer(value) or isinstance(value, display.Decimal) 158 | 159 | 160 | def is_integer_or_rational_or_decimal(value): 161 | return is_integer_or_rational(value) or is_integer_or_decimal(value) 162 | 163 | 164 | def is_non_integer_rational(value): 165 | return is_integer_or_rational(value) and not is_integer(value) 166 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/number_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.sample.number.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | 23 | # Dependency imports 24 | from absl.testing import absltest 25 | from absl.testing import parameterized 26 | from mathematics_dataset.sample import number 27 | from six.moves import range 28 | import sympy 29 | 30 | 31 | class NumberTest(parameterized.TestCase): 32 | 33 | def testCoprimeDensity(self): 34 | self.assertEqual(number._coprime_density(1), 1.0) 35 | self.assertEqual(number._coprime_density(2), 0.5) 36 | self.assertLess(abs(number._coprime_density(3) - 2/3), 1e-6) 37 | self.assertLess(abs(number._coprime_density(6) - 1/3), 1e-6) 38 | 39 | @parameterized.parameters(False, True) 40 | def testInteger_allowZero(self, signed): 41 | saw_zero = False 42 | saw_nonzero = False 43 | for _ in range(1000): 44 | sample = number.integer(1, signed=signed) 45 | if sample == 0: 46 | saw_zero = True 47 | else: 48 | saw_nonzero = True 49 | if saw_zero and saw_nonzero: 50 | break 51 | self.assertTrue(saw_zero) 52 | self.assertTrue(saw_nonzero) 53 | 54 | def testNonIntegerRational(self): 55 | for _ in range(1000): 56 | entropy = random.uniform(0, 10) 57 | signed = random.choice([False, True]) 58 | sample = number.non_integer_rational(entropy, signed) 59 | self.assertNotEqual(sympy.denom(sample), 1) 60 | 61 | @parameterized.parameters(False, True) 62 | def testIntegerOrRational(self, signed): 63 | # Tests we can call it. Do it a few times so both code paths get executed. 64 | for _ in range(10): 65 | number.integer_or_rational(2, signed) 66 | 67 | def testNonIntegerDecimal(self): 68 | for _ in range(1000): 69 | sample = number.non_integer_decimal(1, False) 70 | self.assertNotEqual(sympy.denom(sample), 1) 71 | self.assertLen(str(sample), 3) # should be of form "0.n" 72 | self.assertGreater(sample, 0) # positive 73 | 74 | def testNonIntegerDecimal_size(self): 75 | saw_bigger_one = False 76 | saw_smaller_one = False 77 | for _ in range(1000): 78 | sample = number.non_integer_decimal(2, False) 79 | if sample > 1: 80 | saw_bigger_one = True 81 | else: 82 | saw_smaller_one = True 83 | if saw_bigger_one and saw_smaller_one: 84 | break 85 | self.assertTrue(saw_bigger_one) 86 | self.assertTrue(saw_smaller_one) 87 | 88 | @parameterized.parameters( 89 | lambda: number.integer(0, True), 90 | lambda: number.integer(1, True), 91 | lambda: number.non_integer_rational(2, True), 92 | lambda: number.non_integer_decimal(1, True)) 93 | def testGenerate_signed(self, generator): 94 | saw_positive = False 95 | saw_negative = False 96 | for _ in range(1000): 97 | sample = generator() 98 | saw_positive |= sample > 0 99 | saw_negative |= sample < 0 100 | if saw_positive and saw_negative: 101 | break 102 | 103 | self.assertTrue(saw_positive) 104 | self.assertTrue(saw_negative) 105 | 106 | @parameterized.parameters( 107 | lambda: number.integer(2, False), 108 | lambda: number.non_integer_rational(2, False)) 109 | def testIntegerRational_distinctCount(self, generator): 110 | seen = set() 111 | for _ in range(3000): 112 | seen.add(generator()) 113 | self.assertGreaterEqual(len(seen), 10 ** 2) 114 | 115 | @parameterized.parameters(number.integer, number.non_integer_decimal) 116 | def testEntropyOfValue(self, generator): 117 | for entropy in [1, 2, 4, 8, 16]: 118 | sum_entropy = 0.0 119 | count = 2000 120 | for _ in range(count): 121 | value = generator(entropy, signed=True) 122 | sum_entropy += number.entropy_of_value(value) 123 | avg_entropy = sum_entropy / count 124 | error = abs(entropy - avg_entropy) / entropy 125 | self.assertLess(error, 0.2) 126 | 127 | 128 | if __name__ == '__main__': 129 | absltest.main() 130 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mathematical operations used to build up expressions for printing. 16 | 17 | We can't use sympy because sympy will automatically simplify many types of 18 | expressions, even with `evaluate=False` passed in. For example: 19 | 20 | * Mul(-2, -3, evaluate=False) gives -(-6), not (-2) x (-3). 21 | * Add(2, 1, evaluate=False) gives 1 + 2, because the terms are sorted. 22 | 23 | As such, it's easier just to work with our own op classes that display precisely 24 | as we created them. This also allows us to use custom symbols for the 25 | expressions, such as the multiplication symbol. 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | import abc 33 | 34 | # Dependency imports 35 | from absl import logging 36 | from mathematics_dataset.sample import number 37 | from mathematics_dataset.util import display 38 | import numpy as np 39 | import six 40 | from six.moves import zip 41 | import sympy 42 | 43 | 44 | MUL_SYMBOL = '*' 45 | DIV_SYMBOL = '/' 46 | POW_SYMBOL = '**' 47 | GT_SYMBOL = '>' 48 | LT_SYMBOL = '<' 49 | GE_SYMBOL = '>=' 50 | LE_SYMBOL = '<=' 51 | EQ_SYMBOL = '=' 52 | NE_SYMBOL = '!=' 53 | 54 | 55 | # Operator precedence levels. Used to insert brackets if necessary. 56 | _EQ_PRECEDENCE = 0 57 | _CONSTANT_PRECEDENCE = 1 58 | _POW_PRECEDENCE = 2 59 | _SQRT_PRECEDENCE = 3 60 | _MUL_PRECEDENCE = 4 61 | _ADD_PRECEDENCE = 5 62 | 63 | 64 | def bracketed(child, parent, bracket_if_same_precedence): 65 | """Returns string representation of `child`, possibly bracketed. 66 | 67 | Args: 68 | child: Instance of `Op` or a valid value for `ConstantOp`. 69 | parent: Instance of `Op`. Used to determine whether `child` needs to be 70 | bracketed first before appearing in the parent op's expression. 71 | bracket_if_same_precedence: Whether to bracket if the child has the same 72 | operator precedence as the parent. 73 | 74 | Returns: 75 | String representation of `child`. 76 | """ 77 | if not isinstance(child, Op): 78 | child = Constant(child) 79 | 80 | child_precedence = child.precedence 81 | parent_precedence = parent.precedence 82 | if (parent_precedence > child_precedence 83 | or (parent_precedence == child_precedence 84 | and not bracket_if_same_precedence)): 85 | return str(child) 86 | else: 87 | return '({})'.format(child) 88 | 89 | 90 | def _flatten(iterable): 91 | """Returns list.""" 92 | if isinstance(iterable, (list, tuple)): 93 | result = list(iterable) 94 | else: 95 | assert isinstance(iterable, dict) 96 | keys = sorted(six.iterkeys(iterable)) 97 | result = [iterable[key] for key in keys] 98 | # Check we don't have any hierarchy in the structure (otherwise would need 99 | # to use something recursive like tf.contrib.framework.nest.flatten). 100 | for item in result: 101 | assert not isinstance(item, (list, tuple, dict)) 102 | return result 103 | 104 | 105 | def _pack_sequence_as(example, flat): 106 | if isinstance(example, list) or isinstance(example, tuple): 107 | return flat 108 | else: 109 | assert isinstance(example, dict) 110 | keys = sorted(six.iterkeys(example)) 111 | return {key: value for key, value in zip(keys, flat)} 112 | 113 | 114 | @six.add_metaclass(abc.ABCMeta) 115 | class Op(object): 116 | """An operation. 117 | 118 | This needs to support being transformed into sympy (and possibly in the future 119 | other types such as an appropriately formatted string), when given the op 120 | arguments. 121 | """ 122 | 123 | def __init__(self, children): 124 | """Initialize this `Op` base class. 125 | 126 | Args: 127 | children: Iterable structure containing child ops. 128 | """ 129 | assert isinstance(children, (list, dict, tuple)) 130 | flat_children = _flatten(children) 131 | flat_children = [child if isinstance(child, Op) else Constant(child) 132 | for child in flat_children] 133 | children = _pack_sequence_as(children, flat_children) 134 | self._children = children 135 | 136 | @property 137 | def children(self): 138 | """Returns iterable or dict over immediate children.""" 139 | return self._children 140 | 141 | def descendants(self): 142 | """Returns list of all descendants (self, children, grandchildren, etc).""" 143 | descendants = [self] 144 | flat_children = _flatten(self._children) 145 | for child in flat_children: 146 | descendants += child.descendants() 147 | return descendants 148 | 149 | @abc.abstractmethod 150 | def __str__(self): 151 | """Returns a string format of this op.""" 152 | 153 | @abc.abstractmethod 154 | def sympy(self): 155 | """Returns the sympifcation of this op.""" 156 | 157 | def _sympy_(self): 158 | """Convenience method to automatically sympify this object.""" 159 | try: 160 | return self.sympy() 161 | except AttributeError as e: 162 | # Note: we print this error here, before raising it again, because sympy 163 | # will think `AttributeError` refers to this object not having a `_sympy_` 164 | # method, rather than having it, which leads to otherwise confusing error 165 | # messages. 166 | logging.error( 167 | 'Encountered attribute error while trying to sympify: %s', e) 168 | raise e 169 | 170 | @abc.abstractproperty 171 | def precedence(self): 172 | """Returns the precedence (integer) of this op.""" 173 | 174 | 175 | class Constant(Op): 176 | """Returns a constant value; a nullary op.""" 177 | 178 | def __init__(self, value): 179 | super(Constant, self).__init__([]) 180 | if isinstance(value, six.integer_types): 181 | value = sympy.Integer(value) 182 | self._value = value 183 | 184 | def __str__(self): 185 | return str(self._value) 186 | 187 | def sympy(self): 188 | return self._value 189 | 190 | @property 191 | def value(self): 192 | return self._value 193 | 194 | @value.setter 195 | def value(self, value): 196 | self._value = value 197 | 198 | def _is_simple(self): 199 | """Returns whether it's a simple number, rather than a division or neg.""" 200 | if isinstance(self._value, sympy.Symbol): 201 | return True 202 | elif (isinstance(self._value, int) 203 | or isinstance(self._value, sympy.Integer) 204 | or isinstance(self._value, display.Decimal) 205 | or isinstance(self._value, np.int64) 206 | or isinstance(self._value, np.int32)): 207 | return self._value >= 0 208 | elif isinstance(self._value, sympy.Rational): 209 | return False 210 | elif isinstance(self._value, sympy.Function): 211 | return True 212 | else: 213 | raise ValueError('Unknown type {}'.format(type(self._value))) 214 | 215 | @property 216 | def precedence(self): 217 | if self._is_simple(): 218 | return _CONSTANT_PRECEDENCE 219 | else: 220 | return _MUL_PRECEDENCE 221 | 222 | 223 | class _SumLikeOp(Op): 224 | """Abstract op for sum-like terms which may contain negative entries.""" 225 | 226 | @abc.abstractmethod 227 | def expanded_signs_and_terms(self): 228 | """Returns a list of arguments, plus any sub-arguments from sub-adds. 229 | 230 | E.g., if this op is `Add(Add(2, Neg(3)), Mul(4, 5), 1)`, then will return 231 | `[(True, 2), (False, 3), (True, Mul(4, 5)), (True, 1)]` (the arguments of 232 | the inner add have been extracted). 233 | """ 234 | 235 | def __str__(self): 236 | signs_and_terms = self.expanded_signs_and_terms() 237 | if not signs_and_terms: 238 | return '0' 239 | for i, (sign, term) in enumerate(signs_and_terms): 240 | if i == 0: 241 | if sign: 242 | expression = bracketed(term, self, True) 243 | else: 244 | expression = '-' + bracketed(term, self, True) 245 | else: 246 | if sign: 247 | expression += ' + ' + bracketed(term, self, True) 248 | else: 249 | expression += ' - ' + bracketed(term, self, True) 250 | return expression 251 | 252 | 253 | class Identity(_SumLikeOp): 254 | """The identity op (a unitary op).""" 255 | 256 | def __init__(self, input_): 257 | super(Identity, self).__init__({'input': input_}) 258 | 259 | def expanded_signs_and_terms(self): 260 | if isinstance(self.children['input'], _SumLikeOp): 261 | return self.children['input'].expanded_signs_and_terms() 262 | else: 263 | return [(True, self.children['input'])] 264 | 265 | def __str__(self): 266 | return str(self.children['input']) 267 | 268 | def sympy(self): 269 | return self.children['input'].sympy() 270 | 271 | @property 272 | def precedence(self): 273 | return self.children['input'].precedence 274 | 275 | 276 | class Neg(_SumLikeOp): 277 | """Negation, a unary op. Also has special display when appearing in a sum.""" 278 | 279 | def __init__(self, arg): 280 | super(Neg, self).__init__({'input': arg}) 281 | 282 | def expanded_signs_and_terms(self): 283 | if isinstance(self.children['input'], _SumLikeOp): 284 | inner_signs_and_terms = self.children['input'].expanded_signs_and_terms() 285 | return [(not sign, term) for (sign, term) in inner_signs_and_terms] 286 | else: 287 | return [(False, self.children['input'])] 288 | 289 | def sympy(self): 290 | return -sympy.sympify(self.children['input']) 291 | 292 | def inner(self): 293 | return self.children['input'] 294 | 295 | @property 296 | def precedence(self): 297 | return _ADD_PRECEDENCE 298 | 299 | 300 | class Add(_SumLikeOp): 301 | """Addition.""" 302 | 303 | def __init__(self, *args): 304 | super(Add, self).__init__(args) 305 | 306 | def expanded_signs_and_terms(self): 307 | """Returns a list of arguments, plus any sub-arguments from sub-adds. 308 | 309 | E.g., if this op is `Add(Add(2, 3), Mul(4, 5), 1)`, then will return 310 | `[2, 3, Mul(4, 5), 1]` (the arguments of the inner add have been extracted). 311 | """ 312 | expanded = [] 313 | for arg in self.children: 314 | if isinstance(arg, _SumLikeOp): 315 | expanded += arg.expanded_signs_and_terms() 316 | else: 317 | expanded.append((True, arg)) 318 | return expanded 319 | 320 | def sympy(self): 321 | return sympy.Add(*[sympy.sympify(arg) for arg in self.children]) 322 | 323 | @property 324 | def precedence(self): 325 | return _ADD_PRECEDENCE 326 | 327 | 328 | class Sub(Op): 329 | """Subtraction.""" 330 | 331 | def __init__(self, left, right): 332 | super(Sub, self).__init__({'left': left, 'right': right}) 333 | 334 | def __str__(self): 335 | return (bracketed(self.children['left'], self, False) + ' - ' 336 | + bracketed(self.children['right'], self, True)) 337 | 338 | def sympy(self): 339 | return sympy.Add( 340 | self.children['left'], sympy.Mul(-1, self.children['right'])) 341 | 342 | @property 343 | def precedence(self): 344 | return _ADD_PRECEDENCE 345 | 346 | 347 | class Mul(Op): 348 | """Multiplication.""" 349 | 350 | def __init__(self, *args): 351 | super(Mul, self).__init__(args) 352 | 353 | def __str__(self): 354 | if not self.children: 355 | return '1' 356 | else: 357 | args = [bracketed(arg, self, False) for arg in self.children] 358 | return MUL_SYMBOL.join(args) 359 | 360 | def sympy(self): 361 | return sympy.Mul(*[sympy.sympify(arg) for arg in self.children]) 362 | 363 | @property 364 | def precedence(self): 365 | return _MUL_PRECEDENCE 366 | 367 | 368 | class Div(Op): 369 | """Division.""" 370 | 371 | def __init__(self, numer, denom): 372 | super(Div, self).__init__({'numer': numer, 'denom': denom}) 373 | 374 | def __str__(self): 375 | return u'{}{}{}'.format( 376 | bracketed(self.children['numer'], self, True), DIV_SYMBOL, 377 | bracketed(self.children['denom'], self, True)) 378 | 379 | def sympy(self): 380 | return sympy.Mul( 381 | self.children['numer'], sympy.Pow(self.children['denom'], -1)) 382 | 383 | @property 384 | def precedence(self): 385 | return _MUL_PRECEDENCE 386 | 387 | 388 | class Pow(Op): 389 | """Power a to the power b.""" 390 | 391 | def __init__(self, a, b): 392 | super(Pow, self).__init__({'a': a, 'b': b}) 393 | 394 | def __str__(self): 395 | return u'{}{}{}'.format( 396 | bracketed(self.children['a'], self, True), POW_SYMBOL, 397 | bracketed(self.children['b'], self, True)) 398 | 399 | def sympy(self): 400 | return sympy.Pow( 401 | sympy.sympify(self.children['a']), sympy.sympify(self.children['b'])) 402 | 403 | @property 404 | def precedence(self): 405 | return _POW_PRECEDENCE 406 | 407 | 408 | class Sqrt(Op): 409 | """Square root of a value.""" 410 | 411 | def __init__(self, a): 412 | super(Sqrt, self).__init__({'a': a}) 413 | 414 | def __str__(self): 415 | return 'sqrt({})'.format(self.children['a']) 416 | 417 | def sympy(self): 418 | return sympy.sqrt(self.children['a']) 419 | 420 | @property 421 | def precedence(self): 422 | return _POW_PRECEDENCE 423 | 424 | 425 | class Eq(Op): 426 | """Equality.""" 427 | 428 | def __init__(self, left, right): 429 | super(Eq, self).__init__({'left': left, 'right': right}) 430 | 431 | def __str__(self): 432 | return '{} = {}'.format(self.children['left'], self.children['right']) 433 | 434 | def sympy(self): 435 | return sympy.Eq(self.children['left'], self.children['right']) 436 | 437 | @property 438 | def precedence(self): 439 | return _EQ_PRECEDENCE 440 | 441 | 442 | def number_constants(expressions): 443 | """Returns list of integer, rational, decimal constants in the expressions.""" 444 | if isinstance(expressions, Op): 445 | expressions = [expressions] 446 | descendants = [] 447 | for expression in expressions: 448 | descendants += expression.descendants() 449 | candidate_constants = [op for op in descendants if isinstance(op, Constant)] 450 | return [constant for constant in candidate_constants 451 | if number.is_integer_or_rational_or_decimal(constant.value)] 452 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.sample.ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import absltest 23 | from mathematics_dataset.sample import ops 24 | from six.moves import range 25 | import sympy 26 | 27 | 28 | class OpsTest(absltest.TestCase): 29 | 30 | def testNeg(self): 31 | op = ops.Neg(2) 32 | self.assertEqual(str(op), '-2') 33 | self.assertEqual(op.sympy(), -2) 34 | 35 | op = ops.Add(ops.Neg(2), 3) 36 | self.assertEqual(str(op), '-2 + 3') 37 | self.assertEqual(op.sympy(), 1) 38 | 39 | op = ops.Add(3, ops.Neg(2)) 40 | self.assertEqual(str(op), '3 - 2') 41 | self.assertEqual(op.sympy(), 1) 42 | 43 | op = ops.Add(ops.Add(ops.Neg(2), 5), 3) 44 | self.assertEqual(str(op), '-2 + 5 + 3') 45 | self.assertEqual(op.sympy(), 6) 46 | 47 | op = ops.Add(3, ops.Add(ops.Identity(ops.Neg(2)), 5)) 48 | self.assertEqual(str(op), '3 - 2 + 5') 49 | self.assertEqual(op.sympy(), 6) 50 | 51 | op = ops.Add(3, ops.Add(2, ops.Neg(5))) 52 | self.assertEqual(str(op), '3 + 2 - 5') 53 | self.assertEqual(op.sympy(), 0) 54 | 55 | def testAdd(self): 56 | add = ops.Add() 57 | self.assertEqual(str(add), '0') 58 | self.assertEqual(add.sympy(), 0) 59 | 60 | add = ops.Add(2, 3) 61 | self.assertEqual(str(add), '2 + 3') 62 | self.assertEqual(add.sympy(), 5) 63 | 64 | add = ops.Add(ops.Add(1, 2), 3) 65 | self.assertEqual(str(add), '1 + 2 + 3') 66 | self.assertEqual(add.sympy(), 6) 67 | 68 | def testSub(self): 69 | sub = ops.Sub(2, 3) 70 | self.assertEqual(str(sub), '2 - 3') 71 | self.assertEqual(sub.sympy(), -1) 72 | 73 | sub = ops.Sub(ops.Sub(1, 2), 3) 74 | self.assertEqual(str(sub), '1 - 2 - 3') 75 | self.assertEqual(sub.sympy(), -4) 76 | 77 | sub = ops.Sub(1, ops.Sub(2, 3)) 78 | self.assertEqual(str(sub), '1 - (2 - 3)') 79 | self.assertEqual(sub.sympy(), 2) 80 | 81 | sub = ops.Sub(ops.Neg(1), 2) 82 | self.assertEqual(str(sub), '-1 - 2') 83 | self.assertEqual(sub.sympy(), -3) 84 | 85 | def testMul(self): 86 | mul = ops.Mul() 87 | self.assertEqual(str(mul), '1') 88 | self.assertEqual(mul.sympy(), 1) 89 | 90 | mul = ops.Mul(2, 3) 91 | self.assertEqual(str(mul), '2*3') 92 | self.assertEqual(mul.sympy(), 6) 93 | 94 | mul = ops.Mul(ops.Identity(ops.Constant(-2)), 3) 95 | self.assertEqual(str(mul), '-2*3') 96 | self.assertEqual(mul.sympy(), -6) 97 | 98 | mul = ops.Mul(ops.Add(1, 2), 3) 99 | self.assertEqual(str(mul), '(1 + 2)*3') 100 | self.assertEqual(mul.sympy(), 9) 101 | 102 | mul = ops.Mul(ops.Mul(2, 3), 5) 103 | self.assertEqual(str(mul), '2*3*5') 104 | self.assertEqual(mul.sympy(), 30) 105 | 106 | # TODO(b/124038946): reconsider how we want brackets in these cases: 107 | # mul = ops.Mul(ops.Div(2, 3), 5) 108 | # self.assertEqual(str(mul), '(2/3)*5') 109 | # self.assertEqual(mul.sympy(), sympy.Rational(10, 3)) 110 | # 111 | # mul = ops.Mul(sympy.Rational(2, 3), 5) 112 | # self.assertEqual(str(mul), '(2/3)*5') 113 | # self.assertEqual(mul.sympy(), sympy.Rational(10, 3)) 114 | 115 | def testDiv(self): 116 | div = ops.Div(2, 3) 117 | self.assertEqual(str(div), '2/3') 118 | self.assertEqual(div.sympy(), sympy.Rational(2, 3)) 119 | 120 | div = ops.Div(2, sympy.Rational(4, 5)) 121 | self.assertEqual(str(div), '2/(4/5)') 122 | self.assertEqual(div.sympy(), sympy.Rational(5, 2)) 123 | 124 | div = ops.Div(1, ops.Div(2, 3)) 125 | self.assertEqual(str(div), '1/(2/3)') 126 | self.assertEqual(div.sympy(), sympy.Rational(3, 2)) 127 | 128 | div = ops.Div(ops.Div(2, 3), 4) 129 | self.assertEqual(str(div), '(2/3)/4') 130 | self.assertEqual(div.sympy(), sympy.Rational(1, 6)) 131 | 132 | div = ops.Div(2, ops.Mul(3, 4)) 133 | self.assertEqual(str(div), '2/(3*4)') 134 | 135 | div = ops.Div(2, sympy.Function('f')(sympy.Symbol('x'))) 136 | self.assertEqual(str(div), '2/f(x)') 137 | 138 | def testPow(self): 139 | pow_ = ops.Pow(2, 3) 140 | self.assertEqual(str(pow_), '2**3') 141 | self.assertEqual(pow_.sympy(), 8) 142 | 143 | pow_ = ops.Pow(4, sympy.Rational(1, 2)) 144 | self.assertEqual(str(pow_), '4**(1/2)') 145 | self.assertEqual(pow_.sympy(), 2) 146 | 147 | pow_ = ops.Pow(sympy.Rational(1, 2), 3) 148 | self.assertEqual(str(pow_), '(1/2)**3') 149 | self.assertEqual(pow_.sympy(), 1/8) 150 | 151 | pow_ = ops.Pow(3, ops.Pow(2, 1)) 152 | self.assertEqual(str(pow_), '3**(2**1)') 153 | self.assertEqual(pow_.sympy(), 9) 154 | 155 | pow_ = ops.Pow(ops.Pow(2, 3), 4) 156 | self.assertEqual(str(pow_), '(2**3)**4') 157 | self.assertEqual(pow_.sympy(), 4096) 158 | 159 | pow_ = ops.Pow(-5, 2) 160 | self.assertEqual(str(pow_), '(-5)**2') 161 | self.assertEqual(pow_.sympy(), 25) 162 | 163 | def testEq(self): 164 | op = ops.Eq(ops.Add(2, 3), 4) 165 | self.assertEqual(str(op), '2 + 3 = 4') 166 | self.assertEqual(op.sympy(), False) 167 | 168 | def testDescendants(self): 169 | constants = [ops.Constant(i) for i in range(6)] 170 | 171 | # (1 + 2*3**4) / 5 - 6 172 | expression = ops.Sub( 173 | ops.Div( 174 | ops.Add( 175 | constants[0], 176 | ops.Mul( 177 | constants[1], 178 | ops.Pow( 179 | constants[2], 180 | constants[3]))), 181 | constants[4]), 182 | constants[5]) 183 | descendants = expression.descendants() 184 | descendants = ops._flatten(descendants) 185 | 186 | for constant in constants: 187 | self.assertIn(constant, descendants) 188 | self.assertEqual(descendants.count(constant), 1) 189 | 190 | # Also test top-level. 191 | self.assertEqual(constants[0].descendants(), [constants[0]]) 192 | 193 | # Also general structure. 194 | constant = ops.Constant(3) 195 | expression = ops.Neg(constant) 196 | self.assertEqual(set(expression.descendants()), set([constant, expression])) 197 | 198 | def testNumberConstants(self): 199 | constant = ops.Constant(3) 200 | expression = ops.Neg(constant) 201 | constants = ops.number_constants([expression]) 202 | self.assertEqual(constants, [constant]) 203 | 204 | 205 | if __name__ == '__main__': 206 | absltest.main() 207 | -------------------------------------------------------------------------------- /mathematics_dataset/sample/polynomials_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.sample.polynomials.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | 23 | # Dependency imports 24 | from absl.testing import parameterized 25 | from mathematics_dataset.sample import polynomials 26 | import numpy as np 27 | from six.moves import range 28 | import sympy 29 | import tensorflow as tf 30 | 31 | 32 | class ExpressionWithValueTest(tf.test.TestCase, parameterized.TestCase): 33 | 34 | def testSplitValueEqually(self): 35 | split = polynomials._split_value_equally(3, 2) 36 | self.assertEqual(split, [1, 2]) 37 | split = polynomials._split_value_equally(sympy.sympify('3/4'), 2) 38 | self.assertEqual(split, [sympy.sympify('1/4'), sympy.sympify('1/2')]) 39 | 40 | def testIntegersWithSum(self): 41 | value = 13 42 | count = 10 43 | terms = polynomials.integers_with_sum(value=value, count=count, entropy=4.0) 44 | self.assertLen(terms, count) 45 | self.assertEqual(sum(terms), value) 46 | 47 | def testMonomial(self): 48 | x, y = sympy.symbols('x y') 49 | self.assertEqual(str(polynomials.monomial(1, [x, y], [2, 3])), 'x**2*y**3') 50 | # TODO(b/124038530): how handle rational coefficients; are they even used? 51 | # self.assertEqual( 52 | # str(polynomials.monomial(sympy.Rational(2, 3), [x], [1])), '2*x/3') 53 | # self.assertEqual( 54 | # str(polynomials.monomial(sympy.Rational(1, 3), [x], [1])), 'x/3') 55 | self.assertEqual(str(polynomials.monomial(x, [y], [4])), 'x*y**4') 56 | 57 | def testExpandCoefficients(self): 58 | for _ in range(10): 59 | num_variables = np.random.randint(1, 4) 60 | degrees = np.random.randint(0, 4, [num_variables]) 61 | coefficients = np.random.randint(-3, 3, degrees + 1) 62 | entropy = np.random.uniform(0, 10) 63 | expanded = polynomials.expand_coefficients(coefficients, entropy) 64 | collapsed = np.vectorize(sum)(expanded) 65 | self.assertAllEqual(coefficients, collapsed) 66 | 67 | def testCoefficientsToPolynomial(self): 68 | coeffs = [3, 2, 1] 69 | x = sympy.Symbol('x') 70 | polynomial = polynomials.coefficients_to_polynomial(coeffs, [x]) 71 | polynomial = sympy.sympify(polynomial) 72 | self.assertEqual(polynomial, x*x + 2*x + 3) 73 | 74 | def testUnivariate(self): 75 | # Test generation for: x**2 + 2*x + 1 76 | x = sympy.Symbol('x') 77 | coeffs = [1, 2, 3] 78 | for _ in range(10): 79 | expanded = polynomials.expand_coefficients(coeffs, 5.0) 80 | polynomial = polynomials.coefficients_to_polynomial(expanded, [x]) 81 | sympified = sympy.sympify(polynomial) 82 | self.assertEqual(sympified, 1 + 2*x + 3*x*x) 83 | 84 | def testMultivariate(self): 85 | # Test generation for: x**2 + 2*x*y + 3*y**2 - x + 5 86 | x, y = sympy.symbols('x y') 87 | coeffs = [[5, 0, 3], [-1, 2, 0], [1, 0, 0]] 88 | for _ in range(10): 89 | expanded = polynomials.expand_coefficients(coeffs, 5.0, length=10) 90 | polynomial = polynomials.coefficients_to_polynomial(expanded, [x, y]) 91 | sympified = sympy.sympify(polynomial) 92 | self.assertEqual(sympified, x*x + 2*x*y + 3*y*y - x + 5) 93 | 94 | def testAddCoefficients(self): 95 | # Add x**2 + 2*y and 3*x + 4*y**3. 96 | coeffs1 = [[0, 2], [0, 0], [1, 0]] 97 | coeffs2 = [[0, 0, 0, 4], [3, 0, 0, 0]] 98 | target = [[0, 2, 0, 4], [3, 0, 0, 0], [1, 0, 0, 0]] 99 | actual = polynomials.add_coefficients(coeffs1, coeffs2) 100 | self.assertAllEqual(target, actual) 101 | 102 | def testCoefficientsLinearSplit(self): 103 | for degree in range(3): 104 | for ndims in range(3): 105 | for _ in range(10): 106 | coefficients = np.random.randint(-5, 5, [degree + 1] * ndims) 107 | entropy = random.uniform(1, 4) 108 | c1, c2, coeffs1, coeffs2 = polynomials.coefficients_linear_split( 109 | coefficients, entropy) 110 | c1 = int(c1) 111 | c2 = int(c2) 112 | coeffs1 = np.asarray(coeffs1, dtype=np.int32) 113 | coeffs2 = np.asarray(coeffs2, dtype=np.int32) 114 | sum_ = c1 * coeffs1 + c2 * coeffs2 115 | self.assertAllEqual(sum_, coefficients) 116 | 117 | def testSampleWithBrackets(self): 118 | x, y = sympy.symbols('x y') 119 | for _ in range(100): 120 | degrees = np.random.randint(1, 4, [2]) 121 | entropy = random.uniform(0, 4) 122 | polynomial = polynomials.sample_with_brackets( 123 | variables=[x, y], degrees=degrees, entropy=entropy) 124 | self.assertIn('(', str(polynomial)) 125 | poly = sympy.poly(sympy.sympify(polynomial).expand()) 126 | self.assertEqual(poly.degree(x), degrees[0]) 127 | self.assertEqual(poly.degree(y), degrees[1]) 128 | 129 | def testTrim(self): 130 | self.assertAllEqual(polynomials.trim([1]), [1]) 131 | self.assertAllEqual(polynomials.trim([1, 0]), [1]) 132 | self.assertAllEqual(polynomials.trim([0, 1]), [0, 1]) 133 | self.assertAllEqual(polynomials.trim([0]), []) 134 | self.assertAllEqual(polynomials.trim([0, 0]), []) 135 | 136 | def testDifferentiate_univariate(self): 137 | coeffs = [5, 3, 2] 138 | expected = [3, 4] 139 | actual = polynomials.differentiate(coeffs, 0) 140 | self.assertAllEqual(expected, actual) 141 | 142 | def testDifferentiate_multivariate(self): 143 | coeffs = [[0, 3, 1], [5, 0, 0], [0, 2, 0]] 144 | expected = [[5, 0], [0, 4]] 145 | actual = polynomials.differentiate(coeffs, 0) 146 | self.assertAllEqual(expected, actual) 147 | 148 | def testIntegrate_univariate(self): 149 | coeffs = [5, 3, 2] 150 | expected = [0, 5, sympy.Rational(3, 2), sympy.Rational(2, 3)] 151 | actual = polynomials.integrate(coeffs, 0) 152 | self.assertAllEqual(expected, actual) 153 | 154 | def testIntegrate_multivariate(self): 155 | coeffs = [[0, 1], [1, 0]] 156 | expected = [[0, 0, sympy.Rational(1, 2)], [0, 1, 0]] 157 | actual = polynomials.integrate(coeffs, 1) 158 | self.assertAllEqual(expected, actual) 159 | 160 | 161 | if __name__ == '__main__': 162 | tf.test.main() 163 | -------------------------------------------------------------------------------- /mathematics_dataset/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /mathematics_dataset/util/combinatorics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Combinatorics utility functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import random 23 | 24 | # Dependency imports 25 | from six.moves import range 26 | from six.moves import zip 27 | 28 | 29 | def uniform_positive_integers_with_sum(count, sum_): 30 | """Returns list of size `count` of integers >= 1, summing to `sum_`.""" 31 | assert sum_ >= 0 32 | if count > sum_: 33 | raise ValueError('Cannot find {} numbers >= 1 with sum {}' 34 | .format(count, sum_)) 35 | if count == 0: 36 | return [] 37 | # Select `count - 1` numbers from {1, ..., sum_ - 1} 38 | separators = random.sample(list(range(1, sum_)), count - 1) 39 | separators = sorted(separators) 40 | return [right - left 41 | for left, right in zip([0] + separators, separators + [sum_])] 42 | 43 | 44 | def uniform_non_negative_integers_with_sum(count, sum_): 45 | """Returns list of size `count` of integers >= 0, summing to `sum_`.""" 46 | positive = uniform_positive_integers_with_sum(count, sum_ + count) 47 | return [i - 1 for i in positive] 48 | 49 | 50 | def log_number_binary_trees(size): 51 | """Returns (nat) log of number of binary trees with `size` internal nodes.""" 52 | # This is equal to log of C_size, where C_n is the nth Catalan number. 53 | assert isinstance(size, int) 54 | assert size >= 0 55 | log = 0.0 56 | for k in range(2, size + 1): 57 | log += math.log(size + k) - math.log(k) 58 | return log 59 | -------------------------------------------------------------------------------- /mathematics_dataset/util/combinatorics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests mathematics_dataset.util.combinatorics.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | 23 | # Dependency imports 24 | from absl.testing import absltest 25 | from mathematics_dataset.util import combinatorics 26 | 27 | 28 | class CombinatoricsTest(absltest.TestCase): 29 | 30 | def testPositiveIntegersWithSum(self): 31 | result = combinatorics.uniform_positive_integers_with_sum(1, 1) 32 | self.assertEqual(result, [1]) 33 | result = combinatorics.uniform_positive_integers_with_sum(2, 2) 34 | self.assertEqual(result, [1, 1]) 35 | result = combinatorics.uniform_positive_integers_with_sum(1, 10) 36 | self.assertEqual(sum(result), 10) 37 | result = combinatorics.uniform_positive_integers_with_sum(2, 10) 38 | self.assertEqual(sum(result), 10) 39 | result = combinatorics.uniform_positive_integers_with_sum(0, 0) 40 | self.assertEqual(result, []) 41 | 42 | def testNonNegativeIntegersWithSum(self): 43 | result = combinatorics.uniform_non_negative_integers_with_sum(1, 0) 44 | self.assertEqual(result, [0]) 45 | result = combinatorics.uniform_non_negative_integers_with_sum(2, 0) 46 | self.assertEqual(result, [0, 0]) 47 | result = combinatorics.uniform_non_negative_integers_with_sum(3, 10) 48 | self.assertEqual(sum(result), 10) 49 | 50 | def testLogNumberBinaryTrees(self): 51 | self.assertAlmostEqual( 52 | combinatorics.log_number_binary_trees(0), math.log(1)) 53 | self.assertAlmostEqual( 54 | combinatorics.log_number_binary_trees(1), math.log(1)) 55 | self.assertAlmostEqual( 56 | combinatorics.log_number_binary_trees(2), math.log(2)) 57 | self.assertAlmostEqual( 58 | combinatorics.log_number_binary_trees(3), math.log(5)) 59 | self.assertAlmostEqual( 60 | combinatorics.log_number_binary_trees(4), math.log(14)) 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /mathematics_dataset/util/composition_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.util.composition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import absltest 23 | from mathematics_dataset.util import composition 24 | import sympy 25 | 26 | 27 | class FunctionHandleTest(absltest.TestCase): 28 | 29 | def testApply(self): 30 | handle = composition.FunctionHandle('f', 'g') 31 | applied = handle.apply(*sympy.symbols('x y')) 32 | self.assertEqual(str(applied), 'f(g(x, y))') 33 | applied = handle.apply(sympy.symbols('x')) 34 | self.assertEqual(str(applied), 'f(g(x))') 35 | 36 | 37 | class ContextTest(absltest.TestCase): 38 | 39 | def testPeel(self): 40 | sample_args = composition.SampleArgs(4, 3.0) 41 | entropy, new_sample_args = sample_args.peel() 42 | self.assertAlmostEqual(entropy, 0.75) 43 | self.assertEqual(new_sample_args.num_modules, 4) 44 | self.assertAlmostEqual(new_sample_args.entropy, 2.25) 45 | 46 | def testSplit(self): 47 | sample_args = composition.SampleArgs(4, 5.0) 48 | children = sample_args.split(2) 49 | self.assertLen(children, 2) 50 | self.assertEqual(sum([child.num_modules for child in children]), 3) 51 | self.assertAlmostEqual(sum([child.entropy for child in children]), 5.0) 52 | 53 | 54 | class EntityTest(absltest.TestCase): 55 | 56 | def testInit_valueErrorIfSelfAndHandle(self): 57 | with self.assertRaisesRegex(self, ValueError, 'Cannot specify handle'): 58 | composition.Entity(context=composition.Context(), 59 | value=0, 60 | description='Something with {self}. ', 61 | handle='additional') 62 | 63 | 64 | if __name__ == '__main__': 65 | absltest.main() 66 | -------------------------------------------------------------------------------- /mathematics_dataset/util/display.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functionality for displaying expressions. 16 | 17 | SymPy provides a lot of functionality for displaying expressions, but it's 18 | slightly too centered on being a symbolic maths engine to provides all our 19 | needs. For example, it's impossible to display an unsimplified fraction like 20 | 3/6, or a decimal that isn't internally represented as a float and thus subject 21 | to rounding. 22 | 23 | Also provides some other convenience such as converting numbers to words, and 24 | displaying percentages (properly formatted). 25 | """ 26 | 27 | from __future__ import absolute_import 28 | from __future__ import division 29 | from __future__ import print_function 30 | 31 | import decimal 32 | 33 | # Dependency imports 34 | import sympy 35 | 36 | 37 | # For converting integers to words: 38 | _INTEGER_LOW = [ 39 | 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 40 | 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteeen', 'fifteen', 41 | 'sixteen', 'seventeen', 'eighteen', 'nineteen' 42 | ] 43 | _INTEGER_MID = [ 44 | '', '', 'twenty', 'thirty', 'fourty', 'fifty', 'sixty', 'seventy', 'eighty', 45 | 'ninety' 46 | ] 47 | _INTEGER_HIGH = [ 48 | (int(1e12), 'trillion'), (int(1e9), 'billion'), (int(1e6), 'million'), 49 | (int(1e3), 'thousand'), (100, 'hundred') 50 | ] 51 | 52 | 53 | # For converting rationals to words: 54 | _SINGULAR_DENOMINATORS = [ 55 | '', '', 'half', 'third', 'quarter', 'fifth', 'sixth', 'seventh', 'eighth', 56 | 'ninth', 'tenth', 'eleventh', 'twelth', 'thirteenth', 'fourteenth', 57 | 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth', 58 | 'twentieth' 59 | ] 60 | _PLURAL_DENOMINATORS = [ 61 | '', '', 'halves', 'thirds', 'quarters', 'fifths', 'sixths', 'sevenths', 62 | 'eighths', 'ninths', 'tenths', 'elevenths', 'twelths', 'thirteenths', 63 | 'fourteenths', 'fifteenths', 'sixteenths', 'seventeenths', 'eighteenths', 64 | 'nineteenths', 'twentieths' 65 | ] 66 | 67 | 68 | # For converting ordinals to words: 69 | _ORDINALS = [ 70 | 'zeroth', 'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 71 | 'eighth', 'ninth', 'tenth', 'eleventh', 'twelth', 'thirteenth', 72 | 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 73 | 'nineteenth', 'twentieth' 74 | ] 75 | 76 | 77 | class Decimal(object): 78 | """Display a value as a decimal.""" 79 | 80 | def __init__(self, value): 81 | """Initializes a `Decimal`. 82 | 83 | Args: 84 | value: (Sympy) value to display as a decimal. 85 | 86 | Raises: 87 | ValueError: If `value` cannot be represented as a non-terminating decimal. 88 | """ 89 | self._value = sympy.Rational(value) 90 | 91 | numer = int(sympy.numer(self._value)) 92 | denom = int(sympy.denom(self._value)) 93 | 94 | denom_factors = list(sympy.factorint(denom).keys()) 95 | for factor in denom_factors: 96 | if factor not in [2, 5]: 97 | raise ValueError('Cannot represent {} as a non-recurring decimal.' 98 | .format(value)) 99 | self._decimal = decimal.Decimal(numer) / decimal.Decimal(denom) 100 | 101 | @property 102 | def value(self): 103 | """Returns the value as a `sympy.Rational` object.""" 104 | return self._value 105 | 106 | def _sympy_(self): 107 | return self._value 108 | 109 | def decimal_places(self): 110 | """Returns the number of decimal places, e.g., 32 has 0 and 1.43 has 2.""" 111 | if isinstance(self._decimal, int): 112 | return 0 113 | elif isinstance(self._decimal, decimal.Decimal): 114 | return -self._decimal.as_tuple().exponent 115 | 116 | def __str__(self): 117 | sign, digits, exponent = self._decimal.as_tuple() 118 | sign = '' if sign == 0 else '-' 119 | 120 | num_left_digits = len(digits) + exponent # number digits "before" point 121 | 122 | if num_left_digits > 0: 123 | int_part = ''.join(str(digit) for digit in digits[:num_left_digits]) 124 | else: 125 | int_part = '0' 126 | 127 | if exponent < 0: 128 | frac_part = '.' 129 | if num_left_digits < 0: 130 | frac_part += '0' * -num_left_digits 131 | frac_part += ''.join(str(digit) for digit in digits[exponent:]) 132 | else: 133 | frac_part = '' 134 | 135 | return sign + int_part + frac_part 136 | 137 | def __add__(self, other): 138 | if not isinstance(other, Decimal): 139 | raise ValueError('Arithmetic support limited to other `Decimal`s.') 140 | return Decimal(self.value + other.value) 141 | 142 | def __sub__(self, other): 143 | if not isinstance(other, Decimal): 144 | raise ValueError('Arithmetic support limited to other `Decimal`s.') 145 | return Decimal(self.value - other.value) 146 | 147 | def __mul__(self, other): 148 | if not isinstance(other, Decimal): 149 | raise ValueError('Arithmetic support limited to other `Decimal`s.') 150 | return Decimal(self.value * other.value) 151 | 152 | def __neg__(self): 153 | return Decimal(-self.value) 154 | 155 | def round(self, ndigits=0): 156 | """Returns a new `Decimal` rounded to this many decimal places.""" 157 | scale = sympy.Integer(10 ** ndigits) 158 | numer = sympy.numer(self.value) * scale 159 | denom = sympy.denom(self.value) 160 | return Decimal(int(round(numer / denom)) / scale) 161 | 162 | def __round__(self, ndigits): 163 | return self.round(ndigits) 164 | 165 | def __int__(self): 166 | """Returns conversion to integer if possible; TypeError if non-integer.""" 167 | if self.decimal_places() == 0: 168 | return int(self._decimal) 169 | else: 170 | raise TypeError('Cannot represent {} as an integer.'.format(str(self))) 171 | 172 | # NOTE: this is implemented in addition to `__cmp__` because SymPy does not 173 | # support inequality comparison between sympy objects and objects that are not 174 | # convertible to sympy objects (such as strings). 175 | def __eq__(self, other): 176 | return self.value == other 177 | 178 | # Python 2 comparison 179 | def __cmp__(self, other): 180 | if self.value == other: 181 | return 0 182 | if self.value < other: 183 | return -1 184 | return 1 185 | 186 | # Python 3 comparison: 187 | def __lt__(self, other): 188 | return self.value < other 189 | 190 | def __le__(self, other): 191 | return self.value <= other 192 | 193 | def __gt__(self, other): 194 | return self.value > other 195 | 196 | def __ge__(self, other): 197 | return self.value >= other 198 | 199 | 200 | class Percentage(object): 201 | """Container for a percentage.""" 202 | 203 | def __init__(self, value): 204 | """Initializes a `Percentage`. 205 | 206 | Args: 207 | value: Percentage as a fractional value. E.g., pass in 208 | `sympy.Rational(2, 5)` to create the percentage "40%". 209 | """ 210 | self._value = value 211 | 212 | def _sympy_(self): 213 | return self._value 214 | 215 | def __str__(self): 216 | # Display percentages as decimals (not fractions). 217 | value = Decimal(self._value * 100) 218 | return str(value) + '%' 219 | 220 | 221 | class NonSimpleRational(object): 222 | """Container for rational a / b where allow gcd(a, b) > 1.""" 223 | 224 | def __init__(self, numer, denom): 225 | self._numer = numer 226 | self._denom = denom 227 | 228 | @property 229 | def numer(self): 230 | return self._numer 231 | 232 | @property 233 | def denom(self): 234 | return self._denom 235 | 236 | def __str__(self): 237 | return '{}/{}'.format(self._numer, self._denom) 238 | 239 | 240 | class StringNumber(object): 241 | """A string representing a number, that can also be sympified.""" 242 | 243 | def __init__(self, value, join_number_words_with_hyphens=True): 244 | """Initializes a `StringNumber`. 245 | 246 | Args: 247 | value: An integer or rational. 248 | join_number_words_with_hyphens: Whether to join the words in integers with 249 | hyphens when describing as a string. 250 | """ 251 | self._join_number_words_with_hyphens = join_number_words_with_hyphens 252 | self._sympy_value = sympy.sympify(value) 253 | self._string = self._to_string(value) 254 | 255 | def _integer_to_words(self, integer): 256 | """Converts an integer to a list of words.""" 257 | if integer < 0: 258 | raise ValueError('Cannot handle negative numbers.') 259 | 260 | if integer < 20: 261 | return [_INTEGER_LOW[integer]] 262 | 263 | words = None 264 | 265 | if integer < 100: 266 | tens, ones = divmod(integer, 10) 267 | if ones > 0: 268 | return [_INTEGER_MID[tens], _INTEGER_LOW[ones]] 269 | else: 270 | return [_INTEGER_MID[tens]] 271 | 272 | for value, word in _INTEGER_HIGH: 273 | if integer >= value: 274 | den, rem = divmod(integer, value) 275 | words = self._integer_to_words(den) + [word] 276 | if rem > 0: 277 | if rem < 100: 278 | words.append('and') 279 | words += self._integer_to_words(rem) 280 | return words 281 | 282 | def _rational_to_string(self, rational): 283 | """Converts a rational to words, e.g., "two thirds".""" 284 | numer = sympy.numer(rational) 285 | denom = sympy.denom(rational) 286 | 287 | numer_words = self._to_string(numer) 288 | 289 | if denom == 1: 290 | return numer_words 291 | 292 | if denom <= 0 or denom >= len(_PLURAL_DENOMINATORS): 293 | raise ValueError('Unsupported denominator {}.'.format(denom)) 294 | 295 | if numer == 1: 296 | denom_word = _SINGULAR_DENOMINATORS[denom] 297 | else: 298 | denom_word = _PLURAL_DENOMINATORS[denom] 299 | 300 | return '{} {}'.format(numer_words, denom_word) 301 | 302 | def _to_string(self, number): 303 | """Converts an integer or rational to words.""" 304 | if isinstance(number, sympy.Integer) or isinstance(number, int): 305 | words = self._integer_to_words(number) 306 | join_char = '-' if self._join_number_words_with_hyphens else ' ' 307 | return join_char.join(words) 308 | elif isinstance(number, sympy.Rational): 309 | return self._rational_to_string(number) 310 | else: 311 | raise ValueError('Unable to handle number {} with type {}.' 312 | .format(number, type(number))) 313 | 314 | def _sympy_(self): 315 | return self._sympy_value 316 | 317 | def __str__(self): 318 | return self._string 319 | 320 | 321 | class StringOrdinal(object): 322 | """A string representation of an ordinal, e.g., "first".""" 323 | 324 | def __init__(self, position): 325 | """Initializes a `StringOrdinal`. 326 | 327 | Args: 328 | position: An integer >= 0. 329 | 330 | Raises: 331 | ValueError: If `position` is non-positive or out of range. 332 | """ 333 | if position < 0 or position >= len(_ORDINALS): 334 | raise ValueError('Unsupported ordinal {}.'.format(position)) 335 | self._string = _ORDINALS[position] 336 | 337 | def __str__(self): 338 | return self._string 339 | 340 | 341 | class NumberList(object): 342 | """Contains a list of numbers, intended for display.""" 343 | 344 | def __init__(self, numbers): 345 | self._numbers = numbers 346 | 347 | def __str__(self): 348 | """Converts the list to a string. 349 | 350 | Returns: 351 | Human readable string. 352 | 353 | Raises: 354 | ValueError: if any of the strings contain a comma and thus would lead to 355 | an ambigious representation. 356 | """ 357 | strings = [] 358 | for number in self._numbers: 359 | string = str(number) 360 | if ',' in string: 361 | raise ValueError('String representation of the list will be ambigious, ' 362 | 'since term "{}" contains a comma.'.format(string)) 363 | strings.append(string) 364 | return ', '.join(strings) 365 | 366 | 367 | class NumberInBase(object): 368 | """Contains value, represented in a given base.""" 369 | 370 | def __init__(self, value, base): 371 | """Initializes a `NumberInBase`. 372 | 373 | Args: 374 | value: Positive or negative integer. 375 | base: Integer in the range [2, 36]. 376 | 377 | Raises: 378 | ValueError: If base is not in the range [2, 36] (since this is the limit 379 | that can be represented by 10 numbers plus 26 letters). 380 | """ 381 | if not 2 <= base <= 36: 382 | raise ValueError('base={} must be in the range [2, 36]'.format(base)) 383 | self._value = value 384 | self._base = base 385 | 386 | chars = [] 387 | remainder = abs(value) 388 | while True: 389 | digit = remainder % base 390 | char = str(digit) if digit <= 9 else chr(ord('a') + digit - 10) 391 | chars.append(char) 392 | remainder = int(remainder / base) 393 | if remainder == 0: 394 | break 395 | if value < 0: 396 | chars.append('-') 397 | 398 | self._str = ''.join(reversed(chars)) 399 | 400 | def __str__(self): 401 | return self._str 402 | 403 | def _sympy_(self): 404 | return self._value 405 | -------------------------------------------------------------------------------- /mathematics_dataset/util/display_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.util.display.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import absltest 23 | from mathematics_dataset.util import display 24 | import sympy 25 | 26 | 27 | class DecimalTest(absltest.TestCase): 28 | 29 | def testBasic_integer(self): 30 | decimal = display.Decimal(123) 31 | self.assertEqual(str(decimal), '123') 32 | self.assertEqual(sympy.sympify(decimal), sympy.Integer(123)) 33 | self.assertEqual(decimal.decimal_places(), 0) 34 | 35 | def testBasic_ten(self): 36 | decimal = display.Decimal(10) 37 | self.assertEqual(str(decimal), '10') 38 | self.assertEqual(sympy.sympify(decimal), sympy.Integer(10)) 39 | self.assertEqual(decimal.decimal_places(), 0) 40 | 41 | def testBasic(self): 42 | decimal = display.Decimal(sympy.Rational(123, 100)) 43 | self.assertEqual(str(decimal), '1.23') 44 | self.assertEqual(sympy.sympify(decimal), sympy.Rational(123, 100)) 45 | self.assertEqual(decimal.decimal_places(), 2) 46 | 47 | def testStr(self): 48 | self.assertEqual(str(display.Decimal(sympy.Rational(0, 10))), '0') 49 | self.assertEqual(str(display.Decimal(sympy.Rational(-1, 10))), '-0.1') 50 | self.assertEqual(str(display.Decimal(sympy.Rational(-11, 10))), '-1.1') 51 | self.assertEqual(str(display.Decimal(sympy.Rational(11, 10))), '1.1') 52 | self.assertEqual(str(display.Decimal(sympy.Rational(101, 1))), '101') 53 | self.assertEqual( 54 | str(display.Decimal(sympy.Rational(20171, 1000000))), '0.020171') 55 | 56 | def testStr_verySmall(self): 57 | # Tests it doesn't display in "scientific" notation 1E-9. 58 | decimal = display.Decimal(sympy.Rational(1, 1000000000)) 59 | self.assertEqual(str(decimal), '0.000000001') 60 | 61 | def testAdd(self): 62 | self.assertEqual((display.Decimal(2) + display.Decimal(3)).value, 5) 63 | 64 | def testSub(self): 65 | self.assertEqual((display.Decimal(2) - display.Decimal(3)).value, -1) 66 | 67 | def testMul(self): 68 | self.assertEqual((display.Decimal(2) * display.Decimal(3)).value, 6) 69 | 70 | def testRound(self): 71 | decimal = display.Decimal(sympy.Rational(2675, 1000)) # 2.675 72 | self.assertEqual(sympy.sympify(decimal.round()), sympy.Integer(3)) 73 | self.assertEqual(sympy.sympify(decimal.round(1)), sympy.Rational(27, 10)) 74 | self.assertEqual(sympy.sympify(decimal.round(2)), sympy.Rational(268, 100)) 75 | self.assertEqual(sympy.sympify(decimal.round(3)), 76 | sympy.Rational(2675, 1000)) 77 | 78 | def testInt(self): 79 | decimal = display.Decimal(123) 80 | self.assertEqual(int(decimal), 123) 81 | 82 | def testInt_errorIfNonInt(self): 83 | decimal = display.Decimal(sympy.Rational(1, 2)) 84 | with self.assertRaisesRegex(self, TypeError, 'Cannot represent'): 85 | int(decimal) 86 | 87 | def testComparison(self): 88 | decimal = display.Decimal(sympy.Rational(-1, 2)) 89 | # pylint: disable=g-generic-assert 90 | self.assertFalse(decimal != -0.5) 91 | self.assertTrue(decimal != 0) 92 | self.assertFalse(decimal < -0.5) 93 | self.assertTrue(decimal < 0) 94 | self.assertTrue(decimal <= -0.5) 95 | self.assertTrue(decimal <= 0) 96 | self.assertFalse(decimal > -0.5) 97 | self.assertTrue(decimal > -1) 98 | self.assertTrue(decimal >= -0.5) 99 | self.assertFalse(decimal >= 0) 100 | self.assertFalse(decimal == 0) 101 | self.assertTrue(decimal == -0.5) 102 | 103 | def testNegation(self): 104 | decimal = display.Decimal(sympy.Rational(1, 2)) 105 | decimal = -decimal 106 | self.assertNotEqual(decimal, 0.5) 107 | self.assertEqual(decimal, -0.5) 108 | 109 | 110 | class PercentageTest(absltest.TestCase): 111 | 112 | def testPercentage(self): 113 | percentage = display.Percentage(1.5) 114 | self.assertEqual(str(percentage), '150%') 115 | 116 | percentage = display.Percentage(sympy.Rational(67, 100)) 117 | self.assertEqual(str(percentage), '67%') 118 | 119 | percentage = display.Percentage(sympy.Rational(67, 1000)) 120 | self.assertEqual(str(percentage), '6.7%') 121 | 122 | 123 | class NonSimpleRationalTest(absltest.TestCase): 124 | 125 | def testBasic(self): 126 | frac = display.NonSimpleRational(4, 6) 127 | self.assertEqual(frac.numer, 4) 128 | self.assertEqual(frac.denom, 6) 129 | self.assertEqual(str(frac), '4/6') 130 | 131 | 132 | class StringNumberTest(absltest.TestCase): 133 | 134 | def testIntegerToWords(self): 135 | words = display.StringNumber(0) 136 | self.assertEqual(str(words), 'zero') 137 | self.assertEqual(sympy.sympify(words), 0) 138 | 139 | words = display.StringNumber(8) 140 | self.assertEqual(str(words), 'eight') 141 | self.assertEqual(sympy.sympify(words), 8) 142 | 143 | words = display.StringNumber(12) 144 | self.assertEqual(str(words), 'twelve') 145 | self.assertEqual(sympy.sympify(words), 12) 146 | 147 | words = display.StringNumber(30) 148 | self.assertEqual(str(words), 'thirty') 149 | self.assertEqual(sympy.sympify(words), 30) 150 | 151 | words = display.StringNumber(100) 152 | self.assertEqual(str(words), 'one-hundred') 153 | self.assertEqual(sympy.sympify(words), 100) 154 | 155 | words = display.StringNumber(103) 156 | self.assertEqual(str(words), 'one-hundred-and-three') 157 | self.assertEqual(sympy.sympify(words), 103) 158 | 159 | words = display.StringNumber(15439822) 160 | self.assertEqual(str(words), 'fifteen-million-four-hundred-and-thirty-nine' 161 | '-thousand-eight-hundred-and-twenty-two') 162 | self.assertEqual(sympy.sympify(words), 15439822) 163 | 164 | def testRationalToWords(self): 165 | words = display.StringNumber(sympy.Rational(2, 3)) 166 | self.assertEqual(str(words), 'two thirds') 167 | 168 | 169 | class StringOrdinalTest(absltest.TestCase): 170 | 171 | def testBasic(self): 172 | ordinal = display.StringOrdinal(0) 173 | self.assertEqual(str(ordinal), 'zeroth') 174 | ordinal = display.StringOrdinal(10) 175 | self.assertEqual(str(ordinal), 'tenth') 176 | 177 | def testCreate_errorIfNegative(self): 178 | with self.assertRaisesRegex(self, ValueError, 'Unsupported ordinal'): 179 | display.StringOrdinal(-1) 180 | 181 | 182 | class NumberListTest(absltest.TestCase): 183 | 184 | def testBasic(self): 185 | numbers = [2, 3, 1] 186 | number_list = display.NumberList(numbers) 187 | string = str(number_list) 188 | self.assertEqual(string, '2, 3, 1') 189 | 190 | 191 | class NumberInBaseTest(absltest.TestCase): 192 | 193 | def testBasic(self): 194 | self.assertEqual(str(display.NumberInBase(1, 10)), '1') 195 | self.assertEqual(str(display.NumberInBase(-1, 10)), '-1') 196 | self.assertEqual(str(display.NumberInBase(1, 2)), '1') 197 | self.assertEqual(str(display.NumberInBase(-1, 2)), '-1') 198 | self.assertEqual(str(display.NumberInBase(2, 2)), '10') 199 | self.assertEqual(str(display.NumberInBase(-2, 2)), '-10') 200 | self.assertEqual(str(display.NumberInBase(10, 16)), 'a') 201 | self.assertEqual(str(display.NumberInBase(16, 16)), '10') 202 | self.assertEqual(str(display.NumberInBase(256, 16)), '100') 203 | self.assertEqual(str(display.NumberInBase(-75483, 10)), '-75483') 204 | 205 | 206 | if __name__ == '__main__': 207 | absltest.main() 208 | -------------------------------------------------------------------------------- /mathematics_dataset/util/probability.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functionality for working with probability spaces and random variables. 16 | 17 | Basic recap of probability theory, and thus of classes in this file: 18 | 19 | * A probability space is a (finite or infinite) set Omega with a probability 20 | measure defined on this. 21 | * A random variable is a mapping from a probability space to another measure 22 | space. 23 | * An event is a measurable set in a sample space. 24 | 25 | For example, suppose a bag contains 3 balls: two red balls, and one white ball. 26 | This could be represented by a discrete probability space of size 3 with 27 | elements {1, 2, 3}, with equal measure assigned to all 3 elements; and a random 28 | variable that maps 1->red, 2->red, and 3->white. Then the probability of drawing 29 | a red ball is the measure in the probability space of the inverse under the 30 | random variable mapping of {red}, i.e., of {1, 2}, which is 2/3. 31 | """ 32 | 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import abc 38 | import itertools 39 | 40 | # Dependency imports 41 | import six 42 | from six.moves import zip 43 | import sympy 44 | 45 | 46 | @six.add_metaclass(abc.ABCMeta) 47 | class Event(object): 48 | """Represents an event in a measure space.""" 49 | 50 | 51 | @six.add_metaclass(abc.ABCMeta) 52 | class ProbabilitySpace(object): 53 | """Represents a probability space.""" 54 | 55 | @abc.abstractmethod 56 | def probability(self, event): 57 | """Returns the probability of an event.""" 58 | 59 | 60 | @six.add_metaclass(abc.ABCMeta) 61 | class RandomVariable(object): 62 | """Random variable; a mapping from a probability space to a measure space.""" 63 | 64 | @abc.abstractmethod 65 | def __call__(self, event): 66 | """Maps an `_Event` in the probability space to one in the sample space.""" 67 | 68 | @abc.abstractmethod 69 | def inverse(self, event): 70 | """Maps event in the sample space back to the inverse in the prob. space.""" 71 | 72 | 73 | class DiscreteEvent(Event): 74 | """Set of discrete values.""" 75 | 76 | def __init__(self, values): 77 | self._values = values 78 | 79 | @property 80 | def values(self): 81 | return self._values 82 | 83 | 84 | class FiniteProductEvent(Event): 85 | """Event consisting of cartesian product of events.""" 86 | 87 | def __init__(self, events): 88 | """Initializes a `FiniteProductEvent`. 89 | 90 | Args: 91 | events: Tuple of `Event`s; resulting event will be cartesian product of 92 | these. 93 | """ 94 | self._events = events 95 | 96 | @property 97 | def events(self): 98 | return self._events 99 | 100 | def all_sequences(self): 101 | """Returns iterator of sequences by selecting a single event in each coord. 102 | 103 | This assumes that every component event is an instance of `DiscreteEvent`. 104 | 105 | Returns: 106 | Iterator over tuples of values. 107 | 108 | Raises: 109 | ValueError: If one of the component events is not a `DiscreteEvent`. 110 | """ 111 | if not all(isinstance(event, DiscreteEvent) for event in self._events): 112 | raise ValueError('Not all component events are DiscreteEvents') 113 | values_list = [event.values for event in self._events] 114 | return itertools.product(*values_list) 115 | 116 | 117 | class CountLevelSetEvent(Event): 118 | """Event of all sequences with fixed number of different values occurring.""" 119 | 120 | def __init__(self, counts): 121 | """Initializes `CountLevelSetEvent`. 122 | 123 | E.g., to construct the event of getting two red balls and one green ball, 124 | pass `counts = {red: 2, green: 1}`. (Then `all_sequences()` would return 125 | `[(red, red, green), (red, green, red), (green, red, red)]`. 126 | 127 | Args: 128 | counts: Dictionary mapping values to the number of times they occur in a 129 | sequence. 130 | """ 131 | self._counts = counts 132 | self._all_sequences = None 133 | 134 | @property 135 | def counts(self): 136 | return self._counts 137 | 138 | def all_sequences(self): 139 | """Returns all sequences generated by this level set.""" 140 | if self._all_sequences is None: 141 | # Generate via dynamic programming. 142 | cache = {} # dict mapping tuple -> list of tuples 143 | labels = list(self._counts.keys()) 144 | 145 | def generate(counts): 146 | """Returns list of tuples for given `counts` of labels.""" 147 | if sum(counts) == 0: 148 | return [()] 149 | counts = tuple(counts) 150 | if counts in cache: 151 | return cache[counts] 152 | generated = [] 153 | for i, count in enumerate(counts): 154 | if count == 0: 155 | continue 156 | counts_minus = list(counts) 157 | counts_minus[i] -= 1 158 | counts_minus = tuple(counts_minus) 159 | extensions = generate(counts_minus) 160 | generated += [tuple([labels[i]] + list(extension)) 161 | for extension in extensions] 162 | cache[counts] = generated 163 | return generated 164 | 165 | self._all_sequences = generate(list(self._counts.values())) 166 | 167 | return self._all_sequences 168 | 169 | 170 | class SequenceEvent(Event): 171 | """Collection of sequences.""" 172 | 173 | def __init__(self, sequences): 174 | self._sequences = sequences 175 | 176 | def all_sequences(self): 177 | return self._sequences 178 | 179 | 180 | def normalize_weights(weights): 181 | """Normalizes the weights (as sympy.Rational) in dictionary of weights.""" 182 | weight_sum = sum(six.itervalues(weights)) 183 | return { 184 | i: sympy.Rational(weight, weight_sum) 185 | for i, weight in six.iteritems(weights) 186 | } 187 | 188 | 189 | class DiscreteProbabilitySpace(ProbabilitySpace): 190 | """Discrete probability space.""" 191 | 192 | def __init__(self, weights=None): 193 | """Initializes an `DiscreteProbabilitySpace`. 194 | 195 | Args: 196 | weights: Dictionary mapping values to relative probability of selecting 197 | that value. This will be normalized. 198 | """ 199 | self._weights = normalize_weights(weights) 200 | 201 | def probability(self, event): 202 | if isinstance(event, DiscreteEvent): 203 | return sum(self._weights[value] 204 | for value in event.values if value in self._weights) 205 | else: 206 | raise ValueError('Unhandled event type {}'.format(type(event))) 207 | 208 | @property 209 | def weights(self): 210 | """Returns dictionary of probability of each element.""" 211 | return self._weights 212 | 213 | 214 | class FiniteProductSpace(ProbabilitySpace): 215 | """Finite cartesian product of probability spaces.""" 216 | 217 | def __init__(self, spaces): 218 | """Initializes a `FiniteProductSpace`. 219 | 220 | Args: 221 | spaces: List of `ProbabilitySpace`. 222 | """ 223 | self._spaces = spaces 224 | 225 | def all_spaces_equal(self): 226 | return all([self._spaces[0] == space for space in self._spaces]) 227 | 228 | def probability(self, event): 229 | # Specializations for optimization. 230 | if isinstance(event, FiniteProductEvent): 231 | assert len(self._spaces) == len(event.events) 232 | return sympy.prod([ 233 | space.probability(event_slice) 234 | for space, event_slice in zip(self._spaces, event.events)]) 235 | 236 | if isinstance(event, CountLevelSetEvent) and self.all_spaces_equal(): 237 | space = self._spaces[0] 238 | counts = event.counts 239 | probabilities = { 240 | value: space.probability(DiscreteEvent({value})) 241 | for value in six.iterkeys(counts) 242 | } 243 | 244 | num_events = sum(six.itervalues(counts)) 245 | assert num_events == len(self._spaces) 246 | # Multinomial coefficient: 247 | coeff = ( 248 | sympy.factorial(num_events) / sympy.prod( 249 | [sympy.factorial(i) for i in six.itervalues(counts)])) 250 | return coeff * sympy.prod([ 251 | pow(probabilities[value], counts[value]) 252 | for value in six.iterkeys(counts) 253 | ]) 254 | 255 | raise ValueError('Unhandled event type {}'.format(type(event))) 256 | 257 | @property 258 | def spaces(self): 259 | """Returns list of spaces.""" 260 | return self._spaces 261 | 262 | 263 | class SampleWithoutReplacementSpace(ProbabilitySpace): 264 | """Probability space formed by sampling discrete space without replacement.""" 265 | 266 | def __init__(self, weights, n_samples): 267 | """Initializes a `SampleWithoutReplacementSpace`. 268 | 269 | Args: 270 | weights: Dictionary mapping values to relative probability of selecting 271 | that value. This will be normalized. 272 | n_samples: Number of samples to draw. 273 | 274 | Raises: 275 | ValueError: If `n_samples > len(weights)`. 276 | """ 277 | if n_samples > len(weights): 278 | raise ValueError('n_samples is more than number of discrete elements') 279 | self._weights = normalize_weights(weights) 280 | self._n_samples = n_samples 281 | 282 | @property 283 | def n_samples(self): 284 | """Number of samples to draw.""" 285 | return self._n_samples 286 | 287 | def probability(self, event): 288 | try: 289 | all_sequences = event.all_sequences() 290 | except AttributeError: 291 | raise ValueError('Unhandled event type {}'.format(type(event))) 292 | 293 | probability_sum = 0 294 | for sequence in all_sequences: 295 | if len(sequence) != len(set(sequence)): 296 | continue # not all unique, so not "without replacement". 297 | p_sequence = 1 298 | removed_prob = 0 299 | for i in sequence: 300 | p = self._weights[i] if i in self._weights else 0 301 | if p == 0: 302 | p_sequence = 0 303 | break 304 | p_sequence *= p / (1 - removed_prob) 305 | removed_prob += p 306 | probability_sum += p_sequence 307 | return probability_sum 308 | 309 | 310 | class IdentityRandomVariable(RandomVariable): 311 | """Identity map of a probability space.""" 312 | 313 | def __call__(self, event): 314 | return event 315 | 316 | def inverse(self, event): 317 | return event 318 | 319 | 320 | class DiscreteRandomVariable(RandomVariable): 321 | """Specialization to discrete random variable. 322 | 323 | This is simply a mapping from a discrete space to a discrete space (dictionary 324 | lookup). 325 | """ 326 | 327 | def __init__(self, mapping): 328 | """Initializes `DiscreteRandomVariable` from `mapping` dict.""" 329 | self._mapping = mapping 330 | self._inverse = {} 331 | for key, value in six.iteritems(mapping): 332 | if value in self._inverse: 333 | self._inverse[value].add(key) 334 | else: 335 | self._inverse[value] = set([key]) 336 | 337 | def __call__(self, event): 338 | if isinstance(event, DiscreteEvent): 339 | return DiscreteEvent({self._mapping[value] for value in event.values}) 340 | else: 341 | raise ValueError('Unhandled event type {}'.format(type(event))) 342 | 343 | def inverse(self, event): 344 | if isinstance(event, DiscreteEvent): 345 | set_ = set() 346 | for value in event.values: 347 | if value in self._inverse: 348 | set_.update(self._inverse[value]) 349 | return DiscreteEvent(set_) 350 | else: 351 | raise ValueError('Unhandled event type {}'.format(type(event))) 352 | 353 | 354 | class FiniteProductRandomVariable(RandomVariable): 355 | """Product random variable. 356 | 357 | This has the following semantics. Let this be X = (X_1, ..., X_n). Then 358 | 359 | X(w) = (X_1(w_1), ..., X_n(w_n)) 360 | 361 | (the sample space is assumed to be of sequence type). 362 | """ 363 | 364 | def __init__(self, random_variables): 365 | """Initializes a `FiniteProductRandomVariable`. 366 | 367 | Args: 368 | random_variables: Tuple of `RandomVariable`. 369 | """ 370 | self._random_variables = random_variables 371 | 372 | def __call__(self, event): 373 | if isinstance(event, FiniteProductEvent): 374 | assert len(event.events) == len(self._random_variables) 375 | zipped = list(zip(self._random_variables, event.events)) 376 | return FiniteProductEvent( 377 | [random_variable(sub_event) 378 | for random_variable, sub_event in zipped]) 379 | else: 380 | raise ValueError('Unhandled event type {}'.format(type(event))) 381 | 382 | def inverse(self, event): 383 | # Specialization for `FiniteProductEvent`; don't need to take all sequences. 384 | if isinstance(event, FiniteProductEvent): 385 | assert len(event.events) == len(self._random_variables) 386 | zipped = list(zip(self._random_variables, event.events)) 387 | return FiniteProductEvent(tuple( 388 | random_variable.inverse(sub_event) 389 | for random_variable, sub_event in zipped)) 390 | 391 | # Try fallback of mapping each sequence separately. 392 | try: 393 | all_sequences = event.all_sequences() 394 | except AttributeError: 395 | raise ValueError('Unhandled event type {}'.format(type(event))) 396 | 397 | mapped = set() 398 | for sequence in all_sequences: 399 | assert len(sequence) == len(self._random_variables) 400 | zipped = list(zip(self._random_variables, sequence)) 401 | mapped_sequence = FiniteProductEvent(tuple( 402 | random_variable.inverse(DiscreteEvent({element})) 403 | for random_variable, element in zipped)) 404 | mapped.update(mapped_sequence.all_sequences()) 405 | return SequenceEvent(mapped) 406 | -------------------------------------------------------------------------------- /mathematics_dataset/util/probability_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for mathematics_dataset.util.probability.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import absltest 23 | from mathematics_dataset.util import probability 24 | import sympy 25 | 26 | 27 | class FiniteProductEventTest(absltest.TestCase): 28 | 29 | def testAllSequences(self): 30 | event = probability.FiniteProductEvent([probability.DiscreteEvent({1, 2}), 31 | probability.DiscreteEvent({3})]) 32 | all_sequences = [i for i in event.all_sequences()] 33 | self.assertEqual(all_sequences, [(1, 3), (2, 3)]) 34 | 35 | 36 | class CountLevelSetEventTest(absltest.TestCase): 37 | 38 | def testAllSequences(self): 39 | event = probability.CountLevelSetEvent({'a': 2, 'b': 4, 'c': 1}) 40 | all_sequences = event.all_sequences() 41 | 42 | # Number of sequences should be 7! / (4! * 2! * 1!) = 105. 43 | self.assertLen(all_sequences, 105) 44 | # They should all be unique. 45 | self.assertEqual(len(all_sequences), len(set(all_sequences))) 46 | # And check contains one correctly generated tuple. 47 | self.assertIn(('a', 'b', 'c', 'b', 'b', 'a', 'b'), all_sequences) 48 | 49 | 50 | class DiscreteProbabilitySpaceTest(absltest.TestCase): 51 | 52 | def testBasic(self): 53 | space = probability.DiscreteProbabilitySpace({0: 1, 1: 2, 2: 3}) 54 | p = space.probability(probability.DiscreteEvent([0])) 55 | self.assertEqual(p, sympy.Rational(1, 6)) 56 | p = space.probability(probability.DiscreteEvent([0, 1])) 57 | self.assertEqual(p, sympy.Rational(1, 2)) 58 | p = space.probability(probability.DiscreteEvent([0, 1, 2])) 59 | self.assertEqual(p, 1) 60 | p = space.probability(probability.DiscreteEvent([0, 1, 2, 3])) 61 | self.assertEqual(p, 1) 62 | p = space.probability(probability.DiscreteEvent([3])) 63 | self.assertEqual(p, 0) 64 | 65 | 66 | class FiniteProductSpaceTest(absltest.TestCase): 67 | 68 | def testProbability_FiniteProductEvent(self): 69 | # 5 coin flips of a biased coin with heads prob = 1/3. 70 | base_space = probability.DiscreteProbabilitySpace({'h': 1, 't': 2}) 71 | space = probability.FiniteProductSpace([base_space] * 5) 72 | 73 | heads = probability.DiscreteEvent({'h'}) 74 | tails = probability.DiscreteEvent({'t'}) 75 | event = probability.FiniteProductEvent([heads, heads, tails, tails, heads]) 76 | self.assertEqual(space.probability(event), sympy.Rational(4, 3**5)) 77 | 78 | def testProbability_CountLevelSetEvent(self): 79 | base_space = probability.DiscreteProbabilitySpace({'a': 2, 'b': 3, 'c': 5}) 80 | space = probability.FiniteProductSpace([base_space] * 12) 81 | event = probability.CountLevelSetEvent({'a': 7, 'b': 2, 'c': 3}) 82 | 83 | # Probability should be (12 choose 7 2 3) * p(a)^7 p(b)^2 p(c)^3 84 | coeff = 7920 85 | p_a = sympy.Rational(1, 5) 86 | p_b = sympy.Rational(3, 10) 87 | p_c = sympy.Rational(1, 2) 88 | self.assertEqual(space.probability(event), 89 | coeff * pow(p_a, 7) * pow(p_b, 2) * pow(p_c, 3)) 90 | 91 | 92 | class SampleWithoutReplacementSpaceTest(absltest.TestCase): 93 | 94 | def testBasic(self): 95 | space = probability.SampleWithoutReplacementSpace({0: 1, 1: 1}, 2) 96 | event_0_0 = probability.FiniteProductEvent( 97 | [probability.DiscreteEvent({0}), probability.DiscreteEvent({0})]) 98 | event_0_1 = probability.FiniteProductEvent( 99 | [probability.DiscreteEvent({0}), probability.DiscreteEvent({1})]) 100 | p_0_0 = space.probability(event_0_0) 101 | p_0_1 = space.probability(event_0_1) 102 | self.assertEqual(p_0_0, 0) 103 | self.assertEqual(p_0_1, sympy.Rational(1, 2)) 104 | 105 | space = probability.SampleWithoutReplacementSpace({0: 1, 1: 0}, 1) 106 | event_0 = probability.FiniteProductEvent([probability.DiscreteEvent({0})]) 107 | event_1 = probability.FiniteProductEvent([probability.DiscreteEvent({1})]) 108 | event_2 = probability.FiniteProductEvent([probability.DiscreteEvent({2})]) 109 | p_0 = space.probability(event_0) 110 | p_1 = space.probability(event_1) 111 | p_2 = space.probability(event_2) 112 | self.assertEqual(p_0, 1) 113 | self.assertEqual(p_1, 0) 114 | self.assertEqual(p_2, 0) 115 | 116 | 117 | class DiscreteRandomVariableTest(absltest.TestCase): 118 | 119 | def testCall(self): 120 | random_variable = probability.DiscreteRandomVariable({1: 1, 2: 3, 3: 4}) 121 | forwards = random_variable(probability.DiscreteEvent({1, 3})) 122 | self.assertEqual(forwards.values, {1, 4}) 123 | 124 | def testInverse(self): 125 | random_variable = probability.DiscreteRandomVariable({1: 1, 2: 3, 3: 4}) 126 | inverse = random_variable.inverse(probability.DiscreteEvent({1, 3})) 127 | self.assertEqual(inverse.values, {1, 2}) 128 | 129 | random_variable = probability.DiscreteRandomVariable({1: 1, 2: 1}) 130 | inverse = random_variable.inverse(probability.DiscreteEvent({1, 5})) 131 | self.assertEqual(inverse.values, {1, 2}) 132 | 133 | 134 | class FiniteProductRandomVariableTest(absltest.TestCase): 135 | 136 | def _random_variable(self): 137 | rv1 = probability.DiscreteRandomVariable({1: 'a', 2: 'b', 3: 'c'}) 138 | rv2 = probability.DiscreteRandomVariable({1: 'x', 2: 'y', 3: 'x'}) 139 | return probability.FiniteProductRandomVariable((rv1, rv2)) 140 | 141 | def testCall_FiniteProductEvent(self): 142 | rv = self._random_variable() 143 | event1 = probability.DiscreteEvent({1, 2}) 144 | event2 = probability.DiscreteEvent({1, 3}) 145 | event = probability.FiniteProductEvent((event1, event2)) 146 | result = rv(event) 147 | self.assertIsInstance(result, probability.FiniteProductEvent) 148 | self.assertLen(result.events, 2) 149 | self.assertEqual(result.events[0].values, {'a', 'b'}) 150 | self.assertEqual(result.events[1].values, {'x'}) 151 | 152 | def testInverse_FiniteProductEvent(self): 153 | rv = self._random_variable() 154 | event1 = probability.DiscreteEvent({'a', 'b'}) 155 | event2 = probability.DiscreteEvent({'x'}) 156 | event = probability.FiniteProductEvent((event1, event2)) 157 | result = rv.inverse(event) 158 | self.assertIsInstance(result, probability.FiniteProductEvent) 159 | self.assertLen(result.events, 2) 160 | self.assertEqual(result.events[0].values, {1, 2}) 161 | self.assertEqual(result.events[1].values, {1, 3}) 162 | 163 | def testInverse_CountLevelSetEvent(self): 164 | rv = self._random_variable() 165 | event = probability.CountLevelSetEvent({'a': 1, 'x': 1}) 166 | result = rv.inverse(event) 167 | sequences = result.all_sequences() 168 | self.assertLen(sequences, 2) 169 | self.assertEqual(set(sequences), {(1, 1), (1, 3)}) 170 | 171 | 172 | if __name__ == '__main__': 173 | absltest.main() 174 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module setuptools script.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | description = """A synthetic dataset of school-level mathematics questions. 25 | 26 | This dataset code generates mathematical question and answer pairs, from a range 27 | of question types (such as in arithmetic, algebra, probability, etc), at roughly 28 | school-level difficulty. This is designed to test the mathematical learning and 29 | reasoning skills of learning models. 30 | 31 | Original paper: Analysing Mathematical Reasoning Abilities of Neural Models 32 | (Saxton, Grefenstette, Hill, Kohli) (https://openreview.net/pdf?id=H1gR5iR5FX). 33 | """ 34 | 35 | setup( 36 | name='mathematics_dataset', 37 | version='1.0.1', 38 | description='A synthetic dataset of school-level mathematics questions', 39 | long_description=description, 40 | author='DeepMind', 41 | author_email='saxton@google.com', 42 | license='Apache License, Version 2.0', 43 | keywords='mathematics dataset', 44 | url='https://github.com/deepmind/mathematics_dataset', 45 | packages=find_packages(), 46 | install_requires=[ 47 | 'absl-py>=0.1.0', 48 | 'numpy>=1.10', 49 | 'six', 50 | 'sympy>=1.2', 51 | ], 52 | classifiers=[ 53 | 'Development Status :: 4 - Beta', 54 | 'Environment :: Console', 55 | 'Intended Audience :: Developers', 56 | 'Intended Audience :: Science/Research', 57 | 'License :: OSI Approved :: Apache Software License', 58 | 'Operating System :: POSIX :: Linux', 59 | 'Programming Language :: Python :: 2.7', 60 | 'Programming Language :: Python :: 3.4', 61 | 'Programming Language :: Python :: 3.5', 62 | 'Programming Language :: Python :: 3.6', 63 | 'Programming Language :: Python :: 3.7', 64 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 65 | ], 66 | ) 67 | --------------------------------------------------------------------------------