├── .gitignore ├── LICENSE ├── README.md ├── senn_cnn ├── Dockerfile ├── README.md ├── bashrc_docker ├── docker-compose.yaml ├── experiments │ ├── __init__.py │ ├── config.py │ ├── models.py │ ├── senn_cifar10_manycycle.py │ ├── senn_cifar10_onecycle.py │ ├── transfer_cifar10_pretrain.py │ ├── transfer_tinyimagenet_fixed.py │ └── transfer_tinyimagenet_senn.py ├── pdm.lock ├── pyproject.toml ├── senn │ ├── __init__.py │ ├── dummy.py │ ├── linalg.py │ ├── models.py │ ├── neural.py │ └── opt.py └── tests │ ├── test_linalg.py │ ├── test_models.py │ └── test_neural.py └── senn_mlp ├── Dockerfile ├── README.md ├── build.sh ├── checkpoints └── .placeholder ├── config.yaml ├── data.py ├── datasets └── .placeholder ├── experiment1.py ├── experiment1 └── default_config.yaml ├── experiment2.py ├── experiment2 └── default_config.yaml ├── experiment3.py ├── experiment3 └── default_config.yaml ├── experiment4.py ├── experiment4 └── default_config.yaml ├── experiment_utils.py ├── jaxutils.py ├── langevin.py ├── logs └── .placeholder ├── nets.py ├── nets_legacy.py ├── optim.py ├── requirements.txt ├── results └── .placeholder ├── tboard.sh ├── use └── visualisation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | .pdm-python 162 | 163 | orbax -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ml-research@TUDarmstadt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Expanding Neural Networks 2 | 3 | This repository contains all code necessary to reproduce our experiments from the main body of the paper "Self-Expanding 4 | Neural Networks". There are two codebases, one for CNN experiments and one for MLP experiments. 5 | They are located in the senn_cnn and senn_mlp folders respectively. 6 | 7 | ## Setting up the CNN Codebase 8 | Setting up the CNN codebase requires to navigate to the senn_cnn and proceed with the requirements and environment variables as described below. 9 | 10 | ### Installing Requirements 11 | The requirements can be installed using PDM. To do so, first install PDM, then run ```pdm install``` to install the requirements. This will setup a virtual environment with all dependencies. We provide a Dockerfile as well which should be used after setting up pdm. 12 | 13 | #### Docker 14 | In order ease reproduction, we provide a Dockerfile. 15 | We also provide a docker-compose service to setup the docker container. 16 | To use Docker, first install Docker and docker-compose, then run ```docker compose build``` to build the container. 17 | Then run ```docker compose up -d``` to start the container. This will start a container with all dependencies installed. 18 | 19 | #### PDM 20 | It is possible to install the requirements using PDM, a python package manager. To do so, first install PDM, then run 21 | ```pdm install``` to install the requirements. This will setup a virtual environment with all dependencies. 22 | 23 | ### Environment Variables 24 | The following environment variables are used: 25 | CUDA_VISIBLE_DEVICES: Specifies which GPU to use. 26 | DATASETS_ROOT_DIR: Specifies the root directory for datasets. 27 | WANDB_DIR: Specifies the root directory for wandb logs. 28 | RTPT_INITIALS: Specifies the initials to use for RTPT. 29 | WANDB_API_KEY: Specifies the API key to use for wandb. This is only necessary if you want to log to wandb. 30 | If using Docker, these may be set in the docker-compose.yml file, or by creating a .env file in the root directory. 31 | If using a .env file create it before running ```docker compose up -d```. 32 | 33 | ### Running Experiments 34 | To run an experiment, simply run the corresponding script. If using Docker, this may be done by running 35 | ```docker compose exec -e CUDA_VISIBLE_DEVICES -e JAX_PLATFORMS main python experiments/{experiment}.py``` 36 | where experiment is on of: 37 | - senn_cifar10_manycycle 38 | - senn_cifar10_onecycle 39 | - transfer_cifar10_pretrain 40 | - transfer_tinyimagenet_fixed 41 | - transfer_tinyimagenet_senn 42 | If using pdm only, source the virtual environment, then run ```python experiments/{experiment}.py```. 43 | 44 | #### Transfer Learning Experiments 45 | The transfer learning experiments save a model checkpoint after training. This checkpoint is then loaded and used to 46 | initialize the model for the transfer learning task. In docker the directory is /senn/orbax/pretrained/final. If you want to alter the checkpoint dir you can change the CHECKPOINT_DIR variable in each experiment file. 47 | 48 | #### Experiment Settings 49 | The hyperparameters and tasks may be varied by editing the settings in the corresponding experiment file. 50 | 51 | #### Changing the random seed 52 | The random seeds used for the model initialization training can be altered in the experiment script. 53 | 54 | ## Setting up the MLP Codebase 55 | You can setup the MLP codebase by navigating to the senn_mlp folder and proceeding with the requirements and environment variables as described below. 56 | ### Installing Requirements 57 | We provide two methods to quickly install all dependencies required for our experiments. 58 | 59 | #### Docker 60 | In order ease reproduction, we provide a Dockerfile and scripts with which to use it. 61 | 1. To build the container run ```sh build.sh``` in the main ("senn") directory. 62 | 2. In order to run the container on gpu 0 execute the command ```./use 0```, where this argument may be omitted if using cpu. 63 | 3. To start a container with port forwarding for tensorboard use ```sh tboard.sh``` 64 | 65 | ### Pip 66 | It is of course possible to install dependencies with pip, without using Docker: 67 | 1. Install JAX - follow instructions in [official repository](https://github.com/google/jax), e.g. for cuda as of May 2023: 68 | ```pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html```. 69 | 2. Install PyTorch and Torchvision - instructions at [official website](https://pytorch.org/get-started/locally/), e.g. as of May 2023: 70 | ```pip install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu118```. Make sure to install a version with CUDA version requirements compatible with JAX. 71 | 3. Install other requirements from requirements.txt: 72 | ```pip install -r requirements.txt``` 73 | 74 | ### Running Experiments 75 | If you are not using Docker, adjust the "checkpoints", "datasets" and "logs" folders in config.yaml to your liking. To 76 | run an experiment, simply execute the corresponding script, e.g. for experiment 4: 77 | ```python experiment4.py --name my_experiment```. 78 | The "--name my_experiment" argument specifies that tensorboard will store any logs under that name. 79 | 80 | #### Tensorboard Results 81 | Various metrics, such as training/validation loss and neuron count will be logged during training. If using 82 | Docker, these metrics may be viewed by starting a container with ```sh tboard.sh``` and then running the command 83 | ```tensorboard --logdir /logs``` inside it. If Docker is not used, then simply run ```tensorboard --logdir ./mylogs```, 84 | replacing "./mylogs" with the directory chosen in "config.yaml" for "meta:logdir". 85 | 86 | #### Experiment Settings 87 | The hyperparameters and tasks may be adjusted in the "default_config.yaml" in the folder corresponding to the experiment 88 | e.g. for experiment 4 one would adjust "senn/experiment4/default_config.yaml". 89 | 90 | ##### Changing the random seed 91 | The random seed used during an experiment may be varied by altering "meta:seed", or by passing the commandline argument 92 | ```--seed```, e.g. ```python experiment4.py --name my_experiment --seed 2```. 93 | 94 | ##### Subset size in experiment 4 95 | In experiment 4 there is an additional replication relevant setting, "data:defaults:N", initially set to 4800. This 96 | specifies the number of examples from each class which will be included in the subset of MNIST trained on. For example, 97 | N=6000 would result in training on the full 60000 images. 98 | -------------------------------------------------------------------------------- /senn_cnn/Dockerfile: -------------------------------------------------------------------------------- 1 | # select image, we choose tensorflow image to build off of for JAX 2 | FROM nvcr.io/nvidia/tensorflow:23.04-tf2-py3 AS builder 3 | 4 | # install python installation tools 5 | RUN pip install --upgrade pip setuptools wheel 6 | RUN pip install pdm 7 | 8 | # copy in files defining the build for pdm 9 | COPY pyproject.toml pdm.lock README.md /senn/ 10 | # COPY senn/ /senn/senn 11 | 12 | # set working directory and install all dependencies with pdm 13 | # note that local 'senn' package is *not* installed 14 | WORKDIR /senn 15 | RUN mkdir __pypackages__ && pdm sync --no-editable --no-self 16 | 17 | 18 | # select image, same as above 19 | FROM nvcr.io/nvidia/tensorflow:23.04-tf2-py3 20 | 21 | # set python path to new directory and copy pdm installation from builder into it 22 | ENV PYTHONPATH=/pdm_pkgs 23 | COPY --from=builder /senn/__pypackages__/3.8/lib /pdm_pkgs 24 | 25 | # set working directory 26 | WORKDIR /senn 27 | 28 | # entrypoint is bash so that we get an interactive shell by default 29 | # /root/.bashrc can be used to run shell commands on startup 30 | ENTRYPOINT "/bin/bash" 31 | -------------------------------------------------------------------------------- /senn_cnn/README.md: -------------------------------------------------------------------------------- 1 | # CNN Codebase 2 | This is the codebase for the CNN experiments. 3 | For a description of the setup and how to run experiments, see the README.md in the root directory. -------------------------------------------------------------------------------- /senn_cnn/bashrc_docker: -------------------------------------------------------------------------------- 1 | cd /senn 2 | 3 | # old commands for pdm passthrough setup: 4 | # pdm use --venv in-project 5 | # eval $(pdm venv activate in-project) 6 | # pdm install 7 | 8 | # new command to just add local package as editable from volume 9 | pip install --editable . 10 | -------------------------------------------------------------------------------- /senn_cnn/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | main: 4 | build: 5 | context: . 6 | dockerfile: Dockerfile 7 | stdin_open: true 8 | tty: true 9 | env_file: .env 10 | environment: 11 | - CUDA_VISIBLE_DEVICES 12 | - JAX_PLATFORMS 13 | - DATASETS_ROOT_DIR=/datasets 14 | - WANDB_DIR=/wandb 15 | volumes: 16 | - .:/senn 17 | - ./bashrc_docker:/root/.bashrc 18 | - $DATASETS_ROOT_DIR:/datasets 19 | - $WANDB_DIR:/wandb 20 | shm_size: '8gb' 21 | deploy: 22 | resources: 23 | reservations: 24 | devices: 25 | - capabilities: [gpu] 26 | -------------------------------------------------------------------------------- /senn_cnn/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/self-expanding-neural-networks/3480be01cbbfa46726af3a84dd9fb834d1ca979e/senn_cnn/experiments/__init__.py -------------------------------------------------------------------------------- /senn_cnn/experiments/config.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import tensorflow_datasets as tfds 3 | import tensorflow.data 4 | from jax.random import PRNGKey 5 | import models 6 | import os 7 | import jax 8 | from jax import numpy as jnp 9 | from flax.core import frozen_dict 10 | from flax.traverse_util import ModelParamTraversal 11 | import optax 12 | from flax import linen as nn 13 | from functools import partial 14 | 15 | import senn.opt 16 | from senn.opt import ( 17 | TrainState, 18 | KronTracker, 19 | IRootTracker, 20 | InnerOpt, 21 | InnerState, 22 | InnerConfig, 23 | TreeOpt, 24 | FlattenOpt, 25 | DiagOpt, 26 | ) 27 | from senn.linalg import IRootWhitener, DiagWhitener, HybridWhitener, MaskedWhitener 28 | from senn.models import ExpandableDense 29 | 30 | from tiny_imagenet import TinyImagenetDataset 31 | import tensorflow as tf 32 | 33 | 34 | dataset_kwargs = dict( 35 | name=wandb.config.dataset, 36 | download=True, 37 | as_supervised=True, 38 | data_dir=os.environ["DATASETS_ROOT_DIR"], 39 | ) 40 | 41 | 42 | def get_trainset(info=False): 43 | return tfds.load( 44 | **dataset_kwargs, 45 | split=wandb.config.train_split, 46 | with_info=info, 47 | shuffle_files=True, 48 | ) 49 | 50 | 51 | trainset, info = get_trainset(info=True) 52 | evalset_kwargs = dict() 53 | 54 | 55 | def get_evalset_dict(): 56 | evalset_dict = { 57 | split: tfds.load(**dataset_kwargs, **evalset_kwargs, split=split) 58 | for split in wandb.config.eval_splits 59 | } 60 | return evalset_dict 61 | 62 | 63 | evalset_dict = get_evalset_dict() 64 | if "num_classes" not in wandb.config: 65 | wandb.config.num_classes = info.features["label"].num_classes 66 | 67 | def resize_dataset(img, label): 68 | img = tf.image.resize_with_pad(img, *wandb.config.resize_to) 69 | return img, label 70 | if wandb.config.resize_to is not None: 71 | trainset = trainset.map(resize_dataset) 72 | evalset_dict = {key: val.map(resize_dataset) for key, val in evalset_dict.items()} 73 | 74 | nonlin = {"tanh": jnp.tanh, "swish": jax.nn.swish}[wandb.config.model_nonlinearity] 75 | model_types = dict( 76 | Perceptron=models.Perceptron, 77 | SmallConvNet=models.SmallConvNet, 78 | AllCnnA=models.AllCnnA, 79 | AllCnnC=models.AllCnnC, 80 | DenseNet=models.DenseNet, 81 | WaveletNet=models.WaveletNet, 82 | BottleneckDense=models.BottleneckDense, 83 | ExpandableDense=senn.models.ExpandableDense, 84 | WaveNet=models.WaveNet, 85 | ) 86 | module_def = model_types[wandb.config.model_type] 87 | 88 | 89 | def blockify_kwargs(widths, **kwargs): 90 | blocks = list(models.ExpandableBlock(widths=ws, **kwargs) for ws in widths) 91 | return dict(blocks=blocks) 92 | 93 | 94 | model_kwargs = wandb.config.model_kwargs 95 | if wandb.config.model_type == "ExpandableDense": 96 | 97 | def build_from_widths(widths=None): 98 | kwargs = frozen_dict.freeze(wandb.config.model_kwargs) 99 | kwargs, init_widths = kwargs.pop("widths") 100 | widths = init_widths if widths is None else widths 101 | return senn.models.ExpandableDense.build( 102 | widthss=widths, out=wandb.config.num_classes, nonlin=nonlin, **kwargs 103 | ) 104 | 105 | model = build_from_widths() 106 | print(model) 107 | else: 108 | model = module_def(out=wandb.config.num_classes, **model_kwargs, nonlin=nonlin) 109 | (example, _) = next(trainset.take(1).as_numpy_iterator()) 110 | example = jnp.array(example, dtype=jnp.float32) 111 | print(model.tabulate(PRNGKey(0), example)) 112 | variables = model.init(PRNGKey(wandb.config.model_seed), example) 113 | 114 | 115 | def has_kernel(pytree): 116 | return isinstance(pytree, frozen_dict.FrozenDict) and "kernel" in pytree.keys() 117 | 118 | 119 | optimizer = senn.opt.SimpleOpt() 120 | SCHEDULE_REPEAT = wandb.config.num_cycles 121 | PHASE_LEN = (wandb.config.epochs * wandb.config.epoch_len) // SCHEDULE_REPEAT 122 | schedule_kwargs = dict( 123 | transition_steps=PHASE_LEN, 124 | peak_value=wandb.config.peak_learning_rate, 125 | pct_start=wandb.config.pct_start, 126 | ) 127 | schedule_fn = ( 128 | optax.linear_onecycle_schedule 129 | if wandb.config.linear_annealing 130 | else optax.cosine_onecycle_schedule 131 | ) 132 | schedules = [schedule_fn(**schedule_kwargs) for i in range(SCHEDULE_REPEAT)] 133 | boundaries = [i * PHASE_LEN for i in range(SCHEDULE_REPEAT)][1:] 134 | schedule = optax.join_schedules(schedules, boundaries) 135 | #schedule = schedule_fn(**schedule_kwargs) 136 | # schedule = optax.cosine_onecycle_schedule(transition_steps=wandb.config.epochs*781, peak_value=wandb.config.peak_learning_rate, pct_start=wandb.config.pct_start) 137 | # schedule = optax.linear_onecycle_schedule(transition_steps=wandb.config.epochs*781, peak_value=wandb.config.peak_learning_rate, pct_start=wandb.config.pct_start) 138 | # first_order = optax.adam(learning_rate=schedule, b2=0.99) 139 | first_order = senn.opt.MyAdam( 140 | lr=schedule, 141 | mom1=1e-1, 142 | mom2=1e-2, 143 | weight_decay=wandb.config.weight_decay, 144 | noise_std=wandb.config.noise_std, 145 | order=0, 146 | ) 147 | if wandb.config.fast_turbo: 148 | first_order = optax.adamw(learning_rate=schedule, b2=0.99, weight_decay=1e-2) 149 | optimizer = senn.opt.WrappedFirstOrder(tx=first_order) 150 | 151 | apply_fn = model.apply 152 | add_width_fn = partial(model.apply, method=model.maybe_add_width, mutable=True) 153 | init_variables = variables 154 | if (order := wandb.config.taylor_order) is not None: 155 | # apply_fn = taylorify(apply_fn, basepoint=variables, order=order) 156 | def apply_fn(variables, *args, **kwargs): 157 | def inner(params, variables, *args, **kwargs): 158 | variables = frozen_dict.copy(variables, dict(params=params)) 159 | return model.apply(variables, *args, **kwargs) 160 | 161 | tinner = taylorify(inner, basepoint=init_variables["params"], order=order) 162 | params = variables["params"] 163 | params = frozen_dict.freeze(params) 164 | return tinner(params, variables, *args, **kwargs) 165 | 166 | 167 | initial_train_state = TrainState.create( 168 | optimizer, 169 | variables["params"], 170 | variables.get("probes", {}), 171 | apply_fn, 172 | batch_stats=variables.get("batch_stats", {}), 173 | dummy_input=example, 174 | # add_width_fn=add_width_fn, 175 | model=model, 176 | # traversal=ModelParamTraversal(lambda s: "/bud/" not in s), 177 | path_pred=lambda path: "bud" not in path, 178 | ) 179 | -------------------------------------------------------------------------------- /senn_cnn/experiments/transfer_tinyimagenet_fixed.py: -------------------------------------------------------------------------------- 1 | CHECKPOINT_DIR = "/senn/orbax/pretrained/final" 2 | 3 | from collections.abc import Callable 4 | from functools import partial 5 | from compose import compose 6 | from math import prod 7 | 8 | import jax 9 | import tensorflow as tf 10 | from jax import numpy as jnp 11 | import numpy as np 12 | from jax.random import PRNGKey 13 | from jax.tree_util import tree_map, tree_leaves 14 | from flax import struct, linen as nn 15 | from typing import Any 16 | from flax.core import frozen_dict 17 | from senn.opt import ( 18 | Task, 19 | TreeOpt, 20 | TrainState, 21 | step as opt_step, 22 | Stepper, 23 | softmax_grad_hgrad, 24 | universal_grad_hgrad, 25 | DiagOpt, 26 | ) 27 | import senn.opt 28 | from senn.neural import HPerturb, hperturb 29 | 30 | import tensorflow_datasets as tfds 31 | import sklearn 32 | from tensorflow_probability.substrates import jax as tfp 33 | import dm_pix as pix 34 | 35 | from time import time 36 | from tqdm import tqdm, trange 37 | import os 38 | import sys 39 | 40 | import wandb 41 | from rtpt import RTPT 42 | import copy 43 | 44 | import models 45 | 46 | from flax.training import orbax_utils 47 | import orbax.checkpoint 48 | from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager, CheckpointManagerOptions 49 | 50 | calib_err_bins: Any = jax.nn.sigmoid(-1.0 * jnp.arange(-2, 12)) 51 | 52 | 53 | class CalibCount(struct.PyTreeNode): 54 | err: float 55 | correct: int 56 | incorrect: int 57 | 58 | @classmethod 59 | def count(cls, logits, labels): 60 | active_idx = jax.vmap(jnp.argmax)(logits) 61 | active_err = jax.vmap(lambda x: 1.0 - jnp.max(jax.nn.softmax(x)))(logits) 62 | correct = active_idx == labels 63 | incorrect = active_idx != labels 64 | bin_idx = jnp.digitize(active_err, calib_err_bins) 65 | idxs = jnp.arange(len(calib_err_bins) + 1) 66 | 67 | def count_idx(idx): 68 | where = bin_idx == idx 69 | return cls( 70 | jnp.sum(active_err, where=where), 71 | jnp.sum(correct, where=where), 72 | jnp.sum(incorrect, where=where), 73 | ) 74 | 75 | return jax.vmap(count_idx)(idxs) 76 | 77 | def plot(self, title="Calibration"): 78 | percentiles = [50, 5, 95] 79 | keys = list(map(lambda p: f"{p}%", percentiles)) 80 | xs = list(map(float, self.err / (self.correct + self.incorrect))) 81 | 82 | def make_ys(percentile): 83 | return list( 84 | map( 85 | float, 86 | tfp.math.betaincinv( 87 | 1.0 * self.incorrect, 1.0 * self.correct, percentile / 100.0 88 | ), 89 | ) 90 | ) 91 | 92 | ys = list(map(make_ys, percentiles)) 93 | # print(len(xs)) 94 | # print(xs) 95 | # print(len(ys)) 96 | # print(ys) 97 | return wandb.plot.line_series(xs=xs, ys=ys, keys=keys, title=title, xname="err") 98 | 99 | 100 | def label_filter(filt, *item): 101 | img, label = item 102 | return filt[label] 103 | 104 | 105 | def ymetrics(task, ys): 106 | loss = jnp.sum(task.loss(ys)) 107 | argmax = jax.vmap(jnp.argmax)(ys) 108 | ground_truth = jnp.mod(task.label, wandb.config.num_classes) 109 | correct = argmax == ground_truth 110 | acc = jnp.sum(correct) 111 | err = jnp.sum(~correct) 112 | return dict( 113 | loss=loss, accuracy=acc, err=err, calib=CalibCount.count(ys, task.label) 114 | ) 115 | 116 | 117 | def eval_batch(key, task, train_state): 118 | out = dict() 119 | ys = train_state.eval(task) 120 | out = dict(max_a_posteriori=ymetrics(task, ys)) 121 | if wandb.config.only_max_a_posteriori_eval or wandb.config.fast_turbo: 122 | return out 123 | mys = train_state.eval_marginalized(task, key, wandb.config.eval_samples) 124 | rys = train_state.eval(task, key=key) 125 | return dict(**out, marginalized=ymetrics(task, mys), sampled=ymetrics(task, rys)) 126 | #return dict( 127 | # max_a_posteriori=ymetrics(task, ys), 128 | # marginalized=ymetrics(task, mys), 129 | # sampled=ymetrics(task, rys), 130 | #) 131 | 132 | 133 | def eval_one_epoch(name, dataset, train_state): 134 | filters = wandb.config.eval_filters 135 | for fidx, filt in enumerate(filters): 136 | if filt is None: 137 | # fname, fdataset = name, dataset 138 | fname = name 139 | else: 140 | fname = f"{name}_F{int(fidx)}" 141 | # fdataset = dataset.filter(partial(label_filter, tf.convert_to_tensor(np.array(filt)))) 142 | _eval_one_epoch(fname, dataset, train_state, filt=filt) 143 | 144 | 145 | def filtered_and_size(dataset, filt=None): 146 | dataset_size = dataset.cardinality().numpy() 147 | if False and filt is not None: 148 | fdataset = dataset.filter( 149 | partial(label_filter, tf.convert_to_tensor(np.array(filt))) 150 | ) 151 | fsize = dataset_size * np.sum(np.array(filt)) / len(filt) 152 | else: 153 | fdataset, fsize = dataset, dataset_size 154 | return fdataset, fsize 155 | 156 | 157 | def _eval_one_epoch(name, dataset, train_state, filt=None): 158 | fdataset, dataset_size = filtered_and_size(dataset, filt) 159 | key = PRNGKey(0) 160 | batched = fdataset.batch(wandb.config.eval_batch_size, drop_remainder=True) 161 | tasks = as_task_iter(batched) 162 | metrics = None 163 | items_seen = 0 164 | for task in tasks: 165 | items_seen += len(task.label) 166 | new_metrics = eval_batch(key, task, train_state) 167 | metrics = ( 168 | new_metrics if metrics is None else tree_map(jnp.add, metrics, new_metrics) 169 | ) 170 | is_leaf = lambda x: isinstance(x, CalibCount) 171 | 172 | def finalize(leaf): 173 | if isinstance(leaf, CalibCount): 174 | return leaf.plot() 175 | else: 176 | return leaf / items_seen 177 | 178 | # metrics = tree_map(finalize, metrics, is_leaf=is_leaf) 179 | name_metrics = dict() 180 | for key, value in metrics.items(): 181 | calib_count = value.pop("calib") 182 | name_metrics.update( 183 | {key: tree_map(lambda a: a / items_seen, copy.deepcopy(value))} 184 | ) 185 | # wandb.log({name: {key: tree_map(lambda a: a/items_seen, copy.deepcopy(value))}}, commit=False) 186 | plot_name = f"{name}_{key}_calib" 187 | if wandb.config.log_calibration: 188 | wandb.log({plot_name: calib_count.plot(title=plot_name)}, commit=False) 189 | # metrics = tree_map(lambda a: a/items_seen, metrics) 190 | wandb.log({name: name_metrics}, commit=False) 191 | 192 | 193 | def eval_final(name, dataset, train_state): 194 | eval_one_epoch(name, dataset, train_state) 195 | 196 | 197 | def get_train_weight(): 198 | if wandb.config.burn_in_period is None: 199 | w = 1.0 / wandb.config.temperature 200 | else: 201 | epoch = wandb.run.summary.get("epoch", 0) 202 | # total_epochs = wandb.config.epochs 203 | # centered = epoch - total_epochs/2 204 | 205 | period = wandb.config.burn_in_period 206 | initial = wandb.config.burn_in_initial_weight 207 | lninit = jnp.log(initial) 208 | t = jnp.minimum(epoch, period) / period 209 | lnw = lninit * (1.0 - t) 210 | w = jnp.exp(lnw) / wandb.config.temperature 211 | # scaled = 2 * width * centered / (total_epochs) 212 | # w = jax.nn.sigmoid(scaled) / wandb.config.temperature 213 | wandb.log(dict(train_weight=w), commit=False) 214 | return w 215 | 216 | 217 | def augment_img(img, key): 218 | if wandb.config.timnet_aug: 219 | return tiny_imagenet_augment_img(img, key) 220 | flip_key, trans_key = jax.random.split(key) 221 | img = pix.random_flip_left_right(flip_key, img) 222 | trans = jax.random.randint(trans_key, shape=(2,), minval=-5, maxval=5) 223 | affine = jnp.identity(4) 224 | affine = affine.at[:2, -1].set(trans) 225 | # affine = jnp.concatenate((jnp.identity(2), trans[:,None]), axis=-1) 226 | img = pix.affine_transform(img, affine) 227 | return img 228 | 229 | def tiny_imagenet_augment_img(img, key): 230 | initial_shape = img.shape 231 | flip_key, crop_key, rotate_key, scale_key, color_key, shortcut_key = jax.random.split(key, 6) 232 | img = pix.random_flip_left_right(flip_key, img) 233 | only_flipped = img 234 | padto = wandb.config.augment_timnet_padtowidth 235 | img = pix.pad_to_size(img, padto, padto) 236 | max_radians = 2*jnp.pi / 18 237 | if wandb.config.augment_timnet_rotate: 238 | angle_key, rbool_key = jax.random.split(rotate_key) 239 | angle = jax.random.uniform(angle_key, minval=-max_radians, maxval=max_radians) 240 | do_rotation = jax.random.bernoulli(rbool_key, p=wandb.config.augment_timnet_rotate_prob) 241 | img = jnp.where(do_rotation, pix.rotate(img, angle), img) 242 | if wandb.config.augment_timnet_scale: 243 | scale_key, sbool_key = jax.random.split(scale_key) 244 | scale = jnp.exp(jax.random.uniform(scale_key, minval=jnp.log(0.8), maxval=jnp.log(1.2))) 245 | do_scale = jax.random.bernoulli(sbool_key, p=wandb.config.augment_timnet_scale_prob) 246 | scaled = pix.affine_transform(img, jnp.array([scale, scale, 1.])) 247 | img = jnp.where(do_scale, scaled, img) 248 | img = pix.random_crop(crop_key, img, initial_shape) 249 | 250 | if wandb.config.augment_color: 251 | sat_key, bri_key, con_key, cbool_key = jax.random.split(color_key, 4) 252 | cimg = img 253 | cimg = pix.random_saturation(sat_key, cimg, 0.8, 1.3) 254 | cimg = pix.random_brightness(bri_key, cimg, 0.3) 255 | cimg = pix.random_contrast(con_key, cimg, 0.8, 1.3) 256 | do_color = jax.random.bernoulli(cbool_key, p=wandb.config.augment_timnet_color_prob) 257 | img = jnp.where(do_color, cimg, img) 258 | if (shortcut_prob := wandb.config.augment_timnet_shortcut_prob) != 0.: 259 | do_shortcut = jax.random.bernoulli(shortcut_key, p=shortcut_prob) 260 | img = jnp.where(do_shortcut, only_flipped, img) 261 | return img 262 | 263 | 264 | @jax.jit 265 | def augment_task(task, key): 266 | imgs = task.x 267 | imgs = jax.vmap(augment_img)(imgs, jax.random.split(key, len(imgs))) 268 | return task.replace(x=imgs) 269 | 270 | 271 | def train_one_epoch(stepper, key, dataset, train_state, filt=None): 272 | fdataset, dataset_size = filtered_and_size(dataset, filt) 273 | tasks = as_task_iter(fdataset.batch(wandb.config.batch_size, drop_remainder=True)) 274 | epoch_weight = get_train_weight() 275 | 276 | def check(tree): 277 | return jax.tree_util.tree_all( 278 | tree_map(lambda arr: jnp.isfinite(arr).all(), tree) 279 | ) 280 | 281 | for step, task in tqdm(enumerate(tasks)): 282 | # if not check(train_state): 283 | # exit() 284 | key, opt_key, augment_key = jax.random.split(key, 3) 285 | if wandb.config.augment_data: 286 | task = augment_task(task, augment_key) 287 | weights = jnp.ones(len(task.label)) / len(task.label) 288 | if wandb.config.batch_upweighting: 289 | weights = weights * dataset_size 290 | weights = weights * epoch_weight 291 | new_train_state = stepper(train_state, task, weights, opt_key) 292 | if ( 293 | not wandb.config.frozen_burn_in 294 | or (wandb.config.burn_in_period is None) or wandb.run.summary.get("epoch", 0) > wandb.config.burn_in_period 295 | ): 296 | train_state = new_train_state 297 | else: 298 | train_state = new_train_state.replace(params=train_state.params) 299 | 300 | def active_inputs(p): 301 | return jnp.sum(jax.vmap(jnp.linalg.norm, in_axes=-2)(p) != 0.) 302 | def activity_status(param, state): 303 | actual = int(active_inputs(param)) 304 | mask = int(jnp.sum(state.curv[0].mask)) 305 | print(f"{param.shape}: mask {mask} actual {actual}") 306 | #tree_map(activity_status, train_state.params, train_state.opt_state) 307 | return train_state 308 | 309 | 310 | def lossfn(label, y): 311 | label = jnp.mod(label, wandb.config.num_classes) 312 | lsm = jax.nn.log_softmax(y) 313 | label_log_prob = jnp.take_along_axis(lsm, label[..., None], axis=-1)[..., 0] 314 | soften = wandb.config.soften_lossfn 315 | label_log_prob = (1.0 - soften) * label_log_prob + soften * jnp.mean(lsm, axis=-1) 316 | return -label_log_prob 317 | 318 | 319 | def norm_numpy_img(npimg): 320 | return (npimg - jnp.array(wandb.config.dataset_mean)) / jnp.array( 321 | wandb.config.dataset_std 322 | ) 323 | 324 | 325 | def as_task_iter(dataset): 326 | def as_task(item): 327 | xs, labels = item 328 | xs = norm_numpy_img(xs) 329 | return Task(xs, labels, lossfn) 330 | 331 | return map(as_task, tfds.as_numpy(dataset)) 332 | 333 | 334 | def init_prune(train_state): 335 | return train_state.init_prune() 336 | # params = train_state.tx.init_prune_params(train_state.params) 337 | # opt_state = train_state.tx.init_prune_opt_state( 338 | # train_state.params, train_state.opt_state 339 | # ) 340 | # return train_state.replace(params=params, opt_state=opt_state) 341 | 342 | def reset_output_layer(train_state): 343 | params = train_state.params 344 | kernel = params["final"]["kernel"] 345 | new_shape = kernel.shape[:-1] + (wandb.config.num_classes,) 346 | new_kernel = jnp.zeros(new_shape, dtype=kernel.dtype) 347 | train_state = train_state.replace( 348 | params=params.copy( 349 | dict(final=params["final"].copy( 350 | dict(kernel=new_kernel) 351 | )) 352 | ) 353 | ) 354 | return train_state.tx_reinit_changed_shapes() 355 | 356 | 357 | def main(): 358 | exp_name = sys.argv[1] if len(sys.argv) > 1 else None 359 | if exp_name is not None: exp_name = str(exp_name) 360 | wandb_project = "senn_transfer" 361 | if len(sys.argv) > 2 and sys.argv[2] == "resume": 362 | # interpret exp_name as a run id 363 | wandb.init(project=wandb_project, config={}, id=exp_name, resume="must") 364 | else: 365 | wandb.init(project=wandb_project, config={}, name=exp_name) 366 | 367 | wandb.config.epochs = 100 368 | wandb.config.epoch_len = 1562 369 | wandb.config.dataset = "TinyImagenetDataset" 370 | wandb.config.resize_to = [64, 64] 371 | wandb.config.dataset_mean = 0.5 * 255 372 | wandb.config.dataset_std = 0.25 * 255 373 | wandb.config.augment_data = True 374 | wandb.config.augment_color = True 375 | wandb.config.augment_timnet_shortcut_prob = 0.5 376 | wandb.config.augment_timnet_color_prob = 0.3 377 | wandb.config.augment_timnet_rotate = True 378 | wandb.config.augment_timnet_rotate_prob = 0.3 379 | wandb.config.augment_timnet_scale = True 380 | wandb.config.augment_timnet_scale_prob = 0.3 381 | wandb.config.augment_timnet_padtowidth = 128 382 | wandb.config.timnet_aug = True 383 | wandb.config.train_split = "train" 384 | all_pass = jnp.ones(shape=(10,), dtype=jnp.bool_) 385 | none_pass = ~all_pass 386 | label_filters = [ 387 | none_pass.at[:2].set(True), 388 | none_pass.at[2:4].set(True), 389 | none_pass.at[4:6].set(True), 390 | none_pass.at[6:8].set(True), 391 | none_pass.at[8:].set(True), 392 | ] 393 | label_filters = [all_pass] 394 | wandb.config.train_filters = label_filters 395 | wandb.config.eval_splits = ["train[:10%]", "validation"] 396 | #wandb.config.eval_splits = ["train[:10%]", "test"] 397 | wandb.config.eval_filters = [None] 398 | wandb.config.num_classes = 200 399 | #wandb.config.num_classes = 10 400 | # wandb.config.model_type = "SmallConvNet" 401 | # wandb.config.model_kwargs = dict(hidden=64, depth=3, multiplicity=1, use_bias=False, dense_size=64) 402 | # wandb.config.model_type = "Perceptron" 403 | # wandb.config.model_kwargs = dict(hidden=32, depth=1, use_bias=False, flatten_last_n=3) 404 | # wandb.config.model_type = "BasicResNet" 405 | # wandb.config.model_kwargs = dict(stage_sizes=[2, 2, 2]) 406 | # wandb.config.model_type = "AllCnnC" 407 | # wandb.config.model_kwargs = dict(early_channels=1*96, late_channels=1*192, fake_batch_norm=True) 408 | # wandb.config.model_notes = "deleted_conv_4" 409 | # wandb.config.model_type = "DenseNet" 410 | # wandb.config.model_kwargs = dict(depth=2, growth=2**7, init_conv=False) 411 | # wandb.config.model_type = "WaveletNet" 412 | # wandb.config.model_kwargs = dict(hidden=64, depth=3, level=1) 413 | # wandb.config.model_type = "BottleneckDense" 414 | # wandb.config.model_kwargs = dict(blocks=3, width=128, extra_depth=0, maybe_bud_width=16) 415 | # config is special cased for ExpandableDense 416 | wandb.config.model_type = "ExpandableDense" 417 | IW = 128 418 | W0 = 1 419 | WN = 1 420 | wandb.config.bud_width = None 421 | #wandb.config.model_kwargs = dict(widths=([IW]*W0,) + ([IW] * WN,) * 2, maybe_bud_width=wandb.config.bud_width) 422 | WLIST = [3, 3, 3, 3] 423 | widths = tuple([IW]*w for w in WLIST) 424 | #widths = ([459], [503, 120], [499, 28]) 425 | wandb.config.model_kwargs = dict(widths=widths, maybe_bud_width=wandb.config.bud_width) 426 | #wandb.config.model_type = "WaveNet" 427 | #wandb.config.model_kwargs = dict(widths=[32]*3, levels=4) 428 | 429 | wandb.config.use_batchnorm = False 430 | wandb.config.use_dropout = False 431 | wandb.config.batch_size = 64 432 | #wandb.config.batch_size = 1024 433 | wandb.config.batch_upweighting = True 434 | wandb.config.eval_batch_size = 64 435 | #wandb.config.eval_batch_size = 1024 436 | wandb.config.eval_samples = 16 437 | wandb.config.lr = 1e-3 438 | wandb.config.add_unit_normal_curvature = True 439 | wandb.config.grad_update_as_curvature = False 440 | wandb.config.grad_as_curvature = False 441 | wandb.config.root_of_grad_for_curvature = False 442 | wandb.config.grad_curvature_mul = 1e-1 443 | wandb.config.model_seed = 0 444 | wandb.config.train_seed = 0 445 | wandb.config.varinf = False 446 | wandb.config.paired_varinf = False 447 | wandb.config.hess_only_varinf = True 448 | wandb.config.per_example_varinf_grad = False 449 | wandb.config.per_example_varinf = False 450 | wandb.config.varinf_sample_scale = 1e0 451 | wandb.config.model_nonlinearity = "swish" 452 | wandb.config.curvature = "UBAH" 453 | wandb.config.ubah_mul = 1e0 454 | wandb.config.temperature = 1e1**-0.0 455 | wandb.config.burn_in_period = None 456 | wandb.config.burn_in_initial_weight = 1e-3 457 | wandb.config.frozen_burn_in = True 458 | # wandb.config.whitener = "Masked" 459 | # wandb.config.whitener_diag_fraction = 0e-1 460 | # wandb.config.initial_precision = 1e1 461 | wandb.config.soften_lossfn = 0e-1 462 | # wandb.config.init_prior_precision = 1e1 463 | wandb.config.log_calibration = False 464 | 465 | wandb.config.taylor_order = None 466 | 467 | # OPTIMIZER 468 | wandb.config.peak_learning_rate = 3e-4 469 | wandb.config.noise_std = 0e-4 470 | wandb.config.weight_decay = 1e-2 471 | wandb.config.pct_start = 0.1 472 | wandb.config.linear_annealing = False 473 | wandb.config.num_cycles = 1 474 | 475 | # SIZE ADAPTION 476 | wandb.config.freeze_thaw_disable = True 477 | wandb.config.freeze_is_prune = True 478 | wandb.config.expansion_lower_bound = True 479 | wandb.config.freeze_thresh = 0.0 480 | wandb.config.thaw_thresh = 0.0 481 | wandb.config.freeze_thresh_rel = 3e-4 482 | wandb.config.thaw_thresh_rel = 3e-4 483 | wandb.config.thaw_prob_size_compensate = True 484 | wandb.config.minimum_width = 8 485 | wandb.config.maximum_width = 512 486 | wandb.config.ignore_width = wandb.config.bud_width if wandb.config.bud_width is not None else 0 487 | # wandb.config.bud_width = 16 488 | wandb.config.init_prune_to_min_width = not wandb.config.freeze_thaw_disable 489 | wandb.config.expansion_min_step = ( 490 | 0.00 * wandb.config.epochs * wandb.config.epoch_len 491 | ) 492 | wandb.config.expansion_max_step = ( 493 | 1.0 * wandb.config.epochs * wandb.config.epoch_len 494 | ) 495 | wandb.config.pruned_lr_rescale = False 496 | 497 | wandb.config.untouched_thresh = 10 498 | wandb.config.reinit_prob = 1e-1 499 | wandb.config.add_width_thresh = 0.2 500 | wandb.config.add_width_factor = 0.2 501 | 502 | wandb.config.enable_reinit = False 503 | wandb.config.enable_width_expansion = False 504 | wandb.config.enable_depth_expansion = False 505 | wandb.config.depth_score_max_k = 64 506 | wandb.config.min_epoch_for_depth_expansion = 0 507 | wandb.config.depth_score_add_to_current_score = 1e0 508 | #wandb.config.depth_score_addition_thresh = 1e1 509 | wandb.config.depth_score_abs_thresh = 0e1 510 | wandb.config.depth_score_rel_thresh = wandb.config.depth_score_max_k * 1e0 * wandb.config.thaw_thresh_rel 511 | wandb.config.block_size_hard_cap = 4 512 | 513 | wandb.config.use_global_expansion_score = True 514 | wandb.config.global_score_is_max_not_sum = False 515 | 516 | wandb.config.only_max_a_posteriori_eval = True 517 | wandb.config.iroot_error_warn = False 518 | # Minimum computation necessary for standard baseline training: 519 | wandb.config.fast = False 520 | if wandb.config.fast: 521 | assert wandb.config.curvature == "EMP_FISH" 522 | assert wandb.config.freeze_thaw_disable 523 | assert not wandb.config.enable_width_expansion 524 | assert not wandb.config.enable_depth_expansion 525 | assert wandb.config.bud_width is None 526 | assert not wandb.config.enable_reinit 527 | wandb.config.fast_turbo = False 528 | if wandb.config.fast_turbo: 529 | assert wandb.config.fast 530 | assert not wandb.config.varinf 531 | assert not wandb.config.log_calibration 532 | 533 | # Load previously saved 'final' as initial state 534 | #wandb.config.load_as_init = "/senn/orbax/cifar10_128_final" 535 | #wandb.config.load_as_init = "/senn/orbax/lif1r5r6/final" 536 | wandb.config.load_as_init = CHECKPOINT_DIR 537 | 538 | rtpt_initials = os.environ.get("RTPT_INITIALS") 539 | assert rtpt_initials is not None 540 | total_epochs = wandb.config.epochs * len(wandb.config.train_filters) 541 | rtpt = RTPT( 542 | name_initials=rtpt_initials, 543 | experiment_name=wandb.run.name, 544 | max_iterations=total_epochs, 545 | ) 546 | 547 | import config 548 | 549 | trainset = config.trainset 550 | evalset_dict = config.evalset_dict 551 | 552 | stepper = Stepper( 553 | wandb.config.varinf, 554 | wandb.config.paired_varinf, 555 | wandb.config.hess_only_varinf, 556 | wandb.config.per_example_varinf_grad, 557 | wandb.config.per_example_varinf, 558 | wandb.config.varinf_sample_scale, 559 | grad_hgrad=universal_grad_hgrad, 560 | ) 561 | train_state = config.initial_train_state 562 | 563 | if wandb.config.init_prune_to_min_width: 564 | train_state = train_state.init_prune() 565 | # train_state = init_prune(train_state) 566 | 567 | run_id = wandb.run.id 568 | checkpoint_dir = f"/tmp/flax_ckpt/orbax/{run_id}/managed" 569 | final_checkpoint_dir = f"/senn/orbax/{run_id}/final" 570 | checkpointer = PyTreeCheckpointer() 571 | checkpoint_manager_opts = CheckpointManagerOptions( 572 | max_to_keep=2, 573 | create=True, 574 | ) 575 | checkpoint_manager = CheckpointManager( 576 | checkpoint_dir, 577 | checkpointer, 578 | checkpoint_manager_opts, 579 | ) 580 | #ckpt_save_args = orbax_utils.save_args_from_target(train_state) 581 | def final_save(state): 582 | ckpt_save_args = orbax_utils.save_args_from_target(state) 583 | checkpointer.save(final_checkpoint_dir, state, save_args=ckpt_save_args) 584 | def epoch_save(epoch, state): 585 | ckpt_save_args = orbax_utils.save_args_from_target(state) 586 | checkpoint_manager.save(epoch, state, save_kwargs=dict(save_args=ckpt_save_args)) 587 | 588 | if (load_dir := wandb.config.load_as_init) is not None: 589 | train_state = checkpointer.restore(load_dir, item=train_state) 590 | train_state = reset_output_layer(train_state) 591 | 592 | latest_ckpt_step = checkpoint_manager.latest_step() 593 | if latest_ckpt_step is not None: 594 | train_state = checkpoint_manager.restore(latest_ckpt_step, items=train_state) 595 | 596 | key = PRNGKey(wandb.config.train_seed) 597 | rtpt.start() 598 | def log_metrics(epoch): 599 | metrics = train_state.get_metrics() 600 | metrics = tree_map(np.array, metrics).unfreeze() 601 | wandb.log(metrics, commit=False) 602 | for name, evalset in evalset_dict.items(): 603 | eval_one_epoch(name, evalset, train_state) 604 | wandb.log(dict(epoch=epoch), commit=True) 605 | if checkpoint_manager.latest_step() is None: 606 | log_metrics(-1) 607 | for fidx, filt in enumerate(wandb.config.train_filters): 608 | for epoch, key in enumerate(jax.random.split(key, wandb.config.epochs)): 609 | if (ckpt_step := checkpoint_manager.latest_step()) is not None: 610 | if epoch <= ckpt_step: continue 611 | reinit_key, width_key, key = jax.random.split(key, 3) 612 | #metrics = train_state.get_metrics() 613 | #metrics = tree_map(np.array, metrics).unfreeze() 614 | #wandb.log(metrics, commit=False) 615 | #for name, evalset in evalset_dict.items(): 616 | # eval_one_epoch(name, evalset, train_state) 617 | if wandb.config.enable_reinit: 618 | train_state = train_state.maybe_reinit(reinit_key) 619 | train_state = train_one_epoch( 620 | stepper, key, trainset, train_state, filt=filt 621 | ) 622 | log_metrics(epoch) 623 | #wandb.log(dict(epoch=epoch), commit=True) 624 | rtpt.step() 625 | 626 | epoch_save(epoch, train_state) 627 | 628 | old_train_state = train_state 629 | if wandb.config.enable_width_expansion: 630 | train_state = train_state.maybe_expand_width( 631 | key=width_key, builder=config.build_from_widths 632 | ) 633 | width_was_added = tree_map(jnp.shape, train_state) != tree_map(jnp.shape, old_train_state) 634 | if wandb.config.enable_depth_expansion and not width_was_added: 635 | if epoch >= wandb.config.min_epoch_for_depth_expansion: 636 | train_state = train_state.maybe_insert_layer( 637 | builder=config.build_from_widths 638 | ) 639 | final_save(train_state) 640 | train_state = train_state.pin_prior() 641 | for name, evalset in evalset_dict.items(): 642 | eval_final(name, evalset, train_state) 643 | eval_final("train_full", trainset, train_state) 644 | wandb.log(dict(epoch=epoch + 1), commit=True) 645 | 646 | 647 | if __name__ == "__main__": 648 | main() 649 | -------------------------------------------------------------------------------- /senn_cnn/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "senn" 7 | description = "Dynamically sized neural networks for JAX/Flax" 8 | authors = [ 9 | {name = "wittnus", email = "mail@wittnus.me"}, 10 | ] 11 | dependencies = [ 12 | "jax>=0.4.13", 13 | "tensorflow-probability==0.20.1", 14 | "compose>=1.4.8", 15 | "scikit-learn>=1.3.0", 16 | "torch>=2.0.1", 17 | "torchvision>=0.15.2", 18 | "tensorflow-datasets>=4.9.2", 19 | "tqdm>=4.65.0", 20 | "wandb>=0.15.5", 21 | "rtpt>=0.0.4", 22 | "absl-py>=1.4.0", 23 | "dm-pix>=0.4.1", 24 | "jaxwt>=0.1.0", 25 | "PyWavelets>=1.4.1", 26 | "dtcwt>=0.12.0", 27 | "tiny-imagenet-tfds @ git+https://github.com/rmenzenbach/tiny-imagenet-tfds.git", 28 | "flax==0.7.0", 29 | "etils[epath]>=1.3.0", 30 | "optax==0.1.7" 31 | ] 32 | requires-python = ">=3.8" 33 | readme = "README.md" 34 | license = {text = "MIT"} 35 | classifiers = [ 36 | "Programming Language :: Python :: 3", 37 | "License :: OSI Approved :: MIT License", 38 | "Operating System :: OS Independent", 39 | ] 40 | dynamic = ["version"] 41 | 42 | [project.optional-dependencies] 43 | cuda = [ 44 | ] 45 | [tool.pdm] 46 | version = {use_scm = true} 47 | 48 | [tool.pdm.dev-dependencies] 49 | test = [ 50 | "pytest>=7.3.1", 51 | ] 52 | dev = [ 53 | "black>=23.3.0", 54 | ] 55 | 56 | [[tool.pdm.source]] 57 | url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 58 | name = "jax_cuda_releases" 59 | type = "find_links" 60 | 61 | [[tool.pdm.source]] 62 | url = "https://download.pytorch.org/whl/cpu" 63 | name = "pytorch" 64 | type = "find_links" 65 | 66 | [tool.pdm.scripts] 67 | pre_example = "docker compose up --no-recreate -d" 68 | example.cmd = "docker compose exec -e CUDA_VISIBLE_DEVICES -e JAX_PLATFORMS -w /senn/examples main python" 69 | test.cmd = "docker compose exec -e CUDA_VISIBLE_DEVICES='' -e JAX_PLATFORMS=cpu -w /senn main python -m pytest" 70 | exec.cmd = "docker compose exec -e CUDA_VISIBLE_DEVICES -e JAX_PLATFORMS main python" 71 | black.cmd = "docker compose exec -w /senn main python -m black ." 72 | 73 | [tool.pytest.ini_options] 74 | addopts = [ 75 | "--import-mode=importlib", 76 | ] 77 | [tool.setuptools.packages.find] 78 | exclude = ["orbax"] 79 | -------------------------------------------------------------------------------- /senn_cnn/senn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/self-expanding-neural-networks/3480be01cbbfa46726af3a84dd9fb834d1ca979e/senn_cnn/senn/__init__.py -------------------------------------------------------------------------------- /senn_cnn/senn/dummy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import flax 3 | 4 | 5 | def hi(): 6 | print(np.ones((3, 2))) 7 | print(flax.__version__) 8 | print("hi") 9 | -------------------------------------------------------------------------------- /senn_cnn/senn/linalg.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from compose import compose 3 | from abc import ABC, abstractmethod 4 | from typing import Any 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from tensorflow_probability.substrates import jax as tfp 9 | from flax import struct 10 | 11 | import wandb 12 | 13 | 14 | def direct_update(M, v, multiplier=1.0): 15 | return M + jnp.outer(v, v) * multiplier 16 | 17 | 18 | def inv_update(M, v, multiplier=1.0, soln=None): 19 | EPS = 1e-12 20 | Mv = M @ v 21 | denom = 1 + jnp.inner(v, Mv.conj()) * multiplier 22 | new_mult = -multiplier / (denom + EPS) 23 | if soln is None: 24 | return direct_update(M, Mv, multiplier=new_mult) 25 | else: 26 | delta_soln = Mv @ soln 27 | delta_soln = jnp.outer(Mv, delta_soln) * new_mult 28 | new_soln = soln + delta_soln 29 | return direct_update(M, Mv, multiplier=new_mult), new_soln 30 | 31 | 32 | def chol_update(M, v, multiplier=1.0): 33 | return tfp.math.cholesky_update(M, v, multiplier=multiplier) 34 | 35 | 36 | def ichol_update(M, v, multiplier=1.0): 37 | MTv = M.T @ v 38 | denom = 1 + jnp.inner(MTv, MTv.conj()) * multiplier 39 | new_mult = -multiplier / denom 40 | return chol_update(M, M @ MTv, multiplier=new_mult) 41 | 42 | 43 | class SecondMoment(struct.PyTreeNode): 44 | direct: jax.Array 45 | inv: jax.Array 46 | chol: jax.Array 47 | ichol: jax.Array 48 | 49 | def scale_by(self, scale): 50 | return self.replace( 51 | direct=self.direct * scale, 52 | inv=self.inv / scale, 53 | chol=self.chol * jnp.sqrt(scale), 54 | ichol=self.ichol / jnp.sqrt(scale), 55 | ) 56 | 57 | def rank_one_update(self, v, multiplier=1.0, decay=None, soln=None): 58 | scale = 1.0 if decay is None else decay 59 | multiplier = multiplier if decay is None else (1.0 - decay) * multiplier 60 | inv, soln = inv_update(self.inv / scale, v, multiplier=multiplier, soln=soln) 61 | newmom = SecondMoment( 62 | direct_update(self.direct * scale, v, multiplier=multiplier), 63 | inv, 64 | chol_update(self.chol * jnp.sqrt(scale), v, multiplier=multiplier), 65 | ichol_update(self.ichol / jnp.sqrt(scale), v, multiplier=multiplier), 66 | ) 67 | return newmom if soln is None else newmom, soln 68 | 69 | def init_identity(size): 70 | return SecondMoment(*(jnp.identity(size),) * 4) 71 | 72 | 73 | class Whitener(struct.PyTreeNode): 74 | iroot: Any 75 | eps: float = 1e-12 76 | mag_cap: float = 1e3 77 | 78 | # for compatibility with senn.opt.HessTracker 79 | def init(self, x, initial_precision=1.0): 80 | return self.init_identity(x.size).rescale(initial_precision**2) 81 | 82 | @classmethod 83 | def init_identity(cls, size, **kwargs): 84 | return cls(iroot=jnp.identity(size), **kwargs) 85 | 86 | def rescale(self, factor): 87 | # trace = jnp.sum(jnp.square(jnp.abs(self.iroot)), axis=-2) 88 | new_iroot = self.iroot * jnp.reciprocal(jnp.sqrt(factor)) 89 | # mag_cap = wandb.config.get("inv_root_curvature_diag_elem_cap", self.mag_cap) 90 | # new_iroot = jnp.where(trace[None,:] < self.mag_cap, self.iroot, new_iroot) 91 | return self.replace(iroot=new_iroot) 92 | # return self.replace(iroot=self.iroot*jnp.reciprocal(jnp.sqrt(factor))) 93 | 94 | def trace_inv(self): 95 | return jnp.sum(jnp.square(jnp.abs(self.iroot))) 96 | 97 | def diag_inv(self): 98 | return jnp.sum(jnp.square(jnp.abs(self.iroot)), axis=-1) 99 | 100 | def iroot_mul(self, tangents): 101 | return tangents @ self.iroot.T 102 | 103 | def whiten(self, tangents): 104 | return tangents @ self.iroot 105 | 106 | def solve(self, tangents): 107 | return self.iroot_mul(self.whiten(tangents)) 108 | 109 | def w_solve(self, whites): 110 | return self.iroot_mul(whites) 111 | 112 | def w_rank_n_update(self, whites): 113 | def update(carry, white): 114 | return carry.w_rank_one_update(white), None 115 | 116 | out, _ = jax.lax.scan(update, self, whites) 117 | return out 118 | 119 | def rank_n_update(self, vecs): 120 | return self.w_rank_n_update(self.whiten(vecs)) 121 | 122 | @abstractmethod 123 | def w_rank_one_update(self, white): 124 | raise NotImplementedError 125 | 126 | def rank_one_update(self, vec): 127 | return self.w_rank_one_update(self.whiten(vec)) 128 | 129 | 130 | class LinearWhitener(Whitener): 131 | def multiplier(self, white_sqnorm): 132 | x = -white_sqnorm * jnp.reciprocal(1.0 + white_sqnorm) 133 | x = jnp.expm1(0.5 * jnp.log1p(jnp.maximum(-1.0, x))) 134 | x = x * jnp.reciprocal(self.eps + white_sqnorm) 135 | jax.lax.cond( 136 | jnp.isfinite(x).all(), 137 | lambda a: None, 138 | lambda a: jax.debug.print("multiplier error with sqnorm {}", white_sqnorm), 139 | x, 140 | ) 141 | return x 142 | 143 | @abstractmethod 144 | def _get_factor(self, whites): 145 | pass 146 | 147 | def _iroot_update(self, factor): 148 | identity = jnp.identity(self.iroot.shape[-1]) 149 | # new_iroot = self.iroot @ (identity + factor) 150 | new_iroot = self.iroot + self.iroot @ factor 151 | return self.replace(iroot=new_iroot) 152 | 153 | def w_rank_n_update(self, whites): 154 | return self._iroot_update(self._get_factor(whites)) 155 | 156 | def w_rank_one_update(self, white): 157 | return self.w_rank_n_update(white[None, ...]) 158 | 159 | def w_rank_n_inv_update(self, whites): 160 | wmag = jnp.sum(jnp.square(jnp.abs(whites))) 161 | # jax.debug.print("inv_update wmag: {}", wmag) 162 | # jax.lax.cond( 163 | # jnp.isfinite(wmag), 164 | # lambda x: None, 165 | # lambda x: jax.debug.breakpoint(), 166 | # None) 167 | 168 | mul = jnp.expm1(0.5 * jnp.log1p(wmag)) 169 | mul = mul * jnp.reciprocal(wmag + self.eps) 170 | factor = mul * whites.T @ whites 171 | # jax.lax.cond(jnp.isfinite(factor).all(), lambda x: None, lambda x: jax.debug.breakpoint(), 0) 172 | return self._iroot_update(factor) 173 | 174 | 175 | class IRootWhitener(LinearWhitener): 176 | def _get_factor(self, whites): 177 | wmag = jnp.sum(jnp.inner(whites, whites)) 178 | return self.multiplier(wmag) * whites.T @ whites 179 | 180 | 181 | class DiagWhitener(LinearWhitener): 182 | def _get_factor(self, whites): 183 | mags = jax.vmap(lambda w: jnp.sum(jnp.inner(w, w)), in_axes=-1, out_axes=-1)( 184 | whites 185 | ) 186 | return jnp.diag(self.multiplier(mags) * mags) 187 | 188 | 189 | class HybridWhitener(IRootWhitener, DiagWhitener): 190 | diag_fraction: float = 0.8 191 | 192 | def check_iroot_finite(self, note): 193 | jax.lax.cond( 194 | jnp.isfinite(self.iroot).all(), 195 | lambda x: None, 196 | lambda x: jax.debug.print(note), 197 | 0, 198 | ) 199 | 200 | def w_rank_n_update(self, whites): 201 | # FOR DEBUGGING: 202 | # return self._iroot_update(IRootWhitener._get_factor(self, whites)) 203 | # END DEBUGGING 204 | 205 | out = self 206 | diag_whites = jnp.sqrt(self.diag_fraction) * whites 207 | diag_factor = DiagWhitener._get_factor(self, diag_whites) 208 | out = out._iroot_update(diag_factor) 209 | out.check_iroot_finite("nonfinite iroot after diag update") 210 | identity = jnp.identity(self.iroot.shape[-1]) 211 | # iroot_whites = jnp.sqrt(1. - self.diag_fraction)*whites@(identity + diag_factor) 212 | iroot_whites = jnp.sqrt(1.0 - self.diag_fraction) * whites 213 | iroot_whites = iroot_whites + iroot_whites @ diag_factor 214 | iroot_factor = IRootWhitener._get_factor(self, iroot_whites) 215 | out = out._iroot_update(iroot_factor) 216 | out.check_iroot_finite("nonfinite iroot after non-diag update") 217 | return out 218 | 219 | 220 | class MaskedWhitener(HybridWhitener): 221 | """Important: do not call w_rank_one_update or w_rank_n_update since this class needs access to unwhitened updates.""" 222 | 223 | direct: Any = None 224 | mask: Any = None 225 | 226 | def rescale(self, factor): 227 | out = super().rescale(factor) 228 | return out.replace(direct=out.direct * factor) 229 | 230 | def _direct_from_iroot(self): 231 | return jnp.linalg.inv(self.iroot @ self.iroot.T) 232 | 233 | def _recompute_direct(self): 234 | return self.replace(direct=self._direct_from_iroot()) 235 | 236 | def init(self, x, *args, **kwargs): 237 | out = super().init(x, *args, **kwargs) 238 | out = out.replace(mask=jnp.ones(x.size, dtype=jnp.bool_)) 239 | return out._recompute_direct() 240 | 241 | @classmethod 242 | def init_identity(cls, size, **kwargs): 243 | kwargs = {"mask": jnp.ones(size, dtype=jnp.bool_), **kwargs} 244 | out = super(MaskedWhitener, cls).init_identity(size, **kwargs) 245 | return out._recompute_direct() 246 | 247 | def rank_one_update(self, vec): 248 | return self.rank_n_update(vec[None, ...]) 249 | 250 | def rank_n_update(self, vecs): 251 | out = super().rank_n_update(vecs) 252 | delta_direct = vecs.T @ vecs 253 | # adjust for diag_fraction to make updates consistent with direct 254 | delta_direct = ( 255 | delta_direct * (1.0 - self.diag_fraction) 256 | + jnp.diag(jnp.diag(delta_direct)) * self.diag_fraction 257 | ) 258 | out = out.replace(direct=out.direct + delta_direct) 259 | jax.lax.cond( 260 | jnp.isfinite(self.iroot).all(), 261 | lambda x: None, 262 | lambda x: jax.debug.print("nonfinite iroot"), 263 | 0, 264 | ) 265 | return out.maybe_reset_iroot("rank_n_update") 266 | 267 | def direct_mul(self, vecs): 268 | return vecs @ self.direct 269 | 270 | def reset_iroot(self): 271 | D = len(self.direct) 272 | mdirect = jnp.where( 273 | self.mask[:, None] & self.mask, self.direct, jnp.identity(D) 274 | ) 275 | chol = jnp.linalg.cholesky(mdirect) 276 | chol = jnp.where(self.mask[:, None] & self.mask, chol, jnp.zeros_like(chol)) 277 | iroot = jnp.linalg.pinv(chol.T) 278 | # pinv = jnp.linalg.pinv(self.direct) 279 | # iroot = jnp.linalg.cholesky(pinv) 280 | return self.replace(iroot=iroot) 281 | 282 | def maybe_reset_iroot(self, where): 283 | D = len(self.iroot) 284 | error = self.iroot.T @ self.direct @ self.iroot - jnp.identity(D) 285 | error_norm = jnp.max(jnp.abs(error)) 286 | should_reset = (error_norm > 1e0) | (~jnp.isfinite(error_norm).all()) 287 | 288 | def do_reset(error_norm): 289 | if wandb.config.iroot_error_warn: 290 | jax.debug.print( 291 | "iroot error norm reached {} in " + where + ", recalculating...", 292 | error_norm, 293 | ) 294 | # jax.debug.print("iroot finite: {}", jnp.isfinite(self.iroot).all()) 295 | # jax.debug.print("direct finite: {}", jnp.isfinite(self.direct).all()) 296 | # jax.debug.print("error finite: {}", jnp.isfinite(error).all()) 297 | # jax.debug.print("direct: {}", self.direct) 298 | return self.reset_iroot() 299 | 300 | def do_nothing(error_norm): 301 | return self 302 | 303 | return jax.lax.cond(should_reset, do_reset, do_nothing, error_norm) 304 | 305 | def kill_latent(self, idx): 306 | # jax.debug.print("iroot sandwich {}",self.iroot.T @ self.direct @ self.iroot) 307 | # remove redundant column from iroot at idx 308 | true_at_idx = jnp.arange(len(self.mask)) == idx 309 | ir_col = jax.lax.dynamic_index_in_dim(self.iroot, idx, axis=-1) 310 | # could substitute lines below with a preconditioned solve? 311 | Dmul = self.direct @ ir_col 312 | Dmul = jnp.where(self.mask[:, None], Dmul, 0.0) 313 | mag = jnp.sum(ir_col.T @ Dmul) 314 | scale = jnp.reciprocal(jnp.maximum(1e-3, 1.0 - mag)) 315 | # scale = 1e-6 316 | new_iroot = jnp.where(true_at_idx[None, :], 0.0, self.iroot) 317 | # new_var = new_iroot.T @ Dmul * scale 318 | precon = new_iroot.T @ (self.direct + Dmul @ Dmul.T * scale) 319 | # new_var = jnp.linalg.lstsq(new_iroot, ir_col)[0] 320 | new_var = jax.scipy.sparse.linalg.gmres( 321 | new_iroot, 322 | ir_col, 323 | M=precon, 324 | maxiter=1, 325 | restart=20, 326 | solve_method="incremental", 327 | )[0] 328 | return self.replace(iroot=new_iroot).w_rank_n_inv_update(new_var.T) 329 | 330 | def freeze(self, idx): 331 | # jax.debug.print("pre-freeze iroot sandwich {}",jnp.diag(self.iroot.T @ self.direct @ self.iroot)) 332 | ir_row = jax.lax.dynamic_index_in_dim(self.iroot, idx) # [1, N] 333 | ir_row_mag = jnp.sum(jnp.square(jnp.abs(ir_row))) 334 | normed = ir_row / (jnp.sqrt(ir_row_mag + self.eps)) 335 | delta = (self.iroot @ normed.T) @ normed 336 | true_at_idx = jnp.arange(len(self.mask)) == idx 337 | new_iroot = self.iroot - delta 338 | new_iroot = jnp.where(true_at_idx[:, None], 0.0, new_iroot) 339 | # STILL NEED TO REINSERT VARIANCE FROM FOLLOWING LINE 340 | # extra_var = jnp.where(true_at_idx, 0., ir_row) 341 | # THIS PROBABLY IS WRONG: 342 | # extra_wvar = jnp.where(true_at_idx, 0., self.direct @ ir_row) 343 | # new_iroot = jnp.where(true_at_idx[None, :], 0., new_iroot) 344 | new_mask = self.mask & ~true_at_idx 345 | new_whitener = self.replace(iroot=new_iroot, mask=new_mask) 346 | jax.lax.cond( 347 | jnp.isfinite(new_iroot).all(), 348 | lambda x: None, 349 | lambda x: jax.debug.print("freeze error"), 350 | 0, 351 | ) 352 | # remove redundant column of self.iroot 353 | new_whitener = new_whitener.kill_latent(idx) 354 | jax.lax.cond( 355 | jnp.isfinite(new_whitener.iroot).all(), 356 | lambda x: None, 357 | lambda x: jax.debug.print("kill latent error"), 358 | 0, 359 | ) 360 | return new_whitener.maybe_reset_iroot("freeze") 361 | 362 | def freeze_many(self, where): 363 | cond_fn = lambda tup: jnp.any(tup[0]) 364 | 365 | def body_fn(tup): 366 | where, out = tup 367 | idx = jnp.argmax(where) 368 | out = out.freeze(idx) 369 | where = where & ~(jnp.arange(len(where)) == idx) 370 | return where, out 371 | 372 | where, out = jax.lax.while_loop(cond_fn, body_fn, (where, self)) 373 | return out 374 | 375 | def thaw(self, idx): 376 | dir_row = jax.lax.dynamic_index_in_dim(self.direct, idx) 377 | true_at_idx = jnp.arange(len(self.mask)) == idx 378 | new_col = -self.iroot @ (self.iroot.T @ dir_row.T) 379 | new_col = jnp.where(true_at_idx[:, None], 1.0, new_col) 380 | dir_elem = jax.lax.dynamic_index_in_dim(dir_row, idx, axis=1) 381 | 382 | dir_elem = jnp.maximum(dir_elem + dir_row @ new_col, 1e-3 * dir_elem) 383 | dir_elem = jnp.maximum(dir_elem, self.eps) 384 | 385 | col_scale = jnp.reciprocal(jnp.sqrt(jnp.abs(dir_elem))) 386 | new_iroot = jnp.where(true_at_idx[None, :], new_col * col_scale, self.iroot) 387 | 388 | # iroot_row = jnp.reciprocal(jnp.sqrt(jnp.abs(dir_row)) + self.eps) 389 | # only_diag = jnp.where(true_at_idx[None, :], iroot_row, jnp.zeros_like(iroot_row)) 390 | # new_iroot = jnp.where(true_at_idx[:, None], only_diag, self.iroot) 391 | new_mask = self.mask | true_at_idx 392 | # jax.lax.cond(jnp.isfinite(new_iroot).all(), lambda x: None, lambda x: jax.debug.breakpoint(), 0) 393 | jax.lax.cond( 394 | jnp.isfinite(new_iroot).all(), 395 | lambda x: None, 396 | lambda x: jax.debug.print("thaw error"), 397 | 0, 398 | ) 399 | health = jnp.max(jnp.diag(new_iroot.T @ self.direct @ new_iroot)) 400 | jax.lax.cond( 401 | health < 1e3, 402 | lambda x: None, 403 | lambda x: jax.debug.print("WARN: thaw health"), 404 | 0, 405 | ) 406 | return self.replace(iroot=new_iroot, mask=new_mask).maybe_reset_iroot("thaw") 407 | 408 | def thaw_many(self, where): 409 | cond_fn = lambda tup: jnp.any(tup[0]) 410 | 411 | def body_fn(tup): 412 | where, out = tup 413 | idx = jnp.argmax(where) 414 | out = out.thaw(idx) 415 | where = where & ~(jnp.arange(len(where)) == idx) 416 | return where, out 417 | 418 | where, out = jax.lax.while_loop(cond_fn, body_fn, (where, self)) 419 | return out 420 | 421 | @jax.jit 422 | def gmres_solve(self, vecs): 423 | subvecs = jnp.where(self.mask, vecs, 0.0) 424 | subdirect = jnp.where(self.mask[:, None] & self.mask, self.direct, 0.0) 425 | precon = self.iroot @ self.iroot.T 426 | solns = jax.scipy.sparse.linalg.gmres( 427 | subdirect, 428 | subvecs.T, 429 | M=precon, 430 | maxiter=1, 431 | restart=20, 432 | solve_method="batched", 433 | )[0].T 434 | return solns 435 | 436 | @jax.jit 437 | def cg_solve(self, vecs): 438 | subvecs = jnp.where(self.mask, vecs, 0.0) 439 | subdirect = jnp.where(self.mask[:, None] & self.mask, self.direct, 0.0) 440 | precon = self.iroot @ self.iroot.T 441 | solns = jax.scipy.sparse.linalg.cg( 442 | subdirect, 443 | subvecs.T, 444 | M=precon, 445 | maxiter=20, 446 | )[0].T 447 | return solns 448 | 449 | @jax.jit 450 | def cg_project(self, vecs): 451 | vecs = vecs @ self.direct 452 | return self.cg_solve(vecs) 453 | 454 | def freeze_prune_thaw_scores(self, grads, params, ngrad=None): 455 | if ngrad is None: 456 | ngrad = grads @ self.iroot @ self.iroot.T 457 | frz_scaling = jnp.reciprocal( 458 | jnp.sqrt(jnp.maximum(self.eps, jnp.abs(self.diag_inv()))) 459 | ) 460 | if wandb.config.expansion_lower_bound: 461 | frz_scaling = jnp.sqrt(jnp.diag(self.direct)) 462 | tha_center = jnp.abs(jnp.diag(self.direct)) 463 | if not wandb.config.expansion_lower_bound: 464 | tha_center_ = jnp.abs( 465 | tha_center - tha_center * self.diag_inv() * tha_center 466 | ) 467 | tha_center = jnp.maximum(0e-2 * tha_center, tha_center_) 468 | tha_scaling = jnp.reciprocal(jnp.sqrt(jnp.maximum(self.eps, tha_center))) 469 | root_scores = ( 470 | ngrad * frz_scaling, 471 | (ngrad - params) * frz_scaling, 472 | (grads - ngrad @ self.direct) * tha_scaling, 473 | ) 474 | scores = tuple(jnp.sum(jnp.square(jnp.abs(rs)), axis=0) for rs in root_scores) 475 | return scores 476 | 477 | def thaw(self, idx): 478 | dir_row = jax.lax.dynamic_index_in_dim(self.direct, idx) 479 | true_at_idx = jnp.arange(len(self.mask)) == idx 480 | new_col = -self.iroot @ (self.iroot.T @ dir_row.T) 481 | new_col = jnp.where(true_at_idx[:, None], 1.0, new_col) 482 | dir_elem = jax.lax.dynamic_index_in_dim(dir_row, idx, axis=1) 483 | 484 | dir_elem = jnp.maximum(dir_elem + dir_row @ new_col, 1e-3 * dir_elem) 485 | dir_elem = jnp.maximum(dir_elem, self.eps) 486 | 487 | col_scale = jnp.reciprocal(jnp.sqrt(jnp.abs(dir_elem))) 488 | new_iroot = jnp.where(true_at_idx[None, :], new_col * col_scale, self.iroot) 489 | 490 | # iroot_row = jnp.reciprocal(jnp.sqrt(jnp.abs(dir_row)) + self.eps) 491 | # only_diag = jnp.where(true_at_idx[None, :], iroot_row, jnp.zeros_like(iroot_row)) 492 | # new_iroot = jnp.where(true_at_idx[:, None], only_diag, self.iroot) 493 | new_mask = self.mask | true_at_idx 494 | # jax.lax.cond(jnp.isfinite(new_iroot).all(), lambda x: None, lambda x: jax.debug.breakpoint(), 0) 495 | jax.lax.cond( 496 | jnp.isfinite(new_iroot).all(), 497 | lambda x: None, 498 | lambda x: jax.debug.print("thaw error"), 499 | 0, 500 | ) 501 | health = jnp.max(jnp.diag(new_iroot.T @ self.direct @ new_iroot)) 502 | jax.lax.cond( 503 | health < 1e3, 504 | lambda x: None, 505 | lambda x: jax.debug.print("WARN: thaw health"), 506 | 0, 507 | ) 508 | return self.replace(iroot=new_iroot, mask=new_mask).maybe_reset_iroot("thaw") 509 | 510 | def thaw_many(self, where): 511 | cond_fn = lambda tup: jnp.any(tup[0]) 512 | 513 | def body_fn(tup): 514 | where, out = tup 515 | idx = jnp.argmax(where) 516 | out = out.thaw(idx) 517 | where = where & ~(jnp.arange(len(where)) == idx) 518 | return where, out 519 | 520 | where, out = jax.lax.while_loop(cond_fn, body_fn, (where, self)) 521 | return out 522 | 523 | @jax.jit 524 | def gmres_solve(self, vecs): 525 | subvecs = jnp.where(self.mask, vecs, 0.0) 526 | subdirect = jnp.where(self.mask[:, None] & self.mask, self.direct, 0.0) 527 | precon = self.iroot @ self.iroot.T 528 | solns = jax.scipy.sparse.linalg.gmres( 529 | subdirect, 530 | subvecs.T, 531 | M=precon, 532 | maxiter=1, 533 | restart=20, 534 | solve_method="batched", 535 | )[0].T 536 | return solns 537 | 538 | @jax.jit 539 | def cg_solve(self, vecs): 540 | subvecs = jnp.where(self.mask, vecs, 0.0) 541 | subdirect = jnp.where(self.mask[:, None] & self.mask, self.direct, 0.0) 542 | precon = self.iroot @ self.iroot.T 543 | solns = jax.scipy.sparse.linalg.cg( 544 | subdirect, 545 | subvecs.T, 546 | M=precon, 547 | maxiter=20, 548 | )[0].T 549 | return solns 550 | 551 | def freeze_prune_thaw_scores(self, grads, params, ngrad=None): 552 | if ngrad is None: 553 | ngrad = grads @ self.iroot @ self.iroot.T 554 | frz_scaling = jnp.reciprocal( 555 | jnp.sqrt(jnp.maximum(self.eps, jnp.abs(self.diag_inv()))) 556 | ) 557 | tha_center = jnp.abs(jnp.diag(self.direct)) 558 | tha_center_ = jnp.abs(tha_center - tha_center * self.diag_inv() * tha_center) 559 | tha_center = jnp.maximum(0e-2 * tha_center, tha_center_) 560 | tha_scaling = jnp.reciprocal(jnp.sqrt(jnp.maximum(self.eps, tha_center))) 561 | root_scores = ( 562 | ngrad * frz_scaling, 563 | (ngrad - params) * frz_scaling, 564 | (grads - ngrad @ self.direct) * tha_scaling, 565 | ) 566 | scores = tuple(jnp.sum(jnp.square(jnp.abs(rs)), axis=0) for rs in root_scores) 567 | return scores 568 | -------------------------------------------------------------------------------- /senn_cnn/senn/models.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any, Optional, Sequence 3 | from functools import partial 4 | from compose import compose 5 | from math import prod 6 | 7 | import jax 8 | from jax import numpy as jnp 9 | from jax.random import PRNGKey 10 | from jax.tree_util import ( 11 | tree_map, 12 | tree_leaves, 13 | tree_flatten, 14 | tree_unflatten, 15 | tree_reduce, 16 | ) 17 | from flax import struct, linen as nn 18 | from flax.core import frozen_dict 19 | from senn.opt import ( 20 | Task, 21 | TreeOpt, 22 | TrainState, 23 | step as opt_step, 24 | Stepper, 25 | softmax_grad_hgrad, 26 | universal_grad_hgrad, 27 | DiagOpt, 28 | ) 29 | import senn.opt 30 | from senn.neural import HPerturb, hperturb 31 | 32 | import tensorflow_datasets as tfds 33 | 34 | from time import time 35 | from tqdm import tqdm, trange 36 | 37 | import wandb 38 | 39 | 40 | class FakeBatchNorm(nn.Module): 41 | @nn.compact 42 | def __call__(self, x): 43 | return x * self.param("gamma", nn.initializers.ones, (1, 1), jnp.float32) 44 | 45 | 46 | def pad_vars( 47 | module, 48 | index, 49 | length, 50 | axis=-1, 51 | from_back=False, 52 | collection="params", 53 | init=nn.initializers.zeros, 54 | filt=lambda arr: True, 55 | ): 56 | def pad(arr, init=init): 57 | if not filt(arr): 58 | return arr 59 | if not (-len(arr.shape) <= axis < len(arr.shape)): 60 | return arr 61 | pad_shape = tuple(jnp.array(arr.shape).at[axis].set(length)) 62 | key = module.make_rng("params") 63 | padding = init(key, pad_shape, arr.dtype) 64 | split_idx = arr.shape[axis] - index if from_back else index 65 | prefix, suffix = jnp.split(arr, [split_idx], axis=axis) 66 | new_arr = jnp.concatenate([prefix, padding, suffix], axis=axis) 67 | return new_arr 68 | 69 | for key, value in module.variables[collection].items(): 70 | results = tree_map(pad, value) 71 | module.put_variable(collection, key, tree_map(pad, value)) 72 | if collection == "params" and module.is_mutable_collection("was_padded"): 73 | # def fake_pad(arr): 74 | # arr = jnp.zeros_like(arr, dtype=jnp.bool_) 75 | # return pad(arr, init=nn.initializers.ones) 76 | 77 | if module.has_variable("was_padded", key): 78 | old = module.get_variable("was_padded", key) 79 | else: 80 | old = tree_map(lambda a: jnp.zeros_like(a, dtype=jnp.bool_), value) 81 | was_padded = tree_map(partial(pad, init=nn.initializers.ones), old) 82 | 83 | # was_padded = tree_map(fake_pad, value) 84 | # if module.has_variable("was_padded", key): 85 | # old = module.get_variable("was_padded", key) 86 | # was_padded = tree_map(jnp.logical_or, was_padded, old) 87 | module.put_variable("was_padded", key, was_padded) 88 | return None 89 | 90 | 91 | def pad_vars_back(*args, **kwargs): 92 | return pad_vars(*args, **kwargs, from_back=True) 93 | 94 | 95 | def pad_dense_inputs_back(mdl, idx, length): 96 | pad_vars_back(mdl, idx, length, collection="params", axis=-2) 97 | return None 98 | 99 | 100 | class DenseLayer(nn.Module): 101 | linear: nn.Module 102 | nonlin: Optional[nn.Module] 103 | norm: Optional[nn.Module] 104 | in_paddable_collections: Sequence[str] = ("params",) 105 | out_paddable_collections: Sequence[str] = ("params", "probes") 106 | 107 | @nn.compact 108 | def __call__(self, fmaps): 109 | cat = jnp.concatenate(fmaps, axis=-1) 110 | y = self.linear(cat) 111 | y = y if self.norm is None else self.norm(y) 112 | y = y if self.nonlin is None else self.nonlin(y) 113 | return y 114 | 115 | 116 | class XLayer(DenseLayer): 117 | def pad_back_inputs(self, idx, length): 118 | filt = lambda arr: len(arr.shape) > 1 and arr.shape[-2] > 1 119 | for col in self.variables.keys(): 120 | if col in self.in_paddable_collections: 121 | pad_vars_back(self, idx, length, collection=col, filt=filt, axis=-2) 122 | return None 123 | 124 | def pad_back_outputs(self, idx, length): 125 | filt = lambda arr: len(arr.shape) > 0 and arr.shape[-1] > 1 126 | for col in self.variables.keys(): 127 | if col in self.out_paddable_collections: 128 | pad_vars_back(self, idx, length, collection=col, filt=filt, axis=-1) 129 | return None 130 | 131 | def out_dim(self): 132 | def dim(arr): 133 | return 0 if len(arr.shape) < 1 else arr.shape[-1] 134 | 135 | dims = tree_map(dim, self.linear.variables["params"]) 136 | return tree_reduce(jnp.maximum, dims, 0) 137 | 138 | def zero_params(self): 139 | for name, tree in self.variables["params"].items(): 140 | new_tree = tree_map(lambda arr: jnp.zeros_like(arr), tree) 141 | self.put_variable("params", name, new_tree) 142 | 143 | 144 | class Buddable(nn.Module): 145 | main: XLayer 146 | bud: Optional[XLayer] 147 | 148 | def __call__(self, fmaps): 149 | if self.bud is not None: 150 | fmaps = fmaps + (self.bud(fmaps),) 151 | return self.main(fmaps) 152 | 153 | def out_dim(self): 154 | return self.main.out_dim() 155 | 156 | def pad_back_inputs(self, idx, length): 157 | if self.bud is not None: 158 | self.bud.pad_back_inputs(idx, length) 159 | idx = idx + self.bud.out_dim() 160 | self.main.pad_back_inputs(idx, length) 161 | return None 162 | 163 | def pad_back_outputs(self, idx, length): 164 | self.main.pad_back_outputs(idx, length) 165 | 166 | def bud_reinit_allowed(self, col="allowed"): 167 | if self.bud is None: 168 | return None 169 | for key, param in self.main.variables["params"].items(): 170 | bud_out = self.bud.out_dim() 171 | 172 | def crop(arr): 173 | if len(arr.shape) < 1 or arr.shape == (1, 1): 174 | return arr 175 | bo_true = jnp.ones(shape=(bud_out,), dtype=jnp.bool_) 176 | in_ok = jnp.any(arr[..., :-bud_out, :], axis=-1, keepdims=True) 177 | return in_ok & bo_true 178 | 179 | allowed = self.main.variables[col].get(key) 180 | self.bud.put_variable(col, key, tree_map(crop, allowed)) 181 | return None 182 | 183 | def new_layer_vars(self): 184 | mdl, bvars = self.bud.unbind() 185 | bud_length = self.bud.out_dim() 186 | _, new_bud = mdl.apply(bvars, method=mdl.zero_params, mutable=True) 187 | _, new_main = mdl.apply( 188 | bvars, 189 | idx=0, 190 | length=bud_length, 191 | method=mdl.pad_back_inputs, 192 | mutable=True, 193 | rngs=dict(params=self.make_rng("params")), 194 | ) 195 | merged = frozen_dict.freeze( 196 | { 197 | col: {"bud": new_bud[col], "main": new_main[col]} 198 | for col in self.variables.keys() 199 | } 200 | ) 201 | return merged 202 | 203 | def zero_bud(self): 204 | self.bud.zero_params() 205 | 206 | 207 | def expandable_conv(features, **kwargs): 208 | return nn.Conv( 209 | features=features, 210 | kernel_size=(3, 3), 211 | strides=(1, 1), 212 | padding="SAME", 213 | use_bias=False, 214 | **kwargs, 215 | ) 216 | 217 | 218 | def expandable_pool(x): 219 | return nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") 220 | 221 | 222 | class Block(nn.Module): 223 | layers: Sequence[Buddable] 224 | 225 | @nn.compact 226 | def __call__(self, x): 227 | fmaps = (x,) 228 | for layer in self.layers[:-1]: 229 | fmaps = fmaps + (layer(fmaps),) 230 | for layer in self.layers[-1:]: 231 | x = layer(fmaps) 232 | return x 233 | 234 | 235 | class XBlock(Block): 236 | def pad_intermediate(self, idx, length): 237 | assert idx + 1 < len(self.layers) 238 | self.layers[idx].pad_back_outputs(0, length) 239 | index = 0 240 | for layer in self.layers[idx + 1 :]: 241 | layer.pad_back_inputs(index, length) 242 | index += layer.out_dim() 243 | return None 244 | 245 | def pad_final(self, length): 246 | self.layers[-1].pad_back_outputs(0, length) 247 | return None 248 | 249 | def pad_inputs(self, length): 250 | index = 0 251 | for layer in self.layers: 252 | layer.pad_back_inputs(index, length) 253 | index += layer.out_dim() 254 | return None 255 | 256 | def shift_old_to_new(self, idx): 257 | def maybe_write(source, dest, name, new_name=None): 258 | new_name = name if new_name is None else new_name 259 | if source.has_variable("old", name): 260 | v = source.get_variable("old", name) 261 | dest.put_variable("new", new_name, v) 262 | 263 | for i, (layer, next_layer) in enumerate(zip(layers, layers[1:])): 264 | if i < idx: 265 | maybe_write(layer, layer, "main") 266 | maybe_write(layer, layer, "bud") 267 | if i == idx: 268 | maybe_write(layer, layer, "bud", new_name="main") 269 | maybe_write(layer, next_layer, "main") 270 | if i > idx: 271 | maybe_write(layer, next_layer, "main") 272 | maybe_write(layer, next_layer, "bud") 273 | return None 274 | 275 | def activate_bud(self, idx): 276 | # collections: was_padded, params, probes 277 | new_layer = self.layers[idx].new_layer_vars() 278 | length = self.layers[idx].bud.out_dim() 279 | index = 0 280 | for layer in self.layers[idx:]: 281 | layer.pad_back_inputs(index, length) 282 | index += layer.out_dim() 283 | self.layers[idx].zero_bud() 284 | widths = list(layer.out_dim() for layer in self.layers) 285 | new_widths = widths[:idx] + [length] + widths[idx:] 286 | return new_widths, new_layer 287 | 288 | def insert_layer_vars(self, idx, new_layer): 289 | name = lambda i: f"layers_{i}" 290 | for i, layer in enumerate(self.layers[idx:]): 291 | for col in layer.variables.keys(): 292 | self.put_variable(col, name(i + 1), self.get_variable(col, name(i))) 293 | for col, variable in new_layer.variables.items(): 294 | self.put_variable(col, name(idx), variable) 295 | return None 296 | 297 | 298 | def width_to_add(module): 299 | if "add_width" in module.variables: 300 | return tree_reduce(jnp.maximum, module.variables["add_width"], 0) 301 | else: 302 | return 0 303 | 304 | 305 | class ExpandableDense(nn.Module): 306 | blocks: Sequence[XBlock] 307 | final: nn.Module 308 | 309 | @nn.nowrap 310 | @classmethod 311 | def build(cls, *, out, nonlin, widthss, maybe_bud_width=None): 312 | def layer(w): 313 | # norm = FakeBatchnorm 314 | # norm = IdentityModule 315 | return XLayer(linear=expandable_conv(w), nonlin=HPerturb(nonlin), norm=None) 316 | 317 | def bud(): 318 | return None if maybe_bud_width is None else layer(maybe_bud_width) 319 | 320 | def buddable(w): 321 | return Buddable(main=layer(w), bud=bud()) 322 | 323 | def block(widths): 324 | return XBlock(list(map(buddable, widths))) 325 | 326 | blocks = tuple(map(block, widthss)) 327 | final = nn.Dense(out, use_bias=False) 328 | return cls(blocks, final) 329 | 330 | @nn.compact 331 | def __call__(self, x): 332 | pool = expandable_pool 333 | for block in self.blocks: 334 | x = block(x) 335 | x = pool(x) 336 | x = jnp.mean(x, axis=(-2, -3)) 337 | x = self.final(x) 338 | return x 339 | 340 | def layer_widths(self): 341 | return tuple( 342 | tuple(layer.out_dim() for layer in block.layers) for block in self.blocks 343 | ) 344 | 345 | def maybe_add_width(self): 346 | old_widths = self.layer_widths() 347 | for i, block in enumerate(self.blocks): 348 | for j, layer in enumerate(block.layers): 349 | to_add = width_to_add(layer.main.linear) 350 | if to_add <= 0: 351 | continue 352 | jax.debug.print("adding {} neurons to block {}, layer {}", to_add, i, j) 353 | if j + 1 < len(block.layers): 354 | self.blocks[i].pad_intermediate(j, to_add) 355 | elif j + 1 == len(block.layers) and i + 1 < len(self.blocks): 356 | self.blocks[i].pad_final(to_add) 357 | self.blocks[i + 1].pad_inputs(to_add) 358 | elif j + 1 == len(block.layers) and i + 1 == len(self.blocks): 359 | self.blocks[i].pad_final(to_add) 360 | pad_dense_inputs_back(self.final, 0, to_add) 361 | # do not expand final layer 362 | continue 363 | new_widths = self.layer_widths() 364 | if new_widths == old_widths: 365 | return None 366 | else: 367 | return new_widths 368 | # for col, value in self.variables.items(): 369 | # if self.is_mutable_collection(col): 370 | # for key, variable in value.items(): 371 | # self.put_variable(col, key, variable) 372 | # if (to_add := width_to_add(self.dense)) > 0: 373 | # pad_dense_inputs_back(self.dense, 0, to_add) 374 | # self.blocks[-1].pad_final(to_add) 375 | 376 | def bud_reinit_allowed(self): 377 | for block in self.blocks: 378 | for layer in block.layers: 379 | layer.bud_reinit_allowed() 380 | 381 | def activate_bud(self, block_idx, layer_idx): 382 | # collections: params, probes, was_padded 383 | old_widths = self.layer_widths() 384 | widths, new_layer = self.blocks[block_idx].activate_bud(layer_idx) 385 | out_w = old_widths[:block_idx] + (widths,) + old_widths[block_idx + 1 :] 386 | return out_w, new_layer 387 | 388 | @nn.nowrap 389 | def insert_into_tree(self, tree, bidx, lidx, item): 390 | bname = lambda i: f"blocks_{i}" 391 | lname = lambda i: f"layers_{i}" 392 | new_block = {lname(lidx): item} 393 | old_block = tree[bname(bidx)] 394 | N = len(old_block.keys()) 395 | for i in range(0, lidx): 396 | new_block[lname(i)] = old_block[lname(i)] 397 | # new_block = cursor(new_block)[lname(i)].set(tree[lname(i)]) 398 | for i in range(lidx, N): 399 | new_block[lname(i + 1)] = old_block[lname(i)] 400 | # new_block = cursor(new_block)[lname(i+1)].set(tree[lname(i)]) 401 | # return cursor(tree)[bname(bidx)].set(new_block) 402 | new_block = frozen_dict.freeze(new_block) 403 | return tree.copy({bname(bidx): new_block}) 404 | 405 | def argmax_score(self, ignore_first=True): 406 | best = -jnp.inf 407 | coords = -1, -1 408 | for bidx, block in enumerate(self.blocks): 409 | if len(block.layers) >= wandb.config.block_size_hard_cap: 410 | continue 411 | for lidx, layer in enumerate(block.layers): 412 | if ignore_first and (bidx, lidx) == (0, 0): 413 | continue 414 | score = layer.main.linear.get_variable("score", "kernel", -jnp.inf) 415 | if score > best: 416 | best = score 417 | coords = bidx, lidx 418 | return best, coords 419 | -------------------------------------------------------------------------------- /senn_cnn/senn/neural.py: -------------------------------------------------------------------------------- 1 | """NN Modules and associated second order and expansion methods""" 2 | from functools import partial, wraps 3 | from compose import compose 4 | from collections.abc import Callable 5 | from typing import Tuple, Any, Optional 6 | from math import prod 7 | 8 | import jax 9 | from jax import numpy as jnp 10 | from jax.random import PRNGKey 11 | from jax.tree_util import tree_map, Partial 12 | from tensorflow_probability.substrates import jax as tfp 13 | import flax 14 | from flax import linen as nn 15 | 16 | from senn import linalg 17 | 18 | 19 | def with_dummy_cotan(mdl): 20 | def f(x, cotan): 21 | return x 22 | 23 | def fwd(x, cotan): 24 | y, vjp = nn.vjp(mdl.apply, mdl, x) 25 | params_t = vjp(cotan) 26 | return x, params_t 27 | 28 | def bwd(params_t, x_t): 29 | return params_t, x_t 30 | 31 | 32 | def homogenize_last_dim(x): 33 | """inserts a constant input of 1.0 at the first channel index""" 34 | ones = jnp.ones(x.shape[:-1] + (1,), dtype=x.dtype) 35 | return jnp.concatenate((ones, x), axis=-1) 36 | 37 | 38 | class Sherman(nn.Module): 39 | """Tracks second order statistics for a linear transform. 40 | 41 | Assumes that the last dimension of the kernel is output, and the rest are input. 42 | Uses exponential averaging, with weight for new observation given by "decay". 43 | """ 44 | 45 | decay: Any = None 46 | 47 | @nn.compact 48 | def update_and_track(self, x, fake_bias): 49 | # vec = x.reshape((self.in_size(),)) 50 | vec = jnp.ravel(x) 51 | init_fn = linalg.SecondMoment.init_identity 52 | self.variable("kron", "in", init_fn, vec.size) 53 | self.variable("kron", "out", init_fn, fake_bias.size) 54 | kron_in = self.get_variable("kron", "in") 55 | kron_in = kron_in.rank_one_update(vec, decay=self.decay) 56 | self.put_variable("kron", "in", kron_in) 57 | 58 | # assert fake_bias.shape == (self.features,) 59 | return self.perturb("out", fake_bias) 60 | 61 | def update_kron_out(self): 62 | x = self.get_variable("perturbations", "out") 63 | # vec = x.reshape((self.features,)) 64 | vec = jnp.ravel(x) 65 | kron_out = self.get_variable("kron", "out") 66 | kron_out = kron_out.rank_one_update(vec, decay=self.decay) 67 | self.put_variable("kron", "out", kron_out) 68 | 69 | def get_kron(self): 70 | kron_in = self.get_variable("kron", "in") 71 | kron_out = self.get_variable("kron", "out") 72 | return kron_in, kron_out 73 | 74 | @nn.nowrap 75 | def kron_mul(self, Kin, Kout, kernel): 76 | in_size = Kin.shape[-1] 77 | out_size = Kout.shape[-1] 78 | reshaped = kernel.reshape((in_size, out_size)) 79 | reshaped = Kin @ reshaped @ (jnp.conj(Kout).T) 80 | return reshaped.reshape(kernel.shape) 81 | 82 | def ichol_mul(self, kernel): 83 | Kin, Kout = self.get_kron() 84 | kernel = self.kron_mul(Kin.ichol, Kout.ichol, kernel) 85 | return kernel 86 | 87 | def inv_mul(self, kernel): 88 | Kin, Kout = self.get_kron() 89 | kernel = self.kron_mul(Kin.inv, Kout.inv, kernel) 90 | return kernel 91 | 92 | 93 | class Kronify(nn.Module): 94 | """Wraps a linear module, for which it tracks second order statistics. 95 | 96 | Optionally, appends a nonlinearity to the module and/or adds a homogeneous coordinate to 97 | the input. This coordinate simulates the use of a bias while introducing no new logic. 98 | """ 99 | 100 | linear: nn.Module 101 | sherman: Sherman = Sherman() 102 | homogenize: bool = False 103 | nonlin: Optional[Callable] = None 104 | 105 | def reduced_linear_variables(self): 106 | params = self.linear.variables["params"] 107 | assert "bias" not in params, "bias not supported, use homogenize instead" 108 | params = tree_map(lambda arr: arr[..., [0]], params) 109 | return {"params": params} 110 | 111 | def update_kron_out(self): 112 | self.sherman.update_kron_out() 113 | 114 | def __call__(self, x): 115 | if self.homogenize: 116 | x = homogenize_last_dim(x) 117 | self.sow("intermediates", "kernel_in", x) 118 | 119 | if self.has_rng("noisy_params"): 120 | assert not self.is_initializing(), "do not use noise when initializing" 121 | 122 | lin, linvars = self.linear.unbind() 123 | kernel = self.linear.get_variable("params", "kernel") 124 | key = self.make_rng("noisy_params") 125 | white_noise = jax.random.normal(key, kernel.shape, kernel.dtype) 126 | unwhite_noise = self.sherman.ichol_mul(white_noise) 127 | noised_kernel = jax.lax.stop_gradient(unwhite_noise) + kernel 128 | y = lin.apply(linvars.copy({"params": {"kernel": noised_kernel}}), x) 129 | 130 | else: 131 | y = self.linear(x) 132 | 133 | if self.has_rng("hutchinson"): 134 | dummy = self.linear.clone(features=1).bind(self.reduced_linear_variables()) 135 | dummy_y, vjp = nn.vjp(lambda mdl: mdl(x), dummy) 136 | hutch_key = self.make_rng("hutchinson") 137 | hutch_cotan = jax.random.rademacher(hutch_key, dummy_y.shape, dummy_y.dtype) 138 | (params_t,) = vjp(hutch_cotan) 139 | kron_in_sample = params_t["params"]["kernel"] 140 | 141 | fake_bias = jnp.zeros(y.shape[-1:], dtype=y.dtype) 142 | assert hutch_cotan.shape[-1] == 1 143 | tracked_bias = self.sherman.update_and_track(kron_in_sample, fake_bias) 144 | y = y + hutch_cotan * tracked_bias 145 | if self.nonlin is not None: 146 | y = self.nonlin(y) 147 | return y 148 | 149 | 150 | class ScannedKronify(Kronify): 151 | initial_count: int = 1 152 | 153 | @nn.compact 154 | def __call__(self, x): 155 | assert ( 156 | x.shape[-1] == self.linear.features 157 | ), "we require feature size to remain unchanged" 158 | length = self.initial_count if self.is_initializing else None 159 | 160 | def body_fun(module, carry, _): 161 | return super(type(self), module).__call__(carry), None 162 | 163 | scan = nn.scan( 164 | body_fun, 165 | variable_axes={True: 0}, 166 | variable_broadcast=False, 167 | split_rngs={True: True}, 168 | length=length, 169 | ) 170 | x, _ = scan(self, x, None) 171 | return x 172 | 173 | 174 | def reduced_variables(mdl): 175 | variables = mdl.variables 176 | params = variables["params"] 177 | new_params = tree_map(lambda arr: arr[..., [0]], params) 178 | return variables.copy(dict(params=new_params)) 179 | 180 | 181 | def record_input_sensitivity(mdl, x, hutch): 182 | dummy = mdl.clone(features=1).bind(reduced_variables(mdl)) 183 | dummy_y, vjp = nn.vjp(lambda dum: super(type(dum), dum).__call__(x), dummy) 184 | (vars_cotan,) = vjp(hutch) 185 | for key, value in vars_cotan["params"].items(): 186 | mdl.put_variable("hutch_in", key, value) 187 | 188 | 189 | def record_output_sensitivity(mdl, y, hutch): 190 | assert hutch.shape[:-1] == y.shape[:-1] 191 | fake_bias = jnp.zeros(shape=y.shape[-1:], dtype=y.dtype) 192 | for name in mdl.variables["params"].keys(): 193 | fake_bias = mdl.perturb(name, fake_bias, collection="hutch_out") 194 | return y + fake_bias * hutch 195 | 196 | 197 | def make_hutch_for(mdl, x): 198 | key = mdl.make_rng("hutchinson") 199 | hutch_shape = x.shape[:-1] + (1,) 200 | return jax.random.rademacher(key, shape=hutch_shape, dtype=x.dtype) 201 | 202 | 203 | class KDense(nn.Dense): 204 | homogenize: bool = False 205 | nonlin: Optional[Callable] = None 206 | use_bias: bool = False 207 | 208 | @nn.compact 209 | def __call__(self, x): 210 | if self.homogenize: 211 | x = homogenize_last_dim(x) 212 | 213 | y = super().__call__(x) 214 | 215 | hutch = make_hutch_for(self, y) 216 | record_input_sensitivity(self, x, hutch) 217 | y = record_output_sensitivity(self, y, hutch) 218 | 219 | y = y if self.nonlin is None else self.nonlin(y) 220 | return y 221 | 222 | 223 | class KKronify(nn.Module): 224 | """Wraps a linear module, for which it tracks second order statistics. 225 | 226 | Optionally, appends a nonlinearity to the module and/or adds a homogeneous coordinate to 227 | the input. This coordinate simulates the use of a bias while introducing no new logic. 228 | """ 229 | 230 | linear: nn.Module 231 | homogenize: bool = False 232 | nonlin: Optional[Callable] = None 233 | 234 | def reduced_linear_variables(self): 235 | params = self.linear.variables["params"] 236 | assert "bias" not in params, "bias not supported, use homogenize instead" 237 | params = tree_map(lambda arr: arr[..., [0]], params) 238 | return {"params": params} 239 | 240 | def update_kron_out(self): 241 | self.sherman.update_kron_out() 242 | 243 | def __call__(self, x): 244 | if self.homogenize: 245 | x = homogenize_last_dim(x) 246 | self.sow("intermediates", "kernel_in", x) 247 | 248 | y = self.linear(x) 249 | 250 | if self.has_rng("hutchinson"): 251 | dummy = self.linear.clone(features=1).bind(self.reduced_linear_variables()) 252 | dummy_y, vjp = nn.vjp(lambda mdl: mdl(x), dummy) 253 | hutch_key = self.make_rng("hutchinson") 254 | hutch_cotan = jax.random.rademacher(hutch_key, dummy_y.shape, dummy_y.dtype) 255 | (params_t,) = vjp(hutch_cotan) 256 | kron_in_sample = params_t["params"]["kernel"] 257 | 258 | fake_bias = jnp.zeros(y.shape[-1:], dtype=y.dtype) 259 | assert hutch_cotan.shape[-1] == 1 260 | tracked_bias = self.sherman.update_and_track(kron_in_sample, fake_bias) 261 | y = y + hutch_cotan * tracked_bias 262 | if self.nonlin is not None: 263 | y = self.nonlin(y) 264 | return y 265 | 266 | 267 | def _homogenized(cls, argnums=(0,)): 268 | old_call = cls.__call__ 269 | 270 | @wraps(old_call) 271 | def new_call(self, *args, **kwargs): 272 | maybe_homog = ( 273 | lambda tup: homogenize_last_dim(tup[1]) if tup[0] in argnums else tup[1] 274 | ) 275 | return old_call(self, *map(maybe_homog, enumerate(args)), **kwargs) 276 | 277 | cls.__call__ = new_call 278 | return cls 279 | 280 | 281 | def _instrumented(cls): 282 | old_call = cls.__call__ 283 | 284 | def new_call(self, *args): 285 | pass 286 | 287 | 288 | def value_grad_curv(fn, x): 289 | ones_jvp = lambda x: jax.jvp(fn, (x,), (jnp.ones_like(x),)) 290 | (y, dy), (_, ddy) = jax.jvp(ones_jvp, (x,), (jnp.ones_like(x),)) 291 | return y, dy, ddy 292 | 293 | 294 | def general_hperturb(fn, transform=jnp.sign): 295 | @jax.custom_vjp 296 | def inner(x, *args, perturb): 297 | return fn(x, *args) 298 | 299 | def inner_fwd(x, *args, perturb): 300 | y, vjp = jax.vjp(fn, x, *args) 301 | ddy = jax.grad(jax.grad(fn))(x, *args) 302 | return vjp, ddy, perturb 303 | 304 | def inner_bwd(res, g): 305 | vjp, ddy, perturb = res 306 | stop_g = vjp(jax.lax.stop_gradient(g)) 307 | stop_vjp = jax.lax.stop_gradient(vjp)(g) 308 | scale = transform(ddy * g) 309 | 310 | def combine(a, b): 311 | return a + scale * b - scale * jax.lax.stop_gradient(b) 312 | 313 | x_grad, *args_grad = map(combine, stop_vjp, stop_g) 314 | return x_grad + jnp.sqrt(scale * jnp.abs(g * ddy)) * perturb, *args_grad, None 315 | 316 | inner.defvjp(inner_fwd, inner_bwd) 317 | return inner 318 | 319 | 320 | def hperturb(fn, elementwise=True, chol_rank=None): 321 | @jax.custom_vjp 322 | def inner(key, mag, x): 323 | return fn(x) 324 | 325 | def inner_fwd(key, mag, x): 326 | noise = jax.random.rademacher(key, x.shape, x.dtype) 327 | if elementwise: 328 | y, dy, ddy = value_grad_curv(fn, x) 329 | assert y.shape == x.shape 330 | assert dy.shape == x.shape 331 | assert ddy.shape == x.shape 332 | perturbation = noise * mag 333 | 334 | @Partial 335 | def vjp(g): 336 | return (dy * g,) 337 | 338 | else: 339 | EPS = 1e-4 340 | assert len(x.shape) == 1 341 | (full_rank,) = x.shape 342 | y, vjp = jax.vjp(fn, x) 343 | hess = jax.hessian(fn)(x) 344 | hess = hess + EPS * jnp.identity(full_rank) 345 | max_rank = full_rank if chol_rank is None else chol_rank 346 | max_rank = full_rank 347 | low_rank, _, _ = tfp.math.low_rank_cholesky(hess, max_rank) 348 | perturbation = low_rank @ (noise[..., :max_rank] * mag) 349 | ddy = 1.0 350 | res = vjp, ddy, perturbation 351 | return y, res 352 | 353 | def inner_bwd(res, g): 354 | vjp, ddy, perturbation = res 355 | perturbation = jnp.sqrt(jnp.abs(ddy * g) + 1e-12) * perturbation 356 | # (grad,) = vjp(g) 357 | (grad_stopped_g,) = vjp(jax.lax.stop_gradient(g)) 358 | (grad_stopped_vjp,) = jax.lax.stop_gradient(vjp)(g) 359 | zero_unstopped_vjp = grad_stopped_g - jax.lax.stop_gradient(grad_stopped_g) 360 | grad_out = grad_stopped_vjp + jnp.sign(ddy * g) * zero_unstopped_vjp 361 | return ( 362 | None, 363 | None, 364 | grad_out + perturbation, 365 | ) 366 | 367 | inner.defvjp(inner_fwd, inner_bwd) 368 | return inner 369 | 370 | 371 | class HPerturb(nn.Module): 372 | fn: Callable 373 | rng_name: str = "nonlin" 374 | 375 | @nn.compact 376 | def __call__(self, x): 377 | probe = self.perturb("nonlin", jnp.zeros_like(x), collection="probes") 378 | if self.has_variable("probes", "nonlin"): 379 | key = PRNGKey(0) if self.is_initializing() else self.make_rng("nonlin") 380 | y = hperturb(self.fn)(key, probe, x) 381 | else: 382 | y = self.fn(x) 383 | return y 384 | -------------------------------------------------------------------------------- /senn_cnn/tests/test_linalg.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from tensorflow_probability.substrates import jax as tfp 4 | from functools import partial 5 | from compose import compose 6 | 7 | import pytest 8 | 9 | from senn import linalg as sennlinalg 10 | 11 | 12 | def weighted_samples(key, dim=8, count=128): 13 | init = jnp.eye(dim) 14 | vec_key, weight_key = jax.random.split(key) 15 | vecs = jax.vmap( 16 | partial( 17 | jax.random.normal, 18 | shape=(8,), 19 | ) 20 | )(jax.random.split(vec_key, count)) 21 | weights = jax.vmap(jax.random.uniform)(jax.random.split(weight_key, count)) 22 | return init, vecs, weights 23 | 24 | 25 | def scan_updates(fn, init, vecs): 26 | """fn: M, v --> newM, init: M, vecs: [v]""" 27 | as_pair = lambda a: (a, a) 28 | _, intermediates = jax.lax.scan(compose(as_pair, fn), init, vecs) 29 | return intermediates 30 | 31 | 32 | def check_rank_one_update_fn(updatefn, check_vs_direct): 33 | key = jax.random.PRNGKey(seed=0) 34 | init, vecs, weights = weighted_samples(key) 35 | 36 | directs = scan_updates(sennlinalg.direct_update, init, vecs) 37 | incrementals = scan_updates(updatefn, init, vecs) 38 | 39 | for i in range(len(weights)): 40 | check_vs_direct(incrementals[i], directs[i]) 41 | 42 | 43 | def approx_identity(direct): 44 | return pytest.approx(jnp.eye(direct.shape[-1]), abs=1e-5) 45 | 46 | 47 | def check_chol(chol, direct): 48 | assert tfp.math.hpsd_quadratic_form_solve(direct, chol) == approx_identity(direct) 49 | 50 | 51 | def test_chol_update(): 52 | check_rank_one_update_fn(sennlinalg.chol_update, check_chol) 53 | 54 | 55 | def check_inv(inv, direct): 56 | assert inv @ direct == approx_identity(direct) 57 | 58 | 59 | def test_inv_update(): 60 | check_rank_one_update_fn(sennlinalg.inv_update, check_inv) 61 | 62 | 63 | def check_ichol(ichol, direct): 64 | assert ichol.T @ direct @ ichol == approx_identity(direct) 65 | 66 | 67 | def test_ichol_update(): 68 | check_rank_one_update_fn(sennlinalg.ichol_update, check_ichol) 69 | 70 | 71 | def test_cholupdate(): 72 | dim = 8 73 | count = 128 74 | I = jnp.eye(8) 75 | key = jax.random.PRNGKey(seed=0) 76 | samples = jax.vmap(partial(jax.random.normal, shape=(8,)))( 77 | jax.random.split(key, count) 78 | ) 79 | C = I + jnp.sum(jax.vmap(lambda v: jnp.outer(v, v))(samples), axis=0) 80 | 81 | actual_chol = jnp.linalg.cholesky(C) 82 | 83 | def f(carry, x): 84 | return tfp.math.cholesky_update(carry, x), None 85 | 86 | update_chol, _ = jax.lax.scan(f, I, samples) 87 | 88 | assert update_chol == pytest.approx(actual_chol, rel=1e-4) 89 | 90 | actual_ichol = jnp.linalg.cholesky(jnp.linalg.inv(C)) 91 | 92 | def g(carry, x): 93 | y = carry.T @ x 94 | mult = -1 / (1 + jnp.inner(y, y)) 95 | return tfp.math.cholesky_update(carry, carry @ y, multiplier=mult), None 96 | 97 | update_ichol, _ = jax.lax.scan(g, I, samples) 98 | 99 | assert update_ichol == pytest.approx(actual_ichol, rel=1e-4) 100 | -------------------------------------------------------------------------------- /senn_cnn/tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from senn import models 3 | from senn.opt import TrainState 4 | import senn.opt 5 | 6 | import jax 7 | import optax 8 | from jax import numpy as jnp 9 | from jax.random import PRNGKey 10 | from jax.tree_util import tree_map, tree_reduce, tree_all 11 | import numpy as np 12 | import wandb 13 | 14 | import flax 15 | from flax.traverse_util import flatten_dict, unflatten_dict 16 | 17 | widthss = ([4, 6],) 18 | num_channels = 3 19 | img_size = 32 20 | example = jnp.zeros((1, img_size, img_size, num_channels)) 21 | num_classes = 2 22 | bud_size = 8 23 | 24 | wandb.init(mode="disabled") 25 | wandb.config.ubah_mul = 1.0 26 | wandb.config.add_unit_normal_curvature = True 27 | wandb.config.grad_update_as_curvature = True 28 | wandb.config.curvature = "UBAH" 29 | wandb.config.grad_curvature_mul = 1.0 30 | wandb.config.use_dropout = False 31 | wandb.config.pruned_lr_rescale = True 32 | wandb.config.freeze_thaw_disable = False 33 | wandb.config.expansion_cutoff_step = 1000000 34 | wandb.config.freeze_thresh = 1e-2 35 | wandb.config.thaw_thresh = 1e-2 36 | wandb.config.freeze_is_prune = True 37 | wandb.config.freeze_thresh_rel = 1e-2 38 | wandb.config.thaw_thresh_rel = 1e-2 39 | wandb.config.thaw_prob_size_compensate = True 40 | wandb.config.ignore_width = bud_size 41 | 42 | wandb.config.minimum_width = num_classes * 2 43 | wandb.config.num_classes = num_classes 44 | wandb.config.untouched_thresh = 1 45 | 46 | 47 | def builder(widthss): 48 | return models.ExpandableDense.build( 49 | widthss=widthss, out=num_classes, nonlin=jax.nn.swish, maybe_bud_width=bud_size 50 | ) 51 | 52 | 53 | def make_model(): 54 | # model_kwargs = dict(widthss=widthss, out=num_classes, nonlin=jax.nn.swish, maybe_bud_width=bud_size) 55 | # model = models.ExpandableDense.build(**model_kwargs) 56 | model = builder(widthss) 57 | variables = model.init(PRNGKey(0), example) 58 | return model, variables 59 | 60 | 61 | def make_schedule(): 62 | return optax.cosine_onecycle_schedule( 63 | transition_steps=100 * 781, 64 | peak_value=1e-3, 65 | pct_start=0.1, 66 | ) 67 | 68 | 69 | def make_optimizer(): 70 | schedule = make_schedule() 71 | first_order = senn.opt.MyAdam( 72 | lr=schedule, mom1=1e-1, mom2=1e-2, weight_decay=1e-2, noise_std=0.0, order=0 73 | ) 74 | optimizer = senn.opt.WrappedFirstOrder(tx=first_order) 75 | return optimizer 76 | 77 | 78 | def test_init(): 79 | model, variables = make_model() 80 | assert set(variables) == {"params", "probes"} 81 | 82 | 83 | def make_trainstate(): 84 | model, variables = make_model() 85 | optimizer = make_optimizer() 86 | ts = TrainState.create( 87 | optimizer, 88 | variables["params"], 89 | variables.get("probes", {}), 90 | model.apply, 91 | batch_stats=variables.get("batch_stats", {}), 92 | dummy_input=example, 93 | model=model, 94 | path_pred=lambda path: "bud" not in path, 95 | ) 96 | ts = ts.init_prune() 97 | return ts 98 | 99 | 100 | def make_stepper(): 101 | return senn.opt.Stepper(varinf=True, paired_varinf=True, hess_only_varinf=True) 102 | 103 | 104 | def lossfn(label, y): 105 | lsm = jax.nn.log_softmax(y) 106 | label_log_prob = jnp.take_along_axis(lsm, label[..., None], axis=-1)[..., 0] 107 | return -label_log_prob 108 | 109 | 110 | def make_tasks(batch_size=64, key=PRNGKey(0)): 111 | xkey, lkey = jax.random.split(key) 112 | shape = (batch_size,) + example.shape[1:] 113 | xs = jax.random.normal(xkey, shape=shape) 114 | logits = jnp.zeros((num_classes,)) 115 | labels = jax.random.categorical(lkey, logits, shape=(batch_size,)) 116 | return senn.opt.Task(xs, labels, lossfn) 117 | 118 | 119 | def test_trainstate_create(): 120 | train_state = make_trainstate() 121 | call_stepper(train_state) 122 | 123 | 124 | def call_stepper(train_state): 125 | stepper = make_stepper() 126 | key = PRNGKey(0) 127 | tasks = make_tasks() 128 | weights = jnp.ones(len(tasks.label)) 129 | new_train_state = stepper(train_state, tasks, weights, key) 130 | return new_train_state 131 | 132 | 133 | def test_get_metrics(): 134 | train_state = make_trainstate() 135 | train_state.get_metrics() 136 | 137 | 138 | def test_pruned_init(): 139 | train_state = make_trainstate() 140 | for key in flatten_dict(train_state.opt_state): 141 | assert "bud" not in key 142 | for key, value in flatten_dict(train_state.params).items(): 143 | if "main" in key: 144 | assert jnp.all(value[..., -bud_size:, :] == 0.0) 145 | 146 | 147 | def test_maybe_add_width(): 148 | train_state = make_trainstate() 149 | add_width = tree_map(lambda a: 7, train_state.subparams) 150 | key = PRNGKey(0) 151 | train_state = train_state.maybe_expand_width( 152 | key=key, builder=builder, add_width=add_width 153 | ) 154 | call_stepper(train_state) 155 | 156 | 157 | def test_insert_layer(): 158 | train_state = make_trainstate() 159 | train_state = call_stepper(train_state) 160 | bidx, lidx = 0, 1 161 | train_state = train_state.insert_layer(builder, bidx, lidx) 162 | print(tree_map(lambda arr: arr.shape, train_state.params)) 163 | print(tree_map(lambda arr: arr.shape, train_state.probes)) 164 | train_state = call_stepper(train_state) 165 | 166 | 167 | def test_inserted_same_function(): 168 | train_state = make_trainstate() 169 | print(train_state.params) 170 | tasks = make_tasks() 171 | ys0 = jax.vmap(train_state.eval)(tasks) 172 | print(ys0) 173 | bidx, lidx = 0, 1 174 | train_state = train_state.insert_layer(builder, bidx, lidx) 175 | print(train_state.params) 176 | ys1 = jax.vmap(train_state.eval)(tasks) 177 | print(ys1) 178 | assert pytest.approx(ys0) == ys1 179 | -------------------------------------------------------------------------------- /senn_cnn/tests/test_neural.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from compose import compose 3 | 4 | from flax import linen as nn 5 | from senn import neural 6 | import jax 7 | from jax import numpy as jnp 8 | from jax.random import PRNGKey 9 | from jax.tree_util import tree_map 10 | import numpy as np 11 | import pytest 12 | from tensorflow_probability.substrates import jax as tfp 13 | 14 | 15 | def make_hdense(num_features=8): 16 | key = jax.random.PRNGKey(seed=0) 17 | in_shape = (num_features + 3,) 18 | out_shape = (num_features,) 19 | # model = neural.HDense(num_features) 20 | model = neural.Kronify(nn.Dense(num_features, use_bias=False)) 21 | state = model.init({"params": key, "hutchinson": key}, jnp.zeros(in_shape)) 22 | return model, state, in_shape, out_shape 23 | 24 | 25 | def test_hdense_init(num_features=8): 26 | model, state, in_shape, out_shape = make_hdense() 27 | assert "params" in state 28 | assert "linear" in state["params"] 29 | assert "kernel" in state["params"]["linear"] 30 | assert state["params"]["linear"]["kernel"].shape == in_shape + out_shape 31 | 32 | 33 | def test_hdense_sow(): 34 | model, state, in_shape, out_shape = make_hdense() 35 | x = np.random.normal(size=in_shape) 36 | y1 = model.apply(state, x) 37 | y2, aux = model.apply(state, x, mutable="intermediates") 38 | 39 | assert y1 == pytest.approx(y2) 40 | assert aux["intermediates"]["kernel_in"][0] == pytest.approx(x) 41 | 42 | 43 | def test_hdense_stats(): 44 | model, state, in_shape, out_shape = make_hdense() 45 | hutch_key = PRNGKey(0) 46 | x = jnp.array(np.random.normal(size=in_shape)) 47 | yct = np.random.normal(size=out_shape) 48 | fwd = partial(model.apply, mutable="kron", rngs={"hutchinson": hutch_key}) 49 | y, aux = fwd(state, x) 50 | M = jnp.identity(in_shape[0]) + jnp.outer(x, x) 51 | approx = partial(pytest.approx, rel=1e-4) 52 | assert "sherman" in aux["kron"] 53 | kron_in = aux["kron"]["sherman"]["in"] 54 | assert kron_in.direct == approx(M) 55 | assert kron_in.inv == approx(jnp.linalg.inv(M)) 56 | assert kron_in.chol == approx(jnp.linalg.cholesky(M)) 57 | calc_ichol = compose(jnp.linalg.cholesky, jnp.linalg.inv) 58 | assert kron_in.ichol == approx(calc_ichol(M)) 59 | 60 | 61 | def test_hdense_grad_perturb(): 62 | model, state, in_shape, out_shape = make_hdense() 63 | assert state["perturbations"]["sherman"]["out"] == pytest.approx( 64 | jnp.zeros(out_shape) 65 | ) 66 | x = np.random.normal(size=in_shape) 67 | yct = np.random.normal(size=out_shape) 68 | fwd = partial(model.apply, mutable="kron", rngs={"hutchinson": PRNGKey(0)}) 69 | y, VJP, aux = jax.vjp(fwd, state, x, has_aux=True) 70 | state_grad, x_grad = VJP(yct) 71 | assert state_grad["params"]["linear"]["kernel"] == pytest.approx(jnp.outer(x, yct)) 72 | assert state_grad["perturbations"]["sherman"]["out"] == pytest.approx(yct) 73 | 74 | 75 | def test_hdense_perturb_rank_one_update(): 76 | model, state, in_shape, out_shape = make_hdense() 77 | assert state["perturbations"]["sherman"]["out"] == pytest.approx( 78 | jnp.zeros(out_shape) 79 | ) 80 | x = np.random.normal(size=in_shape) 81 | yct = np.random.normal(size=out_shape) 82 | fwd = partial(model.apply, mutable="kron", rngs={"hutchinson": PRNGKey(0)}) 83 | y, VJP, aux = jax.vjp(fwd, state, x, has_aux=True) 84 | state = state.copy(aux) 85 | state_grad, x_grad = VJP(yct) 86 | 87 | state_and_pgrad = state.copy({"perturbations": state_grad["perturbations"]}) 88 | _, aux2 = model.apply(state_and_pgrad, method=model.update_kron_out, mutable="kron") 89 | 90 | M = jnp.identity(out_shape[0]) + jnp.outer(yct, yct) 91 | approx = partial(pytest.approx, rel=1e-4) 92 | kron_out = aux2["kron"]["sherman"]["out"] 93 | assert kron_out.direct == approx(M) 94 | assert kron_out.inv == approx(jnp.linalg.inv(M)) 95 | assert kron_out.chol == approx(jnp.linalg.cholesky(M)) 96 | calc_ichol = compose(jnp.linalg.cholesky, jnp.linalg.inv) 97 | assert kron_out.ichol == approx(calc_ichol(M)) 98 | 99 | 100 | def test_hperturb_transform(fn=jnp.tanh): 101 | approx = partial(pytest.approx, rel=1e-6, abs=1e-6) 102 | hfn = neural.hperturb(fn) 103 | key = jax.random.PRNGKey(seed=0) 104 | x = np.random.normal(size=(3,)) 105 | true_y = fn(x) 106 | yct = np.random.normal(size=x.shape) 107 | true_y, true_vjp = jax.vjp(fn, x) 108 | (true_grad,) = true_vjp(yct) 109 | mags = np.random.normal(size=x.shape) 110 | 111 | def calc(mag): 112 | my_y, my_vjp = jax.vjp(partial(hfn, key, mag), x) 113 | return my_y, my_vjp(yct)[0] 114 | 115 | my_y, my_grad = calc(jnp.zeros_like(mags)) 116 | assert my_y == approx(true_y) 117 | assert my_grad == approx(true_grad) 118 | 119 | _, (zero, my_vec) = jax.jvp(calc, (jnp.zeros_like(mags),), (mags,)) 120 | assert zero == approx(jnp.zeros_like(zero)) 121 | hess_diag = jax.vmap(jax.hessian(fn))(x) 122 | abs_hess_diag = jnp.abs(hess_diag * yct) 123 | assert jnp.square(my_vec) == approx(jnp.square(mags) * abs_hess_diag) 124 | 125 | 126 | def test_hperturb_cholesky(fn=jax.scipy.special.logsumexp): 127 | dim = 8 128 | approx = partial(pytest.approx, rel=1e-1, abs=1e-2) 129 | hfn = neural.hperturb(fn, elementwise=False) 130 | keys = jax.random.split(PRNGKey(0), 1024) 131 | x = np.random.normal(size=dim) 132 | 133 | def gen_outer(key): 134 | def calc(mag): 135 | return jax.grad(hfn, argnums=2)(key, mag, x) 136 | 137 | _, vec = jax.jvp(calc, (0.0,), (1.0,)) 138 | return jnp.outer(vec, vec) 139 | 140 | outers = jax.vmap(gen_outer)(keys) 141 | estimated_hess = jnp.mean(outers, axis=0) 142 | 143 | direct_hess = jax.hessian(fn)(x) 144 | assert estimated_hess == approx(direct_hess) 145 | 146 | 147 | def make_scanned_kronify(initial_count=1, hidden_dim=3): 148 | linear = nn.Dense(hidden_dim, use_bias=False) 149 | kronify = neural.Kronify(linear, nonlin=None) 150 | scanned = neural.ScannedKronify(linear, nonlin=None, initial_count=initial_count) 151 | x = np.random.normal(size=(hidden_dim,)) 152 | rngs = {"params": PRNGKey(0), "hutchinson": PRNGKey(0)} 153 | kvars = kronify.init(rngs, x) 154 | svars = scanned.init(rngs, x) 155 | return (kronify, kvars), (scanned, svars) 156 | 157 | 158 | @pytest.mark.parametrize("initial_count", [7, 0, 1]) 159 | def test_scanned_init(initial_count): 160 | (kronify, kvars), (scanned, svars) = make_scanned_kronify(initial_count) 161 | 162 | def check(k, s): 163 | assert (initial_count,) + k.shape == s.shape 164 | 165 | tree_map(check, kvars, svars) 166 | 167 | 168 | @pytest.mark.parametrize("initial_count", [7, 0, 1]) 169 | def test_scanned_apply(initial_count, hidden_dim=3): 170 | _, (scanned, svars) = make_scanned_kronify(initial_count, hidden_dim) 171 | x = np.random.normal(size=(hidden_dim,)) 172 | y = scanned.apply(svars, x) 173 | approx = partial(pytest.approx, rel=1e-4) 174 | if initial_count == 0: 175 | assert y == approx(x) 176 | 177 | 178 | def test_noisy_params(hidden_dim=3): 179 | (kronify, kvars), _ = make_scanned_kronify(1, hidden_dim) 180 | key1, key2 = jax.random.split(PRNGKey(0), 2) 181 | x = np.random.normal(size=(hidden_dim,)) 182 | y1 = kronify.apply(kvars, x, rngs={"noisy_params": key1}) 183 | y2 = kronify.apply(kvars, x, rngs={"noisy_params": key2}) 184 | y1_1 = kronify.apply(kvars, x, rngs={"noisy_params": key1}) 185 | approx = partial(pytest.approx, rel=1e-4) 186 | assert y1 == approx(y1_1) 187 | assert y1 != approx(y2) 188 | 189 | 190 | def make_kdense(*args, x, **kwargs): 191 | key = PRNGKey(0) 192 | model = neural.KDense(*args, **kwargs) 193 | variables = model.init(rngs={"params": key, "hutchinson": key}, x=x) 194 | return model, variables 195 | 196 | 197 | def make_vec(dim): 198 | return np.random.normal(size=(dim,)) 199 | 200 | 201 | def test_kdense_init(): 202 | x = make_vec(3) 203 | mdl, state = make_kdense(features=3, x=x) 204 | assert state["params"].keys() == {"kernel"} 205 | assert state["hutch_out"].keys() == {"kernel"} 206 | y, aux = mdl.apply(state, x, mutable="hutch_in", rngs={"hutchinson": PRNGKey(0)}) 207 | assert aux["hutch_in"].keys() == {"kernel"} 208 | -------------------------------------------------------------------------------- /senn_mlp/Dockerfile: -------------------------------------------------------------------------------- 1 | # Select the base image 2 | # FROM nvcr.io/nvidia/tensorflow:21.11-tf2-py3 3 | FROM nvcr.io/nvidia/tensorflow:23.04-tf2-py3 4 | 5 | # Select the working directory 6 | WORKDIR /senn 7 | 8 | # Install Python requirements 9 | COPY ./requirements.txt ./requirements.txt 10 | RUN pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 11 | RUN pip install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cu118 12 | RUN pip install -r requirements.txt 13 | -------------------------------------------------------------------------------- /senn_mlp/README.md: -------------------------------------------------------------------------------- 1 | # MLP Codebase 2 | This is the codebase for the MLP experiments. 3 | For a description of the setup and how to run experiments, see the README.md in the root directory. 4 | -------------------------------------------------------------------------------- /senn_mlp/build.sh: -------------------------------------------------------------------------------- 1 | docker build . -t egwene 2 | -------------------------------------------------------------------------------- /senn_mlp/checkpoints/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/self-expanding-neural-networks/3480be01cbbfa46726af3a84dd9fb834d1ca979e/senn_mlp/checkpoints/.placeholder -------------------------------------------------------------------------------- /senn_mlp/config.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | logdir: /logs 3 | name: untitled 4 | data: 5 | root: /datasets 6 | defaults: 7 | root: /datasets 8 | checkpointing: 9 | directory: /checkpoints 10 | -------------------------------------------------------------------------------- /senn_mlp/data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | 4 | import jax.numpy as jnp 5 | from torchvision.datasets import MNIST, FashionMNIST, CIFAR10 6 | import torchvision.transforms.functional as TF 7 | import torch.nn.functional as F 8 | 9 | from tqdm import tqdm 10 | 11 | def trfm(size, img, channels=1): 12 | t = TF.to_tensor(img) 13 | t = F.interpolate(t.unsqueeze(0), size=(size, size), mode="bilinear", align_corners=False) 14 | numpy_img = t.squeeze(0).permute(1,2,0).numpy() 15 | out = jnp.array(numpy_img) 16 | if out.shape[-1] != channels: 17 | if channels == 1: 18 | out = out.mean(axis=-1, keepdims=True) 19 | else: 20 | assert out.shape[-1] == 1, \ 21 | f"incompatible channel number: expected {channels} but got {out.shape[-1]}" 22 | out = jnp.tile(out, (3,)) 23 | return out 24 | 25 | def smallnist(N, classes, size, train=True, root="../datasets"): 26 | dataset = iter(MNIST(root, train=train, download=True, transform=partial(trfm, size))) 27 | imgs = [] 28 | labels = [] 29 | while len(imgs) < N: 30 | img, label = next(dataset) 31 | if label in classes: 32 | labels.append(classes.index(label)) 33 | imgs.append(img) 34 | else: 35 | continue 36 | return jnp.array(imgs), jnp.array(labels) 37 | 38 | def smallfnist(N, classes, size, train=True, root="../datasets"): 39 | dataset = iter(FashionMNIST(root, train=train, download=True, transform=partial(trfm, size))) 40 | imgs = [] 41 | labels = [] 42 | while len(imgs) < N: 43 | img, label = next(dataset) 44 | if label in classes: 45 | labels.append(classes.index(label)) 46 | imgs.append(img) 47 | pbar.update(1) 48 | else: 49 | continue 50 | return jnp.array(imgs), jnp.array(labels) 51 | 52 | def get_dataset(root, name, train, resolution): 53 | if name == "mnist": 54 | return MNIST(root, train=train, download=True, transform=partial(trfm, resolution)) 55 | elif name == "fmnist": 56 | return FashionMNIST(root, train=train, download=True, transform=partial(trfm, resolution)) 57 | elif name == "cifar10": 58 | return CIFAR10(root, train=train, download=True, transform=partial(trfm, resolution)) 59 | else: 60 | raise NotImplementedError(f"Dataset '{name}' not recognised.") 61 | 62 | def get_chunk(dataset, labels, remap, N, start=0): 63 | elems = tqdm(((d, remap(l)) for d, l in iter(dataset) if l in labels), total=N, desc="Compiling data tranch") 64 | imgs, labels = map(lambda gen: jnp.array(list(gen)), zip(*islice(elems, start, N))) 65 | return imgs, labels 66 | 67 | def cfg_tranch(defaults, tranch, resolution): 68 | def get(key): 69 | return tranch[key] if key in tranch else defaults[key].get() 70 | N = get('N') 71 | TN = get('TN') 72 | classes = get('classes') 73 | 74 | dataset = get('dataset') 75 | root = get('root') 76 | remap_val = get('remap') 77 | remap = lambda x: x if remap_val is None else remap_val[classes.index(x)] 78 | 79 | train = get_chunk(get_dataset(root, dataset, True, resolution), classes, remap, len(classes)*N) 80 | test = get_chunk(get_dataset(root, dataset, False, resolution), classes, remap, len(classes)*TN) 81 | 82 | return train, test 83 | 84 | def cfg_tranches(cfg, resolution): 85 | defaults = cfg['defaults'] 86 | return list(cfg_tranch(defaults, tranch, resolution) for tranch in cfg['tranches'].get()) 87 | 88 | def tranch_cat(tranches, index, train): 89 | tups = list(islice([tranch[0] if train else tranch[1] for tranch in tranches], index+1)) 90 | return tuple([jnp.concatenate(arrs, axis=0) for arrs in zip(*tups)]) 91 | -------------------------------------------------------------------------------- /senn_mlp/datasets/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/self-expanding-neural-networks/3480be01cbbfa46726af3a84dd9fb834d1ca979e/senn_mlp/datasets/.placeholder -------------------------------------------------------------------------------- /senn_mlp/experiment1/default_config.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | seed: 0 3 | propseed: 0 4 | task: 5 | type: regression 6 | out_size: 1 7 | in_size: 1 8 | hidden: [16, 16, 16, 16] 9 | data: 10 | N: 50 11 | TN: 1000 12 | net: 13 | capacities: [128] 14 | contents: [1] 15 | rational: true 16 | opt: 17 | max_epochs: 2000 18 | skip_first_n_batches: 20 19 | lr: 1.0e-1 20 | reduce_thresh: 1000 21 | max_accuracy: 0.99 22 | order: 100 23 | method: cg 24 | tikhonov: 1.0e-1 25 | batch_size: 50 26 | tau: null 27 | sqtau: 1000 28 | soln_tau: 29 | soln_sqtau: 30 | soln_adam: false 31 | 32 | weight_decay: 0.0e-3 33 | l2_regularization: 0.0e-3 34 | evo: 35 | pure_kfac: false 36 | proposals_per_layer: 100 37 | proposal_temperature: 1.0e+0 38 | layer_proposals_per_layer: 100 39 | steps: 300 40 | initial_lr: 3.0e-1 41 | thresh: 1.0e+0 42 | layer_cost_mul: 20. 43 | recursive: true 44 | abs_thresh: 2.5e-3 45 | layer_abs_thresh: 2.5e-3 46 | cooldown: 30 47 | layer_cooldown: 3 48 | size_costing: false 49 | layer_eig_floor: 0.01 50 | total_size_scaling: false 51 | metrics: 52 | activation_histograms: false 53 | gradient_histograms: false 54 | natgrad_histograms: false 55 | checkpointing: 56 | enable: True 57 | cooldown: 1 58 | restore: True 59 | -------------------------------------------------------------------------------- /senn_mlp/experiment2/default_config.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | seed: 0 3 | propseed: 0 4 | task: 5 | type: classification 6 | sklearn: True 7 | out_size: 1 8 | in_size: 2 9 | data: 10 | name: make_moons 11 | N: 200 12 | TN: 500 13 | noise: 0.3 14 | train_seed: 0 15 | test_seed: 1 16 | net: 17 | capacities: [16, 16, 16] 18 | contents: [null, null, null] 19 | rational: true 20 | opt: 21 | max_epochs: 400 22 | skip_first_n_batches: 20 23 | lr: 1.0e-1 24 | reduce_thresh: 1000 25 | max_accuracy: 0.99 26 | order: 100 27 | method: cg 28 | tikhonov: 1.0e-1 29 | batch_size: 200 30 | tau: null 31 | sqtau: 1000 32 | soln_tau: 33 | soln_sqtau: 34 | soln_adam: false 35 | 36 | weight_decay: 0.0e-3 37 | l2_regularization: 0.0e-3 38 | evo: 39 | pure_kfac: false 40 | proposals_per_layer: 100 41 | proposal_temperature: 1.0e+0 42 | layer_proposals_per_layer: 100 43 | steps: 300 44 | initial_lr: 3.0e-1 45 | thresh: 1.0e+0 46 | layer_cost_mul: 2. 47 | recursive: true 48 | abs_thresh: 2.5e-3 49 | layer_abs_thresh: 2.5e-3 50 | cooldown: 30 51 | layer_cooldown: 3 52 | size_costing: false 53 | layer_eig_floor: 0.01 54 | total_size_scaling: false 55 | metrics: 56 | activation_histograms: false 57 | gradient_histograms: false 58 | natgrad_histograms: false 59 | checkpointing: 60 | enable: True 61 | cooldown: 1 62 | restore: True 63 | -------------------------------------------------------------------------------- /senn_mlp/experiment3/default_config.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | seed: 2 3 | propseed: 0 4 | task: 5 | type: classification 6 | resolution: 24 7 | 8 | out_size: 1 9 | hidden: [16, 16, 16, 16] 10 | data: 11 | N: 300 12 | TN: 1000 13 | defaults: 14 | N: 6000 15 | TN: 1000 16 | classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 17 | dataset: mnist 18 | remap: null 19 | tranches: 20 | - {} 21 | net: 22 | capacities: [64, 64, 64] 23 | contents: [10, null, null] 24 | rational: true 25 | opt: 26 | max_epochs: 2000 27 | skip_first_n_batches: 20 28 | lr: 0.1 29 | reduce_thresh: 1000 30 | max_accuracy: 0.99 31 | order: 100 32 | method: cg 33 | tikhonov: 1.0e-1 34 | batch_size: 1024 35 | tau: null 36 | sqtau: 1000 37 | soln_tau: 38 | soln_sqtau: 39 | soln_adam: false 40 | 41 | weight_decay: 1.0e-3 42 | l2_regularization: 0.0e-3 43 | evo: 44 | pure_kfac: false 45 | proposals_per_layer: 10000 46 | layer_proposals_per_layer: 100 47 | steps: 10 48 | thresh: 0.007 49 | layer_cost_mul: 60. 50 | recursive: true 51 | abs_thresh: 2.5e-1 52 | layer_abs_thresh: 2.5e-1 53 | cooldown: 10 54 | layer_cooldown: 1 55 | size_costing: false 56 | layer_eig_floor: 0.01 57 | total_size_scaling: false 58 | metrics: 59 | activation_histograms: false 60 | gradient_histograms: false 61 | natgrad_histograms: false 62 | checkpointing: 63 | enable: True 64 | cooldown: 1 65 | restore: True 66 | -------------------------------------------------------------------------------- /senn_mlp/experiment4/default_config.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | seed: 0 3 | propseed: 0 4 | task: 5 | type: classification 6 | resolution: 28 7 | 8 | out_size: 1 9 | hidden: [16, 16, 16, 16] 10 | data: 11 | N: 300 12 | TN: 1000 13 | defaults: 14 | N: 4800 15 | TN: 1000 16 | classes: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 17 | dataset: mnist 18 | remap: null 19 | tranches: 20 | - {} 21 | net: 22 | capacities: [128] 23 | contents: [10] 24 | rational: true 25 | opt: 26 | max_epochs: 2000 27 | skip_first_n_batches: 20 28 | lr: 0.1 29 | reduce_thresh: 1000 30 | max_accuracy: 0.99 31 | order: 100 32 | method: cg 33 | tikhonov: 1.0e-1 34 | batch_size: 1024 35 | tau: null 36 | sqtau: 1000 37 | soln_tau: 38 | soln_sqtau: 39 | soln_adam: false 40 | 41 | weight_decay: 1.0e-3 42 | l2_regularization: 0.0e-3 43 | evo: 44 | pure_kfac: false 45 | proposals_per_layer: 100 46 | proposal_temperature: 1.0e1 47 | layer_proposals_per_layer: 100 48 | steps: 10 49 | thresh: 0.03 50 | layer_cost_mul: 20. 51 | recursive: true 52 | abs_thresh: 2.5e-3 53 | layer_abs_thresh: 2.5e-3 54 | cooldown: 10 55 | layer_cooldown: 3 56 | size_costing: false 57 | layer_eig_floor: 0.01 58 | total_size_scaling: false 59 | metrics: 60 | activation_histograms: false 61 | gradient_histograms: false 62 | natgrad_histograms: false 63 | checkpointing: 64 | enable: True 65 | cooldown: 1 66 | restore: True 67 | -------------------------------------------------------------------------------- /senn_mlp/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import confuse 3 | import argparse 4 | from rtpt import RTPT 5 | from datetime import datetime 6 | from tensorflow import summary 7 | 8 | def get_rtpt(name, max_iter=1): 9 | rtpt = RTPT(name_initials='RM', experiment_name=name, max_iterations=max_iter) 10 | rtpt.start() 11 | return rtpt 12 | 13 | def get_cfg(script_name, args=[]): 14 | """args: [('--argname', type, 'dest.var'), ...]""" 15 | 16 | cfg = confuse.Configuration('experiment') 17 | cfg.set_file(f'./{script_name}/default_config.yaml', base_for_paths=True) 18 | OVERRIDE = './config.yaml' 19 | if os.path.isfile(OVERRIDE): 20 | cfg.set_file(OVERRIDE, base_for_paths=True) 21 | parser = argparse.ArgumentParser() 22 | base_args = [ 23 | ('--name', str, 'meta.name'), 24 | ('--seed', int, 'meta.seed'), 25 | ('--epochs', int, 'opt.max_epochs'), 26 | ('--thresh', str, 'evo.thresh') 27 | ] 28 | for n, t, d in base_args + args: 29 | parser.add_argument(n, type=t, dest=d) 30 | cfg.set_args(parser.parse_args(), dots=True) 31 | return cfg 32 | 33 | def set_writer(cfg): 34 | now = datetime.now() 35 | exp_name = cfg['meta']['name'].get(f"{now.date()}_{now.time()}") 36 | logdir = cfg['meta']['logdir'].get() 37 | writer = summary.create_file_writer(f"{logdir}/{exp_name}") 38 | writer.set_as_default() 39 | return writer 40 | -------------------------------------------------------------------------------- /senn_mlp/jaxutils.py: -------------------------------------------------------------------------------- 1 | from jax import random 2 | 3 | def key_iter(seed=0): 4 | key = random.PRNGKey(seed) 5 | while True: 6 | key, key_ = random.split(key) 7 | yield key_ 8 | -------------------------------------------------------------------------------- /senn_mlp/langevin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | from jax import flatten_util 4 | import jax.numpy as jnp 5 | from jax.tree_util import tree_map as jtm, tree_reduce as jtr 6 | from compose import compose 7 | from functools import partial 8 | 9 | 10 | def tree_normal_like(tree, key): 11 | treevec, unflatten = jax.flatten_util.ravel_pytree(tree) 12 | noise = jax.random.normal(key, treevec.shape) 13 | return unflatten(noise) 14 | 15 | 16 | def tree_inner(tree1, tree2): 17 | prods = jtm(jnp.multiply, tree1, tree2) 18 | sumprods = jtm(jnp.sum, prods) 19 | return jtr(jnp.add, sumprods) 20 | 21 | 22 | def mala_step(lossgradfn, priorvar, old_state, key, lr, temp=1e0, legacy=False): 23 | def reglossgrad(state): 24 | loss, grad = lossgradfn(state) 25 | prior_loss = jtr(jnp.add, jtm(lambda s, p: jnp.sum(s**2 / p), state, priorvar)) 26 | return prior_loss + loss, grad 27 | 28 | noise_key, accept_key = jax.random.split(key) 29 | old_loss, old_grad = reglossgrad(old_state) 30 | noise = tree_normal_like(old_state, noise_key) 31 | 32 | def _delta(prior, state, grad, noise): 33 | return -lr * (grad / temp + state / prior) + jnp.sqrt(2. * lr) * noise * jnp.sqrt(prior) 34 | 35 | delta = lambda s, g, n: jtm(_delta, priorvar, s, g, n) 36 | 37 | half_delta = jtm(partial(jnp.multiply, 0.5), delta(old_state, old_grad, noise)) 38 | half_state = jtm(jnp.add, old_state, half_delta) 39 | half_loss, half_grad = reglossgrad(half_state) 40 | 41 | prop_state = jtm(jnp.add, old_state, delta(half_state, half_grad, noise)) 42 | prop_loss, prop_grad = reglossgrad(prop_state) 43 | 44 | mean_grad = jtm(compose(partial(jnp.multiply, 0.5), jnp.add), old_grad, prop_grad) 45 | 46 | actual_improvement = old_loss - prop_loss 47 | expected_improvement = -tree_inner(mean_grad, delta(old_state, mean_grad, noise)) 48 | if not legacy: 49 | improvement_gap = (actual_improvement - expected_improvement)/temp 50 | else: 51 | improvement_gap = actual_improvement - expected_improvement 52 | 53 | accept_prob = jnp.minimum(1., jnp.exp(improvement_gap)) 54 | accept_decision = jax.random.bernoulli(accept_key, p=accept_prob, shape=()) 55 | next_state = jtm(partial(jax.lax.select, accept_decision), prop_state, old_state) 56 | return next_state, accept_prob 57 | 58 | 59 | def mala_steps(lossgradfn, priorvar, state, key, lr, steps, temp=1e0, legacy=False): 60 | def f(state, key): 61 | return mala_step(lossgradfn, priorvar, state, key, lr, temp=temp, legacy=legacy) 62 | 63 | keys = jax.random.split(key, steps) 64 | final_state, probs = jax.lax.scan(f, init=state, xs=keys) 65 | accept_rate = jnp.mean(probs) 66 | return final_state, accept_rate 67 | 68 | 69 | def vgd_step(lossgradfn, old_state, lr): 70 | loss, grad = lossgradfn(old_state) 71 | delta = jtm(partial(jnp.multiply, -lr), grad) 72 | new_state = jtm(jnp.add, old_state, delta) 73 | return new_state 74 | -------------------------------------------------------------------------------- /senn_mlp/logs/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/self-expanding-neural-networks/3480be01cbbfa46726af3a84dd9fb834d1ca979e/senn_mlp/logs/.placeholder -------------------------------------------------------------------------------- /senn_mlp/nets.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Sequence, Optional 2 | from abc import ABC, abstractmethod 3 | from functools import partial 4 | from itertools import count 5 | import numpy as np 6 | 7 | import jax 8 | from jax import lax, random, numpy as jnp 9 | from jax.tree_util import tree_map as jtm, tree_reduce as jtr 10 | import flax 11 | from flax import linen as nn 12 | from flax.linen import initializers 13 | from flax.traverse_util import t_identity as trav_identity 14 | 15 | def_prec = 'highest' 16 | 17 | 18 | 19 | def pad_axis(arr, count, axis=0): 20 | if count == 0: 21 | return arr 22 | assert count > 0 23 | pad_shape = arr.shape[:axis] + (count,) + arr.shape[axis:][1:] 24 | return jnp.concatenate([arr, jnp.zeros(pad_shape)], axis=axis) 25 | 26 | def pad_target(current): 27 | return int(2**jnp.ceil(jnp.log2(current+1))) 28 | 29 | def update_dict(target, args, func, is_leaf): 30 | for key, value in args.items(): 31 | if is_leaf(value): 32 | target[key] = func(target[key], value) 33 | else: 34 | target[key] = update_dict(target[key], value, func, is_leaf) 35 | return target 36 | 37 | 38 | 39 | class Identity(nn.Module): 40 | @nn.compact 41 | def __call__(self, x): 42 | return x 43 | 44 | 45 | 46 | class Passive(nn.Module): 47 | 48 | def insert_feature(self, feature, index, origin=0): 49 | func = lambda v, f: v.at[..., index].set(f[..., origin]) 50 | return self.variables.copy({ 51 | 'params': jtm(func, self.variables['params'].unfreeze(), feature['params']) 52 | }) 53 | 54 | def pad_features(self, count): 55 | trav = trav_identity['params'].tree() 56 | func = lambda p: pad_axis(p, count, axis=-1) 57 | return trav.update(func, self.variables.unfreeze()) 58 | 59 | def null(self): 60 | return jnp.bool_(True) 61 | 62 | def restrict_params(self, in_dim, out_dim): 63 | params = self.variables["params"] 64 | restrict = lambda arr: arr.at[..., out_dim:].set(0.) 65 | return jtm(restrict, params) 66 | 67 | 68 | 69 | class DConv(Passive): 70 | kernel_size: int = 3 71 | init_identity: bool = False 72 | #N.B. kernel param has shape [3,3,1,F] 73 | 74 | def identity_kernel(self, F): 75 | K = jnp.zeros((self.kernel_size, self.kernel_size, 1, F)) 76 | cidx = self.kernel_size//2 77 | return K.at[cidx, cidx,...].set(1.) 78 | 79 | @nn.compact 80 | def __call__(self, x): 81 | dim = x.shape[-1] 82 | kernel = (self.kernel_size,)*2 83 | if self.init_identity: 84 | init = lambda rng, shape, dtype: self.identity_kernel(x.shape[-1]) 85 | x = nn.Conv(dim, kernel, padding='SAME', feature_group_count=dim, kernel_init=init, precision=def_prec)(x) 86 | else: 87 | lec_n = initializers.lecun_uniform() 88 | unif = initializers.uniform() 89 | def init(rng, shape, dtype): 90 | default = lec_n(rng, shape, dtype) 91 | return default / jnp.sqrt((default**2).sum(axis=-2, keepdims=True)) 92 | x = nn.Conv(dim, kernel, padding='SAME', feature_group_count=dim, kernel_init=init, bias_init=unif, precision=def_prec)(x) 93 | return x 94 | 95 | def identity_params(self, F): 96 | params = self.variables["params"] 97 | return {'Conv_0': { 98 | 'bias': jnp.zeros(F), 99 | 'kernel': self.identity_kernel(F) 100 | }} 101 | 102 | 103 | 104 | class Rational1D(Passive): 105 | residual: bool = True 106 | init_identity: bool = False 107 | epsilon: float = 1e-8 108 | 109 | @nn.compact 110 | def __call__(self, x): 111 | init = initializers.zeros if self.init_identity else random.normal 112 | v = self.param('w_vec', init, x.shape[-1:]) 113 | c = self.param('w_const', init, x.shape[-1:]) 114 | sumsq = v**2 + c**2 + self.epsilon 115 | if self.residual: 116 | init = initializers.ones if self.init_identity else random.normal 117 | l = self.param('w_lin', init, x.shape[-1:]) 118 | sumsq = sumsq + l**2 119 | norm = jnp.sqrt(sumsq) 120 | #DISABLE NORMING: 121 | # norm = jnp.ones_like(norm) 122 | den = 1/(1 + x**2) 123 | self.sow("intermediates", "lin", x) 124 | self.sow("intermediates", "odd", 2.*x*den) 125 | self.sow("intermediates", "even", den) 126 | num = 2.*v*x + c 127 | out = num*den 128 | if self.residual: 129 | out = out + l*x 130 | out = out / norm 131 | self.sow("intermediates", "activations", out) 132 | return out 133 | 134 | def null(self): 135 | null = lambda name: self.variables["params"][name] == 0. 136 | return null('w_vec') & null('w_const') & null('w_lin') 137 | 138 | def identity_params(self, F, rank=None): 139 | params = { 140 | 'w_vec': jnp.zeros(F), 141 | 'w_const': jnp.zeros(F), 142 | 'w_lin': jnp.ones(F) 143 | } 144 | if rank is not None: 145 | params['w_lin'] = params['w_lin'].at[rank:].set(0.) 146 | return params 147 | 148 | 149 | 150 | 151 | class Active(ABC, Passive): 152 | @abstractmethod 153 | def pad_inputs(self, count): 154 | pass 155 | 156 | @abstractmethod 157 | def input_size(self): 158 | pass 159 | 160 | @abstractmethod 161 | def output_size(self): 162 | pass 163 | 164 | def pad_pow2(self, override=False): 165 | null = self.null() & ~jnp.bool_(override) 166 | target = pad_target((~null).sum()) 167 | current = len(null) 168 | deficit = target - current 169 | if deficit > 0: 170 | params = self.pad_features(deficit) 171 | return params, target 172 | else: 173 | return self.variables, current 174 | 175 | 176 | 177 | class Splittable(Active): 178 | @abstractmethod 179 | def split_basis_size(self): 180 | pass 181 | 182 | @abstractmethod 183 | def split_params(self, basis, shift): 184 | pass 185 | 186 | # @abstractmethod 187 | # def split_module(self): 188 | # pass 189 | 190 | 191 | 192 | class DDense(nn.Dense, Splittable): 193 | init_zero: bool = False 194 | def __init__(self, features, init_zero = False): 195 | if init_zero: 196 | super().__init__(features, kernel_init=initializers.zeros, precision=def_prec) 197 | else: 198 | super().__init__(features, bias_init=initializers.normal(1e0), precision=def_prec) 199 | # init = initializers.lecun_uniform() 200 | # unif = initializers.uniform() 201 | # super().__init__(features, bias_init=unif, kernel_init=init, precision=def_prec) 202 | 203 | def __call__(self, x): 204 | self.sow("intermediates", "preactivations", x) 205 | x = nn.Dense.__call__(self, x) 206 | return x 207 | 208 | def pad_inputs(self, count): 209 | return self.variables.copy({ 210 | "params": self.variables["params"].copy({ 211 | "kernel": pad_axis(self.variables["params"]["kernel"], count, axis=-2) 212 | }) 213 | }) 214 | 215 | def input_size(self): 216 | return self.variables["params"]["kernel"].shape[-2] 217 | 218 | def output_size(self): 219 | return self.variables["params"]["kernel"].shape[-1] 220 | 221 | def null(self): 222 | null = (self.variables["params"]["kernel"] == 0.).all(axis=-2) 223 | return null & (self.variables["params"]["bias"] == 0.) 224 | 225 | def split_basis_size(self): 226 | return min(self.input_size(), self.output_size()) 227 | 228 | # def split_params(self, basis, shift, kill_in, kill_out): 229 | # old = self.variables["params"]["kernel"] 230 | # in_dim, out_dim = old.shape 231 | # # basis_inv = jnp.linalg.pinv(basis) 232 | # # if in_dim-kill_in < out_dim-kill_out: 233 | # if in_dim < out_dim: 234 | # if kill_in > 0: 235 | # basis = basis.at[-kill_in:,:].set(0.) 236 | # basis = basis.at[:,-kill_in:].set(0.) 237 | # shift = shift.at[-kill_in:].set(0.) 238 | # basis_inv = jnp.linalg.pinv(basis) 239 | # first = basis 240 | # second = basis_inv@old 241 | # intermediate = in_dim 242 | # else: 243 | # if kill_out > 0: 244 | # basis = basis.at[-kill_out:,:].set(0.) 245 | # basis = basis.at[:,-kill_out:].set(0.) 246 | # shift = shift.at[-kill_out:].set(0.) 247 | # basis_inv = jnp.linalg.pinv(basis) 248 | # first = old@basis 249 | # second = basis_inv 250 | # intermediate = out_dim 251 | # first_params = { 252 | # "bias": shift, 253 | # "kernel": first 254 | # } 255 | # second_params = { 256 | # "bias": self.variables["params"]["bias"] - shift@second, 257 | # "kernel": second 258 | # } 259 | # return first_params, second_params 260 | 261 | def split_params(self, basis, inv_basis, shift): 262 | old = self.variables["params"]["kernel"] 263 | first = basis 264 | # print(old.shape) 265 | # print(inv_basis.shape) 266 | second = inv_basis@old[:inv_basis.shape[1],:] 267 | first_params = { 268 | "bias": shift, 269 | "kernel": first, 270 | } 271 | second_params = { 272 | "bias": self.variables["params"]["bias"] - shift@second, 273 | "kernel": second 274 | } 275 | return first_params, second_params 276 | 277 | def split_module(self): 278 | intermediate = min(self.input_size(), self.output_size()) 279 | return DDense(intermediate), DDense(self.output_size()) 280 | 281 | def identity_params(self): 282 | params = self.variables["params"] 283 | old_kernel = params["kernel"] 284 | diag = jnp.diag_indices(min(old_kernel.shape)) 285 | return params.copy({ 286 | "bias": jnp.zeros_like(params["bias"]), 287 | "kernel": jnp.zeros_like(old_kernel).at[diag].set(1.) 288 | }) 289 | 290 | def restrict_params(self, in_dim, out_dim): 291 | params = super().restrict_params(in_dim, out_dim) 292 | # print(params) 293 | return params.copy({ 294 | "kernel": params["kernel"].at[in_dim:, :].set(0.) 295 | }) 296 | 297 | def extract_activations(self): 298 | return self.variables["intermediates"]["preactivations"] 299 | 300 | def extract_out_grads(self): 301 | return self.variables["params"]["bias"] 302 | 303 | 304 | 305 | 306 | class Layer(Splittable): 307 | features: Optional[int] 308 | make_invariant: Sequence[nn.Module] = () 309 | make_linear: Callable[[int], Active] = DDense 310 | make_equivariant: Sequence[Callable[[], Passive]] = (partial(Rational1D, True),) 311 | 312 | def setup(self): 313 | F = self.features 314 | self.linear = None if F is None else self.make_linear(F) 315 | self.equivariant = None if F is None else [func() for func in self.make_equivariant] 316 | self.invariant = [func() for func in self.make_invariant] 317 | 318 | nn.nowrap 319 | def dormant(self): 320 | return self.features is None 321 | 322 | def __call__(self, x): 323 | if self.features is not None: 324 | for module in self.invariant: 325 | x = module(x) 326 | x = self.linear(x) 327 | for module in self.equivariant: 328 | x = module(x) 329 | return x 330 | 331 | def pad_inputs(self, count): 332 | return self.variables.copy({ 333 | "params": self.variables["params"].copy({ 334 | "linear": self.linear.pad_inputs(count)["params"] 335 | }) 336 | }) 337 | 338 | def null(self): 339 | if self.dormant(): 340 | return None 341 | else: 342 | out = self.linear.null() 343 | for i, module in enumerate(self.equivariant): 344 | if f'equivariant_{i}' in self.variables['params']: 345 | out = out & module.null() 346 | return out 347 | 348 | def input_size(self): 349 | return self.linear.input_size() 350 | 351 | def output_size(self): 352 | return self.linear.output_size() 353 | 354 | def split_basis_size(self): 355 | return self.linear.split_basis_size() 356 | 357 | # def split_params(self, basis, shift, kill_in, kill_out): 358 | # first, second = self.linear.split_params(basis, shift, kill_in, kill_out) 359 | # # F = first['kernel'].shape[-1] 360 | # # equivariant_params = [ 361 | # # eq.identity_params(F) for eq in self.equivariant 362 | # # ] 363 | # # first = self.variables["params"].copy({ 364 | # # "linear": first 365 | # # }) 366 | # # for i, p in enumerate(equivariant_params): 367 | # # first = first.copy({ 368 | # # f"equivariant_{i}": p 369 | # # }) 370 | # second = self.variables["params"].copy({ 371 | # "linear": second 372 | # }) 373 | # # assert "invariant_0" not in self.variables["params"] 374 | # return first, second 375 | 376 | def split_params(self, basis, inv_basis, shift): 377 | first, second = self.linear.split_params(basis, inv_basis, shift) 378 | second = self.variables["params"].copy({ 379 | "linear": second 380 | }) 381 | return first, second 382 | 383 | def split_identity(self, linear_params, rank=None): 384 | params = self.variables['params'] 385 | params = params.copy({ 386 | 'linear': linear_params, 387 | }) 388 | for i, eq in enumerate(self.equivariant): 389 | params = params.copy({ 390 | f'equivariant_{i}': eq.identity_params(self.features, rank=rank) 391 | }) 392 | assert 'invariant_0' not in params 393 | return params 394 | 395 | def restrict_params(self, in_dim, out_dim): 396 | params = self.variables["params"] 397 | params = params.copy({ 398 | 'linear': self.linear.restrict_params(in_dim, out_dim) 399 | }) 400 | params = params.copy({ 401 | f'equivariant_{i}': eq.restrict_params(in_dim, out_dim) \ 402 | for i, eq in enumerate(self.equivariant) \ 403 | if f'equivariant_{i}' in params 404 | }) 405 | assert 'invariant_0' not in params 406 | return params 407 | 408 | def identity_params(self): 409 | return self.split_identity(self.linear.identity_params()) 410 | 411 | 412 | 413 | class Layers(nn.Module): 414 | layers: Sequence[Layer] 415 | 416 | def features(self): 417 | return [layer.features for layer in self.layers if layer.features] 418 | 419 | def nulls(self): 420 | return [layer.null() for layer in self.layers if layer.features] 421 | 422 | def lift(self, index, func, *args): 423 | return func(self.layers[index], *args) 424 | 425 | def __call__(self, x): 426 | for L in self.layers: 427 | x = L(x) 428 | return x 429 | 430 | def pad_features(self, index, count): 431 | variables = self.variables 432 | def layer_update(i, new_params): 433 | nonlocal variables 434 | variables = variables.copy({ 435 | 'params': variables['params'].copy({ 436 | f'layers_{i}': new_params['params'] 437 | }) 438 | }) 439 | layer_update(index, self.layers[index].pad_features(count)) 440 | subsequent = index + 1 441 | if subsequent < len(self.layers): 442 | layer_update(subsequent, self.layers[subsequent].pad_inputs(count)) 443 | return variables 444 | 445 | def pad_pow2(self, index, override=False): 446 | features = self.features() 447 | increase = None 448 | old = features[index] 449 | updated, features[index] = self.layers[index].pad_pow2(override) 450 | variables = self.variables 451 | def layer_update(i, new_params): 452 | nonlocal variables 453 | variables = variables.copy({ 454 | "params": variables["params"].copy({ 455 | f"layers_{i}": new_params["params"] 456 | }) 457 | }) 458 | layer_update(index, updated) 459 | layer_update(index+1, self.layers[index+1].pad_inputs(features[index] - old)) 460 | return variables, features 461 | 462 | def split_layer(self, index, basis, shift, inserter): 463 | kill_out = sum(self.layers[index].null()) 464 | kill_in = sum(self.layers[index-1].null()) if index > 0 else 0 465 | first_linear, second_layer = self.layers[index].split_params(basis, shift, kill_in, kill_out) 466 | first_layer = inserter(first_linear) 467 | params = self.variables['params'] 468 | for i in reversed(range(index+1, len(self.layers))): 469 | params = params.copy({ 470 | f'layers_{i+1}': params[f'layers_{i}'] 471 | }) 472 | params = params.copy({ 473 | f'layers_{index}': first_layer, 474 | f'layers_{index+1}': second_layer 475 | }) 476 | return params 477 | 478 | @nn.nowrap 479 | @staticmethod 480 | def get_nearest_active(dims, index): 481 | assert dims[index] is None 482 | for i, dim in enumerate(dims[index:]): 483 | if dim is not None: 484 | succeeding = index + i 485 | break 486 | preceding = None 487 | for i, dim in enumerate(reversed(dims[:index+1])): 488 | if dim is not None: 489 | preceding = index - i 490 | break 491 | return preceding, succeeding 492 | 493 | def activate_layer(self, dims, index, basis, shift): 494 | preceding, succeeding = self.get_nearest_active(dims, index) 495 | input_dim = self.layers[0].linear.input_size() 496 | # assert input_dim == 1 497 | new_rank = sum(~self.layers[preceding].null()) if preceding is not None else input_dim 498 | new_input_size = self.layers[index-1].linear.output_size() if index != 0 else input_dim 499 | basis = basis.at[new_rank:,:].set(0.) 500 | basis = basis.at[:,new_rank:].set(0.) 501 | shift = shift.at[new_rank:].set(0.) 502 | inv_basis = jnp.linalg.pinv(basis) 503 | basis = basis[:new_input_size,:] 504 | # kill_in = sum(self.layers[preceding].null()) if preceding is not None else 0 505 | # kill_out = sum(self.layers[succeeding].null()) 506 | # print(inv_basis.shape) 507 | new_linear, new_succeeding = self.layers[succeeding] \ 508 | .split_params(basis, 509 | inv_basis, 510 | shift) 511 | new_layer = self.layers[index].split_identity(new_linear, rank=new_rank) 512 | 513 | return self.variables.copy({ 514 | 'params': self.variables['params'].copy({ 515 | f'layers_{index}': new_layer, 516 | f'layers_{succeeding}': new_succeeding 517 | }) 518 | }) 519 | 520 | def restrict_params(self, dims): 521 | state = self.variables 522 | params = state["params"] 523 | 524 | def modify(layer, in_dim, out_dim): 525 | if out_dim is None: 526 | return layer.identity_params() 527 | else: 528 | return layer.restrict_params(in_dim, out_dim) 529 | 530 | def get_key(i): 531 | return f"layers_{i}" 532 | 533 | input_size = self.layers[0].input_size() 534 | in_dims = [input_size] + dims[:-1] 535 | out_dims = dims 536 | 537 | return state.copy({ 538 | "params": params.copy({ 539 | get_key(i): modify(layer, in_dim, out_dim) \ 540 | for i, (layer, in_dim, out_dim) in enumerate(zip(self.layers, in_dims, out_dims)) 541 | }) 542 | }) 543 | 544 | def restrict_grad(self, dims): 545 | params = self.variables["params"] 546 | # print(params) 547 | 548 | input_size = self.layers[0].input_size() 549 | # in_dims = [input_size] + dims[:-1] 550 | in_dims = [] 551 | carry = input_size 552 | for out_dim in dims: 553 | if out_dim is None: 554 | in_dims.append(None) 555 | else: 556 | in_dims.append(carry) 557 | carry = out_dim 558 | 559 | for i, (in_dim, out_dim) in enumerate(zip(in_dims, dims)): 560 | key = f"layers_{i}" 561 | if out_dim is None: 562 | params = params.copy({ 563 | key: jtm(lambda arr: arr*0., params[key]) 564 | }) 565 | else: 566 | params = params.copy({ 567 | key: self.layers[i].restrict_params(in_dim, out_dim) 568 | }) 569 | # print(params) 570 | # print(dims) 571 | # print('Done!') 572 | grad = self.variables.copy({ 573 | "params": params 574 | }) 575 | return grad 576 | 577 | def extract_activations(self): 578 | return [layer.linear.extract_activations() for layer in self.layers] 579 | 580 | def extract_out_grads(self): 581 | return [layer.linear.extract_out_grads() for layer in self.layers] 582 | 583 | 584 | 585 | 586 | class Seq(nn.Module): 587 | modules: Sequence[nn.Module] 588 | 589 | @nn.compact 590 | def __call__(self, x): 591 | for m in modules: 592 | x = m(x) 593 | return x 594 | 595 | 596 | 597 | def push_tangent(func, t, x): 598 | return jax.jvp(func, [x], [t]) 599 | 600 | def push_curvature(func, t, x): 601 | return jax.jvp(partial(push_tangent, func, t), [x], [t]) 602 | 603 | def reject_from(a, b): 604 | # return the component of a orthogonal to b 605 | return a - b * jnp.sum(a*b) / jnp.sum(b*b) 606 | 607 | def combine(a, b, c): 608 | # return coefficients for 'a' and 'b' to produce least squares solution for c 609 | # only accurate up to a scale factor 610 | alpha = (b*b).sum() * (a*c).sum() - (a*b).sum() * (b*c).sum() 611 | beta = (a*a).sum() * (b*c).sum() - (a*b).sum() * (a*c).sum() 612 | return alpha, beta 613 | 614 | def tree_length(v): 615 | tsq = jtm(lambda v: jnp.sum(v**2), v) 616 | ssq = jtr(lambda a, b: a + b, tsq) 617 | return jnp.sqrt(ssq) 618 | 619 | def pass_one(func, lossfn, params, p, x): 620 | func_x = lambda params: func(params, x) 621 | (y, Jp),(_, J2p) = push_curvature(func_x, p, params) 622 | # y, JT = jax.vjp(func_x, params) 623 | (loss, dL), (Jp_prod_dL, dLdJp) = jax.jvp(jax.value_and_grad(lossfn), [y], [Jp]) 624 | 625 | dL_p = Jp_prod_dL 626 | d2L_p = jnp.sum(J2p*dL) + jnp.sum(Jp*dLdJp) 627 | 628 | return y, loss, dL_p, d2L_p, dL, Jp, Jp_prod_dL#, JT 629 | 630 | def pass_two(func, params, err, x): 631 | func_x = lambda params: func(params, x) 632 | _, JT = jax.vjp(func_x, params) 633 | return JT(err) 634 | 635 | def pass_three(func, params, p, x): 636 | func_x = lambda params: func(params, x) 637 | return push_tangent(func_x, p[0], params)[1] 638 | 639 | def batch_ng(func, lossfn, metric, params, x, labels, p, axis=0): 640 | y, loss, dL_p, d2L_p, dL, Jp, corr = jax.vmap(lambda x, l: pass_one(func, partial(lossfn, l), params, p, x), 0)(x, labels) 641 | 642 | orth_grad = reject_from(dL, Jp) 643 | im_reject = jax.vmap(partial(pass_two, func, params), 0)(orth_grad, x) 644 | delta_p = jtm(lambda v: jnp.sum(v, axis=0), im_reject) 645 | delta_Jp = jax.vmap(partial(pass_three, func, params, delta_p), 0)(x) 646 | 647 | alpha, beta = combine(Jp, delta_Jp, dL) 648 | pu_rescale = beta/alpha * tree_length(p) 649 | p_update = jtm(lambda v: v * pu_rescale, delta_p) 650 | 651 | #curv = len(d2L_p)*jnp.sqrt(jnp.mean(d2L_p**2)) 652 | curv = jnp.abs(jnp.sum(d2L_p)) 653 | #curv = jnp.sum(jnp.abs(d2L_p)) 654 | p_rescale = -jnp.sum(dL_p) / curv 655 | param_update = jtm(lambda v: v * p_rescale, p) 656 | 657 | speed = jnp.sqrt((Jp**2).sum()) 658 | utility = corr.sum() / jnp.sqrt((dL**2).sum()) * speed 659 | distance = speed / p_rescale 660 | 661 | avg_loss = jnp.mean(loss) 662 | if metric is None: 663 | return avg_loss, param_update, p_update, utility 664 | else: 665 | return avg_loss, param_update, p_update, utility, metric(y, labels), curv, distance 666 | 667 | #goodness = dL . Jp / (|dL| |Jp|) 668 | -------------------------------------------------------------------------------- /senn_mlp/nets_legacy.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Sequence, Optional 2 | from abc import ABC, abstractmethod 3 | from functools import partial 4 | from itertools import count 5 | import numpy as np 6 | 7 | import jax 8 | from jax import lax, random, numpy as jnp 9 | from jax.tree_util import tree_map as jtm, tree_reduce as jtr 10 | import flax 11 | from flax import linen as nn 12 | from flax.linen import initializers 13 | from flax.traverse_util import t_identity as trav_identity 14 | 15 | def_prec = 'highest' 16 | 17 | 18 | 19 | def pad_axis(arr, count, axis=0): 20 | if count == 0: 21 | return arr 22 | assert count > 0 23 | pad_shape = arr.shape[:axis] + (count,) + arr.shape[axis:][1:] 24 | return jnp.concatenate([arr, jnp.zeros(pad_shape)], axis=axis) 25 | 26 | def pad_target(current): 27 | return int(2**jnp.ceil(jnp.log2(current+1))) 28 | 29 | def update_dict(target, args, func, is_leaf): 30 | for key, value in args.items(): 31 | if is_leaf(value): 32 | target[key] = func(target[key], value) 33 | else: 34 | target[key] = update_dict(target[key], value, func, is_leaf) 35 | return target 36 | 37 | 38 | 39 | class Identity(nn.Module): 40 | @nn.compact 41 | def __call__(self, x): 42 | return x 43 | 44 | 45 | 46 | class Passive(nn.Module): 47 | 48 | def insert_feature(self, feature, index, origin=0): 49 | func = lambda v, f: v.at[..., index].set(f[..., origin]) 50 | return self.variables.copy({ 51 | 'params': jtm(func, self.variables['params'].unfreeze(), feature['params']) 52 | }) 53 | 54 | def pad_features(self, count): 55 | trav = trav_identity['params'].tree() 56 | func = lambda p: pad_axis(p, count, axis=-1) 57 | return trav.update(func, self.variables.unfreeze()) 58 | 59 | def null(self): 60 | return jnp.bool_(True) 61 | 62 | def restrict_params(self, dim): 63 | params = self.variables["params"] 64 | restrict = lambda arr: arr.at[..., dim:].set(0.) 65 | return jtm(restrict, params) 66 | 67 | 68 | 69 | class DConv(Passive): 70 | kernel_size: int = 3 71 | init_identity: bool = False 72 | #N.B. kernel param has shape [3,3,1,F] 73 | 74 | def identity_kernel(self, F): 75 | K = jnp.zeros((self.kernel_size, self.kernel_size, 1, F)) 76 | cidx = self.kernel_size//2 77 | return K.at[cidx, cidx,...].set(1.) 78 | 79 | @nn.compact 80 | def __call__(self, x): 81 | dim = x.shape[-1] 82 | kernel = (self.kernel_size,)*2 83 | if self.init_identity: 84 | init = lambda rng, shape, dtype: self.identity_kernel(x.shape[-1]) 85 | x = nn.Conv(dim, kernel, padding='SAME', feature_group_count=dim, kernel_init=init, precision=def_prec)(x) 86 | else: 87 | lec_n = initializers.lecun_uniform() 88 | unif = initializers.uniform() 89 | def init(rng, shape, dtype): 90 | default = lec_n(rng, shape, dtype) 91 | return default / jnp.sqrt((default**2).sum(axis=-2, keepdims=True)) 92 | x = nn.Conv(dim, kernel, padding='SAME', feature_group_count=dim, kernel_init=init, bias_init=unif, precision=def_prec)(x) 93 | return x 94 | 95 | def identity_params(self, F): 96 | params = self.variables["params"] 97 | return {'Conv_0': { 98 | 'bias': jnp.zeros(F), 99 | 'kernel': self.identity_kernel(F) 100 | }} 101 | 102 | 103 | 104 | class Rational1D(Passive): 105 | residual: bool = True 106 | init_identity: bool = False 107 | epsilon: float = 1e-8 108 | 109 | @nn.compact 110 | def __call__(self, x): 111 | init = initializers.zeros if self.init_identity else random.normal 112 | v = self.param('w_vec', init, x.shape[-1:]) 113 | c = self.param('w_const', init, x.shape[-1:]) 114 | sumsq = v**2 + c**2 + self.epsilon 115 | if self.residual: 116 | init = initializers.ones if self.init_identity else random.normal 117 | l = self.param('w_lin', init, x.shape[-1:]) 118 | sumsq = sumsq + l**2 119 | norm = jnp.sqrt(sumsq) 120 | #DISABLE NORMING: 121 | # norm = jnp.ones_like(norm) 122 | den = 1/(1 + x**2) 123 | self.sow("intermediates", "lin", x) 124 | self.sow("intermediates", "odd", 2.*x*den) 125 | self.sow("intermediates", "even", den) 126 | num = 2.*v*x + c 127 | out = num*den 128 | if self.residual: 129 | out = out + l*x 130 | out = out / norm 131 | self.sow("intermediates", "activations", out) 132 | return out 133 | 134 | def null(self): 135 | null = lambda name: self.variables["params"][name] == 0. 136 | return null('w_vec') & null('w_const') & null('w_lin') 137 | 138 | def identity_params(self, F, rank=None): 139 | params = { 140 | 'w_vec': jnp.zeros(F), 141 | 'w_const': jnp.zeros(F), 142 | 'w_lin': jnp.ones(F) 143 | } 144 | if rank is not None: 145 | params['w_lin'] = params['w_lin'].at[rank:].set(0.) 146 | return params 147 | 148 | 149 | 150 | 151 | class Active(ABC, Passive): 152 | @abstractmethod 153 | def pad_inputs(self, count): 154 | pass 155 | 156 | @abstractmethod 157 | def input_size(self): 158 | pass 159 | 160 | @abstractmethod 161 | def output_size(self): 162 | pass 163 | 164 | def pad_pow2(self, override=False): 165 | null = self.null() & ~jnp.bool_(override) 166 | target = pad_target((~null).sum()) 167 | current = len(null) 168 | deficit = target - current 169 | if deficit > 0: 170 | params = self.pad_features(deficit) 171 | return params, target 172 | else: 173 | return self.variables, current 174 | 175 | 176 | 177 | class Splittable(Active): 178 | @abstractmethod 179 | def split_basis_size(self): 180 | pass 181 | 182 | @abstractmethod 183 | def split_params(self, basis, shift): 184 | pass 185 | 186 | # @abstractmethod 187 | # def split_module(self): 188 | # pass 189 | 190 | 191 | 192 | class DDense(nn.Dense, Splittable): 193 | init_zero: bool = False 194 | def __init__(self, features, init_zero = False): 195 | if init_zero: 196 | super().__init__(features, kernel_init=initializers.zeros, precision=def_prec) 197 | else: 198 | super().__init__(features, bias_init=initializers.normal(1e0), precision=def_prec) 199 | # init = initializers.lecun_uniform() 200 | # unif = initializers.uniform() 201 | # super().__init__(features, bias_init=unif, kernel_init=init, precision=def_prec) 202 | 203 | def __call__(self, x): 204 | self.sow("intermediates", "preactivations", x) 205 | x = nn.Dense.__call__(self, x) 206 | return x 207 | 208 | def pad_inputs(self, count): 209 | return self.variables.copy({ 210 | "params": self.variables["params"].copy({ 211 | "kernel": pad_axis(self.variables["params"]["kernel"], count, axis=-2) 212 | }) 213 | }) 214 | 215 | def input_size(self): 216 | return self.variables["params"]["kernel"].shape[-2] 217 | 218 | def output_size(self): 219 | return self.variables["params"]["kernel"].shape[-1] 220 | 221 | def null(self): 222 | null = (self.variables["params"]["kernel"] == 0.).all(axis=-2) 223 | return null & (self.variables["params"]["bias"] == 0.) 224 | 225 | def split_basis_size(self): 226 | return min(self.input_size(), self.output_size()) 227 | 228 | # def split_params(self, basis, shift, kill_in, kill_out): 229 | # old = self.variables["params"]["kernel"] 230 | # in_dim, out_dim = old.shape 231 | # # basis_inv = jnp.linalg.pinv(basis) 232 | # # if in_dim-kill_in < out_dim-kill_out: 233 | # if in_dim < out_dim: 234 | # if kill_in > 0: 235 | # basis = basis.at[-kill_in:,:].set(0.) 236 | # basis = basis.at[:,-kill_in:].set(0.) 237 | # shift = shift.at[-kill_in:].set(0.) 238 | # basis_inv = jnp.linalg.pinv(basis) 239 | # first = basis 240 | # second = basis_inv@old 241 | # intermediate = in_dim 242 | # else: 243 | # if kill_out > 0: 244 | # basis = basis.at[-kill_out:,:].set(0.) 245 | # basis = basis.at[:,-kill_out:].set(0.) 246 | # shift = shift.at[-kill_out:].set(0.) 247 | # basis_inv = jnp.linalg.pinv(basis) 248 | # first = old@basis 249 | # second = basis_inv 250 | # intermediate = out_dim 251 | # first_params = { 252 | # "bias": shift, 253 | # "kernel": first 254 | # } 255 | # second_params = { 256 | # "bias": self.variables["params"]["bias"] - shift@second, 257 | # "kernel": second 258 | # } 259 | # return first_params, second_params 260 | 261 | def split_params(self, basis, inv_basis, shift): 262 | old = self.variables["params"]["kernel"] 263 | first = basis 264 | # print(old.shape) 265 | # print(inv_basis.shape) 266 | second = inv_basis@old 267 | first_params = { 268 | "bias": shift, 269 | "kernel": first, 270 | } 271 | second_params = { 272 | "bias": self.variables["params"]["bias"] - shift@second, 273 | "kernel": second 274 | } 275 | return first_params, second_params 276 | 277 | def split_module(self): 278 | intermediate = min(self.input_size(), self.output_size()) 279 | return DDense(intermediate), DDense(self.output_size()) 280 | 281 | def identity_params(self): 282 | params = self.variables["params"] 283 | old_kernel = params["kernel"] 284 | diag = jnp.diag_indices(min(old_kernel.shape)) 285 | return params.copy({ 286 | "bias": jnp.zeros_like(params["bias"]), 287 | "kernel": jnp.zeros_like(old_kernel).at[diag].set(1.) 288 | }) 289 | 290 | def restrict_params(self, dim): 291 | params = super().restrict_params(dim) 292 | # print(params) 293 | return params.copy({ 294 | "kernel": params["kernel"].at[dim:, :].set(0.) 295 | }) 296 | 297 | def extract_activations(self): 298 | return self.variables["intermediates"]["preactivations"] 299 | 300 | def extract_out_grads(self): 301 | return self.variables["params"]["bias"] 302 | 303 | 304 | 305 | 306 | class Layer(Splittable): 307 | features: Optional[int] 308 | make_invariant: Sequence[nn.Module] = () 309 | make_linear: Callable[[int], Active] = DDense 310 | make_equivariant: Sequence[Callable[[], Passive]] = (partial(Rational1D, True),) 311 | 312 | def setup(self): 313 | F = self.features 314 | self.linear = None if F is None else self.make_linear(F) 315 | self.equivariant = None if F is None else [func() for func in self.make_equivariant] 316 | self.invariant = [func() for func in self.make_invariant] 317 | 318 | nn.nowrap 319 | def dormant(self): 320 | return self.features is None 321 | 322 | def __call__(self, x): 323 | if self.features is not None: 324 | for module in self.invariant: 325 | x = module(x) 326 | x = self.linear(x) 327 | for module in self.equivariant: 328 | x = module(x) 329 | return x 330 | 331 | def pad_inputs(self, count): 332 | return self.variables.copy({ 333 | "params": self.variables["params"].copy({ 334 | "linear": self.linear.pad_inputs(count)["params"] 335 | }) 336 | }) 337 | 338 | def null(self): 339 | if self.dormant(): 340 | return None 341 | else: 342 | out = self.linear.null() 343 | for i, module in enumerate(self.equivariant): 344 | if f'equivariant_{i}' in self.variables['params']: 345 | out = out & module.null() 346 | return out 347 | 348 | def input_size(self): 349 | return self.linear.input_size() 350 | 351 | def output_size(self): 352 | return self.linear.output_size() 353 | 354 | def split_basis_size(self): 355 | return self.linear.split_basis_size() 356 | 357 | # def split_params(self, basis, shift, kill_in, kill_out): 358 | # first, second = self.linear.split_params(basis, shift, kill_in, kill_out) 359 | # # F = first['kernel'].shape[-1] 360 | # # equivariant_params = [ 361 | # # eq.identity_params(F) for eq in self.equivariant 362 | # # ] 363 | # # first = self.variables["params"].copy({ 364 | # # "linear": first 365 | # # }) 366 | # # for i, p in enumerate(equivariant_params): 367 | # # first = first.copy({ 368 | # # f"equivariant_{i}": p 369 | # # }) 370 | # second = self.variables["params"].copy({ 371 | # "linear": second 372 | # }) 373 | # # assert "invariant_0" not in self.variables["params"] 374 | # return first, second 375 | 376 | def split_params(self, basis, inv_basis, shift): 377 | first, second = self.linear.split_params(basis, inv_basis, shift) 378 | second = self.variables["params"].copy({ 379 | "linear": second 380 | }) 381 | return first, second 382 | 383 | def split_identity(self, linear_params, rank=None): 384 | params = self.variables['params'] 385 | params = params.copy({ 386 | 'linear': linear_params, 387 | }) 388 | for i, eq in enumerate(self.equivariant): 389 | params = params.copy({ 390 | f'equivariant_{i}': eq.identity_params(self.features, rank=rank) 391 | }) 392 | assert 'invariant_0' not in params 393 | return params 394 | 395 | def restrict_params(self, dim): 396 | params = self.variables["params"] 397 | params = params.copy({ 398 | 'linear': self.linear.restrict_params(dim) 399 | }) 400 | params = params.copy({ 401 | f'equivariant_{i}': eq.restrict_params(dim) \ 402 | for i, eq in enumerate(self.equivariant) \ 403 | if f'equivariant_{i}' in params 404 | }) 405 | assert 'invariant_0' not in params 406 | return params 407 | 408 | def identity_params(self): 409 | return self.split_identity(self.linear.identity_params()) 410 | 411 | 412 | 413 | class Layers(nn.Module): 414 | layers: Sequence[Layer] 415 | 416 | def features(self): 417 | return [layer.features for layer in self.layers if layer.features] 418 | 419 | def nulls(self): 420 | return [layer.null() for layer in self.layers if layer.features] 421 | 422 | def lift(self, index, func, *args): 423 | return func(self.layers[index], *args) 424 | 425 | def __call__(self, x): 426 | for L in self.layers: 427 | x = L(x) 428 | return x 429 | 430 | def pad_features(self, index, count): 431 | variables = self.variables 432 | def layer_update(i, new_params): 433 | nonlocal variables 434 | variables = variables.copy({ 435 | 'params': variables['params'].copy({ 436 | f'layers_{i}': new_params['params'] 437 | }) 438 | }) 439 | layer_update(index, self.layers[index].pad_features(count)) 440 | subsequent = index + 1 441 | if subsequent < len(self.layers): 442 | layer_update(subsequent, self.layers[subsequent].pad_inputs(count)) 443 | return variables 444 | 445 | def pad_pow2(self, index, override=False): 446 | features = self.features() 447 | increase = None 448 | old = features[index] 449 | updated, features[index] = self.layers[index].pad_pow2(override) 450 | variables = self.variables 451 | def layer_update(i, new_params): 452 | nonlocal variables 453 | variables = variables.copy({ 454 | "params": variables["params"].copy({ 455 | f"layers_{i}": new_params["params"] 456 | }) 457 | }) 458 | layer_update(index, updated) 459 | layer_update(index+1, self.layers[index+1].pad_inputs(features[index] - old)) 460 | return variables, features 461 | 462 | def split_layer(self, index, basis, shift, inserter): 463 | kill_out = sum(self.layers[index].null()) 464 | kill_in = sum(self.layers[index-1].null()) if index > 0 else 0 465 | first_linear, second_layer = self.layers[index].split_params(basis, shift, kill_in, kill_out) 466 | first_layer = inserter(first_linear) 467 | params = self.variables['params'] 468 | for i in reversed(range(index+1, len(self.layers))): 469 | params = params.copy({ 470 | f'layers_{i+1}': params[f'layers_{i}'] 471 | }) 472 | params = params.copy({ 473 | f'layers_{index}': first_layer, 474 | f'layers_{index+1}': second_layer 475 | }) 476 | return params 477 | 478 | @nn.nowrap 479 | @staticmethod 480 | def get_nearest_active(dims, index): 481 | assert dims[index] is None 482 | for i, dim in enumerate(dims[index:]): 483 | if dim is not None: 484 | succeeding = index + i 485 | break 486 | preceding = None 487 | for i, dim in enumerate(reversed(dims[:index+1])): 488 | if dim is not None: 489 | preceding = index - i 490 | break 491 | return preceding, succeeding 492 | 493 | def activate_layer(self, dims, index, basis, shift): 494 | preceding, succeeding = self.get_nearest_active(dims, index) 495 | input_dim = self.layers[0].linear.input_size() 496 | # assert input_dim == 1 497 | new_rank = sum(~self.layers[preceding].null()) if preceding is not None else input_dim 498 | new_input_size = self.layers[index-1].linear.output_size() if index != 0 else input_dim 499 | basis = basis.at[new_rank:,:].set(0.) 500 | basis = basis.at[:,new_rank:].set(0.) 501 | shift = shift.at[new_rank:].set(0.) 502 | inv_basis = jnp.linalg.pinv(basis) 503 | basis = basis[:new_input_size,:] 504 | # kill_in = sum(self.layers[preceding].null()) if preceding is not None else 0 505 | # kill_out = sum(self.layers[succeeding].null()) 506 | # print(inv_basis.shape) 507 | new_linear, new_succeeding = self.layers[succeeding] \ 508 | .split_params(basis, 509 | inv_basis, 510 | shift) 511 | new_layer = self.layers[index].split_identity(new_linear, rank=new_rank) 512 | 513 | return self.variables.copy({ 514 | 'params': self.variables['params'].copy({ 515 | f'layers_{index}': new_layer, 516 | f'layers_{succeeding}': new_succeeding 517 | }) 518 | }) 519 | 520 | def restrict_params(self, dims): 521 | state = self.variables 522 | params = state["params"] 523 | 524 | def modify(layer, dim): 525 | if dim is None: 526 | return layer.identity_params() 527 | else: 528 | return layer.restrict_params(dim) 529 | 530 | def get_key(i): 531 | return f"layers_{i}" 532 | 533 | return state.copy({ 534 | "params": params.copy({ 535 | get_key(i): modify(layer, dim) \ 536 | for i, (layer, dim) in enumerate(zip(self.layers, dims)) 537 | }) 538 | }) 539 | 540 | def restrict_grad(self, dims): 541 | params = self.variables["params"] 542 | # print(params) 543 | for i, dim in enumerate(dims): 544 | key = f"layers_{i}" 545 | if dim is None: 546 | params = params.copy({ 547 | key: jtm(lambda arr: arr*0., params[key]) 548 | }) 549 | # print(params) 550 | # print(dims) 551 | # print('Done!') 552 | grad = self.variables.copy({ 553 | "params": params 554 | }) 555 | return grad 556 | 557 | def extract_activations(self): 558 | return [layer.linear.extract_activations() for layer in self.layers] 559 | 560 | def extract_out_grads(self): 561 | return [layer.linear.extract_out_grads() for layer in self.layers] 562 | 563 | 564 | 565 | 566 | class Seq(nn.Module): 567 | modules: Sequence[nn.Module] 568 | 569 | @nn.compact 570 | def __call__(self, x): 571 | for m in modules: 572 | x = m(x) 573 | return x 574 | 575 | 576 | 577 | def push_tangent(func, t, x): 578 | return jax.jvp(func, [x], [t]) 579 | 580 | def push_curvature(func, t, x): 581 | return jax.jvp(partial(push_tangent, func, t), [x], [t]) 582 | 583 | def reject_from(a, b): 584 | # return the component of a orthogonal to b 585 | return a - b * jnp.sum(a*b) / jnp.sum(b*b) 586 | 587 | def combine(a, b, c): 588 | # return coefficients for 'a' and 'b' to produce least squares solution for c 589 | # only accurate up to a scale factor 590 | alpha = (b*b).sum() * (a*c).sum() - (a*b).sum() * (b*c).sum() 591 | beta = (a*a).sum() * (b*c).sum() - (a*b).sum() * (a*c).sum() 592 | return alpha, beta 593 | 594 | def tree_length(v): 595 | tsq = jtm(lambda v: jnp.sum(v**2), v) 596 | ssq = jtr(lambda a, b: a + b, tsq) 597 | return jnp.sqrt(ssq) 598 | 599 | def pass_one(func, lossfn, params, p, x): 600 | func_x = lambda params: func(params, x) 601 | (y, Jp),(_, J2p) = push_curvature(func_x, p, params) 602 | # y, JT = jax.vjp(func_x, params) 603 | (loss, dL), (Jp_prod_dL, dLdJp) = jax.jvp(jax.value_and_grad(lossfn), [y], [Jp]) 604 | 605 | dL_p = Jp_prod_dL 606 | d2L_p = jnp.sum(J2p*dL) + jnp.sum(Jp*dLdJp) 607 | 608 | return y, loss, dL_p, d2L_p, dL, Jp, Jp_prod_dL#, JT 609 | 610 | def pass_two(func, params, err, x): 611 | func_x = lambda params: func(params, x) 612 | _, JT = jax.vjp(func_x, params) 613 | return JT(err) 614 | 615 | def pass_three(func, params, p, x): 616 | func_x = lambda params: func(params, x) 617 | return push_tangent(func_x, p[0], params)[1] 618 | 619 | def batch_ng(func, lossfn, metric, params, x, labels, p, axis=0): 620 | y, loss, dL_p, d2L_p, dL, Jp, corr = jax.vmap(lambda x, l: pass_one(func, partial(lossfn, l), params, p, x), 0)(x, labels) 621 | 622 | orth_grad = reject_from(dL, Jp) 623 | im_reject = jax.vmap(partial(pass_two, func, params), 0)(orth_grad, x) 624 | delta_p = jtm(lambda v: jnp.sum(v, axis=0), im_reject) 625 | delta_Jp = jax.vmap(partial(pass_three, func, params, delta_p), 0)(x) 626 | 627 | alpha, beta = combine(Jp, delta_Jp, dL) 628 | pu_rescale = beta/alpha * tree_length(p) 629 | p_update = jtm(lambda v: v * pu_rescale, delta_p) 630 | 631 | #curv = len(d2L_p)*jnp.sqrt(jnp.mean(d2L_p**2)) 632 | curv = jnp.abs(jnp.sum(d2L_p)) 633 | #curv = jnp.sum(jnp.abs(d2L_p)) 634 | p_rescale = -jnp.sum(dL_p) / curv 635 | param_update = jtm(lambda v: v * p_rescale, p) 636 | 637 | speed = jnp.sqrt((Jp**2).sum()) 638 | utility = corr.sum() / jnp.sqrt((dL**2).sum()) * speed 639 | distance = speed / p_rescale 640 | 641 | avg_loss = jnp.mean(loss) 642 | if metric is None: 643 | return avg_loss, param_update, p_update, utility 644 | else: 645 | return avg_loss, param_update, p_update, utility, metric(y, labels), curv, distance 646 | 647 | #goodness = dL . Jp / (|dL| |Jp|) 648 | -------------------------------------------------------------------------------- /senn_mlp/optim.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Sequence, Optional 2 | from functools import partial 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | import jax 7 | from jax import numpy as jnp 8 | from jax import vmap, jit 9 | from jax.tree_util import tree_map, tree_reduce 10 | import flax 11 | 12 | 13 | 14 | def sqtree(tree): 15 | return tree_map(lambda v: v**2, tree) 16 | 17 | def sqlen(tree): 18 | sq = tree_map(lambda v: jnp.sum(v**2), tree) 19 | return tree_reduce(lambda a, b: a + b, sq) 20 | 21 | def zeros_like_tree(tree): 22 | return tree_map(lambda arr: jnp.zeros_like(arr), tree) 23 | 24 | def calc_update(tau, old, new): 25 | return (new-old)/tau 26 | 27 | def tree_update(old, direction, scale=1.): 28 | return tree_map(lambda a, b: a + b*scale, old, direction) 29 | 30 | def tree_scale(scale, tree): 31 | return tree_map(lambda arr: scale*arr, tree) 32 | 33 | def state_exp_update(key, tau, state, new): 34 | old = state[key] 35 | update = tree_map(partial(calc_update, tau), old, new) 36 | # umag = sqlen(update) 37 | newstate = state.copy({key: tree_update(old, update)}) 38 | return newstate 39 | 40 | def tree_axis(tree, axis): 41 | return tree_map(lambda: axis, tree) 42 | 43 | @dataclass 44 | class EMA: 45 | key: str 46 | tau: Optional[float] 47 | initial: float = 0. 48 | sqtau: Optional[float] = None 49 | t: int = 0 50 | 51 | @property 52 | def sqkey(self): 53 | return f'_{self.key}_sq' 54 | 55 | @property 56 | def tkey(self): 57 | return f'_{self.key}_t' 58 | 59 | def init(self, state): 60 | if 'ema' not in state: 61 | state = state.copy({'ema': {}}) 62 | return state.copy({'ema': state['ema'].copy({ 63 | self.key: self.initial, 64 | self.sqkey: sqtree(self.initial), 65 | self.tkey: self.t 66 | })}) 67 | 68 | def debias(self, tree, tau, t): 69 | if tau is None: 70 | return tree 71 | correction = 1/(1 - ((tau-1)/tau)**t) 72 | return tree_map(lambda arr: arr*correction, tree) 73 | 74 | def get_t(self, state): 75 | return state['ema'][self.tkey] 76 | 77 | def increment_t(self, state): 78 | return state.copy({'ema': state['ema'].copy({ 79 | self.tkey: self.get_t(state) + 1 80 | })}) 81 | 82 | def mean(self, state): 83 | return self.debias(state['ema'][self.key], self.tau, self.get_t(state)) 84 | 85 | def mean_sq(self, state): 86 | sqtau = self.tau if self.sqtau is None else self.sqtau 87 | return self.debias(state['ema'][self.sqkey], sqtau, self.get_t(state)) 88 | 89 | def sqmag(self, state): 90 | return sqlen(self.mean(state)) 91 | 92 | def variance(self, state): 93 | return tree_map(lambda a, b: a - b, state['ema'][self.sqkey], sqtree(self.mean(state))) 94 | 95 | def scalarvar(self, state): 96 | return sqlen(self.variance(state)) 97 | 98 | def update(self, state, value, batch_axis=None): 99 | sqval = sqtree(value) 100 | if batch_axis is not None: 101 | batch_mean = lambda v: jnp.mean(v, axis=batch_axis) 102 | value = tree_map(batch_mean, value) 103 | sqval = tree_map(batch_mean, sqval) 104 | if self.tau is None: 105 | state = state.copy({'ema': state['ema'].copy({ 106 | self.key: value, 107 | self.sqkey: sqval 108 | })}) 109 | else: 110 | sqtau = self.tau if self.sqtau is None else self.sqtau 111 | state = state.copy({'ema': state_exp_update(self.key, self.tau, state['ema'], value)}) 112 | state = state.copy({'ema': state_exp_update(self.sqkey, sqtau, state['ema'], sqval)}) 113 | state = self.increment_t(state) 114 | return state 115 | 116 | 117 | 118 | class Excalibur: 119 | 120 | def __init__(self, cfg, state): 121 | self.tau = cfg['tau'].get(None) 122 | self.LR = cfg['lr'].get(0.1) 123 | 124 | self.jl = EMA('jl', self.tau, state['params']) 125 | # state = self.jl.init(state) 126 | self.fp = EMA('fp', self.tau, state['params']) 127 | # state = self.fp.init(state) 128 | 129 | self.curv = EMA('curv', self.tau, 0.) 130 | 131 | # JP, FL, JFP 132 | self.lprods = EMA('lprods', self.tau, jnp.zeros((3,))) 133 | self.prods = EMA('prods', self.tau, jnp.zeros((3, 3))) 134 | 135 | self.lsqmag = EMA('lsqmag', self.tau, 1.) 136 | 137 | def init(self, state): 138 | for avg in [self.jl, self.fp, self.curv, self.lprods, self.prods]: 139 | state = avg.init(state) 140 | return state 141 | 142 | @staticmethod 143 | def alpha(func, state): 144 | tan = lambda params: jax.jvp(func, (params,), (state['p'],)) 145 | (Y, JP), (_, JP2) = jax.jvp(tan, (state['params'],), (state['p'],)) 146 | return Y, JP, JP2 147 | # for ema, val in zip([self.Y, self.JP, self.JP2], [Y, JP, JP2]): 148 | # state = ema.update(state, val) 149 | 150 | @staticmethod 151 | def beta(lossfn, state, Y, JP): 152 | (loss, L), (L_JP, L2) = jax.jvp(jax.value_and_grad(lossfn), (Y,), (JP,)) 153 | return loss, L, L_JP, L2 154 | 155 | @staticmethod 156 | def gamma(func, state, cotangents): 157 | _, JT = jax.vjp(func, state['params']) 158 | # print('1.3.1') 159 | return list(map(JT, cotangents)) 160 | 161 | def precondition(self, state, fp): 162 | # disable due to poor performance 163 | return fp 164 | loss_sqmag = self.lsqmag.mean(state) 165 | efdiag = self.jl.mean_sq(state) 166 | # efdiag = tree_map(lambda e, f: e.at[f == 0.].set(1.), efdiag, fp) 167 | return tree_map(lambda v, efd: v / (efd+1e-40) * loss_sqmag, fp, efdiag) 168 | 169 | def observe(self, model, lossfn, restrict_grad, xs, labels, state): 170 | # print('1') 171 | func = lambda x, params: model.apply(state.copy({'params': params}), x) 172 | # print('1.1') 173 | Y, JP, JP2 = vmap(lambda x: self.alpha(partial(func, x), state))(xs) 174 | # print('1.2') 175 | loss, L, L_JP, L2 = vmap(lambda l, Y_, JP_: self.beta(partial(lossfn, l), state, Y_, JP_))(labels, Y, JP) 176 | # print('1.3') 177 | (fp,), (jl,) = vmap(lambda x, JP_, L_: self.gamma(partial(func, x), state, [JP_, L_]))(xs, JP, L) 178 | rg = lambda tan: restrict_grad({'params': tan})['params'] 179 | fp = rg(fp) 180 | jl = rg(jl) 181 | # print('1.4') 182 | 183 | state = self.lsqmag.update(state, vmap(lambda l: jnp.sum(l**2))(L), batch_axis=0) 184 | # print('2') 185 | # fp = vmap(partial(self.precondition, state))(fp) 186 | state = self.fp.update(state, fp, batch_axis=0) 187 | state = self.jl.update(state, jl, batch_axis=0) 188 | fp, jl = rg(self.fp.mean(state)), rg(self.jl.mean(state)) 189 | fp = self.precondition(state, fp) 190 | # print('3') 191 | 192 | push_tan = lambda x, t: jax.jvp(partial(func, x), (state['params'],), (t,))[1] 193 | JFP = vmap(lambda x: push_tan(x, fp))(xs) 194 | FL = vmap(lambda x: push_tan(x, jl))(xs) 195 | # print('4') 196 | 197 | y_axis_count = len(Y.shape) - 1 198 | y_axes = list(-(i + 1) for i in range(y_axis_count)) 199 | vectors = jnp.stack([JP, FL, JFP], axis=1) 200 | lprodu = jnp.sum(L[:,None,...]*vectors, axis=y_axes) 201 | produ = jnp.sum(vectors[:,:,None,...] * vectors[:,None,:,...], axis=y_axes) 202 | # print('5') 203 | # lprodu = jnp.array([dot(L, b) for b in vectors]).moveaxis(-1, 0) 204 | # produ = jnp.array([[dot(a, b) for b in vectors] for a in vectors]).moveaxis(-1, 0) 205 | 206 | state = self.lprods.update(state, lprodu, batch_axis=0) 207 | state = self.prods.update(state, produ, batch_axis=0) 208 | # state = self.jl.update(state, jl, batch_axis=None) 209 | # state = self.fp.update(state, fp, batch_axis=None) 210 | # print('6') 211 | def dot(a, b): 212 | return vmap(lambda a, b: jnp.sum(a*b))(a, b) 213 | curv = dot(JP2, L) + dot(JP, L2) 214 | state = self.curv.update(state, curv, batch_axis=0) 215 | # print('7') 216 | return state 217 | 218 | def update_params(self, state): 219 | recip_tau = 1. if self.tau is None else 1/self.tau 220 | curv = jnp.abs(self.curv.mean(state)) + 1e-3 221 | #curvvar = jnp.abs(self.curv.variance(state)) 222 | scale = -self.lprods.mean(state)[0] / curv 223 | scale = scale * recip_tau * self.LR 224 | return state.copy({'params': tree_update(state['params'], state['p'], scale)}) 225 | 226 | def update_p(self, state): 227 | # raise NotImplementedError 228 | mat = self.prods.mean(state) 229 | vec = self.lprods.mean(state) 230 | coeffs = jnp.linalg.pinv(mat)@vec 231 | recip_tau = 1. if self.tau is None else 1/self.tau 232 | jl_scale = coeffs[1] * recip_tau 233 | fp_scale = coeffs[2] * recip_tau 234 | 235 | old = state['p'] 236 | new = tree_map(lambda v: v*coeffs[0], old) 237 | new = tree_update(new, self.jl.mean(state), jl_scale) 238 | new = tree_update(new, self.precondition(state, self.fp.mean(state)), fp_scale) 239 | 240 | new = tree_map(lambda v: v/jnp.sqrt(sqlen(new)), new) 241 | 242 | return state.copy({'p': new}) 243 | 244 | def step(self, func, lossfn, state): 245 | raise NotImplementedError 246 | state = self.observe(func, lossfn, state) 247 | state = self.update_params(state) 248 | state = self.update_p(state) 249 | return state 250 | 251 | 252 | 253 | class SimpleGradient: 254 | 255 | def __init__(self, cfg, params, name='grad'): 256 | self.tau = cfg['tau'].get(None) 257 | self.sqtau = cfg['sqtau'].get(None) 258 | self.epsilon = 1e-8 259 | 260 | self.grad = EMA(name, self.tau, zeros_like_tree(params), self.sqtau) 261 | 262 | def init(self, state): 263 | return self.grad.init(state) 264 | 265 | def observe(self, func, restrict_grad, xs, labels, params, state): 266 | maybe_restrict_grad = lambda x: restrict_grad(x) if restrict_grad is not None else x 267 | grad_for_x = lambda x, l: maybe_restrict_grad(jax.grad(func)(params, x, l)) 268 | # grad_for_x = lambda x, l: restrict_grad(jax.grad(func)(params, x, l)) 269 | grads = jax.vmap(grad_for_x)(xs, labels) 270 | return self.grad.update(state, grads, batch_axis=0) 271 | 272 | def read(self, state): 273 | return self.grad.mean(state) 274 | 275 | def adam(self, state): 276 | grad = self.grad.mean(state) 277 | sqgrad = self.grad.mean_sq(state) 278 | return tree_map( 279 | lambda a, b: a/(jnp.sqrt(b)+self.epsilon), 280 | grad, 281 | sqgrad 282 | ) 283 | 284 | 285 | 286 | class KrylovInverter: 287 | 288 | def __init__(self, cfg, params): 289 | self.tau = cfg['tau'].get(None) 290 | self.order = cfg['order'].get() 291 | 292 | initial = tree_map(lambda arr: jnp.zeros((self.order,) + arr.shape), params) 293 | 294 | self.krylov_powers = EMA('krylov_powers', self.tau, initial) 295 | 296 | initial_interleaved = jnp.zeros((2*self.order,)) 297 | self.interleaved = EMA('interleaved', self.tau, initial_interleaved) 298 | 299 | def init(self, state): 300 | state = self.krylov_powers.init(state) 301 | state = self.interleaved.init(state) 302 | return state 303 | 304 | def observe(self, operator, restrict_grad, xs, tangent, state): 305 | maybe_restrict_grad = lambda p: restrict_grad(p) if restrict_grad is not None else p 306 | def calc_powers(x): 307 | func = lambda p, _: (maybe_restrict_grad(operator(p, x)),)*2 308 | # func = lambda p, _: (restrict_grad(operator(p, x)),)*2 309 | init = tangent 310 | dummies = jnp.zeros((self.order,)) 311 | return jax.lax.scan(func, init, dummies)[1] 312 | krylov_powers = jax.vmap(calc_powers)(xs) 313 | state = self.krylov_powers.update(state, krylov_powers, batch_axis=0) 314 | 315 | krylov_powers = tree_map(lambda v: jnp.mean(v, axis=0), krylov_powers) 316 | dot = partial(tree_map, lambda a, b: jnp.sum(a*b)) 317 | shift = lambda kry, tan: jnp.roll(kry, 1, axis=0).at[0].set(tan) 318 | shifted_krylov = tree_map(shift, krylov_powers, tangent) 319 | 320 | evens = tree_reduce(jax.lax.add, jax.vmap(dot)(shifted_krylov, shifted_krylov)) 321 | odds = tree_reduce(jax.lax.add, jax.vmap(dot)(shifted_krylov, krylov_powers)) 322 | interleaved = jnp.zeros((2*self.order,)).at[0::2].set(evens)\ 323 | .at[1::2].set(odds) 324 | state = self.interleaved.update(state, interleaved) 325 | return state 326 | 327 | def inv_pow(self, tangent, exponent, state): 328 | assert exponent == 1 329 | dot = partial(tree_map, lambda a, b: jnp.sum(a*b)) 330 | shift = lambda kry, tan: jnp.roll(kry, 1, axis=0).at[0].set(tan) 331 | krylov_powers = self.krylov_powers.mean(state) 332 | shifted_krylov = tree_map(shift, krylov_powers, tangent) 333 | # 334 | # evens = tree_reduce(jax.lax.add, jax.vmap(dot)(shifted_krylov, shifted_krylov)) 335 | # odds = tree_reduce(jax.lax.add, jax.vmap(dot)(shifted_krylov, krylov_powers)) 336 | # interleaved = jnp.zeros((2*self.order,)).at[0::2].set(evens)\ 337 | # .at[1::2].set(odds) 338 | interleaved = self.interleaved.mean(state) 339 | # print(tangent) 340 | # print(krylov_powers) 341 | # exit() 342 | tanprods = jax.vmap(partial(dot, tangent))(shifted_krylov) 343 | prods_init = jnp.zeros((self.order,)) 344 | tanprods = tree_reduce(lambda a, b: a + b, tanprods, prods_init) 345 | print(f'tanprods: {tanprods}') 346 | # tanprods = interleaved[:self.order] 347 | # print(f'tanprods: {tanprods}') 348 | # crossprods = jax.vmap(lambda vec: jax.vmap(partial(dot, vec))(krylov_powers))(krylov_powers) 349 | # crossprods = jax.vmap(lambda vec: jax.vmap(partial(dot, vec))(krylov_powers))(shifted_krylov) 350 | # cross_init = jnp.zeros((self.order,)*2) 351 | # crossprods = tree_reduce(lambda a, b: a + b, crossprods, cross_init) 352 | # THIS IS NOT ALWAYS SYMMETRIC DUE TO NUMERICAL ERROR SO SYMMETRISE 353 | indices = jnp.arange(self.order)[:,None] + (jnp.arange(self.order)+1)[None,:] 354 | # print(indices) 355 | # exit() 356 | crossprods = interleaved[indices] 357 | power_penalty = 1.1 358 | penalties = jnp.power(power_penalty, jnp.arange(self.order)) 359 | dindices = jnp.diag_indices(self.order) 360 | new_diag = crossprods[dindices] * penalties 361 | crossprods = crossprods.at[dindices].set(new_diag) 362 | print(f'crossprods: {crossprods}') 363 | # crossprods = 0.5*(crossprods+jnp.transpose(crossprods)) 364 | # tancrossprod = dot(tangent, tangent) 365 | # tancrossprod = tree_reduce(lambda a, b: a + b, tancrossprod) 366 | # print(f'tancrossprod: {tancrossprod}') 367 | print(crossprods.dtype) 368 | w, v = jnp.linalg.eigh(crossprods) 369 | print(f'eig: {w}') 370 | winv = jax.nn.relu(jnp.reciprocal(w)) 371 | coeffs = jnp.sum(tanprods[None,:]@v * winv * v, axis=-1) 372 | print(f'coeffs: {coeffs}') 373 | # print(f'pinvprods: {jnp.linalg.pinv(crossprods)}') 374 | # coeffs = jnp.linalg.pinv(crossprods)@tanprods 375 | # coeffs = jnp.linalg.solve(crossprods, tanprods) 376 | # print(f'coeffs: {coeffs}') 377 | print(f'theory out dot: {jnp.sum(coeffs*tanprods)}') 378 | # print(f'inverse out dot: {jnp.sum(tanprods*(crossprods@tanprods))}') 379 | 380 | out = jax.vmap(lambda a, b: tree_map(partial(jnp.multiply, a), b))(coeffs, shifted_krylov) 381 | out = tree_map(lambda arr: jnp.sum(arr, axis=0), out) 382 | print(f'actual out dot: {tree_reduce(jax.lax.add, dot(tangent, out))}') 383 | # tangent_coeff = coeffs[0] 384 | # initial_out = tree_map(lambda tan: tangent_coeff*tan, tangent) 385 | # if self.order == 1: 386 | # out = initial_out 387 | # else: 388 | # krylov_coeffs = jnp.roll(coeffs, -1).at[-1].set(0.) 389 | # print(f'theory in dot: {jnp.sum(krylov_coeffs*tanprods) + tangent_coeff*tancrossprod}') 390 | # # out = tree_map(lambda tan, kry: tangent_coeff*tan + jnp.average(kry, 391 | # # axis=0, 392 | # # weights=krylov_coeffs)*jnp.sum(krylov_coeffs), 393 | # # tangent, 394 | # # krylov_powers) 395 | # rest = jax.vmap(lambda a, b: tree_map(partial(jnp.multiply, a), b))\ 396 | # (krylov_coeffs, krylov_powers) 397 | # out = tree_map(lambda init, rest: init + jnp.sum(rest, axis=0), initial_out, rest) 398 | # # out = jax.lax.associative_scan(func, initial_out, (krylov_coeffs, krylov_powers))[1] 399 | return out 400 | 401 | 402 | 403 | class CGInverter: 404 | 405 | def __init__(self, cfg, params): 406 | # self.tau = None 407 | self.tau = cfg['soln_tau'].get(None) 408 | self.sqtau = cfg['soln_sqtau'].get(None) 409 | self.epsilon = 1e-8 410 | self.order = cfg['order'].get() 411 | self.method = cfg['method'].get('cg') 412 | assert self.method in ['cg', 'gmres'], f"requested method: {self.method} not recognised" 413 | initial = tree_map(jnp.zeros_like, params) 414 | self.soln = EMA('soln', self.tau, initial, self.sqtau) 415 | 416 | def init(self, state): 417 | state = self.soln.init(state) 418 | return state 419 | 420 | def observe(self, operator, restrict_grad, xs, tangent, state): 421 | maybe_rg = lambda p: restrict_grad(p) if restrict_grad is not None else p 422 | F_x = lambda p, x: maybe_rg(operator(maybe_rg(p), x)) 423 | F = lambda p: tree_map(lambda arr: arr.mean(axis=0), vmap(partial(F_x, p))(xs)) 424 | if self.method == 'cg': 425 | soln, _ = jax.scipy.sparse.linalg.cg(F, tangent, maxiter=self.order) 426 | elif self.method == 'gmres': 427 | soln, _ = jax.scipy.sparse.linalg.gmres(F, tangent, maxiter=self.order, restart=self.order) 428 | state = self.soln.update(state, soln) 429 | return state 430 | 431 | def read(self, state): 432 | return self.soln.mean(state) 433 | 434 | def adam(self, state): 435 | grad = self.soln.mean(state) 436 | sqgrad = self.soln.mean_sq(state) 437 | return tree_map( 438 | lambda a, b: a/(jnp.sqrt(b)+self.epsilon), 439 | grad, 440 | sqgrad 441 | ) 442 | 443 | 444 | 445 | class FisherNorm: 446 | 447 | def __init__(self, cfg, params): 448 | self.tau = None 449 | self.Fsqmag = EMA('Fsqmag', self.tau) 450 | 451 | def init(self, state): 452 | state = self.Fsqmag.init(state) 453 | return state 454 | 455 | def observe(self, func, xs, tangent, params, state): 456 | fwd_jac = lambda tan, x: jax.jvp(lambda p: func(p, x), (params,), (tan,))[1] 457 | sqmag_x = lambda x: sqlen(fwd_jac(tangent, x)) 458 | sqmag = vmap(sqmag_x)(xs).mean(axis=0) 459 | state = self.Fsqmag.update(state, sqmag) 460 | return state 461 | 462 | def read(self, state): 463 | return self.Fsqmag.mean(state) 464 | 465 | def apply(self, tangent, state): 466 | mag = jnp.sqrt(self.read(state)) 467 | return tree_map(lambda arr: arr/mag, tangent) 468 | 469 | 470 | 471 | class KrylovNG: 472 | 473 | def __init__(self, cfg, params): 474 | self.SG = SimpleGradient(cfg, params) 475 | self.KI = KrylovInverter(cfg, params) 476 | self.key = 'nat_grad' 477 | self.adam = cfg['use_adam'].get(False) 478 | 479 | def init(self, state=flax.core.frozen_dict.freeze({})): 480 | state = self.SG.init(state) 481 | state = self.KI.init(state) 482 | return state 483 | 484 | def observe(self, func, lossfn, restrict_grad, xs, ls, params, state): 485 | tree_double = lambda tree: tree_map(lambda arr: arr.astype('float64'), tree) 486 | params=tree_double(params) 487 | # print(func(params, xs[0])) 488 | # print(lossfn(ls[0], func(params, xs[0]))) 489 | # exit() 490 | state = self.SG.observe(lambda params, x, l: lossfn(l, func(params, x)), 491 | restrict_grad, 492 | xs, 493 | ls, 494 | params, 495 | state) 496 | fwd_jac = lambda tan, x: jax.jvp(lambda p: func(p, x), (params,), (tan,))[1] 497 | bkd_jac = lambda tan, x: jax.vjp(lambda p: func(p, x), params,)[1](tan)[0] 498 | # sqjacob = lambda tan, x: bkd_jac(fwd_jac(tan, x), x) 499 | ID_MUL = 0e1 500 | sqjacob = lambda tan, x: tree_map(lambda a ,b: ID_MUL*a+b, tan, bkd_jac(fwd_jac(tan, x), x)) 501 | if self.adam: 502 | grad = tree_double(self.SG.adam(state)) 503 | else: 504 | grad = tree_double(self.SG.read(state)) 505 | # print(params) 506 | # print(grad) 507 | # tan_ = fwd_jac(grad, xs[0]) 508 | # # print(tan_) 509 | # # JT = jax.vjp(lambda p: func(p, xs[0]), params)[1] 510 | # # print(JT) 511 | # # tan_2 = JT(tan_) 512 | # tan_2 = bkd_jac(tan_, xs[0]) 513 | # print(tan_2) 514 | # exit() 515 | state = self.KI.observe(sqjacob, restrict_grad, xs, grad, state) 516 | nat_grad = self.KI.inv_pow(grad, 1, state) 517 | return state.copy({self.key: nat_grad}) 518 | 519 | def read(self, state): 520 | return state[self.key] 521 | 522 | 523 | 524 | class CGNG: 525 | 526 | def __init__(self, cfg, params): 527 | self.SG = SimpleGradient(cfg, params) 528 | self.CGinv = CGInverter(cfg, params) 529 | self.tikhonov = cfg['tikhonov'].get(1e0) 530 | self.adam = cfg['soln_adam'].get(False) 531 | self.norm = FisherNorm(cfg, params) 532 | self.gnorm = FisherNorm(cfg, params) 533 | self.param_Fnorm = FisherNorm(cfg, params) 534 | self.l2 = cfg['l2_regularization'].get(0e0) 535 | 536 | def init(self, state=flax.core.frozen_dict.freeze({})): 537 | state = self.SG.init(state) 538 | state = self.CGinv.init(state) 539 | state = self.norm.init(state) 540 | state = self.gnorm.init(state) 541 | state = self.param_Fnorm.init(state) 542 | return state 543 | 544 | def observe(self, func, lossfn, restrict_grad, xs, ls, params, state): 545 | tree_double = lambda tree: tree_map(lambda arr: arr.astype('float64'), tree) 546 | params=tree_double(params) 547 | state = self.SG.observe(lambda params, x, l: lossfn(l, func(params, x)), 548 | restrict_grad, 549 | xs, 550 | ls, 551 | params, 552 | state) 553 | fwd_jac = lambda tan, x: jax.jvp(lambda p: func(p, x), (params,), (tan,))[1] 554 | bkd_jac = lambda tan, x: jax.vjp(lambda p: func(p, x), params,)[1](tan)[0] 555 | sqjacob = lambda tan, x: bkd_jac(fwd_jac(tan, x), x) 556 | 557 | rparams = restrict_grad(params) 558 | grad = self.SG.read(state) 559 | grad = tree_map(lambda g, p: g + self.l2*p, grad, rparams) 560 | grad = tree_double(grad) 561 | 562 | # gradlen = jnp.sqrt(sqlen(grad)) 563 | state = self.gnorm.observe(func, xs, grad, params, state) 564 | eps = self.tikhonov 565 | # id_coeff = eps * self.gnorm.read(state)/sqlen(grad) 566 | id_coeff = eps 567 | tikhonov = lambda tan, x: tree_map(jax.lax.add, sqjacob(tan, x), tree_scale(id_coeff, tan)) 568 | 569 | state = self.CGinv.observe(tikhonov, restrict_grad, xs, grad, state) 570 | 571 | state = self.norm.observe(func, xs, self.read(state), params, state) 572 | state = self.param_Fnorm.observe(func, xs, rparams, rparams, state) 573 | return state 574 | 575 | def raw_grad(self, state): 576 | if self.adam: 577 | return self.SG.adam(state) 578 | else: 579 | return self.SG.read(state) 580 | 581 | def read(self, state): 582 | if self.adam: 583 | return self.CGinv.adam(state) 584 | else: 585 | return self.CGinv.read(state) 586 | 587 | def read_normed(self, state): 588 | return self.norm.apply(self.read(state), state) 589 | 590 | 591 | 592 | class KrylovNM: 593 | 594 | def __init__(self, cfg, params): 595 | self.SG = SimpleGradient(cfg, params) 596 | self.KI = KrylovInverter(cfg, params) 597 | 598 | def init(self, state): 599 | state = self.SG.init(state) 600 | state = self.KI.init(state) 601 | return state 602 | 603 | def observe(self, func, lossfn, params, xs, state): 604 | assert False 605 | #need to change sqjacob to hessian, then implement powers > 1 606 | state = self.SG.observe(state, lambda params, x: lossfn(func(params, x))) 607 | fwd_jac = lambda tan, x: jax.jvp(lambda p: func(p, x), (params,), (tan,))[1][0] 608 | bkd_jac = lambda tan, x: jax.vjp(lambda p: func(p, x), (params,))[1]((tan,))[0] 609 | sqjacob = lambda tan, x: bkd_jac(fwd_jac(tan, x), x) 610 | grad = self.SG.read(state) 611 | state = self.KI.observe(sqjacob, xs, grad, state) 612 | h2inv_grad = self.KI.inv_pow(grad, 2, state) 613 | return state.copy({self.key: nat_grad}) 614 | 615 | def read(self, state): 616 | return state[self.key] 617 | -------------------------------------------------------------------------------- /senn_mlp/requirements.txt: -------------------------------------------------------------------------------- 1 | tbp-nightly 2 | matplotlib 3 | flax 4 | confuse 5 | rtpt 6 | tqdm 7 | compose 8 | scikit-learn -------------------------------------------------------------------------------- /senn_mlp/results/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/self-expanding-neural-networks/3480be01cbbfa46726af3a84dd9fb834d1ca979e/senn_mlp/results/.placeholder -------------------------------------------------------------------------------- /senn_mlp/tboard.sh: -------------------------------------------------------------------------------- 1 | # IMAGE=$(docker build . --no-cache -q -t egwene) 2 | IMAGE=$(docker build . -q -t egwene) 3 | CONTAINER=$(docker run --rm -v $PWD:/water\ 4 | -v $PWD/results:/water/results\ 5 | -v $PWD/datasets:/datasets\ 6 | -v $PWD/logs:/logs\ 7 | -p 56006:6006\ 8 | -itd ${IMAGE}) 9 | docker exec ${CONTAINER} sh -c "tensorboard --logdir /logs &" 10 | docker attach ${CONTAINER} 11 | -------------------------------------------------------------------------------- /senn_mlp/use: -------------------------------------------------------------------------------- 1 | # IMAGE=$(docker build . --no-cache -q -t egwene) 2 | IMAGE=$(docker build . -q -t egwene) 3 | CONTAINER=$(docker run --rm -v $PWD:/senn\ 4 | -v $PWD/results:/senn/results\ 5 | -v $PWD/datasets:/datasets\ 6 | -v $PWD/logs:/logs\ 7 | -v $PWD/checkpoints:/checkpoints\ 8 | -itd --gpus='"device='$1'"' ${IMAGE}) 9 | docker exec ${CONTAINER} sh -c "tensorboard --logdir /logs &" 10 | docker attach ${CONTAINER} 11 | -------------------------------------------------------------------------------- /senn_mlp/visualisation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from matplotlib import ticker 5 | from PIL import Image 6 | import glob 7 | 8 | 9 | def plot_decision_boundary(pred_func, X, Y, step, save_path): 10 | # Set min and max values and give it some padding 11 | x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 12 | y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 13 | h = 0.01 14 | # generate a grid of points with distance h between them 15 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) 16 | # predict on the whole grid 17 | Z = pred_func(np.c_[xx.ravel(), yy.ravel()]) 18 | Z = Z.reshape(xx.shape) 19 | # plot the contour 20 | levels = np.linspace(0.0, 1.0, 41) 21 | 22 | fig = plt.figure() 23 | ax = fig.add_subplot(111) 24 | cax = ax.contourf(xx, yy, Z, cmap='PRGn', levels=levels) 25 | cb = fig.colorbar(cax) 26 | ax.scatter(X[:, 0], X[:, 1], c=Y, cmap='Paired') 27 | 28 | # number of colorbar ticks 29 | tick_locator = ticker.MaxNLocator(nbins=10) 30 | cb.locator = tick_locator 31 | cb.update_ticks() 32 | 33 | # add a title with epoch 34 | plt.title("Epoch: " + '{:04d}'.format(step)) 35 | 36 | # Turn off the grid for this plot 37 | ax.grid(False) 38 | plt.tight_layout() 39 | 40 | plt.savefig(os.path.join(save_path, 'decision_boundary_step_' + '{:04d}'.format(step) + '.png'), 41 | bbox_inches='tight') 42 | 43 | 44 | def animate_decision_boundary(save_path): 45 | frames = [] 46 | imgs = glob.glob(save_path + '/*decision_boundary_step*.png') 47 | sorted_imgs = sorted(imgs) # I guess the wild card returns unsorted stuff 48 | 49 | for i in sorted_imgs: 50 | new_frame = Image.open(i) 51 | frames.append(new_frame) 52 | 53 | frames[0].save(os.path.join(save_path, 'decision_boundary_animation.gif'), format='GIF', 54 | append_images=frames[1:], duration=750, save_all=True, loop=0) 55 | 56 | 57 | def visualize_fisherD(F_logD, save_path): 58 | ticks_font_size = 10 59 | 60 | fig = plt.figure() 61 | ax = fig.add_subplot(111) 62 | 63 | ax.plot(F_logD) 64 | ax.grid(False) 65 | plt.xticks(fontsize=ticks_font_size) 66 | plt.yticks(fontsize=ticks_font_size) 67 | plt.xlabel('Optimization steps', fontsize=ticks_font_size + 4) 68 | plt.ylabel('|F|', fontsize=ticks_font_size + 4) 69 | 70 | plt.tight_layout() 71 | 72 | plt.savefig(os.path.join(save_path, 'fisher_logD.png'), bbox_inches='tight') 73 | 74 | # also save original data to be able to change/reproduce plots later 75 | np.save(os.path.join(save_path, 'fisher_logD.npy'), np.array(F_logD)) 76 | --------------------------------------------------------------------------------