├── .DS_Store
├── Demo.py
├── DySymNet.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
├── requires.txt
└── top_level.txt
├── DySymNet
├── SymbolicRegression.py
├── __init__.py
└── scripts
│ ├── __init__.py
│ ├── controller.py
│ ├── functions.py
│ ├── params.py
│ ├── pretty_print.py
│ ├── regularization.py
│ ├── symbolic_network.py
│ └── utils.py
├── LICENSE
├── README.md
├── build
└── lib
│ └── DySymNet
│ ├── SymbolicRegression.py
│ ├── __init__.py
│ └── scripts
│ ├── __init__.py
│ ├── controller.py
│ ├── functions.py
│ ├── params.py
│ ├── pretty_print.py
│ ├── regularization.py
│ ├── symbolic_network.py
│ └── utils.py
├── data
└── Nguyen-1.csv
├── dist
├── DySymNet-0.2.0-py3-none-any.whl
└── DySymNet-0.2.0.tar.gz
├── environment.yml
├── img
├── .DS_Store
├── ICML-logo.svg
├── Overview.png
└── Snipaste.png
└── setup.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/.DS_Store
--------------------------------------------------------------------------------
/Demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from DySymNet import SymbolicRegression
3 | from DySymNet.scripts.params import Params
4 | from DySymNet.scripts.functions import *
5 |
6 | # You can customize some hyperparameters according to READEME
7 | config = Params()
8 |
9 | # such as operators
10 | funcs = [Identity(), Sin(), Cos(), Square(), Plus(), Sub(), Product()]
11 | config.funcs_avail = funcs
12 |
13 |
14 | # Example 1: Input ground truth expression
15 | SR = SymbolicRegression.SymboliRegression(config=config, func="x_1**3 + x_1**2 + x_1", func_name="Nguyen-1")
16 | eq, R2, error, relative_error = SR.solve_environment()
17 | print('Expression: ', eq)
18 | print('R2: ', R2)
19 | print('error: ', error)
20 | print('relative_error: ', relative_error)
21 | print('log(1 + MSE): ', np.log(1 + error))
22 |
23 |
24 | # Example 2: Load the data file
25 | params = Params() # configuration for a specific task
26 | data_path = './data/Nguyen-1.csv' # data file should be in csv format
27 | SR = SymbolicRegression(config=params, func_name='Nguyen-1', data_path=data_path) # you can rename the func_name as any other you want.
28 | eq, R2, error, relative_error = SR.solve_environment() # return results
29 | print('Expression: ', eq)
30 | print('R2: ', R2)
31 | print('error: ', error)
32 | print('relative_error: ', relative_error)
33 | print('log(1 + MSE): ', np.log(1 + error))
--------------------------------------------------------------------------------
/DySymNet.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: DySymNet
3 | Version: 0.2.0
4 | Summary: This package contains the official Pytorch implementation for the paper "A Neural-Guided Dynamic Symbolic Network for Exploring Mathematical Expressions from Data" accepted by ICML'24.
5 | Home-page: https://github.com/AILWQ/DySymNet
6 | Author: Wenqiang Li
7 | Author-email: liwenqiang2021@gmail.com
8 | License: MIT
9 | Platform: UNKNOWN
10 | Classifier: Programming Language :: Python :: 3
11 | Classifier: License :: OSI Approved :: MIT License
12 | Classifier: Operating System :: OS Independent
13 | Requires-Python: >=3.6.0
14 | Description-Content-Type: text/markdown
15 | License-File: LICENSE
16 |
17 |
18 | ## DySymNet
19 |
20 |
21 |

