├── .gitignore ├── LICENSE ├── README.md ├── assets ├── animation.gif └── sketch.png ├── conf ├── config.yaml ├── experiment │ ├── function │ │ ├── R │ │ │ ├── R1.yaml │ │ │ ├── R2.yaml │ │ │ └── R3.yaml │ │ ├── constant │ │ │ ├── constant1.yaml │ │ │ ├── constant2.yaml │ │ │ ├── constant3.yaml │ │ │ ├── constant4.yaml │ │ │ ├── constant5.yaml │ │ │ ├── constant6.yaml │ │ │ ├── constant7.yaml │ │ │ └── constant8.yaml │ │ ├── keijzer │ │ │ ├── keijzer10.yaml │ │ │ ├── keijzer11.yaml │ │ │ ├── keijzer12.yaml │ │ │ ├── keijzer13.yaml │ │ │ ├── keijzer14.yaml │ │ │ ├── keijzer15.yaml │ │ │ ├── keijzer3.yaml │ │ │ ├── keijzer4.yaml │ │ │ ├── keijzer6.yaml │ │ │ ├── keijzer7.yaml │ │ │ ├── keijzer8.yaml │ │ │ └── keijzer9.yaml │ │ └── nguyen │ │ │ ├── nguyen1.yaml │ │ │ ├── nguyen10.yaml │ │ │ ├── nguyen11.yaml │ │ │ ├── nguyen12.yaml │ │ │ ├── nguyen2.yaml │ │ │ ├── nguyen3.yaml │ │ │ ├── nguyen4.yaml │ │ │ ├── nguyen5.yaml │ │ │ ├── nguyen6.yaml │ │ │ ├── nguyen7.yaml │ │ │ ├── nguyen8.yaml │ │ │ └── nguyen9.yaml │ ├── scorer │ │ ├── basic_scorer.yaml │ │ ├── complexity_scorer.yaml │ │ └── minmax_scorer.yaml │ ├── seed_functions │ │ ├── 2D_linear.yaml │ │ ├── 2D_mixed.yaml │ │ ├── coefficients.yaml │ │ ├── generate.yaml │ │ ├── linear.yaml │ │ └── mixed.yaml │ └── standard.yaml ├── logger │ └── default_logger.yaml └── model │ ├── base_prompt │ ├── basic_image.yaml │ ├── basic_mixed.yaml │ ├── basic_text.yaml │ ├── image_all.yaml │ ├── image_best.yaml │ └── no_info.yaml │ ├── gpt3.5-turbo.yaml │ ├── gpt4o.yaml │ ├── llama3-8b.yaml │ └── llava1.6-34b.yaml ├── current_functions.py ├── data ├── R │ ├── R1 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── R2 │ │ ├── test_points.npy │ │ └── train_points.npy │ └── R3 │ │ ├── test_points.npy │ │ └── train_points.npy ├── constant │ ├── constant1 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── constant2 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── constant3 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── constant4 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── constant5 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── constant6 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── constant7 │ │ ├── test_points.npy │ │ └── train_points.npy │ └── constant8 │ │ ├── test_points.npy │ │ └── train_points.npy ├── keijzer │ ├── keijzer10 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer11 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer12 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer13 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer14 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer15 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer3 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer4 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer6 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer7 │ │ ├── test_points.npy │ │ └── train_points.npy │ ├── keijzer8 │ │ ├── test_points.npy │ │ └── train_points.npy │ └── keijzer9 │ │ ├── test_points.npy │ │ └── train_points.npy └── nguyen │ ├── nguyen1 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen10 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen11 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen12 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen2 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen3 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen4 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen5 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen6 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen7 │ ├── test_points.npy │ └── train_points.npy │ ├── nguyen8 │ ├── test_points.npy │ └── train_points.npy │ └── nguyen9 │ ├── test_points.npy │ └── train_points.npy ├── download_model.py ├── main.py ├── models ├── __init__.py ├── hf_model.py ├── llava_model_hf.py └── openai_model.py ├── optimizer.py ├── plotter.py ├── prompts ├── OPRO │ ├── basic_image.txt │ ├── basic_mixed.txt │ ├── basic_text.txt │ ├── image_all.txt │ ├── image_best.txt │ └── no_info.txt └── seed_functions │ ├── generate_seed.txt │ └── generate_seed_image.txt ├── requirements.txt ├── scorers ├── __init__.py ├── basic_scorer.py ├── complexity_scorer.py ├── minmax_scorer.py └── scorer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **.out 2 | /__pycache__ 3 | **/__pycache__ 4 | /.ipynb_checkpoints 5 | /outputs 6 | **/outputs 7 | /runs 8 | **/runs 9 | /multirun 10 | **/multirun 11 | /openai 12 | **/profiles -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Matteo Merler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # In-Context Symbolic Regression 2 | 3 |

4 | 5 |
6 | Overview of the method. 7 |

