├── .gitignore ├── README.md ├── READ_ME_assets └── graph_construction.gif ├── math_prog_synth_env.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── math_prog_synth_env ├── __init__.py ├── compute_graph.py ├── envs │ ├── __init__.py │ └── math_env.py ├── setup.py ├── tokenization │ ├── question_corpus_mlp.txt │ └── tokenizer.model ├── typed_operators │ └── __init__.py ├── unit_testing │ ├── __init__.py │ ├── artifacts │ │ ├── extract_formal_elements_examples.json │ │ ├── extract_formal_elements_examples.py │ │ ├── params.yaml │ │ └── problems │ │ │ ├── algebra__linear_1d.txt │ │ │ ├── algebra__linear_2d.txt │ │ │ ├── algebra__polynomial_roots.txt │ │ │ ├── numbers__div_remainder.txt │ │ │ ├── numbers__gcd.txt │ │ │ ├── numbers__is_factor.txt │ │ │ ├── numbers__is_prime.txt │ │ │ ├── numbers__lcm.txt │ │ │ ├── numbers__list_prime_factors.txt │ │ │ └── polynomials__evaluate.txt │ ├── extract_problems_for_guessing_test.py │ ├── test_compute_graph.py │ ├── test_extract_formal_elements.py │ ├── test_graphs.py │ ├── test_gym_environment.py │ ├── test_gym_environment_guessing.py │ ├── test_operators.py │ └── test_utils.py └── utils.py ├── params.yaml └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | __pycache__ 3 | mathematics_dataset-v1.0/* 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # math_prog_synth_env 2 | 3 | This repository contains an implementation of math_prog_synth_env as described in https://arxiv.org/abs/2107.07373. 4 | 5 | ![Graph construction video](https://github.com/JohnnyYeeee/math_prog_synth_env/blob/main/READ_ME_assets/graph_construction.gif?raw=true) 6 | 7 | The full code used to produce the results reported in the paper can be found here: https://github.com/joepalermo/dm_math_solvers 8 | 9 | ## Setup: 10 | 11 | ``` bash 12 | git clone https://github.com/JohnnyYeeee/math_prog_synth_env.git 13 | # optionally create and activate a new environment 14 | conda create -n math_prog_synth_env -y python=3.7 15 | conda activate math_prog_synth_env 16 | # install dependencies 17 | pip install -e math_prog_synth_env 18 | ``` 19 | 20 | ```python 21 | import gym 22 | # the first time running this may take awhile (particularly to download the data) 23 | env = gym.make('math_prog_synth_env:math-env-v0', config_file='params.yaml') 24 | ``` 25 | 26 | Before running the environment several pre-requisites need to be completed: 27 | 28 | - The raw data (https://storage.googleapis.com/mathematics-dataset/mathematics_dataset-v1.0.tar.gz) needs to be downloaded 29 | - The data needs to be split into train/val/test sets 30 | - A tokenizer needs to be created 31 | 32 | Upon running `gym.make('math_prog_synth_env:math-env-v0', config_file='params.yaml')` a check is performed to determine if the last step (tokenizer creation) has been completed. If not then all 3 steps will be automatically completed. 33 | 34 | ## Run unit tests 35 | 36 | To run the unit tests, change working directory to the root of the project and then run `python -m unittest discover math_prog_synth_env/unit_testing` 37 | -------------------------------------------------------------------------------- /READ_ME_assets/graph_construction.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnnyYeeee/math_prog_synth_env/15ba964796f265879d9ce41e27b5299c2e0831a0/READ_ME_assets/graph_construction.gif -------------------------------------------------------------------------------- /math_prog_synth_env.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: math_prog_synth_env 3 | Version: 0.0.1 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /math_prog_synth_env.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | math_prog_synth_env.egg-info/PKG-INFO 4 | math_prog_synth_env.egg-info/SOURCES.txt 5 | math_prog_synth_env.egg-info/dependency_links.txt 6 | math_prog_synth_env.egg-info/requires.txt 7 | math_prog_synth_env.egg-info/top_level.txt -------------------------------------------------------------------------------- /math_prog_synth_env.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /math_prog_synth_env.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | gym 2 | sympy 3 | numpy 4 | scipy 5 | sentencepiece 6 | torch 7 | mathematics_dataset 8 | tqdm 9 | sklearn 10 | google-cloud-storage 11 | pyyaml 12 | -------------------------------------------------------------------------------- /math_prog_synth_env.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /math_prog_synth_env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id="math-env-v0", 5 | entry_point="math_prog_synth_env.envs:MathEnv", 6 | ) 7 | -------------------------------------------------------------------------------- /math_prog_synth_env/compute_graph.py: -------------------------------------------------------------------------------- 1 | from inspect import signature 2 | from math_prog_synth_env.utils import extract_formal_elements 3 | from math_prog_synth_env.typed_operators import * 4 | import signal 5 | 6 | class Node: 7 | def __init__(self, action): 8 | self.action = action 9 | self.args = [] 10 | if type(self.action) == str: # if action is a formal element 11 | self.num_parameters = 0 12 | self.types = [] 13 | else: 14 | self.num_parameters = len(signature(self.action).parameters) 15 | self.types = [ 16 | type_.annotation 17 | for name, type_ in signature(self.action).parameters.items() 18 | ] 19 | 20 | def set_arg(self, node): 21 | assert len(self.args) < self.num_parameters 22 | self.args.append(node) 23 | 24 | def set_args(self, nodes): 25 | assert len(self.args) == 0 26 | self.args = nodes 27 | 28 | def are_args_set(self): 29 | return len(self.args) == self.num_parameters 30 | 31 | 32 | class ComputeGraph: 33 | def __init__(self, question): 34 | self.formal_elements = extract_formal_elements(question) 35 | self.formal_element_types = [type(f) for f in self.formal_elements] 36 | self.root = None 37 | self.current_node = None # reference to the first node (breadth-first) that requires one or more arguments 38 | self.queue = [] 39 | self.n_nodes = 0 40 | 41 | def lookup_formal_element(self, action): 42 | """f12 => int(12)""" 43 | try: 44 | selected_formal_element = self.formal_elements[int(action[1:])] 45 | except: 46 | selected_formal_element = ( 47 | action # if index is out of range, return dummy value 48 | ) 49 | return selected_formal_element 50 | 51 | def build_string(self, current_node): 52 | if type(current_node) == str: # case: param 53 | return f"'{current_node}'" 54 | elif type(current_node.action) == str: # case: formal element 55 | assert current_node.action[0] == "f" 56 | formal_element = self.lookup_formal_element(current_node.action) 57 | return f"{type(formal_element).__name__}('{formal_element}')" 58 | elif current_node.action is None: # case: None (i.e. for an ap) 59 | return "None" 60 | else: 61 | arg_strings = [] 62 | if len(current_node.args) < current_node.num_parameters: 63 | num_params = current_node.num_parameters 64 | num_args = len(current_node.args) 65 | args = current_node.args + [ 66 | f"p_{i}" for i in range(num_args, num_params) 67 | ] 68 | else: 69 | args = current_node.args 70 | for arg in args: 71 | arg_string = self.build_string(arg) 72 | arg_strings.append(arg_string) 73 | return f"{current_node.action.__name__}({','.join(['{}'.format(arg_string) for arg_string in arg_strings])})" 74 | 75 | def __str__(self): 76 | """ 77 | traverse the graph to construct a string representing the compute graph. 78 | :return: 79 | """ 80 | return self.build_string(self.root) 81 | 82 | def eval(self): 83 | """ 84 | evaluate the compute graph 85 | :return: the output of the compute graph 86 | """ 87 | try: 88 | string_to_eval = str(self) 89 | if "\'p_" in string_to_eval: 90 | raise Exception("unreplaced params are in arb, e.g. 'p_0'") 91 | output = eval(string_to_eval) 92 | # if output is a set, reformat as a sorted string 93 | if type(output) == set: 94 | return ", ".join([str(x) for x in sorted(list(output))]) 95 | else: 96 | return output 97 | except: 98 | return None 99 | 100 | def add(self, action): 101 | """ 102 | Add an action to the compute graph. Elements are added breadth-first order: KNOB. 103 | 104 | :param action: either an operator or a formal element 105 | """ 106 | if self.root is None: 107 | self.root = Node(action) 108 | if not self.root.are_args_set(): 109 | self.current_node = self.root 110 | else: 111 | self.current_node = None 112 | else: 113 | new_node = Node(action) 114 | self.current_node.set_arg(new_node) 115 | if new_node.num_parameters > 0: 116 | self.queue.append( 117 | new_node 118 | ) # add new node to queue for later processing 119 | if self.current_node.are_args_set(): 120 | if len(self.queue) > 0: 121 | self.current_node = self.queue.pop() 122 | else: 123 | self.current_node = None 124 | 125 | def reset(self): 126 | self.root = None 127 | self.current_node = None 128 | self.queue = [] 129 | -------------------------------------------------------------------------------- /math_prog_synth_env/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from math_prog_synth_env.envs.math_env import MathEnv 2 | -------------------------------------------------------------------------------- /math_prog_synth_env/envs/math_env.py: -------------------------------------------------------------------------------- 1 | 2 | from inspect import signature 3 | import os 4 | from pathlib import Path 5 | from sympy import sympify 6 | from random import sample 7 | import gym 8 | import numpy as np 9 | from gym import spaces 10 | from scipy.special import softmax 11 | import sentencepiece as spm 12 | from math_prog_synth_env.compute_graph import ComputeGraph 13 | from math_prog_synth_env.typed_operators import * 14 | from math_prog_synth_env.utils import load_data, split_validation_data 15 | import torch 16 | 17 | class MathEnv(gym.Env): 18 | def __init__(self, config_file): 19 | import yaml 20 | 21 | self.compute_graph = None 22 | self.episode_actions = None 23 | # load config 24 | with open(config_file, 'r') as stream: 25 | config = yaml.safe_load(stream) 26 | self.config = config 27 | self.encode_question = config["encode_question"] 28 | self.max_num_nodes = self._max_episode_steps = config["max_num_nodes"] 29 | self.max_formal_elements = config["max_formal_elements"] 30 | self.max_difficulty = config["max_difficulty"] 31 | self.question_vocab_size = config["question_vocab_size"] 32 | self.max_sequence_length = config["max_sequence_length"] 33 | # define available operator functions 34 | self.operators = [ 35 | lookup_value, 36 | solve_system, 37 | append, 38 | append_to_empty_list, 39 | make_equation, 40 | lookup_value_equation, 41 | extract_isolated_variable, 42 | substitution_left_to_right, 43 | factor, 44 | differentiate, 45 | differentiate_wrt, 46 | simplify, 47 | make_function, 48 | replace_arg, 49 | mod, 50 | gcd, 51 | divides, 52 | is_prime, 53 | lcm, 54 | lcd, 55 | prime_factors, 56 | evaluate_function, 57 | not_op 58 | ] 59 | # ensure that every operator listed in config["operators"] is present in the above list 60 | valid_op_names = [op.__name__ for op in self.operators] 61 | assert all([op in valid_op_names for op in config["operators"]]) 62 | # define action and observation space 63 | self.operators = [operator for operator in self.operators if (operator.__name__ in config["operators"])] 64 | self.operator_output_types = [ 65 | signature(operator).return_annotation for operator in self.operators 66 | ] 67 | self.actions = self.operators + [ 68 | f"f{i}" for i in range(self.max_formal_elements) 69 | ] 70 | self.action_names = [op.__name__ for op in self.operators] + [f"f{i}" for i in range(self.max_formal_elements)] 71 | self.num_actions = len(self.actions) 72 | # increment by 2 to account for both the question padding and the answer padding 73 | self.total_vocab_size = self.question_vocab_size + self.num_actions + 2 74 | self.action_space = spaces.Discrete(len(self.actions)) 75 | self.action_indices = np.arange(len(self.actions)) 76 | self.observation_space = spaces.MultiDiscrete( 77 | [self.total_vocab_size for _ in range(config["max_sequence_length"])] 78 | ) 79 | 80 | # Set up if data not downloaded yet 81 | if not os.path.isfile(self.config["tokenizer_filepath"] + ".model"): 82 | print("No data/tokenizer found: Redownloading data") 83 | self.setup() 84 | # load data 85 | self.train = load_data(config, train=True) 86 | self.val = split_validation_data(config, self.train) 87 | self.test = load_data(config, train=False) 88 | # load tokenizer 89 | self.question_padding_token = config["question_vocab_size"] 90 | # increment config["question_vocab_size"] by 1 to account for padding token 91 | self.action_padding_token = (config["question_vocab_size"] + 1) + self.num_actions 92 | self.tokenizer = spm.SentencePieceProcessor(model_file=self.config["tokenizer_filepath"] + ".model") 93 | 94 | 95 | def step(self, action_index): 96 | """ 97 | :param action_index: index into the action space 98 | :return: observation, reward, done, info 99 | An action fills the next element in the compute graph. 100 | -observation: question + interim compute graph 101 | -reward: 0 if the compute doesn't evaluate correctly, 1 if it does 102 | -done: True if the graph is complete, False if it isn't 103 | -info: None 104 | """ 105 | action = self.actions[action_index] 106 | self.compute_graph.n_nodes += 1 107 | self.compute_graph.add(action) 108 | self.episode_actions.append(action_index) 109 | output = self.compute_graph.eval() 110 | compute_graph = str(self.compute_graph) 111 | full_raw_observation = f"{self.question}; {compute_graph}" 112 | if self.encode_question: 113 | encoded_question = self.encode(self.question) 114 | # increment by (self.question_vocab_size + 1) to ensure no overlap between question vocab and action vocab 115 | episode_actions_array = np.array(self.episode_actions) + (self.question_vocab_size + 1) 116 | episode_actions_padding_array = np.array([self.action_padding_token 117 | for _ in range(self.max_num_nodes - len(self.episode_actions))]) 118 | observation = np.concatenate([encoded_question, episode_actions_array, episode_actions_padding_array]) 119 | else: 120 | observation = full_raw_observation 121 | next_mask = self.compute_mask() 122 | done = ( 123 | self.compute_graph.current_node is None 124 | or self.compute_graph.n_nodes >= self.max_num_nodes 125 | or np.array_equal(next_mask, np.zeros(len(next_mask))) 126 | ) 127 | # get reward 128 | if done: 129 | # cleanup output 130 | sympify_output = None 131 | sympify_answer = None 132 | try: 133 | sympify_output = sympify(str(output)) 134 | sympify_answer = sympify(self.answer) 135 | except: 136 | pass 137 | if sympify_output is not None and sympify_answer is not None and \ 138 | sympify_output == sympify_answer: 139 | reward = 1 140 | elif str(output) == str(self.answer): 141 | reward = 1 142 | else: 143 | reward = 0 144 | else: 145 | reward = 0 146 | info = {"raw_observation": full_raw_observation} 147 | return observation, reward, done, info 148 | 149 | 150 | # tokenization utilities ------------------------------------------------------------------------------------------- 151 | 152 | def encode(self, raw_observation): 153 | encoded_ids = self.tokenizer.encode(raw_observation) 154 | # pad the encoded ids up to a maximum length 155 | encoded_ids.extend( 156 | [self.question_padding_token for _ in range(self.config["max_sequence_length"] - len(encoded_ids))] 157 | ) 158 | return np.array(encoded_ids) 159 | 160 | def decode_question(self, encoded_ids): 161 | '' 162 | # filter out padding tokens before decoding 163 | encoded_ids = [id_ for id_ in encoded_ids.tolist() if id_ < self.question_padding_token] 164 | return self.tokenizer.decode(encoded_ids) 165 | 166 | # utilities to reset the environment ------------------------------------------------------------------------------- 167 | 168 | def reset(self, mode='train'): 169 | # randomly sample a module and difficulty level 170 | module_name = sample(list(self.train.keys()), 1)[0] 171 | difficulty = sample(list(self.train[module_name].keys()), 1)[0] 172 | return self.reset_by_module_and_difficulty(module_name, difficulty, mode=mode) 173 | 174 | def reset_from_text(self, question, answer): 175 | self.module_name = 'N/A' 176 | self.difficulty = 'N/A' 177 | self.question = question 178 | self.answer = answer 179 | self.module_difficulty_index = 'N/A' 180 | self.compute_graph = ComputeGraph(self.question) 181 | self.episode_actions = list() 182 | obs = np.concatenate([self.encode(self.question), 183 | np.array([self.action_padding_token for _ in range(self.max_num_nodes)])]) 184 | return obs, {'raw_observation': self.question} 185 | 186 | def reset_with_same_problem(self): 187 | self.compute_graph = ComputeGraph(self.question) 188 | self.episode_actions = list() 189 | obs = np.concatenate([self.encode(self.question), 190 | np.array([self.action_padding_token for _ in range(self.max_num_nodes)])]) 191 | return obs, {'raw_observation': self.question} 192 | 193 | def reset_with_specific_problem( 194 | self, module_name, difficulty, module_difficulty_index, train=True 195 | ): 196 | self.module_name = module_name 197 | self.difficulty = difficulty 198 | if train: 199 | 200 | problem_dict = self.train[module_name][difficulty][module_difficulty_index] 201 | else: 202 | problem_dict = self.val[module_name][difficulty][module_difficulty_index] 203 | self.question = problem_dict['question'] 204 | self.answer = problem_dict['answer'] 205 | self.module_difficulty_index = problem_dict['module_difficulty_index'] 206 | self.compute_graph = ComputeGraph(self.question) 207 | self.episode_actions = list() 208 | obs = np.concatenate([self.encode(self.question), 209 | np.array([self.action_padding_token for _ in range(self.max_num_nodes)])]) 210 | return obs, {'raw_observation': self.question} 211 | 212 | def reset_by_module_and_difficulty(self, module_name, difficulty, mode='train'): 213 | self.module_name = module_name 214 | self.difficulty = difficulty 215 | if mode == 'train': 216 | problem_dict = sample( 217 | self.train[module_name][difficulty], 1 218 | )[0] 219 | elif mode == 'val': 220 | problem_dict = sample( 221 | self.val[module_name][difficulty], 1 222 | )[0] 223 | else: 224 | problem_dict = sample( 225 | self.test[module_name][difficulty], 1 226 | )[0] 227 | 228 | self.question = problem_dict['question'] 229 | self.answer = problem_dict['answer'] 230 | self.module_difficulty_index = problem_dict['module_difficulty_index'] 231 | self.compute_graph = ComputeGraph(self.question) 232 | self.episode_actions = list() 233 | obs = np.concatenate([self.encode(self.question), 234 | np.array([self.action_padding_token for _ in range(self.max_num_nodes)])]) 235 | return obs, {'raw_observation': self.question} 236 | 237 | # utilities to sample actions -------------------------------------------------------------------------------------- 238 | 239 | def get_action_index(self, action): 240 | return self.actions.index(action) 241 | 242 | def sample_action_index(self): 243 | return self.action_space.sample() 244 | 245 | def sample_masked_action_index(self): 246 | choices = np.arange(len(self.actions)) 247 | mask = self.compute_mask() 248 | valid_choices = np.array([x for x, m in zip(choices, mask) if m != 0]) 249 | return np.random.choice(valid_choices) 250 | 251 | def sample_masked_policy_vector(self): 252 | policy_vector = np.random.uniform(size=len(self.actions)) 253 | masked_policy_vector = self.mask_invalid_types(policy_vector) 254 | masked_normed_policy_vector = masked_policy_vector / np.sum( 255 | masked_policy_vector 256 | ) 257 | return masked_normed_policy_vector 258 | 259 | def sample_masked_action_from_model(self, model, obs): 260 | policy_vector = softmax(model(obs).detach().numpy()[0]) 261 | masked_policy_vector = self.mask_invalid_types(policy_vector) 262 | masked_normed_policy_vector = masked_policy_vector / np.sum( 263 | masked_policy_vector 264 | ) 265 | choices = np.arange(len(self.actions)) 266 | action_index = np.random.choice(choices, p=masked_normed_policy_vector) 267 | return action_index 268 | 269 | def compute_mask(self): 270 | if not self.compute_graph.current_node: 271 | # first action must be an operator 272 | mask = np.concatenate( 273 | [np.ones(len(self.operators)), np.zeros(self.max_formal_elements)] 274 | ) 275 | else: 276 | current_arg_index = len(self.compute_graph.current_node.args) 277 | next_type = self.compute_graph.current_node.types[current_arg_index] 278 | available_types = ( 279 | self.operator_output_types + self.compute_graph.formal_element_types 280 | ) 281 | mask = np.array( 282 | [1 if issubclass(type_, next_type) else 0 for type_ in available_types] 283 | ) 284 | mask = np.concatenate( 285 | [ 286 | mask, 287 | np.zeros( 288 | self.max_formal_elements 289 | - len(self.compute_graph.formal_elements) 290 | ), 291 | ] 292 | ) 293 | return mask 294 | 295 | def mask_invalid_types(self, model_output): 296 | mask = self.compute_mask() 297 | if torch.is_tensor(model_output): 298 | mask = torch.from_numpy(mask).type(torch.FloatTensor) 299 | masked_output = mask * model_output 300 | return masked_output 301 | 302 | def render(self): 303 | pass 304 | 305 | def close(self): 306 | pass 307 | 308 | def setup(self): 309 | """To be ran on first use of the environment. 310 | Downloads data, splits data and trains tokenizer.""" 311 | print("Downloading Data:") 312 | self._get_data() 313 | print("Splitting Data:") 314 | self._split_data() 315 | print("Training Tokenizer:") 316 | self._train_tokenizer() 317 | 318 | 319 | def _get_data(self): 320 | import tarfile 321 | import requests 322 | 323 | url = 'https://storage.googleapis.com/mathematics-dataset/mathematics_dataset-v1.0.tar.gz' 324 | myfile = requests.get(url) 325 | open("mathematics_dataset-v1.0.tar.gz", 'wb').write(myfile.content) 326 | 327 | print("Data Downloaded") 328 | 329 | data_tar = tarfile.open(name=self.config["data_download_location"], mode='r:gz') 330 | data_tar.extractall(path=self.config["data_unpack_dir"]) 331 | print("Data unpacked") 332 | 333 | 334 | def _split_data(self): 335 | import os 336 | from tqdm import tqdm 337 | 338 | problem_filepaths = [os.path.join(os.path.join(self.config["data_unpack_dir"],self.config["all_data_dirpath"]), filename) for filename in 339 | self.config["selected_filenames"]] 340 | train_problem_filepaths = [os.path.join(self.config["data_dirpath"], filename) for filename in 341 | self.config["selected_filenames"]] 342 | test_problem_filepaths = [os.path.join(self.config["test_data_dirpath"], filename) for filename in 343 | self.config["selected_filenames"]] 344 | 345 | if os.path.isdir(self.config["data_dirpath"]) or os.path.isdir(self.config["test_data_dirpath"]): 346 | raise ValueError(f"data directories already exist") 347 | else: 348 | os.mkdir(self.config["data_dirpath"]) 349 | os.mkdir(self.config["test_data_dirpath"]) 350 | 351 | for filepath, train_filepath, test_filepath in tqdm( 352 | zip(problem_filepaths, train_problem_filepaths, test_problem_filepaths)): 353 | # read data 354 | with open(filepath, "r") as f: 355 | lines = f.readlines() 356 | num_pairs = len(lines) // 2 357 | num_train_pairs = int((1 - self.config["test_percentage"]) * num_pairs) 358 | 359 | # Write data 360 | with open(train_filepath, "w") as f: 361 | f.writelines(lines[:2 * num_train_pairs]) 362 | with open(test_filepath, "w") as f: 363 | f.writelines(lines[2 * num_train_pairs:]) 364 | print("train and test datasets have been created") 365 | 366 | def _get_corpus_for_tokenizer(self): 367 | from random import shuffle 368 | from sklearn.model_selection import train_test_split 369 | 370 | filepaths = [ 371 | f"mathematics_dataset-v1.0/train-easy/{filename}" for filename in self.config["selected_filenames"] 372 | ] 373 | questions = [] 374 | 375 | for filepath in filepaths: 376 | with open(filepath, "r") as f: 377 | lines = f.readlines() 378 | num_pairs = min(len(lines) // 2, self.config["num_problems_per_module_corpus"]) 379 | for i in range(0, 2 * num_pairs, 2): 380 | question = lines[i].strip() 381 | answer = lines[i + 1].strip() 382 | questions.append(question) 383 | 384 | shuffle(questions) 385 | train_questions, val_questions = train_test_split(questions, test_size=0.4) 386 | with open(self.config["corpus_path"], "w") as f: 387 | f.write("\n".join(train_questions)) 388 | print("Downloaded corpus for training tokenizer") 389 | 390 | def _train_tokenizer(self): 391 | import sentencepiece as spm 392 | #Get corpus 393 | self._get_corpus_for_tokenizer() 394 | # train tokenizer on question corpus 395 | hardcoded_symbols = ['G'] # why is 'G' needed? 396 | spm.SentencePieceTrainer.train(input=self.config["corpus_path"], 397 | model_prefix=self.config["tokenizer_filepath"], 398 | vocab_size=self.config["question_vocab_size"], 399 | user_defined_symbols=hardcoded_symbols) 400 | print("Tokenizer saved") 401 | -------------------------------------------------------------------------------- /math_prog_synth_env/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="math_env", version="0.0.1", install_requires=["gym", "sympy"]) 4 | -------------------------------------------------------------------------------- /math_prog_synth_env/tokenization/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnnyYeeee/math_prog_synth_env/15ba964796f265879d9ce41e27b5299c2e0831a0/math_prog_synth_env/tokenization/tokenizer.model -------------------------------------------------------------------------------- /math_prog_synth_env/typed_operators/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sympy 3 | from typing import List, Dict, Set 4 | import multiprocess as mp 5 | import time 6 | # from math import log 7 | 8 | 9 | # type definitions -------------------------------------- 10 | 11 | class Equation(object): 12 | def __init__(self, equation: str): 13 | assert len(equation.split("=")) == 2 14 | self.equation = equation 15 | 16 | def __str__(self): 17 | return self.equation 18 | 19 | def __eq__(self, equation): 20 | return self.equation == str(equation) 21 | 22 | def split(self, split_on): 23 | return self.equation.split(split_on) 24 | 25 | 26 | class Function(Equation): 27 | def __init__(self, function: str): 28 | assert len(function.split("=")) == 2 29 | function_arg_pattern = "([a-zA-Z0-9\s]+)\(([a-zA-Z0-9\s]+)\)" 30 | # extract parts of function definition 31 | lhs, rhs = function.split("=") 32 | match = re.match(function_arg_pattern, lhs) 33 | assert match is not None 34 | self.name, self.parameter = match.group(1), match.group(2) 35 | self.function = function 36 | self.equation = function 37 | 38 | def __str__(self): 39 | return str(self.function) 40 | 41 | def __eq__(self, function): 42 | return self.function == str(function) 43 | 44 | 45 | class Expression(object): 46 | def __init__(self, expression: str): 47 | assert "=" not in expression 48 | self.expression = str(expression) 49 | 50 | def __str__(self): 51 | return self.expression 52 | 53 | def __eq__(self, other): 54 | return self.expression == other.expression 55 | 56 | def __hash__(self): 57 | return hash(self.expression) 58 | 59 | 60 | class Variable(Expression): 61 | def __init__(self, variable: str): 62 | self.variable = str(variable) 63 | assert variable.isalpha() 64 | 65 | def __str__(self): 66 | return self.variable 67 | 68 | def __eq__(self, variable): 69 | return self.variable == variable 70 | 71 | def __hash__(self): 72 | return hash(self.variable) 73 | 74 | 75 | class Value(Expression): 76 | def __init__(self, value: float): 77 | self.value = float(value) 78 | 79 | def __str__(self): 80 | if self.value % 1 == 0: 81 | return str(int(self.value)) 82 | else: 83 | return str(self.value) 84 | 85 | def __eq__(self, value): 86 | return self.value == value.value 87 | 88 | def __hash__(self): 89 | return hash(str(self.value)) 90 | 91 | def __lt__(self, other): 92 | return self.value < other.value 93 | 94 | def __gt__(self, other): 95 | return self.value > other.value 96 | 97 | class Rational(Expression): 98 | def __init__(self, rational: str): 99 | self.rational = str(rational) 100 | try: 101 | self.numerator, self.denominator = [Value(x) for x in self.rational.split('/')] 102 | except: 103 | self.numerator = Value(self.rational) 104 | self.denominator = 1 105 | 106 | def __str__(self): 107 | return self.rational 108 | 109 | def __eq__(self, rational): 110 | return self.rational == str(rational) 111 | 112 | def __hash__(self): 113 | return hash(str(self.rational)) 114 | 115 | def __lt__(self, other): 116 | return sympy.Rational(self.rational) < sympy.Rational(other.rational) 117 | 118 | # operator definitions -------------------------------------- 119 | 120 | 121 | # solve_system(system: List[Equation]) -> Dict[Variable, Set[Value]] 122 | def solve_system(system: list) -> dict: 123 | """ 124 | solve a system of linear equations. 125 | 126 | :param system: List[ 127 | :return: Dict[Variable, Value] 128 | """ 129 | def sympy_solve(system, return_dict): 130 | # run in try-except to suppress exception logging (must be done here due to use of multiprocess) 131 | try: 132 | solutions = sympy.solve(system, rational=True) 133 | return_dict["solutions"] = solutions 134 | except: 135 | pass 136 | sympy_equations = [] 137 | for equation in system: 138 | lhs, rhs = str(equation).split("=") 139 | sympy_eq = sympy.Eq(sympy.sympify(lhs), sympy.sympify(rhs)) 140 | sympy_equations.append(sympy_eq) 141 | 142 | manager = mp.Manager() 143 | return_dict = manager.dict() 144 | p = mp.Process(target=sympy_solve, args=(sympy_equations, return_dict)) 145 | p.start() 146 | p.join(1) 147 | 148 | if p.is_alive(): 149 | p.terminate() 150 | p.join() 151 | solutions = return_dict.get("solutions", []) 152 | 153 | # Convert list to dictionary if no solution found. 154 | if len(solutions) == 0: 155 | raise Exception("no solution found") 156 | elif type(solutions) is dict: 157 | return {Variable(str(k)): set([Rational(v)]) for k, v in solutions.items()} 158 | elif type(solutions) is list: 159 | solutions_dict = {} 160 | for soln in solutions: 161 | for k, v in soln.items(): 162 | if str(k) in solutions_dict.keys(): 163 | solutions_dict[Variable(str(k))].add(Rational(v)) 164 | else: 165 | solutions_dict[Variable(str(k))] = set([Rational(v)]) 166 | return solutions_dict 167 | 168 | 169 | # append(system: List[Equation], equation: Equation) -> List[Equation] 170 | def append(system: list, equation: Equation) -> list: 171 | if not system: 172 | return [equation] 173 | else: 174 | system.append(equation) 175 | return system 176 | 177 | 178 | def append_to_empty_list(equation: Equation) -> list: 179 | return [equation] 180 | 181 | 182 | # lookup_value(mapping: Dict[Variable, Set[Value]], key: Variable) 183 | def lookup_value(mapping: dict, key: Variable) -> object: 184 | # TODO: figure out how to constrain output type in this case (multiple output types) 185 | assert key in mapping 186 | corresponding_set = mapping[key] 187 | if len(corresponding_set) == 1: 188 | return corresponding_set.pop() 189 | else: 190 | return corresponding_set 191 | 192 | 193 | # lookup_value_equation(mapping: Dict[Variable, Set[Value]], key: Variable) -> Equation: 194 | def lookup_value_equation(mapping: dict, key: Variable) -> Equation: 195 | assert key in mapping 196 | corresponding_set = mapping[key] 197 | value = corresponding_set.pop() 198 | return Equation(f"{key} = {value}") 199 | 200 | 201 | def make_equation(expression1: Expression, expression2: Expression) -> Equation: 202 | return Equation(f"{expression1} = {expression2}") 203 | 204 | 205 | def make_function(expression1: Expression, expression2: Expression) -> Function: 206 | return Function(f"{expression1} = {expression2}") 207 | 208 | 209 | def extract_isolated_variable(equation: Equation) -> Variable: 210 | lhs, rhs = str(equation).split("=") 211 | lhs, rhs = lhs.strip(), rhs.strip() 212 | if len(lhs) == 1 and lhs.isalpha(): 213 | return lhs 214 | elif len(rhs) == 1 and rhs.isalpha(): 215 | return rhs 216 | else: 217 | raise Exception("there is no isolated variable") 218 | 219 | 220 | def project_lhs(equation: Equation) -> Expression: 221 | return Expression(str(equation).split("=")[0].strip()) 222 | 223 | 224 | def project_rhs(equation: Equation) -> Expression: 225 | return Expression(str(equation).split("=")[1].strip()) 226 | 227 | 228 | def substitution_left_to_right(arb: object, eq: Equation) -> object: 229 | return str(arb).replace(str(project_lhs(eq)), str(project_rhs(eq))) 230 | 231 | 232 | def substitution_right_to_left(arb: object, eq: Equation) -> object: 233 | """substitution_right_to_left""" 234 | return str(arb).replace(str(project_rhs(eq)), str(project_lhs(eq))) 235 | 236 | 237 | def factor(inpt: Expression) -> Expression: 238 | output = Expression(str(sympy.factor(inpt))) 239 | return output 240 | 241 | 242 | def simplify(inpt: object) -> object: 243 | if "=" in str(inpt): 244 | lhs, rhs = str(inpt).split("=") 245 | lhs, rhs = lhs.strip(), rhs.strip() 246 | output = Equation(f"{sympy.simplify(lhs)} = {sympy.simplify(rhs)}".strip()) 247 | else: 248 | output = Expression(str(sympy.simplify(str(inpt))).strip()) 249 | return output 250 | 251 | 252 | def differentiate(expression: Expression) -> Expression: 253 | derivative = sympy.diff(sympy.sympify(str(expression))) 254 | return Expression(str(derivative)) 255 | 256 | 257 | def differentiate_wrt(expression: Expression, variable: Variable) -> Expression: 258 | derivative = sympy.diff(sympy.sympify(str(expression)), sympy.sympify(str(variable))) 259 | return Expression(str(derivative)) 260 | 261 | 262 | def replace_arg(function: Function, var: Variable) -> Function: 263 | lhs, rhs = function.split(" = ") 264 | lhs = lhs.replace('(' + str(function.parameter) + ')', '(' + str(var) + ')') 265 | rhs = rhs.replace(str(function.parameter), str(var)) 266 | return make_function(Expression(lhs), Expression(rhs)) 267 | 268 | 269 | def mod(numerator: Value, denominator: Value) -> Value: 270 | return Value(numerator.value % denominator.value) 271 | 272 | 273 | def divides(numerator: Value, denominator: Value) -> bool: 274 | return numerator.value % denominator.value == 0 275 | 276 | 277 | def gcd(x: Value, y: Value) -> Value: 278 | """greatest common divisor""" 279 | from math import gcd 280 | 281 | return Value(gcd(int(x.value), int(y.value))) 282 | 283 | 284 | def is_prime(x: Value) -> bool: 285 | return sympy.isprime(int(x.value)) 286 | 287 | 288 | def lcm(x: Value, y: Value) -> Value: 289 | """least common multiple""" 290 | import math 291 | assert int(x.value) == x.value and int(y.value) == y.value 292 | x, y = int(x.value), int(y.value) 293 | return Value(abs(x * y) // math.gcd(x, y)) 294 | 295 | def lcd(x: Rational, y: Rational) -> Value: 296 | """least common denominator""" 297 | return lcm(x.denominator, y.denominator) 298 | 299 | 300 | def prime_factors(n: Value) -> set: 301 | # https://stackoverflow.com/questions/16996217/prime-factorization-list 302 | assert int(n.value) == n.value 303 | 304 | if is_prime(n): 305 | return n 306 | 307 | n = int(n.value) 308 | divisors = [d for d in range(2, n // 2 + 1) if n % d == 0] 309 | 310 | return set( 311 | [Value(d) for d in divisors if all(d % od != 0 for od in divisors if od != d)] 312 | ) 313 | 314 | 315 | def evaluate_function(function_definition: Function, function_argument: Expression) -> Value: 316 | """ 317 | :param function_definition: e.g. 'f(x) = x + x**3' 318 | :param function_argument: e.g. either '2' or 'f(2)' 319 | :return: 320 | """ 321 | function_definition_pattern = "([a-zA-Z0-9\s]+)\(([a-zA-Z0-9\s]+)\)" 322 | function_arg_pattern = "([a-zA-Z0-9\s]+)\((-?[a-zA-Z0-9\s]+)\)" 323 | # extract parts of function definition 324 | lhs, rhs = str(function_definition).split("=") 325 | match = re.match(function_definition_pattern, lhs) 326 | function_name_from_definition, function_parameter = match.group(1), match.group(2) 327 | # extract parts of function argument 328 | function_argument_ = re.match(function_arg_pattern, str(function_argument)) 329 | if function_argument_ is not None: 330 | function_name_from_argument, function_argument = ( 331 | function_argument_.group(1), 332 | function_argument_.group(2), 333 | ) 334 | assert function_name_from_definition == function_name_from_argument 335 | # evaluate function 336 | rhs_with_arg = rhs.replace(function_parameter, f'({function_argument})') 337 | return Value(eval(rhs_with_arg)) 338 | 339 | 340 | def not_op(x: bool) -> bool: 341 | assert type(x) == bool 342 | return not x 343 | 344 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnnyYeeee/math_prog_synth_env/15ba964796f265879d9ce41e27b5299c2e0831a0/math_prog_synth_env/unit_testing/__init__.py -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/extract_formal_elements_examples.json: -------------------------------------------------------------------------------- 1 | { 2 | "Suppose -3*z + 133 = 4*n - 10, 5*n = 25. Let l = -21 + z. Let r = l + -11. Calculate the least common multiple of 7 and r.": [ 3 | "-3*z + 133 = 4*n - 10", 4 | "5*n = 25", 5 | "l = -21 + z", 6 | "r = l + -11", 7 | "7", 8 | "r" 9 | ], 10 | "Calculate the common denominator of 1/(3/(-6)) - 402/(-60) and -71/12.": [ 11 | "1/(3/(-6)) - 402/(-60)", 12 | "-71/12" 13 | ], 14 | "What is the common denominator of -64/1065 and 92/105?": [ 15 | "-64/1065", 16 | "92/105" 17 | ], 18 | "What is the smallest common multiple of (-4)/12*(-20 - -2) and 4?": [ 19 | "(-4)/12*(-20 - -2)", 20 | "4" 21 | ], 22 | "Let q = -54.3 + 54. Suppose 0 = -5*z - 8 - 7. Which is the nearest to -1/5? (a) 5 (b) z (c) q": [ 23 | "q = -54.3 + 54", 24 | "0 = -5*z - 8 - 7", 25 | "-1/5", 26 | "(a) 5 (b) z (c) q" 27 | ], 28 | "Let d(j) = -j**3 - 5*j**2 - 4*j + 1. Let n be d(-4). Suppose -5*h = 2*i - 2*h + n, 0 = i + 5*h - 10. What is the nearest to 0 in 1/3, i, -2?": [ 29 | "d(j) = -j**3 - 5*j**2 - 4*j + 1", 30 | "n", 31 | "d(-4)", 32 | "-5*h = 2*i - 2*h + n", 33 | "0 = i + 5*h - 10", 34 | "0", 35 | "1/3", 36 | "i", 37 | "-2" 38 | ], 39 | "Let f = -2.31 + 0.31. What is the nearest to f in 0.3, -2, 0.2?": [ 40 | "f = -2.31 + 0.31", 41 | "f", 42 | "0.3", 43 | "-2", 44 | "0.2" 45 | ], 46 | "Let o(v) = 77*v + 1. Let b(l) = 155*l + 2. Suppose 4*c - 25 = -c. Let a(u) = c*o(u) - 3*b(u). Is a(-4) composite?": [ 47 | "o(v) = 77*v + 1", 48 | "b(l) = 155*l + 2", 49 | "4*c - 25 = -c", 50 | "a(u) = c*o(u) - 3*b(u)", 51 | "a(-4)" 52 | ], 53 | "Let j = -5 - 28. Is j/6*(-1 - 13) a composite number?": [ 54 | "j = -5 - 28", 55 | "j/6*(-1 - 13)" 56 | ], 57 | "Suppose 0 = -j - 4*a + 611, 4*j + a - 1468 = 1051. Is j prime?": [ 58 | "0 = -j - 4*a + 611", 59 | "4*j + a - 1468 = 1051", 60 | "j" 61 | ], 62 | "Let l(b) = -142004*b - 62917*b - 377393*b. Let d be l(-1). Let v = d - 262314. Round v to the nearest 100000.": [ 63 | "l(b) = -142004*b - 62917*b - 377393*b", 64 | "d", 65 | "l(-1)", 66 | "v = d - 262314", 67 | "v", 68 | "100000" 69 | ], 70 | "Suppose 5*t - 2 = -7. Let z be -612*1 + (-2 - -3). Let c be t/(-4) + z/(-4). What is c rounded to the nearest ten?": [ 71 | "5*t - 2 = -7", 72 | "z", 73 | "-612*1 + (-2 - -3)", 74 | "c", 75 | "t/(-4) + z/(-4)", 76 | "c" 77 | ], 78 | "Let m = 1.5 - 7.5. Let z = m - -22. Let v = z + -16.00017. Round v to four decimal places.": [ 79 | "m = 1.5 - 7.5", 80 | "z = m - -22", 81 | "v = z + -16.00017", 82 | "v" 83 | ], 84 | "Let w(b) = -2*b - 3. Suppose 0*j + 16 = -3*j - o, j + 3*o = 8. Let u = j - -5. What is w(u)?": [ 85 | "w(b) = -2*b - 3", 86 | "0*j + 16 = -3*j - o", 87 | "j + 3*o = 8", 88 | "u = j - -5", 89 | "w(u)" 90 | ], 91 | "Let p(o) = 2*o**3 - 12*o**2 + 6*o - 5. Let i(m) = -m**3 + 6*m**2 - 3*m + 2. Let q be 82/12 - 2/(-12). Let f(s) = q*i(s) + 3*p(s). Determine f(5).": [ 92 | "p(o) = 2*o**3 - 12*o**2 + 6*o - 5", 93 | "i(m) = -m**3 + 6*m**2 - 3*m + 2", 94 | "q", 95 | "82/12 - 2/(-12)", 96 | "f(s) = q*i(s) + 3*p(s)", 97 | "f(5)" 98 | ], 99 | "Let l(r) be the third derivative of 3*r**6/40 - r**5/60 - 6*r**2. What is l(-1)?": [ 100 | "l(r)", 101 | "3*r**6/40 - r**5/60 - 6*r**2", 102 | "l(-1)" 103 | ], 104 | "Let o = -788/3 - -260. Which is bigger: -0.1 or o?": [ 105 | "o = -788/3 - -260", 106 | "-0.1", 107 | "o" 108 | ], 109 | "Let r = 4 + -2. Which is greater: r or 0.09?": [ 110 | "r = 4 + -2", 111 | "r", 112 | "0.09" 113 | ], 114 | "Let q = 17 - 18. Let v be (2 + q)*12/(-16). Is v > -1?": [ 115 | "q = 17 - 18", 116 | "v", 117 | "(2 + q)*12/(-16)", 118 | "v > -1" 119 | ], 120 | "Suppose 3*x + 197 = 4*x. Calculate the remainder when x is divided by 33.": [ 121 | "3*x + 197 = 4*x", 122 | "x", 123 | "33" 124 | ], 125 | "Suppose -106 = -2*u + s, u - 40 = -5*s + 13. Calculate the remainder when u is divided by 14.": [ 126 | "-106 = -2*u + s", 127 | "u - 40 = -5*s + 13", 128 | "u", 129 | "14" 130 | ], 131 | "Let x = -41 - -20. Let t = x + 27. Calculate the remainder when t is divided by 4.": [ 132 | "x = -41 - -20", 133 | "t = x + 27", 134 | "t", 135 | "4" 136 | ], 137 | "Let d = -25019/90 - -278. Let v(j) be the third derivative of 0 + 1/27*j**3 - d*j**5 + 1/54*j**4 + 3*j**2 + 0*j. Suppose v(o) = 0. What is o?": [ 138 | "d = -25019/90 - -278", 139 | "v(j)", 140 | "0 + 1/27*j**3 - d*j**5 + 1/54*j**4 + 3*j**2 + 0*j", 141 | "v(o) = 0", 142 | "o" 143 | ], 144 | "Let g be 2 - (0 - (-1 - -1)). Determine q so that -q**4 - 6*q**2 + 0*q**4 - 3 + g - 4*q - 4*q**3 = 0.": [ 145 | "g", 146 | "2 - (0 - (-1 - -1))", 147 | "q", 148 | "-q**4 - 6*q**2 + 0*q**4 - 3 + g - 4*q - 4*q**3 = 0" 149 | ], 150 | "Let d(k) be the first derivative of -1 - 4/3*k**3 + 0*k + 1/2*k**2. Find z such that d(z) = 0.": [ 151 | "d(k)", 152 | "-1 - 4/3*k**3 + 0*k + 1/2*k**2", 153 | "z", 154 | "d(z) = 0" 155 | ], 156 | "Suppose -55 = -8*l + 3*l. Let k = l + -7. What is the units digit of k?": [ 157 | "-55 = -8*l + 3*l", 158 | "k = l + -7", 159 | "k" 160 | ], 161 | "Let t(p) = p**3 - 3*p**2 - 4*p + 2. Let a be t(4). Suppose 2*f = a + 2. Let l = f - -12. What is the units digit of l?": [ 162 | "t(p) = p**3 - 3*p**2 - 4*p + 2", 163 | "a", 164 | "t(4)", 165 | "2*f = a + 2", 166 | "l = f - -12", 167 | "l" 168 | ], 169 | "Suppose 5*j - 1126 + 331 = 0. What is the tens digit of j?": [ 170 | "5*j - 1126 + 331 = 0", 171 | "j" 172 | ], 173 | "Suppose 0 = -4*x + 8*x - 40. Let h(i) = i**2 - 9*i - 14. Let n be h(x). Sort -1, 4, n.": [ 174 | "0 = -4*x + 8*x - 40", 175 | "h(i) = i**2 - 9*i - 14", 176 | "n", 177 | "h(x)", 178 | "-1", 179 | "4", 180 | "n" 181 | ], 182 | "Let g = 1 + 2. Let a = 0.95 - -0.05. Put a, g, -1 in descending order.": [ 183 | "g = 1 + 2", 184 | "a = 0.95 - -0.05", 185 | "a", 186 | "g", 187 | "-1" 188 | ], 189 | "Let m be (-7)/56 - (-1)/(-8). Sort m, 0, -4 in descending order.": [ 190 | "m", 191 | "(-7)/56 - (-1)/(-8)", 192 | "m", 193 | "0", 194 | "-4" 195 | ], 196 | "Let w be (-1 + 13)*3/(-6). Let b = w - -6. Let i = 2 - b. Solve -15 = 3*c + i*c for c.": [ 197 | "w", 198 | "(-1 + 13)*3/(-6)", 199 | "b = w - -6", 200 | "i = 2 - b", 201 | "-15 = 3*c + i*c", 202 | "c" 203 | ], 204 | "Suppose -c + 4*v + 2 = -24, -4*c - 3*v + 9 = 0. Solve 2*b - c = -b for b.": [ 205 | "-c + 4*v + 2 = -24", 206 | "-4*c - 3*v + 9 = 0", 207 | "2*b - c = -b", 208 | "b" 209 | ], 210 | "Let v(k) = k**3 + k**2 - k - 3. Let d be v(0). Let a be ((-15)/2)/d*4. Let x = a + -8. Solve -3 + 11 = x*p for p.": [ 211 | "v(k) = k**3 + k**2 - k - 3", 212 | "d", 213 | "v(0)", 214 | "a", 215 | "((-15)/2)/d*4", 216 | "x = a + -8", 217 | "-3 + 11 = x*p", 218 | "p" 219 | ], 220 | "Let h(t) = t**3 + t**2 + 1. Let v(d) = 6*d**3 + 24*d**2 + 4. Let w(j) = 4*h(j) - v(j). What is the third derivative of w(x) wrt x?": [ 221 | "h(t) = t**3 + t**2 + 1", 222 | "v(d) = 6*d**3 + 24*d**2 + 4", 223 | "w(j) = 4*h(j) - v(j)", 224 | "w(x)", 225 | "x" 226 | ], 227 | "Let v = -7 - -12. Suppose 0 = 2*h - 3*x - 16 - 5, 0 = -v*h + 3*x + 30. What is the first derivative of 5*t - h - t + 0 - 2*t wrt t?": [ 228 | "v = -7 - -12", 229 | "0 = 2*h - 3*x - 16 - 5", 230 | "0 = -v*h + 3*x + 30", 231 | "5*t - h - t + 0 - 2*t", 232 | "t" 233 | ], 234 | "Let b(y) be the second derivative of -3*y**8/56 - y**4/6 - y. What is the third derivative of b(o) wrt o?": [ 235 | "b(y)", 236 | "-3*y**8/56 - y**4/6 - y", 237 | "b(o)", 238 | "o" 239 | ], 240 | "Let p = -3 - -6. Let w(d) = 0*d**2 + p*d**2 - 2*d**2 - 3*d**2. Let t(b) = -3*b. Give t(w(k)).": [ 241 | "p = -3 - -6", 242 | "w(d) = 0*d**2 + p*d**2 - 2*d**2 - 3*d**2", 243 | "t(b) = -3*b", 244 | "t(w(k))" 245 | ], 246 | "Let m(s) = 7*s - 12. Let z(g) = -5*g**2. What is z(m(k))?": [ 247 | "m(s) = 7*s - 12", 248 | "z(g) = -5*g**2", 249 | "z(m(k))" 250 | ], 251 | "Let w(q) = 2*q**2. Let v(x) be the first derivative of 0*x + 0*x**2 + 4/3*x**3 - 2. Determine v(w(p)).": [ 252 | "w(q) = 2*q**2", 253 | "v(x)", 254 | "0*x + 0*x**2 + 4/3*x**3 - 2", 255 | "v(w(p))" 256 | ], 257 | "Let f be 4/22 - 20/(-11). Suppose s = -0*s + 4*n + 12, 0 = -n - f. Which is the second smallest value? (a) -0.2 (b) s (c) 2/7": [ 258 | "f", 259 | "4/22 - 20/(-11)", 260 | "s = -0*s + 4*n + 12", 261 | "0 = -n - f", 262 | "(a) -0.2 (b) s (c) 2/7" 263 | ], 264 | "Let s = 1.5 + -1.5. Suppose 0 = p + p + 8. Which is the third biggest value? (a) s (b) -5 (c) p": [ 265 | "s = 1.5 + -1.5", 266 | "0 = p + p + 8", 267 | "(a) s (b) -5 (c) p" 268 | ], 269 | "Let r = 1 - -4. Let u be (-3 - -1)*3/(-2). Suppose -r*s - u = -4*s. Which is the third biggest value? (a) -0.3 (b) 2/11 (c) s": [ 270 | "r = 1 - -4", 271 | "u", 272 | "(-3 - -1)*3/(-2)", 273 | "-r*s - u = -4*s", 274 | "(a) -0.3 (b) 2/11 (c) s" 275 | ], 276 | "Suppose 3*n = -0*x - 3*x + 93, -2*n - 2 = 0. Does 12 divide x?": [ 277 | "3*n = -0*x - 3*x + 93", 278 | "-2*n - 2 = 0", 279 | "12", 280 | "x" 281 | ], 282 | "Is 1330/(-28)*4/(-2) a multiple of 19?": [ 283 | "1330/(-28)*4/(-2)", 284 | "19" 285 | ], 286 | "Is 3 - (1344/(-10) + 2/5) a multiple of 36?": [ 287 | "3 - (1344/(-10) + 2/5)", 288 | "36" 289 | ], 290 | "Suppose 2*y + 12 = 6*y. Suppose y = f - 15. Solve -8 = -4*w, -3*d - 4*w + f = -8*d for d.": [ 291 | "2*y + 12 = 6*y", 292 | "y = f - 15", 293 | "-8 = -4*w", 294 | "-3*d - 4*w + f = -8*d", 295 | "d" 296 | ], 297 | "Let l(v) = -v**3 + 12*v**2 + 13*v + 2. Let r(q) = -2*q + 5. Let c be r(-4). Let y be l(c). Solve -w + 2 = -3*s - 8, s + 1 = -y*w for s.": [ 298 | "l(v) = -v**3 + 12*v**2 + 13*v + 2", 299 | "r(q) = -2*q + 5", 300 | "c", 301 | "r(-4)", 302 | "y", 303 | "l(c)", 304 | "-w + 2 = -3*s - 8", 305 | "s + 1 = -y*w", 306 | "s" 307 | ], 308 | "Suppose 0 = f + 1, 17*f = 5*w + 12*f - 15. Suppose -3*s + 2 = -13. Suppose s*c + 3*j = 36, 3*c + 0*j - 18 = -3*j. Solve 20 = a - 5*z, -2*a + c = -w*z - 7 for a.": [ 309 | "0 = f + 1", 310 | "17*f = 5*w + 12*f - 15", 311 | "-3*s + 2 = -13", 312 | "s*c + 3*j = 36", 313 | "3*c + 0*j - 18 = -3*j", 314 | "20 = a - 5*z", 315 | "-2*a + c = -w*z - 7", 316 | "a" 317 | ], 318 | "Let q be (25 + 1)/2 - (5 + -3). What is the highest common divisor of q and 99?": [ 319 | "q", 320 | "(25 + 1)/2 - (5 + -3)", 321 | "q", 322 | "99" 323 | ], 324 | "Let n(j) = 5*j**3 - j**2 + 2*j - 1. Let u be 7/9 - 6/(-27). Let v be n(u). Calculate the greatest common factor of 1 and v.": [ 325 | "n(j) = 5*j**3 - j**2 + 2*j - 1", 326 | "u", 327 | "7/9 - 6/(-27)", 328 | "v", 329 | "n(u)", 330 | "1", 331 | "v" 332 | ], 333 | "Let f be (-6)/5*(-360)/(-27). Suppose -5*k - 5*a = -335, 0*k = 4*k + 3*a - 271. Let p = k + f. Calculate the greatest common factor of p and 6.": [ 334 | "f", 335 | "(-6)/5*(-360)/(-27)", 336 | "-5*k - 5*a = -335", 337 | "0*k = 4*k + 3*a - 271", 338 | "p = k + f", 339 | "p", 340 | "6" 341 | ], 342 | "Let k(w) = -w**2 + 13*w - 4. What are the prime factors of k(6)?": [ 343 | "k(w) = -w**2 + 13*w - 4", 344 | "k(6)" 345 | ], 346 | "Let w(x) = x**2 + 10*x + 24. List the prime factors of w(-11).": [ 347 | "w(x) = x**2 + 10*x + 24", 348 | "w(-11)" 349 | ], 350 | "Let x(m) = m**3 + 6*m**2 - 7*m + 4. Let k be x(-7). Suppose g + 4*u = -12, 3*u = -0*g - k*g + 4. What are the prime factors of g?": [ 351 | "x(m) = m**3 + 6*m**2 - 7*m + 4", 352 | "k", 353 | "x(-7)", 354 | "g + 4*u = -12", 355 | "3*u = -0*g - k*g + 4", 356 | "g" 357 | ] 358 | } 359 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/extract_formal_elements_examples.py: -------------------------------------------------------------------------------- 1 | from math_prog_synth_env.typed_operators import * 2 | 3 | typed_examples = { 4 | "Solve 0 = 4*b + b + 15 for b.": [ 5 | Equation("0 = 4*b + b + 15"), 6 | Variable("b") 7 | ], 8 | "Suppose -3*z + 133 = 4*n - 10, 5*n = 25. Let l = -21 + z. Let r = l + -11. Calculate the least common multiple of 7 and r.": [ 9 | Equation("-3*z + 133 = 4*n - 10"), 10 | Equation("5*n = 25"), 11 | Equation("l = -21 + z"), 12 | Equation("r = l + -11"), 13 | Value("7"), 14 | Variable("r") 15 | ], 16 | "Calculate the common denominator of 1/(3/(-6)) - 402/(-60) and -71/12.": [ 17 | Rational(sympy.sympify("1/(3/(-6)) - 402/(-60)")), 18 | Rational("-71/12") 19 | ], 20 | "What is the common denominator of -64/1065 and 92/105?": [ 21 | Rational("-64/1065"), 22 | Rational("92/105") 23 | ], 24 | "What is the smallest common multiple of (-4)/12*(-20 - -2) and 4?": [ 25 | Value(sympy.sympify("(-4)/12*(-20 - -2)")), 26 | Value("4") 27 | ], 28 | # "Let q = -54.3 + 54. Suppose 0 = -5*z - 8 - 7. Which is the nearest to -1/5? (a) 5 (b) z (c) q": [ 29 | # Equation("q = -54.3 + 54"), 30 | # Equation("0 = -5*z - 8 - 7"), 31 | # Rational("-1/5"), 32 | # "(a) 5 (b) z (c) q" 33 | # ], 34 | "Let d(j) = -j**3 - 5*j**2 - 4*j + 1. Let n be d(-4). Suppose -5*h = 2*i - 2*h + n, 0 = i + 5*h - 10. What is the nearest to 0 in 1/3, i, -2?": [ 35 | Function("d(j) = -j**3 - 5*j**2 - 4*j + 1"), 36 | Variable("n"), 37 | Expression("d(-4)"), 38 | Equation("-5*h = 2*i - 2*h + n"), 39 | Equation("0 = i + 5*h - 10"), 40 | Value("0"), 41 | Rational("1/3"), 42 | Variable("i"), 43 | Value("-2") 44 | ], 45 | "Let f = -2.31 + 0.31. What is the nearest to f in 0.3, -2, 0.2?": [ 46 | Equation("f = -2.31 + 0.31"), 47 | Variable("f"), 48 | Value("0.3"), 49 | Value("-2"), 50 | Value("0.2") 51 | ], 52 | "Let o(v) = 77*v + 1. Let b(l) = 155*l + 2. Suppose 4*c - 25 = -c. Let a(u) = c*o(u) - 3*b(u). Is a(-4) composite?": [ 53 | Function("o(v) = 77*v + 1"), 54 | Function("b(l) = 155*l + 2"), 55 | Equation("4*c - 25 = -c"), 56 | Function("a(u) = c*o(u) - 3*b(u)"), 57 | Expression("a(-4)") 58 | ], 59 | "Let j = -5 - 28. Is j/6*(-1 - 13) a composite number?": [ 60 | Equation("j = -5 - 28"), 61 | Expression("j/6*(-1 - 13)") 62 | ], 63 | "Suppose 0 = -j - 4*a + 611, 4*j + a - 1468 = 1051. Is j prime?": [ 64 | Equation("0 = -j - 4*a + 611"), 65 | Equation("4*j + a - 1468 = 1051"), 66 | Variable("j") 67 | ], 68 | "Let l(b) = -142004*b - 62917*b - 377393*b. Let d be l(-1). Let v = d - 262314. Round v to the nearest 100000.": [ 69 | Function("l(b) = -142004*b - 62917*b - 377393*b"), 70 | Variable("d"), 71 | Expression("l(-1)"), 72 | Equation("v = d - 262314"), 73 | Variable("v"), 74 | Value("100000") 75 | ], 76 | "Suppose 5*t - 2 = -7. Let z be -612*1 + (-2 - -3). Let c be t/(-4) + z/(-4). What is c rounded to the nearest ten?": [ 77 | Equation("5*t - 2 = -7"), 78 | Variable("z"), 79 | Value(sympy.sympify("-612*1 + (-2 - -3)")), 80 | Variable("c"), 81 | Expression("t/(-4) + z/(-4)"), 82 | Variable("c") 83 | ], 84 | "Let m = 1.5 - 7.5. Let z = m - -22. Let v = z + -16.00017. Round v to four decimal places.": [ 85 | Equation("m = 1.5 - 7.5"), 86 | Equation("z = m - -22"), 87 | Equation("v = z + -16.00017"), 88 | Variable("v") 89 | ], 90 | "Let w(b) = -2*b - 3. Suppose 0*j + 16 = -3*j - o, j + 3*o = 8. Let u = j - -5. What is w(u)?": [ 91 | Function("w(b) = -2*b - 3"), 92 | Equation("0*j + 16 = -3*j - o"), 93 | Equation("j + 3*o = 8"), 94 | Equation("u = j - -5"), 95 | Expression("w(u)") 96 | ], 97 | "Let p(o) = 2*o**3 - 12*o**2 + 6*o - 5. Let i(m) = -m**3 + 6*m**2 - 3*m + 2. Let q be 82/12 - 2/(-12). Let f(s) = q*i(s) + 3*p(s). Determine f(5).": [ 98 | Function("p(o) = 2*o**3 - 12*o**2 + 6*o - 5"), 99 | Function("i(m) = -m**3 + 6*m**2 - 3*m + 2"), 100 | Variable("q"), 101 | Value(sympy.sympify("82/12 - 2/(-12)")), 102 | Function("f(s) = q*i(s) + 3*p(s)"), 103 | Expression("f(5)") 104 | ], 105 | "Let l(r) be the third derivative of 3*r**6/40 - r**5/60 - 6*r**2. What is l(-1)?": [ 106 | Expression("l(r)"), 107 | Expression("3*r**6/40 - r**5/60 - 6*r**2"), 108 | Expression("l(-1)") 109 | ], 110 | "Let o = -788/3 - -260. Which is bigger: -0.1 or o?": [ 111 | Equation("o = -788/3 - -260"), 112 | Value("-0.1"), 113 | Variable("o") 114 | ], 115 | "Let r = 4 + -2. Which is greater: r or 0.09?": [ 116 | Equation("r = 4 + -2"), 117 | Variable("r"), 118 | Value("0.09") 119 | ], 120 | # "Let q = 17 - 18. Let v be (2 + q)*12/(-16). Is v > -1?": [ 121 | # Equation("q = 17 - 18"), 122 | # Variable("v"), 123 | # Expression("(2 + q)*12/(-16)"), 124 | # "v > -1" 125 | # ], 126 | "Suppose 3*x + 197 = 4*x. Calculate the remainder when x is divided by 33.": [ 127 | Equation("3*x + 197 = 4*x"), 128 | Variable("x"), 129 | Value("33") 130 | ], 131 | "Suppose -106 = -2*u + s, u - 40 = -5*s + 13. Calculate the remainder when u is divided by 14.": [ 132 | Equation("-106 = -2*u + s"), 133 | Equation("u - 40 = -5*s + 13"), 134 | Variable("u"), 135 | Value("14") 136 | ], 137 | "Let x = -41 - -20. Let t = x + 27. Calculate the remainder when t is divided by 4.": [ 138 | Equation("x = -41 - -20"), 139 | Equation("t = x + 27"), 140 | Variable("t"), 141 | Value("4") 142 | ], 143 | "Let d = -25019/90 - -278. Let v(j) be the third derivative of 0 + 1/27*j**3 - d*j**5 + 1/54*j**4 + 3*j**2 + 0*j. Suppose v(o) = 0. What is o?": [ 144 | Equation("d = -25019/90 - -278"), 145 | Expression("v(j)"), 146 | Expression("0 + 1/27*j**3 - d*j**5 + 1/54*j**4 + 3*j**2 + 0*j"), 147 | Function("v(o) = 0"), 148 | Variable("o") 149 | ], 150 | "Let g be 2 - (0 - (-1 - -1)). Determine q so that -q**4 - 6*q**2 + 0*q**4 - 3 + g - 4*q - 4*q**3 = 0.": [ 151 | Variable("g"), 152 | Value(sympy.sympify("2 - (0 - (-1 - -1))")), 153 | Variable("q"), 154 | Equation("-q**4 - 6*q**2 + 0*q**4 - 3 + g - 4*q - 4*q**3 = 0") 155 | ], 156 | "Let d(k) be the first derivative of -1 - 4/3*k**3 + 0*k + 1/2*k**2. Find z such that d(z) = 0.": [ 157 | Expression("d(k)"), 158 | Expression("-1 - 4/3*k**3 + 0*k + 1/2*k**2"), 159 | Variable("z"), 160 | Function("d(z) = 0") 161 | ], 162 | "Suppose -55 = -8*l + 3*l. Let k = l + -7. What is the units digit of k?": [ 163 | Equation("-55 = -8*l + 3*l"), 164 | Equation("k = l + -7"), 165 | Variable("k") 166 | ], 167 | "Let t(p) = p**3 - 3*p**2 - 4*p + 2. Let a be t(4). Suppose 2*f = a + 2. Let l = f - -12. What is the units digit of l?": [ 168 | Function("t(p) = p**3 - 3*p**2 - 4*p + 2"), 169 | Variable("a"), 170 | Expression("t(4)"), 171 | Equation("2*f = a + 2"), 172 | Equation("l = f - -12"), 173 | Variable("l") 174 | ], 175 | "Suppose 5*j - 1126 + 331 = 0. What is the tens digit of j?": [ 176 | Equation("5*j - 1126 + 331 = 0"), 177 | Variable("j") 178 | ], 179 | "Suppose 0 = -4*x + 8*x - 40. Let h(i) = i**2 - 9*i - 14. Let n be h(x). Sort -1, 4, n.": [ 180 | Equation("0 = -4*x + 8*x - 40"), 181 | Function("h(i) = i**2 - 9*i - 14"), 182 | Variable("n"), 183 | Expression("h(x)"), 184 | Value("-1"), 185 | Value("4"), 186 | Variable("n") 187 | ], 188 | "Let g = 1 + 2. Let a = 0.95 - -0.05. Put a, g, -1 in descending order.": [ 189 | Equation("g = 1 + 2"), 190 | Equation("a = 0.95 - -0.05"), 191 | Variable("a"), 192 | Variable("g"), 193 | Value("-1") 194 | ], 195 | "Let m be (-7)/56 - (-1)/(-8). Sort m, 0, -4 in descending order.": [ 196 | Variable("m"), 197 | Rational(sympy.sympify("(-7)/56 - (-1)/(-8)")), 198 | Variable("m"), 199 | Value("0"), 200 | Value("-4") 201 | ], 202 | "Let w be (-1 + 13)*3/(-6). Let b = w - -6. Let i = 2 - b. Solve -15 = 3*c + i*c for c.": [ 203 | Variable("w"), 204 | Value(sympy.sympify("(-1 + 13)*3/(-6)")), 205 | Equation("b = w - -6"), 206 | Equation("i = 2 - b"), 207 | Equation("-15 = 3*c + i*c"), 208 | Variable("c") 209 | ], 210 | "Suppose -c + 4*v + 2 = -24, -4*c - 3*v + 9 = 0. Solve 2*b - c = -b for b.": [ 211 | Equation("-c + 4*v + 2 = -24"), 212 | Equation("-4*c - 3*v + 9 = 0"), 213 | Equation("2*b - c = -b"), 214 | Variable("b") 215 | ], 216 | "Let v(k) = k**3 + k**2 - k - 3. Let d be v(0). Let a be ((-15)/2)/d*4. Let x = a + -8. Solve -3 + 11 = x*p for p.": [ 217 | Function("v(k) = k**3 + k**2 - k - 3"), 218 | Variable("d"), 219 | Expression("v(0)"), 220 | Variable("a"), 221 | Expression("((-15)/2)/d*4"), 222 | Equation("x = a + -8"), 223 | Equation("-3 + 11 = x*p"), 224 | Variable("p") 225 | ], 226 | "Let h(t) = t**3 + t**2 + 1. Let v(d) = 6*d**3 + 24*d**2 + 4. Let w(j) = 4*h(j) - v(j). What is the third derivative of w(x) wrt x?": [ 227 | Function("h(t) = t**3 + t**2 + 1"), 228 | Function("v(d) = 6*d**3 + 24*d**2 + 4"), 229 | Function("w(j) = 4*h(j) - v(j)"), 230 | Expression("w(x)"), 231 | Variable("x") 232 | ], 233 | "Let v = -7 - -12. Suppose 0 = 2*h - 3*x - 16 - 5, 0 = -v*h + 3*x + 30. What is the first derivative of 5*t - h - t + 0 - 2*t wrt t?": [ 234 | Equation("v = -7 - -12"), 235 | Equation("0 = 2*h - 3*x - 16 - 5"), 236 | Equation("0 = -v*h + 3*x + 30"), 237 | Expression("5*t - h - t + 0 - 2*t"), 238 | Variable("t") 239 | ], 240 | "Let b(y) be the second derivative of -3*y**8/56 - y**4/6 - y. What is the third derivative of b(o) wrt o?": [ 241 | Expression("b(y)"), 242 | Expression("-3*y**8/56 - y**4/6 - y"), 243 | Expression("b(o)"), 244 | Variable("o") 245 | ], 246 | "Let p = -3 - -6. Let w(d) = 0*d**2 + p*d**2 - 2*d**2 - 3*d**2. Let t(b) = -3*b. Give t(w(k)).": [ 247 | Equation("p = -3 - -6"), 248 | Function("w(d) = 0*d**2 + p*d**2 - 2*d**2 - 3*d**2"), 249 | Function("t(b) = -3*b"), 250 | Expression("t(w(k))") 251 | ], 252 | "Let m(s) = 7*s - 12. Let z(g) = -5*g**2. What is z(m(k))?": [ 253 | Function("m(s) = 7*s - 12"), 254 | Function("z(g) = -5*g**2"), 255 | Expression("z(m(k))") 256 | ], 257 | "Let w(q) = 2*q**2. Let v(x) be the first derivative of 0*x + 0*x**2 + 4/3*x**3 - 2. Determine v(w(p)).": [ 258 | Function("w(q) = 2*q**2"), 259 | Expression("v(x)"), 260 | Expression("0*x + 0*x**2 + 4/3*x**3 - 2"), 261 | Expression("v(w(p))") 262 | ], 263 | # "Let f be 4/22 - 20/(-11). Suppose s = -0*s + 4*n + 12, 0 = -n - f. Which is the second smallest value? (a) -0.2 (b) s (c) 2/7": [ 264 | # Value("f"), 265 | # Expression("4/22 - 20/(-11)"), 266 | # Equation("s = -0*s + 4*n + 12"), 267 | # Equation("0 = -n - f"), 268 | # "(a) -0.2 (b) s (c) 2/7" 269 | # ], 270 | # "Let s = 1.5 + -1.5. Suppose 0 = p + p + 8. Which is the third biggest value? (a) s (b) -5 (c) p": [ 271 | # Equation("s = 1.5 + -1.5"), 272 | # Equation("0 = p + p + 8"), 273 | # "(a) s (b) -5 (c) p" 274 | # ], 275 | # "Let r = 1 - -4. Let u be (-3 - -1)*3/(-2). Suppose -r*s - u = -4*s. Which is the third biggest value? (a) -0.3 (b) 2/11 (c) s": [ 276 | # Equation("r = 1 - -4"), 277 | # Value("u"), 278 | # Expression("(-3 - -1)*3/(-2)"), 279 | # Equation("-r*s - u = -4*s"), 280 | # "(a) -0.3 (b) 2/11 (c) s" 281 | # ], 282 | "Suppose 3*n = -0*x - 3*x + 93, -2*n - 2 = 0. Does 12 divide x?": [ 283 | Equation("3*n = -0*x - 3*x + 93"), 284 | Equation("-2*n - 2 = 0"), 285 | Value("12"), 286 | Variable("x") 287 | ], 288 | "Is 1330/(-28)*4/(-2) a multiple of 19?": [ 289 | Value(sympy.sympify("1330/(-28)*4/(-2)")), 290 | Value("19") 291 | ], 292 | "Is 3 - (1344/(-10) + 2/5) a multiple of 36?": [ 293 | Value(sympy.sympify("3 - (1344/(-10) + 2/5)")), 294 | Value("36") 295 | ], 296 | "Suppose 2*y + 12 = 6*y. Suppose y = f - 15. Solve -8 = -4*w, -3*d - 4*w + f = -8*d for d.": [ 297 | Equation("2*y + 12 = 6*y"), 298 | Equation("y = f - 15"), 299 | Equation("-8 = -4*w"), 300 | Equation("-3*d - 4*w + f = -8*d"), 301 | Variable("d") 302 | ], 303 | "Let l(v) = -v**3 + 12*v**2 + 13*v + 2. Let r(q) = -2*q + 5. Let c be r(-4). Let y be l(c). Solve -w + 2 = -3*s - 8, s + 1 = -y*w for s.": [ 304 | Function("l(v) = -v**3 + 12*v**2 + 13*v + 2"), 305 | Function("r(q) = -2*q + 5"), 306 | Variable("c"), 307 | Expression("r(-4)"), 308 | Variable("y"), 309 | Expression("l(c)"), 310 | Equation("-w + 2 = -3*s - 8"), 311 | Equation("s + 1 = -y*w"), 312 | Variable("s") 313 | ], 314 | "Suppose 0 = f + 1, 17*f = 5*w + 12*f - 15. Suppose -3*s + 2 = -13. Suppose s*c + 3*j = 36, 3*c + 0*j - 18 = -3*j. Solve 20 = a - 5*z, -2*a + c = -w*z - 7 for a.": [ 315 | Equation("0 = f + 1"), 316 | Equation("17*f = 5*w + 12*f - 15"), 317 | Equation("-3*s + 2 = -13"), 318 | Equation("s*c + 3*j = 36"), 319 | Equation("3*c + 0*j - 18 = -3*j"), 320 | Equation("20 = a - 5*z"), 321 | Equation("-2*a + c = -w*z - 7"), 322 | Variable("a") 323 | ], 324 | "Let q be (25 + 1)/2 - (5 + -3). What is the highest common divisor of q and 99?": [ 325 | Variable("q"), 326 | Value(sympy.sympify("(25 + 1)/2 - (5 + -3)")), 327 | Variable("q"), 328 | Value("99") 329 | ], 330 | "Let n(j) = 5*j**3 - j**2 + 2*j - 1. Let u be 7/9 - 6/(-27). Let v be n(u). Calculate the greatest common factor of 1 and v.": [ 331 | Function("n(j) = 5*j**3 - j**2 + 2*j - 1"), 332 | Variable("u"), 333 | Value(sympy.sympify("7/9 - 6/(-27)")), 334 | Variable("v"), 335 | Expression("n(u)"), 336 | Value("1"), 337 | Variable("v") 338 | ], 339 | "Let f be (-6)/5*(-360)/(-27). Suppose -5*k - 5*a = -335, 0*k = 4*k + 3*a - 271. Let p = k + f. Calculate the greatest common factor of p and 6.": [ 340 | Variable("f"), 341 | Value(sympy.sympify("(-6)/5*(-360)/(-27)")), 342 | Equation("-5*k - 5*a = -335"), 343 | Equation("0*k = 4*k + 3*a - 271"), 344 | Equation("p = k + f"), 345 | Variable("p"), 346 | Value("6") 347 | ], 348 | "Let k(w) = -w**2 + 13*w - 4. What are the prime factors of k(6)?": [ 349 | Function("k(w) = -w**2 + 13*w - 4"), 350 | Expression("k(6)") 351 | ], 352 | "Let w(x) = x**2 + 10*x + 24. List the prime factors of w(-11).": [ 353 | Function("w(x) = x**2 + 10*x + 24"), 354 | Expression("w(-11)") 355 | ], 356 | "Let x(m) = m**3 + 6*m**2 - 7*m + 4. Let k be x(-7). Suppose g + 4*u = -12, 3*u = -0*g - k*g + 4. What are the prime factors of g?": [ 357 | Function("x(m) = m**3 + 6*m**2 - 7*m + 4"), 358 | Variable("k"), 359 | Expression("x(-7)"), 360 | Equation("g + 4*u = -12"), 361 | Equation("3*u = -0*g - k*g + 4"), 362 | Variable("g"), 363 | ] 364 | } 365 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/params.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "num_problems_per_module": 10, 3 | "num_problems_per_module_corpus": 30000, 4 | "validation_percentage": 0.2, 5 | "encode_question": true, 6 | "max_sequence_length": 125, 7 | "question_vocab_size": 250, 8 | "max_difficulty": 0, 9 | "univariate_differentiation": true, 10 | "num_environments": 50, 11 | "corpus_path": "corpus.txt", 12 | "tokenizer_filepath": "tokenizer", 13 | "max_formal_elements": 13, 14 | "max_num_nodes": 7, 15 | "data_download_location": "mathematics_dataset-v1.0.tar.gz", 16 | "data_unpack_dir": ".", 17 | "all_data_dirpath": "mathematics_dataset-v1.0/train-easy", 18 | "data_dirpath": "mathematics_dataset-v1.0/train", 19 | "test_data_dirpath": "mathematics_dataset-v1.0/test", 20 | "test_percentage": 0.1, # Percentage of data to be used as test set 21 | "selected_filenames": [ 22 | 'numbers__is_factor.txt', 23 | 'numbers__is_prime.txt', 24 | 'numbers__list_prime_factors.txt', 25 | 'calculus__differentiate.txt', 26 | 'polynomials__evaluate.txt', 27 | 'numbers__div_remainder.txt', 28 | 'numbers__gcd.txt', 29 | 'numbers__lcm.txt', 30 | 'algebra__linear_1d.txt', 31 | 'algebra__polynomial_roots.txt', 32 | 'algebra__linear_2d.txt', 33 | 'algebra__linear_1d_composed.txt', 34 | 'algebra__linear_2d_composed.txt', 35 | 'algebra__polynomial_roots_composed.txt', 36 | 'calculus__differentiate_composed.txt', 37 | 'numbers__div_remainder_composed.txt', 38 | 'numbers__gcd_composed.txt', 39 | 'numbers__is_factor_composed.txt', 40 | 'numbers__is_prime_composed.txt', 41 | 'numbers__lcm_composed.txt', 42 | 'numbers__list_prime_factors_composed.txt', 43 | 'polynomials__evaluate_composed.txt', 44 | 'polynomials__compose.txt' 45 | ], 46 | "types":[ 47 | "object", 48 | "Equation", 49 | "Function", 50 | "Expression", 51 | "Variable", 52 | "Value", 53 | "Rational" 54 | ], 55 | "operators":[ 56 | "lookup_value", 57 | "solve_system", 58 | "append", 59 | "append_to_empty_list", 60 | "factor", 61 | "differentiate", 62 | "mod", 63 | "gcd", 64 | "divides", 65 | "is_prime", 66 | "lcm", 67 | "lcd", 68 | "prime_factors", 69 | "evaluate_function", 70 | "not_op", 71 | "differentiate_wrt" 72 | # "make_equation", 73 | # "simplify", 74 | # "make_function", 75 | # "replace_arg", 76 | # "lookup_value_equation", 77 | # "extract_isolated_variable", 78 | # "substitution_left_to_right", 79 | ] 80 | } 81 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/algebra__linear_1d.txt: -------------------------------------------------------------------------------- 1 | Solve 0 = 4*b + b + 15 for b. 2 | -3 3 | Solve -3*d = -0*d + 3 for d. 4 | -1 5 | Solve -4*h + 9 = 41 for h. 6 | -8 7 | Solve 2514*m = 2508*m - 24 for m. 8 | -4 9 | Solve -7*a + 6*a = 4 for a. 10 | -4 11 | Solve 288*w - 298*w = -70 for w. 12 | 7 13 | Solve -14*h = -4*h - 10 for h. 14 | 1 15 | Solve 5*w + 3 = -2 for w. 16 | -1 17 | Solve -15*f + 21*f - 12 = 0 for f. 18 | 2 19 | Solve -22 = 6*c - 4 for c. 20 | -3 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/algebra__linear_2d.txt: -------------------------------------------------------------------------------- 1 | Solve 0 = 4*f - 0*t - 4*t - 4, -4*f + t = -13 for f. 2 | 4 3 | Solve 5*a = -2*h + 15, 0 = -2*h + 6*a - 2*a - 12 for h. 4 | 0 5 | Solve 2*k + 14 = 2*j, -7*k - 4*j + 24 = -9*k for k. 6 | -2 7 | Solve -2*o - 3*w = 13, 0*o - 4*o - 5*w = 21 for o. 8 | 1 9 | Solve -16 = 4*d - 4, 3*o = 4*d + 3 for o. 10 | -3 11 | Solve 3*z + w + 7 = 15, w - 3 = 2*z for z. 12 | 1 13 | Solve 4*o + 7 = -0*o - 5*k, -11 = 2*o - 5*k for o. 14 | -3 15 | Solve -4*f + 3*t + 2 = 0, 0*t = 4*f + 2*t - 12 for f. 16 | 2 17 | Solve 0 = -3*q - 4*n + 15, 3*q + 2*n = -3*n + 18 for q. 18 | 1 19 | Solve 4*f - 35 + 0 = -3*y, 4*y - 30 = -2*f for f. 20 | 5 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/algebra__polynomial_roots.txt: -------------------------------------------------------------------------------- 1 | Solve -3*h**2/2 - 24*h - 45/2 = 0 for h. 2 | -15, -1 3 | Factor -n**2/3 - 25*n - 536/3. 4 | -(n + 8)*(n + 67)/3 5 | Let c**3/9 - 11*c**2/3 + 35*c - 75 = 0. What is c? 6 | 3, 15 7 | What is f in -87616*f**2 - 1776*f - 9 = 0? 8 | -3/296 9 | Find s such that 9*s**4 - 8958*s**3 - 14952*s**2 - 2994*s + 2991 = 0. 10 | -1, 1/3, 997 11 | Factor -4*a**2/9 - 184*a/9 + 800/9. 12 | -4*(a - 4)*(a + 50)/9 13 | Factor -a**3/4 + 9*a**2/4 + 210*a - 3100. 14 | -(a - 20)**2*(a + 31)/4 15 | Factor 54*a**3 + 483*a**2 + 405*a - 24. 16 | 3*(a + 1)*(a + 8)*(18*a - 1) 17 | Factor 5*q**3 - 295*q**2 - 605*q - 305. 18 | 5*(q - 61)*(q + 1)**2 19 | What is o in -5*o**5 - 65*o**4 - 170*o**3 + 310*o**2 + 175*o - 245 = 0? 20 | -7, -1, 1 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/numbers__div_remainder.txt: -------------------------------------------------------------------------------- 1 | Calculate the remainder when 93 is divided by 59. 2 | 34 3 | What is the remainder when 779 is divided by 223? 4 | 110 5 | What is the remainder when 1862 is divided by 16? 6 | 6 7 | What is the remainder when 813 is divided by 24? 8 | 21 9 | What is the remainder when 164 is divided by 85? 10 | 79 11 | What is the remainder when 412 is divided by 13? 12 | 9 13 | Calculate the remainder when 52 is divided by 9. 14 | 7 15 | Calculate the remainder when 352 is divided by 138. 16 | 76 17 | What is the remainder when 187 is divided by 107? 18 | 80 19 | What is the remainder when 15193 is divided by 15? 20 | 13 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/numbers__gcd.txt: -------------------------------------------------------------------------------- 1 | Calculate the highest common divisor of 1300 and 300. 2 | 100 3 | Calculate the greatest common factor of 11130 and 6. 4 | 6 5 | What is the greatest common divisor of 352 and 3454? 6 | 22 7 | Calculate the greatest common divisor of 17 and 272. 8 | 17 9 | Calculate the highest common divisor of 11711 and 49. 10 | 49 11 | What is the greatest common divisor of 275 and 495? 12 | 55 13 | Calculate the highest common factor of 4 and 534. 14 | 2 15 | What is the highest common factor of 8 and 2792? 16 | 8 17 | Calculate the highest common factor of 84 and 15932. 18 | 28 19 | Calculate the highest common divisor of 54 and 27. 20 | 27 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/numbers__is_factor.txt: -------------------------------------------------------------------------------- 1 | Is 15 a factor of 720? 2 | True 3 | Does 11 divide 1973? 4 | False 5 | Is 347 a multiple of 3? 6 | False 7 | Does 28 divide 1204? 8 | True 9 | Is 17 a factor of 3594? 10 | False 11 | Is 23 a factor of 3059? 12 | True 13 | Does 16 divide 6606? 14 | False 15 | Is 54094 a multiple of 86? 16 | True 17 | Is 6 a factor of 5064? 18 | True 19 | Is 28 a factor of 840? 20 | True -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/numbers__is_prime.txt: -------------------------------------------------------------------------------- 1 | Is 93163 a prime number? 2 | False 3 | Is 29179 prime? 4 | True 5 | Is 323431 a prime number? 6 | False 7 | Is 5939 prime? 8 | True 9 | Is 1454 prime? 10 | False 11 | Is 350767 prime? 12 | True 13 | Is 66574 a composite number? 14 | True 15 | Is 3037 composite? 16 | False 17 | Is 28151 composite? 18 | False 19 | Is 2053 a prime number? 20 | True -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/numbers__lcm.txt: -------------------------------------------------------------------------------- 1 | Calculate the smallest common multiple of 351 and 141. 2 | 16497 3 | Calculate the common denominator of -29/936 and -115/48. 4 | 1872 5 | Calculate the common denominator of 79/195 and 113/60. 6 | 780 7 | What is the lowest common multiple of 12 and 20? 8 | 60 9 | Calculate the common denominator of -3/40 and 57/1652. 10 | 16520 11 | What is the common denominator of 81/140 and 57/140? 12 | 140 13 | What is the lowest common multiple of 4 and 4? 14 | 4 15 | Find the common denominator of -7/33 and -55/54. 16 | 594 17 | What is the common denominator of -11/1458 and -26/9? 18 | 1458 19 | What is the smallest common multiple of 25 and 20? 20 | 100 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/numbers__list_prime_factors.txt: -------------------------------------------------------------------------------- 1 | What are the prime factors of 329? 2 | 7, 47 3 | What are the prime factors of 2250? 4 | 2, 3, 5 5 | What are the prime factors of 7380? 6 | 2, 3, 5, 41 7 | What are the prime factors of 6792? 8 | 2, 3, 283 9 | List the prime factors of 32253. 10 | 3, 13, 827 11 | List the prime factors of 1312. 12 | 2, 41 13 | What are the prime factors of 773? 14 | 773 15 | What are the prime factors of 12963? 16 | 3, 29, 149 17 | List the prime factors of 31114. 18 | 2, 47, 331 19 | What are the prime factors of 1316? 20 | 2, 7, 47 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/artifacts/problems/polynomials__evaluate.txt: -------------------------------------------------------------------------------- 1 | Let i(h) = -7*h - 15. Determine i(-2). 2 | -1 3 | Let k(u) = u**2 + u - 4. What is k(0)? 4 | -4 5 | Let x(f) = -f - 19. Calculate x(9). 6 | -28 7 | Let t(x) = -x**2 + 3*x - 3. Calculate t(3). 8 | -3 9 | Let s(c) = -7*c**2 - 2. Determine s(-2). 10 | -30 11 | Let g(a) = a**2 - 4*a + 10. Determine g(12). 12 | 106 13 | Let i(j) = -3*j - 45. Determine i(11). 14 | -78 15 | Let v(i) = 3*i + 19. Determine v(-9). 16 | -8 17 | Let j(m) = -m**3 - 11*m**2 - 14*m + 36. Give j(-9). 18 | 0 19 | Let k(t) = -t**3 + 9*t**2 - t + 3. Determine k(9). 20 | -6 -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/extract_problems_for_guessing_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math_prog_synth_env.utils import load_training_data 3 | 4 | n = 20 5 | filename_to_top_lines = dict() 6 | for filename in os.listdir('mathematics_dataset-v1.0/train-easy'): 7 | filepath = os.path.join('mathematics_dataset-v1.0/train-easy', filename) 8 | with open(filepath) as f: 9 | lines = f.read().split('\n') 10 | top_n_lines = lines[:20] 11 | filename_to_top_lines[filename] = top_n_lines 12 | 13 | for filename, top_n_lines in filename_to_top_lines.items(): 14 | with open(f'environment/unit_testing/artifacts/problems/{filename}', 'w') as f: 15 | string = "\n".join(top_n_lines) 16 | f.write(string) 17 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/test_compute_graph.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from math_prog_synth_env.utils import extract_formal_elements 3 | from math_prog_synth_env.typed_operators import * 4 | from math_prog_synth_env.compute_graph import ComputeGraph, Node 5 | 6 | 7 | class Test(unittest.TestCase): 8 | def test_easy_algebra__linear_1d(self): 9 | question = "Solve 0 = 4*b + b + 15 for b." 10 | f = extract_formal_elements(question) 11 | cg = ComputeGraph(question) 12 | lookup_value_node = Node(lookup_value) 13 | solve_system_node = Node(solve_system) 14 | append_to_empty_list_node = Node(append_to_empty_list) 15 | append_to_empty_list_node.set_arg(Node('f0')) 16 | solve_system_node.set_arg(append_to_empty_list_node) 17 | lookup_value_node.set_args([solve_system_node, Node('f1')]) 18 | cg.root = lookup_value_node 19 | assert str(cg) == "lookup_value(solve_system(append_to_empty_list(Equation('0 = 4*b + b + 15'))),Variable('b'))" 20 | assert cg.eval() == Value(-3) 21 | 22 | def test_incomplete_compute_graph(self): 23 | question = "Solve 0 = 4*b + b + 15 for b." 24 | cg = ComputeGraph(question) 25 | lookup_value_node = Node(lookup_value) 26 | solve_system_node = Node(solve_system) 27 | lookup_value_node.set_arg(solve_system_node) 28 | cg.root = lookup_value_node 29 | assert str(cg) == "lookup_value(solve_system('p_0'),'p_1')" 30 | assert cg.eval() == None 31 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/test_extract_formal_elements.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from math_prog_synth_env.utils import extract_formal_elements 3 | 4 | 5 | def test_string_equality(question, formal_elements): 6 | '''this is a weaker test because types are ignored''' 7 | extracted_formal_elements = extract_formal_elements(question) 8 | for efe, fe in zip(extracted_formal_elements, formal_elements): 9 | print(efe, fe) 10 | assert str(efe) == fe 11 | 12 | 13 | def test_type_equality(question, formal_elements): 14 | '''this is a stronger test because it requires that the formal elements get casted correctly as well''' 15 | extracted_formal_elements = extract_formal_elements(question) 16 | for efe, fe in zip(extracted_formal_elements, formal_elements): 17 | assert efe == fe, (efe, fe) 18 | 19 | 20 | class Test(unittest.TestCase): 21 | 22 | def test_examples(self): 23 | 24 | # # do weak test 25 | # question_to_formal_elements = read_json( 26 | # "environment/unit_testing/artifacts/extract_formal_elements_examples.json" 27 | # ) 28 | # for question, formal_elements in question_to_formal_elements.items(): 29 | # test_string_equality(question, formal_elements) 30 | 31 | # do strong test 32 | from math_prog_synth_env.unit_testing.artifacts.extract_formal_elements_examples import typed_examples 33 | for question, formal_elements in typed_examples.items(): 34 | test_type_equality(question, formal_elements) 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/test_graphs.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from math_prog_synth_env.utils import extract_formal_elements, cast_formal_element 3 | from math_prog_synth_env.typed_operators import * 4 | 5 | 6 | class Test(unittest.TestCase): 7 | """all test cases are taken from train-easy""" 8 | 9 | def test_easy_algebra__linear_1d(self): 10 | question = "Solve 0 = 4*b + b + 15 for b." 11 | fs = extract_formal_elements(question) 12 | assert fs == [Equation("0 = 4*b + b + 15"), Variable("b")] 13 | system = append_to_empty_list(fs[0]) 14 | solution = solve_system(system) 15 | value = lookup_value(solution, fs[1]) 16 | assert value == Value(-3) 17 | 18 | def test_easy_algebra__linear_1d_composed(self): 19 | question = "Let w be (-1 + 13)*3/(-6). Let b = w - -6. Let i = 2 - b. Solve -15 = 3*c + i*c for c." 20 | f = extract_formal_elements(question) 21 | assert f == [ 22 | Variable("w"), 23 | Value(sympy.sympify("(-1 + 13)*3/(-6)")), 24 | Equation("b = w - -6"), 25 | Equation("i = 2 - b"), 26 | Equation("-15 = 3*c + i*c"), 27 | Variable("c"), 28 | ] 29 | eq1 = make_equation(f[0], f[1]) 30 | system = append(append(append_to_empty_list(eq1), f[2]), f[3]) 31 | soln = solve_system(system) 32 | i_eq = lookup_value_equation(soln, extract_isolated_variable(f[3])) 33 | lin_eq = substitution_left_to_right(f[4], i_eq) 34 | assert lookup_value(solve_system(append_to_empty_list(lin_eq)), f[5]) == Value(-3) 35 | 36 | def test_easy_algebra__linear_2d(self): 37 | question = "Solve 0 = 4*f - 0*t - 4*t - 4, -4*f + t = -13 for f." 38 | f = extract_formal_elements(question) 39 | assert f == [ 40 | Equation("0 = 4*f - 0*t - 4*t - 4"), 41 | Equation("-4*f + t = -13"), 42 | Variable("f"), 43 | ] 44 | 45 | assert lookup_value( 46 | solve_system(append(append_to_empty_list(f[0]), f[1])), f[2] 47 | ) == Value(4) 48 | 49 | def test_algebra__linear_2d_composed(self): 50 | question = "Suppose 2*y + 12 = 6*y. Suppose y = f - 15. Solve -8 = -4*w, -3*d - 4*w + f = -8*d for d." 51 | f = extract_formal_elements(question) 52 | assert f == [ 53 | Equation("2*y + 12 = 6*y"), 54 | Equation("y = f - 15"), 55 | Equation("-8 = -4*w"), 56 | Equation("-3*d - 4*w + f = -8*d"), 57 | Variable("d"), 58 | ] 59 | system = append(append(append(append_to_empty_list(f[0]), f[1]), f[2]), f[3]) 60 | assert lookup_value(solve_system(system), f[4]) == Value(-2) 61 | 62 | def test_algebra__polynomial_roots_1(self): 63 | question = "Solve -3*h**2/2 - 24*h - 45/2 = 0 for h." 64 | f = extract_formal_elements(question) 65 | assert f == [Equation("-3*h**2/2 - 24*h - 45/2 = 0"), Variable("h")] 66 | soln = lookup_value(solve_system(append_to_empty_list(f[0])), f[1]) 67 | assert soln == {Rational(-1), Rational(-15)} 68 | 69 | def test_algebra__polynomial_roots_2(self): 70 | question = "Factor -n**2/3 - 25*n - 536/3." 71 | f = extract_formal_elements(question) 72 | assert f == [Expression("-n**2/3 - 25*n - 536/3")] 73 | assert factor(f[0]) == Expression("-(n + 8)*(n + 67)/3") 74 | 75 | def test_algebra__polynomial_roots_3(self): 76 | question = ( 77 | "Find s such that 9*s**4 - 8958*s**3 - 14952*s**2 - 2994*s + 2991 = 0." 78 | ) 79 | f = extract_formal_elements(question) 80 | assert f == [ 81 | Variable("s"), 82 | Equation("9*s**4 - 8958*s**3 - 14952*s**2 - 2994*s + 2991 = 0"), 83 | ] 84 | assert lookup_value(solve_system(append_to_empty_list(f[1])), f[0]) == { 85 | Rational(-1), 86 | Rational('1/3'), 87 | Rational(997), 88 | } 89 | 90 | def test_algebra__polynomial_roots_composed_1(self): 91 | question = "Let d = -25019/90 - -278. Let v(j) be the third derivative of 0 + 1/27*j**3 - d*j**5 + 1/54*j**4 + 3*j**2 + 0*j. Suppose v(o) = 0. What is o?" 92 | f = extract_formal_elements(question) 93 | assert f == [ 94 | Equation("d = -25019/90 - -278"), 95 | Expression("v(j)"), 96 | Expression("0 + 1/27*j**3 - d*j**5 + 1/54*j**4 + 3*j**2 + 0*j"), 97 | Function("v(o) = 0"), 98 | Variable("o"), 99 | ] 100 | d = simplify(f[0]) 101 | function = substitution_left_to_right(f[2], d) 102 | v = differentiate(differentiate(differentiate(function))) 103 | v_eq = make_function(f[1], v) 104 | v_eq_o = replace_arg(v_eq, f[4]) 105 | equation = substitution_left_to_right( 106 | f[3], v_eq_o 107 | ) # e.g. x.subs(sym.sympify('f(x)'), sym.sympify('v')) 108 | assert lookup_value(solve_system(append_to_empty_list(equation)), f[4]) == { 109 | Rational('-1/3'), 110 | Rational(1), 111 | } 112 | 113 | def test_calculus__differentiate(self): 114 | question = "What is the second derivative of 2*c*n**2*z**3 + 30*c*n**2 + 2*c*n*z**2 - 2*c + n**2*z**2 - 3*n*z**3 - 2*n*z wrt n?" 115 | f = extract_formal_elements(question) 116 | assert f == [Expression('2*c*n**2*z**3 + 30*c*n**2 + 2*c*n*z**2 - 2*c + n**2*z**2 - 3*n*z**3 - 2*n*z'), Variable('n')] 117 | assert differentiate_wrt(differentiate_wrt(f[0], f[1]), f[1]) == Expression('4*c*z**3 + 60*c + 2*z**2') 118 | 119 | def test_numbers__div_remainder(self): 120 | question = "Calculate the remainder when 93 is divided by 59." 121 | f = extract_formal_elements(question) 122 | assert f == [Value("93"), Value("59")] 123 | assert mod(f[0], f[1]) == Value("34") 124 | 125 | def test_numbers__gcd(self): 126 | question = "Calculate the greatest common fac of 11130 and 6." 127 | f = extract_formal_elements(question) 128 | assert f == [Value("11130"), Value("6")] 129 | assert gcd(f[0], f[1]) == Value("6") 130 | 131 | def test_numbers__is_factor(self): 132 | question = "Is 15 a fac of 720?" 133 | f = extract_formal_elements(question) 134 | assert f == [Value("15"), Value("720")] 135 | assert divides(f[1], f[0]) == True 136 | 137 | def test_numbers__is_prime(self): 138 | question = "Is 93163 a prime number?" 139 | f = extract_formal_elements(question) 140 | assert f == [Value("93163")] 141 | assert is_prime(f[0]) == False 142 | 143 | def test_numbers__lcm(self): 144 | question = "Calculate the smallest common multiple of 351 and 141." 145 | f = extract_formal_elements(question) 146 | assert f == [Value("351"), Value("141")] 147 | assert lcm(f[0], f[1]) == Value("16497") 148 | 149 | def test_numbers__list_prime_factors(self): 150 | question = "What are the prime factors of 329?" 151 | f = extract_formal_elements(question) 152 | assert f == [Value("329")] 153 | assert prime_factors(f[0]) == {Value(7), Value(47)} 154 | 155 | def test_polynomials_evaluate(self): 156 | question = "Let i(h) = -7*h - 15. Determine i(-2)." 157 | f = extract_formal_elements(question) 158 | assert f == [Function("i(h) = -7*h - 15"), Expression("i(-2)")] 159 | assert evaluate_function(f[0], f[1]) == Value(-1) 160 | 161 | # requiring new operators -------------------------------------------- 162 | 163 | # def test_comparison__closest(self): 164 | # question = 'Which is the closest to -1/3? (a) -8/7 (b) 5 (c) -1.3' 165 | # f = extract_formal_elements(question) 166 | # power_f0 = power(f[0], f[1]) 167 | # rounded_power_f0 = round_to_int(power_f0, f[2]) 168 | # assert rounded_power_f0 == '3' 169 | 170 | # def test_comparison__pair_composed(self): 171 | # question = 'Let o = -788/3 - -260. Which is bigger: -0.1 or o?' 172 | # f = extract_formal_elements(question) 173 | # assert f == [Equation('o = -788/3 - -260'), Value('-0.1'), Variable('o')] 174 | # o = sy(f[0]) 175 | # m = max_arg(f[1], pr(o)) 176 | # assert srl(m, o) == Value('-0.1') 177 | 178 | # def test_comparison__sort_composed(self): 179 | # question = 'Suppose $f[0 = -4*x + 8*x - 40]. Let $f[h(i) = i**2 - 9*i - 14]. Let $f[n] be $f[h(x)]. Sort $f[-1], $f[4], $f[n].' 180 | # f = extract_formal_elements(question) 181 | # x = lv(ss(f[0]), get 182 | 183 | # def test_arithmetic__add_or_sub_in_base(self): 184 | # question = 'In base 13, what is 7a79 - -5?' 185 | # f = extract_formal_elements(question) 186 | # assert f == ['13', '7a79 - -5'] 187 | # assert eval_in_base(f[1], f[0]) == '7a81' 188 | # 189 | # def test_arithmetic__nearest_integer_root_1(self): 190 | # question = 'What is the square root of 664 to the nearest 1?' 191 | # f = extract_formal_elements(question) 192 | # root_f1 = root(f[1], f[0]) 193 | # rounded_root_f1 = round_to_int(root_f1, f[2]) 194 | # assert rounded_root_f1 == '26' 195 | # 196 | # def test_arithmetic__nearest_integer_root_2(self): 197 | # question = 'What is $f[1699] to the power of $f[1/6], to the nearest $f[1]?' 198 | # f = extract_formal_elements(question) 199 | # power_f0 = power(f[0], f[1]) 200 | # rounded_power_f0 = round_to_int(power_f0, f[2]) 201 | # assert rounded_power_f0 == '3' 202 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/test_gym_environment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math_prog_synth_env.utils import extract_formal_elements 3 | from math_prog_synth_env.envs.math_env import MathEnv 4 | from math_prog_synth_env.typed_operators import * 5 | import unittest 6 | 7 | env = MathEnv('math_prog_synth_env/unit_testing/artifacts/params.yaml') 8 | 9 | class Test(unittest.TestCase): 10 | def test_algebra_linear_1d_fail_1(self): 11 | # reset - then fail after 1st action 12 | encoded_question, _ = env.reset_from_text("Solve 0 = 4*b + b + 15 for b.", "-3") 13 | question = env.decode_question(encoded_question) 14 | f = extract_formal_elements(question) # for use below 15 | assert f == ["0 = 4*b + b + 15", "b"] 16 | action = "f0" 17 | action_index = env.get_action_index(action) 18 | observation_, reward, done, info = env.step(action_index) 19 | assert ( 20 | info["raw_observation"] 21 | == f"{question}; Equation('0 = 4*b + b + 15')" 22 | ) 23 | assert reward == 0 24 | assert done 25 | 26 | def test_algebra_linear_1d_fail_2(self): 27 | #env = MathEnv('params.yaml') 28 | # reset - then fail after 2nd action 29 | encoded_question, _ = env.reset_from_text("Solve 0 = 4*b + b + 15 for b.", "-3") 30 | question = env.decode_question(encoded_question) 31 | assert question == "Solve 0 = 4*b + b + 15 for b." 32 | action = solve_system 33 | action_index = env.get_action_index(action) 34 | observation, reward, done, info = env.step(action_index) 35 | assert ( 36 | info["raw_observation"] == f"{question}; solve_system('p_0')" 37 | ) 38 | assert reward == 0 39 | assert not done 40 | # next action 41 | action = "f0" 42 | action_index = env.get_action_index(action) 43 | observation_, reward, done, info = env.step(action_index) 44 | assert ( 45 | info["raw_observation"] 46 | == f"{question}; solve_system(Equation('0 = 4*b + b + 15'))" 47 | ) 48 | assert reward == 0 49 | assert done 50 | 51 | def test_algebra_linear_1d_fail_3(self): 52 | #env = MathEnv('params.yaml') 53 | # reset - then fail after 1st action 54 | encoded_question, _ = env.reset_from_text("Solve 0 = 4*b + b + 15 for b.", "-3") 55 | question = env.decode_question(encoded_question) 56 | f = extract_formal_elements(question) # for use below 57 | assert f == ["0 = 4*b + b + 15", "b"] 58 | action = "f10" # indexing out of range 59 | action_index = env.get_action_index(action) 60 | observation_, reward, done, info = env.step(action_index) 61 | assert reward == 0 62 | assert done 63 | 64 | def test_algebra_linear_1d_success_1(self): 65 | #env = MathEnv('params.yaml') 66 | # reset - then succeed after 4th action 67 | encoded_question, _ = env.reset_from_text("Solve 0 = 4*b + b + 15 for b.", "-3") 68 | question = env.decode_question(encoded_question) 69 | assert question == "Solve 0 = 4*b + b + 15 for b." 70 | action = lookup_value 71 | action_index = env.get_action_index(action) 72 | observation, reward, done, info = env.step(action_index) 73 | assert ( 74 | info["raw_observation"] 75 | == f"{question}; lookup_value('p_0','p_1')" 76 | ) 77 | assert reward == 0 78 | assert not done 79 | assert env.compute_graph.current_node == env.compute_graph.root 80 | # next action 81 | action = solve_system 82 | action_index = env.get_action_index(action) 83 | observation, reward, done, info = env.step(action_index) 84 | assert ( 85 | info["raw_observation"] 86 | == f"{question}; lookup_value(solve_system('p_0'),'p_1')" 87 | ) 88 | assert reward == 0 89 | assert not done 90 | # current node is still root because it takes 2 arguments and only 1 has been given 91 | assert env.compute_graph.current_node == env.compute_graph.root 92 | # next action 93 | action = "f1" 94 | action_index = env.get_action_index(action) 95 | observation, reward, done, info = env.step(action_index) 96 | assert ( 97 | info["raw_observation"] 98 | == f"{question}; lookup_value(solve_system('p_0'),Variable('b'))" 99 | ) 100 | assert reward == 0 101 | assert not done 102 | # current node is now the solve_system node because the lookup_value node has its args set 103 | assert env.compute_graph.current_node == env.compute_graph.root.args[0] 104 | # next action 105 | action = append_to_empty_list 106 | action_index = env.get_action_index(action) 107 | observation, reward, done, info = env.step(action_index) 108 | assert ( 109 | info["raw_observation"] 110 | == f"{question}; lookup_value(solve_system(append_to_empty_list('p_0')),Variable('b'))" 111 | ) 112 | assert reward == 0 113 | assert not done 114 | # next action 115 | action = "f0" 116 | action_index = env.get_action_index(action) 117 | observation, reward, done, info = env.step(action_index) 118 | assert ( 119 | info["raw_observation"] 120 | == f"{question}; lookup_value(solve_system(append_to_empty_list(Equation('0 = 4*b + b + 15'))),Variable('b'))" 121 | ) 122 | assert reward == 1 123 | assert done 124 | 125 | def test_calculus_differentiate_success_1_with_masking(self): 126 | #env = MathEnv('params.yaml') 127 | # reset - then succeed after 4th action 128 | encoded_question, _ = env.reset_from_text("Find the first derivative of 2*d**4 - 35*d**2 - 695 wrt d.", 129 | "8*d**3 - 70*d") 130 | question = env.decode_question(encoded_question) 131 | assert question == "Find the first derivative of 2*d**4 - 35*d**2 - 695 wrt d." 132 | # take action 133 | action = differentiate_wrt 134 | action_index = env.get_action_index(action) 135 | observation, reward, done, info = env.step(action_index) 136 | assert ( 137 | info["raw_observation"] == f"{question}; differentiate_wrt('p_0','p_1')" 138 | ) 139 | # take action 140 | action = "f0" 141 | action_index = env.get_action_index(action) 142 | observation, reward, done, info = env.step(action_index) 143 | assert reward == 0 144 | assert not done 145 | assert ( 146 | info["raw_observation"] == f"{question}; differentiate_wrt(Expression('2*d**4 - 35*d**2 - 695'),'p_1')" 147 | ) 148 | vector = np.ones(len(env.actions)) 149 | masked_vector = env.mask_invalid_types(vector) 150 | assert masked_vector[env.get_action_index("f0")] == 0 and \ 151 | masked_vector[env.get_action_index("f1")] == 1 152 | # take action 153 | action = "f1" 154 | action_index = env.get_action_index(action) 155 | observation, reward, done, info = env.step(action_index) 156 | assert reward == 1 157 | assert done 158 | 159 | def test_calculus_differentiate_success_2_with_masking(self): 160 | #env = MathEnv('params.yaml') 161 | # reset - then succeed after 4th action 162 | encoded_question, _ = env.reset_from_text("Find the first derivative of 2*d**4 - 35*d**2 - 695 wrt d.", 163 | "8*d**3 - 70*d") 164 | question = env.decode_question(encoded_question) 165 | assert question == "Find the first derivative of 2*d**4 - 35*d**2 - 695 wrt d." 166 | # take action 167 | action = differentiate 168 | action_index = env.get_action_index(action) 169 | observation, reward, done, info = env.step(action_index) 170 | assert ( 171 | info["raw_observation"] == f"{question}; differentiate('p_0')" 172 | ) 173 | # take action 174 | action = "f0" 175 | action_index = env.get_action_index(action) 176 | observation, reward, done, info = env.step(action_index) 177 | assert reward == 1 178 | assert done 179 | assert ( 180 | info["raw_observation"] == f"{question}; differentiate(Expression('2*d**4 - 35*d**2 - 695'))" 181 | ) 182 | 183 | def test_numbers_div_remainder_success(self): 184 | #env = MathEnv('params.yaml') 185 | # reset - then succeed after 4th action 186 | encoded_question, _ = env.reset_from_text("Calculate the remainder when 93 is divided by 59.", "34") 187 | question = env.decode_question(encoded_question) 188 | assert question == "Calculate the remainder when 93 is divided by 59." 189 | assert env.compute_graph.formal_elements == [Value("93"), Value("59")] 190 | # first action 191 | action = mod 192 | action_index = env.get_action_index(action) 193 | observation, reward, done, info = env.step(action_index) 194 | assert ( 195 | info["raw_observation"] == f"{question}; mod('p_0','p_1')" 196 | ) 197 | assert reward == 0 198 | assert not done 199 | # next action 200 | action = "f0" 201 | action_index = env.get_action_index(action) 202 | observation, reward, done, info = env.step(action_index) 203 | assert ( 204 | info["raw_observation"] 205 | == f"{question}; mod(Value('93'),'p_1')" 206 | ) 207 | assert reward == 0 208 | assert not done 209 | # next action 210 | action = "f1" 211 | action_index = env.get_action_index(action) 212 | observation, reward, done, info = env.step(action_index) 213 | assert ( 214 | info["raw_observation"] 215 | == f"{question}; mod(Value('93'),Value('59'))" 216 | ) 217 | assert reward == 1 218 | assert done 219 | 220 | def test_numbers_gcd_success(self): 221 | #env = MathEnv('params.yaml') 222 | # reset - then succeed after 4th action 223 | encoded_question, _ = env.reset_from_text("Calculate the highest common divisor of 1300 and 300.", "100") 224 | question = env.decode_question(encoded_question) 225 | assert question == "Calculate the highest common divisor of 1300 and 300." 226 | # first action 227 | action = gcd 228 | action_index = env.get_action_index(action) 229 | observation, reward, done, info = env.step(action_index) 230 | assert ( 231 | info["raw_observation"] == f"{question}; gcd('p_0','p_1')" 232 | ) 233 | assert reward == 0 234 | assert not done 235 | # next action 236 | action = "f0" 237 | action_index = env.get_action_index(action) 238 | observation, reward, done, info = env.step(action_index) 239 | assert ( 240 | info["raw_observation"] 241 | == f"{question}; gcd(Value('1300'),'p_1')" 242 | ) 243 | assert reward == 0 244 | assert not done 245 | # next action 246 | action = "f1" 247 | action_index = env.get_action_index(action) 248 | observation, reward, done, info = env.step(action_index) 249 | assert ( 250 | info["raw_observation"] 251 | == f"{question}; gcd(Value('1300'),Value('300'))" 252 | ) 253 | assert reward == 1 254 | assert done 255 | 256 | def test_is_prime_success_1(self): 257 | #env = MathEnv('params.yaml') 258 | # reset - then succeed after 4th action 259 | encoded_question, _ = env.reset_from_text("Is 93163 a prime number?", "False") 260 | question = env.decode_question(encoded_question) 261 | assert question == "Is 93163 a prime number?" 262 | # first action 263 | action = is_prime 264 | action_index = env.get_action_index(action) 265 | observation, reward, done, info = env.step(action_index) 266 | assert ( 267 | info["raw_observation"] == f"{question}; is_prime('p_0')" 268 | ) 269 | assert reward == 0 270 | assert not done 271 | # next action 272 | action = "f0" 273 | action_index = env.get_action_index(action) 274 | observation, reward, done, info = env.step(action_index) 275 | assert ( 276 | info["raw_observation"] 277 | == f"{question}; is_prime(Value('93163'))" 278 | ) 279 | assert reward == 1 280 | assert done 281 | 282 | def test_is_prime_success_2(self): 283 | #env = MathEnv('params.yaml') 284 | # reset - then succeed after 4th action 285 | encoded_question, _ = env.reset_from_text("Is 66574 a composite number?", "True") 286 | question = env.decode_question(encoded_question) 287 | assert question == "Is 66574 a composite number?" 288 | # first action 289 | action = not_op 290 | action_index = env.get_action_index(action) 291 | observation, reward, done, info = env.step(action_index) 292 | assert ( 293 | info["raw_observation"] == f"{question}; not_op('p_0')" 294 | ) 295 | assert reward == 0 296 | assert not done 297 | # next action 298 | action = is_prime 299 | action_index = env.get_action_index(action) 300 | observation, reward, done, info = env.step(action_index) 301 | assert ( 302 | info["raw_observation"] 303 | == f"{question}; not_op(is_prime('p_0'))" 304 | ) 305 | assert reward == 0 306 | assert not done 307 | # next action 308 | action = "f0" 309 | action_index = env.get_action_index(action) 310 | observation, reward, done, info = env.step(action_index) 311 | assert ( 312 | info["raw_observation"] 313 | == f"{question}; not_op(is_prime(Value('66574')))" 314 | ) 315 | assert reward == 1 316 | assert done 317 | 318 | def test_problem_third_diff_success(self): 319 | #env = MathEnv('params.yaml') 320 | # reset - then succeed after 4th action 321 | encoded_question, _ = env.reset_from_text("Find the third derivative of -272*j**5 + j**3 - 8234*j**2.", 322 | "-16320*j**2 + 6") 323 | question = env.decode_question(encoded_question) 324 | assert question == "Find the third derivative of -272*j**5 + j**3 - 8234*j**2." 325 | # take action 326 | action_index = env.get_action_index(differentiate) 327 | observation, reward, done, info = env.step(action_index) 328 | assert reward == 0 329 | assert not done 330 | # take action 331 | observation, reward, done, info = env.step(action_index) 332 | assert reward == 0 333 | assert not done 334 | # take action 335 | observation, reward, done, info = env.step(action_index) 336 | assert reward == 0 337 | assert not done 338 | # take action 339 | action_index = env.get_action_index("f0") 340 | observation, reward, done, info = env.step(action_index) 341 | assert reward == 1 342 | assert done 343 | 344 | def test_max_nodes_failure(self): 345 | #env = MathEnv('params.yaml') 346 | encoded_question, _ = env.reset_from_text("Is 66574 a composite number?", "True") 347 | question = env.decode_question(encoded_question) 348 | assert question == "Is 66574 a composite number?" 349 | nt_action_index = env.get_action_index(not_op) 350 | for i in range(env.max_num_nodes-1): 351 | # take action 352 | observation, reward, done, info = env.step(nt_action_index) 353 | assert reward == 0 354 | assert not done 355 | # take final action 356 | i += 1 357 | observation, reward, done, info = env.step(nt_action_index) 358 | assert reward == 0 359 | assert done 360 | 361 | 362 | def test_lcd1(self): 363 | #env = MathEnv('params.yaml') 364 | encoded_question, _ = env.reset_from_text("What is the common denominator of -64/1065 and 92/105?", "7455") 365 | question = env.decode_question(encoded_question) 366 | # lcd 367 | assert question == "What is the common denominator of -64/1065 and 92/105?" 368 | action_index = env.get_action_index(lcd) 369 | observation, reward, done, info = env.step(action_index) 370 | assert reward == 0 371 | assert not done 372 | # f0 373 | action_index = env.get_action_index("f0") 374 | observation, reward, done, info = env.step(action_index) 375 | assert reward == 0 376 | assert not done 377 | # f1 378 | action_index = env.get_action_index("f1") 379 | observation, reward, done, info = env.step(action_index) 380 | assert reward == 1 381 | assert done 382 | 383 | def test_lcd2(self): 384 | #env = MathEnv('params.yaml') 385 | encoded_question, _ = env.reset_from_text("Calculate the common denominator of 1/(3/(-6)) - 402/(-60) and -71/12.", "60") 386 | question = env.decode_question(encoded_question) 387 | # lcd 388 | assert question == "Calculate the common denominator of 1/(3/(-6)) - 402/(-60) and -71/12." 389 | action_index = env.get_action_index(lcd) 390 | observation, reward, done, info = env.step(action_index) 391 | assert reward == 0 392 | assert not done 393 | # f0 394 | action_index = env.get_action_index("f0") 395 | observation, reward, done, info = env.step(action_index) 396 | assert reward == 0 397 | assert not done 398 | # f1 399 | action_index = env.get_action_index("f1") 400 | observation, reward, done, info = env.step(action_index) 401 | assert reward == 1 402 | assert done 403 | 404 | def test_polynomial_roots_1(self): 405 | question = "What is f in -87616*f**2 - 1776*f - 9 = 0?" 406 | answer = "-3/296" 407 | #env = MathEnv('params.yaml') 408 | encoded_question, _ = env.reset_from_text(question, answer) 409 | action = lookup_value 410 | action_index = env.get_action_index(action) 411 | observation, reward, done, info = env.step(action_index) 412 | assert reward == 0 413 | assert not done 414 | # next action 415 | action = solve_system 416 | action_index = env.get_action_index(action) 417 | observation, reward, done, info = env.step(action_index) 418 | assert reward == 0 419 | assert not done 420 | # next action 421 | action = "f0" 422 | action_index = env.get_action_index(action) 423 | observation, reward, done, info = env.step(action_index) 424 | assert reward == 0 425 | assert not done 426 | # next action 427 | action = append_to_empty_list 428 | action_index = env.get_action_index(action) 429 | observation, reward, done, info = env.step(action_index) 430 | assert reward == 0 431 | assert not done 432 | # next action 433 | action = "f1" 434 | action_index = env.get_action_index(action) 435 | observation, reward, done, info = env.step(action_index) 436 | assert reward == 1 437 | assert done 438 | 439 | def test_polynomial_roots_2(self): 440 | question = "Solve -3*h**2/2 - 24*h - 45/2 = 0 for h." 441 | answer = "-15, -1" 442 | #env = MathEnv('params.yaml') 443 | encoded_question, _ = env.reset_from_text(question, answer) 444 | action = lookup_value 445 | action_index = env.get_action_index(action) 446 | observation, reward, done, info = env.step(action_index) 447 | assert reward == 0 448 | assert not done 449 | # next action 450 | action = solve_system 451 | action_index = env.get_action_index(action) 452 | observation, reward, done, info = env.step(action_index) 453 | assert reward == 0 454 | assert not done 455 | # next action 456 | action = "f1" 457 | action_index = env.get_action_index(action) 458 | observation, reward, done, info = env.step(action_index) 459 | assert reward == 0 460 | assert not done 461 | # next action 462 | action = append_to_empty_list 463 | action_index = env.get_action_index(action) 464 | observation, reward, done, info = env.step(action_index) 465 | assert reward == 0 466 | assert not done 467 | # next action 468 | action = "f0" 469 | action_index = env.get_action_index(action) 470 | observation, reward, done, info = env.step(action_index) 471 | print(info['raw_observation']) 472 | assert reward == 1 473 | assert done 474 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/test_gym_environment_guessing.py: -------------------------------------------------------------------------------- 1 | from math_prog_synth_env.envs.math_env import MathEnv 2 | from math_prog_synth_env.utils import guess_until_problem_solved, load_question_answer_pairs 3 | import unittest 4 | import os 5 | 6 | class Test(unittest.TestCase): 7 | 8 | def test_guess_until_correct(self): 9 | """this test only terminates when the graph is correctly guessed or timeout is reached""" 10 | env = MathEnv('params.yaml') 11 | for filename in [fn for fn in os.listdir('math_prog_synth_env/unit_testing/artifacts/problems') if '.txt' in fn]: 12 | filepath = os.path.join(f'math_prog_synth_env/unit_testing/artifacts/problems/{filename}') 13 | question_answer_pairs = load_question_answer_pairs(filepath) 14 | for question, answer in question_answer_pairs[:5]: 15 | guess_until_problem_solved(env, question, answer, verbose=False, max_episode_index=50000) 16 | 17 | -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/test_operators.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from math_prog_synth_env.typed_operators import * 4 | from sympy import sympify 5 | 6 | 7 | class Test(unittest.TestCase): 8 | def test_value(self): 9 | assert Value(1) == Value(1.0) 10 | assert {Value(1)} == {Value(1)} 11 | assert {Value(1)} == {Value(1.0)} 12 | 13 | def test_solve_system(self): 14 | system = [Equation("x = 1")] 15 | assert solve_system(system) == {Variable("x"): {Rational(1)}} 16 | 17 | system = [Equation("x = 1"), Equation("y = 1")] 18 | assert solve_system(system) == { 19 | Variable("x"): {Rational(1)}, 20 | Variable("y"): {Rational(1)}, 21 | } 22 | 23 | system = [Equation("x + y = 1"), Equation("x - y = 1")] 24 | assert solve_system(system) == { 25 | Variable("x"): {Rational(1)}, 26 | Variable("y"): {Rational(0)}, 27 | } 28 | 29 | system = [Equation("3*x + y = 9"), Equation("x + 2*y = 8")] 30 | assert solve_system(system) == { 31 | Variable("x"): {Rational(2)}, 32 | Variable("y"): {Rational(3)}, 33 | } 34 | 35 | # # fails on singular matrix 36 | system = [ 37 | Equation("x + 2*y - 3*z = 1"), 38 | Equation("3*x - 2*y + z = 2"), 39 | Equation("-x + 2*y - 2*z = 3"), 40 | ] 41 | self.assertRaises(Exception, solve_system, system) 42 | 43 | # system with floating point coefficients 44 | system = [Equation("-15 = 3*c + 2.0*c")] 45 | assert solve_system(system) == {Variable("c"): {Rational(-3)}} 46 | 47 | # quadratic equation 48 | system = [Equation("-3*h**2/2 - 24*h - 45/2 = 0")] 49 | assert solve_system(system) == {Variable("h"): {Rational(-15), Rational(-1)}} 50 | 51 | # unsolvable equation / infinite loop without timeout 52 | system = [Equation('-4*i**3*j**3 - 2272*i**3 - 769*i**2*j - j**3 = 1')] 53 | self.assertRaises(Exception, solve_system, system) 54 | 55 | system = [Equation('-g**3 - 9*g**2 - g + l(g) - 10 = 0')] 56 | self.assertRaises(Exception, solve_system, system) 57 | 58 | # unsolvable equation / infinite loop without timeout 59 | system = [Equation('-4*i**3*j**3 - 2272*i**3 - 769*i**2*j - j**3 = 1')] 60 | self.assertRaises(Exception, solve_system, system) 61 | 62 | system = [Equation('9*s**4 - 8958*s**3 - 14952*s**2 - 2994*s + 2991 = 0')] 63 | assert solve_system(system) == {Variable("s"): {Rational(-1), Rational('1/3'), Rational(997)}} 64 | 65 | system = [Equation('-3*h**2/2 - 24*h - 45/2 = 0')] 66 | assert solve_system(system) == {Variable("h"): {Rational('-1'), Rational('-15')}} 67 | # print([(str(k), [str(v) for v in vset]) for k,vset in solve_system(system).items()]) 68 | 69 | def test_is_prime(self): 70 | assert is_prime(Value('3')) 71 | assert not_op(is_prime(Value('4'))) 72 | 73 | def test_prime_factors(self): 74 | result = prime_factors(Value('7380')) 75 | assert ", ".join([str(x) for x in sorted(list(result))]) == '2, 3, 5, 41' 76 | 77 | def test_lcd(self): 78 | assert lcd(Rational('2/3'), Rational('3/5')) == Value('15') 79 | assert lcd(Rational('2/3'), Rational('3/5')) == Value('15') 80 | 81 | def test_third_derivative(self): 82 | inpt = Expression('-272*j**5 + j**3 - 8234*j**2') 83 | third_derivative = differentiate(differentiate(differentiate(inpt))) 84 | assert sympify(third_derivative) == sympify(Expression('-16320*j**2 + 6')) 85 | 86 | def test_function_evaluation1(self): 87 | f0 = Function('l(t) = -t**2 - 7*t - 7') 88 | f1 = Expression('l(-5)') 89 | output = evaluate_function(f0, f1) 90 | assert output == Value(3) 91 | 92 | def test_function_evaluation2(self): 93 | f0 = Function('x(k) = k**3 + k**2 + 6*k + 9') 94 | f1 = Expression('x(-2)') 95 | output = evaluate_function(f0, f1) 96 | assert output == Value(-7) 97 | 98 | def test_diff_distractors(self): 99 | expression = Expression('442*c**4 + 248') 100 | output1 = differentiate(factor(expression)) 101 | output2 = factor(differentiate(expression)) 102 | output3 = differentiate(simplify(expression)) 103 | output4 = simplify(differentiate(expression)) 104 | answer = Expression('1768*c**3') 105 | assert output1 == answer 106 | assert output2 == answer 107 | assert output3 == answer 108 | assert output4 == answer 109 | 110 | def test_lcd(self): 111 | arg1 = Rational('-64/1065') 112 | arg2 = Rational('92/105') 113 | output = lcd(arg1, arg2) 114 | assert output == Value('7455') 115 | 116 | def test_replace_arg(self): 117 | f = Function('funk(k) = k**3 + k**2 + 6*k + 9') 118 | replaced_f = Function('funk(x) = x**3 + x**2 + 6*x + 9') 119 | assert replace_arg(f, Variable('x')) == replaced_f -------------------------------------------------------------------------------- /math_prog_synth_env/unit_testing/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from math_prog_synth_env.utils import is_numeric 3 | 4 | 5 | class Test(unittest.TestCase): 6 | def test_is_numeric(self): 7 | assert is_numeric("2") 8 | assert is_numeric("2.0") 9 | assert not is_numeric("2.0.") 10 | -------------------------------------------------------------------------------- /math_prog_synth_env/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from tqdm import tqdm 4 | from math_prog_synth_env.typed_operators import Equation, Function, Expression, Variable, Value, Rational 5 | import sympy 6 | 7 | 8 | def is_numeric(string): 9 | return all([x.isnumeric() or x == "." for x in string] + [string.count(".") <= 1]) 10 | 11 | 12 | def extract_formal_elements_as_annotations(question): 13 | pattern = "\$f\[(.+?)\]" 14 | return re.findall(pattern, question) 15 | 16 | 17 | def extract_formal_elements(question, cast=True): 18 | # split on punctuation unless it is immediately preceded and followed by a number (indicating it is a decimal) 19 | split_on_punctuation = "***".join( 20 | [ 21 | string 22 | for string in re.split("(? 0 and not string.isspace() 24 | ] 25 | ) 26 | # TODO: use a more sophisticated mechanism (CFG?) to math expressions, equations, etc... this could account for variables names that have length greater than 1 27 | split_on_words = [ 28 | string 29 | for string in re.split("[A-Za-z]\w+|\*\*\*", split_on_punctuation) 30 | if len(string) > 0 and not string.isspace() 31 | ] 32 | # strip trailing or leading whitespace 33 | formal_elements = [string.strip() for string in split_on_words] 34 | # filter for the special case where the letter "a" gets included at the end of a formal element 35 | formal_elements = [ 36 | f if len(re.findall("[0-9A-Za-z\)](\sa)", f)) < 1 else f.split(" a")[0] 37 | for f in formal_elements 38 | ] 39 | # cast types 40 | if cast: 41 | formal_elements = [cast_formal_element(f) for f in formal_elements] 42 | return formal_elements 43 | 44 | 45 | def cast_formal_element(f): 46 | try: 47 | x = sympy.sympify(f) 48 | if type(x) == sympy.core.numbers.Rational: 49 | return Rational(str(x)) 50 | elif issubclass(type(x), sympy.core.numbers.Number): 51 | return Value(str(x)) 52 | elif type(x) == sympy.core.symbol.Symbol: 53 | return Variable(f) 54 | else: 55 | return Expression(f) 56 | except: 57 | if "=" in f: 58 | try: 59 | return Function(f) 60 | except: 61 | return Equation(f) 62 | 63 | 64 | def guess_until_problem_solved(env, question, answer, verbose=False, max_episode_index=1000): 65 | episode_i = 0 66 | graph_guessed_correctly = False 67 | encoded_question, _ = env.reset_from_text(question, answer) 68 | print(f"\nquestion: {env.decode_question(encoded_question)}") 69 | while not graph_guessed_correctly and episode_i < max_episode_index: 70 | encoded_question, _ = env.reset_from_text(question, answer) 71 | done = False 72 | step_i = 0 73 | if verbose: 74 | print(f"episode: {episode_i}") 75 | while not done: 76 | action_index = env.sample_masked_action_index() 77 | observation, reward, done, info = env.step(action_index) 78 | if verbose: 79 | if "lookup_value(solve_system(append_to_empty_list('p_0')),Variable('b'))" in info['raw_observation']: 80 | print() 81 | print(f"\t\tS': {info['raw_observation']}, R: {reward}, done: {done}") 82 | if reward == 1: 83 | graph_guessed_correctly = True 84 | step_i += 1 85 | episode_i += 1 86 | assert graph_guessed_correctly 87 | print(f'graph: {info["raw_observation"].split(";")[1]}') 88 | print(f"{episode_i} trials taken to guess: {question}") 89 | 90 | 91 | def filter_univariate(examples): 92 | univariate_examples = [] 93 | for example_dict in examples: 94 | question = example_dict['question'] 95 | formal_elements = extract_formal_elements(question, cast=False) 96 | function = formal_elements[0] 97 | num_vars = len([ch for ch in set(function) if ch.isalpha()]) 98 | if num_vars == 1: 99 | univariate_examples.append(example_dict) 100 | return univariate_examples 101 | 102 | 103 | def get_module_name_from_filepath(fp): 104 | module_name = fp.split("/")[-1].split(".txt")[0] 105 | if "compose" in module_name: 106 | module_name = module_name.split("_compose")[0] 107 | else: 108 | module_name = module_name 109 | return module_name 110 | 111 | 112 | def load_question_answer_pairs(filepath): 113 | qa_pairs = [] 114 | with open(filepath, "r") as f: 115 | lines = f.readlines() 116 | num_pairs = len(lines) // 2 117 | for i in range(0, 2 * num_pairs, 2): 118 | question = lines[i].strip() 119 | answer = lines[i + 1].strip() 120 | qa_pairs.append((question, answer)) 121 | return qa_pairs 122 | 123 | 124 | # load data 125 | def load_data(config, train=True): 126 | data = {} 127 | print("loading problems") 128 | if train: 129 | problem_filepaths = [os.path.join(config["data_dirpath"], filename) for filename in config["selected_filenames"]] 130 | else: 131 | problem_filepaths = [os.path.join(config["test_data_dirpath"], filename) for filename in config["selected_filenames"]] 132 | 133 | problem_counts = {} 134 | for filepath in tqdm(problem_filepaths): 135 | with open(filepath, "r") as f: 136 | lines = f.readlines() 137 | num_pairs = min(len(lines) // 2, config["num_problems_per_module"]) 138 | for i in range(0, 2 * num_pairs, 2): 139 | question = lines[i].strip() 140 | answer = lines[i + 1].strip() 141 | # for uncomposed problems set difficulty to 0 to distinguish them 142 | difficulty = ( 143 | len(re.split("(? config["max_difficulty"]: 149 | continue 150 | module_name = get_module_name_from_filepath(filepath) 151 | # increment problem count for (module_name, difficulty) 152 | if (module_name, difficulty) in problem_counts: 153 | problem_counts[(module_name, difficulty)] += 1 154 | else: 155 | problem_counts[(module_name, difficulty)] = 1 156 | # store problem 157 | problem_dict = {'module_difficulty_index': problem_counts[(module_name, difficulty)], 158 | 'question': question, 159 | 'answer': answer} 160 | if module_name in data: 161 | if difficulty in data[module_name]: 162 | data[module_name][difficulty].append(problem_dict) 163 | else: 164 | data[module_name][difficulty] = [problem_dict] 165 | else: 166 | data[module_name] = {difficulty: [problem_dict]} 167 | if config["univariate_differentiation"]: 168 | data['calculus__differentiate'][0] = filter_univariate(data['calculus__differentiate'][0]) 169 | return data 170 | 171 | def split_validation_data(config, train): 172 | val = {} 173 | for module_name in train: 174 | val[module_name] = {} 175 | for difficulty in train[module_name]: 176 | num_examples = len(train[module_name][difficulty]) 177 | num_val = int(num_examples * config["validation_percentage"]) 178 | val[module_name][difficulty] = train[module_name][difficulty][:num_val] 179 | train[module_name][difficulty] = train[module_name][difficulty][num_val:] 180 | assert ( 181 | len(train[module_name][difficulty]) 182 | + len(val[module_name][difficulty]) 183 | == num_examples 184 | ) 185 | return val -------------------------------------------------------------------------------- /params.yaml: -------------------------------------------------------------------------------- 1 | { 2 | 3 | "num_problems_per_module": 1000000, # int, max samples selected per module 4 | "num_problems_per_module_corpus": 30000, # int, max samples per module used to create tokenizer 5 | "validation_percentage": 0.2, # float, percentage of data to use as validation set 6 | "encode_question": true, # bool, controls whether question is encoded using tokenizer 7 | "max_sequence_length": 125, # int, size of the observation space 8 | "question_vocab_size": 250, # int, tokenizer vocab size 9 | "max_difficulty": 0, # int, controls maximum difficulty of problems in the dataset 10 | "univariate_differentiation": true, # bool, "true" removes all problems from calculus_differentiate with multiple variables 11 | "num_environments": 50, # int, Number of environments to run in parallel 12 | "corpus_path": "corpus.txt", # str, path to corpus for training tokenizer 13 | "tokenizer_filepath": "tokenizer", # str, prefix for saving tokenizer 14 | "max_formal_elements": 13, # int, max number of formal elements in a given problem. Setting this lower reduces action space 15 | "max_num_nodes": 7, # int, max number of nodes in a constructed graph. Setting this lower limits search depth. 16 | "data_download_location": "mathematics_dataset-v1.0.tar.gz", # str, path to download packed data 17 | "data_unpack_dir": ".", # str, directory to unpack data to. Creates a new directory in this location 18 | "all_data_dirpath": "mathematics_dataset-v1.0/train-easy", # str, path to data before splitting into train and test sets 19 | "data_dirpath": "mathematics_dataset-v1.0/train", # str, path to train data 20 | "test_data_dirpath": "mathematics_dataset-v1.0/test", # str, path to test data 21 | "test_percentage": 0.1, # float, Percentage of data to use as test set 22 | "selected_filenames": [ # List[str], modules to include in training/evaluation 23 | 'numbers__is_factor.txt', 24 | 'numbers__is_prime.txt', 25 | 'numbers__list_prime_factors.txt', 26 | 'calculus__differentiate.txt', 27 | 'polynomials__evaluate.txt', 28 | 'numbers__div_remainder.txt', 29 | 'numbers__gcd.txt', 30 | 'numbers__lcm.txt', 31 | 'algebra__linear_1d.txt', 32 | 'algebra__polynomial_roots.txt', 33 | 'algebra__linear_2d.txt', 34 | 'algebra__linear_1d_composed.txt', 35 | 'algebra__linear_2d_composed.txt', 36 | 'algebra__polynomial_roots_composed.txt', 37 | 'calculus__differentiate_composed.txt', 38 | 'numbers__div_remainder_composed.txt', 39 | 'numbers__gcd_composed.txt', 40 | 'numbers__is_factor_composed.txt', 41 | 'numbers__is_prime_composed.txt', 42 | 'numbers__lcm_composed.txt', 43 | 'numbers__list_prime_factors_composed.txt', 44 | 'polynomials__evaluate_composed.txt', 45 | 'polynomials__compose.txt' 46 | ], 47 | "operators":[ # List[str], operators that can be used for constructing graphs 48 | "lookup_value", 49 | "solve_system", 50 | "append", 51 | "append_to_empty_list", 52 | "factor", 53 | "differentiate", 54 | "mod", 55 | "gcd", 56 | "divides", 57 | "is_prime", 58 | "lcm", 59 | "lcd", 60 | "prime_factors", 61 | "evaluate_function", 62 | "not_op" 63 | # "differentiate_wrt", 64 | # "make_equation", 65 | # "simplify", 66 | # "make_function", 67 | # "replace_arg", 68 | # "lookup_value_equation", 69 | # "extract_isolated_variable", 70 | # "substitution_left_to_right", 71 | ] 72 | } 73 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='math_prog_synth_env', 4 | version='0.0.1', 5 | install_requires=['gym', 'sympy', 'numpy', 'scipy', 'sentencepiece', 'torch', 'mathematics_dataset', 'tqdm', 6 | 'sklearn', "google-cloud-storage", "pyyaml", "multiprocess"] 7 | ) --------------------------------------------------------------------------------