├── .gitignore ├── EQL_Layer_tf.py ├── LICENSE ├── README.md ├── createjobs.py ├── data_utils.py ├── evaluation.py ├── example_data ├── F1data_test ├── F1data_train_val ├── F2data_test ├── F2data_train_val ├── F3data_test ├── F3data_train_val ├── F4data_test ├── F4data_train_val ├── F5data_test └── F5data_train_val ├── example_results ├── F1 │ ├── graph0_y1.png │ ├── latex0.png │ ├── parameters.json │ └── results.csv └── F2 │ ├── graph_y0.png │ ├── latex_y0.png │ ├── parameters.json │ └── results.csv ├── model_selection.py ├── timeout.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific 2 | data 3 | results 4 | jobs 5 | *.sub 6 | *.sh 7 | 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .dmypy.json 118 | dmypy.json -------------------------------------------------------------------------------- /EQL_Layer_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Representations of EQL functions and layers. 3 | - *EQL_Layer* are regularized tf.layers.Dense objects representing the function layers in the network. The 4 | intermediate EQL_Layers consist of multiple *EQL_fn* objects and have a helper method get_fns to retrieve the 5 | fn structure of each layer. 6 | - *EQL_fn* is a a group of *self.repeats* identical functions. It takes the input for all these functions as a list 7 | and returns a list of all the outputs. *self.tf_fn* and *self.sympy_fn* are tensorflow and sympy representations of 8 | the function used. 9 | - *reg_div* implements the regularized division and reg_div.__call__ is used like a normal tf function in an EQL_fn. 10 | """ 11 | import sympy as sp 12 | import tensorflow as tf 13 | 14 | from utils import number_of_positional_arguments 15 | 16 | 17 | class EQL_fn(object): 18 | """EQL_fn is a group of *self.repeats* identical nodes, e.g. 10 sine functions.""" 19 | 20 | def __init__(self, tf_fn, sympy_fn, repeats): 21 | """ 22 | :param tf_fn: A Tensorflow operation or a class instance with __call__ acting as a tensorflow function. 23 | :param sympy_fn: A sympy operation matching tf_fn 24 | :param repeats: number of times the function is used in the layer 25 | """ 26 | self.tf_fn = tf_fn 27 | self.sympy_fn = sympy_fn 28 | self.num_positional_args = number_of_positional_arguments(self.tf_fn) 29 | self.repeats = repeats 30 | 31 | def __call__(self, data): 32 | slices = tf.split(data, [self.repeats] * self.num_positional_args, axis=1) 33 | return self.tf_fn(*slices) 34 | 35 | def get_total_dimension(self): 36 | return self.repeats * self.num_positional_args 37 | 38 | 39 | class reg_div(object): 40 | """Save regularized division, used as tf function for division layer.""" 41 | 42 | def __init__(self, div_thresh_fn): 43 | """ 44 | Initializing regularized division. 45 | :param div_thresh_fn: a fn that calculated the division threshold from a given train step tensor 46 | """ 47 | self.div_thresh_fn = div_thresh_fn 48 | 49 | def __call__(self, numerator, denominator): 50 | """ 51 | Acts as a normal tensorflow math function, performing save regularized division. Implemented as a class so that 52 | it can follow the tf.div signature. Adds division loss (threshold penalty) to loss collections. 53 | """ 54 | step = tf.train.get_or_create_global_step() 55 | div_thresh = self.div_thresh_fn(step) 56 | mask = tf.cast(denominator > div_thresh, dtype=tf.float32) 57 | div = tf.reciprocal(tf.abs(denominator) + 1e-10) 58 | output = numerator * div * mask 59 | P_theta = tf.maximum((div_thresh - denominator), 0.0) # equation 5 in paper 60 | tf.add_to_collection('Threshold_penalties', P_theta) 61 | return output 62 | 63 | 64 | # Dict of function tuples consisting of matching tensorflow functions/function classes and sympy funcs 65 | dict_of_ops = {'multiply': (tf.multiply, sp.Symbol.__mul__), 66 | 'sin': (tf.sin, sp.sin), 67 | 'cos': (tf.cos, sp.cos), 68 | 'id': (tf.identity, sp.Id), 69 | 'sub': (tf.subtract, sp.Symbol.__sub__), 70 | 'log': (tf.log, sp.log), 71 | 'exp': (tf.exp, sp.exp), 72 | 'reg_div': (reg_div, sp.Symbol.__truediv__)} 73 | 74 | 75 | def validate_op_dict(op_dict): 76 | """Checks if dictionary only includes keywords matching the keywords in dict_of_ops.""" 77 | if not isinstance(op_dict, dict): 78 | raise ValueError("Operation dict has to be a dictionary.") 79 | if not op_dict: 80 | raise ValueError("No parameters given") 81 | for key in op_dict: 82 | if key not in dict_of_ops: 83 | raise ValueError('Unknown parameter {} passed'.format(key)) 84 | 85 | 86 | def op_dict_to_eql_op_list(op_dict): 87 | """Transforms a dict of fn_tuples specified by strings into list of EQL functions.""" 88 | list_of_EQL_fn = [] 89 | for key, value in op_dict.items(): 90 | if key == 'reg_div': 91 | reg_division = dict_of_ops[key][0](value.div_thresh_fn) # This is __init__ call to reg_div class. 92 | list_of_EQL_fn.append(EQL_fn(reg_division.__call__, dict_of_ops[key][1], repeats=value.repeats)) 93 | else: 94 | list_of_EQL_fn.append(EQL_fn(dict_of_ops[key][0], dict_of_ops[key][1], repeats=value)) 95 | return list_of_EQL_fn 96 | 97 | 98 | def kill_small_elements(matrix, threshold): 99 | """Routine setting all elements in a matrix with absolute value smaller than a given threshold to zero.""" 100 | kill_matrix = tf.abs(matrix) > threshold 101 | matrix = tf.multiply(tf.cast(kill_matrix, dtype=tf.float32), matrix) 102 | return matrix 103 | 104 | 105 | class EQL_Layer(object): 106 | """ 107 | CREATES THE EQL LAYERS 108 | Uses module 'tf.layers.dense' to perform the operation (inputs*weights+biases), 109 | data is split, chunks to given to different activation functions. 110 | Returns: Output tensor following concatenation of the chunks after passing them through corresponding activation functons 111 | """ 112 | 113 | def __init__(self, weight_init_scale, seed=None, **op_dict): 114 | validate_op_dict(op_dict) 115 | self.list_of_ops = op_dict_to_eql_op_list(op_dict) 116 | self.matmul_output_dim = sum(item.get_total_dimension() for item in self.list_of_ops) 117 | self.w_init_scale = weight_init_scale 118 | self.seed = seed 119 | 120 | def __call__(self, data, l1_reg_sched, l0_threshold): 121 | layer_output = self.get_matmul_output(data, l1_reg_sched, l0_threshold) 122 | indices = [item.get_total_dimension() for item in self.list_of_ops] 123 | slices = tf.split(layer_output, indices, axis=1) 124 | outputs = [op(tensor_slice) for op, tensor_slice in zip(self.list_of_ops, slices)] 125 | return tf.concat(outputs, axis=1) 126 | 127 | def get_matmul_output(self, data, l1_reg_sched, l0_threshold): 128 | """Method building the regularized matrix multiplication layer and returning the output for given data.""" 129 | k_init = tf.random_uniform_initializer(minval=-self.w_init_scale, maxval=self.w_init_scale, seed=self.seed) 130 | k_regu = tf.contrib.layers.l1_regularizer(scale=l1_reg_sched) 131 | layer = tf.layers.Dense(self.matmul_output_dim, kernel_initializer=k_init, kernel_regularizer=k_regu, 132 | bias_regularizer=k_regu) 133 | layer.build(data.get_shape()) 134 | new_kernel = kill_small_elements(layer.kernel, l0_threshold) 135 | new_bias = kill_small_elements(layer.bias, l0_threshold) 136 | assign_kernel = tf.assign(layer.kernel, new_kernel) 137 | assign_bias = tf.assign(layer.bias, new_bias) 138 | with tf.control_dependencies([assign_kernel, assign_bias]): 139 | layer_output = layer(data) 140 | return layer_output 141 | 142 | def get_fns(self): 143 | """Method returning the functions used in layer as a list of (tf_fn, sympy fn, repeats, num_args) tuples.""" 144 | fn_list = [(eql_fn.tf_fn, eql_fn.sympy_fn, eql_fn.repeats, eql_fn.num_positional_args) for eql_fn in 145 | self.list_of_ops] 146 | return fn_list 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Arnab Bhattacharjee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Equation Learner](https://al.is.tuebingen.mpg.de/publications/sahoolampertmartius2018-eqldiv). 2 | 3 | By [Subham S. Sahoo](https://arxiv.org/search/cs?searchtype=author&query=Sahoo%2C+S+S), [Christoph H. Lampert](https://cvml.ist.ac.at/) and [Georg Martius](http://georg.playfulmachines.com/) 4 | 5 | Implemented by [Anselm Paulus](https://scholar.google.com/citations?user=njZL5CQAAAAJ&hl=en), Arnab Bhattacharjee and [Michal Rolínek](https://scholar.google.de/citations?user=DVdSTFQAAAAJ&hl=en). 6 | 7 | Autonomous Learning Group, [Max Planck Institute](https://is.tuebingen.mpg.de/) for Intelligent Systems. 8 | 9 | ## Table of Contents 10 | 0. [Introduction](#introduction) 11 | 0. [Usage](#usage) 12 | 0. [Dependencies](#dependencies) 13 | 0. [Notes](#notes) 14 | 15 | 16 | ## Introduction 17 | 18 | This repository contains TensorFlow implementation of the EQL-Div architecture presented in ICML 2018 paper ["Learning Equations for Extrapolation and Control"](https://al.is.tuebingen.mpg.de/publications/sahoolampertmartius2018-eqldiv). This work proposes a neural network architecture for symbolic regression. 19 | There is also a [Theano implementation, see martius-lab/EQL](https://github.com/martius-lab/EQL). 20 | 21 | 22 | ## Usage 23 | 24 | ### Prepare data 25 | Either provide a python function to 'learn' by calling 26 | ``` 27 | python3 data_utils.py "{'file_name': 'F1data', 'fn_to_learn': 'F1', 'train_val_examples': 10000, 'test_examples': 5000}" 28 | 29 | ``` 30 | or use your own numpy arrays saved in training/evaluation data files. 31 | 32 | ### Train individual model 33 | 34 | Once the data is fixed train the model with 35 | ``` 36 | python3 train.py '{"train_val_file": "data/F1data_train_val", "test_file": "data/F1data_test"}' 37 | ``` 38 | Or possibly change some parameters with 39 | ``` 40 | python3 train.py '{"train_val_file": "data/F1data_train_val", "test_file": "data/F1data_test", "batch_size": 25}' 41 | ``` 42 | 43 | ### Perform model selection 44 | 45 | In case you want to follow the model selection procedure from the paper, first generate runfiles for all the required settings with 46 | ``` 47 | python3 createjobs.py '{"train_val_file": "data/F1data_train_val", "test_file": "data/F1data_test"}' 48 | ``` 49 | 50 | Then run all scripts in the jobs folder. 51 | 52 | Finally the model selection is performed by 53 | 54 | ``` 55 | python3 model_selection.py "{'results_path': 'results/model_selection'}" 56 | ``` 57 | 58 | ### Inspect the learned formulas 59 | 60 | In each result folder one can find png files with latex and graph representations of the learned formulas. 61 | 62 | Latex representation of function F1: 63 | ![alt text](example_results/F1/latex0.png "Latex example") 64 | Graph representation of function F1: 65 | ![alt text](example_results/F1/graph0_y1.png "Graph example") 66 | 67 | ## Dependencies: 68 | - python>=3.5 69 | - tensorflow>=1.7 70 | - graphviz (including binaries) 71 | - latex 72 | 73 | ## Notes 74 | 75 | *Disclaimer*: This code is a PROTOTYPE and may contains bugs. Use at your own risk. 76 | 77 | *Contribute*: If you spot some incompatibility of have some additional ideas, contribute via a pull request! Thank you! 78 | -------------------------------------------------------------------------------- /createjobs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generation of job files for model selection. 3 | - *generate_jobs* creates shell script files for multiple jobs. Varying the number of hidden layers and other 4 | parameters allows for effective model selection. Also creates a submission file to submit all jobs. 5 | """ 6 | import os 7 | from ast import literal_eval 8 | from sys import argv 9 | 10 | import numpy as np 11 | 12 | job_dir = 'jobs' 13 | result_dir = os.path.join("results", "model_selection") 14 | submitfile = os.path.join('jobs', 'submit.sh') 15 | 16 | 17 | def generate_jobs(train_val_file, test_file): 18 | if not os.path.exists(job_dir): 19 | os.mkdir(job_dir) 20 | with open(submitfile, 'w') as submit: 21 | pwd = os.getcwd() 22 | id = 0 23 | l1_reg_range = 10 ** (-np.linspace(35, 60, num=10) / 10.0) 24 | for l1_reg_scale in l1_reg_range: 25 | for num_h_layers in [1, 2, 3, 4]: 26 | params = dict(model_base_dir=result_dir, train_val_file=train_val_file, test_file=test_file, id=id, 27 | num_h_layers=num_h_layers, reg_scale=l1_reg_scale, kill_summaries=True, 28 | generate_symbolic_expr=False) 29 | dict_str = str(params) 30 | cmd = '{} "{}"'.format('python3 ' + os.path.join(pwd, 'train.py '), dict_str) 31 | script_fname = os.path.join(job_dir, str(id) + ".sh") 32 | submit.write(str(id) + ".sh" + "\n") 33 | with open(script_fname, 'w') as f: 34 | f.write(cmd) 35 | os.chmod(script_fname, 0o755) # makes script executable 36 | id += 1 37 | os.chmod(submitfile, 0o755) # makes script executable 38 | print('Jobs succesfully generated.') 39 | 40 | 41 | if __name__ == '__main__': 42 | if len(argv) > 1: 43 | passed_dict = literal_eval(argv[1]) 44 | train_val_file = passed_dict['train_val_file'] 45 | test_file = passed_dict['test_file'] 46 | else: 47 | raise ValueError('No dict given. Please pass a dict {"train_val_file": ..., "test_file": ...}.') 48 | print('Using %s and %s as paths to data files.' % (train_val_file, test_file)) 49 | generate_jobs(train_val_file=train_val_file, test_file=test_file) 50 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handling of data related tasks, e.g. reading of input and generating data files. 3 | - *input_from_file* extracts input and output data from data file. Data file must contain a list of training, 4 | validation, extrapolation and metadata where metadata is a dictionary of the data parameters. 5 | - *input_penalty_epoch* generates new input data (using penalty boundaries) for penalty epochs. The output fed into 6 | the Estimator for these epochs is set to zero because in penalty epochs we compute gradients based only on the 7 | output calculated by the EQL, not the expected output (no MSE or similar is calculated). 8 | - *files_from_fn* generates a data file containing training-, validation-, extrapolation- and metadata for a fn 9 | passed through the input in parameter dictionary. The python function has to be defined in data_utils.py. 10 | *files_from_fn* is also called when *data_utils.py* is run from command line with the parameter dictionary passed as 11 | a string. 12 | """ 13 | import gzip 14 | import os.path 15 | import pickle 16 | from ast import literal_eval 17 | from sys import argv 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from utils import to_float32, number_of_positional_arguments 23 | 24 | """Equation 1-4 from the paper. Equation 5 describes the cart pendulum from the paper.""" 25 | 26 | 27 | def F1(x1, x2, x3, x4): 28 | """Requires 1 hidden layer.""" 29 | y0 = (np.sin(np.pi * x1) + np.sin(2 * np.pi * x2 + np.pi / 8.0) + x2 - x3 * x4) / 3.0 30 | return y0, 31 | 32 | 33 | def F2(x1, x2, x3, x4): 34 | """Requires 2 hidden layers.""" 35 | y0 = (np.sin(np.pi * x1) + x2 * np.cos(2 * np.pi * x1 + np.pi / 4.0) + x3 - x4 * x4) / 3.0 36 | return y0, 37 | 38 | 39 | def F3(x1, x2, x3, x4): 40 | """Requires 2 hidden layers.""" 41 | y0 = ((1.0 + x2) * np.sin(np.pi * x1) + x2 * x3 * x4) / 3.0 42 | return y0, 43 | 44 | 45 | def F4(x1, x2, x3, x4): 46 | """Requires 4 hidden layers.""" 47 | y0 = 0.5 * (np.sin(np.pi * x1) + np.cos(2.0 * x2 * np.sin(np.pi * x1)) + x2 * x3 * x4) 48 | return y0, 49 | 50 | 51 | def F5(x1, x2, x3, x4): 52 | """Equation for cart pendulum. Requires 4 hidden layers.""" 53 | y1 = x3 54 | y2 = x4 55 | y3 = (-x1 - 0.01 * x3 + x4 ** 2 * np.sin(x2) + 0.1 * x4 * np.cos(x2) + 9.81 * np.sin(x2) * np.cos(x2)) \ 56 | / (np.sin(x2) ** 2 + 1) 57 | y4 = -0.2 * x4 - 19.62 * np.sin(x2) + x1 * np.cos(x2) + 0.01 * x3 * np.cos(x2) - x4 ** 2 * np.sin(x2) * np.cos(x2) \ 58 | / (np.sin(x2) ** 2 + 1) 59 | return y1, y2, y3, y4, 60 | 61 | 62 | data_gen_params = {'file_name': 'F1data', # file name for the generated data file, will be created in data/file_name 63 | 'fn_to_learn': 'F1', # python function to learn, should be defined in data_utils 64 | 'train_val_examples': 10000, # total number of examples for training and validation 65 | 'train_val_bounds': (-1.0, 1.0), # domain boundaries for validation and training normal epochs 66 | 'test_examples': 5000, # number of test examples, if set to None no test_data file is created 67 | 'test_bounds': (-2.0, 2.0), # domain boundaries for test data 68 | 'noise': 0.01, 69 | 'seed': None 70 | } 71 | 72 | 73 | def generate_data(fn, num_examples, bounds, noise, seed=None): 74 | np.random.seed(seed) 75 | lower, upper = bounds 76 | input_dim = number_of_positional_arguments(fn) 77 | xs = np.random.uniform(lower, upper, (num_examples, input_dim)).astype(np.float32) 78 | xs_as_list = np.split(xs, input_dim, axis=1) 79 | ys = fn(*xs_as_list) 80 | ys = np.concatenate(ys, axis=1) 81 | ys += np.random.uniform(-noise, noise, ys.shape) 82 | return xs, ys 83 | 84 | 85 | def data_from_file(filename, split=None): 86 | """ 87 | Routine extracting data from given file. 88 | :param filename: path to the file data should be extracted from 89 | :param split: if split is not None, the data is split into two chunks, one of size split*num_examples and one of 90 | size (1-split)*num_examples. If it is None, all data is returned as one chunk 91 | :return: if split is not None list of data-chunks, otherwise all data as one chunk 92 | """ 93 | data = to_float32(pickle.load(gzip.open(filename, "rb"), encoding='latin1')) 94 | if split is not None: 95 | split_point = int(len(data[0]) * split) 96 | data = [np.split(dat, [split_point]) for dat in data] 97 | data = zip(*data) 98 | return data 99 | 100 | 101 | def input_from_data(data, batch_size, repeats): 102 | """ 103 | Function turning data into input for the network. Provides enough data for *repeats* epochs. 104 | :param data: numpy array of data 105 | :param batch_size: size of batch returned, only relevant for training regime 106 | :param repeats: integer factor determining how many times (epochs) data is reused 107 | :return: *repeats* times data split into inputs and labels in batches 108 | """ 109 | ds = tf.data.Dataset.from_tensor_slices(data).shuffle(buffer_size=1000).repeat(repeats).batch(batch_size) 110 | xs, ys = ds.make_one_shot_iterator().get_next() 111 | return xs, ys 112 | 113 | 114 | def get_penalty_data(num_examples, penalty_bounds, num_inputs, num_outputs): 115 | """ 116 | Function returning penalty data. In penalty epoch labels are irrelevant, therefore labels are set to zero. 117 | Only provides enough data to train for one epoch. 118 | :param num_examples: Total number of examples to be trained in penalty epoch. 119 | :param penalty_bounds: Boundaries to be used to generate penalty data, either a tuple or a list of tuples 120 | """ 121 | if isinstance(penalty_bounds, tuple): 122 | lower, upper = penalty_bounds 123 | else: 124 | lower, upper = zip(*penalty_bounds) 125 | xs = np.random.uniform(lower, upper, (num_examples, num_inputs)).astype(np.float32) 126 | ys = np.zeros((num_examples, num_outputs), dtype=np.float32) 127 | return xs, ys 128 | 129 | 130 | def get_input_fns(train_val_split, batch_size, train_val_file, test_file, penalty_every, num_inputs, num_outputs, 131 | train_val_examples, penalty_bounds, extracted_penalty_bounds, **_): 132 | """ 133 | Routine to determine which input function to use for training(normal or penalty epoch) / validation / testing. 134 | :param train_val_split: float specifying the data split, .8 means 80% of data is used for training, 20% for val 135 | :param batch_size: Size of batches used for training (both in normal and penalty epochs). 136 | :param train_val_file: Path to file containing training and validation data. 137 | :param test_file: Path to file containing test data. 138 | :param penalty_every: Integer specifying after how many normal epochs a penalty epoch occurs. 139 | :param num_inputs: number of input arguments 140 | :param num_outputs: number of outputs 141 | :param train_val_examples: number of examples to use for training and validation 142 | :param penalty_bounds: default domain boundaries used to generate penalty epoch training data. 143 | :param extracted_penalty_bounds: domain boundaries for penalty data generation extracted from data files 144 | :return: functions returning train-, penalty_train-, validation- and (if provided in datafile) test-input 145 | if no extrapolation test data is provided test_input is None 146 | """ 147 | penalty_bounds = penalty_bounds or extracted_penalty_bounds 148 | train_data, val_data = data_from_file(train_val_file, split=train_val_split) 149 | penalty_data = get_penalty_data(num_examples=int(train_val_split * train_val_examples), 150 | penalty_bounds=penalty_bounds, num_inputs=num_inputs, num_outputs=num_outputs) 151 | train_input = lambda: input_from_data(data=train_data, batch_size=batch_size, repeats=penalty_every) 152 | val_input = lambda: input_from_data(data=val_data, batch_size=batch_size, repeats=1) 153 | penalty_input = lambda: input_from_data(data=penalty_data, batch_size=batch_size, repeats=1) 154 | if test_file is not None: 155 | test_data = data_from_file(test_file) 156 | test_input = lambda: input_from_data(data=test_data, batch_size=batch_size, repeats=1) 157 | else: 158 | test_input = None 159 | return train_input, penalty_input, val_input, test_input 160 | 161 | 162 | def extract_metadata(train_val_file, test_file, domain_bound_factor=2, res_bound_factor=10): 163 | """ 164 | Routine to extract additional information about data from data file. 165 | :param train_val_file: Path to training/validation data file 166 | :param test_file: Path to extrapolation data file 167 | :param domain_bound_factor: factor to scale the domain boundary of train/val data to get penalty data boundary 168 | :param res_bound_factor: factor to scale the maximum output of train/val data to get penalty data result boundary 169 | :return: metadata dict 170 | """ 171 | train_val_data = pickle.load(gzip.open(train_val_file, "rb"), encoding='latin1') 172 | train_val_examples = train_val_data[0].shape[0] 173 | num_inputs = train_val_data[0].shape[1] 174 | num_outputs = train_val_data[1].shape[1] 175 | extracted_output_bound = np.max(np.abs(train_val_data[1])) * res_bound_factor 176 | if test_file is not None: 177 | test_data = pickle.load(gzip.open(test_file, "rb"), encoding='latin1') 178 | extracted_penalty_bounds = zip(np.min(test_data[0], axis=0), np.max(test_data[0], axis=0)) 179 | else: 180 | extracted_penalty_bounds = zip(np.min(train_val_data[0], axis=0) * domain_bound_factor, 181 | np.max(train_val_data[0], axis=0) * domain_bound_factor) 182 | metadata = dict(train_val_examples=train_val_examples, num_inputs=num_inputs, num_outputs=num_outputs, 183 | extracted_output_bound=extracted_output_bound, extracted_penalty_bounds=extracted_penalty_bounds) 184 | return metadata 185 | 186 | 187 | def files_from_fn(file_name, fn_to_learn, train_val_examples, test_examples, train_val_bounds, 188 | test_bounds, noise, seed=None): 189 | """ 190 | Routine generating .gz file with train-, validation, test and meta-data from function. 191 | It is worth noting that that the function is saved as a string in metadata. 192 | :param file_name: Name of the data file to be created. It is being saved in the directory 'data'. 193 | :param fn_to_learn: string name of python function used to generate data. Should be defined in data_utils.py. 194 | :param train_val_examples: Total number of examples used for training and validation. 195 | :param train_val_bounds: Boundaries used to generate training and validation data. 196 | :param test_examples: Total number of examples used for testing. 197 | :param test_bounds: Boundaries used to generate test data. 198 | """ 199 | fn_to_learn = globals()[fn_to_learn] 200 | if not os.path.exists('data'): 201 | os.mkdir('data') 202 | train_val_set = generate_data(fn=fn_to_learn, num_examples=train_val_examples, bounds=train_val_bounds, noise=noise, 203 | seed=seed) 204 | train_val_data_file = os.path.join('data', file_name + '_train_val') 205 | pickle.dump(train_val_set, gzip.open(train_val_data_file, "wb")) 206 | print('Successfully created train/val data file in %s.' % train_val_data_file) 207 | 208 | if test_examples is not None: 209 | test_set = generate_data(fn=fn_to_learn, num_examples=test_examples, bounds=test_bounds, noise=noise, seed=seed) 210 | test_data_file = os.path.join('data', file_name + '_test') 211 | pickle.dump(test_set, gzip.open(test_data_file, "wb")) 212 | print('Successfully created test data file in %s.' % test_data_file) 213 | 214 | 215 | if __name__ == '__main__': 216 | if len(argv) > 1: 217 | print('Updating default parameters.') 218 | data_gen_params.update(literal_eval(argv[1])) 219 | else: 220 | print('Using default parameters.') 221 | files_from_fn(**data_gen_params) 222 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for symbolic manipulation with formulas and evaluation. 3 | - Implements *EvaluationHook* which is used to generate symbolic expressions of the current formula represented 4 | by the network structure and to calculate the complexity of the current network. 5 | - Generation of the symbolic expression mainly consists of *symbolic_eql_layer* and *symbolic_matmul_and_bias* 6 | routines, which perform the symbolic representation of the EQL fns and the matrix multiplication/bias addition. 7 | - Symbolic expressions are saved as pngs of a latex representation and of a rendered graphviz graph. 8 | - The complexity calculation is performed in three steps: 9 | calculate_complexity -> complexity_of_layer -> complexity of node 10 | """ 11 | 12 | from functools import reduce 13 | from os import path 14 | 15 | import numpy as np 16 | import sympy 17 | import tensorflow as tf 18 | from graphviz import Source 19 | from sympy.printing.dot import dotprint 20 | from tensorflow.python.training.session_run_hook import SessionRunHook 21 | 22 | from timeout import time_limit, TimeoutException 23 | from utils import generate_arguments, yield_with_repeats, weight_name_for_i 24 | 25 | 26 | class EvaluationHook(SessionRunHook): 27 | """Hook for saving evaluating the eql.""" 28 | 29 | def __init__(self, list_of_vars, store_path=None): 30 | self.list_of_vars = list_of_vars 31 | self.weights = None 32 | self.store_path = store_path 33 | self.fns_list = None 34 | self.round_decimals = 3 35 | self.complexity = None 36 | 37 | def begin(self): 38 | self.iteration = 0 39 | 40 | def after_create_session(self, session, coord): 41 | pass 42 | 43 | def before_run(self, run_context): 44 | if self.iteration == 0: 45 | graph = tf.get_default_graph() 46 | tens = {v: graph.get_tensor_by_name(v) for v in self.list_of_vars} 47 | else: 48 | tens = {} 49 | return tf.train.SessionRunArgs(fetches=tens) 50 | 51 | def after_run(self, run_context, run_values): 52 | if self.iteration == 0: 53 | self.weights = run_values.results 54 | self.iteration += 1 55 | 56 | def end(self, session): 57 | if self.store_path is not None: 58 | if self.fns_list is None: 59 | raise ValueError("Network structure not provided. Call init_network_structure first.") 60 | kernels = [value for key, value in self.weights.items() if 'kernel' in key.lower()] 61 | biases = [value for key, value in self.weights.items() if 'bias' in key.lower()] 62 | self.complexity = calculate_complexity(kernels, biases, self.fns_list, self.thresh) 63 | if self.generate_symbolic_expr: 64 | save_symbolic_expression(kernels, biases, self.fns_list, self.store_path, self.round_decimals) 65 | 66 | def init_network_structure(self, model, params): 67 | self.fns_list = [layer.get_fns() for layer in model.eql_layers] 68 | self.thresh = params['complexity_threshold'] 69 | self.generate_symbolic_expr = params['generate_symbolic_expr'] 70 | 71 | def get_complexity(self): 72 | if self.complexity is not None: 73 | return self.complexity 74 | else: 75 | raise ValueError('Complexity not yet evaluated.') 76 | 77 | 78 | def set_evaluation_hook(num_h_layers, model_dir, **_): 79 | kernel_tensornames = [weight_name_for_i(i, 'kernel') for i in range(num_h_layers + 1)] 80 | bias_tensornames = [weight_name_for_i(i, 'bias') for i in range(num_h_layers + 1)] 81 | symbolic_hook = EvaluationHook([*kernel_tensornames, *bias_tensornames], store_path=model_dir) 82 | return symbolic_hook 83 | 84 | 85 | @time_limit(60) 86 | def proper_simplify(expr): 87 | """ Combine trig and normal simplification for sympy expression.""" 88 | return sympy.simplify(sympy.trigsimp(expr)) 89 | 90 | 91 | def symbolic_matmul_and_bias(input_nodes_symbolic, weight_matrix, bias_vector): 92 | """ Computes a symbolic representations of nodes in a layer after matrix mul of the previous layer. 93 | :param input_nodes_symbolic: list of sympy expressions 94 | :param weight_matrix: 2D numpy array of shape (input_dim, output_dim) 95 | :param bias_vector: 1D numpy array of shape (output_dim) 96 | :return: list of sympy expressions at output nodes of length (output_dim) 97 | """ 98 | 99 | def output_for_index(i): 100 | return bias_vector[i] + sum([w * x for w, x in zip(weight_matrix[:, i], input_nodes_symbolic)]) 101 | 102 | return [output_for_index(i) for i in range(weight_matrix.shape[1])] 103 | 104 | 105 | def symbolic_eql_layer(input_nodes_symbolic, output_fn_group_list): 106 | """ Computes a symbolic representation of a node given incoming weights and the output fn. 107 | :param input_nodes_symbolic: list of sympy expressions 108 | :param output_fn_group_list: list of (sympy function, repeats) tuples to be applied to input nodes. 109 | :return: list of sympy expressions at output nodes 110 | """ 111 | _, output_fns, repeats, arg_nums = zip(*output_fn_group_list) 112 | arg_iterator = generate_arguments(input_nodes_symbolic, repeats, arg_nums) 113 | fn_iterator = yield_with_repeats(output_fns, repeats) 114 | return [fn(*items) for fn, items in zip(fn_iterator, arg_iterator)] 115 | 116 | 117 | def get_symbol_list(number_of_symbols): 118 | """ Returns a list of sympy expression, each being an identity of a variable. To be used for input layer.""" 119 | return sympy.symbols(['x_{}'.format(i + 1) for i in range(number_of_symbols)], real=True) 120 | 121 | 122 | def expression_graph_as_png(expr, output_file, view=True): 123 | """ Save a PNG of rendered graph (graphviz) of the symbolic expression. 124 | :param expr: sympy expression 125 | :param output_file: string with .png extension 126 | :param view: set to True if system default PNG viewer should pop up 127 | :return: None 128 | """ 129 | assert output_file.endswith('.png') 130 | graph = Source(dotprint(expr)) 131 | graph.format = 'png' 132 | graph.render(output_file.rpartition('.png')[0], view=view, cleanup=True) 133 | 134 | 135 | def expr_to_latex_png(expr, output_file): 136 | """Saves a png of a latex representation of a symbolic expression.""" 137 | sympy.preview(expr, viewer='file', filename=output_file) 138 | 139 | 140 | def expr_to_latex(expr): 141 | """Returns latex representation (as string) of a symbolic expression.""" 142 | return sympy.latex(expr) 143 | 144 | 145 | def round_sympy_expr(expr, decimals): 146 | """Returns the expression with every float rounded to the given number of decimals.""" 147 | rounded_expr = expr 148 | for a in sympy.preorder_traversal(expr): 149 | if isinstance(a, sympy.Float): 150 | rounded_expr = rounded_expr.subs(a, round(a, decimals)) 151 | return rounded_expr 152 | 153 | 154 | def save_symbolic_expression(kernels, biases, fns_list, save_path, round_decimals): 155 | """ 156 | Saves a symbolic expression of network as pngs showing the equation as a tree and as a latex equation. 157 | :param kernels: list of 2D numpy arrays 158 | :param biases: list of 1D numpy arrays 159 | :param sympy_fns: list of lists of (tf_fn, sp_fn, repeats, num_args) tuples 160 | :param save_path: path (str) for saving the symbolic expressions 161 | :param round_decimals: integer specifying to which decimal the expression is rounded 162 | """ 163 | in_nodes = get_symbol_list(kernels[0].shape[0]) 164 | res = in_nodes 165 | for kernel, bias, fns in zip(kernels, biases, fns_list): 166 | res = symbolic_matmul_and_bias(res, kernel, bias) 167 | res = symbolic_eql_layer(res, fns) 168 | for i, result in enumerate(res): 169 | round_sympy_expr(result, round_decimals) 170 | try: 171 | proper_simplify(result) 172 | except TimeoutException or RecursionError: 173 | print('Simplification of result y%i failed. Saving representations of non-simplified formula.' % i) 174 | expr_to_latex_png(res, path.join(save_path, 'latex_y' + str(i) + '.png')) 175 | expression_graph_as_png(result, path.join(save_path, 'graph_y' + str(i) + '.png'), view=False) 176 | 177 | 178 | def calculate_complexity(kernels, biases, fns_list, thresh): 179 | """ 180 | Routine that counts units with nonzero input * output weights (only non-identity units) 181 | :param kernels: list of numpy matrices 182 | :param thresh: list of numpy arrays 183 | :param fns_list: list of lists containg (tf_fn, sp_fn, repeats, arg_num) tuples 184 | :param thresh: threshold to determine how active a node has to be to be considered active in the calculation 185 | :return: complexity (number of active nodes) of network 186 | """ 187 | complexities = [ 188 | complexity_of_layer(fns=fns, in_biases=in_biases, in_weights=in_weights, out_weights=out_weights, thresh=thresh) 189 | for fns, in_biases, in_weights, out_weights in zip(fns_list, biases[:-1], kernels[:-1], kernels[1:])] 190 | complexity = sum(complexities) 191 | return complexity 192 | 193 | 194 | def complexity_of_layer(fns, in_biases, in_weights, out_weights, thresh): 195 | """ 196 | Routine that returns the complexity (number of active nodes) of a given layer. 197 | :param fns: list of (tf_fn, sp_pn, repeats, arg_num) tuples, one for each fn block in layer 198 | :param in_biases: numpy array describing the biases added to inputs for this layer 199 | :param in_weights: numpy matrix describing the weights between the previous layer and this layer 200 | :param out_weights: numpy matrix describing the weights between this layer and the next layer 201 | :param thresh: threshold to determine how active a node has to be to be considered active in the calculation 202 | :return: complexity (number of active nodes) of a given layer 203 | """ 204 | in_weight_sum = np.sum(np.abs(in_weights), axis=0) + in_biases # adding up all abs weights contributing to input 205 | out_weight_sum = np.sum(np.abs(out_weights), axis=1) # adding up all abs weights that use specific output 206 | output_fns, _, repeats, arg_nums = zip(*fns) 207 | input_iterator = generate_arguments(all_args=in_weight_sum, repeats=repeats, arg_nums=arg_nums) 208 | fn_iterator = yield_with_repeats(output_fns, repeats) 209 | count = sum([complexity_of_node(out_weight, in_weights, fn, thresh) 210 | for out_weight, in_weights, fn in zip(out_weight_sum, input_iterator, fn_iterator)]) 211 | return count 212 | 213 | 214 | def complexity_of_node(out_weight, in_weights, fn, thresh): 215 | """ 216 | Routine that returns the complexity of a node. 217 | :param out_weight: float output weight of node 218 | :param in_weights: tuple of input weights for node 219 | :param fn: tensorflow function used in this node 220 | :param thresh: threshold to determine how active a node has to be to be considered active in the calculation 221 | :return: 1 if node is active and 0 if inactive 222 | """ 223 | if fn == tf.identity: 224 | return 0 225 | all_weights = [out_weight, *in_weights] 226 | weight_product = reduce(lambda x, y: x * y, all_weights) 227 | # Note that multiplication units can also be linear units if one of their inputs is constant 228 | # we only count the nodes with multiple inputs if both inputs are bigger than a threshold and the product of 229 | # the output weight with the sum of the input weights is bigger than the squared threshold 230 | if all(weight > thresh for weight in all_weights) and (weight_product > thresh ** len(all_weights)): 231 | return 1 232 | else: 233 | return 0 234 | -------------------------------------------------------------------------------- /example_data/F1data_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F1data_test -------------------------------------------------------------------------------- /example_data/F1data_train_val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F1data_train_val -------------------------------------------------------------------------------- /example_data/F2data_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F2data_test -------------------------------------------------------------------------------- /example_data/F2data_train_val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F2data_train_val -------------------------------------------------------------------------------- /example_data/F3data_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F3data_test -------------------------------------------------------------------------------- /example_data/F3data_train_val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F3data_train_val -------------------------------------------------------------------------------- /example_data/F4data_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F4data_test -------------------------------------------------------------------------------- /example_data/F4data_train_val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F4data_train_val -------------------------------------------------------------------------------- /example_data/F5data_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F5data_test -------------------------------------------------------------------------------- /example_data/F5data_train_val: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_data/F5data_train_val -------------------------------------------------------------------------------- /example_results/F1/graph0_y1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_results/F1/graph0_y1.png -------------------------------------------------------------------------------- /example_results/F1/latex0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_results/F1/latex0.png -------------------------------------------------------------------------------- /example_results/F1/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 20, 3 | "beta1": 0.4, 4 | "bound": 10, 5 | "complexity_threshold": 0.01, 6 | "epoch_factor": 1000, 7 | "generate_symbolic_expr": true, 8 | "id": 1, 9 | "kill_summaries": false, 10 | "l0_threshold": 0.008, 11 | "layer_width": 10, 12 | "learning_rate": 0.0005, 13 | "model_base_dir": "results", 14 | "model_dir": "results\\1", 15 | "network_init_seed": null, 16 | "num_h_layers": 1, 17 | "penalty_bounds": null, 18 | "penalty_every": 50, 19 | "reg_scale": 1e-05, 20 | "reg_sched": [ 21 | 0.25, 22 | 0.95 23 | ], 24 | "test_div_threshold": 0.0001, 25 | "test_file": "data/F1data_test", 26 | "train_val_file": "data/F1data_train_val", 27 | "train_val_split": 0.9, 28 | "weight_init_param": 1.0 29 | } -------------------------------------------------------------------------------- /example_results/F1/results.csv: -------------------------------------------------------------------------------- 1 | val_error,complexity,extr_error,id 2 | 0.0068555544,3,0.0076549686,1 3 | -------------------------------------------------------------------------------- /example_results/F2/graph_y0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_results/F2/graph_y0.png -------------------------------------------------------------------------------- /example_results/F2/latex_y0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/EQL_Tensorflow/95f6de7e9e4494fd838fbabf3622f3d76623fe2f/example_results/F2/latex_y0.png -------------------------------------------------------------------------------- /example_results/F2/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 20, 3 | "beta1": 0.4, 4 | "complexity_threshold": 0.01, 5 | "epoch_factor": 1000, 6 | "generate_symbolic_expr": true, 7 | "id": 1000, 8 | "kill_summaries": false, 9 | "l0_threshold": 0.05, 10 | "layer_width": 10, 11 | "learning_rate": 0.0005, 12 | "model_base_dir": "results", 13 | "model_dir": "results\\1000", 14 | "network_init_seed": null, 15 | "num_h_layers": 2, 16 | "output_bound": null, 17 | "penalty_bounds": null, 18 | "penalty_every": 50, 19 | "reg_scale": 1e-05, 20 | "reg_sched": [ 21 | 0.25, 22 | 0.95 23 | ], 24 | "test_div_threshold": 0.0001, 25 | "test_file": "data/F2data_test", 26 | "train_val_file": "data/F2data_train_val", 27 | "train_val_split": 0.9, 28 | "weight_init_param": 1.0 29 | } -------------------------------------------------------------------------------- /example_results/F2/results.csv: -------------------------------------------------------------------------------- 1 | val_error,complexity,extr_error,id 2 | 0.005967019,5,0.01763588,1000 3 | -------------------------------------------------------------------------------- /model_selection.py: -------------------------------------------------------------------------------- 1 | """ 2 | Aggregation and scoring of multiple results. 3 | - *aggregate_csv_files_recursively* collects the results in a given directory in a pandas dataframe. 4 | - *select_instance* selects the best performing model instance. It expects a pandas dataframe or a filename of a 5 | file containing the results of each model instance and selects the best instance based on validation- 6 | and extrapolation-performance or validation-performance and complexity, depending on the availability of extrapolation data. 7 | - Running *model_selection.py* executes both of these routines. 8 | """ 9 | from ast import literal_eval 10 | from os import path, walk 11 | from sys import argv 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | 17 | def select_instance(df=None, file=None, use_extrapolation=True): 18 | """ 19 | Expects a file with one row per network and columns reporting the parameters and complexity and performance 20 | First line should be the column names, col1 col2 col3..., then one additional comments line which can be empty. 21 | Third line should be the values for each column. 22 | :param df: pandas dataframe containing data about model performance 23 | :param file: file containing data about model performance, only used if dataframe is none 24 | :param use_extrapolation: flag to determine if extrapolation data should be used 25 | :return: pandas dataframe containing id and performance data of best model. 26 | """ 27 | if df is not None and file is not None: 28 | raise ValueError('Both results_df and file specified. Only specify one.') 29 | if df is None: 30 | if file is None: 31 | raise ValueError('Either results_df or file have to be specified.') 32 | df = pd.read_csv(file) 33 | if 'extr_error' in df.keys(): 34 | extr_available = not df['extr_error'].isnull().values.any() 35 | else: 36 | extr_available = False 37 | if use_extrapolation and not extr_available: 38 | raise ValueError("use_extrapolation flag is set to True but no extrapolation results were found.") 39 | 40 | if use_extrapolation: 41 | df['extr_normed'] = normalize_to_zero_one(df['extr_error']) 42 | df['val_normed'] = normalize_to_zero_one(df['val_error']) 43 | df['complexity_normed'] = normalize_to_zero_one(df['complexity'], defensive=False) 44 | 45 | if use_extrapolation: 46 | print('Extrapolation data used.') 47 | df['score'] = np.sqrt(df['extr_normed'] ** 2 + df['val_normed'] ** 2) 48 | else: 49 | print('No extrapolation data used, performing model selection based on complexity and validation instead.') 50 | df['score'] = np.sqrt(df['complexity_normed'] ** 2 + df['val_normed'] ** 2) 51 | 52 | scored_df = df.sort_values(['score']) 53 | best_instance = scored_df.iloc[[0]] 54 | return best_instance, scored_df 55 | 56 | 57 | def normalize_to_zero_one(arr, defensive=True): 58 | """ 59 | Routine that normalizes an array to zero and one. 60 | :param arr: array to be normalized 61 | :param defensive: flag to determine if behavior is defensive (if all array elements are the same raise exception) 62 | or not (if all array elements are the same return an array of same length filled with zeros) 63 | """ 64 | if np.isclose(np.max(arr), np.min(arr)): 65 | if defensive: 66 | raise ValueError('All elements in array are the same, no normalization possible.') 67 | else: 68 | return np.zeros(len(arr)) 69 | norm_arr = (arr - np.min(arr)) / (np.max(arr) - np.min(arr)) 70 | return norm_arr 71 | 72 | 73 | def aggregate_csv_files_recursively(directory, filename): 74 | """ Returns a pandas DF that is a concatenation of csvs with given filename in given directory (recursively).""" 75 | return pd.concat(_df_from_csv_recursive_generator(directory, filename)) 76 | 77 | 78 | def _df_from_csv_recursive_generator(directory, filename): 79 | """ Returns a generator producing pandas DF for each csv with given filename in given directory (recursively).""" 80 | for root, dirs, files in walk(directory): 81 | if filename in files: 82 | yield pd.read_csv(path.join(root, filename)) 83 | 84 | 85 | if __name__ == '__main__': 86 | if len(argv) > 1: 87 | passed_dict = literal_eval(argv[1]) 88 | results_path = passed_dict['results_path'] 89 | use_extrapolation = passed_dict['use_extrapolation'] 90 | else: 91 | raise ValueError('Path to results directory must be passed.') 92 | aggregated_results = aggregate_csv_files_recursively(results_path, "results.csv") 93 | best_instance, ordered_instances = select_instance(df=aggregated_results, use_extrapolation=use_extrapolation) 94 | ordered_instances.to_csv(path.join(results_path, "scored_results.csv")) 95 | print('All instances, ordered by score:\n', ordered_instances) 96 | print('Selected model instance:\n', best_instance) 97 | -------------------------------------------------------------------------------- /timeout.py: -------------------------------------------------------------------------------- 1 | import signal 2 | import threading 3 | from contextlib import contextmanager 4 | 5 | from six.moves import _thread 6 | 7 | 8 | class TimeoutException(Exception): 9 | pass 10 | 11 | 12 | @contextmanager 13 | def time_limit(seconds): 14 | if hasattr(signal, "SIGALRM"): 15 | # for Linux 16 | def signal_handler(signum, frame): 17 | raise TimeoutException("Timeout after {} seconds.".format(seconds)) 18 | 19 | signal.signal(signal.SIGALRM, signal_handler) 20 | signal.alarm(seconds) 21 | try: 22 | yield 23 | finally: 24 | signal.alarm(0) 25 | else: 26 | # for Windows 27 | timer = threading.Timer(seconds, lambda: _thread.interrupt_main()) 28 | timer.start() 29 | try: 30 | yield 31 | except KeyboardInterrupt: 32 | raise TimeoutException("Timeout after {} seconds.".format(seconds)) 33 | finally: 34 | timer.cancel() 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ Neural Network Estimator for EQL - Equation Learner """ 2 | import math 3 | import sys 4 | from collections import namedtuple 5 | 6 | import tensorflow as tf 7 | 8 | import EQL_Layer_tf as eql 9 | from data_utils import get_input_fns, extract_metadata 10 | from evaluation import set_evaluation_hook 11 | from utils import step_to_epochs, get_run_config, save_results, update_runtime_params, \ 12 | get_div_thresh_fn, get_max_episode 13 | 14 | # more network parameters are loaded from utils.py 15 | default_params = {'model_base_dir': 'results', 16 | 'id': 1, # job_id to identify jobs in result metrics file, separate model_dir for each id 17 | 'train_val_file': 'data/F1data_train_val', # Datafile containing training, validation data 18 | 'test_file': 'data/F1data_test', # Datafile containing test data, if set to None no test data is used 19 | 'epoch_factor': 1000, # max_epochs = epoch_factor * num_h_layers 20 | 'num_h_layers': 1, # number of hidden layers used in network 21 | 'generate_symbolic_expr': True, # saves final network as a latex png and symbolic graph 22 | 'kill_summaries': False, # reduces data generation, recommended when creating many jobs 23 | } 24 | 25 | 26 | class Model(object): 27 | """ Class that defines a graph for EQL. """ 28 | 29 | def __init__(self, mode, layer_width, num_h_layers, reg_sched, output_bound, weight_init_param, epoch_factor, 30 | batch_size, test_div_threshold, reg_scale, l0_threshold, train_val_split, network_init_seed=None, **_): 31 | self.train_data_size = int(train_val_split * metadata['train_val_examples']) 32 | self.width = layer_width 33 | self.num_h_layers = num_h_layers 34 | self.weight_init_scale = weight_init_param / math.sqrt(metadata['num_inputs'] + num_h_layers) 35 | self.seed = network_init_seed 36 | self.reg_start = math.floor(num_h_layers * epoch_factor * reg_sched[0]) 37 | self.reg_end = math.floor(num_h_layers * epoch_factor * reg_sched[1]) 38 | self.output_bound = output_bound or metadata['extracted_output_bound'] 39 | self.reg_scale = reg_scale 40 | self.batch_size = batch_size 41 | self.l0_threshold = l0_threshold 42 | self.is_training = (mode == tf.estimator.ModeKeys.TRAIN) 43 | div_thresh_fn = get_div_thresh_fn(self.is_training, self.batch_size, test_div_threshold, 44 | train_examples=self.train_data_size) 45 | reg_div = namedtuple('reg_div', ['repeats', 'div_thresh_fn']) 46 | self.eql_layers = [eql.EQL_Layer(sin=self.width, cos=self.width, multiply=self.width, id=self.width, 47 | weight_init_scale=self.weight_init_scale, seed=self.seed) 48 | for _ in range(self.num_h_layers)] 49 | self.eql_layers.append( 50 | eql.EQL_Layer(reg_div=reg_div(repeats=metadata['num_outputs'], div_thresh_fn=div_thresh_fn), 51 | weight_init_scale=self.weight_init_scale, seed=self.seed)) 52 | 53 | def __call__(self, inputs): 54 | global_step = tf.train.get_or_create_global_step() 55 | num_epochs = step_to_epochs(global_step, self.batch_size, self.train_data_size) 56 | l1_reg_sched = tf.multiply(tf.cast(tf.less(num_epochs, self.reg_end), tf.float32), 57 | tf.cast(tf.greater(num_epochs, self.reg_start), tf.float32)) * self.reg_scale 58 | l0_threshold = tf.cond(tf.less(num_epochs, self.reg_end), lambda: tf.zeros(1), lambda: self.l0_threshold) 59 | 60 | output = inputs 61 | for layer in self.eql_layers: 62 | output = layer(output, l1_reg_sched=l1_reg_sched, l0_threshold=l0_threshold) 63 | 64 | P_bound = (tf.abs(output) - self.output_bound) * tf.cast((tf.abs(output) > self.output_bound), dtype=tf.float32) 65 | tf.add_to_collection('Bound_penalties', P_bound) 66 | return output 67 | 68 | 69 | def model_fn(features, labels, mode, params): 70 | """ The model_fn argument for creating an Estimator. """ 71 | model = Model(mode=mode, **params) 72 | evaluation_hook.init_network_structure(model, params) 73 | global_step = tf.train.get_or_create_global_step() 74 | input_data = features 75 | predictions = model(input_data) 76 | if mode == tf.estimator.ModeKeys.TRAIN: 77 | reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 78 | reg_loss = tf.reduce_sum([tf.reduce_mean(reg_loss) for reg_loss in reg_losses], name='reg_loss_mean_sum') 79 | bound_penalty = tf.reduce_sum(tf.get_collection('Bound_penalties')) 80 | P_theta = tf.reduce_sum(tf.get_collection('Threshold_penalties')) 81 | penalty_loss = P_theta + bound_penalty 82 | mse_loss = tf.losses.mean_squared_error(labels, predictions) 83 | normal_loss = tf.losses.get_total_loss() + P_theta 84 | loss = penalty_loss if penalty_flag else normal_loss 85 | train_accuracy = tf.identity( 86 | tf.metrics.percentage_below(values=tf.abs(labels - predictions), threshold=0.02)[1], name='train_accuracy') 87 | tf.summary.scalar('total_loss', loss, family='losses') 88 | tf.summary.scalar('MSE_loss', mse_loss, family='losses') # inaccurate for penalty epochs (ignore) 89 | tf.summary.scalar('Penalty_Loss', penalty_loss, family='losses') 90 | tf.summary.scalar("Regularization_loss", reg_loss, family='losses') 91 | tf.summary.scalar('train_acc', train_accuracy, family='accuracies') # inaccurate for penalty epochs (ignore) 92 | return tf.estimator.EstimatorSpec( 93 | mode=tf.estimator.ModeKeys.TRAIN, loss=loss, 94 | train_op=tf.train.AdamOptimizer(params['learning_rate'], beta1=params['beta1']).minimize(loss, global_step)) 95 | if mode == tf.estimator.ModeKeys.EVAL: 96 | loss = tf.sqrt(tf.losses.mean_squared_error(labels, predictions)) 97 | eval_acc_metric = tf.metrics.percentage_below(values=tf.abs(labels - predictions), threshold=0.02) 98 | return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL, loss=loss, 99 | eval_metric_ops={'eval_accuracy': eval_acc_metric}) 100 | 101 | 102 | if __name__ == '__main__': 103 | tf.logging.set_verbosity(tf.logging.INFO) 104 | runtime_params = update_runtime_params(sys.argv, default_params) 105 | metadata = extract_metadata(runtime_params['train_val_file'], runtime_params['test_file']) 106 | run_config = get_run_config(runtime_params['kill_summaries']) 107 | eqlearner = tf.estimator.Estimator(model_fn=model_fn, config=run_config, model_dir=runtime_params['model_dir'], 108 | params=runtime_params) 109 | logging_hook = tf.train.LoggingTensorHook(tensors={'train_accuracy': 'train_accuracy'}, every_n_iter=1000) 110 | evaluation_hook = set_evaluation_hook(**runtime_params) 111 | max_episode = get_max_episode(**runtime_params) 112 | 113 | train_input, penalty_train_input, val_input, test_input = get_input_fns(**runtime_params, **metadata) 114 | print('One train episode equals %d normal epochs and 1 penalty epoch.' % runtime_params['penalty_every']) 115 | for train_episode in range(1, max_episode + 1): 116 | print('Train episode: %d out of %d.' % (train_episode, max_episode)) 117 | penalty_flag = True 118 | eqlearner.train(input_fn=penalty_train_input) 119 | penalty_flag = False 120 | eqlearner.train(input_fn=train_input, hooks=[logging_hook]) 121 | print('Training complete. Evaluating...') 122 | val_results = eqlearner.evaluate(input_fn=val_input, name='validation', hooks=[evaluation_hook]) 123 | results = dict(val_error=val_results['loss'], complexity=evaluation_hook.get_complexity()) 124 | if test_input is not None: # test_input function is only provided if extrapolation data is given 125 | extr_results = eqlearner.evaluate(input_fn=test_input, name='extrapolation') 126 | results['extr_error'] = extr_results['loss'] 127 | save_results(results, runtime_params) 128 | print('Model evaluated. Results:\n', results) 129 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Useful Routines used in EQL. """ 2 | import csv 3 | import inspect 4 | import json 5 | from ast import literal_eval 6 | from itertools import accumulate 7 | from os import path 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | # The following parameters should not be changed in most cases. 13 | network_parameters = {'train_val_split': .9, # how data in train_val_file is split, .9 means 90% train 10% validation 14 | 'layer_width': 10, # number of identical nodes per hidden layer 15 | 'batch_size': 20, # size of data batches used for training 16 | 'learning_rate': 5e-4, 17 | 'beta1': .4, 18 | 'l0_threshold': .05, # threshold for regularization, see paper: chapter 2.3 Reg Phases 19 | 'reg_scale': 1e-5, 20 | 'reg_sched': (.25, .95), # (reg_start, reg_end) 21 | 'output_bound': None, # output boundary for penalty epochs, if set to None it is calculated 22 | # from training/validation data 23 | 'weight_init_param': 1., 24 | 'test_div_threshold': 1e-4, # threshold for denominator in division layer used when testing 25 | 'complexity_threshold': 0.01, # determines how small a weight has to be to be considered inactive 26 | 'penalty_every': 50, # feed in penalty data for training and evaluate after every n epochs 27 | 'penalty_bounds': None, # domain boundaries for generating penalty data, if None it is calculated 28 | # from extrapolation_data (if provided) or training/validation data 29 | 'network_init_seed': None, # seed for initializing weights in network 30 | } 31 | 32 | 33 | def update_runtime_params(argv, params): 34 | """Routine to update the default parameters with network_parameters and parameters from commandline.""" 35 | params.update(network_parameters) 36 | if len(argv) > 1: 37 | params.update(literal_eval(argv[1])) 38 | params['model_dir'] = path.join(params['model_base_dir'], str(params['id'])) 39 | return params 40 | 41 | 42 | def get_max_episode(num_h_layers, epoch_factor, penalty_every, **_): 43 | """Routine to calculate the total number of training episodes 44 | (1 episode = 1 penalty epoch + *penalty_every* normal epochs""" 45 | max_episode = (num_h_layers * epoch_factor) // penalty_every 46 | if max_episode == 0: 47 | raise ValueError('Penalty_every has to be smaller than the total number of epochs.') 48 | return max_episode 49 | 50 | 51 | def step_to_epochs(global_step, batch_size, train_examples, **_): 52 | epoch = tf.div(global_step, int(train_examples / batch_size)) + 1 53 | return epoch 54 | 55 | 56 | def to_float32(list_of_arrays): 57 | return tuple([arr.astype(np.float32) for arr in list_of_arrays]) 58 | 59 | 60 | def number_of_positional_arguments(fn): 61 | params = [value.default for key, value in inspect.signature(fn).parameters.items()] 62 | return sum(1 for item in params if item == inspect.Parameter.empty) 63 | 64 | 65 | def get_run_config(kill_summaries): 66 | """ 67 | Creates run config for Estimator. 68 | :param kill_summaries: Boolean flag, if set to True run_config prevents creating too many checkpoint files. 69 | """ 70 | if kill_summaries: 71 | checkpoint_args = dict(save_summary_steps=None, save_checkpoints_secs=None, save_checkpoints_steps=1e8) 72 | else: 73 | checkpoint_args = dict(save_summary_steps=1000) 74 | session_conf = tf.ConfigProto(intra_op_parallelism_threads=1) 75 | run_config = tf.estimator.RunConfig().replace(log_step_count_steps=1000, keep_checkpoint_max=1, 76 | session_config=session_conf, **checkpoint_args) 77 | return run_config 78 | 79 | 80 | def weight_name_for_i(i, weight_type): 81 | if i == 0: 82 | return 'dense/{}:0'.format(weight_type) 83 | return 'dense_{}/{}:0'.format(i, weight_type) 84 | 85 | 86 | def save_results(results, params): 87 | """ 88 | Routine that saves the results as a csv file. 89 | :param results: dictionary containing evaluation results 90 | :param params: dict of runtime parameters 91 | """ 92 | results['id'] = params['id'] 93 | results_file = path.join(params['model_dir'], 'results.csv') 94 | with open(path.join(params['model_dir'], 'parameters.json'), 'w') as f: 95 | json.dump(params, f, sort_keys=True, indent=4) 96 | save_dict_as_csv(results, results_file) 97 | 98 | 99 | def save_dict_as_csv(dict_to_save, file_path): 100 | with open(file_path, 'w') as f: 101 | writer = csv.DictWriter(f, fieldnames=dict_to_save.keys()) 102 | writer.writeheader() 103 | writer.writerow(dict_to_save) 104 | 105 | 106 | def yield_with_repeats(iterable, repeats): 107 | """ Yield the ith item in iterable repeats[i] times. """ 108 | it = iter(iterable) 109 | for num in repeats: 110 | new_val = next(it) 111 | for i in range(num): 112 | yield new_val 113 | 114 | 115 | def yield_equal_chunks(l, n): 116 | """Yield successive n-sized chunks from l.""" 117 | for i in range(0, len(l), n): 118 | yield l[i:i + n] 119 | 120 | 121 | def iter_by_chunks(lst, chunk_lens): 122 | """ Split list into groups of given size and return an iterator of the groups. 123 | Example iter_by_chunks([1, 2, 3, 4], [1, 0, 0, 2, 1]) = ([1], [], [], [2, 3], [4]). 124 | :param lst: a list 125 | :param chunk_lens: a list specifying lengths of individual chunks 126 | :return: a generator object yielding one chunk at a time 127 | """ 128 | splits = [0] + list(accumulate(chunk_lens)) 129 | for beg, end in zip(splits[:-1], splits[1:]): 130 | yield lst[beg:end] 131 | 132 | 133 | def generate_arguments(all_args, repeats, arg_nums): 134 | """ 135 | Split all args into chunks for functions. Example: 136 | generate_arguments([0,1,2,3,4,5,6,7,8,9,10], [1, 3, 1], [2, 2, 3]) -> [(0,1), (2,5), (3,6), (4,7), (8,9,10)] 137 | :param all_args: list of all arguments 138 | :param repeats: list of number of repeats for each function group 139 | :param arg_nums: list of number of inputs for each function group 140 | :return a generator object yielding one chunk at a time 141 | """ 142 | lengths = (a * b for a, b in zip(repeats, arg_nums)) 143 | all_chunks = iter_by_chunks(all_args, lengths) 144 | for big_chunk, repeat in zip(all_chunks, repeats): 145 | yield from zip(*yield_equal_chunks(big_chunk, repeat)) 146 | 147 | 148 | def get_div_thresh_fn(is_training, batch_size, test_div_threshold, train_examples, **_): 149 | """ 150 | Returns function to calculate the division threshold from a given step. 151 | :param is_training: Boolean to decide if training threshold or test threshold is used. 152 | """ 153 | if is_training: 154 | def get_div_thresh(step): 155 | epoch = step_to_epochs(global_step=step, batch_size=batch_size, train_examples=train_examples) 156 | return 1. / tf.sqrt(tf.cast(epoch, dtype=tf.float32)) 157 | else: 158 | def get_div_thresh(step): 159 | return test_div_threshold 160 | return get_div_thresh 161 | --------------------------------------------------------------------------------