8 | 9 | 10 | Official code implementation for the ACL 2024 Student Research Workshop paper [In-Context Symbolic Regression: Leveraging Language Models for Function Discovery](https://aclanthology.org/2024.acl-srw.49/). The proposed approach defines an iterative procedure that refines a functional form with an LLM and determines its coefficients with an external optimizer. 11 | 12 |

13 | 14 |
15 | Example of the ICSR optimization loop. 16 |

17 | 18 | ## Setup 19 | 20 | This codebase uses Python 3.11.6. 21 | 22 | First, clone the repository, then install the required packages. 23 | We recommend using a virtual environment to avoid conflicts with other Python packages: 24 | 25 | ```bash 26 | python3 -m venv .venv 27 | source .venv/bin/activate 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Usage 32 | 33 | The method is self-contained in the `main.py` script, configured using [Hydra](https://hydra.cc/) from the `conf/` directory. Before running the method, ensure that the configuration files are set up correctly. Specifically, the `root` path in the `conf/config.yaml` file should be set to the absolute path for the root directory of the repository. If models from the HuggingFace platform are used (for example, Llama3 8B), the `cache_dir` path in the `conf/model/llama3-8b.yaml` (or any other added model) should point to a directory where the models are stored. If models from OpenAI are used (for example, GPT-3.5 or GPT-4o), the `api_key` path in the `conf/model/gpt-3.5.yaml` (or any other model) should be set to the path of the API key. 34 | 35 | ```bash 36 | python3 main.py 37 | ``` 38 | 39 | This will run the method with the default configuration on the "nguyen1" experiment using Llama3 8B as the language model. Note that this requires a capable GPU. For a quick demo, we suggest using an OpenAI model, like GPT-3.5 (which requires an API key). To run the same experiment with GPT-3.5, use the following command: 40 | 41 | ```bash 42 | python3 main.py model=gpt3.5-turbo 43 | ``` 44 | 45 | ## Configuration 46 | 47 | All the configuration files are stored in the `conf/` directory. The values for each parameter in the provided configuration file reflect the ones used for the paper's experiments. The configuration files are organized as follows: 48 | 49 | - All experiments are defined in the `conf/experiment/function`, where all four benchmarks are stored. To add a new benchmark, create a new YAML file in the `conf/experiment/function` directory (or in a subdirectory if necessary) following the existing file structure (see `conf/experiment/function/nguyen/nguyen1.yaml` for an example). 50 | - All models are defined in the `conf/model` directory. By default, we include a configuration file for Llama3 8B (using the HuggingFace platform), LLaVa-NeXT, GPT-3.5 and GPT-4o. To add a new model, create a new YAML file in the `conf/model` directory following the existing file structure (see `conf/model/llama3-8b.yaml` for an example using HugoingFace models or `conf/model/gpt-3.5.yaml` for an example using OpenAI models). 51 | - The file `conf/config.yaml` contains general configuration settings, like the torch device and random seed. The `root` path should be set to the absolute path for the root directory of the repository. 52 | - The file `conf/experiment/standard.yaml` contains the default configuration for all experiments independent of the function, like the coefficient optimizer settings. 53 | - The files in the `conf/experiment/function` subdirectories contain the configuration for each individual function. These include the ground truth function and the point range for the function evaluation. For reproducibility, we also provide the exact random points we sampled for the experiments in the `data` directory. By default, all experiments are performed on this set of points and are not re-sampled. The path for each function's data points is defined in the respective function configuration file. 54 | - The files in the `conf/experiment/scorer` directory contain the configuration for the scoring functions used to evaluate the discovered functions. 55 | - The files in the `conf/experiment/seed_functions` directory contain the configuration for the seed functions used to initialize the discovered functions. In the paper we ask the LLM to generate the initial seed functions, but changing this configuration allows for a manually defined set of seed functions. 56 | - The files in the `conf/model/base_prompt` directory contain the configuration for all the prompts used by the models. In practice, the main difference between prompts is the type of images provided for vision models, used in the experiments in the Appendix of the paper. 57 | 58 | Using hydra, the configuration can be overridden from the command line. Note that there is a difference in overriding a single parameter (i.e. line) in a file (for example the number of iterations in `conf/experiment/function/nguyen/nguyen1.yaml) and overriding the entire file (used when changing the model or the experiment). The former is done with '.' separators, while the latter is done with '/' separators. For example, to run the "keijzer6" experiment with the GPT-3.5 model, 100 iterations and stop when the $R^2$ score on the training points is higher than 0.95, use the following command: 59 | 60 | ```bash 61 | python3 main.py experiment/function=keijzer/keijzer6 model=gpt3.5-turbo experiment.function.iterations=100 experiment.function.tolerance=0.95 62 | ``` 63 | 64 | To find the exact name of each parameter, check the respective configuration files. 65 | 66 | ## Code Structure 67 | 68 | The code is organized as follows: 69 | 70 | - The main entry point is the `main.py` This defines a Workspace class that loads the configuration files, sets up the points and model and runs the main loop. 71 | - `current_functions.py` contains an helper class to manage the best functions found during the optimization and add them to the context. 72 | - `optimizer.py` contains the optimizer class that is used to optimize the coefficients of the functions. 73 | - `plotter.py` contains the class used to plot the results of the experiments and to give inputs to the visual models. 74 | - The `scorers` directory contains the classes used to evaluate the discovered functions. 75 | - The `models` directory contains the classes used to interact with the language models. 76 | 77 | ## Citation 78 | 79 | If you find this code useful, please consider citing our paper: 80 | 81 | ```bibtex 82 | @inproceedings{merler-etal-2024-context, 83 | title = "In-Context Symbolic Regression: Leveraging Large Language Models for Function Discovery", 84 | author = "Merler, Matteo and 85 | Haitsiukevich, Katsiaryna and 86 | Dainese, Nicola and 87 | Marttinen, Pekka", 88 | editor = "Fu, Xiyan and 89 | Fleisig, Eve", 90 | booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 4: Student Research Workshop)", 91 | month = aug, 92 | year = "2024", 93 | address = "Bangkok, Thailand", 94 | publisher = "Association for Computational Linguistics", 95 | url = "https://aclanthology.org/2024.acl-srw.49", 96 | doi = "10.18653/v1/2024.acl-srw.49", 97 | pages = "589--606" 98 | } 99 | 100 | ``` 101 | -------------------------------------------------------------------------------- /assets/animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/assets/animation.gif -------------------------------------------------------------------------------- /assets/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/assets/sketch.png -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: llama3-8b 3 | - experiment: standard 4 | - logger: default_logger 5 | - _self_ 6 | 7 | # Experiment 8 | output_dir: runs # Output directory where all runs will be saved, with the following structure: output_dir/benchmark_name/experiment_name/model_name/run_id 9 | max_retries: 5 # Maximum number of retries if prompt failed to generate a valid function as the output 10 | force_valid: false # Force the output to be valid (i.e. a valid function). If false, when the model fails to generate a valid function, after max_retries, the output will be the best generated function so far 11 | force_unique: false # Force the output to be unique (i.e. different from all the functions in the prompt) 12 | prompts_path: prompts 13 | max_points_in_prompt: 40 # Maximum number of points in the prompt (if more are provided, they will automatically be downsampled) 14 | checkpoints: [50, 100, 200, 300, 400, 500, 600, 700, 800, 900] # Partial results will be saved at these iterations 15 | 16 | # Torch 17 | device: 'auto' # auto works for both CPU and GPU and can be used in a multi-GPU setup 18 | use_bfloat16: false 19 | seed: -1 # If -1, the seed will be randomly generated 20 | 21 | # Project root 22 | root: ??? # Path to the root of the project, where the 'conf', 'data' directories are located and where main.py is executed 23 | 24 | # Plotter 25 | plotter: 26 | save_video: true 27 | save_frames: false 28 | gif_duration: 1000 29 | plotter_resolution: 1000 30 | plotter_fig_size: 10 -------------------------------------------------------------------------------- /conf/experiment/function/R/R1.yaml: -------------------------------------------------------------------------------- 1 | name: "R1" 2 | group: "R" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/R/R1" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 20 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "(x+1)^3/(x^2-x+1)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/R/R2.yaml: -------------------------------------------------------------------------------- 1 | name: "R2" 2 | group: "R" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/R/R2" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 20 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "(x^5-3*x^3+1)/(x^2+1)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/R/R3.yaml: -------------------------------------------------------------------------------- 1 | name: "R3" 2 | group: "R" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/R/R3" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 20 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "(x^6+x^5)/(x^4+x^3+x^2+x+1)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant1.yaml: -------------------------------------------------------------------------------- 1 | name: "constant1" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant1" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "3.39*x^3 + 2.12*x^2 + 1.78*x" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant2.yaml: -------------------------------------------------------------------------------- 1 | name: "constant2" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant2" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sin(x^2) * cos(x) - 0.75" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant3.yaml: -------------------------------------------------------------------------------- 1 | name: "constant3" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant3" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: [-1, -1] 11 | max_points: [1, 1] 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: [-1, -1] 20 | max_points: [1, 1] 21 | num_points: 500 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sin(1.5*x1)*cos(0.5*x2)" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant4.yaml: -------------------------------------------------------------------------------- 1 | name: "constant4" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant4" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: [0, 0] 11 | max_points: [1, 1] 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: [0, 0] 20 | max_points: [1, 1] 21 | num_points: 500 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "2.7*x1^x2" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant5.yaml: -------------------------------------------------------------------------------- 1 | name: "constant5" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant5" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: 0 11 | max_points: 4 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 4 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sqrt(1.23*x)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant6.yaml: -------------------------------------------------------------------------------- 1 | name: "constant6" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant6" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: 0 11 | max_points: 4 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 4 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x^(0.426)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant7.yaml: -------------------------------------------------------------------------------- 1 | name: "constant7" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant7" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: [-1, -1] 11 | max_points: [1, 1] 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: [-1, -1] 20 | max_points: [1, 1] 21 | num_points: 500 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "2*sin(1.3*x1) + cos(x2)" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/constant/constant8.yaml: -------------------------------------------------------------------------------- 1 | name: "constant8" 2 | group: "constant" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/constant/constant8" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: 0 11 | max_points: 2 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 2 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "ln(x+1.4) + ln(x^2+1.3)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer10.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer10" 2 | group: "keijzer" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer10" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: 0 11 | max_points: 1 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 1 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x1^x2" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer11.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer11" 2 | group: "keijzer" 3 | 4 | # Test points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer11" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -3 11 | max_points: 3 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -3 20 | max_points: 3 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x1*x2 + sin((x1 - 1) * (x2 - 1))" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer12.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer12" 2 | group: "keijzer" 3 | 4 | # Test points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer12" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -3 11 | max_points: 3 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -3 20 | max_points: 3 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x1^4 - x1^3 + (x2^2)/2 - x2" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer13.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer13" 2 | group: "keijzer" 3 | 4 | # Test points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer13" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -3 11 | max_points: 3 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -3 20 | max_points: 3 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "6*sin(x1)*cos(x2)" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer14.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer14" 2 | group: "keijzer" 3 | 4 | # Test points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer14" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -3 11 | max_points: 3 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -3 20 | max_points: 3 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "8/(2+x1^2+x2^2)" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer15.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer15" 2 | group: "keijzer" 3 | 4 | # Test points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer15" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -3 11 | max_points: 3 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -3 20 | max_points: 3 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x1^3/5 + x2^3/2 - x2 - x1" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer3.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer3" 2 | group: "keijzer" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer3" 8 | # This part is only used if generate_points is true 9 | random_points: false 10 | min_points: -1 11 | max_points: 1 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 10000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "0.3*x * sin(2*pi*x)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 100 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer4.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer4" 2 | group: "keijzer" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer4" 8 | # This part is only used if generate_points is true 9 | random_points: false 10 | min_points: 0 11 | max_points: 10 12 | num_points: 200 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0.05 20 | max_points: 10.05 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x^3*exp(-x)*cos(x)*sin(x)*(sin(x)^2*cos(x)-1)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer6.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer6" 2 | group: "keijzer" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer6" 8 | # This part is only used if generate_points is true 9 | random_points: false 10 | min_points: -1 11 | max_points: 1 12 | num_points: 50 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 100 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "(x*(x+1))/2" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer7.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer7" 2 | group: "keijzer" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer7" 8 | # This part is only used if generate_points is true 9 | random_points: false 10 | min_points: 1 11 | max_points: 100 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 1 20 | max_points: 100 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "ln(x)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer8.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer8" 2 | group: "keijzer" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer8" 8 | # This part is only used if generate_points is true 9 | random_points: false 10 | min_points: 0 11 | max_points: 100 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 100 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sqrt(x)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/keijzer/keijzer9.yaml: -------------------------------------------------------------------------------- 1 | name: "keijzer9" 2 | group: "keijzer" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/keijzer/keijzer9" 8 | # This part is only used if generate_points is true 9 | random_points: false 10 | min_points: 0 11 | max_points: 100 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 100 21 | num_points: 1000 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "ln(x + sqrt(x^2 + 1))" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 5 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen1.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen1" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen1" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x^3 + x^2 + x" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen10.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen10" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen10" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: [-1, -1] 11 | max_points: [1, 1] 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: [-1, -1] 20 | max_points: [1, 1] 21 | num_points: 500 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "2*sin(x1)*cos(x2)" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen11.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen11" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen11" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: [0, 0] 11 | max_points: [1, 1] 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: [0, 0] 20 | max_points: [1, 1] 21 | num_points: 500 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x1^x2" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen12.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen12" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen12" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: [-1, -1] 11 | max_points: [1, 1] 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: [-1, -1] 20 | max_points: [1, 1] 21 | num_points: 500 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x1^4 - x1^3 + (1/2)*x2^2 - x2" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen2.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen2" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen2" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x^4 + x^3 + x^2 + x" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen3.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen3" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen3" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x^5 + x^4 + x^3 + x^2 + x" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen4.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen4" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen4" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "x^6 + x^5 + x^4 + x^3 + x^2 + x" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen5.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen5" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen5" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sin(x^2) * cos(x) - 1" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen6.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen6" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen6" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: -1 11 | max_points: 1 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: -1 20 | max_points: 1 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sin(x) + sin(x + x^2)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen7.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen7" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen7" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: 0 11 | max_points: 2 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 2 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "log(x + 1) + log(x^2 + 1)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen8.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen8" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen8" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: 0 11 | max_points: 4 12 | num_points: 20 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: 0 20 | max_points: 4 21 | num_points: 200 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sqrt(x)" 27 | num_variables: 1 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/function/nguyen/nguyen9.yaml: -------------------------------------------------------------------------------- 1 | name: "nguyen9" 2 | group: "nguyen" 3 | 4 | # Train points 5 | train_points: 6 | generate_points: false 7 | data_folder: "data/nguyen/nguyen9" 8 | # This part is only used if generate_points is true 9 | random_points: true 10 | min_points: [-1, -1] 11 | max_points: [1, 1] 12 | num_points: 100 13 | xs_noise_std: 0 14 | ys_noise_std: 0 15 | add_extremes: true # Manually adds points at the extremes of the interval 16 | 17 | # Test points 18 | test_points: 19 | min_points: [-1, -1] 20 | max_points: [1, 1] 21 | num_points: 500 22 | 23 | tolerance: 0.99999 24 | 25 | # Test function 26 | test_function: "sin(x1) + sin(x2^2)" 27 | num_variables: 2 28 | 29 | # Total iterations 30 | iterations: 10 -------------------------------------------------------------------------------- /conf/experiment/scorer/basic_scorer.yaml: -------------------------------------------------------------------------------- 1 | name: basic_scorer 2 | rounding: 8 3 | scientific: false 4 | normalize: false 5 | lambda: 0 -------------------------------------------------------------------------------- /conf/experiment/scorer/complexity_scorer.yaml: -------------------------------------------------------------------------------- 1 | name: complexity_scorer 2 | rounding: 8 3 | scientific: false 4 | normalize: false 5 | lambda: 0.05 6 | max_nodes: 30 7 | alternative: false # Use alternative scorer with custom function -------------------------------------------------------------------------------- /conf/experiment/scorer/minmax_scorer.yaml: -------------------------------------------------------------------------------- 1 | name: minmax_scorer 2 | rounding: 5 3 | scientific: false 4 | normalize: true 5 | 6 | # Normalization 7 | min_score: 1 8 | max_score: 10 -------------------------------------------------------------------------------- /conf/experiment/seed_functions/2D_linear.yaml: -------------------------------------------------------------------------------- 1 | # Initial functions 2 | functions: ["2x1 + - 3x2 + 1", "-6x1 + 4x2 - 4", "3x1 + x2 + 2", "-10x1 - 14x2 + 21", "x1 + x2", "-x1 + x2", "-3 - 4x1", "x2 - 6"] -------------------------------------------------------------------------------- /conf/experiment/seed_functions/2D_mixed.yaml: -------------------------------------------------------------------------------- 1 | # Initial functions 2 | functions: ["2x1 + 3x2 + 1", "-6x1 + x2 - 4", "3x1^2 - 7x2^2 + 2x1 + x2 + 1", "-2x1^2 + 3x1 - x2 - 6", "4x1^3 + x2^3 + 3x1^2 - 5x2^2 + 2x1 + 1", "sin(x1) - cos(x2) + 2x1 + 1", "4cos(x1) + 2x1 + x2^2 + 1", "-4x2^2 + 2x1 - 3", "2x1^2 + 5x2^2-7x1+5x2-6"] -------------------------------------------------------------------------------- /conf/experiment/seed_functions/coefficients.yaml: -------------------------------------------------------------------------------- 1 | # Initial functions 2 | functions: [c0*x1^3 + c1*x1, c0*x1 + c1*log(x1), c0*sin(x1) + c1*cos(x1), c0*exp(x1)*x1, c0*x1/x1^2 + c1, c0*x1^4 + c1*x1^3] -------------------------------------------------------------------------------- /conf/experiment/seed_functions/generate.yaml: -------------------------------------------------------------------------------- 1 | functions: [] 2 | max_tries: 10 3 | generation_tokens: 512 -------------------------------------------------------------------------------- /conf/experiment/seed_functions/linear.yaml: -------------------------------------------------------------------------------- 1 | # Initial functions 2 | functions: ["2x + 1", "-6x - 4", "3x + 2", "-10x + 21", "x", "-x", "-3 - 4x", "13x - 6"] -------------------------------------------------------------------------------- /conf/experiment/seed_functions/mixed.yaml: -------------------------------------------------------------------------------- 1 | # Initial functions 2 | functions: ["2x + 1", "-6x - 4", "3x^2 + 2x + 1", "-2x^2 + 3x - 6", "4x^3 + 3x^2 + 2x + 1", "sin(x) + 2x + 1", "4cos(x) + 2x + 1", "-4x^3 + 2x - 3"] -------------------------------------------------------------------------------- /conf/experiment/standard.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - function: nguyen/nguyen1 3 | - seed_functions: generate 4 | - scorer: complexity_scorer 5 | - _self_ 6 | 7 | # Seed functions 8 | generate_seed_functions: true 9 | 10 | # Optimizer 11 | optimizer: 12 | optimizer_threads: 5 # Number of threads to use for the optimizer (each with a different initial value for the coefficients) 13 | timeout: 10 # Timeout in seconds, after which the optimizer will stop 14 | p0_min: -5 # Lower bound for the initial value of the coefficients 15 | p0_max: 5 # Upper bound for the initial value of the coefficients 16 | coeff_rounding: 4 # Number of decimal places to round the coefficients to 17 | tol: 1e-3 # Tolerance for the optimizer, under which the coefficients are considered to be zero -------------------------------------------------------------------------------- /conf/logger/default_logger.yaml: -------------------------------------------------------------------------------- 1 | loggers: ["console", "file"] 2 | level: "INFO" 3 | run_id: ??? -------------------------------------------------------------------------------- /conf/model/base_prompt/basic_image.yaml: -------------------------------------------------------------------------------- 1 | prompt: basic_image.txt 2 | 3 | # Type of input image (points, best_guess, all_guesses) 4 | input_image: points 5 | 6 | # Number of functions to keep in the prompt 7 | prompt_size: 5 -------------------------------------------------------------------------------- /conf/model/base_prompt/basic_mixed.yaml: -------------------------------------------------------------------------------- 1 | prompt: basic_mixed.txt 2 | 3 | # Type of input image (points, best_guess, all_guesses) 4 | input_image: points 5 | 6 | # Number of functions to keep in the prompt 7 | prompt_size: 5 -------------------------------------------------------------------------------- /conf/model/base_prompt/basic_text.yaml: -------------------------------------------------------------------------------- 1 | prompt: basic_text.txt 2 | 3 | # Number of functions to keep in the prompt 4 | prompt_size: 5 -------------------------------------------------------------------------------- /conf/model/base_prompt/image_all.yaml: -------------------------------------------------------------------------------- 1 | prompt: image_all.txt 2 | 3 | # Type of input image (points, best_guess, all_guesses) 4 | input_image: all_guesses 5 | 6 | # Number of functions to keep in the prompt 7 | prompt_size: 5 -------------------------------------------------------------------------------- /conf/model/base_prompt/image_best.yaml: -------------------------------------------------------------------------------- 1 | prompt: image_best.txt 2 | 3 | # Type of input image (points, best_guess, all_guesses) 4 | input_image: best_guess 5 | 6 | # Number of functions to keep in the prompt 7 | prompt_size: 5 -------------------------------------------------------------------------------- /conf/model/base_prompt/no_info.yaml: -------------------------------------------------------------------------------- 1 | prompt: "no_info.txt" 2 | 3 | # Number of functions to keep in the prompt 4 | prompt_size: 0 -------------------------------------------------------------------------------- /conf/model/gpt3.5-turbo.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_prompt: basic_text 3 | - _self_ 4 | 5 | # General 6 | name: gpt-3.5-turbo 7 | tokenizer_pad: \\[PAD\\] 8 | tokenizer_padding_side: left 9 | visual: false 10 | cache_dir: '' 11 | 12 | # Seed functions prompt 13 | seed_function_prompt: seed_functions/generate_seed.txt 14 | 15 | # Sampling 16 | max_new_tokens: 2048 17 | top_p: 0.90 18 | top_k: 60 19 | num_beams: 1 20 | 21 | # Sampling temperature 22 | temperature: 1.0 23 | temperature_schedule: false 24 | temperature_schedule_gamma: 0.995 25 | 26 | # OpenAI API 27 | api_key_path: openai/openai_key 28 | organization_id_path: openai/openai_org -------------------------------------------------------------------------------- /conf/model/gpt4o.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_prompt: basic_text 3 | - _self_ 4 | 5 | # General 6 | name: gpt-4o 7 | tokenizer_pad: \\[PAD\\] 8 | tokenizer_padding_side: left 9 | visual: false 10 | cache_dir: '' 11 | 12 | # Seed functions prompt 13 | seed_function_prompt: seed_functions/generate_seed.txt 14 | 15 | # Sampling 16 | max_new_tokens: 2048 17 | top_p: 0.90 18 | top_k: 60 19 | num_beams: 1 20 | 21 | # Sampling temperature 22 | temperature: 1.0 23 | temperature_schedule: false 24 | temperature_schedule_gamma: 0.995 25 | 26 | # OpenAI API 27 | api_key_path: openai/openai_key 28 | organization_id_path: openai/openai_org -------------------------------------------------------------------------------- /conf/model/llama3-8b.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_prompt: basic_text 3 | - _self_ 4 | 5 | # General 6 | name: meta-llama/Meta-Llama-3-8B-Instruct 7 | tokenizer_pad: \\[PAD\\] 8 | tokenizer_padding_side: left 9 | visual: false 10 | cache_dir: '' 11 | 12 | # Seed functions prompt 13 | seed_function_prompt: seed_functions/generate_seed.txt 14 | 15 | # Sampling 16 | max_new_tokens: 512 17 | top_p: 0.90 18 | top_k: 60 19 | num_beams: 1 20 | 21 | # Sampling temperature 22 | temperature: 1.0 23 | temperature_schedule: false 24 | temperature_schedule_gamma: 0.995 -------------------------------------------------------------------------------- /conf/model/llava1.6-34b.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_prompt: image_best 3 | - _self_ 4 | 5 | # General 6 | name: llava-hf/llava-v1.6-34b-hf 7 | tokenizer_pad: \\[PAD\\] 8 | tokenizer_padding_side: left 9 | visual: true 10 | cache_dir: '' 11 | 12 | # Seed functions prompt 13 | seed_function_prompt: seed_functions/generate_seed_image.txt 14 | 15 | # Sampling 16 | max_new_tokens: 512 17 | top_p: 0.90 18 | top_k: 60 19 | num_beams: 1 20 | 21 | # Sampling temperature 22 | temperature: 1.0 23 | temperature_schedule: false 24 | temperature_schedule_gamma: 0.995 -------------------------------------------------------------------------------- /current_functions.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import sympy 4 | import utils 5 | import numpy as np 6 | 7 | from typing import Any, Dict 8 | 9 | class CurrentFunctions(object): 10 | """ 11 | Helper class to manage the current functions in the prompt. 12 | """ 13 | 14 | def __init__(self, seed_functions, scorer, optimizer, context_len, logger, num_variables) -> None: 15 | """ 16 | Initialize the class. 17 | 18 | Parameters 19 | ---------- 20 | seed_functions -> the seed functions to use. 21 | scorer -> the scorer to use. 22 | context_len -> the length of the context. 23 | logger -> the logger to use. 24 | num_variables -> the number of variables. 25 | """ 26 | 27 | self.seed_functions = seed_functions 28 | self.scorer = scorer 29 | self.optimizer = optimizer 30 | self.context_len = context_len 31 | self.logger = logger 32 | self.num_variables = num_variables 33 | functions = [utils.string_to_function(name, self.num_variables) for name in self.seed_functions.keys()] 34 | self.logger.info(f"Seed functions: {functions}.") 35 | self.functions = {} 36 | self.scores = {} 37 | self.norm_scores = {} 38 | self.screen_names = {} 39 | 40 | # Optimize seed function coefficients 41 | for function in functions: 42 | try: 43 | optimized_function, coeff_function = self.optimizer.optimize(function, return_coeff=True, quiet=False) 44 | if self.func_in_list(coeff_function): 45 | self.logger.warning(f"Function {coeff_function} already in prompt.") 46 | continue 47 | self.functions[coeff_function] = optimized_function 48 | self.logger.info(f"Optimized seed function: {str(coeff_function)}.") 49 | except Exception as e: 50 | self.logger.warning(f"Could not optimize function {function}. {e}") 51 | pass 52 | self.logger.info(f"Optimized seed functions: {self.functions}.") 53 | if len(self.functions) == 0: 54 | self.logger.warning("Failed to optimize all seed functions. Function list will be empty.") 55 | else: 56 | self.scores, self.norm_scores = self.scorer.score_current_functions(self.functions) 57 | self.clean_scores() 58 | self.screen_names = {function: re.sub(r'c\d+', 'c', str(function)) for function in self.functions} 59 | 60 | self.logger.info(f"Current scores: {self.scores}.") 61 | self.logger.info(f"Current normalized scores: {self.norm_scores}.") 62 | 63 | def func_in_list(self, function: Any) -> bool: 64 | """ 65 | Checks if a function is already in the prompt by assigning the same symbol to all coefficients. 66 | 67 | Parameters 68 | ---------- 69 | function -> the function to check. 70 | 71 | Returns 72 | ------- 73 | bool -> whether the function is already in the prompt or not. 74 | """ 75 | symbols = set(function.free_symbols) 76 | for f in self.functions: 77 | symbols = symbols | set(f.free_symbols) 78 | coeffs = [s for s in symbols if str(s).startswith("c")] 79 | subs = {c: sympy.Symbol('c') for c in coeffs} 80 | function = function.subs(subs) 81 | for f in self.functions: 82 | f = f.subs(subs) 83 | if utils.func_equals(f, function, self.num_variables): 84 | return True 85 | return False 86 | 87 | def clean_scores(self) -> None: 88 | """ 89 | Remove eventual inf scores from the scores. 90 | """ 91 | print(f"Started cleaning scores {self.scores}.") 92 | removals = [] 93 | removals = [function for function in self.scores if self.scores[function] == np.inf] 94 | removals += [function for function in self.norm_scores if self.norm_scores[function] == np.inf and function not in removals] 95 | 96 | for function in removals: 97 | self.logger.warning(f"Removing function {function} with score {self.scores[function]} ({self.norm_scores[function]}) from the prompt.") 98 | self.functions.pop(function) 99 | self.scores.pop(function) 100 | self.norm_scores.pop(function) 101 | 102 | print(f"Finished cleaning scores {self.scores}.") 103 | 104 | def add_function(self, expr: Any, function: Any) -> None: 105 | """ 106 | Adds a function to the current functions. 107 | 108 | Parameters 109 | ---------- 110 | expr -> the coefficient form of the function. 111 | function -> the function to add. 112 | """ 113 | self.logger.info(f"Adding function {expr}.") 114 | 115 | # Check if the function is already in the prompt, necessary if force_unique is False 116 | if self.func_in_list(expr): 117 | self.logger.info(f"Function {expr} already in prompt.") 118 | return 119 | 120 | if len(self.scores) >= self.context_len and self.scorer.score(function) > np.max(list(self.scores.values())): 121 | self.logger.info(f"Function {expr} has score {self.scorer.score(function)}, which is higher than the current worst score {np.max(list(self.scores.values()))}.") 122 | return 123 | 124 | self.functions[expr] = function 125 | self.screen_names[expr] = re.sub(r'c\d+', 'c', str(expr)) 126 | self.scores, self.norm_scores = self.scorer.score_current_functions(self.functions) 127 | self.clean_scores() 128 | 129 | # Remove the worst function if the context is full 130 | if len(self.functions) > self.context_len: 131 | worst_function = sorted(self.scores.items(), key=lambda x: x[1], reverse=True)[0][0] 132 | self.logger.info(f"Removing function {worst_function}.") 133 | self.functions.pop(worst_function) 134 | self.screen_names.pop(worst_function) 135 | self.scores.pop(worst_function) 136 | self.norm_scores.pop(worst_function) 137 | 138 | self.logger.info(f"Current scores: {self.scores}.") 139 | self.logger.info(f"Current normalized scores: {self.norm_scores}.") 140 | 141 | def get_best_function(self, return_coeff: bool = True) -> str: 142 | """ 143 | Gets the best function in the current functions. 144 | 145 | Returns 146 | ------- 147 | best_function -> the best function in the current functions. 148 | return_coeff -> whether to return the function in coefficient form. 149 | """ 150 | best_function = sorted(self.scores.items(), key=lambda x: x[1])[0][0] 151 | if return_coeff: 152 | return best_function 153 | else: 154 | return self.functions[best_function] 155 | 156 | def get_prompt_functions(self) -> Dict[str, float]: 157 | """ 158 | Gets the prompt functions (from the normalized scores) 159 | 160 | Returns 161 | ------- 162 | prompt_functions -> the current functions. 163 | """ 164 | top_functions = sorted(self.norm_scores.items(), key=lambda x: x[1]) 165 | top_functions = top_functions[:self.context_len] 166 | top_functions = sorted(top_functions, key=lambda x: x[1], reverse=True) 167 | return top_functions 168 | 169 | def get_prompt(self, base_prompt: str) -> str: 170 | """ 171 | Generates a prompt for the model, by appending the current functions and their scores to a base prompt. 172 | 173 | Parameters 174 | ---------- 175 | base_prompt -> the base prompt to append to. 176 | 177 | Returns 178 | ------- 179 | prompt -> the prompt to use for the model. 180 | """ 181 | top_functions = self.get_prompt_functions() 182 | functions = "\n".join([f'Function: {self.screen_names[function_name]}\nError: {fit}\n' for function_name, fit in top_functions]) 183 | functions += "\nNew Functions: " 184 | prompt = base_prompt.format(functions=functions) 185 | return prompt -------------------------------------------------------------------------------- /data/R/R1/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R1/test_points.npy -------------------------------------------------------------------------------- /data/R/R1/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R1/train_points.npy -------------------------------------------------------------------------------- /data/R/R2/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R2/test_points.npy -------------------------------------------------------------------------------- /data/R/R2/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R2/train_points.npy -------------------------------------------------------------------------------- /data/R/R3/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R3/test_points.npy -------------------------------------------------------------------------------- /data/R/R3/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R3/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant1/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant1/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant1/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant1/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant2/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant2/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant2/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant2/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant3/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant3/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant3/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant3/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant4/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant4/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant4/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant4/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant5/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant5/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant5/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant5/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant6/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant6/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant6/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant6/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant7/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant7/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant7/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant7/train_points.npy -------------------------------------------------------------------------------- /data/constant/constant8/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant8/test_points.npy -------------------------------------------------------------------------------- /data/constant/constant8/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant8/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer10/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer10/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer10/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer10/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer11/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer11/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer11/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer11/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer12/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer12/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer12/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer12/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer13/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer13/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer13/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer13/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer14/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer14/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer14/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer14/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer15/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer15/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer15/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer15/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer3/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer3/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer3/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer3/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer4/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer4/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer4/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer4/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer6/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer6/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer6/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer6/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer7/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer7/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer7/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer7/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer8/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer8/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer8/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer8/train_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer9/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer9/test_points.npy -------------------------------------------------------------------------------- /data/keijzer/keijzer9/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer9/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen1/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen1/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen1/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen1/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen10/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen10/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen10/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen10/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen11/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen11/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen11/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen11/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen12/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen12/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen12/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen12/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen2/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen2/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen2/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen2/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen3/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen3/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen3/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen3/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen4/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen4/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen4/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen4/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen5/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen5/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen5/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen5/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen6/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen6/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen6/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen6/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen7/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen7/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen7/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen7/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen8/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen8/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen8/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen8/train_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen9/test_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen9/test_points.npy -------------------------------------------------------------------------------- /data/nguyen/nguyen9/train_points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen9/train_points.npy -------------------------------------------------------------------------------- /download_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import time 5 | 6 | from transformers import AutoTokenizer, AutoModelForCausalLM 7 | 8 | def main(): 9 | """ 10 | Download a pretrained model from HuggingFace and save it to the hf cache directory. 11 | """ 12 | parser = argparse.ArgumentParser(description="Download a pretrained model from HuggingFace.") 13 | parser.add_argument("model_name", type=str, help="The name of the model to download.") 14 | parser.add_argument("--hf_cache", type=str, default="models/cache/", help="The path to save the model.") 15 | parser.add_argument("--hf_token", type=str, default=None, help="Huggingface auth token.") 16 | 17 | args = parser.parse_args() 18 | dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | print(f"Downloading model {args.model_name} to {args.hf_cache}...") 22 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=dtype, cache_dir=args.hf_cache, token=args.hf_token, local_files_only=False) 23 | model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=dtype, cache_dir=args.hf_cache, 24 | token=args.hf_token, local_files_only=False, device_map='auto') 25 | print(f"Model {args.model_name} downloaded and saved to {args.hf_cache}.") 26 | 27 | # print model vocab size 28 | print(f"Model vocab size: {tokenizer.vocab_size}") 29 | 30 | special_token_dict = tokenizer.special_tokens_map 31 | tokenizer.add_special_tokens(special_token_dict) 32 | model.resize_token_embeddings(len(tokenizer)) 33 | print(f"Model vocab size after resizing: {tokenizer.vocab_size}") 34 | 35 | print(f"Sample inference:") 36 | prompt = "Tell me a joke about Symbolic Regression." 37 | start = time.perf_counter() 38 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device) 39 | input_ids = inputs.input_ids 40 | attention_mask = inputs.attention_mask 41 | outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=5000, do_sample=True, temperature=0.9, top_k=50, top_p=0.9, num_return_sequences=1) 42 | print(f"Prompt: {prompt}") 43 | output = tokenizer.decode(outputs[0], skip_special_tokens=True) 44 | output = output[len(prompt):] 45 | end = time.perf_counter() 46 | print(f"Output: {output}") 47 | 48 | print(f"Model was using device {model.device}") 49 | print(f"Model took {end-start:.2f} seconds to generate the output.") 50 | 51 | if __name__ == "__main__": 52 | main() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import copy 5 | import time 6 | import datetime 7 | import warnings 8 | import signal 9 | import cProfile 10 | 11 | import hydra 12 | import torch 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import utils 16 | 17 | from transformers import set_seed 18 | from omegaconf import OmegaConf, DictConfig, listconfig 19 | from sklearn.metrics import r2_score 20 | 21 | from plotter import Plotter 22 | from optimizer import Optimizer 23 | from current_functions import CurrentFunctions 24 | from scorers import BasicScorer, MinMaxScorer, ComplexityScorer 25 | from mloggers import ConsoleLogger, FileLogger, MultiLogger, LogLevel 26 | 27 | from typing import Dict, Tuple, List, Any 28 | from collections.abc import Callable 29 | 30 | 31 | class Workspace(object): 32 | """ 33 | Workspace class for running the symbolic regression experiment. 34 | """ 35 | def __init__(self, cfg: DictConfig) -> None: 36 | self.cfg = cfg 37 | 38 | # Output setup 39 | self.root_dir = cfg.get("root", os.getcwd()) 40 | self.output_dir = cfg.get("output_dir", "output") 41 | model_folder_name = cfg.model.name.strip() 42 | if "/" in model_folder_name: 43 | model_folder_name = model_folder_name.split("/")[-1] 44 | experiment_folder_name = os.path.join(cfg.experiment.function.group, cfg.experiment.function.name) if hasattr(cfg.experiment.function, "group") else cfg.experiment.function.name 45 | self.output_path = os.path.join(self.root_dir, self.output_dir, experiment_folder_name, model_folder_name, datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/") 46 | while os.path.exists(self.output_path): 47 | self.output_path = os.path.join(self.root_dir, self.output_dir, experiment_folder_name, model_folder_name, datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + str(np.random.randint(0, 1000)) + "/") 48 | os.makedirs(self.output_path) 49 | 50 | # Logger setup 51 | cfg.logger.run_id = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 52 | loggers_list = cfg.logger.loggers 53 | log_level = LogLevel[cfg.logger.get("level", "INFO")] 54 | loggers = [] 55 | for logger in loggers_list: 56 | if logger == "console": 57 | loggers.append(ConsoleLogger(default_priority=log_level)) 58 | elif logger == "file": 59 | loggers.append(FileLogger(os.path.join(self.output_path, 'log.json'), default_priority=log_level)) 60 | elif logger == "": 61 | pass 62 | else: 63 | print(f'[WARNING] Logger "{logger}" is not supported') 64 | self.logger = MultiLogger(loggers, default_priority=log_level) 65 | self.logger.info(f"Project root: {self.root_dir}.") 66 | self.logger.info(f"Logging to {self.output_path}.") 67 | job_id = utils.get_job_id() 68 | self.logger.info(f"Slurm job ID: {job_id}.") if job_id is not None else None 69 | 70 | # Redirect warnings to logger 71 | warnings.filterwarnings("default") 72 | warnings.showwarning = lambda *args, **kwargs: self.logger.warning(str(args[0])) 73 | 74 | # RNG setup 75 | if not hasattr(cfg, "seed") or cfg.seed is None or cfg.seed == -1: 76 | self.cfg.seed = np.random.randint(0, np.iinfo(np.int32).max) 77 | self.logger.info(f"Seed not specified, using random seed: {self.cfg.seed}.") 78 | else: 79 | self.logger.info(f"Using seed: {self.cfg.seed}.") 80 | 81 | np.random.seed(self.cfg.seed) 82 | torch.manual_seed(self.cfg.seed) 83 | torch.cuda.manual_seed_all(self.cfg.seed) if torch.cuda.is_available() else None 84 | set_seed(self.cfg.seed) 85 | 86 | if torch.cuda.is_available(): 87 | torch.cuda.init() 88 | 89 | if cfg.device == "auto": 90 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 91 | else: 92 | self.device = torch.device(cfg.device) 93 | 94 | if cfg.get("use_bfloat16", False): 95 | self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 96 | else: 97 | self.dtype = torch.float16 98 | 99 | self.logger.info(f"Using device: {self.device} with dtype: {self.dtype}.") 100 | if torch.cuda.is_available() and ('cuda' in cfg.device or 'auto' in cfg.device): 101 | self.logger.info(f"Device name: {torch.cuda.get_device_name()}.") 102 | 103 | self.cache_dir = self.cfg.model.get("cache_dir", os.environ.get("HF_HOME", None)) 104 | if self.cache_dir == "": 105 | self.cache_dir = os.environ.get("HF_HOME", None) 106 | 107 | if self.cache_dir is not None: 108 | os.environ['HF_HOME'] = self.cache_dir 109 | os.environ['TRANSFORMERS_CACHE'] = os.environ['HF_HOME'] 110 | self.logger.info(f"Cache dir: {os.environ.get('HF_HOME', None)}.") 111 | 112 | # Experiment settings 113 | self.data_folder = cfg.experiment.function.train_points.get("data_folder", None) 114 | self.data_folder = os.path.join(self.root_dir, self.data_folder) if self.data_folder is not None else None 115 | 116 | if cfg.experiment.function.train_points.generate_points: 117 | self.min_train_points = cfg.experiment.function.train_points.min_points 118 | self.max_train_points = cfg.experiment.function.train_points.max_points 119 | self.num_train_points = cfg.experiment.function.train_points.num_points 120 | self.xs_noise_std = cfg.experiment.function.train_points.xs_noise_std 121 | self.ys_noise_std = cfg.experiment.function.train_points.ys_noise_std 122 | self.random_train_points = self.cfg.experiment.function.train_points.random_points \ 123 | if hasattr(self.cfg.experiment.function.train_points, "random_points") \ 124 | and self.cfg.experiment.function.train_points.random_points else False 125 | else: 126 | assert self.data_folder is not None, "No data folder specified." 127 | assert os.path.exists(self.data_folder), f"Data folder {self.data_folder} does not exist." 128 | assert os.path.exists(os.path.join(self.data_folder, 'train_points.npy')), f"Train points file {os.path.join(self.data_folder, 'train_points.npy')} does not exist." 129 | assert os.path.exists(os.path.join(self.data_folder, 'test_points.npy')), f"Test points file {os.path.join(self.data_folder, 'test_points.npy')} does not exist." 130 | 131 | train_points_file = os.path.join(self.data_folder, 'train_points.npy') 132 | self.train_points = utils.load_points(train_points_file) 133 | self.num_train_points = len(self.train_points) 134 | self.min_train_points = np.min(self.train_points) 135 | self.max_train_points = np.max(self.train_points) 136 | self.logger.info(f"Loaded train points from {train_points_file}.") 137 | 138 | test_points_file = os.path.join(self.data_folder, 'test_points.npy') 139 | self.test_points = utils.load_points(test_points_file) 140 | self.num_test_points = len(self.test_points) 141 | self.min_test_points = np.min(self.test_points) 142 | self.max_test_points = np.max(self.test_points) 143 | self.logger.info(f"Loaded test points from {test_points_file}.") 144 | 145 | self.tolerance = cfg.experiment.function.tolerance 146 | self.num_variables = cfg.experiment.function.num_variables 147 | if self.num_variables > 2 and self.visual_model: 148 | self.logger.error("Visual models only support up to 2 variables.") 149 | exit(1) 150 | 151 | self.iterations = cfg.experiment.function.iterations 152 | self.max_retries = cfg.max_retries 153 | self.force_valid = cfg.force_valid 154 | self.force_unique = cfg.force_unique 155 | self.checkpoints = cfg.checkpoints 156 | 157 | if "test_function" not in cfg.experiment.function: 158 | self.logger.info("Test function is not known.") 159 | self.test_function = None 160 | else: 161 | self.test_function_name = cfg.experiment.function.test_function 162 | self.test_function = utils.string_to_function(self.test_function_name, self.num_variables) 163 | 164 | # Points setup 165 | if cfg.experiment.function.train_points.generate_points: 166 | add_extremes = cfg.experiment.function.train_points.add_extremes if hasattr(cfg.experiment.function.train_points, "add_extremes") else False 167 | self.train_points = self.generate_points(self.test_function, self.min_train_points, self.max_train_points, self.num_train_points, 168 | xs_noise_std=self.xs_noise_std, ys_noise_std=self.ys_noise_std, 169 | random_points=self.random_train_points, save_fig=True, add_extremes=add_extremes) 170 | np.save(os.path.join(self.output_path, "train_points.npy"), self.train_points) 171 | 172 | self.min_test_points = cfg.experiment.function.test_points.min_points if hasattr(cfg.experiment.function.test_points, "min_points") else self.min_train_points 173 | self.max_test_points = cfg.experiment.function.test_points.max_points if hasattr(cfg.experiment.function.test_points, "max_points") else self.max_train_points 174 | self.num_test_points = cfg.experiment.function.test_points.num_points 175 | self.test_points = self.generate_points(self.test_function, self.min_test_points, self.max_test_points, self.num_test_points, random_points=False, save_fig=False) 176 | np.save(os.path.join(self.output_path, "test_points.npy"), self.test_points) 177 | 178 | if cfg.experiment.get("normalize_points", False): 179 | self.train_points = utils.normalize_points(self.train_points, cfg.experiment.normalize_method, cfg.experiment.normalize_percentile) 180 | self.logger.info(f"Normalized train points: {self.train_points}.") 181 | 182 | self.logger.info(f"Train points: {utils.array_to_string(self.train_points)}.") 183 | if self.num_test_points > 100: 184 | self.logger.info(f"Not logging test points as there are more than 100 ({self.num_test_points}).") 185 | else: 186 | self.logger.info(f"Test points: {utils.array_to_string(self.test_points)}.") 187 | 188 | # Optimizer settings 189 | self.optimizer = Optimizer(cfg, self.train_points, self.logger) 190 | 191 | # Plotter setup 192 | if self.num_variables > 2: 193 | self.logger.warning("Plotter will not plot points and animation as there are more than 2 variables.") 194 | self.save_frames = cfg.plotter.save_frames if hasattr(cfg.plotter, "save_frames") else False 195 | if self.save_frames: 196 | os.makedirs(self.output_path + "frames/") 197 | 198 | self.save_video = cfg.plotter.save_video if hasattr(cfg.plotter, "save_video") else True 199 | self.save_video = False if self.num_variables > 2 else self.save_video 200 | if self.save_video: 201 | plt.rcParams.update({'figure.max_open_warning': cfg.experiment.function.iterations + 5}) 202 | self.plotter = Plotter(cfg, self.train_points, self.test_points, self.output_path) 203 | self.plotter.plot_points(save_fig=True, plot_test=False) 204 | 205 | # Base prompt 206 | self.prompts_path = os.path.join(self.root_dir, cfg.prompts_path) 207 | self.prompt_size = cfg.model.base_prompt.prompt_size 208 | with open(os.path.join(self.prompts_path, "OPRO", cfg.model.base_prompt.prompt), "r") as f: 209 | self.base_prompt = f.read() 210 | 211 | self.prompt_points = utils.decimate_points(self.train_points, cfg.max_points_in_prompt) 212 | self.prompt_points = utils.array_to_string(self.prompt_points) 213 | if self.num_train_points > self.cfg.max_points_in_prompt: 214 | self.logger.info(f"Found {self.num_train_points} train points, decimated to {cfg.max_points_in_prompt} for the prompt.") 215 | self.logger.info(f"Prompt points: {self.prompt_points}.") 216 | 217 | self.base_prompt = self.base_prompt.format(points=self.prompt_points, num_variables=self.num_variables, 218 | variables_list=[f"x{i+1}" for i in range(self.num_variables)], functions="{functions}") 219 | 220 | # Model settings 221 | self.model_name = cfg.model.name 222 | self.model = None 223 | 224 | self.visual_model = cfg.model.visual 225 | if hasattr(cfg.model.base_prompt, "input_image"): 226 | self.input_img = cfg.model.base_prompt.input_image 227 | valid_inputs = ["points", "previous_guess", "best_guess", "all_guesses"] 228 | if self.input_img not in valid_inputs: 229 | self.logger.error(f"Input image {self.input_img} not supported. Valid inputs are {valid_inputs}.") 230 | exit(1) 231 | else: 232 | self.input_img = None 233 | 234 | self.logger.info(f"Base Prompt: {self.base_prompt} with input image {self.input_img}.") 235 | self.tokenizer_pad = cfg.model.tokenizer_pad 236 | self.tokenizer_padding_side = cfg.model.tokenizer_padding_side 237 | 238 | self.max_new_tokens = cfg.model.max_new_tokens 239 | self.top_p = cfg.model.top_p 240 | self.top_k = cfg.model.top_k 241 | self.num_beams = cfg.model.num_beams 242 | 243 | self.temperature = cfg.model.temperature 244 | if cfg.model.temperature_schedule: 245 | self.temperature_scheduler = torch.optim.lr_scheduler.ExponentialLR(torch.optim.Adam([torch.tensor(self.temperature)], lr=1), gamma=cfg.model.temperature_schedule_gamma) 246 | else: 247 | self.temperature_scheduler = None 248 | 249 | model_args = { 250 | "temperature": self.temperature, 251 | "top_p": self.top_p, 252 | "top_k": self.top_k, 253 | "num_beams": self.num_beams, 254 | "max_length": self.max_new_tokens, 255 | "min_length": 0, 256 | "tokenizer_pad": self.tokenizer_pad, 257 | "tokenizer_padding_side": self.tokenizer_padding_side, 258 | "seed": self.cfg.seed, 259 | "api_key_path": os.path.join(self.root_dir, cfg.model.api_key_path) if hasattr(cfg.model, "api_key_path") else None, 260 | "organization_id_path": os.path.join(self.root_dir, cfg.model.organization_id_path) if hasattr(cfg.model, "organization_id_path") else None, 261 | } 262 | if torch.cuda.is_available() and 'A100' in torch.cuda.get_device_name(0): 263 | model_args['attn_implementation'] = 'flash_attention_2' 264 | model_args['use_flash_attn'] = True 265 | self.logger.info("Using Flash Attention 2") 266 | 267 | self.model = utils.load_model(self.model_name, self.device, self.dtype, self.cache_dir, model_args) 268 | self.logger.info("Model loaded - {model_name}.".format(model_name=self.model_name)) 269 | 270 | # Scorer settings 271 | if "basic" in cfg.experiment.scorer.name.lower(): 272 | self.scorer = BasicScorer(self.train_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific) 273 | self.test_scorer = BasicScorer(self.test_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific) 274 | elif "minmax" in cfg.experiment.scorer.name.lower(): 275 | min_score = cfg.experiment.scorer.min_score 276 | max_score = cfg.experiment.scorer.max_score 277 | self.scorer = MinMaxScorer(self.train_points, min_score=min_score, max_score=max_score, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific) 278 | self.test_scorer = MinMaxScorer(self.test_points, min_score=min_score, max_score=max_score, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific) 279 | elif "complexity" in cfg.experiment.scorer.name.lower(): 280 | self.logger.info(f"Complexity scorer with lambda {cfg['experiment']['scorer']['lambda']} and max nodes {cfg.experiment.scorer.max_nodes}.") 281 | alternative = False 282 | if hasattr(cfg.experiment.scorer, "alternative") and cfg.experiment.scorer.alternative: 283 | alternative = True 284 | self.logger.info("Using alternative complexity scorer.") 285 | self.scorer = ComplexityScorer(self.train_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific, lam=cfg['experiment']['scorer']['lambda'], max_nodes=cfg.experiment.scorer.max_nodes, alternative=alternative) 286 | self.test_scorer = ComplexityScorer(self.test_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific, lam=cfg['experiment']['scorer']['lambda'], max_nodes=cfg.experiment.scorer.max_nodes, alternative=alternative) 287 | else: 288 | self.logger.error(f"Scorer {cfg.experiment.scorer.name} not supported.") 289 | exit(1) 290 | 291 | # Seed functions 292 | self.seed_functions = {} 293 | min_seed_functions = max(5, self.prompt_size) # If the prompt size is small (e.g. 1) we still want to generate a few seed functions to avoid getting stuck 294 | gen_time = 0 295 | if cfg.experiment.generate_seed_functions: 296 | self.seed_functions, gen_time = self.generate_seed_functions() 297 | assert len(self.seed_functions) >= min_seed_functions, f"Could not generate {min_seed_functions} seed functions. Generated {len(self.seed_functions)} seed functions." 298 | else: 299 | self.seed_functions = {name: utils.string_to_function(name, self.num_variables) for name in cfg.experiment.seed_functions.functions} 300 | self.logger.info(f"Loaded seed functions: {self.seed_functions}.") 301 | self.current_functions = CurrentFunctions(self.seed_functions, self.scorer, self.optimizer, self.prompt_size, self.logger, self.num_variables) 302 | self.logger.info(f"Current functions: {self.current_functions.functions}.") 303 | self.logger.info(f"Current scores: {self.current_functions.scores}.") 304 | 305 | if len(self.current_functions.functions) < self.prompt_size: 306 | self.logger.warning(f"Could not generate {self.prompt_size} seed functions. Generated {len(self.current_functions.functions)} seed functions.") 307 | if len(self.current_functions.functions) == 0: 308 | self.logger.error("No seed functions generated. Exiting.") 309 | exit(1) 310 | else: 311 | self.logger.info(f"Succesfully generated {self.prompt_size} seed functions in {gen_time} seconds.") 312 | 313 | # Results json 314 | self.results = { 315 | "experiment_name": self.cfg.experiment.function.name, 316 | "seed": self.cfg.seed, 317 | "train_points": utils.array_to_string(self.train_points), 318 | "test_points": utils.array_to_string(self.test_points), 319 | "best_expr": "", 320 | "best_function": "", 321 | "scores": [], 322 | "R2_trains": [], 323 | "R2_tests": [], 324 | "R2_alls": [], 325 | "best_scores": [], 326 | "best_scores_normalized": [], 327 | "iterations": 0, 328 | "tries_per_iteration": [], 329 | "generations_per_iteration": [], 330 | "num_unique": len(self.current_functions.functions), 331 | "best_found_at": 0, 332 | "sympy_equivalent": False, 333 | "temperatures": [], 334 | "times": { 335 | "iteration": [], 336 | "seed_function_generation": gen_time, 337 | "generation_per_iteration": [], 338 | "optimization_per_iteration": [], 339 | } 340 | } 341 | if "test_function" in self.cfg.experiment.function: 342 | self.results["test_function"] = self.cfg.experiment.function.test_function 343 | 344 | # Save config 345 | with open(self.output_path + "config.yaml", "w") as f: 346 | OmegaConf.save(self.cfg, f) 347 | 348 | def generate_points(self, function: Callable, min_points: float, max_points: float, num: int, xs_noise_std: float = 0, ys_noise_std: float = 0, 349 | random_points: bool = False, add_extremes: bool = True) -> str: 350 | """ 351 | Generates points from a given function, with optional noise. 352 | 353 | Parameters 354 | ---------- 355 | function -> the function to generate points from. 356 | min_points -> the minimum value of the points to generate. 357 | max_points -> the maximum value of the points to generate. 358 | num -> the number of points to generate. 359 | xs_noise_std -> the standard deviation of the noise to add to the xs. 360 | ys_noise_std -> the standard deviation of the noise to add to the ys. 361 | random_points -> whether to generate random points instead of a grid/meshgrid. 362 | add_extremes -> whether to add points at the extreme values of the interval manually to ensure they are included. 363 | 364 | Returns 365 | ------- 366 | points -> the points as a string. 367 | """ 368 | min_value = copy.deepcopy(min_points) 369 | max_value = copy.deepcopy(max_points) 370 | if type(min_points) != list and type(min_points) != listconfig.ListConfig: 371 | min_points = [min_points] * self.num_variables 372 | if type(max_points) != list and type(max_points) != listconfig.ListConfig: 373 | max_points = [max_points] * self.num_variables 374 | min_points = np.array(min_points, dtype=np.float32) 375 | max_points = np.array(max_points, dtype=np.float32) 376 | 377 | points_per_dim = int(np.floor(num**(1/self.num_variables))) 378 | self.logger.info(f"Generating {points_per_dim} points per dimension for a total of {points_per_dim**self.num_variables} points.") 379 | 380 | if random_points: 381 | # Add points at the extreme values of the interval manually to ensure they are included 382 | # This depends on the number of dimensions 383 | # For example, in 1D if the interval is [0, 1] we need to add points at 0 and 1 384 | # In 2D, if the interval is [(0, 0), (1, 1)] we need to add points at (0, 0), (0, 1), (1, 0), (1, 1) 385 | if add_extremes: 386 | variable_ranges = np.array([[min_points[i], max_points[i]] for i in range(self.num_variables)]) 387 | extreme_points = np.array(np.meshgrid(*variable_ranges)).T.reshape(-1, self.num_variables) 388 | self.logger.info(f"Adding {len(extreme_points)} extreme points ({extreme_points}).") 389 | 390 | # Reshape min and max points to match the random shape. Currently min and max are of shape (num_variables,), so we need to add n dimensions of size points_per_dim by copying the min and max values 391 | # For example, if min is [0, 1] and num_variables is 2 and points_per_dim is 3, we need to reshape min to an array of shape (2, 3, 3) with all values being 0 and 1 across the last dimension 392 | random_shape = tuple([self.num_variables, *([points_per_dim] * self.num_variables)]) 393 | min_points = np.expand_dims(min_points, axis=tuple(range(1, self.num_variables + 1))) 394 | max_points = np.expand_dims(max_points, axis=tuple(range(1, self.num_variables + 1))) 395 | max_points += 1e-10 # Add small eps to max_points as the rightmost value is not included in np.random.uniform 396 | for i in range(1, self.num_variables+1): 397 | min_points = np.repeat(min_points, points_per_dim, axis=i) 398 | max_points = np.repeat(max_points, points_per_dim, axis=i) 399 | Xs = np.random.uniform(min_points, max_points, random_shape) 400 | 401 | else: 402 | Xs = np.meshgrid(*[np.linspace(min_points[i], max_points[i], points_per_dim) for i in range(self.num_variables)]) 403 | Xs = np.array(Xs) 404 | if xs_noise_std: 405 | Xs += np.random.normal(0, xs_noise_std, Xs.shape) 406 | pts = np.array(list(zip(*[x.flat for x in Xs]))) 407 | 408 | ys = utils.eval_function(function, pts, self.num_variables).T 409 | 410 | if ys_noise_std: 411 | ys += np.random.normal(0, ys_noise_std, ys.shape) 412 | 413 | if random_points and add_extremes: 414 | pts = np.concatenate((extreme_points, pts)) 415 | extreme_ys = utils.eval_function(function, extreme_points, self.num_variables).T 416 | ys = np.concatenate((extreme_ys, ys)) 417 | 418 | points = np.concatenate((pts, ys.reshape(-1, 1)), axis=1) 419 | if add_extremes and len(points) > num: 420 | # Remove random points to account for the extra extremes 421 | # The points are sampled randomly so removing from the end is the same as removing random indices 422 | self.logger.info(f"Removing {len(points)-num} randomly generated points: {points[num:]}") 423 | points = points[:num] 424 | while any(np.isinf(points[:, -1])): 425 | # Remove points where the function is infinite 426 | inf_indices = np.where(np.isinf(points)) 427 | self.logger.info(f"Removing {len(inf_indices)} points where the function is infinite.") 428 | points = np.delete(points, inf_indices[0], axis=0) 429 | 430 | if len(points) < num and random_points: 431 | # Generate new points to replace the infinite ones 432 | self.logger.info(f"Recursively generating {num-len(points)} new points.") 433 | new_points = self.generate_points(function, min_value, max_value, num-len(points), xs_noise_std, ys_noise_std, random_points, add_extremes=False) 434 | points = np.concatenate((points, new_points)) 435 | 436 | return points 437 | 438 | def generate_seed_functions(self) -> Tuple[Dict[str, Any], float]: 439 | """ 440 | Generates initial seed functions for the experiment. 441 | 442 | Parameters 443 | ---------- 444 | 445 | Returns 446 | ------- 447 | seed_functions -> the generated seed functions. 448 | gen_time -> the time it took to generate the seed functions. 449 | """ 450 | generation_tokens = self.cfg.experiment.seed_functions.generation_tokens if hasattr(self.cfg.experiment.seed_functions, "generation_tokens") else 512 451 | max_tries = self.cfg.experiment.seed_functions.max_tries if hasattr(self.cfg.experiment.seed_functions, "max_tries") else 10 452 | seed_functions = {} 453 | 454 | seed_prompt = self.cfg.get("model").get("seed_function_prompt", None) 455 | assert seed_prompt is not None, "Seed function prompt not specified." 456 | seed_prompt = os.path.join(self.prompts_path, seed_prompt) 457 | 458 | with open(seed_prompt, "r") as f: 459 | prompt = f.read() 460 | img_path = os.path.join(self.output_path, "points.png") if self.input_img else None 461 | 462 | prompt = prompt.format(points=self.prompt_points, num_variables=self.num_variables, variables_list=[f"x{i+1}" for i in range(self.num_variables)]) 463 | self.logger.info("Prompt for seed functions generation:") 464 | self.logger.info(prompt) 465 | 466 | start_time = time.perf_counter() 467 | with torch.inference_mode(): 468 | for i in range(max_tries): 469 | # Generate seed functions using the model 470 | self.logger.info(f"Attempt {i+1} of {max_tries} to generate seed functions.") 471 | seeds = self.model.generate(prompt, return_prompt=False, image_files=img_path, temperature=self.temperature, max_new_tokens=generation_tokens) 472 | self.logger.info("Model output for seed functions: " + seeds) 473 | 474 | # Parse model output 475 | for seed in seeds.split("\n"): 476 | if "x" not in seed: 477 | self.logger.info(f"Skipping line {seed} as it does not contain 'x' and is likely not a function.") 478 | continue 479 | if "Error" in seed: 480 | self.logger.info(f"Skipping line {seed} as it contains 'Error'.") 481 | continue 482 | seed = utils.clean_function(seed) 483 | self.logger.info(f"Seed function: {seed}.") 484 | if seed == "": 485 | continue 486 | try: 487 | valid, reason = utils.is_valid_function(seed, None, self.num_variables) 488 | self.logger.info(f"Function {seed}. Valid: {valid}. Reason: {reason}.") 489 | if valid: 490 | function = utils.string_to_function(seed, self.num_variables) 491 | seed_functions[seed] = function 492 | except Exception as e: 493 | self.logger.warning(f"Could not parse line {seed}.") 494 | self.logger.warning(str(e)) 495 | pass 496 | # Here we continue even if we already have enough seed functions, as we might not have enough valid seed functions after optimization 497 | # Perhaps a better approach should be optimizing here directly and exiting if we have enough valid seed functions 498 | end_time = time.perf_counter() 499 | 500 | self.logger.info(f"Generated seed functions: {seed_functions}.") 501 | return seed_functions, end_time - start_time 502 | 503 | 504 | def get_new_function(self, prompt: str) -> Tuple[List, bool]: 505 | """ 506 | Generates a new function from the model, given a prompt. 507 | 508 | Parameters 509 | ---------- 510 | prompt -> the prompt to use for the model. 511 | 512 | Returns 513 | ------- 514 | functions -> the new functions generated by the model as a string. 515 | found_valid -> whether a valid function was found. 516 | """ 517 | img = None 518 | if self.visual_model: 519 | if self.input_img == "points": 520 | img = os.path.join(self.output_path, "points.png") 521 | elif self.input_img == "previous_guess": 522 | img = os.path.join(self.output_path, "frames", f"{self.results['iterations']-1}.png") 523 | elif self.input_img == "best_guess": 524 | fig, ax = self.plotter.plot_results(self.current_functions.get_best_function(return_coeff=False), self.test_function, plot_true=False) 525 | fig.savefig(self.output_path + "best_guess.png") 526 | plt.close(fig) 527 | img = os.path.join(self.output_path, "best_guess.png") 528 | elif self.input_img == "all_guesses": 529 | os.makedirs(self.output_path + "prompt_input/") 530 | path = os.path.join(self.output_path, "prompt_input/") 531 | img = [] 532 | functions = self.current_functions.get_prompt_functions() 533 | for expr, _ in functions: 534 | try: 535 | function = self.current_functions.functions[expr] 536 | fig, ax = self.plotter.plot_results(function, self.test_function, plot_true=False, label="Function: " + str(function)) 537 | function_string = str(expr) 538 | fig.suptitle("Plot of " + function_string) 539 | fig.text(0.5, 0.90, "Error: " + str(self.current_functions.norm_scores[expr]), ha='center') 540 | file_name = function_string.replace(" ", "_").replace("/", "div") 541 | fig.savefig(os.path.join(path, f"{file_name}.png")) 542 | plt.close(fig) 543 | img.append(os.path.join(path, f"{file_name}.png")) 544 | except Exception as e: 545 | self.logger.warning(f"Could not plot function {function}.") 546 | self.logger.warning(str(e)) 547 | pass 548 | 549 | new_output = self.model.generate(prompt, return_prompt=False, image_files=img, temperature=self.temperature) 550 | self.logger.info("Prompt: " + prompt) 551 | self.logger.info("Model output: " + new_output) 552 | 553 | # Clean up images 554 | if self.visual_model and self.input_img == "best_guess": 555 | os.remove(self.output_path + "best_guess.png") 556 | elif self.visual_model and self.input_img == "all_guesses": 557 | for file in os.listdir(path): 558 | os.remove(os.path.join(path, file)) 559 | os.rmdir(path) 560 | 561 | functions = [] 562 | lines = new_output.split("\n") 563 | for line in lines: 564 | if "x" not in line: 565 | self.logger.info(f"Skipping line {line} as it does not contain 'x' and is likely not a function.") 566 | continue 567 | if "Error" in line: 568 | self.logger.info(f"Skipping line {line} as it contains 'Error'.") 569 | continue 570 | line = utils.clean_function(line) 571 | if line == "": 572 | continue 573 | self.logger.info("Cleaned line: " + line + ".") 574 | try: 575 | valid, reason = utils.is_valid_function(line, self.current_functions, self.num_variables) 576 | self.logger.info(f"Valid: {valid}. Reason: {reason}.") 577 | if valid: 578 | functions.append(line) 579 | elif not valid and reason == "Function already in prompt." and not self.force_unique: 580 | functions.append(line) 581 | except Exception as e: 582 | self.logger.warning(f"Could not parse line {line}.") 583 | self.logger.warning(str(e)) 584 | pass 585 | 586 | found_valid = False 587 | if len(functions) == 0: 588 | self.logger.warning("Could not find a valid function in the output. Using the last function in the output.") 589 | functions = [self.current_functions.get_best_function()] 590 | else: 591 | found_valid = True 592 | functions = [utils.string_to_function(function, self.num_variables) for function in functions] 593 | self.logger.info(f"Found functions: {functions}.") 594 | 595 | return functions, found_valid 596 | 597 | def get_new_function_and_score(self, prompt: str) -> Tuple[Any, Any, float]: 598 | """ 599 | Generates a new function from the model, given a prompt, and scores it if it is valid. 600 | 601 | Parameters 602 | ---------- 603 | prompt -> the prompt to use for the model. 604 | 605 | Returns 606 | ------- 607 | expression -> the coefficient form of the function. 608 | function -> function with the optimized coefficients. 609 | score -> the score of the function. 610 | """ 611 | valid = False 612 | start_time = time.perf_counter() 613 | for i in range(self.max_retries): 614 | self.logger.info(f"Attempt {i+1} of {self.max_retries} to find a valid function.") 615 | functions, valid = self.get_new_function(prompt) 616 | if valid and len(functions) > 1: 617 | self.logger.info(f"Found {len(functions)} functions in the output.") 618 | break 619 | else: 620 | self.logger.info(f"Could not find a valid function in the output. Trying again.") 621 | 622 | self.results["tries_per_iteration"].append(i+1) 623 | self.results["times"]["generation_per_iteration"].append(time.perf_counter() - start_time) 624 | 625 | if not valid: 626 | if self.force_valid: 627 | self.logger.error(f"Could not find a valid function after {self.max_retries} tries. Exiting.") 628 | exit(1) 629 | else: 630 | best_expr = self.current_functions.get_best_function(return_coeff=True) 631 | best_function = self.current_functions.functions[best_expr] 632 | best_score = self.current_functions.scores[best_expr] 633 | self.logger.warning(f"Could not find a valid function after {self.max_retries} tries. Using {best_function}.") 634 | else: 635 | best_score = np.inf 636 | start_time = time.perf_counter() 637 | for function in functions: 638 | if not self.current_functions.func_in_list(function): 639 | self.results["num_unique"] += 1 640 | self.results["generations_per_iteration"].append(len(functions)) 641 | try: 642 | opt_function, exp = self.optimizer.optimize(function, return_coeff=True) 643 | score = self.scorer.score(opt_function) 644 | self.logger.info(f"New function: {str(opt_function)}. Score: {score}.") 645 | if score < best_score: 646 | best_score = score 647 | best_function = opt_function 648 | best_expr = exp 649 | except Exception as e: 650 | self.logger.warning(f"Could not optimize function {function}. {e}") 651 | pass 652 | 653 | self.results["times"]["optimization_per_iteration"].append(time.perf_counter() - start_time) 654 | self.logger.info(f"Optimizer time: {time.perf_counter() - start_time}.") 655 | if best_score == np.inf: 656 | self.logger.warning(f"No functions scored below inf. Using the best function in the prompt.") 657 | best_expr = self.current_functions.get_best_function(return_coeff=True) 658 | best_function = self.current_functions.functions[best_expr] 659 | best_score = self.current_functions.scores[best_expr] 660 | self.logger.info(f"Best function: {best_function}. Score: {best_score}.") 661 | 662 | self.logger.info(f"Finished get new function and score. Best function: {best_function}. Score: {best_score}.") 663 | return best_expr, best_function, best_score 664 | 665 | def get_R2_scores(self, function: Any) -> Tuple[float, float, float]: 666 | """ 667 | Computes the R2 scores for the train and test sets given a function, removing the 5% worst predictions. 668 | 669 | Parameters 670 | ---------- 671 | function -> the function to evaluate. 672 | 673 | Returns 674 | ------- 675 | r2_train -> the R2 score for the train set. 676 | r2_test -> the R2 score for the test set. 677 | r2_all -> the R2 score for all points. 678 | """ 679 | 680 | y_true_train = self.train_points[:, -1] 681 | y_true_test = self.test_points[:, -1] 682 | 683 | def compute_predictions(function, points, num_variables): 684 | y_pred = utils.eval_function(function, points[:, 0:-1], num_variables) 685 | 686 | # Compute a boolean mask of the 5% worst predictions 687 | worst_indices = np.argsort(np.abs(y_pred - points[:, -1]))[-int(len(points) * 0.05):] 688 | mask = np.zeros(len(points), dtype=bool) 689 | mask[worst_indices] = True 690 | 691 | return y_pred[~mask], mask 692 | 693 | try: 694 | y_pred_train, y_train_mask = compute_predictions(function, self.train_points, self.num_variables) 695 | y_pred_test, y_test_mask = compute_predictions(function, self.test_points, self.num_variables) 696 | 697 | y_true_train = y_true_train[~y_train_mask] 698 | y_true_test = y_true_test[~y_test_mask] 699 | 700 | assert len(y_true_train) == len(y_pred_train), f"Length of true train points ({len(y_true_train)}) does not match length of predicted train points ({len(y_pred_train)})." 701 | assert len(y_true_test) == len(y_pred_test), f"Length of true test points ({len(y_true_test)}) does not match length of predicted test points ({len(y_pred_test)})." 702 | 703 | except Exception as e: 704 | self.logger.warning(f"Could not evaluate function {function}. {e}") 705 | return np.nan, np.nan, np.nan 706 | try: 707 | r2_train = r2_score(y_true_train, y_pred_train) 708 | except Exception as e: 709 | self.logger.warning(f"Could not calculate R2 score for train set. {e}") 710 | r2_train = np.nan 711 | try: 712 | r2_test = r2_score(y_true_test, y_pred_test) 713 | except Exception as e: 714 | self.logger.warning(f"Could not calculate R2 score for test set. {e}") 715 | r2_test = np.nan 716 | try: 717 | r2_all = r2_score(np.concatenate((y_true_train, y_true_test)), np.concatenate((y_pred_train, y_pred_test))) 718 | except Exception as e: 719 | self.logger.warning(f"Could not calculate R2 score for all points. {e}") 720 | r2_all = np.nan 721 | 722 | return r2_train, r2_test, r2_all 723 | 724 | def run(self) -> None: 725 | """ 726 | Runs the main experiment, iterating and generating new functions until the tolerance is reached. 727 | """ 728 | main_timer_start = time.perf_counter() 729 | if self.save_video: 730 | frames = [] 731 | 732 | # Check if one of the generated seed functions is already below the tolerance 733 | best_expr = self.current_functions.get_best_function(return_coeff=True) 734 | best_function = self.current_functions.get_best_function(return_coeff=False) 735 | score = self.current_functions.scores[best_expr] 736 | r2_train, r2_test, r2_all = self.get_R2_scores(best_function) 737 | 738 | if r2_train >= self.tolerance: 739 | self.logger.info(f"The seed function {best_expr} is already above the R2 tolerance {self.tolerance}.") 740 | self.logger.info(f"Best function: {best_function}. R2 (train): {r2_train}.") 741 | 742 | self.results["scores"].append(score) if score != np.inf else self.results["scores"].append("inf") 743 | self.results["best_scores"].append(self.current_functions.scores[best_expr]) 744 | self.results["best_scores_normalized"].append(self.current_functions.norm_scores[best_expr]) 745 | self.results["best_found_at"] = 0 746 | self.results["temperatures"].append(self.temperature) 747 | 748 | if self.save_video: 749 | frame, ax = self.plotter.record_frame(best_function, best_function, r2_test, self.test_function, -1, plot_true=True) 750 | if self.save_frames: 751 | frame.savefig(self.output_path + "frames/" + f"{i}.png") 752 | frames.append(frame) 753 | else: 754 | # Start the main loop 755 | prompt = self.current_functions.get_prompt(self.base_prompt) 756 | for i in range(self.iterations): 757 | start_time = time.perf_counter() 758 | self.logger.info(f"Round {i+1}.") 759 | self.logger.info(f"Scores: {self.current_functions.scores}.") 760 | 761 | # Handle temperature schedule 762 | if self.temperature_scheduler is not None: 763 | self.temperature = self.temperature_scheduler.get_last_lr()[0] 764 | self.logger.info(f"Temperature: {self.temperature}.") 765 | self.results["temperatures"].append(self.temperature) 766 | self.temperature_scheduler.step() 767 | 768 | # Get new function and score 769 | expr, function, score = self.get_new_function_and_score(prompt) 770 | self.current_functions.add_function(expr, function) 771 | best_expr = self.current_functions.get_best_function(return_coeff=True) 772 | best_function = self.current_functions.functions[best_expr] 773 | 774 | # Update results 775 | if expr == best_expr: 776 | self.results["best_found_at"] = i+1 777 | self.results["iterations"] = i+1 778 | self.results["scores"].append(score) if score != np.inf else self.results["scores"].append("inf") 779 | self.results["best_scores"].append(self.current_functions.scores[best_expr]) 780 | self.results["best_scores_normalized"].append(self.current_functions.norm_scores[best_expr]) 781 | self.results["times"]["iteration"].append(time.perf_counter() - start_time) 782 | 783 | r2_train, r2_test, r2_all = self.get_R2_scores(function) 784 | self.results["R2_trains"].append(r2_train) 785 | self.results["R2_tests"].append(r2_test) 786 | self.results["R2_alls"].append(r2_all) 787 | 788 | # Update video 789 | if self.save_video: 790 | if not score == np.inf: 791 | frame, ax = self.plotter.record_frame(best_function, function, r2_test, self.test_function, i, plot_true=True) 792 | if self.save_frames: 793 | frame.savefig(self.output_path + "frames/" + f"{i}.png") 794 | frames.append(frame) 795 | else: 796 | self.logger.warning(f"Skipping frame {i} as the score is inf.") 797 | 798 | # Check if the tolerance is reached 799 | if self.test_function is not None: 800 | if utils.func_equals(best_function, self.test_function, self.num_variables): 801 | self.logger.info(f"Function is equivalent to the true function.") 802 | self.results["equivalent"] = True 803 | break 804 | 805 | if r2_train >= self.tolerance: 806 | self.logger.info(f"Found a function with R2 (train) = {r2_train} above the tolerance {self.tolerance}.") 807 | break 808 | prompt = self.current_functions.get_prompt(self.base_prompt) 809 | 810 | if i in self.checkpoints: 811 | self.logger.info(f"Checkpoint {i}. Saving results.") 812 | results_checkpoint = copy.deepcopy(self.results) 813 | checkpoint_timer_end = time.perf_counter() 814 | results_checkpoint["times"]["total"] = checkpoint_timer_end - main_timer_start 815 | 816 | best_expr = self.current_functions.get_best_function(return_coeff=True) 817 | best_function = self.current_functions.get_best_function(return_coeff=False) 818 | test_score = self.test_scorer.score(best_function) 819 | results_checkpoint["test_score"] = test_score 820 | results_checkpoint["best_function"] = str(best_function) 821 | results_checkpoint["best_expr"] = str(best_expr) 822 | r2_train, r2_test, r2_all = self.get_R2_scores(best_function) 823 | results_checkpoint["r2_train"] = r2_train 824 | results_checkpoint["r2_test"] = r2_test 825 | results_checkpoint["r2_all"] = r2_all 826 | results_checkpoint["final_complexity"] = utils.count_nodes(best_function) 827 | with open(self.output_path + f"results_checkpoint_{i}.json", "w") as f: 828 | json.dump(results_checkpoint, f) 829 | self.logger.info(f"Checkpoint {i} saved.") 830 | 831 | # Save final results 832 | main_timer_end = time.perf_counter() 833 | best_expr = self.current_functions.get_best_function(return_coeff=True) 834 | best_function = self.current_functions.get_best_function(return_coeff=False) 835 | test_score = self.test_scorer.score(best_function) 836 | self.logger.info(f"Test score: {test_score}.") 837 | 838 | self.logger.info(f"Best function: {best_function}. Score: {self.current_functions.scores[best_expr]} ({self.current_functions.norm_scores[best_expr]}).") 839 | 840 | if hasattr(self, "test_function_name") and self.test_function_name is not None: 841 | self.logger.info(f"True function: {self.test_function_name}") 842 | fig, ax = self.plotter.plot_results(best_function, self.test_function) 843 | fig.savefig(self.output_path + "final.png") 844 | 845 | self.results["best_function"] = str(best_function) 846 | self.results["best_expr"] = str(best_expr) 847 | self.results["test_score"] = test_score 848 | self.results["times"]["total"] = main_timer_end - main_timer_start + self.results["times"]["seed_function_generation"] 849 | self.results["times"]["avg_generation"] = np.mean(self.results["times"]["generation_per_iteration"]) if len(self.results["times"]["generation_per_iteration"]) > 0 else 0 850 | self.results["times"]["avg_optimization"] = np.mean(self.results["times"]["optimization_per_iteration"]) if len(self.results["times"]["optimization_per_iteration"]) > 0 else 0 851 | 852 | r2_train, r2_test, r2_all = self.get_R2_scores(best_function) 853 | self.results["r2_train"] = r2_train 854 | self.results["r2_test"] = r2_test 855 | self.results["r2_all"] = r2_all 856 | self.logger.info(f"R2 train: {np.round(r2_train, 6)}. R2 test: {np.round(r2_test, 6)}. R2 all: {np.round(r2_all, 6)}.") 857 | 858 | final_complexity = utils.count_nodes(best_function) 859 | self.results["final_complexity"] = final_complexity 860 | self.logger.info(f"Number of nodes in final expression tree: {final_complexity}.") 861 | 862 | with open(self.output_path + "results.json", "w") as f: 863 | json.dump(self.results, f) 864 | 865 | if self.save_video and len(frames) > 0: 866 | self.plotter.record_video(frames) 867 | 868 | 869 | @hydra.main(version_base=None, config_path="conf", config_name="config") 870 | def main(cfg: DictConfig) -> None: 871 | workspace = Workspace(cfg) 872 | workspace.run() 873 | 874 | def dump_profile(): 875 | profiler.disable() 876 | job_id = utils.get_job_id() 877 | print(f"Dumping profile to {os.path.join(os.getcwd(), 'profiles', 'profile')}_{job_id if job_id is not None else 'local'}") 878 | if not os.path.exists("./profiles"): 879 | os.makedirs("./profiles") 880 | profiler.dump_stats(f"./profiles/profile_{job_id if job_id is not None else 'local'}") 881 | 882 | def signal_handler(sig, frame): 883 | # Ignore warnings, as otherwise we break the logger 884 | warnings.filterwarnings("ignore") 885 | dump_profile() 886 | print(f"Detecting signal {sig}. Dumping profile to {os.path.join(os.getcwd(), 'profiles', 'profile')}_{job_id if job_id is not None else 'local'}") 887 | sys.stdout.flush() 888 | if sig == signal.SIGTERM or sig == signal.SIGINT: 889 | sys.exit(1) 890 | 891 | if __name__ == "__main__": 892 | # Run full profiler if env variable PROFILE is set 893 | do_profile = os.environ.get("PROFILE", False) 894 | print("Initializing profiler.") 895 | print("Profile will only be created if the code fails or is terminated.") if not do_profile else print("Profile will be created.") 896 | job_id = utils.get_job_id() 897 | 898 | # Set termination signal handlers to dump profile when terminated by SLURM 899 | signal.signal(signal.SIGTERM, signal_handler) 900 | signal.signal(signal.SIGINT, signal_handler) 901 | signal.signal(signal.SIGCONT, signal_handler) 902 | 903 | # Setup profiler 904 | global profiler 905 | profiler = cProfile.Profile() 906 | profiler.enable() 907 | 908 | try: 909 | main() 910 | except Exception as e: 911 | # Catch exceptions and dump profile 912 | print("Caught exception in main.") 913 | print(e) 914 | dump_profile() 915 | sys.exit(2) 916 | 917 | if do_profile: 918 | dump_profile() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_model import HuggingFaceModel 2 | from .llava_model_hf import LLaVaModelHF 3 | from .openai_model import OpenAIModel -------------------------------------------------------------------------------- /models/hf_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fcntl 3 | import utils 4 | 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | 7 | class HuggingFaceModel(object): 8 | def __init__(self, model_name, device, dtype, cache_dir=None, **kwargs): 9 | self.model_name = model_name 10 | self.device = device 11 | self.dtype = dtype 12 | token = os.environ.get("HF_TOKEN", None) 13 | 14 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=dtype, cache_dir=cache_dir, token=token) 15 | self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, cache_dir=cache_dir, token=token, device_map='auto') 16 | self.model.eval() 17 | 18 | self.temperature = kwargs.get("temperature", 1.0) 19 | self.top_k = kwargs.get("top_k", 50) 20 | self.top_p = kwargs.get("top_p", 0.9) 21 | self.num_beams = kwargs.get("num_beams", 1) 22 | self.num_return_sequences = kwargs.get("num_return_sequences", 1) 23 | self.max_new_tokens = kwargs.get("max_new_tokens", 256) 24 | self.min_new_tokens = kwargs.get("min_new_tokens", 0) 25 | 26 | if "tokenizer_pad" in kwargs: 27 | self.tokenizer.pad_token = kwargs["tokenizer_pad"] 28 | 29 | if "tokenizer_padding_side" in kwargs: 30 | self.tokenizer.padding_side = kwargs["tokenizer_padding_side"] 31 | 32 | def generate(self, prompt, return_prompt=False, image_files=None, temperature=None, max_new_tokens=None): 33 | if temperature is None: 34 | temperature = self.temperature 35 | if max_new_tokens is None: 36 | max_new_tokens = self.max_new_tokens 37 | 38 | if '' in prompt: 39 | prompt = prompt.replace('', '') # Not used for non vision models, this assumes that this class is always used for text models (as the vision model used is LLaVA and is implemented in a different class) 40 | 41 | messages = utils.get_messages(prompt) 42 | inputs = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.device) 43 | 44 | outputs = self.model.generate(**inputs, do_sample=True, temperature=temperature, top_k=self.top_k, top_p=self.top_p, num_beams=self.num_beams, 45 | num_return_sequences=self.num_return_sequences, max_new_tokens=max_new_tokens, min_new_tokens=self.min_new_tokens, pad_token_id=self.tokenizer.eos_token_id) 46 | 47 | outputs = outputs[0][len(inputs[0]):] if not return_prompt else outputs[0] 48 | decoded_output = self.tokenizer.decode(outputs, skip_special_tokens=True) 49 | 50 | # Remove llama special words 51 | decoded_output = decoded_output.replace("assistant", "").replace("user", "").replace("system", "") 52 | 53 | return decoded_output -------------------------------------------------------------------------------- /models/llava_model_hf.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration 3 | 4 | class LLaVaModelHF(object): 5 | def __init__(self, model_name, device, dtype, cache_dir=None, **kwargs): 6 | self.model_name = model_name 7 | self.device = device 8 | self.dtype = dtype 9 | print(f"Cache dir: {cache_dir}") 10 | use_flash_attention = kwargs.get("use_flash_attn", False) 11 | attn_implementation = "flash_attention_2" if use_flash_attention else None 12 | 13 | self.processor = LlavaNextProcessor.from_pretrained(self.model_name, cache_dir=cache_dir) 14 | self.model = LlavaNextForConditionalGeneration.from_pretrained(self.model_name, torch_dtype=self.dtype, low_cpu_mem_usage=True, 15 | device_map='auto', cache_dir=cache_dir, attn_implementation=attn_implementation) 16 | self.model.eval() 17 | print(f"Model loaded on device: {self.model.device}") 18 | 19 | self.temperature = kwargs.get("temperature", 1.0) 20 | self.top_k = kwargs.get("top_k", 50) 21 | self.top_p = kwargs.get("top_p", 0.9) 22 | self.num_beams = kwargs.get("num_beams", 1) 23 | self.num_return_sequences = kwargs.get("num_return_sequences", 1) 24 | self.max_new_tokens = kwargs.get("max_new_tokens", 256) 25 | self.min_new_tokens = kwargs.get("min_new_tokens", 0) 26 | 27 | def generate(self, prompt, return_prompt=False, image_files=None, temperature=None, max_new_tokens=None): 28 | if temperature is None: 29 | temperature = self.temperature 30 | if max_new_tokens is None: 31 | max_new_tokens = self.max_new_tokens 32 | 33 | if image_files is None and '' in prompt: 34 | prompt = prompt.replace('', '') 35 | 36 | image_path = image_files[0] if type(image_files) == list else image_files 37 | image = Image.open(image_path) if image_files is not None else None 38 | inputs = self.processor(prompt, image, return_tensors="pt").to(self.device) 39 | 40 | outputs = self.model.generate(**inputs, do_sample=True, temperature=temperature, top_k=self.top_k, top_p=self.top_p, num_beams=self.num_beams, 41 | num_return_sequences=self.num_return_sequences, max_new_tokens=max_new_tokens, min_new_tokens=self.min_new_tokens, pad_token_id=self.processor.tokenizer.eos_token_id) 42 | 43 | outputs = outputs[0][len(inputs['input_ids'][0]):] if not return_prompt else outputs[0] 44 | decoded_output = self.processor.tokenizer.decode(outputs, skip_special_tokens=True) 45 | 46 | return decoded_output 47 | -------------------------------------------------------------------------------- /models/openai_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import base64 4 | 5 | class OpenAIModel(object): 6 | def __init__(self, model_name, device, dtype, cache_dir=None, **kwargs): 7 | self.model_name = model_name 8 | self.device = device 9 | self.dtype = dtype 10 | 11 | self.api_key_path = None if "api_key_path" not in kwargs else kwargs["api_key_path"] 12 | self.api_key = self.get_api_key() 13 | self.organization_id_path = None if "organization_id_path" not in kwargs else kwargs["organization_id_path"] 14 | self.organization_id = self.get_org_id() 15 | assert self.api_key is not None, "API key not found." 16 | 17 | self.client = openai.Client( 18 | api_key=self.api_key, 19 | organization=self.organization_id 20 | ) 21 | 22 | self.top_p = 0.9 if "top_p" not in kwargs else kwargs["top_p"] 23 | self.temperature = 1.0 if "temperature" not in kwargs else kwargs["temperature"] 24 | self.max_length = 1024 if "max_length" not in kwargs else kwargs["max_length"] 25 | self.num_return_sequences = 1 if "num_return_sequences" not in kwargs else kwargs["num_return_sequences"] 26 | self.seed = None if "seed" not in kwargs else kwargs["seed"] 27 | 28 | def get_api_key(self): 29 | if "OPENAI_API_KEY" in os.environ: 30 | return os.environ["OPENAI_API_KEY"] 31 | elif self.api_key_path is not None: 32 | with open(self.api_key_path, "r") as f: 33 | return f.read().strip() 34 | 35 | return None 36 | 37 | def get_org_id(self): 38 | if "OPENAI_ORG_ID" in os.environ: 39 | return os.environ["OPENAI_ORG_ID"] 40 | elif self.organization_id_path is not None: 41 | with open(self.organization_id_path, "r") as f: 42 | return f.read().strip() 43 | 44 | return None 45 | 46 | def encode_image(self, image_path): 47 | with open(image_path, "rb") as image_file: 48 | return base64.b64encode(image_file.read()).decode('utf-8') 49 | 50 | def get_messages(self, prompt, splits=["system", "user"], image_files=None): 51 | messages = [] 52 | for split in splits: 53 | start_tag = f"<{split}>" 54 | end_tag = f"" 55 | if start_tag not in prompt or end_tag not in prompt: 56 | continue 57 | 58 | start_idx = prompt.find(start_tag) 59 | end_idx = prompt.find(end_tag) 60 | 61 | messages.append({ 62 | "role": split, 63 | "content": prompt[start_idx + len(start_tag):end_idx].strip() 64 | }) 65 | 66 | if len(messages) == 0: 67 | messages.append({ 68 | "role": "user", 69 | "content": prompt 70 | }) 71 | 72 | if image_files is not None: 73 | user_index = next((i for i, item in enumerate(messages) if item["role"] == "user"), None) 74 | user_msg = messages[user_index] 75 | user_msg["content"] = [{"type": "text", "text": user_msg["content"]}] 76 | for image_file in image_files: 77 | base64_image = self.encode_image(image_file) 78 | user_msg["content"].append({ 79 | "type": "image", 80 | "url": f"data:image/jpeg;base64,{base64_image}" 81 | }) 82 | 83 | return messages 84 | 85 | def generate(self, prompt, return_prompt=False, image_files=None, temperature=None, max_new_tokens=None): 86 | if temperature is None: 87 | temperature = self.temperature 88 | if max_new_tokens is None: 89 | max_new_tokens = self.max_length 90 | 91 | messages = self.get_messages(prompt, image_files=image_files) 92 | response = self.client.chat.completions.create( 93 | messages=messages, 94 | model=self.model_name, 95 | temperature=temperature, 96 | max_tokens=max_new_tokens, 97 | top_p=self.top_p, 98 | n=self.num_return_sequences, 99 | seed=self.seed 100 | ) 101 | 102 | if self.num_return_sequences==1: 103 | return response.choices[0].message.content.strip() 104 | else: 105 | return [choice.message.content.strip() for choice in response.choices] -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sympy 3 | import threading 4 | import warnings 5 | import numpy as np 6 | 7 | from omegaconf import DictConfig 8 | from scipy.optimize import curve_fit 9 | from mloggers import MultiLogger 10 | from typing import Any 11 | from sympy import Mul, Add, Dummy, sift, numbered_symbols 12 | 13 | class Optimizer(object): 14 | """ 15 | Optimizer class used to fit a function to a set of points, given a base shape. 16 | For example, takes as input something like "ax^2 + bx + c" and fits (a, b, c) to a set of points. 17 | """ 18 | def __init__(self, cfg: DictConfig, points: np.ndarray, logger: MultiLogger) -> None: 19 | """ 20 | Initializes the optimizer. 21 | 22 | Parameters 23 | ---------- 24 | cfg : DictConfig -> The configuration file. 25 | points : np.ndarray -> The points to fit to. 26 | logger : MultiLogger -> The logger to log to. 27 | """ 28 | self.cfg = cfg 29 | self.logger = logger 30 | self.points = points 31 | self.num_variables = cfg.experiment.function.num_variables 32 | self.invalid_coefficients = ["x", "y", "e"] 33 | 34 | self.coeff_rounding = cfg.experiment.optimizer.coeff_rounding if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "coeff_rounding") else 2 35 | self.tol = cfg.experiment.optimizer.tol if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "tol") else 1e-3 # Tolerance used to zero out coefficients that are close to 0 36 | self.num_threads = cfg.experiment.optimizer.optimizer_threads if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "optimizer_threads") else 5 37 | self.timeout = cfg.experiment.optimizer.timeout if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "timeout") else 10 38 | self.p0_min = cfg.experiment.optimizer.p0_min if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "p0_min") else -10 39 | self.p0_max = cfg.experiment.optimizer.p0_max if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "p0_max") else 10 40 | 41 | def replace_coefficients(self, exp: sympy.core.add.Add) -> sympy.core.add.Add: 42 | """ 43 | Replaces the number coefficients of a function with symbols. 44 | 45 | Parameters 46 | ---------- 47 | exp : sympy.core.add.Add -> The function to replace coefficients of. 48 | 49 | Returns 50 | ------- 51 | exp : sympy.core.add.Add -> The function with coefficients replaced. 52 | """ 53 | def is_coefficient(symbol: Any) -> bool: 54 | if len(symbol.args) > 0: 55 | for arg in symbol.args: 56 | if not is_coefficient(arg): 57 | return False 58 | 59 | if re.match(r"c\d+", str(symbol)): 60 | return True 61 | elif symbol.is_Dummy: 62 | return True 63 | elif symbol.is_number: 64 | return True 65 | 66 | return False 67 | 68 | # Adapted from https://stackoverflow.com/questions/59686990/replacing-numbers-with-parameters-in-sympy 69 | def nfact2dum(m): 70 | assert m.is_Mul or m.is_Add or m.is_Function 71 | if m.is_Mul: 72 | if not any([is_coefficient(i) for i in m.args]): 73 | return m 74 | nonnum = sift(m.args, lambda i:is_coefficient(i), binary=True)[1] 75 | return Mul(*([Dummy()] + nonnum)) 76 | elif m.is_Add: 77 | if not any([is_coefficient(i) for i in m.args]): 78 | return m 79 | nonnum = sift(m.args, lambda i:is_coefficient(i), binary=True)[1] 80 | return Add(*([Dummy()] + nonnum)) 81 | elif m.is_Function: 82 | args = [] 83 | for arg in m.args: 84 | if arg.is_Mul or arg.is_Add or arg.is_Function: 85 | args.append(nfact2dum(arg)) 86 | else: 87 | args.append(arg) 88 | return Dummy() * m.func(*args) 89 | 90 | # Add +1 at the end of the expression to make sure that a constant is included 91 | exp = exp + 1 92 | 93 | # Replace all symbols beginning with c with a dummy 94 | # (as they are coefficients, otherwise we could generate a symbol that is already in the expression) 95 | exp = exp.replace(lambda x: re.match(r"c\d+", str(x)) or str(x).lower() == "c", lambda x: Dummy()) 96 | 97 | # Replace all coefficients with dummies 98 | exp = exp.replace( 99 | lambda x:x.is_Mul or x.is_Add or x.is_Function, 100 | lambda x: nfact2dum(x)) 101 | # Replace all exponents of dummy variables with 1 102 | exp = exp.replace(lambda x: x.is_Pow and x.base.is_Dummy, lambda x: x.base) 103 | 104 | # Replace all dummies with symbols 105 | exp = exp.subs(list(zip(exp.atoms(Dummy),numbered_symbols('c')))) 106 | 107 | return exp 108 | 109 | def get_optimizable_sympy_exp(self, exp: sympy.core.add.Add, quiet: bool = False) -> Any: 110 | """ 111 | Returns a sympy expression that can be optimized by scipy. 112 | 113 | Parameters 114 | ---------- 115 | exp : sympy.core.add.Add -> The expression to make optimizable. 116 | quiet : bool -> Whether to log the results. 117 | 118 | Returns 119 | ------- 120 | exp : Any -> The optimizable expression. 121 | """ 122 | exp = self.replace_coefficients(exp) 123 | self.logger.info("Optimizing function: " + str(exp)) if not quiet else None 124 | symbols = list(exp.free_symbols) 125 | 126 | # Sort symbols so that all x's come first (to find the variables that aren't coefficients) 127 | symbols.sort(key=lambda x: str(x).replace("x", " ")) 128 | 129 | # Safety check to make sure that the number of variables is correct 130 | num_variables = len(re.findall(r"x\d*", str(symbols))) 131 | assert num_variables == self.num_variables, f"Number of variables ({num_variables}) does not match number of variables in config ({self.num_variables})" 132 | 133 | symbols = [symbols[:num_variables], *symbols[num_variables:]] 134 | return sympy.lambdify(symbols, exp, "numpy"), exp 135 | 136 | def _run_curve_fit(self, f: Any, num_parameters: int, results: Any, done_event: Any, quiet: bool = True, random_p0: bool = True) -> Any: 137 | """ 138 | Runs the curve fit function with a timeout. 139 | 140 | Parameters 141 | ---------- 142 | f : Any -> The function to fit. 143 | num_parameters : int -> The number of parameters to fit. 144 | results : Any -> The results list to append to. 145 | done_event : Any -> The event to set when done. 146 | quiet : bool -> Whether to log the results. 147 | random_p0 : bool -> Whether to use random starting points. 148 | 149 | Returns 150 | ------- 151 | popt : np.ndarray -> The optimized parameters. 152 | pcov : np.ndarray -> The covariance matrix. 153 | """ 154 | p0 = np.random.uniform(self.p0_min, self.p0_max, num_parameters) if random_p0 else np.ones(num_parameters) 155 | popt = None 156 | try: 157 | popt, pcov = curve_fit(f, self.points[:, :-1].T, self.points[:, -1].T, p0=p0) 158 | results.append((popt, pcov)) 159 | done_event.set() 160 | return True 161 | except Exception as e: 162 | print(f"Optimization failed: {e}") 163 | pass 164 | 165 | return False 166 | 167 | def optimize(self, exp: sympy.core.add.Add, return_coeff: bool = False, quiet: bool = False) -> sympy.core.add.Add: 168 | """ 169 | Optimizes a function to a set of points. 170 | 171 | Parameters 172 | ---------- 173 | exp : sympy.core.add.Add -> The base shape to optimize. 174 | return_coeff : bool -> Whether to return the expression in coefficient form. 175 | quiet : bool -> Whether to log the results. 176 | 177 | Returns 178 | ------- 179 | exp : sympy.core.add.Add -> The optimized function. 180 | coeff_exp : sympy.core.add.Add -> The optimized function in coefficient form. (Only if return_coeff is True) 181 | """ 182 | f, exp = self.get_optimizable_sympy_exp(exp, quiet=quiet) 183 | symbols = exp.free_symbols 184 | symbols = sorted(symbols, key=lambda x: str(x).replace("x", " ")) 185 | coefficients = symbols[self.num_variables:] 186 | coeff_exp = exp if return_coeff else None 187 | Xs = self.points[:, :-1].T 188 | ys = self.points[:, -1].T 189 | 190 | # Direct warnings only to console logger as file logger breaks with threading 191 | warnings.filterwarnings("default") 192 | warnings.showwarning = lambda *args, **kwargs: self.logger.warning(str(args[0]), mask=["file"]) 193 | self.logger.info("Redirecting warnings to console logger only.") 194 | 195 | # Run curve fit with a timeout with num_threads random starting points 196 | results = [] 197 | if self.num_threads == 1: 198 | self.logger.info("Running optimization with 1 attempt.") if not quiet else None 199 | self._run_curve_fit(f, len(coefficients), results=results, done_event=threading.Event(), quiet=quiet, random_p0=False) 200 | popt, pcov = results[0] 201 | else: 202 | done_event = threading.Event() 203 | threads = [] 204 | for i in range(self.num_threads): 205 | threads.append(threading.Thread(target=lambda: self._run_curve_fit(f, len(coefficients), results=results, done_event=done_event, quiet=quiet, random_p0=i!=0))) 206 | threads[-1].start() 207 | 208 | done_event.wait(self.timeout) 209 | for thread in threads: 210 | thread.join() 211 | if thread.is_alive(): 212 | self.logger.warning(f"Thread {thread} did not finish in time.") 213 | thread._stop() 214 | self.logger.info("All threads finished.") if not quiet else None 215 | if not done_event.is_set(): 216 | raise ValueError("Optimization failed: timeout reached") 217 | 218 | # Direct warnings back to normal 219 | warnings.filterwarnings("default") 220 | warnings.showwarning = lambda *args, **kwargs: self.logger.warning(str(args[0])) 221 | self.logger.info("Redirecting warnings back to normal (both file and console).") 222 | 223 | # Get the best parameters 224 | best_popt = None 225 | best_pcov = None 226 | best_error = np.inf 227 | for popt, pcov in results: 228 | error = np.sum((f(Xs, *popt) - ys) ** 2) 229 | if error < best_error: 230 | best_error = error 231 | best_popt = popt 232 | best_pcov = pcov 233 | 234 | popt = best_popt 235 | pcov = best_pcov 236 | 237 | if pcov is None or np.isinf(pcov).any() or np.isnan(pcov).any(): 238 | raise ValueError("Optimization failed: covariance matrix is invalid") 239 | popt = [np.round(x, self.coeff_rounding) for x in popt] 240 | self.logger.info("Optimized parameters: " + str(popt)) if not quiet else None 241 | 242 | assert len(coefficients) == len(popt), f"Number of found coefficients {coefficients} does not match number of parameters {len(popt)})" 243 | zero_subs = {} 244 | for i, coefficient in enumerate(coefficients): 245 | if popt[i] < self.tol and popt[i] > -self.tol: 246 | zero_subs[coefficient] = 0 247 | return exp.subs(list(zip(coefficients, popt))), coeff_exp.subs(zero_subs) if return_coeff else None 248 | -------------------------------------------------------------------------------- /plotter.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import utils 3 | import tempfile 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from mpl_toolkits.mplot3d import axes3d 9 | from omegaconf import DictConfig 10 | from collections.abc import Callable 11 | from typing import List, Any 12 | 13 | class Plotter(object): 14 | def __init__(self, cfg: DictConfig, train_points: np.ndarray, test_points: np.ndarray, output_path: str) -> None: 15 | """ 16 | Initializes the plotter. 17 | 18 | Parameters 19 | ---------- 20 | cfg : DictConfig -> The configuration file. 21 | train_points : np.ndarray -> The train points. 22 | test_points : np.ndarray -> The test points. 23 | Xs : np.ndarray -> The Xs of the points, as produced by meshgrid (necessary for contour plots). 24 | ys : np.ndarray -> The ys of the points (necessary for contour plots). 25 | output_path : str -> The path to save the plots to. 26 | """ 27 | self.cfg = cfg 28 | self.num_variables = cfg.experiment.function.num_variables 29 | self.output_path = output_path 30 | 31 | self.train_points = train_points 32 | self.test_points = test_points 33 | assert self.test_points.shape[1] == self.num_variables + 1 34 | Xs = self.train_points[:, :-1] 35 | Xs = np.concatenate((Xs, self.test_points[:, :-1])) 36 | self.min_points = [np.min(self.test_points[:, i]) for i in range(self.num_variables)] 37 | self.max_points = [np.max(self.test_points[:, i]) for i in range(self.num_variables)] 38 | 39 | num_test = self.cfg.plotter.plotter_resolution if hasattr(self.cfg, "plotter") and hasattr(self.cfg.plotter, "plotter_resolution") else 1000 40 | if self.num_variables == 1: 41 | self.Xs_test = np.linspace(self.min_points[0], self.max_points[0], num_test).reshape(-1, 1) 42 | elif self.num_variables == 2: 43 | num_test = np.floor(np.sqrt(num_test)).astype(int) 44 | self.Xs_test = np.meshgrid(*[np.linspace(self.min_points[i], self.max_points[i], num_test) for i in range(self.num_variables)]) 45 | 46 | self.gif_duration = self.cfg.plotter.gif_duration if hasattr(self.cfg, "plotter") and hasattr(self.cfg.plotter, "gif_duration") else 1000 47 | self.fig_size = (self.cfg.plotter.plotter_fig_size, self.cfg.plotter.plotter_fig_size) if hasattr(self.cfg, "plotter") and hasattr(self.cfg.plotter, "plotter_fig_size") else (10, 10) 48 | self.function_cache = {} 49 | 50 | def _eval_function(self, function: Any, Xs: np.ndarray, num_variables: int) -> np.ndarray: 51 | if function in self.function_cache: 52 | return self.function_cache[function] 53 | else: 54 | ys = utils.eval_function(function, Xs, num_variables) 55 | self.function_cache[function] = ys 56 | return ys 57 | 58 | def plot_points(self, save_fig: bool = False, save_path: str = "points.png", plot_test=False) -> None: 59 | """ 60 | Plots a set of points. Used to feed to visual models. 61 | 62 | Parameters 63 | ---------- 64 | save_fig : bool -> Whether to save the figure or not. 65 | save_path : str -> The path to save the figure to. 66 | """ 67 | if self.num_variables > 2: 68 | return 69 | 70 | save_path = self.output_path + save_path 71 | if self.num_variables == 1: 72 | plt.figure(figsize=self.fig_size) 73 | plt.scatter(self.train_points[:, 0], self.train_points[:, 1], color="blue", label="Train points", zorder=1) 74 | if plot_test: 75 | plt.scatter(self.test_points[:, 0], self.test_points[:, 1], color="red", label="Test points", alpha=.25, zorder=1) 76 | # plt.grid(alpha=.4,linestyle='--') 77 | plt.xlabel("x") 78 | plt.ylabel("f(x)") 79 | plt.xlim(self.min_points[0] - 0.1 * (self.max_points[0] - self.min_points[0]), self.max_points[0] + 0.1 * (self.max_points[0] - self.min_points[0])) 80 | plt.ylim(np.min(self.train_points[:, 1]) - 0.1 * (np.max(self.train_points[:, 1]) - np.min(self.train_points[:, 1])), np.max(self.train_points[:, 1]) + 0.1 * (np.max(self.train_points[:, 1]) - np.min(self.train_points[:, 1]))) 81 | plt.legend() 82 | elif self.num_variables == 2: 83 | ax = plt.figure(figsize=self.fig_size).add_subplot(projection='3d') 84 | ax.scatter(self.train_points[:, 0], self.train_points[:, 1], self.train_points[:, 2], c='b', label='Train points') 85 | if plot_test: 86 | ax.scatter(self.test_points[:, 0], self.test_points[:, 1], self.test_points[:, 2], c='r', label='Test points', alpha=.25) 87 | plt.xlabel("x1") 88 | plt.ylabel("x2") 89 | ax.legend(loc="upper right") 90 | else: 91 | raise ValueError("Invalid number of variables.") 92 | 93 | plt.tight_layout() 94 | 95 | if save_fig: 96 | plt.savefig(save_path) 97 | else: 98 | return plt.gcf(), plt.gca() 99 | 100 | def plot_results(self, function: Any, test_function: Any = None, plot_true: bool = True, label: str = "Model's best guess") -> plt.Figure: 101 | """ 102 | Plots the results of the experiment, showing the test points, the true function, and the model's best guess. 103 | 104 | Parameters 105 | ---------- 106 | function : str -> The model's best guess. 107 | test_function : Callable[[float], float] -> The true function. 108 | plot_true : bool -> Whether to plot the true function or not. 109 | label : str -> The label of the model's guess. 110 | 111 | Returns 112 | ------- 113 | plt.Figure -> The figure of the plot. 114 | """ 115 | fig, ax = self.plot_points(plot_test=False) 116 | if test_function is None: 117 | plot_true = False 118 | 119 | if self.num_variables == 1: 120 | if plot_true: 121 | ys_test = self._eval_function(test_function, self.Xs_test, self.num_variables) 122 | plt.plot(self.Xs_test, ys_test, color="red", label="True function", zorder=0, linestyle="--") 123 | 124 | ys = self._eval_function(function, self.Xs_test, self.num_variables) 125 | plt.plot(self.Xs_test, ys, color="green", label=label, zorder=0) 126 | plt.legend(loc="lower right") 127 | 128 | elif self.num_variables == 2: 129 | X1, X2 = self.Xs_test 130 | if plot_true: 131 | Z_test = self._eval_function(test_function, np.array([X1, X2]), self.num_variables).reshape(X1.shape) 132 | ax.plot_surface(X1, X2, Z_test, edgecolor='orangered', lw=0.25, alpha=0.1, label="True function", color="red") 133 | 134 | Z = self._eval_function(function, np.array([X1, X2]), self.num_variables).reshape(X1.shape) 135 | ax.plot_surface(X1, X2, Z, edgecolor='mediumseagreen', lw=0.5, alpha=0.3, label=label, color="green") 136 | ax.legend(loc="upper right") 137 | 138 | return plt.gcf(), plt.gca() 139 | 140 | def record_frame(self, best_function: str, last_function: str, score: float, test_function: Callable[[float], float], round: int, plot_true : bool = True) -> plt.Figure: 141 | """ 142 | Records a frame of the animation. 143 | 144 | Parameters 145 | ---------- 146 | best_function : str -> The model's best guess. 147 | last_function : str -> The model's last guess. 148 | score : float -> The score of the best function. 149 | test_function : Callable[[float], float] -> The true function. 150 | round : int -> The round number. 151 | plot_true : bool -> Whether to plot the true function or not. 152 | 153 | Returns 154 | ------- 155 | plt.Figure -> The figure of the frame. 156 | """ 157 | fig, ax = self.plot_results(best_function, test_function, plot_true=plot_true) 158 | if self.num_variables == 1: 159 | ys_last = self._eval_function(last_function, self.Xs_test, self.num_variables) 160 | plt.plot(self.Xs_test, ys_last, color="orange", label="Last guess") 161 | plt.legend() 162 | elif self.num_variables == 2: 163 | X1, X2 = self.Xs_test 164 | Z_last = self._eval_function(last_function, np.array([X1, X2]), self.num_variables).reshape(X1.shape) 165 | ax.plot_surface(X1, X2, Z_last, edgecolor='orange', lw=0.5, alpha=0.3, label="Last guess", color="orange") 166 | ax.text2D(0.05, 0.95, f"Score: {score:.3f}", transform=ax.transAxes, fontsize=10, verticalalignment='top') 167 | ax.legend(loc="upper right") 168 | 169 | ax.set_title(f"Round {round+1}, R2: {score:.5f}") 170 | fig.tight_layout() 171 | return plt.gcf(), plt.gca() 172 | 173 | def record_video(self, frames: List[plt.Figure]) -> None: 174 | """ 175 | Records the animation from the frames buffer. 176 | 177 | Parameters 178 | ---------- 179 | frames : List[plt.Figure] -> The frames buffer. 180 | """ 181 | images = [] 182 | with tempfile.TemporaryDirectory() as tmp_path: 183 | for i, frame in enumerate(frames): 184 | frame.savefig(tmp_path + f"{i}.png") 185 | for i in range(len(frames)): 186 | images.append(PIL.Image.open(tmp_path + f"{i}.png")) 187 | # Extend the last frame to make the last (best) result more visible 188 | for _ in range(5): 189 | images.append(images[-1]) 190 | 191 | images[0].save(self.output_path + "animation.gif", save_all=True, append_images=images[1:], duration=self.gif_duration, loop=0) -------------------------------------------------------------------------------- /prompts/OPRO/basic_image.txt: -------------------------------------------------------------------------------- 1 | I want you to act as a mathematical function generator. 2 | You are given an image of a set of points, plotted as a scatter plot on a graph. 3 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better. 4 | 5 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions in a list and nothing else. 6 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}. 7 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you. 8 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment! 9 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "... 10 | 11 | {functions} -------------------------------------------------------------------------------- /prompts/OPRO/basic_mixed.txt: -------------------------------------------------------------------------------- 1 | I want you to act as a mathematical function generator. 2 | You are given an image of a graph with a set of point plotted as a scatter plot with (x, y) coordinates below: 3 | {points} 4 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better. 5 | 6 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions and nothing else. 7 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}. 8 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you. 9 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment! 10 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "... 11 | 12 | {functions} -------------------------------------------------------------------------------- /prompts/OPRO/basic_text.txt: -------------------------------------------------------------------------------- 1 | I want you to act as a mathematical function generator. 2 | You are given a set of points with (x, y) coordinates below: 3 | {points} 4 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better. 5 | 6 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions and nothing else. 7 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}. 8 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you. 9 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment! 10 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "... 11 | 12 | {functions} -------------------------------------------------------------------------------- /prompts/OPRO/image_all.txt: -------------------------------------------------------------------------------- 1 | I want you to act as a mathematical function generator. 2 | You are given an image of a graph with a set of point plotted as a scatter plot in blue. 3 | The coordinates of the points are: {points} 4 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better. These functions are also shown in the image in the form of a green line and the formula for the plotted function can be seen in the title. 5 | 6 | Your task is to give me a list of five new potential functions that are different from all the ones reported below, and have a lower error value than all of the functions below. Only output the new functions and nothing else. 7 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}. 8 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you. 9 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "... 10 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment! 11 | 12 | {functions} -------------------------------------------------------------------------------- /prompts/OPRO/image_best.txt: -------------------------------------------------------------------------------- 1 | I want you to act as a mathematical function generator. 2 | You are given an image of a graph with a set of point plotted as a scatter plot in blue. 3 | 4 | The coordinates of the points are: {points} 5 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better. The function with the lowest error (the best function so far) is also shown in the image in the form of a green line. 6 | 7 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions and nothing else. 8 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}. 9 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you. 10 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "... 11 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment! 12 | 13 | {functions} -------------------------------------------------------------------------------- /prompts/OPRO/no_info.txt: -------------------------------------------------------------------------------- 1 | Generate five random functions of the form Function: f(x). The functions you generate should always have at most {num_variables} variables {variables_list}. 2 | Only output the functions and nothing else. -------------------------------------------------------------------------------- /prompts/seed_functions/generate_seed.txt: -------------------------------------------------------------------------------- 1 | I want you to act as a mathematical function generator. 2 | Given a set of points below, you are to come up with 5 potential functions that would fit the points. Don't worry too much about accuracy: your task is to generate a set of functions that are as diverse as possible, so that they can serve as starting points for further optimization. 3 | 4 | To generate the functions, you will start from a set of basic operators and expressions, and combine them into something more complex. 5 | Your options are: 6 | - {num_variables} independent variable symbols: {variables_list}. 7 | - A coefficient symbol: c (there is no need to write a number - write this generic coefficient instead). 8 | - Basic operators: +, -, *, /, ^, sqrt, exp, log, abs 9 | - Trigonometric expressions: sin, cos, tan, sinh, cosh, tanh 10 | - Standard constants: "pi" represents pi and "E" represents euler's constant. 11 | 12 | Make sure there are no numbers in the functions, use the coefficient token 'c' instead. 13 | You are required to use at least one of each variables from the available list ({variables_list}). 14 | Analyze the points carefully: if there are any negative points in the input, sqrt and log can not be used unless the input can never be negative. Be careful about dividing by zero! 15 | The functions should all begin with the indicators "f1 = ", "f2 = "... Only write the new function and no additional output. 16 | Your task is to combine an arbitrary number of these basic blocks to create a complex expression. Don't be afraid to be creative and experiment! The functions should be as complex as possible, combining many different operations. Variety is key! 17 | 18 | Points: {points} 19 | 20 | Functions: 21 | -------------------------------------------------------------------------------- /prompts/seed_functions/generate_seed_image.txt: -------------------------------------------------------------------------------- 1 | I want you to act as a mathematical function generator. 2 | Given a set of points below, you are to come up with 5 potential functions that would fit the points. Don't worry too much about accuracy: your task is to generate a set of functions that are as diverse as possible, so that they can serve as starting points for further optimization. 3 | 4 | To generate the functions, you will start from a set of basic operators and expressions, and combine them into something more complex. 5 | Your options are: 6 | - {num_variables} independent variable symbols: {variables_list}. 7 | - A coefficient symbol: c (there is no need to write a number - write this generic coefficient instead). 8 | - Basic operators: +, -, *, /, ^, sqrt, exp, log, abs 9 | - Trigonometric expressions: sin, cos, tan, sinh, cosh, tanh 10 | - Standard constants: "pi" represents pi and "E" represents euler's constant. 11 | 12 | Make sure there are no numbers in the functions, use the coefficient token 'c' instead. 13 | You are required to use at least one of each variables from the available list ({variables_list}). 14 | Analyze the points carefully: if there are any negative points in the input, sqrt and log can not be used unless the input can never be negative. Be careful about dividing by zero! 15 | The functions should all begin with the indicators "f1 = ", "f2 = "... Only write the new function and no additional output. 16 | Your task is to combine an arbitrary number of these basic blocks to create a complex expression. Don't be afraid to be creative and experiment! The functions should be as complex as possible, combining many different operations. Variety is key! 17 | 18 | Points: {points} 19 | 20 | 21 | Functions: 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | antlr4-python3-runtime==4.9.3 3 | bitsandbytes==0.41.1 4 | fire==0.6.0 5 | huggingface-hub==0.23.2 6 | hydra-core==1.3.2 7 | ipykernel==6.25.2 8 | ipython==8.16.1 9 | ipython-genutils==0.2.0 10 | ipywidgets==8.1.1 11 | jupyter==1.0.0 12 | jupyterlab==4.0.7 13 | matplotlib==3.8.0 14 | matplotlib-inline==0.1.6 15 | mloggers==1.3.1 16 | numpy==1.26.1 17 | omegaconf==2.3.0 18 | openai==1.29.0 19 | pandas==2.2.0 20 | safetensors==0.4.2 21 | scikit-image==0.23.2 22 | scikit-learn==1.3.1 23 | scipy==1.11.3 24 | sympy==1.12.1 25 | tokenizers==0.19.1 26 | torch==2.1.0 27 | torchaudio==2.1.0 28 | torchvision==0.16.0 29 | tqdm==4.66.1 30 | transformers==4.41.2 -------------------------------------------------------------------------------- /scorers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scorer import Scorer 2 | from .basic_scorer import BasicScorer 3 | from .minmax_scorer import MinMaxScorer 4 | from .complexity_scorer import ComplexityScorer -------------------------------------------------------------------------------- /scorers/basic_scorer.py: -------------------------------------------------------------------------------- 1 | from .scorer import Scorer 2 | from utils import format_exp 3 | from sklearn.metrics import mean_squared_error 4 | import numpy as np 5 | import utils 6 | import sys 7 | 8 | from typing import Dict, Tuple 9 | 10 | class BasicScorer(Scorer): 11 | """ 12 | Basic scorer for symbolic regression. 13 | Scores the function using unnormalized MSE. 14 | """ 15 | 16 | def __init__(self, points: np.ndarray, rounding: int = 2, scientific: bool = False): 17 | """ 18 | Initialize the scorer. 19 | 20 | Parameters 21 | ---------- 22 | points -> the points to evaluate the function on. 23 | rounding -> number of decimal places to round the score to (-1 for no rounding) 24 | scientific -> whether to use scientific notation for the score. 25 | 26 | Returns 27 | ------- 28 | None. 29 | """ 30 | 31 | super().__init__(points) 32 | self.round = rounding 33 | self.scientific = scientific 34 | 35 | def score(self, function: callable) -> float: 36 | """ 37 | Scores a function on a given set of points. 38 | 39 | Parameters 40 | ---------- 41 | function -> the function to score. 42 | round -> whether to round the score. 43 | scientific -> whether to use scientific notation for the score. 44 | 45 | Returns 46 | ------- 47 | score -> the score of the function. 48 | """ 49 | xs = self.points[:, 0:-1] 50 | ys = self.points[:, -1] 51 | num_variables = xs.shape[1] 52 | 53 | try: 54 | ys_pred = utils.eval_function(function, xs, num_variables) 55 | fit = mean_squared_error(ys, ys_pred) 56 | except Exception as e: 57 | # If the function is invalid for some points (e.g. division by 0), return inf 58 | fit = np.inf 59 | 60 | if self.round > 0: 61 | fit = np.round(fit, self.round) 62 | fit = fit.astype(np.float64) 63 | return fit 64 | 65 | def score_current_functions(self, current_functions: Dict[str, callable]) -> Tuple[Dict, Dict]: 66 | """ 67 | Scores the current functions in the prompt. 68 | 69 | Parameters 70 | ---------- 71 | current_functions -> the current functions in the prompt. 72 | round -> whether to round the score. 73 | scientific -> whether to use scientific notation for the score. 74 | 75 | Returns 76 | ------- 77 | scores -> the score of the current functions. 78 | normalized_scores -> the normalized score of the current functions. 79 | """ 80 | scores = { function: self.score(current_functions[function]) for function in current_functions } 81 | normalized_scores = scores.copy() 82 | if self.scientific: 83 | normalized_scores = { name: format_exp(score, self.round) for name, score in normalized_scores.items()} 84 | 85 | return scores, normalized_scores -------------------------------------------------------------------------------- /scorers/complexity_scorer.py: -------------------------------------------------------------------------------- 1 | from .scorer import Scorer 2 | from utils import format_exp 3 | from sklearn.metrics import mean_squared_error 4 | import numpy as np 5 | import utils 6 | import sys 7 | 8 | from typing import Dict, Tuple 9 | 10 | class ComplexityScorer(Scorer): 11 | """ 12 | Complexity scorer for symbolic regression. 13 | Scores the function using normalized MSE combined with a measure for the complexity of the function. 14 | The score measure was defined in https://arxiv.org/abs/2303.06833 15 | """ 16 | 17 | def __init__(self, points: np.ndarray, rounding: int = 2, scientific: bool = False, lam: float = 0.5, max_nodes: int = 30, alternative: bool = False): 18 | """ 19 | Initialize the scorer. 20 | 21 | Parameters 22 | ---------- 23 | points -> the points to evaluate the function on. 24 | rounding -> number of decimal places to round the score to (-1 for no rounding) 25 | scientific -> whether to use scientific notation for the score. 26 | lam -> the lambda parameter for the complexity term. 27 | max_nodes -> the maximum number of nodes in the expression tree (used for normalization). 28 | alternative -> whether to use the alternative scoring function. 29 | 30 | Returns 31 | ------- 32 | None. 33 | """ 34 | 35 | super().__init__(points) 36 | self.round = rounding 37 | self.scientific = scientific 38 | self.eps = 1e-6 39 | self.lam = lam 40 | self.max_nodes = max_nodes 41 | self.alternative = alternative 42 | 43 | def score(self, function: callable) -> float: 44 | """ 45 | Scores a function on a given set of points. 46 | 47 | Parameters 48 | ---------- 49 | function -> the function to score. 50 | round -> whether to round the score. 51 | scientific -> whether to use scientific notation for the score. 52 | 53 | Returns 54 | ------- 55 | score -> the score of the function. 56 | """ 57 | xs = self.points[:, 0:-1] 58 | ys = self.points[:, -1] 59 | n = xs.shape[0] 60 | num_variables = xs.shape[1] 61 | 62 | try: 63 | ys_pred = utils.eval_function(function, xs, num_variables) 64 | # Calculate normalized MSE 65 | fit = 1/n * np.linalg.norm(ys - ys_pred)**2 / (1/n * np.linalg.norm(ys)**2 + self.eps) 66 | except Exception as e: 67 | # If the function is invalid for some points (e.g. division by 0), return inf 68 | fit = np.inf 69 | fit = fit.astype(np.float64) 70 | 71 | complexity = utils.count_nodes(function) 72 | complexity_term = np.exp(-complexity/self.max_nodes) 73 | 74 | if self.alternative: 75 | error = fit + self.lam * (1-complexity_term) 76 | else: 77 | fit_term = 1/(1 + fit) 78 | error = 1/(fit_term + self.lam * complexity_term + self.eps) 79 | 80 | if self.round > 0: 81 | error = np.round(error, self.round) 82 | 83 | return error 84 | 85 | def score_current_functions(self, current_functions: Dict[str, callable]) -> Tuple[Dict, Dict]: 86 | """ 87 | Scores the current functions in the prompt. 88 | 89 | Parameters 90 | ---------- 91 | current_functions -> the current functions in the prompt. 92 | round -> whether to round the score. 93 | scientific -> whether to use scientific notation for the score. 94 | 95 | Returns 96 | ------- 97 | scores -> the score of the current functions. 98 | normalized_scores -> the normalized score of the current functions. 99 | """ 100 | scores = { function: self.score(current_functions[function]) for function in current_functions } 101 | normalized_scores = scores.copy() 102 | if self.scientific: 103 | normalized_scores = { name: format_exp(score, self.round) for name, score in normalized_scores.items()} 104 | 105 | return scores, normalized_scores -------------------------------------------------------------------------------- /scorers/minmax_scorer.py: -------------------------------------------------------------------------------- 1 | from .scorer import Scorer 2 | from utils import format_exp 3 | from sklearn.metrics import mean_squared_error 4 | from sklearn.preprocessing import MinMaxScaler 5 | import numpy as np 6 | import utils 7 | 8 | from typing import Dict, Tuple 9 | 10 | class MinMaxScorer(Scorer): 11 | """ 12 | Basic scorer for symbolic regression. 13 | Scores the function using normalized MSE. 14 | The MSE is normalized by the MinMaxScaler. 15 | """ 16 | 17 | def __init__(self, points: np.ndarray, min_score: float, max_score: float, rounding: int = 2, scientific: bool = False): 18 | """ 19 | Initialize the scorer. 20 | 21 | Parameters 22 | ---------- 23 | points -> the points to evaluate the function on. 24 | min_score -> the minimum value for the interval to scale the scores to. 25 | max_score -> the maximum value for the interval to scale the scores to. 26 | rounding -> number of decimal places to round the score to (-1 for no rounding) 27 | scientific -> whether to use scientific notation for the score. 28 | 29 | Returns 30 | ------- 31 | None. 32 | """ 33 | 34 | super().__init__(points) 35 | self.round = rounding 36 | self.scientific = scientific 37 | self.min_score = min_score 38 | self.max_score = max_score 39 | self.scaler = MinMaxScaler(feature_range=(min_score, max_score)) 40 | 41 | def score(self, function: callable) -> float: 42 | """ 43 | Scores a function on a given set of points. 44 | 45 | Parameters 46 | ---------- 47 | function -> the function to score. 48 | round -> whether to round the score. 49 | scientific -> whether to use scientific notation for the score. 50 | 51 | Returns 52 | ------- 53 | score -> the score of the function. 54 | """ 55 | xs = self.points[:, 0:-1] 56 | ys = self.points[:, 1] 57 | num_variables = xs.shape[1] 58 | 59 | try: 60 | ys_pred = utils.eval_function(function, xs, num_variables) 61 | fit = mean_squared_error(ys, ys_pred) 62 | except: 63 | # If the function is invalid for some points (e.g. division by 0), return inf 64 | fit = np.inf 65 | 66 | if self.round > 0: 67 | fit = np.round(fit, self.round) 68 | return fit.astype(np.float64) 69 | 70 | def score_current_functions(self, current_functions: list) -> Tuple[Dict, Dict]: 71 | """ 72 | Scores the current functions in the prompt. 73 | 74 | Parameters 75 | ---------- 76 | current_functions -> the current functions in the prompt. 77 | 78 | Returns 79 | ------- 80 | scores -> the scores of the current functions. 81 | """ 82 | scores = { function: self.score(current_functions[function]) for function in current_functions } 83 | inf_scores = { name: score for name, score in scores.items() if score == np.inf } 84 | # Remove inf scores from the list of scores to normalize 85 | # Can't remove them entirely otherwise this would cause inconsistencies (current_functions would be missing some scores). We handle those in the main loop 86 | if len(inf_scores) > 0: 87 | for name in inf_scores: 88 | scores.pop(name) 89 | scores_array = np.array(list(scores.values())).reshape(-1, 1) 90 | if len(scores_array) != 0: 91 | self.scaler.fit(scores_array) 92 | normalized_scores = { name: self.scaler.transform(np.array(score).reshape(-1, 1))[0][0] for name, score in scores.items() } 93 | else: 94 | normalized_scores = scores.copy() 95 | 96 | # Add inf scores back to the list of scores 97 | for name in inf_scores: 98 | scores[name] = np.inf 99 | normalized_scores[name] = np.inf 100 | if self.round > 0 and not self.scientific: 101 | normalized_scores = { name: np.round(score, self.round) for name, score in normalized_scores.items() } 102 | elif self.scientific: 103 | normalized_scores = { name: format_exp(score, self.round) for name, score in normalized_scores.items() } 104 | return scores, normalized_scores -------------------------------------------------------------------------------- /scorers/scorer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Tuple 3 | 4 | import numpy as np 5 | 6 | class Scorer(ABC): 7 | """ 8 | Abstract class for different scoring methods to evaluate the goodness of a function on a given set of points. 9 | """ 10 | 11 | def __init__(self, points: np.ndarray, round: int = 2, scientific: bool = False): 12 | """ 13 | Initialize the scorer. 14 | 15 | Parameters 16 | ---------- 17 | points -> the points to evaluate the function on. 18 | rounding -> number of decimal places to round the score to (-1 for no rounding) 19 | scientific -> whether to use scientific notation for the score. 20 | 21 | Returns 22 | ------- 23 | None. 24 | """ 25 | 26 | self.points = points 27 | # self.tolerance = 1e-6 # Add tolerance to avoid division by 0 28 | # self.points[:, 0:-1] = self.points[:, 0:-1] + self.tolerance 29 | 30 | @abstractmethod 31 | def score(self, function: callable) -> float: 32 | """ 33 | Scores a function on a given set of points. 34 | 35 | Parameters 36 | ---------- 37 | function -> the function to score. 38 | 39 | Returns 40 | ------- 41 | score -> the score of the function. 42 | """ 43 | 44 | pass 45 | 46 | @abstractmethod 47 | def score_current_functions(self, current_functions: Dict[str, callable]) -> Tuple[Dict, Dict]: 48 | """ 49 | Scores the current functions in the prompt. 50 | 51 | Parameters 52 | ---------- 53 | current_functions -> the current functions in the prompt. 54 | 55 | Returns 56 | ------- 57 | scores -> the scores of the current functions. 58 | """ 59 | 60 | pass 61 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from torch import dtype, device 7 | from torch.cuda import get_device_name, is_available 8 | from collections.abc import Callable 9 | from typing import Tuple, List, Dict, Any 10 | 11 | from sklearn.metrics import mean_squared_error, r2_score 12 | from models import LLaVaModelHF, HuggingFaceModel, OpenAIModel 13 | 14 | import sympy 15 | from sympy.parsing.sympy_parser import parse_expr 16 | 17 | def get_job_id() -> str: 18 | """ 19 | Gets the SLURM job id from the environment variables if available. 20 | 21 | Returns 22 | ------- 23 | job_id -> The SLURM job id (or None if not available). 24 | """ 25 | 26 | job_id = os.environ.get("SLURM_JOB_ID", None) 27 | if job_id is not None: 28 | job_id += "_" + os.environ.get("SLURM_ARRAY_TASK_ID", None) if "SLURM_ARRAY_TASK_ID" in os.environ else "" 29 | 30 | return job_id 31 | 32 | def load_model(model_name: str, device: device, dtype: dtype, cache_dir: str = None, model_args = None) -> Any: 33 | """ 34 | Utility to load a model from the HuggingFace model hub. 35 | Mostly needed to deal with LLaVA models, that are not available on the model hub yet. 36 | 37 | Parameters 38 | ---------- 39 | model_name -> the name of the model to load. 40 | device -> the device to load the model on. 41 | dtype -> the dtype to load the model with. 42 | cache_dir -> the cache directory to use for the model. 43 | 44 | Returns 45 | ------- 46 | model -> the loaded model. 47 | """ 48 | if 'llava' in model_name: 49 | model = LLaVaModelHF(model_name, device, dtype, cache_dir, **model_args) 50 | elif 'gpt' in model_name: 51 | model = OpenAIModel(model_name, device, dtype, cache_dir, **model_args) 52 | else: 53 | model = HuggingFaceModel(model_name, device, dtype, cache_dir, **model_args) 54 | 55 | return model 56 | 57 | def get_messages(prompt: str, splits: List[str] = ["system", "user"]) -> List[Dict[str, str]]: 58 | """ 59 | Converts a prompt string into a list of messages for each split. 60 | 61 | Parameters: 62 | prompt (str): The prompt string. 63 | splits (list[str]): A list of the splits to parse. Defaults to ["system", "user"]. 64 | 65 | Returns: 66 | list[dict[str, str]]: A dictionary of the messages for each split. 67 | """ 68 | 69 | messages = [] 70 | for split in splits: 71 | start_tag = f"<{split}>" 72 | end_tag = f"" 73 | 74 | start_idx = prompt.find(start_tag) 75 | end_idx = prompt.find(end_tag) 76 | 77 | # Skip if the split is not in the prompt (e.g. no system prompt) 78 | if start_idx == -1 or end_idx == -1: 79 | continue 80 | messages.append({ 81 | "role": split, 82 | "content": prompt[start_idx + len(start_tag):end_idx].strip() 83 | }) 84 | 85 | # If no splits at all, assume the whole prompt is a user message 86 | if len(messages) == 0: 87 | messages.append({ 88 | "role": "user", 89 | "content": prompt 90 | }) 91 | 92 | return messages 93 | 94 | def load_points(file_path: str) -> np.ndarray: 95 | """ 96 | Loads a set of points from a file. 97 | 98 | Parameters 99 | ---------- 100 | file_path -> the path to the file containing the points. 101 | 102 | Returns 103 | ------- 104 | points -> the points. 105 | """ 106 | if file_path.endswith(".npy"): 107 | points = np.load(file_path) 108 | elif file_path.endswith(".txt"): 109 | points = np.loadtxt(file_path) 110 | elif file_path.endswith(".csv"): 111 | points = pd.read_csv(file_path).values 112 | elif file_path.endswith(".tsv"): 113 | points = pd.read_csv(file_path, sep="\t").values 114 | else: 115 | raise ValueError("Invalid file format. (only .npy, .txt, .csv, and .tsv are supported)") 116 | return points 117 | 118 | def normalize_points(points: np.ndarray, method: str = "minmax", percentile: int = None) -> np.ndarray: 119 | """ 120 | Normalizes a set of points. 121 | 122 | Parameters 123 | ---------- 124 | points -> the points to normalize. 125 | method -> the normalization method to use. (minmax, zscore, percentile) 126 | percentile -> the percentile to use for percentile normalization (if applicable). 127 | 128 | Returns 129 | ------- 130 | points -> the normalized points. 131 | """ 132 | if method == "percentile" and percentile is None: 133 | raise ValueError("Percentile normalization requires a percentile value.") 134 | 135 | ys = np.array([point[-1] for point in points]) 136 | if method == "minmax": 137 | points = np.array([np.concatenate([point[:-1], [(y - ys.min()) / (ys.max() - ys.min())]]) for point, y in zip(points, ys)]) 138 | elif method == "zscore": 139 | points = np.array([np.concatenate([point[:-1], [(y - ys.mean()) / ys.std()]]) for point, y in zip(points, ys)]) 140 | elif method == "percentile": 141 | points = np.array([np.concatenate([point[:-1], [y /np.percentile(ys, percentile)]]) for point, y in zip(points, ys)]) 142 | else: 143 | raise ValueError("Invalid normalization method.") 144 | 145 | points = np.round(points, 4) 146 | return points 147 | 148 | def decimate_points(points: np.ndarray, max_points: int) -> np.ndarray: 149 | """ 150 | Reduces the number of points to a maximum number to be used in the prompt. 151 | 152 | Parameters 153 | ---------- 154 | points -> the points to decimate. 155 | max_points -> the maximum number of points to keep. 156 | 157 | Returns 158 | ------- 159 | points -> the decimated points. 160 | """ 161 | 162 | if points.shape[0] <= max_points: 163 | return points 164 | 165 | # Find an evenly spaced subset of points 166 | indices = np.linspace(0, points.shape[0] - 1, max_points, dtype=int) 167 | points = points[indices] 168 | return points 169 | 170 | def split_points(points: np.ndarray, test_fraction: float, split_strategy: str = "random", seed: int = None) -> Tuple[np.ndarray, np.ndarray]: 171 | """ 172 | Splits a set of points into train and test sets. 173 | 174 | Parameters 175 | ---------- 176 | points -> the points to split. 177 | test_fraction -> the fraction of points to use for the test set. 178 | split_strategy -> the strategy to use for splitting the points. (random, middle, end) 179 | seed -> the seed to use for the random split. 180 | 181 | Returns 182 | ------- 183 | train_points -> the train points. 184 | test_points -> the test points. 185 | """ 186 | num_points = points.shape[0] 187 | num_test_points = int(num_points * test_fraction) 188 | points = points[points[:, 0].argsort()] 189 | 190 | if seed is not None: 191 | np.random.seed(seed) 192 | 193 | #! Middle and end are not working properly with n_variables > 1, not fixed as unused in final version 194 | if split_strategy == "random": 195 | indices = np.random.choice(num_points, num_test_points, replace=False) 196 | mask = np.ones(num_points, dtype=bool) 197 | mask[indices] = False 198 | train_points = points[mask] 199 | test_points = points[~mask] 200 | elif split_strategy == "middle": 201 | start = (num_points - num_test_points) // 2 202 | end = start + num_test_points 203 | train_points = np.concatenate([points[:start], points[end:]]) 204 | test_points = points[start:end] 205 | elif split_strategy == "end": 206 | train_points = points[:-num_test_points] 207 | test_points = points[-num_test_points:] 208 | else: 209 | raise ValueError("Invalid split strategy.") 210 | 211 | return train_points, test_points 212 | 213 | def array_to_string(points: np.ndarray) -> str: 214 | """ 215 | Converts a numpy array of points to a string. 216 | 217 | Parameters 218 | ---------- 219 | points -> the numpy array of points to convert. 220 | 221 | Returns 222 | ------- 223 | points -> the string of points. 224 | """ 225 | points = points.tolist() 226 | points_str = "" 227 | for point in points: 228 | point_str = ", ".join([str(np.round(x, 2)) for x in point]) 229 | point_str = f"({point_str})" 230 | points_str += point_str + ", " 231 | 232 | return points_str[:-2] 233 | 234 | def string_to_array(points: str) -> np.ndarray: 235 | """ 236 | Converts a string of points to a numpy array. 237 | 238 | Parameters 239 | ---------- 240 | points -> the string of points to convert. 241 | 242 | Returns 243 | ------- 244 | points -> the numpy array of points. 245 | """ 246 | points = points.replace("(", "").split("), ") 247 | points = [point.replace(")", "") for point in points] 248 | points = [point.split(", ") for point in points] 249 | points = [[float(coordinate) for coordinate in point] for point in points] 250 | return np.array(points) 251 | 252 | def eval_function(function: sympy.core.function.Function, Xs: np.ndarray, num_variables: int) -> float: 253 | """ 254 | Evaluates a sympy function at a point. 255 | 256 | Parameters 257 | ---------- 258 | function -> the function to evaluate. 259 | Xs -> the points to evaluate the function at. (Variables have to be sorted alphabetically) 260 | num_variables -> the number of variables the function takes. 261 | 262 | Returns 263 | ------- 264 | ys -> the value of the function at x. 265 | """ 266 | symbols = function.free_symbols 267 | symbols = sorted(symbols, key=lambda x: str(x)) 268 | if Xs.shape[-1] != num_variables: 269 | Xs = np.array(list(zip(*[x.flat for x in Xs]))) 270 | 271 | ys = [] 272 | for point in Xs: 273 | if type(point) == np.ndarray: 274 | subs = {symbol: value for symbol, value in zip(symbols, point)} 275 | else: 276 | subs = {symbols[0]: point} 277 | try : 278 | y = function.evalf(subs=subs) 279 | y = float(y) 280 | except Exception as e: 281 | print(f"Error evaluating function: {function} at point {point}. {e}") 282 | y = np.inf 283 | ys.append(y) 284 | 285 | ys = np.array(ys) 286 | ys = ys.astype(np.float32) 287 | return ys 288 | 289 | def clean_function(function: str) -> str: 290 | """ 291 | Cleans a function string to be evaluable. 292 | """ 293 | function = function.strip(".") 294 | function = function.replace(" ", "") 295 | 296 | if "=" in function: 297 | function = function.split("=")[1] 298 | elif ":" in function: 299 | function = function.split(":")[1] 300 | 301 | # Remove characters that are not allowed in a function 302 | removals = ["'", '"', "\\", "\n", "\t", "\r", " ", "_"] 303 | for removal in removals: 304 | function = function.replace(removal, "") 305 | 306 | # Remove trailing operators 307 | while len(function) > 1 and function[-1] in ["+", "-", "*", "/", "**"]: 308 | if len(function) == 1: 309 | return lambda x: 0 310 | function = function[:-1] 311 | 312 | # Remove leading operators 313 | while len(function) > 1 and function[0] in ["+", "*", "/", "**"]: 314 | if len(function) == 1: 315 | return lambda x: 0 316 | function = function[1:] 317 | 318 | # Remove leading indicators of a function definition 319 | removals = ["Function", "Newfunction", "Thefunctionis", ":"] 320 | 321 | for removal in removals: 322 | if removal.lower() in function.lower(): 323 | function = function.replace(removal, "") 324 | function = function.strip() 325 | 326 | return function 327 | 328 | def string_to_function(function: str, num_variables: int = 1) -> Callable[[float], float]: 329 | """ 330 | Converts a string to a callable function using eval. 331 | 332 | Parameters 333 | ---------- 334 | function -> the string to convert. 335 | num_variables -> the number of variables the function should take. 336 | 337 | Returns 338 | ------- 339 | f -> the callable function. 340 | """ 341 | function = clean_function(function) 342 | 343 | np_func = ["sin", "cos", "tan", "exp", "log", "sqrt"] 344 | function = function.replace("^", "**") 345 | #! This only works for variables in x (x1, x2, x3, ...) 346 | #! This only works with coefficients that end with numbers (e.g. c0, c1, c2, ...) 347 | function = re.sub(r"(\d)x", r"\1*x", function) 348 | regex = r"(\d)(" + "|".join(np_func) + ")" 349 | function = re.sub(regex, r"\1*\2", function) 350 | f = parse_expr(function) 351 | return f 352 | 353 | def is_valid_function(function: str, current_functions: Any, num_variables: int = 1) -> Tuple[bool, str]: 354 | """ 355 | Checks if a function is valid. 356 | 357 | Parameters 358 | ---------- 359 | function -> the function to check. 360 | current_functions -> the current functions in the prompt. 361 | num_variables -> the number of variables the function should take. 362 | 363 | Returns 364 | ------- 365 | valid -> whether the function is valid. 366 | reason -> the reason the function is invalid (if applicable). 367 | """ 368 | valid = True 369 | reason = "" 370 | if type(function) == str: 371 | f = string_to_function(function, num_variables) 372 | else: 373 | f = function 374 | symbols = f.free_symbols 375 | variables = [str(symbol) for symbol in symbols if str(symbol).startswith("x")] 376 | 377 | if len(variables) > num_variables: 378 | valid = False 379 | reason = "Too many variables in function." 380 | return valid, reason 381 | 382 | if current_functions is not None and current_functions.func_in_list(f): 383 | valid = False 384 | reason = "Function already in prompt." 385 | return valid, reason 386 | 387 | return valid, reason 388 | 389 | def format_exp(x: float, d: int = 6) -> str: 390 | """ 391 | Formats a number in scientific notation with custom precision. (used in Scorers) 392 | 393 | Parameters 394 | ---------- 395 | x -> the number to format. 396 | d -> the number of decimal places to round to. 397 | 398 | Returns 399 | ------- 400 | x -> the formatted number. 401 | """ 402 | n = int(np.floor(np.log10(abs(x)))) 403 | significand = x / 10 ** n 404 | exp_sign = '+' if n >= 0 else '-' 405 | return f'{significand:.{d}f}e{exp_sign}{n:02d}' 406 | 407 | def func_equals(f1: Any, f2: Any, num_variables: int) -> bool: 408 | """ 409 | Checks if two functions are equal. Used in place of sympy.equals as the latter can become very slow for certain functions. 410 | https://stackoverflow.com/questions/37112738/sympy-comparing-expressions 411 | 412 | Parameters 413 | ---------- 414 | f1 -> the first function. 415 | f2 -> the second function. 416 | num_variables -> the number of variables the functions should take. 417 | 418 | Returns 419 | ------- 420 | equal -> whether the functions are equal. 421 | """ 422 | if f1 == f2: 423 | return True 424 | if f1 is None or f2 is None: 425 | return False 426 | if f1.free_symbols != f2.free_symbols: 427 | return False 428 | if f1.free_symbols != set([sympy.Symbol(f"x{i + 1}") for i in range(num_variables)]): 429 | return False 430 | return False 431 | 432 | def count_nodes(formula: Any) -> int: 433 | """ 434 | Gets the complexity of a sympy formula, represented by the number of nodes in its expression tree. 435 | 436 | Parameters 437 | ---------- 438 | formula -> the formula to get the complexity of. 439 | 440 | Returns 441 | ------- 442 | complexity -> the complexity of the formula. 443 | """ 444 | return formula.count_ops() 445 | 446 | def replace_zero_coefficients(expr: Any, formula: Any, threshold: float = 1e-2) -> Any: 447 | """ 448 | Replaces coefficients that are close to zero in a formula with zero. 449 | 450 | Parameters 451 | ---------- 452 | expr -> the expression to replace coefficients in (with coefficients c0, c1...) 453 | formula -> the formula to replace coefficients in (with numerical coefficients) 454 | threshold -> the threshold to consider a coefficient zero. 455 | 456 | Returns 457 | ------- 458 | expr -> the expression with zero coefficients replaced. 459 | formula -> the formula with zero coefficients replaced. 460 | """ 461 | coeffs_dict = formula.as_coefficients_dict() 462 | expr_dict = expr.as_coefficients_dict() 463 | 464 | for key, value in coeffs_dict.items(): 465 | if abs(value) < threshold: 466 | expr_dict[key] = 0 467 | formula = formula.subs(key, 0) 468 | 469 | expr = expr.subs(expr_dict) 470 | 471 | print(expr, formula) --------------------------------------------------------------------------------