22 |
23 |
24 | 
25 |
26 | This repository contains the official Pytorch implementation for the paper [***A Neural-Guided Dynamic Symbolic Network for Exploring Mathematical Expressions from Data***](https://openreview.net/forum?id=pTmrk4XPFx) accepted by ICML'24.
27 |
28 | [](https://openreview.net/pdf?id=IejxxE9DO2)
29 | [](https://arxiv.org/abs/2309.13705)
30 | 
31 | 
32 |
33 | ## 🔥 Highlights
34 |
35 | - ***DySymNet*** is a new search paradigm for symbolic regression (SR) that searches the symbolic network with various architectures instead of searching expressions in the large functional space.
36 | - ***DySymNet*** possesses promising capabilities in solving high-dimensional problems and optimizing coefficients, which are lacking in current SR methods.
37 | - ***DySymNet*** outperforms state-of-the-art baselines across various SR standard benchmark datasets and the well-known SRBench with more variables.
38 |
39 | ## 📦 Requirements
40 |
41 | Install the conda environment and packages:
42 |
43 | ```setup
44 | conda env create -f environment.yml
45 | conda activate dysymnet
46 | ```
47 |
48 | The packages have been tested on Linux.
49 |
50 | ## 📋 Getting started
51 |
52 | The main running script is `SymbolicRegression.py` and it relies on configuring runs via `params.py`. The `params.py` includes various hyperparameters of the controller RNN and the symbolic network. You can configure the following hyperparameters as required:
53 |
54 | #### parameters for symbolic network structure
55 |
56 | | Parameters | Description | **Example Values** |
57 | | :--------------: | :----------------------------------------------------------: | :----------------: |
58 | | `funcs_avail` | Operator library | See `params.py` |
59 | | `n_layers` | Range of symbolic network layers | [2, 3, 4, 5] |
60 | | `num_func_layer` | Range of the number of neurons per layer of a symbolic network | [2, 3, 4, 5, 6] |
61 |
62 | Note: You can add the additional operators in the `functions.py` by referring to existing operators and place them inside `funcs_avail` if you want to use them.
63 |
64 | #### parameters for controller RNN
65 |
66 | | Parameters | Description | **Example Values** |
67 | | :--------------: | :---------------------------------------: | :----------------: |
68 | | `num_epochs` | epochs for sampling | 500 |
69 | | `batch_size` | Size for a batch sampling | 10 |
70 | | `optimizer` | Optimizer for training RNN | Adam |
71 | | `hidden_size` | Hidden dim. of RNN layer | 32 |
72 | | `embedding_size` | Embedding dim. | 16 |
73 | | `learning_rate1` | Learning rate for training RNN | 0.0006 |
74 | | `risk_seeking` | using risk seeking policy gradient or not | True |
75 | | `risk_factor` | Risk factor | 0.5 |
76 | | `entropy_weight` | Entropy weight | 0.005 |
77 | | `reward_type` | Loss type for computing reward | mse |
78 |
79 |
80 |
81 | #### parameters for symbolic network training
82 |
83 | | Parameters | Description | **Example Values** |
84 | | :----------------: | :-------------------------------------------: | :----------------: |
85 | | `learning_rate2` | Learning rate for training symbolic network | 0.01 |
86 | | `reg_weight` | Regularizaiton weight | 5e-3 |
87 | | `threshold` | Prunning threshold | 0.05 |
88 | | `trials` | Training trials for training symbolic network | 1 |
89 | | `n_epochs1` | Epochs for the first training stage | 10001 |
90 | | `n_epochs2` | Epochs for the second training stage | 10001 |
91 | | `summary_step` | Summary for every `n` training steps | 1000 |
92 | | `clip_grad` | Using adaptive gradient clipping or not | True |
93 | | `max_norm` | Norm threshold for gradient clipping | 1.0 |
94 | | `window_size` | Window size for adaptive gradient clipping | 50 |
95 | | `refine_constants` | Refining constants or not | True |
96 | | `n_restarts` | Number of restarts for BFGS optimization | 1 |
97 | | `add_bias` | adding bias or not | False |
98 | | `verbose` | Print training process or not | True |
99 | | `use_gpu` | Using cuda or not | False |
100 | | `plot_reward` | Plot reward curve or not | False |
101 |
102 | **Note:** `threshold` controls the complexity of the final expression, and is a trade-off between complexity and precision, which you can customise according to your actual requirements.
103 |
104 | #### parameters for genearting input data
105 |
106 | | Parameters | Description | **Example Values** |
107 | | :-----------: | :----------------------------------------: | :----------------: |
108 | | `N_TRAIN` | Size of input data | 100 |
109 | | `N_VAL` | Size of validation dataset | 100 |
110 | | `NOISE` | Standard deviation of noise for input data | 0 |
111 | | `DOMAIN` | Domain of input data | (-1, 1) |
112 | | `N_TEST` | Size of test dataset | 100 |
113 | | `DOMAIN_TEST` | Domain of test dataset | (-1, 1) |
114 |
115 | #### Additional parameters
116 |
117 | `results_dir` configures the save path for all results
118 |
119 | ## 🤖 Symbolic Regression
120 |
121 | We provide two ways to perform symbolic regression tasks.
122 |
123 | #### Option1: Input ground truth expression
124 |
125 | When you want to discover an expression for which the ground truth is known, for example to test a standard benchmark, you can edit the script `SymbolicRegression.py` as follows:
126 |
127 | ```python
128 | # SymbolicRegression.py
129 | params = Params() # configuration for a specific task
130 | ground_truth_eq = "x_1 + x_2" # variable names should be written as x_i, where i>=1.
131 | eq_name = "x_1+x_2"
132 | SR = SymbolicRegression(config=params, func=ground_truth_eq, fun_name=eq_name) # A new folder named "func_name" will be created to store the result files.
133 | eq, R2, error, relative_error = SR.solve_environment() # return results
134 | ```
135 |
136 | In this way, the function `generate_data` is used to automatically generate the corresponding data set $\mathcal{D}(X, y)$ for inference, instead of you generating the data yourself.
137 |
138 | Then, you can run `SymbolicRegression.py` directly, or you can run it in the terminal as follows:
139 |
140 | ```python
141 | python SymbolicRegression.py
142 | ```
143 |
144 | After running this script, the results will be stored in path `./results/test/func_name`.
145 |
146 | #### Option2: Load the data file
147 |
148 | When you only have observed data and do not know the ground truth, you can perform symbolic regression by entering the path to the csv data file:
149 |
150 | ```python
151 | # SymbolicRegression.py
152 | params = Params() # configuration for a specific task
153 | data_path = './data/Nguyen-1.csv' # data file should be in csv format
154 | SR = SymbolicRegression(config=params, func_name='Nguyen-1', data_path=data_path) # you can rename the func_name as any other you want.
155 | eq, R2, error, relative_error = SR.solve_environment() # return results
156 | ```
157 |
158 | **Note:** the data file should contains ($X\_{dim} + 1$) colums, which $X\_{dim}$ is the number of independent variables and the last colum is the corresponding $y$ values.
159 |
160 | Then, you can run `SymbolicRegression.py` directly, or you can run it in the terminal as follows:
161 |
162 | ```python
163 | python SymbolicRegression.py
164 | ```
165 |
166 | After running this script, the results will be stored in path `./results/test/func_name`.
167 |
168 | #### Output
169 |
170 | Once the script stops early or finishes running, you will get the following output:
171 |
172 | ```
173 | Expression: x_1 + x_2
174 | R2: 1.0
175 | error: 4.3591795754679974e-13
176 | relative_error: 2.036015757767018e-06
177 | log(1 + MSE): 4.3587355946774144e-13
178 | ```
179 |
180 | ## 🔗 Citing this work
181 |
182 | If you find our work and this codebase helpful, please consider starring this repo 🌟 and cite:
183 |
184 | ```bibtex
185 | @inproceedings{
186 | li2024a,
187 | title={A Neural-Guided Dynamic Symbolic Network for Exploring Mathematical Expressions from Data},
188 | author={Wenqiang Li and Weijun Li and Lina Yu and Min Wu and Linjun Sun and Jingyi Liu and Yanjie Li and Shu Wei and Deng Yusong and Meilan Hao},
189 | booktitle={Forty-first International Conference on Machine Learning},
190 | year={2024},
191 | url={https://openreview.net/forum?id=IejxxE9DO2}
192 | }
193 | ```
194 |
195 |
196 |
--------------------------------------------------------------------------------
/DySymNet.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | README.md
3 | setup.py
4 | DySymNet/SymbolicRegression.py
5 | DySymNet/__init__.py
6 | DySymNet.egg-info/PKG-INFO
7 | DySymNet.egg-info/SOURCES.txt
8 | DySymNet.egg-info/dependency_links.txt
9 | DySymNet.egg-info/requires.txt
10 | DySymNet.egg-info/top_level.txt
11 | DySymNet/scripts/__init__.py
12 | DySymNet/scripts/controller.py
13 | DySymNet/scripts/functions.py
14 | DySymNet/scripts/params.py
15 | DySymNet/scripts/pretty_print.py
16 | DySymNet/scripts/regularization.py
17 | DySymNet/scripts/symbolic_network.py
18 | DySymNet/scripts/utils.py
--------------------------------------------------------------------------------
/DySymNet.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DySymNet.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | scikit-learn==1.5.2
2 | numpy==1.26.4
3 | sympy==1.13.3
4 | torch==2.2.2
5 | matplotlib==3.9.2
6 | tqdm==4.66.5
7 | pandas==2.2.3
8 | pip==24.2
9 | scipy==1.13.1
10 |
--------------------------------------------------------------------------------
/DySymNet.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | DySymNet
2 |
--------------------------------------------------------------------------------
/DySymNet/SymbolicRegression.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import os
5 | import torch
6 | import sympy as sp
7 | import pandas as pd
8 | from scipy.optimize import minimize
9 | from .scripts.functions import *
10 | from .scripts import functions as functions
11 | import collections
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 | from sympy import symbols, Float
15 | from torch import nn, optim
16 | from .scripts.controller import Agent
17 | import torch.nn.functional as F
18 | from .scripts import pretty_print
19 | from .scripts.regularization import L12Smooth
20 | from .scripts.symbolic_network import SymbolicNet
21 | from sklearn.metrics import r2_score
22 | from .scripts.params import Params
23 | from .scripts.utils import nrmse, R_Square, MSE, Relative_Error
24 |
25 |
26 | def generate_data(func, N, range_min, range_max):
27 | """Generates datasets."""
28 | free_symbols = sp.sympify(func).free_symbols
29 | x_dim = free_symbols.__len__()
30 | sp_expr = sp.lambdify(free_symbols, func)
31 | x = (range_max - range_min) * torch.rand([N, x_dim]) + range_min
32 | y = torch.tensor([[sp_expr(*x_i)] for x_i in x])
33 | return x, y
34 |
35 |
36 | class TimedFun:
37 | def __init__(self, fun, stop_after=10):
38 | self.fun_in = fun
39 | self.started = False
40 | self.stop_after = stop_after
41 |
42 | def fun(self, x, *args):
43 | if self.started is False:
44 | self.started = time.time()
45 | elif abs(time.time() - self.started) >= self.stop_after:
46 | raise ValueError("Time is over.")
47 | self.fun_value = self.fun_in(*x, *args) # sp.lambdify()
48 | self.x = x
49 | return self.fun_value
50 |
51 |
52 | class SymboliRegression:
53 | def __init__(self, config, func=None, func_name=None, data_path=None):
54 | """
55 | Args:
56 | config: All configs in the Params class, type: Params
57 | func: the function to be predicted, type: str
58 | func_name: the name of the function, type: str
59 | data_path: the path of the data, type: str
60 | """
61 | self.data_path = data_path
62 | self.X = None
63 | self.y = None
64 | self.funcs_per_layer = None
65 | self.num_epochs = config.num_epochs
66 | self.batch_size = config.batch_size
67 | self.input_size = config.input_size # number of operators
68 | self.hidden_size = config.hidden_size
69 | self.embedding_size = config.embedding_size
70 | self.n_layers = config.n_layers
71 | self.num_func_layer = config.num_func_layer
72 | self.funcs_avail = config.funcs_avail
73 | self.optimizer = config.optimizer
74 | self.auto = False
75 | self.add_bias = config.add_bias
76 | self.threshold = config.threshold
77 |
78 | self.clip_grad = config.clip_grad
79 | self.max_norm = config.max_norm
80 | self.window_size = config.window_size
81 | self.refine_constants = config.refine_constants
82 | self.n_restarts = config.n_restarts
83 | self.reward_type = config.reward_type
84 |
85 | if config.use_gpu:
86 | self.device = torch.device('cuda')
87 | else:
88 | self.device = torch.device('cpu')
89 | print("Use Device:", self.device)
90 |
91 | # Standard deviation of random distribution for weight initializations.
92 | self.init_sd_first = 0.1
93 | self.init_sd_last = 1.0
94 | self.init_sd_middle = 0.5
95 |
96 | self.config = config
97 |
98 | self.func = func
99 | self.func_name = func_name
100 |
101 | # generate data or load data from file
102 | if self.func is not None:
103 | # add noise
104 | if config.NOISE > 0:
105 | self.X, self.y = generate_data(func, self.config.N_TRAIN, self.config.DOMAIN[0], self.config.DOMAIN[1]) # y shape is (N, 1)
106 | y_rms = torch.sqrt(torch.mean(self.y ** 2))
107 | scale = config.NOISE * y_rms
108 | self.y += torch.empty(self.y.shape[-1]).normal_(mean=0, std=scale)
109 | self.x_test, self.y_test = generate_data(func, self.config.N_TRAIN, range_min=self.config.DOMAIN_TEST[0],
110 | range_max=self.config.DOMAIN_TEST[1])
111 |
112 | else:
113 | self.X, self.y = generate_data(func, self.config.N_TRAIN, self.config.DOMAIN[0], self.config.DOMAIN[1]) # y shape is (N, 1)
114 | self.x_test, self.y_test = generate_data(func, self.config.N_TRAIN, range_min=self.config.DOMAIN_TEST[0],
115 | range_max=self.config.DOMAIN_TEST[1])
116 | else:
117 | self.X, self.y = self.load_data(self.data_path)
118 | self.x_test, self.y_test = self.X, self.y
119 |
120 | self.dtype = self.X.dtype # obtain the data type, which determines the parameter type of the model
121 |
122 | if isinstance(self.n_layers, list) or isinstance(self.num_func_layer, list):
123 | print('*' * 25, 'Start Sampling...', '*' * 25 + '\n')
124 | self.auto = True
125 |
126 | self.agent = Agent(auto=self.auto, input_size=self.input_size, hidden_size=self.hidden_size,
127 | num_funcs_avail=len(self.funcs_avail), n_layers=self.n_layers,
128 | num_funcs_layer=self.num_func_layer, device=self.device, dtype=self.dtype)
129 |
130 | self.agent = self.agent.to(self.dtype)
131 |
132 | if not os.path.exists(self.config.results_dir):
133 | os.makedirs(self.config.results_dir)
134 |
135 | func_dir = os.path.join(self.config.results_dir, func_name)
136 | if not os.path.exists(func_dir):
137 | os.makedirs(func_dir)
138 | self.results_dir = func_dir
139 |
140 | self.now_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
141 |
142 | # save hyperparameters
143 | args = {
144 | "date": self.now_time,
145 | "add_bias": config.add_bias,
146 | "train_domain": config.DOMAIN,
147 | "test_domain": config.DOMAIN_TEST,
148 | "num_epochs": config.num_epochs,
149 | "batch_size": config.batch_size,
150 | "input_size": config.input_size,
151 | "hidden_size": config.hidden_size,
152 | "risk_factor": config.risk_factor,
153 | "n_layers": config.n_layers,
154 | "num_func_layer": config.num_func_layer,
155 | "funcs_avail": str([func.name for func in config.funcs_avail]),
156 | "init_sd_first": 0.1,
157 | "init_sd_last": 1.0,
158 | "init_sd_middle": 0.5,
159 | "noise_level": config.NOISE
160 | }
161 | with open(os.path.join(self.results_dir, 'args_{}.txt'.format(self.func_name)), 'a') as f:
162 | f.write(json.dumps(args))
163 | f.write("\n")
164 | f.close()
165 |
166 | def solve_environment(self):
167 | epoch_best_expressions = []
168 | epoch_best_rewards = []
169 | epoch_mean_rewards = []
170 | epoch_mean_r2 = []
171 | epoch_best_r2 = []
172 | epoch_best_relative_error = []
173 | epoch_mean_relative_error = []
174 | best_expression, best_performance, best_relative_error = None, float('-inf'), float('inf')
175 | early_stopping = False
176 |
177 | # log the expressions of all epochs
178 | f1 = open(os.path.join(self.results_dir, 'eq_{}_all.txt'.format(self.func_name)), 'a')
179 | f1.write('\n{}\t\t{}\n'.format(self.now_time, self.func_name))
180 | f1.write('{}\t\tReward\t\tR2\t\tExpression\t\tnum_layers\t\tnum_funcs_layer\t\tfuncs_per_layer\n'.format(self.reward_type))
181 |
182 | # log the best expressions of each epoch
183 | f2 = open(os.path.join(self.results_dir, 'eq_{}_summary.txt'.format(self.func_name)), 'a')
184 | f2.write('\n{}\t\t{}\n'.format(self.now_time, self.func_name))
185 | f2.write('Epoch\t\tReward\t\tR2\t\tExpression\n')
186 |
187 | if self.optimizer == "Adam":
188 | optimizer = torch.optim.Adam(self.agent.parameters(), lr=self.config.learning_rate1)
189 | else:
190 | optimizer = torch.optim.RMSprop(self.agent.parameters(), lr=self.config.learning_rate1)
191 |
192 | for i in range(self.num_epochs):
193 | print("******************** Epoch {:02d} ********************".format(i))
194 | expressions = []
195 | rewards = []
196 | r2 = []
197 | relative_error_list = []
198 | batch_log_probs = torch.zeros([self.batch_size], device=self.device)
199 | batch_entropies = torch.zeros([self.batch_size], device=self.device)
200 |
201 | j = 0
202 | while j < self.batch_size:
203 | error, R2, eq, log_probs, entropies, num_layers, num_func_layer, funcs_per_layer_name = self.play_episodes() # play an episode
204 | # if the expression is invalid, e.g. a constant or None, resample the structure of the symbolic network
205 | if 'x_1' not in str(eq) or eq is None:
206 | R2 = 0.0
207 | if 'x_1' in str(eq) and self.refine_constants:
208 | res = self.bfgs(eq, self.X, self.y, self.n_restarts)
209 | eq = res['best expression']
210 | R2 = res['R2']
211 | error = res['error']
212 | relative_error = res['relative error']
213 | else:
214 | relative_error = 100
215 |
216 | reward = 1 / (1 + error)
217 | print("Final expression: ", eq)
218 | print("Test R2: ", R2)
219 | print("Test error: ", error)
220 | print("Relative error: ", relative_error)
221 | print("Reward: ", reward)
222 | print('\n')
223 |
224 | f1.write('{:.8f}\t\t{:.8f}\t\t{:.8f}\t\t{:.8f}\t\t{}\t\t{}\t\t{}\t\t{}\n'.format(error, relative_error, reward, R2, eq, num_layers,
225 | num_func_layer,
226 | funcs_per_layer_name))
227 |
228 | if R2 > 0.99:
229 | print("~ Early Stopping Met ~")
230 | print("Best expression: ", eq)
231 | print("Best reward: ", reward)
232 | print(f"{self.config.reward_type} error: ", error)
233 | print("Relative error: ", relative_error)
234 | early_stopping = True
235 | break
236 |
237 | batch_log_probs[j] = log_probs
238 | batch_entropies[j] = entropies
239 | expressions.append(eq)
240 | rewards.append(reward)
241 | r2.append(R2)
242 | relative_error_list.append(relative_error)
243 | j += 1
244 |
245 | if early_stopping:
246 | f2.write('{}\t\t{:.8f}\t\t{:.8f}\t\t{}\n'.format(i, reward, R2, eq))
247 | break
248 |
249 | # a batch expressions
250 | ## reward
251 | rewards = torch.tensor(rewards, device=self.device)
252 | best_epoch_expression = expressions[np.argmax(rewards.cpu())]
253 | epoch_best_expressions.append(best_epoch_expression)
254 | epoch_best_rewards.append(max(rewards).item())
255 | epoch_mean_rewards.append(torch.mean(rewards).item())
256 |
257 | ## R2
258 | r2 = torch.tensor(r2, device=self.device)
259 | best_r2_expression = expressions[np.argmax(r2.cpu())]
260 | epoch_best_r2.append(max(r2).item())
261 | epoch_mean_r2.append(torch.mean(r2).item())
262 |
263 | epoch_best_relative_error.append(relative_error_list[np.argmax(r2.cpu())])
264 |
265 | # log the best expression of a batch
266 | f2.write(
267 | '{}\t\t{:.8f}\t\t{:.8f}\t\t{:.8f}\t\t{}\n'.format(i, relative_error_list[np.argmax(r2.cpu())], max(rewards).item(), max(r2).item(),
268 | best_r2_expression))
269 |
270 | # save the best expression from the beginning to now
271 | if max(r2) > best_performance:
272 | best_performance = max(r2)
273 | best_expression = best_r2_expression
274 | best_relative_error = min(epoch_best_relative_error)
275 |
276 | if self.config.risk_seeking:
277 | threshold = np.quantile(rewards.cpu(), self.config.risk_factor)
278 | indices_to_keep = torch.tensor([j for j in range(len(rewards)) if rewards[j] > threshold], device=self.device)
279 | if len(indices_to_keep) == 0:
280 | print("Threshold removes all expressions. Terminating.")
281 | break
282 |
283 | # Select corresponding subset of rewards, log_probabilities, and entropies
284 | sub_rewards = torch.index_select(rewards, 0, indices_to_keep)
285 | sub_log_probs = torch.index_select(batch_log_probs, 0, indices_to_keep)
286 | sub_entropies = torch.index_select(batch_entropies, 0, indices_to_keep)
287 |
288 | # Compute risk seeking and entropy gradient
289 | risk_seeking_grad = torch.sum((sub_rewards - threshold) * sub_log_probs, dim=0)
290 | entropy_grad = torch.sum(sub_entropies, dim=0)
291 |
292 | # Mean reduction and clip to limit exploding gradients
293 | risk_seeking_grad = torch.clip(risk_seeking_grad / (self.config.risk_factor * len(sub_rewards)), min=-1e6, max=1e6)
294 | entropy_grad = self.config.entropy_weight * torch.clip(entropy_grad / (self.config.risk_factor * len(sub_rewards)), min=-1e6, max=1e6)
295 |
296 | # compute loss and update parameters
297 | loss = -1 * (risk_seeking_grad + entropy_grad)
298 | optimizer.zero_grad()
299 | loss.backward()
300 | optimizer.step()
301 |
302 | f1.close()
303 | f2.close()
304 |
305 | # save the rewards
306 | f3 = open(os.path.join(self.results_dir, "reward_{}_{}.txt".format(self.func_name, self.now_time)), 'w')
307 | for i in range(len(epoch_mean_rewards)):
308 | f3.write("{} {:.8f}\n".format(i + 1, epoch_mean_rewards[i]))
309 | f3.close()
310 |
311 | # plot reward curve
312 | if self.config.plot_reward:
313 | # plt.plot([i + 1 for i in range(len(epoch_best_rewards))], epoch_best_rewards) # best reward of full epoch
314 | plt.plot([i + 1 for i in range(len(epoch_mean_rewards))], epoch_mean_rewards) # mean reward of full epoch
315 | plt.xlabel('Epoch')
316 | plt.ylabel('Reward')
317 | plt.title('Reward over Time ' + self.now_time)
318 | plt.show()
319 | plt.savefig(os.path.join(self.results_dir, "reward_{}_{}.png".format(self.func_name, self.now_time)))
320 |
321 | if early_stopping:
322 | return eq, R2, error, relative_error
323 | else:
324 | return best_expression, best_performance.item(), 1 / max(rewards).item() - 1, best_relative_error
325 |
326 | def bfgs(self, eq, X, y, n_restarts):
327 | variable = self.vars_name
328 |
329 | # Parse the expression and get all the constants
330 | expr = eq
331 | c = symbols('c0:10000') # Suppose we have at most n constants, c0, c1, ..., cn-1
332 | consts = list(expr.atoms(Float)) # Only floating-point coefficients are counted, not power exponents
333 | consts_dict = {c[i]: const for i, const in enumerate(consts)} # map between c_i and unoptimized constants
334 |
335 | for c_i, val in consts_dict.items():
336 | expr = expr.subs(val, c_i)
337 |
338 | def loss(expr, X):
339 | diffs = []
340 | for i in range(X.shape[0]):
341 | curr_expr = expr
342 | for idx, j in enumerate(variable):
343 | curr_expr = sp.sympify(curr_expr).subs(j, X[i, idx])
344 | diff = curr_expr - y[i]
345 | diffs.append(diff)
346 | return np.mean(np.square(diffs))
347 |
348 | # Lists where all restarted will be appended
349 | F_loss = []
350 | RE_list = []
351 | R2_list = []
352 | consts_ = []
353 | funcs = []
354 |
355 | print('Constructing BFGS loss...')
356 | loss_func = loss(expr, X)
357 |
358 | for i in range(n_restarts):
359 | x0 = np.array(consts, dtype=float)
360 | s = list(consts_dict.keys())
361 | # bfgs optimization
362 | fun_timed = TimedFun(fun=sp.lambdify(s, loss_func, modules=['numpy']), stop_after=int(1e10))
363 | if len(x0):
364 | minimize(fun_timed.fun, x0, method='BFGS') # check consts interval and if they are int
365 | consts_.append(fun_timed.x)
366 | else:
367 | consts_.append([])
368 |
369 | final = expr
370 | for i in range(len(s)):
371 | final = sp.sympify(final).replace(s[i], fun_timed.x[i])
372 |
373 | funcs.append(final)
374 |
375 | values = {x: X[:, idx] for idx, x in enumerate(variable)}
376 | y_pred = sp.lambdify(variable, final)(**values)
377 | if isinstance(y_pred, float):
378 | print('y_pred is float: ', y_pred, type(y_pred))
379 | R2 = 0.0
380 | loss_eq = 10000
381 | else:
382 | y_pred = torch.where(torch.isinf(y_pred), 10000, y_pred) # check if there is inf
383 | y_pred = torch.where(y_pred.clone().detach() > 10000, 10000, y_pred) # check if there is large number
384 | R2 = max(0.0, R_Square(y.squeeze(), y_pred))
385 | loss_eq = torch.mean(torch.square(y.squeeze() - y_pred)).item()
386 | relative_error = torch.mean(torch.abs((y.squeeze() - y_pred) / y.squeeze())).item()
387 | R2_list.append(R2)
388 | F_loss.append(loss_eq)
389 | RE_list.append(relative_error)
390 | best_R2_id = np.nanargmax(R2_list)
391 | best_consts = consts_[best_R2_id]
392 | best_expr = funcs[best_R2_id]
393 | best_R2 = R2_list[best_R2_id]
394 | best_error = F_loss[best_R2_id]
395 | best_re = RE_list[best_R2_id]
396 |
397 | return {'best expression': best_expr,
398 | 'constants': best_consts,
399 | 'R2': best_R2,
400 | 'error': best_error,
401 | 'relative error': best_re}
402 |
403 | def play_episodes(self):
404 | ############################### Sample a symbolic network ##############################
405 | init_state = torch.rand((1, self.input_size), device=self.device, dtype=self.dtype) # initial the input state
406 |
407 | if self.auto:
408 | num_layers, num_funcs_layer, action_index, log_probs, entropies = self.agent(
409 | init_state) # output the symbolic network structure parameters
410 | self.n_layers = num_layers
411 | self.num_func_layer = num_funcs_layer
412 | else:
413 | action_index, log_probs, entropies = self.agent(init_state)
414 |
415 | self.funcs_per_layer = {}
416 | self.funcs_per_layer_name = {}
417 |
418 | for i in range(self.n_layers):
419 | layer_funcs_list = list()
420 | layer_funcs_list_name = list()
421 | for j in range(self.num_func_layer):
422 | layer_funcs_list.append(self.funcs_avail[action_index[i, j]])
423 | layer_funcs_list_name.append(self.funcs_avail[action_index[i, j]].name)
424 | self.funcs_per_layer.update({i + 1: layer_funcs_list})
425 | self.funcs_per_layer_name.update({i + 1: layer_funcs_list_name})
426 |
427 | # let binary functions follow unary functions
428 | for layer, funcs in self.funcs_per_layer.items():
429 | unary_funcs = [func for func in funcs if isinstance(func, BaseFunction)]
430 | binary_funcs = [func for func in funcs if isinstance(func, BaseFunction2)]
431 | sorted_funcs = unary_funcs + binary_funcs
432 | self.funcs_per_layer[layer] = sorted_funcs
433 |
434 | print("Operators of each layer obtained by sampling: ", self.funcs_per_layer_name)
435 |
436 | ############################### Start training ##############################
437 | error_test, r2_test, eq = self.train(self.config.trials)
438 |
439 | return error_test, r2_test, eq, log_probs, entropies, self.n_layers, self.num_func_layer, self.funcs_per_layer_name
440 |
441 | def train(self, trials=1):
442 | """Train the network to find a given function"""
443 |
444 | data, target = self.X.to(self.device), self.y.to(self.device)
445 | test_data, test_target = self.x_test.to(self.device), self.y_test.to(self.device)
446 |
447 | self.x_dim = data.shape[-1]
448 |
449 | self.vars_name = [f'x_{i}' for i in range(1, self.x_dim + 1)] # Variable names
450 |
451 | width_per_layer = [len(f) for f in self.funcs_per_layer.values()]
452 | n_double_per_layer = [functions.count_double(f) for f in self.funcs_per_layer.values()]
453 |
454 | if self.auto:
455 | init_stddev = [self.init_sd_first] + [self.init_sd_middle] * (self.n_layers - 2) + [self.init_sd_last]
456 |
457 | # Arrays to keep track of various quantities as a function of epoch
458 | loss_list = [] # Total loss (MSE + regularization)
459 | error_list = [] # MSE
460 | reg_list = [] # Regularization
461 | error_test_list = [] # Test error
462 | r2_test_list = [] # Test R2
463 |
464 | error_test_final = []
465 | r2_test_final = []
466 | eq_list = []
467 |
468 | def log_grad_norm(net):
469 | sqsum = 0.0
470 | for p in net.parameters():
471 | if p.grad is not None:
472 | sqsum += (p.grad ** 2).sum().item()
473 | return np.sqrt(sqsum)
474 |
475 | # for trial in range(trials):
476 | retrain_num = 0
477 | trial = 0
478 | while 0 <= trial < trials:
479 | print("Training on function " + self.func_name + " Trial " + str(trial + 1) + " out of " + str(trials))
480 |
481 | # reinitialize for each trial
482 | if self.auto:
483 | net = SymbolicNet(self.n_layers,
484 | x_dim=self.x_dim,
485 | funcs=self.funcs_per_layer,
486 | initial_weights=None,
487 | init_stddev=init_stddev,
488 | add_bias=self.add_bias).to(self.device)
489 |
490 | else:
491 | net = SymbolicNet(self.n_layers,
492 | x_dim=self.x_dim,
493 | funcs=self.funcs_per_layer,
494 | initial_weights=[
495 | # kind of a hack for truncated normal
496 | torch.fmod(torch.normal(0, self.init_sd_first, size=(self.x_dim, width_per_layer[0] + n_double_per_layer[0])),
497 | 2),
498 | # binary operator has two inputs
499 | torch.fmod(
500 | torch.normal(0, self.init_sd_middle, size=(width_per_layer[0], width_per_layer[1] + n_double_per_layer[1])),
501 | 2),
502 | torch.fmod(
503 | torch.normal(0, self.init_sd_middle, size=(width_per_layer[1], width_per_layer[2] + n_double_per_layer[2])),
504 | 2),
505 | torch.fmod(torch.normal(0, self.init_sd_last, size=(width_per_layer[-1], 1)), 2)
506 | ]).to(self.device)
507 |
508 | net.to(self.dtype)
509 |
510 | loss_val = np.nan
511 | restart_flag = False
512 | while np.isnan(loss_val):
513 | # training restarts if gradients blow up
514 | criterion = nn.MSELoss()
515 | optimizer = optim.RMSprop(net.parameters(),
516 | lr=self.config.learning_rate2,
517 | alpha=0.9, # smoothing constant
518 | eps=1e-10,
519 | momentum=0.0,
520 | centered=False)
521 |
522 | # adaptive learning rate
523 | lmbda = lambda epoch: 0.1
524 | scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
525 |
526 | if self.clip_grad:
527 | que = collections.deque()
528 |
529 | net.train() # Set model to training mode
530 |
531 | # First stage of training, preceded by 0th warmup stage
532 | for epoch in range(self.config.n_epochs1 + 2000):
533 | optimizer.zero_grad() # zero the parameter gradients
534 | outputs = net(data) # forward pass
535 | regularization = L12Smooth(a=0.01)
536 | mse_loss = criterion(outputs, target)
537 |
538 | reg_loss = regularization(net.get_weights_tensor())
539 | loss = mse_loss + self.config.reg_weight * reg_loss
540 | # loss = mse_loss
541 | loss.backward()
542 |
543 | if self.clip_grad:
544 | grad_norm = log_grad_norm(net)
545 | que.append(grad_norm)
546 | if len(que) > self.window_size:
547 | que.popleft()
548 | clip_threshold = 0.1 * sum(que) / len(que)
549 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=clip_threshold, norm_type=2)
550 | else:
551 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=self.max_norm, norm_type=2)
552 |
553 | optimizer.step()
554 |
555 | # summary
556 | if epoch % self.config.summary_step == 0:
557 | error_val = mse_loss.item()
558 | reg_val = reg_loss.item()
559 | loss_val = loss.item()
560 |
561 | error_list.append(error_val)
562 | reg_list.append(reg_val)
563 | loss_list.append(loss_val)
564 |
565 | with torch.no_grad(): # test error
566 | test_outputs = net(test_data) # [num_points, 1] as same as test_target
567 | if self.reward_type == 'mse':
568 | test_loss = F.mse_loss(test_outputs, test_target)
569 | elif self.reward_type == 'nrmse':
570 | test_loss = nrmse(test_target, test_outputs)
571 | error_test_val = test_loss.item()
572 | error_test_list.append(error_test_val)
573 | test_outputs = torch.where(torch.isnan(test_outputs), torch.full_like(test_outputs, 100),
574 | test_outputs)
575 | r2 = R_Square(test_target, test_outputs)
576 | r2_test_list.append(r2)
577 |
578 | if self.config.verbose:
579 | print("Epoch: {}\tTotal training loss: {}\tTest {}: {}".format(epoch, loss_val, self.reward_type, error_test_val))
580 |
581 | if np.isnan(loss_val) or loss_val > 1000: # If loss goes to NaN, restart training
582 | restart_flag = True
583 | break
584 |
585 | if epoch == 2000:
586 | scheduler.step() # lr /= 10
587 |
588 | if restart_flag:
589 | break
590 |
591 | scheduler.step() # lr /= 10 again
592 |
593 | for epoch in range(self.config.n_epochs2):
594 | optimizer.zero_grad() # zero the parameter gradients
595 | outputs = net(data)
596 | regularization = L12Smooth(a=0.01)
597 | mse_loss = criterion(outputs, target)
598 | reg_loss = regularization(net.get_weights_tensor())
599 | loss = mse_loss + self.config.reg_weight * reg_loss
600 | loss.backward()
601 |
602 | if self.clip_grad:
603 | grad_norm = log_grad_norm(net)
604 | que.append(grad_norm)
605 | if len(que) > self.window_size:
606 | que.popleft()
607 | clip_threshold = 0.1 * sum(que) / len(que)
608 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=clip_threshold, norm_type=2)
609 | else:
610 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=self.max_norm, norm_type=2)
611 |
612 | optimizer.step()
613 |
614 | if epoch % self.config.summary_step == 0:
615 | error_val = mse_loss.item()
616 | reg_val = reg_loss.item()
617 | loss_val = loss.item()
618 | error_list.append(error_val)
619 | reg_list.append(reg_val)
620 | loss_list.append(loss_val)
621 |
622 | with torch.no_grad(): # test error
623 | test_outputs = net(test_data)
624 | if self.reward_type == 'mse':
625 | test_loss = F.mse_loss(test_outputs, test_target)
626 | elif self.reward_type == 'nrmse':
627 | test_loss = nrmse(test_target, test_outputs)
628 | error_test_val = test_loss.item()
629 | error_test_list.append(error_test_val)
630 | test_outputs = torch.where(torch.isnan(test_outputs), torch.full_like(test_outputs, 100),
631 | test_outputs)
632 | r2 = R_Square(test_target, test_outputs)
633 | r2_test_list.append(r2)
634 | if self.config.verbose:
635 | print("Epoch: {}\tTotal training loss: {}\tTest {}: {}".format(epoch, loss_val, self.reward_type, error_test_val))
636 |
637 | if np.isnan(loss_val) or loss_val > 1000: # If loss goes to NaN, restart training
638 | break
639 |
640 | if restart_flag:
641 | # self.play_episodes()
642 | retrain_num += 1
643 | if retrain_num == 5: # only allow 5 restarts
644 | return 10000, None, None
645 | continue
646 |
647 | # After the training, the symbolic network was transformed into an expression by pruning
648 | with torch.no_grad():
649 | weights = net.get_weights()
650 | if self.add_bias:
651 | biases = net.get_biases()
652 | else:
653 | biases = None
654 | expr = pretty_print.network(weights, self.funcs_per_layer, self.vars_name, self.threshold, self.add_bias, biases)
655 |
656 | # results of training trials
657 | error_test_final.append(error_test_list[-1])
658 | r2_test_final.append(r2_test_list[-1])
659 | eq_list.append(expr)
660 |
661 | trial += 1
662 |
663 | error_expr_sorted = sorted(zip(error_test_final, r2_test_final, eq_list), key=lambda x: x[0]) # List of (error, r2, expr)
664 | print('error_expr_sorted', error_expr_sorted)
665 |
666 | return error_expr_sorted[0]
667 |
668 | def load_data(self, path):
669 | data = pd.read_csv(path)
670 |
671 | if data.shape[1] < 2:
672 | raise ValueError('CSV file must contain at least 2 columns.')
673 |
674 | x_data = data.iloc[:, :-1]
675 | y_data = data.iloc[:, -1:]
676 |
677 | X = torch.tensor(x_data.values, dtype=torch.float32)
678 | y = torch.tensor(y_data.values, dtype=torch.float32)
679 |
680 | return X, y
681 |
682 |
683 | if __name__ == "__main__":
684 | # Configure the parameters
685 | config = Params()
686 |
687 | # Example 1: Input ground truth expression
688 | SR = SymboliRegression(config=config, func="x_1 + x_2", func_name="x_1+x_2")
689 | eq, R2, error, relative_error = SR.solve_environment()
690 | print('Expression: ', eq)
691 | print('R2: ', R2)
692 | print('error: ', error)
693 | print('relative_error: ', relative_error)
694 | print('log(1 + MSE): ', np.log(1 + error))
695 |
696 | # Example 2: Input data path of csv file
697 | # SR = SymboliRegression(config=config, func_name="Nguyen-1", data_path="./data/Nguyen-1.csv")
698 | # eq, R2, error, relative_error = SR.solve_environment()
699 | # print('Expression: ', eq)
700 | # print('R2: ', R2)
701 | # print('error: ', error)
702 | # print('relative_error: ', relative_error)
703 | # print('log(1 + MSE): ', np.log(1 + error))
--------------------------------------------------------------------------------
/DySymNet/__init__.py:
--------------------------------------------------------------------------------
1 | name = "DymSymNet"
--------------------------------------------------------------------------------
/DySymNet/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/DySymNet/scripts/__init__.py
--------------------------------------------------------------------------------
/DySymNet/scripts/controller.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Categorical
5 | from torch.nn.functional import one_hot, log_softmax
6 |
7 |
8 | class Agent(nn.Module):
9 |
10 | def __init__(self, auto, input_size, hidden_size, num_funcs_avail, n_layers, num_funcs_layer, device=None, dtype=torch.float32):
11 | super(Agent, self).__init__()
12 | self.auto = auto
13 | self.num_funcs_avail = num_funcs_avail # Optional operator category per layer
14 | self.n_layers = n_layers # Optional number of layers
15 | self.num_funcs_layer = num_funcs_layer # Optional number of operators per layer
16 | self.dtype = dtype
17 |
18 | if device is not None:
19 | self.device = device
20 | else:
21 | self.device = 'cpu'
22 |
23 | if self.auto:
24 | self.n_layer_decoder = nn.Linear(hidden_size, len(self.n_layers), device=device)
25 | self.num_funcs_layer_decoder = nn.Linear(hidden_size, len(self.num_funcs_layer), device=device)
26 | self.max_input_size = max(len(self.n_layers), len(self.num_funcs_layer))
27 | self.dynamic_lstm_cell = nn.LSTMCell(self.max_input_size, hidden_size, device=device)
28 | self.embedding = nn.Linear(self.num_funcs_avail, len(self.num_funcs_layer), device=device)
29 |
30 | self.lstm_cell = nn.LSTMCell(input_size, hidden_size, device=device)
31 | self.decoder = nn.Linear(hidden_size, self.num_funcs_avail, device=device) # output probability distribution
32 | self.n_steps = n_layers
33 | self.hidden_size = hidden_size
34 | self.hidden = self.init_hidden()
35 |
36 | def init_hidden(self):
37 | h_t = torch.zeros(1, self.hidden_size, dtype=self.dtype, device=self.device) # [batch_size, hidden_size]
38 | c_t = torch.zeros(1, self.hidden_size, dtype=self.dtype, device=self.device) # [batch_size, hidden_size]
39 |
40 | return h_t, c_t
41 |
42 | def forward(self, input):
43 |
44 | if self.auto:
45 | if input.shape[-1] < self.max_input_size:
46 | input = nn.functional.pad(input, (0, self.max_input_size - input.shape[0]), 'constant', 0)
47 |
48 | assert input.shape[-1] == self.max_input_size, 'Error: the input dim of the first step is not equal to the max dim'
49 |
50 | h_t, c_t = self.hidden
51 |
52 | # Sample the number of layers first
53 | h_t, c_t = self.dynamic_lstm_cell(input, (h_t, c_t)) # [batch_size, hidden_size]
54 | n_layer_logits = self.n_layer_decoder(h_t) # [batch_size, len(n_layers)]
55 | n_layer_probs = F.softmax(n_layer_logits, dim=-1)
56 | dist = Categorical(probs=n_layer_probs)
57 | action_index1 = dist.sample()
58 | log_prob1 = dist.log_prob(action_index1)
59 | entropy1 = dist.entropy()
60 | num_layers = self.n_layers[action_index1]
61 |
62 | # Sample the number of operators per layer
63 | input = n_layer_logits
64 | if input.shape[-1] < self.max_input_size:
65 | input = nn.functional.pad(input, (0, self.max_input_size - input.shape[-1]), 'constant', 0)
66 | h_t, c_t = self.dynamic_lstm_cell(input, (h_t, c_t))
67 | n_funcs_layer_logits = self.num_funcs_layer_decoder(h_t) # [batch_size, len(num_funcs_layer)]
68 | n_funcs_layer_probs = F.softmax(n_funcs_layer_logits, dim=-1)
69 | dist = Categorical(probs=n_funcs_layer_probs)
70 | action_index2 = dist.sample()
71 | log_prob2 = dist.log_prob(action_index2)
72 | entropy2 = dist.entropy()
73 | num_funcs_layer = self.num_funcs_layer[action_index2]
74 |
75 | # Sample the operators
76 | input = n_funcs_layer_logits
77 | if input.shape[-1] < self.max_input_size:
78 | input = nn.functional.pad(input, (0, self.max_input_size - input.shape[0]), 'constant', 0)
79 |
80 | outputs = []
81 | for t in range(num_layers):
82 | h_t, c_t = self.dynamic_lstm_cell(input, (h_t, c_t))
83 | output = self.decoder(h_t) # [batch_size, len(func_avail)]
84 | outputs.append(output)
85 | input = self.embedding(output)
86 |
87 | outputs = torch.stack(outputs).squeeze(1) # [n_layers, len(funcs)]
88 | probs = F.softmax(outputs, dim=-1)
89 | dist = Categorical(probs=probs)
90 | action_index3 = dist.sample((num_funcs_layer,)).transpose(0, 1) # [num_layers, num_func_layer]
91 | # print("action_index: ", action_index)
92 | log_probs = dist.log_prob(action_index3.transpose(0, 1)).transpose(0, 1) # [num_layers, num_func_layer] compute the log probability of the sampled action
93 | entropies = dist.entropy() # [num_layers] compute the entropy of the action distribution
94 | log_probs, entropies = torch.sum(log_probs), torch.sum(entropies)
95 |
96 | # another way to sample
97 | # probs = F.softmax(episode_logits, dim=-1)
98 | # action_index = torch.multinomial(probs, self.num_func_layer, replacement=True)
99 |
100 | # mask = one_hot(action_index, num_classes=self.input_size).squeeze(1)
101 | # log_probs = log_softmax(episode_logits, dim=-1)
102 | # episode_log_probs = torch.sum(mask.float() * log_probs)
103 |
104 | log_probs = log_probs + log_prob1 + log_prob2
105 | entropies = entropies + entropy1 + entropy2
106 |
107 | return num_layers, num_funcs_layer, action_index3, log_probs, entropies
108 |
109 | # Fix the number of layers and the number of operators per layer, only sample the operators, each layer is different
110 | else:
111 | outputs = []
112 | h_t, c_t = self.hidden
113 |
114 | for i in range(self.n_steps):
115 | h_t, c_t = self.lstm_cell(input, (h_t, c_t))
116 | output = self.decoder(h_t) # [batch_size, num_choices]
117 | outputs.append(output)
118 | input = output
119 |
120 | outputs = torch.stack(outputs).squeeze(1) # [num_steps, num_choices]
121 | probs = F.softmax(outputs, dim=-1)
122 | dist = Categorical(probs=probs)
123 | action_index = dist.sample((self.num_funcs_layer,)).transpose(0, 1) # [num_layers, num_func_layer]
124 | # print("action_index: ", action_index)
125 | log_probs = dist.log_prob(action_index.transpose(0, 1)).transpose(0, 1) # [num_layers, num_func_layer]
126 | entropies = dist.entropy() # [num_layers]
127 | log_probs, entropies = torch.sum(log_probs), torch.sum(entropies)
128 | return action_index, log_probs, entropies
129 |
--------------------------------------------------------------------------------
/DySymNet/scripts/functions.py:
--------------------------------------------------------------------------------
1 | """Functions for use with symbolic regression.
2 |
3 | These functions encapsulate multiple implementations (sympy, Tensorflow, numpy) of a particular function so that the
4 | functions can be used in multiple contexts."""
5 |
6 | import torch
7 | # import tensorflow as tf
8 | import numpy as np
9 | import sympy as sp
10 |
11 |
12 | class BaseFunction:
13 | """Abstract class for primitive functions"""
14 |
15 | def __init__(self, norm=1):
16 | self.norm = norm
17 |
18 | def sp(self, x):
19 | """Sympy implementation"""
20 | return None
21 |
22 | def torch(self, x):
23 | """No need for base function"""
24 | return None
25 |
26 | def tf(self, x):
27 | """Automatically convert sympy to TensorFlow"""
28 | z = sp.symbols('z')
29 | return sp.utilities.lambdify(z, self.sp(z), 'tensorflow')(x)
30 |
31 | def np(self, x):
32 | """Automatically convert sympy to numpy"""
33 | z = sp.symbols('z')
34 | return sp.utilities.lambdify(z, self.sp(z), 'numpy')(x)
35 |
36 |
37 | class Constant(BaseFunction):
38 | def torch(self, x):
39 | return torch.ones_like(x)
40 |
41 | def sp(self, x):
42 | return 1
43 |
44 | def np(self, x):
45 | return np.ones_like
46 |
47 |
48 | class Identity(BaseFunction):
49 | def __init__(self):
50 | super(Identity, self).__init__()
51 | self.name = 'id'
52 |
53 | def torch(self, x):
54 | return x / self.norm # ??
55 |
56 | def sp(self, x):
57 | return x / self.norm
58 |
59 | def np(self, x):
60 | return np.array(x) / self.norm
61 |
62 |
63 | class Square(BaseFunction):
64 | def __init__(self):
65 | super(Square, self).__init__()
66 | self.name = 'pow2'
67 |
68 | def torch(self, x):
69 | return torch.square(x) / self.norm
70 |
71 | def sp(self, x):
72 | return x ** 2 / self.norm
73 |
74 | def np(self, x):
75 | return np.square(x) / self.norm
76 |
77 |
78 | class Pow(BaseFunction):
79 | def __init__(self, power, norm=1):
80 | BaseFunction.__init__(self, norm=norm)
81 | self.power = power
82 | self.name = 'pow{}'.format(int(power))
83 |
84 | def torch(self, x):
85 | return torch.pow(x, self.power) / self.norm
86 |
87 | def sp(self, x):
88 | return x ** self.power / self.norm
89 |
90 |
91 | # class Sin(BaseFunction):
92 | # def torch(self, x):
93 | # return torch.sin(x * 2 * 2 * np.pi) / self.norm
94 | #
95 | # def sp(self, x):
96 | # return sp.sin(x * 2 * 2 * np.pi) / self.norm
97 |
98 | class Sin(BaseFunction):
99 | def __init__(self):
100 | super().__init__()
101 | self.name = 'sin'
102 |
103 | def torch(self, x):
104 | return torch.sin(x) / self.norm
105 |
106 | def sp(self, x):
107 | return sp.sin(x) / self.norm
108 |
109 |
110 | class Cos(BaseFunction):
111 | def __init__(self):
112 | super(Cos, self).__init__()
113 | self.name = 'cos'
114 |
115 | def torch(self, x):
116 | return torch.cos(x) / self.norm
117 |
118 | def sp(self, x):
119 | return sp.cos(x) / self.norm
120 |
121 |
122 | class Tan(BaseFunction):
123 | def __init__(self):
124 | super(Tan, self).__init__()
125 | self.name = 'tan'
126 |
127 | def torch(self, x):
128 | return torch.tan(x) / self.norm
129 |
130 | def sp(self, x):
131 | return sp.tan(x) / self.norm
132 |
133 |
134 | class Sigmoid(BaseFunction):
135 | def torch(self, x):
136 | return torch.sigmoid(x) / self.norm
137 |
138 | # def tf(self, x):
139 | # return tf.sigmoid(x) / self.norm
140 |
141 | def sp(self, x):
142 | return 1 / (1 + sp.exp(-20 * x)) / self.norm
143 |
144 | def np(self, x):
145 | return 1 / (1 + np.exp(-20 * x)) / self.norm
146 |
147 | def name(self, x):
148 | return "sigmoid(x)"
149 |
150 |
151 | # class Exp(BaseFunction):
152 | # def __init__(self, norm=np.e):
153 | # super().__init__(norm)
154 | #
155 | # # ?? why the minus 1
156 | # def torch(self, x):
157 | # return (torch.exp(x) - 1) / self.norm
158 | #
159 | # def sp(self, x):
160 | # return (sp.exp(x) - 1) / self.norm
161 |
162 | class Exp(BaseFunction):
163 | def __init__(self):
164 | super().__init__()
165 | self.name = 'exp'
166 |
167 | # ?? why the minus 1
168 | def torch(self, x):
169 | return torch.exp(x)
170 |
171 | def sp(self, x):
172 | return sp.exp(x)
173 |
174 |
175 | class Log(BaseFunction):
176 | def __init__(self):
177 | super(Log, self).__init__()
178 | self.name = 'log'
179 |
180 | def torch(self, x):
181 | return torch.log(torch.abs(x) + 1e-6) / self.norm
182 |
183 | def sp(self, x):
184 | return sp.log(sp.Abs(x) + 1e-6) / self.norm
185 |
186 |
187 | class Sqrt(BaseFunction):
188 | def __init__(self):
189 | super(Sqrt, self).__init__()
190 | self.name = 'sqrt'
191 |
192 | def torch(self, x):
193 | return torch.sqrt(torch.abs(x)) / self.norm
194 |
195 | def sp(self, x):
196 | return sp.sqrt(sp.Abs(x)) / self.norm
197 |
198 |
199 | class BaseFunction2:
200 | """Abstract class for primitive functions with 2 inputs"""
201 |
202 | def __init__(self, norm=1.):
203 | self.norm = norm
204 |
205 | def sp(self, x, y):
206 | """Sympy implementation"""
207 | return None
208 |
209 | def torch(self, x, y):
210 | return None
211 |
212 | def tf(self, x, y):
213 | """Automatically convert sympy to TensorFlow"""
214 | a, b = sp.symbols('a b')
215 | return sp.utilities.lambdify([a, b], self.sp(a, b), 'tensorflow')(x, y)
216 |
217 | def np(self, x, y):
218 | """Automatically convert sympy to numpy"""
219 | a, b = sp.symbols('a b')
220 | return sp.utilities.lambdify([a, b], self.sp(a, b), 'numpy')(x, y)
221 |
222 | # def name(self, x, y):
223 | # return str(self.sp)
224 |
225 |
226 | class Product(BaseFunction2):
227 | def __init__(self, norm=0.1):
228 | super().__init__(norm=norm)
229 | self.name = '*'
230 |
231 | def torch(self, x, y):
232 | return x * y / self.norm
233 |
234 | def sp(self, x, y):
235 | return x * y / self.norm
236 |
237 |
238 | class Plus(BaseFunction2):
239 | def __init__(self, norm=1.0):
240 | super().__init__(norm=norm)
241 | self.name = '+'
242 |
243 | def torch(self, x, y):
244 | return (x + y) / self.norm
245 |
246 | def sp(self, x, y):
247 | return (x + y) / self.norm
248 |
249 |
250 | class Sub(BaseFunction2):
251 | def __init__(self, norm=1.0):
252 | super().__init__(norm=norm)
253 | self.name = '-'
254 |
255 | def torch(self, x, y):
256 | return (x - y) / self.norm
257 |
258 | def sp(self, x, y):
259 | return (x - y) / self.norm
260 |
261 |
262 | class Div(BaseFunction2):
263 | def __init__(self):
264 | super(Div, self).__init__()
265 | self.name = '/'
266 |
267 | def torch(self, x, y):
268 | return x / (y + 1e-6)
269 |
270 | def sp(self, x, y):
271 | return x / (y + 1e-6)
272 |
273 |
274 | def count_inputs(funcs):
275 | i = 0
276 | for func in funcs:
277 | if isinstance(func, BaseFunction):
278 | i += 1
279 | elif isinstance(func, BaseFunction2):
280 | i += 2
281 | return i
282 |
283 |
284 | def count_double(funcs):
285 | i = 0
286 | for func in funcs:
287 | if isinstance(func, BaseFunction2):
288 | i += 1
289 | return i
290 |
291 |
292 | default_func = [
293 | Product(),
294 | Plus(),
295 | Sin(),
296 | ]
297 |
--------------------------------------------------------------------------------
/DySymNet/scripts/params.py:
--------------------------------------------------------------------------------
1 | from .functions import *
2 |
3 |
4 | class Params:
5 | # Optional operators during sampling
6 | funcs_avail = [Identity(),
7 | Sin(),
8 | Cos(),
9 | # Tan(),
10 | # Exp(),
11 | # Log(),
12 | # Sqrt(),
13 | Square(),
14 | # Pow(3),
15 | # Pow(4),
16 | # Pow(5),
17 | # Pow(6),
18 | Plus(),
19 | Sub(),
20 | Product(),
21 | # Div()
22 | ]
23 | n_layers = [2, 3, 4, 5] # optional number of layers
24 | num_func_layer = [2, 3, 4, 5, 6] # optional number of functions in each layer
25 |
26 | # symbolic network training parameters
27 | learning_rate2 = 1e-2
28 | reg_weight = 5e-3
29 | threshold = 0.05
30 | trials = 1 # training trials of symbolic network
31 | n_epochs1 = 10001
32 | n_epochs2 = 10001
33 | summary_step = 1000
34 | clip_grad = True # clip gradient or not
35 | max_norm = 1 # norm threshold for gradient clipping
36 | window_size = 50 # window size for adaptive gradient clipping
37 | refine_constants = True # refine constants or not
38 | n_restarts = 1 # number of restarts for BFGS optimization
39 | add_bias = False # add bias or not
40 | verbose = True # print training process or not
41 | use_gpu = False # use cuda or not
42 | plot_reward = False # plot reward or not
43 |
44 | # controller parameters
45 | num_epochs = 500
46 | batch_size = 10
47 | if isinstance(n_layers, list) or isinstance(num_func_layer, list):
48 | input_size = max(len(n_layers), len(num_func_layer))
49 | else:
50 | input_size = len(funcs_avail)
51 | optimizer = "Adam"
52 | hidden_size = 32
53 | embedding_size = 16
54 | learning_rate1 = 0.0006
55 | risk_seeking = True
56 | risk_factor = 0.5
57 | entropy_weight = 0.005
58 | reward_type = "mse" # mse, nrmse
59 |
60 | # dataset parameters
61 | N_TRAIN = 100 # Size of training dataset
62 | N_VAL = 100 # Size of validation dataset
63 | NOISE = 0 # Standard deviation of noise for training dataset
64 | DOMAIN = (-1, 1) # Domain of dataset - range from which we sample x. Default (-1, 1)
65 | # DOMAIN = np.array([[0, -1, -1], [1, 1, 1]]) # Use this format if each input variable has a different domain
66 | N_TEST = 100 # Size of test dataset
67 | DOMAIN_TEST = (-1, 1) # Domain of test dataset - should be larger than training domain to test extrapolation. Default (-2, 2)
68 | var_names = [f'x_{i}' for i in range(1, 21)] # not used
69 |
70 | # save path
71 | results_dir = './results/test'
72 |
--------------------------------------------------------------------------------
/DySymNet/scripts/pretty_print.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a mathematical expression of the symbolic regression network (AKA EQL network) using SymPy. This expression
3 | can be used to pretty-print the expression (including human-readable text, LaTeX, etc.). SymPy also allows algebraic
4 | manipulation of the expression.
5 | The main function is network(...)
6 | There are several filtering functions to simplify expressions, although these are not always needed if the weight matrix
7 | is already pruned.
8 | """
9 | import pdb
10 |
11 | import sympy as sp
12 | from . import functions
13 |
14 |
15 | def apply_activation(W, funcs, n_double=0):
16 | """Given an (n, m) matrix W and (m) vector of funcs, apply funcs to W.
17 |
18 | Arguments:
19 | W: (n, m) matrix
20 | funcs: list of activation functions (SymPy functions)
21 | n_double: Number of activation functions that take in 2 inputs
22 |
23 | Returns:
24 | SymPy matrix with 1 column that represents the output of applying the activation functions.
25 | """
26 | W = sp.Matrix(W)
27 | if n_double == 0:
28 | for i in range(W.shape[0]):
29 | for j in range(W.shape[1]):
30 | W[i, j] = funcs[j](W[i, j])
31 | else:
32 | W_new = W.copy()
33 | out_size = len(funcs)
34 | for i in range(W.shape[0]):
35 | in_j = 0
36 | out_j = 0
37 | while out_j < out_size - n_double:
38 | W_new[i, out_j] = funcs[out_j](W[i, in_j])
39 | in_j += 1
40 | out_j += 1
41 | while out_j < out_size:
42 | W_new[i, out_j] = funcs[out_j](W[i, in_j], W[i, in_j + 1])
43 | in_j += 2
44 | out_j += 1
45 | for i in range(n_double):
46 | W_new.col_del(-1)
47 | W = W_new
48 | return W
49 |
50 |
51 | def sym_pp(W_list, funcs, var_names, threshold=0.01, n_double=None, add_bias=False, biases=None):
52 | """Pretty print the hidden layers (not the last layer) of the symbolic regression network
53 |
54 | Arguments:
55 | W_list: list of weight matrices for the hidden layers
56 | funcs: dict of lambda functions using sympy. has the same size as W_list[i][j, :]
57 | var_names: list of strings for names of variables
58 | threshold: threshold for filtering expression. set to 0 for no filtering.
59 | n_double: list Number of activation functions that take in 2 inputs
60 |
61 | Returns:
62 | Simplified sympy expression.
63 | """
64 | vars = []
65 | for var in var_names:
66 | if isinstance(var, str):
67 | vars.append(sp.Symbol(var))
68 | else:
69 | vars.append(var)
70 | try:
71 | expr = sp.Matrix(vars).T
72 |
73 | if add_bias and biases is not None:
74 | assert len(W_list) == len(biases), "The number of biases must be equal to the number of weights."
75 | for i, (W, b) in enumerate(zip(W_list, biases)):
76 | W = filter_mat(sp.Matrix(W), threshold=threshold)
77 | b = filter_mat(sp.Matrix(b), threshold=threshold)
78 | expr = expr * W + b
79 | expr = apply_activation(expr, funcs[i + 1], n_double=n_double[i])
80 |
81 | else:
82 | for i, W in enumerate(W_list):
83 | W = filter_mat(sp.Matrix(W), threshold=threshold) # Pruning
84 | expr = expr * W
85 | expr = apply_activation(expr, funcs[i + 1], n_double=n_double[i])
86 | except:
87 | pdb.set_trace()
88 | # expr = expr * W_list[-1]
89 | return expr
90 |
91 |
92 | def last_pp(eq, W, add_bias=False, biases=None):
93 | """Pretty print the last layer."""
94 | if add_bias and biases is not None:
95 | return eq * filter_mat(sp.Matrix(W)) + filter_mat(sp.Matrix(biases))
96 | else:
97 | return eq * filter_mat(sp.Matrix(W))
98 |
99 |
100 | def network(weights, funcs, var_names, threshold=0.01, add_bias=False, biases=None):
101 | """Pretty print the entire symbolic regression network.
102 |
103 | Arguments:
104 | weights: list of weight matrices for the entire network
105 | funcs: dict of lambda functions using sympy. has the same size as W_list[i][j, :]
106 | var_names: list of strings for names of variables
107 | threshold: threshold for filtering expression. set to 0 for no filtering.
108 |
109 | Returns:
110 | Simplified sympy expression."""
111 | n_double = [functions.count_double(funcs_per_layer) for funcs_per_layer in funcs.values()]
112 | # translate operators to sympy operators
113 | sp_funcs = {}
114 | for key, value in funcs.items():
115 | sp_value = [func.sp for func in value]
116 | sp_funcs.update({key: sp_value})
117 |
118 | if add_bias and biases is not None:
119 | assert len(weights) == len(biases), "The number of biases must be equal to the number of weights - 1."
120 | expr = sym_pp(weights[:-1], sp_funcs, var_names, threshold=threshold, n_double=n_double, add_bias=add_bias, biases=biases[:-1])
121 | expr = last_pp(expr, weights[-1], add_bias=add_bias, biases=biases[-1])
122 | else:
123 | expr = sym_pp(weights[:-1], sp_funcs, var_names, threshold=threshold, n_double=n_double, add_bias=add_bias)
124 | expr = last_pp(expr, weights[-1], add_bias=add_bias)
125 |
126 | try:
127 | expr = expr[0, 0]
128 | return expr
129 | except Exception as e:
130 | print("An exception occurred:", e)
131 |
132 |
133 |
134 | def filter_mat(mat, threshold=0.01):
135 | """Remove elements of a matrix below a threshold."""
136 | for i in range(mat.shape[0]):
137 | for j in range(mat.shape[1]):
138 | if abs(mat[i, j]) < threshold:
139 | mat[i, j] = 0
140 | return mat
141 |
142 |
143 | def filter_expr(expr, threshold=0.01):
144 | """Remove additive terms with coefficient below threshold
145 | TODO: Make more robust. This does not work in all cases."""
146 | expr_new = sp.Integer(0)
147 | for arg in expr.args:
148 | if arg.is_constant() and abs(arg) > threshold: # hack way to check if it's a number
149 | expr_new = expr_new + arg
150 | elif not arg.is_constant() and abs(arg.args[0]) > threshold:
151 | expr_new = expr_new + arg
152 | return expr_new
153 |
154 |
155 | def filter_expr2(expr, threshold=0.01):
156 | """Sets all constants under threshold to 0
157 | TODO: Test"""
158 | for a in sp.preorder_traversal(expr):
159 | if isinstance(a, sp.Float) and a < threshold:
160 | expr = expr.subs(a, 0)
161 | return expr
162 |
--------------------------------------------------------------------------------
/DySymNet/scripts/regularization.py:
--------------------------------------------------------------------------------
1 | """Methods for regularization to produce sparse networks.
2 |
3 | L2 regularization mostly penalizes the weight magnitudes without introducing sparsity.
4 | L1 regularization promotes sparsity.
5 | L1/2 promotes sparsity even more than L1. However, it can be difficult to train due to non-convexity and exploding
6 | gradients close to 0. Thus, we introduce a smoothed L1/2 regularization to remove the exploding gradients."""
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class L12Smooth(nn.Module):
13 | def __init__(self, a):
14 | super(L12Smooth, self).__init__()
15 | self.a = a
16 |
17 | def forward(self, input_tensor):
18 | """input: predictions"""
19 | return self.l12_smooth(input_tensor, self.a)
20 |
21 | def l12_smooth(self, input_tensor, a=0.05):
22 | """Smoothed L1/2 norm"""
23 | if type(input_tensor) == list:
24 | return sum([self.l12_smooth(tensor) for tensor in input_tensor])
25 |
26 | smooth_abs = torch.where(torch.abs(input_tensor) < a,
27 | torch.pow(input_tensor, 4) / (-8 * a ** 3) + torch.square(input_tensor) * 3 / 4 / a + 3 * a / 8,
28 | torch.abs(input_tensor))
29 |
30 | return torch.sum(torch.sqrt(smooth_abs))
31 |
--------------------------------------------------------------------------------
/DySymNet/scripts/symbolic_network.py:
--------------------------------------------------------------------------------
1 | """Contains the symbolic regression neural network architecture."""
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from . import functions as functions
6 |
7 |
8 | class SymbolicLayer(nn.Module):
9 | """Neural network layer for symbolic regression where activation functions correspond to primitive functions.
10 | Can take multi-input activation functions (like multiplication)"""
11 |
12 | def __init__(self, funcs=None, initial_weight=None, init_stddev=0.1, in_dim=None, add_bias=False):
13 | """
14 | funcs: List of activation functions, using utils.functions
15 | initial_weight: (Optional) Initial value for weight matrix
16 | variable: Boolean of whether initial_weight is a variable or not
17 | init_stddev: (Optional) if initial_weight isn't passed in, this is standard deviation of initial weight
18 | """
19 | super().__init__()
20 |
21 | if funcs is None:
22 | funcs = functions.default_func
23 | self.initial_weight = initial_weight
24 | self.W = None # Weight matrix
25 | self.built = False # Boolean whether weights have been initialized
26 | self.add_bias = add_bias
27 |
28 | self.output = None # tensor for layer output
29 | self.n_funcs = len(funcs) # Number of activation functions (and number of layer outputs)
30 | self.funcs = [func.torch for func in funcs] # Convert functions to list of PyTorch functions
31 | self.n_double = functions.count_double(funcs) # Number of activation functions that take 2 inputs
32 | self.n_single = self.n_funcs - self.n_double # Number of activation functions that take 1 input
33 |
34 | self.out_dim = self.n_funcs + self.n_double
35 |
36 | if self.initial_weight is not None: # use the given initial weight
37 | self.W = nn.Parameter(self.initial_weight.clone().detach()) # copies
38 | self.built = True
39 | else:
40 | self.W = nn.Parameter(torch.fmod(torch.normal(mean=0.0, std=init_stddev, size=(in_dim, self.out_dim)), 2))
41 | if add_bias:
42 | self.b = nn.Parameter(torch.fmod(torch.normal(mean=0.0, std=init_stddev, size=(1, self.out_dim)), 2))
43 |
44 | def forward(self, x): # used to be __call__
45 | """Multiply by weight matrix and apply activation units"""
46 |
47 | if self.add_bias:
48 | g = torch.matmul(x, self.W) + self.b
49 | else:
50 | g = torch.matmul(x, self.W)
51 | self.output = []
52 |
53 | in_i = 0 # input index
54 | out_i = 0 # output index
55 | # Apply functions with only a single input, binary operators must come after unary operators
56 | while out_i < self.n_single:
57 | self.output.append(self.funcs[out_i](g[:, in_i])) # g[:, in_i] is the input to the activation function
58 | in_i += 1
59 | out_i += 1
60 | # Apply functions that take 2 inputs and produce 1 output
61 | while out_i < self.n_funcs:
62 | self.output.append(self.funcs[out_i](g[:, in_i], g[:, in_i + 1]))
63 | in_i += 2
64 | out_i += 1
65 |
66 | self.output = torch.stack(self.output, dim=1) # [n_points, n_funcs]
67 |
68 | return self.output
69 |
70 | def get_weight(self):
71 | return self.W.cpu().detach().numpy()
72 |
73 | def get_bias(self):
74 | return self.b.cpu().detach().numpy()
75 |
76 | def get_weight_tensor(self):
77 | return self.W.clone()
78 |
79 |
80 | class SymbolicNet(nn.Module):
81 | """Symbolic regression network with multiple layers. Produces one output."""
82 |
83 | def __init__(self, symbolic_depth, x_dim, funcs=None, initial_weights=None, init_stddev=0.1, add_bias=False):
84 | super(SymbolicNet, self).__init__()
85 |
86 | self.depth = symbolic_depth # symbolic network depths
87 | self.funcs = funcs # operators for each layer sampled by controller,{id: []}
88 | self.add_bias = add_bias # add bias or not
89 | layer_in_dim = [x_dim] + [len(funcs[i+1]) for i in range(self.depth)]
90 |
91 | if initial_weights is not None:
92 | layers = [SymbolicLayer(funcs=funcs[i+1], initial_weight=initial_weights[i], in_dim=layer_in_dim[i], add_bias=self.add_bias) for i in range(self.depth)]
93 | self.output_weight = nn.Parameter(initial_weights[-1].clone().detach())
94 |
95 | else:
96 | # Each layer initializes its own weights
97 | if not isinstance(init_stddev, list):
98 | init_stddev = [init_stddev] * self.depth
99 | layers = [SymbolicLayer(funcs=self.funcs[i+1], init_stddev=init_stddev[i], in_dim=layer_in_dim[i], add_bias=self.add_bias)
100 | for i in range(self.depth)]
101 | # Initialize weights for last layer (without activation functions)
102 | self.output_weight = nn.Parameter(torch.rand((layers[-1].n_funcs, 1)))
103 | if add_bias:
104 | self.output_bias = nn.Parameter(torch.rand((1, 1)))
105 |
106 | self.hidden_layers = nn.Sequential(*layers)
107 |
108 | def forward(self, input):
109 | h = self.hidden_layers(input) # Building hidden layers
110 | return torch.matmul(h, self.output_weight) # Final output (no activation units) of network
111 |
112 | def get_weights(self):
113 | """Return list of weight matrices"""
114 | # First part is iterating over hidden weights. Then append the output weight.
115 | return [self.hidden_layers[i].get_weight() for i in range(self.depth)] + \
116 | [self.output_weight.cpu().detach().numpy()]
117 |
118 | def get_biases(self):
119 | return [self.hidden_layers[i].get_bias() for i in range(self.depth)] + \
120 | [self.output_bias.cpu().detach().numpy()]
121 |
122 | def get_weights_tensor(self):
123 | """Return list of weight matrices as tensors"""
124 | return [self.hidden_layers[i].get_weight_tensor() for i in range(self.depth)] + \
125 | [self.output_weight.clone()]
126 |
127 |
128 | class SymbolicLayerL0(SymbolicLayer):
129 | def __init__(self, in_dim=None, funcs=None, initial_weight=None, init_stddev=0.1,
130 | bias=False, droprate_init=0.5, lamba=1.,
131 | beta=2 / 3, gamma=-0.1, zeta=1.1, epsilon=1e-6):
132 | super().__init__(in_dim=in_dim, funcs=funcs, initial_weight=initial_weight, init_stddev=init_stddev)
133 |
134 | self.droprate_init = droprate_init if droprate_init != 0 else 0.5
135 | self.use_bias = bias
136 | self.lamba = lamba
137 | self.bias = None
138 | self.in_dim = in_dim
139 | self.eps = None
140 |
141 | self.beta = beta
142 | self.gamma = gamma
143 | self.zeta = zeta
144 | self.epsilon = epsilon
145 |
146 | if self.use_bias:
147 | self.bias = nn.Parameter(0.1 * torch.ones((1, self.out_dim)))
148 | self.qz_log_alpha = nn.Parameter(torch.normal(mean=np.log(1 - self.droprate_init) - np.log(self.droprate_init),
149 | std=1e-2, size=(in_dim, self.out_dim)))
150 |
151 | def quantile_concrete(self, u):
152 | """Quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
153 | y = torch.sigmoid((torch.log(u) - torch.log(1.0 - u) + self.qz_log_alpha) / self.beta)
154 | return y * (self.zeta - self.gamma) + self.gamma
155 |
156 | def sample_u(self, shape, reuse_u=False):
157 | """Uniform random numbers for concrete distribution"""
158 | if self.eps is None or not reuse_u:
159 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160 | self.eps = torch.rand(size=shape).to(device) * (1 - 2 * self.epsilon) + self.epsilon
161 | return self.eps
162 |
163 | def sample_z(self, batch_size, sample=True):
164 | """Use the hard concrete distribution as described in https://arxiv.org/abs/1712.01312"""
165 | if sample:
166 | eps = self.sample_u((batch_size, self.in_dim, self.out_dim))
167 | z = self.quantile_concrete(eps)
168 | return torch.clamp(z, min=0, max=1)
169 | else: # Mean of the hard concrete distribution
170 | pi = torch.sigmoid(self.qz_log_alpha)
171 | return torch.clamp(pi * (self.zeta - self.gamma) + self.gamma, min=0.0, max=1.0)
172 |
173 | def get_z_mean(self):
174 | """Mean of the hard concrete distribution"""
175 | pi = torch.sigmoid(self.qz_log_alpha)
176 | return torch.clamp(pi * (self.zeta - self.gamma) + self.gamma, min=0.0, max=1.0)
177 |
178 | def sample_weights(self, reuse_u=False):
179 | z = self.quantile_concrete(self.sample_u((self.in_dim, self.out_dim), reuse_u=reuse_u))
180 | mask = torch.clamp(z, min=0.0, max=1.0)
181 | return mask * self.W
182 |
183 | def get_weight(self):
184 | """Deterministic value of weight based on mean of z"""
185 | return self.W * self.get_z_mean()
186 |
187 | def loss(self):
188 | """Regularization loss term"""
189 | return torch.sum(torch.sigmoid(self.qz_log_alpha - self.beta * np.log(-self.gamma / self.zeta)))
190 |
191 | def forward(self, x, sample=True, reuse_u=False):
192 | """Multiply by weight matrix and apply activation units"""
193 | if sample:
194 | h = torch.matmul(x, self.sample_weights(reuse_u=reuse_u))
195 | else:
196 | w = self.get_weight()
197 | h = torch.matmul(x, w)
198 |
199 | if self.use_bias:
200 | h = h + self.bias
201 |
202 | # shape of h = (?, self.n_funcs)
203 |
204 | output = []
205 | # apply a different activation unit to each column of h
206 | in_i = 0 # input index
207 | out_i = 0 # output index
208 | # Apply functions with only a single input
209 | while out_i < self.n_single:
210 | output.append(self.funcs[out_i](h[:, in_i]))
211 | in_i += 1
212 | out_i += 1
213 | # Apply functions that take 2 inputs and produce 1 output
214 | while out_i < self.n_funcs:
215 | output.append(self.funcs[out_i](h[:, in_i], h[:, in_i + 1]))
216 | in_i += 2
217 | out_i += 1
218 | output = torch.stack(output, dim=1)
219 | return output
220 |
221 |
222 | class SymbolicNetL0(nn.Module):
223 | """Symbolic regression network with multiple layers. Produces one output."""
224 |
225 | def __init__(self, symbolic_depth, in_dim=1, funcs=None, initial_weights=None, init_stddev=0.1):
226 | super(SymbolicNetL0, self).__init__()
227 | self.depth = symbolic_depth # Number of hidden layers
228 | self.funcs = funcs
229 |
230 | layer_in_dim = [in_dim] + self.depth * [len(funcs)]
231 | if initial_weights is not None:
232 | layers = [SymbolicLayerL0(funcs=funcs, initial_weight=initial_weights[i],
233 | in_dim=layer_in_dim[i])
234 | for i in range(self.depth)]
235 | self.output_weight = nn.Parameter(initial_weights[-1].clone().detach())
236 | else:
237 | # Each layer initializes its own weights
238 | if not isinstance(init_stddev, list):
239 | init_stddev = [init_stddev] * self.depth
240 | layers = [SymbolicLayerL0(funcs=funcs, init_stddev=init_stddev[i], in_dim=layer_in_dim[i])
241 | for i in range(self.depth)]
242 | # Initialize weights for last layer (without activation functions)
243 | self.output_weight = nn.Parameter(torch.rand(size=(self.hidden_layers[-1].n_funcs, 1)) * 2)
244 |
245 | self.hidden_layers = nn.Sequential(*layers)
246 |
247 | def forward(self, input, sample=True, reuse_u=False):
248 | # connect output from previous layer to input of next layer
249 | h = input
250 | for i in range(self.depth):
251 | h = self.hidden_layers[i](h, sample=sample, reuse_u=reuse_u)
252 |
253 | h = torch.matmul(h, self.output_weight) # Final output (no activation units) of network
254 | return h
255 |
256 | def get_loss(self):
257 | return torch.sum(torch.stack([self.hidden_layers[i].loss() for i in range(self.depth)]))
258 |
259 | def get_weights(self):
260 | """Return list of weight matrices"""
261 | # First part is iterating over hidden weights. Then append the output weight.
262 | return [self.hidden_layers[i].get_weight().cpu().detach().numpy() for i in range(self.depth)] + \
263 | [self.output_weight.cpu().detach().numpy()]
264 |
--------------------------------------------------------------------------------
/DySymNet/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 | from sklearn import feature_selection
6 |
7 |
8 | def nrmse(y_true, y_pred):
9 | """y, y_pred should be (num_samples,)"""
10 | assert y_true.shape == y_pred.shape, "y_true and y_pred must have the same shape"
11 | var = torch.var(y_true)
12 | return (torch.sqrt(torch.mean((y_true - y_pred) ** 2)) / var).item()
13 |
14 |
15 | def MSE(y, y_pred):
16 | return torch.mean(torch.square(y - y_pred)).item()
17 |
18 |
19 | def Relative_Error(y, y_pred):
20 | return torch.mean(torch.abs((y - y_pred) / y)).item()
21 |
22 |
23 | def nrmse_np(y_true, y_pred):
24 | """y, y_pred should be (num_samples,)"""
25 | assert y_true.shape == y_pred.shape, "y_true and y_pred must have the same shape"
26 | var = np.var(y_true)
27 | return np.sqrt(np.mean((y_true - y_pred) ** 2)) / var
28 |
29 |
30 | def R_Square(y, y_pred):
31 | """y, y_pred should be same shape (num_samples,) or (num_samples, 1)"""
32 | return (1 - torch.sum(torch.square(y - y_pred)) / torch.sum(torch.square(y - torch.mean(y)))).item()
33 |
34 |
35 | def get_logger(filename, verbosity=1, name=None):
36 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
37 | formatter = logging.Formatter(
38 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
39 | )
40 | logger = logging.getLogger(name)
41 | logger.setLevel(level_dict[verbosity])
42 |
43 | fh = logging.FileHandler(filename, "w")
44 | fh.setFormatter(formatter)
45 | logger.addHandler(fh)
46 |
47 | sh = logging.StreamHandler()
48 | sh.setFormatter(formatter)
49 | logger.addHandler(sh)
50 |
51 | return logger
52 |
53 |
54 | def get_top_k_features(X, y, k=10):
55 | if y.ndim == 2:
56 | y = y[:, 0]
57 | # if X.shape[1] <= k:
58 | # return [i for i in range(X.shape[1])]
59 | else:
60 | kbest = feature_selection.SelectKBest(feature_selection.r_regression, k=k)
61 | kbest.fit(X, y)
62 | scores = kbest.scores_
63 | # scores = corr(X, y)
64 | top_features = np.argsort(-np.abs(scores))
65 | print("keeping only the top-{} features. Order was {}".format(k, top_features))
66 | return list(top_features[:k])
67 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2024] [Wenqiang Li]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## DySymNet
2 |
3 |
4 |

