├── .gitignore
├── LICENSE
├── README.md
├── assets
├── animation.gif
└── sketch.png
├── conf
├── config.yaml
├── experiment
│ ├── function
│ │ ├── R
│ │ │ ├── R1.yaml
│ │ │ ├── R2.yaml
│ │ │ └── R3.yaml
│ │ ├── constant
│ │ │ ├── constant1.yaml
│ │ │ ├── constant2.yaml
│ │ │ ├── constant3.yaml
│ │ │ ├── constant4.yaml
│ │ │ ├── constant5.yaml
│ │ │ ├── constant6.yaml
│ │ │ ├── constant7.yaml
│ │ │ └── constant8.yaml
│ │ ├── keijzer
│ │ │ ├── keijzer10.yaml
│ │ │ ├── keijzer11.yaml
│ │ │ ├── keijzer12.yaml
│ │ │ ├── keijzer13.yaml
│ │ │ ├── keijzer14.yaml
│ │ │ ├── keijzer15.yaml
│ │ │ ├── keijzer3.yaml
│ │ │ ├── keijzer4.yaml
│ │ │ ├── keijzer6.yaml
│ │ │ ├── keijzer7.yaml
│ │ │ ├── keijzer8.yaml
│ │ │ └── keijzer9.yaml
│ │ └── nguyen
│ │ │ ├── nguyen1.yaml
│ │ │ ├── nguyen10.yaml
│ │ │ ├── nguyen11.yaml
│ │ │ ├── nguyen12.yaml
│ │ │ ├── nguyen2.yaml
│ │ │ ├── nguyen3.yaml
│ │ │ ├── nguyen4.yaml
│ │ │ ├── nguyen5.yaml
│ │ │ ├── nguyen6.yaml
│ │ │ ├── nguyen7.yaml
│ │ │ ├── nguyen8.yaml
│ │ │ └── nguyen9.yaml
│ ├── scorer
│ │ ├── basic_scorer.yaml
│ │ ├── complexity_scorer.yaml
│ │ └── minmax_scorer.yaml
│ ├── seed_functions
│ │ ├── 2D_linear.yaml
│ │ ├── 2D_mixed.yaml
│ │ ├── coefficients.yaml
│ │ ├── generate.yaml
│ │ ├── linear.yaml
│ │ └── mixed.yaml
│ └── standard.yaml
├── logger
│ └── default_logger.yaml
└── model
│ ├── base_prompt
│ ├── basic_image.yaml
│ ├── basic_mixed.yaml
│ ├── basic_text.yaml
│ ├── image_all.yaml
│ ├── image_best.yaml
│ └── no_info.yaml
│ ├── gpt3.5-turbo.yaml
│ ├── gpt4o.yaml
│ ├── llama3-8b.yaml
│ └── llava1.6-34b.yaml
├── current_functions.py
├── data
├── R
│ ├── R1
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── R2
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ └── R3
│ │ ├── test_points.npy
│ │ └── train_points.npy
├── constant
│ ├── constant1
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── constant2
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── constant3
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── constant4
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── constant5
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── constant6
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── constant7
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ └── constant8
│ │ ├── test_points.npy
│ │ └── train_points.npy
├── keijzer
│ ├── keijzer10
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer11
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer12
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer13
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer14
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer15
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer3
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer4
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer6
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer7
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ ├── keijzer8
│ │ ├── test_points.npy
│ │ └── train_points.npy
│ └── keijzer9
│ │ ├── test_points.npy
│ │ └── train_points.npy
└── nguyen
│ ├── nguyen1
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen10
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen11
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen12
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen2
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen3
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen4
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen5
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen6
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen7
│ ├── test_points.npy
│ └── train_points.npy
│ ├── nguyen8
│ ├── test_points.npy
│ └── train_points.npy
│ └── nguyen9
│ ├── test_points.npy
│ └── train_points.npy
├── download_model.py
├── main.py
├── models
├── __init__.py
├── hf_model.py
├── llava_model_hf.py
└── openai_model.py
├── optimizer.py
├── plotter.py
├── prompts
├── OPRO
│ ├── basic_image.txt
│ ├── basic_mixed.txt
│ ├── basic_text.txt
│ ├── image_all.txt
│ ├── image_best.txt
│ └── no_info.txt
└── seed_functions
│ ├── generate_seed.txt
│ └── generate_seed_image.txt
├── requirements.txt
├── scorers
├── __init__.py
├── basic_scorer.py
├── complexity_scorer.py
├── minmax_scorer.py
└── scorer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **.out
2 | /__pycache__
3 | **/__pycache__
4 | /.ipynb_checkpoints
5 | /outputs
6 | **/outputs
7 | /runs
8 | **/runs
9 | /multirun
10 | **/multirun
11 | /openai
12 | **/profiles
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Matteo Merler
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # In-Context Symbolic Regression
2 |
3 |
4 |
5 |
6 | Overview of the method.
7 |
8 |
9 |
10 | Official code implementation for the ACL 2024 Student Research Workshop paper [In-Context Symbolic Regression: Leveraging Language Models for Function Discovery](https://aclanthology.org/2024.acl-srw.49/). The proposed approach defines an iterative procedure that refines a functional form with an LLM and determines its coefficients with an external optimizer.
11 |
12 |
13 |
14 |
15 | Example of the ICSR optimization loop.
16 |
17 |
18 | ## Setup
19 |
20 | This codebase uses Python 3.11.6.
21 |
22 | First, clone the repository, then install the required packages.
23 | We recommend using a virtual environment to avoid conflicts with other Python packages:
24 |
25 | ```bash
26 | python3 -m venv .venv
27 | source .venv/bin/activate
28 | pip install -r requirements.txt
29 | ```
30 |
31 | ## Usage
32 |
33 | The method is self-contained in the `main.py` script, configured using [Hydra](https://hydra.cc/) from the `conf/` directory. Before running the method, ensure that the configuration files are set up correctly. Specifically, the `root` path in the `conf/config.yaml` file should be set to the absolute path for the root directory of the repository. If models from the HuggingFace platform are used (for example, Llama3 8B), the `cache_dir` path in the `conf/model/llama3-8b.yaml` (or any other added model) should point to a directory where the models are stored. If models from OpenAI are used (for example, GPT-3.5 or GPT-4o), the `api_key` path in the `conf/model/gpt-3.5.yaml` (or any other model) should be set to the path of the API key.
34 |
35 | ```bash
36 | python3 main.py
37 | ```
38 |
39 | This will run the method with the default configuration on the "nguyen1" experiment using Llama3 8B as the language model. Note that this requires a capable GPU. For a quick demo, we suggest using an OpenAI model, like GPT-3.5 (which requires an API key). To run the same experiment with GPT-3.5, use the following command:
40 |
41 | ```bash
42 | python3 main.py model=gpt3.5-turbo
43 | ```
44 |
45 | ## Configuration
46 |
47 | All the configuration files are stored in the `conf/` directory. The values for each parameter in the provided configuration file reflect the ones used for the paper's experiments. The configuration files are organized as follows:
48 |
49 | - All experiments are defined in the `conf/experiment/function`, where all four benchmarks are stored. To add a new benchmark, create a new YAML file in the `conf/experiment/function` directory (or in a subdirectory if necessary) following the existing file structure (see `conf/experiment/function/nguyen/nguyen1.yaml` for an example).
50 | - All models are defined in the `conf/model` directory. By default, we include a configuration file for Llama3 8B (using the HuggingFace platform), LLaVa-NeXT, GPT-3.5 and GPT-4o. To add a new model, create a new YAML file in the `conf/model` directory following the existing file structure (see `conf/model/llama3-8b.yaml` for an example using HugoingFace models or `conf/model/gpt-3.5.yaml` for an example using OpenAI models).
51 | - The file `conf/config.yaml` contains general configuration settings, like the torch device and random seed. The `root` path should be set to the absolute path for the root directory of the repository.
52 | - The file `conf/experiment/standard.yaml` contains the default configuration for all experiments independent of the function, like the coefficient optimizer settings.
53 | - The files in the `conf/experiment/function` subdirectories contain the configuration for each individual function. These include the ground truth function and the point range for the function evaluation. For reproducibility, we also provide the exact random points we sampled for the experiments in the `data` directory. By default, all experiments are performed on this set of points and are not re-sampled. The path for each function's data points is defined in the respective function configuration file.
54 | - The files in the `conf/experiment/scorer` directory contain the configuration for the scoring functions used to evaluate the discovered functions.
55 | - The files in the `conf/experiment/seed_functions` directory contain the configuration for the seed functions used to initialize the discovered functions. In the paper we ask the LLM to generate the initial seed functions, but changing this configuration allows for a manually defined set of seed functions.
56 | - The files in the `conf/model/base_prompt` directory contain the configuration for all the prompts used by the models. In practice, the main difference between prompts is the type of images provided for vision models, used in the experiments in the Appendix of the paper.
57 |
58 | Using hydra, the configuration can be overridden from the command line. Note that there is a difference in overriding a single parameter (i.e. line) in a file (for example the number of iterations in `conf/experiment/function/nguyen/nguyen1.yaml) and overriding the entire file (used when changing the model or the experiment). The former is done with '.' separators, while the latter is done with '/' separators. For example, to run the "keijzer6" experiment with the GPT-3.5 model, 100 iterations and stop when the $R^2$ score on the training points is higher than 0.95, use the following command:
59 |
60 | ```bash
61 | python3 main.py experiment/function=keijzer/keijzer6 model=gpt3.5-turbo experiment.function.iterations=100 experiment.function.tolerance=0.95
62 | ```
63 |
64 | To find the exact name of each parameter, check the respective configuration files.
65 |
66 | ## Code Structure
67 |
68 | The code is organized as follows:
69 |
70 | - The main entry point is the `main.py` This defines a Workspace class that loads the configuration files, sets up the points and model and runs the main loop.
71 | - `current_functions.py` contains an helper class to manage the best functions found during the optimization and add them to the context.
72 | - `optimizer.py` contains the optimizer class that is used to optimize the coefficients of the functions.
73 | - `plotter.py` contains the class used to plot the results of the experiments and to give inputs to the visual models.
74 | - The `scorers` directory contains the classes used to evaluate the discovered functions.
75 | - The `models` directory contains the classes used to interact with the language models.
76 |
77 | ## Citation
78 |
79 | If you find this code useful, please consider citing our paper:
80 |
81 | ```bibtex
82 | @inproceedings{merler-etal-2024-context,
83 | title = "In-Context Symbolic Regression: Leveraging Large Language Models for Function Discovery",
84 | author = "Merler, Matteo and
85 | Haitsiukevich, Katsiaryna and
86 | Dainese, Nicola and
87 | Marttinen, Pekka",
88 | editor = "Fu, Xiyan and
89 | Fleisig, Eve",
90 | booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 4: Student Research Workshop)",
91 | month = aug,
92 | year = "2024",
93 | address = "Bangkok, Thailand",
94 | publisher = "Association for Computational Linguistics",
95 | url = "https://aclanthology.org/2024.acl-srw.49",
96 | doi = "10.18653/v1/2024.acl-srw.49",
97 | pages = "589--606"
98 | }
99 |
100 | ```
101 |
--------------------------------------------------------------------------------
/assets/animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/assets/animation.gif
--------------------------------------------------------------------------------
/assets/sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/assets/sketch.png
--------------------------------------------------------------------------------
/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: llama3-8b
3 | - experiment: standard
4 | - logger: default_logger
5 | - _self_
6 |
7 | # Experiment
8 | output_dir: runs # Output directory where all runs will be saved, with the following structure: output_dir/benchmark_name/experiment_name/model_name/run_id
9 | max_retries: 5 # Maximum number of retries if prompt failed to generate a valid function as the output
10 | force_valid: false # Force the output to be valid (i.e. a valid function). If false, when the model fails to generate a valid function, after max_retries, the output will be the best generated function so far
11 | force_unique: false # Force the output to be unique (i.e. different from all the functions in the prompt)
12 | prompts_path: prompts
13 | max_points_in_prompt: 40 # Maximum number of points in the prompt (if more are provided, they will automatically be downsampled)
14 | checkpoints: [50, 100, 200, 300, 400, 500, 600, 700, 800, 900] # Partial results will be saved at these iterations
15 |
16 | # Torch
17 | device: 'auto' # auto works for both CPU and GPU and can be used in a multi-GPU setup
18 | use_bfloat16: false
19 | seed: -1 # If -1, the seed will be randomly generated
20 |
21 | # Project root
22 | root: ??? # Path to the root of the project, where the 'conf', 'data' directories are located and where main.py is executed
23 |
24 | # Plotter
25 | plotter:
26 | save_video: true
27 | save_frames: false
28 | gif_duration: 1000
29 | plotter_resolution: 1000
30 | plotter_fig_size: 10
--------------------------------------------------------------------------------
/conf/experiment/function/R/R1.yaml:
--------------------------------------------------------------------------------
1 | name: "R1"
2 | group: "R"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/R/R1"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 20
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "(x+1)^3/(x^2-x+1)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/R/R2.yaml:
--------------------------------------------------------------------------------
1 | name: "R2"
2 | group: "R"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/R/R2"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 20
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "(x^5-3*x^3+1)/(x^2+1)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/R/R3.yaml:
--------------------------------------------------------------------------------
1 | name: "R3"
2 | group: "R"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/R/R3"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 20
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "(x^6+x^5)/(x^4+x^3+x^2+x+1)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant1.yaml:
--------------------------------------------------------------------------------
1 | name: "constant1"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant1"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "3.39*x^3 + 2.12*x^2 + 1.78*x"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant2.yaml:
--------------------------------------------------------------------------------
1 | name: "constant2"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant2"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sin(x^2) * cos(x) - 0.75"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant3.yaml:
--------------------------------------------------------------------------------
1 | name: "constant3"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant3"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: [-1, -1]
11 | max_points: [1, 1]
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: [-1, -1]
20 | max_points: [1, 1]
21 | num_points: 500
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sin(1.5*x1)*cos(0.5*x2)"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant4.yaml:
--------------------------------------------------------------------------------
1 | name: "constant4"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant4"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: [0, 0]
11 | max_points: [1, 1]
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: [0, 0]
20 | max_points: [1, 1]
21 | num_points: 500
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "2.7*x1^x2"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant5.yaml:
--------------------------------------------------------------------------------
1 | name: "constant5"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant5"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: 0
11 | max_points: 4
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 4
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sqrt(1.23*x)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant6.yaml:
--------------------------------------------------------------------------------
1 | name: "constant6"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant6"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: 0
11 | max_points: 4
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 4
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x^(0.426)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant7.yaml:
--------------------------------------------------------------------------------
1 | name: "constant7"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant7"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: [-1, -1]
11 | max_points: [1, 1]
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: [-1, -1]
20 | max_points: [1, 1]
21 | num_points: 500
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "2*sin(1.3*x1) + cos(x2)"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/constant/constant8.yaml:
--------------------------------------------------------------------------------
1 | name: "constant8"
2 | group: "constant"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/constant/constant8"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: 0
11 | max_points: 2
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 2
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "ln(x+1.4) + ln(x^2+1.3)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer10.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer10"
2 | group: "keijzer"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer10"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: 0
11 | max_points: 1
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 1
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x1^x2"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer11.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer11"
2 | group: "keijzer"
3 |
4 | # Test points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer11"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -3
11 | max_points: 3
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -3
20 | max_points: 3
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x1*x2 + sin((x1 - 1) * (x2 - 1))"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer12.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer12"
2 | group: "keijzer"
3 |
4 | # Test points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer12"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -3
11 | max_points: 3
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -3
20 | max_points: 3
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x1^4 - x1^3 + (x2^2)/2 - x2"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer13.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer13"
2 | group: "keijzer"
3 |
4 | # Test points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer13"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -3
11 | max_points: 3
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -3
20 | max_points: 3
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "6*sin(x1)*cos(x2)"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer14.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer14"
2 | group: "keijzer"
3 |
4 | # Test points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer14"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -3
11 | max_points: 3
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -3
20 | max_points: 3
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "8/(2+x1^2+x2^2)"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer15.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer15"
2 | group: "keijzer"
3 |
4 | # Test points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer15"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -3
11 | max_points: 3
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -3
20 | max_points: 3
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x1^3/5 + x2^3/2 - x2 - x1"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer3.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer3"
2 | group: "keijzer"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer3"
8 | # This part is only used if generate_points is true
9 | random_points: false
10 | min_points: -1
11 | max_points: 1
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 10000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "0.3*x * sin(2*pi*x)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 100
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer4.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer4"
2 | group: "keijzer"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer4"
8 | # This part is only used if generate_points is true
9 | random_points: false
10 | min_points: 0
11 | max_points: 10
12 | num_points: 200
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0.05
20 | max_points: 10.05
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x^3*exp(-x)*cos(x)*sin(x)*(sin(x)^2*cos(x)-1)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer6.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer6"
2 | group: "keijzer"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer6"
8 | # This part is only used if generate_points is true
9 | random_points: false
10 | min_points: -1
11 | max_points: 1
12 | num_points: 50
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 100
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "(x*(x+1))/2"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer7.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer7"
2 | group: "keijzer"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer7"
8 | # This part is only used if generate_points is true
9 | random_points: false
10 | min_points: 1
11 | max_points: 100
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 1
20 | max_points: 100
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "ln(x)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer8.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer8"
2 | group: "keijzer"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer8"
8 | # This part is only used if generate_points is true
9 | random_points: false
10 | min_points: 0
11 | max_points: 100
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 100
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sqrt(x)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/keijzer/keijzer9.yaml:
--------------------------------------------------------------------------------
1 | name: "keijzer9"
2 | group: "keijzer"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/keijzer/keijzer9"
8 | # This part is only used if generate_points is true
9 | random_points: false
10 | min_points: 0
11 | max_points: 100
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 100
21 | num_points: 1000
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "ln(x + sqrt(x^2 + 1))"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 5
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen1.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen1"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen1"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x^3 + x^2 + x"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen10.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen10"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen10"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: [-1, -1]
11 | max_points: [1, 1]
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: [-1, -1]
20 | max_points: [1, 1]
21 | num_points: 500
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "2*sin(x1)*cos(x2)"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen11.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen11"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen11"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: [0, 0]
11 | max_points: [1, 1]
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: [0, 0]
20 | max_points: [1, 1]
21 | num_points: 500
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x1^x2"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen12.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen12"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen12"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: [-1, -1]
11 | max_points: [1, 1]
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: [-1, -1]
20 | max_points: [1, 1]
21 | num_points: 500
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x1^4 - x1^3 + (1/2)*x2^2 - x2"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen2.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen2"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen2"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x^4 + x^3 + x^2 + x"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen3.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen3"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen3"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x^5 + x^4 + x^3 + x^2 + x"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen4.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen4"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen4"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "x^6 + x^5 + x^4 + x^3 + x^2 + x"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen5.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen5"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen5"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sin(x^2) * cos(x) - 1"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen6.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen6"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen6"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: -1
11 | max_points: 1
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: -1
20 | max_points: 1
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sin(x) + sin(x + x^2)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen7.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen7"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen7"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: 0
11 | max_points: 2
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 2
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "log(x + 1) + log(x^2 + 1)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen8.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen8"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen8"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: 0
11 | max_points: 4
12 | num_points: 20
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: 0
20 | max_points: 4
21 | num_points: 200
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sqrt(x)"
27 | num_variables: 1
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/function/nguyen/nguyen9.yaml:
--------------------------------------------------------------------------------
1 | name: "nguyen9"
2 | group: "nguyen"
3 |
4 | # Train points
5 | train_points:
6 | generate_points: false
7 | data_folder: "data/nguyen/nguyen9"
8 | # This part is only used if generate_points is true
9 | random_points: true
10 | min_points: [-1, -1]
11 | max_points: [1, 1]
12 | num_points: 100
13 | xs_noise_std: 0
14 | ys_noise_std: 0
15 | add_extremes: true # Manually adds points at the extremes of the interval
16 |
17 | # Test points
18 | test_points:
19 | min_points: [-1, -1]
20 | max_points: [1, 1]
21 | num_points: 500
22 |
23 | tolerance: 0.99999
24 |
25 | # Test function
26 | test_function: "sin(x1) + sin(x2^2)"
27 | num_variables: 2
28 |
29 | # Total iterations
30 | iterations: 10
--------------------------------------------------------------------------------
/conf/experiment/scorer/basic_scorer.yaml:
--------------------------------------------------------------------------------
1 | name: basic_scorer
2 | rounding: 8
3 | scientific: false
4 | normalize: false
5 | lambda: 0
--------------------------------------------------------------------------------
/conf/experiment/scorer/complexity_scorer.yaml:
--------------------------------------------------------------------------------
1 | name: complexity_scorer
2 | rounding: 8
3 | scientific: false
4 | normalize: false
5 | lambda: 0.05
6 | max_nodes: 30
7 | alternative: false # Use alternative scorer with custom function
--------------------------------------------------------------------------------
/conf/experiment/scorer/minmax_scorer.yaml:
--------------------------------------------------------------------------------
1 | name: minmax_scorer
2 | rounding: 5
3 | scientific: false
4 | normalize: true
5 |
6 | # Normalization
7 | min_score: 1
8 | max_score: 10
--------------------------------------------------------------------------------
/conf/experiment/seed_functions/2D_linear.yaml:
--------------------------------------------------------------------------------
1 | # Initial functions
2 | functions: ["2x1 + - 3x2 + 1", "-6x1 + 4x2 - 4", "3x1 + x2 + 2", "-10x1 - 14x2 + 21", "x1 + x2", "-x1 + x2", "-3 - 4x1", "x2 - 6"]
--------------------------------------------------------------------------------
/conf/experiment/seed_functions/2D_mixed.yaml:
--------------------------------------------------------------------------------
1 | # Initial functions
2 | functions: ["2x1 + 3x2 + 1", "-6x1 + x2 - 4", "3x1^2 - 7x2^2 + 2x1 + x2 + 1", "-2x1^2 + 3x1 - x2 - 6", "4x1^3 + x2^3 + 3x1^2 - 5x2^2 + 2x1 + 1", "sin(x1) - cos(x2) + 2x1 + 1", "4cos(x1) + 2x1 + x2^2 + 1", "-4x2^2 + 2x1 - 3", "2x1^2 + 5x2^2-7x1+5x2-6"]
--------------------------------------------------------------------------------
/conf/experiment/seed_functions/coefficients.yaml:
--------------------------------------------------------------------------------
1 | # Initial functions
2 | functions: [c0*x1^3 + c1*x1, c0*x1 + c1*log(x1), c0*sin(x1) + c1*cos(x1), c0*exp(x1)*x1, c0*x1/x1^2 + c1, c0*x1^4 + c1*x1^3]
--------------------------------------------------------------------------------
/conf/experiment/seed_functions/generate.yaml:
--------------------------------------------------------------------------------
1 | functions: []
2 | max_tries: 10
3 | generation_tokens: 512
--------------------------------------------------------------------------------
/conf/experiment/seed_functions/linear.yaml:
--------------------------------------------------------------------------------
1 | # Initial functions
2 | functions: ["2x + 1", "-6x - 4", "3x + 2", "-10x + 21", "x", "-x", "-3 - 4x", "13x - 6"]
--------------------------------------------------------------------------------
/conf/experiment/seed_functions/mixed.yaml:
--------------------------------------------------------------------------------
1 | # Initial functions
2 | functions: ["2x + 1", "-6x - 4", "3x^2 + 2x + 1", "-2x^2 + 3x - 6", "4x^3 + 3x^2 + 2x + 1", "sin(x) + 2x + 1", "4cos(x) + 2x + 1", "-4x^3 + 2x - 3"]
--------------------------------------------------------------------------------
/conf/experiment/standard.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - function: nguyen/nguyen1
3 | - seed_functions: generate
4 | - scorer: complexity_scorer
5 | - _self_
6 |
7 | # Seed functions
8 | generate_seed_functions: true
9 |
10 | # Optimizer
11 | optimizer:
12 | optimizer_threads: 5 # Number of threads to use for the optimizer (each with a different initial value for the coefficients)
13 | timeout: 10 # Timeout in seconds, after which the optimizer will stop
14 | p0_min: -5 # Lower bound for the initial value of the coefficients
15 | p0_max: 5 # Upper bound for the initial value of the coefficients
16 | coeff_rounding: 4 # Number of decimal places to round the coefficients to
17 | tol: 1e-3 # Tolerance for the optimizer, under which the coefficients are considered to be zero
--------------------------------------------------------------------------------
/conf/logger/default_logger.yaml:
--------------------------------------------------------------------------------
1 | loggers: ["console", "file"]
2 | level: "INFO"
3 | run_id: ???
--------------------------------------------------------------------------------
/conf/model/base_prompt/basic_image.yaml:
--------------------------------------------------------------------------------
1 | prompt: basic_image.txt
2 |
3 | # Type of input image (points, best_guess, all_guesses)
4 | input_image: points
5 |
6 | # Number of functions to keep in the prompt
7 | prompt_size: 5
--------------------------------------------------------------------------------
/conf/model/base_prompt/basic_mixed.yaml:
--------------------------------------------------------------------------------
1 | prompt: basic_mixed.txt
2 |
3 | # Type of input image (points, best_guess, all_guesses)
4 | input_image: points
5 |
6 | # Number of functions to keep in the prompt
7 | prompt_size: 5
--------------------------------------------------------------------------------
/conf/model/base_prompt/basic_text.yaml:
--------------------------------------------------------------------------------
1 | prompt: basic_text.txt
2 |
3 | # Number of functions to keep in the prompt
4 | prompt_size: 5
--------------------------------------------------------------------------------
/conf/model/base_prompt/image_all.yaml:
--------------------------------------------------------------------------------
1 | prompt: image_all.txt
2 |
3 | # Type of input image (points, best_guess, all_guesses)
4 | input_image: all_guesses
5 |
6 | # Number of functions to keep in the prompt
7 | prompt_size: 5
--------------------------------------------------------------------------------
/conf/model/base_prompt/image_best.yaml:
--------------------------------------------------------------------------------
1 | prompt: image_best.txt
2 |
3 | # Type of input image (points, best_guess, all_guesses)
4 | input_image: best_guess
5 |
6 | # Number of functions to keep in the prompt
7 | prompt_size: 5
--------------------------------------------------------------------------------
/conf/model/base_prompt/no_info.yaml:
--------------------------------------------------------------------------------
1 | prompt: "no_info.txt"
2 |
3 | # Number of functions to keep in the prompt
4 | prompt_size: 0
--------------------------------------------------------------------------------
/conf/model/gpt3.5-turbo.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base_prompt: basic_text
3 | - _self_
4 |
5 | # General
6 | name: gpt-3.5-turbo
7 | tokenizer_pad: \\[PAD\\]
8 | tokenizer_padding_side: left
9 | visual: false
10 | cache_dir: ''
11 |
12 | # Seed functions prompt
13 | seed_function_prompt: seed_functions/generate_seed.txt
14 |
15 | # Sampling
16 | max_new_tokens: 2048
17 | top_p: 0.90
18 | top_k: 60
19 | num_beams: 1
20 |
21 | # Sampling temperature
22 | temperature: 1.0
23 | temperature_schedule: false
24 | temperature_schedule_gamma: 0.995
25 |
26 | # OpenAI API
27 | api_key_path: openai/openai_key
28 | organization_id_path: openai/openai_org
--------------------------------------------------------------------------------
/conf/model/gpt4o.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base_prompt: basic_text
3 | - _self_
4 |
5 | # General
6 | name: gpt-4o
7 | tokenizer_pad: \\[PAD\\]
8 | tokenizer_padding_side: left
9 | visual: false
10 | cache_dir: ''
11 |
12 | # Seed functions prompt
13 | seed_function_prompt: seed_functions/generate_seed.txt
14 |
15 | # Sampling
16 | max_new_tokens: 2048
17 | top_p: 0.90
18 | top_k: 60
19 | num_beams: 1
20 |
21 | # Sampling temperature
22 | temperature: 1.0
23 | temperature_schedule: false
24 | temperature_schedule_gamma: 0.995
25 |
26 | # OpenAI API
27 | api_key_path: openai/openai_key
28 | organization_id_path: openai/openai_org
--------------------------------------------------------------------------------
/conf/model/llama3-8b.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base_prompt: basic_text
3 | - _self_
4 |
5 | # General
6 | name: meta-llama/Meta-Llama-3-8B-Instruct
7 | tokenizer_pad: \\[PAD\\]
8 | tokenizer_padding_side: left
9 | visual: false
10 | cache_dir: ''
11 |
12 | # Seed functions prompt
13 | seed_function_prompt: seed_functions/generate_seed.txt
14 |
15 | # Sampling
16 | max_new_tokens: 512
17 | top_p: 0.90
18 | top_k: 60
19 | num_beams: 1
20 |
21 | # Sampling temperature
22 | temperature: 1.0
23 | temperature_schedule: false
24 | temperature_schedule_gamma: 0.995
--------------------------------------------------------------------------------
/conf/model/llava1.6-34b.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base_prompt: image_best
3 | - _self_
4 |
5 | # General
6 | name: llava-hf/llava-v1.6-34b-hf
7 | tokenizer_pad: \\[PAD\\]
8 | tokenizer_padding_side: left
9 | visual: true
10 | cache_dir: ''
11 |
12 | # Seed functions prompt
13 | seed_function_prompt: seed_functions/generate_seed_image.txt
14 |
15 | # Sampling
16 | max_new_tokens: 512
17 | top_p: 0.90
18 | top_k: 60
19 | num_beams: 1
20 |
21 | # Sampling temperature
22 | temperature: 1.0
23 | temperature_schedule: false
24 | temperature_schedule_gamma: 0.995
--------------------------------------------------------------------------------
/current_functions.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | import sympy
4 | import utils
5 | import numpy as np
6 |
7 | from typing import Any, Dict
8 |
9 | class CurrentFunctions(object):
10 | """
11 | Helper class to manage the current functions in the prompt.
12 | """
13 |
14 | def __init__(self, seed_functions, scorer, optimizer, context_len, logger, num_variables) -> None:
15 | """
16 | Initialize the class.
17 |
18 | Parameters
19 | ----------
20 | seed_functions -> the seed functions to use.
21 | scorer -> the scorer to use.
22 | context_len -> the length of the context.
23 | logger -> the logger to use.
24 | num_variables -> the number of variables.
25 | """
26 |
27 | self.seed_functions = seed_functions
28 | self.scorer = scorer
29 | self.optimizer = optimizer
30 | self.context_len = context_len
31 | self.logger = logger
32 | self.num_variables = num_variables
33 | functions = [utils.string_to_function(name, self.num_variables) for name in self.seed_functions.keys()]
34 | self.logger.info(f"Seed functions: {functions}.")
35 | self.functions = {}
36 | self.scores = {}
37 | self.norm_scores = {}
38 | self.screen_names = {}
39 |
40 | # Optimize seed function coefficients
41 | for function in functions:
42 | try:
43 | optimized_function, coeff_function = self.optimizer.optimize(function, return_coeff=True, quiet=False)
44 | if self.func_in_list(coeff_function):
45 | self.logger.warning(f"Function {coeff_function} already in prompt.")
46 | continue
47 | self.functions[coeff_function] = optimized_function
48 | self.logger.info(f"Optimized seed function: {str(coeff_function)}.")
49 | except Exception as e:
50 | self.logger.warning(f"Could not optimize function {function}. {e}")
51 | pass
52 | self.logger.info(f"Optimized seed functions: {self.functions}.")
53 | if len(self.functions) == 0:
54 | self.logger.warning("Failed to optimize all seed functions. Function list will be empty.")
55 | else:
56 | self.scores, self.norm_scores = self.scorer.score_current_functions(self.functions)
57 | self.clean_scores()
58 | self.screen_names = {function: re.sub(r'c\d+', 'c', str(function)) for function in self.functions}
59 |
60 | self.logger.info(f"Current scores: {self.scores}.")
61 | self.logger.info(f"Current normalized scores: {self.norm_scores}.")
62 |
63 | def func_in_list(self, function: Any) -> bool:
64 | """
65 | Checks if a function is already in the prompt by assigning the same symbol to all coefficients.
66 |
67 | Parameters
68 | ----------
69 | function -> the function to check.
70 |
71 | Returns
72 | -------
73 | bool -> whether the function is already in the prompt or not.
74 | """
75 | symbols = set(function.free_symbols)
76 | for f in self.functions:
77 | symbols = symbols | set(f.free_symbols)
78 | coeffs = [s for s in symbols if str(s).startswith("c")]
79 | subs = {c: sympy.Symbol('c') for c in coeffs}
80 | function = function.subs(subs)
81 | for f in self.functions:
82 | f = f.subs(subs)
83 | if utils.func_equals(f, function, self.num_variables):
84 | return True
85 | return False
86 |
87 | def clean_scores(self) -> None:
88 | """
89 | Remove eventual inf scores from the scores.
90 | """
91 | print(f"Started cleaning scores {self.scores}.")
92 | removals = []
93 | removals = [function for function in self.scores if self.scores[function] == np.inf]
94 | removals += [function for function in self.norm_scores if self.norm_scores[function] == np.inf and function not in removals]
95 |
96 | for function in removals:
97 | self.logger.warning(f"Removing function {function} with score {self.scores[function]} ({self.norm_scores[function]}) from the prompt.")
98 | self.functions.pop(function)
99 | self.scores.pop(function)
100 | self.norm_scores.pop(function)
101 |
102 | print(f"Finished cleaning scores {self.scores}.")
103 |
104 | def add_function(self, expr: Any, function: Any) -> None:
105 | """
106 | Adds a function to the current functions.
107 |
108 | Parameters
109 | ----------
110 | expr -> the coefficient form of the function.
111 | function -> the function to add.
112 | """
113 | self.logger.info(f"Adding function {expr}.")
114 |
115 | # Check if the function is already in the prompt, necessary if force_unique is False
116 | if self.func_in_list(expr):
117 | self.logger.info(f"Function {expr} already in prompt.")
118 | return
119 |
120 | if len(self.scores) >= self.context_len and self.scorer.score(function) > np.max(list(self.scores.values())):
121 | self.logger.info(f"Function {expr} has score {self.scorer.score(function)}, which is higher than the current worst score {np.max(list(self.scores.values()))}.")
122 | return
123 |
124 | self.functions[expr] = function
125 | self.screen_names[expr] = re.sub(r'c\d+', 'c', str(expr))
126 | self.scores, self.norm_scores = self.scorer.score_current_functions(self.functions)
127 | self.clean_scores()
128 |
129 | # Remove the worst function if the context is full
130 | if len(self.functions) > self.context_len:
131 | worst_function = sorted(self.scores.items(), key=lambda x: x[1], reverse=True)[0][0]
132 | self.logger.info(f"Removing function {worst_function}.")
133 | self.functions.pop(worst_function)
134 | self.screen_names.pop(worst_function)
135 | self.scores.pop(worst_function)
136 | self.norm_scores.pop(worst_function)
137 |
138 | self.logger.info(f"Current scores: {self.scores}.")
139 | self.logger.info(f"Current normalized scores: {self.norm_scores}.")
140 |
141 | def get_best_function(self, return_coeff: bool = True) -> str:
142 | """
143 | Gets the best function in the current functions.
144 |
145 | Returns
146 | -------
147 | best_function -> the best function in the current functions.
148 | return_coeff -> whether to return the function in coefficient form.
149 | """
150 | best_function = sorted(self.scores.items(), key=lambda x: x[1])[0][0]
151 | if return_coeff:
152 | return best_function
153 | else:
154 | return self.functions[best_function]
155 |
156 | def get_prompt_functions(self) -> Dict[str, float]:
157 | """
158 | Gets the prompt functions (from the normalized scores)
159 |
160 | Returns
161 | -------
162 | prompt_functions -> the current functions.
163 | """
164 | top_functions = sorted(self.norm_scores.items(), key=lambda x: x[1])
165 | top_functions = top_functions[:self.context_len]
166 | top_functions = sorted(top_functions, key=lambda x: x[1], reverse=True)
167 | return top_functions
168 |
169 | def get_prompt(self, base_prompt: str) -> str:
170 | """
171 | Generates a prompt for the model, by appending the current functions and their scores to a base prompt.
172 |
173 | Parameters
174 | ----------
175 | base_prompt -> the base prompt to append to.
176 |
177 | Returns
178 | -------
179 | prompt -> the prompt to use for the model.
180 | """
181 | top_functions = self.get_prompt_functions()
182 | functions = "\n".join([f'Function: {self.screen_names[function_name]}\nError: {fit}\n' for function_name, fit in top_functions])
183 | functions += "\nNew Functions: "
184 | prompt = base_prompt.format(functions=functions)
185 | return prompt
--------------------------------------------------------------------------------
/data/R/R1/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R1/test_points.npy
--------------------------------------------------------------------------------
/data/R/R1/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R1/train_points.npy
--------------------------------------------------------------------------------
/data/R/R2/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R2/test_points.npy
--------------------------------------------------------------------------------
/data/R/R2/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R2/train_points.npy
--------------------------------------------------------------------------------
/data/R/R3/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R3/test_points.npy
--------------------------------------------------------------------------------
/data/R/R3/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/R/R3/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant1/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant1/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant1/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant1/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant2/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant2/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant2/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant2/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant3/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant3/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant3/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant3/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant4/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant4/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant4/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant4/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant5/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant5/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant5/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant5/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant6/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant6/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant6/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant6/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant7/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant7/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant7/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant7/train_points.npy
--------------------------------------------------------------------------------
/data/constant/constant8/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant8/test_points.npy
--------------------------------------------------------------------------------
/data/constant/constant8/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/constant/constant8/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer10/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer10/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer10/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer10/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer11/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer11/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer11/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer11/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer12/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer12/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer12/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer12/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer13/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer13/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer13/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer13/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer14/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer14/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer14/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer14/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer15/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer15/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer15/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer15/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer3/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer3/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer3/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer3/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer4/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer4/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer4/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer4/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer6/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer6/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer6/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer6/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer7/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer7/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer7/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer7/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer8/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer8/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer8/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer8/train_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer9/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer9/test_points.npy
--------------------------------------------------------------------------------
/data/keijzer/keijzer9/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/keijzer/keijzer9/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen1/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen1/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen1/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen1/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen10/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen10/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen10/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen10/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen11/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen11/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen11/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen11/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen12/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen12/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen12/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen12/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen2/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen2/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen2/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen2/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen3/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen3/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen3/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen3/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen4/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen4/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen4/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen4/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen5/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen5/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen5/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen5/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen6/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen6/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen6/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen6/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen7/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen7/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen7/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen7/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen8/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen8/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen8/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen8/train_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen9/test_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen9/test_points.npy
--------------------------------------------------------------------------------
/data/nguyen/nguyen9/train_points.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/merlerm/In-Context-Symbolic-Regression/d44d26b006591094d0b5ee65bb9f5ce2b4fe1a95/data/nguyen/nguyen9/train_points.npy
--------------------------------------------------------------------------------
/download_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import time
5 |
6 | from transformers import AutoTokenizer, AutoModelForCausalLM
7 |
8 | def main():
9 | """
10 | Download a pretrained model from HuggingFace and save it to the hf cache directory.
11 | """
12 | parser = argparse.ArgumentParser(description="Download a pretrained model from HuggingFace.")
13 | parser.add_argument("model_name", type=str, help="The name of the model to download.")
14 | parser.add_argument("--hf_cache", type=str, default="models/cache/", help="The path to save the model.")
15 | parser.add_argument("--hf_token", type=str, default=None, help="Huggingface auth token.")
16 |
17 | args = parser.parse_args()
18 | dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | print(f"Downloading model {args.model_name} to {args.hf_cache}...")
22 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, torch_dtype=dtype, cache_dir=args.hf_cache, token=args.hf_token, local_files_only=False)
23 | model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=dtype, cache_dir=args.hf_cache,
24 | token=args.hf_token, local_files_only=False, device_map='auto')
25 | print(f"Model {args.model_name} downloaded and saved to {args.hf_cache}.")
26 |
27 | # print model vocab size
28 | print(f"Model vocab size: {tokenizer.vocab_size}")
29 |
30 | special_token_dict = tokenizer.special_tokens_map
31 | tokenizer.add_special_tokens(special_token_dict)
32 | model.resize_token_embeddings(len(tokenizer))
33 | print(f"Model vocab size after resizing: {tokenizer.vocab_size}")
34 |
35 | print(f"Sample inference:")
36 | prompt = "Tell me a joke about Symbolic Regression."
37 | start = time.perf_counter()
38 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
39 | input_ids = inputs.input_ids
40 | attention_mask = inputs.attention_mask
41 | outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=5000, do_sample=True, temperature=0.9, top_k=50, top_p=0.9, num_return_sequences=1)
42 | print(f"Prompt: {prompt}")
43 | output = tokenizer.decode(outputs[0], skip_special_tokens=True)
44 | output = output[len(prompt):]
45 | end = time.perf_counter()
46 | print(f"Output: {output}")
47 |
48 | print(f"Model was using device {model.device}")
49 | print(f"Model took {end-start:.2f} seconds to generate the output.")
50 |
51 | if __name__ == "__main__":
52 | main()
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import copy
5 | import time
6 | import datetime
7 | import warnings
8 | import signal
9 | import cProfile
10 |
11 | import hydra
12 | import torch
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import utils
16 |
17 | from transformers import set_seed
18 | from omegaconf import OmegaConf, DictConfig, listconfig
19 | from sklearn.metrics import r2_score
20 |
21 | from plotter import Plotter
22 | from optimizer import Optimizer
23 | from current_functions import CurrentFunctions
24 | from scorers import BasicScorer, MinMaxScorer, ComplexityScorer
25 | from mloggers import ConsoleLogger, FileLogger, MultiLogger, LogLevel
26 |
27 | from typing import Dict, Tuple, List, Any
28 | from collections.abc import Callable
29 |
30 |
31 | class Workspace(object):
32 | """
33 | Workspace class for running the symbolic regression experiment.
34 | """
35 | def __init__(self, cfg: DictConfig) -> None:
36 | self.cfg = cfg
37 |
38 | # Output setup
39 | self.root_dir = cfg.get("root", os.getcwd())
40 | self.output_dir = cfg.get("output_dir", "output")
41 | model_folder_name = cfg.model.name.strip()
42 | if "/" in model_folder_name:
43 | model_folder_name = model_folder_name.split("/")[-1]
44 | experiment_folder_name = os.path.join(cfg.experiment.function.group, cfg.experiment.function.name) if hasattr(cfg.experiment.function, "group") else cfg.experiment.function.name
45 | self.output_path = os.path.join(self.root_dir, self.output_dir, experiment_folder_name, model_folder_name, datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "/")
46 | while os.path.exists(self.output_path):
47 | self.output_path = os.path.join(self.root_dir, self.output_dir, experiment_folder_name, model_folder_name, datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + str(np.random.randint(0, 1000)) + "/")
48 | os.makedirs(self.output_path)
49 |
50 | # Logger setup
51 | cfg.logger.run_id = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
52 | loggers_list = cfg.logger.loggers
53 | log_level = LogLevel[cfg.logger.get("level", "INFO")]
54 | loggers = []
55 | for logger in loggers_list:
56 | if logger == "console":
57 | loggers.append(ConsoleLogger(default_priority=log_level))
58 | elif logger == "file":
59 | loggers.append(FileLogger(os.path.join(self.output_path, 'log.json'), default_priority=log_level))
60 | elif logger == "":
61 | pass
62 | else:
63 | print(f'[WARNING] Logger "{logger}" is not supported')
64 | self.logger = MultiLogger(loggers, default_priority=log_level)
65 | self.logger.info(f"Project root: {self.root_dir}.")
66 | self.logger.info(f"Logging to {self.output_path}.")
67 | job_id = utils.get_job_id()
68 | self.logger.info(f"Slurm job ID: {job_id}.") if job_id is not None else None
69 |
70 | # Redirect warnings to logger
71 | warnings.filterwarnings("default")
72 | warnings.showwarning = lambda *args, **kwargs: self.logger.warning(str(args[0]))
73 |
74 | # RNG setup
75 | if not hasattr(cfg, "seed") or cfg.seed is None or cfg.seed == -1:
76 | self.cfg.seed = np.random.randint(0, np.iinfo(np.int32).max)
77 | self.logger.info(f"Seed not specified, using random seed: {self.cfg.seed}.")
78 | else:
79 | self.logger.info(f"Using seed: {self.cfg.seed}.")
80 |
81 | np.random.seed(self.cfg.seed)
82 | torch.manual_seed(self.cfg.seed)
83 | torch.cuda.manual_seed_all(self.cfg.seed) if torch.cuda.is_available() else None
84 | set_seed(self.cfg.seed)
85 |
86 | if torch.cuda.is_available():
87 | torch.cuda.init()
88 |
89 | if cfg.device == "auto":
90 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91 | else:
92 | self.device = torch.device(cfg.device)
93 |
94 | if cfg.get("use_bfloat16", False):
95 | self.dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
96 | else:
97 | self.dtype = torch.float16
98 |
99 | self.logger.info(f"Using device: {self.device} with dtype: {self.dtype}.")
100 | if torch.cuda.is_available() and ('cuda' in cfg.device or 'auto' in cfg.device):
101 | self.logger.info(f"Device name: {torch.cuda.get_device_name()}.")
102 |
103 | self.cache_dir = self.cfg.model.get("cache_dir", os.environ.get("HF_HOME", None))
104 | if self.cache_dir == "":
105 | self.cache_dir = os.environ.get("HF_HOME", None)
106 |
107 | if self.cache_dir is not None:
108 | os.environ['HF_HOME'] = self.cache_dir
109 | os.environ['TRANSFORMERS_CACHE'] = os.environ['HF_HOME']
110 | self.logger.info(f"Cache dir: {os.environ.get('HF_HOME', None)}.")
111 |
112 | # Experiment settings
113 | self.data_folder = cfg.experiment.function.train_points.get("data_folder", None)
114 | self.data_folder = os.path.join(self.root_dir, self.data_folder) if self.data_folder is not None else None
115 |
116 | if cfg.experiment.function.train_points.generate_points:
117 | self.min_train_points = cfg.experiment.function.train_points.min_points
118 | self.max_train_points = cfg.experiment.function.train_points.max_points
119 | self.num_train_points = cfg.experiment.function.train_points.num_points
120 | self.xs_noise_std = cfg.experiment.function.train_points.xs_noise_std
121 | self.ys_noise_std = cfg.experiment.function.train_points.ys_noise_std
122 | self.random_train_points = self.cfg.experiment.function.train_points.random_points \
123 | if hasattr(self.cfg.experiment.function.train_points, "random_points") \
124 | and self.cfg.experiment.function.train_points.random_points else False
125 | else:
126 | assert self.data_folder is not None, "No data folder specified."
127 | assert os.path.exists(self.data_folder), f"Data folder {self.data_folder} does not exist."
128 | assert os.path.exists(os.path.join(self.data_folder, 'train_points.npy')), f"Train points file {os.path.join(self.data_folder, 'train_points.npy')} does not exist."
129 | assert os.path.exists(os.path.join(self.data_folder, 'test_points.npy')), f"Test points file {os.path.join(self.data_folder, 'test_points.npy')} does not exist."
130 |
131 | train_points_file = os.path.join(self.data_folder, 'train_points.npy')
132 | self.train_points = utils.load_points(train_points_file)
133 | self.num_train_points = len(self.train_points)
134 | self.min_train_points = np.min(self.train_points)
135 | self.max_train_points = np.max(self.train_points)
136 | self.logger.info(f"Loaded train points from {train_points_file}.")
137 |
138 | test_points_file = os.path.join(self.data_folder, 'test_points.npy')
139 | self.test_points = utils.load_points(test_points_file)
140 | self.num_test_points = len(self.test_points)
141 | self.min_test_points = np.min(self.test_points)
142 | self.max_test_points = np.max(self.test_points)
143 | self.logger.info(f"Loaded test points from {test_points_file}.")
144 |
145 | self.tolerance = cfg.experiment.function.tolerance
146 | self.num_variables = cfg.experiment.function.num_variables
147 | if self.num_variables > 2 and self.visual_model:
148 | self.logger.error("Visual models only support up to 2 variables.")
149 | exit(1)
150 |
151 | self.iterations = cfg.experiment.function.iterations
152 | self.max_retries = cfg.max_retries
153 | self.force_valid = cfg.force_valid
154 | self.force_unique = cfg.force_unique
155 | self.checkpoints = cfg.checkpoints
156 |
157 | if "test_function" not in cfg.experiment.function:
158 | self.logger.info("Test function is not known.")
159 | self.test_function = None
160 | else:
161 | self.test_function_name = cfg.experiment.function.test_function
162 | self.test_function = utils.string_to_function(self.test_function_name, self.num_variables)
163 |
164 | # Points setup
165 | if cfg.experiment.function.train_points.generate_points:
166 | add_extremes = cfg.experiment.function.train_points.add_extremes if hasattr(cfg.experiment.function.train_points, "add_extremes") else False
167 | self.train_points = self.generate_points(self.test_function, self.min_train_points, self.max_train_points, self.num_train_points,
168 | xs_noise_std=self.xs_noise_std, ys_noise_std=self.ys_noise_std,
169 | random_points=self.random_train_points, save_fig=True, add_extremes=add_extremes)
170 | np.save(os.path.join(self.output_path, "train_points.npy"), self.train_points)
171 |
172 | self.min_test_points = cfg.experiment.function.test_points.min_points if hasattr(cfg.experiment.function.test_points, "min_points") else self.min_train_points
173 | self.max_test_points = cfg.experiment.function.test_points.max_points if hasattr(cfg.experiment.function.test_points, "max_points") else self.max_train_points
174 | self.num_test_points = cfg.experiment.function.test_points.num_points
175 | self.test_points = self.generate_points(self.test_function, self.min_test_points, self.max_test_points, self.num_test_points, random_points=False, save_fig=False)
176 | np.save(os.path.join(self.output_path, "test_points.npy"), self.test_points)
177 |
178 | if cfg.experiment.get("normalize_points", False):
179 | self.train_points = utils.normalize_points(self.train_points, cfg.experiment.normalize_method, cfg.experiment.normalize_percentile)
180 | self.logger.info(f"Normalized train points: {self.train_points}.")
181 |
182 | self.logger.info(f"Train points: {utils.array_to_string(self.train_points)}.")
183 | if self.num_test_points > 100:
184 | self.logger.info(f"Not logging test points as there are more than 100 ({self.num_test_points}).")
185 | else:
186 | self.logger.info(f"Test points: {utils.array_to_string(self.test_points)}.")
187 |
188 | # Optimizer settings
189 | self.optimizer = Optimizer(cfg, self.train_points, self.logger)
190 |
191 | # Plotter setup
192 | if self.num_variables > 2:
193 | self.logger.warning("Plotter will not plot points and animation as there are more than 2 variables.")
194 | self.save_frames = cfg.plotter.save_frames if hasattr(cfg.plotter, "save_frames") else False
195 | if self.save_frames:
196 | os.makedirs(self.output_path + "frames/")
197 |
198 | self.save_video = cfg.plotter.save_video if hasattr(cfg.plotter, "save_video") else True
199 | self.save_video = False if self.num_variables > 2 else self.save_video
200 | if self.save_video:
201 | plt.rcParams.update({'figure.max_open_warning': cfg.experiment.function.iterations + 5})
202 | self.plotter = Plotter(cfg, self.train_points, self.test_points, self.output_path)
203 | self.plotter.plot_points(save_fig=True, plot_test=False)
204 |
205 | # Base prompt
206 | self.prompts_path = os.path.join(self.root_dir, cfg.prompts_path)
207 | self.prompt_size = cfg.model.base_prompt.prompt_size
208 | with open(os.path.join(self.prompts_path, "OPRO", cfg.model.base_prompt.prompt), "r") as f:
209 | self.base_prompt = f.read()
210 |
211 | self.prompt_points = utils.decimate_points(self.train_points, cfg.max_points_in_prompt)
212 | self.prompt_points = utils.array_to_string(self.prompt_points)
213 | if self.num_train_points > self.cfg.max_points_in_prompt:
214 | self.logger.info(f"Found {self.num_train_points} train points, decimated to {cfg.max_points_in_prompt} for the prompt.")
215 | self.logger.info(f"Prompt points: {self.prompt_points}.")
216 |
217 | self.base_prompt = self.base_prompt.format(points=self.prompt_points, num_variables=self.num_variables,
218 | variables_list=[f"x{i+1}" for i in range(self.num_variables)], functions="{functions}")
219 |
220 | # Model settings
221 | self.model_name = cfg.model.name
222 | self.model = None
223 |
224 | self.visual_model = cfg.model.visual
225 | if hasattr(cfg.model.base_prompt, "input_image"):
226 | self.input_img = cfg.model.base_prompt.input_image
227 | valid_inputs = ["points", "previous_guess", "best_guess", "all_guesses"]
228 | if self.input_img not in valid_inputs:
229 | self.logger.error(f"Input image {self.input_img} not supported. Valid inputs are {valid_inputs}.")
230 | exit(1)
231 | else:
232 | self.input_img = None
233 |
234 | self.logger.info(f"Base Prompt: {self.base_prompt} with input image {self.input_img}.")
235 | self.tokenizer_pad = cfg.model.tokenizer_pad
236 | self.tokenizer_padding_side = cfg.model.tokenizer_padding_side
237 |
238 | self.max_new_tokens = cfg.model.max_new_tokens
239 | self.top_p = cfg.model.top_p
240 | self.top_k = cfg.model.top_k
241 | self.num_beams = cfg.model.num_beams
242 |
243 | self.temperature = cfg.model.temperature
244 | if cfg.model.temperature_schedule:
245 | self.temperature_scheduler = torch.optim.lr_scheduler.ExponentialLR(torch.optim.Adam([torch.tensor(self.temperature)], lr=1), gamma=cfg.model.temperature_schedule_gamma)
246 | else:
247 | self.temperature_scheduler = None
248 |
249 | model_args = {
250 | "temperature": self.temperature,
251 | "top_p": self.top_p,
252 | "top_k": self.top_k,
253 | "num_beams": self.num_beams,
254 | "max_length": self.max_new_tokens,
255 | "min_length": 0,
256 | "tokenizer_pad": self.tokenizer_pad,
257 | "tokenizer_padding_side": self.tokenizer_padding_side,
258 | "seed": self.cfg.seed,
259 | "api_key_path": os.path.join(self.root_dir, cfg.model.api_key_path) if hasattr(cfg.model, "api_key_path") else None,
260 | "organization_id_path": os.path.join(self.root_dir, cfg.model.organization_id_path) if hasattr(cfg.model, "organization_id_path") else None,
261 | }
262 | if torch.cuda.is_available() and 'A100' in torch.cuda.get_device_name(0):
263 | model_args['attn_implementation'] = 'flash_attention_2'
264 | model_args['use_flash_attn'] = True
265 | self.logger.info("Using Flash Attention 2")
266 |
267 | self.model = utils.load_model(self.model_name, self.device, self.dtype, self.cache_dir, model_args)
268 | self.logger.info("Model loaded - {model_name}.".format(model_name=self.model_name))
269 |
270 | # Scorer settings
271 | if "basic" in cfg.experiment.scorer.name.lower():
272 | self.scorer = BasicScorer(self.train_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific)
273 | self.test_scorer = BasicScorer(self.test_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific)
274 | elif "minmax" in cfg.experiment.scorer.name.lower():
275 | min_score = cfg.experiment.scorer.min_score
276 | max_score = cfg.experiment.scorer.max_score
277 | self.scorer = MinMaxScorer(self.train_points, min_score=min_score, max_score=max_score, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific)
278 | self.test_scorer = MinMaxScorer(self.test_points, min_score=min_score, max_score=max_score, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific)
279 | elif "complexity" in cfg.experiment.scorer.name.lower():
280 | self.logger.info(f"Complexity scorer with lambda {cfg['experiment']['scorer']['lambda']} and max nodes {cfg.experiment.scorer.max_nodes}.")
281 | alternative = False
282 | if hasattr(cfg.experiment.scorer, "alternative") and cfg.experiment.scorer.alternative:
283 | alternative = True
284 | self.logger.info("Using alternative complexity scorer.")
285 | self.scorer = ComplexityScorer(self.train_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific, lam=cfg['experiment']['scorer']['lambda'], max_nodes=cfg.experiment.scorer.max_nodes, alternative=alternative)
286 | self.test_scorer = ComplexityScorer(self.test_points, rounding=cfg.experiment.scorer.rounding, scientific=cfg.experiment.scorer.scientific, lam=cfg['experiment']['scorer']['lambda'], max_nodes=cfg.experiment.scorer.max_nodes, alternative=alternative)
287 | else:
288 | self.logger.error(f"Scorer {cfg.experiment.scorer.name} not supported.")
289 | exit(1)
290 |
291 | # Seed functions
292 | self.seed_functions = {}
293 | min_seed_functions = max(5, self.prompt_size) # If the prompt size is small (e.g. 1) we still want to generate a few seed functions to avoid getting stuck
294 | gen_time = 0
295 | if cfg.experiment.generate_seed_functions:
296 | self.seed_functions, gen_time = self.generate_seed_functions()
297 | assert len(self.seed_functions) >= min_seed_functions, f"Could not generate {min_seed_functions} seed functions. Generated {len(self.seed_functions)} seed functions."
298 | else:
299 | self.seed_functions = {name: utils.string_to_function(name, self.num_variables) for name in cfg.experiment.seed_functions.functions}
300 | self.logger.info(f"Loaded seed functions: {self.seed_functions}.")
301 | self.current_functions = CurrentFunctions(self.seed_functions, self.scorer, self.optimizer, self.prompt_size, self.logger, self.num_variables)
302 | self.logger.info(f"Current functions: {self.current_functions.functions}.")
303 | self.logger.info(f"Current scores: {self.current_functions.scores}.")
304 |
305 | if len(self.current_functions.functions) < self.prompt_size:
306 | self.logger.warning(f"Could not generate {self.prompt_size} seed functions. Generated {len(self.current_functions.functions)} seed functions.")
307 | if len(self.current_functions.functions) == 0:
308 | self.logger.error("No seed functions generated. Exiting.")
309 | exit(1)
310 | else:
311 | self.logger.info(f"Succesfully generated {self.prompt_size} seed functions in {gen_time} seconds.")
312 |
313 | # Results json
314 | self.results = {
315 | "experiment_name": self.cfg.experiment.function.name,
316 | "seed": self.cfg.seed,
317 | "train_points": utils.array_to_string(self.train_points),
318 | "test_points": utils.array_to_string(self.test_points),
319 | "best_expr": "",
320 | "best_function": "",
321 | "scores": [],
322 | "R2_trains": [],
323 | "R2_tests": [],
324 | "R2_alls": [],
325 | "best_scores": [],
326 | "best_scores_normalized": [],
327 | "iterations": 0,
328 | "tries_per_iteration": [],
329 | "generations_per_iteration": [],
330 | "num_unique": len(self.current_functions.functions),
331 | "best_found_at": 0,
332 | "sympy_equivalent": False,
333 | "temperatures": [],
334 | "times": {
335 | "iteration": [],
336 | "seed_function_generation": gen_time,
337 | "generation_per_iteration": [],
338 | "optimization_per_iteration": [],
339 | }
340 | }
341 | if "test_function" in self.cfg.experiment.function:
342 | self.results["test_function"] = self.cfg.experiment.function.test_function
343 |
344 | # Save config
345 | with open(self.output_path + "config.yaml", "w") as f:
346 | OmegaConf.save(self.cfg, f)
347 |
348 | def generate_points(self, function: Callable, min_points: float, max_points: float, num: int, xs_noise_std: float = 0, ys_noise_std: float = 0,
349 | random_points: bool = False, add_extremes: bool = True) -> str:
350 | """
351 | Generates points from a given function, with optional noise.
352 |
353 | Parameters
354 | ----------
355 | function -> the function to generate points from.
356 | min_points -> the minimum value of the points to generate.
357 | max_points -> the maximum value of the points to generate.
358 | num -> the number of points to generate.
359 | xs_noise_std -> the standard deviation of the noise to add to the xs.
360 | ys_noise_std -> the standard deviation of the noise to add to the ys.
361 | random_points -> whether to generate random points instead of a grid/meshgrid.
362 | add_extremes -> whether to add points at the extreme values of the interval manually to ensure they are included.
363 |
364 | Returns
365 | -------
366 | points -> the points as a string.
367 | """
368 | min_value = copy.deepcopy(min_points)
369 | max_value = copy.deepcopy(max_points)
370 | if type(min_points) != list and type(min_points) != listconfig.ListConfig:
371 | min_points = [min_points] * self.num_variables
372 | if type(max_points) != list and type(max_points) != listconfig.ListConfig:
373 | max_points = [max_points] * self.num_variables
374 | min_points = np.array(min_points, dtype=np.float32)
375 | max_points = np.array(max_points, dtype=np.float32)
376 |
377 | points_per_dim = int(np.floor(num**(1/self.num_variables)))
378 | self.logger.info(f"Generating {points_per_dim} points per dimension for a total of {points_per_dim**self.num_variables} points.")
379 |
380 | if random_points:
381 | # Add points at the extreme values of the interval manually to ensure they are included
382 | # This depends on the number of dimensions
383 | # For example, in 1D if the interval is [0, 1] we need to add points at 0 and 1
384 | # In 2D, if the interval is [(0, 0), (1, 1)] we need to add points at (0, 0), (0, 1), (1, 0), (1, 1)
385 | if add_extremes:
386 | variable_ranges = np.array([[min_points[i], max_points[i]] for i in range(self.num_variables)])
387 | extreme_points = np.array(np.meshgrid(*variable_ranges)).T.reshape(-1, self.num_variables)
388 | self.logger.info(f"Adding {len(extreme_points)} extreme points ({extreme_points}).")
389 |
390 | # Reshape min and max points to match the random shape. Currently min and max are of shape (num_variables,), so we need to add n dimensions of size points_per_dim by copying the min and max values
391 | # For example, if min is [0, 1] and num_variables is 2 and points_per_dim is 3, we need to reshape min to an array of shape (2, 3, 3) with all values being 0 and 1 across the last dimension
392 | random_shape = tuple([self.num_variables, *([points_per_dim] * self.num_variables)])
393 | min_points = np.expand_dims(min_points, axis=tuple(range(1, self.num_variables + 1)))
394 | max_points = np.expand_dims(max_points, axis=tuple(range(1, self.num_variables + 1)))
395 | max_points += 1e-10 # Add small eps to max_points as the rightmost value is not included in np.random.uniform
396 | for i in range(1, self.num_variables+1):
397 | min_points = np.repeat(min_points, points_per_dim, axis=i)
398 | max_points = np.repeat(max_points, points_per_dim, axis=i)
399 | Xs = np.random.uniform(min_points, max_points, random_shape)
400 |
401 | else:
402 | Xs = np.meshgrid(*[np.linspace(min_points[i], max_points[i], points_per_dim) for i in range(self.num_variables)])
403 | Xs = np.array(Xs)
404 | if xs_noise_std:
405 | Xs += np.random.normal(0, xs_noise_std, Xs.shape)
406 | pts = np.array(list(zip(*[x.flat for x in Xs])))
407 |
408 | ys = utils.eval_function(function, pts, self.num_variables).T
409 |
410 | if ys_noise_std:
411 | ys += np.random.normal(0, ys_noise_std, ys.shape)
412 |
413 | if random_points and add_extremes:
414 | pts = np.concatenate((extreme_points, pts))
415 | extreme_ys = utils.eval_function(function, extreme_points, self.num_variables).T
416 | ys = np.concatenate((extreme_ys, ys))
417 |
418 | points = np.concatenate((pts, ys.reshape(-1, 1)), axis=1)
419 | if add_extremes and len(points) > num:
420 | # Remove random points to account for the extra extremes
421 | # The points are sampled randomly so removing from the end is the same as removing random indices
422 | self.logger.info(f"Removing {len(points)-num} randomly generated points: {points[num:]}")
423 | points = points[:num]
424 | while any(np.isinf(points[:, -1])):
425 | # Remove points where the function is infinite
426 | inf_indices = np.where(np.isinf(points))
427 | self.logger.info(f"Removing {len(inf_indices)} points where the function is infinite.")
428 | points = np.delete(points, inf_indices[0], axis=0)
429 |
430 | if len(points) < num and random_points:
431 | # Generate new points to replace the infinite ones
432 | self.logger.info(f"Recursively generating {num-len(points)} new points.")
433 | new_points = self.generate_points(function, min_value, max_value, num-len(points), xs_noise_std, ys_noise_std, random_points, add_extremes=False)
434 | points = np.concatenate((points, new_points))
435 |
436 | return points
437 |
438 | def generate_seed_functions(self) -> Tuple[Dict[str, Any], float]:
439 | """
440 | Generates initial seed functions for the experiment.
441 |
442 | Parameters
443 | ----------
444 |
445 | Returns
446 | -------
447 | seed_functions -> the generated seed functions.
448 | gen_time -> the time it took to generate the seed functions.
449 | """
450 | generation_tokens = self.cfg.experiment.seed_functions.generation_tokens if hasattr(self.cfg.experiment.seed_functions, "generation_tokens") else 512
451 | max_tries = self.cfg.experiment.seed_functions.max_tries if hasattr(self.cfg.experiment.seed_functions, "max_tries") else 10
452 | seed_functions = {}
453 |
454 | seed_prompt = self.cfg.get("model").get("seed_function_prompt", None)
455 | assert seed_prompt is not None, "Seed function prompt not specified."
456 | seed_prompt = os.path.join(self.prompts_path, seed_prompt)
457 |
458 | with open(seed_prompt, "r") as f:
459 | prompt = f.read()
460 | img_path = os.path.join(self.output_path, "points.png") if self.input_img else None
461 |
462 | prompt = prompt.format(points=self.prompt_points, num_variables=self.num_variables, variables_list=[f"x{i+1}" for i in range(self.num_variables)])
463 | self.logger.info("Prompt for seed functions generation:")
464 | self.logger.info(prompt)
465 |
466 | start_time = time.perf_counter()
467 | with torch.inference_mode():
468 | for i in range(max_tries):
469 | # Generate seed functions using the model
470 | self.logger.info(f"Attempt {i+1} of {max_tries} to generate seed functions.")
471 | seeds = self.model.generate(prompt, return_prompt=False, image_files=img_path, temperature=self.temperature, max_new_tokens=generation_tokens)
472 | self.logger.info("Model output for seed functions: " + seeds)
473 |
474 | # Parse model output
475 | for seed in seeds.split("\n"):
476 | if "x" not in seed:
477 | self.logger.info(f"Skipping line {seed} as it does not contain 'x' and is likely not a function.")
478 | continue
479 | if "Error" in seed:
480 | self.logger.info(f"Skipping line {seed} as it contains 'Error'.")
481 | continue
482 | seed = utils.clean_function(seed)
483 | self.logger.info(f"Seed function: {seed}.")
484 | if seed == "":
485 | continue
486 | try:
487 | valid, reason = utils.is_valid_function(seed, None, self.num_variables)
488 | self.logger.info(f"Function {seed}. Valid: {valid}. Reason: {reason}.")
489 | if valid:
490 | function = utils.string_to_function(seed, self.num_variables)
491 | seed_functions[seed] = function
492 | except Exception as e:
493 | self.logger.warning(f"Could not parse line {seed}.")
494 | self.logger.warning(str(e))
495 | pass
496 | # Here we continue even if we already have enough seed functions, as we might not have enough valid seed functions after optimization
497 | # Perhaps a better approach should be optimizing here directly and exiting if we have enough valid seed functions
498 | end_time = time.perf_counter()
499 |
500 | self.logger.info(f"Generated seed functions: {seed_functions}.")
501 | return seed_functions, end_time - start_time
502 |
503 |
504 | def get_new_function(self, prompt: str) -> Tuple[List, bool]:
505 | """
506 | Generates a new function from the model, given a prompt.
507 |
508 | Parameters
509 | ----------
510 | prompt -> the prompt to use for the model.
511 |
512 | Returns
513 | -------
514 | functions -> the new functions generated by the model as a string.
515 | found_valid -> whether a valid function was found.
516 | """
517 | img = None
518 | if self.visual_model:
519 | if self.input_img == "points":
520 | img = os.path.join(self.output_path, "points.png")
521 | elif self.input_img == "previous_guess":
522 | img = os.path.join(self.output_path, "frames", f"{self.results['iterations']-1}.png")
523 | elif self.input_img == "best_guess":
524 | fig, ax = self.plotter.plot_results(self.current_functions.get_best_function(return_coeff=False), self.test_function, plot_true=False)
525 | fig.savefig(self.output_path + "best_guess.png")
526 | plt.close(fig)
527 | img = os.path.join(self.output_path, "best_guess.png")
528 | elif self.input_img == "all_guesses":
529 | os.makedirs(self.output_path + "prompt_input/")
530 | path = os.path.join(self.output_path, "prompt_input/")
531 | img = []
532 | functions = self.current_functions.get_prompt_functions()
533 | for expr, _ in functions:
534 | try:
535 | function = self.current_functions.functions[expr]
536 | fig, ax = self.plotter.plot_results(function, self.test_function, plot_true=False, label="Function: " + str(function))
537 | function_string = str(expr)
538 | fig.suptitle("Plot of " + function_string)
539 | fig.text(0.5, 0.90, "Error: " + str(self.current_functions.norm_scores[expr]), ha='center')
540 | file_name = function_string.replace(" ", "_").replace("/", "div")
541 | fig.savefig(os.path.join(path, f"{file_name}.png"))
542 | plt.close(fig)
543 | img.append(os.path.join(path, f"{file_name}.png"))
544 | except Exception as e:
545 | self.logger.warning(f"Could not plot function {function}.")
546 | self.logger.warning(str(e))
547 | pass
548 |
549 | new_output = self.model.generate(prompt, return_prompt=False, image_files=img, temperature=self.temperature)
550 | self.logger.info("Prompt: " + prompt)
551 | self.logger.info("Model output: " + new_output)
552 |
553 | # Clean up images
554 | if self.visual_model and self.input_img == "best_guess":
555 | os.remove(self.output_path + "best_guess.png")
556 | elif self.visual_model and self.input_img == "all_guesses":
557 | for file in os.listdir(path):
558 | os.remove(os.path.join(path, file))
559 | os.rmdir(path)
560 |
561 | functions = []
562 | lines = new_output.split("\n")
563 | for line in lines:
564 | if "x" not in line:
565 | self.logger.info(f"Skipping line {line} as it does not contain 'x' and is likely not a function.")
566 | continue
567 | if "Error" in line:
568 | self.logger.info(f"Skipping line {line} as it contains 'Error'.")
569 | continue
570 | line = utils.clean_function(line)
571 | if line == "":
572 | continue
573 | self.logger.info("Cleaned line: " + line + ".")
574 | try:
575 | valid, reason = utils.is_valid_function(line, self.current_functions, self.num_variables)
576 | self.logger.info(f"Valid: {valid}. Reason: {reason}.")
577 | if valid:
578 | functions.append(line)
579 | elif not valid and reason == "Function already in prompt." and not self.force_unique:
580 | functions.append(line)
581 | except Exception as e:
582 | self.logger.warning(f"Could not parse line {line}.")
583 | self.logger.warning(str(e))
584 | pass
585 |
586 | found_valid = False
587 | if len(functions) == 0:
588 | self.logger.warning("Could not find a valid function in the output. Using the last function in the output.")
589 | functions = [self.current_functions.get_best_function()]
590 | else:
591 | found_valid = True
592 | functions = [utils.string_to_function(function, self.num_variables) for function in functions]
593 | self.logger.info(f"Found functions: {functions}.")
594 |
595 | return functions, found_valid
596 |
597 | def get_new_function_and_score(self, prompt: str) -> Tuple[Any, Any, float]:
598 | """
599 | Generates a new function from the model, given a prompt, and scores it if it is valid.
600 |
601 | Parameters
602 | ----------
603 | prompt -> the prompt to use for the model.
604 |
605 | Returns
606 | -------
607 | expression -> the coefficient form of the function.
608 | function -> function with the optimized coefficients.
609 | score -> the score of the function.
610 | """
611 | valid = False
612 | start_time = time.perf_counter()
613 | for i in range(self.max_retries):
614 | self.logger.info(f"Attempt {i+1} of {self.max_retries} to find a valid function.")
615 | functions, valid = self.get_new_function(prompt)
616 | if valid and len(functions) > 1:
617 | self.logger.info(f"Found {len(functions)} functions in the output.")
618 | break
619 | else:
620 | self.logger.info(f"Could not find a valid function in the output. Trying again.")
621 |
622 | self.results["tries_per_iteration"].append(i+1)
623 | self.results["times"]["generation_per_iteration"].append(time.perf_counter() - start_time)
624 |
625 | if not valid:
626 | if self.force_valid:
627 | self.logger.error(f"Could not find a valid function after {self.max_retries} tries. Exiting.")
628 | exit(1)
629 | else:
630 | best_expr = self.current_functions.get_best_function(return_coeff=True)
631 | best_function = self.current_functions.functions[best_expr]
632 | best_score = self.current_functions.scores[best_expr]
633 | self.logger.warning(f"Could not find a valid function after {self.max_retries} tries. Using {best_function}.")
634 | else:
635 | best_score = np.inf
636 | start_time = time.perf_counter()
637 | for function in functions:
638 | if not self.current_functions.func_in_list(function):
639 | self.results["num_unique"] += 1
640 | self.results["generations_per_iteration"].append(len(functions))
641 | try:
642 | opt_function, exp = self.optimizer.optimize(function, return_coeff=True)
643 | score = self.scorer.score(opt_function)
644 | self.logger.info(f"New function: {str(opt_function)}. Score: {score}.")
645 | if score < best_score:
646 | best_score = score
647 | best_function = opt_function
648 | best_expr = exp
649 | except Exception as e:
650 | self.logger.warning(f"Could not optimize function {function}. {e}")
651 | pass
652 |
653 | self.results["times"]["optimization_per_iteration"].append(time.perf_counter() - start_time)
654 | self.logger.info(f"Optimizer time: {time.perf_counter() - start_time}.")
655 | if best_score == np.inf:
656 | self.logger.warning(f"No functions scored below inf. Using the best function in the prompt.")
657 | best_expr = self.current_functions.get_best_function(return_coeff=True)
658 | best_function = self.current_functions.functions[best_expr]
659 | best_score = self.current_functions.scores[best_expr]
660 | self.logger.info(f"Best function: {best_function}. Score: {best_score}.")
661 |
662 | self.logger.info(f"Finished get new function and score. Best function: {best_function}. Score: {best_score}.")
663 | return best_expr, best_function, best_score
664 |
665 | def get_R2_scores(self, function: Any) -> Tuple[float, float, float]:
666 | """
667 | Computes the R2 scores for the train and test sets given a function, removing the 5% worst predictions.
668 |
669 | Parameters
670 | ----------
671 | function -> the function to evaluate.
672 |
673 | Returns
674 | -------
675 | r2_train -> the R2 score for the train set.
676 | r2_test -> the R2 score for the test set.
677 | r2_all -> the R2 score for all points.
678 | """
679 |
680 | y_true_train = self.train_points[:, -1]
681 | y_true_test = self.test_points[:, -1]
682 |
683 | def compute_predictions(function, points, num_variables):
684 | y_pred = utils.eval_function(function, points[:, 0:-1], num_variables)
685 |
686 | # Compute a boolean mask of the 5% worst predictions
687 | worst_indices = np.argsort(np.abs(y_pred - points[:, -1]))[-int(len(points) * 0.05):]
688 | mask = np.zeros(len(points), dtype=bool)
689 | mask[worst_indices] = True
690 |
691 | return y_pred[~mask], mask
692 |
693 | try:
694 | y_pred_train, y_train_mask = compute_predictions(function, self.train_points, self.num_variables)
695 | y_pred_test, y_test_mask = compute_predictions(function, self.test_points, self.num_variables)
696 |
697 | y_true_train = y_true_train[~y_train_mask]
698 | y_true_test = y_true_test[~y_test_mask]
699 |
700 | assert len(y_true_train) == len(y_pred_train), f"Length of true train points ({len(y_true_train)}) does not match length of predicted train points ({len(y_pred_train)})."
701 | assert len(y_true_test) == len(y_pred_test), f"Length of true test points ({len(y_true_test)}) does not match length of predicted test points ({len(y_pred_test)})."
702 |
703 | except Exception as e:
704 | self.logger.warning(f"Could not evaluate function {function}. {e}")
705 | return np.nan, np.nan, np.nan
706 | try:
707 | r2_train = r2_score(y_true_train, y_pred_train)
708 | except Exception as e:
709 | self.logger.warning(f"Could not calculate R2 score for train set. {e}")
710 | r2_train = np.nan
711 | try:
712 | r2_test = r2_score(y_true_test, y_pred_test)
713 | except Exception as e:
714 | self.logger.warning(f"Could not calculate R2 score for test set. {e}")
715 | r2_test = np.nan
716 | try:
717 | r2_all = r2_score(np.concatenate((y_true_train, y_true_test)), np.concatenate((y_pred_train, y_pred_test)))
718 | except Exception as e:
719 | self.logger.warning(f"Could not calculate R2 score for all points. {e}")
720 | r2_all = np.nan
721 |
722 | return r2_train, r2_test, r2_all
723 |
724 | def run(self) -> None:
725 | """
726 | Runs the main experiment, iterating and generating new functions until the tolerance is reached.
727 | """
728 | main_timer_start = time.perf_counter()
729 | if self.save_video:
730 | frames = []
731 |
732 | # Check if one of the generated seed functions is already below the tolerance
733 | best_expr = self.current_functions.get_best_function(return_coeff=True)
734 | best_function = self.current_functions.get_best_function(return_coeff=False)
735 | score = self.current_functions.scores[best_expr]
736 | r2_train, r2_test, r2_all = self.get_R2_scores(best_function)
737 |
738 | if r2_train >= self.tolerance:
739 | self.logger.info(f"The seed function {best_expr} is already above the R2 tolerance {self.tolerance}.")
740 | self.logger.info(f"Best function: {best_function}. R2 (train): {r2_train}.")
741 |
742 | self.results["scores"].append(score) if score != np.inf else self.results["scores"].append("inf")
743 | self.results["best_scores"].append(self.current_functions.scores[best_expr])
744 | self.results["best_scores_normalized"].append(self.current_functions.norm_scores[best_expr])
745 | self.results["best_found_at"] = 0
746 | self.results["temperatures"].append(self.temperature)
747 |
748 | if self.save_video:
749 | frame, ax = self.plotter.record_frame(best_function, best_function, r2_test, self.test_function, -1, plot_true=True)
750 | if self.save_frames:
751 | frame.savefig(self.output_path + "frames/" + f"{i}.png")
752 | frames.append(frame)
753 | else:
754 | # Start the main loop
755 | prompt = self.current_functions.get_prompt(self.base_prompt)
756 | for i in range(self.iterations):
757 | start_time = time.perf_counter()
758 | self.logger.info(f"Round {i+1}.")
759 | self.logger.info(f"Scores: {self.current_functions.scores}.")
760 |
761 | # Handle temperature schedule
762 | if self.temperature_scheduler is not None:
763 | self.temperature = self.temperature_scheduler.get_last_lr()[0]
764 | self.logger.info(f"Temperature: {self.temperature}.")
765 | self.results["temperatures"].append(self.temperature)
766 | self.temperature_scheduler.step()
767 |
768 | # Get new function and score
769 | expr, function, score = self.get_new_function_and_score(prompt)
770 | self.current_functions.add_function(expr, function)
771 | best_expr = self.current_functions.get_best_function(return_coeff=True)
772 | best_function = self.current_functions.functions[best_expr]
773 |
774 | # Update results
775 | if expr == best_expr:
776 | self.results["best_found_at"] = i+1
777 | self.results["iterations"] = i+1
778 | self.results["scores"].append(score) if score != np.inf else self.results["scores"].append("inf")
779 | self.results["best_scores"].append(self.current_functions.scores[best_expr])
780 | self.results["best_scores_normalized"].append(self.current_functions.norm_scores[best_expr])
781 | self.results["times"]["iteration"].append(time.perf_counter() - start_time)
782 |
783 | r2_train, r2_test, r2_all = self.get_R2_scores(function)
784 | self.results["R2_trains"].append(r2_train)
785 | self.results["R2_tests"].append(r2_test)
786 | self.results["R2_alls"].append(r2_all)
787 |
788 | # Update video
789 | if self.save_video:
790 | if not score == np.inf:
791 | frame, ax = self.plotter.record_frame(best_function, function, r2_test, self.test_function, i, plot_true=True)
792 | if self.save_frames:
793 | frame.savefig(self.output_path + "frames/" + f"{i}.png")
794 | frames.append(frame)
795 | else:
796 | self.logger.warning(f"Skipping frame {i} as the score is inf.")
797 |
798 | # Check if the tolerance is reached
799 | if self.test_function is not None:
800 | if utils.func_equals(best_function, self.test_function, self.num_variables):
801 | self.logger.info(f"Function is equivalent to the true function.")
802 | self.results["equivalent"] = True
803 | break
804 |
805 | if r2_train >= self.tolerance:
806 | self.logger.info(f"Found a function with R2 (train) = {r2_train} above the tolerance {self.tolerance}.")
807 | break
808 | prompt = self.current_functions.get_prompt(self.base_prompt)
809 |
810 | if i in self.checkpoints:
811 | self.logger.info(f"Checkpoint {i}. Saving results.")
812 | results_checkpoint = copy.deepcopy(self.results)
813 | checkpoint_timer_end = time.perf_counter()
814 | results_checkpoint["times"]["total"] = checkpoint_timer_end - main_timer_start
815 |
816 | best_expr = self.current_functions.get_best_function(return_coeff=True)
817 | best_function = self.current_functions.get_best_function(return_coeff=False)
818 | test_score = self.test_scorer.score(best_function)
819 | results_checkpoint["test_score"] = test_score
820 | results_checkpoint["best_function"] = str(best_function)
821 | results_checkpoint["best_expr"] = str(best_expr)
822 | r2_train, r2_test, r2_all = self.get_R2_scores(best_function)
823 | results_checkpoint["r2_train"] = r2_train
824 | results_checkpoint["r2_test"] = r2_test
825 | results_checkpoint["r2_all"] = r2_all
826 | results_checkpoint["final_complexity"] = utils.count_nodes(best_function)
827 | with open(self.output_path + f"results_checkpoint_{i}.json", "w") as f:
828 | json.dump(results_checkpoint, f)
829 | self.logger.info(f"Checkpoint {i} saved.")
830 |
831 | # Save final results
832 | main_timer_end = time.perf_counter()
833 | best_expr = self.current_functions.get_best_function(return_coeff=True)
834 | best_function = self.current_functions.get_best_function(return_coeff=False)
835 | test_score = self.test_scorer.score(best_function)
836 | self.logger.info(f"Test score: {test_score}.")
837 |
838 | self.logger.info(f"Best function: {best_function}. Score: {self.current_functions.scores[best_expr]} ({self.current_functions.norm_scores[best_expr]}).")
839 |
840 | if hasattr(self, "test_function_name") and self.test_function_name is not None:
841 | self.logger.info(f"True function: {self.test_function_name}")
842 | fig, ax = self.plotter.plot_results(best_function, self.test_function)
843 | fig.savefig(self.output_path + "final.png")
844 |
845 | self.results["best_function"] = str(best_function)
846 | self.results["best_expr"] = str(best_expr)
847 | self.results["test_score"] = test_score
848 | self.results["times"]["total"] = main_timer_end - main_timer_start + self.results["times"]["seed_function_generation"]
849 | self.results["times"]["avg_generation"] = np.mean(self.results["times"]["generation_per_iteration"]) if len(self.results["times"]["generation_per_iteration"]) > 0 else 0
850 | self.results["times"]["avg_optimization"] = np.mean(self.results["times"]["optimization_per_iteration"]) if len(self.results["times"]["optimization_per_iteration"]) > 0 else 0
851 |
852 | r2_train, r2_test, r2_all = self.get_R2_scores(best_function)
853 | self.results["r2_train"] = r2_train
854 | self.results["r2_test"] = r2_test
855 | self.results["r2_all"] = r2_all
856 | self.logger.info(f"R2 train: {np.round(r2_train, 6)}. R2 test: {np.round(r2_test, 6)}. R2 all: {np.round(r2_all, 6)}.")
857 |
858 | final_complexity = utils.count_nodes(best_function)
859 | self.results["final_complexity"] = final_complexity
860 | self.logger.info(f"Number of nodes in final expression tree: {final_complexity}.")
861 |
862 | with open(self.output_path + "results.json", "w") as f:
863 | json.dump(self.results, f)
864 |
865 | if self.save_video and len(frames) > 0:
866 | self.plotter.record_video(frames)
867 |
868 |
869 | @hydra.main(version_base=None, config_path="conf", config_name="config")
870 | def main(cfg: DictConfig) -> None:
871 | workspace = Workspace(cfg)
872 | workspace.run()
873 |
874 | def dump_profile():
875 | profiler.disable()
876 | job_id = utils.get_job_id()
877 | print(f"Dumping profile to {os.path.join(os.getcwd(), 'profiles', 'profile')}_{job_id if job_id is not None else 'local'}")
878 | if not os.path.exists("./profiles"):
879 | os.makedirs("./profiles")
880 | profiler.dump_stats(f"./profiles/profile_{job_id if job_id is not None else 'local'}")
881 |
882 | def signal_handler(sig, frame):
883 | # Ignore warnings, as otherwise we break the logger
884 | warnings.filterwarnings("ignore")
885 | dump_profile()
886 | print(f"Detecting signal {sig}. Dumping profile to {os.path.join(os.getcwd(), 'profiles', 'profile')}_{job_id if job_id is not None else 'local'}")
887 | sys.stdout.flush()
888 | if sig == signal.SIGTERM or sig == signal.SIGINT:
889 | sys.exit(1)
890 |
891 | if __name__ == "__main__":
892 | # Run full profiler if env variable PROFILE is set
893 | do_profile = os.environ.get("PROFILE", False)
894 | print("Initializing profiler.")
895 | print("Profile will only be created if the code fails or is terminated.") if not do_profile else print("Profile will be created.")
896 | job_id = utils.get_job_id()
897 |
898 | # Set termination signal handlers to dump profile when terminated by SLURM
899 | signal.signal(signal.SIGTERM, signal_handler)
900 | signal.signal(signal.SIGINT, signal_handler)
901 | signal.signal(signal.SIGCONT, signal_handler)
902 |
903 | # Setup profiler
904 | global profiler
905 | profiler = cProfile.Profile()
906 | profiler.enable()
907 |
908 | try:
909 | main()
910 | except Exception as e:
911 | # Catch exceptions and dump profile
912 | print("Caught exception in main.")
913 | print(e)
914 | dump_profile()
915 | sys.exit(2)
916 |
917 | if do_profile:
918 | dump_profile()
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .hf_model import HuggingFaceModel
2 | from .llava_model_hf import LLaVaModelHF
3 | from .openai_model import OpenAIModel
--------------------------------------------------------------------------------
/models/hf_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import fcntl
3 | import utils
4 |
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 |
7 | class HuggingFaceModel(object):
8 | def __init__(self, model_name, device, dtype, cache_dir=None, **kwargs):
9 | self.model_name = model_name
10 | self.device = device
11 | self.dtype = dtype
12 | token = os.environ.get("HF_TOKEN", None)
13 |
14 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=dtype, cache_dir=cache_dir, token=token)
15 | self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype, cache_dir=cache_dir, token=token, device_map='auto')
16 | self.model.eval()
17 |
18 | self.temperature = kwargs.get("temperature", 1.0)
19 | self.top_k = kwargs.get("top_k", 50)
20 | self.top_p = kwargs.get("top_p", 0.9)
21 | self.num_beams = kwargs.get("num_beams", 1)
22 | self.num_return_sequences = kwargs.get("num_return_sequences", 1)
23 | self.max_new_tokens = kwargs.get("max_new_tokens", 256)
24 | self.min_new_tokens = kwargs.get("min_new_tokens", 0)
25 |
26 | if "tokenizer_pad" in kwargs:
27 | self.tokenizer.pad_token = kwargs["tokenizer_pad"]
28 |
29 | if "tokenizer_padding_side" in kwargs:
30 | self.tokenizer.padding_side = kwargs["tokenizer_padding_side"]
31 |
32 | def generate(self, prompt, return_prompt=False, image_files=None, temperature=None, max_new_tokens=None):
33 | if temperature is None:
34 | temperature = self.temperature
35 | if max_new_tokens is None:
36 | max_new_tokens = self.max_new_tokens
37 |
38 | if '' in prompt:
39 | prompt = prompt.replace('', '') # Not used for non vision models, this assumes that this class is always used for text models (as the vision model used is LLaVA and is implemented in a different class)
40 |
41 | messages = utils.get_messages(prompt)
42 | inputs = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.device)
43 |
44 | outputs = self.model.generate(**inputs, do_sample=True, temperature=temperature, top_k=self.top_k, top_p=self.top_p, num_beams=self.num_beams,
45 | num_return_sequences=self.num_return_sequences, max_new_tokens=max_new_tokens, min_new_tokens=self.min_new_tokens, pad_token_id=self.tokenizer.eos_token_id)
46 |
47 | outputs = outputs[0][len(inputs[0]):] if not return_prompt else outputs[0]
48 | decoded_output = self.tokenizer.decode(outputs, skip_special_tokens=True)
49 |
50 | # Remove llama special words
51 | decoded_output = decoded_output.replace("assistant", "").replace("user", "").replace("system", "")
52 |
53 | return decoded_output
--------------------------------------------------------------------------------
/models/llava_model_hf.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
3 |
4 | class LLaVaModelHF(object):
5 | def __init__(self, model_name, device, dtype, cache_dir=None, **kwargs):
6 | self.model_name = model_name
7 | self.device = device
8 | self.dtype = dtype
9 | print(f"Cache dir: {cache_dir}")
10 | use_flash_attention = kwargs.get("use_flash_attn", False)
11 | attn_implementation = "flash_attention_2" if use_flash_attention else None
12 |
13 | self.processor = LlavaNextProcessor.from_pretrained(self.model_name, cache_dir=cache_dir)
14 | self.model = LlavaNextForConditionalGeneration.from_pretrained(self.model_name, torch_dtype=self.dtype, low_cpu_mem_usage=True,
15 | device_map='auto', cache_dir=cache_dir, attn_implementation=attn_implementation)
16 | self.model.eval()
17 | print(f"Model loaded on device: {self.model.device}")
18 |
19 | self.temperature = kwargs.get("temperature", 1.0)
20 | self.top_k = kwargs.get("top_k", 50)
21 | self.top_p = kwargs.get("top_p", 0.9)
22 | self.num_beams = kwargs.get("num_beams", 1)
23 | self.num_return_sequences = kwargs.get("num_return_sequences", 1)
24 | self.max_new_tokens = kwargs.get("max_new_tokens", 256)
25 | self.min_new_tokens = kwargs.get("min_new_tokens", 0)
26 |
27 | def generate(self, prompt, return_prompt=False, image_files=None, temperature=None, max_new_tokens=None):
28 | if temperature is None:
29 | temperature = self.temperature
30 | if max_new_tokens is None:
31 | max_new_tokens = self.max_new_tokens
32 |
33 | if image_files is None and '' in prompt:
34 | prompt = prompt.replace('', '')
35 |
36 | image_path = image_files[0] if type(image_files) == list else image_files
37 | image = Image.open(image_path) if image_files is not None else None
38 | inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)
39 |
40 | outputs = self.model.generate(**inputs, do_sample=True, temperature=temperature, top_k=self.top_k, top_p=self.top_p, num_beams=self.num_beams,
41 | num_return_sequences=self.num_return_sequences, max_new_tokens=max_new_tokens, min_new_tokens=self.min_new_tokens, pad_token_id=self.processor.tokenizer.eos_token_id)
42 |
43 | outputs = outputs[0][len(inputs['input_ids'][0]):] if not return_prompt else outputs[0]
44 | decoded_output = self.processor.tokenizer.decode(outputs, skip_special_tokens=True)
45 |
46 | return decoded_output
47 |
--------------------------------------------------------------------------------
/models/openai_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import openai
3 | import base64
4 |
5 | class OpenAIModel(object):
6 | def __init__(self, model_name, device, dtype, cache_dir=None, **kwargs):
7 | self.model_name = model_name
8 | self.device = device
9 | self.dtype = dtype
10 |
11 | self.api_key_path = None if "api_key_path" not in kwargs else kwargs["api_key_path"]
12 | self.api_key = self.get_api_key()
13 | self.organization_id_path = None if "organization_id_path" not in kwargs else kwargs["organization_id_path"]
14 | self.organization_id = self.get_org_id()
15 | assert self.api_key is not None, "API key not found."
16 |
17 | self.client = openai.Client(
18 | api_key=self.api_key,
19 | organization=self.organization_id
20 | )
21 |
22 | self.top_p = 0.9 if "top_p" not in kwargs else kwargs["top_p"]
23 | self.temperature = 1.0 if "temperature" not in kwargs else kwargs["temperature"]
24 | self.max_length = 1024 if "max_length" not in kwargs else kwargs["max_length"]
25 | self.num_return_sequences = 1 if "num_return_sequences" not in kwargs else kwargs["num_return_sequences"]
26 | self.seed = None if "seed" not in kwargs else kwargs["seed"]
27 |
28 | def get_api_key(self):
29 | if "OPENAI_API_KEY" in os.environ:
30 | return os.environ["OPENAI_API_KEY"]
31 | elif self.api_key_path is not None:
32 | with open(self.api_key_path, "r") as f:
33 | return f.read().strip()
34 |
35 | return None
36 |
37 | def get_org_id(self):
38 | if "OPENAI_ORG_ID" in os.environ:
39 | return os.environ["OPENAI_ORG_ID"]
40 | elif self.organization_id_path is not None:
41 | with open(self.organization_id_path, "r") as f:
42 | return f.read().strip()
43 |
44 | return None
45 |
46 | def encode_image(self, image_path):
47 | with open(image_path, "rb") as image_file:
48 | return base64.b64encode(image_file.read()).decode('utf-8')
49 |
50 | def get_messages(self, prompt, splits=["system", "user"], image_files=None):
51 | messages = []
52 | for split in splits:
53 | start_tag = f"<{split}>"
54 | end_tag = f"{split}>"
55 | if start_tag not in prompt or end_tag not in prompt:
56 | continue
57 |
58 | start_idx = prompt.find(start_tag)
59 | end_idx = prompt.find(end_tag)
60 |
61 | messages.append({
62 | "role": split,
63 | "content": prompt[start_idx + len(start_tag):end_idx].strip()
64 | })
65 |
66 | if len(messages) == 0:
67 | messages.append({
68 | "role": "user",
69 | "content": prompt
70 | })
71 |
72 | if image_files is not None:
73 | user_index = next((i for i, item in enumerate(messages) if item["role"] == "user"), None)
74 | user_msg = messages[user_index]
75 | user_msg["content"] = [{"type": "text", "text": user_msg["content"]}]
76 | for image_file in image_files:
77 | base64_image = self.encode_image(image_file)
78 | user_msg["content"].append({
79 | "type": "image",
80 | "url": f"data:image/jpeg;base64,{base64_image}"
81 | })
82 |
83 | return messages
84 |
85 | def generate(self, prompt, return_prompt=False, image_files=None, temperature=None, max_new_tokens=None):
86 | if temperature is None:
87 | temperature = self.temperature
88 | if max_new_tokens is None:
89 | max_new_tokens = self.max_length
90 |
91 | messages = self.get_messages(prompt, image_files=image_files)
92 | response = self.client.chat.completions.create(
93 | messages=messages,
94 | model=self.model_name,
95 | temperature=temperature,
96 | max_tokens=max_new_tokens,
97 | top_p=self.top_p,
98 | n=self.num_return_sequences,
99 | seed=self.seed
100 | )
101 |
102 | if self.num_return_sequences==1:
103 | return response.choices[0].message.content.strip()
104 | else:
105 | return [choice.message.content.strip() for choice in response.choices]
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sympy
3 | import threading
4 | import warnings
5 | import numpy as np
6 |
7 | from omegaconf import DictConfig
8 | from scipy.optimize import curve_fit
9 | from mloggers import MultiLogger
10 | from typing import Any
11 | from sympy import Mul, Add, Dummy, sift, numbered_symbols
12 |
13 | class Optimizer(object):
14 | """
15 | Optimizer class used to fit a function to a set of points, given a base shape.
16 | For example, takes as input something like "ax^2 + bx + c" and fits (a, b, c) to a set of points.
17 | """
18 | def __init__(self, cfg: DictConfig, points: np.ndarray, logger: MultiLogger) -> None:
19 | """
20 | Initializes the optimizer.
21 |
22 | Parameters
23 | ----------
24 | cfg : DictConfig -> The configuration file.
25 | points : np.ndarray -> The points to fit to.
26 | logger : MultiLogger -> The logger to log to.
27 | """
28 | self.cfg = cfg
29 | self.logger = logger
30 | self.points = points
31 | self.num_variables = cfg.experiment.function.num_variables
32 | self.invalid_coefficients = ["x", "y", "e"]
33 |
34 | self.coeff_rounding = cfg.experiment.optimizer.coeff_rounding if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "coeff_rounding") else 2
35 | self.tol = cfg.experiment.optimizer.tol if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "tol") else 1e-3 # Tolerance used to zero out coefficients that are close to 0
36 | self.num_threads = cfg.experiment.optimizer.optimizer_threads if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "optimizer_threads") else 5
37 | self.timeout = cfg.experiment.optimizer.timeout if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "timeout") else 10
38 | self.p0_min = cfg.experiment.optimizer.p0_min if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "p0_min") else -10
39 | self.p0_max = cfg.experiment.optimizer.p0_max if hasattr(cfg.experiment, "optimizer") and hasattr(cfg.experiment.optimizer, "p0_max") else 10
40 |
41 | def replace_coefficients(self, exp: sympy.core.add.Add) -> sympy.core.add.Add:
42 | """
43 | Replaces the number coefficients of a function with symbols.
44 |
45 | Parameters
46 | ----------
47 | exp : sympy.core.add.Add -> The function to replace coefficients of.
48 |
49 | Returns
50 | -------
51 | exp : sympy.core.add.Add -> The function with coefficients replaced.
52 | """
53 | def is_coefficient(symbol: Any) -> bool:
54 | if len(symbol.args) > 0:
55 | for arg in symbol.args:
56 | if not is_coefficient(arg):
57 | return False
58 |
59 | if re.match(r"c\d+", str(symbol)):
60 | return True
61 | elif symbol.is_Dummy:
62 | return True
63 | elif symbol.is_number:
64 | return True
65 |
66 | return False
67 |
68 | # Adapted from https://stackoverflow.com/questions/59686990/replacing-numbers-with-parameters-in-sympy
69 | def nfact2dum(m):
70 | assert m.is_Mul or m.is_Add or m.is_Function
71 | if m.is_Mul:
72 | if not any([is_coefficient(i) for i in m.args]):
73 | return m
74 | nonnum = sift(m.args, lambda i:is_coefficient(i), binary=True)[1]
75 | return Mul(*([Dummy()] + nonnum))
76 | elif m.is_Add:
77 | if not any([is_coefficient(i) for i in m.args]):
78 | return m
79 | nonnum = sift(m.args, lambda i:is_coefficient(i), binary=True)[1]
80 | return Add(*([Dummy()] + nonnum))
81 | elif m.is_Function:
82 | args = []
83 | for arg in m.args:
84 | if arg.is_Mul or arg.is_Add or arg.is_Function:
85 | args.append(nfact2dum(arg))
86 | else:
87 | args.append(arg)
88 | return Dummy() * m.func(*args)
89 |
90 | # Add +1 at the end of the expression to make sure that a constant is included
91 | exp = exp + 1
92 |
93 | # Replace all symbols beginning with c with a dummy
94 | # (as they are coefficients, otherwise we could generate a symbol that is already in the expression)
95 | exp = exp.replace(lambda x: re.match(r"c\d+", str(x)) or str(x).lower() == "c", lambda x: Dummy())
96 |
97 | # Replace all coefficients with dummies
98 | exp = exp.replace(
99 | lambda x:x.is_Mul or x.is_Add or x.is_Function,
100 | lambda x: nfact2dum(x))
101 | # Replace all exponents of dummy variables with 1
102 | exp = exp.replace(lambda x: x.is_Pow and x.base.is_Dummy, lambda x: x.base)
103 |
104 | # Replace all dummies with symbols
105 | exp = exp.subs(list(zip(exp.atoms(Dummy),numbered_symbols('c'))))
106 |
107 | return exp
108 |
109 | def get_optimizable_sympy_exp(self, exp: sympy.core.add.Add, quiet: bool = False) -> Any:
110 | """
111 | Returns a sympy expression that can be optimized by scipy.
112 |
113 | Parameters
114 | ----------
115 | exp : sympy.core.add.Add -> The expression to make optimizable.
116 | quiet : bool -> Whether to log the results.
117 |
118 | Returns
119 | -------
120 | exp : Any -> The optimizable expression.
121 | """
122 | exp = self.replace_coefficients(exp)
123 | self.logger.info("Optimizing function: " + str(exp)) if not quiet else None
124 | symbols = list(exp.free_symbols)
125 |
126 | # Sort symbols so that all x's come first (to find the variables that aren't coefficients)
127 | symbols.sort(key=lambda x: str(x).replace("x", " "))
128 |
129 | # Safety check to make sure that the number of variables is correct
130 | num_variables = len(re.findall(r"x\d*", str(symbols)))
131 | assert num_variables == self.num_variables, f"Number of variables ({num_variables}) does not match number of variables in config ({self.num_variables})"
132 |
133 | symbols = [symbols[:num_variables], *symbols[num_variables:]]
134 | return sympy.lambdify(symbols, exp, "numpy"), exp
135 |
136 | def _run_curve_fit(self, f: Any, num_parameters: int, results: Any, done_event: Any, quiet: bool = True, random_p0: bool = True) -> Any:
137 | """
138 | Runs the curve fit function with a timeout.
139 |
140 | Parameters
141 | ----------
142 | f : Any -> The function to fit.
143 | num_parameters : int -> The number of parameters to fit.
144 | results : Any -> The results list to append to.
145 | done_event : Any -> The event to set when done.
146 | quiet : bool -> Whether to log the results.
147 | random_p0 : bool -> Whether to use random starting points.
148 |
149 | Returns
150 | -------
151 | popt : np.ndarray -> The optimized parameters.
152 | pcov : np.ndarray -> The covariance matrix.
153 | """
154 | p0 = np.random.uniform(self.p0_min, self.p0_max, num_parameters) if random_p0 else np.ones(num_parameters)
155 | popt = None
156 | try:
157 | popt, pcov = curve_fit(f, self.points[:, :-1].T, self.points[:, -1].T, p0=p0)
158 | results.append((popt, pcov))
159 | done_event.set()
160 | return True
161 | except Exception as e:
162 | print(f"Optimization failed: {e}")
163 | pass
164 |
165 | return False
166 |
167 | def optimize(self, exp: sympy.core.add.Add, return_coeff: bool = False, quiet: bool = False) -> sympy.core.add.Add:
168 | """
169 | Optimizes a function to a set of points.
170 |
171 | Parameters
172 | ----------
173 | exp : sympy.core.add.Add -> The base shape to optimize.
174 | return_coeff : bool -> Whether to return the expression in coefficient form.
175 | quiet : bool -> Whether to log the results.
176 |
177 | Returns
178 | -------
179 | exp : sympy.core.add.Add -> The optimized function.
180 | coeff_exp : sympy.core.add.Add -> The optimized function in coefficient form. (Only if return_coeff is True)
181 | """
182 | f, exp = self.get_optimizable_sympy_exp(exp, quiet=quiet)
183 | symbols = exp.free_symbols
184 | symbols = sorted(symbols, key=lambda x: str(x).replace("x", " "))
185 | coefficients = symbols[self.num_variables:]
186 | coeff_exp = exp if return_coeff else None
187 | Xs = self.points[:, :-1].T
188 | ys = self.points[:, -1].T
189 |
190 | # Direct warnings only to console logger as file logger breaks with threading
191 | warnings.filterwarnings("default")
192 | warnings.showwarning = lambda *args, **kwargs: self.logger.warning(str(args[0]), mask=["file"])
193 | self.logger.info("Redirecting warnings to console logger only.")
194 |
195 | # Run curve fit with a timeout with num_threads random starting points
196 | results = []
197 | if self.num_threads == 1:
198 | self.logger.info("Running optimization with 1 attempt.") if not quiet else None
199 | self._run_curve_fit(f, len(coefficients), results=results, done_event=threading.Event(), quiet=quiet, random_p0=False)
200 | popt, pcov = results[0]
201 | else:
202 | done_event = threading.Event()
203 | threads = []
204 | for i in range(self.num_threads):
205 | threads.append(threading.Thread(target=lambda: self._run_curve_fit(f, len(coefficients), results=results, done_event=done_event, quiet=quiet, random_p0=i!=0)))
206 | threads[-1].start()
207 |
208 | done_event.wait(self.timeout)
209 | for thread in threads:
210 | thread.join()
211 | if thread.is_alive():
212 | self.logger.warning(f"Thread {thread} did not finish in time.")
213 | thread._stop()
214 | self.logger.info("All threads finished.") if not quiet else None
215 | if not done_event.is_set():
216 | raise ValueError("Optimization failed: timeout reached")
217 |
218 | # Direct warnings back to normal
219 | warnings.filterwarnings("default")
220 | warnings.showwarning = lambda *args, **kwargs: self.logger.warning(str(args[0]))
221 | self.logger.info("Redirecting warnings back to normal (both file and console).")
222 |
223 | # Get the best parameters
224 | best_popt = None
225 | best_pcov = None
226 | best_error = np.inf
227 | for popt, pcov in results:
228 | error = np.sum((f(Xs, *popt) - ys) ** 2)
229 | if error < best_error:
230 | best_error = error
231 | best_popt = popt
232 | best_pcov = pcov
233 |
234 | popt = best_popt
235 | pcov = best_pcov
236 |
237 | if pcov is None or np.isinf(pcov).any() or np.isnan(pcov).any():
238 | raise ValueError("Optimization failed: covariance matrix is invalid")
239 | popt = [np.round(x, self.coeff_rounding) for x in popt]
240 | self.logger.info("Optimized parameters: " + str(popt)) if not quiet else None
241 |
242 | assert len(coefficients) == len(popt), f"Number of found coefficients {coefficients} does not match number of parameters {len(popt)})"
243 | zero_subs = {}
244 | for i, coefficient in enumerate(coefficients):
245 | if popt[i] < self.tol and popt[i] > -self.tol:
246 | zero_subs[coefficient] = 0
247 | return exp.subs(list(zip(coefficients, popt))), coeff_exp.subs(zero_subs) if return_coeff else None
248 |
--------------------------------------------------------------------------------
/plotter.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import utils
3 | import tempfile
4 |
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 |
8 | from mpl_toolkits.mplot3d import axes3d
9 | from omegaconf import DictConfig
10 | from collections.abc import Callable
11 | from typing import List, Any
12 |
13 | class Plotter(object):
14 | def __init__(self, cfg: DictConfig, train_points: np.ndarray, test_points: np.ndarray, output_path: str) -> None:
15 | """
16 | Initializes the plotter.
17 |
18 | Parameters
19 | ----------
20 | cfg : DictConfig -> The configuration file.
21 | train_points : np.ndarray -> The train points.
22 | test_points : np.ndarray -> The test points.
23 | Xs : np.ndarray -> The Xs of the points, as produced by meshgrid (necessary for contour plots).
24 | ys : np.ndarray -> The ys of the points (necessary for contour plots).
25 | output_path : str -> The path to save the plots to.
26 | """
27 | self.cfg = cfg
28 | self.num_variables = cfg.experiment.function.num_variables
29 | self.output_path = output_path
30 |
31 | self.train_points = train_points
32 | self.test_points = test_points
33 | assert self.test_points.shape[1] == self.num_variables + 1
34 | Xs = self.train_points[:, :-1]
35 | Xs = np.concatenate((Xs, self.test_points[:, :-1]))
36 | self.min_points = [np.min(self.test_points[:, i]) for i in range(self.num_variables)]
37 | self.max_points = [np.max(self.test_points[:, i]) for i in range(self.num_variables)]
38 |
39 | num_test = self.cfg.plotter.plotter_resolution if hasattr(self.cfg, "plotter") and hasattr(self.cfg.plotter, "plotter_resolution") else 1000
40 | if self.num_variables == 1:
41 | self.Xs_test = np.linspace(self.min_points[0], self.max_points[0], num_test).reshape(-1, 1)
42 | elif self.num_variables == 2:
43 | num_test = np.floor(np.sqrt(num_test)).astype(int)
44 | self.Xs_test = np.meshgrid(*[np.linspace(self.min_points[i], self.max_points[i], num_test) for i in range(self.num_variables)])
45 |
46 | self.gif_duration = self.cfg.plotter.gif_duration if hasattr(self.cfg, "plotter") and hasattr(self.cfg.plotter, "gif_duration") else 1000
47 | self.fig_size = (self.cfg.plotter.plotter_fig_size, self.cfg.plotter.plotter_fig_size) if hasattr(self.cfg, "plotter") and hasattr(self.cfg.plotter, "plotter_fig_size") else (10, 10)
48 | self.function_cache = {}
49 |
50 | def _eval_function(self, function: Any, Xs: np.ndarray, num_variables: int) -> np.ndarray:
51 | if function in self.function_cache:
52 | return self.function_cache[function]
53 | else:
54 | ys = utils.eval_function(function, Xs, num_variables)
55 | self.function_cache[function] = ys
56 | return ys
57 |
58 | def plot_points(self, save_fig: bool = False, save_path: str = "points.png", plot_test=False) -> None:
59 | """
60 | Plots a set of points. Used to feed to visual models.
61 |
62 | Parameters
63 | ----------
64 | save_fig : bool -> Whether to save the figure or not.
65 | save_path : str -> The path to save the figure to.
66 | """
67 | if self.num_variables > 2:
68 | return
69 |
70 | save_path = self.output_path + save_path
71 | if self.num_variables == 1:
72 | plt.figure(figsize=self.fig_size)
73 | plt.scatter(self.train_points[:, 0], self.train_points[:, 1], color="blue", label="Train points", zorder=1)
74 | if plot_test:
75 | plt.scatter(self.test_points[:, 0], self.test_points[:, 1], color="red", label="Test points", alpha=.25, zorder=1)
76 | # plt.grid(alpha=.4,linestyle='--')
77 | plt.xlabel("x")
78 | plt.ylabel("f(x)")
79 | plt.xlim(self.min_points[0] - 0.1 * (self.max_points[0] - self.min_points[0]), self.max_points[0] + 0.1 * (self.max_points[0] - self.min_points[0]))
80 | plt.ylim(np.min(self.train_points[:, 1]) - 0.1 * (np.max(self.train_points[:, 1]) - np.min(self.train_points[:, 1])), np.max(self.train_points[:, 1]) + 0.1 * (np.max(self.train_points[:, 1]) - np.min(self.train_points[:, 1])))
81 | plt.legend()
82 | elif self.num_variables == 2:
83 | ax = plt.figure(figsize=self.fig_size).add_subplot(projection='3d')
84 | ax.scatter(self.train_points[:, 0], self.train_points[:, 1], self.train_points[:, 2], c='b', label='Train points')
85 | if plot_test:
86 | ax.scatter(self.test_points[:, 0], self.test_points[:, 1], self.test_points[:, 2], c='r', label='Test points', alpha=.25)
87 | plt.xlabel("x1")
88 | plt.ylabel("x2")
89 | ax.legend(loc="upper right")
90 | else:
91 | raise ValueError("Invalid number of variables.")
92 |
93 | plt.tight_layout()
94 |
95 | if save_fig:
96 | plt.savefig(save_path)
97 | else:
98 | return plt.gcf(), plt.gca()
99 |
100 | def plot_results(self, function: Any, test_function: Any = None, plot_true: bool = True, label: str = "Model's best guess") -> plt.Figure:
101 | """
102 | Plots the results of the experiment, showing the test points, the true function, and the model's best guess.
103 |
104 | Parameters
105 | ----------
106 | function : str -> The model's best guess.
107 | test_function : Callable[[float], float] -> The true function.
108 | plot_true : bool -> Whether to plot the true function or not.
109 | label : str -> The label of the model's guess.
110 |
111 | Returns
112 | -------
113 | plt.Figure -> The figure of the plot.
114 | """
115 | fig, ax = self.plot_points(plot_test=False)
116 | if test_function is None:
117 | plot_true = False
118 |
119 | if self.num_variables == 1:
120 | if plot_true:
121 | ys_test = self._eval_function(test_function, self.Xs_test, self.num_variables)
122 | plt.plot(self.Xs_test, ys_test, color="red", label="True function", zorder=0, linestyle="--")
123 |
124 | ys = self._eval_function(function, self.Xs_test, self.num_variables)
125 | plt.plot(self.Xs_test, ys, color="green", label=label, zorder=0)
126 | plt.legend(loc="lower right")
127 |
128 | elif self.num_variables == 2:
129 | X1, X2 = self.Xs_test
130 | if plot_true:
131 | Z_test = self._eval_function(test_function, np.array([X1, X2]), self.num_variables).reshape(X1.shape)
132 | ax.plot_surface(X1, X2, Z_test, edgecolor='orangered', lw=0.25, alpha=0.1, label="True function", color="red")
133 |
134 | Z = self._eval_function(function, np.array([X1, X2]), self.num_variables).reshape(X1.shape)
135 | ax.plot_surface(X1, X2, Z, edgecolor='mediumseagreen', lw=0.5, alpha=0.3, label=label, color="green")
136 | ax.legend(loc="upper right")
137 |
138 | return plt.gcf(), plt.gca()
139 |
140 | def record_frame(self, best_function: str, last_function: str, score: float, test_function: Callable[[float], float], round: int, plot_true : bool = True) -> plt.Figure:
141 | """
142 | Records a frame of the animation.
143 |
144 | Parameters
145 | ----------
146 | best_function : str -> The model's best guess.
147 | last_function : str -> The model's last guess.
148 | score : float -> The score of the best function.
149 | test_function : Callable[[float], float] -> The true function.
150 | round : int -> The round number.
151 | plot_true : bool -> Whether to plot the true function or not.
152 |
153 | Returns
154 | -------
155 | plt.Figure -> The figure of the frame.
156 | """
157 | fig, ax = self.plot_results(best_function, test_function, plot_true=plot_true)
158 | if self.num_variables == 1:
159 | ys_last = self._eval_function(last_function, self.Xs_test, self.num_variables)
160 | plt.plot(self.Xs_test, ys_last, color="orange", label="Last guess")
161 | plt.legend()
162 | elif self.num_variables == 2:
163 | X1, X2 = self.Xs_test
164 | Z_last = self._eval_function(last_function, np.array([X1, X2]), self.num_variables).reshape(X1.shape)
165 | ax.plot_surface(X1, X2, Z_last, edgecolor='orange', lw=0.5, alpha=0.3, label="Last guess", color="orange")
166 | ax.text2D(0.05, 0.95, f"Score: {score:.3f}", transform=ax.transAxes, fontsize=10, verticalalignment='top')
167 | ax.legend(loc="upper right")
168 |
169 | ax.set_title(f"Round {round+1}, R2: {score:.5f}")
170 | fig.tight_layout()
171 | return plt.gcf(), plt.gca()
172 |
173 | def record_video(self, frames: List[plt.Figure]) -> None:
174 | """
175 | Records the animation from the frames buffer.
176 |
177 | Parameters
178 | ----------
179 | frames : List[plt.Figure] -> The frames buffer.
180 | """
181 | images = []
182 | with tempfile.TemporaryDirectory() as tmp_path:
183 | for i, frame in enumerate(frames):
184 | frame.savefig(tmp_path + f"{i}.png")
185 | for i in range(len(frames)):
186 | images.append(PIL.Image.open(tmp_path + f"{i}.png"))
187 | # Extend the last frame to make the last (best) result more visible
188 | for _ in range(5):
189 | images.append(images[-1])
190 |
191 | images[0].save(self.output_path + "animation.gif", save_all=True, append_images=images[1:], duration=self.gif_duration, loop=0)
--------------------------------------------------------------------------------
/prompts/OPRO/basic_image.txt:
--------------------------------------------------------------------------------
1 | I want you to act as a mathematical function generator.
2 | You are given an image of a set of points, plotted as a scatter plot on a graph.
3 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better.
4 |
5 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions in a list and nothing else.
6 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}.
7 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you.
8 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment!
9 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "...
10 |
11 | {functions}
--------------------------------------------------------------------------------
/prompts/OPRO/basic_mixed.txt:
--------------------------------------------------------------------------------
1 | I want you to act as a mathematical function generator.
2 | You are given an image of a graph with a set of point plotted as a scatter plot with (x, y) coordinates below:
3 | {points}
4 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better.
5 |
6 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions and nothing else.
7 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}.
8 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you.
9 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment!
10 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "...
11 |
12 | {functions}
--------------------------------------------------------------------------------
/prompts/OPRO/basic_text.txt:
--------------------------------------------------------------------------------
1 | I want you to act as a mathematical function generator.
2 | You are given a set of points with (x, y) coordinates below:
3 | {points}
4 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better.
5 |
6 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions and nothing else.
7 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}.
8 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you.
9 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment!
10 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "...
11 |
12 | {functions}
--------------------------------------------------------------------------------
/prompts/OPRO/image_all.txt:
--------------------------------------------------------------------------------
1 | I want you to act as a mathematical function generator.
2 | You are given an image of a graph with a set of point plotted as a scatter plot in blue.
3 | The coordinates of the points are: {points}
4 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better. These functions are also shown in the image in the form of a green line and the formula for the plotted function can be seen in the title.
5 |
6 | Your task is to give me a list of five new potential functions that are different from all the ones reported below, and have a lower error value than all of the functions below. Only output the new functions and nothing else.
7 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}.
8 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you.
9 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "...
10 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment!
11 |
12 | {functions}
--------------------------------------------------------------------------------
/prompts/OPRO/image_best.txt:
--------------------------------------------------------------------------------
1 | I want you to act as a mathematical function generator.
2 | You are given an image of a graph with a set of point plotted as a scatter plot in blue.
3 |
4 | The coordinates of the points are: {points}
5 | Below are some previous functions and the error they make on the points above. The errors are arranged in order of their fit values, with the highest values coming first, and lower is better. The function with the lowest error (the best function so far) is also shown in the image in the form of a green line.
6 |
7 | Your task is to give me a list of five new functions that are different from all the ones reported below, and have a lower error value than all of the ones below. Only output the new functions and nothing else.
8 | Remember that the functions you generate should always have at most {num_variables} variables {variables_list}.
9 | The functions should have parametric form, using 'c' in place of any constant or coefficient. The coefficients will be optimized to fit the data. Make absolutely sure that the functions you generate are completely different from the ones already given to you.
10 | The functions should all begin with the indicators "f1(x) = ", "f2(x) = "...
11 | Remember that you can combine the simple building blocks (operations, constants, variables) in any way you want to generate more complex functions. Don't be afraid to experiment!
12 |
13 | {functions}
--------------------------------------------------------------------------------
/prompts/OPRO/no_info.txt:
--------------------------------------------------------------------------------
1 | Generate five random functions of the form Function: f(x). The functions you generate should always have at most {num_variables} variables {variables_list}.
2 | Only output the functions and nothing else.
--------------------------------------------------------------------------------
/prompts/seed_functions/generate_seed.txt:
--------------------------------------------------------------------------------
1 | I want you to act as a mathematical function generator.
2 | Given a set of points below, you are to come up with 5 potential functions that would fit the points. Don't worry too much about accuracy: your task is to generate a set of functions that are as diverse as possible, so that they can serve as starting points for further optimization.
3 |
4 | To generate the functions, you will start from a set of basic operators and expressions, and combine them into something more complex.
5 | Your options are:
6 | - {num_variables} independent variable symbols: {variables_list}.
7 | - A coefficient symbol: c (there is no need to write a number - write this generic coefficient instead).
8 | - Basic operators: +, -, *, /, ^, sqrt, exp, log, abs
9 | - Trigonometric expressions: sin, cos, tan, sinh, cosh, tanh
10 | - Standard constants: "pi" represents pi and "E" represents euler's constant.
11 |
12 | Make sure there are no numbers in the functions, use the coefficient token 'c' instead.
13 | You are required to use at least one of each variables from the available list ({variables_list}).
14 | Analyze the points carefully: if there are any negative points in the input, sqrt and log can not be used unless the input can never be negative. Be careful about dividing by zero!
15 | The functions should all begin with the indicators "f1 = ", "f2 = "... Only write the new function and no additional output.
16 | Your task is to combine an arbitrary number of these basic blocks to create a complex expression. Don't be afraid to be creative and experiment! The functions should be as complex as possible, combining many different operations. Variety is key!
17 |
18 | Points: {points}
19 |
20 | Functions:
21 |
--------------------------------------------------------------------------------
/prompts/seed_functions/generate_seed_image.txt:
--------------------------------------------------------------------------------
1 | I want you to act as a mathematical function generator.
2 | Given a set of points below, you are to come up with 5 potential functions that would fit the points. Don't worry too much about accuracy: your task is to generate a set of functions that are as diverse as possible, so that they can serve as starting points for further optimization.
3 |
4 | To generate the functions, you will start from a set of basic operators and expressions, and combine them into something more complex.
5 | Your options are:
6 | - {num_variables} independent variable symbols: {variables_list}.
7 | - A coefficient symbol: c (there is no need to write a number - write this generic coefficient instead).
8 | - Basic operators: +, -, *, /, ^, sqrt, exp, log, abs
9 | - Trigonometric expressions: sin, cos, tan, sinh, cosh, tanh
10 | - Standard constants: "pi" represents pi and "E" represents euler's constant.
11 |
12 | Make sure there are no numbers in the functions, use the coefficient token 'c' instead.
13 | You are required to use at least one of each variables from the available list ({variables_list}).
14 | Analyze the points carefully: if there are any negative points in the input, sqrt and log can not be used unless the input can never be negative. Be careful about dividing by zero!
15 | The functions should all begin with the indicators "f1 = ", "f2 = "... Only write the new function and no additional output.
16 | Your task is to combine an arbitrary number of these basic blocks to create a complex expression. Don't be afraid to be creative and experiment! The functions should be as complex as possible, combining many different operations. Variety is key!
17 |
18 | Points: {points}
19 |
20 |
21 | Functions:
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.23.0
2 | antlr4-python3-runtime==4.9.3
3 | bitsandbytes==0.41.1
4 | fire==0.6.0
5 | huggingface-hub==0.23.2
6 | hydra-core==1.3.2
7 | ipykernel==6.25.2
8 | ipython==8.16.1
9 | ipython-genutils==0.2.0
10 | ipywidgets==8.1.1
11 | jupyter==1.0.0
12 | jupyterlab==4.0.7
13 | matplotlib==3.8.0
14 | matplotlib-inline==0.1.6
15 | mloggers==1.3.1
16 | numpy==1.26.1
17 | omegaconf==2.3.0
18 | openai==1.29.0
19 | pandas==2.2.0
20 | safetensors==0.4.2
21 | scikit-image==0.23.2
22 | scikit-learn==1.3.1
23 | scipy==1.11.3
24 | sympy==1.12.1
25 | tokenizers==0.19.1
26 | torch==2.1.0
27 | torchaudio==2.1.0
28 | torchvision==0.16.0
29 | tqdm==4.66.1
30 | transformers==4.41.2
--------------------------------------------------------------------------------
/scorers/__init__.py:
--------------------------------------------------------------------------------
1 | from .scorer import Scorer
2 | from .basic_scorer import BasicScorer
3 | from .minmax_scorer import MinMaxScorer
4 | from .complexity_scorer import ComplexityScorer
--------------------------------------------------------------------------------
/scorers/basic_scorer.py:
--------------------------------------------------------------------------------
1 | from .scorer import Scorer
2 | from utils import format_exp
3 | from sklearn.metrics import mean_squared_error
4 | import numpy as np
5 | import utils
6 | import sys
7 |
8 | from typing import Dict, Tuple
9 |
10 | class BasicScorer(Scorer):
11 | """
12 | Basic scorer for symbolic regression.
13 | Scores the function using unnormalized MSE.
14 | """
15 |
16 | def __init__(self, points: np.ndarray, rounding: int = 2, scientific: bool = False):
17 | """
18 | Initialize the scorer.
19 |
20 | Parameters
21 | ----------
22 | points -> the points to evaluate the function on.
23 | rounding -> number of decimal places to round the score to (-1 for no rounding)
24 | scientific -> whether to use scientific notation for the score.
25 |
26 | Returns
27 | -------
28 | None.
29 | """
30 |
31 | super().__init__(points)
32 | self.round = rounding
33 | self.scientific = scientific
34 |
35 | def score(self, function: callable) -> float:
36 | """
37 | Scores a function on a given set of points.
38 |
39 | Parameters
40 | ----------
41 | function -> the function to score.
42 | round -> whether to round the score.
43 | scientific -> whether to use scientific notation for the score.
44 |
45 | Returns
46 | -------
47 | score -> the score of the function.
48 | """
49 | xs = self.points[:, 0:-1]
50 | ys = self.points[:, -1]
51 | num_variables = xs.shape[1]
52 |
53 | try:
54 | ys_pred = utils.eval_function(function, xs, num_variables)
55 | fit = mean_squared_error(ys, ys_pred)
56 | except Exception as e:
57 | # If the function is invalid for some points (e.g. division by 0), return inf
58 | fit = np.inf
59 |
60 | if self.round > 0:
61 | fit = np.round(fit, self.round)
62 | fit = fit.astype(np.float64)
63 | return fit
64 |
65 | def score_current_functions(self, current_functions: Dict[str, callable]) -> Tuple[Dict, Dict]:
66 | """
67 | Scores the current functions in the prompt.
68 |
69 | Parameters
70 | ----------
71 | current_functions -> the current functions in the prompt.
72 | round -> whether to round the score.
73 | scientific -> whether to use scientific notation for the score.
74 |
75 | Returns
76 | -------
77 | scores -> the score of the current functions.
78 | normalized_scores -> the normalized score of the current functions.
79 | """
80 | scores = { function: self.score(current_functions[function]) for function in current_functions }
81 | normalized_scores = scores.copy()
82 | if self.scientific:
83 | normalized_scores = { name: format_exp(score, self.round) for name, score in normalized_scores.items()}
84 |
85 | return scores, normalized_scores
--------------------------------------------------------------------------------
/scorers/complexity_scorer.py:
--------------------------------------------------------------------------------
1 | from .scorer import Scorer
2 | from utils import format_exp
3 | from sklearn.metrics import mean_squared_error
4 | import numpy as np
5 | import utils
6 | import sys
7 |
8 | from typing import Dict, Tuple
9 |
10 | class ComplexityScorer(Scorer):
11 | """
12 | Complexity scorer for symbolic regression.
13 | Scores the function using normalized MSE combined with a measure for the complexity of the function.
14 | The score measure was defined in https://arxiv.org/abs/2303.06833
15 | """
16 |
17 | def __init__(self, points: np.ndarray, rounding: int = 2, scientific: bool = False, lam: float = 0.5, max_nodes: int = 30, alternative: bool = False):
18 | """
19 | Initialize the scorer.
20 |
21 | Parameters
22 | ----------
23 | points -> the points to evaluate the function on.
24 | rounding -> number of decimal places to round the score to (-1 for no rounding)
25 | scientific -> whether to use scientific notation for the score.
26 | lam -> the lambda parameter for the complexity term.
27 | max_nodes -> the maximum number of nodes in the expression tree (used for normalization).
28 | alternative -> whether to use the alternative scoring function.
29 |
30 | Returns
31 | -------
32 | None.
33 | """
34 |
35 | super().__init__(points)
36 | self.round = rounding
37 | self.scientific = scientific
38 | self.eps = 1e-6
39 | self.lam = lam
40 | self.max_nodes = max_nodes
41 | self.alternative = alternative
42 |
43 | def score(self, function: callable) -> float:
44 | """
45 | Scores a function on a given set of points.
46 |
47 | Parameters
48 | ----------
49 | function -> the function to score.
50 | round -> whether to round the score.
51 | scientific -> whether to use scientific notation for the score.
52 |
53 | Returns
54 | -------
55 | score -> the score of the function.
56 | """
57 | xs = self.points[:, 0:-1]
58 | ys = self.points[:, -1]
59 | n = xs.shape[0]
60 | num_variables = xs.shape[1]
61 |
62 | try:
63 | ys_pred = utils.eval_function(function, xs, num_variables)
64 | # Calculate normalized MSE
65 | fit = 1/n * np.linalg.norm(ys - ys_pred)**2 / (1/n * np.linalg.norm(ys)**2 + self.eps)
66 | except Exception as e:
67 | # If the function is invalid for some points (e.g. division by 0), return inf
68 | fit = np.inf
69 | fit = fit.astype(np.float64)
70 |
71 | complexity = utils.count_nodes(function)
72 | complexity_term = np.exp(-complexity/self.max_nodes)
73 |
74 | if self.alternative:
75 | error = fit + self.lam * (1-complexity_term)
76 | else:
77 | fit_term = 1/(1 + fit)
78 | error = 1/(fit_term + self.lam * complexity_term + self.eps)
79 |
80 | if self.round > 0:
81 | error = np.round(error, self.round)
82 |
83 | return error
84 |
85 | def score_current_functions(self, current_functions: Dict[str, callable]) -> Tuple[Dict, Dict]:
86 | """
87 | Scores the current functions in the prompt.
88 |
89 | Parameters
90 | ----------
91 | current_functions -> the current functions in the prompt.
92 | round -> whether to round the score.
93 | scientific -> whether to use scientific notation for the score.
94 |
95 | Returns
96 | -------
97 | scores -> the score of the current functions.
98 | normalized_scores -> the normalized score of the current functions.
99 | """
100 | scores = { function: self.score(current_functions[function]) for function in current_functions }
101 | normalized_scores = scores.copy()
102 | if self.scientific:
103 | normalized_scores = { name: format_exp(score, self.round) for name, score in normalized_scores.items()}
104 |
105 | return scores, normalized_scores
--------------------------------------------------------------------------------
/scorers/minmax_scorer.py:
--------------------------------------------------------------------------------
1 | from .scorer import Scorer
2 | from utils import format_exp
3 | from sklearn.metrics import mean_squared_error
4 | from sklearn.preprocessing import MinMaxScaler
5 | import numpy as np
6 | import utils
7 |
8 | from typing import Dict, Tuple
9 |
10 | class MinMaxScorer(Scorer):
11 | """
12 | Basic scorer for symbolic regression.
13 | Scores the function using normalized MSE.
14 | The MSE is normalized by the MinMaxScaler.
15 | """
16 |
17 | def __init__(self, points: np.ndarray, min_score: float, max_score: float, rounding: int = 2, scientific: bool = False):
18 | """
19 | Initialize the scorer.
20 |
21 | Parameters
22 | ----------
23 | points -> the points to evaluate the function on.
24 | min_score -> the minimum value for the interval to scale the scores to.
25 | max_score -> the maximum value for the interval to scale the scores to.
26 | rounding -> number of decimal places to round the score to (-1 for no rounding)
27 | scientific -> whether to use scientific notation for the score.
28 |
29 | Returns
30 | -------
31 | None.
32 | """
33 |
34 | super().__init__(points)
35 | self.round = rounding
36 | self.scientific = scientific
37 | self.min_score = min_score
38 | self.max_score = max_score
39 | self.scaler = MinMaxScaler(feature_range=(min_score, max_score))
40 |
41 | def score(self, function: callable) -> float:
42 | """
43 | Scores a function on a given set of points.
44 |
45 | Parameters
46 | ----------
47 | function -> the function to score.
48 | round -> whether to round the score.
49 | scientific -> whether to use scientific notation for the score.
50 |
51 | Returns
52 | -------
53 | score -> the score of the function.
54 | """
55 | xs = self.points[:, 0:-1]
56 | ys = self.points[:, 1]
57 | num_variables = xs.shape[1]
58 |
59 | try:
60 | ys_pred = utils.eval_function(function, xs, num_variables)
61 | fit = mean_squared_error(ys, ys_pred)
62 | except:
63 | # If the function is invalid for some points (e.g. division by 0), return inf
64 | fit = np.inf
65 |
66 | if self.round > 0:
67 | fit = np.round(fit, self.round)
68 | return fit.astype(np.float64)
69 |
70 | def score_current_functions(self, current_functions: list) -> Tuple[Dict, Dict]:
71 | """
72 | Scores the current functions in the prompt.
73 |
74 | Parameters
75 | ----------
76 | current_functions -> the current functions in the prompt.
77 |
78 | Returns
79 | -------
80 | scores -> the scores of the current functions.
81 | """
82 | scores = { function: self.score(current_functions[function]) for function in current_functions }
83 | inf_scores = { name: score for name, score in scores.items() if score == np.inf }
84 | # Remove inf scores from the list of scores to normalize
85 | # Can't remove them entirely otherwise this would cause inconsistencies (current_functions would be missing some scores). We handle those in the main loop
86 | if len(inf_scores) > 0:
87 | for name in inf_scores:
88 | scores.pop(name)
89 | scores_array = np.array(list(scores.values())).reshape(-1, 1)
90 | if len(scores_array) != 0:
91 | self.scaler.fit(scores_array)
92 | normalized_scores = { name: self.scaler.transform(np.array(score).reshape(-1, 1))[0][0] for name, score in scores.items() }
93 | else:
94 | normalized_scores = scores.copy()
95 |
96 | # Add inf scores back to the list of scores
97 | for name in inf_scores:
98 | scores[name] = np.inf
99 | normalized_scores[name] = np.inf
100 | if self.round > 0 and not self.scientific:
101 | normalized_scores = { name: np.round(score, self.round) for name, score in normalized_scores.items() }
102 | elif self.scientific:
103 | normalized_scores = { name: format_exp(score, self.round) for name, score in normalized_scores.items() }
104 | return scores, normalized_scores
--------------------------------------------------------------------------------
/scorers/scorer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Tuple
3 |
4 | import numpy as np
5 |
6 | class Scorer(ABC):
7 | """
8 | Abstract class for different scoring methods to evaluate the goodness of a function on a given set of points.
9 | """
10 |
11 | def __init__(self, points: np.ndarray, round: int = 2, scientific: bool = False):
12 | """
13 | Initialize the scorer.
14 |
15 | Parameters
16 | ----------
17 | points -> the points to evaluate the function on.
18 | rounding -> number of decimal places to round the score to (-1 for no rounding)
19 | scientific -> whether to use scientific notation for the score.
20 |
21 | Returns
22 | -------
23 | None.
24 | """
25 |
26 | self.points = points
27 | # self.tolerance = 1e-6 # Add tolerance to avoid division by 0
28 | # self.points[:, 0:-1] = self.points[:, 0:-1] + self.tolerance
29 |
30 | @abstractmethod
31 | def score(self, function: callable) -> float:
32 | """
33 | Scores a function on a given set of points.
34 |
35 | Parameters
36 | ----------
37 | function -> the function to score.
38 |
39 | Returns
40 | -------
41 | score -> the score of the function.
42 | """
43 |
44 | pass
45 |
46 | @abstractmethod
47 | def score_current_functions(self, current_functions: Dict[str, callable]) -> Tuple[Dict, Dict]:
48 | """
49 | Scores the current functions in the prompt.
50 |
51 | Parameters
52 | ----------
53 | current_functions -> the current functions in the prompt.
54 |
55 | Returns
56 | -------
57 | scores -> the scores of the current functions.
58 | """
59 |
60 | pass
61 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from torch import dtype, device
7 | from torch.cuda import get_device_name, is_available
8 | from collections.abc import Callable
9 | from typing import Tuple, List, Dict, Any
10 |
11 | from sklearn.metrics import mean_squared_error, r2_score
12 | from models import LLaVaModelHF, HuggingFaceModel, OpenAIModel
13 |
14 | import sympy
15 | from sympy.parsing.sympy_parser import parse_expr
16 |
17 | def get_job_id() -> str:
18 | """
19 | Gets the SLURM job id from the environment variables if available.
20 |
21 | Returns
22 | -------
23 | job_id -> The SLURM job id (or None if not available).
24 | """
25 |
26 | job_id = os.environ.get("SLURM_JOB_ID", None)
27 | if job_id is not None:
28 | job_id += "_" + os.environ.get("SLURM_ARRAY_TASK_ID", None) if "SLURM_ARRAY_TASK_ID" in os.environ else ""
29 |
30 | return job_id
31 |
32 | def load_model(model_name: str, device: device, dtype: dtype, cache_dir: str = None, model_args = None) -> Any:
33 | """
34 | Utility to load a model from the HuggingFace model hub.
35 | Mostly needed to deal with LLaVA models, that are not available on the model hub yet.
36 |
37 | Parameters
38 | ----------
39 | model_name -> the name of the model to load.
40 | device -> the device to load the model on.
41 | dtype -> the dtype to load the model with.
42 | cache_dir -> the cache directory to use for the model.
43 |
44 | Returns
45 | -------
46 | model -> the loaded model.
47 | """
48 | if 'llava' in model_name:
49 | model = LLaVaModelHF(model_name, device, dtype, cache_dir, **model_args)
50 | elif 'gpt' in model_name:
51 | model = OpenAIModel(model_name, device, dtype, cache_dir, **model_args)
52 | else:
53 | model = HuggingFaceModel(model_name, device, dtype, cache_dir, **model_args)
54 |
55 | return model
56 |
57 | def get_messages(prompt: str, splits: List[str] = ["system", "user"]) -> List[Dict[str, str]]:
58 | """
59 | Converts a prompt string into a list of messages for each split.
60 |
61 | Parameters:
62 | prompt (str): The prompt string.
63 | splits (list[str]): A list of the splits to parse. Defaults to ["system", "user"].
64 |
65 | Returns:
66 | list[dict[str, str]]: A dictionary of the messages for each split.
67 | """
68 |
69 | messages = []
70 | for split in splits:
71 | start_tag = f"<{split}>"
72 | end_tag = f"{split}>"
73 |
74 | start_idx = prompt.find(start_tag)
75 | end_idx = prompt.find(end_tag)
76 |
77 | # Skip if the split is not in the prompt (e.g. no system prompt)
78 | if start_idx == -1 or end_idx == -1:
79 | continue
80 | messages.append({
81 | "role": split,
82 | "content": prompt[start_idx + len(start_tag):end_idx].strip()
83 | })
84 |
85 | # If no splits at all, assume the whole prompt is a user message
86 | if len(messages) == 0:
87 | messages.append({
88 | "role": "user",
89 | "content": prompt
90 | })
91 |
92 | return messages
93 |
94 | def load_points(file_path: str) -> np.ndarray:
95 | """
96 | Loads a set of points from a file.
97 |
98 | Parameters
99 | ----------
100 | file_path -> the path to the file containing the points.
101 |
102 | Returns
103 | -------
104 | points -> the points.
105 | """
106 | if file_path.endswith(".npy"):
107 | points = np.load(file_path)
108 | elif file_path.endswith(".txt"):
109 | points = np.loadtxt(file_path)
110 | elif file_path.endswith(".csv"):
111 | points = pd.read_csv(file_path).values
112 | elif file_path.endswith(".tsv"):
113 | points = pd.read_csv(file_path, sep="\t").values
114 | else:
115 | raise ValueError("Invalid file format. (only .npy, .txt, .csv, and .tsv are supported)")
116 | return points
117 |
118 | def normalize_points(points: np.ndarray, method: str = "minmax", percentile: int = None) -> np.ndarray:
119 | """
120 | Normalizes a set of points.
121 |
122 | Parameters
123 | ----------
124 | points -> the points to normalize.
125 | method -> the normalization method to use. (minmax, zscore, percentile)
126 | percentile -> the percentile to use for percentile normalization (if applicable).
127 |
128 | Returns
129 | -------
130 | points -> the normalized points.
131 | """
132 | if method == "percentile" and percentile is None:
133 | raise ValueError("Percentile normalization requires a percentile value.")
134 |
135 | ys = np.array([point[-1] for point in points])
136 | if method == "minmax":
137 | points = np.array([np.concatenate([point[:-1], [(y - ys.min()) / (ys.max() - ys.min())]]) for point, y in zip(points, ys)])
138 | elif method == "zscore":
139 | points = np.array([np.concatenate([point[:-1], [(y - ys.mean()) / ys.std()]]) for point, y in zip(points, ys)])
140 | elif method == "percentile":
141 | points = np.array([np.concatenate([point[:-1], [y /np.percentile(ys, percentile)]]) for point, y in zip(points, ys)])
142 | else:
143 | raise ValueError("Invalid normalization method.")
144 |
145 | points = np.round(points, 4)
146 | return points
147 |
148 | def decimate_points(points: np.ndarray, max_points: int) -> np.ndarray:
149 | """
150 | Reduces the number of points to a maximum number to be used in the prompt.
151 |
152 | Parameters
153 | ----------
154 | points -> the points to decimate.
155 | max_points -> the maximum number of points to keep.
156 |
157 | Returns
158 | -------
159 | points -> the decimated points.
160 | """
161 |
162 | if points.shape[0] <= max_points:
163 | return points
164 |
165 | # Find an evenly spaced subset of points
166 | indices = np.linspace(0, points.shape[0] - 1, max_points, dtype=int)
167 | points = points[indices]
168 | return points
169 |
170 | def split_points(points: np.ndarray, test_fraction: float, split_strategy: str = "random", seed: int = None) -> Tuple[np.ndarray, np.ndarray]:
171 | """
172 | Splits a set of points into train and test sets.
173 |
174 | Parameters
175 | ----------
176 | points -> the points to split.
177 | test_fraction -> the fraction of points to use for the test set.
178 | split_strategy -> the strategy to use for splitting the points. (random, middle, end)
179 | seed -> the seed to use for the random split.
180 |
181 | Returns
182 | -------
183 | train_points -> the train points.
184 | test_points -> the test points.
185 | """
186 | num_points = points.shape[0]
187 | num_test_points = int(num_points * test_fraction)
188 | points = points[points[:, 0].argsort()]
189 |
190 | if seed is not None:
191 | np.random.seed(seed)
192 |
193 | #! Middle and end are not working properly with n_variables > 1, not fixed as unused in final version
194 | if split_strategy == "random":
195 | indices = np.random.choice(num_points, num_test_points, replace=False)
196 | mask = np.ones(num_points, dtype=bool)
197 | mask[indices] = False
198 | train_points = points[mask]
199 | test_points = points[~mask]
200 | elif split_strategy == "middle":
201 | start = (num_points - num_test_points) // 2
202 | end = start + num_test_points
203 | train_points = np.concatenate([points[:start], points[end:]])
204 | test_points = points[start:end]
205 | elif split_strategy == "end":
206 | train_points = points[:-num_test_points]
207 | test_points = points[-num_test_points:]
208 | else:
209 | raise ValueError("Invalid split strategy.")
210 |
211 | return train_points, test_points
212 |
213 | def array_to_string(points: np.ndarray) -> str:
214 | """
215 | Converts a numpy array of points to a string.
216 |
217 | Parameters
218 | ----------
219 | points -> the numpy array of points to convert.
220 |
221 | Returns
222 | -------
223 | points -> the string of points.
224 | """
225 | points = points.tolist()
226 | points_str = ""
227 | for point in points:
228 | point_str = ", ".join([str(np.round(x, 2)) for x in point])
229 | point_str = f"({point_str})"
230 | points_str += point_str + ", "
231 |
232 | return points_str[:-2]
233 |
234 | def string_to_array(points: str) -> np.ndarray:
235 | """
236 | Converts a string of points to a numpy array.
237 |
238 | Parameters
239 | ----------
240 | points -> the string of points to convert.
241 |
242 | Returns
243 | -------
244 | points -> the numpy array of points.
245 | """
246 | points = points.replace("(", "").split("), ")
247 | points = [point.replace(")", "") for point in points]
248 | points = [point.split(", ") for point in points]
249 | points = [[float(coordinate) for coordinate in point] for point in points]
250 | return np.array(points)
251 |
252 | def eval_function(function: sympy.core.function.Function, Xs: np.ndarray, num_variables: int) -> float:
253 | """
254 | Evaluates a sympy function at a point.
255 |
256 | Parameters
257 | ----------
258 | function -> the function to evaluate.
259 | Xs -> the points to evaluate the function at. (Variables have to be sorted alphabetically)
260 | num_variables -> the number of variables the function takes.
261 |
262 | Returns
263 | -------
264 | ys -> the value of the function at x.
265 | """
266 | symbols = function.free_symbols
267 | symbols = sorted(symbols, key=lambda x: str(x))
268 | if Xs.shape[-1] != num_variables:
269 | Xs = np.array(list(zip(*[x.flat for x in Xs])))
270 |
271 | ys = []
272 | for point in Xs:
273 | if type(point) == np.ndarray:
274 | subs = {symbol: value for symbol, value in zip(symbols, point)}
275 | else:
276 | subs = {symbols[0]: point}
277 | try :
278 | y = function.evalf(subs=subs)
279 | y = float(y)
280 | except Exception as e:
281 | print(f"Error evaluating function: {function} at point {point}. {e}")
282 | y = np.inf
283 | ys.append(y)
284 |
285 | ys = np.array(ys)
286 | ys = ys.astype(np.float32)
287 | return ys
288 |
289 | def clean_function(function: str) -> str:
290 | """
291 | Cleans a function string to be evaluable.
292 | """
293 | function = function.strip(".")
294 | function = function.replace(" ", "")
295 |
296 | if "=" in function:
297 | function = function.split("=")[1]
298 | elif ":" in function:
299 | function = function.split(":")[1]
300 |
301 | # Remove characters that are not allowed in a function
302 | removals = ["'", '"', "\\", "\n", "\t", "\r", " ", "_"]
303 | for removal in removals:
304 | function = function.replace(removal, "")
305 |
306 | # Remove trailing operators
307 | while len(function) > 1 and function[-1] in ["+", "-", "*", "/", "**"]:
308 | if len(function) == 1:
309 | return lambda x: 0
310 | function = function[:-1]
311 |
312 | # Remove leading operators
313 | while len(function) > 1 and function[0] in ["+", "*", "/", "**"]:
314 | if len(function) == 1:
315 | return lambda x: 0
316 | function = function[1:]
317 |
318 | # Remove leading indicators of a function definition
319 | removals = ["Function", "Newfunction", "Thefunctionis", ":"]
320 |
321 | for removal in removals:
322 | if removal.lower() in function.lower():
323 | function = function.replace(removal, "")
324 | function = function.strip()
325 |
326 | return function
327 |
328 | def string_to_function(function: str, num_variables: int = 1) -> Callable[[float], float]:
329 | """
330 | Converts a string to a callable function using eval.
331 |
332 | Parameters
333 | ----------
334 | function -> the string to convert.
335 | num_variables -> the number of variables the function should take.
336 |
337 | Returns
338 | -------
339 | f -> the callable function.
340 | """
341 | function = clean_function(function)
342 |
343 | np_func = ["sin", "cos", "tan", "exp", "log", "sqrt"]
344 | function = function.replace("^", "**")
345 | #! This only works for variables in x (x1, x2, x3, ...)
346 | #! This only works with coefficients that end with numbers (e.g. c0, c1, c2, ...)
347 | function = re.sub(r"(\d)x", r"\1*x", function)
348 | regex = r"(\d)(" + "|".join(np_func) + ")"
349 | function = re.sub(regex, r"\1*\2", function)
350 | f = parse_expr(function)
351 | return f
352 |
353 | def is_valid_function(function: str, current_functions: Any, num_variables: int = 1) -> Tuple[bool, str]:
354 | """
355 | Checks if a function is valid.
356 |
357 | Parameters
358 | ----------
359 | function -> the function to check.
360 | current_functions -> the current functions in the prompt.
361 | num_variables -> the number of variables the function should take.
362 |
363 | Returns
364 | -------
365 | valid -> whether the function is valid.
366 | reason -> the reason the function is invalid (if applicable).
367 | """
368 | valid = True
369 | reason = ""
370 | if type(function) == str:
371 | f = string_to_function(function, num_variables)
372 | else:
373 | f = function
374 | symbols = f.free_symbols
375 | variables = [str(symbol) for symbol in symbols if str(symbol).startswith("x")]
376 |
377 | if len(variables) > num_variables:
378 | valid = False
379 | reason = "Too many variables in function."
380 | return valid, reason
381 |
382 | if current_functions is not None and current_functions.func_in_list(f):
383 | valid = False
384 | reason = "Function already in prompt."
385 | return valid, reason
386 |
387 | return valid, reason
388 |
389 | def format_exp(x: float, d: int = 6) -> str:
390 | """
391 | Formats a number in scientific notation with custom precision. (used in Scorers)
392 |
393 | Parameters
394 | ----------
395 | x -> the number to format.
396 | d -> the number of decimal places to round to.
397 |
398 | Returns
399 | -------
400 | x -> the formatted number.
401 | """
402 | n = int(np.floor(np.log10(abs(x))))
403 | significand = x / 10 ** n
404 | exp_sign = '+' if n >= 0 else '-'
405 | return f'{significand:.{d}f}e{exp_sign}{n:02d}'
406 |
407 | def func_equals(f1: Any, f2: Any, num_variables: int) -> bool:
408 | """
409 | Checks if two functions are equal. Used in place of sympy.equals as the latter can become very slow for certain functions.
410 | https://stackoverflow.com/questions/37112738/sympy-comparing-expressions
411 |
412 | Parameters
413 | ----------
414 | f1 -> the first function.
415 | f2 -> the second function.
416 | num_variables -> the number of variables the functions should take.
417 |
418 | Returns
419 | -------
420 | equal -> whether the functions are equal.
421 | """
422 | if f1 == f2:
423 | return True
424 | if f1 is None or f2 is None:
425 | return False
426 | if f1.free_symbols != f2.free_symbols:
427 | return False
428 | if f1.free_symbols != set([sympy.Symbol(f"x{i + 1}") for i in range(num_variables)]):
429 | return False
430 | return False
431 |
432 | def count_nodes(formula: Any) -> int:
433 | """
434 | Gets the complexity of a sympy formula, represented by the number of nodes in its expression tree.
435 |
436 | Parameters
437 | ----------
438 | formula -> the formula to get the complexity of.
439 |
440 | Returns
441 | -------
442 | complexity -> the complexity of the formula.
443 | """
444 | return formula.count_ops()
445 |
446 | def replace_zero_coefficients(expr: Any, formula: Any, threshold: float = 1e-2) -> Any:
447 | """
448 | Replaces coefficients that are close to zero in a formula with zero.
449 |
450 | Parameters
451 | ----------
452 | expr -> the expression to replace coefficients in (with coefficients c0, c1...)
453 | formula -> the formula to replace coefficients in (with numerical coefficients)
454 | threshold -> the threshold to consider a coefficient zero.
455 |
456 | Returns
457 | -------
458 | expr -> the expression with zero coefficients replaced.
459 | formula -> the formula with zero coefficients replaced.
460 | """
461 | coeffs_dict = formula.as_coefficients_dict()
462 | expr_dict = expr.as_coefficients_dict()
463 |
464 | for key, value in coeffs_dict.items():
465 | if abs(value) < threshold:
466 | expr_dict[key] = 0
467 | formula = formula.subs(key, 0)
468 |
469 | expr = expr.subs(expr_dict)
470 |
471 | print(expr, formula)
--------------------------------------------------------------------------------