├── pyrightconfig.json ├── README.md ├── LICENSE ├── trainer.py ├── train.py ├── .gitignore ├── utils ├── foraging.py ├── exputils.py ├── model.py ├── task.py └── trainer.py ├── task.py └── lndp.py /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | // Install LSP-json to get validation and autocompletion in this file. 3 | "venvPath": "/Users/erpl/anaconda3/envs", 4 | "venv": "jax", 5 | } 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LNDP 2 | 3 | Code for the paper "Evolving Self-Assembling Neural Networks: From Spontaneous Activity to Experience-Dependent Learning" ([PDF](https://arxiv.org/abs/2406.09787)) 4 | 5 | # Training 6 | 7 | To launch training the train.py script has to be executed. You can find the default trainign config in train.py. You can either directly modify it in the file or overwrite config variable with arguments to the script (e.g ```python -m train.py --env_name="Acrobot-v1``` to train on Acrobot). 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Erwan Plantec 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import evosax as ex 2 | import jax.numpy as jnp 3 | from utils.trainer import * 4 | import jax 5 | import equinox as eqx 6 | 7 | 8 | def make(config, task, params_like)->EvosaxTrainer: 9 | 10 | def metrics_fn(state, data): 11 | y = {} 12 | for k, v in data["data"].items(): 13 | y[f"{k}_mean"] = v.mean() 14 | y[f"{k}_max"] = v.max() 15 | y["best"] = - state.best_fitness 16 | y["gen_best"] = data["fitness"].max() 17 | y["gen_mean"] = data["fitness"].mean() 18 | y["gen_worse"] = data["fitness"].min() 19 | y["var"] = jnp.var(data["fitness"]) 20 | return y, state.mean, state.gen_counter 21 | 22 | params_shaper = ex.ParameterReshaper(params_like) 23 | 24 | fitness_shaper = ex.FitnessShaper(maximize=True) 25 | 26 | if config.ckpt_file: 27 | ckpt_file = f"{config.ckpt_file}_{config.env_name}_{config.seed}" 28 | else: 29 | ckpt_file = f"./ckpts/{config.env_name}_{config.seed}" 30 | # save config 31 | eqx.tree_serialise_leaves(f"{ckpt_file}_config.eqx", config) 32 | logger = Logger(bool(config.log), metrics_fn=metrics_fn, ckpt_file=ckpt_file) 33 | 34 | trainer = EvosaxTrainer(train_steps = config.generations, 35 | task = task, 36 | params_shaper=params_shaper, 37 | strategy=config.strategy, 38 | popsize=config.popsize, 39 | fitness_shaper=fitness_shaper, 40 | eval_reps=config.eval_reps, 41 | logger=logger, 42 | n_devices=len(jax.devices())) 43 | 44 | return trainer 45 | 46 | 47 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils.exputils import load_config 2 | from lndp import make as model_factory 3 | from task import make as task_factory, MultiTask 4 | from trainer import make as trainer_factory 5 | 6 | from typing import NamedTuple 7 | import equinox as eqx 8 | import jax.random as jr 9 | 10 | PROJECT = "lndp" 11 | 12 | class Config(NamedTuple): 13 | project: str=PROJECT #wandb project name 14 | seed: int=1 15 | n_seeds: int=1 16 | # --- trainer --- 17 | strategy: str="CMA_ES" 18 | popsize: int=256 19 | generations: int=10_000 20 | ckpt_file: str="" 21 | log: int=0 # 1 for logging to wandb 22 | eval_reps: int=1 # number of evaluations to average over 23 | # --- task --- 24 | env_name: str="CartPole-v1" 25 | n_episodes: int=3 # number of enviornment episodes 26 | l1_penalty: float=0. 27 | dev_after_episode: int=0 28 | env_size: int=5 29 | p_switch: float=0. 30 | dense_reward: int=0 31 | # --- model --- 32 | n_nodes: int=32 # max nb of nodes in the network 33 | node_features: int=8 34 | edge_features: int=4 35 | pruning: int=1 36 | synaptogenesis: int=1 37 | rnn_iters: int=3 #number of propagation steps 38 | dev_steps: int=0 #number of developmental steps 39 | p_hh: float=0.1 #initial connection probabilities (avergae/variance) 40 | s_hh: float=0.0001 41 | p_ih: float=0.1 42 | s_ih: float=0.0001 43 | p_ho: float=0.1 44 | s_ho: float=0.0001 45 | use_bias: int=0 46 | is_recurrent: int=0 # activity is resetted between env steps if is_recurrent=0 47 | gnn_iters: int=1 # number of GNN forward pass 48 | stochastic_decisions: int=0 # if synaptogenesis or pruning is probabilistic 49 | block_lt_updates: int=0 # if 1 will block any change during the agent lifetime 50 | 51 | 52 | if __name__ == '__main__': 53 | 54 | cfg = load_config(Config) 55 | key_model, key_train = jr.split(jr.key(cfg.seed)) 56 | 57 | if "," not in cfg.env_name: 58 | mdl = model_factory(cfg, key_model) 59 | params, statics = eqx.partition(mdl, eqx.is_array) 60 | task = task_factory(cfg, statics) 61 | else: 62 | env_names=cfg.env_name.split(",") 63 | statics = [] 64 | tsks = [] 65 | for env_name in env_names: 66 | _cfg = cfg._replace(env_name=env_name) 67 | mdl = model_factory(_cfg, key_model) 68 | params, statics = eqx.partition(mdl, eqx.is_array) 69 | task = task_factory(_cfg, statics) 70 | tsks.append(task) 71 | task = MultiTask(tsks) 72 | 73 | 74 | trainer = trainer_factory(cfg, task, params) #type:ignore 75 | 76 | for seed in range(cfg.n_seeds): 77 | key_train, _ktrain = jr.split(key_train) 78 | if cfg.log: 79 | trainer.logger.init(cfg.project, cfg._replace(seed=cfg.seed+seed)._asdict()) #type:ignore 80 | trainer.init_and_train_(_ktrain) 81 | if cfg.log: 82 | trainer.logger.finish() #type:ignore 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ckpts/* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /utils/foraging.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.random as jr 4 | from typing import NamedTuple 5 | from jaxtyping import Bool, Float, Int 6 | import equinox as eqx 7 | 8 | 9 | class EnvState(NamedTuple): 10 | obs: Float 11 | reward: Float 12 | done: Bool 13 | pos: Int 14 | goal: Int 15 | steps: Int 16 | 17 | 18 | action_effects = jnp.array([1, -1, 0]) 19 | 20 | def manhattan_distance(a, b): 21 | return jnp.sum(jnp.absolute(a-b)) 22 | 23 | 24 | class GridChemotaxis: 25 | 26 | #------------------------------------------------------------------- 27 | 28 | def __init__(self, p_switch:float=0., n_types: int=2, env_size: int=5, max_steps=10, dense_reward=False) -> None: 29 | 30 | self.n_types = n_types 31 | self.env_size = env_size 32 | self.max_steps = max_steps 33 | self.p_switch = p_switch 34 | self.dense_reward = dense_reward 35 | 36 | #------------------------------------------------------------------- 37 | 38 | def step(self, state: EnvState, action: Int, key: jax.Array)->EnvState: 39 | 40 | goals_pos = jnp.array([0, self.env_size-1], dtype=int) 41 | goal_pos = goals_pos[state.goal] 42 | 43 | dp = action_effects[action] 44 | np = jnp.clip(state.pos + dp, 0, self.env_size-1) 45 | dist_to_goal = jnp.abs(np-goal_pos) 46 | close_to_goal = dist_to_goal == 0 47 | 48 | if not self.dense_reward: 49 | r = (close_to_goal.astype(float)*10) 50 | else: 51 | r = - dist_to_goal 52 | 53 | done = close_to_goal | (state.steps==self.max_steps) 54 | kp, kres = jr.split(key) 55 | switch = jr.uniform(kp)EnvState: 76 | kpos, ktpos, kt = jr.split(key, 3) 77 | 78 | pos = jnp.array(self.env_size//2, dtype=int) 79 | goal = self._sample_goal(kt) 80 | 81 | return EnvState(pos=pos, obs=self._get_obs(pos, goal), reward=jnp.zeros(()), 82 | done=jnp.zeros(()).astype(bool), steps=jnp.zeros((), dtype=int), 83 | goal=goal) 84 | 85 | #------------------------------------------------------------------- 86 | 87 | def _sample_goal(self, key): 88 | 89 | return jr.randint(key, (), minval=0, maxval=2) 90 | 91 | #------------------------------------------------------------------- 92 | 93 | def _get_obs(self, pos, target_pos): 94 | return pos 95 | 96 | #------------------------------------------------------------------- 97 | 98 | class GridTask: 99 | 100 | #------------------------------------------------------------------- 101 | 102 | def __init__(self, statics, n_steps=100, **kwargs): 103 | 104 | self.statics = statics 105 | self.n_steps = n_steps 106 | self.env = GridChemotaxis(**kwargs) 107 | 108 | #------------------------------------------------------------------- 109 | 110 | def __call__(self, params, key, *args, **kwargs): 111 | 112 | raise NotImplementedError 113 | 114 | #------------------------------------------------------------------- 115 | 116 | class GridEpisodicTask(GridTask): 117 | 118 | #------------------------------------------------------------------- 119 | 120 | def __call__(self, params, key, *args, **kwargs): 121 | 122 | key, kpinit, keinit = jr.split(key,3) 123 | 124 | pi = eqx.combine(params, self.statics) 125 | pi_state = pi.initialize(kpinit) 126 | 127 | env_state = self.env.reset(keinit) 128 | rews = jnp.zeros(()) 129 | def step(c, _): 130 | pi_state, env_state, key = c 131 | key, k = jr.split(key) 132 | pi_state = pi_state._replace(r=env_state.reward[None]) 133 | action, pi_state = pi(env_state.obs, pi_state, k) 134 | env_state = self.env.step(env_state, action, k) 135 | return [pi_state, env_state, key], [env_state.reward, env_state] 136 | 137 | [pi_state, env_state, _], [rews, env_states] = jax.lax.scan(step, [pi_state, env_state, key], None, self.n_steps) 138 | return rews.sum(), dict() 139 | 140 | 141 | class PiState(NamedTuple): 142 | r: jax.Array=jnp.zeros((1,)) 143 | 144 | class Pi(eqx.Module): 145 | """ 146 | """ 147 | #------------------------------------------------------------------- 148 | # Parameters: 149 | 150 | # Statics: 151 | 152 | #------------------------------------------------------------------- 153 | 154 | def __call__(self, obs, state, key, *args, **kwargs): 155 | p = obs 156 | action = jnp.argmax((p*action_effects)) 157 | return action, state 158 | 159 | #------------------------------------------------------------------- 160 | 161 | def initialize(self, key): 162 | return PiState() 163 | 164 | if __name__ == '__main__': 165 | env = GridChemotaxis() 166 | s = env.reset(jr.key(1)) 167 | key = jr.key(2) 168 | ret = 0 169 | for i in range(100): 170 | key, k = jr.split(key) 171 | s = env.step(s, jnp.array(1,dtype=int), k) 172 | ret += s.reward 173 | print(ret) 174 | 175 | -------------------------------------------------------------------------------- /utils/exputils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def load_config(factory): 4 | default_config = factory() 5 | parser = argparse.ArgumentParser() 6 | bools = [] 7 | for k, v in default_config._asdict().items(): 8 | dtype = int if isinstance(v, bool) else type(v) 9 | dv = int(v) if isinstance(v, bool) else v 10 | if isinstance(v, bool): bools.append(k) 11 | parser.add_argument(f"--{k}", type=dtype, default=dv) 12 | config = vars(parser.parse_args()) 13 | for k in bools: 14 | config[k] = bool(config[k]) 15 | config = factory(**config) 16 | return config 17 | 18 | import jax 19 | import jax.experimental.host_callback as hcb 20 | from tqdm import tqdm 21 | 22 | def progress_bar_scan(num_samples, message=None): 23 | "Progress bar for a JAX scan" 24 | if message is None: 25 | message="" 26 | tqdm_bars = {} 27 | 28 | print_rate = 5 29 | remainder = num_samples % print_rate 30 | 31 | def _define_tqdm(arg, transform): 32 | tqdm_bars[0] = tqdm(range(num_samples)) 33 | tqdm_bars[0].set_description(message, refresh=False) 34 | 35 | def _update_tqdm(arg, transform): 36 | tqdm_bars[0].update(arg) 37 | 38 | def _update_progress_bar(iter_num): 39 | "Updates tqdm progress bar of a JAX scan or loop" 40 | _ = jax.lax.cond( 41 | iter_num == 0, 42 | lambda _: hcb.id_tap(_define_tqdm, None, result=iter_num), 43 | lambda _: iter_num, 44 | operand=None, 45 | ) 46 | _ = jax.lax.cond( 47 | # update tqdm every multiple of `print_rate` except at the end 48 | (iter_num % print_rate == 0) & (iter_num != num_samples-remainder), 49 | lambda _: hcb.id_tap(_update_tqdm, print_rate, result=iter_num), 50 | lambda _: iter_num, 51 | operand=None, 52 | ) 53 | _ = jax.lax.cond( 54 | # update tqdm by `remainder` 55 | iter_num == num_samples-remainder, 56 | lambda _: hcb.id_tap(_update_tqdm, remainder, result=iter_num), 57 | lambda _: iter_num, 58 | operand=None, 59 | ) 60 | 61 | def _close_tqdm(arg, transform): 62 | tqdm_bars[0].close() 63 | 64 | def close_tqdm(result, iter_num): 65 | return jax.lax.cond( 66 | iter_num == num_samples-1, 67 | lambda _: hcb.id_tap(_close_tqdm, None, result=result), 68 | lambda _: result, 69 | operand=None, 70 | ) 71 | 72 | def _progress_bar_scan(func): 73 | """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. 74 | Note that `body_fun` must either be looping over `np.arange(num_samples)`, 75 | or be looping over a tuple who's first element is `np.arange(num_samples)` 76 | This means that `iter_num` is the current iteration number 77 | """ 78 | def wrapper_progress_bar(carry, x): 79 | if type(x) is tuple: 80 | iter_num, *_ = x 81 | else: 82 | iter_num = x 83 | _update_progress_bar(iter_num) 84 | result = func(carry, x) 85 | return close_tqdm(result, iter_num) 86 | 87 | return wrapper_progress_bar 88 | 89 | return _progress_bar_scan 90 | 91 | 92 | 93 | def progress_bar_fori(num_samples, message=None): 94 | "Progress bar for a JAX scan" 95 | if message is None: 96 | message="" 97 | tqdm_bars = {} 98 | 99 | print_rate = 5 100 | remainder = num_samples % print_rate 101 | 102 | def _define_tqdm(arg, transform): 103 | tqdm_bars[0] = tqdm(range(num_samples)) 104 | tqdm_bars[0].set_description(message, refresh=False) 105 | 106 | def _update_tqdm(arg, transform): 107 | tqdm_bars[0].update(arg) 108 | 109 | def _update_progress_bar(iter_num): 110 | "Updates tqdm progress bar of a JAX scan or loop" 111 | _ = jax.lax.cond( 112 | iter_num == 0, 113 | lambda _: hcb.id_tap(_define_tqdm, None, result=iter_num), 114 | lambda _: iter_num, 115 | operand=None, 116 | ) 117 | _ = jax.lax.cond( 118 | # update tqdm every multiple of `print_rate` except at the end 119 | (iter_num % print_rate == 0) & (iter_num != num_samples-remainder), 120 | lambda _: hcb.id_tap(_update_tqdm, print_rate, result=iter_num), 121 | lambda _: iter_num, 122 | operand=None, 123 | ) 124 | _ = jax.lax.cond( 125 | # update tqdm by `remainder` 126 | iter_num == num_samples-remainder, 127 | lambda _: hcb.id_tap(_update_tqdm, remainder, result=iter_num), 128 | lambda _: iter_num, 129 | operand=None, 130 | ) 131 | 132 | def _close_tqdm(arg, transform): 133 | tqdm_bars[0].close() 134 | 135 | def close_tqdm(result, iter_num): 136 | return jax.lax.cond( 137 | iter_num == num_samples-1, 138 | lambda _: hcb.id_tap(_close_tqdm, None, result=result), 139 | lambda _: result, 140 | operand=None, 141 | ) 142 | 143 | def _progress_bar_scan(func): 144 | """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. 145 | Note that `body_fun` must either be looping over `np.arange(num_samples)`, 146 | or be looping over a tuple who's first element is `np.arange(num_samples)` 147 | This means that `iter_num` is the current iteration number 148 | """ 149 | def wrapper_progress_bar(x, carry): 150 | iter_num = x 151 | _update_progress_bar(iter_num) 152 | result = func(x, carry) 153 | return close_tqdm(result, iter_num) 154 | 155 | return wrapper_progress_bar 156 | 157 | return _progress_bar_scan -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Optional, Callable, Union 2 | from jaxtyping import Float 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import jax.nn as jnn 7 | import equinox as eqx 8 | import equinox.nn as nn 9 | 10 | class Graph(NamedTuple): 11 | A: jax.Array 12 | h: jax.Array 13 | e: jax.Array 14 | 15 | @property 16 | def N(self): 17 | return self.h.shape[0] 18 | 19 | 20 | def simple_rnn(h, w): 21 | h = jnn.tanh(jnp.dot(h,w)) 22 | return h 23 | 24 | scaled_dot_product = lambda q, k: jnp.dot(q,k.T) / jnp.sqrt(k.shape[-1]) 25 | 26 | normalize = lambda x: x / (jnp.linalg.norm(x, axis=-1, keepdims=True)+1e-8) 27 | 28 | 29 | class GraphTransformer(eqx.Module): 30 | 31 | """ 32 | paper: https://arxiv.org/pdf/2012.09699v2.pdf 33 | """ 34 | #------------------------------------------------------------------- 35 | # params : 36 | Q: nn.Linear # Query function 37 | K: nn.Linear # Key 38 | V: nn.Linear # Value 39 | O: nn.Linear # Output V*heads->do 40 | E: Optional[nn.Linear] # Edge attention 41 | # statics : 42 | n_heads: int 43 | use_edge_features: bool 44 | qk_features: int 45 | value_features: int 46 | #------------------------------------------------------------------- 47 | 48 | def __init__(self, 49 | in_features: int, 50 | out_features: int, 51 | qk_features: int, 52 | value_features: int, 53 | n_heads: int, 54 | *, 55 | use_edge_features: bool=False, 56 | in_edge_features: int=1, 57 | use_bias: bool=False, 58 | key: jax.Array): 59 | """ 60 | """ 61 | key_Q, key_K, key_V, key_O, key_E = jr.split(key, 5) 62 | self.n_heads = n_heads 63 | self.use_edge_features = use_edge_features 64 | self.qk_features = qk_features 65 | self.value_features = value_features 66 | 67 | self.Q = nn.Linear(in_features, qk_features*n_heads, key=key_Q, use_bias=use_bias) 68 | self.K = nn.Linear(in_features, qk_features*n_heads, key=key_K, use_bias=use_bias) 69 | self.V = nn.Linear(in_features, value_features*n_heads, key=key_V, use_bias=use_bias) 70 | self.O = nn.Linear(value_features*n_heads, out_features, key=key_O) 71 | if use_edge_features: 72 | self.E = nn.Linear(in_edge_features, n_heads, key=key_E) 73 | else: 74 | self.E = None 75 | 76 | #------------------------------------------------------------------- 77 | 78 | def __call__(self, graph: Graph)->Graph: 79 | """return features aggregated through attention""" 80 | h = graph.h 81 | N = h.shape[0] 82 | q, k, v = jax.vmap(self.Q)(h), jax.vmap(self.K)(h), jax.vmap(self.V)(h) 83 | # Compute attention scores (before softmax) (N x N x H) 84 | scores = jax.vmap(scaled_dot_product, in_axes=-1, out_axes=-1)(q.reshape((N, self.qk_features, -1)), 85 | k.reshape((N, self.qk_features, -1))) 86 | if self.use_edge_features: 87 | assert self.E is not None 88 | # use edge features to compute attention scores 89 | we = jax.vmap(jax.vmap(self.E))(graph.e) 90 | scores = scores * we 91 | 92 | w = jnn.softmax(scores, axis=1) # (N x N x H) 93 | x = jax.vmap(jnp.dot, in_axes=-1, out_axes=-1)( 94 | w.transpose((1,0,2)), v.reshape(N, self.value_features, -1) 95 | ) # (N x dv X H) 96 | x = x.reshape((N, -1)) # (N x dv*H) (concatenate the heads) 97 | 98 | h = jax.vmap(self.O)(x) 99 | 100 | return eqx.tree_at(lambda G: G.h, graph, h) 101 | 102 | 103 | def erdos_renyi(key: jax.Array, N: int, p: float, self_loops: bool=False): 104 | """random adjacemcy matrix""" 105 | A = (jr.uniform(key, (N,N)) < p).astype(float) 106 | if not self_loops: 107 | A = jnp.where(jnp.identity(N), 0., A) 108 | return A 109 | 110 | def reservoir(key: jax.Array, N: int, in_dims: int, out_dims: int, p_hh: float=1., p_ih: float=.3, p_ho: float=.5): 111 | key_ih, key_hh, key_ho = jr.split(key, 3) 112 | A = jnp.zeros((N, N)) 113 | I = jnp.arange(in_dims) 114 | O = jnp.arange(out_dims) + (N-out_dims) 115 | H = jnp.arange(N-out_dims-in_dims) + in_dims 116 | 117 | A = A.at[jnp.ix_(I, H)].set((jr.uniform(key_ih, (in_dims, len(H)))jax.Array: 172 | 173 | sigma = jnp.clip(self.sigma, -1., 1.) 174 | alpha = jnp.clip(self.alpha, 0., jnp.inf) 175 | mu = self.mu 176 | dt = jnp.clip(self.dt, 0.01, 1.) 177 | W = jr.multivariate_normal(key, mean=jnp.zeros((self.d,)), cov=sigma, method="svd") 178 | x = x + alpha*(mu-x) * dt + W 179 | return x 180 | 181 | #------------------------------------------------------------------- 182 | 183 | def initialize(self, key: jax.Array)->jax.Array: 184 | sigma = jnp.clip(self.sigma, -1., 1.) 185 | W = jr.multivariate_normal(key, mean=jnp.zeros((self.d,)), cov=sigma, method="svd") 186 | return W -------------------------------------------------------------------------------- /task.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.random as jr 4 | import jax.nn as jnn 5 | from jaxtyping import PyTree 6 | from typing import Callable, Optional 7 | from utils.task import * 8 | from utils.foraging import GridEpisodicTask 9 | 10 | class MultiTask: 11 | 12 | def __init__(self, tasks): 13 | self.tasks = tasks 14 | 15 | def __call__(self, params, key, data=None): 16 | f_total = 0 17 | data = {} 18 | for i, task in enumerate(self.tasks): 19 | f, d = task(params, key, data) 20 | f_total += f 21 | data[f"tsk_{i}"] = f 22 | for k, v in d.items(): 23 | data[f"tsk_{i}_{k}"] = v 24 | 25 | return f_total, data 26 | 27 | class MultiEpisodeGymnaxTask(GymnaxTask): 28 | """ 29 | """ 30 | #------------------------------------------------------------------- 31 | n_episodes: int 32 | fitness_transform: Optional[Callable] 33 | l1_penalty: float 34 | episode_intervention: Callable 35 | dev_after_episode: bool 36 | #------------------------------------------------------------------- 37 | 38 | def __init__( 39 | self, 40 | statics: PyTree, 41 | env: str, 42 | n_episodes: int=8, 43 | env_params: Optional[PyTree] = None, 44 | fitness_transform: Optional[Callable]=None, 45 | data_fn: Callable=lambda d: d, 46 | l1_penalty: float=0., 47 | episode_intervention: Callable=lambda pi, e, k: pi, 48 | dev_after_episode: bool=False): 49 | 50 | super().__init__(statics, env, env_params, data_fn) 51 | 52 | self.n_episodes = n_episodes 53 | self.fitness_transform = fitness_transform 54 | self.l1_penalty = l1_penalty 55 | self.episode_intervention = episode_intervention 56 | self.dev_after_episode = dev_after_episode 57 | 58 | #------------------------------------------------------------------- 59 | 60 | def __call__( 61 | self, 62 | params: Params, 63 | key: jax.Array, 64 | task_params: Optional[TaskParams]=None): 65 | 66 | policy_state=None 67 | full_return = 0. 68 | returns = [] 69 | densities = [] 70 | for episode in range(self.n_episodes): 71 | key, key_, key__ = jr.split(key, 3) 72 | policy_state, episode_return, data = self._rollout(params, key_, policy_state) 73 | d = policy_state.G.A.sum() / (policy_state.G.A.shape[0]**2) 74 | l1 = d * self.l1_penalty 75 | densities.append(d) 76 | full_return = (full_return + (episode_return-l1) / self.n_episodes) 77 | returns.append(episode_return) 78 | policy_state = self.episode_intervention(policy_state, episode, key__) 79 | 80 | if self.fitness_transform is not None: 81 | fitness = self.fitness_transform(full_return, data) #type:ignore 82 | else: 83 | fitness = full_return 84 | data = {f"ep_{i}":r for i, r in enumerate(returns)} 85 | data["density"] = sum(densities) / self.n_episodes 86 | return fitness, data 87 | #------------------------------------------------------------------- 88 | 89 | def _rollout(self, params: Params, key: jax.Array, policy_state: Optional[PolicyState]=None)->Tuple[PolicyState,Float,PyTree]: 90 | """ 91 | code adapted from: https://github.com/RobertTLange/gymnax/blob/main/gymnax/experimental/rollout.py 92 | """ 93 | 94 | model = eqx.combine(params, self.statics) 95 | key_reset, key_episode, key_model, key_dev = jr.split(key, 4) 96 | obs, state = self.env.reset(key_reset, self.env_params) 97 | 98 | def policy_step(state_input, tmp): 99 | """lax.scan compatible step transition in jax env.""" 100 | policy_state, obs, state, rng, last_reward, cum_reward, valid_mask = state_input 101 | rng, rng_step, rng_net = jax.random.split(rng, 3) 102 | action, policy_state = model(obs, policy_state._replace(r=last_reward), rng_net) 103 | next_obs, next_state, reward, done, _ = self.env.step( 104 | rng_step, state, action, self.env_params 105 | ) 106 | new_cum_reward = cum_reward + reward * valid_mask 107 | new_valid_mask = valid_mask * (1 - done) 108 | carry = [ 109 | policy_state._replace(d=1-valid_mask), 110 | next_obs, 111 | next_state, 112 | rng, 113 | reward*valid_mask, 114 | new_cum_reward, 115 | new_valid_mask, 116 | ] 117 | y = [policy_state, obs, action, reward*valid_mask, next_obs, done] 118 | return carry, y 119 | 120 | if policy_state is None: 121 | policy_state = model.initialize(key_model) 122 | # Scan over episode step loop 123 | carry_out, scan_out = jax.lax.scan( 124 | policy_step, 125 | [ 126 | policy_state, 127 | obs, 128 | state, 129 | key_episode, 130 | jnp.array([0.0]), 131 | jnp.array([0.0]), 132 | jnp.array([1.0]), 133 | ], 134 | (), 135 | self.env.default_params.max_steps_in_episode, 136 | ) 137 | # Return the sum of rewards accumulated by agent in episode rollout 138 | policy_states, obs, action, reward, _, _ = scan_out 139 | policy_state, *_ = carry_out 140 | cum_return = carry_out[-2][0] 141 | data = {"policy_states": policy_states, "obs": obs, 142 | "action": action, "rewards": reward} 143 | data = self.data_fn(data) 144 | 145 | if self.dev_after_episode: 146 | policy_state = model.dev(policy_state._replace(r=jnp.zeros((1,))), key_dev)#type:ignore 147 | 148 | return policy_state, cum_return, data 149 | 150 | class MultiepisodeBraxTask(BraxTask): 151 | 152 | """ 153 | """ 154 | #------------------------------------------------------------------- 155 | n_episodes: int 156 | fitness_transform: Optional[Callable] 157 | #------------------------------------------------------------------- 158 | 159 | def __init__( 160 | self, 161 | statics: PyTree, 162 | env: str, 163 | n_episodes: int=8, 164 | backend: str="positional", 165 | fitness_transform: Optional[Callable]=None, 166 | data_fn: Callable=lambda d: d): 167 | 168 | super().__init__(statics, env, 500, backend, data_fn) 169 | 170 | self.n_episodes = n_episodes 171 | self.fitness_transform = fitness_transform 172 | 173 | #------------------------------------------------------------------- 174 | 175 | def __call__( 176 | self, 177 | params: Params, 178 | key: jax.Array, 179 | task_params: Optional[TaskParams]=None): 180 | 181 | policy_state=None 182 | full_return = 0. 183 | returns = [] 184 | 185 | for _ in range(self.n_episodes): 186 | key, key_ = jr.split(key) 187 | policy_state, episode_return, data = self._rollout(params, key_, policy_state) 188 | full_return = full_return + episode_return / self.n_episodes 189 | returns.append(episode_return) 190 | 191 | if self.fitness_transform is not None: 192 | fitness = self.fitness_transform(full_return, data) #type:ignore 193 | else: 194 | fitness = full_return 195 | return fitness, {f"ep_{i}":r for i, r in enumerate(returns)} 196 | 197 | #------------------------------------------------------------------- 198 | 199 | def _rollout(self, params: Params, key: jax.Array, policy_state: Optional[PolicyState]=None): 200 | 201 | policy = eqx.combine(params, self.statics) 202 | key, init_env_key, init_policy_key, rollout_key = jr.split(key, 4) 203 | 204 | policy_state = policy.initialize(init_policy_key) if policy_state is None else policy_state 205 | env_state = self.initialize(init_env_key) 206 | init_state = State(env_state=env_state, policy_state=policy_state) 207 | 208 | def env_step(carry, x): 209 | state, key = carry 210 | key, _key = jr.split(key) 211 | action, policy_state = policy(state.env_state.obs, state.policy_state._replace(r=state.env_state.reward[None], d=state.env_state.done[None]), _key) 212 | env_state = self.env.step(state.env_state, action) 213 | new_state = State(env_state=env_state, policy_state=policy_state) 214 | 215 | return [new_state, key], state 216 | 217 | [state, _], states = jax.lax.scan(env_step, [init_state, rollout_key], None, self.max_steps) 218 | data = {"policy_states": states.policy_state, "obs": states.env_state.obs} 219 | data = self.data_fn(data) 220 | data["reward"] = states.env_state.reward 221 | return state.policy_state, states.env_state.reward.sum(), data 222 | 223 | 224 | def make(config, statics): 225 | data_fn = lambda d: d 226 | if config.env_name=="Grid": 227 | return GridEpisodicTask(statics, p_switch=config.p_switch, env_size=config.env_size, dense_reward=bool(config.dense_reward)) 228 | elif config.env_name[0].isupper(): 229 | return MultiEpisodeGymnaxTask(statics, n_episodes=config.n_episodes, env=config.env_name, 230 | data_fn=data_fn, l1_penalty=config.l1_penalty, 231 | dev_after_episode=bool(config.dev_after_episode)) 232 | else: 233 | return MultiepisodeBraxTask(statics, env=config.env_name, n_episodes=config.n_episodes) 234 | -------------------------------------------------------------------------------- /utils/task.py: -------------------------------------------------------------------------------- 1 | from ast import Call 2 | from typing import Callable, NamedTuple, Optional, Tuple, TypeAlias, Union 3 | import jax 4 | import jax.numpy as jnp 5 | import jax.random as jr 6 | import equinox as eqx 7 | 8 | import gymnax 9 | 10 | from brax import envs 11 | from brax.envs import Env 12 | from jaxtyping import Float, PyTree 13 | 14 | Params: TypeAlias = PyTree 15 | TaskParams: TypeAlias = PyTree 16 | EnvState: TypeAlias = PyTree 17 | Action: TypeAlias = jax.Array 18 | PolicyState: TypeAlias = PyTree 19 | BraxEnv: TypeAlias = Env 20 | GymEnv: TypeAlias = gymnax.environments.environment.Environment 21 | 22 | class State(NamedTuple): 23 | env_state: EnvState 24 | policy_state: PolicyState 25 | 26 | 27 | #======================================================================= 28 | #======================================================================= 29 | #======================================================================= 30 | 31 | class BraxTask(eqx.Module): 32 | 33 | """ 34 | """ 35 | #------------------------------------------------------------------- 36 | env: BraxEnv 37 | statics: PyTree[...] 38 | max_steps: int 39 | data_fn: Callable[[PyTree], dict] 40 | #------------------------------------------------------------------- 41 | 42 | def __init__( 43 | self, 44 | statics: PyTree[...], 45 | env: Union[str, BraxEnv], 46 | max_steps: int, 47 | backend: str="positional", 48 | data_fn: Callable=lambda x: x, 49 | env_kwargs: dict={}): 50 | 51 | if isinstance(env, str): 52 | self.env = envs.get_environment(env, backend=backend, **env_kwargs) 53 | else: 54 | self.env = env 55 | 56 | self.statics = statics 57 | self.max_steps = max_steps 58 | self.data_fn = data_fn 59 | 60 | #------------------------------------------------------------------- 61 | 62 | def __call__( 63 | self, 64 | params: Params, 65 | key: jax.Array, 66 | task_params: Optional[TaskParams]=None)->Tuple[Float, PyTree]: 67 | 68 | _, _, data = self.rollout(params, key) 69 | return jnp.sum(data["reward"]), data 70 | 71 | #------------------------------------------------------------------- 72 | 73 | def rollout( 74 | self, 75 | params: Params, 76 | key: jax.Array, 77 | task_params: Optional[TaskParams]=None)->Tuple[State, State, dict]: 78 | 79 | key, init_env_key, init_policy_key, rollout_key = jr.split(key, 4) 80 | policy = eqx.combine(params, self.statics) 81 | 82 | policy_state = policy.initialize(init_policy_key) 83 | env_state = self.initialize(init_env_key) 84 | init_state = State(env_state=env_state, policy_state=policy_state) 85 | 86 | def env_step(carry, x): 87 | state, key = carry 88 | key, _key = jr.split(key) 89 | action, policy_state = policy(state.env_state.obs, state.policy_state, _key) 90 | env_state = self.env.step(state.env_state, action) 91 | new_state = State(env_state=env_state, policy_state=policy_state) 92 | 93 | return [new_state, key], state 94 | 95 | [state, _], states = jax.lax.scan(env_step, [init_state, rollout_key], None, self.max_steps) 96 | data = {"policy_states": states.policy_state, "obs": states.env_state.obs} 97 | data = self.data_fn(data) 98 | data["reward"] = states.env_state.reward 99 | return state, states, data 100 | 101 | #------------------------------------------------------------------- 102 | 103 | def step(self, *args, **kwargs): 104 | return self.env.step(*args, **kwargs) 105 | 106 | def reset(self, *args, **kwargs): 107 | return self.env.reset(*args, **kwargs) 108 | 109 | #------------------------------------------------------------------- 110 | 111 | def initialize(self, key:jax.Array)->EnvState: 112 | 113 | return self.env.reset(key) 114 | 115 | #------------------------------------------------------------------- 116 | 117 | 118 | #======================================================================= 119 | #======================================================================= 120 | #======================================================================= 121 | 122 | 123 | class GymnaxTask(eqx.Module): 124 | 125 | """ 126 | """ 127 | #------------------------------------------------------------------- 128 | statics: PyTree 129 | env: GymEnv 130 | env_params: PyTree 131 | data_fn: Callable 132 | #------------------------------------------------------------------- 133 | 134 | def __init__( 135 | self, 136 | statics: PyTree, 137 | env: str, 138 | env_params: Optional[PyTree] = None, 139 | data_fn: Callable=lambda d: d): 140 | 141 | self.statics = statics 142 | self.env, default_env_params = gymnax.make(env) #type: ignore 143 | self.env_params = env_params if env_params is not None else default_env_params 144 | self.data_fn = data_fn 145 | 146 | #------------------------------------------------------------------- 147 | 148 | def __call__( 149 | self, 150 | params: Params, 151 | key: jax.Array, 152 | task_params: Optional[TaskParams]=None)->Tuple[Float, PyTree]: 153 | 154 | return self.rollout(params, key, task_params) 155 | 156 | #------------------------------------------------------------------- 157 | 158 | def rollout(self, params: Params, key: jax.Array, task_params: Optional[TaskParams]=None)->Tuple[Float, PyTree]: 159 | """ 160 | code adapted from: https://github.com/RobertTLange/gymnax/blob/main/gymnax/experimental/rollout.py 161 | """ 162 | 163 | model = eqx.combine(params, self.statics) 164 | key_reset, key_episode, key_model = jr.split(key, 3) 165 | obs, state = self.env.reset(key_reset, self.env_params) 166 | 167 | def policy_step(state_input, tmp): 168 | """lax.scan compatible step transition in jax env.""" 169 | policy_state, obs, state, rng, cum_reward, valid_mask = state_input 170 | rng, rng_step, rng_net = jax.random.split(rng, 3) 171 | 172 | action, policy_state = model(obs, policy_state, rng_net) 173 | next_obs, next_state, reward, done, _ = self.env.step( 174 | rng_step, state, action, self.env_params 175 | ) 176 | new_cum_reward = cum_reward + reward * valid_mask 177 | new_valid_mask = valid_mask * (1 - done) 178 | carry = [ 179 | policy_state, 180 | next_obs, 181 | next_state, 182 | rng, 183 | new_cum_reward, 184 | new_valid_mask, 185 | ] 186 | y = [policy_state, obs, action, reward, next_obs, done] 187 | return carry, y 188 | 189 | policy_state = model.initialize(key_model) 190 | # Scan over episode step loop 191 | carry_out, scan_out = jax.lax.scan( 192 | policy_step, 193 | [ 194 | policy_state, 195 | obs, 196 | state, 197 | key_episode, 198 | jnp.array([0.0]), 199 | jnp.array([1.0]), 200 | ], 201 | (), 202 | self.env.default_params.max_steps_in_episode, 203 | ) 204 | # Return the sum of rewards accumulated by agent in episode rollout 205 | policy_state, obs, action, reward, _, _ = scan_out 206 | cum_return = carry_out[-2][0] 207 | data = {"policy_states": policy_state, "obs": obs, 208 | "action": action, "rewards": reward} 209 | data = self.data_fn(data) 210 | return cum_return, data 211 | 212 | #------------------------------------------------------------------- 213 | 214 | 215 | #======================================================================= 216 | #======================================================================= 217 | #======================================================================= 218 | 219 | class RandomDiscretePolicy(eqx.Module): 220 | """ 221 | """ 222 | #------------------------------------------------------------------- 223 | n_actions: int 224 | #------------------------------------------------------------------- 225 | def __init__(self, n_actions: int): 226 | self.n_actions = n_actions 227 | #------------------------------------------------------------------- 228 | def __call__(self, env_state: EnvState, policy_state: PolicyState, key: jax.Array): 229 | return jr.randint(key, (), 0, self.n_actions), None 230 | #------------------------------------------------------------------- 231 | def initialize(self, *args, **kwargs): 232 | return None 233 | 234 | class RandomContinuousPolicy(eqx.Module): 235 | """ 236 | """ 237 | #------------------------------------------------------------------- 238 | action_dims: int 239 | #------------------------------------------------------------------- 240 | def __init__(self, action_dims: int): 241 | self.action_dims = action_dims 242 | #------------------------------------------------------------------- 243 | def __call__(self, env_state: EnvState, policy_state: PolicyState, key: jax.Array): 244 | return jr.normal(key, (self.action_dims,)), None 245 | #------------------------------------------------------------------- 246 | def initialize(self, *args, **kwargs): 247 | return None 248 | 249 | class StatefulPolicyWrapper(eqx.Module): 250 | """ 251 | Wrapper adding a policy state to the signature call of a stateless policy 252 | """ 253 | #------------------------------------------------------------------- 254 | policy: Union[PyTree[...], Callable[[EnvState, jax.Array], Action]] 255 | #------------------------------------------------------------------- 256 | def __init__(self, policy: Union[PyTree[...], Callable[[EnvState, jax.Array], Action]]): 257 | self.policy = policy 258 | #------------------------------------------------------------------- 259 | def __call__(self, env_state, policy_state, key): 260 | action = self.policy(env_state, key) 261 | return action, None 262 | #------------------------------------------------------------------- 263 | def initialize(self, *args, **kwargs): 264 | return None 265 | 266 | 267 | ENV_SPACES = { 268 | "CartPole-v1": (4, 2, "discrete"), 269 | "Acrobot-v1": (6, 3, "discrete"), 270 | "MountainCar-v0": (2, 3, "discrete"), 271 | "halfcheetah": (17, 6, "continuous"), 272 | "ant": (27, 8, "continuous"), 273 | "walker2d": (17, 6, "continuous"), 274 | "inverted_pendulum": (4, 1, "continuous"), 275 | 'inverted_double_pendulum': (8, 1, "continuous"), 276 | "hopper": (11, 3, "continuous"), 277 | "Pendulum-v1": (3, 1, "continuous"), 278 | "PointRobot-misc": (6, 2, "continuous"), 279 | "MetaMaze-misc": (15, 4, "discrete"), 280 | "Reacher-misc": (8, 2, "continuous") 281 | } 282 | 283 | 284 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.random as jr 4 | import jax.nn as jnn 5 | from jax.experimental.shard_map import shard_map as shmap 6 | import jax.experimental.host_callback as hcb 7 | from jax.sharding import Mesh, PartitionSpec as P 8 | from jax.experimental import mesh_utils 9 | import evosax as ex 10 | import equinox as eqx 11 | from typing import Any, Callable, Dict, Optional, Union, Tuple, TypeAlias 12 | from jaxtyping import PyTree 13 | import os 14 | import wandb 15 | from utils.exputils import * 16 | 17 | Data: TypeAlias = PyTree[...] 18 | TaskParams: TypeAlias = PyTree[...] 19 | TrainState: TypeAlias = PyTree[...] 20 | 21 | class Logger: 22 | 23 | #------------------------------------------------------------------- 24 | 25 | def __init__( 26 | self, 27 | wandb_log: bool, 28 | metrics_fn: Callable[[TrainState, Data], Tuple[Data, Data, int]], 29 | ckpt_file: Optional[str]=None, 30 | ckpt_freq: int=100, 31 | verbose: bool=False): 32 | 33 | if ckpt_file is not None and "/" in ckpt_file: 34 | if not os.path.isdir(ckpt_file[:ckpt_file.rindex("/")]): 35 | os.makedirs(ckpt_file[:ckpt_file.rindex("/")]) 36 | self.wandb_log = wandb_log 37 | self.metrics_fn = metrics_fn 38 | self.ckpt_file = ckpt_file 39 | self.ckpt_freq = ckpt_freq 40 | self.epoch = [0] 41 | self.verbose = verbose 42 | 43 | #------------------------------------------------------------------- 44 | 45 | def log(self, state: TrainState, data: Data): 46 | 47 | log_data, ckpt_data, epoch = self.metrics_fn(state, data) 48 | if self.wandb_log: 49 | self._log(log_data) 50 | self.save_chkpt(ckpt_data, epoch) 51 | return log_data 52 | 53 | #------------------------------------------------------------------- 54 | 55 | def _log(self, data: dict): 56 | hcb.id_tap( 57 | lambda d, *_: wandb.log(d), data 58 | ) 59 | 60 | #------------------------------------------------------------------- 61 | 62 | def save_chkpt(self, data: dict, epoch: int): 63 | 64 | def save(data): 65 | assert self.ckpt_file is not None 66 | file = f"{self.ckpt_file}.eqx" 67 | if self.verbose: 68 | print("saving data at: ", file) 69 | eqx.tree_serialise_leaves(file, data) 70 | 71 | def tap_save(data): 72 | hcb.id_tap(lambda d, *_: save(d), data) 73 | return None 74 | 75 | if self.ckpt_file is not None: 76 | jax.lax.cond( 77 | (jnp.mod(epoch, self.ckpt_freq))==0, 78 | lambda data : tap_save(data), 79 | lambda data : None, 80 | data 81 | ) 82 | 83 | #------------------------------------------------------------------- 84 | 85 | def wandb_init(self, project: str, config: dict, **kwargs): 86 | if self.wandb_log: 87 | wandb.init(project=project, config=config, **kwargs) 88 | 89 | #------------------------------------------------------------------- 90 | 91 | def wandb_finish(self, *args, **kwargs): 92 | if self.wandb_log: 93 | wandb.finish(*args, **kwargs) 94 | 95 | 96 | 97 | class BaseTrainer(eqx.Module): 98 | 99 | """ 100 | """ 101 | #------------------------------------------------------------------- 102 | train_steps: int 103 | logger: Optional[Logger] 104 | progress_bar: Optional[bool] 105 | #------------------------------------------------------------------- 106 | 107 | def __init__(self, 108 | train_steps: int, 109 | logger: Optional[Logger]=None, 110 | progress_bar: Optional[bool]=False): 111 | 112 | self.train_steps = train_steps 113 | self.progress_bar = progress_bar 114 | self.logger = logger 115 | 116 | #------------------------------------------------------------------- 117 | 118 | def __call__(self, key: jax.Array): 119 | 120 | return self.init_and_train(key) 121 | 122 | #------------------------------------------------------------------- 123 | 124 | def train(self, state: TrainState, key: jax.Array, data: Optional[Data]=None)->Tuple[TrainState, Data]: 125 | 126 | def _step(c, x): 127 | s, k = c 128 | k, k_ = jr.split(k) 129 | s, data = self.train_step(s, k_) 130 | 131 | if self.logger is not None: 132 | self.logger.log(s, data) 133 | 134 | return [s, k], {"states": s, "metrics": data} 135 | 136 | if self.progress_bar: 137 | _step = progress_bar_scan(self.train_steps)(_step) #type: ignore 138 | 139 | [state, key], data = jax.lax.scan(_step, [state, key], jnp.arange(self.train_steps)) 140 | 141 | return state, data 142 | 143 | #------------------------------------------------------------------- 144 | 145 | def train_(self, state: TrainState, key: jax.Array, data: Optional[Data]=None)->TrainState: 146 | 147 | def _step(i, c): 148 | s, k = c 149 | k, k_ = jr.split(k) 150 | s, data = self.train_step(s, k_) 151 | if self.logger is not None: 152 | self.logger.log(s, data) 153 | return [s, k] 154 | 155 | if self.progress_bar: 156 | _step = progress_bar_fori(self.train_steps)(_step) #type: ignore 157 | 158 | [state, key] = jax.lax.fori_loop(0, self.train_steps, _step, [state, key]) 159 | return state 160 | 161 | #------------------------------------------------------------------- 162 | 163 | def log(self, data): 164 | hcb.id_tap( 165 | lambda d, *_: wandb.log(d), data 166 | ) 167 | 168 | #------------------------------------------------------------------- 169 | 170 | def init_and_train(self, key: jax.Array, data: Optional[Data]=None)->Tuple[TrainState, Data]: 171 | init_key, train_key = jr.split(key) 172 | state = self.initialize(init_key) 173 | return self.train(state, train_key, data) 174 | 175 | #------------------------------------------------------------------- 176 | 177 | def init_and_train_(self, key: jax.Array, data: Optional[Data]=None)->TrainState: 178 | init_key, train_key = jr.split(key) 179 | state = self.initialize(init_key) 180 | return self.train_(state, train_key, data) 181 | 182 | #------------------------------------------------------------------- 183 | 184 | 185 | Params = PyTree[...] 186 | Task = Callable 187 | 188 | 189 | def default_metrics(state, data): 190 | y = {} 191 | y["best"] = state.best_fitness 192 | y["gen_best"] = data["fitness"].min() 193 | y["gen_mean"] = data["fitness"].mean() 194 | y["gen_worse"] = data["fitness"].max() 195 | y["var"] = data["fitness"].var() 196 | return y 197 | 198 | 199 | class EvosaxTrainer(BaseTrainer): 200 | 201 | """ 202 | """ 203 | #------------------------------------------------------------------- 204 | strategy: ex.Strategy 205 | es_params: ex.EvoParams 206 | params_shaper: ex.ParameterReshaper 207 | task: Task 208 | fitness_shaper: ex.FitnessShaper 209 | n_devices: int 210 | multi_device_mode: str 211 | #------------------------------------------------------------------- 212 | 213 | def __init__( 214 | self, 215 | train_steps: int, 216 | strategy: Union[ex.Strategy, str], 217 | task: Callable, 218 | params_shaper: ex.ParameterReshaper, 219 | popsize: Optional[int]=None, 220 | fitness_shaper: Optional[ex.FitnessShaper]=None, 221 | es_kws: Optional[Dict[str, Any]]={}, 222 | es_params: Optional[ex.EvoParams]=None, 223 | eval_reps: int=1, 224 | logger: Optional[Logger]=None, 225 | progress_bar: Optional[bool]=True, 226 | n_devices: int=1, 227 | multi_device_mode: str="shmap"): 228 | 229 | super().__init__(train_steps=train_steps, 230 | logger=logger, 231 | progress_bar=progress_bar) 232 | 233 | if isinstance(strategy, str): 234 | assert popsize is not None 235 | self.strategy = self.create_strategy(strategy, popsize, params_shaper.total_params, **es_kws) # type: ignore 236 | else: 237 | self.strategy = strategy 238 | 239 | if es_params is None: 240 | self.es_params = self.strategy.default_params 241 | else: 242 | self.es_params = es_params 243 | 244 | self.params_shaper = params_shaper 245 | 246 | if eval_reps > 1: 247 | def _eval_fn(p: Params, k: jax.Array, tp: Optional[PyTree]=None): 248 | """ 249 | """ 250 | fit, info = jax.vmap(task, in_axes=(None,0,None))(p, jr.split(k,eval_reps), tp) 251 | return jnp.mean(fit), info 252 | self.task = _eval_fn 253 | else : 254 | self.task = task 255 | 256 | if fitness_shaper is None: 257 | self.fitness_shaper = ex.FitnessShaper() 258 | else: 259 | self.fitness_shaper = fitness_shaper 260 | 261 | self.n_devices = n_devices 262 | self.multi_device_mode = multi_device_mode 263 | 264 | #------------------------------------------------------------------- 265 | 266 | def eval(self, *args, **kwargs): 267 | 268 | if self.n_devices == 1: 269 | return self._eval(*args, **kwargs) 270 | if self.multi_device_mode=="shmap": 271 | return self._eval_shmap(*args, **kwargs) 272 | elif self.multi_device_mode == "pmap": 273 | return self._eval_pmap(*args, **kwargs) 274 | else: 275 | raise ValueError(f"multi_device_mode {self.multi_device_mode} is not a valid mode") 276 | 277 | #------------------------------------------------------------------- 278 | 279 | def _eval(self, x: jax.Array, key: jax.Array, task_params: PyTree)->Tuple[jax.Array, PyTree]: 280 | 281 | params = self.params_shaper.reshape(x) 282 | _eval = jax.vmap(self.task, in_axes=(0, 0, None)) 283 | return _eval(params, jr.split(key, x.shape[0]), task_params) 284 | 285 | #------------------------------------------------------------------- 286 | 287 | def _eval_shmap(self, x: jax.Array, key: jax.Array, task_params: PyTree)->Tuple[jax.Array, PyTree]: 288 | 289 | devices = mesh_utils.create_device_mesh((self.n_devices,)) 290 | device_mesh = Mesh(devices, axis_names=("p")) 291 | 292 | _eval = lambda x, k: self.task(self.params_shaper.reshape_single(x), k) 293 | batch_eval = jax.vmap(_eval, in_axes=(0,None)) 294 | sheval = shmap(batch_eval, 295 | mesh=device_mesh, 296 | in_specs=(P("p",), P()), 297 | out_specs=(P("p"), P("p")), 298 | check_rep=False) 299 | 300 | return sheval(x, key) 301 | 302 | #------------------------------------------------------------------- 303 | 304 | def _eval_pmap(self, x: jax.Array, key: jax.Array, data: PyTree)->Tuple[jax.Array, PyTree]: 305 | 306 | _eval = lambda x, k: self.task(self.params_shaper.reshape_single(x), k) 307 | batch_eval = jax.vmap(_eval, in_axes=(0,None)) 308 | pop_batch = x.shape[0] // self.n_devices 309 | x = x.reshape((self.n_devices, pop_batch, -1)) 310 | pmapeval = jax.pmap(batch_eval, in_axes=(0,None)) #type: ignore 311 | f, eval_data = pmapeval(x, key) 312 | return f.reshape((-1,)), eval_data 313 | 314 | #------------------------------------------------------------------- 315 | 316 | def train_step(self, state: TrainState, key: jax.Array, data: Optional[TaskParams]=None) -> Tuple[TrainState, Data]: 317 | 318 | ask_key, eval_key = jr.split(key, 2) 319 | x, state = self.strategy.ask(ask_key, state, self.es_params) 320 | fitness, eval_data = self.eval(x, eval_key, data) 321 | f = self.fitness_shaper.apply(x, fitness) 322 | state = self.strategy.tell(x, f, state, self.es_params) 323 | state = self._update_evo_state(state, x, fitness) 324 | return state, {"fitness": fitness, "data": eval_data} #TODO def best as >= 325 | 326 | #------------------------------------------------------------------- 327 | 328 | def _update_evo_state(self, state: TrainState, x: jax.Array, f: jax.Array)->TrainState: 329 | is_best = f.min() <= state.best_fitness 330 | gen_best = x[jnp.argmin(f)] 331 | best_member = jax.lax.select(is_best, gen_best, state.best_member) 332 | state = state.replace(best_member=best_member) #type:ignore 333 | return state 334 | 335 | #------------------------------------------------------------------- 336 | 337 | def initialize(self, key: jax.Array, **kwargs) -> TrainState: 338 | 339 | state = self.strategy.initialize(key, self.es_params) 340 | state = state.replace(**kwargs) 341 | return state 342 | 343 | #------------------------------------------------------------------- 344 | 345 | def create_strategy(self, name: str, popsize: int, num_dims: int, **kwargs)->ex.Strategy: 346 | 347 | ES = getattr(ex, name) 348 | es = ES(popsize=popsize, num_dims=num_dims, **kwargs) 349 | return es 350 | 351 | #------------------------------------------------------------------- 352 | 353 | def load_ckpt(self, ckpt_path: str)->Params: 354 | params = eqx.tree_deserialise_leaves( 355 | ckpt_path, jnp.zeros((self.params_shaper.total_params,)) 356 | ) 357 | return params 358 | 359 | #------------------------------------------------------------------- 360 | 361 | def train_from_model_ckpt(self, ckpt_path: str, key: jax.Array)->Tuple[TrainState, Data]: #type:ignore 362 | 363 | key_init, key_train = jr.split(key) 364 | params = self.load_ckpt(ckpt_path) 365 | state = self.initialize(key_init, mean=self.params_shaper.flatten_single(params)) 366 | return self.train(state, key_train) 367 | 368 | #------------------------------------------------------------------- 369 | 370 | def train_from_model_ckpt_(self, ckpt_path: str, key: jax.Array)->TrainState:#type:ignore 371 | 372 | key_init, key_train = jr.split(key) 373 | params = self.load_ckpt(ckpt_path) 374 | state = self.initialize(key_init, mean=self.params_shaper.flatten_single(params)) 375 | return self.train_(state, key_train) -------------------------------------------------------------------------------- /lndp.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, NamedTuple, Optional, Tuple, TypeAlias 3 | from jaxtyping import Bool, Float, Array, PyTree 4 | import jax 5 | import jax.numpy as jnp 6 | import jax.random as jr 7 | import jax.nn as jnn 8 | import equinox as eqx 9 | import equinox.nn as nn 10 | from numpy import s_ 11 | 12 | from utils.model import * 13 | #from src.gnn.generators import meta_reservoir, reservoir 14 | #from src.tasks.rl import ENV_SPACES 15 | 16 | double_vmap = lambda f: jax.vmap(jax.vmap(f)) 17 | normalize = lambda x: x/(jnp.linalg.norm(x, axis=-1, keepdims=True)+1e-8) 18 | 19 | class PolicyState(NamedTuple): 20 | G: Graph 21 | a: jax.Array 22 | w: jax.Array 23 | r: Float 24 | d: Float 25 | t: Float 26 | a_seq: jax.Array 27 | 28 | 29 | #================================================================================================================================================ 30 | #================================================================================================================================================ 31 | #================================================================================================================================================ 32 | 33 | 34 | class LNDP(eqx.Module): 35 | 36 | #------------------------------------------------------------------- 37 | # Params : 38 | gnn: Callable 39 | Wpre: Callable 40 | node_fn: Callable 41 | edge_fn: Callable 42 | # Statics: 43 | action_dims: int 44 | obs_dims: int 45 | n_nodes: int 46 | dev_steps: int 47 | rnn_iters: int 48 | node_features: int 49 | edge_features: int 50 | p_hh: float 51 | s_hh: float 52 | p_ih: float 53 | s_ih: float 54 | p_ho: float 55 | s_ho: float 56 | mask_A: Callable 57 | discrete_action: bool 58 | env_is_pendulum: bool 59 | use_bias: bool 60 | is_recurrent: bool 61 | gnn_iters: int 62 | stochastic_decisions: bool 63 | pruning: bool 64 | synaptogenesis: bool 65 | ablate_gt: bool=False 66 | block_lt_updates: bool=False 67 | # Optional Params 68 | prune_fn: Optional[Callable]=None 69 | adde_fn: Optional[Callable]=None 70 | sa_fn: Optional[OrnsteinUhlenbeckProcess]=None 71 | #------------------------------------------------------------------- 72 | 73 | def __init__(self, 74 | action_dims: int, 75 | obs_dims: int, 76 | n_nodes: int, 77 | edge_features: int=4, 78 | dev_steps: int=0, 79 | rnn_iters: int=5, 80 | node_features: int=8, 81 | p_hh: float=.2, 82 | s_hh: float=.05, 83 | p_ih: float=.2, 84 | s_ih: float=.05, 85 | p_ho: float=.2, 86 | s_ho: float=.05, 87 | discrete_action: bool=True, 88 | env_is_pendulum: bool=False, 89 | use_bias: bool=False, 90 | is_recurrent: bool=False, 91 | gnn_iters: int=1, 92 | stochastic_decisions: bool=False, 93 | pruning: bool=True, 94 | synaptogenesis: bool=True, 95 | ablate_gt: bool=False, 96 | block_lt_updates: bool=False, 97 | *, 98 | key: jax.Array): 99 | 100 | """ 101 | Nodes are also GRUs 102 | """ 103 | # --- 104 | 105 | self.action_dims=action_dims 106 | self.obs_dims = obs_dims 107 | self.n_nodes = n_nodes 108 | self.dev_steps = dev_steps 109 | self.rnn_iters = rnn_iters 110 | self.node_features = node_features 111 | self.edge_features = edge_features 112 | self.p_hh = p_hh 113 | self.s_hh = s_hh 114 | self.p_ih = p_ih 115 | self.s_ih = s_ih 116 | self.p_ho = p_ho 117 | self.s_ho = s_ho 118 | self.discrete_action = discrete_action 119 | self.env_is_pendulum = env_is_pendulum 120 | self.use_bias = use_bias 121 | self.is_recurrent = is_recurrent 122 | self.gnn_iters = gnn_iters 123 | self.stochastic_decisions = stochastic_decisions 124 | self.pruning = pruning 125 | self.synaptogenesis = synaptogenesis 126 | self.ablate_gt = ablate_gt 127 | self.block_lt_updates = block_lt_updates 128 | 129 | # --- 130 | 131 | key, key_gnn, key_Wpre = jr.split(key,3) 132 | self.gnn = GraphTransformer(in_features=node_features, 133 | out_features=node_features, 134 | qk_features=4, 135 | value_features=8, 136 | n_heads=3, 137 | use_edge_features=True, 138 | in_edge_features=edge_features+3, 139 | key=key_gnn) 140 | self.Wpre = nn.Linear(5+rnn_iters+node_features+1, node_features, key=key_Wpre) 141 | 142 | # --- 143 | 144 | key, key_node = jr.split(key) 145 | self.node_fn = nn.GRUCell(node_features, node_features, key=key_node) 146 | 147 | # --- 148 | 149 | key, key_edge = jr.split(key) 150 | self.edge_fn = nn.GRUCell(2*node_features + 2*(rnn_iters+1) + 1, 151 | edge_features, 152 | key=key_edge) 153 | 154 | # --- 155 | 156 | if synaptogenesis: 157 | key, key_adde = jr.split(key) 158 | self.adde_fn = nn.MLP(2*node_features, 1, 16, 1, key=key_adde, final_activation=jnn.sigmoid) 159 | else: 160 | self.adde_fn = None 161 | 162 | # --- 163 | 164 | if pruning: 165 | key, key_prune = jr.split(key) 166 | self.prune_fn = nn.MLP(edge_features, 1, 16, 1, key=key_prune, final_activation=jnn.sigmoid) 167 | else: 168 | self.prune_fn = None 169 | 170 | # --- 171 | 172 | if dev_steps: 173 | key, key_sa = jr.split(key) 174 | self.sa_fn = OrnsteinUhlenbeckProcess(obs_dims, key_sa) 175 | else: 176 | self.sa_fn = None 177 | 178 | # --- 179 | 180 | self.mask_A = partial(reservoir, 181 | N=n_nodes, 182 | in_dims=obs_dims, 183 | out_dims=action_dims, 184 | p_hh=1., 185 | p_ih=1., 186 | p_ho=1., 187 | key=jr.key(1)) 188 | 189 | #------------------------------------------------------------------- 190 | 191 | def __call__(self, obs: jax.Array, state: PolicyState, key: jax.Array)->Tuple[jax.Array, PolicyState]: 192 | """ 193 | """ 194 | #state = self.update_state(state, key) 195 | if not self.block_lt_updates: 196 | state = jax.lax.cond( 197 | state.d.sum(), 198 | lambda s, k: s, 199 | lambda s, k: self.update_state(s, k), 200 | state, key) 201 | # --- 202 | a, a_seq = self.forward_rnn(obs, state) 203 | state = state._replace(a_seq=a_seq) 204 | # --- 205 | k = 2. if self.env_is_pendulum else 1. 206 | action = jnp.argmax(a[-self.action_dims:]) \ 207 | if self.discrete_action \ 208 | else a[-self.action_dims:]*k 209 | if not self.discrete_action: 210 | action = jnp.where(jnp.isnan(action), 0, action) 211 | if self.is_recurrent: 212 | state = state._replace(a=a) 213 | return action, state 214 | 215 | #------------------------------------------------------------------- 216 | 217 | def initialize(self, key: jax.Array)->PolicyState: 218 | """ 219 | """ 220 | key_A, key_e, key_h, key_dev = jr.split(key, 4) 221 | A = meta_reservoir(N=self.n_nodes, 222 | in_dims=self.obs_dims, 223 | out_dims=self.action_dims, 224 | mu_hh=self.p_hh, 225 | s_hh=self.s_hh, 226 | mu_ih=self.p_ih, 227 | s_ih=self.s_ih, 228 | mu_ho=self.p_ho, 229 | s_ho=self.s_ho, 230 | clip=False, 231 | key=key_A) 232 | h = jr.uniform(key_h, (self.n_nodes, self.node_features), minval=-1., maxval=1.) 233 | e = jr.uniform(key_e, (self.n_nodes, self.n_nodes, self.edge_features), minval=-1., maxval=1.) * A[...,None] 234 | G = Graph(A=A, e=e, h=h) 235 | a = jnp.zeros((self.n_nodes,)) 236 | w = e[...,0] 237 | state = PolicyState(a=a, w=w, G=G, r=jnp.zeros((1,)), d=jnp.zeros((1,)), 238 | t=jnp.zeros(()), a_seq=jnp.zeros((self.n_nodes, self.rnn_iters+1))) 239 | 240 | # --- Dvpt phase --- 241 | if self.dev_steps: 242 | state = self.dev(state, key_dev) 243 | 244 | return state 245 | #------------------------------------------------------------------- 246 | 247 | def dev(self, state: PolicyState, key: jax.Array): 248 | 249 | assert self.sa_fn is not None 250 | 251 | def dev_step(i, s): 252 | state, sa, key = s 253 | key, key_sa, key_up = jr.split(key, 3) 254 | sa = self.sa_fn(sa, key_sa) #type:ignore 255 | state = self.update_state(state, key_up, is_dev=jnp.ones(())) 256 | _, a_seq = self.forward_rnn(sa, state) 257 | state = state._replace(a_seq=a_seq) 258 | return [state,sa,key] 259 | 260 | key_init, key_dev = jr.split(key) 261 | sa = self.sa_fn.initialize(key_init) 262 | state, *_ = jax.lax.fori_loop(0, self.dev_steps, dev_step, [state,sa,key_dev]) 263 | state = state._replace(a=jnp.zeros_like(state.a)) 264 | return state 265 | 266 | #------------------------------------------------------------------- 267 | 268 | def update_state(self, state: PolicyState, key: jax.Array, is_dev: Float=jnp.zeros(()))->PolicyState: 269 | """ 270 | Update the state of the network: 271 | - Update nodes based on their activation sequences and gnn perception 272 | - Update edges based on node states 273 | """ 274 | G = state.G 275 | N = G.h.shape[0] 276 | 277 | # --- Update node states --- 278 | 279 | h = G.h 280 | a = state.a_seq 281 | x = jax.vmap(self.Wpre)(jnp.concatenate([self.get_node_features(G), a, h], axis=-1)) 282 | x = jax.lax.fori_loop(0, self.gnn_iters, lambda _, G: self.gnn(G), G._replace(h=x, e=self.get_edge_features(G))).h 283 | h = jax.vmap(self.node_fn)(x, h) 284 | 285 | # --- Update edge states --- 286 | 287 | e = G.e 288 | hh = jnp.concatenate( 289 | [jnp.repeat(h[None,:], N, axis=0), 290 | jnp.repeat(h[:,None], N, axis=1), 291 | jnp.repeat(a[None,:], N, axis=0), 292 | jnp.repeat(a[:, None], N, axis=1), 293 | jnp.ones((self.n_nodes, self.n_nodes, 1))*state.r], 294 | axis=-1) 295 | e = double_vmap(self.edge_fn)(hh, e) * state.G.A[...,None] 296 | w = e[...,0] 297 | 298 | # --- Prune edges --- 299 | if self.pruning: 300 | key, key_prune = jr.split(key) 301 | p_pruned = double_vmap(self.prune_fn)(e)[...,0] 302 | if self.stochastic_decisions: 303 | pruned = (jr.uniform(key_prune, p_pruned.shape)<(p_pruned*.5)).astype(float) * self.mask_A() 304 | else: 305 | pruned = (p_pruned>.6).astype(float) * self.mask_A() 306 | else: 307 | pruned = jnp.zeros((self.n_nodes,self.n_nodes)) 308 | 309 | # --- Add edges --- 310 | if self.synaptogenesis: 311 | key, key_add = jr.split(key) 312 | hh = jnp.concatenate( 313 | [jnp.repeat(h[None,:], N, axis=0), 314 | jnp.repeat(h[:,None], N, axis=1)], 315 | axis=-1) 316 | p_add = double_vmap(self.adde_fn)(hh)[...,0] 317 | if self.stochastic_decisions: 318 | add = (jr.uniform(key_add, p_add.shape)<(p_add*.5)).astype(float)* self.mask_A() 319 | else: 320 | add = (p_add > .6).astype(float) * self.mask_A() 321 | else: 322 | add = jnp.zeros((self.n_nodes,self.n_nodes)) 323 | 324 | # Modify adjacency matrix 325 | A = jnp.where(add, 1., G.A) 326 | A = jnp.where(pruned, 0., A) 327 | 328 | return eqx.tree_at(lambda s: [s.G.h, s.G.e, s.G.A, s.w, s.t], 329 | state, 330 | [h, e, A, w, state.t+1.]) 331 | 332 | #------------------------------------------------------------------- 333 | 334 | def get_edge_features(self, graph: Graph)->jax.Array: 335 | 336 | return jnp.concatenate( 337 | [ 338 | graph.e, 339 | graph.A[...,None], 340 | graph.A.T[...,None], 341 | jnp.identity(graph.N)[...,None] 342 | ], axis=-1 343 | ) 344 | 345 | #------------------------------------------------------------------- 346 | 347 | def get_node_features(self, graph: Graph)->jax.Array: 348 | """ 349 | 6 features: 350 | - in degree 351 | - out degree 352 | - degree 353 | - layer (1-hot encoding of {inp, hidden, ou}) 354 | """ 355 | i_d = graph.A.sum(1)[...,None] / 10. 356 | o_d = graph.A.sum(0)[...,None] / 10. 357 | d = i_d+o_d / 20. 358 | typ = jnp.zeros((self.n_nodes, 2)).at[:self.obs_dims, 0].set(1.).at[-self.action_dims:,1].set(1.) 359 | 360 | return jnp.concatenate([i_d, o_d, d, typ], axis=-1) 361 | 362 | #------------------------------------------------------------------- 363 | 364 | def forward_rnn(self, 365 | obs: Float[Array, "O"], 366 | state: PolicyState 367 | )->Tuple[Float[Array, "N"], Float[Array, "t+1 N"]]: 368 | """ 369 | Run RNN forward and return the final and sequence of activations for each node 370 | N: number of nodes 371 | t: number of iterations of rnn 372 | """ 373 | if self.use_bias: 374 | b = state.G.h[:, 0] 375 | else: 376 | b = jnp.zeros_like(state.a) 377 | rnn = lambda a, _: (jnn.tanh(a.at[:self.obs_dims].set(obs) @ state.w + b), a.at[:self.obs_dims].set(obs)) 378 | a, a_seq = jax.lax.scan(rnn, state.a, None, self.rnn_iters) 379 | return a, jnp.concatenate([a_seq, a[None,:]], axis=0).T 380 | 381 | 382 | 383 | #================================================================================================================================================ 384 | #================================================================================================================================================ 385 | #================================================================================================================================================ 386 | 387 | ENV_SPACES = { 388 | "CartPole-v1": (4, 2, "discrete"), 389 | "Acrobot-v1": (6, 3, "discrete"), 390 | "MountainCar-v0": (2, 3, "discrete"), 391 | "halfcheetah": (17, 6, "continuous"), 392 | "ant": (27, 8, "continuous"), 393 | "walker2d": (17, 6, "continuous"), 394 | "inverted_pendulum": (4, 1, "continuous"), 395 | 'inverted_double_pendulum': (8, 1, "continuous"), 396 | "hopper": (11, 3, "continuous"), 397 | "Pendulum-v1": (3, 1, "continuous"), 398 | "PointRobot-misc": (6, 2, "continuous"), 399 | "MetaMaze-misc": (15, 4, "discrete"), 400 | "Reacher-misc": (8, 2, "continuous") 401 | } 402 | 403 | 404 | 405 | def make(cfg, key): 406 | 407 | if cfg.env_name=="Grid": 408 | obs_dims, action_dims, action_type = 1, 3, "discrete" 409 | else: 410 | obs_dims, action_dims, action_type = ENV_SPACES.get(cfg.env_name, None) 411 | 412 | return LNDP(action_dims=action_dims, 413 | obs_dims=obs_dims, 414 | n_nodes=cfg.n_nodes, 415 | edge_features=cfg.edge_features, 416 | dev_steps=cfg.dev_steps, 417 | rnn_iters=cfg.rnn_iters, 418 | node_features=cfg.node_features, 419 | p_hh=cfg.p_hh, 420 | s_hh=cfg.s_hh, 421 | p_ih=cfg.p_ih, 422 | s_ih=cfg.s_ih, 423 | p_ho=cfg.p_ho, 424 | s_ho=cfg.s_ho, 425 | discrete_action=action_type in ["d", "discrete"], 426 | env_is_pendulum=cfg.env_name=="Pendulum-v1", 427 | use_bias=bool(cfg.use_bias), 428 | is_recurrent=bool(cfg.is_recurrent), 429 | gnn_iters=cfg.gnn_iters, 430 | stochastic_decisions=bool(cfg.stochastic_decisions), 431 | pruning=bool(cfg.pruning), 432 | synaptogenesis=bool(cfg.synaptogenesis), 433 | block_lt_updates=bool(cfg.block_lt_updates), 434 | key=key) 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | --------------------------------------------------------------------------------