5 |
6 |
7 | 
8 |
9 | This repository contains the official Pytorch implementation for the paper [***A Neural-Guided Dynamic Symbolic Network for Exploring Mathematical Expressions from Data***](https://openreview.net/forum?id=pTmrk4XPFx) accepted by ICML'24.
10 |
11 | [](https://openreview.net/pdf?id=IejxxE9DO2)
12 | [](https://arxiv.org/abs/2309.13705)
13 | 
14 | 
15 |
16 | ## 🔥 News
17 | [2024/10/12] ***Now DySymNet can be installed via 'pip install DySymNet'. You only need 1 command to start exploring expressions!***
18 |
19 |
20 | ## 🚀 Highlights
21 |
22 | - ***DySymNet*** is a new search paradigm for symbolic regression (SR) that searches the symbolic network with various architectures instead of searching expressions in the large functional space.
23 | - ***DySymNet*** possesses promising capabilities in solving high-dimensional problems and optimizing coefficients, which are lacking in current SR methods.
24 | - ***DySymNet*** outperforms state-of-the-art baselines across various SR standard benchmark datasets and the well-known SRBench with more variables.
25 |
26 | ## 📦 Install
27 |
28 | Install ***DySymNet*** using only one command:
29 |
30 | ```setup
31 | pip install DySymNet
32 | ```
33 |
34 | ## 🤗 Quick start
35 |
36 | You can create and run the following script in any directory:
37 |
38 | ```python
39 | # Demo.py
40 | import numpy as np
41 | from DySymNet import SymbolicRegression
42 | from DySymNet.scripts.params import Params
43 | from DySymNet.scripts.functions import *
44 |
45 | # You can customize some hyperparameters according to parameter configuration
46 | config = Params()
47 |
48 | # such as operators
49 | funcs = [Identity(), Sin(), Cos(), Square(), Plus(), Sub(), Product()]
50 | config.funcs_avail = funcs
51 |
52 | # Example 1: Input ground truth expression
53 | SR = SymbolicRegression.SymboliRegression(config=config, func="x_1**3 + x_1**2 + x_1", func_name="Nguyen-1")
54 | eq, R2, error, relative_error = SR.solve_environment()
55 | print('Expression: ', eq)
56 | print('R2: ', R2)
57 | print('error: ', error)
58 | print('relative_error: ', relative_error)
59 | print('log(1 + MSE): ', np.log(1 + error))
60 | ```
61 |
62 | Then you can get a folder named as "results" in the current directory, which contains subfolders named func_name that record the logs of the script running process.
63 |
64 |
65 |
66 | ## ⚙️ Parameter configuration
67 |
68 | The main running script is `SymbolicRegression.py` and it relies on configuring runs via `params.py`. The `params.py` includes various hyperparameters of the controller RNN and the symbolic network. You can configure the following hyperparameters as required:
69 |
70 | #### parameters for symbolic network structure
71 |
72 | | Parameters | Description | **Example Values** |
73 | | :--------------: | :----------------------------------------------------------: | :----------------: |
74 | | `funcs_avail` | Operator library | See `params.py` |
75 | | `n_layers` | Range of symbolic network layers | [2, 3, 4, 5] |
76 | | `num_func_layer` | Range of the number of neurons per layer of a symbolic network | [2, 3, 4, 5, 6] |
77 |
78 | Note: You can add the additional operators in the `functions.py` by referring to existing operators and place them inside `funcs_avail` if you want to use them.
79 |
80 | #### parameters for controller RNN
81 |
82 | | Parameters | Description | **Example Values** |
83 | | :--------------: | :---------------------------------------: | :----------------: |
84 | | `num_epochs` | epochs for sampling | 500 |
85 | | `batch_size` | Size for a batch sampling | 10 |
86 | | `optimizer` | Optimizer for training RNN | Adam |
87 | | `hidden_size` | Hidden dim. of RNN layer | 32 |
88 | | `embedding_size` | Embedding dim. | 16 |
89 | | `learning_rate1` | Learning rate for training RNN | 0.0006 |
90 | | `risk_seeking` | using risk seeking policy gradient or not | True |
91 | | `risk_factor` | Risk factor | 0.5 |
92 | | `entropy_weight` | Entropy weight | 0.005 |
93 | | `reward_type` | Loss type for computing reward | mse |
94 |
95 |
96 |
97 | #### parameters for symbolic network training
98 |
99 | | Parameters | Description | **Example Values** |
100 | | :----------------: | :-------------------------------------------: | :----------------: |
101 | | `learning_rate2` | Learning rate for training symbolic network | 0.01 |
102 | | `reg_weight` | Regularizaiton weight | 5e-3 |
103 | | `threshold` | Prunning threshold | 0.05 |
104 | | `trials` | Training trials for training symbolic network | 1 |
105 | | `n_epochs1` | Epochs for the first training stage | 10001 |
106 | | `n_epochs2` | Epochs for the second training stage | 10001 |
107 | | `summary_step` | Summary for every `n` training steps | 1000 |
108 | | `clip_grad` | Using adaptive gradient clipping or not | True |
109 | | `max_norm` | Norm threshold for gradient clipping | 1.0 |
110 | | `window_size` | Window size for adaptive gradient clipping | 50 |
111 | | `refine_constants` | Refining constants or not | True |
112 | | `n_restarts` | Number of restarts for BFGS optimization | 1 |
113 | | `add_bias` | adding bias or not | False |
114 | | `verbose` | Print training process or not | True |
115 | | `use_gpu` | Using cuda or not | False |
116 | | `plot_reward` | Plot reward curve or not | False |
117 |
118 | **Note:** `threshold` controls the complexity of the final expression, and is a trade-off between complexity and precision, which you can customise according to your actual requirements.
119 |
120 | #### parameters for genearting input data
121 |
122 | | Parameters | Description | **Example Values** |
123 | | :-----------: | :----------------------------------------: | :----------------: |
124 | | `N_TRAIN` | Size of input data | 100 |
125 | | `N_VAL` | Size of validation dataset | 100 |
126 | | `NOISE` | Standard deviation of noise for input data | 0 |
127 | | `DOMAIN` | Domain of input data | (-1, 1) |
128 | | `N_TEST` | Size of test dataset | 100 |
129 | | `DOMAIN_TEST` | Domain of test dataset | (-1, 1) |
130 |
131 | #### Additional parameters
132 |
133 | `results_dir` configures the save path for all results
134 |
135 | ## 🤖 Symbolic Regression
136 |
137 | We provide two ways to perform symbolic regression tasks.
138 |
139 | #### Option1: Input ground truth expression
140 |
141 | When you want to discover an expression for which the ground truth is known, for example to test a standard benchmark, you can edit the script `SymbolicRegression.py` as follows:
142 |
143 | ```python
144 | # SymbolicRegression.py
145 | params = Params() # configuration for a specific task
146 | ground_truth_eq = "x_1 + x_2" # variable names should be written as x_i, where i>=1.
147 | eq_name = "x_1+x_2"
148 | SR = SymbolicRegression(config=params, func=ground_truth_eq, fun_name=eq_name) # A new folder named "func_name" will be created to store the result files.
149 | eq, R2, error, relative_error = SR.solve_environment() # return results
150 | ```
151 |
152 | In this way, the function `generate_data` is used to automatically generate the corresponding data set $\mathcal{D}(X, y)$ for inference, instead of you generating the data yourself.
153 |
154 | Then, you can run `SymbolicRegression.py` directly, or you can run it in the terminal as follows:
155 |
156 | ```python
157 | python SymbolicRegression.py
158 | ```
159 |
160 | After running this script, the results will be stored in path `./results/test/func_name`.
161 |
162 | #### Option2: Load the data file
163 |
164 | When you only have observed data and do not know the ground truth, you can perform symbolic regression by entering the path to the csv data file:
165 |
166 | ```python
167 | # SymbolicRegression.py
168 | params = Params() # configuration for a specific task
169 | data_path = './data/Nguyen-1.csv' # data file should be in csv format
170 | SR = SymbolicRegression(config=params, func_name='Nguyen-1', data_path=data_path) # you can rename the func_name as any other you want.
171 | eq, R2, error, relative_error = SR.solve_environment() # return results
172 | ```
173 |
174 | **Note:** the data file should contains ($X\_{dim} + 1$) colums, which $X\_{dim}$ is the number of independent variables and the last colum is the corresponding $y$ values.
175 |
176 | Then, you can run `SymbolicRegression.py` directly, or you can run it in the terminal as follows:
177 |
178 | ```python
179 | python SymbolicRegression.py
180 | ```
181 |
182 | After running this script, the results will be stored in path `./results/test/func_name`.
183 |
184 | #### Output
185 |
186 | Once the script stops early or finishes running, you will get the following output:
187 |
188 | ```
189 | Expression: x_1 + x_2
190 | R2: 1.0
191 | error: 4.3591795754679974e-13
192 | relative_error: 2.036015757767018e-06
193 | log(1 + MSE): 4.3587355946774144e-13
194 | ```
195 |
196 | ## 🌟 Citing this work
197 |
198 | If you find our work and this codebase helpful, please consider starring this repo and cite:
199 |
200 | ```bibtex
201 | @inproceedings{
202 | li2024a,
203 | title={A Neural-Guided Dynamic Symbolic Network for Exploring Mathematical Expressions from Data},
204 | author={Wenqiang Li and Weijun Li and Lina Yu and Min Wu and Linjun Sun and Jingyi Liu and Yanjie Li and Shu Wei and Deng Yusong and Meilan Hao},
205 | booktitle={Forty-first International Conference on Machine Learning},
206 | year={2024},
207 | url={https://openreview.net/forum?id=IejxxE9DO2}
208 | }
209 | ```
210 |
--------------------------------------------------------------------------------
/build/lib/DySymNet/SymbolicRegression.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import os
5 | import torch
6 | import sympy as sp
7 | import pandas as pd
8 | from scipy.optimize import minimize
9 | from .scripts.functions import *
10 | from .scripts import functions as functions
11 | import collections
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 | from sympy import symbols, Float
15 | from torch import nn, optim
16 | from .scripts.controller import Agent
17 | import torch.nn.functional as F
18 | from .scripts import pretty_print
19 | from .scripts.regularization import L12Smooth
20 | from .scripts.symbolic_network import SymbolicNet
21 | from sklearn.metrics import r2_score
22 | from .scripts.params import Params
23 | from .scripts.utils import nrmse, R_Square, MSE, Relative_Error
24 |
25 |
26 | def generate_data(func, N, range_min, range_max):
27 | """Generates datasets."""
28 | free_symbols = sp.sympify(func).free_symbols
29 | x_dim = free_symbols.__len__()
30 | sp_expr = sp.lambdify(free_symbols, func)
31 | x = (range_max - range_min) * torch.rand([N, x_dim]) + range_min
32 | y = torch.tensor([[sp_expr(*x_i)] for x_i in x])
33 | return x, y
34 |
35 |
36 | class TimedFun:
37 | def __init__(self, fun, stop_after=10):
38 | self.fun_in = fun
39 | self.started = False
40 | self.stop_after = stop_after
41 |
42 | def fun(self, x, *args):
43 | if self.started is False:
44 | self.started = time.time()
45 | elif abs(time.time() - self.started) >= self.stop_after:
46 | raise ValueError("Time is over.")
47 | self.fun_value = self.fun_in(*x, *args) # sp.lambdify()
48 | self.x = x
49 | return self.fun_value
50 |
51 |
52 | class SymboliRegression:
53 | def __init__(self, config, func=None, func_name=None, data_path=None):
54 | """
55 | Args:
56 | config: All configs in the Params class, type: Params
57 | func: the function to be predicted, type: str
58 | func_name: the name of the function, type: str
59 | data_path: the path of the data, type: str
60 | """
61 | self.data_path = data_path
62 | self.X = None
63 | self.y = None
64 | self.funcs_per_layer = None
65 | self.num_epochs = config.num_epochs
66 | self.batch_size = config.batch_size
67 | self.input_size = config.input_size # number of operators
68 | self.hidden_size = config.hidden_size
69 | self.embedding_size = config.embedding_size
70 | self.n_layers = config.n_layers
71 | self.num_func_layer = config.num_func_layer
72 | self.funcs_avail = config.funcs_avail
73 | self.optimizer = config.optimizer
74 | self.auto = False
75 | self.add_bias = config.add_bias
76 | self.threshold = config.threshold
77 |
78 | self.clip_grad = config.clip_grad
79 | self.max_norm = config.max_norm
80 | self.window_size = config.window_size
81 | self.refine_constants = config.refine_constants
82 | self.n_restarts = config.n_restarts
83 | self.reward_type = config.reward_type
84 |
85 | if config.use_gpu:
86 | self.device = torch.device('cuda')
87 | else:
88 | self.device = torch.device('cpu')
89 | print("Use Device:", self.device)
90 |
91 | # Standard deviation of random distribution for weight initializations.
92 | self.init_sd_first = 0.1
93 | self.init_sd_last = 1.0
94 | self.init_sd_middle = 0.5
95 |
96 | self.config = config
97 |
98 | self.func = func
99 | self.func_name = func_name
100 |
101 | # generate data or load data from file
102 | if self.func is not None:
103 | # add noise
104 | if config.NOISE > 0:
105 | self.X, self.y = generate_data(func, self.config.N_TRAIN, self.config.DOMAIN[0], self.config.DOMAIN[1]) # y shape is (N, 1)
106 | y_rms = torch.sqrt(torch.mean(self.y ** 2))
107 | scale = config.NOISE * y_rms
108 | self.y += torch.empty(self.y.shape[-1]).normal_(mean=0, std=scale)
109 | self.x_test, self.y_test = generate_data(func, self.config.N_TRAIN, range_min=self.config.DOMAIN_TEST[0],
110 | range_max=self.config.DOMAIN_TEST[1])
111 |
112 | else:
113 | self.X, self.y = generate_data(func, self.config.N_TRAIN, self.config.DOMAIN[0], self.config.DOMAIN[1]) # y shape is (N, 1)
114 | self.x_test, self.y_test = generate_data(func, self.config.N_TRAIN, range_min=self.config.DOMAIN_TEST[0],
115 | range_max=self.config.DOMAIN_TEST[1])
116 | else:
117 | self.X, self.y = self.load_data(self.data_path)
118 | self.x_test, self.y_test = self.X, self.y
119 |
120 | self.dtype = self.X.dtype # obtain the data type, which determines the parameter type of the model
121 |
122 | if isinstance(self.n_layers, list) or isinstance(self.num_func_layer, list):
123 | print('*' * 25, 'Start Sampling...', '*' * 25 + '\n')
124 | self.auto = True
125 |
126 | self.agent = Agent(auto=self.auto, input_size=self.input_size, hidden_size=self.hidden_size,
127 | num_funcs_avail=len(self.funcs_avail), n_layers=self.n_layers,
128 | num_funcs_layer=self.num_func_layer, device=self.device, dtype=self.dtype)
129 |
130 | self.agent = self.agent.to(self.dtype)
131 |
132 | if not os.path.exists(self.config.results_dir):
133 | os.makedirs(self.config.results_dir)
134 |
135 | func_dir = os.path.join(self.config.results_dir, func_name)
136 | if not os.path.exists(func_dir):
137 | os.makedirs(func_dir)
138 | self.results_dir = func_dir
139 |
140 | self.now_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
141 |
142 | # save hyperparameters
143 | args = {
144 | "date": self.now_time,
145 | "add_bias": config.add_bias,
146 | "train_domain": config.DOMAIN,
147 | "test_domain": config.DOMAIN_TEST,
148 | "num_epochs": config.num_epochs,
149 | "batch_size": config.batch_size,
150 | "input_size": config.input_size,
151 | "hidden_size": config.hidden_size,
152 | "risk_factor": config.risk_factor,
153 | "n_layers": config.n_layers,
154 | "num_func_layer": config.num_func_layer,
155 | "funcs_avail": str([func.name for func in config.funcs_avail]),
156 | "init_sd_first": 0.1,
157 | "init_sd_last": 1.0,
158 | "init_sd_middle": 0.5,
159 | "noise_level": config.NOISE
160 | }
161 | with open(os.path.join(self.results_dir, 'args_{}.txt'.format(self.func_name)), 'a') as f:
162 | f.write(json.dumps(args))
163 | f.write("\n")
164 | f.close()
165 |
166 | def solve_environment(self):
167 | epoch_best_expressions = []
168 | epoch_best_rewards = []
169 | epoch_mean_rewards = []
170 | epoch_mean_r2 = []
171 | epoch_best_r2 = []
172 | epoch_best_relative_error = []
173 | epoch_mean_relative_error = []
174 | best_expression, best_performance, best_relative_error = None, float('-inf'), float('inf')
175 | early_stopping = False
176 |
177 | # log the expressions of all epochs
178 | f1 = open(os.path.join(self.results_dir, 'eq_{}_all.txt'.format(self.func_name)), 'a')
179 | f1.write('\n{}\t\t{}\n'.format(self.now_time, self.func_name))
180 | f1.write('{}\t\tReward\t\tR2\t\tExpression\t\tnum_layers\t\tnum_funcs_layer\t\tfuncs_per_layer\n'.format(self.reward_type))
181 |
182 | # log the best expressions of each epoch
183 | f2 = open(os.path.join(self.results_dir, 'eq_{}_summary.txt'.format(self.func_name)), 'a')
184 | f2.write('\n{}\t\t{}\n'.format(self.now_time, self.func_name))
185 | f2.write('Epoch\t\tReward\t\tR2\t\tExpression\n')
186 |
187 | if self.optimizer == "Adam":
188 | optimizer = torch.optim.Adam(self.agent.parameters(), lr=self.config.learning_rate1)
189 | else:
190 | optimizer = torch.optim.RMSprop(self.agent.parameters(), lr=self.config.learning_rate1)
191 |
192 | for i in range(self.num_epochs):
193 | print("******************** Epoch {:02d} ********************".format(i))
194 | expressions = []
195 | rewards = []
196 | r2 = []
197 | relative_error_list = []
198 | batch_log_probs = torch.zeros([self.batch_size], device=self.device)
199 | batch_entropies = torch.zeros([self.batch_size], device=self.device)
200 |
201 | j = 0
202 | while j < self.batch_size:
203 | error, R2, eq, log_probs, entropies, num_layers, num_func_layer, funcs_per_layer_name = self.play_episodes() # play an episode
204 | # if the expression is invalid, e.g. a constant or None, resample the structure of the symbolic network
205 | if 'x_1' not in str(eq) or eq is None:
206 | R2 = 0.0
207 | if 'x_1' in str(eq) and self.refine_constants:
208 | res = self.bfgs(eq, self.X, self.y, self.n_restarts)
209 | eq = res['best expression']
210 | R2 = res['R2']
211 | error = res['error']
212 | relative_error = res['relative error']
213 | else:
214 | relative_error = 100
215 |
216 | reward = 1 / (1 + error)
217 | print("Final expression: ", eq)
218 | print("Test R2: ", R2)
219 | print("Test error: ", error)
220 | print("Relative error: ", relative_error)
221 | print("Reward: ", reward)
222 | print('\n')
223 |
224 | f1.write('{:.8f}\t\t{:.8f}\t\t{:.8f}\t\t{:.8f}\t\t{}\t\t{}\t\t{}\t\t{}\n'.format(error, relative_error, reward, R2, eq, num_layers,
225 | num_func_layer,
226 | funcs_per_layer_name))
227 |
228 | if R2 > 0.99:
229 | print("~ Early Stopping Met ~")
230 | print("Best expression: ", eq)
231 | print("Best reward: ", reward)
232 | print(f"{self.config.reward_type} error: ", error)
233 | print("Relative error: ", relative_error)
234 | early_stopping = True
235 | break
236 |
237 | batch_log_probs[j] = log_probs
238 | batch_entropies[j] = entropies
239 | expressions.append(eq)
240 | rewards.append(reward)
241 | r2.append(R2)
242 | relative_error_list.append(relative_error)
243 | j += 1
244 |
245 | if early_stopping:
246 | f2.write('{}\t\t{:.8f}\t\t{:.8f}\t\t{}\n'.format(i, reward, R2, eq))
247 | break
248 |
249 | # a batch expressions
250 | ## reward
251 | rewards = torch.tensor(rewards, device=self.device)
252 | best_epoch_expression = expressions[np.argmax(rewards.cpu())]
253 | epoch_best_expressions.append(best_epoch_expression)
254 | epoch_best_rewards.append(max(rewards).item())
255 | epoch_mean_rewards.append(torch.mean(rewards).item())
256 |
257 | ## R2
258 | r2 = torch.tensor(r2, device=self.device)
259 | best_r2_expression = expressions[np.argmax(r2.cpu())]
260 | epoch_best_r2.append(max(r2).item())
261 | epoch_mean_r2.append(torch.mean(r2).item())
262 |
263 | epoch_best_relative_error.append(relative_error_list[np.argmax(r2.cpu())])
264 |
265 | # log the best expression of a batch
266 | f2.write(
267 | '{}\t\t{:.8f}\t\t{:.8f}\t\t{:.8f}\t\t{}\n'.format(i, relative_error_list[np.argmax(r2.cpu())], max(rewards).item(), max(r2).item(),
268 | best_r2_expression))
269 |
270 | # save the best expression from the beginning to now
271 | if max(r2) > best_performance:
272 | best_performance = max(r2)
273 | best_expression = best_r2_expression
274 | best_relative_error = min(epoch_best_relative_error)
275 |
276 | if self.config.risk_seeking:
277 | threshold = np.quantile(rewards.cpu(), self.config.risk_factor)
278 | indices_to_keep = torch.tensor([j for j in range(len(rewards)) if rewards[j] > threshold], device=self.device)
279 | if len(indices_to_keep) == 0:
280 | print("Threshold removes all expressions. Terminating.")
281 | break
282 |
283 | # Select corresponding subset of rewards, log_probabilities, and entropies
284 | sub_rewards = torch.index_select(rewards, 0, indices_to_keep)
285 | sub_log_probs = torch.index_select(batch_log_probs, 0, indices_to_keep)
286 | sub_entropies = torch.index_select(batch_entropies, 0, indices_to_keep)
287 |
288 | # Compute risk seeking and entropy gradient
289 | risk_seeking_grad = torch.sum((sub_rewards - threshold) * sub_log_probs, dim=0)
290 | entropy_grad = torch.sum(sub_entropies, dim=0)
291 |
292 | # Mean reduction and clip to limit exploding gradients
293 | risk_seeking_grad = torch.clip(risk_seeking_grad / (self.config.risk_factor * len(sub_rewards)), min=-1e6, max=1e6)
294 | entropy_grad = self.config.entropy_weight * torch.clip(entropy_grad / (self.config.risk_factor * len(sub_rewards)), min=-1e6, max=1e6)
295 |
296 | # compute loss and update parameters
297 | loss = -1 * (risk_seeking_grad + entropy_grad)
298 | optimizer.zero_grad()
299 | loss.backward()
300 | optimizer.step()
301 |
302 | f1.close()
303 | f2.close()
304 |
305 | # save the rewards
306 | f3 = open(os.path.join(self.results_dir, "reward_{}_{}.txt".format(self.func_name, self.now_time)), 'w')
307 | for i in range(len(epoch_mean_rewards)):
308 | f3.write("{} {:.8f}\n".format(i + 1, epoch_mean_rewards[i]))
309 | f3.close()
310 |
311 | # plot reward curve
312 | if self.config.plot_reward:
313 | # plt.plot([i + 1 for i in range(len(epoch_best_rewards))], epoch_best_rewards) # best reward of full epoch
314 | plt.plot([i + 1 for i in range(len(epoch_mean_rewards))], epoch_mean_rewards) # mean reward of full epoch
315 | plt.xlabel('Epoch')
316 | plt.ylabel('Reward')
317 | plt.title('Reward over Time ' + self.now_time)
318 | plt.show()
319 | plt.savefig(os.path.join(self.results_dir, "reward_{}_{}.png".format(self.func_name, self.now_time)))
320 |
321 | if early_stopping:
322 | return eq, R2, error, relative_error
323 | else:
324 | return best_expression, best_performance.item(), 1 / max(rewards).item() - 1, best_relative_error
325 |
326 | def bfgs(self, eq, X, y, n_restarts):
327 | variable = self.vars_name
328 |
329 | # Parse the expression and get all the constants
330 | expr = eq
331 | c = symbols('c0:10000') # Suppose we have at most n constants, c0, c1, ..., cn-1
332 | consts = list(expr.atoms(Float)) # Only floating-point coefficients are counted, not power exponents
333 | consts_dict = {c[i]: const for i, const in enumerate(consts)} # map between c_i and unoptimized constants
334 |
335 | for c_i, val in consts_dict.items():
336 | expr = expr.subs(val, c_i)
337 |
338 | def loss(expr, X):
339 | diffs = []
340 | for i in range(X.shape[0]):
341 | curr_expr = expr
342 | for idx, j in enumerate(variable):
343 | curr_expr = sp.sympify(curr_expr).subs(j, X[i, idx])
344 | diff = curr_expr - y[i]
345 | diffs.append(diff)
346 | return np.mean(np.square(diffs))
347 |
348 | # Lists where all restarted will be appended
349 | F_loss = []
350 | RE_list = []
351 | R2_list = []
352 | consts_ = []
353 | funcs = []
354 |
355 | print('Constructing BFGS loss...')
356 | loss_func = loss(expr, X)
357 |
358 | for i in range(n_restarts):
359 | x0 = np.array(consts, dtype=float)
360 | s = list(consts_dict.keys())
361 | # bfgs optimization
362 | fun_timed = TimedFun(fun=sp.lambdify(s, loss_func, modules=['numpy']), stop_after=int(1e10))
363 | if len(x0):
364 | minimize(fun_timed.fun, x0, method='BFGS') # check consts interval and if they are int
365 | consts_.append(fun_timed.x)
366 | else:
367 | consts_.append([])
368 |
369 | final = expr
370 | for i in range(len(s)):
371 | final = sp.sympify(final).replace(s[i], fun_timed.x[i])
372 |
373 | funcs.append(final)
374 |
375 | values = {x: X[:, idx] for idx, x in enumerate(variable)}
376 | y_pred = sp.lambdify(variable, final)(**values)
377 | if isinstance(y_pred, float):
378 | print('y_pred is float: ', y_pred, type(y_pred))
379 | R2 = 0.0
380 | loss_eq = 10000
381 | else:
382 | y_pred = torch.where(torch.isinf(y_pred), 10000, y_pred) # check if there is inf
383 | y_pred = torch.where(y_pred.clone().detach() > 10000, 10000, y_pred) # check if there is large number
384 | R2 = max(0.0, R_Square(y.squeeze(), y_pred))
385 | loss_eq = torch.mean(torch.square(y.squeeze() - y_pred)).item()
386 | relative_error = torch.mean(torch.abs((y.squeeze() - y_pred) / y.squeeze())).item()
387 | R2_list.append(R2)
388 | F_loss.append(loss_eq)
389 | RE_list.append(relative_error)
390 | best_R2_id = np.nanargmax(R2_list)
391 | best_consts = consts_[best_R2_id]
392 | best_expr = funcs[best_R2_id]
393 | best_R2 = R2_list[best_R2_id]
394 | best_error = F_loss[best_R2_id]
395 | best_re = RE_list[best_R2_id]
396 |
397 | return {'best expression': best_expr,
398 | 'constants': best_consts,
399 | 'R2': best_R2,
400 | 'error': best_error,
401 | 'relative error': best_re}
402 |
403 | def play_episodes(self):
404 | ############################### Sample a symbolic network ##############################
405 | init_state = torch.rand((1, self.input_size), device=self.device, dtype=self.dtype) # initial the input state
406 |
407 | if self.auto:
408 | num_layers, num_funcs_layer, action_index, log_probs, entropies = self.agent(
409 | init_state) # output the symbolic network structure parameters
410 | self.n_layers = num_layers
411 | self.num_func_layer = num_funcs_layer
412 | else:
413 | action_index, log_probs, entropies = self.agent(init_state)
414 |
415 | self.funcs_per_layer = {}
416 | self.funcs_per_layer_name = {}
417 |
418 | for i in range(self.n_layers):
419 | layer_funcs_list = list()
420 | layer_funcs_list_name = list()
421 | for j in range(self.num_func_layer):
422 | layer_funcs_list.append(self.funcs_avail[action_index[i, j]])
423 | layer_funcs_list_name.append(self.funcs_avail[action_index[i, j]].name)
424 | self.funcs_per_layer.update({i + 1: layer_funcs_list})
425 | self.funcs_per_layer_name.update({i + 1: layer_funcs_list_name})
426 |
427 | # let binary functions follow unary functions
428 | for layer, funcs in self.funcs_per_layer.items():
429 | unary_funcs = [func for func in funcs if isinstance(func, BaseFunction)]
430 | binary_funcs = [func for func in funcs if isinstance(func, BaseFunction2)]
431 | sorted_funcs = unary_funcs + binary_funcs
432 | self.funcs_per_layer[layer] = sorted_funcs
433 |
434 | print("Operators of each layer obtained by sampling: ", self.funcs_per_layer_name)
435 |
436 | ############################### Start training ##############################
437 | error_test, r2_test, eq = self.train(self.config.trials)
438 |
439 | return error_test, r2_test, eq, log_probs, entropies, self.n_layers, self.num_func_layer, self.funcs_per_layer_name
440 |
441 | def train(self, trials=1):
442 | """Train the network to find a given function"""
443 |
444 | data, target = self.X.to(self.device), self.y.to(self.device)
445 | test_data, test_target = self.x_test.to(self.device), self.y_test.to(self.device)
446 |
447 | self.x_dim = data.shape[-1]
448 |
449 | self.vars_name = [f'x_{i}' for i in range(1, self.x_dim + 1)] # Variable names
450 |
451 | width_per_layer = [len(f) for f in self.funcs_per_layer.values()]
452 | n_double_per_layer = [functions.count_double(f) for f in self.funcs_per_layer.values()]
453 |
454 | if self.auto:
455 | init_stddev = [self.init_sd_first] + [self.init_sd_middle] * (self.n_layers - 2) + [self.init_sd_last]
456 |
457 | # Arrays to keep track of various quantities as a function of epoch
458 | loss_list = [] # Total loss (MSE + regularization)
459 | error_list = [] # MSE
460 | reg_list = [] # Regularization
461 | error_test_list = [] # Test error
462 | r2_test_list = [] # Test R2
463 |
464 | error_test_final = []
465 | r2_test_final = []
466 | eq_list = []
467 |
468 | def log_grad_norm(net):
469 | sqsum = 0.0
470 | for p in net.parameters():
471 | if p.grad is not None:
472 | sqsum += (p.grad ** 2).sum().item()
473 | return np.sqrt(sqsum)
474 |
475 | # for trial in range(trials):
476 | retrain_num = 0
477 | trial = 0
478 | while 0 <= trial < trials:
479 | print("Training on function " + self.func_name + " Trial " + str(trial + 1) + " out of " + str(trials))
480 |
481 | # reinitialize for each trial
482 | if self.auto:
483 | net = SymbolicNet(self.n_layers,
484 | x_dim=self.x_dim,
485 | funcs=self.funcs_per_layer,
486 | initial_weights=None,
487 | init_stddev=init_stddev,
488 | add_bias=self.add_bias).to(self.device)
489 |
490 | else:
491 | net = SymbolicNet(self.n_layers,
492 | x_dim=self.x_dim,
493 | funcs=self.funcs_per_layer,
494 | initial_weights=[
495 | # kind of a hack for truncated normal
496 | torch.fmod(torch.normal(0, self.init_sd_first, size=(self.x_dim, width_per_layer[0] + n_double_per_layer[0])),
497 | 2),
498 | # binary operator has two inputs
499 | torch.fmod(
500 | torch.normal(0, self.init_sd_middle, size=(width_per_layer[0], width_per_layer[1] + n_double_per_layer[1])),
501 | 2),
502 | torch.fmod(
503 | torch.normal(0, self.init_sd_middle, size=(width_per_layer[1], width_per_layer[2] + n_double_per_layer[2])),
504 | 2),
505 | torch.fmod(torch.normal(0, self.init_sd_last, size=(width_per_layer[-1], 1)), 2)
506 | ]).to(self.device)
507 |
508 | net.to(self.dtype)
509 |
510 | loss_val = np.nan
511 | restart_flag = False
512 | while np.isnan(loss_val):
513 | # training restarts if gradients blow up
514 | criterion = nn.MSELoss()
515 | optimizer = optim.RMSprop(net.parameters(),
516 | lr=self.config.learning_rate2,
517 | alpha=0.9, # smoothing constant
518 | eps=1e-10,
519 | momentum=0.0,
520 | centered=False)
521 |
522 | # adaptive learning rate
523 | lmbda = lambda epoch: 0.1
524 | scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
525 |
526 | if self.clip_grad:
527 | que = collections.deque()
528 |
529 | net.train() # Set model to training mode
530 |
531 | # First stage of training, preceded by 0th warmup stage
532 | for epoch in range(self.config.n_epochs1 + 2000):
533 | optimizer.zero_grad() # zero the parameter gradients
534 | outputs = net(data) # forward pass
535 | regularization = L12Smooth(a=0.01)
536 | mse_loss = criterion(outputs, target)
537 |
538 | reg_loss = regularization(net.get_weights_tensor())
539 | loss = mse_loss + self.config.reg_weight * reg_loss
540 | # loss = mse_loss
541 | loss.backward()
542 |
543 | if self.clip_grad:
544 | grad_norm = log_grad_norm(net)
545 | que.append(grad_norm)
546 | if len(que) > self.window_size:
547 | que.popleft()
548 | clip_threshold = 0.1 * sum(que) / len(que)
549 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=clip_threshold, norm_type=2)
550 | else:
551 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=self.max_norm, norm_type=2)
552 |
553 | optimizer.step()
554 |
555 | # summary
556 | if epoch % self.config.summary_step == 0:
557 | error_val = mse_loss.item()
558 | reg_val = reg_loss.item()
559 | loss_val = loss.item()
560 |
561 | error_list.append(error_val)
562 | reg_list.append(reg_val)
563 | loss_list.append(loss_val)
564 |
565 | with torch.no_grad(): # test error
566 | test_outputs = net(test_data) # [num_points, 1] as same as test_target
567 | if self.reward_type == 'mse':
568 | test_loss = F.mse_loss(test_outputs, test_target)
569 | elif self.reward_type == 'nrmse':
570 | test_loss = nrmse(test_target, test_outputs)
571 | error_test_val = test_loss.item()
572 | error_test_list.append(error_test_val)
573 | test_outputs = torch.where(torch.isnan(test_outputs), torch.full_like(test_outputs, 100),
574 | test_outputs)
575 | r2 = R_Square(test_target, test_outputs)
576 | r2_test_list.append(r2)
577 |
578 | if self.config.verbose:
579 | print("Epoch: {}\tTotal training loss: {}\tTest {}: {}".format(epoch, loss_val, self.reward_type, error_test_val))
580 |
581 | if np.isnan(loss_val) or loss_val > 1000: # If loss goes to NaN, restart training
582 | restart_flag = True
583 | break
584 |
585 | if epoch == 2000:
586 | scheduler.step() # lr /= 10
587 |
588 | if restart_flag:
589 | break
590 |
591 | scheduler.step() # lr /= 10 again
592 |
593 | for epoch in range(self.config.n_epochs2):
594 | optimizer.zero_grad() # zero the parameter gradients
595 | outputs = net(data)
596 | regularization = L12Smooth(a=0.01)
597 | mse_loss = criterion(outputs, target)
598 | reg_loss = regularization(net.get_weights_tensor())
599 | loss = mse_loss + self.config.reg_weight * reg_loss
600 | loss.backward()
601 |
602 | if self.clip_grad:
603 | grad_norm = log_grad_norm(net)
604 | que.append(grad_norm)
605 | if len(que) > self.window_size:
606 | que.popleft()
607 | clip_threshold = 0.1 * sum(que) / len(que)
608 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=clip_threshold, norm_type=2)
609 | else:
610 | torch.nn.utils.clip_grad_norm_(parameters=net.parameters(), max_norm=self.max_norm, norm_type=2)
611 |
612 | optimizer.step()
613 |
614 | if epoch % self.config.summary_step == 0:
615 | error_val = mse_loss.item()
616 | reg_val = reg_loss.item()
617 | loss_val = loss.item()
618 | error_list.append(error_val)
619 | reg_list.append(reg_val)
620 | loss_list.append(loss_val)
621 |
622 | with torch.no_grad(): # test error
623 | test_outputs = net(test_data)
624 | if self.reward_type == 'mse':
625 | test_loss = F.mse_loss(test_outputs, test_target)
626 | elif self.reward_type == 'nrmse':
627 | test_loss = nrmse(test_target, test_outputs)
628 | error_test_val = test_loss.item()
629 | error_test_list.append(error_test_val)
630 | test_outputs = torch.where(torch.isnan(test_outputs), torch.full_like(test_outputs, 100),
631 | test_outputs)
632 | r2 = R_Square(test_target, test_outputs)
633 | r2_test_list.append(r2)
634 | if self.config.verbose:
635 | print("Epoch: {}\tTotal training loss: {}\tTest {}: {}".format(epoch, loss_val, self.reward_type, error_test_val))
636 |
637 | if np.isnan(loss_val) or loss_val > 1000: # If loss goes to NaN, restart training
638 | break
639 |
640 | if restart_flag:
641 | # self.play_episodes()
642 | retrain_num += 1
643 | if retrain_num == 5: # only allow 5 restarts
644 | return 10000, None, None
645 | continue
646 |
647 | # After the training, the symbolic network was transformed into an expression by pruning
648 | with torch.no_grad():
649 | weights = net.get_weights()
650 | if self.add_bias:
651 | biases = net.get_biases()
652 | else:
653 | biases = None
654 | expr = pretty_print.network(weights, self.funcs_per_layer, self.vars_name, self.threshold, self.add_bias, biases)
655 |
656 | # results of training trials
657 | error_test_final.append(error_test_list[-1])
658 | r2_test_final.append(r2_test_list[-1])
659 | eq_list.append(expr)
660 |
661 | trial += 1
662 |
663 | error_expr_sorted = sorted(zip(error_test_final, r2_test_final, eq_list), key=lambda x: x[0]) # List of (error, r2, expr)
664 | print('error_expr_sorted', error_expr_sorted)
665 |
666 | return error_expr_sorted[0]
667 |
668 | def load_data(self, path):
669 | data = pd.read_csv(path)
670 |
671 | if data.shape[1] < 2:
672 | raise ValueError('CSV file must contain at least 2 columns.')
673 |
674 | x_data = data.iloc[:, :-1]
675 | y_data = data.iloc[:, -1:]
676 |
677 | X = torch.tensor(x_data.values, dtype=torch.float32)
678 | y = torch.tensor(y_data.values, dtype=torch.float32)
679 |
680 | return X, y
681 |
682 |
683 | if __name__ == "__main__":
684 | # Configure the parameters
685 | config = Params()
686 |
687 | # Example 1: Input ground truth expression
688 | SR = SymboliRegression(config=config, func="x_1 + x_2", func_name="x_1+x_2")
689 | eq, R2, error, relative_error = SR.solve_environment()
690 | print('Expression: ', eq)
691 | print('R2: ', R2)
692 | print('error: ', error)
693 | print('relative_error: ', relative_error)
694 | print('log(1 + MSE): ', np.log(1 + error))
695 |
696 | # Example 2: Input data path of csv file
697 | # SR = SymboliRegression(config=config, func_name="Nguyen-1", data_path="./data/Nguyen-1.csv")
698 | # eq, R2, error, relative_error = SR.solve_environment()
699 | # print('Expression: ', eq)
700 | # print('R2: ', R2)
701 | # print('error: ', error)
702 | # print('relative_error: ', relative_error)
703 | # print('log(1 + MSE): ', np.log(1 + error))
--------------------------------------------------------------------------------
/build/lib/DySymNet/__init__.py:
--------------------------------------------------------------------------------
1 | name = "DymSymNet"
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/build/lib/DySymNet/scripts/__init__.py
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/controller.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Categorical
5 | from torch.nn.functional import one_hot, log_softmax
6 |
7 |
8 | class Agent(nn.Module):
9 |
10 | def __init__(self, auto, input_size, hidden_size, num_funcs_avail, n_layers, num_funcs_layer, device=None, dtype=torch.float32):
11 | super(Agent, self).__init__()
12 | self.auto = auto
13 | self.num_funcs_avail = num_funcs_avail # Optional operator category per layer
14 | self.n_layers = n_layers # Optional number of layers
15 | self.num_funcs_layer = num_funcs_layer # Optional number of operators per layer
16 | self.dtype = dtype
17 |
18 | if device is not None:
19 | self.device = device
20 | else:
21 | self.device = 'cpu'
22 |
23 | if self.auto:
24 | self.n_layer_decoder = nn.Linear(hidden_size, len(self.n_layers), device=device)
25 | self.num_funcs_layer_decoder = nn.Linear(hidden_size, len(self.num_funcs_layer), device=device)
26 | self.max_input_size = max(len(self.n_layers), len(self.num_funcs_layer))
27 | self.dynamic_lstm_cell = nn.LSTMCell(self.max_input_size, hidden_size, device=device)
28 | self.embedding = nn.Linear(self.num_funcs_avail, len(self.num_funcs_layer), device=device)
29 |
30 | self.lstm_cell = nn.LSTMCell(input_size, hidden_size, device=device)
31 | self.decoder = nn.Linear(hidden_size, self.num_funcs_avail, device=device) # output probability distribution
32 | self.n_steps = n_layers
33 | self.hidden_size = hidden_size
34 | self.hidden = self.init_hidden()
35 |
36 | def init_hidden(self):
37 | h_t = torch.zeros(1, self.hidden_size, dtype=self.dtype, device=self.device) # [batch_size, hidden_size]
38 | c_t = torch.zeros(1, self.hidden_size, dtype=self.dtype, device=self.device) # [batch_size, hidden_size]
39 |
40 | return h_t, c_t
41 |
42 | def forward(self, input):
43 |
44 | if self.auto:
45 | if input.shape[-1] < self.max_input_size:
46 | input = nn.functional.pad(input, (0, self.max_input_size - input.shape[0]), 'constant', 0)
47 |
48 | assert input.shape[-1] == self.max_input_size, 'Error: the input dim of the first step is not equal to the max dim'
49 |
50 | h_t, c_t = self.hidden
51 |
52 | # Sample the number of layers first
53 | h_t, c_t = self.dynamic_lstm_cell(input, (h_t, c_t)) # [batch_size, hidden_size]
54 | n_layer_logits = self.n_layer_decoder(h_t) # [batch_size, len(n_layers)]
55 | n_layer_probs = F.softmax(n_layer_logits, dim=-1)
56 | dist = Categorical(probs=n_layer_probs)
57 | action_index1 = dist.sample()
58 | log_prob1 = dist.log_prob(action_index1)
59 | entropy1 = dist.entropy()
60 | num_layers = self.n_layers[action_index1]
61 |
62 | # Sample the number of operators per layer
63 | input = n_layer_logits
64 | if input.shape[-1] < self.max_input_size:
65 | input = nn.functional.pad(input, (0, self.max_input_size - input.shape[-1]), 'constant', 0)
66 | h_t, c_t = self.dynamic_lstm_cell(input, (h_t, c_t))
67 | n_funcs_layer_logits = self.num_funcs_layer_decoder(h_t) # [batch_size, len(num_funcs_layer)]
68 | n_funcs_layer_probs = F.softmax(n_funcs_layer_logits, dim=-1)
69 | dist = Categorical(probs=n_funcs_layer_probs)
70 | action_index2 = dist.sample()
71 | log_prob2 = dist.log_prob(action_index2)
72 | entropy2 = dist.entropy()
73 | num_funcs_layer = self.num_funcs_layer[action_index2]
74 |
75 | # Sample the operators
76 | input = n_funcs_layer_logits
77 | if input.shape[-1] < self.max_input_size:
78 | input = nn.functional.pad(input, (0, self.max_input_size - input.shape[0]), 'constant', 0)
79 |
80 | outputs = []
81 | for t in range(num_layers):
82 | h_t, c_t = self.dynamic_lstm_cell(input, (h_t, c_t))
83 | output = self.decoder(h_t) # [batch_size, len(func_avail)]
84 | outputs.append(output)
85 | input = self.embedding(output)
86 |
87 | outputs = torch.stack(outputs).squeeze(1) # [n_layers, len(funcs)]
88 | probs = F.softmax(outputs, dim=-1)
89 | dist = Categorical(probs=probs)
90 | action_index3 = dist.sample((num_funcs_layer,)).transpose(0, 1) # [num_layers, num_func_layer]
91 | # print("action_index: ", action_index)
92 | log_probs = dist.log_prob(action_index3.transpose(0, 1)).transpose(0, 1) # [num_layers, num_func_layer] compute the log probability of the sampled action
93 | entropies = dist.entropy() # [num_layers] compute the entropy of the action distribution
94 | log_probs, entropies = torch.sum(log_probs), torch.sum(entropies)
95 |
96 | # another way to sample
97 | # probs = F.softmax(episode_logits, dim=-1)
98 | # action_index = torch.multinomial(probs, self.num_func_layer, replacement=True)
99 |
100 | # mask = one_hot(action_index, num_classes=self.input_size).squeeze(1)
101 | # log_probs = log_softmax(episode_logits, dim=-1)
102 | # episode_log_probs = torch.sum(mask.float() * log_probs)
103 |
104 | log_probs = log_probs + log_prob1 + log_prob2
105 | entropies = entropies + entropy1 + entropy2
106 |
107 | return num_layers, num_funcs_layer, action_index3, log_probs, entropies
108 |
109 | # Fix the number of layers and the number of operators per layer, only sample the operators, each layer is different
110 | else:
111 | outputs = []
112 | h_t, c_t = self.hidden
113 |
114 | for i in range(self.n_steps):
115 | h_t, c_t = self.lstm_cell(input, (h_t, c_t))
116 | output = self.decoder(h_t) # [batch_size, num_choices]
117 | outputs.append(output)
118 | input = output
119 |
120 | outputs = torch.stack(outputs).squeeze(1) # [num_steps, num_choices]
121 | probs = F.softmax(outputs, dim=-1)
122 | dist = Categorical(probs=probs)
123 | action_index = dist.sample((self.num_funcs_layer,)).transpose(0, 1) # [num_layers, num_func_layer]
124 | # print("action_index: ", action_index)
125 | log_probs = dist.log_prob(action_index.transpose(0, 1)).transpose(0, 1) # [num_layers, num_func_layer]
126 | entropies = dist.entropy() # [num_layers]
127 | log_probs, entropies = torch.sum(log_probs), torch.sum(entropies)
128 | return action_index, log_probs, entropies
129 |
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/functions.py:
--------------------------------------------------------------------------------
1 | """Functions for use with symbolic regression.
2 |
3 | These functions encapsulate multiple implementations (sympy, Tensorflow, numpy) of a particular function so that the
4 | functions can be used in multiple contexts."""
5 |
6 | import torch
7 | # import tensorflow as tf
8 | import numpy as np
9 | import sympy as sp
10 |
11 |
12 | class BaseFunction:
13 | """Abstract class for primitive functions"""
14 |
15 | def __init__(self, norm=1):
16 | self.norm = norm
17 |
18 | def sp(self, x):
19 | """Sympy implementation"""
20 | return None
21 |
22 | def torch(self, x):
23 | """No need for base function"""
24 | return None
25 |
26 | def tf(self, x):
27 | """Automatically convert sympy to TensorFlow"""
28 | z = sp.symbols('z')
29 | return sp.utilities.lambdify(z, self.sp(z), 'tensorflow')(x)
30 |
31 | def np(self, x):
32 | """Automatically convert sympy to numpy"""
33 | z = sp.symbols('z')
34 | return sp.utilities.lambdify(z, self.sp(z), 'numpy')(x)
35 |
36 |
37 | class Constant(BaseFunction):
38 | def torch(self, x):
39 | return torch.ones_like(x)
40 |
41 | def sp(self, x):
42 | return 1
43 |
44 | def np(self, x):
45 | return np.ones_like
46 |
47 |
48 | class Identity(BaseFunction):
49 | def __init__(self):
50 | super(Identity, self).__init__()
51 | self.name = 'id'
52 |
53 | def torch(self, x):
54 | return x / self.norm # ??
55 |
56 | def sp(self, x):
57 | return x / self.norm
58 |
59 | def np(self, x):
60 | return np.array(x) / self.norm
61 |
62 |
63 | class Square(BaseFunction):
64 | def __init__(self):
65 | super(Square, self).__init__()
66 | self.name = 'pow2'
67 |
68 | def torch(self, x):
69 | return torch.square(x) / self.norm
70 |
71 | def sp(self, x):
72 | return x ** 2 / self.norm
73 |
74 | def np(self, x):
75 | return np.square(x) / self.norm
76 |
77 |
78 | class Pow(BaseFunction):
79 | def __init__(self, power, norm=1):
80 | BaseFunction.__init__(self, norm=norm)
81 | self.power = power
82 | self.name = 'pow{}'.format(int(power))
83 |
84 | def torch(self, x):
85 | return torch.pow(x, self.power) / self.norm
86 |
87 | def sp(self, x):
88 | return x ** self.power / self.norm
89 |
90 |
91 | # class Sin(BaseFunction):
92 | # def torch(self, x):
93 | # return torch.sin(x * 2 * 2 * np.pi) / self.norm
94 | #
95 | # def sp(self, x):
96 | # return sp.sin(x * 2 * 2 * np.pi) / self.norm
97 |
98 | class Sin(BaseFunction):
99 | def __init__(self):
100 | super().__init__()
101 | self.name = 'sin'
102 |
103 | def torch(self, x):
104 | return torch.sin(x) / self.norm
105 |
106 | def sp(self, x):
107 | return sp.sin(x) / self.norm
108 |
109 |
110 | class Cos(BaseFunction):
111 | def __init__(self):
112 | super(Cos, self).__init__()
113 | self.name = 'cos'
114 |
115 | def torch(self, x):
116 | return torch.cos(x) / self.norm
117 |
118 | def sp(self, x):
119 | return sp.cos(x) / self.norm
120 |
121 |
122 | class Tan(BaseFunction):
123 | def __init__(self):
124 | super(Tan, self).__init__()
125 | self.name = 'tan'
126 |
127 | def torch(self, x):
128 | return torch.tan(x) / self.norm
129 |
130 | def sp(self, x):
131 | return sp.tan(x) / self.norm
132 |
133 |
134 | class Sigmoid(BaseFunction):
135 | def torch(self, x):
136 | return torch.sigmoid(x) / self.norm
137 |
138 | # def tf(self, x):
139 | # return tf.sigmoid(x) / self.norm
140 |
141 | def sp(self, x):
142 | return 1 / (1 + sp.exp(-20 * x)) / self.norm
143 |
144 | def np(self, x):
145 | return 1 / (1 + np.exp(-20 * x)) / self.norm
146 |
147 | def name(self, x):
148 | return "sigmoid(x)"
149 |
150 |
151 | # class Exp(BaseFunction):
152 | # def __init__(self, norm=np.e):
153 | # super().__init__(norm)
154 | #
155 | # # ?? why the minus 1
156 | # def torch(self, x):
157 | # return (torch.exp(x) - 1) / self.norm
158 | #
159 | # def sp(self, x):
160 | # return (sp.exp(x) - 1) / self.norm
161 |
162 | class Exp(BaseFunction):
163 | def __init__(self):
164 | super().__init__()
165 | self.name = 'exp'
166 |
167 | # ?? why the minus 1
168 | def torch(self, x):
169 | return torch.exp(x)
170 |
171 | def sp(self, x):
172 | return sp.exp(x)
173 |
174 |
175 | class Log(BaseFunction):
176 | def __init__(self):
177 | super(Log, self).__init__()
178 | self.name = 'log'
179 |
180 | def torch(self, x):
181 | return torch.log(torch.abs(x) + 1e-6) / self.norm
182 |
183 | def sp(self, x):
184 | return sp.log(sp.Abs(x) + 1e-6) / self.norm
185 |
186 |
187 | class Sqrt(BaseFunction):
188 | def __init__(self):
189 | super(Sqrt, self).__init__()
190 | self.name = 'sqrt'
191 |
192 | def torch(self, x):
193 | return torch.sqrt(torch.abs(x)) / self.norm
194 |
195 | def sp(self, x):
196 | return sp.sqrt(sp.Abs(x)) / self.norm
197 |
198 |
199 | class BaseFunction2:
200 | """Abstract class for primitive functions with 2 inputs"""
201 |
202 | def __init__(self, norm=1.):
203 | self.norm = norm
204 |
205 | def sp(self, x, y):
206 | """Sympy implementation"""
207 | return None
208 |
209 | def torch(self, x, y):
210 | return None
211 |
212 | def tf(self, x, y):
213 | """Automatically convert sympy to TensorFlow"""
214 | a, b = sp.symbols('a b')
215 | return sp.utilities.lambdify([a, b], self.sp(a, b), 'tensorflow')(x, y)
216 |
217 | def np(self, x, y):
218 | """Automatically convert sympy to numpy"""
219 | a, b = sp.symbols('a b')
220 | return sp.utilities.lambdify([a, b], self.sp(a, b), 'numpy')(x, y)
221 |
222 | # def name(self, x, y):
223 | # return str(self.sp)
224 |
225 |
226 | class Product(BaseFunction2):
227 | def __init__(self, norm=0.1):
228 | super().__init__(norm=norm)
229 | self.name = '*'
230 |
231 | def torch(self, x, y):
232 | return x * y / self.norm
233 |
234 | def sp(self, x, y):
235 | return x * y / self.norm
236 |
237 |
238 | class Plus(BaseFunction2):
239 | def __init__(self, norm=1.0):
240 | super().__init__(norm=norm)
241 | self.name = '+'
242 |
243 | def torch(self, x, y):
244 | return (x + y) / self.norm
245 |
246 | def sp(self, x, y):
247 | return (x + y) / self.norm
248 |
249 |
250 | class Sub(BaseFunction2):
251 | def __init__(self, norm=1.0):
252 | super().__init__(norm=norm)
253 | self.name = '-'
254 |
255 | def torch(self, x, y):
256 | return (x - y) / self.norm
257 |
258 | def sp(self, x, y):
259 | return (x - y) / self.norm
260 |
261 |
262 | class Div(BaseFunction2):
263 | def __init__(self):
264 | super(Div, self).__init__()
265 | self.name = '/'
266 |
267 | def torch(self, x, y):
268 | return x / (y + 1e-6)
269 |
270 | def sp(self, x, y):
271 | return x / (y + 1e-6)
272 |
273 |
274 | def count_inputs(funcs):
275 | i = 0
276 | for func in funcs:
277 | if isinstance(func, BaseFunction):
278 | i += 1
279 | elif isinstance(func, BaseFunction2):
280 | i += 2
281 | return i
282 |
283 |
284 | def count_double(funcs):
285 | i = 0
286 | for func in funcs:
287 | if isinstance(func, BaseFunction2):
288 | i += 1
289 | return i
290 |
291 |
292 | default_func = [
293 | Product(),
294 | Plus(),
295 | Sin(),
296 | ]
297 |
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/params.py:
--------------------------------------------------------------------------------
1 | from .functions import *
2 |
3 |
4 | class Params:
5 | # Optional operators during sampling
6 | funcs_avail = [Identity(),
7 | Sin(),
8 | Cos(),
9 | # Tan(),
10 | # Exp(),
11 | # Log(),
12 | # Sqrt(),
13 | Square(),
14 | # Pow(3),
15 | # Pow(4),
16 | # Pow(5),
17 | # Pow(6),
18 | Plus(),
19 | Sub(),
20 | Product(),
21 | # Div()
22 | ]
23 | n_layers = [2, 3, 4, 5] # optional number of layers
24 | num_func_layer = [2, 3, 4, 5, 6] # optional number of functions in each layer
25 |
26 | # symbolic network training parameters
27 | learning_rate2 = 1e-2
28 | reg_weight = 5e-3
29 | threshold = 0.05
30 | trials = 1 # training trials of symbolic network
31 | n_epochs1 = 10001
32 | n_epochs2 = 10001
33 | summary_step = 1000
34 | clip_grad = True # clip gradient or not
35 | max_norm = 1 # norm threshold for gradient clipping
36 | window_size = 50 # window size for adaptive gradient clipping
37 | refine_constants = True # refine constants or not
38 | n_restarts = 1 # number of restarts for BFGS optimization
39 | add_bias = False # add bias or not
40 | verbose = True # print training process or not
41 | use_gpu = False # use cuda or not
42 | plot_reward = False # plot reward or not
43 |
44 | # controller parameters
45 | num_epochs = 500
46 | batch_size = 10
47 | if isinstance(n_layers, list) or isinstance(num_func_layer, list):
48 | input_size = max(len(n_layers), len(num_func_layer))
49 | else:
50 | input_size = len(funcs_avail)
51 | optimizer = "Adam"
52 | hidden_size = 32
53 | embedding_size = 16
54 | learning_rate1 = 0.0006
55 | risk_seeking = True
56 | risk_factor = 0.5
57 | entropy_weight = 0.005
58 | reward_type = "mse" # mse, nrmse
59 |
60 | # dataset parameters
61 | N_TRAIN = 100 # Size of training dataset
62 | N_VAL = 100 # Size of validation dataset
63 | NOISE = 0 # Standard deviation of noise for training dataset
64 | DOMAIN = (-1, 1) # Domain of dataset - range from which we sample x. Default (-1, 1)
65 | # DOMAIN = np.array([[0, -1, -1], [1, 1, 1]]) # Use this format if each input variable has a different domain
66 | N_TEST = 100 # Size of test dataset
67 | DOMAIN_TEST = (-1, 1) # Domain of test dataset - should be larger than training domain to test extrapolation. Default (-2, 2)
68 | var_names = [f'x_{i}' for i in range(1, 21)] # not used
69 |
70 | # save path
71 | results_dir = './results/test'
72 |
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/pretty_print.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a mathematical expression of the symbolic regression network (AKA EQL network) using SymPy. This expression
3 | can be used to pretty-print the expression (including human-readable text, LaTeX, etc.). SymPy also allows algebraic
4 | manipulation of the expression.
5 | The main function is network(...)
6 | There are several filtering functions to simplify expressions, although these are not always needed if the weight matrix
7 | is already pruned.
8 | """
9 | import pdb
10 |
11 | import sympy as sp
12 | from . import functions
13 |
14 |
15 | def apply_activation(W, funcs, n_double=0):
16 | """Given an (n, m) matrix W and (m) vector of funcs, apply funcs to W.
17 |
18 | Arguments:
19 | W: (n, m) matrix
20 | funcs: list of activation functions (SymPy functions)
21 | n_double: Number of activation functions that take in 2 inputs
22 |
23 | Returns:
24 | SymPy matrix with 1 column that represents the output of applying the activation functions.
25 | """
26 | W = sp.Matrix(W)
27 | if n_double == 0:
28 | for i in range(W.shape[0]):
29 | for j in range(W.shape[1]):
30 | W[i, j] = funcs[j](W[i, j])
31 | else:
32 | W_new = W.copy()
33 | out_size = len(funcs)
34 | for i in range(W.shape[0]):
35 | in_j = 0
36 | out_j = 0
37 | while out_j < out_size - n_double:
38 | W_new[i, out_j] = funcs[out_j](W[i, in_j])
39 | in_j += 1
40 | out_j += 1
41 | while out_j < out_size:
42 | W_new[i, out_j] = funcs[out_j](W[i, in_j], W[i, in_j + 1])
43 | in_j += 2
44 | out_j += 1
45 | for i in range(n_double):
46 | W_new.col_del(-1)
47 | W = W_new
48 | return W
49 |
50 |
51 | def sym_pp(W_list, funcs, var_names, threshold=0.01, n_double=None, add_bias=False, biases=None):
52 | """Pretty print the hidden layers (not the last layer) of the symbolic regression network
53 |
54 | Arguments:
55 | W_list: list of weight matrices for the hidden layers
56 | funcs: dict of lambda functions using sympy. has the same size as W_list[i][j, :]
57 | var_names: list of strings for names of variables
58 | threshold: threshold for filtering expression. set to 0 for no filtering.
59 | n_double: list Number of activation functions that take in 2 inputs
60 |
61 | Returns:
62 | Simplified sympy expression.
63 | """
64 | vars = []
65 | for var in var_names:
66 | if isinstance(var, str):
67 | vars.append(sp.Symbol(var))
68 | else:
69 | vars.append(var)
70 | try:
71 | expr = sp.Matrix(vars).T
72 |
73 | if add_bias and biases is not None:
74 | assert len(W_list) == len(biases), "The number of biases must be equal to the number of weights."
75 | for i, (W, b) in enumerate(zip(W_list, biases)):
76 | W = filter_mat(sp.Matrix(W), threshold=threshold)
77 | b = filter_mat(sp.Matrix(b), threshold=threshold)
78 | expr = expr * W + b
79 | expr = apply_activation(expr, funcs[i + 1], n_double=n_double[i])
80 |
81 | else:
82 | for i, W in enumerate(W_list):
83 | W = filter_mat(sp.Matrix(W), threshold=threshold) # Pruning
84 | expr = expr * W
85 | expr = apply_activation(expr, funcs[i + 1], n_double=n_double[i])
86 | except:
87 | pdb.set_trace()
88 | # expr = expr * W_list[-1]
89 | return expr
90 |
91 |
92 | def last_pp(eq, W, add_bias=False, biases=None):
93 | """Pretty print the last layer."""
94 | if add_bias and biases is not None:
95 | return eq * filter_mat(sp.Matrix(W)) + filter_mat(sp.Matrix(biases))
96 | else:
97 | return eq * filter_mat(sp.Matrix(W))
98 |
99 |
100 | def network(weights, funcs, var_names, threshold=0.01, add_bias=False, biases=None):
101 | """Pretty print the entire symbolic regression network.
102 |
103 | Arguments:
104 | weights: list of weight matrices for the entire network
105 | funcs: dict of lambda functions using sympy. has the same size as W_list[i][j, :]
106 | var_names: list of strings for names of variables
107 | threshold: threshold for filtering expression. set to 0 for no filtering.
108 |
109 | Returns:
110 | Simplified sympy expression."""
111 | n_double = [functions.count_double(funcs_per_layer) for funcs_per_layer in funcs.values()]
112 | # translate operators to sympy operators
113 | sp_funcs = {}
114 | for key, value in funcs.items():
115 | sp_value = [func.sp for func in value]
116 | sp_funcs.update({key: sp_value})
117 |
118 | if add_bias and biases is not None:
119 | assert len(weights) == len(biases), "The number of biases must be equal to the number of weights - 1."
120 | expr = sym_pp(weights[:-1], sp_funcs, var_names, threshold=threshold, n_double=n_double, add_bias=add_bias, biases=biases[:-1])
121 | expr = last_pp(expr, weights[-1], add_bias=add_bias, biases=biases[-1])
122 | else:
123 | expr = sym_pp(weights[:-1], sp_funcs, var_names, threshold=threshold, n_double=n_double, add_bias=add_bias)
124 | expr = last_pp(expr, weights[-1], add_bias=add_bias)
125 |
126 | try:
127 | expr = expr[0, 0]
128 | return expr
129 | except Exception as e:
130 | print("An exception occurred:", e)
131 |
132 |
133 |
134 | def filter_mat(mat, threshold=0.01):
135 | """Remove elements of a matrix below a threshold."""
136 | for i in range(mat.shape[0]):
137 | for j in range(mat.shape[1]):
138 | if abs(mat[i, j]) < threshold:
139 | mat[i, j] = 0
140 | return mat
141 |
142 |
143 | def filter_expr(expr, threshold=0.01):
144 | """Remove additive terms with coefficient below threshold
145 | TODO: Make more robust. This does not work in all cases."""
146 | expr_new = sp.Integer(0)
147 | for arg in expr.args:
148 | if arg.is_constant() and abs(arg) > threshold: # hack way to check if it's a number
149 | expr_new = expr_new + arg
150 | elif not arg.is_constant() and abs(arg.args[0]) > threshold:
151 | expr_new = expr_new + arg
152 | return expr_new
153 |
154 |
155 | def filter_expr2(expr, threshold=0.01):
156 | """Sets all constants under threshold to 0
157 | TODO: Test"""
158 | for a in sp.preorder_traversal(expr):
159 | if isinstance(a, sp.Float) and a < threshold:
160 | expr = expr.subs(a, 0)
161 | return expr
162 |
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/regularization.py:
--------------------------------------------------------------------------------
1 | """Methods for regularization to produce sparse networks.
2 |
3 | L2 regularization mostly penalizes the weight magnitudes without introducing sparsity.
4 | L1 regularization promotes sparsity.
5 | L1/2 promotes sparsity even more than L1. However, it can be difficult to train due to non-convexity and exploding
6 | gradients close to 0. Thus, we introduce a smoothed L1/2 regularization to remove the exploding gradients."""
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class L12Smooth(nn.Module):
13 | def __init__(self, a):
14 | super(L12Smooth, self).__init__()
15 | self.a = a
16 |
17 | def forward(self, input_tensor):
18 | """input: predictions"""
19 | return self.l12_smooth(input_tensor, self.a)
20 |
21 | def l12_smooth(self, input_tensor, a=0.05):
22 | """Smoothed L1/2 norm"""
23 | if type(input_tensor) == list:
24 | return sum([self.l12_smooth(tensor) for tensor in input_tensor])
25 |
26 | smooth_abs = torch.where(torch.abs(input_tensor) < a,
27 | torch.pow(input_tensor, 4) / (-8 * a ** 3) + torch.square(input_tensor) * 3 / 4 / a + 3 * a / 8,
28 | torch.abs(input_tensor))
29 |
30 | return torch.sum(torch.sqrt(smooth_abs))
31 |
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/symbolic_network.py:
--------------------------------------------------------------------------------
1 | """Contains the symbolic regression neural network architecture."""
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from . import functions as functions
6 |
7 |
8 | class SymbolicLayer(nn.Module):
9 | """Neural network layer for symbolic regression where activation functions correspond to primitive functions.
10 | Can take multi-input activation functions (like multiplication)"""
11 |
12 | def __init__(self, funcs=None, initial_weight=None, init_stddev=0.1, in_dim=None, add_bias=False):
13 | """
14 | funcs: List of activation functions, using utils.functions
15 | initial_weight: (Optional) Initial value for weight matrix
16 | variable: Boolean of whether initial_weight is a variable or not
17 | init_stddev: (Optional) if initial_weight isn't passed in, this is standard deviation of initial weight
18 | """
19 | super().__init__()
20 |
21 | if funcs is None:
22 | funcs = functions.default_func
23 | self.initial_weight = initial_weight
24 | self.W = None # Weight matrix
25 | self.built = False # Boolean whether weights have been initialized
26 | self.add_bias = add_bias
27 |
28 | self.output = None # tensor for layer output
29 | self.n_funcs = len(funcs) # Number of activation functions (and number of layer outputs)
30 | self.funcs = [func.torch for func in funcs] # Convert functions to list of PyTorch functions
31 | self.n_double = functions.count_double(funcs) # Number of activation functions that take 2 inputs
32 | self.n_single = self.n_funcs - self.n_double # Number of activation functions that take 1 input
33 |
34 | self.out_dim = self.n_funcs + self.n_double
35 |
36 | if self.initial_weight is not None: # use the given initial weight
37 | self.W = nn.Parameter(self.initial_weight.clone().detach()) # copies
38 | self.built = True
39 | else:
40 | self.W = nn.Parameter(torch.fmod(torch.normal(mean=0.0, std=init_stddev, size=(in_dim, self.out_dim)), 2))
41 | if add_bias:
42 | self.b = nn.Parameter(torch.fmod(torch.normal(mean=0.0, std=init_stddev, size=(1, self.out_dim)), 2))
43 |
44 | def forward(self, x): # used to be __call__
45 | """Multiply by weight matrix and apply activation units"""
46 |
47 | if self.add_bias:
48 | g = torch.matmul(x, self.W) + self.b
49 | else:
50 | g = torch.matmul(x, self.W)
51 | self.output = []
52 |
53 | in_i = 0 # input index
54 | out_i = 0 # output index
55 | # Apply functions with only a single input, binary operators must come after unary operators
56 | while out_i < self.n_single:
57 | self.output.append(self.funcs[out_i](g[:, in_i])) # g[:, in_i] is the input to the activation function
58 | in_i += 1
59 | out_i += 1
60 | # Apply functions that take 2 inputs and produce 1 output
61 | while out_i < self.n_funcs:
62 | self.output.append(self.funcs[out_i](g[:, in_i], g[:, in_i + 1]))
63 | in_i += 2
64 | out_i += 1
65 |
66 | self.output = torch.stack(self.output, dim=1) # [n_points, n_funcs]
67 |
68 | return self.output
69 |
70 | def get_weight(self):
71 | return self.W.cpu().detach().numpy()
72 |
73 | def get_bias(self):
74 | return self.b.cpu().detach().numpy()
75 |
76 | def get_weight_tensor(self):
77 | return self.W.clone()
78 |
79 |
80 | class SymbolicNet(nn.Module):
81 | """Symbolic regression network with multiple layers. Produces one output."""
82 |
83 | def __init__(self, symbolic_depth, x_dim, funcs=None, initial_weights=None, init_stddev=0.1, add_bias=False):
84 | super(SymbolicNet, self).__init__()
85 |
86 | self.depth = symbolic_depth # symbolic network depths
87 | self.funcs = funcs # operators for each layer sampled by controller,{id: []}
88 | self.add_bias = add_bias # add bias or not
89 | layer_in_dim = [x_dim] + [len(funcs[i+1]) for i in range(self.depth)]
90 |
91 | if initial_weights is not None:
92 | layers = [SymbolicLayer(funcs=funcs[i+1], initial_weight=initial_weights[i], in_dim=layer_in_dim[i], add_bias=self.add_bias) for i in range(self.depth)]
93 | self.output_weight = nn.Parameter(initial_weights[-1].clone().detach())
94 |
95 | else:
96 | # Each layer initializes its own weights
97 | if not isinstance(init_stddev, list):
98 | init_stddev = [init_stddev] * self.depth
99 | layers = [SymbolicLayer(funcs=self.funcs[i+1], init_stddev=init_stddev[i], in_dim=layer_in_dim[i], add_bias=self.add_bias)
100 | for i in range(self.depth)]
101 | # Initialize weights for last layer (without activation functions)
102 | self.output_weight = nn.Parameter(torch.rand((layers[-1].n_funcs, 1)))
103 | if add_bias:
104 | self.output_bias = nn.Parameter(torch.rand((1, 1)))
105 |
106 | self.hidden_layers = nn.Sequential(*layers)
107 |
108 | def forward(self, input):
109 | h = self.hidden_layers(input) # Building hidden layers
110 | return torch.matmul(h, self.output_weight) # Final output (no activation units) of network
111 |
112 | def get_weights(self):
113 | """Return list of weight matrices"""
114 | # First part is iterating over hidden weights. Then append the output weight.
115 | return [self.hidden_layers[i].get_weight() for i in range(self.depth)] + \
116 | [self.output_weight.cpu().detach().numpy()]
117 |
118 | def get_biases(self):
119 | return [self.hidden_layers[i].get_bias() for i in range(self.depth)] + \
120 | [self.output_bias.cpu().detach().numpy()]
121 |
122 | def get_weights_tensor(self):
123 | """Return list of weight matrices as tensors"""
124 | return [self.hidden_layers[i].get_weight_tensor() for i in range(self.depth)] + \
125 | [self.output_weight.clone()]
126 |
127 |
128 | class SymbolicLayerL0(SymbolicLayer):
129 | def __init__(self, in_dim=None, funcs=None, initial_weight=None, init_stddev=0.1,
130 | bias=False, droprate_init=0.5, lamba=1.,
131 | beta=2 / 3, gamma=-0.1, zeta=1.1, epsilon=1e-6):
132 | super().__init__(in_dim=in_dim, funcs=funcs, initial_weight=initial_weight, init_stddev=init_stddev)
133 |
134 | self.droprate_init = droprate_init if droprate_init != 0 else 0.5
135 | self.use_bias = bias
136 | self.lamba = lamba
137 | self.bias = None
138 | self.in_dim = in_dim
139 | self.eps = None
140 |
141 | self.beta = beta
142 | self.gamma = gamma
143 | self.zeta = zeta
144 | self.epsilon = epsilon
145 |
146 | if self.use_bias:
147 | self.bias = nn.Parameter(0.1 * torch.ones((1, self.out_dim)))
148 | self.qz_log_alpha = nn.Parameter(torch.normal(mean=np.log(1 - self.droprate_init) - np.log(self.droprate_init),
149 | std=1e-2, size=(in_dim, self.out_dim)))
150 |
151 | def quantile_concrete(self, u):
152 | """Quantile, aka inverse CDF, of the 'stretched' concrete distribution"""
153 | y = torch.sigmoid((torch.log(u) - torch.log(1.0 - u) + self.qz_log_alpha) / self.beta)
154 | return y * (self.zeta - self.gamma) + self.gamma
155 |
156 | def sample_u(self, shape, reuse_u=False):
157 | """Uniform random numbers for concrete distribution"""
158 | if self.eps is None or not reuse_u:
159 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160 | self.eps = torch.rand(size=shape).to(device) * (1 - 2 * self.epsilon) + self.epsilon
161 | return self.eps
162 |
163 | def sample_z(self, batch_size, sample=True):
164 | """Use the hard concrete distribution as described in https://arxiv.org/abs/1712.01312"""
165 | if sample:
166 | eps = self.sample_u((batch_size, self.in_dim, self.out_dim))
167 | z = self.quantile_concrete(eps)
168 | return torch.clamp(z, min=0, max=1)
169 | else: # Mean of the hard concrete distribution
170 | pi = torch.sigmoid(self.qz_log_alpha)
171 | return torch.clamp(pi * (self.zeta - self.gamma) + self.gamma, min=0.0, max=1.0)
172 |
173 | def get_z_mean(self):
174 | """Mean of the hard concrete distribution"""
175 | pi = torch.sigmoid(self.qz_log_alpha)
176 | return torch.clamp(pi * (self.zeta - self.gamma) + self.gamma, min=0.0, max=1.0)
177 |
178 | def sample_weights(self, reuse_u=False):
179 | z = self.quantile_concrete(self.sample_u((self.in_dim, self.out_dim), reuse_u=reuse_u))
180 | mask = torch.clamp(z, min=0.0, max=1.0)
181 | return mask * self.W
182 |
183 | def get_weight(self):
184 | """Deterministic value of weight based on mean of z"""
185 | return self.W * self.get_z_mean()
186 |
187 | def loss(self):
188 | """Regularization loss term"""
189 | return torch.sum(torch.sigmoid(self.qz_log_alpha - self.beta * np.log(-self.gamma / self.zeta)))
190 |
191 | def forward(self, x, sample=True, reuse_u=False):
192 | """Multiply by weight matrix and apply activation units"""
193 | if sample:
194 | h = torch.matmul(x, self.sample_weights(reuse_u=reuse_u))
195 | else:
196 | w = self.get_weight()
197 | h = torch.matmul(x, w)
198 |
199 | if self.use_bias:
200 | h = h + self.bias
201 |
202 | # shape of h = (?, self.n_funcs)
203 |
204 | output = []
205 | # apply a different activation unit to each column of h
206 | in_i = 0 # input index
207 | out_i = 0 # output index
208 | # Apply functions with only a single input
209 | while out_i < self.n_single:
210 | output.append(self.funcs[out_i](h[:, in_i]))
211 | in_i += 1
212 | out_i += 1
213 | # Apply functions that take 2 inputs and produce 1 output
214 | while out_i < self.n_funcs:
215 | output.append(self.funcs[out_i](h[:, in_i], h[:, in_i + 1]))
216 | in_i += 2
217 | out_i += 1
218 | output = torch.stack(output, dim=1)
219 | return output
220 |
221 |
222 | class SymbolicNetL0(nn.Module):
223 | """Symbolic regression network with multiple layers. Produces one output."""
224 |
225 | def __init__(self, symbolic_depth, in_dim=1, funcs=None, initial_weights=None, init_stddev=0.1):
226 | super(SymbolicNetL0, self).__init__()
227 | self.depth = symbolic_depth # Number of hidden layers
228 | self.funcs = funcs
229 |
230 | layer_in_dim = [in_dim] + self.depth * [len(funcs)]
231 | if initial_weights is not None:
232 | layers = [SymbolicLayerL0(funcs=funcs, initial_weight=initial_weights[i],
233 | in_dim=layer_in_dim[i])
234 | for i in range(self.depth)]
235 | self.output_weight = nn.Parameter(initial_weights[-1].clone().detach())
236 | else:
237 | # Each layer initializes its own weights
238 | if not isinstance(init_stddev, list):
239 | init_stddev = [init_stddev] * self.depth
240 | layers = [SymbolicLayerL0(funcs=funcs, init_stddev=init_stddev[i], in_dim=layer_in_dim[i])
241 | for i in range(self.depth)]
242 | # Initialize weights for last layer (without activation functions)
243 | self.output_weight = nn.Parameter(torch.rand(size=(self.hidden_layers[-1].n_funcs, 1)) * 2)
244 |
245 | self.hidden_layers = nn.Sequential(*layers)
246 |
247 | def forward(self, input, sample=True, reuse_u=False):
248 | # connect output from previous layer to input of next layer
249 | h = input
250 | for i in range(self.depth):
251 | h = self.hidden_layers[i](h, sample=sample, reuse_u=reuse_u)
252 |
253 | h = torch.matmul(h, self.output_weight) # Final output (no activation units) of network
254 | return h
255 |
256 | def get_loss(self):
257 | return torch.sum(torch.stack([self.hidden_layers[i].loss() for i in range(self.depth)]))
258 |
259 | def get_weights(self):
260 | """Return list of weight matrices"""
261 | # First part is iterating over hidden weights. Then append the output weight.
262 | return [self.hidden_layers[i].get_weight().cpu().detach().numpy() for i in range(self.depth)] + \
263 | [self.output_weight.cpu().detach().numpy()]
264 |
--------------------------------------------------------------------------------
/build/lib/DySymNet/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 | from sklearn import feature_selection
6 |
7 |
8 | def nrmse(y_true, y_pred):
9 | """y, y_pred should be (num_samples,)"""
10 | assert y_true.shape == y_pred.shape, "y_true and y_pred must have the same shape"
11 | var = torch.var(y_true)
12 | return (torch.sqrt(torch.mean((y_true - y_pred) ** 2)) / var).item()
13 |
14 |
15 | def MSE(y, y_pred):
16 | return torch.mean(torch.square(y - y_pred)).item()
17 |
18 |
19 | def Relative_Error(y, y_pred):
20 | return torch.mean(torch.abs((y - y_pred) / y)).item()
21 |
22 |
23 | def nrmse_np(y_true, y_pred):
24 | """y, y_pred should be (num_samples,)"""
25 | assert y_true.shape == y_pred.shape, "y_true and y_pred must have the same shape"
26 | var = np.var(y_true)
27 | return np.sqrt(np.mean((y_true - y_pred) ** 2)) / var
28 |
29 |
30 | def R_Square(y, y_pred):
31 | """y, y_pred should be same shape (num_samples,) or (num_samples, 1)"""
32 | return (1 - torch.sum(torch.square(y - y_pred)) / torch.sum(torch.square(y - torch.mean(y)))).item()
33 |
34 |
35 | def get_logger(filename, verbosity=1, name=None):
36 | level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
37 | formatter = logging.Formatter(
38 | "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
39 | )
40 | logger = logging.getLogger(name)
41 | logger.setLevel(level_dict[verbosity])
42 |
43 | fh = logging.FileHandler(filename, "w")
44 | fh.setFormatter(formatter)
45 | logger.addHandler(fh)
46 |
47 | sh = logging.StreamHandler()
48 | sh.setFormatter(formatter)
49 | logger.addHandler(sh)
50 |
51 | return logger
52 |
53 |
54 | def get_top_k_features(X, y, k=10):
55 | if y.ndim == 2:
56 | y = y[:, 0]
57 | # if X.shape[1] <= k:
58 | # return [i for i in range(X.shape[1])]
59 | else:
60 | kbest = feature_selection.SelectKBest(feature_selection.r_regression, k=k)
61 | kbest.fit(X, y)
62 | scores = kbest.scores_
63 | # scores = corr(X, y)
64 | top_features = np.argsort(-np.abs(scores))
65 | print("keeping only the top-{} features. Order was {}".format(k, top_features))
66 | return list(top_features[:k])
67 |
--------------------------------------------------------------------------------
/data/Nguyen-1.csv:
--------------------------------------------------------------------------------
1 | -0.7403801071368961,-0.5980661649010526
2 | 0.5206056048792103,0.9327356397523248
3 | 0.6525207575940659,1.3561365636798772
4 | 0.5356322049996658,0.9762079394047138
5 | -0.586133136442661,-0.443948325535601
6 | -0.3769022485680784,-0.28838790745876036
7 | -0.6201560629882985,-0.4740705376705796
8 | -0.8404642346449558,-0.7277713400189059
9 | 0.31264230226858225,0.4409467986663268
10 | -0.20966006871096332,-0.17491882414882493
11 | 0.48421935624763335,0.8322218716406187
12 | -0.5506969824140331,-0.41443812939783053
13 | -0.6350467381160174,-0.4878667956161994
14 | 0.5237341566285307,0.941690575897929
15 | 0.4441307242480925,0.7289885425998979
16 | 0.8378658262994001,2.12808281840724
17 | 0.4609140775860472,0.771273274760623
18 | 0.27611115397676866,0.37339851135463886
19 | -0.8594323772124557,-0.7556057558522368
20 | -0.10877247441601634,-0.09822795944736544
--------------------------------------------------------------------------------
/dist/DySymNet-0.2.0-py3-none-any.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/dist/DySymNet-0.2.0-py3-none-any.whl
--------------------------------------------------------------------------------
/dist/DySymNet-0.2.0.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/dist/DySymNet-0.2.0.tar.gz
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: dysymnet
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | dependencies:
6 | - _libgcc_mutex=0.1=conda_forge
7 | - _openmp_mutex=4.5=2_gnu
8 | - blas=1.0=mkl
9 | - brotlipy=0.7.0=py38h0a891b7_1005
10 | - bzip2=1.0.8=h7f98852_4
11 | - ca-certificates=2023.01.10=h06a4308_0
12 | - certifi=2022.12.7=py38h06a4308_0
13 | - cffi=1.15.1=py38h4a40e3a_3
14 | - charset-normalizer=2.1.1=pyhd8ed1ab_0
15 | - cryptography=39.0.0=py38h3d167d9_0
16 | - cudatoolkit=11.3.1=h9edb442_11
17 | - ffmpeg=4.3=hf484d3e_0
18 | - freetype=2.12.1=hca18f0e_1
19 | - gmp=6.2.1=h58526e2_0
20 | - gnutls=3.6.13=h85f3911_1
21 | - idna=3.4=pyhd8ed1ab_0
22 | - intel-openmp=2021.4.0=h06a4308_3561
23 | - jpeg=9e=h166bdaf_2
24 | - lame=3.100=h166bdaf_1003
25 | - lcms2=2.14=hfd0df8a_1
26 | - ld_impl_linux-64=2.40=h41732ed_0
27 | - lerc=4.0.0=h27087fc_0
28 | - libdeflate=1.17=h0b41bf4_0
29 | - libffi=3.4.2=h7f98852_5
30 | - libgcc-ng=12.2.0=h65d4601_19
31 | - libgomp=12.2.0=h65d4601_19
32 | - libiconv=1.17=h166bdaf_0
33 | - libnsl=2.0.0=h7f98852_0
34 | - libpng=1.6.39=h753d276_0
35 | - libsqlite=3.40.0=h753d276_0
36 | - libstdcxx-ng=12.2.0=h46fd767_19
37 | - libtiff=4.5.0=h6adf6a1_2
38 | - libuuid=2.32.1=h7f98852_1000
39 | - libwebp-base=1.2.4=h166bdaf_0
40 | - libxcb=1.13=h7f98852_1004
41 | - libzlib=1.2.13=h166bdaf_4
42 | - mkl=2021.4.0=h06a4308_640
43 | - mkl-service=2.4.0=py38h95df7f1_0
44 | - mkl_fft=1.3.1=py38h8666266_1
45 | - mkl_random=1.2.2=py38h1abd341_0
46 | - ncurses=6.3=h27087fc_1
47 | - nettle=3.6=he412f7d_0
48 | - numpy=1.23.5=py38h14f4228_0
49 | - numpy-base=1.23.5=py38h31eccc5_0
50 | - openh264=2.1.1=h780b84a_0
51 | - openjpeg=2.5.0=hfec8fc6_2
52 | - openssl=3.1.0=hd590300_2
53 | - pillow=9.4.0=py38hde6dc18_1
54 | - pip=23.0=pyhd8ed1ab_0
55 | - pthread-stubs=0.4=h36c2ea0_1001
56 | - pycparser=2.21=pyhd8ed1ab_0
57 | - pyopenssl=23.0.0=pyhd8ed1ab_0
58 | - pysocks=1.7.1=pyha2e5f31_6
59 | - python=3.8.16=he550d4f_1_cpython
60 | - python_abi=3.8=3_cp38
61 | - pytorch=1.12.1=py3.8_cuda11.3_cudnn8.3.2_0
62 | - pytorch-mutex=1.0=cuda
63 | - readline=8.1.2=h0f457ee_0
64 | - requests=2.28.2=pyhd8ed1ab_0
65 | - setuptools=67.1.0=pyhd8ed1ab_0
66 | - six=1.16.0=pyh6c4a22f_0
67 | - tk=8.6.12=h27826a3_0
68 | - torchaudio=0.12.1=py38_cu113
69 | - torchvision=0.13.1=py38_cu113
70 | - tqdm=4.65.0=py38hb070fc8_0
71 | - typing_extensions=4.4.0=pyha770c72_0
72 | - urllib3=1.26.14=pyhd8ed1ab_0
73 | - wheel=0.38.4=pyhd8ed1ab_0
74 | - xorg-libxau=1.0.9=h7f98852_0
75 | - xorg-libxdmcp=1.1.3=h7f98852_0
76 | - xz=5.2.6=h166bdaf_0
77 | - zlib=1.2.13=h166bdaf_4
78 | - zstd=1.5.2=h3eb15da_6
79 | - pip:
80 | - asttokens==2.2.1
81 | - colorama==0.4.6
82 | - contourpy==1.0.7
83 | - cycler==0.11.0
84 | - executing==1.2.0
85 | - fonttools==4.38.0
86 | - icecream==2.1.3
87 | - imageio==2.26.1
88 | - joblib==1.2.0
89 | - kiwisolver==1.4.4
90 | - lazy-loader==0.2
91 | - matplotlib==3.6.3
92 | - mpmath==1.2.1
93 | - natsort==8.3.1
94 | - networkx==3.0
95 | - packaging==23.0
96 | - pandas==2.0.1
97 | - pmlb==1.0.1.post3
98 | - pygments==2.14.0
99 | - pylustrator==1.3.0
100 | - pyparsing==3.0.9
101 | - pyqt5==5.15.9
102 | - pyqt5-qt5==5.15.2
103 | - pyqt5-sip==12.11.1
104 | - python-dateutil==2.8.2
105 | - pytz==2023.3
106 | - pywavelets==1.4.1
107 | - pyyaml==6.0
108 | - qtawesome==1.2.3
109 | - qtpy==2.3.0
110 | - scikit-image==0.20.0
111 | - scikit-learn==1.2.2
112 | - scipy==1.9.1
113 | - seaborn==0.12.2
114 | - sympy==1.11.1
115 | - threadpoolctl==3.1.0
116 | - tifffile==2023.3.21
117 | - tzdata==2023.3
118 |
--------------------------------------------------------------------------------
/img/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/img/.DS_Store
--------------------------------------------------------------------------------
/img/ICML-logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
297 |
--------------------------------------------------------------------------------
/img/Overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/img/Overview.png
--------------------------------------------------------------------------------
/img/Snipaste.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AILWQ/DySymNet/3f356066be8b11dd8a9692d26a27aed876accdf3/img/Snipaste.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Note: To use the 'upload' functionality of this file, you must:
5 | # $ pipenv install twine --dev
6 |
7 | import io
8 | import os
9 | import sys
10 | from shutil import rmtree
11 |
12 |
13 | import setuptools
14 |
15 |
16 | # Package meta-data.
17 | NAME = 'DySymNet'
18 | DESCRIPTION = 'This package contains the official Pytorch implementation for the paper "A Neural-Guided Dynamic Symbolic Network for Exploring Mathematical Expressions from Data" accepted by ICML\'24.'
19 | URL = 'https://github.com/AILWQ/DySymNet'
20 | EMAIL = 'liwenqiang2021@gmail.com'
21 | AUTHOR = 'Wenqiang Li'
22 | REQUIRES_PYTHON = '>=3.6.0'
23 | VERSION = '0.2.0'
24 |
25 | # What packages are required for this module to be executed?
26 | REQUIRED = [
27 | 'scikit-learn==1.5.2',
28 | 'numpy==1.26.4',
29 | 'sympy==1.13.3',
30 | 'torch==2.2.2',
31 | 'matplotlib==3.9.2',
32 | 'tqdm==4.66.5',
33 | 'pandas==2.2.3',
34 | 'pip==24.2',
35 | 'scipy==1.13.1'
36 | ]
37 |
38 | # The rest you shouldn't have to touch too much :)
39 | # ------------------------------------------------
40 | # Except, perhaps the License and Trove Classifiers!
41 | # If you do change the License, remember to change the Trove Classifier for that!
42 |
43 | here = os.path.abspath(os.path.dirname(__file__))
44 |
45 | # Import the README and use it as the long-description.
46 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
47 | try:
48 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
49 | long_description = '\n' + f.read()
50 | except FileNotFoundError:
51 | long_description = DESCRIPTION
52 |
53 |
54 | # Where the magic happens:
55 | setuptools.setup(
56 | name=NAME,
57 | version=VERSION,
58 | description=DESCRIPTION,
59 | long_description=long_description,
60 | long_description_content_type='text/markdown',
61 | author=AUTHOR,
62 | author_email=EMAIL,
63 | python_requires=REQUIRES_PYTHON,
64 | url=URL,
65 | packages=setuptools.find_packages(),
66 | install_requires=REQUIRED,
67 | license='MIT',
68 | classifiers=[
69 | "Programming Language :: Python :: 3",
70 | "License :: OSI Approved :: MIT License",
71 | "Operating System :: OS Independent",
72 |
73 | ]
74 | )
--------------------------------------------------------------------------------