├── .gitignore
├── LICENSE.md
├── README.md
├── __init__.py
├── examples
├── example_notebooks
│ ├── README.md
│ ├── control_policy_optimization.ipynb
│ ├── control_with_memory_SHO.ipynb
│ ├── control_with_memory_acrobot.ipynb
│ ├── objective_function_optimization.ipynb
│ └── symbolic_regression_dynamical_system.ipynb
└── example_scripts
│ ├── control_policy_optimization.py
│ ├── control_with_memory_SHO.py
│ ├── control_with_memory_acrobot.py
│ ├── objective_function_optimization.py
│ └── symbolic_regression_dynamical_system.py
├── figures
├── applications.jpg
└── logo.jpg
└── kozax
├── environments
├── SR_environments
│ ├── lorenz_attractor.py
│ ├── lotka_volterra.py
│ ├── time_series_environment_base.py
│ └── vd_pol_oscillator.py
└── control_environments
│ ├── acrobot.py
│ ├── cart_pole.py
│ ├── control_environment_base.py
│ ├── harmonic_oscillator.py
│ └── reactor.py
├── fitness_functions
├── Gymnax_fitness_function.py
├── ODE_fitness_function.py
├── SR_fitness_function.py
└── base_fitness_function.py
├── genetic_operators
├── crossover.py
├── initialization.py
├── mutation.py
└── reproduction.py
└── genetic_programming.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | version.py
3 | pyproject.toml
4 | dist/*
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | kozax: Genetic programming framework in JAX
2 |
3 | Copyright (c) 2024 sdevries0
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
6 |
7 | You are free to:
8 | Share — copy and redistribute the material in any medium or format
9 | The licensor cannot revoke these freedoms as long as you follow the license terms.
10 |
11 | Under the following terms:
12 | Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
13 | NonCommercial — You may not use the material for commercial purposes .
14 | NoDerivatives — If you remix, transform, or build upon the material, you may not distribute the modified material.
15 | No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
16 |
17 | Notices:
18 | You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation .
19 | No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material.
20 |
21 | License Details:
22 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-nd/4.0/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | # Kozax: Flexible and Scalable Genetic Programming in JAX
6 | Kozax introduces a general framework for evolving computer programs with genetic programming in JAX. With JAX, the computer programs can be vectorized and evaluated on parallel on CPU and GPU. Furthermore, just-in-time compilation provides massive speedups for evolving offspring. Check out the [paper](https://arxiv.org/abs/2502.03047) introducing Kozax.
7 |
8 | # Features
9 | Kozax allows the user to:
10 | - define custom operators
11 | - define custom fitness functions
12 | - use trees flexibly, ranging from symbolic regression to reinforcement learning
13 | - evolve multiple trees simultaneously, even with different inputs
14 | - numerically optimise constants in the computer programs
15 |
16 | # How to use
17 | You can install Kozax via pip with
18 | ```
19 | pip install kozax
20 | ```
21 |
22 | Below is a short demo showing how you can use kozax. First we generate data:
23 | ```python
24 | import jax
25 | import jax.numpy as jnp
26 | import jax.random as jr
27 |
28 | key = jr.PRNGKey(0) #Initialize key
29 | data_key, gp_key = jr.split(key) #Split key for data and genetic programming
30 | x = jr.uniform(data_key, shape=(30,), minval=-5, maxval = 5) #Inputs
31 | y = -0.1*x**3 + 0.3*x**2 + 1.5*x #Targets
32 | ```
33 |
34 | Now we have to define a fitness function. This allows for much freedom, because you can use the computer program anyway you want to during evaluation.
35 | ```python
36 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction
37 |
38 | class FitnessFunction(BaseFitnessFunction):
39 | """
40 | The fitness function inherits the class BaseFitnessFunction and should implement the __call__ function, with the candidate, data and tree_evaluator as inputs. The tree_evaluator is used to compute the value of the candidate for each input. jax.vmap is used to vectorize the evaluation of the candidate over the inputs. The candidate's predictions are used to compute the fitness value with the mean squared error.
41 | """
42 | def __call__(self, candidate, data, tree_evaluator):
43 | X, Y = data
44 | predictions = jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, X)
45 | return jnp.mean(jnp.square(predictions-Y))
46 |
47 | fitness_function = FitnessFunction()
48 | ```
49 |
50 | Now we will use genetic programming to recover the equation from the data. This requires defining the hyperparameters, initializing the population and the general loop of evaluating and evolving the population.
51 | ```python
52 | from kozax.genetic_programming import GeneticProgramming
53 |
54 | #Define hyperparameters
55 | population_size = 500
56 | num_generations = 100
57 |
58 | #Initialize genetic programming strategy
59 | strategy = GeneticProgramming(num_generations, population_size, fitness_function)
60 |
61 | #Fit the strategy on the data. With verbose, we can print the intermediate solutions.
62 | strategy.fit(gp_key, (x, y), verbose = True)
63 | ```
64 |
65 | There are additional [examples](https://github.com/sdevries0/kozax/tree/main/examples) on how to use kozax on more complex problems.
66 |
67 | |Example|Notebook|Script|
68 | |---|---|---|
69 | |Symbolic regression of a dynamical system|[Notebook](examples/example_notebooks/symbolic_regression_dynamical_system.ipynb)|[Script](examples/example_scripts/symbolic_regression_dynamical_system.py)|
70 | |Control policy optimization in Gymnax environment|[Notebook](examples/example_notebooks/control_policy_optimization.ipynb)|[Script](examples/example_scripts/control_policy_optimization.py)|
71 | |Control policy optimization with dynamic memory|[Notebook](examples/example_notebooks/control_policy_optimization_with_memory.ipynb)|[Script](examples/example_scripts/control_policy_optimization_with_memory.py)|
72 | |Optimization of a loss function to train a neural network|[Notebook](examples/example_notebooks/objective_function_optimization.ipynb)|[Script](examples/example_scripts/objective_function_optimization.py)|
73 |
74 |
75 | # Citation
76 | If you make use of this code in your research paper, please cite:
77 | ```
78 | @article{de2025kozax,
79 | title={Kozax: Flexible and Scalable Genetic Programming in JAX},
80 | author={de Vries, Sigur and Keemink, Sander W and van Gerven, Marcel AJ},
81 | journal={arXiv preprint arXiv:2502.03047},
82 | year={2025}
83 | }
84 | ```
85 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
--------------------------------------------------------------------------------
/examples/example_notebooks/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | Here you can find several examples of applications of Kozax to various problems. These problems can easily be extended for other use cases.
--------------------------------------------------------------------------------
/examples/example_notebooks/objective_function_optimization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Objective function optimization\n",
8 | "\n",
9 | "In this example, Kozax is used to evolve a symbolic loss function to train a neural network. With each candidate loss function, a neural network is trained on the task of binary classification of XOR data points."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "These device(s) are detected: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "# Specify the cores to use for XLA\n",
27 | "import os\n",
28 | "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=10'\n",
29 | "\n",
30 | "import jax\n",
31 | "import jax.numpy as jnp\n",
32 | "import jax.random as jr\n",
33 | "import optax \n",
34 | "from typing import Callable, Tuple\n",
35 | "from jax import Array\n",
36 | "\n",
37 | "from kozax.genetic_programming import GeneticProgramming"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "We define a fitness function class that includes the network initialization, training loop and weight updates. At every epoch, a new batch of data is sampled, and the fitness is computed as the accuracy of the trained network on a validation set."
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "class FitnessFunction:\n",
54 | " \"\"\"\n",
55 | " A class to define the fitness function for evaluating candidate loss functions.\n",
56 | " The fitness is computed as the accuracy of a neural network trained with the candidate loss function\n",
57 | " on a binary classification task (XOR data).\n",
58 | "\n",
59 | " Attributes:\n",
60 | " input_dim (int): Dimension of the input data.\n",
61 | " hidden_dim (int): Dimension of the hidden layers in the neural network.\n",
62 | " output_dim (int): Dimension of the output.\n",
63 | " epochs (int): Number of training epochs.\n",
64 | " learning_rate (float): Learning rate for the optimizer.\n",
65 | " optim (optax.GradientTransformation): Optax optimizer instance.\n",
66 | " \"\"\"\n",
67 | " def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, epochs: int, learning_rate: float):\n",
68 | " self.input_dim = input_dim\n",
69 | " self.hidden_dim = hidden_dim\n",
70 | " self.output_dim = output_dim\n",
71 | " self.optim = optax.adam(learning_rate)\n",
72 | " self.epochs = epochs\n",
73 | "\n",
74 | " def __call__(self, candidate: str, data: Tuple[Array, Array, Array], tree_evaluator: Callable) -> Array:\n",
75 | " \"\"\"\n",
76 | " Computes the fitness of a candidate loss function.\n",
77 | "\n",
78 | " Args:\n",
79 | " candidate: The candidate loss function (symbolic tree).\n",
80 | " data (tuple): A tuple containing the data keys, test keys, and network keys.\n",
81 | " tree_evaluator: A function to evaluate the symbolic tree.\n",
82 | "\n",
83 | " Returns:\n",
84 | " Array: The mean loss (1 - accuracy) on the validation set.\n",
85 | " \"\"\"\n",
86 | " data_keys, test_keys, network_keys = data\n",
87 | " losses = jax.vmap(self.train, in_axes=[None, 0, 0, 0, None])(candidate, data_keys, test_keys, network_keys, tree_evaluator)\n",
88 | " return jnp.mean(losses)\n",
89 | "\n",
90 | " def get_data(self, key: jr.PRNGKey, n_samples: int = 50) -> Tuple[Array, Array]:\n",
91 | " \"\"\"\n",
92 | " Generates XOR data.\n",
93 | "\n",
94 | " Args:\n",
95 | " key (jax.random.PRNGKey): Random key for data generation.\n",
96 | " n_samples (int): Number of samples to generate.\n",
97 | "\n",
98 | " Returns:\n",
99 | " tuple: A tuple containing the input data (x) and the target labels (y).\n",
100 | " \"\"\"\n",
101 | " x = jr.uniform(key, shape=(n_samples, 2))\n",
102 | " y = jnp.logical_xor(x[:,0]>0.5, x[:,1]>0.5)\n",
103 | "\n",
104 | " return x, y[:,None]\n",
105 | "\n",
106 | " def loss_function(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array, y: Array, candidate: str, tree_evaluator: Callable) -> Array:\n",
107 | " \"\"\"\n",
108 | " Computes the loss with an evolved loss function for a given set of parameters and data.\n",
109 | "\n",
110 | " Args:\n",
111 | " params (tuple): The parameters of the neural network.\n",
112 | " x (Array): The input data.\n",
113 | " y (Array): The target labels.\n",
114 | " candidate: The candidate loss function (symbolic tree).\n",
115 | " tree_evaluator: A function to evaluate the symbolic tree.\n",
116 | "\n",
117 | " Returns:\n",
118 | " Array: The mean loss.\n",
119 | " \"\"\"\n",
120 | " pred = self.neural_network(params, x)\n",
121 | " return jnp.mean(jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, jnp.concatenate([pred, y], axis=-1)))\n",
122 | " \n",
123 | " def train(self, candidate: str, data_key: jr.PRNGKey, test_key: jr.PRNGKey, network_key: jr.PRNGKey, tree_evaluator: Callable) -> Array:\n",
124 | " \"\"\"\n",
125 | " Trains a neural network with a given candidate loss function.\n",
126 | "\n",
127 | " Args:\n",
128 | " candidate: The candidate loss function (symbolic tree).\n",
129 | " data_key (jax.random.PRNGKey): Random key for data generation during training.\n",
130 | " test_key (jax.random.PRNGKey): Random key for data generation during testing.\n",
131 | " network_key (jax.random.PRNGKey): Random key for initializing the network parameters.\n",
132 | " tree_evaluator: A function to evaluate the symbolic tree.\n",
133 | "\n",
134 | " Returns:\n",
135 | " Array: The validation loss (1 - accuracy).\n",
136 | " \"\"\"\n",
137 | " params = self.init_network_params(network_key)\n",
138 | "\n",
139 | " optim_state = self.optim.init(params)\n",
140 | "\n",
141 | " def step(i: int, carry: Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]) -> Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]:\n",
142 | " params, optim_state, key = carry\n",
143 | "\n",
144 | " key, _key = jr.split(key)\n",
145 | "\n",
146 | " x_train, y_train = self.get_data(_key, n_samples=50)\n",
147 | "\n",
148 | " # Evaluate network parameters and compute gradients\n",
149 | " grads = jax.grad(self.loss_function)(params, x_train, y_train, candidate, tree_evaluator)\n",
150 | " \n",
151 | " # Update parameters\n",
152 | " updates, optim_state = self.optim.update(grads, optim_state, params)\n",
153 | " params = optax.apply_updates(params, updates)\n",
154 | "\n",
155 | " return (params, optim_state, key)\n",
156 | "\n",
157 | " (params, _, _) = jax.lax.fori_loop(0, self.epochs, step, (params, optim_state, data_key))\n",
158 | "\n",
159 | " # Evaluate parameters on test set\n",
160 | " x_test, y_test = self.get_data(test_key, n_samples=500)\n",
161 | "\n",
162 | " pred = self.neural_network(params, x_test)\n",
163 | " return 1 - jnp.mean(y_test==(pred>0.5)) # Return 1 - accuracy\n",
164 | "\n",
165 | " def neural_network(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array) -> Array:\n",
166 | " \"\"\"\n",
167 | " Defines the neural network architecture (forward pass).\n",
168 | "\n",
169 | " Args:\n",
170 | " params (tuple): The parameters of the neural network.\n",
171 | " x (Array): The input data.\n",
172 | "\n",
173 | " Returns:\n",
174 | " Array: The output of the neural network.\n",
175 | " \"\"\"\n",
176 | " w1, b1, w2, b2, w3, b3 = params\n",
177 | " hidden = jnp.tanh(jnp.dot(x, w1) + b1)\n",
178 | " hidden = jnp.tanh(jnp.dot(hidden, w2) + b2)\n",
179 | " output = jnp.dot(hidden, w3) + b3\n",
180 | " return jax.nn.sigmoid(output)\n",
181 | "\n",
182 | " def init_network_params(self, key: jr.PRNGKey) -> Tuple[Array, Array, Array, Array, Array, Array]:\n",
183 | " \"\"\"\n",
184 | " Initializes the parameters of the neural network.\n",
185 | "\n",
186 | " Args:\n",
187 | " key (jax.random.PRNGKey): Random key for parameter initialization.\n",
188 | "\n",
189 | " Returns:\n",
190 | " tuple: A tuple containing the initialized weights and biases.\n",
191 | " \"\"\"\n",
192 | " key1, key2, key3 = jr.split(key, 3)\n",
193 | " w1 = jr.normal(key1, (self.input_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.input_dim)\n",
194 | " b1 = jnp.zeros(self.hidden_dim)\n",
195 | " w2 = jr.normal(key2, (self.hidden_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.hidden_dim)\n",
196 | " b2 = jnp.zeros(self.hidden_dim)\n",
197 | " w3 = jr.normal(key3, (self.hidden_dim, self.output_dim)) * jnp.sqrt(2.0 / self.hidden_dim)\n",
198 | " b3 = jnp.zeros(self.output_dim)\n",
199 | " return (w1, b1, w2, b2, w3, b3)"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "metadata": {},
205 | "source": [
206 | "To make sure the optimized loss function generalizes, a batch of neural networks are trained with different data and weight initialization. For this purpose, a batch of keys for initialization, data sampling and validation data are generated."
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": 3,
212 | "metadata": {},
213 | "outputs": [],
214 | "source": [
215 | "def generate_keys(key, batch_size=4):\n",
216 | " key1, key2, key3 = jr.split(key, 3)\n",
217 | " return jr.split(key1, batch_size), jr.split(key2, batch_size), jr.split(key3, batch_size)"
218 | ]
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {},
223 | "source": [
224 | "Here we define the hyperparameters and inputs to the genetic programming algorithm. The inputs to the trees are the prediction and target value. "
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 4,
230 | "metadata": {},
231 | "outputs": [
232 | {
233 | "name": "stdout",
234 | "output_type": "stream",
235 | "text": [
236 | "Input data should be formatted as: ['pred', 'y'].\n",
237 | "In generation 1, best fitness = 0.4820, best solution = log(pred*y + pred - y)\n",
238 | "In generation 2, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n",
239 | "In generation 3, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n",
240 | "In generation 4, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n",
241 | "In generation 5, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n",
242 | "In generation 6, best fitness = 0.3335, best solution = 2*pred*(0.425 - y)\n",
243 | "In generation 7, best fitness = 0.3335, best solution = 2*pred*(0.425 - y)\n",
244 | "In generation 8, best fitness = 0.3335, best solution = pred*(0.425 - y)\n",
245 | "In generation 9, best fitness = 0.3335, best solution = pred*(0.425 - y)\n",
246 | "In generation 10, best fitness = 0.3300, best solution = (0.425 - y)*log(pred + 0.793)\n",
247 | "In generation 11, best fitness = 0.1045, best solution = 3*pred*(pred - y - 0.597)\n",
248 | "In generation 12, best fitness = 0.1045, best solution = 3*pred*(pred - y - 0.597)\n",
249 | "In generation 13, best fitness = 0.1045, best solution = 3*pred*(pred - y - 0.597)\n",
250 | "In generation 14, best fitness = 0.0955, best solution = (2*pred + 0.107)*(pred - y - 0.597)\n",
251 | "In generation 15, best fitness = 0.0955, best solution = (2*pred + 0.107)*(pred - y - 0.597)\n",
252 | "In generation 16, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n",
253 | "In generation 17, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n",
254 | "In generation 18, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n",
255 | "In generation 19, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n",
256 | "In generation 20, best fitness = 0.0895, best solution = (pred + 0.107)*(pred - 2*y - 0.103)\n",
257 | "In generation 21, best fitness = 0.0865, best solution = pred*(pred - y - log(pred + 0.793) + 0.107)\n",
258 | "In generation 22, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n",
259 | "In generation 23, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n",
260 | "In generation 24, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n",
261 | "In generation 25, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n",
262 | "Complexity: 1, fitness: 0.4909999668598175, equations: nan\n",
263 | "Complexity: 4, fitness: 0.4714999794960022, equations: log(0.25 - pred)\n",
264 | "Complexity: 6, fitness: 0.4139999747276306, equations: log(-pred + y - 0.496)\n",
265 | "Complexity: 7, fitness: 0.119999960064888, equations: pred*(pred - y - 0.354)\n",
266 | "Complexity: 9, fitness: 0.09499996155500412, equations: pred*(pred - y - 0.457)\n",
267 | "Complexity: 11, fitness: 0.09149995446205139, equations: (pred - 0.211)*(pred - 2*y + 0.107)\n",
268 | "Complexity: 12, fitness: 0.08649995177984238, equations: pred*(pred - y - log(pred + 0.793) + 0.107)\n",
269 | "Complexity: 14, fitness: 0.0849999487400055, equations: (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n"
270 | ]
271 | }
272 | ],
273 | "source": [
274 | "key = jr.PRNGKey(0)\n",
275 | "data_key, gp_key = jr.split(key)\n",
276 | "\n",
277 | "population_size = 50\n",
278 | "num_populations = 5\n",
279 | "num_generations = 25\n",
280 | "\n",
281 | "operator_list = [(\"+\", lambda x, y: jnp.add(x, y), 2, 0.5), \n",
282 | " (\"-\", lambda x, y: jnp.subtract(x, y), 2, 0.5),\n",
283 | " (\"*\", lambda x, y: jnp.multiply(x, y), 2, 0.5),\n",
284 | " (\"log\", lambda x: jnp.log(x + 1e-7), 1, 0.1),\n",
285 | " ]\n",
286 | "\n",
287 | "variable_list = [[\"pred\", \"y\"]]\n",
288 | "\n",
289 | "input_dim = 2\n",
290 | "hidden_dim = 16\n",
291 | "output_dim = 1\n",
292 | "\n",
293 | "fitness_function = FitnessFunction(input_dim, hidden_dim, output_dim, learning_rate=0.01, epochs=100)\n",
294 | "\n",
295 | "strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, num_populations = num_populations)\n",
296 | "\n",
297 | "data_keys, test_keys, network_keys = generate_keys(data_key)\n",
298 | "\n",
299 | "strategy.fit(gp_key, (data_keys, test_keys, network_keys), verbose=True)"
300 | ]
301 | }
302 | ],
303 | "metadata": {
304 | "kernelspec": {
305 | "display_name": "Python 3",
306 | "language": "python",
307 | "name": "python3"
308 | },
309 | "language_info": {
310 | "codemirror_mode": {
311 | "name": "ipython",
312 | "version": 3
313 | },
314 | "file_extension": ".py",
315 | "mimetype": "text/x-python",
316 | "name": "python",
317 | "nbconvert_exporter": "python",
318 | "pygments_lexer": "ipython3",
319 | "version": "3.11.4"
320 | },
321 | "orig_nbformat": 4
322 | },
323 | "nbformat": 4,
324 | "nbformat_minor": 2
325 | }
326 |
--------------------------------------------------------------------------------
/examples/example_notebooks/symbolic_regression_dynamical_system.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Symbolic regression of a dynamical system\n",
8 | "\n",
9 | "In this example, Kozax is applied to recover the state equations of the Lotka-Volterra system. The candidate solutions are integrated as a system of differential equations, after which the predictions are compared to the true observations to determine a fitness score."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "These device(s) are detected: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]\n"
22 | ]
23 | }
24 | ],
25 | "source": [
26 | "# Specify the cores to use for XLA\n",
27 | "import os\n",
28 | "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=10'\n",
29 | "\n",
30 | "import jax\n",
31 | "import diffrax\n",
32 | "import jax.numpy as jnp\n",
33 | "import jax.random as jr\n",
34 | "import diffrax\n",
35 | "\n",
36 | "from kozax.genetic_programming import GeneticProgramming\n",
37 | "from kozax.fitness_functions.ODE_fitness_function import ODEFitnessFunction\n",
38 | "from kozax.environments.SR_environments.lotka_volterra import LotkaVolterra"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "First the data is generated, consisting of initial conditions, time points and the true observations. Kozax provides the Lotka-Volterra environment, which is integrated with Diffrax."
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 2,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "def get_data(key, env, dt, T, batch_size=20):\n",
55 | " x0s = env.sample_init_states(batch_size, key)\n",
56 | " ts = jnp.arange(0, T, dt)\n",
57 | "\n",
58 | " def solve(env, ts, x0):\n",
59 | " solver = diffrax.Dopri5()\n",
60 | " dt0 = 0.001\n",
61 | " saveat = diffrax.SaveAt(ts=ts)\n",
62 | "\n",
63 | " system = diffrax.ODETerm(env.drift)\n",
64 | "\n",
65 | " # Solve the system given an initial conditions\n",
66 | " sol = diffrax.diffeqsolve(system, solver, ts[0], ts[-1], dt0, x0, saveat=saveat, max_steps=500, \n",
67 | " adjoint=diffrax.DirectAdjoint(), stepsize_controller=diffrax.PIDController(atol=1e-7, rtol=1e-7, dtmin=0.001))\n",
68 | " \n",
69 | " return sol.ys\n",
70 | "\n",
71 | " ys = jax.vmap(solve, in_axes=[None, None, 0])(env, ts, x0s) #Parallelize over the batch dimension\n",
72 | " \n",
73 | " return x0s, ts, ys\n",
74 | "\n",
75 | "key = jr.PRNGKey(0)\n",
76 | "data_key, gp_key = jr.split(key)\n",
77 | "\n",
78 | "T = 30\n",
79 | "dt = 0.2\n",
80 | "env = LotkaVolterra()\n",
81 | "\n",
82 | "# Simulate the data\n",
83 | "data = get_data(data_key, env, dt, T, batch_size=4)\n",
84 | "x0s, ts, ys = data"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "For the fitness function, we used the ODEFitnessFunction that uses Diffrax to integrate candidate solutions. It is possible to select the solver, time step, number of steps and a stepsize controller to balance efficiency and accuracy. To ensure convergence of the genetic programming algorithm, constant optimization is applied to the best candidates at every generation. The constant optimization is performed with a couple of simple evolutionary steps that adjust the values of the constants in a candidate. The hyperparameters that define the constant optimization are `constant_optimization_N_offspring` (number of candidates with different constants should be sampled for each candidate), `constant_optimization_steps` (number of iterations of constant optimization for each candidate), `optimize_constants_elite` (number of candidates that constant optimization is applied to), `constant_step_size_init` (initial value of the step size for sampling constants) and `constant_step_size_decay` (the rate of decrease of the step size over generations)."
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 3,
97 | "metadata": {},
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "Input data should be formatted as: ['x0', 'x1'].\n"
104 | ]
105 | }
106 | ],
107 | "source": [
108 | "#Define the nodes and hyperparameters\n",
109 | "operator_list = [\n",
110 | " (\"+\", lambda x, y: jnp.add(x, y), 2, 0.5), \n",
111 | " (\"-\", lambda x, y: jnp.subtract(x, y), 2, 0.1), \n",
112 | " (\"*\", lambda x, y: jnp.multiply(x, y), 2, 0.5), \n",
113 | " ]\n",
114 | "\n",
115 | "variable_list = [[\"x\" + str(i) for i in range(env.n_var)]]\n",
116 | "layer_sizes = jnp.array([env.n_var])\n",
117 | "\n",
118 | "population_size = 100\n",
119 | "num_populations = 10\n",
120 | "num_generations = 50\n",
121 | "\n",
122 | "#Initialize the fitness function and the genetic programming strategy\n",
123 | "fitness_function = ODEFitnessFunction(solver=diffrax.Dopri5(), dt0 = 0.01, stepsize_controller=diffrax.PIDController(atol=1e-6, rtol=1e-6, dtmin=0.001), max_steps=300)\n",
124 | "\n",
125 | "strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,\n",
126 | " size_parsimony=0.003, constant_optimization_method=\"evolution\", constant_optimization_N_offspring = 50, constant_optimization_steps = 3, \n",
127 | " optimize_constants_elite=100, constant_step_size_init=0.1, constant_step_size_decay=0.99)"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "Kozax provides a fit function that receives the data and a random key. However, it is also possible to run Kozax with an easy loop consisting of evaluating and evolving. This is useful as different input data can be provided during evaluation. In symbolic regression of dynamical systems, it helps to first optimize on a small part of the time points, and provide the full data trajectories only after a couple of generations."
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 4,
140 | "metadata": {},
141 | "outputs": [
142 | {
143 | "name": "stdout",
144 | "output_type": "stream",
145 | "text": [
146 | "In generation 1, best fitness = 1.3755, best solution = [1.29 - 2.42*x0, 0.0769 - 0.336*x1]\n",
147 | "In generation 2, best fitness = 1.3496, best solution = [-0.893*x0*x1 + 1.1, 1.6*x0 - 2.06]\n",
148 | "In generation 3, best fitness = 1.3289, best solution = [-0.331*x0*x1, x0 - 0.413*x1 + 0.0743]\n",
149 | "In generation 4, best fitness = 1.3097, best solution = [x1*(-0.35*x0 - 0.246) + 1.23, 0.364 - 0.36*x1]\n",
150 | "In generation 5, best fitness = 1.2857, best solution = [-1.6*x0 - 0.248*x1 + 1.4, x0 - 0.415*x1 - 0.35]\n",
151 | "In generation 6, best fitness = 1.2720, best solution = [-1.72*x0 - 0.245*x1 + 1.42, 0.822*x0 - 0.384*x1 - 0.265]\n",
152 | "In generation 7, best fitness = 1.2277, best solution = [-1.16*x0*x1 + x0 + 0.551, 0.107*x0 - 0.228*x1 - 0.16]\n",
153 | "In generation 8, best fitness = 1.2014, best solution = [-0.639*x0*x1 + x0 + 0.461, 0.146*x0 - 0.197*x1 - 0.208]\n",
154 | "In generation 9, best fitness = 0.9929, best solution = [-2.6*x0*(-0.0276*x0 + 0.273*x1) + x0 + 0.461, 0.146*x0 - 0.197*x1 - 0.208]\n",
155 | "In generation 10, best fitness = 0.8489, best solution = [-2.56*x0*(-0.0312*x0 + 0.256*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n",
156 | "In generation 11, best fitness = 0.8489, best solution = [-2.56*x0*(-0.0312*x0 + 0.256*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n",
157 | "In generation 12, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n",
158 | "In generation 13, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n",
159 | "In generation 14, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n",
160 | "In generation 15, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n",
161 | "In generation 16, best fitness = 0.8212, best solution = [-2.62*x0*(0.158*x1 + 0.101) + x0 + 0.098, 0.175*x0 - 0.0532*x1 - 0.88]\n",
162 | "In generation 17, best fitness = 0.7373, best solution = [-0.74*x0*(x1 - 2.42), 0.122*x0*x1 - 0.00769*x0 - 0.36*x1]\n",
163 | "In generation 18, best fitness = 0.4043, best solution = [-0.42*x0*(x1 - 2.5), 0.0866*x0*x1 - 0.406*x1]\n",
164 | "In generation 19, best fitness = 0.2175, best solution = [-0.391*x0*(x1 - 2.5), 0.11*x0*x1 - 0.458*x1 + 0.0161]\n",
165 | "In generation 20, best fitness = 0.2129, best solution = [-0.402*x0*(x1 - 2.46), 0.117*x0*x1 - 0.463*x1]\n",
166 | "In generation 21, best fitness = 0.2129, best solution = [-0.402*x0*(x1 - 2.46), 0.117*x0*x1 - 0.463*x1]\n",
167 | "In generation 22, best fitness = 0.2129, best solution = [-0.402*x0*(x1 - 2.46), 0.117*x0*x1 - 0.463*x1]\n",
168 | "In generation 23, best fitness = 0.1362, best solution = [-0.409*x0*(x1 - 2.75), 0.0999*x0*x1 - 0.403*x1 + 0.0227]\n",
169 | "In generation 24, best fitness = 0.1362, best solution = [-0.409*x0*(x1 - 2.75), 0.0999*x0*x1 - 0.403*x1 + 0.0227]\n",
170 | "In generation 25, best fitness = 0.1362, best solution = [-0.409*x0*(x1 - 2.75), 0.0999*x0*x1 - 0.403*x1 + 0.0227]\n",
171 | "In generation 26, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
172 | "In generation 27, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
173 | "In generation 28, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
174 | "In generation 29, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
175 | "In generation 30, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
176 | "In generation 31, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
177 | "In generation 32, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
178 | "In generation 33, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n",
179 | "In generation 34, best fitness = 0.1189, best solution = [-0.41*x0*(x1 - 2.7), 0.105*x0*x1 - 0.407*x1]\n",
180 | "In generation 35, best fitness = 0.0978, best solution = [-0.395*x0*(x1 - 2.74), 0.0977*x0*x1 - 0.402*x1]\n",
181 | "In generation 36, best fitness = 0.0978, best solution = [-0.395*x0*(x1 - 2.74), 0.0977*x0*x1 - 0.402*x1]\n",
182 | "In generation 37, best fitness = 0.0978, best solution = [-0.395*x0*(x1 - 2.74), 0.0977*x0*x1 - 0.402*x1]\n",
183 | "In generation 38, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
184 | "In generation 39, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
185 | "In generation 40, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
186 | "In generation 41, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
187 | "In generation 42, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
188 | "In generation 43, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
189 | "In generation 44, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
190 | "In generation 45, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
191 | "In generation 46, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
192 | "In generation 47, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
193 | "In generation 48, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
194 | "In generation 49, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n",
195 | "In generation 50, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n"
196 | ]
197 | }
198 | ],
199 | "source": [
200 | "# Sample the initial population\n",
201 | "population = strategy.initialize_population(gp_key)\n",
202 | "\n",
203 | "# Define the number of timepoints to include in the data\n",
204 | "end_ts = int(ts.shape[0]/2)\n",
205 | "\n",
206 | "for g in range(num_generations):\n",
207 | " if g == 25: # After 25 generations, use the full data\n",
208 | " end_ts = ts.shape[0]\n",
209 | "\n",
210 | " key, eval_key, sample_key = jr.split(key, 3)\n",
211 | " # Evaluate the population on the data, and return the fitness\n",
212 | " fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key)\n",
213 | "\n",
214 | " # Print the best solution in the population in this generation\n",
215 | " best_fitness, best_solution = strategy.get_statistics(g)\n",
216 | " print(f\"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}\")\n",
217 | "\n",
218 | " # Evolve the population until the last generation. The fitness should be given to the evolve function.\n",
219 | " if g < (num_generations-1):\n",
220 | " population = strategy.evolve_population(population, fitness, sample_key)"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 5,
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "name": "stdout",
230 | "output_type": "stream",
231 | "text": [
232 | "Complexity: 2, fitness: 2.8892037868499756, equations: [-0.776, -0.763]\n",
233 | "Complexity: 4, fitness: 1.9611964225769043, equations: [-2.14*x0, -0.682]\n",
234 | "Complexity: 6, fitness: 1.3917429447174072, equations: [-1.98*x0, -0.285*x1]\n",
235 | "Complexity: 8, fitness: 1.3050819635391235, equations: [-2.23*x0, x0 - 0.41*x1]\n",
236 | "Complexity: 10, fitness: 1.2563836574554443, equations: [0.835 - 2.31*x0, x0 - 0.537*x1]\n",
237 | "Complexity: 12, fitness: 1.2215551137924194, equations: [1.11 - 2.31*x0, x0 - 0.45*x1 - 0.276]\n",
238 | "Complexity: 14, fitness: 0.8544017672538757, equations: [-0.845*x0*(x1 - 2.48), 0.141*x0 - 0.271*x1]\n",
239 | "Complexity: 16, fitness: 0.06887245923280716, equations: [-0.415*x0*(x1 - 2.7), 0.104*x0*x1 - 0.404*x1]\n",
240 | "Complexity: 18, fitness: 0.06406070291996002, equations: [-0.398*x0*(x1 - 2.67), 0.103*x0*x1 - 0.415*x1 - 0.000306]\n",
241 | "Complexity: 20, fitness: 0.016078392043709755, equations: [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n"
242 | ]
243 | }
244 | ],
245 | "source": [
246 | "strategy.print_pareto_front()"
247 | ]
248 | },
249 | {
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | "Instead of using evolution to optimize the constants, Kozax also offers gradient-based optimization. For gradient optimization, it is possible to specify the optimizer, the number of candidates to apply constant optimization to, the initial learning rate and the learning rate decay over generation. These two methods are provided as either can be more effective or efficient for different problems."
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 6,
259 | "metadata": {},
260 | "outputs": [
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "Input data should be formatted as: ['x0', 'x1'].\n"
266 | ]
267 | }
268 | ],
269 | "source": [
270 | "import optax\n",
271 | "\n",
272 | "strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,\n",
273 | " size_parsimony=0.003, constant_optimization_method=\"gradient\", constant_optimization_steps = 15, optimizer_class = optax.adam,\n",
274 | " optimize_constants_elite=100, constant_step_size_init=0.025, constant_step_size_decay=0.95)"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": 7,
280 | "metadata": {},
281 | "outputs": [
282 | {
283 | "name": "stdout",
284 | "output_type": "stream",
285 | "text": [
286 | "In generation 1, best fitness = 1.4242, best solution = [1.41 - 2.33*x0, -0.226*x1 - 0.145]\n",
287 | "In generation 2, best fitness = 1.3698, best solution = [1.16 - 2.29*x0, 0.113 - 0.317*x1]\n",
288 | "In generation 3, best fitness = 1.2569, best solution = [-0.49*x0*x1 + 0.674, 0.907*x0 - 0.261*x1 - 0.681]\n",
289 | "In generation 4, best fitness = 1.1925, best solution = [-0.283*x0*x1 + 0.391, 0.882*x0 - 0.272*x1 - 0.805]\n",
290 | "In generation 5, best fitness = 1.1737, best solution = [-0.276*x0*x1 + 0.392, 0.729*x0 - 0.286*x1 - 0.678]\n",
291 | "In generation 6, best fitness = 1.1737, best solution = [-0.276*x0*x1 + 0.392, 0.729*x0 - 0.286*x1 - 0.678]\n",
292 | "In generation 7, best fitness = 1.1177, best solution = [(1.41 - 0.532*x1)*(x0 - 0.218), x0 - 0.427*x1]\n",
293 | "In generation 8, best fitness = 1.0575, best solution = [2*x0*(0.495 - 0.22*x1), (-1.18*x0 + x1)*(0.00345*x0 - 0.399)]\n",
294 | "In generation 9, best fitness = 0.9400, best solution = [(0.686 - 0.268*x1)*(2*x0 + 0.161), (-1.26*x0 + x1)*(0.0518*x0 - 0.375)]\n",
295 | "In generation 10, best fitness = 0.9385, best solution = [(0.655 - 0.258*x1)*(2*x0 + 0.144), (-1.13*x0 + x1)*(0.0412*x0 - 0.371)]\n",
296 | "In generation 11, best fitness = 0.9128, best solution = [(0.453 - 0.175*x1)*(4*x0 + 0.168), (-1.14*x0 + x1)*(0.0378*x0 - 0.361)]\n",
297 | "In generation 12, best fitness = 0.6460, best solution = [(0.649 - 0.251*x1)*(2*x0 + 0.15), (0.097*x0 - 0.401)*(0.127*x0 + x1 - 0.342)]\n",
298 | "In generation 13, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n",
299 | "In generation 14, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n",
300 | "In generation 15, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n",
301 | "In generation 16, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n",
302 | "In generation 17, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n",
303 | "In generation 18, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n",
304 | "In generation 19, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n",
305 | "In generation 20, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n",
306 | "In generation 21, best fitness = 0.1688, best solution = [(0.566 - 0.215*x1)*(2*x0 - 0.098), x1*(0.113*x0 - 0.423)]\n",
307 | "In generation 22, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n",
308 | "In generation 23, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n",
309 | "In generation 24, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n",
310 | "In generation 25, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n",
311 | "In generation 26, best fitness = 0.1189, best solution = [(0.53 - 0.2*x1)*(2*x0 - 0.0242), x1*(0.104*x0 - 0.421)]\n",
312 | "In generation 27, best fitness = 0.1177, best solution = [(0.574 - 0.218*x1)*(1.87*x0 - 0.06), x1*(0.105*x0 - 0.422)]\n",
313 | "In generation 28, best fitness = 0.1133, best solution = [(0.539 - 0.203*x1)*(2*x0 - 0.0842), x1*(0.105*x0 - 0.421)]\n",
314 | "In generation 29, best fitness = 0.1107, best solution = [1.87*x0*(0.572 - 0.217*x1), x1*(0.105*x0 - 0.42)]\n",
315 | "In generation 30, best fitness = 0.1107, best solution = [1.87*x0*(0.572 - 0.217*x1), x1*(0.105*x0 - 0.42)]\n",
316 | "In generation 31, best fitness = 0.1107, best solution = [1.87*x0*(0.572 - 0.217*x1), x1*(0.105*x0 - 0.42)]\n",
317 | "In generation 32, best fitness = 0.1054, best solution = [1.87*x0*(0.571 - 0.216*x1), x1*(0.104*x0 - 0.419)]\n",
318 | "In generation 33, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
319 | "In generation 34, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
320 | "In generation 35, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
321 | "In generation 36, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
322 | "In generation 37, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
323 | "In generation 38, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
324 | "In generation 39, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
325 | "In generation 40, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n",
326 | "In generation 41, best fitness = 0.0998, best solution = [2*x0*(0.536 - 0.203*x1), x1*(0.104*x0 - 0.418)]\n",
327 | "In generation 42, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n",
328 | "In generation 43, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n",
329 | "In generation 44, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n",
330 | "In generation 45, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n",
331 | "In generation 46, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n",
332 | "In generation 47, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n",
333 | "In generation 48, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n",
334 | "In generation 49, best fitness = 0.0938, best solution = [2*x0*(0.537 - 0.202*x1), x1*(0.103*x0 - 0.415)]\n",
335 | "In generation 50, best fitness = 0.0938, best solution = [2*x0*(0.537 - 0.202*x1), x1*(0.103*x0 - 0.415)]\n"
336 | ]
337 | }
338 | ],
339 | "source": [
340 | "key = jr.PRNGKey(0)\n",
341 | "data_key, gp_key = jr.split(key)\n",
342 | "\n",
343 | "T = 30\n",
344 | "dt = 0.2\n",
345 | "env = LotkaVolterra()\n",
346 | "\n",
347 | "# Simulate the data\n",
348 | "data = get_data(data_key, env, dt, T, batch_size=4)\n",
349 | "x0s, ts, ys = data\n",
350 | "\n",
351 | "# Sample the initial population\n",
352 | "population = strategy.initialize_population(gp_key)\n",
353 | "\n",
354 | "# Define the number of timepoints to include in the data\n",
355 | "end_ts = int(ts.shape[0]/2)\n",
356 | "\n",
357 | "for g in range(num_generations):\n",
358 | " if g == 25: # After 25 generations, use the full data\n",
359 | " end_ts = ts.shape[0]\n",
360 | "\n",
361 | " key, eval_key, sample_key = jr.split(key, 3)\n",
362 | " # Evaluate the population on the data, and return the fitness\n",
363 | " fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key)\n",
364 | "\n",
365 | " # Print the best solution in the population in this generation\n",
366 | " best_fitness, best_solution = strategy.get_statistics(g)\n",
367 | " print(f\"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}\")\n",
368 | "\n",
369 | " # Evolve the population until the last generation. The fitness should be given to the evolve function.\n",
370 | " if g < (num_generations-1):\n",
371 | " population = strategy.evolve_population(population, fitness, sample_key)"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 8,
377 | "metadata": {},
378 | "outputs": [
379 | {
380 | "name": "stdout",
381 | "output_type": "stream",
382 | "text": [
383 | "Complexity: 2, fitness: 2.889601707458496, equations: [-0.789, -0.746]\n",
384 | "Complexity: 4, fitness: 2.231682300567627, equations: [1.05 - x0, -0.336]\n",
385 | "Complexity: 6, fitness: 1.392594814300537, equations: [-1.95*x0, -0.291*x1]\n",
386 | "Complexity: 8, fitness: 1.3738036155700684, equations: [-1.62*x0, x0 - 0.407*x1]\n",
387 | "Complexity: 10, fitness: 1.3001673221588135, equations: [-0.276*x0*x1, x0 - 0.386*x1]\n",
388 | "Complexity: 12, fitness: 1.226672649383545, equations: [-0.274*x0*x1 + 0.218, x0 - 0.407*x1]\n",
389 | "Complexity: 14, fitness: 0.916778028011322, equations: [x0*(1.33 - 0.49*x1), 0.311*x0 - 0.327*x1]\n",
390 | "Complexity: 16, fitness: 0.04582058638334274, equations: [2*x0*(0.537 - 0.202*x1), x1*(0.103*x0 - 0.415)]\n",
391 | "Complexity: 18, fitness: 0.04011291265487671, equations: [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n"
392 | ]
393 | }
394 | ],
395 | "source": [
396 | "strategy.print_pareto_front()"
397 | ]
398 | }
399 | ],
400 | "metadata": {
401 | "kernelspec": {
402 | "display_name": "test_kozax",
403 | "language": "python",
404 | "name": "python3"
405 | },
406 | "language_info": {
407 | "codemirror_mode": {
408 | "name": "ipython",
409 | "version": 3
410 | },
411 | "file_extension": ".py",
412 | "mimetype": "text/x-python",
413 | "name": "python",
414 | "nbconvert_exporter": "python",
415 | "pygments_lexer": "ipython3",
416 | "version": "3.12.2"
417 | },
418 | "orig_nbformat": 4
419 | },
420 | "nbformat": 4,
421 | "nbformat_minor": 2
422 | }
423 |
--------------------------------------------------------------------------------
/examples/example_scripts/control_policy_optimization.py:
--------------------------------------------------------------------------------
1 | """
2 | Control policy optimization
3 |
4 | In this example, a symbolic policy is evolved for the pendulum swingup task. Gymnax is used for simulation of the pendulum environment, showing that Kozax can easily be extended to external libraries.
5 | """
6 |
7 | # Specify the cores to use for XLA
8 | import os
9 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10'
10 | import jax
11 | import jax.numpy as jnp
12 | import jax.random as jr
13 | import gymnax
14 | import matplotlib.pyplot as plt
15 |
16 | from kozax.genetic_programming import GeneticProgramming
17 | from kozax.fitness_functions.Gymnax_fitness_function import GymFitnessFunction
18 |
19 | if __name__ == "__main__":
20 | """
21 | Kozax provides a simple fitness function for Gymnax environments, which is used in this example.
22 | """
23 |
24 | #Define hyperparameters
25 | population_size = 100
26 | num_populations = 5
27 | num_generations = 50
28 | batch_size = 16
29 |
30 | fitness_function = GymFitnessFunction("Pendulum-v1")
31 |
32 | #Define operators and variables
33 | operator_list = [
34 | ("+", lambda x, y: jnp.add(x, y), 2, 0.5),
35 | ("-", lambda x, y: jnp.subtract(x, y), 2, 0.1),
36 | ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5),
37 | ]
38 |
39 | variable_list = [[f"y{i}" for i in range(fitness_function.env.observation_space(fitness_function.env_params).shape[0])]]
40 |
41 | #Initialize strategy
42 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, num_populations=num_populations)
43 |
44 | key = jr.PRNGKey(0)
45 | data_key, gp_key = jr.split(key, 2)
46 |
47 | # The data comprises keys need to initialize the batch of environments.
48 | batch_keys = jr.split(data_key, batch_size)
49 |
50 | strategy.fit(gp_key, batch_keys, verbose=True)
51 |
52 | """
53 | ## Visualize best solution
54 |
55 | We can visualize the sin and cos position in a trajectory using the best solution.
56 | """
57 |
58 | env, env_params = gymnax.make('Pendulum-v1')
59 | key = jr.PRNGKey(2)
60 | obs, env_state = env.reset(key)
61 | all_obs = []
62 | treward = []
63 | actions = []
64 |
65 | done = False
66 |
67 | sin = jnp.sin
68 | cos = jnp.cos
69 |
70 | T=199
71 | for t in range(T):
72 |
73 | y0, y1, y2 = obs
74 | action = 2.92*y0*(-6.56*y1 - 1.29*y2)
75 | obs, env_state, reward, done, _ = env.step(
76 | jr.fold_in(key, t), env_state, action, env_params
77 | )
78 | all_obs.append(obs)
79 | treward.append(reward)
80 | actions.append(action)
81 |
82 | all_obs = jnp.array(all_obs)
83 | plt.plot(all_obs[:,0], label='cos(x)')
84 | plt.plot(all_obs[:,1], label='sin(x)')
85 | plt.legend()
86 | plt.show()
--------------------------------------------------------------------------------
/examples/example_scripts/control_with_memory_SHO.py:
--------------------------------------------------------------------------------
1 | """
2 | # Memory-based control policy optimization (Stochastic Harmonic Oscillator)
3 |
4 | In this more advanced example, we will optimize a symbolic control policy that is extended with a dynamic latent memory (check out this paper for more details:
5 | https://arxiv.org/abs/2406.02765). The memory is defined by a set of differential equations, consisting of a tree for each latent unit. The memory updates every time step, and the
6 | control policy maps the memory to a control signal via an additional readout tree. This setup is applied to stabilization of the stochastic harmonic oscillator at random targets.
7 | """
8 |
9 | # Specify the cores to use for XLA
10 | import os
11 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10'
12 |
13 | import jax
14 | import jax.numpy as jnp
15 | import diffrax
16 | import jax.random as jr
17 | import matplotlib.pyplot as plt
18 | from jax import Array
19 | from typing import Tuple, Callable
20 | import copy
21 |
22 | from kozax.genetic_programming import GeneticProgramming
23 | from kozax.environments.control_environments.harmonic_oscillator import HarmonicOscillator
24 |
25 | """
26 | First we generate data, consisting of initial conditions, keys for noise, targets and parameters of the environment.
27 | """
28 |
29 | def get_data(key, env, batch_size, dt, T, param_setting):
30 | init_key, noise_key, param_key = jr.split(key, 3)
31 | x0, targets = env.sample_init_states(batch_size, init_key)
32 | noise_keys = jr.split(noise_key, batch_size)
33 | ts = jnp.arange(0, T, dt)
34 |
35 | params = env.sample_params(batch_size, param_setting, ts, param_key)
36 | return x0, ts, targets, noise_keys, params
37 |
38 | key = jr.PRNGKey(0)
39 | gp_key, data_key = jr.split(key)
40 | batch_size = 8
41 | T = 40
42 | dt = 0.2
43 | param_setting = "Constant"
44 |
45 | env = HarmonicOscillator(process_noise = 0.05, obs_noise = 0.05)
46 |
47 | data = get_data(data_key, env, batch_size, dt, T, param_setting)
48 |
49 | """
50 | The evaluator class parallelizes the simulation of the control loop over the batched data. In each trajectory, a coupled dynamical system is simulated using diffrax,
51 | integrating both the dynamic memory and the dynamic environment, defined by the drift and diffusion functions. At every time step, the state of the environment is mapped to
52 | observations and the latent memory is mapped to a control signal. Afterwards, the state equation of the memory and environment state is computed. When the simulation is done,
53 | the fitness is computed given the control and environment state.
54 | """
55 |
56 | class Evaluator:
57 | def __init__(self, env, state_size: int, dt0: float, solver=diffrax.Euler(), max_steps: int = 16**4, stepsize_controller: diffrax.AbstractStepSizeController = diffrax.ConstantStepSize()) -> None:
58 | """Evaluates dynamic symbolic policies in control tasks.
59 |
60 | Args:
61 | env: Environment on which the candidate is evaluated.
62 | state_size: Dimensionality of the hidden state.
63 | dt0: Initial step size for integration.
64 | solver: Solver used for integration (default: diffrax.Euler()).
65 | max_steps: The maximum number of steps that can be used in integration (default: 16**4).
66 | stepsize_controller: Controller for the stepsize during integration (default: diffrax.ConstantStepSize()).
67 |
68 | Attributes:
69 | env: Environment on which the candidate is evaluated.
70 | max_fitness: Max fitness which is assigned when a trajectory returns an invalid value.
71 | state_size: Dimensionality of the hidden state.
72 | obs_size: Dimensionality of the observations.
73 | control_size: Dimensionality of the control.
74 | latent_size: Dimensionality of the state of the environment.
75 | dt0: Initial step size for integration.
76 | solver: Solver used for integration.
77 | max_steps: The maximum number of steps that can be used in integration.
78 | stepsize_controller: Controller for the stepsize during integration.
79 | """
80 | self.env = env
81 | self.max_fitness = 1e4
82 | self.state_size = state_size
83 | self.obs_size = env.n_obs
84 | self.control_size = env.n_control_inputs
85 | self.latent_size = env.n_var*env.n_dim
86 | self.dt0 = dt0
87 | self.solver = solver
88 | self.max_steps = max_steps
89 | self.stepsize_controller = stepsize_controller
90 |
91 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float:
92 | """Evaluates the candidate on a task.
93 |
94 | Args:
95 | candidate: The coefficients of the candidate.
96 | data: The data required to evaluate the candidate.
97 | tree_evaluator: Function for evaluating trees.
98 |
99 | Returns:
100 | Fitness of the candidate.
101 | """
102 | _, _, _, _, fitness = jax.vmap(self.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(candidate, *data, tree_evaluator)
103 |
104 | fitness = jnp.mean(fitness)
105 | return fitness
106 |
107 | def evaluate_trajectory(self, candidate: Array, x0: Array, ts: Array, target: float, noise_key: jr.PRNGKey, params: Tuple, tree_evaluator: Callable) -> Tuple[Array, Array, Array, Array, float]:
108 | """Solves the coupled differential equation of the system and controller.
109 | The differential equation of the system is defined in the environment and the differential equation
110 | of the control is defined by the set of trees.
111 |
112 | Args:
113 | candidate: Candidate with trees for the hidden state and readout.
114 | x0: Initial state of the system.
115 | ts: Time points on which the controller is evaluated.
116 | target: Target position that the system should reach.
117 | noise_key: Key to generate noisy observations.
118 | params: Parameters that define the environment.
119 | tree_evaluator: Function for evaluating trees.
120 |
121 | Returns:
122 | States, observations, control, activities of the hidden state of the candidate and the fitness of the candidate.
123 | """
124 | env = copy.copy(self.env)
125 | env.initialize_parameters(params, ts)
126 |
127 | state_equation = candidate[:self.state_size]
128 | readout = candidate[self.state_size:]
129 |
130 | solver = self.solver
131 | dt0 = self.dt0
132 | saveat = diffrax.SaveAt(ts=ts)
133 |
134 | process_noise_key, obs_noise_key = jr.split(noise_key, 2)
135 |
136 | # Concatenate the initial state of the system with the initial state of the latent memory
137 | _x0 = jnp.concatenate([x0, jnp.zeros(self.state_size)])
138 |
139 | brownian_motion = diffrax.UnsafeBrownianPath(shape=(env.n_var,), key=process_noise_key, levy_area=diffrax.SpaceTimeLevyArea) #define process noise
140 | system = diffrax.MultiTerm(diffrax.ODETerm(self._drift), diffrax.ControlTerm(self._diffusion, brownian_motion))
141 |
142 | # Solve the coupled system of the environment and the controller
143 | sol = diffrax.diffeqsolve(
144 | system, solver, ts[0], ts[-1], dt0, _x0, saveat=saveat, adjoint=diffrax.DirectAdjoint(), max_steps=self.max_steps, event=diffrax.Event(self.env.cond_fn_nan),
145 | args=(env, state_equation, readout, obs_noise_key, target, tree_evaluator), stepsize_controller=self.stepsize_controller, throw=False
146 | )
147 |
148 | xs = sol.ys[:,:self.latent_size]
149 | # Get observations of the state at every time step
150 | _, ys = jax.lax.scan(env.f_obs, obs_noise_key, (ts, xs))
151 |
152 | activities = sol.ys[:,self.latent_size:]
153 | # Get control actions at every time step
154 | us = jax.vmap(lambda y, a, tar: tree_evaluator(readout, jnp.concatenate([y, a, jnp.zeros(self.control_size), target])), in_axes=[0,0,None])(ys, activities, target)
155 |
156 | # Compute the fitness of the candidate in this trajectory
157 | fitness = env.fitness_function(xs, us[:,None], target, ts)
158 |
159 | return xs, ys, us, activities, fitness
160 |
161 | def _drift(self, t, x_a, args):
162 | env, state_equation, readout, obs_noise_key, target, tree_evaluator = args
163 | x = x_a[:self.latent_size]
164 | a = x_a[self.latent_size:]
165 |
166 | _, y = env.f_obs(obs_noise_key, (t, x)) #Get observations from system
167 | u = tree_evaluator(readout, jnp.concatenate([jnp.zeros(self.obs_size), a, jnp.zeros(self.control_size), target])) #Compute control action from latent memory
168 |
169 | u = jnp.atleast_1d(u)
170 |
171 | dx = env.drift(t, x, u) #Apply control to system and get system change
172 | da = tree_evaluator(state_equation, jnp.concatenate([y, a, u, target])) #Compute change in latent memory
173 |
174 | return jnp.concatenate([dx, da])
175 |
176 | def _diffusion(self, t, x_a, args):
177 | env, state_equation, readout, obs_noise_key, target, tree_evaluator = args
178 | x = x_a[:self.latent_size]
179 | a = x_a[self.latent_size:]
180 |
181 | return jnp.concatenate([env.diffusion(t, x, jnp.array([0])), jnp.zeros((self.state_size, self.latent_size))]) #Only the system is stochastic
182 |
183 | """
184 | Here we define the hyperparameters, operators, variables and initialize the strategy. The control policy will consist of two latent variables and a readout layer.
185 | `layer_sizes` allows us to define different types of tree, where `variable_list` contains different sets of input variables for each type of tree.
186 | The readout layer will only receive the latent states, while the inputs to the state equations consists of the observations, control signal and latent states.
187 | """
188 |
189 | #Define hyperparameters
190 | population_size = 100
191 | num_populations = 10
192 | num_generations = 50
193 | state_size = 2
194 |
195 | #Define expressions
196 | operator_list = [("+", lambda x, y: x + y, 2, 0.5),
197 | ("-", lambda x, y: x - y, 2, 0.1),
198 | ("*", lambda x, y: x * y, 2, 0.5)]
199 |
200 | variable_list = [["x" + str(i) for i in range(env.n_obs)] + ["a1", "a2", "u", "tar"], ["a1", "a2", "tar"]]
201 |
202 | layer_sizes = jnp.array([state_size, env.n_control_inputs])
203 |
204 | #Define evaluator
205 | fitness_function = Evaluator(env, state_size, 0.05, solver=diffrax.GeneralShARK(), max_steps=1000)
206 |
207 | #Initialize strategy
208 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations, size_parsimony=0.003)
209 |
210 | strategy.fit(gp_key, data, verbose=True)
211 |
212 | # Visualize best solution
213 |
214 | #Generate test_data
215 | data = get_data(jr.PRNGKey(10), env, 4, 0.01, T, param_setting)
216 | x0s, ts, targets, noise_key, params = data
217 |
218 | best_candidate = strategy.pareto_front[1][17]
219 | print(strategy.expression_to_string(best_candidate))
220 |
221 | xs, ys, us, activities, fitness = jax.vmap(fitness_function.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(best_candidate, *data, strategy.tree_evaluator)
222 |
223 | fig, ax = plt.subplots(2, 2, figsize=(10, 5))
224 | ax = ax.ravel()
225 | for i in range(4):
226 | ax[i].plot(ts, xs[i,:,0], label="$x_1$", color = "blue")
227 | ax[i].plot(ts, xs[i,:,1], label="$x_2$", color = "orange")
228 | ax[i].plot(ts, ys[i,:,0], alpha=0.3, color = "blue")
229 | ax[i].plot(ts, ys[i,:,1], alpha=0.3, color = "orange")
230 | ax[i].plot(ts, us[i,:], label="$u_1$", color = "green")
231 | ax[i].hlines(targets[i], ts[0], ts[-1], linestyles='dashed', color = "black")
232 |
233 | plt.legend(loc="best")
234 | plt.show()
--------------------------------------------------------------------------------
/examples/example_scripts/control_with_memory_acrobot.py:
--------------------------------------------------------------------------------
1 | """
2 | Memory-based control policy optimization
3 |
4 | In this more advanced example, we will optimize a symbolic control policy that is extended with a dynamic latent memory (check out this paper for more details: https://arxiv.org/abs/2406.02765).
5 | The memory is defined by a set of differential equations, consisting of a tree for each latent unit. The memory updates every time step, and the control policy maps the memory to a control
6 | signal via an additional readout tree. This setup is applied to the partially observable acrobot swingup task, in which the angular velocity is hidden.
7 | """
8 |
9 | # Specify the cores to use for XLA
10 | import os
11 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10'
12 |
13 | import jax
14 | import jax.numpy as jnp
15 | import diffrax
16 | import jax.random as jr
17 | import matplotlib.pyplot as plt
18 | from jax import Array
19 | from typing import Tuple, Callable
20 | import copy
21 |
22 | from kozax.genetic_programming import GeneticProgramming
23 | from kozax.environments.control_environments.acrobot import Acrobot
24 |
25 | """
26 | First we generate data, consisting of initial conditions, keys for noise, targets and parameters of the environment.
27 | """
28 |
29 | def get_data(key, env, batch_size, dt, T, param_setting):
30 | init_key, noise_key2, param_key = jr.split(key, 3)
31 | x0, targets = env.sample_init_states(batch_size, init_key)
32 | noise_keys = jr.split(noise_key2, batch_size)
33 | ts = jnp.arange(0, T, dt)
34 |
35 | params = env.sample_params(batch_size, param_setting, ts, param_key)
36 | return x0, ts, targets, noise_keys, params
37 |
38 | """
39 | The evaluator class parallelizes the simulation of the control loop over the batched data. In each trajectory, a coupled dynamical system is simulated using diffrax, integrating both the
40 | dynamic memory and the dynamic environment. At every time step, the state of the environment is mapped to observations and the latent memory is mapped to a control signal. Afterwards,
41 | the state equation of the memory and environment state is computed. When the simulation is done, the fitness is computed given the control and environment state.
42 | """
43 |
44 | class Evaluator:
45 | def __init__(self, env, state_size: int, dt0: float, solver=diffrax.Euler(), max_steps: int = 16**4, stepsize_controller: diffrax.AbstractStepSizeController = diffrax.ConstantStepSize()) -> None:
46 | """Evaluates dynamic symbolic policies in control tasks.
47 |
48 | Args:
49 | env: Environment on which the candidate is evaluated.
50 | state_size: Dimensionality of the hidden state.
51 | dt0: Initial step size for integration.
52 | solver: Solver used for integration (default: diffrax.Euler()).
53 | max_steps: The maximum number of steps that can be used in integration (default: 16**4).
54 | stepsize_controller: Controller for the stepsize during integration (default: diffrax.ConstantStepSize()).
55 |
56 | Attributes:
57 | env: Environment on which the candidate is evaluated.
58 | max_fitness: Max fitness which is assigned when a trajectory returns an invalid value.
59 | state_size: Dimensionality of the hidden state.
60 | obs_size: Dimensionality of the observations.
61 | control_size: Dimensionality of the control.
62 | latent_size: Dimensionality of the state of the environment.
63 | dt0: Initial step size for integration.
64 | solver: Solver used for integration.
65 | max_steps: The maximum number of steps that can be used in integration.
66 | stepsize_controller: Controller for the stepsize during integration.
67 | """
68 | self.env = env
69 | self.max_fitness = 1e4
70 | self.state_size = state_size
71 | self.obs_size = env.n_obs
72 | self.control_size = env.n_control_inputs
73 | self.latent_size = env.n_var*env.n_dim
74 | self.dt0 = dt0
75 | self.solver = solver
76 | self.max_steps = max_steps
77 | self.stepsize_controller = stepsize_controller
78 |
79 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float:
80 | """Evaluates the candidate on a task.
81 |
82 | Args:
83 | candidate: The coefficients of the candidate.
84 | data: The data required to evaluate the candidate.
85 | tree_evaluator: Function for evaluating trees.
86 |
87 | Returns:
88 | Fitness of the candidate.
89 | """
90 | _, _, _, _, fitness = jax.vmap(self.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(candidate, *data, tree_evaluator)
91 |
92 | fitness = jnp.mean(fitness)
93 | return fitness
94 |
95 | def evaluate_trajectory(self, candidate: Array, x0: Array, ts: Array, target: float, noise_key: jr.PRNGKey, params: Tuple, tree_evaluator: Callable) -> Tuple[Array, Array, Array, Array, float]:
96 | """Solves the coupled differential equation of the system and controller.
97 | The differential equation of the system is defined in the environment and the differential equation
98 | of the control is defined by the set of trees.
99 |
100 | Args:
101 | candidate: Candidate with trees for the hidden state and readout.
102 | x0: Initial state of the system.
103 | ts: Time points on which the controller is evaluated.
104 | target: Target position that the system should reach.
105 | noise_key: Key to generate noisy observations.
106 | params: Parameters that define the environment.
107 | tree_evaluator: Function for evaluating trees.
108 |
109 | Returns:
110 | States, observations, control, activities of the hidden state of the candidate and the fitness of the candidate.
111 | """
112 | env = copy.copy(self.env)
113 | env.initialize_parameters(params, ts)
114 |
115 | state_equation = candidate[:self.state_size]
116 | readout = candidate[self.state_size:]
117 |
118 | solver = self.solver
119 | dt0 = self.dt0
120 | saveat = diffrax.SaveAt(ts=ts)
121 |
122 | # Concatenate the initial state of the system with the initial state of the latent memory
123 | _x0 = jnp.concatenate([x0, jnp.zeros(self.state_size)])
124 |
125 | system = diffrax.ODETerm(self._drift)
126 |
127 | # Solve the coupled system of the environment and the controller
128 | sol = diffrax.diffeqsolve(
129 | system, solver, ts[0], ts[-1], dt0, _x0, saveat=saveat, adjoint=diffrax.DirectAdjoint(), max_steps=self.max_steps, event=diffrax.Event(self.env.cond_fn_nan),
130 | args=(env, state_equation, readout, noise_key, target, tree_evaluator), stepsize_controller=self.stepsize_controller, throw=False
131 | )
132 |
133 | xs = sol.ys[:,:self.latent_size]
134 | # Get observations of the state at every time step
135 | _, ys = jax.lax.scan(env.f_obs, noise_key, (ts, xs))
136 |
137 | activities = sol.ys[:,self.latent_size:]
138 | # Get control actions at every time step
139 | us = jax.vmap(lambda y, a, tar: tree_evaluator(readout, jnp.concatenate([y, a, jnp.zeros(self.control_size), target])), in_axes=[0,0,None])(ys, activities, target)
140 |
141 | # Compute the fitness of the candidate in this trajectory
142 | fitness = env.fitness_function(xs, us[:,None], target, ts)
143 |
144 | return xs, ys, us, activities, fitness
145 |
146 | def _drift(self, t, x_a, args):
147 | env, state_equation, readout, noise_key, target, tree_evaluator = args
148 | x = x_a[:self.latent_size]
149 | a = x_a[self.latent_size:]
150 |
151 | _, y = env.f_obs(noise_key, (t, x)) #Get observations from system
152 | u = tree_evaluator(readout, jnp.concatenate([jnp.zeros(self.obs_size), a, jnp.zeros(self.control_size), target])) #Compute control action from latent memory
153 |
154 | u = jnp.atleast_1d(u)
155 |
156 | dx = env.drift(t, x, u) #Apply control to system and get system change
157 | da = tree_evaluator(state_equation, jnp.concatenate([y, a, u, target])) #Compute change in latent memory
158 |
159 | return jnp.concatenate([dx, da])
160 |
161 | """
162 | Here we define the hyperparameters, operators, variables and initialize the strategy. The control policy will consist of two latent variables and a readout layer.
163 | `layer_sizes` allows us to define different types of tree, where `variable_list` contains different sets of input variables for each type of tree.
164 | The readout layer will only receive the latent states, while the inputs to the state equations consists of the observations, control signal and latent states.
165 | """
166 |
167 | if __name__ == "__main__":
168 | key = jr.PRNGKey(0)
169 | gp_key, data_key = jr.split(key)
170 | batch_size = 8
171 | T = 40
172 | dt = 0.2
173 | param_setting = "Constant"
174 |
175 | env = Acrobot(n_obs = 2) # Partial observably Acrobot
176 |
177 | data = get_data(data_key, env, batch_size, dt, T, param_setting)
178 |
179 | #Define hyperparameters
180 | population_size = 100
181 | num_populations = 10
182 | num_generations = 50
183 | state_size = 2
184 |
185 | #Define expressions
186 | operator_list = [("+", lambda x, y: x + y, 2, 0.5),
187 | ("-", lambda x, y: x - y, 2, 0.1),
188 | ("*", lambda x, y: x * y, 2, 0.5),
189 | ("sin", lambda x: jnp.sin(x), 1, 0.1),
190 | ("cos", lambda x: jnp.cos(x), 1, 0.1)]
191 |
192 | variable_list = [["x" + str(i) for i in range(env.n_obs)] + ["a1", "a2", "u"], ["a1", "a2"]]
193 |
194 | layer_sizes = jnp.array([state_size, env.n_control_inputs])
195 |
196 | #Define evaluator
197 | fitness_function = Evaluator(env, state_size, 0.05, solver=diffrax.Dopri5(), stepsize_controller=diffrax.PIDController(atol=1e-4, rtol=1e-4, dtmin=0.001), max_steps=1000)
198 |
199 | #Initialize strategy
200 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations, size_parsimony=0.003)
201 |
202 | strategy.fit(gp_key, data, verbose=True)
203 |
204 | """
205 | ## Visualize best solution
206 | """
207 |
208 | #Generate test_data
209 | data = get_data(jr.PRNGKey(10), env, 1, 0.01, T, param_setting)
210 | x0s, ts, _, _, params = data
211 |
212 | best_candidate = strategy.pareto_front[1][21]
213 |
214 | xs, ys, us, activities, fitness = jax.vmap(fitness_function.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(best_candidate, *data, strategy.tree_evaluator)
215 |
216 | plt.plot(ts, -jnp.cos(xs[0,:,0]), color = f"C{0}", label="first link")
217 | plt.plot(ts, -jnp.cos(xs[0,:,0]) - jnp.cos(xs[0,:,0] + xs[0,:,1]), color = f"C{1}", label="second link")
218 | plt.hlines(1.5, ts[0], ts[-1], linestyles='dashed', color = "black")
219 |
220 | plt.legend(loc="best")
221 | plt.show()
--------------------------------------------------------------------------------
/examples/example_scripts/objective_function_optimization.py:
--------------------------------------------------------------------------------
1 | """
2 | Objective function optimization
3 |
4 | In this example, Kozax is used to evolve a symbolic loss function to train a neural network.
5 | With each candidate loss function, a neural network is trained on the task of binary classification of XOR data points.
6 | """
7 |
8 | # Specify the cores to use for XLA
9 | import os
10 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10'
11 |
12 | import jax
13 | import jax.numpy as jnp
14 | import jax.random as jr
15 | import optax
16 | from typing import Callable, Tuple
17 | from jax import Array
18 |
19 | from kozax.genetic_programming import GeneticProgramming
20 |
21 | """
22 | We define a fitness function class that includes the network initialization, training loop and weight updates.
23 | At every epoch, a new batch of data is sampled, and the fitness is computed as the accuracy of the trained network on a validation set.
24 | """
25 |
26 | class FitnessFunction:
27 | """
28 | A class to define the fitness function for evaluating candidate loss functions.
29 | The fitness is computed as the accuracy of a neural network trained with the candidate loss function
30 | on a binary classification task (XOR data).
31 |
32 | Attributes:
33 | input_dim (int): Dimension of the input data.
34 | hidden_dim (int): Dimension of the hidden layers in the neural network.
35 | output_dim (int): Dimension of the output.
36 | epochs (int): Number of training epochs.
37 | learning_rate (float): Learning rate for the optimizer.
38 | optim (optax.GradientTransformation): Optax optimizer instance.
39 | """
40 | def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, epochs: int, learning_rate: float):
41 | self.input_dim = input_dim
42 | self.hidden_dim = hidden_dim
43 | self.output_dim = output_dim
44 | self.optim = optax.adam(learning_rate)
45 | self.epochs = epochs
46 |
47 | def __call__(self, candidate: str, data: Tuple[Array, Array, Array], tree_evaluator: Callable) -> Array:
48 | """
49 | Computes the fitness of a candidate loss function.
50 |
51 | Args:
52 | candidate: The candidate loss function (symbolic tree).
53 | data (tuple): A tuple containing the data keys, test keys, and network keys.
54 | tree_evaluator: A function to evaluate the symbolic tree.
55 |
56 | Returns:
57 | Array: The mean loss (1 - accuracy) on the validation set.
58 | """
59 | data_keys, test_keys, network_keys = data
60 | losses = jax.vmap(self.train, in_axes=[None, 0, 0, 0, None])(candidate, data_keys, test_keys, network_keys, tree_evaluator)
61 | return jnp.mean(losses)
62 |
63 | def get_data(self, key: jr.PRNGKey, n_samples: int = 50) -> Tuple[Array, Array]:
64 | """
65 | Generates XOR data.
66 |
67 | Args:
68 | key (jax.random.PRNGKey): Random key for data generation.
69 | n_samples (int): Number of samples to generate.
70 |
71 | Returns:
72 | tuple: A tuple containing the input data (x) and the target labels (y).
73 | """
74 | x = jr.uniform(key, shape=(n_samples, 2))
75 | y = jnp.logical_xor(x[:,0]>0.5, x[:,1]>0.5)
76 |
77 | return x, y[:,None]
78 |
79 | def loss_function(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array, y: Array, candidate: str, tree_evaluator: Callable) -> Array:
80 | """
81 | Computes the loss with an evolved loss function for a given set of parameters and data.
82 |
83 | Args:
84 | params (tuple): The parameters of the neural network.
85 | x (Array): The input data.
86 | y (Array): The target labels.
87 | candidate: The candidate loss function (symbolic tree).
88 | tree_evaluator: A function to evaluate the symbolic tree.
89 |
90 | Returns:
91 | Array: The mean loss.
92 | """
93 | pred = self.neural_network(params, x)
94 | return jnp.mean(jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, jnp.concatenate([pred, y], axis=-1)))
95 |
96 | def train(self, candidate: str, data_key: jr.PRNGKey, test_key: jr.PRNGKey, network_key: jr.PRNGKey, tree_evaluator: Callable) -> Array:
97 | """
98 | Trains a neural network with a given candidate loss function.
99 |
100 | Args:
101 | candidate: The candidate loss function (symbolic tree).
102 | data_key (jax.random.PRNGKey): Random key for data generation during training.
103 | test_key (jax.random.PRNGKey): Random key for data generation during testing.
104 | network_key (jax.random.PRNGKey): Random key for initializing the network parameters.
105 | tree_evaluator: A function to evaluate the symbolic tree.
106 |
107 | Returns:
108 | Array: The validation loss (1 - accuracy).
109 | """
110 | params = self.init_network_params(network_key)
111 |
112 | optim_state = self.optim.init(params)
113 |
114 | def step(i: int, carry: Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]) -> Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]:
115 | params, optim_state, key = carry
116 |
117 | key, _key = jr.split(key)
118 |
119 | x_train, y_train = self.get_data(_key, n_samples=50)
120 |
121 | # Evaluate network parameters and compute gradients
122 | grads = jax.grad(self.loss_function)(params, x_train, y_train, candidate, tree_evaluator)
123 |
124 | # Update parameters
125 | updates, optim_state = self.optim.update(grads, optim_state, params)
126 | params = optax.apply_updates(params, updates)
127 |
128 | return (params, optim_state, key)
129 |
130 | (params, _, _) = jax.lax.fori_loop(0, self.epochs, step, (params, optim_state, data_key))
131 |
132 | # Evaluate parameters on test set
133 | x_test, y_test = self.get_data(test_key, n_samples=500)
134 |
135 | pred = self.neural_network(params, x_test)
136 | return 1 - jnp.mean(y_test==(pred>0.5)) # Return 1 - accuracy
137 |
138 | def neural_network(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array) -> Array:
139 | """
140 | Defines the neural network architecture (forward pass).
141 |
142 | Args:
143 | params (tuple): The parameters of the neural network.
144 | x (Array): The input data.
145 |
146 | Returns:
147 | Array: The output of the neural network.
148 | """
149 | w1, b1, w2, b2, w3, b3 = params
150 | hidden = jnp.tanh(jnp.dot(x, w1) + b1)
151 | hidden = jnp.tanh(jnp.dot(hidden, w2) + b2)
152 | output = jnp.dot(hidden, w3) + b3
153 | return jax.nn.sigmoid(output)
154 |
155 | def init_network_params(self, key: jr.PRNGKey) -> Tuple[Array, Array, Array, Array, Array, Array]:
156 | """
157 | Initializes the parameters of the neural network.
158 |
159 | Args:
160 | key (jax.random.PRNGKey): Random key for parameter initialization.
161 |
162 | Returns:
163 | tuple: A tuple containing the initialized weights and biases.
164 | """
165 | key1, key2, key3 = jr.split(key, 3)
166 | w1 = jr.normal(key1, (self.input_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.input_dim)
167 | b1 = jnp.zeros(self.hidden_dim)
168 | w2 = jr.normal(key2, (self.hidden_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.hidden_dim)
169 | b2 = jnp.zeros(self.hidden_dim)
170 | w3 = jr.normal(key3, (self.hidden_dim, self.output_dim)) * jnp.sqrt(2.0 / self.hidden_dim)
171 | b3 = jnp.zeros(self.output_dim)
172 | return (w1, b1, w2, b2, w3, b3)
173 |
174 | """
175 | To make sure the optimized loss function generalizes, a batch of neural networks are trained with different data and weight initialization.
176 | For this purpose, a batch of keys for initialization, data sampling and validation data are generated.
177 | """
178 |
179 | def generate_keys(key, batch_size=4):
180 | key1, key2, key3 = jr.split(key, 3)
181 | return jr.split(key1, batch_size), jr.split(key2, batch_size), jr.split(key3, batch_size)
182 |
183 | """
184 | Here we define the hyperparameters and inputs to the genetic programming algorithm.
185 | The inputs to the trees are the prediction and target value.
186 | """
187 |
188 | if __name__ == "__main__":
189 | key = jr.PRNGKey(0)
190 | data_key, gp_key = jr.split(key)
191 |
192 | population_size = 50
193 | num_populations = 5
194 | num_generations = 25
195 |
196 | operator_list = [("+", lambda x, y: jnp.add(x, y), 2, 0.5),
197 | ("-", lambda x, y: jnp.subtract(x, y), 2, 0.5),
198 | ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5),
199 | ("log", lambda x: jnp.log(x + 1e-7), 1, 0.1),
200 | ]
201 |
202 | variable_list = [["pred", "y"]]
203 |
204 | input_dim = 2
205 | hidden_dim = 16
206 | output_dim = 1
207 |
208 | fitness_function = FitnessFunction(input_dim, hidden_dim, output_dim, learning_rate=0.01, epochs=100)
209 |
210 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, num_populations = num_populations)
211 |
212 | data_keys, test_keys, network_keys = generate_keys(data_key)
213 |
214 | strategy.fit(gp_key, (data_keys, test_keys, network_keys), verbose=True)
--------------------------------------------------------------------------------
/examples/example_scripts/symbolic_regression_dynamical_system.py:
--------------------------------------------------------------------------------
1 | """
2 | # Symbolic regression of a dynamical system
3 |
4 | In this example, Kozax is applied to recover the state equations of the Lotka-Volterra system. The candidate solutions are integrated as a system of differential equations, after which the
5 | predictions are compared to the true observations to determine a fitness score.
6 | """
7 |
8 | # Specify the cores to use for XLA
9 | import os
10 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10'
11 |
12 | import jax
13 | import diffrax
14 | import jax.numpy as jnp
15 | import jax.random as jr
16 | import diffrax
17 |
18 | from kozax.genetic_programming import GeneticProgramming
19 | from kozax.fitness_functions.ODE_fitness_function import ODEFitnessFunction
20 | from kozax.environments.SR_environments.lotka_volterra import LotkaVolterra
21 |
22 | """
23 | First the data is generated, consisting of initial conditions, time points and the true observations. Kozax provides the Lotka-Volterra environment, which is integrated with Diffrax.
24 | """
25 |
26 | def get_data(key, env, dt, T, batch_size=20):
27 | x0s = env.sample_init_states(batch_size, key)
28 | ts = jnp.arange(0, T, dt)
29 |
30 | def solve(env, ts, x0):
31 | solver = diffrax.Dopri5()
32 | dt0 = 0.001
33 | saveat = diffrax.SaveAt(ts=ts)
34 |
35 | system = diffrax.ODETerm(env.drift)
36 |
37 | # Solve the system given an initial conditions
38 | sol = diffrax.diffeqsolve(system, solver, ts[0], ts[-1], dt0, x0, saveat=saveat, max_steps=500,
39 | adjoint=diffrax.DirectAdjoint(), stepsize_controller=diffrax.PIDController(atol=1e-7, rtol=1e-7, dtmin=0.001))
40 |
41 | return sol.ys
42 |
43 | ys = jax.vmap(solve, in_axes=[None, None, 0])(env, ts, x0s) #Parallelize over the batch dimension
44 |
45 | return x0s, ts, ys
46 |
47 | """
48 | For the fitness function, we used the ODEFitnessFunction that uses Diffrax to integrate candidate solutions. It is possible to select the solver, time step, number of steps and a
49 | stepsize controller to balance efficiency and accuracy. To ensure convergence of the genetic programming algorithm, constant optimization is applied to the best candidates at every
50 | generation. The constant optimization is performed with a couple of simple evolutionary steps that adjust the values of the constants in a candidate. The hyperparameters that define the
51 | constant optimization are `constant_optimization_N_offspring`, `constant_optimization_steps`, `optimize_constants_elite` and `constant_step_size_init`.
52 | """
53 |
54 | if __name__ == "__main__":
55 | key = jr.PRNGKey(0)
56 | data_key, gp_key = jr.split(key)
57 |
58 | T = 30
59 | dt = 0.2
60 | env = LotkaVolterra()
61 |
62 | # Simulate the data
63 | data = get_data(data_key, env, dt, T, batch_size=4)
64 | x0s, ts, ys = data
65 |
66 | #Define the nodes and hyperparameters
67 | operator_list = [
68 | ("+", lambda x, y: jnp.add(x, y), 2, 0.5),
69 | ("-", lambda x, y: jnp.subtract(x, y), 2, 0.1),
70 | ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5),
71 | ]
72 |
73 | variable_list = [["x" + str(i) for i in range(env.n_var)]]
74 | layer_sizes = jnp.array([env.n_var])
75 |
76 | population_size = 100
77 | num_populations = 10
78 | num_generations = 50
79 |
80 | #Initialize the fitness function and the genetic programming strategy
81 | fitness_function = ODEFitnessFunction(solver=diffrax.Dopri5(), dt0 = 0.01, stepsize_controller=diffrax.PIDController(atol=1e-6, rtol=1e-6, dtmin=0.001), max_steps=300)
82 |
83 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,
84 | size_parsimony=0.003, constant_optimization_method="evolution", constant_optimization_N_offspring = 25, constant_optimization_steps = 5,
85 | optimize_constants_elite=100, constant_step_size_init=0.1, constant_step_size_decay=0.99)
86 | """
87 | Kozax provides a fit function that receives the data and a random key. However, it is also possible to run Kozax with an easy loop consisting of evaluating and evolving.
88 | This is useful as different input data can be provided during evaluation. In symbolic regression of dynamical systems, it helps to first optimize on a small part of the time points,
89 | and provide the full data trajectories only after a couple of generations.
90 | """
91 |
92 | # Sample the initial population
93 | population = strategy.initialize_population(gp_key)
94 |
95 | # Define the number of timepoints to include in the data
96 | end_ts = int(ts.shape[0]/2)
97 |
98 | for g in range(num_generations):
99 | if g == 25: # After 25 generations, use the full data
100 | end_ts = ts.shape[0]
101 |
102 | key, eval_key, sample_key = jr.split(key, 3)
103 | # Evaluate the population on the data, and return the fitness
104 | fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key)
105 |
106 | # Print the best solution in the population in this generation
107 | best_fitness, best_solution = strategy.get_statistics(g)
108 | print(f"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}")
109 |
110 | # Evolve the population until the last generation. The fitness should be given to the evolve function.
111 | if g < (num_generations-1):
112 | population = strategy.evolve_population(population, fitness, sample_key)
113 |
114 | strategy.print_pareto_front()
115 |
116 | """
117 | Instead of using evolution to optimize the constants, Kozax also offers gradient-based optimization. For gradient optimization, it is possible to specify the optimizer, the number of
118 | candidates to apply constant optimization to, the initial learning rate and the learning rate decay over generation. These two methods are provided as either can be more effective or
119 | efficient for different problems.
120 | """
121 |
122 | import optax
123 |
124 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,
125 | size_parsimony=0.003, constant_optimization_method="gradient", constant_optimization_steps = 15, optimizer_class = optax.adam,
126 | optimize_constants_elite=100, constant_step_size_init=0.025, constant_step_size_decay=0.95)
127 |
128 | key = jr.PRNGKey(0)
129 | data_key, gp_key = jr.split(key)
130 |
131 | T = 30
132 | dt = 0.2
133 | env = LotkaVolterra()
134 |
135 | # Simulate the data
136 | data = get_data(data_key, env, dt, T, batch_size=4)
137 | x0s, ts, ys = data
138 |
139 | # Sample the initial population
140 | population = strategy.initialize_population(gp_key)
141 |
142 | # Define the number of timepoints to include in the data
143 | end_ts = int(ts.shape[0]/2)
144 |
145 | for g in range(num_generations):
146 | if g == 25: # After 25 generations, use the full data
147 | end_ts = ts.shape[0]
148 |
149 | key, eval_key, sample_key = jr.split(key, 3)
150 | # Evaluate the population on the data, and return the fitness
151 | fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key)
152 |
153 | # Print the best solution in the population in this generation
154 | best_fitness, best_solution = strategy.get_statistics(g)
155 | print(f"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}")
156 |
157 | # Evolve the population until the last generation. The fitness should be given to the evolve function.
158 | if g < (num_generations-1):
159 | population = strategy.evolve_population(population, fitness, sample_key)
160 |
161 | strategy.print_pareto_front()
--------------------------------------------------------------------------------
/figures/applications.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdevries0/Kozax/95c96ea580c1c9f2aca15e1d29d4ea6aecc9c527/figures/applications.jpg
--------------------------------------------------------------------------------
/figures/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sdevries0/Kozax/95c96ea580c1c9f2aca15e1d29d4ea6aecc9c527/figures/logo.jpg
--------------------------------------------------------------------------------
/kozax/environments/SR_environments/lorenz_attractor.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jrandom
12 | from kozax.environments.SR_environments.time_series_environment_base import EnvironmentBase
13 | from jaxtyping import Array
14 | from typing import Tuple
15 |
16 | class LorenzAttractor(EnvironmentBase):
17 | """
18 | Lorenz Attractor environment for symbolic regression tasks.
19 |
20 | Parameters
21 | ----------
22 | process_noise : float, optional
23 | Standard deviation of the process noise. Default is 0.
24 |
25 | Attributes
26 | ----------
27 | init_mu : :class:`jax.Array`
28 | Mean of the initial state distribution.
29 | init_sd : float
30 | Standard deviation of the initial state distribution.
31 | sigma : float
32 | Parameter sigma of the Lorenz system.
33 | rho : float
34 | Parameter rho of the Lorenz system.
35 | beta : float
36 | Parameter beta of the Lorenz system.
37 | V : :class:`jax.Array`
38 | Process noise covariance matrix.
39 |
40 | Methods
41 | -------
42 | sample_init_states(batch_size, key)
43 | Samples initial states for the environment.
44 | drift(t, state, args)
45 | Computes the drift function for the environment.
46 | diffusion(t, state, args)
47 | Computes the diffusion function for the environment.
48 | """
49 |
50 | def __init__(self, process_noise: float = 0) -> None:
51 | n_var = 3
52 | super().__init__(n_var, process_noise)
53 |
54 | self.init_mu = jnp.array([1, 1, 1])
55 | self.init_sd = 0.1
56 |
57 | self.sigma = 10
58 | self.rho = 28
59 | self.beta = 8 / 3
60 | self.V = self.process_noise * jnp.eye(self.n_var)
61 |
62 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array:
63 | """
64 | Samples initial states for the environment.
65 |
66 | Parameters
67 | ----------
68 | batch_size : int
69 | Number of initial states to sample.
70 | key : :class:`jax.random.PRNGKey`
71 | Random key for sampling.
72 |
73 | Returns
74 | -------
75 | :class:`jax.Array`
76 | Initial states.
77 | """
78 | return self.init_mu + self.init_sd * jrandom.normal(key, shape=(batch_size, self.n_var))
79 |
80 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
81 | """
82 | Computes the drift function for the environment.
83 |
84 | Parameters
85 | ----------
86 | t : float
87 | Current time.
88 | state : :class:`jax.Array`
89 | Current state.
90 | args : tuple
91 | Additional arguments.
92 |
93 | Returns
94 | -------
95 | :class:`jax.Array`
96 | Drift.
97 | """
98 | return jnp.array([self.sigma * (state[1] - state[0]), state[0] * (self.rho - state[2]) - state[1], state[0] * state[1] - self.beta * state[2]])
99 |
100 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
101 | """
102 | Computes the diffusion function for the environment.
103 |
104 | Parameters
105 | ----------
106 | t : float
107 | Current time.
108 | state : :class:`jax.Array`
109 | Current state.
110 | args : tuple
111 | Additional arguments.
112 |
113 | Returns
114 | -------
115 | :class:`jax.Array`
116 | Diffusion.
117 | """
118 | return self.V
119 |
120 | class Lorenz96(EnvironmentBase):
121 | """
122 | Lorenz 96 environment for symbolic regression tasks.
123 |
124 | Parameters
125 | ----------
126 | n_dim : int, optional
127 | Number of dimensions. Default is 3.
128 | process_noise : float, optional
129 | Standard deviation of the process noise. Default is 0.
130 |
131 | Attributes
132 | ----------
133 | F : float
134 | Forcing term.
135 | init_mu : :class:`jax.Array`
136 | Mean of the initial state distribution.
137 | init_sd : float
138 | Standard deviation of the initial state distribution.
139 | V : :class:`jax.Array`
140 | Process noise covariance matrix.
141 |
142 | Methods
143 | -------
144 | sample_init_states(batch_size, key)
145 | Samples initial states for the environment.
146 | drift(t, state, args)
147 | Computes the drift function for the environment.
148 | diffusion(t, state, args)
149 | Computes the diffusion function for the environment.
150 | """
151 |
152 | def __init__(self, n_dim: int = 3, process_noise: float = 0) -> None:
153 | n_var = n_dim
154 | super().__init__(n_var, process_noise)
155 |
156 | self.F = 8
157 | self.init_mu = jnp.ones(self.n_var) * self.F
158 | self.init_sd = 0.1
159 |
160 | self.V = self.process_noise * jnp.eye(self.n_var)
161 |
162 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array:
163 | """
164 | Samples initial states for the environment.
165 |
166 | Parameters
167 | ----------
168 | batch_size : int
169 | Number of initial states to sample.
170 | key : :class:`jax.random.PRNGKey`
171 | Random key for sampling.
172 |
173 | Returns
174 | -------
175 | :class:`jax.Array`
176 | Initial states.
177 | """
178 | return self.init_mu + self.init_sd * jrandom.normal(key, shape=(batch_size, self.n_var))
179 |
180 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
181 | """
182 | Computes the drift function for the environment.
183 |
184 | Parameters
185 | ----------
186 | t : float
187 | Current time.
188 | state : :class:`jax.Array`
189 | Current state.
190 | args : tuple
191 | Additional arguments.
192 |
193 | Returns
194 | -------
195 | :class:`jax.Array`
196 | Drift.
197 | """
198 | f = lambda x_cur, x_next, x_prev1, x_prev2: (x_next - x_prev2) * x_prev1 - x_cur + self.F
199 | return jax.vmap(f)(state, jnp.roll(state, -1), jnp.roll(state, 1), jnp.roll(state, 2))
200 |
201 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
202 | """
203 | Computes the diffusion function for the environment.
204 |
205 | Parameters
206 | ----------
207 | t : float
208 | Current time.
209 | state : :class:`jax.Array`
210 | Current state.
211 | args : tuple
212 | Additional arguments.
213 |
214 | Returns
215 | -------
216 | :class:`jax.Array`
217 | Diffusion.
218 | """
219 | return self.V
--------------------------------------------------------------------------------
/kozax/environments/SR_environments/lotka_volterra.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jrandom
12 | from kozax.environments.SR_environments.time_series_environment_base import EnvironmentBase
13 | from jaxtyping import Array
14 | from typing import Tuple
15 |
16 | class LotkaVolterra(EnvironmentBase):
17 | """
18 | Lotka-Volterra environment for symbolic regression tasks.
19 |
20 | Parameters
21 | ----------
22 | process_noise : float, optional
23 | Standard deviation of the process noise. Default is 0.
24 |
25 | Attributes
26 | ----------
27 | init_mu : :class:`jax.Array`
28 | Mean of the initial state distribution.
29 | init_sd : float
30 | Standard deviation of the initial state distribution.
31 | alpha : float
32 | Parameter alpha of the Lotka-Volterra system.
33 | beta : float
34 | Parameter beta of the Lotka-Volterra system.
35 | delta : float
36 | Parameter delta of the Lotka-Volterra system.
37 | gamma : float
38 | Parameter gamma of the Lotka-Volterra system.
39 | V : :class:`jax.Array`
40 | Process noise covariance matrix.
41 |
42 | Methods
43 | -------
44 | sample_init_states(batch_size, key)
45 | Samples initial states for the environment.
46 | drift(t, state, args)
47 | Computes the drift function for the environment.
48 | diffusion(t, state, args)
49 | Computes the diffusion function for the environment.
50 | """
51 |
52 | def __init__(self, process_noise: float = 0) -> None:
53 | n_var = 2
54 | super().__init__(n_var, process_noise)
55 |
56 | self.init_mu = jnp.array([10, 10])
57 | self.init_sd = 2
58 |
59 | self.alpha = 1.1
60 | self.beta = 0.4
61 | self.delta = 0.1
62 | self.gamma = 0.4
63 | self.V = self.process_noise * jnp.eye(self.n_var)
64 |
65 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array:
66 | """
67 | Samples initial states for the environment.
68 |
69 | Parameters
70 | ----------
71 | batch_size : int
72 | Number of initial states to sample.
73 | key : :class:`jax.random.PRNGKey`
74 | Random key for sampling.
75 |
76 | Returns
77 | -------
78 | :class:`jax.Array`
79 | Initial states.
80 | """
81 | return jrandom.uniform(key, shape=(batch_size, self.n_var), minval=5, maxval=15)
82 |
83 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
84 | """
85 | Computes the drift function for the environment.
86 |
87 | Parameters
88 | ----------
89 | t : float
90 | Current time.
91 | state : :class:`jax.Array`
92 | Current state.
93 | args : tuple
94 | Additional arguments.
95 |
96 | Returns
97 | -------
98 | :class:`jax.Array`
99 | Drift.
100 | """
101 | return jnp.array([self.alpha * state[0] - self.beta * state[0] * state[1], self.delta * state[0] * state[1] - self.gamma * state[1]])
102 |
103 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
104 | """
105 | Computes the diffusion function for the environment.
106 |
107 | Parameters
108 | ----------
109 | t : float
110 | Current time.
111 | state : :class:`jax.Array`
112 | Current state.
113 | args : tuple
114 | Additional arguments.
115 |
116 | Returns
117 | -------
118 | :class:`jax.Array`
119 | Diffusion.
120 | """
121 | return self.V
--------------------------------------------------------------------------------
/kozax/environments/SR_environments/time_series_environment_base.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | from jax.random import PRNGKey
10 | import abc
11 | from typing import Tuple, Any
12 | from jaxtyping import Array
13 |
14 | class EnvironmentBase(abc.ABC):
15 | """
16 | Abstract base class for time series environments in symbolic regression tasks.
17 |
18 | Parameters
19 | ----------
20 | n_var : int
21 | Number of variables in the state.
22 |
23 | Methods
24 | -------
25 | sample_init_states(batch_size, key)
26 | Samples initial states for the environment.
27 | drift(t, state, args)
28 | Computes the drift function for the environment.
29 | diffusion(t, state, args)
30 | Computes the diffusion function for the environment.
31 | terminate_event(state, **kwargs)
32 | Checks if the termination condition is met.
33 | """
34 |
35 | def __init__(self, n_var: int, process_noise: float) -> None:
36 | self.n_var = n_var
37 | self.process_noise = process_noise
38 |
39 | @abc.abstractmethod
40 | def sample_init_states(self, batch_size: int, key: PRNGKey) -> Any:
41 | """
42 | Samples initial states for the environment.
43 |
44 | Parameters
45 | ----------
46 | batch_size : int
47 | Number of initial states to sample.
48 | key : :class:`jax.random.PRNGKey`
49 | Random key for sampling.
50 |
51 | Returns
52 | -------
53 | Any
54 | Initial states.
55 | """
56 | raise NotImplementedError
57 |
58 | @abc.abstractmethod
59 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
60 | """
61 | Computes the drift function for the environment.
62 |
63 | Parameters
64 | ----------
65 | t : float
66 | Current time.
67 | state : :class:`jax.Array`
68 | Current state.
69 | args : tuple
70 | Additional arguments.
71 |
72 | Returns
73 | -------
74 | :class:`jax.Array`
75 | Drift.
76 | """
77 | raise NotImplementedError
78 |
79 | @abc.abstractmethod
80 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
81 | """
82 | Computes the diffusion function for the environment.
83 |
84 | Parameters
85 | ----------
86 | t : float
87 | Current time.
88 | state : :class:`jax.Array`
89 | Current state.
90 | args : tuple
91 | Additional arguments.
92 |
93 | Returns
94 | -------
95 | :class:`jax.Array`
96 | Diffusion.
97 | """
98 | raise NotImplementedError
99 |
100 | def terminate_event(self, state: Array, **kwargs) -> bool:
101 | """
102 | Checks if the termination condition is met.
103 |
104 | Parameters
105 | ----------
106 | state : :class:`jax.Array`
107 | Current state.
108 | kwargs : dict
109 | Additional arguments.
110 |
111 | Returns
112 | -------
113 | bool
114 | True if the termination condition is met, False otherwise.
115 | """
116 | return False
--------------------------------------------------------------------------------
/kozax/environments/SR_environments/vd_pol_oscillator.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jrandom
12 | from kozax.environments.SR_environments.time_series_environment_base import EnvironmentBase
13 | from jaxtyping import Array
14 | from typing import Tuple
15 |
16 | class VanDerPolOscillator(EnvironmentBase):
17 | """
18 | Van der Pol Oscillator environment for symbolic regression tasks.
19 |
20 | Parameters
21 | ----------
22 | process_noise : float, optional
23 | Standard deviation of the process noise. Default is 0.
24 |
25 | Attributes
26 | ----------
27 | init_mu : :class:`jax.Array`
28 | Mean of the initial state distribution.
29 | init_sd : :class:`jax.Array`
30 | Standard deviation of the initial state distribution.
31 | mu : float
32 | Parameter mu of the Van der Pol system.
33 | V : :class:`jax.Array`
34 | Process noise covariance matrix.
35 |
36 | Methods
37 | -------
38 | sample_init_states(batch_size, key)
39 | Samples initial states for the environment.
40 | drift(t, state, args)
41 | Computes the drift function for the environment.
42 | diffusion(t, state, args)
43 | Computes the diffusion function for the environment.
44 | """
45 |
46 | def __init__(self, process_noise: float = 0) -> None:
47 | n_var = 2
48 | super().__init__(n_var, process_noise)
49 |
50 | self.init_mu = jnp.array([0, 0])
51 | self.init_sd = jnp.array([1.0, 1.0])
52 |
53 | self.mu = 1
54 | self.V = self.process_noise * jnp.eye(self.n_var)
55 |
56 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array:
57 | """
58 | Samples initial states for the environment.
59 |
60 | Parameters
61 | ----------
62 | batch_size : int
63 | Number of initial states to sample.
64 | key : :class:`jax.random.PRNGKey`
65 | Random key for sampling.
66 |
67 | Returns
68 | -------
69 | :class:`jax.Array`
70 | Initial states.
71 | """
72 | return self.init_mu + self.init_sd * jrandom.normal(key, shape=(batch_size, self.n_var))
73 |
74 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
75 | """
76 | Computes the drift function for the environment.
77 |
78 | Parameters
79 | ----------
80 | t : float
81 | Current time.
82 | state : :class:`jax.Array`
83 | Current state.
84 | args : tuple
85 | Additional arguments.
86 |
87 | Returns
88 | -------
89 | :class:`jax.Array`
90 | Drift.
91 | """
92 | return jnp.array([state[1], self.mu * (1 - state[0]**2) * state[1] - state[0]])
93 |
94 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
95 | """
96 | Computes the diffusion function for the environment.
97 |
98 | Parameters
99 | ----------
100 | t : float
101 | Current time.
102 | state : :class:`jax.Array`
103 | Current state.
104 | args : tuple
105 | Additional arguments.
106 |
107 | Returns
108 | -------
109 | :class:`jax.Array`
110 | Diffusion.
111 | """
112 | return self.V
--------------------------------------------------------------------------------
/kozax/environments/control_environments/acrobot.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jrandom
12 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase
13 | from jaxtyping import Array
14 | from typing import Tuple
15 |
16 | class Acrobot(EnvironmentBase):
17 | """
18 | Acrobot environment for control tasks.
19 |
20 | Parameters
21 | ----------
22 | process_noise : float
23 | Standard deviation of the process noise.
24 | obs_noise : float
25 | Standard deviation of the observation noise.
26 | n_obs : int, optional
27 | Number of observations. Default is 4.
28 |
29 | Attributes
30 | ----------
31 | n_var : int
32 | Number of variables in the state.
33 | n_control_inputs : int
34 | Number of control inputs.
35 | n_targets : int
36 | Number of targets.
37 | n_dim : int
38 | Number of dimensions.
39 | init_bounds : :class:`jax.Array`
40 | Bounds for initial state sampling.
41 | R : :class:`jax.Array`
42 | Control cost matrix.
43 |
44 | Methods
45 | -------
46 | sample_init_states(batch_size, key)
47 | Samples initial states for the environment.
48 | sample_params(batch_size, mode, ts, key)
49 | Samples parameters for the environment.
50 | f_obs(key, t_x)
51 | Computes the observation function.
52 | initialize_parameters(params, ts)
53 | Initializes the parameters of the environment.
54 | drift(t, state, args)
55 | Computes the drift function for the environment.
56 | diffusion(t, state, args)
57 | Computes the diffusion function for the environment.
58 | fitness_function(state, control, target, ts)
59 | Computes the fitness function for the environment.
60 | cond_fn_nan(t, y, args, **kwargs)
61 | Checks for NaN or infinite values in the state.
62 | """
63 |
64 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 4) -> None:
65 | self.n_var = 4
66 | self.n_control_inputs = 1
67 | self.n_targets = 0
68 | self.n_dim = 1
69 | self.init_bounds = jnp.array([0.1, 0.1, 0.1, 0.1])
70 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs)
71 |
72 | self.R = 0.1 * jnp.eye(self.n_control_inputs)
73 |
74 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]:
75 | """
76 | Samples initial states for the environment.
77 |
78 | Parameters
79 | ----------
80 | batch_size : int
81 | Number of initial states to sample.
82 | key : :class:`jax.random.PRNGKey`
83 | Random key for sampling.
84 |
85 | Returns
86 | -------
87 | x0 : :class:`jax.Array`
88 | Initial states.
89 | targets : :class:`jax.Array`
90 | Target states.
91 | """
92 | init_key, target_key = jrandom.split(key)
93 | x0 = jrandom.uniform(init_key, shape=(batch_size, self.n_var), minval=-self.init_bounds, maxval=self.init_bounds)
94 | targets = jnp.zeros((batch_size, self.n_targets))
95 | return x0, targets
96 |
97 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Tuple[Array, Array, Array, Array]:
98 | """
99 | Samples parameters for the environment.
100 |
101 | Parameters
102 | ----------
103 | batch_size : int
104 | Number of parameters to sample.
105 | mode : str
106 | Mode for sampling parameters.
107 | ts : :class:`jax.Array`
108 | Time steps.
109 | key : :class:`jax.random.PRNGKey`
110 | Random key for sampling.
111 |
112 | Returns
113 | -------
114 | tuple of :class:`jax.Array`
115 | Sampled parameters.
116 | """
117 | l1 = l2 = m1 = m2 = jnp.ones((batch_size))
118 | return l1, l2, m1, m2
119 |
120 | def f_obs(self, key: jrandom.PRNGKey, t_x: Tuple[float, Array]) -> Tuple[jrandom.PRNGKey, Array]:
121 | """
122 | Computes the observation function.
123 |
124 | Parameters
125 | ----------
126 | key : :class:`jax.random.PRNGKey`
127 | Random key for sampling noise.
128 | t_x : tuple of (float, :class:`jax.Array`)
129 | Tuple containing the current time and state.
130 |
131 | Returns
132 | -------
133 | key : :class:`jax.random.PRNGKey`
134 | Updated random key.
135 | out : :class:`jax.Array`
136 | Observation.
137 | """
138 | _, out = super().f_obs(key, t_x)
139 | out = jnp.array([(out[0] + jnp.pi) % (2 * jnp.pi) - jnp.pi, (out[1] + jnp.pi) % (2 * jnp.pi) - jnp.pi, out[2], out[3]])[:self.n_obs]
140 | return key, out
141 |
142 | def initialize_parameters(self, params: Tuple[Array, Array, Array, Array], ts: Array) -> None:
143 | """
144 | Initializes the parameters of the environment.
145 |
146 | Parameters
147 | ----------
148 | params : tuple of :class:`jax.Array`
149 | Parameters to initialize.
150 | ts : :class:`jax.Array`
151 | Time steps.
152 | """
153 | l1, l2, m1, m2 = params
154 | self.l1 = l1 # [m]
155 | self.l2 = l2 # [m]
156 | self.m1 = m1 #: [kg] mass of link 1
157 | self.m2 = m2 #: [kg] mass of link 2
158 | self.lc1 = 0.5 * self.l1 #: [m] position of the center of mass of link 1
159 | self.lc2 = 0.5 * self.l2 #: [m] position of the center of mass of link 2
160 | self.moi1 = self.moi2 = 1.0
161 | self.g = 9.81
162 |
163 | self.G = jnp.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
164 | self.V = self.process_noise * self.G
165 |
166 | self.C = jnp.eye(self.n_var)[:self.n_obs]
167 | self.W = self.obs_noise * jnp.eye(self.n_obs)
168 |
169 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
170 | """
171 | Computes the drift function for the environment.
172 |
173 | Parameters
174 | ----------
175 | t : float
176 | Current time.
177 | state : :class:`jax.Array`
178 | Current state.
179 | args : tuple
180 | Additional arguments.
181 |
182 | Returns
183 | -------
184 | :class:`jax.Array`
185 | Drift.
186 | """
187 | control = jnp.squeeze(args)
188 | control = jnp.clip(control, -1, 1)
189 | theta1, theta2, theta1_dot, theta2_dot = state
190 |
191 | d1 = self.m1 * self.lc1**2 + self.m2 * (self.l1**2 + self.lc2**2 + 2 * self.l1 * self.lc2 * jnp.cos(theta2)) + self.moi1 + self.moi2
192 | d2 = self.m2 * (self.lc2**2 + self.l1 * self.lc2 * jnp.cos(theta2)) + self.moi2
193 |
194 | phi2 = self.m2 * self.lc2 * self.g * jnp.cos(theta1 + theta2 - jnp.pi/2)
195 | phi1 = -self.m2 * self.l1 * self.lc2 * theta2_dot**2 * jnp.sin(theta2) - 2 * self.m2 * self.l1 * self.lc2 * theta1_dot * theta2_dot * jnp.sin(theta1) \
196 | + (self.m1 * self.lc1 + self.m2 * self.l1) * self.g * jnp.cos(theta1 - jnp.pi/2) + phi2
197 |
198 | if self.n_control_inputs == 1:
199 | theta2_acc = (control + d2/d1 * phi1 - self.m2 * self.l1 * self.lc2 * theta1_dot**2 * jnp.sin(theta2) - phi2) \
200 | / (self.m2 * self.lc2**2 + self.moi2 - d2**2 / d1)
201 | theta1_acc = -(d2 * theta2_acc + phi1)/d1
202 | else:
203 | c1, c2 = control
204 | theta2_acc = (c1 + d2/d1 * phi1 - self.m2 * self.l1 * self.lc2 * theta1_dot**2 * jnp.sin(theta2) - phi2) \
205 | / (self.m2 * self.lc2**2 + self.moi2 - d2**2 / d1)
206 | theta1_acc = (c2 - d2 * theta2_acc - phi1)/d1
207 |
208 | return jnp.array([
209 | theta1_dot,
210 | theta2_dot,
211 | theta1_acc,
212 | theta2_acc
213 | ])
214 |
215 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
216 | """
217 | Computes the diffusion function for the environment.
218 |
219 | Parameters
220 | ----------
221 | t : float
222 | Current time.
223 | state : :class:`jax.Array`
224 | Current state.
225 | args : tuple
226 | Additional arguments.
227 |
228 | Returns
229 | -------
230 | :class:`jax.Array`
231 | Diffusion.
232 | """
233 | return self.V
234 |
235 | def fitness_function(self, state: Array, control: Array, target: Array, ts: Array) -> float:
236 | """
237 | Computes the fitness function for the environment.
238 |
239 | Parameters
240 | ----------
241 | state : :class:`jax.Array`
242 | Current state.
243 | control : :class:`jax.Array`
244 | Control inputs.
245 | target : :class:`jax.Array`
246 | Target states.
247 | ts : :class:`jax.Array`
248 | Time steps.
249 |
250 | Returns
251 | -------
252 | float
253 | Fitness value.
254 | """
255 | reached_threshold = jax.vmap(lambda theta1, theta2: -jnp.cos(theta1) - jnp.cos(theta1 + theta2) > 1.5)(state[:,0], state[:,1])
256 | first_success = jnp.argmax(reached_threshold)
257 |
258 | control = jnp.clip(control, -1, 1)
259 |
260 | control_cost = jax.vmap(lambda _state, _u: _u @ self.R @ _u)(state, control)
261 | costs = jnp.where((ts / (ts[1] - ts[0])) > first_success, jnp.zeros_like(control_cost), control_cost)
262 |
263 | return (first_success + (first_success == 0) * ts.shape[0] + jnp.sum(costs))
264 |
265 | def cond_fn_nan(self, t: float, y: Array, args: Tuple, **kwargs) -> float:
266 | """
267 | Checks for NaN or infinite values in the state.
268 |
269 | Parameters
270 | ----------
271 | t : float
272 | Current time.
273 | y : :class:`jax.Array`
274 | Current state.
275 | args : tuple
276 | Additional arguments.
277 | kwargs : dict
278 | Additional keyword arguments.
279 |
280 | Returns
281 | -------
282 | float
283 | -1.0 if NaN or infinite values are found, 1.0 otherwise.
284 | """
285 | return jnp.where((jnp.abs(y[2]) > (4 * jnp.pi)) | (jnp.abs(y[3]) > (9 * jnp.pi)) | jnp.any(jnp.isnan(y)) | jnp.any(jnp.isinf(y)), -1.0, 1.0)
--------------------------------------------------------------------------------
/kozax/environments/control_environments/cart_pole.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jrandom
12 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase
13 | from jaxtyping import Array
14 | from typing import Tuple
15 |
16 | class CartPole(EnvironmentBase):
17 | """
18 | CartPole environment for control tasks.
19 |
20 | Parameters
21 | ----------
22 | process_noise : float
23 | Standard deviation of the process noise.
24 | obs_noise : float
25 | Standard deviation of the observation noise.
26 | n_obs : int, optional
27 | Number of observations. Default is 4.
28 |
29 | Attributes
30 | ----------
31 | n_var : int
32 | Number of variables in the state.
33 | n_control_inputs : int
34 | Number of control inputs.
35 | n_targets : int
36 | Number of targets.
37 | n_dim : int
38 | Number of dimensions.
39 | init_bounds : :class:`jax.Array`
40 | Bounds for initial state sampling.
41 | Q : :class:`jax.Array`
42 | Process noise covariance matrix.
43 | R : :class:`jax.Array`
44 | Observation noise covariance matrix.
45 |
46 | Methods
47 | -------
48 | sample_init_states(batch_size, key)
49 | Samples initial states for the environment.
50 | sample_params(batch_size, mode, ts, key)
51 | Samples parameters for the environment.
52 | initialize_parameters(params, ts)
53 | Initializes the parameters of the environment.
54 | drift(t, state, args)
55 | Computes the drift function for the environment.
56 | diffusion(t, state, args)
57 | Computes the diffusion function for the environment.
58 | fitness_function(state, control, target, ts)
59 | Computes the fitness function for the environment.
60 | terminate_event(state, **kwargs)
61 | Checks if the termination condition is met.
62 | """
63 |
64 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 4) -> None:
65 | self.n_var = 4
66 | self.n_control_inputs = 1
67 | self.n_targets = 0
68 | self.n_dim = 1
69 | self.init_bounds = jnp.array([0.05, 0.05, 0.05, 0.05])
70 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs)
71 |
72 | self.Q = jnp.array(0)
73 | self.R = jnp.array([[0.0]])
74 |
75 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]:
76 | """
77 | Samples initial states for the environment.
78 |
79 | Parameters
80 | ----------
81 | batch_size : int
82 | Number of initial states to sample.
83 | key : :class:`jax.random.PRNGKey`
84 | Random key for sampling.
85 |
86 | Returns
87 | -------
88 | x0 : :class:`jax.Array`
89 | Initial states.
90 | targets : :class:`jax.Array`
91 | Target states.
92 | """
93 | init_key, target_key = jrandom.split(key)
94 | x0 = jrandom.uniform(init_key, shape=(batch_size, self.n_var), minval=-self.init_bounds, maxval=self.init_bounds)
95 | targets = jnp.zeros((batch_size, self.n_targets))
96 | return x0, targets
97 |
98 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Array:
99 | """
100 | Samples parameters for the environment.
101 |
102 | Parameters
103 | ----------
104 | batch_size : int
105 | Number of parameters to sample.
106 | mode : str
107 | Mode for sampling parameters.
108 | ts : :class:`jax.Array`
109 | Time steps.
110 | key : :class:`jax.random.PRNGKey`
111 | Random key for sampling.
112 |
113 | Returns
114 | -------
115 | :class:`jax.Array`
116 | Sampled parameters.
117 | """
118 | return jnp.zeros((batch_size))
119 |
120 | def initialize_parameters(self, params: Array, ts: Array) -> None:
121 | """
122 | Initializes the parameters of the environment.
123 |
124 | Parameters
125 | ----------
126 | params : :class:`jax.Array`
127 | Parameters to initialize.
128 | ts : :class:`jax.Array`
129 | Time steps.
130 | """
131 | _ = params
132 | self.g = 9.81
133 | self.pole_mass = 0.1
134 | self.pole_length = 0.5
135 | self.cart_mass = 1
136 |
137 | self.G = jnp.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]])
138 | self.V = self.process_noise * self.G
139 |
140 | self.C = jnp.eye(self.n_var)[:self.n_obs]
141 | self.W = self.obs_noise * jnp.eye(self.n_obs)
142 |
143 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
144 | """
145 | Computes the drift function for the environment.
146 |
147 | Parameters
148 | ----------
149 | t : float
150 | Current time.
151 | state : :class:`jax.Array`
152 | Current state.
153 | args : tuple
154 | Additional arguments.
155 |
156 | Returns
157 | -------
158 | :class:`jax.Array`
159 | Drift.
160 | """
161 | control = jnp.squeeze(args)
162 | control = jnp.clip(control, -1, 1)
163 | x, theta, x_dot, theta_dot = state
164 |
165 | cos_theta = jnp.cos(theta)
166 | sin_theta = jnp.sin(theta)
167 |
168 | theta_acc = (self.g * sin_theta - cos_theta * (
169 | control + self.pole_mass * self.pole_length * theta_dot**2 * sin_theta
170 | ) / (self.cart_mass + self.pole_mass)) / (
171 | self.pole_length * (4/3 - (self.pole_mass * cos_theta**2) / (self.cart_mass + self.pole_mass)))
172 |
173 | x_acc = (control + self.pole_mass * self.pole_length * (theta_dot**2 * sin_theta - theta_acc * cos_theta)) / (self.cart_mass + self.pole_mass)
174 |
175 | return jnp.array([
176 | x_dot,
177 | theta_dot,
178 | x_acc,
179 | theta_acc
180 | ])
181 |
182 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
183 | """
184 | Computes the diffusion function for the environment.
185 |
186 | Parameters
187 | ----------
188 | t : float
189 | Current time.
190 | state : :class:`jax.Array`
191 | Current state.
192 | args : tuple
193 | Additional arguments.
194 |
195 | Returns
196 | -------
197 | :class:`jax.Array`
198 | Diffusion.
199 | """
200 | return self.V
201 |
202 | def fitness_function(self, state: Array, control: Array, target: Array, ts: Array) -> float:
203 | """
204 | Computes the fitness function for the environment.
205 |
206 | Parameters
207 | ----------
208 | state : :class:`jax.Array`
209 | Current state.
210 | control : :class:`jax.Array`
211 | Control inputs.
212 | target : :class:`jax.Array`
213 | Target states.
214 | ts : :class:`jax.Array`
215 | Time steps.
216 |
217 | Returns
218 | -------
219 | float
220 | Fitness value.
221 | """
222 | invalid_points = jax.vmap(lambda _x, _u: jnp.any(jnp.isinf(_x)) + jnp.isnan(_u))(state, control[:, 0])
223 | punishment = jnp.ones_like(invalid_points)
224 |
225 | costs = jnp.where(invalid_points, punishment, jnp.zeros_like(punishment))
226 |
227 | return jnp.sum(costs)
228 |
229 | def terminate_event(self, state: Array, **kwargs) -> bool:
230 | """
231 | Checks if the termination condition is met.
232 |
233 | Parameters
234 | ----------
235 | state : :class:`jax.Array`
236 | Current state.
237 | kwargs : dict
238 | Additional arguments.
239 |
240 | Returns
241 | -------
242 | bool
243 | True if the termination condition is met, False otherwise.
244 | """
245 | return (jnp.abs(state[1]) > 0.2) | (jnp.abs(state[0]) > 4.8) | jnp.any(jnp.isnan(state)) | jnp.any(jnp.isinf(state))
--------------------------------------------------------------------------------
/kozax/environments/control_environments/control_environment_base.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jrandom
12 | import abc
13 |
14 | _itemsize_kind_type = {
15 | (1, "i"): jnp.int8,
16 | (2, "i"): jnp.int16,
17 | (4, "i"): jnp.int32,
18 | (8, "i"): jnp.int64,
19 | (2, "f"): jnp.float16,
20 | (4, "f"): jnp.float32,
21 | (8, "f"): jnp.float64,
22 | }
23 |
24 | def force_bitcast_convert_type(val, new_type=jnp.int32):
25 | val = jnp.asarray(val)
26 | intermediate_type = _itemsize_kind_type[new_type.dtype.itemsize, val.dtype.kind]
27 | val = val.astype(intermediate_type)
28 | return jax.lax.bitcast_convert_type(val, new_type)
29 |
30 | class EnvironmentBase(abc.ABC):
31 | def __init__(self, process_noise, obs_noise, n_var, n_control_inputs, n_dim, n_obs):
32 | self.process_noise = process_noise
33 | self.obs_noise = obs_noise
34 | self.n_var = n_var
35 | self.n_control_inputs = n_control_inputs
36 | self.n_dim = n_dim
37 | self.n_obs = n_obs
38 |
39 | @abc.abstractmethod
40 | def initialize_parameters(self, params, ts):
41 | raise NotImplementedError
42 |
43 | @abc.abstractmethod
44 | def sample_init_states(self, batch_size, key):
45 | raise NotImplementedError
46 |
47 | @abc.abstractmethod
48 | def sample_params(self, batch_size, mode, ts, key):
49 | raise NotImplementedError
50 |
51 | def f_obs(self, key, t_x):
52 | t, x = t_x
53 | new_key = jrandom.fold_in(key, force_bitcast_convert_type(t))
54 | out = self.C@x + jrandom.normal(new_key, shape=(self.n_obs,))@self.W
55 | return key, out
56 |
57 | @abc.abstractmethod
58 | def drift(self, t, state, args):
59 | raise NotImplementedError
60 |
61 | @abc.abstractmethod
62 | def diffusion(self, t, state, args):
63 | raise NotImplementedError
64 |
65 | @abc.abstractmethod
66 | def fitness_function(self, state, control, target, ts):
67 | raise NotImplementedError
68 |
69 | def terminate_event(self, state, **kwargs):
70 | return False
--------------------------------------------------------------------------------
/kozax/environments/control_environments/harmonic_oscillator.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jrandom
12 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase
13 | from jaxtyping import Array
14 | from typing import Tuple
15 |
16 | class HarmonicOscillator(EnvironmentBase):
17 | """
18 | Harmonic Oscillator environment for control tasks.
19 |
20 | Parameters
21 | ----------
22 | process_noise : float
23 | Standard deviation of the process noise.
24 | obs_noise : float
25 | Standard deviation of the observation noise.
26 | n_obs : int, optional
27 | Number of observations. Default is 2.
28 |
29 | Attributes
30 | ----------
31 | n_dim : int
32 | Number of dimensions.
33 | n_var : int
34 | Number of variables in the state.
35 | n_control_inputs : int
36 | Number of control inputs.
37 | n_targets : int
38 | Number of targets.
39 | mu0 : :class:`jax.Array`
40 | Mean of the initial state distribution.
41 | P0 : :class:`jax.Array`
42 | Covariance matrix of the initial state distribution.
43 | q : float
44 | Process noise parameter.
45 | r : float
46 | Observation noise parameter.
47 | Q : :class:`jax.Array`
48 | Process noise covariance matrix.
49 | R : :class:`jax.Array`
50 | Observation noise covariance matrix.
51 |
52 | Methods
53 | -------
54 | sample_init_states(batch_size, key)
55 | Samples initial states for the environment.
56 | sample_params(batch_size, mode, ts, key)
57 | Samples parameters for the environment.
58 | initialize_parameters(params, ts)
59 | Initializes the parameters of the environment.
60 | drift(t, state, args)
61 | Computes the drift function for the environment.
62 | diffusion(t, state, args)
63 | Computes the diffusion function for the environment.
64 | fitness_function(state, control, target, ts)
65 | Computes the fitness function for the environment.
66 | cond_fn_nan(t, y, args, **kwargs)
67 | Checks for NaN or infinite values in the state.
68 | """
69 |
70 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 2) -> None:
71 | self.n_dim = 1
72 | self.n_var = 2
73 | self.n_control_inputs = 1
74 | self.n_targets = 1
75 | self.mu0 = jnp.zeros(self.n_var)
76 | self.P0 = jnp.eye(self.n_var) * jnp.array([3.0, 1.0])
77 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs)
78 |
79 | self.q = self.r = 0.5
80 | self.Q = jnp.array([[self.q, 0], [0, 0]])
81 | self.R = jnp.array([[self.r]])
82 |
83 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]:
84 | """
85 | Samples initial states for the environment.
86 |
87 | Parameters
88 | ----------
89 | batch_size : int
90 | Number of initial states to sample.
91 | key : :class:`jax.random.PRNGKey`
92 | Random key for sampling.
93 |
94 | Returns
95 | -------
96 | x0 : :class:`jax.Array`
97 | Initial states.
98 | targets : :class:`jax.Array`
99 | Target states.
100 | """
101 | init_key, target_key = jrandom.split(key)
102 | x0 = self.mu0 + jrandom.normal(init_key, shape=(batch_size, self.n_var)) @ self.P0
103 | targets = jrandom.uniform(target_key, shape=(batch_size, self.n_targets), minval=-3, maxval=3)
104 | return x0, targets
105 |
106 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Tuple[Array, Array]:
107 | """
108 | Samples parameters for the environment.
109 |
110 | Parameters
111 | ----------
112 | batch_size : int
113 | Number of parameters to sample.
114 | mode : str
115 | Mode for sampling parameters. Options are "Constant", "Different", "Changing".
116 | ts : :class:`jax.Array`
117 | Time steps.
118 | key : :class:`jax.random.PRNGKey`
119 | Random key for sampling.
120 |
121 | Returns
122 | -------
123 | omegas : :class:`jax.Array`
124 | Sampled omega parameters.
125 | zetas : :class:`jax.Array`
126 | Sampled zeta parameters.
127 | """
128 | omega_key, zeta_key, args_key = jrandom.split(key, 3)
129 | if mode == "Constant":
130 | omegas = jnp.ones((batch_size))
131 | zetas = jnp.zeros((batch_size))
132 | elif mode == "Different":
133 | omegas = jrandom.uniform(omega_key, shape=(batch_size,), minval=0.0, maxval=2.0)
134 | zetas = jrandom.uniform(zeta_key, shape=(batch_size,), minval=0.0, maxval=1.5)
135 | elif mode == "Changing":
136 | decay_factors = jrandom.uniform(args_key, shape=(batch_size, 2), minval=0.98, maxval=1.02)
137 | init_omegas = jrandom.uniform(omega_key, shape=(batch_size,), minval=0.5, maxval=1.5)
138 | init_zetas = jrandom.uniform(zeta_key, shape=(batch_size,), minval=0.0, maxval=1.0)
139 | omegas = jax.vmap(lambda o, d, t: o * (d ** t), in_axes=[0, 0, None])(init_omegas, decay_factors[:, 0], ts)
140 | zetas = jax.vmap(lambda z, d, t: z * (d ** t), in_axes=[0, 0, None])(init_zetas, decay_factors[:, 1], ts)
141 | return omegas, zetas
142 |
143 | def initialize_parameters(self, params: Tuple[Array, Array], ts: Array) -> None:
144 | """
145 | Initializes the parameters of the environment.
146 |
147 | Parameters
148 | ----------
149 | params : tuple of :class:`jax.Array`
150 | Parameters to initialize.
151 | ts : :class:`jax.Array`
152 | Time steps.
153 | """
154 | omega, zeta = params
155 | self.A = jnp.array([[0, 1], [-omega, -zeta]])
156 | self.b = jnp.array([[0.0, 1.0]]).T
157 | self.G = jnp.array([[0, 0], [0, 1]])
158 | self.V = self.process_noise * self.G
159 | self.C = jnp.eye(self.n_var)[:self.n_obs]
160 | self.W = self.obs_noise * jnp.eye(self.n_obs)
161 |
162 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
163 | """
164 | Computes the drift function for the environment.
165 |
166 | Parameters
167 | ----------
168 | t : float
169 | Current time.
170 | state : :class:`jax.Array`
171 | Current state.
172 | args : tuple
173 | Additional arguments.
174 |
175 | Returns
176 | -------
177 | :class:`jax.Array`
178 | Drift.
179 | """
180 | return self.A @ state + self.b @ args
181 |
182 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
183 | """
184 | Computes the diffusion function for the environment.
185 |
186 | Parameters
187 | ----------
188 | t : float
189 | Current time.
190 | state : :class:`jax.Array`
191 | Current state.
192 | args : tuple
193 | Additional arguments.
194 |
195 | Returns
196 | -------
197 | :class:`jax.Array`
198 | Diffusion.
199 | """
200 | return self.V
201 |
202 | def fitness_function(self, state: Array, control: Array, target: Array, ts: Array) -> float:
203 | """
204 | Computes the fitness function for the environment.
205 |
206 | Parameters
207 | ----------
208 | state : :class:`jax.Array`
209 | Current state.
210 | control : :class:`jax.Array`
211 | Control inputs.
212 | target : :class:`jax.Array`
213 | Target states.
214 | ts : :class:`jax.Array`
215 | Time steps.
216 |
217 | Returns
218 | -------
219 | float
220 | Fitness value.
221 | """
222 | x_d = jnp.array([jnp.squeeze(target), 0])
223 | u_d = -jnp.linalg.pinv(self.b) @ self.A @ x_d
224 | costs = jax.vmap(lambda _state, _u: (_state - x_d).T @ self.Q @ (_state - x_d) + (_u - u_d) @ self.R @ (_u - u_d))(state, control)
225 | return jnp.sum(costs)
226 |
227 | def cond_fn_nan(self, t: float, y: Array, args: Tuple, **kwargs) -> float:
228 | """
229 | Checks for NaN or infinite values in the state.
230 |
231 | Parameters
232 | ----------
233 | t : float
234 | Current time.
235 | y : :class:`jax.Array`
236 | Current state.
237 | args : tuple
238 | Additional arguments.
239 | kwargs : dict
240 | Additional keyword arguments.
241 |
242 | Returns
243 | -------
244 | float
245 | -1.0 if NaN or infinite values are found, 1.0 otherwise.
246 | """
247 | return jnp.where(jnp.any(jnp.isinf(y) + jnp.isnan(y)), -1.0, 1.0)
--------------------------------------------------------------------------------
/kozax/environments/control_environments/reactor.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax.numpy as jnp
10 | import jax
11 | import jax.random as jrandom
12 | import diffrax
13 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase
14 | from jaxtyping import Array
15 | from typing import Tuple
16 |
17 | class StirredTankReactor(EnvironmentBase):
18 | """
19 | Stirred Tank Reactor environment for control tasks.
20 |
21 | Parameters
22 | ----------
23 | process_noise : float
24 | Standard deviation of the process noise.
25 | obs_noise : float
26 | Standard deviation of the observation noise.
27 | n_obs : int, optional
28 | Number of observations. Default is 3.
29 | n_targets : int, optional
30 | Number of targets. Default is 1.
31 | max_control : :class:`jax.Array`, optional
32 | Maximum control values. Default is jnp.array([300]).
33 | external_f : callable, optional
34 | External influence function. Default is lambda t: 0.0.
35 |
36 | Attributes
37 | ----------
38 | n_var : int
39 | Number of variables in the state.
40 | n_control_inputs : int
41 | Number of control inputs.
42 | n_dim : int
43 | Number of dimensions.
44 | n_targets : int
45 | Number of targets.
46 | init_lower_bounds : :class:`jax.Array`
47 | Lower bounds for initial state sampling.
48 | init_upper_bounds : :class:`jax.Array`
49 | Upper bounds for initial state sampling.
50 | max_control : :class:`jax.Array`
51 | Maximum control values.
52 | Q : :class:`jax.Array`
53 | Process noise covariance matrix.
54 | r : :class:`jax.Array`
55 | Observation noise covariance matrix.
56 | external_f : callable
57 | External influence function.
58 |
59 | Methods
60 | -------
61 | initialize_parameters(params, ts)
62 | Initializes the parameters of the environment.
63 | sample_param_change(key, batch_size, ts, low, high)
64 | Samples parameter changes over time.
65 | sample_params(batch_size, mode, ts, key)
66 | Samples parameters for the environment.
67 | sample_init_states(batch_size, key)
68 | Samples initial states for the environment.
69 | drift(t, state, args)
70 | Computes the drift function for the environment.
71 | diffusion(t, state, args)
72 | Computes the diffusion function for the environment.
73 | fitness_function(state, control, targets, ts)
74 | Computes the fitness function for the environment.
75 | cond_fn_nan(t, y, args, **kwargs)
76 | Checks for NaN or infinite values in the state.
77 | """
78 |
79 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 3, n_targets: int = 1, max_control: Array = jnp.array([300]), external_f: callable = lambda t: 0.0) -> None:
80 | self.process_noise = process_noise
81 | self.obs_noise = obs_noise
82 | self.n_var = 3
83 | self.n_control_inputs = 1
84 | self.n_dim = 1
85 | self.n_targets = n_targets
86 | self.init_lower_bounds = jnp.array([275, 350, 0.5])
87 | self.init_upper_bounds = jnp.array([300, 375, 1.0])
88 | self.max_control = max_control
89 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs)
90 |
91 | self.Q = jnp.array([[0, 0, 0], [0, 0.01, 0], [0, 0, 0]])
92 | self.r = jnp.array([[0.0001]])
93 | self.external_f = external_f
94 |
95 | def initialize_parameters(self, params: Tuple[Array, Array, Array, Array, Array, Array, Array, Array], ts: Array) -> None:
96 | """
97 | Initializes the parameters of the environment.
98 |
99 | Parameters
100 | ----------
101 | params : tuple of :class:`jax.Array`
102 | Parameters to initialize.
103 | ts : :class:`jax.Array`
104 | Time steps.
105 | """
106 | Vol, Cp, dHr, UA, q, Tf, Tcf, Volc = params
107 | self.Ea = 72750 # activation energy J/gmol
108 | self.R = 8.314 # gas constant J/gmol/K
109 | self.k0 = 7.2e10 # Arrhenius rate constant 1/min
110 | self.Vol = Vol # Volume [L]
111 | self.Cp = Cp # Heat capacity [J/g/K]
112 | self.dHr = dHr # Enthalpy of reaction [J/mol]
113 | self.UA = UA # Heat transfer [J/min/K]
114 | self.q = q # Flowrate [L/min]
115 | self.Cf = 1.0 # Inlet feed concentration [mol/L]
116 | self.Tf = diffrax.LinearInterpolation(ts, Tf) # Inlet feed temperature [K]
117 | self.Tcf = Tcf # Coolant feed temperature [K]
118 | self.Volc = Volc # Cooling jacket volume
119 |
120 | self.k = lambda T: self.k0 * jnp.exp(-self.Ea / self.R / T)
121 |
122 | self.G = jnp.eye(self.n_var) * jnp.array([6, 6, 0.05])
123 | self.process_noise_ts = diffrax.LinearInterpolation(ts, jnp.linspace(self.process_noise[0], self.process_noise[1], ts.shape[0]))
124 |
125 | self.C = jnp.eye(self.n_var)[:self.n_obs]
126 | self.W = self.obs_noise * jnp.eye(self.n_obs) * (jnp.array([15, 15, 0.1])[:self.n_obs])
127 |
128 | self.max_control_ts = diffrax.LinearInterpolation(ts, jnp.hstack([mc * jnp.ones(int(ts.shape[0] // self.max_control.shape[0])) for mc in self.max_control]))
129 | self.external_influence = diffrax.LinearInterpolation(ts, jax.vmap(self.external_f)(ts))
130 |
131 | def sample_param_change(self, key: jrandom.PRNGKey, batch_size: int, ts: Array, low: float, high: float) -> Array:
132 | """
133 | Samples parameter changes over time.
134 |
135 | Parameters
136 | ----------
137 | key : :class:`jax.random.PRNGKey`
138 | Random key for sampling.
139 | batch_size : int
140 | Number of samples.
141 | ts : :class:`jax.Array`
142 | Time steps.
143 | low : float
144 | Lower bound for sampling.
145 | high : float
146 | Upper bound for sampling.
147 |
148 | Returns
149 | -------
150 | :class:`jax.Array`
151 | Sampled parameter values.
152 | """
153 | init_key, decay_key = jrandom.split(key)
154 | decay_factors = jrandom.uniform(decay_key, shape=(batch_size,), minval=1.01, maxval=1.02)
155 | init_values = jrandom.uniform(init_key, shape=(batch_size,), minval=low, maxval=high)
156 | values = jax.vmap(lambda v, d, t: v * (d ** t), in_axes=[0, 0, None])(init_values, decay_factors, ts)
157 | return values
158 |
159 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Tuple[Array, Array, Array, Array, Array, Array, Array, Array]:
160 | """
161 | Samples parameters for the environment.
162 |
163 | Parameters
164 | ----------
165 | batch_size : int
166 | Number of samples.
167 | mode : str
168 | Sampling mode. Options are "Constant", "Different", "Changing".
169 | ts : :class:`jax.Array`
170 | Time steps.
171 | key : :class:`jax.random.PRNGKey`
172 | Random key for sampling.
173 |
174 | Returns
175 | -------
176 | tuple of :class:`jax.Array`
177 | Sampled parameters.
178 | """
179 | if mode == "Constant":
180 | Vol = 100 * jnp.ones(batch_size)
181 | Cp = 239 * jnp.ones(batch_size)
182 | dHr = -5.0e4 * jnp.ones(batch_size)
183 | UA = 5.0e4 * jnp.ones(batch_size)
184 | q = 100 * jnp.ones(batch_size)
185 | Tf = 300 * jnp.ones((batch_size, ts.shape[0]))
186 | Tcf = 300 * jnp.ones(batch_size)
187 | Volc = 20.0 * jnp.ones(batch_size)
188 | elif mode == "Different":
189 | keys = jrandom.split(key, 8)
190 | Vol = jrandom.uniform(keys[0], (batch_size,), minval=75, maxval=150)
191 | Cp = jrandom.uniform(keys[1], (batch_size,), minval=200, maxval=350)
192 | dHr = jrandom.uniform(keys[2], (batch_size,), minval=-55000, maxval=-45000)
193 | UA = jrandom.uniform(keys[3], (batch_size,), minval=25000, maxval=75000)
194 | q = jrandom.uniform(keys[4], (batch_size,), minval=75, maxval=125)
195 | Tf = jnp.repeat(jrandom.uniform(keys[5], (batch_size,), minval=300, maxval=350)[:, None], ts.shape[0], axis=1)
196 | Tcf = jrandom.uniform(keys[6], (batch_size,), minval=250, maxval=300)
197 | Volc = jrandom.uniform(keys[7], (batch_size,), minval=10, maxval=30)
198 | elif mode == "Changing":
199 | keys = jrandom.split(key, 8)
200 | Vol = jrandom.uniform(keys[0], (batch_size,), minval=75, maxval=150)
201 | Cp = jrandom.uniform(keys[1], (batch_size,), minval=200, maxval=350)
202 | dHr = jrandom.uniform(keys[2], (batch_size,), minval=-55000, maxval=-45000)
203 | UA = jrandom.uniform(keys[3], (batch_size,), minval=25000, maxval=75000)
204 | q = jrandom.uniform(keys[4], (batch_size,), minval=75, maxval=125)
205 | Tf = self.sample_param_change(keys[5], batch_size, ts, 300, 350)
206 | Tcf = jrandom.uniform(keys[6], (batch_size,), minval=250, maxval=300)
207 | Volc = jrandom.uniform(keys[7], (batch_size,), minval=10, maxval=30)
208 | return Vol, Cp, dHr, UA, q, Tf, Tcf, Volc
209 |
210 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]:
211 | """
212 | Samples initial states for the environment.
213 |
214 | Parameters
215 | ----------
216 | batch_size : int
217 | Number of initial states to sample.
218 | key : :class:`jax.random.PRNGKey`
219 | Random key for sampling.
220 |
221 | Returns
222 | -------
223 | x0 : :class:`jax.Array`
224 | Initial states.
225 | targets : :class:`jax.Array`
226 | Target states.
227 | """
228 | init_key, target_key = jrandom.split(key)
229 | x0 = jrandom.uniform(init_key, shape=(batch_size, self.n_var), minval=self.init_lower_bounds, maxval=self.init_upper_bounds)
230 | targets = jrandom.uniform(target_key, shape=(batch_size, self.n_targets), minval=400, maxval=480)
231 | return x0, targets
232 |
233 | def drift(self, t: float, state: Array, args: Tuple) -> Array:
234 | """
235 | Computes the drift function for the environment.
236 |
237 | Parameters
238 | ----------
239 | t : float
240 | Current time.
241 | state : :class:`jax.Array`
242 | Current state.
243 | args : tuple
244 | Additional arguments.
245 |
246 | Returns
247 | -------
248 | :class:`jax.Array`
249 | Drift.
250 | """
251 | Tc, T, c = state
252 | control = jnp.squeeze(args)
253 | control = jnp.clip(control, 0, 300)
254 |
255 | dc = (self.q / self.Vol) * (self.Cf - c) - self.k(T) * c
256 | dT = (self.q / self.Vol) * (self.Tf.evaluate(t) - T) + (-self.dHr / self.Cp) * self.k(T) * c + (self.UA / self.Vol / self.Cp) * (Tc - T) + self.external_influence.evaluate(t)
257 | dTc = (control / self.Volc) * (self.Tcf - Tc) + (self.UA / self.Volc / self.Cp) * (T - Tc)
258 |
259 | return jnp.array([dTc, dT, dc])
260 |
261 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array:
262 | """
263 | Computes the diffusion function for the environment.
264 |
265 | Parameters
266 | ----------
267 | t : float
268 | Current time.
269 | state : :class:`jax.Array`
270 | Current state.
271 | args : tuple
272 | Additional arguments.
273 |
274 | Returns
275 | -------
276 | :class:`jax.Array`
277 | Diffusion.
278 | """
279 | return self.process_noise_ts.evaluate(t) * self.G
280 |
281 | def fitness_function(self, state: Array, control: Array, targets: Array, ts: Array) -> float:
282 | """
283 | Computes the fitness function for the environment.
284 |
285 | Parameters
286 | ----------
287 | state : :class:`jax.Array`
288 | Current state.
289 | control : :class:`jax.Array`
290 | Control inputs.
291 | targets : :class:`jax.Array`
292 | Target states.
293 | ts : :class:`jax.Array`
294 | Time steps.
295 |
296 | Returns
297 | -------
298 | float
299 | Fitness value.
300 | """
301 | x_d = jax.vmap(lambda tar: jnp.array([0, tar, 0]))(targets)
302 | costs = jax.vmap(lambda _state, _u, _x_d: (_state - _x_d).T @ self.Q @ (_state - _x_d) + (_u) @ self.r @ (_u))(state, control, x_d)
303 | return jnp.sum(costs)
304 |
305 | def cond_fn_nan(self, t: float, y: Array, args: Tuple, **kwargs) -> float:
306 | """
307 | Checks for NaN or infinite values in the state.
308 |
309 | Parameters
310 | ----------
311 | t : float
312 | Current time.
313 | y : :class:`jax.Array`
314 | Current state.
315 | args : tuple
316 | Additional arguments.
317 | kwargs : dict
318 | Additional keyword arguments.
319 |
320 | Returns
321 | -------
322 | float
323 | -1.0 if NaN or infinite values are found, 1.0 otherwise.
324 | """
325 | return jnp.where(jnp.any(jnp.isinf(y) + jnp.isnan(y)), -1.0, 1.0)
--------------------------------------------------------------------------------
/kozax/fitness_functions/Gymnax_fitness_function.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | from jaxtyping import Array
11 | import jax.numpy as jnp
12 | import jax.random as jr
13 | from typing import Tuple, Callable, Any
14 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction
15 | import gymnax
16 |
17 | class GymFitnessFunction(BaseFitnessFunction):
18 | """
19 | Evaluator for static symbolic policies in control tasks using Gym environments.
20 |
21 | Parameters
22 | ----------
23 | env_name : str
24 | Name of the Gym environment to be used.
25 |
26 | Attributes
27 | ----------
28 | env : gymnax environment
29 | The gymnax environment.
30 | env_params : dict
31 | Parameters for the gymnax environment.
32 | num_steps : int
33 | Number of steps in the environment for each episode.
34 |
35 | Methods
36 | -------
37 | __call__(candidate, keys, tree_evaluator)
38 | Evaluates the candidate on a task.
39 | evaluate_trajectory(candidate, key, tree_evaluator)
40 | Evaluates a rollout of the candidate in the environment.
41 | """
42 |
43 | def __init__(self, env_name: str) -> None:
44 | self.env, self.env_params = gymnax.make(env_name)
45 | self.num_steps = self.env_params.max_steps_in_episode
46 |
47 | def __call__(self, candidate: Array, keys: Array, tree_evaluator: Callable) -> float:
48 | """
49 | Evaluates the candidate on a task.
50 |
51 | Parameters
52 | ----------
53 | candidate : :class:`jax.Array`
54 | The candidate solution to be evaluated.
55 | keys : :class:`jax.Array`
56 | Random keys for evaluation.
57 | tree_evaluator : :class:`Callable`
58 | Function for evaluating trees.
59 |
60 | Returns
61 | -------
62 | float
63 | Fitness of the candidate.
64 | """
65 | reward = jax.vmap(self.evaluate_trajectory, in_axes=(None, 0, None))(candidate, keys, tree_evaluator)
66 | return jnp.mean(reward)
67 |
68 | def evaluate_trajectory(self, candidate: Array, key: jr.PRNGKey, tree_evaluator: Callable) -> Tuple[Array, float]:
69 | """
70 | Evaluates a rollout of the candidate in the environment.
71 |
72 | Parameters
73 | ----------
74 | candidate : :class:`jax.Array`
75 | The candidate solution to be evaluated.
76 | key : :class:`jax.random.PRNGKey`
77 | Random key for evaluation.
78 | tree_evaluator : :class:`Callable`
79 | Function for evaluating trees.
80 |
81 | Returns
82 | -------
83 | reward : float
84 | Total reward obtained during the trajectory.
85 | """
86 | key, subkey = jr.split(key)
87 | state, env_state = self.env.reset(subkey, self.env_params)
88 |
89 | def policy(state: Array) -> Array:
90 | """Symbolic policy."""
91 | a = tree_evaluator(candidate, state)
92 | return a
93 |
94 | def step_fn(carry: Tuple[Array, Any, jr.PRNGKey], _) -> Tuple[Tuple[Array, Any, jr.PRNGKey], Tuple[Array, float, bool]]:
95 | """Step function for lax.scan."""
96 | state, env_state, key = carry
97 |
98 | # Select action based on policy
99 | action = policy(state)
100 |
101 | # Step the environment
102 | key, subkey = jr.split(key)
103 | next_state, next_env_state, reward, done, _ = self.env.step(
104 | subkey, env_state, action, self.env_params
105 | )
106 |
107 | return (next_state, next_env_state, key), (state, reward, done)
108 |
109 | # Run the rollout using lax.scan
110 | (final_carry, (states, rewards, dones)) = jax.lax.scan(
111 | step_fn, (state, env_state, key), None, length=self.num_steps
112 | )
113 |
114 | return -jnp.sum(rewards)
--------------------------------------------------------------------------------
/kozax/fitness_functions/ODE_fitness_function.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | from jaxtyping import Array
11 | import jax.numpy as jnp
12 | from typing import Tuple, Callable
13 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction
14 | import diffrax
15 |
16 | class ODEFitnessFunction(BaseFitnessFunction):
17 | """
18 | Evaluator for candidates on symbolic regression tasks.
19 |
20 | Parameters
21 | ----------
22 | solver : :class:`diffrax.AbstractSolver`, optional
23 | Solver used for integration. Default is `diffrax.Euler()`.
24 | dt0 : float, optional
25 | Initial step size for integration. Default is 0.01.
26 | max_steps : int, optional
27 | The maximum number of steps that can be used in integration. Default is 16**4.
28 | stepsize_controller : :class:`diffrax.AbstractStepSizeController`, optional
29 | Controller for the stepsize during integration. Default is `diffrax.ConstantStepSize()`.
30 |
31 | Attributes
32 | ----------
33 | dt0 : float
34 | Initial step size for integration.
35 | MSE : Callable
36 | Function that computes the mean squared error.
37 | system : :class:`diffrax.ODETerm`
38 | ODE term of the drift function.
39 | solver : :class:`diffrax.AbstractSolver`
40 | Solver used for integration.
41 | stepsize_controller : :class:`diffrax.AbstractStepSizeController`
42 | Controller for the stepsize during integration.
43 | max_steps : int
44 | The maximum number of steps that can be used in integration.
45 |
46 | Methods
47 | -------
48 | __call__(candidate, data, tree_evaluator)
49 | Evaluates the candidate on a task.
50 | evaluate_time_series(candidate, x0, ts, ys, tree_evaluator)
51 | Integrate the candidate as a differential equation and compute the fitness given the predictions.
52 | drift(t, x, args)
53 | Drift function for the ODE system.
54 | """
55 |
56 | def __init__(self, solver: diffrax.AbstractSolver = diffrax.Euler(), dt0: float = 0.01, max_steps: int = 16**4, stepsize_controller: diffrax.AbstractStepSizeController = diffrax.ConstantStepSize()) -> None:
57 | self.dt0 = dt0
58 | self.MSE = lambda pred_ys, true_ys: jnp.mean(jnp.sum(jnp.abs(pred_ys - true_ys), axis=-1))/jnp.mean(jnp.abs(true_ys))
59 | self.system = diffrax.ODETerm(self.drift)
60 | self.solver = solver
61 | self.stepsize_controller = stepsize_controller
62 | self.max_steps = max_steps
63 |
64 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float:
65 | """
66 | Evaluates the candidate on a task.
67 |
68 | Parameters
69 | ----------
70 | candidate : :class:`jax.Array`
71 | The candidate solution to be evaluated.
72 | data : :class:`tuple`
73 | The data required to evaluate the candidate.
74 | tree_evaluator : :class:`Callable`
75 | Function for evaluating trees.
76 |
77 | Returns
78 | -------
79 | float
80 | Fitness of the candidate.
81 | """
82 | x0, ts, ys = data
83 | fitness = jax.vmap(self.evaluate_time_series, in_axes=[None, 0, None, 0, None])(candidate, x0, ts, ys, tree_evaluator)
84 | return jnp.mean(fitness)
85 |
86 | def evaluate_time_series(self, candidate: Array, x0: Array, ts: Array, ys: Array, tree_evaluator: Callable) -> float:
87 | """
88 | Integrate the candidate as a differential equation and compute the fitness given the predictions.
89 |
90 | Parameters
91 | ----------
92 | candidate : :class:`jax.Array`
93 | Candidate that is evaluated.
94 | x0 : :class:`jax.Array`
95 | Initial conditions of the environment.
96 | ts : :class:`jax.Array`
97 | Timepoints of which the system has to be solved.
98 | ys : :class:`jax.Array`
99 | Ground truth data used to compute the fitness.
100 | tree_evaluator : :class:`Callable`
101 | Function for evaluating trees.
102 |
103 | Returns
104 | -------
105 | float
106 | Fitness of the candidate.
107 | """
108 | saveat = diffrax.SaveAt(ts=ts)
109 | sol = diffrax.diffeqsolve(
110 | self.system, self.solver, ts[0], ts[-1], self.dt0, x0, args=(candidate, tree_evaluator), saveat=saveat, max_steps=self.max_steps, stepsize_controller=self.stepsize_controller,
111 | adjoint=diffrax.DirectAdjoint(), throw=False
112 | )
113 | pred_ys = sol.ys
114 | fitness = self.MSE(pred_ys, ys)
115 | return fitness
116 |
117 | def drift(self, t: float, x: Array, args: Tuple) -> Array:
118 | """
119 | Drift function for the ODE system.
120 |
121 | Parameters
122 | ----------
123 | t : float
124 | Current time.
125 | x : :class:`jax.Array`
126 | Current state.
127 | args : :class:`tuple`
128 | Additional arguments, including the candidate and tree evaluator.
129 |
130 | Returns
131 | -------
132 | :class:`jax.Array`
133 | Derivative of the state.
134 | """
135 | candidate, tree_evaluator = args
136 | dx = tree_evaluator(candidate, x)
137 | return dx
--------------------------------------------------------------------------------
/kozax/fitness_functions/SR_fitness_function.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | from typing import Tuple, Callable
11 | from jaxtyping import Array
12 | import jax.numpy as jnp
13 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction
14 |
15 | class SymbolicRegressionFitnessFunction(BaseFitnessFunction):
16 | """
17 | Evaluator for candidates on symbolic regression tasks with x, y data.
18 |
19 | Methods
20 | -------
21 | __call__(candidate, data, tree_evaluator)
22 | Evaluates the candidate on a task.
23 | """
24 |
25 | def __call__(self, candidate: Array, data: Tuple[Array, Array], tree_evaluator: Callable) -> float:
26 | """
27 | Evaluates the candidate on a task.
28 |
29 | Parameters
30 | ----------
31 | candidate : :class:`jax.Array`
32 | The candidate solution to be evaluated.
33 | data : :class:`tuple` of :class:`jax.Array`
34 | The data required to evaluate the candidate. Tuple of (x, y) where x is the input data and y is the true output data.
35 | tree_evaluator : :class:`Callable`
36 | Function for evaluating trees.
37 |
38 | Returns
39 | -------
40 | float
41 | Fitness of the candidate.
42 | """
43 | x, y = data
44 | pred = jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, x)
45 | return jnp.mean(jnp.square(pred - y))
--------------------------------------------------------------------------------
/kozax/fitness_functions/base_fitness_function.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | from abc import ABC, abstractmethod
10 | from typing import Tuple, Callable
11 | from jaxtyping import Array
12 |
13 | class BaseFitnessFunction(ABC):
14 | """
15 | Abstract base class for evaluating candidates in genetic programming.
16 |
17 | Methods
18 | -------
19 | __call__(candidate, data, tree_evaluator)
20 | Evaluates the candidate on a task.
21 | """
22 |
23 | @abstractmethod
24 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float:
25 | """
26 | Evaluates the candidate on a task.
27 |
28 | Parameters
29 | ----------
30 | candidate : :class:`jax.Array`
31 | The candidate solution to be evaluated.
32 | data : :class:`tuple`
33 | The data required to evaluate the candidate.
34 | tree_evaluator : :class:`Callable`
35 | Function for evaluating trees.
36 |
37 | Returns
38 | -------
39 | float
40 | Fitness of the candidate.
41 | """
42 | raise NotImplementedError
--------------------------------------------------------------------------------
/kozax/genetic_operators/crossover.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jr
12 | from typing import Tuple
13 | from jaxtyping import Array
14 | from jax.random import PRNGKey
15 |
16 | def sample_indices(carry: Tuple[PRNGKey, Array, float]) -> Tuple[PRNGKey, Array, float]:
17 | """
18 | Samples indices of the trees in a candidate that will be mutated.
19 |
20 | Parameters
21 | ----------
22 | carry : tuple of (PRNGKey, Array, float)
23 | Tuple containing the random key, indices of trees, and reproduction probability.
24 |
25 | Returns
26 | -------
27 | tuple of (PRNGKey, Array, float)
28 | Updated tuple with the random key, indices of trees to be mutated, and reproduction probability.
29 | """
30 | key, indices, reproduction_probability = carry
31 | indices = jr.bernoulli(key, p=reproduction_probability, shape=indices.shape) * 1.0
32 | return (jr.split(key, 1)[0], indices, reproduction_probability)
33 |
34 | def find_end_idx(carry: Tuple[Array, int, int]) -> Tuple[Array, int, int]:
35 | """
36 | Finds the index of the last node in a subtree.
37 |
38 | Parameters
39 | ----------
40 | carry : tuple of (Array, int, int)
41 | Tuple containing the tree, the number of open slots, and the current node index.
42 |
43 | Returns
44 | -------
45 | tuple of (Array, int, int)
46 | Updated tuple with the tree, open slots, and current node index.
47 | """
48 | tree, open_slots, counter = carry
49 | _, idx1, idx2, _ = tree[counter]
50 | open_slots -= 1 # Reduce open slot for current node
51 | open_slots = jax.lax.select(idx1 < 0, open_slots, open_slots + 1) # Increase the open slots for a child
52 | open_slots = jax.lax.select(idx2 < 0, open_slots, open_slots + 1) # Increase the open slots for a child
53 | counter -= 1
54 | return (tree, open_slots, counter)
55 |
56 | def check_invalid_cx_nodes(carry: Tuple[Array, Array, Array, int, int, Array, Array]) -> bool:
57 | """
58 | Checks if the sampled subtrees are different and if the trees after crossover are valid.
59 |
60 | Parameters
61 | ----------
62 | carry : tuple of (Array, Array, Array, int, int, Array, Array)
63 | Tuple containing the trees, node indices, and other parameters.
64 |
65 | Returns
66 | -------
67 | bool
68 | If the sampled nodes are valid nodes for crossover.
69 | """
70 | tree1, tree2, _, node_idx1, node_idx2, _, _ = carry
71 |
72 | _, _, end_idx1 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree1, 1, node_idx1))
73 | _, _, end_idx2 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree2, 1, node_idx2))
74 |
75 | subtree_size1 = node_idx1 - end_idx1
76 | subtree_size2 = node_idx2 - end_idx2
77 |
78 | empty_nodes1 = jnp.sum(tree1[:, 0] == 0)
79 | empty_nodes2 = jnp.sum(tree2[:, 0] == 0)
80 |
81 | # Check if the subtrees can be inserted
82 | return (empty_nodes1 < subtree_size2 - subtree_size1) | (empty_nodes2 < subtree_size1 - subtree_size2)
83 |
84 | def sample_cx_nodes(carry: Tuple[Array, Array, Array, int, int, Array, Array]) -> Tuple[Array, Array, Array, int, int, Array, Array]:
85 | """
86 | Samples nodes in a pair of trees for crossover.
87 |
88 | Parameters
89 | ----------
90 | carry : tuple of (Array, Array, Array, int, int, Array, Array)
91 | Tuple containing the trees, node indices, and other parameters.
92 |
93 | Returns
94 | -------
95 | tuple of (Array, Array, Array, int, int, Array, Array)
96 | Updated tuple with the sampled nodes.
97 | """
98 | tree1, tree2, keys, _, _, node_ids, operator_indices = carry
99 | key1, key2 = keys
100 |
101 | # Sample nodes from the non-empty nodes, with higher probability for operator nodes
102 | cx_prob1 = jnp.isin(tree1[:, 0], operator_indices)
103 | cx_prob1 = jnp.where(tree1[:, 0] == 0, cx_prob1, cx_prob1 + 1)
104 | node_idx1 = jr.choice(key1, node_ids, p=cx_prob1 * 1.0)
105 |
106 | cx_prob2 = jnp.isin(tree2[:, 0], operator_indices)
107 | cx_prob2 = jnp.where(tree2[:, 0] == 0, cx_prob2, cx_prob2 + 1)
108 | node_idx2 = jr.choice(key2, node_ids, p=cx_prob2 * 1.0)
109 |
110 | return (tree1, tree2, jr.split(key1), node_idx1, node_idx2, node_ids, operator_indices)
111 |
112 | def tree_crossover(tree1: Array,
113 | tree2: Array,
114 | keys: Array,
115 | node_ids: Array,
116 | operator_indices: Array) -> Tuple[Array, Array]:
117 | """
118 | Applies crossover to a pair of trees to produce two new trees.
119 |
120 | Parameters
121 | ----------
122 | tree1 : Array
123 | First tree.
124 | tree2 : Array
125 | Second tree.
126 | keys : Array
127 | Random keys.
128 | node_ids : Array
129 | Indices of all the nodes in the trees.
130 | operator_indices : Array
131 | The indices that belong to operator nodes.
132 |
133 | Returns
134 | -------
135 | tuple of (Array, Array)
136 | Pair of new trees.
137 | """
138 | # Define indices of the nodes
139 | tree_indices = jnp.tile(node_ids[:, None], reps=(1, 4))
140 | key1, key2 = keys
141 |
142 | # Define last node in tree
143 | last_node_idx1 = jnp.sum(tree1[:, 0] == 0)
144 | last_node_idx2 = jnp.sum(tree2[:, 0] == 0)
145 |
146 | # Randomly select nodes for crossover
147 | _, _, _, node_idx1, node_idx2, _, _ = sample_cx_nodes((tree1, tree2, jr.split(key1), 0, 0, node_ids, operator_indices))
148 |
149 | # Reselect until valid crossover nodes have been found
150 | _, _, _, node_idx1, node_idx2, _, _ = jax.lax.while_loop(check_invalid_cx_nodes, sample_cx_nodes, (tree1, tree2, jr.split(key2), node_idx1, node_idx2, node_ids, operator_indices))
151 |
152 | # Retrieve subtrees of selected nodes
153 | _, _, end_idx1 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree1, 1, node_idx1))
154 | _, _, end_idx2 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree2, 1, node_idx2))
155 |
156 | # Initialize children
157 | child1 = jnp.tile(jnp.array([0.0, -1.0, -1.0, 0.0]), (len(node_ids), 1))
158 | child2 = jnp.tile(jnp.array([0.0, -1.0, -1.0, 0.0]), (len(node_ids), 1))
159 |
160 | # Compute subtree sizes
161 | subtree_size1 = node_idx1 - end_idx1
162 | subtree_size2 = node_idx2 - end_idx2
163 |
164 | # Insert nodes before subtree in children
165 | child1 = jnp.where(tree_indices >= node_idx1 + 1, tree1, child1)
166 | child2 = jnp.where(tree_indices >= node_idx2 + 1, tree2, child2)
167 |
168 | # Align nodes after subtree with first open spot after new subtree in children
169 | rolled_tree1 = jnp.roll(tree1, subtree_size1 - subtree_size2, axis=0)
170 | rolled_tree2 = jnp.roll(tree2, subtree_size2 - subtree_size1, axis=0)
171 |
172 | # Insert nodes after subtree in children
173 | child1 = jnp.where((tree_indices >= node_idx1 - subtree_size2 - (end_idx1 - last_node_idx1)) & (tree_indices < node_idx1 + 1 - subtree_size2), rolled_tree1, child1)
174 | child2 = jnp.where((tree_indices >= node_idx2 - subtree_size1 - (end_idx2 - last_node_idx2)) & (tree_indices < node_idx2 + 1 - subtree_size1), rolled_tree2, child2)
175 |
176 | # Update index references to moved nodes in staying nodes
177 | child1 = child1.at[:, 1:3].set(jnp.where((child1[:, 1:3] < (node_idx1 - subtree_size1 + 1)) & (child1[:, 1:3] > -1), child1[:, 1:3] + (subtree_size1 - subtree_size2), child1[:, 1:3]))
178 | child2 = child2.at[:, 1:3].set(jnp.where((child2[:, 1:3] < (node_idx2 - subtree_size2 + 1)) & (child2[:, 1:3] > -1), child2[:, 1:3] + (subtree_size2 - subtree_size1), child2[:, 1:3]))
179 |
180 | # Align subtree with the selected node in children
181 | rolled_subtree1 = jnp.roll(tree1, node_idx2 - node_idx1, axis=0)
182 | rolled_subtree2 = jnp.roll(tree2, node_idx1 - node_idx2, axis=0)
183 |
184 | # Update index references in subtree
185 | rolled_subtree1 = rolled_subtree1.at[:, 1:3].set(jnp.where(rolled_subtree1[:, 1:3] > -1, rolled_subtree1[:, 1:3] + (node_idx2 - node_idx1), -1))
186 | rolled_subtree2 = rolled_subtree2.at[:, 1:3].set(jnp.where(rolled_subtree2[:, 1:3] > -1, rolled_subtree2[:, 1:3] + (node_idx1 - node_idx2), -1))
187 |
188 | # Insert subtree in selected node in children
189 | child1 = jnp.where((tree_indices >= node_idx1 + 1 - subtree_size2) & (tree_indices < node_idx1 + 1), rolled_subtree2, child1)
190 | child2 = jnp.where((tree_indices >= node_idx2 + 1 - subtree_size1) & (tree_indices < node_idx2 + 1), rolled_subtree1, child2)
191 |
192 | return child1, child2
193 |
194 | def full_crossover(tree1: Array,
195 | tree2: Array,
196 | keys: Array,
197 | node_ids: Array,
198 | operator_indices: Array) -> Tuple[Array, Array]:
199 | """
200 | Swaps the entire trees between two candidates.
201 |
202 | Parameters
203 | ----------
204 | tree1 : Array
205 | First tree.
206 | tree2 : Array
207 | Second tree.
208 | keys : Array
209 | Random keys.
210 | node_ids : Array
211 | Indices of all the nodes in the trees.
212 | operator_indices : Array
213 | The indices that belong to operator nodes.
214 |
215 | Returns
216 | -------
217 | tuple of (Array, Array)
218 | Swapped trees.
219 | """
220 | return tree2, tree1
221 |
222 | def crossover(tree1: Array,
223 | tree2: Array,
224 | keys: Array,
225 | node_ids: Array,
226 | operator_indices: Array,
227 | crossover_types: int) -> Tuple[Array, Array]:
228 | """
229 | Applies crossover to a pair of trees based on the crossover type.
230 |
231 | Parameters
232 | ----------
233 | tree1 : Array
234 | First tree.
235 | tree2 : Array
236 | Second tree.
237 | keys : Array
238 | Random keys.
239 | node_ids : Array
240 | Indices of all the nodes in the trees.
241 | operator_indices : Array
242 | The indices that belong to operator nodes.
243 | crossover_types : int
244 | Type of crossover to apply.
245 |
246 | Returns
247 | -------
248 | tuple of (Array, Array)
249 | Pair of new trees.
250 | """
251 | return jax.lax.cond(crossover_types, tree_crossover, full_crossover, tree1, tree2, keys, node_ids, operator_indices)
252 |
253 | def check_different_tree(parent1: Array, parent2: Array, child1: Array, child2: Array) -> bool:
254 | """
255 | Checks if the offspring are different from the parents.
256 |
257 | Parameters
258 | ----------
259 | parent1 : Array
260 | First parent tree.
261 | parent2 : Array
262 | Second parent tree.
263 | child1 : Array
264 | First child tree.
265 | child2 : Array
266 | Second child tree.
267 |
268 | Returns
269 | -------
270 | bool
271 | True if the offspring are different from the parents, False otherwise.
272 | """
273 | size1 = jnp.sum(child1[:, 0] != 0)
274 | size2 = jnp.sum(child2[:, 0] != 0)
275 |
276 | check1 = (jnp.all(parent1 == child1) | jnp.all(parent2 == child1))
277 | check2 = (jnp.all(parent1 == child2) | jnp.all(parent2 == child2))
278 |
279 | return ((check1 | check2) & ((size1 > 1) & (size2 > 1))) | (size1 == 0)
280 |
281 | def check_different_trees(carry: Tuple[Array, Array, Array, Array, Array, float, Array, Array]) -> bool:
282 | """
283 | Checks if the offspring are different from the parents for all trees in the population.
284 |
285 | Parameters
286 | ----------
287 | carry : tuple of (Array, Array, Array, Array, Array, float, Array, Array)
288 | Tuple containing the parent trees, child trees, and other parameters.
289 |
290 | Returns
291 | -------
292 | bool
293 | True if the offspring are different from the parents for all trees, False otherwise.
294 | """
295 | parent1, parent2, child1, child2, _, _, _, _ = carry
296 | return jnp.all(jax.vmap(check_different_tree)(parent1, parent2, child1, child2))
297 |
298 | def safe_crossover(carry: Tuple[Array, Array, Array, Array, Array, float, Array, Array]) -> Tuple[Array, Array, Array, Array, Array, float, Array, Array]:
299 | """
300 | Ensures that the crossover produces valid offspring.
301 |
302 | Parameters
303 | ----------
304 | carry : tuple of (Array, Array, Array, Array, Array, float, Array, Array)
305 | Tuple containing the parent trees, child trees, and other parameters.
306 |
307 | Returns
308 | -------
309 | tuple of (Array, Array, Array, Array, Array, float, Array, Array)
310 | Updated tuple with the parent trees, child trees, and other parameters.
311 | """
312 | parent1, parent2, _, _, keys, reproduction_probability, node_ids, operator_indices = carry
313 | index_key, type_key, new_key = jr.split(keys[0, 0], 3)
314 | _, cx_indices, _ = jax.lax.while_loop(lambda carry: jnp.sum(carry[1]) == 0, sample_indices, (index_key, jnp.zeros(parent1.shape[0]), reproduction_probability))
315 | crossover_types = jr.bernoulli(type_key, p=0.9, shape=(parent1.shape[0],))
316 | offspring1, offspring2 = jax.vmap(crossover, in_axes=[0, 0, 0, None, None, 0])(parent1, parent2, keys, node_ids, operator_indices, crossover_types)
317 | child1 = jnp.where(cx_indices[:, None, None] * jnp.ones_like(parent1), offspring1, parent1)
318 | child2 = jnp.where(cx_indices[:, None, None] * jnp.ones_like(parent2), offspring2, parent2)
319 |
320 | keys = jr.split(new_key, keys.shape[:-1])
321 |
322 | return parent1, parent2, child1, child2, keys, reproduction_probability, node_ids, operator_indices
323 |
324 | def crossover_trees(parent1: Array,
325 | parent2: Array,
326 | keys: Array,
327 | reproduction_probability: float,
328 | max_nodes: int,
329 | operator_indices: Array) -> Tuple[Array, Array]:
330 | """
331 | Applies crossover to the trees in a pair of candidates.
332 |
333 | Parameters
334 | ----------
335 | parent1 : Array
336 | First parent candidate.
337 | parent2 : Array
338 | Second parent candidate.
339 | keys : Array
340 | Random keys.
341 | reproduction_probability : float
342 | Probability of a tree to be mutated.
343 | max_nodes : int
344 | Max number of nodes in a tree.
345 | operator_indices : Array
346 | The indices that belong to operator nodes.
347 |
348 | Returns
349 | -------
350 | tuple of (Array, Array)
351 | Pair of candidates after crossover.
352 | """
353 | _, _, child1, child2, _, _, _, _ = jax.lax.while_loop(check_different_trees, safe_crossover, (
354 | parent1, parent2, jnp.zeros_like(parent1), jnp.zeros_like(parent2), keys, reproduction_probability, jnp.arange(max_nodes), operator_indices))
355 |
356 | return child1, child2
--------------------------------------------------------------------------------
/kozax/genetic_operators/initialization.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import jax.random as jr
12 | from jax.random import PRNGKey
13 | from functools import partial
14 | from jax import Array
15 | from typing import Tuple, Callable
16 |
17 | def sample_node(i: int,
18 | carry: Tuple[PRNGKey, Array, int, int, int, Array, Tuple]) -> Tuple[PRNGKey, Array, int, int, int, Array, Tuple]:
19 | """Samples nodes sequentially in breadth-first order, storing them depth-first.
20 |
21 | Parameters
22 | ----------
23 | i : int
24 | Index of the node.
25 | carry : Tuple[PRNGKey, Array, int, int, int, Array, Tuple]
26 | Tuple containing the random key, tree, open slots, max init depth, max nodes, variable array, and other arguments.
27 |
28 | Returns
29 | -------
30 | Tuple[PRNGKey, Array, int, int, int, Array, Tuple]
31 | Updated tuple with the random key, tree, open slots, max init depth, max nodes, variable array, and other arguments.
32 | """
33 | key, tree, open_slots, max_init_depth, max_nodes, variable_array, args = carry
34 | variable_indices, operator_indices, operator_probabilities, slots, coefficient_sd, map_b_to_d = args
35 | coefficient_key, leaf_key, variable_key, node_key, operator_key = jr.split(key, 5)
36 | _i = map_b_to_d[i].astype(int) # Get depth first index
37 |
38 | depth = (jnp.log(i + 1 + 1e-10) / jnp.log(2)).astype(int) # Compute depth of node
39 | coefficient = jr.normal(coefficient_key) * coefficient_sd
40 | leaf = jax.lax.select(jr.uniform(leaf_key) < 0.5, 1, jr.choice(variable_key, variable_indices, shape=(), p=variable_array)) # Sample coefficient or variable
41 |
42 | index = jax.lax.select((open_slots < max_nodes - i - 1) & (depth + 1 < max_init_depth), # Check if max depth has been reached, or if the number of open slots reached the max number of nodes
43 | jax.lax.select(jr.uniform(node_key) < (0.7 ** depth), # At higher depth, a leaf node is more probable
44 | jr.choice(operator_key, a=operator_indices, shape=(), p=operator_probabilities),
45 | leaf),
46 | leaf)
47 |
48 | index = jax.lax.select(open_slots == 0, 0, index) # If there are no open slots, the node should be empty
49 |
50 | # If parent node is a leaf, the node should be empty
51 | index = jax.lax.select(i > 0, jax.lax.select((slots[jnp.maximum(tree[map_b_to_d[(i + (i % 2) - 2) // 2].astype(int), 0], 0).astype(int)] + i % 2) > 1, index, 0), index)
52 |
53 | # Set index references
54 | tree = jax.lax.select(slots[index] > 0, tree.at[_i, 1].set(map_b_to_d[2 * i + 1]), tree.at[_i, 1].set(-1))
55 | tree = jax.lax.select(slots[index] > 1, tree.at[_i, 2].set(map_b_to_d[2 * i + 2]), tree.at[_i, 2].set(-1))
56 |
57 | tree = jax.lax.select(index == 1, tree.at[_i, 3].set(coefficient), tree) # Set coefficient value
58 | tree = tree.at[_i, 0].set(index)
59 |
60 | open_slots = jax.lax.select(index == 0, open_slots, jnp.maximum(0, open_slots + slots[index] - 1)) # Update the number of open slots
61 |
62 | return (jr.fold_in(key, i), tree, open_slots, max_init_depth, max_nodes, variable_array, args)
63 |
64 | def prune_row(i: int,
65 | carry: Tuple[Array, int, int],
66 | old_tree: Array) -> Tuple[Array, int, int]:
67 | """Sequentially adds nodes to the new tree if it is not empty.
68 |
69 | Parameters
70 | ----------
71 | i : int
72 | Index of the node.
73 | carry : Tuple[Array, int, int]
74 | Tuple containing the tree, counter, and tree size.
75 | old_tree : Array
76 | Tree with empty nodes that have to be pruned.
77 |
78 | Returns
79 | -------
80 | Tuple[Array, int, int]
81 | Updated tuple with the tree, counter, and tree size.
82 | """
83 | tree, counter, tree_size = carry
84 | _i = tree_size - i - 1
85 | row = old_tree[_i]
86 |
87 | # If node is not empty, add node and update index references
88 | tree = jax.lax.select(row[0] != 0, tree.at[counter].set(row), tree.at[:, 1:3].set(jnp.where(tree[:, 1:3] > _i, tree[:, 1:3] - 1, tree[:, 1:3])))
89 | counter = jax.lax.select(row[0] != 0, counter - 1, counter)
90 |
91 | return (tree, counter, tree_size)
92 |
93 | def prune_tree(tree: Array,
94 | tree_size: int,
95 | max_nodes: int) -> Array:
96 | """Removes empty nodes from a tree. The new tree is filled with empty nodes at the end to match the max number of nodes.
97 |
98 | Parameters
99 | ----------
100 | tree : Array
101 | Tree to be pruned.
102 | tree_size : int
103 | Max size of the old tree.
104 | max_nodes : int
105 | Max number of nodes in the new tree.
106 |
107 | Returns
108 | -------
109 | Array
110 | Tree with empty nodes pruned.
111 | """
112 | tree, counter, _ = jax.lax.fori_loop(0, tree_size, partial(prune_row, old_tree=tree), (jnp.tile(jnp.array([0.0, -1.0, -1.0, 0.0]), (max_nodes, 1)), max_nodes - 1, tree_size))
113 | tree = tree.at[:, 1:3].set(jnp.where(tree[:, 1:3] > -1, tree[:, 1:3] + counter + 1, tree[:, 1:3])) # Update index references after pruning
114 | return tree
115 |
116 | def sample_tree(key: PRNGKey,
117 | depth: int,
118 | variable_array: Array,
119 | tree_size: int,
120 | max_nodes: int,
121 | args: Tuple) -> Array:
122 | """Initializes a tree.
123 |
124 | Parameters
125 | ----------
126 | key : PRNGKey
127 | Random key.
128 | depth : int
129 | Max depth in a tree at initialization.
130 | variable_array : Array
131 | The valid variables for this tree.
132 | tree_size : int
133 | Max size of the tree.
134 | max_nodes : int
135 | Max number of nodes in a tree.
136 | args : Tuple
137 | Miscellaneous parameters required for initialization.
138 |
139 | Returns
140 | -------
141 | Array
142 | Initialized tree.
143 | """
144 | # First sample tree at full size given depth
145 | tree = jax.lax.fori_loop(0, tree_size, sample_node, (key, jnp.zeros((tree_size, 4)), 1, depth, max_nodes, variable_array, args))[1] # Sample nodes in a tree sequentially
146 |
147 | # Prune empty rows in tree
148 | pruned_tree = prune_tree(tree, tree_size, max_nodes)
149 | return pruned_tree
150 |
151 | def sample_population(key: PRNGKey,
152 | population_size: int,
153 | num_trees: int,
154 | max_init_depth: int,
155 | variable_array: Array,
156 | sample_function: Callable) -> Array:
157 | """Initializes a population of candidates.
158 |
159 | Parameters
160 | ----------
161 | key : PRNGKey
162 | Random key.
163 | population_size : int
164 | Number of candidates that have to be sampled.
165 | num_trees : int
166 | Number of trees in a candidate.
167 | max_init_depth : int
168 | Max depth in a tree at initialization.
169 | variable_array : Array
170 | The valid variables for each tree.
171 | sample_function : Callable
172 | Function to sample a tree.
173 |
174 | Returns
175 | -------
176 | Array
177 | Population of candidates.
178 | """
179 | sample_candidate = lambda keys: jax.vmap(sample_function, in_axes=[0, None, 0])(keys, max_init_depth, variable_array)
180 | return jax.vmap(sample_candidate)(jr.split(key, (population_size, num_trees)))
--------------------------------------------------------------------------------
/kozax/genetic_operators/reproduction.py:
--------------------------------------------------------------------------------
1 | """
2 | kozax: Genetic programming framework in JAX
3 |
4 | Copyright (c) 2024 sdevries0
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License.
7 | """
8 |
9 | import jax
10 | from jax import Array
11 | import jax.numpy as jnp
12 | import jax.random as jr
13 | from jax.random import PRNGKey
14 | from typing import Callable, List, Tuple
15 |
16 | def evolve_trees(parent1: Array,
17 | parent2: Array,
18 | keys: Array,
19 | type: int,
20 | reproduction_probability: float,
21 | reproduction_functions: List[Callable]) -> Tuple[Array, Array]:
22 | """Applies reproduction function to pair of candidates.
23 |
24 | Parameters
25 | ----------
26 | parent1 : Array
27 | First parent candidate.
28 | parent2 : Array
29 | Second parent candidate.
30 | keys : Array
31 | Random keys.
32 | type : int
33 | Type of reproduction function to apply.
34 | reproduction_probability : float
35 | Probability of a tree to be mutated.
36 | reproduction_functions : List[Callable]
37 | Functions that can be applied for reproduction.
38 |
39 | Returns
40 | -------
41 | Tuple[Array, Array]
42 | Pair of reproduced candidates.
43 | """
44 | child0, child1 = jax.lax.switch(type, reproduction_functions, parent1, parent2, keys, reproduction_probability)
45 |
46 | return child0, child1
47 |
48 | def tournament_selection(population: Array,
49 | fitness: Array,
50 | key: PRNGKey,
51 | tournament_probabilities: Array,
52 | tournament_size: int,
53 | population_indices: Array) -> Array:
54 | """Selects a candidate for reproduction from a tournament.
55 |
56 | Parameters
57 | ----------
58 | population : Array
59 | Population of candidates.
60 | fitness : Array
61 | Fitness of candidates.
62 | key : PRNGKey
63 | Random key.
64 | tournament_probabilities : Array
65 | Probability of each of the ranks in the tournament to be selected for reproduction.
66 | tournament_size : int
67 | Size of the tournament.
68 | population_indices : Array
69 | Indices of the population.
70 |
71 | Returns
72 | -------
73 | Array
74 | Candidate that won the tournament.
75 | """
76 | tournament_key, winner_key = jr.split(key)
77 | indices = jr.choice(tournament_key, population_indices, shape=(tournament_size,))
78 |
79 | index = jr.choice(winner_key, indices[jnp.argsort(fitness[indices])], p=tournament_probabilities)
80 | return population[index]
81 |
82 | def evolve_population(population: Array,
83 | fitness: Array,
84 | key: PRNGKey,
85 | reproduction_type_probabilities: Array,
86 | reproduction_probability: float,
87 | tournament_probabilities: Array,
88 | population_indices: Array,
89 | population_size: int,
90 | tournament_size: int,
91 | num_trees: int,
92 | elite_size: int,
93 | reproduction_functions: List[Callable]) -> Array:
94 | """Reproduces pairs of candidates to obtain a new population.
95 |
96 | Parameters
97 | ----------
98 | population : Array
99 | Population of candidates.
100 | fitness : Array
101 | Fitness of candidates.
102 | key : PRNGKey
103 | Random key.
104 | reproduction_type_probabilities : Array
105 | Probability of each reproduction function to be applied.
106 | reproduction_probability : float
107 | Probability of a tree to be mutated.
108 | tournament_probabilities : Array
109 | Probability of each of the ranks in the tournament to be selected for reproduction.
110 | population_indices : Array
111 | Indices of the population.
112 | population_size : int
113 | Size of the population.
114 | tournament_size : int
115 | Size of the tournament.
116 | num_trees : int
117 | Number of trees in a candidate.
118 | elite_size : int
119 | Number of candidates that remain in the new population without reproduction.
120 | reproduction_functions : List[Callable]
121 | Functions that can be applied for reproduction.
122 |
123 | Returns
124 | -------
125 | Array
126 | Evolved population.
127 | """
128 | left_key, right_key, repro_key, evo_key = jr.split(key, 4)
129 | elite = population[jnp.argsort(fitness)[:elite_size]]
130 |
131 | # Sample parents for reproduction
132 | left_parents = jax.vmap(tournament_selection, in_axes=[None, None, 0, None, None, None])(population,
133 | fitness,
134 | jr.split(left_key, (population_size - elite_size)//2),
135 | tournament_probabilities,
136 | tournament_size,
137 | population_indices)
138 |
139 | right_parents = jax.vmap(tournament_selection, in_axes=[None, None, 0, None, None, None])(population,
140 | fitness,
141 | jr.split(right_key, (population_size - elite_size)//2),
142 | tournament_probabilities,
143 | tournament_size,
144 | population_indices)
145 | # Sample which reproduction function to apply to the parents
146 | reproduction_type = jr.choice(repro_key, jnp.arange(3), shape=((population_size - elite_size)//2,), p=reproduction_type_probabilities)
147 |
148 | left_children, right_children = jax.vmap(evolve_trees, in_axes=[0, 0, 0, 0, None, None])(left_parents,
149 | right_parents,
150 | jr.split(evo_key, ((population_size - elite_size)//2, num_trees, 2)),
151 | reproduction_type,
152 | reproduction_probability,
153 | reproduction_functions)
154 |
155 | evolved_population = jnp.concatenate([elite, left_children, right_children], axis=0)
156 | return evolved_population
157 |
158 | def migrate_population(receiver: Array,
159 | sender: Array,
160 | receiver_fitness: Array,
161 | sender_fitness: Array,
162 | migration_size: int,
163 | population_indices: Array) -> Tuple[Array, Array]:
164 | """Unfit candidates from one population are replaced with fit candidates from another population.
165 |
166 | Parameters
167 | ----------
168 | receiver : Array
169 | Population that receives new candidates.
170 | sender : Array
171 | Population that sends fit candidates.
172 | receiver_fitness : Array
173 | Fitness of the candidates in the receiving population.
174 | sender_fitness : Array
175 | Fitness of the candidates in the sending population.
176 | migration_size : int
177 | How many candidates are migrated to new population.
178 | population_indices : Array
179 | Indices of the population.
180 |
181 | Returns
182 | -------
183 | Tuple[Array, Array]
184 | Population after migration and their fitness.
185 | """
186 | sorted_receiver = receiver[jnp.argsort(receiver_fitness, descending=True)]
187 | sorted_sender = sender[jnp.argsort(sender_fitness, descending=False)]
188 | migrated_population = jnp.where((population_indices < migration_size)[:,None,None,None], sorted_sender, sorted_receiver)
189 | migrated_fitness = jnp.where(population_indices < migration_size, jnp.sort(sender_fitness, descending=False), jnp.sort(receiver_fitness, descending=True))
190 | return migrated_population, migrated_fitness
191 |
192 | def evolve_populations(jit_evolve_population: Callable,
193 | populations: Array,
194 | fitness: Array,
195 | key: PRNGKey,
196 | current_generation: int,
197 | migration_period: int,
198 | migration_size: int,
199 | reproduction_type_probabilities: Array,
200 | reproduction_probabilities: Array,
201 | tournament_probabilities: Array) -> Array:
202 | """Evolves each population independently.
203 |
204 | Parameters
205 | ----------
206 | jit_evolve_population : Callable
207 | Function for evolving trees that is jittable and parallelizable.
208 | populations : Array
209 | Populations of candidates.
210 | fitness : Array
211 | Fitness of candidates.
212 | key : PRNGKey
213 | Random key.
214 | current_generation : int
215 | Current generation number.
216 | migration_period : int
217 | After how many generations migration occurs.
218 | migration_size : int
219 | How many candidates are migrated to new population.
220 | reproduction_type_probabilities : Array
221 | Probability of each reproduction function to be applied.
222 | reproduction_probabilities : Array
223 | Probability of a tree to be mutated.
224 | tournament_probabilities : Array
225 | Probability of each of the ranks in the tournament to be selected for reproduction.
226 |
227 | Returns
228 | -------
229 | Array
230 | Evolved populations.
231 | """
232 | num_populations, population_size, num_trees, _, _ = populations.shape
233 | population_indices = jnp.arange(population_size)
234 |
235 | # Migrate candidates between populations. The populations and fitnesses are rolled for circular migration.
236 | populations, fitness = jax.lax.cond((num_populations > 1) & (((current_generation+1)%migration_period) == 0),
237 | jax.vmap(migrate_population, in_axes=[0, 0, 0, 0, None, None]),
238 | jax.vmap(lambda receiver, sender, receiver_fitness, sender_fitness, migration_size, population_indices: (receiver, receiver_fitness), in_axes=[0, 0, 0, 0, None, None]),
239 | populations,
240 | jnp.roll(populations, 1, axis=0),
241 | fitness,
242 | jnp.roll(fitness, 1, axis=0),
243 | migration_size,
244 | population_indices)
245 |
246 | new_population = jit_evolve_population(populations,
247 | fitness,
248 | jr.split(key, num_populations),
249 | reproduction_type_probabilities,
250 | reproduction_probabilities,
251 | tournament_probabilities,
252 | population_indices)
253 | return new_population
--------------------------------------------------------------------------------