├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── mathematics_dataset
├── __init__.py
├── example.py
├── generate.py
├── generate_settings.py
├── generate_test.py
├── generate_to_file.py
├── modules
│ ├── __init__.py
│ ├── algebra.py
│ ├── algebra_test.py
│ ├── arithmetic.py
│ ├── arithmetic_test.py
│ ├── calculus.py
│ ├── calculus_test.py
│ ├── comparison.py
│ ├── measurement.py
│ ├── modules.py
│ ├── numbers.py
│ ├── polynomials.py
│ ├── probability.py
│ └── train_test_split.py
├── sample
│ ├── __init__.py
│ ├── arithmetic.py
│ ├── arithmetic_test.py
│ ├── linear_system.py
│ ├── linear_system_test.py
│ ├── number.py
│ ├── number_test.py
│ ├── ops.py
│ ├── ops_test.py
│ ├── polynomials.py
│ └── polynomials_test.py
└── util
│ ├── __init__.py
│ ├── combinatorics.py
│ ├── combinatorics_test.py
│ ├── composition.py
│ ├── composition_test.py
│ ├── display.py
│ ├── display_test.py
│ ├── probability.py
│ └── probability_test.py
└── setup.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mathematics Dataset
2 |
3 | This dataset code generates mathematical question and answer pairs, from a range
4 | of question types at roughly school-level difficulty. This is designed to test
5 | the mathematical learning and algebraic reasoning skills of learning models.
6 |
7 | Original paper: [Analysing Mathematical
8 | Reasoning Abilities of Neural Models](https://openreview.net/pdf?id=H1gR5iR5FX)
9 | (Saxton, Grefenstette, Hill, Kohli).
10 |
11 | ## Example questions
12 |
13 | ```
14 | Question: Solve -42*r + 27*c = -1167 and 130*r + 4*c = 372 for r.
15 | Answer: 4
16 |
17 | Question: Calculate -841880142.544 + 411127.
18 | Answer: -841469015.544
19 |
20 | Question: Let x(g) = 9*g + 1. Let q(c) = 2*c + 1. Let f(i) = 3*i - 39. Let w(j) = q(x(j)). Calculate f(w(a)).
21 | Answer: 54*a - 30
22 |
23 | Question: Let e(l) = l - 6. Is 2 a factor of both e(9) and 2?
24 | Answer: False
25 |
26 | Question: Let u(n) = -n**3 - n**2. Let e(c) = -2*c**3 + c. Let l(j) = -118*e(j) + 54*u(j). What is the derivative of l(a)?
27 | Answer: 546*a**2 - 108*a - 118
28 |
29 | Question: Three letters picked without replacement from qqqkkklkqkkk. Give prob of sequence qql.
30 | Answer: 1/110
31 | ```
32 |
33 | ## Pre-generated data
34 |
35 | [Pre-generated files](https://console.cloud.google.com/storage/browser/mathematics-dataset)
36 |
37 | ### Version 1.0
38 |
39 | This is the version released with the original paper. It contains 2 million
40 | (question, answer) pairs per module, with questions limited to 160 characters in
41 | length, and answers to 30 characters in length. Note the training data for each
42 | question type is split into "train-easy", "train-medium", and "train-hard". This
43 | allows training models via a curriculum. The data can also be mixed together
44 | uniformly from these training datasets to obtain the results reported in the
45 | paper. Categories:
46 |
47 | * **algebra** (linear equations, polynomial roots, sequences)
48 | * **arithmetic** (pairwise operations and mixed expressions, surds)
49 | * **calculus** (differentiation)
50 | * **comparison** (closest numbers, pairwise comparisons, sorting)
51 | * **measurement** (conversion, working with time)
52 | * **numbers** (base conversion, remainders, common divisors and multiples,
53 | primality, place value, rounding numbers)
54 | * **polynomials** (addition, simplification, composition, evaluating, expansion)
55 | * **probability** (sampling without replacement)
56 |
57 | ## Getting the source
58 |
59 | ### PyPI
60 |
61 | The easiest way to get the source is to use pip:
62 |
63 | ```shell
64 | $ pip install mathematics_dataset
65 | ```
66 |
67 | ### From GitHub
68 |
69 | Alternately you can get the source by cloning the mathematics_dataset
70 | repository:
71 |
72 | ```shell
73 | $ git clone https://github.com/deepmind/mathematics_dataset
74 | $ pip install --upgrade mathematics_dataset/
75 | ```
76 |
77 | ## Generating examples
78 |
79 | Generated examples can be printed to stdout via the `generate` script. For
80 | example:
81 |
82 | ```shell
83 | python -m mathematics_dataset.generate --filter=linear_1d
84 | ```
85 |
86 | will generate example (question, answer) pairs for solving linear equations in
87 | one variable.
88 |
89 | We've also included `generate_to_file.py` as an example of how to write the
90 | generated examples to text files. You can use this directly, or adapt it for
91 | your generation and training needs.
92 |
93 | ## Dataset Metadata
94 | The following table is necessary for this dataset to be indexed by search
95 | engines such as Google Dataset Search.
96 |
97 |
98 |
99 | property |
100 | value |
101 |
102 |
103 | name |
104 | Mathematics Dataset |
105 |
106 |
107 | url |
108 | https://github.com/deepmind/mathematics_dataset |
109 |
110 |
111 | sameAs |
112 | https://github.com/deepmind/mathematics_dataset |
113 |
114 |
115 | description |
116 | This dataset consists of mathematical question and answer pairs, from a range
117 | of question types at roughly school-level difficulty. This is designed to test
118 | the mathematical learning and algebraic reasoning skills of learning models.\n
119 | \n
120 | ## Example questions\n
121 | \n
122 | ```\n
123 | Question: Solve -42*r + 27*c = -1167 and 130*r + 4*c = 372 for r.\n
124 | Answer: 4\n
125 | \n
126 | Question: Calculate -841880142.544 + 411127.\n
127 | Answer: -841469015.544\n
128 | \n
129 | Question: Let x(g) = 9*g + 1. Let q(c) = 2*c + 1. Let f(i) = 3*i - 39. Let w(j) = q(x(j)). Calculate f(w(a)).\n
130 | Answer: 54*a - 30\n
131 | ```\n
132 | \n
133 | It contains 2 million
134 | (question, answer) pairs per module, with questions limited to 160 characters in
135 | length, and answers to 30 characters in length. Note the training data for each
136 | question type is split into "train-easy", "train-medium", and "train-hard". This
137 | allows training models via a curriculum. The data can also be mixed together
138 | uniformly from these training datasets to obtain the results reported in the
139 | paper. Categories:\n
140 | \n
141 | * **algebra** (linear equations, polynomial roots, sequences)\n
142 | * **arithmetic** (pairwise operations and mixed expressions, surds)\n
143 | * **calculus** (differentiation)\n
144 | * **comparison** (closest numbers, pairwise comparisons, sorting)\n
145 | * **measurement** (conversion, working with time)\n
146 | * **numbers** (base conversion, remainders, common divisors and multiples,\n
147 | primality, place value, rounding numbers)\n
148 | * **polynomials** (addition, simplification, composition, evaluating, expansion)\n
149 | * **probability** (sampling without replacement) |
150 |
151 |
152 | provider |
153 |
154 |
155 |
156 |
157 | property |
158 | value |
159 |
160 |
161 | name |
162 | DeepMind |
163 |
164 |
165 | sameAs |
166 | https://en.wikipedia.org/wiki/DeepMind |
167 |
168 |
169 |
170 | |
171 |
172 |
173 | citation |
174 | https://identifiers.org/arxiv:1904.01557 |
175 |
176 |
177 |
178 |
--------------------------------------------------------------------------------
/mathematics_dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mathematics_dataset/example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Containers for "[example] problems" (i.e., question/answer) pairs."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 |
23 | from mathematics_dataset.util import composition
24 |
25 |
26 | def question(context, template, **kwargs):
27 | """Makes a question, using the given context and template.
28 |
29 | The format is similar to that for python's `format` function, for example:
30 |
31 | ```
32 | question(context, 'What is {} plus {p} over {q}?', 2, p=3, q=4)
33 | ```
34 |
35 | The main difference between this and the standard python formatting is that
36 | this understands `Entity`s in the arguments, and will do appropriate expansion
37 | of text and prefixing of their descriptions.
38 |
39 | Arguments:
40 | context: Instance of `composition.Context`, for extracting entities needed
41 | for describing the problem.
42 | template: A string, like "Calculate the value of {exp}.".
43 | **kwargs: A dictionary mapping arguments to values, e.g.,
44 | `{'exp': sympy.Add(2, 3, evaluate=False)}`.
45 |
46 | Returns:
47 | String.
48 | """
49 | assert isinstance(context, composition.Context)
50 | assert isinstance(template, str)
51 | prefix, kwargs = composition.expand_entities(context, **kwargs)
52 | if prefix:
53 | prefix += ' '
54 | return prefix + template.format(**kwargs)
55 |
56 |
57 | Problem = collections.namedtuple('Problem', ('question', 'answer'))
58 |
59 |
60 |
--------------------------------------------------------------------------------
/mathematics_dataset/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Prints to stdout different curriculum questions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import textwrap
23 |
24 | # Dependency imports
25 | from absl import app
26 | from absl import flags
27 | from absl import logging
28 | from mathematics_dataset import generate_settings
29 | from mathematics_dataset.modules import modules
30 | import six
31 | from six.moves import range
32 |
33 |
34 | FLAGS = flags.FLAGS
35 |
36 | flags.DEFINE_string('filter', '', 'restrict to matching module names')
37 | flags.DEFINE_integer('per_train_module', 10, 'Num of examples per train module')
38 | flags.DEFINE_integer('per_test_module', 10, 'Num of examples per test module')
39 | flags.DEFINE_bool('show_dropped', False, 'Whether to print dropped questions')
40 |
41 |
42 | filtered_modules = collections.OrderedDict([])
43 | counts = {}
44 |
45 |
46 | def _make_entropy_fn(level, num_levels):
47 | """This returns a function that returns a subrange of entropy.
48 |
49 | E.g., if level=1 (medium) and num_levels=3, then the returned function will
50 | map the range [x, x + y] to [x + y/3, x + 2y/3].
51 |
52 | Args:
53 | level: Integer in range [0, num_levels - 1].
54 | num_levels: Number of difficulty levels.
55 |
56 | Returns:
57 | Function to restrict entropy range.
58 | """
59 | lower = level / num_levels
60 | upper = (level + 1) / num_levels
61 | def modify_entropy(range_):
62 | assert len(range_) == 2
63 | length = range_[1] - range_[0]
64 | return (range_[0] + lower * length, range_[0] + upper * length)
65 | return modify_entropy
66 |
67 |
68 | def _filter_and_flatten(modules_):
69 | """Returns flattened dict, filtered according to FLAGS."""
70 | flat = collections.OrderedDict()
71 |
72 | def add(submodules, prefix=None):
73 | for key, module_or_function in six.iteritems(submodules):
74 | full_name = prefix + '__' + key if prefix is not None else key
75 | if isinstance(module_or_function, dict):
76 | add(module_or_function, full_name)
77 | else:
78 | if FLAGS.filter not in full_name:
79 | continue
80 | flat[full_name] = module_or_function
81 |
82 | add(modules_)
83 |
84 | # Make sure list of modules are in deterministic order. This is important when
85 | # generating across multiple machines.
86 | flat = collections.OrderedDict(
87 | [(key, flat[key]) for key in sorted(six.iterkeys(flat))])
88 |
89 | return flat
90 |
91 |
92 | def init_modules(train_split=False):
93 | """Inits the dicts containing functions for generating modules."""
94 | if filtered_modules:
95 | return # already initialized
96 |
97 | all_modules = collections.OrderedDict([])
98 | if train_split:
99 | all_modules['train-easy'] = modules.train(_make_entropy_fn(0, 3))
100 | all_modules['train-medium'] = modules.train(_make_entropy_fn(1, 3))
101 | all_modules['train-hard'] = modules.train(_make_entropy_fn(2, 3))
102 | else:
103 | all_modules['train'] = modules.train(_make_entropy_fn(0, 1))
104 |
105 | all_modules['interpolate'] = modules.test()
106 | all_modules['extrapolate'] = modules.test_extra()
107 |
108 | counts['train'] = FLAGS.per_train_module
109 | counts['train-easy'] = FLAGS.per_train_module // 3
110 | counts['train-medium'] = FLAGS.per_train_module // 3
111 | counts['train-hard'] = FLAGS.per_train_module // 3
112 | counts['interpolate'] = FLAGS.per_test_module
113 | counts['extrapolate'] = FLAGS.per_test_module
114 |
115 | for regime_, modules_ in six.iteritems(all_modules):
116 | filtered_modules[regime_] = _filter_and_flatten(modules_)
117 |
118 |
119 | def sample_from_module(module):
120 | """Samples a problem, ignoring samples with overly long questions / answers.
121 |
122 | Args:
123 | module: Callable returning a `Problem`.
124 |
125 | Returns:
126 | Pair `(problem, num_dropped)`, where `problem` is an instance of `Problem`
127 | and `num_dropped` is an integer >= 0 indicating the number of samples that
128 | were dropped.
129 | """
130 | num_dropped = 0
131 | while True:
132 | problem = module()
133 | question = str(problem.question)
134 | if len(question) > generate_settings.MAX_QUESTION_LENGTH:
135 | num_dropped += 1
136 | if FLAGS.show_dropped:
137 | logging.warning('Dropping question: %s', question)
138 | continue
139 | answer = str(problem.answer)
140 | if len(answer) > generate_settings.MAX_ANSWER_LENGTH:
141 | num_dropped += 1
142 | if FLAGS.show_dropped:
143 | logging.warning('Dropping question with answer: %s', answer)
144 | continue
145 | return problem, num_dropped
146 |
147 |
148 | def main(unused_argv):
149 | """Prints Q&As from modules according to FLAGS.filter."""
150 | init_modules()
151 |
152 | text_wrapper = textwrap.TextWrapper(
153 | width=80, initial_indent=' ', subsequent_indent=' ')
154 |
155 | for regime, flat_modules in six.iteritems(filtered_modules):
156 | per_module = counts[regime]
157 | for module_name, module in six.iteritems(flat_modules):
158 | # These magic print constants make the header bold.
159 | print('\033[1m{}/{}\033[0m'.format(regime, module_name))
160 | num_dropped = 0
161 | for _ in range(per_module):
162 | problem, extra_dropped = sample_from_module(module)
163 | num_dropped += extra_dropped
164 | text = text_wrapper.fill(
165 | '{} \033[92m{}\033[0m'.format(problem.question, problem.answer))
166 | print(text)
167 | if num_dropped > 0:
168 | logging.warning('Dropped %d examples', num_dropped)
169 |
170 |
171 | if __name__ == '__main__':
172 | app.run(main)
173 |
--------------------------------------------------------------------------------
/mathematics_dataset/generate_settings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Settings for generation."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import string
23 |
24 | MAX_QUESTION_LENGTH = 160
25 | MAX_ANSWER_LENGTH = 30
26 | QUESTION_CHARS = (
27 | ['', ' '] + list(string.ascii_letters + string.digits + string.punctuation))
28 | EMPTY_INDEX = QUESTION_CHARS.index('')
29 | NUM_INDICES = len(QUESTION_CHARS)
30 | CHAR_TO_INDEX = {char: index for index, char in enumerate(QUESTION_CHARS)}
31 | INDEX_TO_CHAR = {index: char for index, char in enumerate(QUESTION_CHARS)}
32 |
33 |
--------------------------------------------------------------------------------
/mathematics_dataset/generate_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.generate."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Dependency imports
22 | from absl.testing import absltest
23 | from absl.testing import parameterized
24 | from mathematics_dataset import generate
25 | import six
26 | from six.moves import range
27 |
28 |
29 | class GenerateTest(parameterized.TestCase):
30 |
31 | def testMakeEntropyFn(self):
32 | entropy_full = generate._make_entropy_fn(0, 1)
33 | self.assertEqual(entropy_full((2, 3)), (2, 3))
34 | entropy_third = generate._make_entropy_fn(2, 3)
35 | self.assertEqual(entropy_third((3, 6)), (5, 6))
36 |
37 | @parameterized.parameters('train', 'interpolate', 'extrapolate')
38 | def testGenerate(self, regime):
39 | generate.init_modules()
40 | for module in six.itervalues(generate.filtered_modules[regime]):
41 | for _ in range(3):
42 | question = module()
43 | str(question)
44 |
45 |
46 | if __name__ == '__main__':
47 | absltest.main()
48 |
--------------------------------------------------------------------------------
/mathematics_dataset/generate_to_file.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example of how to write generated questions to text files.
16 |
17 | Given an output directory, this will create the following subdirectories:
18 |
19 | * train-easy
20 | * train-medium
21 | * train-hard
22 | * interpolate
23 | * extrapolate
24 |
25 | and populate each of these directories with a text file for each of the module,
26 | where the text file contains lines alternating between the question and the
27 | answer.
28 |
29 | Passing --train_split=False will create a single output directory 'train' for
30 | training data.
31 | """
32 |
33 | from __future__ import absolute_import
34 | from __future__ import division
35 | from __future__ import print_function
36 |
37 | import os
38 |
39 | # Dependency imports
40 | from absl import app
41 | from absl import flags
42 | from absl import logging
43 | from mathematics_dataset import generate
44 | import six
45 | from six.moves import range
46 |
47 | FLAGS = flags.FLAGS
48 |
49 | flags.DEFINE_string('output_dir', None, 'Where to write output text')
50 | flags.DEFINE_boolean('train_split', True,
51 | 'Whether to split training data by difficulty')
52 | flags.mark_flag_as_required('output_dir')
53 |
54 |
55 | def main(unused_argv):
56 | generate.init_modules(FLAGS.train_split)
57 |
58 | output_dir = os.path.expanduser(FLAGS.output_dir)
59 | if os.path.exists(output_dir):
60 | logging.fatal('output dir %s already exists', output_dir)
61 | logging.info('Writing to %s', output_dir)
62 | os.makedirs(output_dir)
63 |
64 | for regime, flat_modules in six.iteritems(generate.filtered_modules):
65 | regime_dir = os.path.join(output_dir, regime)
66 | os.mkdir(regime_dir)
67 | per_module = generate.counts[regime]
68 | for module_name, module in six.iteritems(flat_modules):
69 | path = os.path.join(regime_dir, module_name + '.txt')
70 | with open(path, 'w') as text_file:
71 | for _ in range(per_module):
72 | problem, _ = generate.sample_from_module(module)
73 | text_file.write(str(problem.question) + '\n')
74 | text_file.write(str(problem.answer) + '\n')
75 | logging.info('Written %s', path)
76 |
77 |
78 | if __name__ == '__main__':
79 | app.run(main)
80 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/algebra.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Algebra-related questions, e.g., "Solve 1 + x = 2."."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 | import random
23 |
24 | # Dependency imports
25 | from mathematics_dataset import example
26 | from mathematics_dataset.sample import linear_system
27 | from mathematics_dataset.sample import number
28 | from mathematics_dataset.sample import ops
29 | from mathematics_dataset.sample import polynomials
30 | from mathematics_dataset.util import composition
31 | from mathematics_dataset.util import display
32 | import numpy as np
33 | from six.moves import range
34 | import sympy
35 |
36 |
37 | _ENTROPY_TRAIN = (3, 10)
38 | _ENTROPY_INTERPOLATE = (8, 8)
39 | _ENTROPY_EXTRAPOLATE = (12, 12)
40 |
41 | # In generating a polynomial with real roots (where the roots are generated
42 | # sequentially), this is the probability of taking a previous root, thus giving
43 | # at least one repeated root, rather than sampling a new number. The value is
44 | # somewhat arbitrary, but gives a "medium probability" of seeing a repeated root
45 | # for lowish degree polynomials.
46 | _POLY_PROBABILITY_REPEATED_ROOT = 0.2
47 |
48 |
49 | def _make_modules(entropy):
50 | """Returns modules given "difficulty" parameters."""
51 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy)
52 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy)
53 |
54 | return {
55 | # Solving equations:
56 | 'polynomial_roots': functools.partial(
57 | polynomial_roots, None, sample_args_pure),
58 | 'polynomial_roots_composed': functools.partial(
59 | polynomial_roots, None, sample_args_composed),
60 | 'linear_1d': functools.partial(
61 | solve_linear_1d, None, sample_args_pure),
62 | 'linear_1d_composed': functools.partial(
63 | solve_linear_1d, None, sample_args_composed),
64 | 'linear_2d': functools.partial(
65 | solve_linear_2d, None, sample_args_pure),
66 | 'linear_2d_composed': functools.partial(
67 | solve_linear_2d, None, sample_args_composed),
68 |
69 | # Sequences:
70 | 'sequence_next_term': functools.partial(sequence_next_term, *entropy),
71 | 'sequence_nth_term': functools.partial(sequence_nth_term, *entropy),
72 | }
73 |
74 |
75 | def train(entropy_fn):
76 | """Returns dict of training modules."""
77 | return _make_modules(entropy_fn(_ENTROPY_TRAIN))
78 |
79 |
80 | def test():
81 | """Returns dict of testing modules."""
82 | return _make_modules(_ENTROPY_INTERPOLATE)
83 |
84 |
85 | def test_extra():
86 | """Returns dict of extrapolation testing modules."""
87 | sample_args_pure = composition.PreSampleArgs(1, 1, *_ENTROPY_EXTRAPOLATE)
88 | return {
89 | 'polynomial_roots_big': functools.partial(
90 | polynomial_roots, None, sample_args_pure),
91 | }
92 |
93 |
94 | def _sample_roots(entropy):
95 | """Generates `num_distinct + num_repeated` polynomial roots."""
96 | num_roots = random.randint(2, 5)
97 |
98 | num_repeated = np.random.binomial(
99 | num_roots - 1, _POLY_PROBABILITY_REPEATED_ROOT)
100 | # Slight hack: don't allow all the roots to be repeated when the entropy is
101 | # high, as this can create very large coefficients.
102 | if entropy > 4:
103 | num_repeated = min(num_repeated, int(num_roots / 2))
104 |
105 | num_distinct = num_roots - num_repeated
106 |
107 | entropies = entropy * np.random.dirichlet(np.ones(num_distinct))
108 |
109 | roots = []
110 |
111 | for root_entropy in entropies:
112 | # Generates a root with small probability of being rational.
113 | # (Otherwise when we multiply out the denominators, we get really large
114 | # coefficients in our polynomial.)
115 | if random.random() < 0.1:
116 | root = number.non_integer_rational(root_entropy, True)
117 | else:
118 | root = number.integer(root_entropy, True)
119 | roots.append(root)
120 |
121 | for _ in range(num_repeated):
122 | roots.append(random.choice(roots[:num_distinct]))
123 |
124 | return roots
125 |
126 |
127 | def _polynomial_coeffs_with_roots(roots, scale_entropy):
128 | """Returns a polynomial with the given roots.
129 |
130 | The polynomial is generated by expanding product_{root in roots} (x - root),
131 | and then (1) scaling by the coefficients so they are all integers with lcm 1,
132 | and then (2) further scaling the coefficients by a random integer or rational
133 | with `scale_entropy` digits.
134 |
135 | Args:
136 | roots: List of values.
137 | scale_entropy: Float; entropy of the random coefficient scaling.
138 |
139 | Returns:
140 | List of coefficients `coeffs`, such that `coeffs[i]` is the coefficient of
141 | variable ** i.
142 | """
143 | variable = sympy.Symbol('x') # doesn't matter, only use coefficients
144 | polynomial = sympy.Poly(sympy.prod([variable - root for root in roots]))
145 | coeffs_reversed = polynomial.all_coeffs()
146 | assert len(coeffs_reversed) == len(roots) + 1
147 | coeffs = list(reversed(coeffs_reversed))
148 | # Multiply terms to change rationals to integers, and then maybe reintroduce.
149 | lcm = sympy.lcm([sympy.denom(coeff) for coeff in coeffs])
150 | if scale_entropy > 0:
151 | while True:
152 | scale = number.integer_or_rational(scale_entropy, signed=True)
153 | if scale != 0:
154 | break
155 | else:
156 | scale = 1
157 | return [coeff * scale * lcm for coeff in coeffs]
158 |
159 |
160 | def polynomial_roots(value, sample_args, context=None):
161 | """E.g., "Solve 2*x**2 - 18 = 0."."""
162 | del value # not currently used
163 | # is_question = context is None
164 | if context is None:
165 | context = composition.Context()
166 |
167 | entropy, sample_args = sample_args.peel()
168 | scale_entropy = min(entropy / 2, 1)
169 |
170 | roots = _sample_roots(entropy - scale_entropy)
171 | solutions = sorted(list(sympy.FiniteSet(*roots)))
172 | coeffs = _polynomial_coeffs_with_roots(roots, scale_entropy)
173 | (polynomial_entity,) = context.sample(
174 | sample_args, [composition.Polynomial(coeffs)])
175 |
176 | if random.choice([False, True]):
177 | # Ask for explicit roots.
178 | if len(solutions) == 1:
179 | answer = solutions[0]
180 | else:
181 | answer = display.NumberList(solutions)
182 |
183 | if polynomial_entity.has_expression():
184 | equality = ops.Eq(polynomial_entity.expression, 0)
185 | variable = polynomial_entity.polynomial_variables[0]
186 | else:
187 | variable = sympy.Symbol(context.pop())
188 | equality = ops.Eq(polynomial_entity.handle.apply(variable), 0)
189 | template = random.choice([
190 | 'Let {equality}. What is {variable}?',
191 | 'Let {equality}. Calculate {variable}.',
192 | 'Suppose {equality}. What is {variable}?',
193 | 'Suppose {equality}. Calculate {variable}.',
194 | 'What is {variable} in {equality}?',
195 | 'Solve {equality} for {variable}.',
196 | 'Find {variable} such that {equality}.',
197 | 'Find {variable}, given that {equality}.',
198 | 'Determine {variable} so that {equality}.',
199 | 'Determine {variable}, given that {equality}.',
200 | 'Solve {equality}.'
201 | ])
202 | return example.Problem(
203 | question=example.question(
204 | context, template, equality=equality, variable=variable),
205 | answer=answer)
206 | else:
207 | if polynomial_entity.has_expression():
208 | expression = polynomial_entity.expression
209 | variable = polynomial_entity.polynomial_variables[0]
210 | else:
211 | variable = sympy.Symbol(context.pop())
212 | expression = polynomial_entity.handle.apply(variable)
213 | factored = sympy.factor(
214 | polynomials.coefficients_to_polynomial(coeffs, variable))
215 | template = random.choice([
216 | 'Factor {expression}.',
217 | ])
218 | return example.Problem(
219 | question=example.question(context, template, expression=expression),
220 | answer=factored)
221 |
222 |
223 | def _solve_linear_system(degree, value, sample_args, context=None):
224 | """Solve linear equations."""
225 | is_question = context is None
226 | if context is None:
227 | context = composition.Context()
228 |
229 | entropy, sample_args = sample_args.peel()
230 |
231 | solutions = []
232 | if value is not None:
233 | solutions.append(value)
234 |
235 | extra_solutions_needed = degree - len(solutions)
236 | if extra_solutions_needed > 0:
237 | entropies = (entropy / 4) * np.random.dirichlet(
238 | np.ones(extra_solutions_needed))
239 | entropies = np.maximum(1, entropies) # min per-solution entropy
240 | entropy -= sum(entropies)
241 | solutions += [number.integer(solution_entropy, True)
242 | for solution_entropy in entropies]
243 | entropy = max(1, entropy)
244 |
245 | variables = [sympy.Symbol(context.pop()) for _ in range(degree)]
246 |
247 | solution_index = 0
248 | # If we're going to be creating a linear system with constants to replace by
249 | # handles from other modules, then we need a linear system with constants
250 | # occurring. Very occasionally this can fail to happen, e.g., "x = -x";
251 | # normally this while loop will only see one iteration.
252 | while True:
253 | equations = linear_system.linear_system(
254 | variables=variables, solutions=solutions, entropy=entropy,
255 | non_trivial_in=solution_index)
256 | constants = ops.number_constants(equations)
257 | if sample_args.num_modules <= 1 or constants:
258 | break
259 |
260 | context.sample_by_replacing_constants(sample_args, equations)
261 |
262 | variable = variables[solution_index]
263 | answer = solutions[solution_index]
264 |
265 | equations = ', '.join([str(equation) for equation in equations])
266 |
267 | if is_question:
268 | template = random.choice([
269 | 'Solve {equations} for {variable}.',
270 | ])
271 | return example.Problem(
272 | example.question(
273 | context, template, equations=equations,
274 | variable=variable),
275 | answer)
276 | else:
277 | return composition.Entity(
278 | context=context,
279 | value=answer,
280 | description='Suppose {equations}.',
281 | handle=variable,
282 | equations=equations)
283 |
284 |
285 | @composition.module(number.is_integer)
286 | def solve_linear_1d(*args, **kwargs):
287 | return _solve_linear_system(1, *args, **kwargs)
288 |
289 |
290 | @composition.module(number.is_integer)
291 | def solve_linear_2d(*args, **kwargs):
292 | return _solve_linear_system(2, *args, **kwargs)
293 |
294 |
295 | class _PolynomialSequence(object):
296 | """A sequence given by a polynomial."""
297 |
298 | def __init__(self, variable, entropy, min_degree=1, max_degree=3):
299 | """Initializes a random polynomial sequence.
300 |
301 | Args:
302 | variable: Variable to use.
303 | entropy: Entropy for polynomial coefficients.
304 | min_degree: Minimum order of polynomial.
305 | max_degree: Maximum order of polynomial.
306 | """
307 | self._degree = random.randint(min_degree, max_degree)
308 | self._variable = variable
309 | polynomial = polynomials.sample_with_small_evaluation(
310 | variable=self._variable, degree=self._degree,
311 | max_abs_input=self._degree + 2, entropy=entropy)
312 | self._sympy = polynomial.sympy()
313 |
314 | @property
315 | def min_num_terms(self):
316 | """Returns the minimum number of terms to identify the sequence.
317 |
318 | This assumes a human-like prior over types of sequences.
319 |
320 | Returns:
321 | Integer >= 1.
322 | """
323 | return self._degree + 2
324 |
325 | @property
326 | def sympy(self):
327 | return self._sympy
328 |
329 | def term(self, n):
330 | """Returns the `n`th term of the sequence."""
331 | return self._sympy.subs(self._variable, n)
332 |
333 |
334 | def sequence_next_term(min_entropy, max_entropy):
335 | """E.g., "What is the next term in the sequence 1, 2, 3?"."""
336 | entropy = random.uniform(min_entropy, max_entropy)
337 | context = composition.Context()
338 | variable = sympy.Symbol(context.pop())
339 |
340 | sequence = _PolynomialSequence(variable, entropy)
341 | min_num_terms = sequence.min_num_terms
342 | num_terms = random.randint(min_num_terms, min_num_terms + 3)
343 | sequence_sample = [sequence.term(n + 1) for n in range(num_terms)]
344 | sequence_sample = display.NumberList(sequence_sample)
345 |
346 | template = random.choice([
347 | 'What is next in {sequence}?',
348 | 'What comes next: {sequence}?',
349 | 'What is the next term in {sequence}?',
350 | ])
351 | answer = sequence.term(num_terms + 1)
352 |
353 | return example.Problem(
354 | question=example.question(context, template, sequence=sequence_sample),
355 | answer=answer)
356 |
357 |
358 | def sequence_nth_term(min_entropy, max_entropy):
359 | """E.g., "What is the nth term in the sequence 1, 2, 3?"."""
360 | entropy = random.uniform(min_entropy, max_entropy)
361 | context = composition.Context()
362 | variable = sympy.Symbol(context.pop())
363 |
364 | sequence = _PolynomialSequence(variable, entropy)
365 | min_num_terms = sequence.min_num_terms
366 | num_terms = random.randint(min_num_terms, min_num_terms + 3)
367 | sequence_sample = [sequence.term(n + 1) for n in range(num_terms)]
368 | sequence_sample = display.NumberList(sequence_sample)
369 |
370 | template = random.choice([
371 | 'What is the {variable}\'th term of {sequence}?',
372 | ])
373 | answer = sequence.sympy
374 |
375 | return example.Problem(
376 | question=example.question(
377 | context, template, variable=variable, sequence=sequence_sample),
378 | answer=answer)
379 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/algebra_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.modules.algebra."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 |
23 | # Dependency imports
24 | from absl.testing import absltest
25 | from mathematics_dataset.modules import algebra
26 | from mathematics_dataset.sample import polynomials
27 | from six.moves import range
28 | import sympy
29 |
30 |
31 | class AlgebraTest(absltest.TestCase):
32 |
33 | def testPolynomialCoeffsWithRoots(self):
34 | coeffs = algebra._polynomial_coeffs_with_roots([1, 2], scale_entropy=0.0)
35 | self.assertEqual(coeffs, [2, -3, 1])
36 |
37 | def testPolynomialRoots(self):
38 | variable = sympy.Symbol('x')
39 | for _ in range(10):
40 | roots = random.sample(list(range(-9, 10)), 3)
41 | coeffs = algebra._polynomial_coeffs_with_roots(roots, scale_entropy=10.0)
42 | polynomial = polynomials.coefficients_to_polynomial(coeffs, variable)
43 | calc_roots = sympy.polys.polytools.real_roots(polynomial)
44 | self.assertEqual(calc_roots, sorted(roots))
45 |
46 |
47 | if __name__ == '__main__':
48 | absltest.main()
49 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/arithmetic_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.modules.arithmetic."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Dependency imports
22 | from absl.testing import absltest
23 | from mathematics_dataset.modules import arithmetic
24 | import sympy
25 |
26 |
27 | class ArithmeticTest(absltest.TestCase):
28 |
29 | def testSurdCoefficients(self):
30 | exp = sympy.sympify('1')
31 | self.assertEqual(arithmetic._surd_coefficients(exp),
32 | (1, 0))
33 |
34 | exp = sympy.sympify('1/2')
35 | self.assertEqual(arithmetic._surd_coefficients(exp),
36 | (1/2, 0))
37 |
38 | exp = sympy.sympify('sqrt(2)')
39 | self.assertEqual(arithmetic._surd_coefficients(exp),
40 | (0, 1))
41 |
42 | exp = sympy.sympify('3*sqrt(2)')
43 | self.assertEqual(arithmetic._surd_coefficients(exp),
44 | (0, 3))
45 |
46 | exp = sympy.sympify('3*sqrt(5)/2')
47 | self.assertEqual(arithmetic._surd_coefficients(exp),
48 | (0, 3/2))
49 |
50 | exp = sympy.sympify('1 + 3 * sqrt(2)')
51 | self.assertEqual(arithmetic._surd_coefficients(exp),
52 | (1, 3))
53 |
54 | exp = sympy.sympify('1/2 + 3 * sqrt(5) / 2')
55 | self.assertEqual(arithmetic._surd_coefficients(exp),
56 | (1/2, 3/2))
57 |
58 | exp = sympy.sympify('sqrt(2)/(-1 + 2*sqrt(2))**2')
59 | self.assertEqual(arithmetic._surd_coefficients(exp),
60 | (8/49, 9/49))
61 |
62 |
63 | if __name__ == '__main__':
64 | absltest.main()
65 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/calculus.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Calculus related questions, e.g., "differentiate x**2"."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 | import math
23 | import random
24 |
25 | # Dependency imports
26 | from mathematics_dataset import example
27 | from mathematics_dataset.sample import polynomials
28 | from mathematics_dataset.util import composition
29 | from mathematics_dataset.util import display
30 | import numpy as np
31 | from six.moves import range
32 | import sympy
33 |
34 |
35 | _ENTROPY_TRAIN = (3, 10)
36 | _ENTROPY_INTERPOLATE = (8, 8)
37 |
38 |
39 | def _make_modules(entropy):
40 | """Returns modules given "difficulty" parameters."""
41 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy)
42 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy)
43 |
44 | return {
45 | 'differentiate_composed': functools.partial(
46 | differentiate_univariate, None, sample_args_composed),
47 | 'differentiate': functools.partial(differentiate, None, sample_args_pure),
48 | }
49 |
50 |
51 | def train(entropy_fn):
52 | """Returns dict of training modules."""
53 | return _make_modules(entropy_fn(_ENTROPY_TRAIN))
54 |
55 |
56 | def test():
57 | """Returns dict of testing modules."""
58 | return _make_modules(_ENTROPY_INTERPOLATE)
59 |
60 |
61 | def test_extra():
62 | """Returns dict of extrapolation testing modules."""
63 | return {
64 | }
65 |
66 |
67 | def _generate_polynomial(num_variables, entropy, derivative_order,
68 | derivative_axis):
69 | """Returns polynomial."""
70 | # Note: numpy randint has upper bound as ) not ], unlike python random.randint
71 | degrees = np.random.randint(1, 4, [num_variables])
72 | degrees[derivative_axis] = np.random.randint(0, 4) # allow to be zero here.
73 |
74 | coefficients = polynomials.sample_coefficients(degrees, entropy)
75 |
76 | # We also generate coefficients that will disappear when differentiated.
77 | # Thus we don't account for the entropy used here.
78 | assert derivative_order > 0
79 | degrees[derivative_axis] = derivative_order - 1
80 | extra_coefficients = polynomials.sample_coefficients(degrees, entropy)
81 |
82 | return np.concatenate(
83 | [extra_coefficients, coefficients], axis=derivative_axis)
84 |
85 |
86 | def _template(module_count, derivative_order, num_variables):
87 | """Selects appropriate template."""
88 | templates = [
89 | 'Find the {nth} derivative of {eq} wrt {var}.',
90 | 'What is the {nth} derivative of {eq} wrt {var}?',
91 | ]
92 | if derivative_order == 1:
93 | templates += [
94 | 'Differentiate {eq} with respect to {var}.',
95 | 'Differentiate {eq} wrt {var}.',
96 | 'What is the derivative of {eq} wrt {var}?',
97 | ]
98 |
99 | derivative_variable_is_unambiguous = num_variables == 1 and module_count == 1
100 | if derivative_variable_is_unambiguous:
101 | templates += [
102 | 'Find the {nth} derivative of {eq}.',
103 | 'What is the {nth} derivative of {eq}?',
104 | ]
105 | if derivative_order == 1:
106 | templates += [
107 | 'Differentiate {eq}.',
108 | 'What is the derivative of {eq}?',
109 | ]
110 |
111 | return random.choice(templates)
112 |
113 |
114 | def _sample_integrand(coefficients, derivative_order, derivative_axis, entropy):
115 | """Integrates `coefficients` and adds sampled "constant" terms."""
116 | coefficients = np.asarray(coefficients)
117 |
118 | # Integrate (with zero for constant terms).
119 | integrand = coefficients
120 | for _ in range(derivative_order):
121 | integrand = polynomials.integrate(integrand, derivative_axis)
122 |
123 | # Add on sampled constant terms.
124 | constant_degrees = np.array(integrand.shape) - 1
125 | constant_degrees[derivative_axis] = derivative_order - 1
126 | extra_coeffs = polynomials.sample_coefficients(constant_degrees, entropy)
127 | pad_amount = coefficients.shape[derivative_axis]
128 | pad = [(0, pad_amount if i == derivative_axis else 0)
129 | for i in range(coefficients.ndim)]
130 | extra_coeffs = np.pad(extra_coeffs, pad, 'constant', constant_values=0)
131 | return integrand + extra_coeffs
132 |
133 |
134 | def _differentiate_polynomial(value, sample_args, context, num_variables):
135 | """Generates a question for differentiating a polynomial."""
136 | is_question = context is None
137 | if context is None:
138 | context = composition.Context()
139 |
140 | if value is not None:
141 | num_variables = value.coefficients.ndim
142 |
143 | entropy, sample_args = sample_args.peel()
144 | max_derivative_order = 3
145 | derivative_order = random.randint(1, max_derivative_order)
146 | entropy = max(0, entropy - math.log10(max_derivative_order))
147 |
148 | derivative_axis = random.randint(0, num_variables - 1)
149 | if value is None:
150 | coefficients = _generate_polynomial(
151 | num_variables, entropy, derivative_order, derivative_axis)
152 | else:
153 | coefficients = _sample_integrand(
154 | value.coefficients, derivative_order, derivative_axis, entropy)
155 |
156 | (entity,) = context.sample(
157 | sample_args, [composition.Polynomial(coefficients)])
158 |
159 | value = coefficients
160 | for _ in range(derivative_order):
161 | value = polynomials.differentiate(value, axis=derivative_axis)
162 | nth = display.StringOrdinal(derivative_order)
163 |
164 | if entity.has_expression():
165 | polynomial = entity.expression
166 | variables = entity.polynomial_variables
167 | else:
168 | variables = [sympy.Symbol(context.pop()) for _ in range(num_variables)]
169 | polynomial = entity.handle.apply(*variables)
170 | variable = variables[derivative_axis]
171 |
172 | if is_question:
173 | template = _template(context.module_count, derivative_order, len(variables))
174 | answer = polynomials.coefficients_to_polynomial(value, variables).sympy()
175 | return example.Problem(
176 | question=example.question(
177 | context, template, eq=polynomial, var=variable, nth=nth),
178 | answer=answer)
179 | else:
180 | fn_symbol = context.pop()
181 | variables_string = ', '.join(str(variable) for variable in variables)
182 | assert len(variables) == 1 # since below we don't specify var we diff wrt
183 | return composition.Entity(
184 | context=context,
185 | value=composition.Polynomial(value),
186 | description='Let {fn}({variables}) be the {nth} derivative of {eq}.',
187 | handle=composition.FunctionHandle(fn_symbol),
188 | fn=fn_symbol, variables=variables_string, nth=nth, eq=polynomial)
189 |
190 |
191 | def differentiate_univariate(value, sample_args, context=None):
192 | return _differentiate_polynomial(value, sample_args, context, 1)
193 |
194 |
195 | @composition.module(composition.is_polynomial)
196 | def differentiate(value, sample_args, context=None):
197 | num_variables = random.randint(1, 4)
198 | return _differentiate_polynomial(value, sample_args, context, num_variables)
199 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/calculus_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.modules.calculus."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Dependency imports
22 | from mathematics_dataset.modules import calculus
23 | import tensorflow as tf
24 |
25 |
26 | class CalculusTest(tf.test.TestCase):
27 |
28 | def testSampleIntegrand(self):
29 | # y + 2*x + 3*x**2
30 | coefficients = [[0, 1], [2, 0], [3, 0]]
31 | derivative_order = 1
32 | derivative_axis = 0
33 | # const + x*y + x**2 + x**3
34 | expected = [[0, 1], [1, 0], [1, 0]]
35 | entropy = 4
36 | result = calculus._sample_integrand(
37 | coefficients, derivative_order, derivative_axis, entropy)
38 | result = result[1:, :] # ignore random constant terms
39 | self.assertAllEqual(result, expected)
40 |
41 |
42 | if __name__ == '__main__':
43 | tf.test.main()
44 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/comparison.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Comparisons, e.g. "is 2 > 3?"."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 | import random
23 |
24 | # Dependency imports
25 | from mathematics_dataset import example
26 | from mathematics_dataset.sample import number
27 | from mathematics_dataset.sample import ops
28 | from mathematics_dataset.util import composition
29 | from mathematics_dataset.util import display
30 | import numpy as np
31 | from six.moves import range
32 | import sympy
33 |
34 |
35 | _ENTROPY_TRAIN = (3, 10)
36 | _ENTROPY_INTERPOLATE = (8, 8)
37 | _ENTROPY_EXTRAPOLATE = (12, 12)
38 |
39 | _EXTRAPOLATION_EXTRA_COUNT = 2
40 |
41 | _PROB_EQUAL = 0.2
42 |
43 |
44 | def _make_modules(entropy):
45 | """Returns modules given "difficulty" parameters."""
46 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy)
47 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy)
48 |
49 | return {
50 | 'pair': functools.partial(pair, sample_args_pure),
51 | 'pair_composed': functools.partial(pair, sample_args_composed),
52 | 'kth_biggest': functools.partial(kth_biggest, sample_args_pure),
53 | 'kth_biggest_composed': functools.partial(
54 | kth_biggest, sample_args_composed),
55 | 'closest': functools.partial(closest, sample_args_pure),
56 | 'closest_composed': functools.partial(closest, sample_args_composed),
57 | 'sort': functools.partial(sort, sample_args_pure),
58 | 'sort_composed': functools.partial(sort, sample_args_composed),
59 | }
60 |
61 |
62 | def train(entropy_fn):
63 | """Returns dict of training modules."""
64 | return _make_modules(entropy_fn(_ENTROPY_TRAIN))
65 |
66 |
67 | def test():
68 | """Returns dict of testing modules."""
69 | return _make_modules(_ENTROPY_INTERPOLATE)
70 |
71 |
72 | def test_extra():
73 | """Returns dict of extrapolation testing modules."""
74 | sample_args_pure = composition.PreSampleArgs(1, 1, *_ENTROPY_EXTRAPOLATE)
75 |
76 | def sort_count():
77 | lower = _sort_count_range(_ENTROPY_TRAIN[1])[1]
78 | return random.randint(lower + 1, lower + _EXTRAPOLATION_EXTRA_COUNT)
79 | def closest_count():
80 | lower = _closest_count_range(_ENTROPY_TRAIN[1])[1]
81 | return random.randint(lower + 1, lower + _EXTRAPOLATION_EXTRA_COUNT)
82 | def kth_biggest_more():
83 | return kth_biggest(sample_args_pure, count=sort_count())
84 | def sort_more():
85 | return sort(sample_args_pure, count=sort_count())
86 | def closest_more():
87 | return closest(sample_args_pure, count=closest_count())
88 |
89 | return {
90 | 'kth_biggest_more': kth_biggest_more,
91 | 'sort_more': sort_more,
92 | 'closest_more': closest_more,
93 | }
94 |
95 |
96 | def _make_comparison_question(context, left, right):
97 | """Makes a question for comparing two values."""
98 | if random.choice([False, True]) and sympy.Ne(left.value, right.value):
99 | # Do question of form: "Which is bigger: a or b?".
100 | if random.choice([False, True]):
101 | answer = (
102 | left.handle if sympy.Gt(left.value, right.value) else right.handle)
103 | template = random.choice([
104 | 'Which is bigger: {left} or {right}?',
105 | 'Which is greater: {left} or {right}?',
106 | ])
107 | else:
108 | answer = (
109 | left.handle if sympy.Lt(left.value, right.value) else right.handle)
110 | template = random.choice([
111 | 'Which is smaller: {left} or {right}?',
112 | ])
113 | return example.Problem(
114 | question=example.question(context, template, left=left, right=right),
115 | answer=answer)
116 |
117 | comparisons = {
118 | '<': sympy.Lt,
119 | '<=': sympy.Le,
120 | '>': sympy.Gt,
121 | '>=': sympy.Ge,
122 | '=': sympy.Eq,
123 | '!=': sympy.Ne,
124 | }
125 |
126 | templates = {
127 | '<': [
128 | 'Is {left} ' + ops.LT_SYMBOL + ' {right}?',
129 | 'Is {left} less than {right}?',
130 | 'Is {left} smaller than {right}?',
131 | ],
132 | '<=': [
133 | 'Is {left} ' + ops.LE_SYMBOL + ' {right}?',
134 | 'Is {left} less than or equal to {right}?',
135 | 'Is {left} at most {right}?',
136 | 'Is {left} at most as big as {right}?',
137 | ],
138 | '>': [
139 | 'Is {left} ' + ops.GT_SYMBOL + ' {right}?',
140 | 'Is {left} greater than {right}?',
141 | 'Is {left} bigger than {right}?',
142 | ],
143 | '>=': [
144 | 'Is {left} ' + ops.GE_SYMBOL + ' {right}?',
145 | 'Is {left} greater than or equal to {right}?',
146 | 'Is {left} at least {right}?',
147 | 'Is {left} at least as big as {right}?',
148 | ],
149 | '=': [
150 | 'Does {left} ' + ops.EQ_SYMBOL + ' {right}?',
151 | 'Are {left} and {right} equal?',
152 | 'Is {left} equal to {right}?',
153 | 'Do {left} and {right} have the same value?',
154 | ],
155 | '!=': [
156 | 'Is {left} ' + ops.NE_SYMBOL + ' {right}?',
157 | 'Is {left} not equal to {right}?',
158 | 'Are {left} and {right} unequal?',
159 | 'Are {left} and {right} nonequal?',
160 | 'Are {left} and {right} non-equal?',
161 | 'Do {left} and {right} have different values?',
162 | ],
163 | }
164 |
165 | comparison = random.choice(list(comparisons.keys()))
166 | template = random.choice(templates[comparison])
167 | question = example.question(context, template, left=left, right=right)
168 | answer = comparisons[comparison](left.value, right.value)
169 |
170 | return example.Problem(question=question, answer=answer)
171 |
172 |
173 | def integer_or_rational_or_decimal(entropy):
174 | if random.choice([False, True]):
175 | return number.integer_or_decimal(entropy, signed=True)
176 | else:
177 | return number.integer_or_rational(entropy, signed=True)
178 |
179 |
180 | def pair(sample_args, context=None):
181 | """Compares two numbers, e.g., "is 1/2 < 0.5?"."""
182 | if context is None:
183 | context = composition.Context()
184 | entropy, sample_args = sample_args.peel()
185 |
186 | def integers_close():
187 | entropy_diff, entropy_left = entropy * np.random.dirichlet([1, 3])
188 | left = number.integer(entropy_left, True)
189 | right = left + number.integer(entropy_diff, True)
190 | return left, right
191 |
192 | def rational_and_integer():
193 | # Pick rational, and integer close to rational evaluation
194 | left = number.non_integer_rational(entropy, True)
195 | right = int(round(left)) + random.randint(-1, 1)
196 | return left, right
197 |
198 | def independent():
199 | # Return an independent pair.
200 | entropy_left, entropy_right = entropy * np.random.dirichlet([1, 1])
201 | left = integer_or_rational_or_decimal(entropy_left)
202 | right = integer_or_rational_or_decimal(entropy_right)
203 | return left, right
204 |
205 | generator = random.choice([integers_close, rational_and_integer, independent])
206 |
207 | left, right = generator()
208 |
209 | # maybe swap for symmetry
210 | if random.choice([False, True]):
211 | left, right = right, left
212 | left, right = context.sample(sample_args, [left, right])
213 |
214 | return _make_comparison_question(context, left, right)
215 |
216 |
217 | def _entities_to_list(entities):
218 | entity_dict = {}
219 | values_template = ''
220 | for i, entity in enumerate(entities):
221 | if i > 0:
222 | values_template += ', '
223 | entity_name = 'entity_{}'.format(i)
224 | entity_dict[entity_name] = entity
225 | values_template += '{' + entity_name + '}'
226 | return entity_dict, values_template
227 |
228 |
229 | def _entities_to_choices(entities, answer):
230 | """Generate a multichoice question template."""
231 | if len(entities) > 26:
232 | raise ValueError('Too many choices: {}'.format(len(entities)))
233 |
234 | entity_dict = {}
235 | choices_template = ''
236 | answer_choice = None
237 | for i, entity in enumerate(entities):
238 | choices_template += ' '
239 | entity_name = 'entity_{}'.format(i)
240 | entity_dict[entity_name] = entity
241 | letter = chr(ord('a') + i)
242 | choices_template += '({letter}) {{{entity_name}}}'.format(
243 | letter=letter, entity_name=entity_name)
244 | if entity is answer:
245 | assert answer_choice is None
246 | answer_choice = letter
247 |
248 | assert answer_choice is not None
249 | return entity_dict, choices_template, answer_choice
250 |
251 |
252 | def _mark_choice_letters_used(count, context):
253 | """Marks the choice letters as used."""
254 | for i in range(count):
255 | context.mark_used(chr(ord('a') + i))
256 |
257 |
258 | def _kth_biggest_list_question(context, entities, adjective, answer):
259 | """Ask for the biggest (or smallest, or second biggest, etc) in a list."""
260 | entity_dict, values_template = _entities_to_list(entities)
261 |
262 | question = example.question(
263 | context, 'What is the {adjective} value in ' + values_template + '?',
264 | adjective=adjective, **entity_dict)
265 | return example.Problem(question=question, answer=answer.handle)
266 |
267 |
268 | def _kth_biggest_multichoice_question(context, entities, adjective, answer):
269 | """Ask for the biggest (or smallest, or second biggest, etc) of choices."""
270 | entity_dict, choices_template, answer_choice = _entities_to_choices(
271 | entities, answer)
272 | question = example.question(
273 | context, 'Which is the {adjective} value?' + choices_template,
274 | adjective=adjective, **entity_dict)
275 | return example.Problem(question=question, answer=answer_choice)
276 |
277 |
278 | def _entity_sort_key(entity):
279 | return sympy.default_sort_key(entity.value)
280 |
281 |
282 | def _sort_count_range(entropy):
283 | min_ = 3
284 | return min_, min_ + int(entropy/2)
285 |
286 |
287 | def _unique_values(entropy, only_integers=False, count=None):
288 | """Generates unique values."""
289 | if count is None:
290 | count = random.randint(*_sort_count_range(entropy))
291 |
292 | if only_integers:
293 | sampler = functools.partial(number.integer, signed=True)
294 | else:
295 | sampler = integer_or_rational_or_decimal
296 |
297 | for _ in range(1000):
298 | entropies = entropy * np.random.dirichlet(np.ones(count))
299 | entropies = np.maximum(1, entropies)
300 | values = [sampler(ent) for ent in entropies]
301 | if len(sympy.FiniteSet(*values)) == len(values):
302 | return values
303 | raise ValueError('Could not generate {} unique values with entropy={}'
304 | .format(count, entropy))
305 |
306 |
307 | def kth_biggest(sample_args, count=None):
308 | """Asks for the kth biggest value in a list."""
309 | sample_args = sample_args()
310 | context = composition.Context()
311 |
312 | entropy, sample_args = sample_args.peel()
313 | values = _unique_values(entropy, count=count)
314 | count = len(values)
315 |
316 | display_multichoice = random.choice([False, True])
317 | if display_multichoice:
318 | _mark_choice_letters_used(count, context)
319 |
320 | entities = context.sample(sample_args, values)
321 | sorted_entities = sorted(entities, key=_entity_sort_key)
322 | ordinal = random.randint(1, count)
323 |
324 | if random.choice([False, True]):
325 | # Do from biggest.
326 | answer = sorted_entities[-ordinal]
327 | adjective = 'biggest'
328 | else:
329 | # Do from smallest.
330 | answer = sorted_entities[ordinal - 1]
331 | adjective = 'smallest'
332 |
333 | if ordinal > 1:
334 | adjective = str(display.StringOrdinal(ordinal)) + ' ' + adjective
335 |
336 | if display_multichoice:
337 | return _kth_biggest_multichoice_question(
338 | context=context, entities=entities, adjective=adjective, answer=answer)
339 | else:
340 | return _kth_biggest_list_question(
341 | context=context, entities=entities, adjective=adjective, answer=answer)
342 |
343 |
344 | def _closest_in_list_question(context, entities, target, adjective, answer):
345 | """Ask for the closest to a given value in a list."""
346 | entity_dict, values_template = _entities_to_list(entities)
347 |
348 | question = example.question(
349 | context,
350 | 'What is the {adjective} to {target} in ' + values_template + '?',
351 | adjective=adjective, target=target, **entity_dict)
352 | return example.Problem(question=question, answer=answer.handle)
353 |
354 |
355 | def _closest_multichoice_question(context, entities, target, adjective, answer):
356 | """Ask for the closest to a given value in a set of choices."""
357 | entity_dict, choices_template, answer_choice = _entities_to_choices(
358 | entities, answer)
359 |
360 | question = example.question(
361 | context,
362 | 'Which is the {adjective} to {target}?' + choices_template,
363 | adjective=adjective, target=target, **entity_dict)
364 | return example.Problem(question=question, answer=answer_choice)
365 |
366 |
367 | def _closest_count_range(entropy):
368 | min_ = 3
369 | return min_, min_ + int(entropy/3)
370 |
371 |
372 | def closest(sample_args, count=None):
373 | """Ask for the closest to a given value in a list."""
374 | sample_args = sample_args()
375 | context = composition.Context()
376 |
377 | entropy, sample_args = sample_args.peel()
378 | if count is None:
379 | count = random.randint(*_closest_count_range(entropy))
380 |
381 | display_multichoice = random.choice([False, True])
382 | if display_multichoice:
383 | _mark_choice_letters_used(count, context)
384 |
385 | entropy_target, entropy_list = entropy * np.random.dirichlet([1, count])
386 | target = integer_or_rational_or_decimal(entropy_target)
387 |
388 | while True:
389 | value_entropies = entropy_list * np.random.dirichlet(np.ones(count))
390 | value_entropies = np.maximum(1, value_entropies)
391 | values = [integer_or_rational_or_decimal(ent) for ent in value_entropies]
392 | differences = [abs(sympy.sympify(value) - target) for value in values]
393 | if len(sympy.FiniteSet(*differences)) == count: # all differences unique
394 | break
395 |
396 | target_and_entities = context.sample(sample_args, [target] + values)
397 | target = target_and_entities[0]
398 | entities = target_and_entities[1:]
399 |
400 | min_difference = min(differences)
401 | answer_index = differences.index(min_difference)
402 | answer = entities[answer_index]
403 | adjective = random.choice(['closest', 'nearest'])
404 |
405 | if display_multichoice:
406 | return _closest_multichoice_question(
407 | context=context, entities=entities, target=target, adjective=adjective,
408 | answer=answer)
409 | else:
410 | return _closest_in_list_question(
411 | context=context, entities=entities, target=target, adjective=adjective,
412 | answer=answer)
413 |
414 |
415 | def sort(sample_args, count=None):
416 | """Ask to sort numbers in increasing or decreasing order."""
417 | sample_args = sample_args()
418 | context = composition.Context()
419 |
420 | entropy, sample_args = sample_args.peel()
421 | # Sometimes just integers, to allow for more terms in a short space.
422 | values = _unique_values(
423 | entropy, only_integers=random.choice([False, True]), count=count)
424 |
425 | entities = context.sample(sample_args, values)
426 |
427 | unsorted_dict, unsorted_template = _entities_to_list(entities)
428 |
429 | ascending = random.choice([False, True])
430 | templates = [
431 | 'Sort ' + unsorted_template + ' in {direction} order.',
432 | 'Put ' + unsorted_template + ' in {direction} order.',
433 | ]
434 | if ascending:
435 | templates.append('Sort ' + unsorted_template + '.')
436 | direction = random.choice(['ascending', 'increasing'])
437 | else:
438 | direction = random.choice(['descending', 'decreasing'])
439 | template = random.choice(templates)
440 |
441 | sorted_entities = sorted(
442 | entities, key=_entity_sort_key, reverse=(not ascending))
443 | answer = ''
444 | for i, entity in enumerate(sorted_entities):
445 | if i > 0:
446 | answer += ', '
447 | answer += str(entity.handle)
448 |
449 | return example.Problem(
450 | question=example.question(
451 | context, template, direction=direction, **unsorted_dict),
452 | answer=answer)
453 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/measurement.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Measurement questions, e.g., "How many hours are there in a day?"."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import functools
23 | import random
24 |
25 | # Dependency imports
26 | from mathematics_dataset import example
27 | from mathematics_dataset.modules import train_test_split
28 | from mathematics_dataset.sample import number
29 | from mathematics_dataset.util import composition
30 | from mathematics_dataset.util import display
31 | import six
32 | import sympy
33 |
34 |
35 | def _make_modules(is_train):
36 | """Returns modules, with split based on the boolean `is_train`."""
37 | return {
38 | 'conversion': functools.partial(
39 | conversion, is_train=is_train, is_extrapolation=False),
40 | 'time': functools.partial(time, is_train=is_train),
41 | }
42 |
43 |
44 | def train(entropy_fn):
45 | """Returns dict of training modules."""
46 | del entropy_fn # unused
47 | return _make_modules(is_train=True)
48 |
49 |
50 | def test():
51 | """Returns dict of testing modules."""
52 | return _make_modules(is_train=False)
53 |
54 |
55 | def test_extra():
56 | """Returns dict of extrapolation testing modules."""
57 | return {
58 | 'conversion': functools.partial(
59 | conversion, is_train=False, is_extrapolation=True),
60 | }
61 |
62 |
63 | Unit = collections.namedtuple('Unit', ('name', 'symbol'))
64 |
65 |
66 | MICRO_SYMBOL = 'u'
67 |
68 |
69 | LENGTH = {
70 | Unit('meter', 'm'): 1,
71 | Unit('kilometer', 'km'): 1000,
72 | Unit('centimeter', 'cm'): sympy.Rational(1, 100),
73 | Unit('millimeter', 'mm'): sympy.Rational(1, 1000),
74 | Unit('micrometer', 'um'): sympy.Rational(1, 1e6),
75 | Unit('nanometer', 'nm'): sympy.Rational(1, 1e9),
76 | }
77 |
78 | TIME = {
79 | Unit('second', 's'): 1,
80 | Unit('minute', None): 60,
81 | Unit('hour', None): 60*60,
82 | Unit('day', None): 24*60*60,
83 | Unit('week', None): 7*24*60*60,
84 | Unit('millisecond', 'ms'): sympy.Rational(1, 1e3),
85 | Unit('microsecond', MICRO_SYMBOL + 's'): sympy.Rational(1, 1e6),
86 | Unit('nanosecond', 'ns'): sympy.Rational(1, 1e9),
87 | }
88 |
89 | TIME_YEARLY = {
90 | Unit('year', None): 1,
91 | Unit('decade', None): 10,
92 | Unit('century', None): 100,
93 | Unit('millennium', None): 1000,
94 | Unit('month', None): sympy.Rational(1, 12),
95 | }
96 |
97 | MASS = {
98 | Unit('kilogram', 'kg'): 1, # Yes, the *kilo*gram is the SI base unit.
99 | Unit('tonne', 't'): 1000,
100 | Unit('gram', 'g'): sympy.Rational(1, 1e3),
101 | Unit('milligram', 'mg'): sympy.Rational(1, 1e6),
102 | Unit('microgram', MICRO_SYMBOL + 'g'): sympy.Rational(1, 1e9),
103 | Unit('nanogram', 'ng'): sympy.Rational(1, 1e12),
104 | }
105 |
106 | VOLUME = {
107 | Unit('litre', 'l'): 1,
108 | Unit('millilitre', 'ml'): sympy.Rational(1, 1000),
109 | }
110 |
111 |
112 | DIMENSIONS = [LENGTH, TIME, TIME_YEARLY, MASS, VOLUME]
113 |
114 |
115 | def pluralize(name):
116 | if name == 'century':
117 | return 'centuries'
118 | if name == 'millennium':
119 | return 'millennia'
120 | return name + 's'
121 |
122 |
123 | def _factor_non_decimal(value):
124 | """Extras x dividing value such that x is coprime to 2 and 5."""
125 | result = 1
126 | factors = sympy.factorint(value)
127 | for factor, power in six.iteritems(factors):
128 | if factor not in [2, 5]:
129 | result *= factor ** power
130 | return result
131 |
132 |
133 | def _sample_conversion_decimal(dimension, is_extrapolation):
134 | """Samples to and from units and values."""
135 | base_unit, target_unit = random.sample(list(dimension.keys()), 2)
136 | scale = sympy.Rational(dimension[base_unit]) / dimension[target_unit]
137 | scale_non_decimal = _factor_non_decimal(sympy.denom(scale))
138 | entropy = 9 if is_extrapolation else 7
139 | base_value = number.non_integer_decimal(entropy, signed=False)
140 | base_value = display.Decimal(base_value.value * scale_non_decimal)
141 | target_value = display.Decimal(base_value.value * scale)
142 | return base_value, base_unit, target_value, target_unit
143 |
144 |
145 | def _conversion_decimal(context, is_train, is_extrapolation):
146 | """E.g., "How many grams are in 5kg?"."""
147 | dimension = random.choice(DIMENSIONS)
148 | while True:
149 | base_value, base_unit, target_value, target_unit = (
150 | _sample_conversion_decimal(dimension, is_extrapolation))
151 | if train_test_split.is_train(base_value) == is_train:
152 | break
153 |
154 | templates = [
155 | 'How many {target_name} are there in {base_value} {base_name}?',
156 | 'What is {base_value} {base_name} in {target_name}?',
157 | 'Convert {base_value} {base_name} to {target_name}.',
158 | ]
159 | if base_unit.symbol is not None:
160 | templates += [
161 | 'How many {target_name} are there in {base_value}{base_symbol}?',
162 | 'What is {base_value}{base_symbol} in {target_name}?',
163 | 'Convert {base_value}{base_symbol} to {target_name}.',
164 | ]
165 | template = random.choice(templates)
166 |
167 | base_name = pluralize(base_unit.name)
168 | target_name = pluralize(target_unit.name)
169 |
170 | question = example.question(
171 | context,
172 | template,
173 | base_name=base_name,
174 | base_symbol=base_unit.symbol,
175 | base_value=base_value,
176 | target_name=target_name)
177 | return example.Problem(question=question, answer=target_value)
178 |
179 |
180 | def _conversion_fraction(context, is_train):
181 | """E.g., "How many grams are in three quarters of a kg?"."""
182 | dimension = random.choice(DIMENSIONS)
183 |
184 | # Limit probability of giving zero answer.
185 | allow_zero = random.random() < 0.2
186 |
187 | # Repeat until we find a pair with an integral answer. (Avoids ambiguity with
188 | # decimals.)
189 | while True:
190 | base_unit, target_unit = random.sample(list(dimension.keys()), 2)
191 | base_value = number.non_integer_rational(2, signed=False)
192 | if train_test_split.is_train(base_value) != is_train:
193 | continue
194 | answer = (base_value * sympy.Rational(dimension[base_unit])
195 | / sympy.Rational(dimension[target_unit]))
196 | if (abs(answer) <= 100000
197 | and sympy.denom(answer) == 1
198 | and (allow_zero or answer != 0)):
199 | break
200 |
201 | template = random.choice([
202 | 'How many {target_name} are there in {base_value} of a {base_name}?',
203 | 'What is {base_value} of a {base_name} in {target_name}?',
204 | ])
205 |
206 | if sympy.denom(base_value) > 20 or random.choice([False, True]):
207 | base_value_string = base_value # Will be represented as e.g., 2/3.
208 | else:
209 | base_value_string = display.StringNumber(base_value) # e.g., two thirds
210 |
211 | question = example.question(
212 | context, template,
213 | base_name=base_unit.name,
214 | base_value=base_value_string,
215 | target_name=pluralize(target_unit.name))
216 | return example.Problem(question=question, answer=answer)
217 |
218 |
219 | def conversion(is_train, is_extrapolation):
220 | """Conversion question, in decimal or fraction."""
221 | context = composition.Context()
222 | # TODO(b/124038528): implement extrapolation for fraction conversions too
223 | if is_extrapolation or random.choice([False, True]):
224 | return _conversion_decimal(
225 | context, is_train=is_train, is_extrapolation=is_extrapolation)
226 | else:
227 | return _conversion_fraction(context, is_train=is_train)
228 |
229 |
230 | def time(is_train):
231 | """Questions for calculating start, end, or time differences."""
232 | context = composition.Context()
233 | start_minutes = random.randint(1, 24*60 - 1)
234 | while True:
235 | duration_minutes = random.randint(1, 12*60 - 1)
236 | if train_test_split.is_train(duration_minutes) == is_train:
237 | break
238 | end_minutes = start_minutes + duration_minutes
239 |
240 | def format_12hr(minutes):
241 | """Format minutes from midnight in 12 hr format."""
242 | hours = (minutes // 60) % 24
243 | minutes %= 60
244 | am_pm = 'AM' if hours < 12 else 'PM'
245 | hours = (hours - 1) % 12 + 1
246 | return '{}:{:02} {}'.format(hours, minutes, am_pm)
247 |
248 | start = format_12hr(start_minutes)
249 | end = format_12hr(end_minutes)
250 |
251 | which_question = random.randint(0, 3)
252 | if which_question == 0:
253 | # Question: What is start = end - duration?
254 | template = random.choice([
255 | 'What is {duration} minutes before {end}?',
256 | ])
257 | return example.Problem(
258 | question=example.question(
259 | context, template, duration=duration_minutes, end=end),
260 | answer=start)
261 | elif which_question == 1:
262 | # Question: What is end = start + duration?
263 | template = random.choice([
264 | 'What is {duration} minutes after {start}?',
265 | ])
266 | return example.Problem(
267 | question=example.question(
268 | context, template, duration=duration_minutes, start=start),
269 | answer=end)
270 | else:
271 | # Question: What is duration = end - start?
272 | template = random.choice([
273 | 'How many minutes are there between {start} and {end}?',
274 | ])
275 | return example.Problem(
276 | question=example.question(context, template, start=start, end=end),
277 | answer=duration_minutes)
278 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """The various mathematics modules."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from mathematics_dataset.modules import algebra
22 | from mathematics_dataset.modules import arithmetic
23 | from mathematics_dataset.modules import calculus
24 | from mathematics_dataset.modules import comparison
25 | from mathematics_dataset.modules import measurement
26 | from mathematics_dataset.modules import numbers
27 | from mathematics_dataset.modules import polynomials
28 | from mathematics_dataset.modules import probability
29 | import six
30 |
31 |
32 | all_ = {
33 | 'algebra': algebra,
34 | 'arithmetic': arithmetic,
35 | 'calculus': calculus,
36 | 'comparison': comparison,
37 | 'measurement': measurement,
38 | 'numbers': numbers,
39 | 'polynomials': polynomials,
40 | 'probability': probability,
41 | }
42 |
43 |
44 | def train(entropy_fn):
45 | """Returns dict of training modules."""
46 | return {
47 | name: module.train(entropy_fn) for name, module in six.iteritems(all_)
48 | }
49 |
50 |
51 | def test():
52 | """Returns dict of testing modules."""
53 | return {name: module.test() for name, module in six.iteritems(all_)}
54 |
55 |
56 | def test_extra():
57 | """Returns dict of extrapolation testing modules."""
58 | return {name: module.test_extra() for name, module in six.iteritems(all_)}
59 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/polynomials.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Polynomial manipulation (adding, composing, finding coefficients, etc)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import functools
22 | import math
23 | import random
24 |
25 | # Dependency imports
26 | from mathematics_dataset import example
27 | from mathematics_dataset.sample import number
28 | from mathematics_dataset.sample import ops
29 | from mathematics_dataset.sample import polynomials
30 | from mathematics_dataset.util import composition
31 | import numpy as np
32 | from six.moves import range
33 | import sympy
34 |
35 |
36 | _ENTROPY_TRAIN = (3, 10)
37 | _ENTROPY_INTERPOLATE = (8, 8)
38 |
39 |
40 | def _make_modules(entropy):
41 | """Returns modules given "difficulty" parameters."""
42 | sample_args_pure = composition.PreSampleArgs(1, 1, *entropy)
43 | sample_args_composed = composition.PreSampleArgs(2, 4, *entropy)
44 | sample_args_mixed = composition.PreSampleArgs(1, 4, *entropy)
45 |
46 | return {
47 | 'coefficient_named':
48 | functools.partial(coefficient_named, None, sample_args_pure),
49 | 'evaluate':
50 | functools.partial(evaluate, None, sample_args_pure),
51 | 'evaluate_composed':
52 | functools.partial(evaluate, None, sample_args_composed),
53 | # TODO(b/124038948): consider doing pure sample args for 'add'?
54 | 'add':
55 | functools.partial(add, None, sample_args_mixed),
56 | 'expand':
57 | functools.partial(expand, None, sample_args_pure),
58 | 'collect':
59 | functools.partial(collect, None, sample_args_pure),
60 | 'compose':
61 | functools.partial(compose, None, sample_args_mixed),
62 |
63 | # Rearranging powers:
64 | 'simplify_power':
65 | functools.partial(simplify_power, None, sample_args_pure),
66 | }
67 |
68 |
69 | def train(entropy_fn):
70 | """Returns dict of training modules."""
71 | return _make_modules(entropy_fn(_ENTROPY_TRAIN))
72 |
73 |
74 | def test():
75 | """Returns dict of testing modules."""
76 | return _make_modules(_ENTROPY_INTERPOLATE)
77 |
78 |
79 | def test_extra():
80 | """Returns dict of extrapolation testing modules."""
81 | return {
82 | }
83 |
84 |
85 | def coefficient_named(value, sample_args, context=None):
86 | """E.g., "Express x^2 + 2x in the form h * x^2 + k * x + t and give h."."""
87 | del value # not used
88 | if context is None:
89 | context = composition.Context()
90 | variable = sympy.Symbol(context.pop())
91 |
92 | entropy, sample_args = sample_args.peel()
93 | degree = random.randint(1, 4)
94 | if random.choice([False, True]):
95 | coefficients = polynomials.sample_coefficients(
96 | degree, entropy/2, min_non_zero=random.randint(degree - 1, degree))
97 | expanded = polynomials.expand_coefficients(coefficients, entropy/2)
98 | expression = polynomials.coefficients_to_polynomial(expanded, variable)
99 | else:
100 | expression = polynomials.sample_with_brackets(variable, degree, entropy)
101 | coefficients = list(reversed(sympy.Poly(expression).all_coeffs()))
102 |
103 | named_coeffs = [sympy.Symbol(context.pop()) for _ in range(degree + 1)]
104 | canonical = polynomials.coefficients_to_polynomial(named_coeffs, variable)
105 |
106 | if random.random() < 0.2: # only small probability of non-zero power
107 | power = random.randint(0, degree)
108 | else:
109 | non_zero_powers = [i for i in range(degree + 1) if coefficients[i] != 0]
110 | power = random.choice(non_zero_powers)
111 |
112 | value = coefficients[power]
113 | named_coeff = named_coeffs[power]
114 |
115 | template = random.choice([
116 | 'Express {expression} as {canonical} and give {target}.',
117 | 'Rearrange {expression} to {canonical} and give {target}.',
118 | 'Express {expression} in the form {canonical} and give {target}.',
119 | 'Rearrange {expression} to the form {canonical} and give {target}.',
120 | ])
121 | return example.Problem(
122 | question=example.question(
123 | context, template, expression=expression, canonical=canonical,
124 | target=named_coeff),
125 | answer=value)
126 |
127 |
128 | _TEMPLATES = [
129 | 'What is {composed}?',
130 | 'Calculate {composed}.',
131 | 'Give {composed}.',
132 | 'Determine {composed}.',
133 | ]
134 |
135 |
136 | @composition.module(number.is_integer)
137 | def evaluate(value, sample_args, context=None):
138 | """Entity for evaluating an integer-valued polynomial at a given point."""
139 | is_question = context is None
140 | if context is None:
141 | context = composition.Context()
142 |
143 | entropy, sample_args = sample_args.peel()
144 |
145 | if value is None:
146 | entropy_value = random.uniform(1, 1 + entropy/3)
147 | entropy = max(0, entropy - entropy_value)
148 | value = number.integer(entropy_value, signed=True)
149 |
150 | entropy_input = random.uniform(1, 1 + entropy/3)
151 | entropy = max(0, entropy - entropy_input)
152 | input_ = number.integer(entropy_input, signed=True)
153 |
154 | degree = random.randint(1, 3)
155 |
156 | entropies = entropy * np.random.dirichlet(list(range(1, degree + 1)))
157 | # Calculate coefficients in reverse order.
158 | target = value
159 | coeffs_reversed = []
160 | for i, coeff_entropy in enumerate(entropies):
161 | power = degree - i
162 | coeff = number.integer(coeff_entropy, signed=True)
163 | if input_ != 0:
164 | coeff += int(round(target / input_ ** power))
165 | if coeff == 0 and i == 0:
166 | # Don't allow zero in leading coefficient.
167 | coeff += random.choice([-1, 1])
168 | coeffs_reversed.append(coeff)
169 | target -= coeff * (input_ ** power)
170 | coeffs_reversed.append(target)
171 |
172 | coefficients = list(reversed(coeffs_reversed))
173 |
174 | (polynomial_entity, input_) = context.sample(
175 | sample_args, [composition.Polynomial(coefficients), input_])
176 | composed = polynomial_entity.handle.apply(input_.handle)
177 |
178 | if is_question:
179 | template = random.choice(_TEMPLATES)
180 | return example.Problem(
181 | question=example.question(context, template, composed=composed),
182 | answer=value)
183 | else:
184 | return composition.Entity(
185 | context=context,
186 | value=value,
187 | expression=composed,
188 | description='Let {self} be {composed}.',
189 | composed=composed)
190 |
191 |
192 | # TODO(b/124039290): merge with compose? both add and compose do similar things.
193 | @composition.module(composition.is_integer_polynomial)
194 | def add(value, sample_args, context=None):
195 | """E.g., "Let f(x)=2x+1, g(x)=3x+2. What is 5*f(x) - 7*g(x)?"."""
196 | is_question = context is None
197 | if context is None:
198 | context = composition.Context()
199 |
200 | entropy, sample_args = sample_args.peel()
201 |
202 | if value is None:
203 | max_degree = 3
204 | degree = random.randint(1, max_degree)
205 | entropy -= math.log10(max_degree)
206 | entropy_value = entropy / 2
207 | entropy -= entropy_value
208 | value = polynomials.sample_coefficients(
209 | degree, entropy=entropy_value, min_non_zero=random.randint(1, 3))
210 | value = composition.Polynomial(value)
211 |
212 | c1, c2, coeffs1, coeffs2 = polynomials.coefficients_linear_split(
213 | value.coefficients, entropy)
214 | coeffs1 = polynomials.trim(coeffs1)
215 | coeffs2 = polynomials.trim(coeffs2)
216 |
217 | c1, c2, fn1, fn2 = context.sample(
218 | sample_args,
219 | [c1, c2, composition.Polynomial(coeffs1), composition.Polynomial(coeffs2)]
220 | )
221 |
222 | var = sympy.var(context.pop())
223 |
224 | expression = (
225 | c1.handle * fn1.handle.apply(var) + c2.handle * fn2.handle.apply(var))
226 |
227 | if is_question:
228 | answer = polynomials.coefficients_to_polynomial(value.coefficients, var)
229 | answer = answer.sympy()
230 | template = random.choice(_TEMPLATES)
231 | return example.Problem(
232 | question=example.question(context, template, composed=expression),
233 | answer=answer)
234 | else:
235 | intermediate_symbol = context.pop()
236 | intermediate = sympy.Function(intermediate_symbol)(var)
237 | return composition.Entity(
238 | context=context,
239 | value=value,
240 | description='Let {intermediate} = {composed}.',
241 | handle=composition.FunctionHandle(intermediate_symbol),
242 | intermediate=intermediate,
243 | composed=expression)
244 |
245 |
246 | def expand(value, sample_args, context=None):
247 | """E.g., "Expand (x**2 + 1)**2."."""
248 | del value # not used
249 | if context is None:
250 | context = composition.Context()
251 | variable = sympy.Symbol(context.pop())
252 | entropy, sample_args = sample_args.peel()
253 |
254 | min_order = 1
255 | max_order = 5
256 | order = random.randint(min_order, max_order)
257 | entropy -= math.log10(max_order - min_order + 1)
258 | expression_ = polynomials.sample_with_brackets(variable, order, entropy)
259 | expanded = sympy.expand(expression_)
260 | template = random.choice([
261 | 'Expand {expression}.'
262 | ])
263 | return example.Problem(
264 | question=example.question(context, template, expression=expression_),
265 | answer=expanded)
266 |
267 |
268 | @composition.module(composition.is_polynomial)
269 | def collect(value, sample_args, context=None):
270 | """Collect terms in an unsimplified polynomial."""
271 | is_question = context is None
272 | if context is None:
273 | context = composition.Context()
274 |
275 | entropy, sample_args = sample_args.peel()
276 | if value is None:
277 | entropy_value, entropy = entropy * np.random.dirichlet([2, 3])
278 | degrees = [random.randint(1, 3)]
279 | value = composition.Polynomial(
280 | polynomials.sample_coefficients(degrees, entropy_value))
281 |
282 | assert isinstance(value, composition.Polynomial)
283 | coefficients = value.coefficients
284 |
285 | all_coefficients_are_integer = True
286 | for coeff in coefficients.flat:
287 | if not number.is_integer(coeff):
288 | all_coefficients_are_integer = False
289 | break
290 |
291 | if all_coefficients_are_integer:
292 | coefficients = polynomials.expand_coefficients(coefficients, entropy)
293 | else:
294 | # put back the unused entropy
295 | sample_args = composition.SampleArgs(
296 | sample_args.num_modules, sample_args.entropy + entropy)
297 |
298 | num_variables = coefficients.ndim
299 | variables = [sympy.Symbol(context.pop()) for _ in range(num_variables)]
300 | unsimplified = polynomials.coefficients_to_polynomial(coefficients, variables)
301 | simplified = unsimplified.sympy().expand()
302 |
303 | # Bit of a hack: handle the very rare case where no number constants appearing
304 | if not ops.number_constants(unsimplified):
305 | unsimplified = ops.Add(unsimplified, ops.Constant(0))
306 | context.sample_by_replacing_constants(sample_args, unsimplified)
307 |
308 | if is_question:
309 | template = 'Collect the terms in {unsimplified}.'
310 | return example.Problem(
311 | question=example.question(context, template, unsimplified=unsimplified),
312 | answer=simplified)
313 | else:
314 | function_symbol = context.pop()
315 | function = sympy.Function(function_symbol)(*variables)
316 | return composition.Entity(
317 | context=context,
318 | value=value,
319 | handle=composition.FunctionHandle(function_symbol),
320 | expression=unsimplified,
321 | polynomial_variables=variables,
322 | description='Let {function} = {unsimplified}.',
323 | function=function,
324 | unsimplified=unsimplified)
325 |
326 |
327 | def compose(value, sample_args, context=None):
328 | """E.g., "Let f(x)=2x+1, let g(x)=3x+10. What is f(g(x))?"."""
329 | del value # unused
330 | if context is None:
331 | context = composition.Context()
332 |
333 | entropy, sample_args = sample_args.peel()
334 | entropy_f, entropy_g = entropy * np.random.dirichlet([1, 1])
335 |
336 | coeffs_f = polynomials.sample_coefficients([random.randint(1, 2)], entropy_f)
337 | coeffs_g = polynomials.sample_coefficients([random.randint(1, 2)], entropy_g)
338 |
339 | entity_f, entity_g = context.sample(
340 | sample_args,
341 | [composition.Polynomial(coeffs_f), composition.Polynomial(coeffs_g)])
342 |
343 | variable = sympy.var(context.pop())
344 |
345 | poly_f = polynomials.coefficients_to_polynomial(coeffs_f, variable)
346 | poly_g = polynomials.coefficients_to_polynomial(coeffs_g, variable)
347 |
348 | poly_f_g = poly_f.sympy().subs(variable, poly_g.sympy()).expand()
349 |
350 | expression = composition.FunctionHandle(entity_f, entity_g).apply(variable)
351 |
352 | template = random.choice(_TEMPLATES)
353 | return example.Problem(
354 | question=example.question(context, template, composed=expression),
355 | answer=poly_f_g)
356 |
357 |
358 | def simplify_power(value, sample_args, context=None):
359 | """E.g., "Simplify ((x**2)**3/x**4)**2/x**3."."""
360 | del value # unused
361 | if context is None:
362 | context = composition.Context()
363 |
364 | entropy, sample_args = sample_args.peel()
365 |
366 | variable = sympy.symbols(context.pop(), positive=True)
367 | unsimplified = polynomials.sample_messy_power(variable, entropy)
368 | answer = unsimplified.sympy()
369 |
370 | template = random.choice([
371 | 'Simplify {unsimplified} assuming {variable} is positive.',
372 | ])
373 | return example.Problem(
374 | example.question(
375 | context, template, unsimplified=unsimplified, variable=variable),
376 | answer)
377 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/probability.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Probability questions (sampling, independence, expectations, ...)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import functools
23 | import random
24 | import string
25 |
26 | # Dependency imports
27 | from mathematics_dataset import example
28 | from mathematics_dataset.modules import train_test_split
29 | from mathematics_dataset.util import combinatorics
30 | from mathematics_dataset.util import composition
31 | from mathematics_dataset.util import display
32 | from mathematics_dataset.util import probability
33 | import numpy as np
34 | from six.moves import range
35 | from six.moves import zip
36 |
37 |
38 | _LETTERS = string.ascii_lowercase
39 |
40 | _MAX_FRAC_TRIVIAL_PROB = 0.1
41 |
42 | # Maximum number of colours and objects in a bag.
43 | _MAX_DISTINCT_LETTERS = 6
44 | _MAX_TOTAL_LETTERS = 20
45 | _MAX_LETTER_REPEAT = 10
46 |
47 | _SWR_SAMPLE_COUNT = [2, 4]
48 | _SWR_SAMPLE_COUNT_EXTRAPOLATE = [5, 5]
49 |
50 | _GERUNDS = {
51 | 'pick': 'picking',
52 | }
53 |
54 |
55 | def _make_modules(is_train):
56 | """Returns modules, with split based on the boolean `is_train`."""
57 | return {
58 | 'swr_p_sequence': functools.partial(
59 | swr_prob_sequence, is_train=is_train, sample_range=_SWR_SAMPLE_COUNT),
60 | 'swr_p_level_set': functools.partial(
61 | swr_prob_level_set, is_train=is_train,
62 | sample_range=_SWR_SAMPLE_COUNT),
63 | }
64 |
65 |
66 | def train(entropy_fn):
67 | """Returns dict of training modules."""
68 | del entropy_fn # unused
69 | return _make_modules(is_train=True)
70 |
71 |
72 | def test():
73 | """Returns dict of testing modules."""
74 | return _make_modules(is_train=False)
75 |
76 |
77 | def test_extra():
78 | """Returns dict of extrapolation testing modules."""
79 | return {
80 | 'swr_p_sequence_more_samples': functools.partial(
81 | swr_prob_sequence, is_train=None,
82 | sample_range=_SWR_SAMPLE_COUNT_EXTRAPOLATE),
83 | 'swr_p_level_set_more_samples': functools.partial(
84 | swr_prob_level_set, is_train=None,
85 | sample_range=_SWR_SAMPLE_COUNT_EXTRAPOLATE),
86 | }
87 |
88 |
89 | def _sequence_event(values, length, verb):
90 | """Returns sequence (finite product) event.
91 |
92 | Args:
93 | values: List of values to sample from.
94 | length: Length of the sequence to generate.
95 | verb: Verb in infinitive form.
96 |
97 | Returns:
98 | Instance of `probability.FiniteProductEvent`, together with a text
99 | description.
100 | """
101 | del verb # unused
102 | samples = [random.choice(values) for _ in range(length)]
103 | events = [probability.DiscreteEvent([sample]) for sample in samples]
104 | event = probability.FiniteProductEvent(events)
105 | sequence = ''.join(str(sample) for sample in samples)
106 | event_description = 'sequence {sequence}'.format(sequence=sequence)
107 | return event, event_description
108 |
109 |
110 | def _word_series(words, conjunction='and'):
111 | """Combines the words using commas and the final conjunction."""
112 | len_words = len(words)
113 | if len_words == 0:
114 | return ''
115 | if len_words == 1:
116 | return words[0]
117 | return '{} {} {}'.format(', '.join(words[:-1]), conjunction, words[-1])
118 |
119 |
120 | def _level_set_event(values, length, verb):
121 | """Generates `LevelSetEvent`; see _generate_sequence_event."""
122 | counts = combinatorics.uniform_non_negative_integers_with_sum(
123 | len(values), length)
124 | counts_dict = dict(list(zip(values, counts)))
125 | event = probability.CountLevelSetEvent(counts_dict)
126 |
127 | shuffled_values = list(values)
128 | random.shuffle(shuffled_values)
129 |
130 | counts_and_values = [
131 | '{} {}'.format(counts_dict[value], value)
132 | for value in shuffled_values
133 | if counts_dict[value] > 0
134 | ]
135 | counts_and_values = _word_series(counts_and_values)
136 | template = random.choice([
137 | '{verbing} {counts_and_values}',
138 | ])
139 | verbing = _GERUNDS[verb]
140 | event_description = template.format(
141 | counts_and_values=counts_and_values, verbing=verbing)
142 | return event, event_description
143 |
144 |
145 | LetterBag = collections.namedtuple(
146 | 'LetterBag',
147 | ('weights', 'random_variable', 'letters_distinct', 'bag_contents'))
148 |
149 |
150 | def _sample_letter_bag(is_train, min_total):
151 | """Samples a "container of letters" and returns info on it."""
152 | while True:
153 | num_distinct_letters = random.randint(1, _MAX_DISTINCT_LETTERS)
154 | num_letters_total = random.randint(
155 | max(num_distinct_letters, min_total),
156 | min(_MAX_TOTAL_LETTERS, num_distinct_letters * _MAX_LETTER_REPEAT))
157 | letter_counts = combinatorics.uniform_positive_integers_with_sum(
158 | num_distinct_letters, num_letters_total)
159 |
160 | # Test/train split.
161 | if (is_train is None
162 | or train_test_split.is_train(sorted(letter_counts)) == is_train):
163 | break
164 |
165 | letters_distinct = random.sample(_LETTERS, num_distinct_letters)
166 | weights = {i: 1 for i in range(num_letters_total)}
167 |
168 | letters_with_repetition = []
169 | for letter, count in zip(letters_distinct, letter_counts):
170 | letters_with_repetition += [letter] * count
171 | random.shuffle(letters_with_repetition)
172 |
173 | random_variable = probability.DiscreteRandomVariable(
174 | {i: letter for i, letter in enumerate(letters_with_repetition)})
175 |
176 | if random.choice([False, True]):
177 | bag_contents = ''.join(letters_with_repetition)
178 | else:
179 | letters_and_counts = [
180 | '{}: {}'.format(letter, count)
181 | for letter, count in zip(letters_distinct, letter_counts)]
182 | bag_contents = '{' + ', '.join(letters_and_counts) + '}'
183 |
184 | return LetterBag(
185 | weights=weights,
186 | random_variable=random_variable,
187 | letters_distinct=letters_distinct,
188 | bag_contents=bag_contents)
189 |
190 |
191 | def _swr_space(is_train, sample_range):
192 | """Returns probability space for sampling without replacement."""
193 | num_sampled = random.randint(*sample_range)
194 | sample = _sample_letter_bag(is_train=is_train, min_total=num_sampled)
195 |
196 | space = probability.SampleWithoutReplacementSpace(sample.weights, num_sampled)
197 |
198 | random_variable = probability.FiniteProductRandomVariable(
199 | [sample.random_variable] * num_sampled)
200 |
201 | random_variable.description = (
202 | str(display.StringNumber(num_sampled))
203 | + ' letters picked without replacement from '
204 | + sample.bag_contents)
205 |
206 | return sample.letters_distinct, space, random_variable
207 |
208 |
209 | def _sample_without_replacement_probability_question(
210 | is_train, event_fn, sample_range):
211 | """Question for prob of some event when sampling without replacement."""
212 | def too_big(event_in_space):
213 | if isinstance(event_in_space, probability.SequenceEvent):
214 | size = len(event_in_space.all_sequences())
215 | else:
216 | assert isinstance(event_in_space, probability.FiniteProductEvent)
217 | size = np.prod([len(event.values) for event in event_in_space.events])
218 | return size > int(2e5)
219 |
220 | allow_trivial_prob = random.random() < _MAX_FRAC_TRIVIAL_PROB
221 |
222 | while True:
223 | distinct_letters, space, random_variable = _swr_space(
224 | is_train, sample_range)
225 |
226 | event, event_description = event_fn(
227 | values=distinct_letters, length=space.n_samples, verb='pick')
228 | event_in_space = random_variable.inverse(event)
229 | if too_big(event_in_space):
230 | continue
231 | answer = space.probability(event_in_space)
232 | if answer not in [0, 1] or allow_trivial_prob:
233 | break
234 |
235 | context = composition.Context()
236 |
237 | template = random.choice([
238 | '{random_variable_capitalize}. What is prob of {event}?',
239 | '{random_variable_capitalize}. Give prob of {event}.',
240 | 'What is prob of {event} when {random_variable}?',
241 | 'Calculate prob of {event} when {random_variable}.',
242 | ])
243 | question = example.question(
244 | context,
245 | template,
246 | random_variable=random_variable.description,
247 | random_variable_capitalize=(
248 | str(random_variable.description).capitalize()),
249 | event=event_description)
250 | return example.Problem(question, answer)
251 |
252 |
253 | def swr_prob_sequence(is_train, sample_range):
254 | """Probability of given sequence when sampling without replacement."""
255 | return _sample_without_replacement_probability_question(
256 | is_train=is_train, event_fn=_sequence_event, sample_range=sample_range)
257 |
258 |
259 | def swr_prob_level_set(is_train, sample_range):
260 | """Probability of given level set when sampling without replacement."""
261 | return _sample_without_replacement_probability_question(
262 | is_train=is_train, event_fn=_level_set_event, sample_range=sample_range)
263 |
--------------------------------------------------------------------------------
/mathematics_dataset/modules/train_test_split.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility for train/test split based on hash value."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import hashlib
22 |
23 |
24 | def is_train(value):
25 | """Returns whether `value` should be used in a training question."""
26 | value_as_string = str(value).encode('utf-8')
27 | return int(hashlib.md5(value_as_string).hexdigest(), 16) % 2 == 0
28 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/arithmetic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Sample arithmetic expressions with a given value."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import math
23 | import random
24 |
25 | # Dependency imports
26 | from mathematics_dataset.sample import number
27 | from mathematics_dataset.sample import ops
28 | from mathematics_dataset.util import combinatorics
29 | import numpy as np
30 | import six
31 | from six.moves import zip
32 | import sympy
33 |
34 |
35 | class _SampleArgs(collections.namedtuple('SampleArgs', ('count', 'entropy'))):
36 | """For sampling mathematical expressions."""
37 |
38 | def peel(self, frac=1):
39 | """Peels one (or `frac`) of an op's entropy."""
40 | entropy = frac * self.entropy / self.count
41 | new_sample_args = _SampleArgs(self.count, self.entropy - entropy)
42 | return entropy, new_sample_args
43 |
44 | def split(self, args):
45 | """Splits the entropy and op counts up."""
46 | non_integer_count = sum(not arg.is_Integer for arg in args)
47 | assert non_integer_count <= self.count - 1
48 | count_split = combinatorics.uniform_non_negative_integers_with_sum(
49 | len(args), (self.count - 1) - non_integer_count)
50 | for i, arg in enumerate(args):
51 | if not arg.is_Integer:
52 | count_split[i] += 1
53 | if all(count == 0 for count in count_split):
54 | assert self.entropy == 0
55 | entropies = np.zeros(len(count_split))
56 | else:
57 | entropies = (
58 | np.random.dirichlet(np.maximum(1e-9, count_split)) * self.entropy)
59 | return [_SampleArgs(op_count, entropy)
60 | for op_count, entropy in zip(count_split, entropies)]
61 |
62 |
63 | def _add_sub_filter(value, sample_args):
64 | return sample_args.count >= 2 or value.is_Integer
65 |
66 |
67 | def _add_op(value, sample_args, rationals_allowed):
68 | """Returns sampled args for `ops.Add`."""
69 | entropy, sample_args = sample_args.peel()
70 | if rationals_allowed and sample_args.count >= 3:
71 | x = number.integer_or_rational(entropy, True)
72 | else:
73 | x = number.integer(entropy, True)
74 | if random.choice([False, True]):
75 | op_args = [x, value - x]
76 | else:
77 | op_args = [value - x, x]
78 | return ops.Add, op_args, sample_args
79 |
80 |
81 | def _sub_op(value, sample_args, rationals_allowed):
82 | """Returns sampled args for `ops.Sub`."""
83 | entropy, sample_args = sample_args.peel()
84 | if rationals_allowed and sample_args.count >= 3:
85 | x = number.integer_or_rational(entropy, True)
86 | else:
87 | x = number.integer(entropy, True)
88 | if random.choice([False, True]):
89 | op_args = [x, x - value]
90 | else:
91 | op_args = [value + x, x]
92 | return ops.Sub, op_args, sample_args
93 |
94 |
95 | def _entropy_of_factor_split(integer):
96 | """Returns entropy (log base 10) of decomposing: integer = a * b."""
97 | assert integer.is_Integer
98 | if integer == 0:
99 | return 0
100 | # Gives dict of form {factor: multiplicity}
101 | factors = sympy.factorint(integer)
102 | return sum(math.log10(mult + 1) for mult in six.itervalues(factors))
103 |
104 |
105 | def _split_factors(integer):
106 | """Randomly factors integer into product of two integers."""
107 | assert integer.is_Integer
108 | if integer == 0:
109 | return [1, 0]
110 | # Gives dict of form {factor: multiplicity}
111 | factors = sympy.factorint(integer)
112 | left = sympy.Integer(1)
113 | right = sympy.Integer(1)
114 | for factor, mult in six.iteritems(factors):
115 | left_mult = random.randint(0, mult)
116 | right_mult = mult - left_mult
117 | left *= factor ** left_mult
118 | right *= factor ** right_mult
119 | return left, right
120 |
121 |
122 | def _mul_filter(value, sample_args):
123 | if sample_args.count >= 2:
124 | return True
125 | if not value.is_Integer:
126 | return False
127 | return sample_args.entropy <= _entropy_of_factor_split(value)
128 |
129 |
130 | def _mul_op(value, sample_args, rationals_allowed):
131 | """Returns sampled args for `ops.Mul`."""
132 | if sample_args.count >= 3:
133 | _, op_args, sample_args = _div_op(value, sample_args, rationals_allowed)
134 | op_args = [op_args[0], sympy.Integer(1) / op_args[1]]
135 | elif sample_args.count == 1:
136 | entropy, sample_args = sample_args.peel()
137 | assert _entropy_of_factor_split(value) >= entropy
138 | op_args = _split_factors(value)
139 | else:
140 | assert sample_args.count == 2
141 | entropy, sample_args = sample_args.peel()
142 | numer = sympy.numer(value)
143 | denom = sympy.denom(value)
144 | p1, p2 = _split_factors(numer)
145 | entropy -= _entropy_of_factor_split(numer)
146 | mult = number.integer(entropy, signed=True, min_abs=1, coprime_to=p1)
147 | op_args = [p1 / (mult * denom), p2 * mult]
148 |
149 | if random.choice([False, True]):
150 | op_args = list(reversed(op_args))
151 |
152 | return ops.Mul, op_args, sample_args
153 |
154 |
155 | def _div_filter(value, sample_args):
156 | del value # unused
157 | del sample_args # unused
158 | return True
159 |
160 |
161 | def _div_op(value, sample_args, rationals_allowed):
162 | """Returns sampled args for `ops.Div`."""
163 | assert rationals_allowed # should be True if this function gets invoked
164 | entropy, sample_args = sample_args.peel()
165 |
166 | numer = sympy.numer(value)
167 | denom = sympy.denom(value)
168 |
169 | if sample_args.count == 1:
170 | mult = number.integer(entropy, signed=True, min_abs=1)
171 | op_args = [numer * mult, denom * mult]
172 | elif sample_args.count == 2:
173 | if numer == 0 or random.choice([False, True]):
174 | x = number.integer(entropy, signed=True, min_abs=1, coprime_to=denom)
175 | op_args = [sympy.Rational(x * numer, denom), x]
176 | else:
177 | x = number.integer(entropy, signed=True, min_abs=1, coprime_to=numer)
178 | op_args = [x, sympy.Rational(x * denom, numer)]
179 | else:
180 | assert sample_args.count >= 3
181 | p2, p1 = _split_factors(numer)
182 | q1, q2 = _split_factors(denom)
183 | entropy -= _entropy_of_factor_split(numer) + _entropy_of_factor_split(denom)
184 | entropy_r = random.uniform(0, entropy)
185 | entropy_s = entropy - entropy_r
186 | r = number.integer(entropy_r, signed=True, min_abs=1, coprime_to=q1*p2)
187 | s = number.integer(entropy_s, signed=False, min_abs=1, coprime_to=p1*q2)
188 | op_args = [sympy.Rational(r*p1, s*q1), sympy.Rational(r*q2, s*p2)]
189 |
190 | return ops.Div, op_args, sample_args
191 |
192 |
193 | def _arithmetic(value, sample_args, add_sub, mul_div):
194 | """Internal arithmetic thingy...."""
195 | assert sample_args.count >= 0
196 | if sample_args.count == 0:
197 | assert sample_args.entropy == 0
198 | return ops.Constant(value)
199 |
200 | allowed = []
201 | if add_sub and _add_sub_filter(value, sample_args):
202 | allowed.append(_add_op)
203 | allowed.append(_sub_op)
204 | if mul_div and _mul_filter(value, sample_args):
205 | allowed.append(_mul_op)
206 | if mul_div and _div_filter(value, sample_args):
207 | allowed.append(_div_op)
208 | if not allowed:
209 | raise ValueError(
210 | 'No valid ops found, add_sub={} mul_div={} value={} sample_args={}'
211 | .format(add_sub, mul_div, value, sample_args))
212 | choice = random.choice(allowed)
213 |
214 | op, args, sample_args = choice(value, sample_args, rationals_allowed=mul_div)
215 | sample_args = sample_args.split(args)
216 | child_expressions = [_arithmetic(arg, child_sample_arg, add_sub, mul_div)
217 | for arg, child_sample_arg in zip(args, sample_args)]
218 |
219 | return op(*child_expressions)
220 |
221 |
222 | def length_range_for_entropy(entropy):
223 | """Returns length range to sample from for given entropy."""
224 | min_length = 3
225 | max_length = min_length + int(entropy / 2)
226 | return min_length, max_length
227 |
228 |
229 | def arithmetic(value, entropy, length=None, add_sub=True, mul_div=True):
230 | """Generates an arithmetic expression with a given value.
231 |
232 | Args:
233 | value: Target value (integer or rational).
234 | entropy: Amount of randomness to use in generating expression.
235 | length: Number of ops to use. If `None` then suitable length will be picked
236 | based on entropy by sampling within the range
237 | `length_range_for_entropy`.
238 | add_sub: Whether to include addition and subtraction operations.
239 | mul_div: Whether to include multiplication and division operations.
240 |
241 | Returns:
242 | Instance of `ops.Op` containing expression.
243 | """
244 | assert isinstance(entropy, float)
245 | if length is None:
246 | min_length, max_length = length_range_for_entropy(entropy)
247 | length = random.randint(min_length, max_length)
248 | # Some entropy used up in sampling the length.
249 | entropy -= math.log10(max_length - min_length + 1)
250 | else:
251 | assert isinstance(length, int)
252 |
253 | # Entropy adjustment, because different binary trees (from sampling ops) can
254 | # lead to the same expression. This is the correct value when we use just
255 | # addition as the op, and is otherwise an an upper bound.
256 | entropy += combinatorics.log_number_binary_trees(length) / math.log(10)
257 |
258 | value = sympy.sympify(value)
259 | sample_args = _SampleArgs(length, entropy)
260 | return _arithmetic(value, sample_args, add_sub, mul_div)
261 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/arithmetic_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.sample.arithmetic."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 |
23 | # Dependency imports
24 | from absl.testing import absltest
25 | from absl.testing import parameterized
26 | from mathematics_dataset.sample import arithmetic
27 | from mathematics_dataset.sample import number
28 | from mathematics_dataset.sample import ops
29 | from six.moves import range
30 | import sympy
31 |
32 |
33 | class ArithmeticTest(parameterized.TestCase):
34 |
35 | def testArithmetic(self):
36 | for _ in range(1000):
37 | target = number.integer_or_rational(4, signed=True)
38 | entropy = 8.0
39 | expression = arithmetic.arithmetic(target, entropy)
40 | self.assertEqual(sympy.sympify(expression), target)
41 |
42 | def testArithmeticLength(self):
43 | """Tests that the generated arithmetic expressions have given length."""
44 | for _ in range(1000):
45 | target = number.integer_or_rational(4, signed=True)
46 | entropy = 8.0
47 | length = random.randint(2, 10)
48 | expression = arithmetic.arithmetic(target, entropy, length)
49 | # Note: actual length is #ops = #numbers - 1.
50 | actual_length = len(ops.number_constants(expression)) - 1
51 | self.assertEqual(actual_length, length)
52 |
53 |
54 | if __name__ == '__main__':
55 | absltest.main()
56 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/linear_system.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Generate linear systems with given set of solutions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 |
23 | # Dependency imports
24 | from mathematics_dataset.sample import number
25 | from mathematics_dataset.sample import ops
26 | from mathematics_dataset.sample import polynomials
27 | import numpy as np
28 | from six.moves import range
29 | import sympy
30 |
31 |
32 | def _make_equals_zero_split(monomials):
33 | """Returns an `ops.Eq` containing sum of monomials split on left and right."""
34 | left = []
35 | right = []
36 | for monomial in monomials:
37 | if random.choice([False, True]):
38 | left.append(monomial)
39 | else:
40 | right.append(ops.Neg(monomial))
41 | if not left:
42 | left = [0]
43 | if not right:
44 | right = [0]
45 | left = ops.Add(*left)
46 | right = ops.Add(*right)
47 | return ops.Eq(left, right)
48 |
49 |
50 | def _is_trivial_in(matrix, variable):
51 | """Returns true if matrix_ij == 0 for some i and all j != variable."""
52 | matrix = np.asarray(matrix)
53 | assert matrix.ndim == 2 and matrix.shape[0] == matrix.shape[1]
54 | size = matrix.shape[0]
55 | if size == 1:
56 | return False
57 | for i in range(size):
58 | all_zero = True
59 | for j in range(size):
60 | if j != variable and matrix[i, j] != 0:
61 | all_zero = False
62 | break
63 | if all_zero:
64 | return True
65 | return False
66 |
67 |
68 | def _invertible_matrix(degree, entropy, non_trivial_in):
69 | """Generates random invertible matrix."""
70 | matrix_entropies = entropy * np.random.dirichlet(np.ones(degree * degree))
71 | matrix_entropies = np.reshape(matrix_entropies, [degree, degree])
72 | matrix_entropies = np.maximum(1, matrix_entropies)
73 |
74 | while True:
75 | def gen(i, j):
76 | return number.integer(matrix_entropies[i, j], True)
77 |
78 | matrix = [[gen(i, j) for i in range(degree)] for j in range(degree)] # pylint: disable=g-complex-comprehension
79 | if non_trivial_in is not None and _is_trivial_in(matrix, non_trivial_in):
80 | continue
81 | if sympy.det(sympy.Matrix(matrix)) != 0:
82 | break
83 |
84 | matrix = np.asarray(matrix).astype(int)
85 | return matrix
86 |
87 |
88 | def linear_system(variables, solutions, entropy, non_trivial_in=None,
89 | length=None):
90 | """Returns a linear system (set of equalities) with the given solutions.
91 |
92 | Args:
93 | variables: List of variables.
94 | solutions: List of solutions, of the same length as `variables`.
95 | entropy: Float >= 0; the entropy used.
96 | non_trivial_in: Optional integer corresponding to a variable for which the
97 | solution shouldn't be "trivial". E.g., "solve a + b = 3, a = -2 for a"
98 | is disallowed if `variables[non_trivial_in] == 'a'`.
99 | length: Total number of terms appearing; if `None` then selected wisely.
100 |
101 | Returns:
102 | List of `ops.Eq`.
103 | """
104 | degree = len(variables)
105 | assert degree == len(solutions)
106 |
107 | frac_entropy_matrix = random.uniform(1/3, 2/3)
108 | matrix = _invertible_matrix(
109 | degree, entropy * frac_entropy_matrix, non_trivial_in)
110 | solutions = np.asarray(solutions)
111 | constant = np.matmul(matrix, solutions.astype(int))
112 | flattened = np.concatenate([np.reshape(matrix, [degree * degree]), constant])
113 | is_zero = flattened == 0
114 |
115 | if length is None:
116 | min_length = np.count_nonzero(flattened) + 1
117 | max_length = max(min_length, 1 + int(degree * (1 + entropy / 2)))
118 | length = random.randint(min_length, max_length)
119 |
120 | counts = polynomials.expanded_coefficient_counts(
121 | length=length, is_zero=is_zero)
122 |
123 | entropies = (1 - frac_entropy_matrix) * entropy * np.random.dirichlet(
124 | np.maximum(1e-9, counts - 1))
125 |
126 | terms = []
127 | for i in range(len(flattened)):
128 | coeffs = polynomials.integers_with_sum(
129 | value=flattened[i], count=counts[i], entropy=entropies[i])
130 | terms.append(coeffs)
131 |
132 | matrix = terms[:degree*degree]
133 | constant = terms[-degree:]
134 | equations = []
135 | for row_index in range(degree):
136 | monomials = []
137 | for col_index in range(degree):
138 | for term in matrix[row_index * degree + col_index]:
139 | monomials.append(polynomials.monomial(term, variables[col_index], 1))
140 | for term in constant[row_index]:
141 | monomials.append(polynomials.monomial(-term, None, 0))
142 | equations.append(_make_equals_zero_split(monomials))
143 |
144 | return equations
145 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/linear_system_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.sample.linear_system."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 |
23 | # Dependency imports
24 | from absl.testing import absltest
25 | from absl.testing import parameterized
26 | from mathematics_dataset.sample import linear_system
27 | from six.moves import range
28 | import sympy
29 |
30 |
31 | class ExpressionWithValueTest(parameterized.TestCase):
32 |
33 | def testIsTrivialIn(self):
34 | self.assertEqual(linear_system._is_trivial_in([[1]], 0), False)
35 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [3, 4]], 0), False)
36 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [3, 0]], 0), True)
37 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [3, 0]], 1), False)
38 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [0, 3]], 0), False)
39 | self.assertEqual(linear_system._is_trivial_in([[1, 2], [0, 3]], 1), True)
40 |
41 | @parameterized.parameters([1, 2, 3])
42 | def testLinearSystem(self, degree):
43 | for _ in range(100): # test a few times
44 | target = [random.randint(-100, 100) for _ in range(degree)]
45 | variables = [sympy.Symbol(chr(ord('a') + i)) for i in range(degree)]
46 | system = linear_system.linear_system(
47 | variables=variables,
48 | solutions=target,
49 | entropy=10.0)
50 | solved = sympy.solve(system, variables)
51 | solved = [solved[symbol] for symbol in variables]
52 | self.assertEqual(target, solved)
53 |
54 |
55 | if __name__ == '__main__':
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/number.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Generate random integers and rationals with minimum guarantees on entropy."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import math
22 | import random
23 |
24 | # Dependency imports
25 | from mathematics_dataset.util import display
26 | import numpy as np
27 | import six
28 | import sympy
29 |
30 |
31 | def _coprime_density(value):
32 | """Returns float > 0; asymptotic density of integers coprime to `value`."""
33 | factors = sympy.factorint(value)
34 | density = 1.0
35 | for prime in six.iterkeys(factors):
36 | density *= 1 - 1 / prime
37 | return density
38 |
39 |
40 | def integer(entropy, signed, min_abs=0, coprime_to=1):
41 | """Returns an integer from a set of size ceil(10**entropy).
42 |
43 | If `signed` is True, then includes negative integers, otherwise includes just
44 | positive integers.
45 |
46 | Args:
47 | entropy: Float >= 0.
48 | signed: Boolean. Whether to also return negative numbers.
49 | min_abs: Integer >= 0. The minimum absolute value.
50 | coprime_to: Optional integer >= 1. The returned integer is guaranteed to be
51 | coprime to `coprime_to`, with entropy still accounted for.
52 |
53 | Returns:
54 | Integer.
55 | """
56 | assert isinstance(min_abs, int) and not isinstance(min_abs, bool)
57 | coprime_to = abs(coprime_to)
58 | assert min_abs >= 0
59 |
60 | max_ = math.pow(10, entropy)
61 | max_ += min_abs
62 | if coprime_to >= 2:
63 | max_ = max_ / _coprime_density(coprime_to) + 1
64 |
65 | if signed:
66 | max_ = int(math.ceil(max_ / 2))
67 | range_ = [-max_, max_]
68 | else:
69 | max_ = int(math.ceil(max_))
70 | range_ = [min_abs, max_]
71 |
72 | while True:
73 | value = random.randint(*range_)
74 | if abs(value) >= min_abs and sympy.gcd(value, coprime_to) == 1:
75 | break
76 |
77 | return sympy.Integer(value)
78 |
79 |
80 | def non_integer_rational(entropy, signed):
81 | """Similar args to `integer`. Entropy split between denom and numer."""
82 | numer_entropy = random.uniform(0, entropy)
83 | denom_entropy = entropy - numer_entropy
84 | numer = integer(numer_entropy, signed, min_abs=1)
85 | denom = integer(denom_entropy, False, min_abs=2, coprime_to=numer)
86 | return sympy.Rational(numer, denom)
87 |
88 |
89 | def integer_or_rational(entropy, signed, min_abs=0):
90 | """Returns a rational, with 50% probability of it being an integer."""
91 | if random.choice([False, True]):
92 | return integer(entropy, signed, min_abs=min_abs)
93 | else:
94 | return non_integer_rational(entropy, signed)
95 |
96 |
97 | def non_integer_decimal(entropy, signed):
98 | """Returns a random decimal; integer divided by random power of ten.
99 |
100 | Guaranteed to be non-integer (i.e., numbers after the decimal point).
101 |
102 | Args:
103 | entropy: Float.
104 | signed: Boolean. Whether to also return negative numbers.
105 |
106 | Returns:
107 | Non-integer decimal.
108 | """
109 | while True:
110 | base = integer(entropy, signed)
111 | shift = random.randint(1, int(math.ceil(entropy)))
112 | divisor = 10**shift
113 | if base % divisor != 0:
114 | return display.Decimal(sympy.Rational(base, divisor))
115 |
116 |
117 | def integer_or_decimal(entropy, signed):
118 | """Returns integer or non-integer decimal; 50% probability of each."""
119 | if random.choice([False, True]):
120 | # Represent it as a decimal so that arithmetic operations are supported:
121 | return display.Decimal(integer(entropy, signed))
122 | else:
123 | return non_integer_decimal(entropy, signed)
124 |
125 |
126 | def entropy_of_value(value):
127 | """Returns "min entropy" that would give probability of getting this value."""
128 | if isinstance(value, display.Decimal):
129 | return entropy_of_value(sympy.numer(value))
130 |
131 | if is_non_integer_rational(value):
132 | numer = sympy.numer(value)
133 | denom = sympy.denom(value)
134 | return entropy_of_value(numer) + entropy_of_value(denom)
135 | elif not is_integer(value):
136 | raise ValueError('Unhandled value: {}'.format(value))
137 |
138 | # Note: we sample integers in a range of size approx 10**entropy about zero,
139 | # so assume that `abs(value)` is about half of the upper range.
140 | return math.log10(5 * abs(value) + 1)
141 |
142 |
143 | def is_integer(value):
144 | return isinstance(value, (int, np.int64, np.int32, sympy.Integer))
145 |
146 |
147 | def is_positive_integer(value):
148 | """Filter for: value is a strictly positive integer."""
149 | return is_integer(value) and value > 0
150 |
151 |
152 | def is_integer_or_rational(value):
153 | return is_integer(value) or isinstance(value, sympy.Rational)
154 |
155 |
156 | def is_integer_or_decimal(value):
157 | return is_integer(value) or isinstance(value, display.Decimal)
158 |
159 |
160 | def is_integer_or_rational_or_decimal(value):
161 | return is_integer_or_rational(value) or is_integer_or_decimal(value)
162 |
163 |
164 | def is_non_integer_rational(value):
165 | return is_integer_or_rational(value) and not is_integer(value)
166 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/number_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.sample.number."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 |
23 | # Dependency imports
24 | from absl.testing import absltest
25 | from absl.testing import parameterized
26 | from mathematics_dataset.sample import number
27 | from six.moves import range
28 | import sympy
29 |
30 |
31 | class NumberTest(parameterized.TestCase):
32 |
33 | def testCoprimeDensity(self):
34 | self.assertEqual(number._coprime_density(1), 1.0)
35 | self.assertEqual(number._coprime_density(2), 0.5)
36 | self.assertLess(abs(number._coprime_density(3) - 2/3), 1e-6)
37 | self.assertLess(abs(number._coprime_density(6) - 1/3), 1e-6)
38 |
39 | @parameterized.parameters(False, True)
40 | def testInteger_allowZero(self, signed):
41 | saw_zero = False
42 | saw_nonzero = False
43 | for _ in range(1000):
44 | sample = number.integer(1, signed=signed)
45 | if sample == 0:
46 | saw_zero = True
47 | else:
48 | saw_nonzero = True
49 | if saw_zero and saw_nonzero:
50 | break
51 | self.assertTrue(saw_zero)
52 | self.assertTrue(saw_nonzero)
53 |
54 | def testNonIntegerRational(self):
55 | for _ in range(1000):
56 | entropy = random.uniform(0, 10)
57 | signed = random.choice([False, True])
58 | sample = number.non_integer_rational(entropy, signed)
59 | self.assertNotEqual(sympy.denom(sample), 1)
60 |
61 | @parameterized.parameters(False, True)
62 | def testIntegerOrRational(self, signed):
63 | # Tests we can call it. Do it a few times so both code paths get executed.
64 | for _ in range(10):
65 | number.integer_or_rational(2, signed)
66 |
67 | def testNonIntegerDecimal(self):
68 | for _ in range(1000):
69 | sample = number.non_integer_decimal(1, False)
70 | self.assertNotEqual(sympy.denom(sample), 1)
71 | self.assertLen(str(sample), 3) # should be of form "0.n"
72 | self.assertGreater(sample, 0) # positive
73 |
74 | def testNonIntegerDecimal_size(self):
75 | saw_bigger_one = False
76 | saw_smaller_one = False
77 | for _ in range(1000):
78 | sample = number.non_integer_decimal(2, False)
79 | if sample > 1:
80 | saw_bigger_one = True
81 | else:
82 | saw_smaller_one = True
83 | if saw_bigger_one and saw_smaller_one:
84 | break
85 | self.assertTrue(saw_bigger_one)
86 | self.assertTrue(saw_smaller_one)
87 |
88 | @parameterized.parameters(
89 | lambda: number.integer(0, True),
90 | lambda: number.integer(1, True),
91 | lambda: number.non_integer_rational(2, True),
92 | lambda: number.non_integer_decimal(1, True))
93 | def testGenerate_signed(self, generator):
94 | saw_positive = False
95 | saw_negative = False
96 | for _ in range(1000):
97 | sample = generator()
98 | saw_positive |= sample > 0
99 | saw_negative |= sample < 0
100 | if saw_positive and saw_negative:
101 | break
102 |
103 | self.assertTrue(saw_positive)
104 | self.assertTrue(saw_negative)
105 |
106 | @parameterized.parameters(
107 | lambda: number.integer(2, False),
108 | lambda: number.non_integer_rational(2, False))
109 | def testIntegerRational_distinctCount(self, generator):
110 | seen = set()
111 | for _ in range(3000):
112 | seen.add(generator())
113 | self.assertGreaterEqual(len(seen), 10 ** 2)
114 |
115 | @parameterized.parameters(number.integer, number.non_integer_decimal)
116 | def testEntropyOfValue(self, generator):
117 | for entropy in [1, 2, 4, 8, 16]:
118 | sum_entropy = 0.0
119 | count = 2000
120 | for _ in range(count):
121 | value = generator(entropy, signed=True)
122 | sum_entropy += number.entropy_of_value(value)
123 | avg_entropy = sum_entropy / count
124 | error = abs(entropy - avg_entropy) / entropy
125 | self.assertLess(error, 0.2)
126 |
127 |
128 | if __name__ == '__main__':
129 | absltest.main()
130 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mathematical operations used to build up expressions for printing.
16 |
17 | We can't use sympy because sympy will automatically simplify many types of
18 | expressions, even with `evaluate=False` passed in. For example:
19 |
20 | * Mul(-2, -3, evaluate=False) gives -(-6), not (-2) x (-3).
21 | * Add(2, 1, evaluate=False) gives 1 + 2, because the terms are sorted.
22 |
23 | As such, it's easier just to work with our own op classes that display precisely
24 | as we created them. This also allows us to use custom symbols for the
25 | expressions, such as the multiplication symbol.
26 | """
27 |
28 | from __future__ import absolute_import
29 | from __future__ import division
30 | from __future__ import print_function
31 |
32 | import abc
33 |
34 | # Dependency imports
35 | from absl import logging
36 | from mathematics_dataset.sample import number
37 | from mathematics_dataset.util import display
38 | import numpy as np
39 | import six
40 | from six.moves import zip
41 | import sympy
42 |
43 |
44 | MUL_SYMBOL = '*'
45 | DIV_SYMBOL = '/'
46 | POW_SYMBOL = '**'
47 | GT_SYMBOL = '>'
48 | LT_SYMBOL = '<'
49 | GE_SYMBOL = '>='
50 | LE_SYMBOL = '<='
51 | EQ_SYMBOL = '='
52 | NE_SYMBOL = '!='
53 |
54 |
55 | # Operator precedence levels. Used to insert brackets if necessary.
56 | _EQ_PRECEDENCE = 0
57 | _CONSTANT_PRECEDENCE = 1
58 | _POW_PRECEDENCE = 2
59 | _SQRT_PRECEDENCE = 3
60 | _MUL_PRECEDENCE = 4
61 | _ADD_PRECEDENCE = 5
62 |
63 |
64 | def bracketed(child, parent, bracket_if_same_precedence):
65 | """Returns string representation of `child`, possibly bracketed.
66 |
67 | Args:
68 | child: Instance of `Op` or a valid value for `ConstantOp`.
69 | parent: Instance of `Op`. Used to determine whether `child` needs to be
70 | bracketed first before appearing in the parent op's expression.
71 | bracket_if_same_precedence: Whether to bracket if the child has the same
72 | operator precedence as the parent.
73 |
74 | Returns:
75 | String representation of `child`.
76 | """
77 | if not isinstance(child, Op):
78 | child = Constant(child)
79 |
80 | child_precedence = child.precedence
81 | parent_precedence = parent.precedence
82 | if (parent_precedence > child_precedence
83 | or (parent_precedence == child_precedence
84 | and not bracket_if_same_precedence)):
85 | return str(child)
86 | else:
87 | return '({})'.format(child)
88 |
89 |
90 | def _flatten(iterable):
91 | """Returns list."""
92 | if isinstance(iterable, (list, tuple)):
93 | result = list(iterable)
94 | else:
95 | assert isinstance(iterable, dict)
96 | keys = sorted(six.iterkeys(iterable))
97 | result = [iterable[key] for key in keys]
98 | # Check we don't have any hierarchy in the structure (otherwise would need
99 | # to use something recursive like tf.contrib.framework.nest.flatten).
100 | for item in result:
101 | assert not isinstance(item, (list, tuple, dict))
102 | return result
103 |
104 |
105 | def _pack_sequence_as(example, flat):
106 | if isinstance(example, list) or isinstance(example, tuple):
107 | return flat
108 | else:
109 | assert isinstance(example, dict)
110 | keys = sorted(six.iterkeys(example))
111 | return {key: value for key, value in zip(keys, flat)}
112 |
113 |
114 | @six.add_metaclass(abc.ABCMeta)
115 | class Op(object):
116 | """An operation.
117 |
118 | This needs to support being transformed into sympy (and possibly in the future
119 | other types such as an appropriately formatted string), when given the op
120 | arguments.
121 | """
122 |
123 | def __init__(self, children):
124 | """Initialize this `Op` base class.
125 |
126 | Args:
127 | children: Iterable structure containing child ops.
128 | """
129 | assert isinstance(children, (list, dict, tuple))
130 | flat_children = _flatten(children)
131 | flat_children = [child if isinstance(child, Op) else Constant(child)
132 | for child in flat_children]
133 | children = _pack_sequence_as(children, flat_children)
134 | self._children = children
135 |
136 | @property
137 | def children(self):
138 | """Returns iterable or dict over immediate children."""
139 | return self._children
140 |
141 | def descendants(self):
142 | """Returns list of all descendants (self, children, grandchildren, etc)."""
143 | descendants = [self]
144 | flat_children = _flatten(self._children)
145 | for child in flat_children:
146 | descendants += child.descendants()
147 | return descendants
148 |
149 | @abc.abstractmethod
150 | def __str__(self):
151 | """Returns a string format of this op."""
152 |
153 | @abc.abstractmethod
154 | def sympy(self):
155 | """Returns the sympifcation of this op."""
156 |
157 | def _sympy_(self):
158 | """Convenience method to automatically sympify this object."""
159 | try:
160 | return self.sympy()
161 | except AttributeError as e:
162 | # Note: we print this error here, before raising it again, because sympy
163 | # will think `AttributeError` refers to this object not having a `_sympy_`
164 | # method, rather than having it, which leads to otherwise confusing error
165 | # messages.
166 | logging.error(
167 | 'Encountered attribute error while trying to sympify: %s', e)
168 | raise e
169 |
170 | @abc.abstractproperty
171 | def precedence(self):
172 | """Returns the precedence (integer) of this op."""
173 |
174 |
175 | class Constant(Op):
176 | """Returns a constant value; a nullary op."""
177 |
178 | def __init__(self, value):
179 | super(Constant, self).__init__([])
180 | if isinstance(value, six.integer_types):
181 | value = sympy.Integer(value)
182 | self._value = value
183 |
184 | def __str__(self):
185 | return str(self._value)
186 |
187 | def sympy(self):
188 | return self._value
189 |
190 | @property
191 | def value(self):
192 | return self._value
193 |
194 | @value.setter
195 | def value(self, value):
196 | self._value = value
197 |
198 | def _is_simple(self):
199 | """Returns whether it's a simple number, rather than a division or neg."""
200 | if isinstance(self._value, sympy.Symbol):
201 | return True
202 | elif (isinstance(self._value, int)
203 | or isinstance(self._value, sympy.Integer)
204 | or isinstance(self._value, display.Decimal)
205 | or isinstance(self._value, np.int64)
206 | or isinstance(self._value, np.int32)):
207 | return self._value >= 0
208 | elif isinstance(self._value, sympy.Rational):
209 | return False
210 | elif isinstance(self._value, sympy.Function):
211 | return True
212 | else:
213 | raise ValueError('Unknown type {}'.format(type(self._value)))
214 |
215 | @property
216 | def precedence(self):
217 | if self._is_simple():
218 | return _CONSTANT_PRECEDENCE
219 | else:
220 | return _MUL_PRECEDENCE
221 |
222 |
223 | class _SumLikeOp(Op):
224 | """Abstract op for sum-like terms which may contain negative entries."""
225 |
226 | @abc.abstractmethod
227 | def expanded_signs_and_terms(self):
228 | """Returns a list of arguments, plus any sub-arguments from sub-adds.
229 |
230 | E.g., if this op is `Add(Add(2, Neg(3)), Mul(4, 5), 1)`, then will return
231 | `[(True, 2), (False, 3), (True, Mul(4, 5)), (True, 1)]` (the arguments of
232 | the inner add have been extracted).
233 | """
234 |
235 | def __str__(self):
236 | signs_and_terms = self.expanded_signs_and_terms()
237 | if not signs_and_terms:
238 | return '0'
239 | for i, (sign, term) in enumerate(signs_and_terms):
240 | if i == 0:
241 | if sign:
242 | expression = bracketed(term, self, True)
243 | else:
244 | expression = '-' + bracketed(term, self, True)
245 | else:
246 | if sign:
247 | expression += ' + ' + bracketed(term, self, True)
248 | else:
249 | expression += ' - ' + bracketed(term, self, True)
250 | return expression
251 |
252 |
253 | class Identity(_SumLikeOp):
254 | """The identity op (a unitary op)."""
255 |
256 | def __init__(self, input_):
257 | super(Identity, self).__init__({'input': input_})
258 |
259 | def expanded_signs_and_terms(self):
260 | if isinstance(self.children['input'], _SumLikeOp):
261 | return self.children['input'].expanded_signs_and_terms()
262 | else:
263 | return [(True, self.children['input'])]
264 |
265 | def __str__(self):
266 | return str(self.children['input'])
267 |
268 | def sympy(self):
269 | return self.children['input'].sympy()
270 |
271 | @property
272 | def precedence(self):
273 | return self.children['input'].precedence
274 |
275 |
276 | class Neg(_SumLikeOp):
277 | """Negation, a unary op. Also has special display when appearing in a sum."""
278 |
279 | def __init__(self, arg):
280 | super(Neg, self).__init__({'input': arg})
281 |
282 | def expanded_signs_and_terms(self):
283 | if isinstance(self.children['input'], _SumLikeOp):
284 | inner_signs_and_terms = self.children['input'].expanded_signs_and_terms()
285 | return [(not sign, term) for (sign, term) in inner_signs_and_terms]
286 | else:
287 | return [(False, self.children['input'])]
288 |
289 | def sympy(self):
290 | return -sympy.sympify(self.children['input'])
291 |
292 | def inner(self):
293 | return self.children['input']
294 |
295 | @property
296 | def precedence(self):
297 | return _ADD_PRECEDENCE
298 |
299 |
300 | class Add(_SumLikeOp):
301 | """Addition."""
302 |
303 | def __init__(self, *args):
304 | super(Add, self).__init__(args)
305 |
306 | def expanded_signs_and_terms(self):
307 | """Returns a list of arguments, plus any sub-arguments from sub-adds.
308 |
309 | E.g., if this op is `Add(Add(2, 3), Mul(4, 5), 1)`, then will return
310 | `[2, 3, Mul(4, 5), 1]` (the arguments of the inner add have been extracted).
311 | """
312 | expanded = []
313 | for arg in self.children:
314 | if isinstance(arg, _SumLikeOp):
315 | expanded += arg.expanded_signs_and_terms()
316 | else:
317 | expanded.append((True, arg))
318 | return expanded
319 |
320 | def sympy(self):
321 | return sympy.Add(*[sympy.sympify(arg) for arg in self.children])
322 |
323 | @property
324 | def precedence(self):
325 | return _ADD_PRECEDENCE
326 |
327 |
328 | class Sub(Op):
329 | """Subtraction."""
330 |
331 | def __init__(self, left, right):
332 | super(Sub, self).__init__({'left': left, 'right': right})
333 |
334 | def __str__(self):
335 | return (bracketed(self.children['left'], self, False) + ' - '
336 | + bracketed(self.children['right'], self, True))
337 |
338 | def sympy(self):
339 | return sympy.Add(
340 | self.children['left'], sympy.Mul(-1, self.children['right']))
341 |
342 | @property
343 | def precedence(self):
344 | return _ADD_PRECEDENCE
345 |
346 |
347 | class Mul(Op):
348 | """Multiplication."""
349 |
350 | def __init__(self, *args):
351 | super(Mul, self).__init__(args)
352 |
353 | def __str__(self):
354 | if not self.children:
355 | return '1'
356 | else:
357 | args = [bracketed(arg, self, False) for arg in self.children]
358 | return MUL_SYMBOL.join(args)
359 |
360 | def sympy(self):
361 | return sympy.Mul(*[sympy.sympify(arg) for arg in self.children])
362 |
363 | @property
364 | def precedence(self):
365 | return _MUL_PRECEDENCE
366 |
367 |
368 | class Div(Op):
369 | """Division."""
370 |
371 | def __init__(self, numer, denom):
372 | super(Div, self).__init__({'numer': numer, 'denom': denom})
373 |
374 | def __str__(self):
375 | return u'{}{}{}'.format(
376 | bracketed(self.children['numer'], self, True), DIV_SYMBOL,
377 | bracketed(self.children['denom'], self, True))
378 |
379 | def sympy(self):
380 | return sympy.Mul(
381 | self.children['numer'], sympy.Pow(self.children['denom'], -1))
382 |
383 | @property
384 | def precedence(self):
385 | return _MUL_PRECEDENCE
386 |
387 |
388 | class Pow(Op):
389 | """Power a to the power b."""
390 |
391 | def __init__(self, a, b):
392 | super(Pow, self).__init__({'a': a, 'b': b})
393 |
394 | def __str__(self):
395 | return u'{}{}{}'.format(
396 | bracketed(self.children['a'], self, True), POW_SYMBOL,
397 | bracketed(self.children['b'], self, True))
398 |
399 | def sympy(self):
400 | return sympy.Pow(
401 | sympy.sympify(self.children['a']), sympy.sympify(self.children['b']))
402 |
403 | @property
404 | def precedence(self):
405 | return _POW_PRECEDENCE
406 |
407 |
408 | class Sqrt(Op):
409 | """Square root of a value."""
410 |
411 | def __init__(self, a):
412 | super(Sqrt, self).__init__({'a': a})
413 |
414 | def __str__(self):
415 | return 'sqrt({})'.format(self.children['a'])
416 |
417 | def sympy(self):
418 | return sympy.sqrt(self.children['a'])
419 |
420 | @property
421 | def precedence(self):
422 | return _POW_PRECEDENCE
423 |
424 |
425 | class Eq(Op):
426 | """Equality."""
427 |
428 | def __init__(self, left, right):
429 | super(Eq, self).__init__({'left': left, 'right': right})
430 |
431 | def __str__(self):
432 | return '{} = {}'.format(self.children['left'], self.children['right'])
433 |
434 | def sympy(self):
435 | return sympy.Eq(self.children['left'], self.children['right'])
436 |
437 | @property
438 | def precedence(self):
439 | return _EQ_PRECEDENCE
440 |
441 |
442 | def number_constants(expressions):
443 | """Returns list of integer, rational, decimal constants in the expressions."""
444 | if isinstance(expressions, Op):
445 | expressions = [expressions]
446 | descendants = []
447 | for expression in expressions:
448 | descendants += expression.descendants()
449 | candidate_constants = [op for op in descendants if isinstance(op, Constant)]
450 | return [constant for constant in candidate_constants
451 | if number.is_integer_or_rational_or_decimal(constant.value)]
452 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/ops_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.sample.ops."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Dependency imports
22 | from absl.testing import absltest
23 | from mathematics_dataset.sample import ops
24 | from six.moves import range
25 | import sympy
26 |
27 |
28 | class OpsTest(absltest.TestCase):
29 |
30 | def testNeg(self):
31 | op = ops.Neg(2)
32 | self.assertEqual(str(op), '-2')
33 | self.assertEqual(op.sympy(), -2)
34 |
35 | op = ops.Add(ops.Neg(2), 3)
36 | self.assertEqual(str(op), '-2 + 3')
37 | self.assertEqual(op.sympy(), 1)
38 |
39 | op = ops.Add(3, ops.Neg(2))
40 | self.assertEqual(str(op), '3 - 2')
41 | self.assertEqual(op.sympy(), 1)
42 |
43 | op = ops.Add(ops.Add(ops.Neg(2), 5), 3)
44 | self.assertEqual(str(op), '-2 + 5 + 3')
45 | self.assertEqual(op.sympy(), 6)
46 |
47 | op = ops.Add(3, ops.Add(ops.Identity(ops.Neg(2)), 5))
48 | self.assertEqual(str(op), '3 - 2 + 5')
49 | self.assertEqual(op.sympy(), 6)
50 |
51 | op = ops.Add(3, ops.Add(2, ops.Neg(5)))
52 | self.assertEqual(str(op), '3 + 2 - 5')
53 | self.assertEqual(op.sympy(), 0)
54 |
55 | def testAdd(self):
56 | add = ops.Add()
57 | self.assertEqual(str(add), '0')
58 | self.assertEqual(add.sympy(), 0)
59 |
60 | add = ops.Add(2, 3)
61 | self.assertEqual(str(add), '2 + 3')
62 | self.assertEqual(add.sympy(), 5)
63 |
64 | add = ops.Add(ops.Add(1, 2), 3)
65 | self.assertEqual(str(add), '1 + 2 + 3')
66 | self.assertEqual(add.sympy(), 6)
67 |
68 | def testSub(self):
69 | sub = ops.Sub(2, 3)
70 | self.assertEqual(str(sub), '2 - 3')
71 | self.assertEqual(sub.sympy(), -1)
72 |
73 | sub = ops.Sub(ops.Sub(1, 2), 3)
74 | self.assertEqual(str(sub), '1 - 2 - 3')
75 | self.assertEqual(sub.sympy(), -4)
76 |
77 | sub = ops.Sub(1, ops.Sub(2, 3))
78 | self.assertEqual(str(sub), '1 - (2 - 3)')
79 | self.assertEqual(sub.sympy(), 2)
80 |
81 | sub = ops.Sub(ops.Neg(1), 2)
82 | self.assertEqual(str(sub), '-1 - 2')
83 | self.assertEqual(sub.sympy(), -3)
84 |
85 | def testMul(self):
86 | mul = ops.Mul()
87 | self.assertEqual(str(mul), '1')
88 | self.assertEqual(mul.sympy(), 1)
89 |
90 | mul = ops.Mul(2, 3)
91 | self.assertEqual(str(mul), '2*3')
92 | self.assertEqual(mul.sympy(), 6)
93 |
94 | mul = ops.Mul(ops.Identity(ops.Constant(-2)), 3)
95 | self.assertEqual(str(mul), '-2*3')
96 | self.assertEqual(mul.sympy(), -6)
97 |
98 | mul = ops.Mul(ops.Add(1, 2), 3)
99 | self.assertEqual(str(mul), '(1 + 2)*3')
100 | self.assertEqual(mul.sympy(), 9)
101 |
102 | mul = ops.Mul(ops.Mul(2, 3), 5)
103 | self.assertEqual(str(mul), '2*3*5')
104 | self.assertEqual(mul.sympy(), 30)
105 |
106 | # TODO(b/124038946): reconsider how we want brackets in these cases:
107 | # mul = ops.Mul(ops.Div(2, 3), 5)
108 | # self.assertEqual(str(mul), '(2/3)*5')
109 | # self.assertEqual(mul.sympy(), sympy.Rational(10, 3))
110 | #
111 | # mul = ops.Mul(sympy.Rational(2, 3), 5)
112 | # self.assertEqual(str(mul), '(2/3)*5')
113 | # self.assertEqual(mul.sympy(), sympy.Rational(10, 3))
114 |
115 | def testDiv(self):
116 | div = ops.Div(2, 3)
117 | self.assertEqual(str(div), '2/3')
118 | self.assertEqual(div.sympy(), sympy.Rational(2, 3))
119 |
120 | div = ops.Div(2, sympy.Rational(4, 5))
121 | self.assertEqual(str(div), '2/(4/5)')
122 | self.assertEqual(div.sympy(), sympy.Rational(5, 2))
123 |
124 | div = ops.Div(1, ops.Div(2, 3))
125 | self.assertEqual(str(div), '1/(2/3)')
126 | self.assertEqual(div.sympy(), sympy.Rational(3, 2))
127 |
128 | div = ops.Div(ops.Div(2, 3), 4)
129 | self.assertEqual(str(div), '(2/3)/4')
130 | self.assertEqual(div.sympy(), sympy.Rational(1, 6))
131 |
132 | div = ops.Div(2, ops.Mul(3, 4))
133 | self.assertEqual(str(div), '2/(3*4)')
134 |
135 | div = ops.Div(2, sympy.Function('f')(sympy.Symbol('x')))
136 | self.assertEqual(str(div), '2/f(x)')
137 |
138 | def testPow(self):
139 | pow_ = ops.Pow(2, 3)
140 | self.assertEqual(str(pow_), '2**3')
141 | self.assertEqual(pow_.sympy(), 8)
142 |
143 | pow_ = ops.Pow(4, sympy.Rational(1, 2))
144 | self.assertEqual(str(pow_), '4**(1/2)')
145 | self.assertEqual(pow_.sympy(), 2)
146 |
147 | pow_ = ops.Pow(sympy.Rational(1, 2), 3)
148 | self.assertEqual(str(pow_), '(1/2)**3')
149 | self.assertEqual(pow_.sympy(), 1/8)
150 |
151 | pow_ = ops.Pow(3, ops.Pow(2, 1))
152 | self.assertEqual(str(pow_), '3**(2**1)')
153 | self.assertEqual(pow_.sympy(), 9)
154 |
155 | pow_ = ops.Pow(ops.Pow(2, 3), 4)
156 | self.assertEqual(str(pow_), '(2**3)**4')
157 | self.assertEqual(pow_.sympy(), 4096)
158 |
159 | pow_ = ops.Pow(-5, 2)
160 | self.assertEqual(str(pow_), '(-5)**2')
161 | self.assertEqual(pow_.sympy(), 25)
162 |
163 | def testEq(self):
164 | op = ops.Eq(ops.Add(2, 3), 4)
165 | self.assertEqual(str(op), '2 + 3 = 4')
166 | self.assertEqual(op.sympy(), False)
167 |
168 | def testDescendants(self):
169 | constants = [ops.Constant(i) for i in range(6)]
170 |
171 | # (1 + 2*3**4) / 5 - 6
172 | expression = ops.Sub(
173 | ops.Div(
174 | ops.Add(
175 | constants[0],
176 | ops.Mul(
177 | constants[1],
178 | ops.Pow(
179 | constants[2],
180 | constants[3]))),
181 | constants[4]),
182 | constants[5])
183 | descendants = expression.descendants()
184 | descendants = ops._flatten(descendants)
185 |
186 | for constant in constants:
187 | self.assertIn(constant, descendants)
188 | self.assertEqual(descendants.count(constant), 1)
189 |
190 | # Also test top-level.
191 | self.assertEqual(constants[0].descendants(), [constants[0]])
192 |
193 | # Also general structure.
194 | constant = ops.Constant(3)
195 | expression = ops.Neg(constant)
196 | self.assertEqual(set(expression.descendants()), set([constant, expression]))
197 |
198 | def testNumberConstants(self):
199 | constant = ops.Constant(3)
200 | expression = ops.Neg(constant)
201 | constants = ops.number_constants([expression])
202 | self.assertEqual(constants, [constant])
203 |
204 |
205 | if __name__ == '__main__':
206 | absltest.main()
207 |
--------------------------------------------------------------------------------
/mathematics_dataset/sample/polynomials_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.sample.polynomials."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 |
23 | # Dependency imports
24 | from absl.testing import parameterized
25 | from mathematics_dataset.sample import polynomials
26 | import numpy as np
27 | from six.moves import range
28 | import sympy
29 | import tensorflow as tf
30 |
31 |
32 | class ExpressionWithValueTest(tf.test.TestCase, parameterized.TestCase):
33 |
34 | def testSplitValueEqually(self):
35 | split = polynomials._split_value_equally(3, 2)
36 | self.assertEqual(split, [1, 2])
37 | split = polynomials._split_value_equally(sympy.sympify('3/4'), 2)
38 | self.assertEqual(split, [sympy.sympify('1/4'), sympy.sympify('1/2')])
39 |
40 | def testIntegersWithSum(self):
41 | value = 13
42 | count = 10
43 | terms = polynomials.integers_with_sum(value=value, count=count, entropy=4.0)
44 | self.assertLen(terms, count)
45 | self.assertEqual(sum(terms), value)
46 |
47 | def testMonomial(self):
48 | x, y = sympy.symbols('x y')
49 | self.assertEqual(str(polynomials.monomial(1, [x, y], [2, 3])), 'x**2*y**3')
50 | # TODO(b/124038530): how handle rational coefficients; are they even used?
51 | # self.assertEqual(
52 | # str(polynomials.monomial(sympy.Rational(2, 3), [x], [1])), '2*x/3')
53 | # self.assertEqual(
54 | # str(polynomials.monomial(sympy.Rational(1, 3), [x], [1])), 'x/3')
55 | self.assertEqual(str(polynomials.monomial(x, [y], [4])), 'x*y**4')
56 |
57 | def testExpandCoefficients(self):
58 | for _ in range(10):
59 | num_variables = np.random.randint(1, 4)
60 | degrees = np.random.randint(0, 4, [num_variables])
61 | coefficients = np.random.randint(-3, 3, degrees + 1)
62 | entropy = np.random.uniform(0, 10)
63 | expanded = polynomials.expand_coefficients(coefficients, entropy)
64 | collapsed = np.vectorize(sum)(expanded)
65 | self.assertAllEqual(coefficients, collapsed)
66 |
67 | def testCoefficientsToPolynomial(self):
68 | coeffs = [3, 2, 1]
69 | x = sympy.Symbol('x')
70 | polynomial = polynomials.coefficients_to_polynomial(coeffs, [x])
71 | polynomial = sympy.sympify(polynomial)
72 | self.assertEqual(polynomial, x*x + 2*x + 3)
73 |
74 | def testUnivariate(self):
75 | # Test generation for: x**2 + 2*x + 1
76 | x = sympy.Symbol('x')
77 | coeffs = [1, 2, 3]
78 | for _ in range(10):
79 | expanded = polynomials.expand_coefficients(coeffs, 5.0)
80 | polynomial = polynomials.coefficients_to_polynomial(expanded, [x])
81 | sympified = sympy.sympify(polynomial)
82 | self.assertEqual(sympified, 1 + 2*x + 3*x*x)
83 |
84 | def testMultivariate(self):
85 | # Test generation for: x**2 + 2*x*y + 3*y**2 - x + 5
86 | x, y = sympy.symbols('x y')
87 | coeffs = [[5, 0, 3], [-1, 2, 0], [1, 0, 0]]
88 | for _ in range(10):
89 | expanded = polynomials.expand_coefficients(coeffs, 5.0, length=10)
90 | polynomial = polynomials.coefficients_to_polynomial(expanded, [x, y])
91 | sympified = sympy.sympify(polynomial)
92 | self.assertEqual(sympified, x*x + 2*x*y + 3*y*y - x + 5)
93 |
94 | def testAddCoefficients(self):
95 | # Add x**2 + 2*y and 3*x + 4*y**3.
96 | coeffs1 = [[0, 2], [0, 0], [1, 0]]
97 | coeffs2 = [[0, 0, 0, 4], [3, 0, 0, 0]]
98 | target = [[0, 2, 0, 4], [3, 0, 0, 0], [1, 0, 0, 0]]
99 | actual = polynomials.add_coefficients(coeffs1, coeffs2)
100 | self.assertAllEqual(target, actual)
101 |
102 | def testCoefficientsLinearSplit(self):
103 | for degree in range(3):
104 | for ndims in range(3):
105 | for _ in range(10):
106 | coefficients = np.random.randint(-5, 5, [degree + 1] * ndims)
107 | entropy = random.uniform(1, 4)
108 | c1, c2, coeffs1, coeffs2 = polynomials.coefficients_linear_split(
109 | coefficients, entropy)
110 | c1 = int(c1)
111 | c2 = int(c2)
112 | coeffs1 = np.asarray(coeffs1, dtype=np.int32)
113 | coeffs2 = np.asarray(coeffs2, dtype=np.int32)
114 | sum_ = c1 * coeffs1 + c2 * coeffs2
115 | self.assertAllEqual(sum_, coefficients)
116 |
117 | def testSampleWithBrackets(self):
118 | x, y = sympy.symbols('x y')
119 | for _ in range(100):
120 | degrees = np.random.randint(1, 4, [2])
121 | entropy = random.uniform(0, 4)
122 | polynomial = polynomials.sample_with_brackets(
123 | variables=[x, y], degrees=degrees, entropy=entropy)
124 | self.assertIn('(', str(polynomial))
125 | poly = sympy.poly(sympy.sympify(polynomial).expand())
126 | self.assertEqual(poly.degree(x), degrees[0])
127 | self.assertEqual(poly.degree(y), degrees[1])
128 |
129 | def testTrim(self):
130 | self.assertAllEqual(polynomials.trim([1]), [1])
131 | self.assertAllEqual(polynomials.trim([1, 0]), [1])
132 | self.assertAllEqual(polynomials.trim([0, 1]), [0, 1])
133 | self.assertAllEqual(polynomials.trim([0]), [])
134 | self.assertAllEqual(polynomials.trim([0, 0]), [])
135 |
136 | def testDifferentiate_univariate(self):
137 | coeffs = [5, 3, 2]
138 | expected = [3, 4]
139 | actual = polynomials.differentiate(coeffs, 0)
140 | self.assertAllEqual(expected, actual)
141 |
142 | def testDifferentiate_multivariate(self):
143 | coeffs = [[0, 3, 1], [5, 0, 0], [0, 2, 0]]
144 | expected = [[5, 0], [0, 4]]
145 | actual = polynomials.differentiate(coeffs, 0)
146 | self.assertAllEqual(expected, actual)
147 |
148 | def testIntegrate_univariate(self):
149 | coeffs = [5, 3, 2]
150 | expected = [0, 5, sympy.Rational(3, 2), sympy.Rational(2, 3)]
151 | actual = polynomials.integrate(coeffs, 0)
152 | self.assertAllEqual(expected, actual)
153 |
154 | def testIntegrate_multivariate(self):
155 | coeffs = [[0, 1], [1, 0]]
156 | expected = [[0, 0, sympy.Rational(1, 2)], [0, 1, 0]]
157 | actual = polynomials.integrate(coeffs, 1)
158 | self.assertAllEqual(expected, actual)
159 |
160 |
161 | if __name__ == '__main__':
162 | tf.test.main()
163 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/combinatorics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Combinatorics utility functions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import math
22 | import random
23 |
24 | # Dependency imports
25 | from six.moves import range
26 | from six.moves import zip
27 |
28 |
29 | def uniform_positive_integers_with_sum(count, sum_):
30 | """Returns list of size `count` of integers >= 1, summing to `sum_`."""
31 | assert sum_ >= 0
32 | if count > sum_:
33 | raise ValueError('Cannot find {} numbers >= 1 with sum {}'
34 | .format(count, sum_))
35 | if count == 0:
36 | return []
37 | # Select `count - 1` numbers from {1, ..., sum_ - 1}
38 | separators = random.sample(list(range(1, sum_)), count - 1)
39 | separators = sorted(separators)
40 | return [right - left
41 | for left, right in zip([0] + separators, separators + [sum_])]
42 |
43 |
44 | def uniform_non_negative_integers_with_sum(count, sum_):
45 | """Returns list of size `count` of integers >= 0, summing to `sum_`."""
46 | positive = uniform_positive_integers_with_sum(count, sum_ + count)
47 | return [i - 1 for i in positive]
48 |
49 |
50 | def log_number_binary_trees(size):
51 | """Returns (nat) log of number of binary trees with `size` internal nodes."""
52 | # This is equal to log of C_size, where C_n is the nth Catalan number.
53 | assert isinstance(size, int)
54 | assert size >= 0
55 | log = 0.0
56 | for k in range(2, size + 1):
57 | log += math.log(size + k) - math.log(k)
58 | return log
59 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/combinatorics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests mathematics_dataset.util.combinatorics."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import math
22 |
23 | # Dependency imports
24 | from absl.testing import absltest
25 | from mathematics_dataset.util import combinatorics
26 |
27 |
28 | class CombinatoricsTest(absltest.TestCase):
29 |
30 | def testPositiveIntegersWithSum(self):
31 | result = combinatorics.uniform_positive_integers_with_sum(1, 1)
32 | self.assertEqual(result, [1])
33 | result = combinatorics.uniform_positive_integers_with_sum(2, 2)
34 | self.assertEqual(result, [1, 1])
35 | result = combinatorics.uniform_positive_integers_with_sum(1, 10)
36 | self.assertEqual(sum(result), 10)
37 | result = combinatorics.uniform_positive_integers_with_sum(2, 10)
38 | self.assertEqual(sum(result), 10)
39 | result = combinatorics.uniform_positive_integers_with_sum(0, 0)
40 | self.assertEqual(result, [])
41 |
42 | def testNonNegativeIntegersWithSum(self):
43 | result = combinatorics.uniform_non_negative_integers_with_sum(1, 0)
44 | self.assertEqual(result, [0])
45 | result = combinatorics.uniform_non_negative_integers_with_sum(2, 0)
46 | self.assertEqual(result, [0, 0])
47 | result = combinatorics.uniform_non_negative_integers_with_sum(3, 10)
48 | self.assertEqual(sum(result), 10)
49 |
50 | def testLogNumberBinaryTrees(self):
51 | self.assertAlmostEqual(
52 | combinatorics.log_number_binary_trees(0), math.log(1))
53 | self.assertAlmostEqual(
54 | combinatorics.log_number_binary_trees(1), math.log(1))
55 | self.assertAlmostEqual(
56 | combinatorics.log_number_binary_trees(2), math.log(2))
57 | self.assertAlmostEqual(
58 | combinatorics.log_number_binary_trees(3), math.log(5))
59 | self.assertAlmostEqual(
60 | combinatorics.log_number_binary_trees(4), math.log(14))
61 |
62 |
63 | if __name__ == '__main__':
64 | absltest.main()
65 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/composition_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.util.composition."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Dependency imports
22 | from absl.testing import absltest
23 | from mathematics_dataset.util import composition
24 | import sympy
25 |
26 |
27 | class FunctionHandleTest(absltest.TestCase):
28 |
29 | def testApply(self):
30 | handle = composition.FunctionHandle('f', 'g')
31 | applied = handle.apply(*sympy.symbols('x y'))
32 | self.assertEqual(str(applied), 'f(g(x, y))')
33 | applied = handle.apply(sympy.symbols('x'))
34 | self.assertEqual(str(applied), 'f(g(x))')
35 |
36 |
37 | class ContextTest(absltest.TestCase):
38 |
39 | def testPeel(self):
40 | sample_args = composition.SampleArgs(4, 3.0)
41 | entropy, new_sample_args = sample_args.peel()
42 | self.assertAlmostEqual(entropy, 0.75)
43 | self.assertEqual(new_sample_args.num_modules, 4)
44 | self.assertAlmostEqual(new_sample_args.entropy, 2.25)
45 |
46 | def testSplit(self):
47 | sample_args = composition.SampleArgs(4, 5.0)
48 | children = sample_args.split(2)
49 | self.assertLen(children, 2)
50 | self.assertEqual(sum([child.num_modules for child in children]), 3)
51 | self.assertAlmostEqual(sum([child.entropy for child in children]), 5.0)
52 |
53 |
54 | class EntityTest(absltest.TestCase):
55 |
56 | def testInit_valueErrorIfSelfAndHandle(self):
57 | with self.assertRaisesRegex(self, ValueError, 'Cannot specify handle'):
58 | composition.Entity(context=composition.Context(),
59 | value=0,
60 | description='Something with {self}. ',
61 | handle='additional')
62 |
63 |
64 | if __name__ == '__main__':
65 | absltest.main()
66 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/display.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functionality for displaying expressions.
16 |
17 | SymPy provides a lot of functionality for displaying expressions, but it's
18 | slightly too centered on being a symbolic maths engine to provides all our
19 | needs. For example, it's impossible to display an unsimplified fraction like
20 | 3/6, or a decimal that isn't internally represented as a float and thus subject
21 | to rounding.
22 |
23 | Also provides some other convenience such as converting numbers to words, and
24 | displaying percentages (properly formatted).
25 | """
26 |
27 | from __future__ import absolute_import
28 | from __future__ import division
29 | from __future__ import print_function
30 |
31 | import decimal
32 |
33 | # Dependency imports
34 | import sympy
35 |
36 |
37 | # For converting integers to words:
38 | _INTEGER_LOW = [
39 | 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight',
40 | 'nine', 'ten', 'eleven', 'twelve', 'thirteen', 'fourteeen', 'fifteen',
41 | 'sixteen', 'seventeen', 'eighteen', 'nineteen'
42 | ]
43 | _INTEGER_MID = [
44 | '', '', 'twenty', 'thirty', 'fourty', 'fifty', 'sixty', 'seventy', 'eighty',
45 | 'ninety'
46 | ]
47 | _INTEGER_HIGH = [
48 | (int(1e12), 'trillion'), (int(1e9), 'billion'), (int(1e6), 'million'),
49 | (int(1e3), 'thousand'), (100, 'hundred')
50 | ]
51 |
52 |
53 | # For converting rationals to words:
54 | _SINGULAR_DENOMINATORS = [
55 | '', '', 'half', 'third', 'quarter', 'fifth', 'sixth', 'seventh', 'eighth',
56 | 'ninth', 'tenth', 'eleventh', 'twelth', 'thirteenth', 'fourteenth',
57 | 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth', 'nineteenth',
58 | 'twentieth'
59 | ]
60 | _PLURAL_DENOMINATORS = [
61 | '', '', 'halves', 'thirds', 'quarters', 'fifths', 'sixths', 'sevenths',
62 | 'eighths', 'ninths', 'tenths', 'elevenths', 'twelths', 'thirteenths',
63 | 'fourteenths', 'fifteenths', 'sixteenths', 'seventeenths', 'eighteenths',
64 | 'nineteenths', 'twentieths'
65 | ]
66 |
67 |
68 | # For converting ordinals to words:
69 | _ORDINALS = [
70 | 'zeroth', 'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh',
71 | 'eighth', 'ninth', 'tenth', 'eleventh', 'twelth', 'thirteenth',
72 | 'fourteenth', 'fifteenth', 'sixteenth', 'seventeenth', 'eighteenth',
73 | 'nineteenth', 'twentieth'
74 | ]
75 |
76 |
77 | class Decimal(object):
78 | """Display a value as a decimal."""
79 |
80 | def __init__(self, value):
81 | """Initializes a `Decimal`.
82 |
83 | Args:
84 | value: (Sympy) value to display as a decimal.
85 |
86 | Raises:
87 | ValueError: If `value` cannot be represented as a non-terminating decimal.
88 | """
89 | self._value = sympy.Rational(value)
90 |
91 | numer = int(sympy.numer(self._value))
92 | denom = int(sympy.denom(self._value))
93 |
94 | denom_factors = list(sympy.factorint(denom).keys())
95 | for factor in denom_factors:
96 | if factor not in [2, 5]:
97 | raise ValueError('Cannot represent {} as a non-recurring decimal.'
98 | .format(value))
99 | self._decimal = decimal.Decimal(numer) / decimal.Decimal(denom)
100 |
101 | @property
102 | def value(self):
103 | """Returns the value as a `sympy.Rational` object."""
104 | return self._value
105 |
106 | def _sympy_(self):
107 | return self._value
108 |
109 | def decimal_places(self):
110 | """Returns the number of decimal places, e.g., 32 has 0 and 1.43 has 2."""
111 | if isinstance(self._decimal, int):
112 | return 0
113 | elif isinstance(self._decimal, decimal.Decimal):
114 | return -self._decimal.as_tuple().exponent
115 |
116 | def __str__(self):
117 | sign, digits, exponent = self._decimal.as_tuple()
118 | sign = '' if sign == 0 else '-'
119 |
120 | num_left_digits = len(digits) + exponent # number digits "before" point
121 |
122 | if num_left_digits > 0:
123 | int_part = ''.join(str(digit) for digit in digits[:num_left_digits])
124 | else:
125 | int_part = '0'
126 |
127 | if exponent < 0:
128 | frac_part = '.'
129 | if num_left_digits < 0:
130 | frac_part += '0' * -num_left_digits
131 | frac_part += ''.join(str(digit) for digit in digits[exponent:])
132 | else:
133 | frac_part = ''
134 |
135 | return sign + int_part + frac_part
136 |
137 | def __add__(self, other):
138 | if not isinstance(other, Decimal):
139 | raise ValueError('Arithmetic support limited to other `Decimal`s.')
140 | return Decimal(self.value + other.value)
141 |
142 | def __sub__(self, other):
143 | if not isinstance(other, Decimal):
144 | raise ValueError('Arithmetic support limited to other `Decimal`s.')
145 | return Decimal(self.value - other.value)
146 |
147 | def __mul__(self, other):
148 | if not isinstance(other, Decimal):
149 | raise ValueError('Arithmetic support limited to other `Decimal`s.')
150 | return Decimal(self.value * other.value)
151 |
152 | def __neg__(self):
153 | return Decimal(-self.value)
154 |
155 | def round(self, ndigits=0):
156 | """Returns a new `Decimal` rounded to this many decimal places."""
157 | scale = sympy.Integer(10 ** ndigits)
158 | numer = sympy.numer(self.value) * scale
159 | denom = sympy.denom(self.value)
160 | return Decimal(int(round(numer / denom)) / scale)
161 |
162 | def __round__(self, ndigits):
163 | return self.round(ndigits)
164 |
165 | def __int__(self):
166 | """Returns conversion to integer if possible; TypeError if non-integer."""
167 | if self.decimal_places() == 0:
168 | return int(self._decimal)
169 | else:
170 | raise TypeError('Cannot represent {} as an integer.'.format(str(self)))
171 |
172 | # NOTE: this is implemented in addition to `__cmp__` because SymPy does not
173 | # support inequality comparison between sympy objects and objects that are not
174 | # convertible to sympy objects (such as strings).
175 | def __eq__(self, other):
176 | return self.value == other
177 |
178 | # Python 2 comparison
179 | def __cmp__(self, other):
180 | if self.value == other:
181 | return 0
182 | if self.value < other:
183 | return -1
184 | return 1
185 |
186 | # Python 3 comparison:
187 | def __lt__(self, other):
188 | return self.value < other
189 |
190 | def __le__(self, other):
191 | return self.value <= other
192 |
193 | def __gt__(self, other):
194 | return self.value > other
195 |
196 | def __ge__(self, other):
197 | return self.value >= other
198 |
199 |
200 | class Percentage(object):
201 | """Container for a percentage."""
202 |
203 | def __init__(self, value):
204 | """Initializes a `Percentage`.
205 |
206 | Args:
207 | value: Percentage as a fractional value. E.g., pass in
208 | `sympy.Rational(2, 5)` to create the percentage "40%".
209 | """
210 | self._value = value
211 |
212 | def _sympy_(self):
213 | return self._value
214 |
215 | def __str__(self):
216 | # Display percentages as decimals (not fractions).
217 | value = Decimal(self._value * 100)
218 | return str(value) + '%'
219 |
220 |
221 | class NonSimpleRational(object):
222 | """Container for rational a / b where allow gcd(a, b) > 1."""
223 |
224 | def __init__(self, numer, denom):
225 | self._numer = numer
226 | self._denom = denom
227 |
228 | @property
229 | def numer(self):
230 | return self._numer
231 |
232 | @property
233 | def denom(self):
234 | return self._denom
235 |
236 | def __str__(self):
237 | return '{}/{}'.format(self._numer, self._denom)
238 |
239 |
240 | class StringNumber(object):
241 | """A string representing a number, that can also be sympified."""
242 |
243 | def __init__(self, value, join_number_words_with_hyphens=True):
244 | """Initializes a `StringNumber`.
245 |
246 | Args:
247 | value: An integer or rational.
248 | join_number_words_with_hyphens: Whether to join the words in integers with
249 | hyphens when describing as a string.
250 | """
251 | self._join_number_words_with_hyphens = join_number_words_with_hyphens
252 | self._sympy_value = sympy.sympify(value)
253 | self._string = self._to_string(value)
254 |
255 | def _integer_to_words(self, integer):
256 | """Converts an integer to a list of words."""
257 | if integer < 0:
258 | raise ValueError('Cannot handle negative numbers.')
259 |
260 | if integer < 20:
261 | return [_INTEGER_LOW[integer]]
262 |
263 | words = None
264 |
265 | if integer < 100:
266 | tens, ones = divmod(integer, 10)
267 | if ones > 0:
268 | return [_INTEGER_MID[tens], _INTEGER_LOW[ones]]
269 | else:
270 | return [_INTEGER_MID[tens]]
271 |
272 | for value, word in _INTEGER_HIGH:
273 | if integer >= value:
274 | den, rem = divmod(integer, value)
275 | words = self._integer_to_words(den) + [word]
276 | if rem > 0:
277 | if rem < 100:
278 | words.append('and')
279 | words += self._integer_to_words(rem)
280 | return words
281 |
282 | def _rational_to_string(self, rational):
283 | """Converts a rational to words, e.g., "two thirds"."""
284 | numer = sympy.numer(rational)
285 | denom = sympy.denom(rational)
286 |
287 | numer_words = self._to_string(numer)
288 |
289 | if denom == 1:
290 | return numer_words
291 |
292 | if denom <= 0 or denom >= len(_PLURAL_DENOMINATORS):
293 | raise ValueError('Unsupported denominator {}.'.format(denom))
294 |
295 | if numer == 1:
296 | denom_word = _SINGULAR_DENOMINATORS[denom]
297 | else:
298 | denom_word = _PLURAL_DENOMINATORS[denom]
299 |
300 | return '{} {}'.format(numer_words, denom_word)
301 |
302 | def _to_string(self, number):
303 | """Converts an integer or rational to words."""
304 | if isinstance(number, sympy.Integer) or isinstance(number, int):
305 | words = self._integer_to_words(number)
306 | join_char = '-' if self._join_number_words_with_hyphens else ' '
307 | return join_char.join(words)
308 | elif isinstance(number, sympy.Rational):
309 | return self._rational_to_string(number)
310 | else:
311 | raise ValueError('Unable to handle number {} with type {}.'
312 | .format(number, type(number)))
313 |
314 | def _sympy_(self):
315 | return self._sympy_value
316 |
317 | def __str__(self):
318 | return self._string
319 |
320 |
321 | class StringOrdinal(object):
322 | """A string representation of an ordinal, e.g., "first"."""
323 |
324 | def __init__(self, position):
325 | """Initializes a `StringOrdinal`.
326 |
327 | Args:
328 | position: An integer >= 0.
329 |
330 | Raises:
331 | ValueError: If `position` is non-positive or out of range.
332 | """
333 | if position < 0 or position >= len(_ORDINALS):
334 | raise ValueError('Unsupported ordinal {}.'.format(position))
335 | self._string = _ORDINALS[position]
336 |
337 | def __str__(self):
338 | return self._string
339 |
340 |
341 | class NumberList(object):
342 | """Contains a list of numbers, intended for display."""
343 |
344 | def __init__(self, numbers):
345 | self._numbers = numbers
346 |
347 | def __str__(self):
348 | """Converts the list to a string.
349 |
350 | Returns:
351 | Human readable string.
352 |
353 | Raises:
354 | ValueError: if any of the strings contain a comma and thus would lead to
355 | an ambigious representation.
356 | """
357 | strings = []
358 | for number in self._numbers:
359 | string = str(number)
360 | if ',' in string:
361 | raise ValueError('String representation of the list will be ambigious, '
362 | 'since term "{}" contains a comma.'.format(string))
363 | strings.append(string)
364 | return ', '.join(strings)
365 |
366 |
367 | class NumberInBase(object):
368 | """Contains value, represented in a given base."""
369 |
370 | def __init__(self, value, base):
371 | """Initializes a `NumberInBase`.
372 |
373 | Args:
374 | value: Positive or negative integer.
375 | base: Integer in the range [2, 36].
376 |
377 | Raises:
378 | ValueError: If base is not in the range [2, 36] (since this is the limit
379 | that can be represented by 10 numbers plus 26 letters).
380 | """
381 | if not 2 <= base <= 36:
382 | raise ValueError('base={} must be in the range [2, 36]'.format(base))
383 | self._value = value
384 | self._base = base
385 |
386 | chars = []
387 | remainder = abs(value)
388 | while True:
389 | digit = remainder % base
390 | char = str(digit) if digit <= 9 else chr(ord('a') + digit - 10)
391 | chars.append(char)
392 | remainder = int(remainder / base)
393 | if remainder == 0:
394 | break
395 | if value < 0:
396 | chars.append('-')
397 |
398 | self._str = ''.join(reversed(chars))
399 |
400 | def __str__(self):
401 | return self._str
402 |
403 | def _sympy_(self):
404 | return self._value
405 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/display_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.util.display."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Dependency imports
22 | from absl.testing import absltest
23 | from mathematics_dataset.util import display
24 | import sympy
25 |
26 |
27 | class DecimalTest(absltest.TestCase):
28 |
29 | def testBasic_integer(self):
30 | decimal = display.Decimal(123)
31 | self.assertEqual(str(decimal), '123')
32 | self.assertEqual(sympy.sympify(decimal), sympy.Integer(123))
33 | self.assertEqual(decimal.decimal_places(), 0)
34 |
35 | def testBasic_ten(self):
36 | decimal = display.Decimal(10)
37 | self.assertEqual(str(decimal), '10')
38 | self.assertEqual(sympy.sympify(decimal), sympy.Integer(10))
39 | self.assertEqual(decimal.decimal_places(), 0)
40 |
41 | def testBasic(self):
42 | decimal = display.Decimal(sympy.Rational(123, 100))
43 | self.assertEqual(str(decimal), '1.23')
44 | self.assertEqual(sympy.sympify(decimal), sympy.Rational(123, 100))
45 | self.assertEqual(decimal.decimal_places(), 2)
46 |
47 | def testStr(self):
48 | self.assertEqual(str(display.Decimal(sympy.Rational(0, 10))), '0')
49 | self.assertEqual(str(display.Decimal(sympy.Rational(-1, 10))), '-0.1')
50 | self.assertEqual(str(display.Decimal(sympy.Rational(-11, 10))), '-1.1')
51 | self.assertEqual(str(display.Decimal(sympy.Rational(11, 10))), '1.1')
52 | self.assertEqual(str(display.Decimal(sympy.Rational(101, 1))), '101')
53 | self.assertEqual(
54 | str(display.Decimal(sympy.Rational(20171, 1000000))), '0.020171')
55 |
56 | def testStr_verySmall(self):
57 | # Tests it doesn't display in "scientific" notation 1E-9.
58 | decimal = display.Decimal(sympy.Rational(1, 1000000000))
59 | self.assertEqual(str(decimal), '0.000000001')
60 |
61 | def testAdd(self):
62 | self.assertEqual((display.Decimal(2) + display.Decimal(3)).value, 5)
63 |
64 | def testSub(self):
65 | self.assertEqual((display.Decimal(2) - display.Decimal(3)).value, -1)
66 |
67 | def testMul(self):
68 | self.assertEqual((display.Decimal(2) * display.Decimal(3)).value, 6)
69 |
70 | def testRound(self):
71 | decimal = display.Decimal(sympy.Rational(2675, 1000)) # 2.675
72 | self.assertEqual(sympy.sympify(decimal.round()), sympy.Integer(3))
73 | self.assertEqual(sympy.sympify(decimal.round(1)), sympy.Rational(27, 10))
74 | self.assertEqual(sympy.sympify(decimal.round(2)), sympy.Rational(268, 100))
75 | self.assertEqual(sympy.sympify(decimal.round(3)),
76 | sympy.Rational(2675, 1000))
77 |
78 | def testInt(self):
79 | decimal = display.Decimal(123)
80 | self.assertEqual(int(decimal), 123)
81 |
82 | def testInt_errorIfNonInt(self):
83 | decimal = display.Decimal(sympy.Rational(1, 2))
84 | with self.assertRaisesRegex(self, TypeError, 'Cannot represent'):
85 | int(decimal)
86 |
87 | def testComparison(self):
88 | decimal = display.Decimal(sympy.Rational(-1, 2))
89 | # pylint: disable=g-generic-assert
90 | self.assertFalse(decimal != -0.5)
91 | self.assertTrue(decimal != 0)
92 | self.assertFalse(decimal < -0.5)
93 | self.assertTrue(decimal < 0)
94 | self.assertTrue(decimal <= -0.5)
95 | self.assertTrue(decimal <= 0)
96 | self.assertFalse(decimal > -0.5)
97 | self.assertTrue(decimal > -1)
98 | self.assertTrue(decimal >= -0.5)
99 | self.assertFalse(decimal >= 0)
100 | self.assertFalse(decimal == 0)
101 | self.assertTrue(decimal == -0.5)
102 |
103 | def testNegation(self):
104 | decimal = display.Decimal(sympy.Rational(1, 2))
105 | decimal = -decimal
106 | self.assertNotEqual(decimal, 0.5)
107 | self.assertEqual(decimal, -0.5)
108 |
109 |
110 | class PercentageTest(absltest.TestCase):
111 |
112 | def testPercentage(self):
113 | percentage = display.Percentage(1.5)
114 | self.assertEqual(str(percentage), '150%')
115 |
116 | percentage = display.Percentage(sympy.Rational(67, 100))
117 | self.assertEqual(str(percentage), '67%')
118 |
119 | percentage = display.Percentage(sympy.Rational(67, 1000))
120 | self.assertEqual(str(percentage), '6.7%')
121 |
122 |
123 | class NonSimpleRationalTest(absltest.TestCase):
124 |
125 | def testBasic(self):
126 | frac = display.NonSimpleRational(4, 6)
127 | self.assertEqual(frac.numer, 4)
128 | self.assertEqual(frac.denom, 6)
129 | self.assertEqual(str(frac), '4/6')
130 |
131 |
132 | class StringNumberTest(absltest.TestCase):
133 |
134 | def testIntegerToWords(self):
135 | words = display.StringNumber(0)
136 | self.assertEqual(str(words), 'zero')
137 | self.assertEqual(sympy.sympify(words), 0)
138 |
139 | words = display.StringNumber(8)
140 | self.assertEqual(str(words), 'eight')
141 | self.assertEqual(sympy.sympify(words), 8)
142 |
143 | words = display.StringNumber(12)
144 | self.assertEqual(str(words), 'twelve')
145 | self.assertEqual(sympy.sympify(words), 12)
146 |
147 | words = display.StringNumber(30)
148 | self.assertEqual(str(words), 'thirty')
149 | self.assertEqual(sympy.sympify(words), 30)
150 |
151 | words = display.StringNumber(100)
152 | self.assertEqual(str(words), 'one-hundred')
153 | self.assertEqual(sympy.sympify(words), 100)
154 |
155 | words = display.StringNumber(103)
156 | self.assertEqual(str(words), 'one-hundred-and-three')
157 | self.assertEqual(sympy.sympify(words), 103)
158 |
159 | words = display.StringNumber(15439822)
160 | self.assertEqual(str(words), 'fifteen-million-four-hundred-and-thirty-nine'
161 | '-thousand-eight-hundred-and-twenty-two')
162 | self.assertEqual(sympy.sympify(words), 15439822)
163 |
164 | def testRationalToWords(self):
165 | words = display.StringNumber(sympy.Rational(2, 3))
166 | self.assertEqual(str(words), 'two thirds')
167 |
168 |
169 | class StringOrdinalTest(absltest.TestCase):
170 |
171 | def testBasic(self):
172 | ordinal = display.StringOrdinal(0)
173 | self.assertEqual(str(ordinal), 'zeroth')
174 | ordinal = display.StringOrdinal(10)
175 | self.assertEqual(str(ordinal), 'tenth')
176 |
177 | def testCreate_errorIfNegative(self):
178 | with self.assertRaisesRegex(self, ValueError, 'Unsupported ordinal'):
179 | display.StringOrdinal(-1)
180 |
181 |
182 | class NumberListTest(absltest.TestCase):
183 |
184 | def testBasic(self):
185 | numbers = [2, 3, 1]
186 | number_list = display.NumberList(numbers)
187 | string = str(number_list)
188 | self.assertEqual(string, '2, 3, 1')
189 |
190 |
191 | class NumberInBaseTest(absltest.TestCase):
192 |
193 | def testBasic(self):
194 | self.assertEqual(str(display.NumberInBase(1, 10)), '1')
195 | self.assertEqual(str(display.NumberInBase(-1, 10)), '-1')
196 | self.assertEqual(str(display.NumberInBase(1, 2)), '1')
197 | self.assertEqual(str(display.NumberInBase(-1, 2)), '-1')
198 | self.assertEqual(str(display.NumberInBase(2, 2)), '10')
199 | self.assertEqual(str(display.NumberInBase(-2, 2)), '-10')
200 | self.assertEqual(str(display.NumberInBase(10, 16)), 'a')
201 | self.assertEqual(str(display.NumberInBase(16, 16)), '10')
202 | self.assertEqual(str(display.NumberInBase(256, 16)), '100')
203 | self.assertEqual(str(display.NumberInBase(-75483, 10)), '-75483')
204 |
205 |
206 | if __name__ == '__main__':
207 | absltest.main()
208 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/probability.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functionality for working with probability spaces and random variables.
16 |
17 | Basic recap of probability theory, and thus of classes in this file:
18 |
19 | * A probability space is a (finite or infinite) set Omega with a probability
20 | measure defined on this.
21 | * A random variable is a mapping from a probability space to another measure
22 | space.
23 | * An event is a measurable set in a sample space.
24 |
25 | For example, suppose a bag contains 3 balls: two red balls, and one white ball.
26 | This could be represented by a discrete probability space of size 3 with
27 | elements {1, 2, 3}, with equal measure assigned to all 3 elements; and a random
28 | variable that maps 1->red, 2->red, and 3->white. Then the probability of drawing
29 | a red ball is the measure in the probability space of the inverse under the
30 | random variable mapping of {red}, i.e., of {1, 2}, which is 2/3.
31 | """
32 |
33 | from __future__ import absolute_import
34 | from __future__ import division
35 | from __future__ import print_function
36 |
37 | import abc
38 | import itertools
39 |
40 | # Dependency imports
41 | import six
42 | from six.moves import zip
43 | import sympy
44 |
45 |
46 | @six.add_metaclass(abc.ABCMeta)
47 | class Event(object):
48 | """Represents an event in a measure space."""
49 |
50 |
51 | @six.add_metaclass(abc.ABCMeta)
52 | class ProbabilitySpace(object):
53 | """Represents a probability space."""
54 |
55 | @abc.abstractmethod
56 | def probability(self, event):
57 | """Returns the probability of an event."""
58 |
59 |
60 | @six.add_metaclass(abc.ABCMeta)
61 | class RandomVariable(object):
62 | """Random variable; a mapping from a probability space to a measure space."""
63 |
64 | @abc.abstractmethod
65 | def __call__(self, event):
66 | """Maps an `_Event` in the probability space to one in the sample space."""
67 |
68 | @abc.abstractmethod
69 | def inverse(self, event):
70 | """Maps event in the sample space back to the inverse in the prob. space."""
71 |
72 |
73 | class DiscreteEvent(Event):
74 | """Set of discrete values."""
75 |
76 | def __init__(self, values):
77 | self._values = values
78 |
79 | @property
80 | def values(self):
81 | return self._values
82 |
83 |
84 | class FiniteProductEvent(Event):
85 | """Event consisting of cartesian product of events."""
86 |
87 | def __init__(self, events):
88 | """Initializes a `FiniteProductEvent`.
89 |
90 | Args:
91 | events: Tuple of `Event`s; resulting event will be cartesian product of
92 | these.
93 | """
94 | self._events = events
95 |
96 | @property
97 | def events(self):
98 | return self._events
99 |
100 | def all_sequences(self):
101 | """Returns iterator of sequences by selecting a single event in each coord.
102 |
103 | This assumes that every component event is an instance of `DiscreteEvent`.
104 |
105 | Returns:
106 | Iterator over tuples of values.
107 |
108 | Raises:
109 | ValueError: If one of the component events is not a `DiscreteEvent`.
110 | """
111 | if not all(isinstance(event, DiscreteEvent) for event in self._events):
112 | raise ValueError('Not all component events are DiscreteEvents')
113 | values_list = [event.values for event in self._events]
114 | return itertools.product(*values_list)
115 |
116 |
117 | class CountLevelSetEvent(Event):
118 | """Event of all sequences with fixed number of different values occurring."""
119 |
120 | def __init__(self, counts):
121 | """Initializes `CountLevelSetEvent`.
122 |
123 | E.g., to construct the event of getting two red balls and one green ball,
124 | pass `counts = {red: 2, green: 1}`. (Then `all_sequences()` would return
125 | `[(red, red, green), (red, green, red), (green, red, red)]`.
126 |
127 | Args:
128 | counts: Dictionary mapping values to the number of times they occur in a
129 | sequence.
130 | """
131 | self._counts = counts
132 | self._all_sequences = None
133 |
134 | @property
135 | def counts(self):
136 | return self._counts
137 |
138 | def all_sequences(self):
139 | """Returns all sequences generated by this level set."""
140 | if self._all_sequences is None:
141 | # Generate via dynamic programming.
142 | cache = {} # dict mapping tuple -> list of tuples
143 | labels = list(self._counts.keys())
144 |
145 | def generate(counts):
146 | """Returns list of tuples for given `counts` of labels."""
147 | if sum(counts) == 0:
148 | return [()]
149 | counts = tuple(counts)
150 | if counts in cache:
151 | return cache[counts]
152 | generated = []
153 | for i, count in enumerate(counts):
154 | if count == 0:
155 | continue
156 | counts_minus = list(counts)
157 | counts_minus[i] -= 1
158 | counts_minus = tuple(counts_minus)
159 | extensions = generate(counts_minus)
160 | generated += [tuple([labels[i]] + list(extension))
161 | for extension in extensions]
162 | cache[counts] = generated
163 | return generated
164 |
165 | self._all_sequences = generate(list(self._counts.values()))
166 |
167 | return self._all_sequences
168 |
169 |
170 | class SequenceEvent(Event):
171 | """Collection of sequences."""
172 |
173 | def __init__(self, sequences):
174 | self._sequences = sequences
175 |
176 | def all_sequences(self):
177 | return self._sequences
178 |
179 |
180 | def normalize_weights(weights):
181 | """Normalizes the weights (as sympy.Rational) in dictionary of weights."""
182 | weight_sum = sum(six.itervalues(weights))
183 | return {
184 | i: sympy.Rational(weight, weight_sum)
185 | for i, weight in six.iteritems(weights)
186 | }
187 |
188 |
189 | class DiscreteProbabilitySpace(ProbabilitySpace):
190 | """Discrete probability space."""
191 |
192 | def __init__(self, weights=None):
193 | """Initializes an `DiscreteProbabilitySpace`.
194 |
195 | Args:
196 | weights: Dictionary mapping values to relative probability of selecting
197 | that value. This will be normalized.
198 | """
199 | self._weights = normalize_weights(weights)
200 |
201 | def probability(self, event):
202 | if isinstance(event, DiscreteEvent):
203 | return sum(self._weights[value]
204 | for value in event.values if value in self._weights)
205 | else:
206 | raise ValueError('Unhandled event type {}'.format(type(event)))
207 |
208 | @property
209 | def weights(self):
210 | """Returns dictionary of probability of each element."""
211 | return self._weights
212 |
213 |
214 | class FiniteProductSpace(ProbabilitySpace):
215 | """Finite cartesian product of probability spaces."""
216 |
217 | def __init__(self, spaces):
218 | """Initializes a `FiniteProductSpace`.
219 |
220 | Args:
221 | spaces: List of `ProbabilitySpace`.
222 | """
223 | self._spaces = spaces
224 |
225 | def all_spaces_equal(self):
226 | return all([self._spaces[0] == space for space in self._spaces])
227 |
228 | def probability(self, event):
229 | # Specializations for optimization.
230 | if isinstance(event, FiniteProductEvent):
231 | assert len(self._spaces) == len(event.events)
232 | return sympy.prod([
233 | space.probability(event_slice)
234 | for space, event_slice in zip(self._spaces, event.events)])
235 |
236 | if isinstance(event, CountLevelSetEvent) and self.all_spaces_equal():
237 | space = self._spaces[0]
238 | counts = event.counts
239 | probabilities = {
240 | value: space.probability(DiscreteEvent({value}))
241 | for value in six.iterkeys(counts)
242 | }
243 |
244 | num_events = sum(six.itervalues(counts))
245 | assert num_events == len(self._spaces)
246 | # Multinomial coefficient:
247 | coeff = (
248 | sympy.factorial(num_events) / sympy.prod(
249 | [sympy.factorial(i) for i in six.itervalues(counts)]))
250 | return coeff * sympy.prod([
251 | pow(probabilities[value], counts[value])
252 | for value in six.iterkeys(counts)
253 | ])
254 |
255 | raise ValueError('Unhandled event type {}'.format(type(event)))
256 |
257 | @property
258 | def spaces(self):
259 | """Returns list of spaces."""
260 | return self._spaces
261 |
262 |
263 | class SampleWithoutReplacementSpace(ProbabilitySpace):
264 | """Probability space formed by sampling discrete space without replacement."""
265 |
266 | def __init__(self, weights, n_samples):
267 | """Initializes a `SampleWithoutReplacementSpace`.
268 |
269 | Args:
270 | weights: Dictionary mapping values to relative probability of selecting
271 | that value. This will be normalized.
272 | n_samples: Number of samples to draw.
273 |
274 | Raises:
275 | ValueError: If `n_samples > len(weights)`.
276 | """
277 | if n_samples > len(weights):
278 | raise ValueError('n_samples is more than number of discrete elements')
279 | self._weights = normalize_weights(weights)
280 | self._n_samples = n_samples
281 |
282 | @property
283 | def n_samples(self):
284 | """Number of samples to draw."""
285 | return self._n_samples
286 |
287 | def probability(self, event):
288 | try:
289 | all_sequences = event.all_sequences()
290 | except AttributeError:
291 | raise ValueError('Unhandled event type {}'.format(type(event)))
292 |
293 | probability_sum = 0
294 | for sequence in all_sequences:
295 | if len(sequence) != len(set(sequence)):
296 | continue # not all unique, so not "without replacement".
297 | p_sequence = 1
298 | removed_prob = 0
299 | for i in sequence:
300 | p = self._weights[i] if i in self._weights else 0
301 | if p == 0:
302 | p_sequence = 0
303 | break
304 | p_sequence *= p / (1 - removed_prob)
305 | removed_prob += p
306 | probability_sum += p_sequence
307 | return probability_sum
308 |
309 |
310 | class IdentityRandomVariable(RandomVariable):
311 | """Identity map of a probability space."""
312 |
313 | def __call__(self, event):
314 | return event
315 |
316 | def inverse(self, event):
317 | return event
318 |
319 |
320 | class DiscreteRandomVariable(RandomVariable):
321 | """Specialization to discrete random variable.
322 |
323 | This is simply a mapping from a discrete space to a discrete space (dictionary
324 | lookup).
325 | """
326 |
327 | def __init__(self, mapping):
328 | """Initializes `DiscreteRandomVariable` from `mapping` dict."""
329 | self._mapping = mapping
330 | self._inverse = {}
331 | for key, value in six.iteritems(mapping):
332 | if value in self._inverse:
333 | self._inverse[value].add(key)
334 | else:
335 | self._inverse[value] = set([key])
336 |
337 | def __call__(self, event):
338 | if isinstance(event, DiscreteEvent):
339 | return DiscreteEvent({self._mapping[value] for value in event.values})
340 | else:
341 | raise ValueError('Unhandled event type {}'.format(type(event)))
342 |
343 | def inverse(self, event):
344 | if isinstance(event, DiscreteEvent):
345 | set_ = set()
346 | for value in event.values:
347 | if value in self._inverse:
348 | set_.update(self._inverse[value])
349 | return DiscreteEvent(set_)
350 | else:
351 | raise ValueError('Unhandled event type {}'.format(type(event)))
352 |
353 |
354 | class FiniteProductRandomVariable(RandomVariable):
355 | """Product random variable.
356 |
357 | This has the following semantics. Let this be X = (X_1, ..., X_n). Then
358 |
359 | X(w) = (X_1(w_1), ..., X_n(w_n))
360 |
361 | (the sample space is assumed to be of sequence type).
362 | """
363 |
364 | def __init__(self, random_variables):
365 | """Initializes a `FiniteProductRandomVariable`.
366 |
367 | Args:
368 | random_variables: Tuple of `RandomVariable`.
369 | """
370 | self._random_variables = random_variables
371 |
372 | def __call__(self, event):
373 | if isinstance(event, FiniteProductEvent):
374 | assert len(event.events) == len(self._random_variables)
375 | zipped = list(zip(self._random_variables, event.events))
376 | return FiniteProductEvent(
377 | [random_variable(sub_event)
378 | for random_variable, sub_event in zipped])
379 | else:
380 | raise ValueError('Unhandled event type {}'.format(type(event)))
381 |
382 | def inverse(self, event):
383 | # Specialization for `FiniteProductEvent`; don't need to take all sequences.
384 | if isinstance(event, FiniteProductEvent):
385 | assert len(event.events) == len(self._random_variables)
386 | zipped = list(zip(self._random_variables, event.events))
387 | return FiniteProductEvent(tuple(
388 | random_variable.inverse(sub_event)
389 | for random_variable, sub_event in zipped))
390 |
391 | # Try fallback of mapping each sequence separately.
392 | try:
393 | all_sequences = event.all_sequences()
394 | except AttributeError:
395 | raise ValueError('Unhandled event type {}'.format(type(event)))
396 |
397 | mapped = set()
398 | for sequence in all_sequences:
399 | assert len(sequence) == len(self._random_variables)
400 | zipped = list(zip(self._random_variables, sequence))
401 | mapped_sequence = FiniteProductEvent(tuple(
402 | random_variable.inverse(DiscreteEvent({element}))
403 | for random_variable, element in zipped))
404 | mapped.update(mapped_sequence.all_sequences())
405 | return SequenceEvent(mapped)
406 |
--------------------------------------------------------------------------------
/mathematics_dataset/util/probability_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for mathematics_dataset.util.probability."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # Dependency imports
22 | from absl.testing import absltest
23 | from mathematics_dataset.util import probability
24 | import sympy
25 |
26 |
27 | class FiniteProductEventTest(absltest.TestCase):
28 |
29 | def testAllSequences(self):
30 | event = probability.FiniteProductEvent([probability.DiscreteEvent({1, 2}),
31 | probability.DiscreteEvent({3})])
32 | all_sequences = [i for i in event.all_sequences()]
33 | self.assertEqual(all_sequences, [(1, 3), (2, 3)])
34 |
35 |
36 | class CountLevelSetEventTest(absltest.TestCase):
37 |
38 | def testAllSequences(self):
39 | event = probability.CountLevelSetEvent({'a': 2, 'b': 4, 'c': 1})
40 | all_sequences = event.all_sequences()
41 |
42 | # Number of sequences should be 7! / (4! * 2! * 1!) = 105.
43 | self.assertLen(all_sequences, 105)
44 | # They should all be unique.
45 | self.assertEqual(len(all_sequences), len(set(all_sequences)))
46 | # And check contains one correctly generated tuple.
47 | self.assertIn(('a', 'b', 'c', 'b', 'b', 'a', 'b'), all_sequences)
48 |
49 |
50 | class DiscreteProbabilitySpaceTest(absltest.TestCase):
51 |
52 | def testBasic(self):
53 | space = probability.DiscreteProbabilitySpace({0: 1, 1: 2, 2: 3})
54 | p = space.probability(probability.DiscreteEvent([0]))
55 | self.assertEqual(p, sympy.Rational(1, 6))
56 | p = space.probability(probability.DiscreteEvent([0, 1]))
57 | self.assertEqual(p, sympy.Rational(1, 2))
58 | p = space.probability(probability.DiscreteEvent([0, 1, 2]))
59 | self.assertEqual(p, 1)
60 | p = space.probability(probability.DiscreteEvent([0, 1, 2, 3]))
61 | self.assertEqual(p, 1)
62 | p = space.probability(probability.DiscreteEvent([3]))
63 | self.assertEqual(p, 0)
64 |
65 |
66 | class FiniteProductSpaceTest(absltest.TestCase):
67 |
68 | def testProbability_FiniteProductEvent(self):
69 | # 5 coin flips of a biased coin with heads prob = 1/3.
70 | base_space = probability.DiscreteProbabilitySpace({'h': 1, 't': 2})
71 | space = probability.FiniteProductSpace([base_space] * 5)
72 |
73 | heads = probability.DiscreteEvent({'h'})
74 | tails = probability.DiscreteEvent({'t'})
75 | event = probability.FiniteProductEvent([heads, heads, tails, tails, heads])
76 | self.assertEqual(space.probability(event), sympy.Rational(4, 3**5))
77 |
78 | def testProbability_CountLevelSetEvent(self):
79 | base_space = probability.DiscreteProbabilitySpace({'a': 2, 'b': 3, 'c': 5})
80 | space = probability.FiniteProductSpace([base_space] * 12)
81 | event = probability.CountLevelSetEvent({'a': 7, 'b': 2, 'c': 3})
82 |
83 | # Probability should be (12 choose 7 2 3) * p(a)^7 p(b)^2 p(c)^3
84 | coeff = 7920
85 | p_a = sympy.Rational(1, 5)
86 | p_b = sympy.Rational(3, 10)
87 | p_c = sympy.Rational(1, 2)
88 | self.assertEqual(space.probability(event),
89 | coeff * pow(p_a, 7) * pow(p_b, 2) * pow(p_c, 3))
90 |
91 |
92 | class SampleWithoutReplacementSpaceTest(absltest.TestCase):
93 |
94 | def testBasic(self):
95 | space = probability.SampleWithoutReplacementSpace({0: 1, 1: 1}, 2)
96 | event_0_0 = probability.FiniteProductEvent(
97 | [probability.DiscreteEvent({0}), probability.DiscreteEvent({0})])
98 | event_0_1 = probability.FiniteProductEvent(
99 | [probability.DiscreteEvent({0}), probability.DiscreteEvent({1})])
100 | p_0_0 = space.probability(event_0_0)
101 | p_0_1 = space.probability(event_0_1)
102 | self.assertEqual(p_0_0, 0)
103 | self.assertEqual(p_0_1, sympy.Rational(1, 2))
104 |
105 | space = probability.SampleWithoutReplacementSpace({0: 1, 1: 0}, 1)
106 | event_0 = probability.FiniteProductEvent([probability.DiscreteEvent({0})])
107 | event_1 = probability.FiniteProductEvent([probability.DiscreteEvent({1})])
108 | event_2 = probability.FiniteProductEvent([probability.DiscreteEvent({2})])
109 | p_0 = space.probability(event_0)
110 | p_1 = space.probability(event_1)
111 | p_2 = space.probability(event_2)
112 | self.assertEqual(p_0, 1)
113 | self.assertEqual(p_1, 0)
114 | self.assertEqual(p_2, 0)
115 |
116 |
117 | class DiscreteRandomVariableTest(absltest.TestCase):
118 |
119 | def testCall(self):
120 | random_variable = probability.DiscreteRandomVariable({1: 1, 2: 3, 3: 4})
121 | forwards = random_variable(probability.DiscreteEvent({1, 3}))
122 | self.assertEqual(forwards.values, {1, 4})
123 |
124 | def testInverse(self):
125 | random_variable = probability.DiscreteRandomVariable({1: 1, 2: 3, 3: 4})
126 | inverse = random_variable.inverse(probability.DiscreteEvent({1, 3}))
127 | self.assertEqual(inverse.values, {1, 2})
128 |
129 | random_variable = probability.DiscreteRandomVariable({1: 1, 2: 1})
130 | inverse = random_variable.inverse(probability.DiscreteEvent({1, 5}))
131 | self.assertEqual(inverse.values, {1, 2})
132 |
133 |
134 | class FiniteProductRandomVariableTest(absltest.TestCase):
135 |
136 | def _random_variable(self):
137 | rv1 = probability.DiscreteRandomVariable({1: 'a', 2: 'b', 3: 'c'})
138 | rv2 = probability.DiscreteRandomVariable({1: 'x', 2: 'y', 3: 'x'})
139 | return probability.FiniteProductRandomVariable((rv1, rv2))
140 |
141 | def testCall_FiniteProductEvent(self):
142 | rv = self._random_variable()
143 | event1 = probability.DiscreteEvent({1, 2})
144 | event2 = probability.DiscreteEvent({1, 3})
145 | event = probability.FiniteProductEvent((event1, event2))
146 | result = rv(event)
147 | self.assertIsInstance(result, probability.FiniteProductEvent)
148 | self.assertLen(result.events, 2)
149 | self.assertEqual(result.events[0].values, {'a', 'b'})
150 | self.assertEqual(result.events[1].values, {'x'})
151 |
152 | def testInverse_FiniteProductEvent(self):
153 | rv = self._random_variable()
154 | event1 = probability.DiscreteEvent({'a', 'b'})
155 | event2 = probability.DiscreteEvent({'x'})
156 | event = probability.FiniteProductEvent((event1, event2))
157 | result = rv.inverse(event)
158 | self.assertIsInstance(result, probability.FiniteProductEvent)
159 | self.assertLen(result.events, 2)
160 | self.assertEqual(result.events[0].values, {1, 2})
161 | self.assertEqual(result.events[1].values, {1, 3})
162 |
163 | def testInverse_CountLevelSetEvent(self):
164 | rv = self._random_variable()
165 | event = probability.CountLevelSetEvent({'a': 1, 'x': 1})
166 | result = rv.inverse(event)
167 | sequences = result.all_sequences()
168 | self.assertLen(sequences, 2)
169 | self.assertEqual(set(sequences), {(1, 1), (1, 3)})
170 |
171 |
172 | if __name__ == '__main__':
173 | absltest.main()
174 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module setuptools script."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from setuptools import find_packages
22 | from setuptools import setup
23 |
24 | description = """A synthetic dataset of school-level mathematics questions.
25 |
26 | This dataset code generates mathematical question and answer pairs, from a range
27 | of question types (such as in arithmetic, algebra, probability, etc), at roughly
28 | school-level difficulty. This is designed to test the mathematical learning and
29 | reasoning skills of learning models.
30 |
31 | Original paper: Analysing Mathematical Reasoning Abilities of Neural Models
32 | (Saxton, Grefenstette, Hill, Kohli) (https://openreview.net/pdf?id=H1gR5iR5FX).
33 | """
34 |
35 | setup(
36 | name='mathematics_dataset',
37 | version='1.0.1',
38 | description='A synthetic dataset of school-level mathematics questions',
39 | long_description=description,
40 | author='DeepMind',
41 | author_email='saxton@google.com',
42 | license='Apache License, Version 2.0',
43 | keywords='mathematics dataset',
44 | url='https://github.com/deepmind/mathematics_dataset',
45 | packages=find_packages(),
46 | install_requires=[
47 | 'absl-py>=0.1.0',
48 | 'numpy>=1.10',
49 | 'six',
50 | 'sympy>=1.2',
51 | ],
52 | classifiers=[
53 | 'Development Status :: 4 - Beta',
54 | 'Environment :: Console',
55 | 'Intended Audience :: Developers',
56 | 'Intended Audience :: Science/Research',
57 | 'License :: OSI Approved :: Apache Software License',
58 | 'Operating System :: POSIX :: Linux',
59 | 'Programming Language :: Python :: 2.7',
60 | 'Programming Language :: Python :: 3.4',
61 | 'Programming Language :: Python :: 3.5',
62 | 'Programming Language :: Python :: 3.6',
63 | 'Programming Language :: Python :: 3.7',
64 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
65 | ],
66 | )
67 |
--------------------------------------------------------------------------------