├── .gitignore ├── LICENSE.md ├── README.md ├── __init__.py ├── examples ├── example_notebooks │ ├── README.md │ ├── control_policy_optimization.ipynb │ ├── control_with_memory_SHO.ipynb │ ├── control_with_memory_acrobot.ipynb │ ├── objective_function_optimization.ipynb │ └── symbolic_regression_dynamical_system.ipynb └── example_scripts │ ├── control_policy_optimization.py │ ├── control_with_memory_SHO.py │ ├── control_with_memory_acrobot.py │ ├── objective_function_optimization.py │ └── symbolic_regression_dynamical_system.py ├── figures ├── applications.jpg └── logo.jpg └── kozax ├── environments ├── SR_environments │ ├── lorenz_attractor.py │ ├── lotka_volterra.py │ ├── time_series_environment_base.py │ └── vd_pol_oscillator.py └── control_environments │ ├── acrobot.py │ ├── cart_pole.py │ ├── control_environment_base.py │ ├── harmonic_oscillator.py │ └── reactor.py ├── fitness_functions ├── Gymnax_fitness_function.py ├── ODE_fitness_function.py ├── SR_fitness_function.py └── base_fitness_function.py ├── genetic_operators ├── crossover.py ├── initialization.py ├── mutation.py └── reproduction.py └── genetic_programming.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | version.py 3 | pyproject.toml 4 | dist/* -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | kozax: Genetic programming framework in JAX 2 | 3 | Copyright (c) 2024 sdevries0 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 6 | 7 | You are free to: 8 | Share — copy and redistribute the material in any medium or format 9 | The licensor cannot revoke these freedoms as long as you follow the license terms. 10 | 11 | Under the following terms: 12 | Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use. 13 | NonCommercial — You may not use the material for commercial purposes . 14 | NoDerivatives — If you remix, transform, or build upon the material, you may not distribute the modified material. 15 | No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits. 16 | 17 | Notices: 18 | You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation . 19 | No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material. 20 | 21 | License Details: 22 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-nd/4.0/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # Kozax: Flexible and Scalable Genetic Programming in JAX 6 | Kozax introduces a general framework for evolving computer programs with genetic programming in JAX. With JAX, the computer programs can be vectorized and evaluated on parallel on CPU and GPU. Furthermore, just-in-time compilation provides massive speedups for evolving offspring. Check out the [paper](https://arxiv.org/abs/2502.03047) introducing Kozax. 7 | 8 | # Features 9 | Kozax allows the user to: 10 | - define custom operators 11 | - define custom fitness functions 12 | - use trees flexibly, ranging from symbolic regression to reinforcement learning 13 | - evolve multiple trees simultaneously, even with different inputs 14 | - numerically optimise constants in the computer programs 15 | 16 | # How to use 17 | You can install Kozax via pip with 18 | ``` 19 | pip install kozax 20 | ``` 21 | 22 | Below is a short demo showing how you can use kozax. First we generate data: 23 | ```python 24 | import jax 25 | import jax.numpy as jnp 26 | import jax.random as jr 27 | 28 | key = jr.PRNGKey(0) #Initialize key 29 | data_key, gp_key = jr.split(key) #Split key for data and genetic programming 30 | x = jr.uniform(data_key, shape=(30,), minval=-5, maxval = 5) #Inputs 31 | y = -0.1*x**3 + 0.3*x**2 + 1.5*x #Targets 32 | ``` 33 | 34 | Now we have to define a fitness function. This allows for much freedom, because you can use the computer program anyway you want to during evaluation. 35 | ```python 36 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction 37 | 38 | class FitnessFunction(BaseFitnessFunction): 39 | """ 40 | The fitness function inherits the class BaseFitnessFunction and should implement the __call__ function, with the candidate, data and tree_evaluator as inputs. The tree_evaluator is used to compute the value of the candidate for each input. jax.vmap is used to vectorize the evaluation of the candidate over the inputs. The candidate's predictions are used to compute the fitness value with the mean squared error. 41 | """ 42 | def __call__(self, candidate, data, tree_evaluator): 43 | X, Y = data 44 | predictions = jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, X) 45 | return jnp.mean(jnp.square(predictions-Y)) 46 | 47 | fitness_function = FitnessFunction() 48 | ``` 49 | 50 | Now we will use genetic programming to recover the equation from the data. This requires defining the hyperparameters, initializing the population and the general loop of evaluating and evolving the population. 51 | ```python 52 | from kozax.genetic_programming import GeneticProgramming 53 | 54 | #Define hyperparameters 55 | population_size = 500 56 | num_generations = 100 57 | 58 | #Initialize genetic programming strategy 59 | strategy = GeneticProgramming(num_generations, population_size, fitness_function) 60 | 61 | #Fit the strategy on the data. With verbose, we can print the intermediate solutions. 62 | strategy.fit(gp_key, (x, y), verbose = True) 63 | ``` 64 | 65 | There are additional [examples](https://github.com/sdevries0/kozax/tree/main/examples) on how to use kozax on more complex problems. 66 | 67 | |Example|Notebook|Script| 68 | |---|---|---| 69 | |Symbolic regression of a dynamical system|[Notebook](examples/example_notebooks/symbolic_regression_dynamical_system.ipynb)|[Script](examples/example_scripts/symbolic_regression_dynamical_system.py)| 70 | |Control policy optimization in Gymnax environment|[Notebook](examples/example_notebooks/control_policy_optimization.ipynb)|[Script](examples/example_scripts/control_policy_optimization.py)| 71 | |Control policy optimization with dynamic memory|[Notebook](examples/example_notebooks/control_policy_optimization_with_memory.ipynb)|[Script](examples/example_scripts/control_policy_optimization_with_memory.py)| 72 | |Optimization of a loss function to train a neural network|[Notebook](examples/example_notebooks/objective_function_optimization.ipynb)|[Script](examples/example_scripts/objective_function_optimization.py)| 73 | 74 | 75 | # Citation 76 | If you make use of this code in your research paper, please cite: 77 | ``` 78 | @article{de2025kozax, 79 | title={Kozax: Flexible and Scalable Genetic Programming in JAX}, 80 | author={de Vries, Sigur and Keemink, Sander W and van Gerven, Marcel AJ}, 81 | journal={arXiv preprint arXiv:2502.03047}, 82 | year={2025} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ -------------------------------------------------------------------------------- /examples/example_notebooks/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | Here you can find several examples of applications of Kozax to various problems. These problems can easily be extended for other use cases. -------------------------------------------------------------------------------- /examples/example_notebooks/objective_function_optimization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Objective function optimization\n", 8 | "\n", 9 | "In this example, Kozax is used to evolve a symbolic loss function to train a neural network. With each candidate loss function, a neural network is trained on the task of binary classification of XOR data points." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "These device(s) are detected: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "# Specify the cores to use for XLA\n", 27 | "import os\n", 28 | "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=10'\n", 29 | "\n", 30 | "import jax\n", 31 | "import jax.numpy as jnp\n", 32 | "import jax.random as jr\n", 33 | "import optax \n", 34 | "from typing import Callable, Tuple\n", 35 | "from jax import Array\n", 36 | "\n", 37 | "from kozax.genetic_programming import GeneticProgramming" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "We define a fitness function class that includes the network initialization, training loop and weight updates. At every epoch, a new batch of data is sampled, and the fitness is computed as the accuracy of the trained network on a validation set." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "class FitnessFunction:\n", 54 | " \"\"\"\n", 55 | " A class to define the fitness function for evaluating candidate loss functions.\n", 56 | " The fitness is computed as the accuracy of a neural network trained with the candidate loss function\n", 57 | " on a binary classification task (XOR data).\n", 58 | "\n", 59 | " Attributes:\n", 60 | " input_dim (int): Dimension of the input data.\n", 61 | " hidden_dim (int): Dimension of the hidden layers in the neural network.\n", 62 | " output_dim (int): Dimension of the output.\n", 63 | " epochs (int): Number of training epochs.\n", 64 | " learning_rate (float): Learning rate for the optimizer.\n", 65 | " optim (optax.GradientTransformation): Optax optimizer instance.\n", 66 | " \"\"\"\n", 67 | " def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, epochs: int, learning_rate: float):\n", 68 | " self.input_dim = input_dim\n", 69 | " self.hidden_dim = hidden_dim\n", 70 | " self.output_dim = output_dim\n", 71 | " self.optim = optax.adam(learning_rate)\n", 72 | " self.epochs = epochs\n", 73 | "\n", 74 | " def __call__(self, candidate: str, data: Tuple[Array, Array, Array], tree_evaluator: Callable) -> Array:\n", 75 | " \"\"\"\n", 76 | " Computes the fitness of a candidate loss function.\n", 77 | "\n", 78 | " Args:\n", 79 | " candidate: The candidate loss function (symbolic tree).\n", 80 | " data (tuple): A tuple containing the data keys, test keys, and network keys.\n", 81 | " tree_evaluator: A function to evaluate the symbolic tree.\n", 82 | "\n", 83 | " Returns:\n", 84 | " Array: The mean loss (1 - accuracy) on the validation set.\n", 85 | " \"\"\"\n", 86 | " data_keys, test_keys, network_keys = data\n", 87 | " losses = jax.vmap(self.train, in_axes=[None, 0, 0, 0, None])(candidate, data_keys, test_keys, network_keys, tree_evaluator)\n", 88 | " return jnp.mean(losses)\n", 89 | "\n", 90 | " def get_data(self, key: jr.PRNGKey, n_samples: int = 50) -> Tuple[Array, Array]:\n", 91 | " \"\"\"\n", 92 | " Generates XOR data.\n", 93 | "\n", 94 | " Args:\n", 95 | " key (jax.random.PRNGKey): Random key for data generation.\n", 96 | " n_samples (int): Number of samples to generate.\n", 97 | "\n", 98 | " Returns:\n", 99 | " tuple: A tuple containing the input data (x) and the target labels (y).\n", 100 | " \"\"\"\n", 101 | " x = jr.uniform(key, shape=(n_samples, 2))\n", 102 | " y = jnp.logical_xor(x[:,0]>0.5, x[:,1]>0.5)\n", 103 | "\n", 104 | " return x, y[:,None]\n", 105 | "\n", 106 | " def loss_function(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array, y: Array, candidate: str, tree_evaluator: Callable) -> Array:\n", 107 | " \"\"\"\n", 108 | " Computes the loss with an evolved loss function for a given set of parameters and data.\n", 109 | "\n", 110 | " Args:\n", 111 | " params (tuple): The parameters of the neural network.\n", 112 | " x (Array): The input data.\n", 113 | " y (Array): The target labels.\n", 114 | " candidate: The candidate loss function (symbolic tree).\n", 115 | " tree_evaluator: A function to evaluate the symbolic tree.\n", 116 | "\n", 117 | " Returns:\n", 118 | " Array: The mean loss.\n", 119 | " \"\"\"\n", 120 | " pred = self.neural_network(params, x)\n", 121 | " return jnp.mean(jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, jnp.concatenate([pred, y], axis=-1)))\n", 122 | " \n", 123 | " def train(self, candidate: str, data_key: jr.PRNGKey, test_key: jr.PRNGKey, network_key: jr.PRNGKey, tree_evaluator: Callable) -> Array:\n", 124 | " \"\"\"\n", 125 | " Trains a neural network with a given candidate loss function.\n", 126 | "\n", 127 | " Args:\n", 128 | " candidate: The candidate loss function (symbolic tree).\n", 129 | " data_key (jax.random.PRNGKey): Random key for data generation during training.\n", 130 | " test_key (jax.random.PRNGKey): Random key for data generation during testing.\n", 131 | " network_key (jax.random.PRNGKey): Random key for initializing the network parameters.\n", 132 | " tree_evaluator: A function to evaluate the symbolic tree.\n", 133 | "\n", 134 | " Returns:\n", 135 | " Array: The validation loss (1 - accuracy).\n", 136 | " \"\"\"\n", 137 | " params = self.init_network_params(network_key)\n", 138 | "\n", 139 | " optim_state = self.optim.init(params)\n", 140 | "\n", 141 | " def step(i: int, carry: Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]) -> Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]:\n", 142 | " params, optim_state, key = carry\n", 143 | "\n", 144 | " key, _key = jr.split(key)\n", 145 | "\n", 146 | " x_train, y_train = self.get_data(_key, n_samples=50)\n", 147 | "\n", 148 | " # Evaluate network parameters and compute gradients\n", 149 | " grads = jax.grad(self.loss_function)(params, x_train, y_train, candidate, tree_evaluator)\n", 150 | " \n", 151 | " # Update parameters\n", 152 | " updates, optim_state = self.optim.update(grads, optim_state, params)\n", 153 | " params = optax.apply_updates(params, updates)\n", 154 | "\n", 155 | " return (params, optim_state, key)\n", 156 | "\n", 157 | " (params, _, _) = jax.lax.fori_loop(0, self.epochs, step, (params, optim_state, data_key))\n", 158 | "\n", 159 | " # Evaluate parameters on test set\n", 160 | " x_test, y_test = self.get_data(test_key, n_samples=500)\n", 161 | "\n", 162 | " pred = self.neural_network(params, x_test)\n", 163 | " return 1 - jnp.mean(y_test==(pred>0.5)) # Return 1 - accuracy\n", 164 | "\n", 165 | " def neural_network(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array) -> Array:\n", 166 | " \"\"\"\n", 167 | " Defines the neural network architecture (forward pass).\n", 168 | "\n", 169 | " Args:\n", 170 | " params (tuple): The parameters of the neural network.\n", 171 | " x (Array): The input data.\n", 172 | "\n", 173 | " Returns:\n", 174 | " Array: The output of the neural network.\n", 175 | " \"\"\"\n", 176 | " w1, b1, w2, b2, w3, b3 = params\n", 177 | " hidden = jnp.tanh(jnp.dot(x, w1) + b1)\n", 178 | " hidden = jnp.tanh(jnp.dot(hidden, w2) + b2)\n", 179 | " output = jnp.dot(hidden, w3) + b3\n", 180 | " return jax.nn.sigmoid(output)\n", 181 | "\n", 182 | " def init_network_params(self, key: jr.PRNGKey) -> Tuple[Array, Array, Array, Array, Array, Array]:\n", 183 | " \"\"\"\n", 184 | " Initializes the parameters of the neural network.\n", 185 | "\n", 186 | " Args:\n", 187 | " key (jax.random.PRNGKey): Random key for parameter initialization.\n", 188 | "\n", 189 | " Returns:\n", 190 | " tuple: A tuple containing the initialized weights and biases.\n", 191 | " \"\"\"\n", 192 | " key1, key2, key3 = jr.split(key, 3)\n", 193 | " w1 = jr.normal(key1, (self.input_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.input_dim)\n", 194 | " b1 = jnp.zeros(self.hidden_dim)\n", 195 | " w2 = jr.normal(key2, (self.hidden_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.hidden_dim)\n", 196 | " b2 = jnp.zeros(self.hidden_dim)\n", 197 | " w3 = jr.normal(key3, (self.hidden_dim, self.output_dim)) * jnp.sqrt(2.0 / self.hidden_dim)\n", 198 | " b3 = jnp.zeros(self.output_dim)\n", 199 | " return (w1, b1, w2, b2, w3, b3)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": {}, 205 | "source": [ 206 | "To make sure the optimized loss function generalizes, a batch of neural networks are trained with different data and weight initialization. For this purpose, a batch of keys for initialization, data sampling and validation data are generated." 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 3, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "def generate_keys(key, batch_size=4):\n", 216 | " key1, key2, key3 = jr.split(key, 3)\n", 217 | " return jr.split(key1, batch_size), jr.split(key2, batch_size), jr.split(key3, batch_size)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "Here we define the hyperparameters and inputs to the genetic programming algorithm. The inputs to the trees are the prediction and target value. " 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 4, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "Input data should be formatted as: ['pred', 'y'].\n", 237 | "In generation 1, best fitness = 0.4820, best solution = log(pred*y + pred - y)\n", 238 | "In generation 2, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n", 239 | "In generation 3, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n", 240 | "In generation 4, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n", 241 | "In generation 5, best fitness = 0.4560, best solution = y*(y - 0.257)*log(pred*y + pred - y)\n", 242 | "In generation 6, best fitness = 0.3335, best solution = 2*pred*(0.425 - y)\n", 243 | "In generation 7, best fitness = 0.3335, best solution = 2*pred*(0.425 - y)\n", 244 | "In generation 8, best fitness = 0.3335, best solution = pred*(0.425 - y)\n", 245 | "In generation 9, best fitness = 0.3335, best solution = pred*(0.425 - y)\n", 246 | "In generation 10, best fitness = 0.3300, best solution = (0.425 - y)*log(pred + 0.793)\n", 247 | "In generation 11, best fitness = 0.1045, best solution = 3*pred*(pred - y - 0.597)\n", 248 | "In generation 12, best fitness = 0.1045, best solution = 3*pred*(pred - y - 0.597)\n", 249 | "In generation 13, best fitness = 0.1045, best solution = 3*pred*(pred - y - 0.597)\n", 250 | "In generation 14, best fitness = 0.0955, best solution = (2*pred + 0.107)*(pred - y - 0.597)\n", 251 | "In generation 15, best fitness = 0.0955, best solution = (2*pred + 0.107)*(pred - y - 0.597)\n", 252 | "In generation 16, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n", 253 | "In generation 17, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n", 254 | "In generation 18, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n", 255 | "In generation 19, best fitness = 0.0915, best solution = (pred - 2*y + 0.107)*log(pred + 0.793)\n", 256 | "In generation 20, best fitness = 0.0895, best solution = (pred + 0.107)*(pred - 2*y - 0.103)\n", 257 | "In generation 21, best fitness = 0.0865, best solution = pred*(pred - y - log(pred + 0.793) + 0.107)\n", 258 | "In generation 22, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n", 259 | "In generation 23, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n", 260 | "In generation 24, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n", 261 | "In generation 25, best fitness = 0.0850, best solution = (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n", 262 | "Complexity: 1, fitness: 0.4909999668598175, equations: nan\n", 263 | "Complexity: 4, fitness: 0.4714999794960022, equations: log(0.25 - pred)\n", 264 | "Complexity: 6, fitness: 0.4139999747276306, equations: log(-pred + y - 0.496)\n", 265 | "Complexity: 7, fitness: 0.119999960064888, equations: pred*(pred - y - 0.354)\n", 266 | "Complexity: 9, fitness: 0.09499996155500412, equations: pred*(pred - y - 0.457)\n", 267 | "Complexity: 11, fitness: 0.09149995446205139, equations: (pred - 0.211)*(pred - 2*y + 0.107)\n", 268 | "Complexity: 12, fitness: 0.08649995177984238, equations: pred*(pred - y - log(pred + 0.793) + 0.107)\n", 269 | "Complexity: 14, fitness: 0.0849999487400055, equations: (pred - 0.211)*(pred - y - log(pred + 0.793) + 0.107)\n" 270 | ] 271 | } 272 | ], 273 | "source": [ 274 | "key = jr.PRNGKey(0)\n", 275 | "data_key, gp_key = jr.split(key)\n", 276 | "\n", 277 | "population_size = 50\n", 278 | "num_populations = 5\n", 279 | "num_generations = 25\n", 280 | "\n", 281 | "operator_list = [(\"+\", lambda x, y: jnp.add(x, y), 2, 0.5), \n", 282 | " (\"-\", lambda x, y: jnp.subtract(x, y), 2, 0.5),\n", 283 | " (\"*\", lambda x, y: jnp.multiply(x, y), 2, 0.5),\n", 284 | " (\"log\", lambda x: jnp.log(x + 1e-7), 1, 0.1),\n", 285 | " ]\n", 286 | "\n", 287 | "variable_list = [[\"pred\", \"y\"]]\n", 288 | "\n", 289 | "input_dim = 2\n", 290 | "hidden_dim = 16\n", 291 | "output_dim = 1\n", 292 | "\n", 293 | "fitness_function = FitnessFunction(input_dim, hidden_dim, output_dim, learning_rate=0.01, epochs=100)\n", 294 | "\n", 295 | "strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, num_populations = num_populations)\n", 296 | "\n", 297 | "data_keys, test_keys, network_keys = generate_keys(data_key)\n", 298 | "\n", 299 | "strategy.fit(gp_key, (data_keys, test_keys, network_keys), verbose=True)" 300 | ] 301 | } 302 | ], 303 | "metadata": { 304 | "kernelspec": { 305 | "display_name": "Python 3", 306 | "language": "python", 307 | "name": "python3" 308 | }, 309 | "language_info": { 310 | "codemirror_mode": { 311 | "name": "ipython", 312 | "version": 3 313 | }, 314 | "file_extension": ".py", 315 | "mimetype": "text/x-python", 316 | "name": "python", 317 | "nbconvert_exporter": "python", 318 | "pygments_lexer": "ipython3", 319 | "version": "3.11.4" 320 | }, 321 | "orig_nbformat": 4 322 | }, 323 | "nbformat": 4, 324 | "nbformat_minor": 2 325 | } 326 | -------------------------------------------------------------------------------- /examples/example_notebooks/symbolic_regression_dynamical_system.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Symbolic regression of a dynamical system\n", 8 | "\n", 9 | "In this example, Kozax is applied to recover the state equations of the Lotka-Volterra system. The candidate solutions are integrated as a system of differential equations, after which the predictions are compared to the true observations to determine a fitness score." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "These device(s) are detected: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "# Specify the cores to use for XLA\n", 27 | "import os\n", 28 | "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=10'\n", 29 | "\n", 30 | "import jax\n", 31 | "import diffrax\n", 32 | "import jax.numpy as jnp\n", 33 | "import jax.random as jr\n", 34 | "import diffrax\n", 35 | "\n", 36 | "from kozax.genetic_programming import GeneticProgramming\n", 37 | "from kozax.fitness_functions.ODE_fitness_function import ODEFitnessFunction\n", 38 | "from kozax.environments.SR_environments.lotka_volterra import LotkaVolterra" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "First the data is generated, consisting of initial conditions, time points and the true observations. Kozax provides the Lotka-Volterra environment, which is integrated with Diffrax." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 2, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "def get_data(key, env, dt, T, batch_size=20):\n", 55 | " x0s = env.sample_init_states(batch_size, key)\n", 56 | " ts = jnp.arange(0, T, dt)\n", 57 | "\n", 58 | " def solve(env, ts, x0):\n", 59 | " solver = diffrax.Dopri5()\n", 60 | " dt0 = 0.001\n", 61 | " saveat = diffrax.SaveAt(ts=ts)\n", 62 | "\n", 63 | " system = diffrax.ODETerm(env.drift)\n", 64 | "\n", 65 | " # Solve the system given an initial conditions\n", 66 | " sol = diffrax.diffeqsolve(system, solver, ts[0], ts[-1], dt0, x0, saveat=saveat, max_steps=500, \n", 67 | " adjoint=diffrax.DirectAdjoint(), stepsize_controller=diffrax.PIDController(atol=1e-7, rtol=1e-7, dtmin=0.001))\n", 68 | " \n", 69 | " return sol.ys\n", 70 | "\n", 71 | " ys = jax.vmap(solve, in_axes=[None, None, 0])(env, ts, x0s) #Parallelize over the batch dimension\n", 72 | " \n", 73 | " return x0s, ts, ys\n", 74 | "\n", 75 | "key = jr.PRNGKey(0)\n", 76 | "data_key, gp_key = jr.split(key)\n", 77 | "\n", 78 | "T = 30\n", 79 | "dt = 0.2\n", 80 | "env = LotkaVolterra()\n", 81 | "\n", 82 | "# Simulate the data\n", 83 | "data = get_data(data_key, env, dt, T, batch_size=4)\n", 84 | "x0s, ts, ys = data" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "For the fitness function, we used the ODEFitnessFunction that uses Diffrax to integrate candidate solutions. It is possible to select the solver, time step, number of steps and a stepsize controller to balance efficiency and accuracy. To ensure convergence of the genetic programming algorithm, constant optimization is applied to the best candidates at every generation. The constant optimization is performed with a couple of simple evolutionary steps that adjust the values of the constants in a candidate. The hyperparameters that define the constant optimization are `constant_optimization_N_offspring` (number of candidates with different constants should be sampled for each candidate), `constant_optimization_steps` (number of iterations of constant optimization for each candidate), `optimize_constants_elite` (number of candidates that constant optimization is applied to), `constant_step_size_init` (initial value of the step size for sampling constants) and `constant_step_size_decay` (the rate of decrease of the step size over generations)." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 3, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "Input data should be formatted as: ['x0', 'x1'].\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "#Define the nodes and hyperparameters\n", 109 | "operator_list = [\n", 110 | " (\"+\", lambda x, y: jnp.add(x, y), 2, 0.5), \n", 111 | " (\"-\", lambda x, y: jnp.subtract(x, y), 2, 0.1), \n", 112 | " (\"*\", lambda x, y: jnp.multiply(x, y), 2, 0.5), \n", 113 | " ]\n", 114 | "\n", 115 | "variable_list = [[\"x\" + str(i) for i in range(env.n_var)]]\n", 116 | "layer_sizes = jnp.array([env.n_var])\n", 117 | "\n", 118 | "population_size = 100\n", 119 | "num_populations = 10\n", 120 | "num_generations = 50\n", 121 | "\n", 122 | "#Initialize the fitness function and the genetic programming strategy\n", 123 | "fitness_function = ODEFitnessFunction(solver=diffrax.Dopri5(), dt0 = 0.01, stepsize_controller=diffrax.PIDController(atol=1e-6, rtol=1e-6, dtmin=0.001), max_steps=300)\n", 124 | "\n", 125 | "strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,\n", 126 | " size_parsimony=0.003, constant_optimization_method=\"evolution\", constant_optimization_N_offspring = 50, constant_optimization_steps = 3, \n", 127 | " optimize_constants_elite=100, constant_step_size_init=0.1, constant_step_size_decay=0.99)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "Kozax provides a fit function that receives the data and a random key. However, it is also possible to run Kozax with an easy loop consisting of evaluating and evolving. This is useful as different input data can be provided during evaluation. In symbolic regression of dynamical systems, it helps to first optimize on a small part of the time points, and provide the full data trajectories only after a couple of generations." 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 4, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "In generation 1, best fitness = 1.3755, best solution = [1.29 - 2.42*x0, 0.0769 - 0.336*x1]\n", 147 | "In generation 2, best fitness = 1.3496, best solution = [-0.893*x0*x1 + 1.1, 1.6*x0 - 2.06]\n", 148 | "In generation 3, best fitness = 1.3289, best solution = [-0.331*x0*x1, x0 - 0.413*x1 + 0.0743]\n", 149 | "In generation 4, best fitness = 1.3097, best solution = [x1*(-0.35*x0 - 0.246) + 1.23, 0.364 - 0.36*x1]\n", 150 | "In generation 5, best fitness = 1.2857, best solution = [-1.6*x0 - 0.248*x1 + 1.4, x0 - 0.415*x1 - 0.35]\n", 151 | "In generation 6, best fitness = 1.2720, best solution = [-1.72*x0 - 0.245*x1 + 1.42, 0.822*x0 - 0.384*x1 - 0.265]\n", 152 | "In generation 7, best fitness = 1.2277, best solution = [-1.16*x0*x1 + x0 + 0.551, 0.107*x0 - 0.228*x1 - 0.16]\n", 153 | "In generation 8, best fitness = 1.2014, best solution = [-0.639*x0*x1 + x0 + 0.461, 0.146*x0 - 0.197*x1 - 0.208]\n", 154 | "In generation 9, best fitness = 0.9929, best solution = [-2.6*x0*(-0.0276*x0 + 0.273*x1) + x0 + 0.461, 0.146*x0 - 0.197*x1 - 0.208]\n", 155 | "In generation 10, best fitness = 0.8489, best solution = [-2.56*x0*(-0.0312*x0 + 0.256*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n", 156 | "In generation 11, best fitness = 0.8489, best solution = [-2.56*x0*(-0.0312*x0 + 0.256*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n", 157 | "In generation 12, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n", 158 | "In generation 13, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n", 159 | "In generation 14, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n", 160 | "In generation 15, best fitness = 0.8420, best solution = [-2.56*x0*(-0.0312*x0 + 0.238*x1) + x0 + 0.245, 0.172*x0 - 0.197*x1 - 0.239]\n", 161 | "In generation 16, best fitness = 0.8212, best solution = [-2.62*x0*(0.158*x1 + 0.101) + x0 + 0.098, 0.175*x0 - 0.0532*x1 - 0.88]\n", 162 | "In generation 17, best fitness = 0.7373, best solution = [-0.74*x0*(x1 - 2.42), 0.122*x0*x1 - 0.00769*x0 - 0.36*x1]\n", 163 | "In generation 18, best fitness = 0.4043, best solution = [-0.42*x0*(x1 - 2.5), 0.0866*x0*x1 - 0.406*x1]\n", 164 | "In generation 19, best fitness = 0.2175, best solution = [-0.391*x0*(x1 - 2.5), 0.11*x0*x1 - 0.458*x1 + 0.0161]\n", 165 | "In generation 20, best fitness = 0.2129, best solution = [-0.402*x0*(x1 - 2.46), 0.117*x0*x1 - 0.463*x1]\n", 166 | "In generation 21, best fitness = 0.2129, best solution = [-0.402*x0*(x1 - 2.46), 0.117*x0*x1 - 0.463*x1]\n", 167 | "In generation 22, best fitness = 0.2129, best solution = [-0.402*x0*(x1 - 2.46), 0.117*x0*x1 - 0.463*x1]\n", 168 | "In generation 23, best fitness = 0.1362, best solution = [-0.409*x0*(x1 - 2.75), 0.0999*x0*x1 - 0.403*x1 + 0.0227]\n", 169 | "In generation 24, best fitness = 0.1362, best solution = [-0.409*x0*(x1 - 2.75), 0.0999*x0*x1 - 0.403*x1 + 0.0227]\n", 170 | "In generation 25, best fitness = 0.1362, best solution = [-0.409*x0*(x1 - 2.75), 0.0999*x0*x1 - 0.403*x1 + 0.0227]\n", 171 | "In generation 26, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 172 | "In generation 27, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 173 | "In generation 28, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 174 | "In generation 29, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 175 | "In generation 30, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 176 | "In generation 31, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 177 | "In generation 32, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 178 | "In generation 33, best fitness = 0.1218, best solution = [-0.393*x0*(x1 - 2.75), 0.102*x0*x1 - 0.406*x1]\n", 179 | "In generation 34, best fitness = 0.1189, best solution = [-0.41*x0*(x1 - 2.7), 0.105*x0*x1 - 0.407*x1]\n", 180 | "In generation 35, best fitness = 0.0978, best solution = [-0.395*x0*(x1 - 2.74), 0.0977*x0*x1 - 0.402*x1]\n", 181 | "In generation 36, best fitness = 0.0978, best solution = [-0.395*x0*(x1 - 2.74), 0.0977*x0*x1 - 0.402*x1]\n", 182 | "In generation 37, best fitness = 0.0978, best solution = [-0.395*x0*(x1 - 2.74), 0.0977*x0*x1 - 0.402*x1]\n", 183 | "In generation 38, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 184 | "In generation 39, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 185 | "In generation 40, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 186 | "In generation 41, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 187 | "In generation 42, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 188 | "In generation 43, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 189 | "In generation 44, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 190 | "In generation 45, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 191 | "In generation 46, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 192 | "In generation 47, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 193 | "In generation 48, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 194 | "In generation 49, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n", 195 | "In generation 50, best fitness = 0.0761, best solution = [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "# Sample the initial population\n", 201 | "population = strategy.initialize_population(gp_key)\n", 202 | "\n", 203 | "# Define the number of timepoints to include in the data\n", 204 | "end_ts = int(ts.shape[0]/2)\n", 205 | "\n", 206 | "for g in range(num_generations):\n", 207 | " if g == 25: # After 25 generations, use the full data\n", 208 | " end_ts = ts.shape[0]\n", 209 | "\n", 210 | " key, eval_key, sample_key = jr.split(key, 3)\n", 211 | " # Evaluate the population on the data, and return the fitness\n", 212 | " fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key)\n", 213 | "\n", 214 | " # Print the best solution in the population in this generation\n", 215 | " best_fitness, best_solution = strategy.get_statistics(g)\n", 216 | " print(f\"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}\")\n", 217 | "\n", 218 | " # Evolve the population until the last generation. The fitness should be given to the evolve function.\n", 219 | " if g < (num_generations-1):\n", 220 | " population = strategy.evolve_population(population, fitness, sample_key)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 5, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "Complexity: 2, fitness: 2.8892037868499756, equations: [-0.776, -0.763]\n", 233 | "Complexity: 4, fitness: 1.9611964225769043, equations: [-2.14*x0, -0.682]\n", 234 | "Complexity: 6, fitness: 1.3917429447174072, equations: [-1.98*x0, -0.285*x1]\n", 235 | "Complexity: 8, fitness: 1.3050819635391235, equations: [-2.23*x0, x0 - 0.41*x1]\n", 236 | "Complexity: 10, fitness: 1.2563836574554443, equations: [0.835 - 2.31*x0, x0 - 0.537*x1]\n", 237 | "Complexity: 12, fitness: 1.2215551137924194, equations: [1.11 - 2.31*x0, x0 - 0.45*x1 - 0.276]\n", 238 | "Complexity: 14, fitness: 0.8544017672538757, equations: [-0.845*x0*(x1 - 2.48), 0.141*x0 - 0.271*x1]\n", 239 | "Complexity: 16, fitness: 0.06887245923280716, equations: [-0.415*x0*(x1 - 2.7), 0.104*x0*x1 - 0.404*x1]\n", 240 | "Complexity: 18, fitness: 0.06406070291996002, equations: [-0.398*x0*(x1 - 2.67), 0.103*x0*x1 - 0.415*x1 - 0.000306]\n", 241 | "Complexity: 20, fitness: 0.016078392043709755, equations: [-0.402*x0*(x1 - 2.76), 0.0993*x0*x1 - 0.397*x1]\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "strategy.print_pareto_front()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Instead of using evolution to optimize the constants, Kozax also offers gradient-based optimization. For gradient optimization, it is possible to specify the optimizer, the number of candidates to apply constant optimization to, the initial learning rate and the learning rate decay over generation. These two methods are provided as either can be more effective or efficient for different problems." 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 6, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Input data should be formatted as: ['x0', 'x1'].\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "import optax\n", 271 | "\n", 272 | "strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations,\n", 273 | " size_parsimony=0.003, constant_optimization_method=\"gradient\", constant_optimization_steps = 15, optimizer_class = optax.adam,\n", 274 | " optimize_constants_elite=100, constant_step_size_init=0.025, constant_step_size_decay=0.95)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 7, 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "In generation 1, best fitness = 1.4242, best solution = [1.41 - 2.33*x0, -0.226*x1 - 0.145]\n", 287 | "In generation 2, best fitness = 1.3698, best solution = [1.16 - 2.29*x0, 0.113 - 0.317*x1]\n", 288 | "In generation 3, best fitness = 1.2569, best solution = [-0.49*x0*x1 + 0.674, 0.907*x0 - 0.261*x1 - 0.681]\n", 289 | "In generation 4, best fitness = 1.1925, best solution = [-0.283*x0*x1 + 0.391, 0.882*x0 - 0.272*x1 - 0.805]\n", 290 | "In generation 5, best fitness = 1.1737, best solution = [-0.276*x0*x1 + 0.392, 0.729*x0 - 0.286*x1 - 0.678]\n", 291 | "In generation 6, best fitness = 1.1737, best solution = [-0.276*x0*x1 + 0.392, 0.729*x0 - 0.286*x1 - 0.678]\n", 292 | "In generation 7, best fitness = 1.1177, best solution = [(1.41 - 0.532*x1)*(x0 - 0.218), x0 - 0.427*x1]\n", 293 | "In generation 8, best fitness = 1.0575, best solution = [2*x0*(0.495 - 0.22*x1), (-1.18*x0 + x1)*(0.00345*x0 - 0.399)]\n", 294 | "In generation 9, best fitness = 0.9400, best solution = [(0.686 - 0.268*x1)*(2*x0 + 0.161), (-1.26*x0 + x1)*(0.0518*x0 - 0.375)]\n", 295 | "In generation 10, best fitness = 0.9385, best solution = [(0.655 - 0.258*x1)*(2*x0 + 0.144), (-1.13*x0 + x1)*(0.0412*x0 - 0.371)]\n", 296 | "In generation 11, best fitness = 0.9128, best solution = [(0.453 - 0.175*x1)*(4*x0 + 0.168), (-1.14*x0 + x1)*(0.0378*x0 - 0.361)]\n", 297 | "In generation 12, best fitness = 0.6460, best solution = [(0.649 - 0.251*x1)*(2*x0 + 0.15), (0.097*x0 - 0.401)*(0.127*x0 + x1 - 0.342)]\n", 298 | "In generation 13, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n", 299 | "In generation 14, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n", 300 | "In generation 15, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n", 301 | "In generation 16, best fitness = 0.3462, best solution = [(0.543 - 0.205*x1)*(2*x0 + 0.111), (0.102*x0 - 0.448)*(0.0216*x0 + x1 - 0.267)]\n", 302 | "In generation 17, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n", 303 | "In generation 18, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n", 304 | "In generation 19, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n", 305 | "In generation 20, best fitness = 0.1724, best solution = [(0.557 - 0.21*x1)*(2*x0 + 0.0978), (0.111*x0 - 0.453)*(x1 - 0.294)]\n", 306 | "In generation 21, best fitness = 0.1688, best solution = [(0.566 - 0.215*x1)*(2*x0 - 0.098), x1*(0.113*x0 - 0.423)]\n", 307 | "In generation 22, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n", 308 | "In generation 23, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n", 309 | "In generation 24, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n", 310 | "In generation 25, best fitness = 0.1231, best solution = [(0.531 - 0.201*x1)*(2*x0 - 0.021), x1*(0.107*x0 - 0.424)]\n", 311 | "In generation 26, best fitness = 0.1189, best solution = [(0.53 - 0.2*x1)*(2*x0 - 0.0242), x1*(0.104*x0 - 0.421)]\n", 312 | "In generation 27, best fitness = 0.1177, best solution = [(0.574 - 0.218*x1)*(1.87*x0 - 0.06), x1*(0.105*x0 - 0.422)]\n", 313 | "In generation 28, best fitness = 0.1133, best solution = [(0.539 - 0.203*x1)*(2*x0 - 0.0842), x1*(0.105*x0 - 0.421)]\n", 314 | "In generation 29, best fitness = 0.1107, best solution = [1.87*x0*(0.572 - 0.217*x1), x1*(0.105*x0 - 0.42)]\n", 315 | "In generation 30, best fitness = 0.1107, best solution = [1.87*x0*(0.572 - 0.217*x1), x1*(0.105*x0 - 0.42)]\n", 316 | "In generation 31, best fitness = 0.1107, best solution = [1.87*x0*(0.572 - 0.217*x1), x1*(0.105*x0 - 0.42)]\n", 317 | "In generation 32, best fitness = 0.1054, best solution = [1.87*x0*(0.571 - 0.216*x1), x1*(0.104*x0 - 0.419)]\n", 318 | "In generation 33, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 319 | "In generation 34, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 320 | "In generation 35, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 321 | "In generation 36, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 322 | "In generation 37, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 323 | "In generation 38, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 324 | "In generation 39, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 325 | "In generation 40, best fitness = 0.1053, best solution = [(0.536 - 0.202*x1)*(2*x0 - 0.0521), x1*(0.103*x0 - 0.418)]\n", 326 | "In generation 41, best fitness = 0.0998, best solution = [2*x0*(0.536 - 0.203*x1), x1*(0.104*x0 - 0.418)]\n", 327 | "In generation 42, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n", 328 | "In generation 43, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n", 329 | "In generation 44, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n", 330 | "In generation 45, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n", 331 | "In generation 46, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n", 332 | "In generation 47, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n", 333 | "In generation 48, best fitness = 0.0941, best solution = [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n", 334 | "In generation 49, best fitness = 0.0938, best solution = [2*x0*(0.537 - 0.202*x1), x1*(0.103*x0 - 0.415)]\n", 335 | "In generation 50, best fitness = 0.0938, best solution = [2*x0*(0.537 - 0.202*x1), x1*(0.103*x0 - 0.415)]\n" 336 | ] 337 | } 338 | ], 339 | "source": [ 340 | "key = jr.PRNGKey(0)\n", 341 | "data_key, gp_key = jr.split(key)\n", 342 | "\n", 343 | "T = 30\n", 344 | "dt = 0.2\n", 345 | "env = LotkaVolterra()\n", 346 | "\n", 347 | "# Simulate the data\n", 348 | "data = get_data(data_key, env, dt, T, batch_size=4)\n", 349 | "x0s, ts, ys = data\n", 350 | "\n", 351 | "# Sample the initial population\n", 352 | "population = strategy.initialize_population(gp_key)\n", 353 | "\n", 354 | "# Define the number of timepoints to include in the data\n", 355 | "end_ts = int(ts.shape[0]/2)\n", 356 | "\n", 357 | "for g in range(num_generations):\n", 358 | " if g == 25: # After 25 generations, use the full data\n", 359 | " end_ts = ts.shape[0]\n", 360 | "\n", 361 | " key, eval_key, sample_key = jr.split(key, 3)\n", 362 | " # Evaluate the population on the data, and return the fitness\n", 363 | " fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key)\n", 364 | "\n", 365 | " # Print the best solution in the population in this generation\n", 366 | " best_fitness, best_solution = strategy.get_statistics(g)\n", 367 | " print(f\"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}\")\n", 368 | "\n", 369 | " # Evolve the population until the last generation. The fitness should be given to the evolve function.\n", 370 | " if g < (num_generations-1):\n", 371 | " population = strategy.evolve_population(population, fitness, sample_key)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 8, 377 | "metadata": {}, 378 | "outputs": [ 379 | { 380 | "name": "stdout", 381 | "output_type": "stream", 382 | "text": [ 383 | "Complexity: 2, fitness: 2.889601707458496, equations: [-0.789, -0.746]\n", 384 | "Complexity: 4, fitness: 2.231682300567627, equations: [1.05 - x0, -0.336]\n", 385 | "Complexity: 6, fitness: 1.392594814300537, equations: [-1.95*x0, -0.291*x1]\n", 386 | "Complexity: 8, fitness: 1.3738036155700684, equations: [-1.62*x0, x0 - 0.407*x1]\n", 387 | "Complexity: 10, fitness: 1.3001673221588135, equations: [-0.276*x0*x1, x0 - 0.386*x1]\n", 388 | "Complexity: 12, fitness: 1.226672649383545, equations: [-0.274*x0*x1 + 0.218, x0 - 0.407*x1]\n", 389 | "Complexity: 14, fitness: 0.916778028011322, equations: [x0*(1.33 - 0.49*x1), 0.311*x0 - 0.327*x1]\n", 390 | "Complexity: 16, fitness: 0.04582058638334274, equations: [2*x0*(0.537 - 0.202*x1), x1*(0.103*x0 - 0.415)]\n", 391 | "Complexity: 18, fitness: 0.04011291265487671, equations: [1.86*x0*(0.574 - 0.215*x1), x1*(0.103*x0 - 0.415)]\n" 392 | ] 393 | } 394 | ], 395 | "source": [ 396 | "strategy.print_pareto_front()" 397 | ] 398 | } 399 | ], 400 | "metadata": { 401 | "kernelspec": { 402 | "display_name": "test_kozax", 403 | "language": "python", 404 | "name": "python3" 405 | }, 406 | "language_info": { 407 | "codemirror_mode": { 408 | "name": "ipython", 409 | "version": 3 410 | }, 411 | "file_extension": ".py", 412 | "mimetype": "text/x-python", 413 | "name": "python", 414 | "nbconvert_exporter": "python", 415 | "pygments_lexer": "ipython3", 416 | "version": "3.12.2" 417 | }, 418 | "orig_nbformat": 4 419 | }, 420 | "nbformat": 4, 421 | "nbformat_minor": 2 422 | } 423 | -------------------------------------------------------------------------------- /examples/example_scripts/control_policy_optimization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Control policy optimization 3 | 4 | In this example, a symbolic policy is evolved for the pendulum swingup task. Gymnax is used for simulation of the pendulum environment, showing that Kozax can easily be extended to external libraries. 5 | """ 6 | 7 | # Specify the cores to use for XLA 8 | import os 9 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10' 10 | import jax 11 | import jax.numpy as jnp 12 | import jax.random as jr 13 | import gymnax 14 | import matplotlib.pyplot as plt 15 | 16 | from kozax.genetic_programming import GeneticProgramming 17 | from kozax.fitness_functions.Gymnax_fitness_function import GymFitnessFunction 18 | 19 | if __name__ == "__main__": 20 | """ 21 | Kozax provides a simple fitness function for Gymnax environments, which is used in this example. 22 | """ 23 | 24 | #Define hyperparameters 25 | population_size = 100 26 | num_populations = 5 27 | num_generations = 50 28 | batch_size = 16 29 | 30 | fitness_function = GymFitnessFunction("Pendulum-v1") 31 | 32 | #Define operators and variables 33 | operator_list = [ 34 | ("+", lambda x, y: jnp.add(x, y), 2, 0.5), 35 | ("-", lambda x, y: jnp.subtract(x, y), 2, 0.1), 36 | ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5), 37 | ] 38 | 39 | variable_list = [[f"y{i}" for i in range(fitness_function.env.observation_space(fitness_function.env_params).shape[0])]] 40 | 41 | #Initialize strategy 42 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, num_populations=num_populations) 43 | 44 | key = jr.PRNGKey(0) 45 | data_key, gp_key = jr.split(key, 2) 46 | 47 | # The data comprises keys need to initialize the batch of environments. 48 | batch_keys = jr.split(data_key, batch_size) 49 | 50 | strategy.fit(gp_key, batch_keys, verbose=True) 51 | 52 | """ 53 | ## Visualize best solution 54 | 55 | We can visualize the sin and cos position in a trajectory using the best solution. 56 | """ 57 | 58 | env, env_params = gymnax.make('Pendulum-v1') 59 | key = jr.PRNGKey(2) 60 | obs, env_state = env.reset(key) 61 | all_obs = [] 62 | treward = [] 63 | actions = [] 64 | 65 | done = False 66 | 67 | sin = jnp.sin 68 | cos = jnp.cos 69 | 70 | T=199 71 | for t in range(T): 72 | 73 | y0, y1, y2 = obs 74 | action = 2.92*y0*(-6.56*y1 - 1.29*y2) 75 | obs, env_state, reward, done, _ = env.step( 76 | jr.fold_in(key, t), env_state, action, env_params 77 | ) 78 | all_obs.append(obs) 79 | treward.append(reward) 80 | actions.append(action) 81 | 82 | all_obs = jnp.array(all_obs) 83 | plt.plot(all_obs[:,0], label='cos(x)') 84 | plt.plot(all_obs[:,1], label='sin(x)') 85 | plt.legend() 86 | plt.show() -------------------------------------------------------------------------------- /examples/example_scripts/control_with_memory_SHO.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Memory-based control policy optimization (Stochastic Harmonic Oscillator) 3 | 4 | In this more advanced example, we will optimize a symbolic control policy that is extended with a dynamic latent memory (check out this paper for more details: 5 | https://arxiv.org/abs/2406.02765). The memory is defined by a set of differential equations, consisting of a tree for each latent unit. The memory updates every time step, and the 6 | control policy maps the memory to a control signal via an additional readout tree. This setup is applied to stabilization of the stochastic harmonic oscillator at random targets. 7 | """ 8 | 9 | # Specify the cores to use for XLA 10 | import os 11 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10' 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | import diffrax 16 | import jax.random as jr 17 | import matplotlib.pyplot as plt 18 | from jax import Array 19 | from typing import Tuple, Callable 20 | import copy 21 | 22 | from kozax.genetic_programming import GeneticProgramming 23 | from kozax.environments.control_environments.harmonic_oscillator import HarmonicOscillator 24 | 25 | """ 26 | First we generate data, consisting of initial conditions, keys for noise, targets and parameters of the environment. 27 | """ 28 | 29 | def get_data(key, env, batch_size, dt, T, param_setting): 30 | init_key, noise_key, param_key = jr.split(key, 3) 31 | x0, targets = env.sample_init_states(batch_size, init_key) 32 | noise_keys = jr.split(noise_key, batch_size) 33 | ts = jnp.arange(0, T, dt) 34 | 35 | params = env.sample_params(batch_size, param_setting, ts, param_key) 36 | return x0, ts, targets, noise_keys, params 37 | 38 | key = jr.PRNGKey(0) 39 | gp_key, data_key = jr.split(key) 40 | batch_size = 8 41 | T = 40 42 | dt = 0.2 43 | param_setting = "Constant" 44 | 45 | env = HarmonicOscillator(process_noise = 0.05, obs_noise = 0.05) 46 | 47 | data = get_data(data_key, env, batch_size, dt, T, param_setting) 48 | 49 | """ 50 | The evaluator class parallelizes the simulation of the control loop over the batched data. In each trajectory, a coupled dynamical system is simulated using diffrax, 51 | integrating both the dynamic memory and the dynamic environment, defined by the drift and diffusion functions. At every time step, the state of the environment is mapped to 52 | observations and the latent memory is mapped to a control signal. Afterwards, the state equation of the memory and environment state is computed. When the simulation is done, 53 | the fitness is computed given the control and environment state. 54 | """ 55 | 56 | class Evaluator: 57 | def __init__(self, env, state_size: int, dt0: float, solver=diffrax.Euler(), max_steps: int = 16**4, stepsize_controller: diffrax.AbstractStepSizeController = diffrax.ConstantStepSize()) -> None: 58 | """Evaluates dynamic symbolic policies in control tasks. 59 | 60 | Args: 61 | env: Environment on which the candidate is evaluated. 62 | state_size: Dimensionality of the hidden state. 63 | dt0: Initial step size for integration. 64 | solver: Solver used for integration (default: diffrax.Euler()). 65 | max_steps: The maximum number of steps that can be used in integration (default: 16**4). 66 | stepsize_controller: Controller for the stepsize during integration (default: diffrax.ConstantStepSize()). 67 | 68 | Attributes: 69 | env: Environment on which the candidate is evaluated. 70 | max_fitness: Max fitness which is assigned when a trajectory returns an invalid value. 71 | state_size: Dimensionality of the hidden state. 72 | obs_size: Dimensionality of the observations. 73 | control_size: Dimensionality of the control. 74 | latent_size: Dimensionality of the state of the environment. 75 | dt0: Initial step size for integration. 76 | solver: Solver used for integration. 77 | max_steps: The maximum number of steps that can be used in integration. 78 | stepsize_controller: Controller for the stepsize during integration. 79 | """ 80 | self.env = env 81 | self.max_fitness = 1e4 82 | self.state_size = state_size 83 | self.obs_size = env.n_obs 84 | self.control_size = env.n_control_inputs 85 | self.latent_size = env.n_var*env.n_dim 86 | self.dt0 = dt0 87 | self.solver = solver 88 | self.max_steps = max_steps 89 | self.stepsize_controller = stepsize_controller 90 | 91 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float: 92 | """Evaluates the candidate on a task. 93 | 94 | Args: 95 | candidate: The coefficients of the candidate. 96 | data: The data required to evaluate the candidate. 97 | tree_evaluator: Function for evaluating trees. 98 | 99 | Returns: 100 | Fitness of the candidate. 101 | """ 102 | _, _, _, _, fitness = jax.vmap(self.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(candidate, *data, tree_evaluator) 103 | 104 | fitness = jnp.mean(fitness) 105 | return fitness 106 | 107 | def evaluate_trajectory(self, candidate: Array, x0: Array, ts: Array, target: float, noise_key: jr.PRNGKey, params: Tuple, tree_evaluator: Callable) -> Tuple[Array, Array, Array, Array, float]: 108 | """Solves the coupled differential equation of the system and controller. 109 | The differential equation of the system is defined in the environment and the differential equation 110 | of the control is defined by the set of trees. 111 | 112 | Args: 113 | candidate: Candidate with trees for the hidden state and readout. 114 | x0: Initial state of the system. 115 | ts: Time points on which the controller is evaluated. 116 | target: Target position that the system should reach. 117 | noise_key: Key to generate noisy observations. 118 | params: Parameters that define the environment. 119 | tree_evaluator: Function for evaluating trees. 120 | 121 | Returns: 122 | States, observations, control, activities of the hidden state of the candidate and the fitness of the candidate. 123 | """ 124 | env = copy.copy(self.env) 125 | env.initialize_parameters(params, ts) 126 | 127 | state_equation = candidate[:self.state_size] 128 | readout = candidate[self.state_size:] 129 | 130 | solver = self.solver 131 | dt0 = self.dt0 132 | saveat = diffrax.SaveAt(ts=ts) 133 | 134 | process_noise_key, obs_noise_key = jr.split(noise_key, 2) 135 | 136 | # Concatenate the initial state of the system with the initial state of the latent memory 137 | _x0 = jnp.concatenate([x0, jnp.zeros(self.state_size)]) 138 | 139 | brownian_motion = diffrax.UnsafeBrownianPath(shape=(env.n_var,), key=process_noise_key, levy_area=diffrax.SpaceTimeLevyArea) #define process noise 140 | system = diffrax.MultiTerm(diffrax.ODETerm(self._drift), diffrax.ControlTerm(self._diffusion, brownian_motion)) 141 | 142 | # Solve the coupled system of the environment and the controller 143 | sol = diffrax.diffeqsolve( 144 | system, solver, ts[0], ts[-1], dt0, _x0, saveat=saveat, adjoint=diffrax.DirectAdjoint(), max_steps=self.max_steps, event=diffrax.Event(self.env.cond_fn_nan), 145 | args=(env, state_equation, readout, obs_noise_key, target, tree_evaluator), stepsize_controller=self.stepsize_controller, throw=False 146 | ) 147 | 148 | xs = sol.ys[:,:self.latent_size] 149 | # Get observations of the state at every time step 150 | _, ys = jax.lax.scan(env.f_obs, obs_noise_key, (ts, xs)) 151 | 152 | activities = sol.ys[:,self.latent_size:] 153 | # Get control actions at every time step 154 | us = jax.vmap(lambda y, a, tar: tree_evaluator(readout, jnp.concatenate([y, a, jnp.zeros(self.control_size), target])), in_axes=[0,0,None])(ys, activities, target) 155 | 156 | # Compute the fitness of the candidate in this trajectory 157 | fitness = env.fitness_function(xs, us[:,None], target, ts) 158 | 159 | return xs, ys, us, activities, fitness 160 | 161 | def _drift(self, t, x_a, args): 162 | env, state_equation, readout, obs_noise_key, target, tree_evaluator = args 163 | x = x_a[:self.latent_size] 164 | a = x_a[self.latent_size:] 165 | 166 | _, y = env.f_obs(obs_noise_key, (t, x)) #Get observations from system 167 | u = tree_evaluator(readout, jnp.concatenate([jnp.zeros(self.obs_size), a, jnp.zeros(self.control_size), target])) #Compute control action from latent memory 168 | 169 | u = jnp.atleast_1d(u) 170 | 171 | dx = env.drift(t, x, u) #Apply control to system and get system change 172 | da = tree_evaluator(state_equation, jnp.concatenate([y, a, u, target])) #Compute change in latent memory 173 | 174 | return jnp.concatenate([dx, da]) 175 | 176 | def _diffusion(self, t, x_a, args): 177 | env, state_equation, readout, obs_noise_key, target, tree_evaluator = args 178 | x = x_a[:self.latent_size] 179 | a = x_a[self.latent_size:] 180 | 181 | return jnp.concatenate([env.diffusion(t, x, jnp.array([0])), jnp.zeros((self.state_size, self.latent_size))]) #Only the system is stochastic 182 | 183 | """ 184 | Here we define the hyperparameters, operators, variables and initialize the strategy. The control policy will consist of two latent variables and a readout layer. 185 | `layer_sizes` allows us to define different types of tree, where `variable_list` contains different sets of input variables for each type of tree. 186 | The readout layer will only receive the latent states, while the inputs to the state equations consists of the observations, control signal and latent states. 187 | """ 188 | 189 | #Define hyperparameters 190 | population_size = 100 191 | num_populations = 10 192 | num_generations = 50 193 | state_size = 2 194 | 195 | #Define expressions 196 | operator_list = [("+", lambda x, y: x + y, 2, 0.5), 197 | ("-", lambda x, y: x - y, 2, 0.1), 198 | ("*", lambda x, y: x * y, 2, 0.5)] 199 | 200 | variable_list = [["x" + str(i) for i in range(env.n_obs)] + ["a1", "a2", "u", "tar"], ["a1", "a2", "tar"]] 201 | 202 | layer_sizes = jnp.array([state_size, env.n_control_inputs]) 203 | 204 | #Define evaluator 205 | fitness_function = Evaluator(env, state_size, 0.05, solver=diffrax.GeneralShARK(), max_steps=1000) 206 | 207 | #Initialize strategy 208 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations, size_parsimony=0.003) 209 | 210 | strategy.fit(gp_key, data, verbose=True) 211 | 212 | # Visualize best solution 213 | 214 | #Generate test_data 215 | data = get_data(jr.PRNGKey(10), env, 4, 0.01, T, param_setting) 216 | x0s, ts, targets, noise_key, params = data 217 | 218 | best_candidate = strategy.pareto_front[1][17] 219 | print(strategy.expression_to_string(best_candidate)) 220 | 221 | xs, ys, us, activities, fitness = jax.vmap(fitness_function.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(best_candidate, *data, strategy.tree_evaluator) 222 | 223 | fig, ax = plt.subplots(2, 2, figsize=(10, 5)) 224 | ax = ax.ravel() 225 | for i in range(4): 226 | ax[i].plot(ts, xs[i,:,0], label="$x_1$", color = "blue") 227 | ax[i].plot(ts, xs[i,:,1], label="$x_2$", color = "orange") 228 | ax[i].plot(ts, ys[i,:,0], alpha=0.3, color = "blue") 229 | ax[i].plot(ts, ys[i,:,1], alpha=0.3, color = "orange") 230 | ax[i].plot(ts, us[i,:], label="$u_1$", color = "green") 231 | ax[i].hlines(targets[i], ts[0], ts[-1], linestyles='dashed', color = "black") 232 | 233 | plt.legend(loc="best") 234 | plt.show() -------------------------------------------------------------------------------- /examples/example_scripts/control_with_memory_acrobot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Memory-based control policy optimization 3 | 4 | In this more advanced example, we will optimize a symbolic control policy that is extended with a dynamic latent memory (check out this paper for more details: https://arxiv.org/abs/2406.02765). 5 | The memory is defined by a set of differential equations, consisting of a tree for each latent unit. The memory updates every time step, and the control policy maps the memory to a control 6 | signal via an additional readout tree. This setup is applied to the partially observable acrobot swingup task, in which the angular velocity is hidden. 7 | """ 8 | 9 | # Specify the cores to use for XLA 10 | import os 11 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10' 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | import diffrax 16 | import jax.random as jr 17 | import matplotlib.pyplot as plt 18 | from jax import Array 19 | from typing import Tuple, Callable 20 | import copy 21 | 22 | from kozax.genetic_programming import GeneticProgramming 23 | from kozax.environments.control_environments.acrobot import Acrobot 24 | 25 | """ 26 | First we generate data, consisting of initial conditions, keys for noise, targets and parameters of the environment. 27 | """ 28 | 29 | def get_data(key, env, batch_size, dt, T, param_setting): 30 | init_key, noise_key2, param_key = jr.split(key, 3) 31 | x0, targets = env.sample_init_states(batch_size, init_key) 32 | noise_keys = jr.split(noise_key2, batch_size) 33 | ts = jnp.arange(0, T, dt) 34 | 35 | params = env.sample_params(batch_size, param_setting, ts, param_key) 36 | return x0, ts, targets, noise_keys, params 37 | 38 | """ 39 | The evaluator class parallelizes the simulation of the control loop over the batched data. In each trajectory, a coupled dynamical system is simulated using diffrax, integrating both the 40 | dynamic memory and the dynamic environment. At every time step, the state of the environment is mapped to observations and the latent memory is mapped to a control signal. Afterwards, 41 | the state equation of the memory and environment state is computed. When the simulation is done, the fitness is computed given the control and environment state. 42 | """ 43 | 44 | class Evaluator: 45 | def __init__(self, env, state_size: int, dt0: float, solver=diffrax.Euler(), max_steps: int = 16**4, stepsize_controller: diffrax.AbstractStepSizeController = diffrax.ConstantStepSize()) -> None: 46 | """Evaluates dynamic symbolic policies in control tasks. 47 | 48 | Args: 49 | env: Environment on which the candidate is evaluated. 50 | state_size: Dimensionality of the hidden state. 51 | dt0: Initial step size for integration. 52 | solver: Solver used for integration (default: diffrax.Euler()). 53 | max_steps: The maximum number of steps that can be used in integration (default: 16**4). 54 | stepsize_controller: Controller for the stepsize during integration (default: diffrax.ConstantStepSize()). 55 | 56 | Attributes: 57 | env: Environment on which the candidate is evaluated. 58 | max_fitness: Max fitness which is assigned when a trajectory returns an invalid value. 59 | state_size: Dimensionality of the hidden state. 60 | obs_size: Dimensionality of the observations. 61 | control_size: Dimensionality of the control. 62 | latent_size: Dimensionality of the state of the environment. 63 | dt0: Initial step size for integration. 64 | solver: Solver used for integration. 65 | max_steps: The maximum number of steps that can be used in integration. 66 | stepsize_controller: Controller for the stepsize during integration. 67 | """ 68 | self.env = env 69 | self.max_fitness = 1e4 70 | self.state_size = state_size 71 | self.obs_size = env.n_obs 72 | self.control_size = env.n_control_inputs 73 | self.latent_size = env.n_var*env.n_dim 74 | self.dt0 = dt0 75 | self.solver = solver 76 | self.max_steps = max_steps 77 | self.stepsize_controller = stepsize_controller 78 | 79 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float: 80 | """Evaluates the candidate on a task. 81 | 82 | Args: 83 | candidate: The coefficients of the candidate. 84 | data: The data required to evaluate the candidate. 85 | tree_evaluator: Function for evaluating trees. 86 | 87 | Returns: 88 | Fitness of the candidate. 89 | """ 90 | _, _, _, _, fitness = jax.vmap(self.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(candidate, *data, tree_evaluator) 91 | 92 | fitness = jnp.mean(fitness) 93 | return fitness 94 | 95 | def evaluate_trajectory(self, candidate: Array, x0: Array, ts: Array, target: float, noise_key: jr.PRNGKey, params: Tuple, tree_evaluator: Callable) -> Tuple[Array, Array, Array, Array, float]: 96 | """Solves the coupled differential equation of the system and controller. 97 | The differential equation of the system is defined in the environment and the differential equation 98 | of the control is defined by the set of trees. 99 | 100 | Args: 101 | candidate: Candidate with trees for the hidden state and readout. 102 | x0: Initial state of the system. 103 | ts: Time points on which the controller is evaluated. 104 | target: Target position that the system should reach. 105 | noise_key: Key to generate noisy observations. 106 | params: Parameters that define the environment. 107 | tree_evaluator: Function for evaluating trees. 108 | 109 | Returns: 110 | States, observations, control, activities of the hidden state of the candidate and the fitness of the candidate. 111 | """ 112 | env = copy.copy(self.env) 113 | env.initialize_parameters(params, ts) 114 | 115 | state_equation = candidate[:self.state_size] 116 | readout = candidate[self.state_size:] 117 | 118 | solver = self.solver 119 | dt0 = self.dt0 120 | saveat = diffrax.SaveAt(ts=ts) 121 | 122 | # Concatenate the initial state of the system with the initial state of the latent memory 123 | _x0 = jnp.concatenate([x0, jnp.zeros(self.state_size)]) 124 | 125 | system = diffrax.ODETerm(self._drift) 126 | 127 | # Solve the coupled system of the environment and the controller 128 | sol = diffrax.diffeqsolve( 129 | system, solver, ts[0], ts[-1], dt0, _x0, saveat=saveat, adjoint=diffrax.DirectAdjoint(), max_steps=self.max_steps, event=diffrax.Event(self.env.cond_fn_nan), 130 | args=(env, state_equation, readout, noise_key, target, tree_evaluator), stepsize_controller=self.stepsize_controller, throw=False 131 | ) 132 | 133 | xs = sol.ys[:,:self.latent_size] 134 | # Get observations of the state at every time step 135 | _, ys = jax.lax.scan(env.f_obs, noise_key, (ts, xs)) 136 | 137 | activities = sol.ys[:,self.latent_size:] 138 | # Get control actions at every time step 139 | us = jax.vmap(lambda y, a, tar: tree_evaluator(readout, jnp.concatenate([y, a, jnp.zeros(self.control_size), target])), in_axes=[0,0,None])(ys, activities, target) 140 | 141 | # Compute the fitness of the candidate in this trajectory 142 | fitness = env.fitness_function(xs, us[:,None], target, ts) 143 | 144 | return xs, ys, us, activities, fitness 145 | 146 | def _drift(self, t, x_a, args): 147 | env, state_equation, readout, noise_key, target, tree_evaluator = args 148 | x = x_a[:self.latent_size] 149 | a = x_a[self.latent_size:] 150 | 151 | _, y = env.f_obs(noise_key, (t, x)) #Get observations from system 152 | u = tree_evaluator(readout, jnp.concatenate([jnp.zeros(self.obs_size), a, jnp.zeros(self.control_size), target])) #Compute control action from latent memory 153 | 154 | u = jnp.atleast_1d(u) 155 | 156 | dx = env.drift(t, x, u) #Apply control to system and get system change 157 | da = tree_evaluator(state_equation, jnp.concatenate([y, a, u, target])) #Compute change in latent memory 158 | 159 | return jnp.concatenate([dx, da]) 160 | 161 | """ 162 | Here we define the hyperparameters, operators, variables and initialize the strategy. The control policy will consist of two latent variables and a readout layer. 163 | `layer_sizes` allows us to define different types of tree, where `variable_list` contains different sets of input variables for each type of tree. 164 | The readout layer will only receive the latent states, while the inputs to the state equations consists of the observations, control signal and latent states. 165 | """ 166 | 167 | if __name__ == "__main__": 168 | key = jr.PRNGKey(0) 169 | gp_key, data_key = jr.split(key) 170 | batch_size = 8 171 | T = 40 172 | dt = 0.2 173 | param_setting = "Constant" 174 | 175 | env = Acrobot(n_obs = 2) # Partial observably Acrobot 176 | 177 | data = get_data(data_key, env, batch_size, dt, T, param_setting) 178 | 179 | #Define hyperparameters 180 | population_size = 100 181 | num_populations = 10 182 | num_generations = 50 183 | state_size = 2 184 | 185 | #Define expressions 186 | operator_list = [("+", lambda x, y: x + y, 2, 0.5), 187 | ("-", lambda x, y: x - y, 2, 0.1), 188 | ("*", lambda x, y: x * y, 2, 0.5), 189 | ("sin", lambda x: jnp.sin(x), 1, 0.1), 190 | ("cos", lambda x: jnp.cos(x), 1, 0.1)] 191 | 192 | variable_list = [["x" + str(i) for i in range(env.n_obs)] + ["a1", "a2", "u"], ["a1", "a2"]] 193 | 194 | layer_sizes = jnp.array([state_size, env.n_control_inputs]) 195 | 196 | #Define evaluator 197 | fitness_function = Evaluator(env, state_size, 0.05, solver=diffrax.Dopri5(), stepsize_controller=diffrax.PIDController(atol=1e-4, rtol=1e-4, dtmin=0.001), max_steps=1000) 198 | 199 | #Initialize strategy 200 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations, size_parsimony=0.003) 201 | 202 | strategy.fit(gp_key, data, verbose=True) 203 | 204 | """ 205 | ## Visualize best solution 206 | """ 207 | 208 | #Generate test_data 209 | data = get_data(jr.PRNGKey(10), env, 1, 0.01, T, param_setting) 210 | x0s, ts, _, _, params = data 211 | 212 | best_candidate = strategy.pareto_front[1][21] 213 | 214 | xs, ys, us, activities, fitness = jax.vmap(fitness_function.evaluate_trajectory, in_axes=[None, 0, None, 0, 0, 0, None])(best_candidate, *data, strategy.tree_evaluator) 215 | 216 | plt.plot(ts, -jnp.cos(xs[0,:,0]), color = f"C{0}", label="first link") 217 | plt.plot(ts, -jnp.cos(xs[0,:,0]) - jnp.cos(xs[0,:,0] + xs[0,:,1]), color = f"C{1}", label="second link") 218 | plt.hlines(1.5, ts[0], ts[-1], linestyles='dashed', color = "black") 219 | 220 | plt.legend(loc="best") 221 | plt.show() -------------------------------------------------------------------------------- /examples/example_scripts/objective_function_optimization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Objective function optimization 3 | 4 | In this example, Kozax is used to evolve a symbolic loss function to train a neural network. 5 | With each candidate loss function, a neural network is trained on the task of binary classification of XOR data points. 6 | """ 7 | 8 | # Specify the cores to use for XLA 9 | import os 10 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10' 11 | 12 | import jax 13 | import jax.numpy as jnp 14 | import jax.random as jr 15 | import optax 16 | from typing import Callable, Tuple 17 | from jax import Array 18 | 19 | from kozax.genetic_programming import GeneticProgramming 20 | 21 | """ 22 | We define a fitness function class that includes the network initialization, training loop and weight updates. 23 | At every epoch, a new batch of data is sampled, and the fitness is computed as the accuracy of the trained network on a validation set. 24 | """ 25 | 26 | class FitnessFunction: 27 | """ 28 | A class to define the fitness function for evaluating candidate loss functions. 29 | The fitness is computed as the accuracy of a neural network trained with the candidate loss function 30 | on a binary classification task (XOR data). 31 | 32 | Attributes: 33 | input_dim (int): Dimension of the input data. 34 | hidden_dim (int): Dimension of the hidden layers in the neural network. 35 | output_dim (int): Dimension of the output. 36 | epochs (int): Number of training epochs. 37 | learning_rate (float): Learning rate for the optimizer. 38 | optim (optax.GradientTransformation): Optax optimizer instance. 39 | """ 40 | def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, epochs: int, learning_rate: float): 41 | self.input_dim = input_dim 42 | self.hidden_dim = hidden_dim 43 | self.output_dim = output_dim 44 | self.optim = optax.adam(learning_rate) 45 | self.epochs = epochs 46 | 47 | def __call__(self, candidate: str, data: Tuple[Array, Array, Array], tree_evaluator: Callable) -> Array: 48 | """ 49 | Computes the fitness of a candidate loss function. 50 | 51 | Args: 52 | candidate: The candidate loss function (symbolic tree). 53 | data (tuple): A tuple containing the data keys, test keys, and network keys. 54 | tree_evaluator: A function to evaluate the symbolic tree. 55 | 56 | Returns: 57 | Array: The mean loss (1 - accuracy) on the validation set. 58 | """ 59 | data_keys, test_keys, network_keys = data 60 | losses = jax.vmap(self.train, in_axes=[None, 0, 0, 0, None])(candidate, data_keys, test_keys, network_keys, tree_evaluator) 61 | return jnp.mean(losses) 62 | 63 | def get_data(self, key: jr.PRNGKey, n_samples: int = 50) -> Tuple[Array, Array]: 64 | """ 65 | Generates XOR data. 66 | 67 | Args: 68 | key (jax.random.PRNGKey): Random key for data generation. 69 | n_samples (int): Number of samples to generate. 70 | 71 | Returns: 72 | tuple: A tuple containing the input data (x) and the target labels (y). 73 | """ 74 | x = jr.uniform(key, shape=(n_samples, 2)) 75 | y = jnp.logical_xor(x[:,0]>0.5, x[:,1]>0.5) 76 | 77 | return x, y[:,None] 78 | 79 | def loss_function(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array, y: Array, candidate: str, tree_evaluator: Callable) -> Array: 80 | """ 81 | Computes the loss with an evolved loss function for a given set of parameters and data. 82 | 83 | Args: 84 | params (tuple): The parameters of the neural network. 85 | x (Array): The input data. 86 | y (Array): The target labels. 87 | candidate: The candidate loss function (symbolic tree). 88 | tree_evaluator: A function to evaluate the symbolic tree. 89 | 90 | Returns: 91 | Array: The mean loss. 92 | """ 93 | pred = self.neural_network(params, x) 94 | return jnp.mean(jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, jnp.concatenate([pred, y], axis=-1))) 95 | 96 | def train(self, candidate: str, data_key: jr.PRNGKey, test_key: jr.PRNGKey, network_key: jr.PRNGKey, tree_evaluator: Callable) -> Array: 97 | """ 98 | Trains a neural network with a given candidate loss function. 99 | 100 | Args: 101 | candidate: The candidate loss function (symbolic tree). 102 | data_key (jax.random.PRNGKey): Random key for data generation during training. 103 | test_key (jax.random.PRNGKey): Random key for data generation during testing. 104 | network_key (jax.random.PRNGKey): Random key for initializing the network parameters. 105 | tree_evaluator: A function to evaluate the symbolic tree. 106 | 107 | Returns: 108 | Array: The validation loss (1 - accuracy). 109 | """ 110 | params = self.init_network_params(network_key) 111 | 112 | optim_state = self.optim.init(params) 113 | 114 | def step(i: int, carry: Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]) -> Tuple[Tuple[Array, Array, Array, Array, Array, Array], optax._src.base.OptState, jr.PRNGKey]: 115 | params, optim_state, key = carry 116 | 117 | key, _key = jr.split(key) 118 | 119 | x_train, y_train = self.get_data(_key, n_samples=50) 120 | 121 | # Evaluate network parameters and compute gradients 122 | grads = jax.grad(self.loss_function)(params, x_train, y_train, candidate, tree_evaluator) 123 | 124 | # Update parameters 125 | updates, optim_state = self.optim.update(grads, optim_state, params) 126 | params = optax.apply_updates(params, updates) 127 | 128 | return (params, optim_state, key) 129 | 130 | (params, _, _) = jax.lax.fori_loop(0, self.epochs, step, (params, optim_state, data_key)) 131 | 132 | # Evaluate parameters on test set 133 | x_test, y_test = self.get_data(test_key, n_samples=500) 134 | 135 | pred = self.neural_network(params, x_test) 136 | return 1 - jnp.mean(y_test==(pred>0.5)) # Return 1 - accuracy 137 | 138 | def neural_network(self, params: Tuple[Array, Array, Array, Array, Array, Array], x: Array) -> Array: 139 | """ 140 | Defines the neural network architecture (forward pass). 141 | 142 | Args: 143 | params (tuple): The parameters of the neural network. 144 | x (Array): The input data. 145 | 146 | Returns: 147 | Array: The output of the neural network. 148 | """ 149 | w1, b1, w2, b2, w3, b3 = params 150 | hidden = jnp.tanh(jnp.dot(x, w1) + b1) 151 | hidden = jnp.tanh(jnp.dot(hidden, w2) + b2) 152 | output = jnp.dot(hidden, w3) + b3 153 | return jax.nn.sigmoid(output) 154 | 155 | def init_network_params(self, key: jr.PRNGKey) -> Tuple[Array, Array, Array, Array, Array, Array]: 156 | """ 157 | Initializes the parameters of the neural network. 158 | 159 | Args: 160 | key (jax.random.PRNGKey): Random key for parameter initialization. 161 | 162 | Returns: 163 | tuple: A tuple containing the initialized weights and biases. 164 | """ 165 | key1, key2, key3 = jr.split(key, 3) 166 | w1 = jr.normal(key1, (self.input_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.input_dim) 167 | b1 = jnp.zeros(self.hidden_dim) 168 | w2 = jr.normal(key2, (self.hidden_dim, self.hidden_dim)) * jnp.sqrt(2.0 / self.hidden_dim) 169 | b2 = jnp.zeros(self.hidden_dim) 170 | w3 = jr.normal(key3, (self.hidden_dim, self.output_dim)) * jnp.sqrt(2.0 / self.hidden_dim) 171 | b3 = jnp.zeros(self.output_dim) 172 | return (w1, b1, w2, b2, w3, b3) 173 | 174 | """ 175 | To make sure the optimized loss function generalizes, a batch of neural networks are trained with different data and weight initialization. 176 | For this purpose, a batch of keys for initialization, data sampling and validation data are generated. 177 | """ 178 | 179 | def generate_keys(key, batch_size=4): 180 | key1, key2, key3 = jr.split(key, 3) 181 | return jr.split(key1, batch_size), jr.split(key2, batch_size), jr.split(key3, batch_size) 182 | 183 | """ 184 | Here we define the hyperparameters and inputs to the genetic programming algorithm. 185 | The inputs to the trees are the prediction and target value. 186 | """ 187 | 188 | if __name__ == "__main__": 189 | key = jr.PRNGKey(0) 190 | data_key, gp_key = jr.split(key) 191 | 192 | population_size = 50 193 | num_populations = 5 194 | num_generations = 25 195 | 196 | operator_list = [("+", lambda x, y: jnp.add(x, y), 2, 0.5), 197 | ("-", lambda x, y: jnp.subtract(x, y), 2, 0.5), 198 | ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5), 199 | ("log", lambda x: jnp.log(x + 1e-7), 1, 0.1), 200 | ] 201 | 202 | variable_list = [["pred", "y"]] 203 | 204 | input_dim = 2 205 | hidden_dim = 16 206 | output_dim = 1 207 | 208 | fitness_function = FitnessFunction(input_dim, hidden_dim, output_dim, learning_rate=0.01, epochs=100) 209 | 210 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, num_populations = num_populations) 211 | 212 | data_keys, test_keys, network_keys = generate_keys(data_key) 213 | 214 | strategy.fit(gp_key, (data_keys, test_keys, network_keys), verbose=True) -------------------------------------------------------------------------------- /examples/example_scripts/symbolic_regression_dynamical_system.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Symbolic regression of a dynamical system 3 | 4 | In this example, Kozax is applied to recover the state equations of the Lotka-Volterra system. The candidate solutions are integrated as a system of differential equations, after which the 5 | predictions are compared to the true observations to determine a fitness score. 6 | """ 7 | 8 | # Specify the cores to use for XLA 9 | import os 10 | os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10' 11 | 12 | import jax 13 | import diffrax 14 | import jax.numpy as jnp 15 | import jax.random as jr 16 | import diffrax 17 | 18 | from kozax.genetic_programming import GeneticProgramming 19 | from kozax.fitness_functions.ODE_fitness_function import ODEFitnessFunction 20 | from kozax.environments.SR_environments.lotka_volterra import LotkaVolterra 21 | 22 | """ 23 | First the data is generated, consisting of initial conditions, time points and the true observations. Kozax provides the Lotka-Volterra environment, which is integrated with Diffrax. 24 | """ 25 | 26 | def get_data(key, env, dt, T, batch_size=20): 27 | x0s = env.sample_init_states(batch_size, key) 28 | ts = jnp.arange(0, T, dt) 29 | 30 | def solve(env, ts, x0): 31 | solver = diffrax.Dopri5() 32 | dt0 = 0.001 33 | saveat = diffrax.SaveAt(ts=ts) 34 | 35 | system = diffrax.ODETerm(env.drift) 36 | 37 | # Solve the system given an initial conditions 38 | sol = diffrax.diffeqsolve(system, solver, ts[0], ts[-1], dt0, x0, saveat=saveat, max_steps=500, 39 | adjoint=diffrax.DirectAdjoint(), stepsize_controller=diffrax.PIDController(atol=1e-7, rtol=1e-7, dtmin=0.001)) 40 | 41 | return sol.ys 42 | 43 | ys = jax.vmap(solve, in_axes=[None, None, 0])(env, ts, x0s) #Parallelize over the batch dimension 44 | 45 | return x0s, ts, ys 46 | 47 | """ 48 | For the fitness function, we used the ODEFitnessFunction that uses Diffrax to integrate candidate solutions. It is possible to select the solver, time step, number of steps and a 49 | stepsize controller to balance efficiency and accuracy. To ensure convergence of the genetic programming algorithm, constant optimization is applied to the best candidates at every 50 | generation. The constant optimization is performed with a couple of simple evolutionary steps that adjust the values of the constants in a candidate. The hyperparameters that define the 51 | constant optimization are `constant_optimization_N_offspring`, `constant_optimization_steps`, `optimize_constants_elite` and `constant_step_size_init`. 52 | """ 53 | 54 | if __name__ == "__main__": 55 | key = jr.PRNGKey(0) 56 | data_key, gp_key = jr.split(key) 57 | 58 | T = 30 59 | dt = 0.2 60 | env = LotkaVolterra() 61 | 62 | # Simulate the data 63 | data = get_data(data_key, env, dt, T, batch_size=4) 64 | x0s, ts, ys = data 65 | 66 | #Define the nodes and hyperparameters 67 | operator_list = [ 68 | ("+", lambda x, y: jnp.add(x, y), 2, 0.5), 69 | ("-", lambda x, y: jnp.subtract(x, y), 2, 0.1), 70 | ("*", lambda x, y: jnp.multiply(x, y), 2, 0.5), 71 | ] 72 | 73 | variable_list = [["x" + str(i) for i in range(env.n_var)]] 74 | layer_sizes = jnp.array([env.n_var]) 75 | 76 | population_size = 100 77 | num_populations = 10 78 | num_generations = 50 79 | 80 | #Initialize the fitness function and the genetic programming strategy 81 | fitness_function = ODEFitnessFunction(solver=diffrax.Dopri5(), dt0 = 0.01, stepsize_controller=diffrax.PIDController(atol=1e-6, rtol=1e-6, dtmin=0.001), max_steps=300) 82 | 83 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations, 84 | size_parsimony=0.003, constant_optimization_method="evolution", constant_optimization_N_offspring = 25, constant_optimization_steps = 5, 85 | optimize_constants_elite=100, constant_step_size_init=0.1, constant_step_size_decay=0.99) 86 | """ 87 | Kozax provides a fit function that receives the data and a random key. However, it is also possible to run Kozax with an easy loop consisting of evaluating and evolving. 88 | This is useful as different input data can be provided during evaluation. In symbolic regression of dynamical systems, it helps to first optimize on a small part of the time points, 89 | and provide the full data trajectories only after a couple of generations. 90 | """ 91 | 92 | # Sample the initial population 93 | population = strategy.initialize_population(gp_key) 94 | 95 | # Define the number of timepoints to include in the data 96 | end_ts = int(ts.shape[0]/2) 97 | 98 | for g in range(num_generations): 99 | if g == 25: # After 25 generations, use the full data 100 | end_ts = ts.shape[0] 101 | 102 | key, eval_key, sample_key = jr.split(key, 3) 103 | # Evaluate the population on the data, and return the fitness 104 | fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key) 105 | 106 | # Print the best solution in the population in this generation 107 | best_fitness, best_solution = strategy.get_statistics(g) 108 | print(f"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}") 109 | 110 | # Evolve the population until the last generation. The fitness should be given to the evolve function. 111 | if g < (num_generations-1): 112 | population = strategy.evolve_population(population, fitness, sample_key) 113 | 114 | strategy.print_pareto_front() 115 | 116 | """ 117 | Instead of using evolution to optimize the constants, Kozax also offers gradient-based optimization. For gradient optimization, it is possible to specify the optimizer, the number of 118 | candidates to apply constant optimization to, the initial learning rate and the learning rate decay over generation. These two methods are provided as either can be more effective or 119 | efficient for different problems. 120 | """ 121 | 122 | import optax 123 | 124 | strategy = GeneticProgramming(num_generations, population_size, fitness_function, operator_list, variable_list, layer_sizes, num_populations = num_populations, 125 | size_parsimony=0.003, constant_optimization_method="gradient", constant_optimization_steps = 15, optimizer_class = optax.adam, 126 | optimize_constants_elite=100, constant_step_size_init=0.025, constant_step_size_decay=0.95) 127 | 128 | key = jr.PRNGKey(0) 129 | data_key, gp_key = jr.split(key) 130 | 131 | T = 30 132 | dt = 0.2 133 | env = LotkaVolterra() 134 | 135 | # Simulate the data 136 | data = get_data(data_key, env, dt, T, batch_size=4) 137 | x0s, ts, ys = data 138 | 139 | # Sample the initial population 140 | population = strategy.initialize_population(gp_key) 141 | 142 | # Define the number of timepoints to include in the data 143 | end_ts = int(ts.shape[0]/2) 144 | 145 | for g in range(num_generations): 146 | if g == 25: # After 25 generations, use the full data 147 | end_ts = ts.shape[0] 148 | 149 | key, eval_key, sample_key = jr.split(key, 3) 150 | # Evaluate the population on the data, and return the fitness 151 | fitness, population = strategy.evaluate_population(population, (x0s, ts[:end_ts], ys[:,:end_ts]), eval_key) 152 | 153 | # Print the best solution in the population in this generation 154 | best_fitness, best_solution = strategy.get_statistics(g) 155 | print(f"In generation {g+1}, best fitness = {best_fitness:.4f}, best solution = {strategy.expression_to_string(best_solution)}") 156 | 157 | # Evolve the population until the last generation. The fitness should be given to the evolve function. 158 | if g < (num_generations-1): 159 | population = strategy.evolve_population(population, fitness, sample_key) 160 | 161 | strategy.print_pareto_front() -------------------------------------------------------------------------------- /figures/applications.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdevries0/Kozax/95c96ea580c1c9f2aca15e1d29d4ea6aecc9c527/figures/applications.jpg -------------------------------------------------------------------------------- /figures/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sdevries0/Kozax/95c96ea580c1c9f2aca15e1d29d4ea6aecc9c527/figures/logo.jpg -------------------------------------------------------------------------------- /kozax/environments/SR_environments/lorenz_attractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jrandom 12 | from kozax.environments.SR_environments.time_series_environment_base import EnvironmentBase 13 | from jaxtyping import Array 14 | from typing import Tuple 15 | 16 | class LorenzAttractor(EnvironmentBase): 17 | """ 18 | Lorenz Attractor environment for symbolic regression tasks. 19 | 20 | Parameters 21 | ---------- 22 | process_noise : float, optional 23 | Standard deviation of the process noise. Default is 0. 24 | 25 | Attributes 26 | ---------- 27 | init_mu : :class:`jax.Array` 28 | Mean of the initial state distribution. 29 | init_sd : float 30 | Standard deviation of the initial state distribution. 31 | sigma : float 32 | Parameter sigma of the Lorenz system. 33 | rho : float 34 | Parameter rho of the Lorenz system. 35 | beta : float 36 | Parameter beta of the Lorenz system. 37 | V : :class:`jax.Array` 38 | Process noise covariance matrix. 39 | 40 | Methods 41 | ------- 42 | sample_init_states(batch_size, key) 43 | Samples initial states for the environment. 44 | drift(t, state, args) 45 | Computes the drift function for the environment. 46 | diffusion(t, state, args) 47 | Computes the diffusion function for the environment. 48 | """ 49 | 50 | def __init__(self, process_noise: float = 0) -> None: 51 | n_var = 3 52 | super().__init__(n_var, process_noise) 53 | 54 | self.init_mu = jnp.array([1, 1, 1]) 55 | self.init_sd = 0.1 56 | 57 | self.sigma = 10 58 | self.rho = 28 59 | self.beta = 8 / 3 60 | self.V = self.process_noise * jnp.eye(self.n_var) 61 | 62 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array: 63 | """ 64 | Samples initial states for the environment. 65 | 66 | Parameters 67 | ---------- 68 | batch_size : int 69 | Number of initial states to sample. 70 | key : :class:`jax.random.PRNGKey` 71 | Random key for sampling. 72 | 73 | Returns 74 | ------- 75 | :class:`jax.Array` 76 | Initial states. 77 | """ 78 | return self.init_mu + self.init_sd * jrandom.normal(key, shape=(batch_size, self.n_var)) 79 | 80 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 81 | """ 82 | Computes the drift function for the environment. 83 | 84 | Parameters 85 | ---------- 86 | t : float 87 | Current time. 88 | state : :class:`jax.Array` 89 | Current state. 90 | args : tuple 91 | Additional arguments. 92 | 93 | Returns 94 | ------- 95 | :class:`jax.Array` 96 | Drift. 97 | """ 98 | return jnp.array([self.sigma * (state[1] - state[0]), state[0] * (self.rho - state[2]) - state[1], state[0] * state[1] - self.beta * state[2]]) 99 | 100 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 101 | """ 102 | Computes the diffusion function for the environment. 103 | 104 | Parameters 105 | ---------- 106 | t : float 107 | Current time. 108 | state : :class:`jax.Array` 109 | Current state. 110 | args : tuple 111 | Additional arguments. 112 | 113 | Returns 114 | ------- 115 | :class:`jax.Array` 116 | Diffusion. 117 | """ 118 | return self.V 119 | 120 | class Lorenz96(EnvironmentBase): 121 | """ 122 | Lorenz 96 environment for symbolic regression tasks. 123 | 124 | Parameters 125 | ---------- 126 | n_dim : int, optional 127 | Number of dimensions. Default is 3. 128 | process_noise : float, optional 129 | Standard deviation of the process noise. Default is 0. 130 | 131 | Attributes 132 | ---------- 133 | F : float 134 | Forcing term. 135 | init_mu : :class:`jax.Array` 136 | Mean of the initial state distribution. 137 | init_sd : float 138 | Standard deviation of the initial state distribution. 139 | V : :class:`jax.Array` 140 | Process noise covariance matrix. 141 | 142 | Methods 143 | ------- 144 | sample_init_states(batch_size, key) 145 | Samples initial states for the environment. 146 | drift(t, state, args) 147 | Computes the drift function for the environment. 148 | diffusion(t, state, args) 149 | Computes the diffusion function for the environment. 150 | """ 151 | 152 | def __init__(self, n_dim: int = 3, process_noise: float = 0) -> None: 153 | n_var = n_dim 154 | super().__init__(n_var, process_noise) 155 | 156 | self.F = 8 157 | self.init_mu = jnp.ones(self.n_var) * self.F 158 | self.init_sd = 0.1 159 | 160 | self.V = self.process_noise * jnp.eye(self.n_var) 161 | 162 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array: 163 | """ 164 | Samples initial states for the environment. 165 | 166 | Parameters 167 | ---------- 168 | batch_size : int 169 | Number of initial states to sample. 170 | key : :class:`jax.random.PRNGKey` 171 | Random key for sampling. 172 | 173 | Returns 174 | ------- 175 | :class:`jax.Array` 176 | Initial states. 177 | """ 178 | return self.init_mu + self.init_sd * jrandom.normal(key, shape=(batch_size, self.n_var)) 179 | 180 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 181 | """ 182 | Computes the drift function for the environment. 183 | 184 | Parameters 185 | ---------- 186 | t : float 187 | Current time. 188 | state : :class:`jax.Array` 189 | Current state. 190 | args : tuple 191 | Additional arguments. 192 | 193 | Returns 194 | ------- 195 | :class:`jax.Array` 196 | Drift. 197 | """ 198 | f = lambda x_cur, x_next, x_prev1, x_prev2: (x_next - x_prev2) * x_prev1 - x_cur + self.F 199 | return jax.vmap(f)(state, jnp.roll(state, -1), jnp.roll(state, 1), jnp.roll(state, 2)) 200 | 201 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 202 | """ 203 | Computes the diffusion function for the environment. 204 | 205 | Parameters 206 | ---------- 207 | t : float 208 | Current time. 209 | state : :class:`jax.Array` 210 | Current state. 211 | args : tuple 212 | Additional arguments. 213 | 214 | Returns 215 | ------- 216 | :class:`jax.Array` 217 | Diffusion. 218 | """ 219 | return self.V -------------------------------------------------------------------------------- /kozax/environments/SR_environments/lotka_volterra.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jrandom 12 | from kozax.environments.SR_environments.time_series_environment_base import EnvironmentBase 13 | from jaxtyping import Array 14 | from typing import Tuple 15 | 16 | class LotkaVolterra(EnvironmentBase): 17 | """ 18 | Lotka-Volterra environment for symbolic regression tasks. 19 | 20 | Parameters 21 | ---------- 22 | process_noise : float, optional 23 | Standard deviation of the process noise. Default is 0. 24 | 25 | Attributes 26 | ---------- 27 | init_mu : :class:`jax.Array` 28 | Mean of the initial state distribution. 29 | init_sd : float 30 | Standard deviation of the initial state distribution. 31 | alpha : float 32 | Parameter alpha of the Lotka-Volterra system. 33 | beta : float 34 | Parameter beta of the Lotka-Volterra system. 35 | delta : float 36 | Parameter delta of the Lotka-Volterra system. 37 | gamma : float 38 | Parameter gamma of the Lotka-Volterra system. 39 | V : :class:`jax.Array` 40 | Process noise covariance matrix. 41 | 42 | Methods 43 | ------- 44 | sample_init_states(batch_size, key) 45 | Samples initial states for the environment. 46 | drift(t, state, args) 47 | Computes the drift function for the environment. 48 | diffusion(t, state, args) 49 | Computes the diffusion function for the environment. 50 | """ 51 | 52 | def __init__(self, process_noise: float = 0) -> None: 53 | n_var = 2 54 | super().__init__(n_var, process_noise) 55 | 56 | self.init_mu = jnp.array([10, 10]) 57 | self.init_sd = 2 58 | 59 | self.alpha = 1.1 60 | self.beta = 0.4 61 | self.delta = 0.1 62 | self.gamma = 0.4 63 | self.V = self.process_noise * jnp.eye(self.n_var) 64 | 65 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array: 66 | """ 67 | Samples initial states for the environment. 68 | 69 | Parameters 70 | ---------- 71 | batch_size : int 72 | Number of initial states to sample. 73 | key : :class:`jax.random.PRNGKey` 74 | Random key for sampling. 75 | 76 | Returns 77 | ------- 78 | :class:`jax.Array` 79 | Initial states. 80 | """ 81 | return jrandom.uniform(key, shape=(batch_size, self.n_var), minval=5, maxval=15) 82 | 83 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 84 | """ 85 | Computes the drift function for the environment. 86 | 87 | Parameters 88 | ---------- 89 | t : float 90 | Current time. 91 | state : :class:`jax.Array` 92 | Current state. 93 | args : tuple 94 | Additional arguments. 95 | 96 | Returns 97 | ------- 98 | :class:`jax.Array` 99 | Drift. 100 | """ 101 | return jnp.array([self.alpha * state[0] - self.beta * state[0] * state[1], self.delta * state[0] * state[1] - self.gamma * state[1]]) 102 | 103 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 104 | """ 105 | Computes the diffusion function for the environment. 106 | 107 | Parameters 108 | ---------- 109 | t : float 110 | Current time. 111 | state : :class:`jax.Array` 112 | Current state. 113 | args : tuple 114 | Additional arguments. 115 | 116 | Returns 117 | ------- 118 | :class:`jax.Array` 119 | Diffusion. 120 | """ 121 | return self.V -------------------------------------------------------------------------------- /kozax/environments/SR_environments/time_series_environment_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | from jax.random import PRNGKey 10 | import abc 11 | from typing import Tuple, Any 12 | from jaxtyping import Array 13 | 14 | class EnvironmentBase(abc.ABC): 15 | """ 16 | Abstract base class for time series environments in symbolic regression tasks. 17 | 18 | Parameters 19 | ---------- 20 | n_var : int 21 | Number of variables in the state. 22 | 23 | Methods 24 | ------- 25 | sample_init_states(batch_size, key) 26 | Samples initial states for the environment. 27 | drift(t, state, args) 28 | Computes the drift function for the environment. 29 | diffusion(t, state, args) 30 | Computes the diffusion function for the environment. 31 | terminate_event(state, **kwargs) 32 | Checks if the termination condition is met. 33 | """ 34 | 35 | def __init__(self, n_var: int, process_noise: float) -> None: 36 | self.n_var = n_var 37 | self.process_noise = process_noise 38 | 39 | @abc.abstractmethod 40 | def sample_init_states(self, batch_size: int, key: PRNGKey) -> Any: 41 | """ 42 | Samples initial states for the environment. 43 | 44 | Parameters 45 | ---------- 46 | batch_size : int 47 | Number of initial states to sample. 48 | key : :class:`jax.random.PRNGKey` 49 | Random key for sampling. 50 | 51 | Returns 52 | ------- 53 | Any 54 | Initial states. 55 | """ 56 | raise NotImplementedError 57 | 58 | @abc.abstractmethod 59 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 60 | """ 61 | Computes the drift function for the environment. 62 | 63 | Parameters 64 | ---------- 65 | t : float 66 | Current time. 67 | state : :class:`jax.Array` 68 | Current state. 69 | args : tuple 70 | Additional arguments. 71 | 72 | Returns 73 | ------- 74 | :class:`jax.Array` 75 | Drift. 76 | """ 77 | raise NotImplementedError 78 | 79 | @abc.abstractmethod 80 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 81 | """ 82 | Computes the diffusion function for the environment. 83 | 84 | Parameters 85 | ---------- 86 | t : float 87 | Current time. 88 | state : :class:`jax.Array` 89 | Current state. 90 | args : tuple 91 | Additional arguments. 92 | 93 | Returns 94 | ------- 95 | :class:`jax.Array` 96 | Diffusion. 97 | """ 98 | raise NotImplementedError 99 | 100 | def terminate_event(self, state: Array, **kwargs) -> bool: 101 | """ 102 | Checks if the termination condition is met. 103 | 104 | Parameters 105 | ---------- 106 | state : :class:`jax.Array` 107 | Current state. 108 | kwargs : dict 109 | Additional arguments. 110 | 111 | Returns 112 | ------- 113 | bool 114 | True if the termination condition is met, False otherwise. 115 | """ 116 | return False -------------------------------------------------------------------------------- /kozax/environments/SR_environments/vd_pol_oscillator.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jrandom 12 | from kozax.environments.SR_environments.time_series_environment_base import EnvironmentBase 13 | from jaxtyping import Array 14 | from typing import Tuple 15 | 16 | class VanDerPolOscillator(EnvironmentBase): 17 | """ 18 | Van der Pol Oscillator environment for symbolic regression tasks. 19 | 20 | Parameters 21 | ---------- 22 | process_noise : float, optional 23 | Standard deviation of the process noise. Default is 0. 24 | 25 | Attributes 26 | ---------- 27 | init_mu : :class:`jax.Array` 28 | Mean of the initial state distribution. 29 | init_sd : :class:`jax.Array` 30 | Standard deviation of the initial state distribution. 31 | mu : float 32 | Parameter mu of the Van der Pol system. 33 | V : :class:`jax.Array` 34 | Process noise covariance matrix. 35 | 36 | Methods 37 | ------- 38 | sample_init_states(batch_size, key) 39 | Samples initial states for the environment. 40 | drift(t, state, args) 41 | Computes the drift function for the environment. 42 | diffusion(t, state, args) 43 | Computes the diffusion function for the environment. 44 | """ 45 | 46 | def __init__(self, process_noise: float = 0) -> None: 47 | n_var = 2 48 | super().__init__(n_var, process_noise) 49 | 50 | self.init_mu = jnp.array([0, 0]) 51 | self.init_sd = jnp.array([1.0, 1.0]) 52 | 53 | self.mu = 1 54 | self.V = self.process_noise * jnp.eye(self.n_var) 55 | 56 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Array: 57 | """ 58 | Samples initial states for the environment. 59 | 60 | Parameters 61 | ---------- 62 | batch_size : int 63 | Number of initial states to sample. 64 | key : :class:`jax.random.PRNGKey` 65 | Random key for sampling. 66 | 67 | Returns 68 | ------- 69 | :class:`jax.Array` 70 | Initial states. 71 | """ 72 | return self.init_mu + self.init_sd * jrandom.normal(key, shape=(batch_size, self.n_var)) 73 | 74 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 75 | """ 76 | Computes the drift function for the environment. 77 | 78 | Parameters 79 | ---------- 80 | t : float 81 | Current time. 82 | state : :class:`jax.Array` 83 | Current state. 84 | args : tuple 85 | Additional arguments. 86 | 87 | Returns 88 | ------- 89 | :class:`jax.Array` 90 | Drift. 91 | """ 92 | return jnp.array([state[1], self.mu * (1 - state[0]**2) * state[1] - state[0]]) 93 | 94 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 95 | """ 96 | Computes the diffusion function for the environment. 97 | 98 | Parameters 99 | ---------- 100 | t : float 101 | Current time. 102 | state : :class:`jax.Array` 103 | Current state. 104 | args : tuple 105 | Additional arguments. 106 | 107 | Returns 108 | ------- 109 | :class:`jax.Array` 110 | Diffusion. 111 | """ 112 | return self.V -------------------------------------------------------------------------------- /kozax/environments/control_environments/acrobot.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jrandom 12 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase 13 | from jaxtyping import Array 14 | from typing import Tuple 15 | 16 | class Acrobot(EnvironmentBase): 17 | """ 18 | Acrobot environment for control tasks. 19 | 20 | Parameters 21 | ---------- 22 | process_noise : float 23 | Standard deviation of the process noise. 24 | obs_noise : float 25 | Standard deviation of the observation noise. 26 | n_obs : int, optional 27 | Number of observations. Default is 4. 28 | 29 | Attributes 30 | ---------- 31 | n_var : int 32 | Number of variables in the state. 33 | n_control_inputs : int 34 | Number of control inputs. 35 | n_targets : int 36 | Number of targets. 37 | n_dim : int 38 | Number of dimensions. 39 | init_bounds : :class:`jax.Array` 40 | Bounds for initial state sampling. 41 | R : :class:`jax.Array` 42 | Control cost matrix. 43 | 44 | Methods 45 | ------- 46 | sample_init_states(batch_size, key) 47 | Samples initial states for the environment. 48 | sample_params(batch_size, mode, ts, key) 49 | Samples parameters for the environment. 50 | f_obs(key, t_x) 51 | Computes the observation function. 52 | initialize_parameters(params, ts) 53 | Initializes the parameters of the environment. 54 | drift(t, state, args) 55 | Computes the drift function for the environment. 56 | diffusion(t, state, args) 57 | Computes the diffusion function for the environment. 58 | fitness_function(state, control, target, ts) 59 | Computes the fitness function for the environment. 60 | cond_fn_nan(t, y, args, **kwargs) 61 | Checks for NaN or infinite values in the state. 62 | """ 63 | 64 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 4) -> None: 65 | self.n_var = 4 66 | self.n_control_inputs = 1 67 | self.n_targets = 0 68 | self.n_dim = 1 69 | self.init_bounds = jnp.array([0.1, 0.1, 0.1, 0.1]) 70 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs) 71 | 72 | self.R = 0.1 * jnp.eye(self.n_control_inputs) 73 | 74 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]: 75 | """ 76 | Samples initial states for the environment. 77 | 78 | Parameters 79 | ---------- 80 | batch_size : int 81 | Number of initial states to sample. 82 | key : :class:`jax.random.PRNGKey` 83 | Random key for sampling. 84 | 85 | Returns 86 | ------- 87 | x0 : :class:`jax.Array` 88 | Initial states. 89 | targets : :class:`jax.Array` 90 | Target states. 91 | """ 92 | init_key, target_key = jrandom.split(key) 93 | x0 = jrandom.uniform(init_key, shape=(batch_size, self.n_var), minval=-self.init_bounds, maxval=self.init_bounds) 94 | targets = jnp.zeros((batch_size, self.n_targets)) 95 | return x0, targets 96 | 97 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Tuple[Array, Array, Array, Array]: 98 | """ 99 | Samples parameters for the environment. 100 | 101 | Parameters 102 | ---------- 103 | batch_size : int 104 | Number of parameters to sample. 105 | mode : str 106 | Mode for sampling parameters. 107 | ts : :class:`jax.Array` 108 | Time steps. 109 | key : :class:`jax.random.PRNGKey` 110 | Random key for sampling. 111 | 112 | Returns 113 | ------- 114 | tuple of :class:`jax.Array` 115 | Sampled parameters. 116 | """ 117 | l1 = l2 = m1 = m2 = jnp.ones((batch_size)) 118 | return l1, l2, m1, m2 119 | 120 | def f_obs(self, key: jrandom.PRNGKey, t_x: Tuple[float, Array]) -> Tuple[jrandom.PRNGKey, Array]: 121 | """ 122 | Computes the observation function. 123 | 124 | Parameters 125 | ---------- 126 | key : :class:`jax.random.PRNGKey` 127 | Random key for sampling noise. 128 | t_x : tuple of (float, :class:`jax.Array`) 129 | Tuple containing the current time and state. 130 | 131 | Returns 132 | ------- 133 | key : :class:`jax.random.PRNGKey` 134 | Updated random key. 135 | out : :class:`jax.Array` 136 | Observation. 137 | """ 138 | _, out = super().f_obs(key, t_x) 139 | out = jnp.array([(out[0] + jnp.pi) % (2 * jnp.pi) - jnp.pi, (out[1] + jnp.pi) % (2 * jnp.pi) - jnp.pi, out[2], out[3]])[:self.n_obs] 140 | return key, out 141 | 142 | def initialize_parameters(self, params: Tuple[Array, Array, Array, Array], ts: Array) -> None: 143 | """ 144 | Initializes the parameters of the environment. 145 | 146 | Parameters 147 | ---------- 148 | params : tuple of :class:`jax.Array` 149 | Parameters to initialize. 150 | ts : :class:`jax.Array` 151 | Time steps. 152 | """ 153 | l1, l2, m1, m2 = params 154 | self.l1 = l1 # [m] 155 | self.l2 = l2 # [m] 156 | self.m1 = m1 #: [kg] mass of link 1 157 | self.m2 = m2 #: [kg] mass of link 2 158 | self.lc1 = 0.5 * self.l1 #: [m] position of the center of mass of link 1 159 | self.lc2 = 0.5 * self.l2 #: [m] position of the center of mass of link 2 160 | self.moi1 = self.moi2 = 1.0 161 | self.g = 9.81 162 | 163 | self.G = jnp.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 164 | self.V = self.process_noise * self.G 165 | 166 | self.C = jnp.eye(self.n_var)[:self.n_obs] 167 | self.W = self.obs_noise * jnp.eye(self.n_obs) 168 | 169 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 170 | """ 171 | Computes the drift function for the environment. 172 | 173 | Parameters 174 | ---------- 175 | t : float 176 | Current time. 177 | state : :class:`jax.Array` 178 | Current state. 179 | args : tuple 180 | Additional arguments. 181 | 182 | Returns 183 | ------- 184 | :class:`jax.Array` 185 | Drift. 186 | """ 187 | control = jnp.squeeze(args) 188 | control = jnp.clip(control, -1, 1) 189 | theta1, theta2, theta1_dot, theta2_dot = state 190 | 191 | d1 = self.m1 * self.lc1**2 + self.m2 * (self.l1**2 + self.lc2**2 + 2 * self.l1 * self.lc2 * jnp.cos(theta2)) + self.moi1 + self.moi2 192 | d2 = self.m2 * (self.lc2**2 + self.l1 * self.lc2 * jnp.cos(theta2)) + self.moi2 193 | 194 | phi2 = self.m2 * self.lc2 * self.g * jnp.cos(theta1 + theta2 - jnp.pi/2) 195 | phi1 = -self.m2 * self.l1 * self.lc2 * theta2_dot**2 * jnp.sin(theta2) - 2 * self.m2 * self.l1 * self.lc2 * theta1_dot * theta2_dot * jnp.sin(theta1) \ 196 | + (self.m1 * self.lc1 + self.m2 * self.l1) * self.g * jnp.cos(theta1 - jnp.pi/2) + phi2 197 | 198 | if self.n_control_inputs == 1: 199 | theta2_acc = (control + d2/d1 * phi1 - self.m2 * self.l1 * self.lc2 * theta1_dot**2 * jnp.sin(theta2) - phi2) \ 200 | / (self.m2 * self.lc2**2 + self.moi2 - d2**2 / d1) 201 | theta1_acc = -(d2 * theta2_acc + phi1)/d1 202 | else: 203 | c1, c2 = control 204 | theta2_acc = (c1 + d2/d1 * phi1 - self.m2 * self.l1 * self.lc2 * theta1_dot**2 * jnp.sin(theta2) - phi2) \ 205 | / (self.m2 * self.lc2**2 + self.moi2 - d2**2 / d1) 206 | theta1_acc = (c2 - d2 * theta2_acc - phi1)/d1 207 | 208 | return jnp.array([ 209 | theta1_dot, 210 | theta2_dot, 211 | theta1_acc, 212 | theta2_acc 213 | ]) 214 | 215 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 216 | """ 217 | Computes the diffusion function for the environment. 218 | 219 | Parameters 220 | ---------- 221 | t : float 222 | Current time. 223 | state : :class:`jax.Array` 224 | Current state. 225 | args : tuple 226 | Additional arguments. 227 | 228 | Returns 229 | ------- 230 | :class:`jax.Array` 231 | Diffusion. 232 | """ 233 | return self.V 234 | 235 | def fitness_function(self, state: Array, control: Array, target: Array, ts: Array) -> float: 236 | """ 237 | Computes the fitness function for the environment. 238 | 239 | Parameters 240 | ---------- 241 | state : :class:`jax.Array` 242 | Current state. 243 | control : :class:`jax.Array` 244 | Control inputs. 245 | target : :class:`jax.Array` 246 | Target states. 247 | ts : :class:`jax.Array` 248 | Time steps. 249 | 250 | Returns 251 | ------- 252 | float 253 | Fitness value. 254 | """ 255 | reached_threshold = jax.vmap(lambda theta1, theta2: -jnp.cos(theta1) - jnp.cos(theta1 + theta2) > 1.5)(state[:,0], state[:,1]) 256 | first_success = jnp.argmax(reached_threshold) 257 | 258 | control = jnp.clip(control, -1, 1) 259 | 260 | control_cost = jax.vmap(lambda _state, _u: _u @ self.R @ _u)(state, control) 261 | costs = jnp.where((ts / (ts[1] - ts[0])) > first_success, jnp.zeros_like(control_cost), control_cost) 262 | 263 | return (first_success + (first_success == 0) * ts.shape[0] + jnp.sum(costs)) 264 | 265 | def cond_fn_nan(self, t: float, y: Array, args: Tuple, **kwargs) -> float: 266 | """ 267 | Checks for NaN or infinite values in the state. 268 | 269 | Parameters 270 | ---------- 271 | t : float 272 | Current time. 273 | y : :class:`jax.Array` 274 | Current state. 275 | args : tuple 276 | Additional arguments. 277 | kwargs : dict 278 | Additional keyword arguments. 279 | 280 | Returns 281 | ------- 282 | float 283 | -1.0 if NaN or infinite values are found, 1.0 otherwise. 284 | """ 285 | return jnp.where((jnp.abs(y[2]) > (4 * jnp.pi)) | (jnp.abs(y[3]) > (9 * jnp.pi)) | jnp.any(jnp.isnan(y)) | jnp.any(jnp.isinf(y)), -1.0, 1.0) -------------------------------------------------------------------------------- /kozax/environments/control_environments/cart_pole.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jrandom 12 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase 13 | from jaxtyping import Array 14 | from typing import Tuple 15 | 16 | class CartPole(EnvironmentBase): 17 | """ 18 | CartPole environment for control tasks. 19 | 20 | Parameters 21 | ---------- 22 | process_noise : float 23 | Standard deviation of the process noise. 24 | obs_noise : float 25 | Standard deviation of the observation noise. 26 | n_obs : int, optional 27 | Number of observations. Default is 4. 28 | 29 | Attributes 30 | ---------- 31 | n_var : int 32 | Number of variables in the state. 33 | n_control_inputs : int 34 | Number of control inputs. 35 | n_targets : int 36 | Number of targets. 37 | n_dim : int 38 | Number of dimensions. 39 | init_bounds : :class:`jax.Array` 40 | Bounds for initial state sampling. 41 | Q : :class:`jax.Array` 42 | Process noise covariance matrix. 43 | R : :class:`jax.Array` 44 | Observation noise covariance matrix. 45 | 46 | Methods 47 | ------- 48 | sample_init_states(batch_size, key) 49 | Samples initial states for the environment. 50 | sample_params(batch_size, mode, ts, key) 51 | Samples parameters for the environment. 52 | initialize_parameters(params, ts) 53 | Initializes the parameters of the environment. 54 | drift(t, state, args) 55 | Computes the drift function for the environment. 56 | diffusion(t, state, args) 57 | Computes the diffusion function for the environment. 58 | fitness_function(state, control, target, ts) 59 | Computes the fitness function for the environment. 60 | terminate_event(state, **kwargs) 61 | Checks if the termination condition is met. 62 | """ 63 | 64 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 4) -> None: 65 | self.n_var = 4 66 | self.n_control_inputs = 1 67 | self.n_targets = 0 68 | self.n_dim = 1 69 | self.init_bounds = jnp.array([0.05, 0.05, 0.05, 0.05]) 70 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs) 71 | 72 | self.Q = jnp.array(0) 73 | self.R = jnp.array([[0.0]]) 74 | 75 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]: 76 | """ 77 | Samples initial states for the environment. 78 | 79 | Parameters 80 | ---------- 81 | batch_size : int 82 | Number of initial states to sample. 83 | key : :class:`jax.random.PRNGKey` 84 | Random key for sampling. 85 | 86 | Returns 87 | ------- 88 | x0 : :class:`jax.Array` 89 | Initial states. 90 | targets : :class:`jax.Array` 91 | Target states. 92 | """ 93 | init_key, target_key = jrandom.split(key) 94 | x0 = jrandom.uniform(init_key, shape=(batch_size, self.n_var), minval=-self.init_bounds, maxval=self.init_bounds) 95 | targets = jnp.zeros((batch_size, self.n_targets)) 96 | return x0, targets 97 | 98 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Array: 99 | """ 100 | Samples parameters for the environment. 101 | 102 | Parameters 103 | ---------- 104 | batch_size : int 105 | Number of parameters to sample. 106 | mode : str 107 | Mode for sampling parameters. 108 | ts : :class:`jax.Array` 109 | Time steps. 110 | key : :class:`jax.random.PRNGKey` 111 | Random key for sampling. 112 | 113 | Returns 114 | ------- 115 | :class:`jax.Array` 116 | Sampled parameters. 117 | """ 118 | return jnp.zeros((batch_size)) 119 | 120 | def initialize_parameters(self, params: Array, ts: Array) -> None: 121 | """ 122 | Initializes the parameters of the environment. 123 | 124 | Parameters 125 | ---------- 126 | params : :class:`jax.Array` 127 | Parameters to initialize. 128 | ts : :class:`jax.Array` 129 | Time steps. 130 | """ 131 | _ = params 132 | self.g = 9.81 133 | self.pole_mass = 0.1 134 | self.pole_length = 0.5 135 | self.cart_mass = 1 136 | 137 | self.G = jnp.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]]) 138 | self.V = self.process_noise * self.G 139 | 140 | self.C = jnp.eye(self.n_var)[:self.n_obs] 141 | self.W = self.obs_noise * jnp.eye(self.n_obs) 142 | 143 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 144 | """ 145 | Computes the drift function for the environment. 146 | 147 | Parameters 148 | ---------- 149 | t : float 150 | Current time. 151 | state : :class:`jax.Array` 152 | Current state. 153 | args : tuple 154 | Additional arguments. 155 | 156 | Returns 157 | ------- 158 | :class:`jax.Array` 159 | Drift. 160 | """ 161 | control = jnp.squeeze(args) 162 | control = jnp.clip(control, -1, 1) 163 | x, theta, x_dot, theta_dot = state 164 | 165 | cos_theta = jnp.cos(theta) 166 | sin_theta = jnp.sin(theta) 167 | 168 | theta_acc = (self.g * sin_theta - cos_theta * ( 169 | control + self.pole_mass * self.pole_length * theta_dot**2 * sin_theta 170 | ) / (self.cart_mass + self.pole_mass)) / ( 171 | self.pole_length * (4/3 - (self.pole_mass * cos_theta**2) / (self.cart_mass + self.pole_mass))) 172 | 173 | x_acc = (control + self.pole_mass * self.pole_length * (theta_dot**2 * sin_theta - theta_acc * cos_theta)) / (self.cart_mass + self.pole_mass) 174 | 175 | return jnp.array([ 176 | x_dot, 177 | theta_dot, 178 | x_acc, 179 | theta_acc 180 | ]) 181 | 182 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 183 | """ 184 | Computes the diffusion function for the environment. 185 | 186 | Parameters 187 | ---------- 188 | t : float 189 | Current time. 190 | state : :class:`jax.Array` 191 | Current state. 192 | args : tuple 193 | Additional arguments. 194 | 195 | Returns 196 | ------- 197 | :class:`jax.Array` 198 | Diffusion. 199 | """ 200 | return self.V 201 | 202 | def fitness_function(self, state: Array, control: Array, target: Array, ts: Array) -> float: 203 | """ 204 | Computes the fitness function for the environment. 205 | 206 | Parameters 207 | ---------- 208 | state : :class:`jax.Array` 209 | Current state. 210 | control : :class:`jax.Array` 211 | Control inputs. 212 | target : :class:`jax.Array` 213 | Target states. 214 | ts : :class:`jax.Array` 215 | Time steps. 216 | 217 | Returns 218 | ------- 219 | float 220 | Fitness value. 221 | """ 222 | invalid_points = jax.vmap(lambda _x, _u: jnp.any(jnp.isinf(_x)) + jnp.isnan(_u))(state, control[:, 0]) 223 | punishment = jnp.ones_like(invalid_points) 224 | 225 | costs = jnp.where(invalid_points, punishment, jnp.zeros_like(punishment)) 226 | 227 | return jnp.sum(costs) 228 | 229 | def terminate_event(self, state: Array, **kwargs) -> bool: 230 | """ 231 | Checks if the termination condition is met. 232 | 233 | Parameters 234 | ---------- 235 | state : :class:`jax.Array` 236 | Current state. 237 | kwargs : dict 238 | Additional arguments. 239 | 240 | Returns 241 | ------- 242 | bool 243 | True if the termination condition is met, False otherwise. 244 | """ 245 | return (jnp.abs(state[1]) > 0.2) | (jnp.abs(state[0]) > 4.8) | jnp.any(jnp.isnan(state)) | jnp.any(jnp.isinf(state)) -------------------------------------------------------------------------------- /kozax/environments/control_environments/control_environment_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jrandom 12 | import abc 13 | 14 | _itemsize_kind_type = { 15 | (1, "i"): jnp.int8, 16 | (2, "i"): jnp.int16, 17 | (4, "i"): jnp.int32, 18 | (8, "i"): jnp.int64, 19 | (2, "f"): jnp.float16, 20 | (4, "f"): jnp.float32, 21 | (8, "f"): jnp.float64, 22 | } 23 | 24 | def force_bitcast_convert_type(val, new_type=jnp.int32): 25 | val = jnp.asarray(val) 26 | intermediate_type = _itemsize_kind_type[new_type.dtype.itemsize, val.dtype.kind] 27 | val = val.astype(intermediate_type) 28 | return jax.lax.bitcast_convert_type(val, new_type) 29 | 30 | class EnvironmentBase(abc.ABC): 31 | def __init__(self, process_noise, obs_noise, n_var, n_control_inputs, n_dim, n_obs): 32 | self.process_noise = process_noise 33 | self.obs_noise = obs_noise 34 | self.n_var = n_var 35 | self.n_control_inputs = n_control_inputs 36 | self.n_dim = n_dim 37 | self.n_obs = n_obs 38 | 39 | @abc.abstractmethod 40 | def initialize_parameters(self, params, ts): 41 | raise NotImplementedError 42 | 43 | @abc.abstractmethod 44 | def sample_init_states(self, batch_size, key): 45 | raise NotImplementedError 46 | 47 | @abc.abstractmethod 48 | def sample_params(self, batch_size, mode, ts, key): 49 | raise NotImplementedError 50 | 51 | def f_obs(self, key, t_x): 52 | t, x = t_x 53 | new_key = jrandom.fold_in(key, force_bitcast_convert_type(t)) 54 | out = self.C@x + jrandom.normal(new_key, shape=(self.n_obs,))@self.W 55 | return key, out 56 | 57 | @abc.abstractmethod 58 | def drift(self, t, state, args): 59 | raise NotImplementedError 60 | 61 | @abc.abstractmethod 62 | def diffusion(self, t, state, args): 63 | raise NotImplementedError 64 | 65 | @abc.abstractmethod 66 | def fitness_function(self, state, control, target, ts): 67 | raise NotImplementedError 68 | 69 | def terminate_event(self, state, **kwargs): 70 | return False -------------------------------------------------------------------------------- /kozax/environments/control_environments/harmonic_oscillator.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jrandom 12 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase 13 | from jaxtyping import Array 14 | from typing import Tuple 15 | 16 | class HarmonicOscillator(EnvironmentBase): 17 | """ 18 | Harmonic Oscillator environment for control tasks. 19 | 20 | Parameters 21 | ---------- 22 | process_noise : float 23 | Standard deviation of the process noise. 24 | obs_noise : float 25 | Standard deviation of the observation noise. 26 | n_obs : int, optional 27 | Number of observations. Default is 2. 28 | 29 | Attributes 30 | ---------- 31 | n_dim : int 32 | Number of dimensions. 33 | n_var : int 34 | Number of variables in the state. 35 | n_control_inputs : int 36 | Number of control inputs. 37 | n_targets : int 38 | Number of targets. 39 | mu0 : :class:`jax.Array` 40 | Mean of the initial state distribution. 41 | P0 : :class:`jax.Array` 42 | Covariance matrix of the initial state distribution. 43 | q : float 44 | Process noise parameter. 45 | r : float 46 | Observation noise parameter. 47 | Q : :class:`jax.Array` 48 | Process noise covariance matrix. 49 | R : :class:`jax.Array` 50 | Observation noise covariance matrix. 51 | 52 | Methods 53 | ------- 54 | sample_init_states(batch_size, key) 55 | Samples initial states for the environment. 56 | sample_params(batch_size, mode, ts, key) 57 | Samples parameters for the environment. 58 | initialize_parameters(params, ts) 59 | Initializes the parameters of the environment. 60 | drift(t, state, args) 61 | Computes the drift function for the environment. 62 | diffusion(t, state, args) 63 | Computes the diffusion function for the environment. 64 | fitness_function(state, control, target, ts) 65 | Computes the fitness function for the environment. 66 | cond_fn_nan(t, y, args, **kwargs) 67 | Checks for NaN or infinite values in the state. 68 | """ 69 | 70 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 2) -> None: 71 | self.n_dim = 1 72 | self.n_var = 2 73 | self.n_control_inputs = 1 74 | self.n_targets = 1 75 | self.mu0 = jnp.zeros(self.n_var) 76 | self.P0 = jnp.eye(self.n_var) * jnp.array([3.0, 1.0]) 77 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs) 78 | 79 | self.q = self.r = 0.5 80 | self.Q = jnp.array([[self.q, 0], [0, 0]]) 81 | self.R = jnp.array([[self.r]]) 82 | 83 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]: 84 | """ 85 | Samples initial states for the environment. 86 | 87 | Parameters 88 | ---------- 89 | batch_size : int 90 | Number of initial states to sample. 91 | key : :class:`jax.random.PRNGKey` 92 | Random key for sampling. 93 | 94 | Returns 95 | ------- 96 | x0 : :class:`jax.Array` 97 | Initial states. 98 | targets : :class:`jax.Array` 99 | Target states. 100 | """ 101 | init_key, target_key = jrandom.split(key) 102 | x0 = self.mu0 + jrandom.normal(init_key, shape=(batch_size, self.n_var)) @ self.P0 103 | targets = jrandom.uniform(target_key, shape=(batch_size, self.n_targets), minval=-3, maxval=3) 104 | return x0, targets 105 | 106 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Tuple[Array, Array]: 107 | """ 108 | Samples parameters for the environment. 109 | 110 | Parameters 111 | ---------- 112 | batch_size : int 113 | Number of parameters to sample. 114 | mode : str 115 | Mode for sampling parameters. Options are "Constant", "Different", "Changing". 116 | ts : :class:`jax.Array` 117 | Time steps. 118 | key : :class:`jax.random.PRNGKey` 119 | Random key for sampling. 120 | 121 | Returns 122 | ------- 123 | omegas : :class:`jax.Array` 124 | Sampled omega parameters. 125 | zetas : :class:`jax.Array` 126 | Sampled zeta parameters. 127 | """ 128 | omega_key, zeta_key, args_key = jrandom.split(key, 3) 129 | if mode == "Constant": 130 | omegas = jnp.ones((batch_size)) 131 | zetas = jnp.zeros((batch_size)) 132 | elif mode == "Different": 133 | omegas = jrandom.uniform(omega_key, shape=(batch_size,), minval=0.0, maxval=2.0) 134 | zetas = jrandom.uniform(zeta_key, shape=(batch_size,), minval=0.0, maxval=1.5) 135 | elif mode == "Changing": 136 | decay_factors = jrandom.uniform(args_key, shape=(batch_size, 2), minval=0.98, maxval=1.02) 137 | init_omegas = jrandom.uniform(omega_key, shape=(batch_size,), minval=0.5, maxval=1.5) 138 | init_zetas = jrandom.uniform(zeta_key, shape=(batch_size,), minval=0.0, maxval=1.0) 139 | omegas = jax.vmap(lambda o, d, t: o * (d ** t), in_axes=[0, 0, None])(init_omegas, decay_factors[:, 0], ts) 140 | zetas = jax.vmap(lambda z, d, t: z * (d ** t), in_axes=[0, 0, None])(init_zetas, decay_factors[:, 1], ts) 141 | return omegas, zetas 142 | 143 | def initialize_parameters(self, params: Tuple[Array, Array], ts: Array) -> None: 144 | """ 145 | Initializes the parameters of the environment. 146 | 147 | Parameters 148 | ---------- 149 | params : tuple of :class:`jax.Array` 150 | Parameters to initialize. 151 | ts : :class:`jax.Array` 152 | Time steps. 153 | """ 154 | omega, zeta = params 155 | self.A = jnp.array([[0, 1], [-omega, -zeta]]) 156 | self.b = jnp.array([[0.0, 1.0]]).T 157 | self.G = jnp.array([[0, 0], [0, 1]]) 158 | self.V = self.process_noise * self.G 159 | self.C = jnp.eye(self.n_var)[:self.n_obs] 160 | self.W = self.obs_noise * jnp.eye(self.n_obs) 161 | 162 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 163 | """ 164 | Computes the drift function for the environment. 165 | 166 | Parameters 167 | ---------- 168 | t : float 169 | Current time. 170 | state : :class:`jax.Array` 171 | Current state. 172 | args : tuple 173 | Additional arguments. 174 | 175 | Returns 176 | ------- 177 | :class:`jax.Array` 178 | Drift. 179 | """ 180 | return self.A @ state + self.b @ args 181 | 182 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 183 | """ 184 | Computes the diffusion function for the environment. 185 | 186 | Parameters 187 | ---------- 188 | t : float 189 | Current time. 190 | state : :class:`jax.Array` 191 | Current state. 192 | args : tuple 193 | Additional arguments. 194 | 195 | Returns 196 | ------- 197 | :class:`jax.Array` 198 | Diffusion. 199 | """ 200 | return self.V 201 | 202 | def fitness_function(self, state: Array, control: Array, target: Array, ts: Array) -> float: 203 | """ 204 | Computes the fitness function for the environment. 205 | 206 | Parameters 207 | ---------- 208 | state : :class:`jax.Array` 209 | Current state. 210 | control : :class:`jax.Array` 211 | Control inputs. 212 | target : :class:`jax.Array` 213 | Target states. 214 | ts : :class:`jax.Array` 215 | Time steps. 216 | 217 | Returns 218 | ------- 219 | float 220 | Fitness value. 221 | """ 222 | x_d = jnp.array([jnp.squeeze(target), 0]) 223 | u_d = -jnp.linalg.pinv(self.b) @ self.A @ x_d 224 | costs = jax.vmap(lambda _state, _u: (_state - x_d).T @ self.Q @ (_state - x_d) + (_u - u_d) @ self.R @ (_u - u_d))(state, control) 225 | return jnp.sum(costs) 226 | 227 | def cond_fn_nan(self, t: float, y: Array, args: Tuple, **kwargs) -> float: 228 | """ 229 | Checks for NaN or infinite values in the state. 230 | 231 | Parameters 232 | ---------- 233 | t : float 234 | Current time. 235 | y : :class:`jax.Array` 236 | Current state. 237 | args : tuple 238 | Additional arguments. 239 | kwargs : dict 240 | Additional keyword arguments. 241 | 242 | Returns 243 | ------- 244 | float 245 | -1.0 if NaN or infinite values are found, 1.0 otherwise. 246 | """ 247 | return jnp.where(jnp.any(jnp.isinf(y) + jnp.isnan(y)), -1.0, 1.0) -------------------------------------------------------------------------------- /kozax/environments/control_environments/reactor.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax.numpy as jnp 10 | import jax 11 | import jax.random as jrandom 12 | import diffrax 13 | from kozax.environments.control_environments.control_environment_base import EnvironmentBase 14 | from jaxtyping import Array 15 | from typing import Tuple 16 | 17 | class StirredTankReactor(EnvironmentBase): 18 | """ 19 | Stirred Tank Reactor environment for control tasks. 20 | 21 | Parameters 22 | ---------- 23 | process_noise : float 24 | Standard deviation of the process noise. 25 | obs_noise : float 26 | Standard deviation of the observation noise. 27 | n_obs : int, optional 28 | Number of observations. Default is 3. 29 | n_targets : int, optional 30 | Number of targets. Default is 1. 31 | max_control : :class:`jax.Array`, optional 32 | Maximum control values. Default is jnp.array([300]). 33 | external_f : callable, optional 34 | External influence function. Default is lambda t: 0.0. 35 | 36 | Attributes 37 | ---------- 38 | n_var : int 39 | Number of variables in the state. 40 | n_control_inputs : int 41 | Number of control inputs. 42 | n_dim : int 43 | Number of dimensions. 44 | n_targets : int 45 | Number of targets. 46 | init_lower_bounds : :class:`jax.Array` 47 | Lower bounds for initial state sampling. 48 | init_upper_bounds : :class:`jax.Array` 49 | Upper bounds for initial state sampling. 50 | max_control : :class:`jax.Array` 51 | Maximum control values. 52 | Q : :class:`jax.Array` 53 | Process noise covariance matrix. 54 | r : :class:`jax.Array` 55 | Observation noise covariance matrix. 56 | external_f : callable 57 | External influence function. 58 | 59 | Methods 60 | ------- 61 | initialize_parameters(params, ts) 62 | Initializes the parameters of the environment. 63 | sample_param_change(key, batch_size, ts, low, high) 64 | Samples parameter changes over time. 65 | sample_params(batch_size, mode, ts, key) 66 | Samples parameters for the environment. 67 | sample_init_states(batch_size, key) 68 | Samples initial states for the environment. 69 | drift(t, state, args) 70 | Computes the drift function for the environment. 71 | diffusion(t, state, args) 72 | Computes the diffusion function for the environment. 73 | fitness_function(state, control, targets, ts) 74 | Computes the fitness function for the environment. 75 | cond_fn_nan(t, y, args, **kwargs) 76 | Checks for NaN or infinite values in the state. 77 | """ 78 | 79 | def __init__(self, process_noise: float = 0.0, obs_noise: float = 0.0, n_obs: int = 3, n_targets: int = 1, max_control: Array = jnp.array([300]), external_f: callable = lambda t: 0.0) -> None: 80 | self.process_noise = process_noise 81 | self.obs_noise = obs_noise 82 | self.n_var = 3 83 | self.n_control_inputs = 1 84 | self.n_dim = 1 85 | self.n_targets = n_targets 86 | self.init_lower_bounds = jnp.array([275, 350, 0.5]) 87 | self.init_upper_bounds = jnp.array([300, 375, 1.0]) 88 | self.max_control = max_control 89 | super().__init__(process_noise, obs_noise, self.n_var, self.n_control_inputs, self.n_dim, n_obs) 90 | 91 | self.Q = jnp.array([[0, 0, 0], [0, 0.01, 0], [0, 0, 0]]) 92 | self.r = jnp.array([[0.0001]]) 93 | self.external_f = external_f 94 | 95 | def initialize_parameters(self, params: Tuple[Array, Array, Array, Array, Array, Array, Array, Array], ts: Array) -> None: 96 | """ 97 | Initializes the parameters of the environment. 98 | 99 | Parameters 100 | ---------- 101 | params : tuple of :class:`jax.Array` 102 | Parameters to initialize. 103 | ts : :class:`jax.Array` 104 | Time steps. 105 | """ 106 | Vol, Cp, dHr, UA, q, Tf, Tcf, Volc = params 107 | self.Ea = 72750 # activation energy J/gmol 108 | self.R = 8.314 # gas constant J/gmol/K 109 | self.k0 = 7.2e10 # Arrhenius rate constant 1/min 110 | self.Vol = Vol # Volume [L] 111 | self.Cp = Cp # Heat capacity [J/g/K] 112 | self.dHr = dHr # Enthalpy of reaction [J/mol] 113 | self.UA = UA # Heat transfer [J/min/K] 114 | self.q = q # Flowrate [L/min] 115 | self.Cf = 1.0 # Inlet feed concentration [mol/L] 116 | self.Tf = diffrax.LinearInterpolation(ts, Tf) # Inlet feed temperature [K] 117 | self.Tcf = Tcf # Coolant feed temperature [K] 118 | self.Volc = Volc # Cooling jacket volume 119 | 120 | self.k = lambda T: self.k0 * jnp.exp(-self.Ea / self.R / T) 121 | 122 | self.G = jnp.eye(self.n_var) * jnp.array([6, 6, 0.05]) 123 | self.process_noise_ts = diffrax.LinearInterpolation(ts, jnp.linspace(self.process_noise[0], self.process_noise[1], ts.shape[0])) 124 | 125 | self.C = jnp.eye(self.n_var)[:self.n_obs] 126 | self.W = self.obs_noise * jnp.eye(self.n_obs) * (jnp.array([15, 15, 0.1])[:self.n_obs]) 127 | 128 | self.max_control_ts = diffrax.LinearInterpolation(ts, jnp.hstack([mc * jnp.ones(int(ts.shape[0] // self.max_control.shape[0])) for mc in self.max_control])) 129 | self.external_influence = diffrax.LinearInterpolation(ts, jax.vmap(self.external_f)(ts)) 130 | 131 | def sample_param_change(self, key: jrandom.PRNGKey, batch_size: int, ts: Array, low: float, high: float) -> Array: 132 | """ 133 | Samples parameter changes over time. 134 | 135 | Parameters 136 | ---------- 137 | key : :class:`jax.random.PRNGKey` 138 | Random key for sampling. 139 | batch_size : int 140 | Number of samples. 141 | ts : :class:`jax.Array` 142 | Time steps. 143 | low : float 144 | Lower bound for sampling. 145 | high : float 146 | Upper bound for sampling. 147 | 148 | Returns 149 | ------- 150 | :class:`jax.Array` 151 | Sampled parameter values. 152 | """ 153 | init_key, decay_key = jrandom.split(key) 154 | decay_factors = jrandom.uniform(decay_key, shape=(batch_size,), minval=1.01, maxval=1.02) 155 | init_values = jrandom.uniform(init_key, shape=(batch_size,), minval=low, maxval=high) 156 | values = jax.vmap(lambda v, d, t: v * (d ** t), in_axes=[0, 0, None])(init_values, decay_factors, ts) 157 | return values 158 | 159 | def sample_params(self, batch_size: int, mode: str, ts: Array, key: jrandom.PRNGKey) -> Tuple[Array, Array, Array, Array, Array, Array, Array, Array]: 160 | """ 161 | Samples parameters for the environment. 162 | 163 | Parameters 164 | ---------- 165 | batch_size : int 166 | Number of samples. 167 | mode : str 168 | Sampling mode. Options are "Constant", "Different", "Changing". 169 | ts : :class:`jax.Array` 170 | Time steps. 171 | key : :class:`jax.random.PRNGKey` 172 | Random key for sampling. 173 | 174 | Returns 175 | ------- 176 | tuple of :class:`jax.Array` 177 | Sampled parameters. 178 | """ 179 | if mode == "Constant": 180 | Vol = 100 * jnp.ones(batch_size) 181 | Cp = 239 * jnp.ones(batch_size) 182 | dHr = -5.0e4 * jnp.ones(batch_size) 183 | UA = 5.0e4 * jnp.ones(batch_size) 184 | q = 100 * jnp.ones(batch_size) 185 | Tf = 300 * jnp.ones((batch_size, ts.shape[0])) 186 | Tcf = 300 * jnp.ones(batch_size) 187 | Volc = 20.0 * jnp.ones(batch_size) 188 | elif mode == "Different": 189 | keys = jrandom.split(key, 8) 190 | Vol = jrandom.uniform(keys[0], (batch_size,), minval=75, maxval=150) 191 | Cp = jrandom.uniform(keys[1], (batch_size,), minval=200, maxval=350) 192 | dHr = jrandom.uniform(keys[2], (batch_size,), minval=-55000, maxval=-45000) 193 | UA = jrandom.uniform(keys[3], (batch_size,), minval=25000, maxval=75000) 194 | q = jrandom.uniform(keys[4], (batch_size,), minval=75, maxval=125) 195 | Tf = jnp.repeat(jrandom.uniform(keys[5], (batch_size,), minval=300, maxval=350)[:, None], ts.shape[0], axis=1) 196 | Tcf = jrandom.uniform(keys[6], (batch_size,), minval=250, maxval=300) 197 | Volc = jrandom.uniform(keys[7], (batch_size,), minval=10, maxval=30) 198 | elif mode == "Changing": 199 | keys = jrandom.split(key, 8) 200 | Vol = jrandom.uniform(keys[0], (batch_size,), minval=75, maxval=150) 201 | Cp = jrandom.uniform(keys[1], (batch_size,), minval=200, maxval=350) 202 | dHr = jrandom.uniform(keys[2], (batch_size,), minval=-55000, maxval=-45000) 203 | UA = jrandom.uniform(keys[3], (batch_size,), minval=25000, maxval=75000) 204 | q = jrandom.uniform(keys[4], (batch_size,), minval=75, maxval=125) 205 | Tf = self.sample_param_change(keys[5], batch_size, ts, 300, 350) 206 | Tcf = jrandom.uniform(keys[6], (batch_size,), minval=250, maxval=300) 207 | Volc = jrandom.uniform(keys[7], (batch_size,), minval=10, maxval=30) 208 | return Vol, Cp, dHr, UA, q, Tf, Tcf, Volc 209 | 210 | def sample_init_states(self, batch_size: int, key: jrandom.PRNGKey) -> Tuple[Array, Array]: 211 | """ 212 | Samples initial states for the environment. 213 | 214 | Parameters 215 | ---------- 216 | batch_size : int 217 | Number of initial states to sample. 218 | key : :class:`jax.random.PRNGKey` 219 | Random key for sampling. 220 | 221 | Returns 222 | ------- 223 | x0 : :class:`jax.Array` 224 | Initial states. 225 | targets : :class:`jax.Array` 226 | Target states. 227 | """ 228 | init_key, target_key = jrandom.split(key) 229 | x0 = jrandom.uniform(init_key, shape=(batch_size, self.n_var), minval=self.init_lower_bounds, maxval=self.init_upper_bounds) 230 | targets = jrandom.uniform(target_key, shape=(batch_size, self.n_targets), minval=400, maxval=480) 231 | return x0, targets 232 | 233 | def drift(self, t: float, state: Array, args: Tuple) -> Array: 234 | """ 235 | Computes the drift function for the environment. 236 | 237 | Parameters 238 | ---------- 239 | t : float 240 | Current time. 241 | state : :class:`jax.Array` 242 | Current state. 243 | args : tuple 244 | Additional arguments. 245 | 246 | Returns 247 | ------- 248 | :class:`jax.Array` 249 | Drift. 250 | """ 251 | Tc, T, c = state 252 | control = jnp.squeeze(args) 253 | control = jnp.clip(control, 0, 300) 254 | 255 | dc = (self.q / self.Vol) * (self.Cf - c) - self.k(T) * c 256 | dT = (self.q / self.Vol) * (self.Tf.evaluate(t) - T) + (-self.dHr / self.Cp) * self.k(T) * c + (self.UA / self.Vol / self.Cp) * (Tc - T) + self.external_influence.evaluate(t) 257 | dTc = (control / self.Volc) * (self.Tcf - Tc) + (self.UA / self.Volc / self.Cp) * (T - Tc) 258 | 259 | return jnp.array([dTc, dT, dc]) 260 | 261 | def diffusion(self, t: float, state: Array, args: Tuple) -> Array: 262 | """ 263 | Computes the diffusion function for the environment. 264 | 265 | Parameters 266 | ---------- 267 | t : float 268 | Current time. 269 | state : :class:`jax.Array` 270 | Current state. 271 | args : tuple 272 | Additional arguments. 273 | 274 | Returns 275 | ------- 276 | :class:`jax.Array` 277 | Diffusion. 278 | """ 279 | return self.process_noise_ts.evaluate(t) * self.G 280 | 281 | def fitness_function(self, state: Array, control: Array, targets: Array, ts: Array) -> float: 282 | """ 283 | Computes the fitness function for the environment. 284 | 285 | Parameters 286 | ---------- 287 | state : :class:`jax.Array` 288 | Current state. 289 | control : :class:`jax.Array` 290 | Control inputs. 291 | targets : :class:`jax.Array` 292 | Target states. 293 | ts : :class:`jax.Array` 294 | Time steps. 295 | 296 | Returns 297 | ------- 298 | float 299 | Fitness value. 300 | """ 301 | x_d = jax.vmap(lambda tar: jnp.array([0, tar, 0]))(targets) 302 | costs = jax.vmap(lambda _state, _u, _x_d: (_state - _x_d).T @ self.Q @ (_state - _x_d) + (_u) @ self.r @ (_u))(state, control, x_d) 303 | return jnp.sum(costs) 304 | 305 | def cond_fn_nan(self, t: float, y: Array, args: Tuple, **kwargs) -> float: 306 | """ 307 | Checks for NaN or infinite values in the state. 308 | 309 | Parameters 310 | ---------- 311 | t : float 312 | Current time. 313 | y : :class:`jax.Array` 314 | Current state. 315 | args : tuple 316 | Additional arguments. 317 | kwargs : dict 318 | Additional keyword arguments. 319 | 320 | Returns 321 | ------- 322 | float 323 | -1.0 if NaN or infinite values are found, 1.0 otherwise. 324 | """ 325 | return jnp.where(jnp.any(jnp.isinf(y) + jnp.isnan(y)), -1.0, 1.0) -------------------------------------------------------------------------------- /kozax/fitness_functions/Gymnax_fitness_function.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | from jaxtyping import Array 11 | import jax.numpy as jnp 12 | import jax.random as jr 13 | from typing import Tuple, Callable, Any 14 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction 15 | import gymnax 16 | 17 | class GymFitnessFunction(BaseFitnessFunction): 18 | """ 19 | Evaluator for static symbolic policies in control tasks using Gym environments. 20 | 21 | Parameters 22 | ---------- 23 | env_name : str 24 | Name of the Gym environment to be used. 25 | 26 | Attributes 27 | ---------- 28 | env : gymnax environment 29 | The gymnax environment. 30 | env_params : dict 31 | Parameters for the gymnax environment. 32 | num_steps : int 33 | Number of steps in the environment for each episode. 34 | 35 | Methods 36 | ------- 37 | __call__(candidate, keys, tree_evaluator) 38 | Evaluates the candidate on a task. 39 | evaluate_trajectory(candidate, key, tree_evaluator) 40 | Evaluates a rollout of the candidate in the environment. 41 | """ 42 | 43 | def __init__(self, env_name: str) -> None: 44 | self.env, self.env_params = gymnax.make(env_name) 45 | self.num_steps = self.env_params.max_steps_in_episode 46 | 47 | def __call__(self, candidate: Array, keys: Array, tree_evaluator: Callable) -> float: 48 | """ 49 | Evaluates the candidate on a task. 50 | 51 | Parameters 52 | ---------- 53 | candidate : :class:`jax.Array` 54 | The candidate solution to be evaluated. 55 | keys : :class:`jax.Array` 56 | Random keys for evaluation. 57 | tree_evaluator : :class:`Callable` 58 | Function for evaluating trees. 59 | 60 | Returns 61 | ------- 62 | float 63 | Fitness of the candidate. 64 | """ 65 | reward = jax.vmap(self.evaluate_trajectory, in_axes=(None, 0, None))(candidate, keys, tree_evaluator) 66 | return jnp.mean(reward) 67 | 68 | def evaluate_trajectory(self, candidate: Array, key: jr.PRNGKey, tree_evaluator: Callable) -> Tuple[Array, float]: 69 | """ 70 | Evaluates a rollout of the candidate in the environment. 71 | 72 | Parameters 73 | ---------- 74 | candidate : :class:`jax.Array` 75 | The candidate solution to be evaluated. 76 | key : :class:`jax.random.PRNGKey` 77 | Random key for evaluation. 78 | tree_evaluator : :class:`Callable` 79 | Function for evaluating trees. 80 | 81 | Returns 82 | ------- 83 | reward : float 84 | Total reward obtained during the trajectory. 85 | """ 86 | key, subkey = jr.split(key) 87 | state, env_state = self.env.reset(subkey, self.env_params) 88 | 89 | def policy(state: Array) -> Array: 90 | """Symbolic policy.""" 91 | a = tree_evaluator(candidate, state) 92 | return a 93 | 94 | def step_fn(carry: Tuple[Array, Any, jr.PRNGKey], _) -> Tuple[Tuple[Array, Any, jr.PRNGKey], Tuple[Array, float, bool]]: 95 | """Step function for lax.scan.""" 96 | state, env_state, key = carry 97 | 98 | # Select action based on policy 99 | action = policy(state) 100 | 101 | # Step the environment 102 | key, subkey = jr.split(key) 103 | next_state, next_env_state, reward, done, _ = self.env.step( 104 | subkey, env_state, action, self.env_params 105 | ) 106 | 107 | return (next_state, next_env_state, key), (state, reward, done) 108 | 109 | # Run the rollout using lax.scan 110 | (final_carry, (states, rewards, dones)) = jax.lax.scan( 111 | step_fn, (state, env_state, key), None, length=self.num_steps 112 | ) 113 | 114 | return -jnp.sum(rewards) -------------------------------------------------------------------------------- /kozax/fitness_functions/ODE_fitness_function.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | from jaxtyping import Array 11 | import jax.numpy as jnp 12 | from typing import Tuple, Callable 13 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction 14 | import diffrax 15 | 16 | class ODEFitnessFunction(BaseFitnessFunction): 17 | """ 18 | Evaluator for candidates on symbolic regression tasks. 19 | 20 | Parameters 21 | ---------- 22 | solver : :class:`diffrax.AbstractSolver`, optional 23 | Solver used for integration. Default is `diffrax.Euler()`. 24 | dt0 : float, optional 25 | Initial step size for integration. Default is 0.01. 26 | max_steps : int, optional 27 | The maximum number of steps that can be used in integration. Default is 16**4. 28 | stepsize_controller : :class:`diffrax.AbstractStepSizeController`, optional 29 | Controller for the stepsize during integration. Default is `diffrax.ConstantStepSize()`. 30 | 31 | Attributes 32 | ---------- 33 | dt0 : float 34 | Initial step size for integration. 35 | MSE : Callable 36 | Function that computes the mean squared error. 37 | system : :class:`diffrax.ODETerm` 38 | ODE term of the drift function. 39 | solver : :class:`diffrax.AbstractSolver` 40 | Solver used for integration. 41 | stepsize_controller : :class:`diffrax.AbstractStepSizeController` 42 | Controller for the stepsize during integration. 43 | max_steps : int 44 | The maximum number of steps that can be used in integration. 45 | 46 | Methods 47 | ------- 48 | __call__(candidate, data, tree_evaluator) 49 | Evaluates the candidate on a task. 50 | evaluate_time_series(candidate, x0, ts, ys, tree_evaluator) 51 | Integrate the candidate as a differential equation and compute the fitness given the predictions. 52 | drift(t, x, args) 53 | Drift function for the ODE system. 54 | """ 55 | 56 | def __init__(self, solver: diffrax.AbstractSolver = diffrax.Euler(), dt0: float = 0.01, max_steps: int = 16**4, stepsize_controller: diffrax.AbstractStepSizeController = diffrax.ConstantStepSize()) -> None: 57 | self.dt0 = dt0 58 | self.MSE = lambda pred_ys, true_ys: jnp.mean(jnp.sum(jnp.abs(pred_ys - true_ys), axis=-1))/jnp.mean(jnp.abs(true_ys)) 59 | self.system = diffrax.ODETerm(self.drift) 60 | self.solver = solver 61 | self.stepsize_controller = stepsize_controller 62 | self.max_steps = max_steps 63 | 64 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float: 65 | """ 66 | Evaluates the candidate on a task. 67 | 68 | Parameters 69 | ---------- 70 | candidate : :class:`jax.Array` 71 | The candidate solution to be evaluated. 72 | data : :class:`tuple` 73 | The data required to evaluate the candidate. 74 | tree_evaluator : :class:`Callable` 75 | Function for evaluating trees. 76 | 77 | Returns 78 | ------- 79 | float 80 | Fitness of the candidate. 81 | """ 82 | x0, ts, ys = data 83 | fitness = jax.vmap(self.evaluate_time_series, in_axes=[None, 0, None, 0, None])(candidate, x0, ts, ys, tree_evaluator) 84 | return jnp.mean(fitness) 85 | 86 | def evaluate_time_series(self, candidate: Array, x0: Array, ts: Array, ys: Array, tree_evaluator: Callable) -> float: 87 | """ 88 | Integrate the candidate as a differential equation and compute the fitness given the predictions. 89 | 90 | Parameters 91 | ---------- 92 | candidate : :class:`jax.Array` 93 | Candidate that is evaluated. 94 | x0 : :class:`jax.Array` 95 | Initial conditions of the environment. 96 | ts : :class:`jax.Array` 97 | Timepoints of which the system has to be solved. 98 | ys : :class:`jax.Array` 99 | Ground truth data used to compute the fitness. 100 | tree_evaluator : :class:`Callable` 101 | Function for evaluating trees. 102 | 103 | Returns 104 | ------- 105 | float 106 | Fitness of the candidate. 107 | """ 108 | saveat = diffrax.SaveAt(ts=ts) 109 | sol = diffrax.diffeqsolve( 110 | self.system, self.solver, ts[0], ts[-1], self.dt0, x0, args=(candidate, tree_evaluator), saveat=saveat, max_steps=self.max_steps, stepsize_controller=self.stepsize_controller, 111 | adjoint=diffrax.DirectAdjoint(), throw=False 112 | ) 113 | pred_ys = sol.ys 114 | fitness = self.MSE(pred_ys, ys) 115 | return fitness 116 | 117 | def drift(self, t: float, x: Array, args: Tuple) -> Array: 118 | """ 119 | Drift function for the ODE system. 120 | 121 | Parameters 122 | ---------- 123 | t : float 124 | Current time. 125 | x : :class:`jax.Array` 126 | Current state. 127 | args : :class:`tuple` 128 | Additional arguments, including the candidate and tree evaluator. 129 | 130 | Returns 131 | ------- 132 | :class:`jax.Array` 133 | Derivative of the state. 134 | """ 135 | candidate, tree_evaluator = args 136 | dx = tree_evaluator(candidate, x) 137 | return dx -------------------------------------------------------------------------------- /kozax/fitness_functions/SR_fitness_function.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | from typing import Tuple, Callable 11 | from jaxtyping import Array 12 | import jax.numpy as jnp 13 | from kozax.fitness_functions.base_fitness_function import BaseFitnessFunction 14 | 15 | class SymbolicRegressionFitnessFunction(BaseFitnessFunction): 16 | """ 17 | Evaluator for candidates on symbolic regression tasks with x, y data. 18 | 19 | Methods 20 | ------- 21 | __call__(candidate, data, tree_evaluator) 22 | Evaluates the candidate on a task. 23 | """ 24 | 25 | def __call__(self, candidate: Array, data: Tuple[Array, Array], tree_evaluator: Callable) -> float: 26 | """ 27 | Evaluates the candidate on a task. 28 | 29 | Parameters 30 | ---------- 31 | candidate : :class:`jax.Array` 32 | The candidate solution to be evaluated. 33 | data : :class:`tuple` of :class:`jax.Array` 34 | The data required to evaluate the candidate. Tuple of (x, y) where x is the input data and y is the true output data. 35 | tree_evaluator : :class:`Callable` 36 | Function for evaluating trees. 37 | 38 | Returns 39 | ------- 40 | float 41 | Fitness of the candidate. 42 | """ 43 | x, y = data 44 | pred = jax.vmap(tree_evaluator, in_axes=[None, 0])(candidate, x) 45 | return jnp.mean(jnp.square(pred - y)) -------------------------------------------------------------------------------- /kozax/fitness_functions/base_fitness_function.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | from abc import ABC, abstractmethod 10 | from typing import Tuple, Callable 11 | from jaxtyping import Array 12 | 13 | class BaseFitnessFunction(ABC): 14 | """ 15 | Abstract base class for evaluating candidates in genetic programming. 16 | 17 | Methods 18 | ------- 19 | __call__(candidate, data, tree_evaluator) 20 | Evaluates the candidate on a task. 21 | """ 22 | 23 | @abstractmethod 24 | def __call__(self, candidate: Array, data: Tuple, tree_evaluator: Callable) -> float: 25 | """ 26 | Evaluates the candidate on a task. 27 | 28 | Parameters 29 | ---------- 30 | candidate : :class:`jax.Array` 31 | The candidate solution to be evaluated. 32 | data : :class:`tuple` 33 | The data required to evaluate the candidate. 34 | tree_evaluator : :class:`Callable` 35 | Function for evaluating trees. 36 | 37 | Returns 38 | ------- 39 | float 40 | Fitness of the candidate. 41 | """ 42 | raise NotImplementedError -------------------------------------------------------------------------------- /kozax/genetic_operators/crossover.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jr 12 | from typing import Tuple 13 | from jaxtyping import Array 14 | from jax.random import PRNGKey 15 | 16 | def sample_indices(carry: Tuple[PRNGKey, Array, float]) -> Tuple[PRNGKey, Array, float]: 17 | """ 18 | Samples indices of the trees in a candidate that will be mutated. 19 | 20 | Parameters 21 | ---------- 22 | carry : tuple of (PRNGKey, Array, float) 23 | Tuple containing the random key, indices of trees, and reproduction probability. 24 | 25 | Returns 26 | ------- 27 | tuple of (PRNGKey, Array, float) 28 | Updated tuple with the random key, indices of trees to be mutated, and reproduction probability. 29 | """ 30 | key, indices, reproduction_probability = carry 31 | indices = jr.bernoulli(key, p=reproduction_probability, shape=indices.shape) * 1.0 32 | return (jr.split(key, 1)[0], indices, reproduction_probability) 33 | 34 | def find_end_idx(carry: Tuple[Array, int, int]) -> Tuple[Array, int, int]: 35 | """ 36 | Finds the index of the last node in a subtree. 37 | 38 | Parameters 39 | ---------- 40 | carry : tuple of (Array, int, int) 41 | Tuple containing the tree, the number of open slots, and the current node index. 42 | 43 | Returns 44 | ------- 45 | tuple of (Array, int, int) 46 | Updated tuple with the tree, open slots, and current node index. 47 | """ 48 | tree, open_slots, counter = carry 49 | _, idx1, idx2, _ = tree[counter] 50 | open_slots -= 1 # Reduce open slot for current node 51 | open_slots = jax.lax.select(idx1 < 0, open_slots, open_slots + 1) # Increase the open slots for a child 52 | open_slots = jax.lax.select(idx2 < 0, open_slots, open_slots + 1) # Increase the open slots for a child 53 | counter -= 1 54 | return (tree, open_slots, counter) 55 | 56 | def check_invalid_cx_nodes(carry: Tuple[Array, Array, Array, int, int, Array, Array]) -> bool: 57 | """ 58 | Checks if the sampled subtrees are different and if the trees after crossover are valid. 59 | 60 | Parameters 61 | ---------- 62 | carry : tuple of (Array, Array, Array, int, int, Array, Array) 63 | Tuple containing the trees, node indices, and other parameters. 64 | 65 | Returns 66 | ------- 67 | bool 68 | If the sampled nodes are valid nodes for crossover. 69 | """ 70 | tree1, tree2, _, node_idx1, node_idx2, _, _ = carry 71 | 72 | _, _, end_idx1 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree1, 1, node_idx1)) 73 | _, _, end_idx2 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree2, 1, node_idx2)) 74 | 75 | subtree_size1 = node_idx1 - end_idx1 76 | subtree_size2 = node_idx2 - end_idx2 77 | 78 | empty_nodes1 = jnp.sum(tree1[:, 0] == 0) 79 | empty_nodes2 = jnp.sum(tree2[:, 0] == 0) 80 | 81 | # Check if the subtrees can be inserted 82 | return (empty_nodes1 < subtree_size2 - subtree_size1) | (empty_nodes2 < subtree_size1 - subtree_size2) 83 | 84 | def sample_cx_nodes(carry: Tuple[Array, Array, Array, int, int, Array, Array]) -> Tuple[Array, Array, Array, int, int, Array, Array]: 85 | """ 86 | Samples nodes in a pair of trees for crossover. 87 | 88 | Parameters 89 | ---------- 90 | carry : tuple of (Array, Array, Array, int, int, Array, Array) 91 | Tuple containing the trees, node indices, and other parameters. 92 | 93 | Returns 94 | ------- 95 | tuple of (Array, Array, Array, int, int, Array, Array) 96 | Updated tuple with the sampled nodes. 97 | """ 98 | tree1, tree2, keys, _, _, node_ids, operator_indices = carry 99 | key1, key2 = keys 100 | 101 | # Sample nodes from the non-empty nodes, with higher probability for operator nodes 102 | cx_prob1 = jnp.isin(tree1[:, 0], operator_indices) 103 | cx_prob1 = jnp.where(tree1[:, 0] == 0, cx_prob1, cx_prob1 + 1) 104 | node_idx1 = jr.choice(key1, node_ids, p=cx_prob1 * 1.0) 105 | 106 | cx_prob2 = jnp.isin(tree2[:, 0], operator_indices) 107 | cx_prob2 = jnp.where(tree2[:, 0] == 0, cx_prob2, cx_prob2 + 1) 108 | node_idx2 = jr.choice(key2, node_ids, p=cx_prob2 * 1.0) 109 | 110 | return (tree1, tree2, jr.split(key1), node_idx1, node_idx2, node_ids, operator_indices) 111 | 112 | def tree_crossover(tree1: Array, 113 | tree2: Array, 114 | keys: Array, 115 | node_ids: Array, 116 | operator_indices: Array) -> Tuple[Array, Array]: 117 | """ 118 | Applies crossover to a pair of trees to produce two new trees. 119 | 120 | Parameters 121 | ---------- 122 | tree1 : Array 123 | First tree. 124 | tree2 : Array 125 | Second tree. 126 | keys : Array 127 | Random keys. 128 | node_ids : Array 129 | Indices of all the nodes in the trees. 130 | operator_indices : Array 131 | The indices that belong to operator nodes. 132 | 133 | Returns 134 | ------- 135 | tuple of (Array, Array) 136 | Pair of new trees. 137 | """ 138 | # Define indices of the nodes 139 | tree_indices = jnp.tile(node_ids[:, None], reps=(1, 4)) 140 | key1, key2 = keys 141 | 142 | # Define last node in tree 143 | last_node_idx1 = jnp.sum(tree1[:, 0] == 0) 144 | last_node_idx2 = jnp.sum(tree2[:, 0] == 0) 145 | 146 | # Randomly select nodes for crossover 147 | _, _, _, node_idx1, node_idx2, _, _ = sample_cx_nodes((tree1, tree2, jr.split(key1), 0, 0, node_ids, operator_indices)) 148 | 149 | # Reselect until valid crossover nodes have been found 150 | _, _, _, node_idx1, node_idx2, _, _ = jax.lax.while_loop(check_invalid_cx_nodes, sample_cx_nodes, (tree1, tree2, jr.split(key2), node_idx1, node_idx2, node_ids, operator_indices)) 151 | 152 | # Retrieve subtrees of selected nodes 153 | _, _, end_idx1 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree1, 1, node_idx1)) 154 | _, _, end_idx2 = jax.lax.while_loop(lambda carry: carry[1] > 0, find_end_idx, (tree2, 1, node_idx2)) 155 | 156 | # Initialize children 157 | child1 = jnp.tile(jnp.array([0.0, -1.0, -1.0, 0.0]), (len(node_ids), 1)) 158 | child2 = jnp.tile(jnp.array([0.0, -1.0, -1.0, 0.0]), (len(node_ids), 1)) 159 | 160 | # Compute subtree sizes 161 | subtree_size1 = node_idx1 - end_idx1 162 | subtree_size2 = node_idx2 - end_idx2 163 | 164 | # Insert nodes before subtree in children 165 | child1 = jnp.where(tree_indices >= node_idx1 + 1, tree1, child1) 166 | child2 = jnp.where(tree_indices >= node_idx2 + 1, tree2, child2) 167 | 168 | # Align nodes after subtree with first open spot after new subtree in children 169 | rolled_tree1 = jnp.roll(tree1, subtree_size1 - subtree_size2, axis=0) 170 | rolled_tree2 = jnp.roll(tree2, subtree_size2 - subtree_size1, axis=0) 171 | 172 | # Insert nodes after subtree in children 173 | child1 = jnp.where((tree_indices >= node_idx1 - subtree_size2 - (end_idx1 - last_node_idx1)) & (tree_indices < node_idx1 + 1 - subtree_size2), rolled_tree1, child1) 174 | child2 = jnp.where((tree_indices >= node_idx2 - subtree_size1 - (end_idx2 - last_node_idx2)) & (tree_indices < node_idx2 + 1 - subtree_size1), rolled_tree2, child2) 175 | 176 | # Update index references to moved nodes in staying nodes 177 | child1 = child1.at[:, 1:3].set(jnp.where((child1[:, 1:3] < (node_idx1 - subtree_size1 + 1)) & (child1[:, 1:3] > -1), child1[:, 1:3] + (subtree_size1 - subtree_size2), child1[:, 1:3])) 178 | child2 = child2.at[:, 1:3].set(jnp.where((child2[:, 1:3] < (node_idx2 - subtree_size2 + 1)) & (child2[:, 1:3] > -1), child2[:, 1:3] + (subtree_size2 - subtree_size1), child2[:, 1:3])) 179 | 180 | # Align subtree with the selected node in children 181 | rolled_subtree1 = jnp.roll(tree1, node_idx2 - node_idx1, axis=0) 182 | rolled_subtree2 = jnp.roll(tree2, node_idx1 - node_idx2, axis=0) 183 | 184 | # Update index references in subtree 185 | rolled_subtree1 = rolled_subtree1.at[:, 1:3].set(jnp.where(rolled_subtree1[:, 1:3] > -1, rolled_subtree1[:, 1:3] + (node_idx2 - node_idx1), -1)) 186 | rolled_subtree2 = rolled_subtree2.at[:, 1:3].set(jnp.where(rolled_subtree2[:, 1:3] > -1, rolled_subtree2[:, 1:3] + (node_idx1 - node_idx2), -1)) 187 | 188 | # Insert subtree in selected node in children 189 | child1 = jnp.where((tree_indices >= node_idx1 + 1 - subtree_size2) & (tree_indices < node_idx1 + 1), rolled_subtree2, child1) 190 | child2 = jnp.where((tree_indices >= node_idx2 + 1 - subtree_size1) & (tree_indices < node_idx2 + 1), rolled_subtree1, child2) 191 | 192 | return child1, child2 193 | 194 | def full_crossover(tree1: Array, 195 | tree2: Array, 196 | keys: Array, 197 | node_ids: Array, 198 | operator_indices: Array) -> Tuple[Array, Array]: 199 | """ 200 | Swaps the entire trees between two candidates. 201 | 202 | Parameters 203 | ---------- 204 | tree1 : Array 205 | First tree. 206 | tree2 : Array 207 | Second tree. 208 | keys : Array 209 | Random keys. 210 | node_ids : Array 211 | Indices of all the nodes in the trees. 212 | operator_indices : Array 213 | The indices that belong to operator nodes. 214 | 215 | Returns 216 | ------- 217 | tuple of (Array, Array) 218 | Swapped trees. 219 | """ 220 | return tree2, tree1 221 | 222 | def crossover(tree1: Array, 223 | tree2: Array, 224 | keys: Array, 225 | node_ids: Array, 226 | operator_indices: Array, 227 | crossover_types: int) -> Tuple[Array, Array]: 228 | """ 229 | Applies crossover to a pair of trees based on the crossover type. 230 | 231 | Parameters 232 | ---------- 233 | tree1 : Array 234 | First tree. 235 | tree2 : Array 236 | Second tree. 237 | keys : Array 238 | Random keys. 239 | node_ids : Array 240 | Indices of all the nodes in the trees. 241 | operator_indices : Array 242 | The indices that belong to operator nodes. 243 | crossover_types : int 244 | Type of crossover to apply. 245 | 246 | Returns 247 | ------- 248 | tuple of (Array, Array) 249 | Pair of new trees. 250 | """ 251 | return jax.lax.cond(crossover_types, tree_crossover, full_crossover, tree1, tree2, keys, node_ids, operator_indices) 252 | 253 | def check_different_tree(parent1: Array, parent2: Array, child1: Array, child2: Array) -> bool: 254 | """ 255 | Checks if the offspring are different from the parents. 256 | 257 | Parameters 258 | ---------- 259 | parent1 : Array 260 | First parent tree. 261 | parent2 : Array 262 | Second parent tree. 263 | child1 : Array 264 | First child tree. 265 | child2 : Array 266 | Second child tree. 267 | 268 | Returns 269 | ------- 270 | bool 271 | True if the offspring are different from the parents, False otherwise. 272 | """ 273 | size1 = jnp.sum(child1[:, 0] != 0) 274 | size2 = jnp.sum(child2[:, 0] != 0) 275 | 276 | check1 = (jnp.all(parent1 == child1) | jnp.all(parent2 == child1)) 277 | check2 = (jnp.all(parent1 == child2) | jnp.all(parent2 == child2)) 278 | 279 | return ((check1 | check2) & ((size1 > 1) & (size2 > 1))) | (size1 == 0) 280 | 281 | def check_different_trees(carry: Tuple[Array, Array, Array, Array, Array, float, Array, Array]) -> bool: 282 | """ 283 | Checks if the offspring are different from the parents for all trees in the population. 284 | 285 | Parameters 286 | ---------- 287 | carry : tuple of (Array, Array, Array, Array, Array, float, Array, Array) 288 | Tuple containing the parent trees, child trees, and other parameters. 289 | 290 | Returns 291 | ------- 292 | bool 293 | True if the offspring are different from the parents for all trees, False otherwise. 294 | """ 295 | parent1, parent2, child1, child2, _, _, _, _ = carry 296 | return jnp.all(jax.vmap(check_different_tree)(parent1, parent2, child1, child2)) 297 | 298 | def safe_crossover(carry: Tuple[Array, Array, Array, Array, Array, float, Array, Array]) -> Tuple[Array, Array, Array, Array, Array, float, Array, Array]: 299 | """ 300 | Ensures that the crossover produces valid offspring. 301 | 302 | Parameters 303 | ---------- 304 | carry : tuple of (Array, Array, Array, Array, Array, float, Array, Array) 305 | Tuple containing the parent trees, child trees, and other parameters. 306 | 307 | Returns 308 | ------- 309 | tuple of (Array, Array, Array, Array, Array, float, Array, Array) 310 | Updated tuple with the parent trees, child trees, and other parameters. 311 | """ 312 | parent1, parent2, _, _, keys, reproduction_probability, node_ids, operator_indices = carry 313 | index_key, type_key, new_key = jr.split(keys[0, 0], 3) 314 | _, cx_indices, _ = jax.lax.while_loop(lambda carry: jnp.sum(carry[1]) == 0, sample_indices, (index_key, jnp.zeros(parent1.shape[0]), reproduction_probability)) 315 | crossover_types = jr.bernoulli(type_key, p=0.9, shape=(parent1.shape[0],)) 316 | offspring1, offspring2 = jax.vmap(crossover, in_axes=[0, 0, 0, None, None, 0])(parent1, parent2, keys, node_ids, operator_indices, crossover_types) 317 | child1 = jnp.where(cx_indices[:, None, None] * jnp.ones_like(parent1), offspring1, parent1) 318 | child2 = jnp.where(cx_indices[:, None, None] * jnp.ones_like(parent2), offspring2, parent2) 319 | 320 | keys = jr.split(new_key, keys.shape[:-1]) 321 | 322 | return parent1, parent2, child1, child2, keys, reproduction_probability, node_ids, operator_indices 323 | 324 | def crossover_trees(parent1: Array, 325 | parent2: Array, 326 | keys: Array, 327 | reproduction_probability: float, 328 | max_nodes: int, 329 | operator_indices: Array) -> Tuple[Array, Array]: 330 | """ 331 | Applies crossover to the trees in a pair of candidates. 332 | 333 | Parameters 334 | ---------- 335 | parent1 : Array 336 | First parent candidate. 337 | parent2 : Array 338 | Second parent candidate. 339 | keys : Array 340 | Random keys. 341 | reproduction_probability : float 342 | Probability of a tree to be mutated. 343 | max_nodes : int 344 | Max number of nodes in a tree. 345 | operator_indices : Array 346 | The indices that belong to operator nodes. 347 | 348 | Returns 349 | ------- 350 | tuple of (Array, Array) 351 | Pair of candidates after crossover. 352 | """ 353 | _, _, child1, child2, _, _, _, _ = jax.lax.while_loop(check_different_trees, safe_crossover, ( 354 | parent1, parent2, jnp.zeros_like(parent1), jnp.zeros_like(parent2), keys, reproduction_probability, jnp.arange(max_nodes), operator_indices)) 355 | 356 | return child1, child2 -------------------------------------------------------------------------------- /kozax/genetic_operators/initialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import jax.random as jr 12 | from jax.random import PRNGKey 13 | from functools import partial 14 | from jax import Array 15 | from typing import Tuple, Callable 16 | 17 | def sample_node(i: int, 18 | carry: Tuple[PRNGKey, Array, int, int, int, Array, Tuple]) -> Tuple[PRNGKey, Array, int, int, int, Array, Tuple]: 19 | """Samples nodes sequentially in breadth-first order, storing them depth-first. 20 | 21 | Parameters 22 | ---------- 23 | i : int 24 | Index of the node. 25 | carry : Tuple[PRNGKey, Array, int, int, int, Array, Tuple] 26 | Tuple containing the random key, tree, open slots, max init depth, max nodes, variable array, and other arguments. 27 | 28 | Returns 29 | ------- 30 | Tuple[PRNGKey, Array, int, int, int, Array, Tuple] 31 | Updated tuple with the random key, tree, open slots, max init depth, max nodes, variable array, and other arguments. 32 | """ 33 | key, tree, open_slots, max_init_depth, max_nodes, variable_array, args = carry 34 | variable_indices, operator_indices, operator_probabilities, slots, coefficient_sd, map_b_to_d = args 35 | coefficient_key, leaf_key, variable_key, node_key, operator_key = jr.split(key, 5) 36 | _i = map_b_to_d[i].astype(int) # Get depth first index 37 | 38 | depth = (jnp.log(i + 1 + 1e-10) / jnp.log(2)).astype(int) # Compute depth of node 39 | coefficient = jr.normal(coefficient_key) * coefficient_sd 40 | leaf = jax.lax.select(jr.uniform(leaf_key) < 0.5, 1, jr.choice(variable_key, variable_indices, shape=(), p=variable_array)) # Sample coefficient or variable 41 | 42 | index = jax.lax.select((open_slots < max_nodes - i - 1) & (depth + 1 < max_init_depth), # Check if max depth has been reached, or if the number of open slots reached the max number of nodes 43 | jax.lax.select(jr.uniform(node_key) < (0.7 ** depth), # At higher depth, a leaf node is more probable 44 | jr.choice(operator_key, a=operator_indices, shape=(), p=operator_probabilities), 45 | leaf), 46 | leaf) 47 | 48 | index = jax.lax.select(open_slots == 0, 0, index) # If there are no open slots, the node should be empty 49 | 50 | # If parent node is a leaf, the node should be empty 51 | index = jax.lax.select(i > 0, jax.lax.select((slots[jnp.maximum(tree[map_b_to_d[(i + (i % 2) - 2) // 2].astype(int), 0], 0).astype(int)] + i % 2) > 1, index, 0), index) 52 | 53 | # Set index references 54 | tree = jax.lax.select(slots[index] > 0, tree.at[_i, 1].set(map_b_to_d[2 * i + 1]), tree.at[_i, 1].set(-1)) 55 | tree = jax.lax.select(slots[index] > 1, tree.at[_i, 2].set(map_b_to_d[2 * i + 2]), tree.at[_i, 2].set(-1)) 56 | 57 | tree = jax.lax.select(index == 1, tree.at[_i, 3].set(coefficient), tree) # Set coefficient value 58 | tree = tree.at[_i, 0].set(index) 59 | 60 | open_slots = jax.lax.select(index == 0, open_slots, jnp.maximum(0, open_slots + slots[index] - 1)) # Update the number of open slots 61 | 62 | return (jr.fold_in(key, i), tree, open_slots, max_init_depth, max_nodes, variable_array, args) 63 | 64 | def prune_row(i: int, 65 | carry: Tuple[Array, int, int], 66 | old_tree: Array) -> Tuple[Array, int, int]: 67 | """Sequentially adds nodes to the new tree if it is not empty. 68 | 69 | Parameters 70 | ---------- 71 | i : int 72 | Index of the node. 73 | carry : Tuple[Array, int, int] 74 | Tuple containing the tree, counter, and tree size. 75 | old_tree : Array 76 | Tree with empty nodes that have to be pruned. 77 | 78 | Returns 79 | ------- 80 | Tuple[Array, int, int] 81 | Updated tuple with the tree, counter, and tree size. 82 | """ 83 | tree, counter, tree_size = carry 84 | _i = tree_size - i - 1 85 | row = old_tree[_i] 86 | 87 | # If node is not empty, add node and update index references 88 | tree = jax.lax.select(row[0] != 0, tree.at[counter].set(row), tree.at[:, 1:3].set(jnp.where(tree[:, 1:3] > _i, tree[:, 1:3] - 1, tree[:, 1:3]))) 89 | counter = jax.lax.select(row[0] != 0, counter - 1, counter) 90 | 91 | return (tree, counter, tree_size) 92 | 93 | def prune_tree(tree: Array, 94 | tree_size: int, 95 | max_nodes: int) -> Array: 96 | """Removes empty nodes from a tree. The new tree is filled with empty nodes at the end to match the max number of nodes. 97 | 98 | Parameters 99 | ---------- 100 | tree : Array 101 | Tree to be pruned. 102 | tree_size : int 103 | Max size of the old tree. 104 | max_nodes : int 105 | Max number of nodes in the new tree. 106 | 107 | Returns 108 | ------- 109 | Array 110 | Tree with empty nodes pruned. 111 | """ 112 | tree, counter, _ = jax.lax.fori_loop(0, tree_size, partial(prune_row, old_tree=tree), (jnp.tile(jnp.array([0.0, -1.0, -1.0, 0.0]), (max_nodes, 1)), max_nodes - 1, tree_size)) 113 | tree = tree.at[:, 1:3].set(jnp.where(tree[:, 1:3] > -1, tree[:, 1:3] + counter + 1, tree[:, 1:3])) # Update index references after pruning 114 | return tree 115 | 116 | def sample_tree(key: PRNGKey, 117 | depth: int, 118 | variable_array: Array, 119 | tree_size: int, 120 | max_nodes: int, 121 | args: Tuple) -> Array: 122 | """Initializes a tree. 123 | 124 | Parameters 125 | ---------- 126 | key : PRNGKey 127 | Random key. 128 | depth : int 129 | Max depth in a tree at initialization. 130 | variable_array : Array 131 | The valid variables for this tree. 132 | tree_size : int 133 | Max size of the tree. 134 | max_nodes : int 135 | Max number of nodes in a tree. 136 | args : Tuple 137 | Miscellaneous parameters required for initialization. 138 | 139 | Returns 140 | ------- 141 | Array 142 | Initialized tree. 143 | """ 144 | # First sample tree at full size given depth 145 | tree = jax.lax.fori_loop(0, tree_size, sample_node, (key, jnp.zeros((tree_size, 4)), 1, depth, max_nodes, variable_array, args))[1] # Sample nodes in a tree sequentially 146 | 147 | # Prune empty rows in tree 148 | pruned_tree = prune_tree(tree, tree_size, max_nodes) 149 | return pruned_tree 150 | 151 | def sample_population(key: PRNGKey, 152 | population_size: int, 153 | num_trees: int, 154 | max_init_depth: int, 155 | variable_array: Array, 156 | sample_function: Callable) -> Array: 157 | """Initializes a population of candidates. 158 | 159 | Parameters 160 | ---------- 161 | key : PRNGKey 162 | Random key. 163 | population_size : int 164 | Number of candidates that have to be sampled. 165 | num_trees : int 166 | Number of trees in a candidate. 167 | max_init_depth : int 168 | Max depth in a tree at initialization. 169 | variable_array : Array 170 | The valid variables for each tree. 171 | sample_function : Callable 172 | Function to sample a tree. 173 | 174 | Returns 175 | ------- 176 | Array 177 | Population of candidates. 178 | """ 179 | sample_candidate = lambda keys: jax.vmap(sample_function, in_axes=[0, None, 0])(keys, max_init_depth, variable_array) 180 | return jax.vmap(sample_candidate)(jr.split(key, (population_size, num_trees))) -------------------------------------------------------------------------------- /kozax/genetic_operators/reproduction.py: -------------------------------------------------------------------------------- 1 | """ 2 | kozax: Genetic programming framework in JAX 3 | 4 | Copyright (c) 2024 sdevries0 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial-NoDerivs 4.0 International License. 7 | """ 8 | 9 | import jax 10 | from jax import Array 11 | import jax.numpy as jnp 12 | import jax.random as jr 13 | from jax.random import PRNGKey 14 | from typing import Callable, List, Tuple 15 | 16 | def evolve_trees(parent1: Array, 17 | parent2: Array, 18 | keys: Array, 19 | type: int, 20 | reproduction_probability: float, 21 | reproduction_functions: List[Callable]) -> Tuple[Array, Array]: 22 | """Applies reproduction function to pair of candidates. 23 | 24 | Parameters 25 | ---------- 26 | parent1 : Array 27 | First parent candidate. 28 | parent2 : Array 29 | Second parent candidate. 30 | keys : Array 31 | Random keys. 32 | type : int 33 | Type of reproduction function to apply. 34 | reproduction_probability : float 35 | Probability of a tree to be mutated. 36 | reproduction_functions : List[Callable] 37 | Functions that can be applied for reproduction. 38 | 39 | Returns 40 | ------- 41 | Tuple[Array, Array] 42 | Pair of reproduced candidates. 43 | """ 44 | child0, child1 = jax.lax.switch(type, reproduction_functions, parent1, parent2, keys, reproduction_probability) 45 | 46 | return child0, child1 47 | 48 | def tournament_selection(population: Array, 49 | fitness: Array, 50 | key: PRNGKey, 51 | tournament_probabilities: Array, 52 | tournament_size: int, 53 | population_indices: Array) -> Array: 54 | """Selects a candidate for reproduction from a tournament. 55 | 56 | Parameters 57 | ---------- 58 | population : Array 59 | Population of candidates. 60 | fitness : Array 61 | Fitness of candidates. 62 | key : PRNGKey 63 | Random key. 64 | tournament_probabilities : Array 65 | Probability of each of the ranks in the tournament to be selected for reproduction. 66 | tournament_size : int 67 | Size of the tournament. 68 | population_indices : Array 69 | Indices of the population. 70 | 71 | Returns 72 | ------- 73 | Array 74 | Candidate that won the tournament. 75 | """ 76 | tournament_key, winner_key = jr.split(key) 77 | indices = jr.choice(tournament_key, population_indices, shape=(tournament_size,)) 78 | 79 | index = jr.choice(winner_key, indices[jnp.argsort(fitness[indices])], p=tournament_probabilities) 80 | return population[index] 81 | 82 | def evolve_population(population: Array, 83 | fitness: Array, 84 | key: PRNGKey, 85 | reproduction_type_probabilities: Array, 86 | reproduction_probability: float, 87 | tournament_probabilities: Array, 88 | population_indices: Array, 89 | population_size: int, 90 | tournament_size: int, 91 | num_trees: int, 92 | elite_size: int, 93 | reproduction_functions: List[Callable]) -> Array: 94 | """Reproduces pairs of candidates to obtain a new population. 95 | 96 | Parameters 97 | ---------- 98 | population : Array 99 | Population of candidates. 100 | fitness : Array 101 | Fitness of candidates. 102 | key : PRNGKey 103 | Random key. 104 | reproduction_type_probabilities : Array 105 | Probability of each reproduction function to be applied. 106 | reproduction_probability : float 107 | Probability of a tree to be mutated. 108 | tournament_probabilities : Array 109 | Probability of each of the ranks in the tournament to be selected for reproduction. 110 | population_indices : Array 111 | Indices of the population. 112 | population_size : int 113 | Size of the population. 114 | tournament_size : int 115 | Size of the tournament. 116 | num_trees : int 117 | Number of trees in a candidate. 118 | elite_size : int 119 | Number of candidates that remain in the new population without reproduction. 120 | reproduction_functions : List[Callable] 121 | Functions that can be applied for reproduction. 122 | 123 | Returns 124 | ------- 125 | Array 126 | Evolved population. 127 | """ 128 | left_key, right_key, repro_key, evo_key = jr.split(key, 4) 129 | elite = population[jnp.argsort(fitness)[:elite_size]] 130 | 131 | # Sample parents for reproduction 132 | left_parents = jax.vmap(tournament_selection, in_axes=[None, None, 0, None, None, None])(population, 133 | fitness, 134 | jr.split(left_key, (population_size - elite_size)//2), 135 | tournament_probabilities, 136 | tournament_size, 137 | population_indices) 138 | 139 | right_parents = jax.vmap(tournament_selection, in_axes=[None, None, 0, None, None, None])(population, 140 | fitness, 141 | jr.split(right_key, (population_size - elite_size)//2), 142 | tournament_probabilities, 143 | tournament_size, 144 | population_indices) 145 | # Sample which reproduction function to apply to the parents 146 | reproduction_type = jr.choice(repro_key, jnp.arange(3), shape=((population_size - elite_size)//2,), p=reproduction_type_probabilities) 147 | 148 | left_children, right_children = jax.vmap(evolve_trees, in_axes=[0, 0, 0, 0, None, None])(left_parents, 149 | right_parents, 150 | jr.split(evo_key, ((population_size - elite_size)//2, num_trees, 2)), 151 | reproduction_type, 152 | reproduction_probability, 153 | reproduction_functions) 154 | 155 | evolved_population = jnp.concatenate([elite, left_children, right_children], axis=0) 156 | return evolved_population 157 | 158 | def migrate_population(receiver: Array, 159 | sender: Array, 160 | receiver_fitness: Array, 161 | sender_fitness: Array, 162 | migration_size: int, 163 | population_indices: Array) -> Tuple[Array, Array]: 164 | """Unfit candidates from one population are replaced with fit candidates from another population. 165 | 166 | Parameters 167 | ---------- 168 | receiver : Array 169 | Population that receives new candidates. 170 | sender : Array 171 | Population that sends fit candidates. 172 | receiver_fitness : Array 173 | Fitness of the candidates in the receiving population. 174 | sender_fitness : Array 175 | Fitness of the candidates in the sending population. 176 | migration_size : int 177 | How many candidates are migrated to new population. 178 | population_indices : Array 179 | Indices of the population. 180 | 181 | Returns 182 | ------- 183 | Tuple[Array, Array] 184 | Population after migration and their fitness. 185 | """ 186 | sorted_receiver = receiver[jnp.argsort(receiver_fitness, descending=True)] 187 | sorted_sender = sender[jnp.argsort(sender_fitness, descending=False)] 188 | migrated_population = jnp.where((population_indices < migration_size)[:,None,None,None], sorted_sender, sorted_receiver) 189 | migrated_fitness = jnp.where(population_indices < migration_size, jnp.sort(sender_fitness, descending=False), jnp.sort(receiver_fitness, descending=True)) 190 | return migrated_population, migrated_fitness 191 | 192 | def evolve_populations(jit_evolve_population: Callable, 193 | populations: Array, 194 | fitness: Array, 195 | key: PRNGKey, 196 | current_generation: int, 197 | migration_period: int, 198 | migration_size: int, 199 | reproduction_type_probabilities: Array, 200 | reproduction_probabilities: Array, 201 | tournament_probabilities: Array) -> Array: 202 | """Evolves each population independently. 203 | 204 | Parameters 205 | ---------- 206 | jit_evolve_population : Callable 207 | Function for evolving trees that is jittable and parallelizable. 208 | populations : Array 209 | Populations of candidates. 210 | fitness : Array 211 | Fitness of candidates. 212 | key : PRNGKey 213 | Random key. 214 | current_generation : int 215 | Current generation number. 216 | migration_period : int 217 | After how many generations migration occurs. 218 | migration_size : int 219 | How many candidates are migrated to new population. 220 | reproduction_type_probabilities : Array 221 | Probability of each reproduction function to be applied. 222 | reproduction_probabilities : Array 223 | Probability of a tree to be mutated. 224 | tournament_probabilities : Array 225 | Probability of each of the ranks in the tournament to be selected for reproduction. 226 | 227 | Returns 228 | ------- 229 | Array 230 | Evolved populations. 231 | """ 232 | num_populations, population_size, num_trees, _, _ = populations.shape 233 | population_indices = jnp.arange(population_size) 234 | 235 | # Migrate candidates between populations. The populations and fitnesses are rolled for circular migration. 236 | populations, fitness = jax.lax.cond((num_populations > 1) & (((current_generation+1)%migration_period) == 0), 237 | jax.vmap(migrate_population, in_axes=[0, 0, 0, 0, None, None]), 238 | jax.vmap(lambda receiver, sender, receiver_fitness, sender_fitness, migration_size, population_indices: (receiver, receiver_fitness), in_axes=[0, 0, 0, 0, None, None]), 239 | populations, 240 | jnp.roll(populations, 1, axis=0), 241 | fitness, 242 | jnp.roll(fitness, 1, axis=0), 243 | migration_size, 244 | population_indices) 245 | 246 | new_population = jit_evolve_population(populations, 247 | fitness, 248 | jr.split(key, num_populations), 249 | reproduction_type_probabilities, 250 | reproduction_probabilities, 251 | tournament_probabilities, 252 | population_indices) 253 | return new_population --------------------------------------------------------------------------------