├── exp_random_nn ├── mnist_as_tf │ └── figure.png ├── mnist_paper_convnet_gp │ └── figure.png ├── mnist_paper_residual_cnn_gp │ └── figure.png ├── run.bash ├── random_comparison.py └── random_plot.py ├── requirements.txt ├── cnn_gp ├── __init__.py ├── kernel_save_tools.py ├── kernel_patch.py ├── data.py └── kernels.py ├── setup.py ├── configs ├── mnist_paper_convnet_gp.py ├── mnist_paper_residual_cnn_gp.py ├── mnist.py ├── cifar10.py └── mnist_as_tf.py ├── exp_mnist_resnet ├── merge_h5_files.py ├── run.bash ├── save_kernel.py └── classify_gp.py ├── LICENSE ├── .gitignore └── README.md /exp_random_nn/mnist_as_tf/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cambridge-mlg/cnn-gp/HEAD/exp_random_nn/mnist_as_tf/figure.png -------------------------------------------------------------------------------- /exp_random_nn/mnist_paper_convnet_gp/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cambridge-mlg/cnn-gp/HEAD/exp_random_nn/mnist_paper_convnet_gp/figure.png -------------------------------------------------------------------------------- /exp_random_nn/mnist_paper_residual_cnn_gp/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cambridge-mlg/cnn-gp/HEAD/exp_random_nn/mnist_paper_residual_cnn_gp/figure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # for main module 2 | numpy>=1.10.0 3 | torch>=1.1.0 4 | torchvision>=0.2.0 5 | tqdm>=4.32 6 | 7 | # For experiments 8 | h5py>=2.9.0 9 | scikit-learn>=0.21.0 10 | pandas>=0.24.0 11 | matplotlib>=3.1.0 12 | -------------------------------------------------------------------------------- /cnn_gp/__init__.py: -------------------------------------------------------------------------------- 1 | from . import kernels, data, kernel_save_tools 2 | from .kernels import * 3 | from .data import * 4 | from .kernel_save_tools import * 5 | 6 | __all__ = kernels.__all__ + data.__all__ + kernel_save_tools.__all__ 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function 5 | from setuptools import setup 6 | 7 | setup(name='cnn_gp', 8 | version="0.1", 9 | author="Adrià Garriga-Alonso, Laurence Aitchison", 10 | author_email="adria.garriga@gmail.com", 11 | description="CNN-GPs in Pytorch", 12 | license="BSD License 2.0", 13 | url="http://github.com/cambridge-mlg/cnn-gp-pytorch", 14 | ext_modules=[], 15 | packages=["cnn_gp"], 16 | install_requires=""" 17 | numpy>=1.10.0 18 | torch>=1.1.0 19 | torchvision>=0.2.0 20 | tqdm>=4.32 21 | """.split()) 22 | -------------------------------------------------------------------------------- /exp_random_nn/run.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | datasets_path="/scratch/ag919/datasets/" 4 | n_samples=10000 5 | seed=12 6 | 7 | # Edit which configs to do this for 8 | for config in cifar10 mnist_paper_convnet_gp mnist_paper_residual_cnn_gp; do 9 | out_path="./exp_random_nn/${config}/" 10 | mkdir "$out_path" 11 | for chans in 3 10 30 100; do 12 | echo "Running with ${chans} channels" 13 | python -m exp_random_nn.random_comparison --datasets_path="$datasets_path" \ 14 | --config="$config" --seed=$seed --channels=$chans --n_samples=$n_samples \ 15 | --out_path="$out_path" 16 | done 17 | 18 | python -m exp_random_nn.random_plot "$out_path/figure.pdf" \ 19 | "$out_path"/*_samples.csv "$out_path"/*_cov.csv 20 | done 21 | -------------------------------------------------------------------------------- /configs/mnist_paper_convnet_gp.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from cnn_gp import Conv2d, ReLU, Sequential 3 | 4 | train_range = range(5000, 55000) 5 | validation_range = list(range(55000, 60000)) + list(range(0, 5000)) 6 | test_range = range(60000, 70000) 7 | 8 | dataset_name = "MNIST" 9 | model_name = "ResNet" 10 | dataset = torchvision.datasets.MNIST 11 | transforms = [] 12 | epochs = 0 13 | in_channels = 1 14 | out_channels = 10 15 | 16 | var_bias = 7.86 17 | var_weight = 2.79 18 | 19 | layers = [] 20 | for _ in range(7): # n_layers 21 | layers += [ 22 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, 23 | var_bias=var_bias), 24 | ReLU(), 25 | ] 26 | initial_model = Sequential( 27 | *layers, 28 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight, 29 | var_bias=var_bias), 30 | ) 31 | -------------------------------------------------------------------------------- /exp_mnist_resnet/merge_h5_files.py: -------------------------------------------------------------------------------- 1 | """ 2 | Relatively inefficient way to merge the results of the workers 3 | """ 4 | import h5py 5 | import sys 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | if len(sys.argv) < 3: 10 | print(f"Usage: {sys.argv[0]} dest_file [source_file1 source_file2 ...]") 11 | sys.exit(1) 12 | 13 | _, dest_file, *src_files = sys.argv 14 | 15 | with h5py.File(dest_file, "a") as dest_f: 16 | for path in tqdm(src_files): 17 | with h5py.File(path, "r") as src_f: 18 | valid_keys = [k 19 | for k in dest_f.keys() 20 | if k in src_f.keys()] 21 | for k in tqdm(valid_keys): 22 | dest_data = dest_f[k] 23 | src_data = src_f[k] 24 | for i in tqdm(range(len(dest_data))): 25 | src = src_data[i, ...] 26 | dest = dest_data[i, ...] 27 | dest_is_nan = np.isnan(dest) 28 | dest[dest_is_nan] = src[dest_is_nan] 29 | 30 | dest_data[i, ...] = dest 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2019, Adrià Garriga-Alonso and Laurence Aitchison 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /configs/mnist_paper_residual_cnn_gp.py: -------------------------------------------------------------------------------- 1 | """ 2 | The best randomly-searched ResNet reported in the paper. 3 | 4 | In the original paper there is a bug. This network sums together layers after 5 | the ReLU nonlinearity, which are not Gaussian, and also do not have mean 0. As 6 | a result, the overall network does not converge to a Gaussian process. The 7 | defined kernel is still valid, even if it doesn't correspond to a NN. 8 | 9 | In the interest of making the results replicable, we have replicated this bug 10 | as well. 11 | 12 | The correct way to use ResNets is to sum things after a Conv2d layer, see for 13 | example the `resnet_block` in `cnn_gp/kernels.py`. 14 | """ 15 | import torchvision 16 | from cnn_gp import Conv2d, ReLU, Sequential, Sum 17 | 18 | train_range = range(5000, 55000) 19 | validation_range = list(range(55000, 60000)) + list(range(0, 5000)) 20 | test_range = range(60000, 70000) 21 | 22 | dataset_name = "MNIST" 23 | model_name = "ResNet" 24 | dataset = torchvision.datasets.MNIST 25 | transforms = [] 26 | epochs = 0 27 | in_channels = 1 28 | out_channels = 10 29 | 30 | var_bias = 4.69 31 | var_weight = 7.27 32 | initial_model = Sequential( 33 | *(Sum([ 34 | Sequential(), 35 | Sequential( 36 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2, 37 | var_bias=var_bias), 38 | ReLU(), 39 | )]) for _ in range(8)), 40 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2, 41 | var_bias=var_bias), 42 | ReLU(), 43 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight, 44 | var_bias=var_bias), 45 | ) 46 | -------------------------------------------------------------------------------- /exp_mnist_resnet/run.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_VISIBLE_DEVICES="0,1" 4 | datasets_path="/scratch/ag919/datasets/" 5 | out_path="/scratch/ag919/grams_pytorch/mnist_test3" 6 | config="mnist_as_tf" 7 | batch_size=200 8 | 9 | if [ -d "$out_path" ]; then 10 | echo "Careful: directory \"$out_path\" already exists" 11 | exit 1 12 | fi 13 | 14 | space_separated_cuda="${CUDA_VISIBLE_DEVICES//,/ }" 15 | n_workers=$(echo $space_separated_cuda | wc -w) 16 | if [ "$n_workers" == 0 ]; then 17 | echo "You must specify CUDA_VISIBLE_DEVICES" 18 | exit 1 19 | fi 20 | 21 | echo "Downloading dataset" 22 | python -c "import configs.$config as c; import cnn_gp; cnn_gp.DatasetFromConfig(\"$datasets_path\", c)" 23 | 24 | echo "Starting kernel computation workers in parallel" 25 | 26 | mkdir "$out_path" 27 | worker_rank=0 28 | for cuda_i in $space_separated_cuda; do 29 | this_worker="${out_path}/$(printf "%02d_nw%02d.h5" $worker_rank $n_workers)" 30 | 31 | CUDA_VISIBLE_DEVICES=$cuda_i python -m exp_mnist_resnet.save_kernel --n_workers=$n_workers \ 32 | --worker_rank=$worker_rank --datasets_path="$datasets_path" --batch_size=$batch_size \ 33 | --config="$config" --out_path="$this_worker" & 34 | pids[${i}]=$! 35 | worker_rank=$((worker_rank+1)) 36 | done 37 | # Wait for all workers 38 | for pid in ${pids[*]}; do 39 | wait $pid 40 | done 41 | 42 | echo "combining all data sets in one" 43 | python -m exp_mnist_resnet.merge_h5_files "${out_path}"/* 44 | 45 | 46 | echo "Classify using the complete set" 47 | combined_file="${out_path}/$(printf "%02d_nw%02d.h5" 0 $n_workers)" 48 | python -m exp_mnist_resnet.classify_gp --datasets_path="$datasets_path" \ 49 | --config="$config" --in_path="$combined_file" 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Latex 107 | *.fls 108 | *.fdb_latexmk 109 | 110 | MNIST-data 111 | MNIST_data 112 | 113 | *.kp 114 | *.pt 115 | *.pyc 116 | *.csv -------------------------------------------------------------------------------- /configs/mnist.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | from cnn_gp import Conv2d, ReLU, Sequential, resnet_block 4 | 5 | train_range = range(50000) 6 | validation_range = range(50000, 60000) 7 | test_range = range(60000, 70000) 8 | 9 | dataset_name = "MNIST" 10 | model_name = "ResNet" 11 | dataset = torchvision.datasets.MNIST 12 | transforms = [] 13 | epochs = 0 14 | in_channels = 1 15 | out_channels = 10 16 | initial_model = Sequential( 17 | Conv2d(kernel_size=3), 18 | 19 | # Big resnet block #1 20 | resnet_block(stride=1, projection_shortcut=True, multiplier=1), 21 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 22 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 23 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 24 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 25 | 26 | # Big resnet block #2 27 | resnet_block(stride=2, projection_shortcut=True, multiplier=2), 28 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 29 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 30 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 31 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 32 | 33 | # Big resnet block #3 34 | resnet_block(stride=2, projection_shortcut=True, multiplier=4), 35 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 36 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 37 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 38 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 39 | 40 | # No nonlinearity here, the next Conv2d substitutes the average pooling 41 | Conv2d(kernel_size=7, padding=0, in_channel_multiplier=4, 42 | out_channel_multiplier=4), 43 | ReLU(), 44 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4), 45 | ) 46 | -------------------------------------------------------------------------------- /configs/cifar10.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from cnn_gp import Conv2d, ReLU, Sequential, resnet_block 3 | 4 | train_range = range(40000) 5 | validation_range = range(40000, 50000) 6 | test_range = range(50000, 60000) 7 | 8 | kernel_batch_size = 350 9 | 10 | dataset_name = "CIFAR10" 11 | model_name = "ResNet" 12 | in_channels = 3 13 | dataset = torchvision.datasets.CIFAR10 14 | transforms = [] 15 | epochs = 0 16 | initial_model = Sequential( 17 | Conv2d(kernel_size=3), 18 | 19 | # Big resnet block #1 20 | resnet_block(stride=1, projection_shortcut=True, multiplier=1), 21 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 22 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 23 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 24 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 25 | 26 | # Big resnet block #2 27 | resnet_block(stride=2, projection_shortcut=True, multiplier=2), 28 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 29 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 30 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 31 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 32 | 33 | # Big resnet block #3 34 | resnet_block(stride=2, projection_shortcut=True, multiplier=4), 35 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 36 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 37 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 38 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 39 | 40 | # No nonlinearity here, the next Conv2d substitutes the average pooling 41 | Conv2d(kernel_size=8, padding=0, in_channel_multiplier=4, 42 | out_channel_multiplier=4), 43 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4, 44 | out_channel_multiplier=4), 45 | ReLU(), 46 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4), 47 | ) 48 | -------------------------------------------------------------------------------- /configs/mnist_as_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Very similar to `./mnist.py`. In this case, however, the contents of the 3 | train/validation/test sets are the same as in the original paper's experiments, 4 | which were written in TensorFlow. 5 | """ 6 | import torchvision 7 | from cnn_gp import Conv2d, ReLU, Sequential, resnet_block 8 | 9 | train_range = range(5000, 55000) 10 | validation_range = list(range(55000, 60000)) + list(range(0, 5000)) 11 | test_range = range(60000, 70000) 12 | 13 | dataset_name = "MNIST" 14 | model_name = "ResNet" 15 | dataset = torchvision.datasets.MNIST 16 | transforms = [] 17 | epochs = 0 18 | in_channels = 1 19 | out_channels = 10 20 | initial_model = Sequential( 21 | Conv2d(kernel_size=3), 22 | 23 | # Big resnet block #1 24 | resnet_block(stride=1, projection_shortcut=True, multiplier=1), 25 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 26 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 27 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 28 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 29 | 30 | # Big resnet block #2 31 | resnet_block(stride=2, projection_shortcut=True, multiplier=2), 32 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 33 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 34 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 35 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 36 | 37 | # Big resnet block #3 38 | resnet_block(stride=2, projection_shortcut=True, multiplier=4), 39 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 40 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 41 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 42 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 43 | 44 | # No nonlinearity here, the next Conv2d substitutes the average pooling 45 | Conv2d(kernel_size=7, padding=0, in_channel_multiplier=4, 46 | out_channel_multiplier=4), 47 | ReLU(), 48 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4), 49 | ) 50 | -------------------------------------------------------------------------------- /exp_mnist_resnet/save_kernel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Save a kernel matrix to disk 3 | """ 4 | import absl.app 5 | import h5py 6 | import torch 7 | import importlib 8 | import os 9 | 10 | from cnn_gp import DatasetFromConfig, save_K 11 | FLAGS = absl.app.flags.FLAGS 12 | 13 | 14 | def main(_): 15 | print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}") 16 | n_workers, worker_rank = FLAGS.n_workers, FLAGS.worker_rank 17 | config = importlib.import_module(f"configs.{FLAGS.config}") 18 | dataset = DatasetFromConfig(FLAGS.datasets_path, config) 19 | model = config.initial_model.cuda() 20 | 21 | def kern(x, x2, same, diag): 22 | with torch.no_grad(): 23 | return model(x.cuda(), x2.cuda(), same, 24 | diag).detach().cpu().numpy() 25 | 26 | with h5py.File(FLAGS.out_path, "w") as f: 27 | kwargs = dict(worker_rank=worker_rank, n_workers=n_workers, 28 | batch_size=FLAGS.batch_size, print_interval=2.) 29 | save_K(f, kern, name="Kxx", X=dataset.train, X2=None, diag=False, **kwargs) 30 | save_K(f, kern, name="Kxvx", X=dataset.validation, X2=dataset.train, diag=False, **kwargs) 31 | save_K(f, kern, name="Kxtx", X=dataset.test, X2=dataset.train, diag=False, **kwargs) 32 | 33 | if worker_rank == 0: 34 | with h5py.File(FLAGS.out_path, "a") as f: 35 | save_K(f, kern, name="Kv_diag", X=dataset.validation, X2=None, diag=True, **kwargs) 36 | save_K(f, kern, name="Kt_diag", X=dataset.test, X2=None, diag=True, **kwargs) 37 | 38 | 39 | if __name__ == '__main__': 40 | f = absl.app.flags 41 | f.DEFINE_string("datasets_path", "/scratch/ag919/datasets/", 42 | "where to save datasets") 43 | f.DEFINE_integer('batch_size', 200, 44 | "max number of examples to simultaneously compute " 45 | "the kernel of") 46 | f.DEFINE_string("config", "mnist", "which config to load from `configs`") 47 | f.DEFINE_integer("n_workers", 1, "num of workers") 48 | f.DEFINE_integer("worker_rank", 0, "rank of worker") 49 | f.DEFINE_string('out_path', None, "path of h5 file to save kernels in") 50 | absl.app.run(main) 51 | -------------------------------------------------------------------------------- /cnn_gp/kernel_save_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .data import ProductIterator, DiagIterator, print_timings 3 | 4 | __all__ = ('create_h5py_dataset', 'save_K') 5 | 6 | 7 | def create_h5py_dataset(f, batch_size, name, diag, N, N2): 8 | """ 9 | Creates a dataset named `name` on `f`, with chunks of size `batch_size`. 10 | The chunks have leading dimension 1, so as to accommodate future resizing 11 | of the leading dimension of the dataset (which starts at 1). 12 | """ 13 | if diag: 14 | chunk_shape = (1, min(batch_size, N)) 15 | shape = (1, N) 16 | maxshape = (None, N) 17 | else: 18 | chunk_shape = (1, min(batch_size, N), min(batch_size, N2)) 19 | shape = (1, N, N2) 20 | maxshape = (None, N, N2) 21 | return f.create_dataset(name, shape=shape, dtype=np.float32, 22 | fillvalue=np.nan, chunks=chunk_shape, 23 | maxshape=maxshape) 24 | 25 | 26 | def save_K(f, kern, name, X, X2, diag, batch_size, worker_rank=0, n_workers=1, 27 | print_interval=2.): 28 | """ 29 | Saves a kernel to the h5py file `f`. Creates its dataset with name `name` 30 | if necessary. 31 | """ 32 | if name in f.keys(): 33 | print("Skipping {} (group exists)".format(name)) 34 | return 35 | else: 36 | N = len(X) 37 | N2 = N if X2 is None else len(X2) 38 | out = create_h5py_dataset(f, batch_size, name, diag, N, N2) 39 | 40 | if diag: 41 | # Don't split the load for diagonals, they are cheap 42 | it = DiagIterator(batch_size, X, X2) 43 | else: 44 | it = ProductIterator(batch_size, X, X2, worker_rank=worker_rank, 45 | n_workers=n_workers) 46 | it = print_timings(it, desc=f"{name} (worker {worker_rank}/{n_workers})", 47 | print_interval=print_interval) 48 | 49 | for same, (i, (x, _y)), (j, (x2, _y2)) in it: 50 | k = kern(x, x2, same, diag) 51 | if np.any(np.isinf(k)) or np.any(np.isnan(k)): 52 | print(f"About to write a nan or inf for {i},{j}") 53 | import ipdb; ipdb.set_trace() 54 | 55 | if diag: 56 | out[0, i:i+len(x)] = k 57 | else: 58 | out[0, i:i+len(x), j:j+len(x2)] = k 59 | -------------------------------------------------------------------------------- /exp_random_nn/random_comparison.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | import importlib 5 | import numpy as np 6 | import pandas as pd 7 | import absl.app 8 | 9 | from cnn_gp import DatasetFromConfig 10 | FLAGS = absl.app.flags.FLAGS 11 | 12 | 13 | def main(_): 14 | torch.manual_seed(FLAGS.seed) 15 | 16 | config = importlib.import_module(f"configs.{FLAGS.config}") 17 | dataset = DatasetFromConfig(FLAGS.datasets_path, config) 18 | dl = torch.utils.data.DataLoader(dataset.train, batch_size=100) 19 | inputs, _ = next(iter(dl)) 20 | inputs = inputs.cuda() 21 | model = config.initial_model.cuda() 22 | 23 | results = [] 24 | r0 = [] 25 | 26 | with torch.no_grad(): 27 | true_cov = model(inputs).cpu().numpy() 28 | 29 | with torch.no_grad(): 30 | for _ in tqdm(range(FLAGS.n_samples)): 31 | nn = model.nn(FLAGS.channels, in_channels=config.in_channels, 32 | out_channels=1).cuda() 33 | results.append(nn(inputs)[:, 0, 0, 0].cpu().numpy()) 34 | r0.append(results[-1][0]) 35 | del nn 36 | 37 | samples_output_filename = os.path.join( 38 | FLAGS.out_path, 39 | f"{FLAGS.channels:04d}_{FLAGS.seed:04d}_samples.csv") 40 | pd.DataFrame({ 41 | 'r0': np.array(r0) / np.sqrt(true_cov[0, 0]), 42 | }).to_csv(samples_output_filename, index=False) 43 | 44 | Ni = inputs.shape[0] 45 | i = np.arange(Ni) * np.ones([Ni, 1]) 46 | j = i.T 47 | R = np.vstack(results) 48 | est_cov = R.T @ R / FLAGS.n_samples 49 | 50 | cov_output_filename = os.path.join( 51 | FLAGS.out_path, 52 | f"{FLAGS.channels:04d}_{FLAGS.seed:04d}_cov.csv") 53 | pd.DataFrame({ 54 | 'i': i.ravel(), 55 | 'j': j.ravel(), 56 | 'est': est_cov.ravel(), 57 | 'true': true_cov.ravel() 58 | }).to_csv(cov_output_filename, index=False) 59 | 60 | if __name__ == '__main__': 61 | f = absl.app.flags 62 | f.DEFINE_string("datasets_path", "/scratch/ag919/datasets/", 63 | "where to save datasets") 64 | f.DEFINE_string("out_path", None, 65 | "where to save the drawn outputs of the NN and kernel") 66 | f.DEFINE_string("config", "cifar10", "which config to load from `configs`") 67 | f.DEFINE_integer("seed", 1, "the random seed") 68 | f.DEFINE_integer("channels", 30, "the number of channels of the random finite NNs") 69 | f.DEFINE_integer("n_samples", 10000, "Number of samples to draw from the NN") 70 | absl.app.run(main) 71 | -------------------------------------------------------------------------------- /cnn_gp/kernel_patch.py: -------------------------------------------------------------------------------- 1 | __all__ = ('ConvKP', 'NonlinKP') 2 | 3 | 4 | class KernelPatch: 5 | """ 6 | Represents a block of the kernel matrix. 7 | Critically, we need the variances of the rows and columns, even if the 8 | diagonal isn't part of the block, and this introduces considerable 9 | complexity. 10 | In particular, we also need to know whether the 11 | rows and columns of the matrix correspond, in which case, we need to do 12 | something different when we add IID noise. 13 | """ 14 | def __init__(self, same_or_kp, diag=False, xy=None, xx=None, yy=None): 15 | if isinstance(same_or_kp, KernelPatch): 16 | same = same_or_kp.same 17 | diag = same_or_kp.diag 18 | xy = same_or_kp.xy 19 | xx = same_or_kp.xx 20 | yy = same_or_kp.yy 21 | else: 22 | same = same_or_kp 23 | 24 | self.Nx = xx.size(0) 25 | self.Ny = yy.size(0) 26 | self.W = xy.size(-2) 27 | self.H = xy.size(-1) 28 | 29 | self.init(same, diag, xy, xx, yy) 30 | 31 | def __radd__(self, other): 32 | return self.__add__(other) 33 | 34 | def __rmul__(self, other): 35 | return self.__mul__(other) 36 | 37 | def __add__(self, other): 38 | return self._do_elementwise(other, '__add__') 39 | 40 | def __mul__(self, other): 41 | return self._do_elementwise(other, '__mul__') 42 | 43 | def _do_elementwise(self, other, op): 44 | KP = type(self) 45 | if isinstance(other, KernelPatch): 46 | other = KP(other) 47 | assert self.same == other.same 48 | assert self.diag == other.diag 49 | return KP( 50 | self.same, 51 | self.diag, 52 | getattr(self.xy, op)(other.xy), 53 | getattr(self.xx, op)(other.xx), 54 | getattr(self.yy, op)(other.yy) 55 | ) 56 | else: 57 | return KP( 58 | self.same, 59 | self.diag, 60 | getattr(self.xy, op)(other), 61 | getattr(self.xx, op)(other), 62 | getattr(self.yy, op)(other) 63 | ) 64 | 65 | 66 | class ConvKP(KernelPatch): 67 | def init(self, same, diag, xy, xx, yy): 68 | self.same = same 69 | self.diag = diag 70 | if diag: 71 | self.xy = xy.view(self.Nx, 1, self.W, self.H) 72 | else: 73 | self.xy = xy.view(self.Nx*self.Ny, 1, self.W, self.H) 74 | self.xx = xx.view(self.Nx, 1, self.W, self.H) 75 | self.yy = yy.view(self.Ny, 1, self.W, self.H) 76 | 77 | 78 | class NonlinKP(KernelPatch): 79 | def init(self, same, diag, xy, xx, yy): 80 | self.same = same 81 | self.diag = diag 82 | if diag: 83 | self.xy = xy.view(self.Nx, 1, self.W, self.H) 84 | self.xx = xx.view(self.Nx, 1, self.W, self.H) 85 | self.yy = yy.view(self.Ny, 1, self.W, self.H) 86 | else: 87 | self.xy = xy.view(self.Nx, self.Ny, self.W, self.H) 88 | self.xx = xx.view(self.Nx, 1, self.W, self.H) 89 | self.yy = yy.view( self.Ny, self.W, self.H) 90 | -------------------------------------------------------------------------------- /exp_mnist_resnet/classify_gp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a pre-computed kernel and a data set, compute train/validation/test accuracy. 3 | """ 4 | import absl.app 5 | import h5py 6 | import numpy as np 7 | import scipy.linalg 8 | import torch 9 | import sklearn.metrics 10 | import scipy 11 | 12 | import importlib 13 | from cnn_gp import DatasetFromConfig 14 | FLAGS = absl.app.flags.FLAGS 15 | 16 | 17 | def solve_system(Kxx, Y): 18 | print("Running scipy solve Kxx^-1 Y routine") 19 | assert Kxx.dtype == torch.float64 and Y.dtype == torch.float64, """ 20 | It is important that `Kxx` and `Y` are `float64`s for the inversion, 21 | even if they were `float32` when being calculated. This makes the 22 | inversion much less likely to complain about the matrix being singular. 23 | """ 24 | A = scipy.linalg.solve( 25 | Kxx.numpy(), Y.numpy(), overwrite_a=True, overwrite_b=False, 26 | check_finite=False, assume_a='pos', lower=False) 27 | return torch.from_numpy(A) 28 | 29 | 30 | def diag_add(K, diag): 31 | if isinstance(K, torch.Tensor): 32 | K.view(K.numel())[::K.shape[-1]+1] += diag 33 | elif isinstance(K, np.ndarray): 34 | K.flat[::K.shape[-1]+1] += diag 35 | else: 36 | raise TypeError("What do I do with a `{}`, K={}?".format(type(K), K)) 37 | 38 | 39 | def print_accuracy(A, Kxvx, Y, key): 40 | Ypred = (Kxvx @ A).argmax(dim=1) 41 | acc = sklearn.metrics.accuracy_score(Y, Ypred) 42 | print(f"{key} accuracy: {acc*100}%") 43 | 44 | 45 | def load_kern(dset, i): 46 | A = np.empty(dset.shape[1:], dtype=np.float32) 47 | dset.read_direct(A, source_sel=np.s_[i, :, :]) 48 | return torch.from_numpy(A).to(dtype=torch.float64) 49 | 50 | 51 | def main(_): 52 | config = importlib.import_module(f"configs.{FLAGS.config}") 53 | dataset = DatasetFromConfig(FLAGS.datasets_path, config) 54 | 55 | print("Reading training labels") 56 | _, Y = dataset.load_full(dataset.train) 57 | n_classes = Y.max() + 1 58 | Y_1hot = torch.ones((len(Y), n_classes), dtype=torch.float64).neg_() # all -1 59 | Y_1hot[torch.arange(len(Y)), Y] = 1. 60 | 61 | with h5py.File(FLAGS.in_path, "r") as f: 62 | print("Loading kernel") 63 | Kxx = load_kern(f["Kxx"], 0) 64 | diag_add(Kxx, FLAGS.jitter) 65 | 66 | print("Solving Kxx^{-1} Y") 67 | A = solve_system(Kxx, Y_1hot) 68 | 69 | _, Yv = dataset.load_full(dataset.validation) 70 | Kxvx = load_kern(f["Kxvx"], 0) 71 | print_accuracy(A, Kxvx, Yv, "validation") 72 | del Kxvx 73 | del Yv 74 | 75 | _, Yt = dataset.load_full(dataset.test) 76 | Kxtx = load_kern(f["Kxtx"], 0) 77 | print_accuracy(A, Kxtx, Yt, "test") 78 | del Kxtx 79 | del Yt 80 | 81 | 82 | # @(py36) ag919@ulam:~/Programacio/cnn-gp-pytorch$ python classify_gp.py --in_path=/scratch/ag919/grams_pytorch/mnist_as_tf/00_nwork07.h5 --config=mnist_as_tf 83 | # magma.py has some problem loading. Proceeding anyways using CPU. 84 | # Original error: ignoring magma shit 85 | # Reading training labels 86 | # Loading kernel 87 | # Solving Kxx^{-1} Y 88 | # Running scipy solve Kxx^-1 Y routine 89 | # train accuracy: 10.26% 90 | # validation accuracy: 99.31% 91 | # test accuracy: 99.11999999999999% 92 | 93 | 94 | if __name__ == '__main__': 95 | f = absl.app.flags 96 | f.DEFINE_string("datasets_path", "/scratch/ag919/datasets/", 97 | "where to save datasets") 98 | f.DEFINE_string("config", "mnist", "which config to load from `configs`") 99 | f.DEFINE_string('in_path', "/scratch/ag919/grams_pytorch/mnist/dest.h5", 100 | "path of h5 file to load kernels from") 101 | f.DEFINE_float("jitter", 0.0, "add to the diagonal") 102 | absl.app.run(main) 103 | -------------------------------------------------------------------------------- /exp_random_nn/random_plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: python random_plot.py ./figure.pdf \ 3 | ./configs/0030_0001_samples.csv ./configs/0030_0001_samples.csv ./configs/0030_0001_samples.csv ./configs/0030_0001_samples.csv \ 4 | ./configs/0030_0001_cov.csv ./configs/0030_0001_cov.csv ./configs/0030_0001_cov.csv ./configs/0030_0001_cov.csv 5 | """ 6 | import sys 7 | import numpy as np 8 | import scipy 9 | import scipy.stats 10 | import pandas as pd 11 | 12 | output_filename = sys.argv[1] 13 | sample_filenames = sys.argv[2:6] 14 | cov_filenames = sys.argv[6:10] 15 | 16 | import matplotlib 17 | import matplotlib.colors as colors 18 | import matplotlib.pyplot as plt 19 | plt.switch_backend('PDF') 20 | #plt.switch_backend('Agg') 21 | #Use tex 22 | from matplotlib import rc 23 | matplotlib.rcParams['text.usetex'] = True 24 | matplotlib.rcParams['text.latex.preamble'] = [r'\usepackage[helvet]{sfmath}\usepackage{helvet}'] 25 | 26 | #Basic measurements 27 | nrows = 3 28 | ncols = 4 29 | points = 10 # Font size 30 | fig_w_in = 5.5 # Plot width (inches) 31 | panel_wh_ratio = 0.9 32 | panel_lm_in = 0.55 33 | panel_rm_in = 0.05 34 | panel_tm_in = 0.2 35 | panel_bm_in = 0.45 36 | fig_lm_in = 0. 37 | fig_rm_in = 0. 38 | fig_tm_in = 0. 39 | fig_bm_in = 0. 40 | y_labelpad = 5 41 | 42 | panel_w_in = (fig_w_in - fig_lm_in - fig_rm_in)/ncols 43 | panel_h_in = panel_wh_ratio * panel_w_in 44 | fig_h_in = nrows*panel_h_in + fig_tm_in + fig_bm_in 45 | 46 | panel_w_s = panel_w_in / fig_w_in 47 | panel_h_s = panel_h_in / fig_h_in 48 | 49 | panel_lm_s = panel_lm_in / fig_w_in 50 | panel_rm_s = panel_rm_in / fig_w_in 51 | panel_tm_s = panel_tm_in / fig_h_in 52 | panel_bm_s = panel_bm_in / fig_h_in 53 | 54 | fig_lm_s = fig_lm_in / fig_w_in 55 | fig_rm_s = fig_rm_in / fig_w_in 56 | fig_tm_s = fig_tm_in / fig_h_in 57 | fig_bm_s = fig_bm_in / fig_h_in 58 | 59 | pt_w_s = 1/72/fig_w_in 60 | pt_h_s = 1/72/fig_w_in 61 | char_w_s = pt_w_s*points 62 | char_h_s = pt_h_s*points 63 | 64 | def bottom_margin(row): 65 | return (nrows - (row + 1))*panel_h_s + panel_bm_s + fig_bm_s 66 | def left_margin(col): 67 | return col*panel_w_s + panel_lm_s + fig_lm_s 68 | def rect(row, col): 69 | return [left_margin(col), bottom_margin(row), panel_w_s - panel_lm_s - panel_rm_s, panel_h_s - panel_tm_s - panel_bm_s] 70 | def label(ax, s): 71 | lmbm, rmtm = ax.get_position().get_points() 72 | lm, bm = lmbm 73 | rm, tm = rmtm 74 | w = rm - lm 75 | h = tm - bm 76 | ax.figure.text(lm-3.3*char_w_s, tm+char_h_s, r'\textbf{' + s + r'}') 77 | 78 | fig = plt.figure(figsize=(fig_w_in, fig_h_in)) 79 | 80 | def set_ylabel_coords(ax, yshift=0): 81 | lmbm, rmtm = ax.get_position().get_points() 82 | lm, bm = lmbm 83 | rm, tm = rmtm 84 | w = rm - lm 85 | h = tm - bm 86 | ax.yaxis.set_label_coords(lm-2.5*char_w_s, bm+h/2+h*yshift, transform = ax.figure.transFigure) 87 | 88 | z = scipy.stats.norm(0, 1) 89 | lim = 4 90 | titles = ["C=3", "C=10", "C=30", "C=100"] 91 | for i in range(4): 92 | ax=fig.add_axes(rect(0, i)) 93 | df = pd.read_csv(sample_filenames[i]) 94 | ax.hist(np.array(df.r0), bins=50, range=(-lim, lim), density=True) 95 | xs = np.linspace(-lim, lim, 100) 96 | ax.plot(xs, z.pdf(xs), linewidth=1) 97 | ax.set_ylim(0, 0.7) 98 | ax.spines['right'].set_visible(False) 99 | ax.spines['top'].set_visible(False) 100 | ax.set_title(titles[i], pad=-5) 101 | ax.set_xlim(-lim, lim) 102 | ax.set_xticks([-lim, 0, lim]) 103 | ax.set_xlabel('output') 104 | if i == 0: 105 | label(ax, 'A') 106 | ax.set_ylabel("pdf") 107 | set_ylabel_coords(ax) 108 | 109 | 110 | for i in range(4): 111 | ax=fig.add_axes(rect(1, i)) 112 | df = pd.read_csv(sample_filenames[i]) 113 | 114 | xs, ys = scipy.stats.probplot(np.array(df.r0), dist=z, fit=False) 115 | 116 | ax.plot(xs, ys, linewidth=1) 117 | ax.plot([-lim, lim], [-lim, lim], linewidth=1) 118 | ax.spines['right'].set_visible(False) 119 | ax.spines['top'].set_visible(False) 120 | ax.set_xlim(-lim, lim) 121 | ax.set_ylim(-lim, lim) 122 | ax.set_xticks([-lim, 0, lim]) 123 | ax.set_yticks([-lim, 0, lim]) 124 | ax.set_xlabel('limiting q.') 125 | if i == 0: 126 | label(ax, 'B') 127 | ax.set_ylabel('sampled q.') 128 | set_ylabel_coords(ax) 129 | 130 | for i in range(4): 131 | ax=fig.add_axes(rect(2, i)) 132 | df = pd.read_csv(cov_filenames[i]) 133 | 134 | est = np.array(df.est) 135 | true = np.array(df.true) 136 | hi_lim = int(1.1 * np.max([est, true])) 137 | order = 10**(len(str(hi_lim))-1) 138 | lims = (0, ((hi_lim+order-1)//order) * order) 139 | 140 | ax.plot(lims, lims, color='tab:orange', linewidth=1) 141 | ax.scatter(true, est, 0.3, color='tab:blue') 142 | ax.spines['right'].set_visible(False) 143 | ax.spines['top'].set_visible(False) 144 | ax.set_xlabel('limiting cov.') 145 | ax.set_xlim(*lims) 146 | ax.set_ylim(*lims) 147 | ax.set_xticks(np.linspace(*lims, 3)) 148 | ax.set_yticks(np.linspace(*lims, 3)) 149 | if i == 0: 150 | label(ax, 'C') 151 | ax.set_ylabel('sampled cov.') 152 | set_ylabel_coords(ax, yshift=-0.05) 153 | 154 | fig.savefig(output_filename, dpi=400) 155 | -------------------------------------------------------------------------------- /cnn_gp/data.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch.utils.data import ConcatDataset, DataLoader, Subset 3 | import os 4 | import numpy as np 5 | import itertools 6 | 7 | __all__ = ('DatasetFromConfig', 'ProductIterator', 'DiagIterator', 8 | 'print_timings') 9 | 10 | 11 | def _this_worker_batch(N_batches, worker_rank, n_workers): 12 | batches_per_worker = np.zeros([n_workers], dtype=np.int) 13 | batches_per_worker[:] = N_batches // n_workers 14 | batches_per_worker[:N_batches % n_workers] += 1 15 | 16 | start_batch = np.sum(batches_per_worker[:worker_rank]) 17 | batches_this_worker = batches_per_worker[worker_rank] 18 | 19 | return int(start_batch), int(batches_this_worker) 20 | 21 | 22 | def _product_generator(N_batches_X, N_batches_X2, same): 23 | for i in range(N_batches_X): 24 | if same: 25 | # Yield only upper triangle 26 | yield (True, i, i) 27 | for j in range(i+1 if same else 0, 28 | N_batches_X2): 29 | yield (False, i, j) 30 | 31 | 32 | def _round_up_div(a, b): 33 | return (a+b-1)//b 34 | 35 | 36 | class ProductIterator(object): 37 | """ 38 | Returns an iterator for loading data from both X and X2. It divides the 39 | load equally among `n_workers`, returning only the one that belongs to 40 | `worker_rank`. 41 | """ 42 | def __init__(self, batch_size, X, X2=None, worker_rank=0, n_workers=1): 43 | N_batches_X = _round_up_div(len(X), batch_size) 44 | if X2 is None: 45 | same = True 46 | X2 = X 47 | N_batches_X2 = N_batches_X 48 | N_batches = max(1, N_batches_X * (N_batches_X+1) // 2) 49 | else: 50 | same = False 51 | N_batches_X2 = _round_up_div(len(X2), batch_size) 52 | N_batches = N_batches_X * N_batches_X2 53 | 54 | start_batch, self.batches_this_worker = _this_worker_batch( 55 | N_batches, worker_rank, n_workers) 56 | 57 | self.idx_iter = itertools.islice( 58 | _product_generator(N_batches_X, N_batches_X2, same), 59 | start_batch, 60 | start_batch + self.batches_this_worker) 61 | 62 | self.worker_rank = worker_rank 63 | self.prev_j = -2 # this + 1 = -1, which is not a valid j 64 | self.X_loader = None 65 | self.X2_loader = None 66 | self.x_batch = None 67 | self.X = X 68 | self.X2 = X2 69 | self.same = same 70 | self.batch_size = batch_size 71 | 72 | def __len__(self): 73 | return self.batches_this_worker 74 | 75 | def __iter__(self): 76 | return self 77 | 78 | def dataloader_beginning_at(self, i, dataset): 79 | return iter(DataLoader( 80 | Subset(dataset, range(i*self.batch_size, len(dataset))), 81 | batch_size=self.batch_size)) 82 | 83 | def __next__(self): 84 | same, i, j = next(self.idx_iter) 85 | 86 | if self.X_loader is None: 87 | self.X_loader = self.dataloader_beginning_at(i, self.X) 88 | 89 | if j != self.prev_j+1: 90 | self.X2_loader = self.dataloader_beginning_at(j, self.X2) 91 | self.x_batch = next(self.X_loader) 92 | self.prev_j = j 93 | 94 | return (same, 95 | (i*self.batch_size, self.x_batch), 96 | (j*self.batch_size, next(self.X2_loader))) 97 | 98 | 99 | class DiagIterator(object): 100 | def __init__(self, batch_size, X, X2=None): 101 | self.batch_size = batch_size 102 | dl = DataLoader(X, batch_size=batch_size) 103 | if X2 is None: 104 | self.same = True 105 | self.it = iter(enumerate(dl)) 106 | self.length = len(dl) 107 | else: 108 | dl2 = DataLoader(X2, batch_size=batch_size) 109 | self.same = False 110 | self.it = iter(enumerate(zip(dl, dl2))) 111 | self.length = min(len(dl), len(dl2)) 112 | 113 | def __iter__(self): 114 | return self 115 | 116 | def __len__(self): 117 | return self.length 118 | 119 | def __next__(self): 120 | if self.same: 121 | i, xy = next(self.it) 122 | xy2 = xy 123 | else: 124 | i, xy, xy2 = next(self.it) 125 | ib = i*self.batch_size 126 | return (self.same, (ib, xy), (ib, xy2)) 127 | 128 | 129 | class DatasetFromConfig(object): 130 | """ 131 | A dataset that contains train, validation and test, and is created from a 132 | config file. 133 | """ 134 | def __init__(self, datasets_path, config): 135 | """ 136 | Requires: 137 | config.dataset_name (e.g. "MNIST") 138 | config.train_range 139 | config.test_range 140 | """ 141 | self.config = config 142 | 143 | trans = torchvision.transforms.ToTensor() 144 | if len(config.transforms) > 0: 145 | trans = torchvision.transforms.Compose([trans] + config.transforms) 146 | 147 | # Full datasets 148 | datasets_path = os.path.join(datasets_path, config.dataset_name) 149 | train_full = config.dataset(datasets_path, train=True, download=True, 150 | transform=trans) 151 | test_full = config.dataset(datasets_path, train=False, transform=trans) 152 | self.data_full = ConcatDataset([train_full, test_full]) 153 | 154 | # Our training/test split 155 | # (could omit some data, or include validation in test) 156 | self.train = Subset(self.data_full, config.train_range) 157 | self.validation = Subset(self.data_full, config.validation_range) 158 | self.test = Subset(self.data_full, config.test_range) 159 | 160 | @staticmethod 161 | def load_full(dataset): 162 | return next(iter(DataLoader(dataset, batch_size=len(dataset)))) 163 | 164 | 165 | def _hhmmss(s): 166 | m, s = divmod(int(s), 60) 167 | h, m = divmod(m, 60) 168 | if h == 0.0: 169 | return f"{m:02d}:{s:02d}" 170 | else: 171 | return f"{h:02d}:{m:02d}:{s:02d}" 172 | 173 | 174 | def print_timings(iterator, desc="time", print_interval=2.): 175 | """ 176 | Prints the current total number of iterations, speed of iteration, and 177 | elapsed time. 178 | 179 | Meant as a rudimentary replacement for `tqdm` that prints a new line at 180 | each iteration, and thus can be used in multiple parallel processes in the 181 | same terminal. 182 | """ 183 | import time 184 | start_time = time.perf_counter() 185 | total = len(iterator) 186 | last_printed = -print_interval 187 | for i, value in enumerate(iterator): 188 | yield value 189 | cur_time = time.perf_counter() 190 | elapsed = cur_time - start_time 191 | it_s = (i+1)/elapsed 192 | total_s = total/it_s 193 | if elapsed > last_printed + print_interval: 194 | print(f"{desc}: {i+1}/{total} it, {it_s:.02f} it/s," 195 | f"[{_hhmmss(elapsed)}<{_hhmmss(total_s)}]") 196 | last_printed = elapsed 197 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # "Deep CNNs as shallow GPs" in Pytorch 2 | Code for "Deep Convolutional Networks as shallow Gaussian Processes" 3 | ([arXiv](https://arxiv.org/abs/1808.05587), 4 | [other material](https://agarri.ga/publication/convnets-as-gps/)), by 5 | [Adrià Garriga-Alonso](https://agarri.ga/), 6 | [Laurence Aitchison](http://www.gatsby.ucl.ac.uk/~laurence/) and 7 | [Carl Edward Rasmussen](http://mlg.eng.cam.ac.uk/carl/). 8 | 9 | The most extensively used libraries 10 | are [PyTorch](https://pytorch.org/), [NumPy](https://www.numpy.org/) and 11 | [H5Py](http://www.h5py.org/), check `requirements.txt` for the rest. 12 | 13 | # The `cnn_gp` package 14 | This library allows you to very easily write down neural network architectures, 15 | and get the kernels corresponding to their equivalent GPs. We can easily build 16 | `Sequential` architectures. For example, a 3-layer convolutional network with a 17 | dense layer at the end is: 18 | 19 | ```python 20 | from cnn_gp import Sequential, Conv2d, ReLU 21 | 22 | model = Sequential( 23 | Conv2d(kernel_size=3), 24 | ReLU(), 25 | Conv2d(kernel_size=3, stride=2), 26 | ReLU(), 27 | Conv2d(kernel_size=14, padding=0), # equivalent to a dense layer 28 | ) 29 | ``` 30 | Optionally call `model = model.cuda()` to use the GPU. 31 | 32 | Then, we can compute the kernel between batches of input images: 33 | ```python 34 | import torch 35 | # X and Z have shape [N_images, N_channels, img_width, img_height] 36 | X = torch.randn(2, 3, 28, 28) 37 | Z = torch.randn(2, 3, 28, 28) 38 | 39 | Kxx = model(X) 40 | Kxx = model(X, X, same=True) 41 | 42 | Kxz = model(X, Z) 43 | 44 | # diagonal of Kxx matrix above 45 | Kxx_diag = model(X, diag=True) 46 | ``` 47 | 48 | We can also instantiate randomly initialized neural networks that have the 49 | architecture corresponding to the kernel. 50 | ```python 51 | network = model.nn(channels=16, in_channels=3, out_channels=10) 52 | isinstance(network, torch.nn.Module) # evaluates to True 53 | 54 | f_X = network(X) # evaluates network at X 55 | ``` 56 | Calling `model.nn` will give us an instance of the network above that can do 10-class 57 | classification. It accepts inputs that are RGB images (3 channels) of size 58 | 28x28. We can then train this neural network as we would any normal Pytorch 59 | model. 60 | 61 | ## Installation 62 | It is possible to install the `cnn_gp` package without any of the dependencies 63 | that are needed for the experiments. Just run 64 | ```sh 65 | pip install -e . 66 | ``` 67 | from the root directory of this same repository. 68 | 69 | ## Current limitations 70 | Dense layers are not implemented. The way to simulate them is to have a 71 | convolutional layer with `padding=0`, and with `kernel_size` as large as the 72 | activations in the previous layer. 73 | 74 | # Replicating the experiments 75 | 76 | First install the packages in `requirements.txt`. To run each of the 77 | experiments, first take a look at the files `exp_mnist_resnet/run.bash` or 78 | `exp_random_nn/run.bash`. Edit the configuration variables near the top 79 | appropriately. Then, run one of the files from the root of the directory, for 80 | example: 81 | 82 | ```bash 83 | bash ./exp_mnist_resnet/run.bash 84 | ``` 85 | 86 | ## Experiment 1: classify MNIST 87 | 88 | Here are the test errors for the best GPs corresponding to the NN architectures 89 | reported in the paper. 90 | 91 | Name in paper | Config file | Validation error | Test error 92 | --------------|-------------|------------------|---------- 93 | ConvNet GP | `mnist_paper_convnet_gp` | 0.71% | 1.03% 94 | Residual CNN GP | `mnist_paper_residual_cnn_gp` | 0.72% | 0.96% 95 | ResNet GP | `mnist_as_tf` | 0.68% | 0.84% 96 | 97 |
98 | (click to expand) Architecture for ConvNet GP 99 | 100 | ```python 101 | var_bias = 7.86 102 | var_weight = 2.79 103 | 104 | initial_model = Sequential( 105 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias), 106 | ReLU(), 107 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias), 108 | ReLU(), 109 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias), 110 | ReLU(), 111 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias), 112 | ReLU(), 113 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias), 114 | ReLU(), 115 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias), 116 | ReLU(), 117 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias), 118 | ReLU(), # Total 7 layers before dense 119 | 120 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight, var_bias=var_bias), 121 | ``` 122 |
123 |
124 | (click to expand) Architecture for Residual CNN GP 125 | 126 | ```python 127 | var_bias = 4.69 128 | var_weight = 7.27 129 | initial_model = Sequential( 130 | *(Sum([ 131 | Sequential(), 132 | Sequential( 133 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2, 134 | var_bias=var_bias), 135 | ReLU(), 136 | )]) for _ in range(8)), 137 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2, 138 | var_bias=var_bias), 139 | ReLU(), 140 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight, 141 | var_bias=var_bias), 142 | ) 143 | ``` 144 |
145 | 146 |
147 | (click to expand) Architecture for ResNet GP 148 | 149 | ```python 150 | initial_model = Sequential( 151 | Conv2d(kernel_size=3), 152 | 153 | # Big resnet block #1 154 | resnet_block(stride=1, projection_shortcut=True, multiplier=1), 155 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 156 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 157 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 158 | resnet_block(stride=1, projection_shortcut=False, multiplier=1), 159 | 160 | # Big resnet block #2 161 | resnet_block(stride=2, projection_shortcut=True, multiplier=2), 162 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 163 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 164 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 165 | resnet_block(stride=1, projection_shortcut=False, multiplier=2), 166 | 167 | # Big resnet block #3 168 | resnet_block(stride=2, projection_shortcut=True, multiplier=4), 169 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 170 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 171 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 172 | resnet_block(stride=1, projection_shortcut=False, multiplier=4), 173 | 174 | # No nonlinearity here, the next Conv2d substitutes the average pooling 175 | Conv2d(kernel_size=7, padding=0, in_channel_multiplier=4, 176 | out_channel_multiplier=4), 177 | ReLU(), 178 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4), 179 | ) 180 | ``` 181 |
182 | 183 | ## Experiment 2: Check that networks converge to a Gaussian process 184 | In the paper, only ResNet-32 GP is presented. This is why an issue when 185 | constructing the Residual CNN GP was originally not caught. More details in the 186 | relevant subsection. 187 | ### ResNet-32 GP 188 | ![Resnet-32 GP](/exp_random_nn/mnist_as_tf/figure.png) 189 | ### ConvNet GP 190 | ![Resnet-32 GP](/exp_random_nn/mnist_paper_convnet_gp/figure.png) 191 | ### Residual CNN GP 192 | The best randomly-searched ResNet reported in the paper. 193 | 194 | In the original paper there is slight issue with how the kernels relate to the 195 | underlying networks. The network sums together layers after the ReLU nonlinearity, 196 | which are not Gaussian, and also do not have mean 0. However, the kernel is valid 197 | and does correspond to a neural network. In particular, if we take an infinite 198 | 1x1 convolution, after each relu layer, 199 | we convert the output of the ReLU's into a zero-mean Gaussian, 200 | with the same kernel, which can be summed. 201 | In the interest of making the results replicable, we have replicated this issue 202 | as well. 203 | 204 | The correct way to use ResNets is to sum things after a Conv2d layer, see for 205 | example the `resnet_block` in [`cnn_gp/kernels.py`](/cnn_gp/kernels.py). 206 | 207 | # BibTex citation record 208 | Note: the version in arXiv is slightly newer and contains information about 209 | which hyperparameters turned out to be the most effective for each architecture. 210 | 211 | ```bibtex 212 | @inproceedings{aga2018cnngp, 213 | author = {{Garriga-Alonso}, Adri{\`a} and Aitchison, Laurence and Rasmussen, Carl Edward}, 214 | title = {Deep Convolutional Networks as shallow {G}aussian Processes}, 215 | booktitle = {International Conference on Learning Representations}, 216 | year = {2019}, 217 | url = {https://openreview.net/forum?id=Bklfsi0cKm}} 218 | ``` 219 | -------------------------------------------------------------------------------- /cnn_gp/kernels.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .kernel_patch import ConvKP, NonlinKP 6 | import math 7 | 8 | 9 | __all__ = ("NNGPKernel", "Conv2d", "ReLU", "Sequential", "Mixture", 10 | "MixtureModule", "Sum", "SumModule", "resnet_block") 11 | 12 | 13 | class NNGPKernel(nn.Module): 14 | """ 15 | Transforms one kernel matrix into another. 16 | [N1, N2, W, H] -> [N1, N2, W, H] 17 | """ 18 | def forward(self, x, y=None, same=None, diag=False): 19 | """ 20 | Either takes one minibatch (x), or takes two minibatches (x and y), and 21 | a boolean indicating whether they're the same. 22 | """ 23 | if y is None: 24 | assert same is None 25 | y = x 26 | same = True 27 | 28 | assert not diag or len(x) == len(y), ( 29 | "diagonal kernels must operate with data of equal length") 30 | 31 | assert 4==len(x.size()) 32 | assert 4==len(y.size()) 33 | assert x.size(1) == y.size(1) 34 | assert x.size(2) == y.size(2) 35 | assert x.size(3) == y.size(3) 36 | 37 | N1 = x.size(0) 38 | N2 = y.size(0) 39 | C = x.size(1) 40 | W = x.size(2) 41 | H = x.size(3) 42 | 43 | # [N1, C, W, H], [N2, C, W, H] -> [N1 N2, 1, W, H] 44 | if diag: 45 | xy = (x*y).mean(1, keepdim=True) 46 | else: 47 | xy = (x.unsqueeze(1)*y).mean(2).view(N1*N2, 1, W, H) 48 | xx = (x**2).mean(1, keepdim=True) 49 | yy = (y**2).mean(1, keepdim=True) 50 | 51 | initial_kp = ConvKP(same, diag, xy, xx, yy) 52 | final_kp = self.propagate(initial_kp) 53 | r = NonlinKP(final_kp).xy 54 | if diag: 55 | return r.view(N1) 56 | else: 57 | return r.view(N1, N2) 58 | 59 | 60 | class Conv2d(NNGPKernel): 61 | def __init__(self, kernel_size, stride=1, padding="same", dilation=1, 62 | var_weight=1., var_bias=0., in_channel_multiplier=1, 63 | out_channel_multiplier=1): 64 | super().__init__() 65 | self.kernel_size = kernel_size 66 | self.stride = stride 67 | self.dilation = dilation 68 | self.var_weight = var_weight 69 | self.var_bias = var_bias 70 | self.kernel_has_row_of_zeros = False 71 | if padding == "same": 72 | self.padding = dilation*(kernel_size//2) 73 | if kernel_size % 2 == 0: 74 | self.kernel_has_row_of_zeros = True 75 | else: 76 | self.padding = padding 77 | 78 | if self.kernel_has_row_of_zeros: 79 | # We need to pad one side larger than the other. We just make a 80 | # kernel that is slightly too large and make its last column and 81 | # row zeros. 82 | kernel = t.ones(1, 1, self.kernel_size+1, self.kernel_size+1) 83 | kernel[:, :, 0, :] = 0. 84 | kernel[:, :, :, 0] = 0. 85 | else: 86 | kernel = t.ones(1, 1, self.kernel_size, self.kernel_size) 87 | self.register_buffer('kernel', kernel 88 | * (self.var_weight / self.kernel_size**2)) 89 | self.in_channel_multiplier, self.out_channel_multiplier = ( 90 | in_channel_multiplier, out_channel_multiplier) 91 | 92 | def propagate(self, kp): 93 | kp = ConvKP(kp) 94 | def f(patch): 95 | return (F.conv2d(patch, self.kernel, stride=self.stride, 96 | padding=self.padding, dilation=self.dilation) 97 | + self.var_bias) 98 | return ConvKP(kp.same, kp.diag, f(kp.xy), f(kp.xx), f(kp.yy)) 99 | 100 | def nn(self, channels, in_channels=None, out_channels=None): 101 | if in_channels is None: 102 | in_channels = channels 103 | if out_channels is None: 104 | out_channels = channels 105 | conv2d = nn.Conv2d( 106 | in_channels=in_channels * self.in_channel_multiplier, 107 | out_channels=out_channels * self.out_channel_multiplier, 108 | kernel_size=self.kernel_size + ( 109 | 1 if self.kernel_has_row_of_zeros else 0), 110 | stride=self.stride, 111 | padding=self.padding, 112 | dilation=self.dilation, 113 | bias=(self.var_bias > 0.), 114 | ) 115 | conv2d.weight.data.normal_(0, math.sqrt( 116 | self.var_weight / conv2d.in_channels) / self.kernel_size) 117 | if self.kernel_has_row_of_zeros: 118 | conv2d.weight.data[:, :, 0, :] = 0 119 | conv2d.weight.data[:, :, :, 0] = 0 120 | if self.var_bias > 0.: 121 | conv2d.bias.data.normal_(0, math.sqrt(self.var_bias)) 122 | return conv2d 123 | 124 | def layers(self): 125 | return 1 126 | 127 | 128 | class ReLU(NNGPKernel): 129 | """ 130 | A ReLU nonlinearity, the covariance is numerically stabilised by clamping 131 | values. 132 | """ 133 | f32_tiny = np.finfo(np.float32).tiny 134 | def propagate(self, kp): 135 | kp = NonlinKP(kp) 136 | """ 137 | We need to calculate (xy, xx, yy == c, v₁, v₂): 138 | ⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤ 139 | √(v₁v₂) / 2π ⎷1 - c²/v₁v₂ + (π - θ)c / √(v₁v₂) 140 | 141 | which is equivalent to: 142 | 1/2π ( √(v₁v₂ - c²) + (π - θ)c ) 143 | 144 | # NOTE we divide by 2 to avoid multiplying the ReLU by sqrt(2) 145 | """ 146 | xx_yy = kp.xx * kp.yy + self.f32_tiny 147 | 148 | # Clamp these so the outputs are not NaN 149 | cos_theta = (kp.xy * xx_yy.rsqrt()).clamp(-1, 1) 150 | sin_theta = t.sqrt((xx_yy - kp.xy**2).clamp(min=0)) 151 | theta = t.acos(cos_theta) 152 | xy = (sin_theta + (math.pi - theta)*kp.xy) / (2*math.pi) 153 | 154 | xx = kp.xx/2. 155 | if kp.same: 156 | yy = xx 157 | if kp.diag: 158 | xy = xx 159 | else: 160 | # Make sure the diagonal agrees with `xx` 161 | eye = t.eye(xy.size()[0]).unsqueeze(-1).unsqueeze(-1).to(kp.xy.device) 162 | xy = (1-eye)*xy + eye*xx 163 | else: 164 | yy = kp.yy/2. 165 | return NonlinKP(kp.same, kp.diag, xy, xx, yy) 166 | 167 | def nn(self, channels, in_channels=None, out_channels=None): 168 | assert in_channels is None 169 | assert out_channels is None 170 | return nn.ReLU() 171 | 172 | def layers(self): 173 | return 0 174 | 175 | 176 | #### Combination classes 177 | 178 | class Sequential(NNGPKernel): 179 | def __init__(self, *mods): 180 | super().__init__() 181 | self.mods = mods 182 | for idx, mod in enumerate(mods): 183 | self.add_module(str(idx), mod) 184 | def propagate(self, kp): 185 | for mod in self.mods: 186 | kp = mod.propagate(kp) 187 | return kp 188 | def nn(self, channels, in_channels=None, out_channels=None): 189 | if len(self.mods) == 0: 190 | return nn.Sequential() 191 | elif len(self.mods) == 1: 192 | return self.mods[0].nn(channels, in_channels=in_channels, out_channels=out_channels) 193 | else: 194 | return nn.Sequential( 195 | self.mods[0].nn(channels, in_channels=in_channels), 196 | *[mod.nn(channels) for mod in self.mods[1:-1]], 197 | self.mods[-1].nn(channels, out_channels=out_channels) 198 | ) 199 | def layers(self): 200 | return sum(mod.layers() for mod in self.mods) 201 | 202 | 203 | class Mixture(NNGPKernel): 204 | """ 205 | Applys multiple modules to the input, and sums the result 206 | (e.g. for the implementation of a ResNet). 207 | 208 | Parameterised by proportion of each module (proportions add 209 | up to one, such that, if each model has average variance 1, 210 | then the output will also have average variance 1. 211 | """ 212 | def __init__(self, mods, logit_proportions=None): 213 | super().__init__() 214 | self.mods = mods 215 | for idx, mod in enumerate(mods): 216 | self.add_module(str(idx), mod) 217 | if logit_proportions is None: 218 | logit_proportions = t.zeros(len(mods)) 219 | self.logit = nn.Parameter(logit_proportions) 220 | def propagate(self, kp): 221 | proportions = F.softmax(self.logit, dim=0) 222 | total = self.mods[0].propagate(kp) * proportions[0] 223 | for i in range(1, len(self.mods)): 224 | total = total + (self.mods[i].propagate(kp) * proportions[i]) 225 | return total 226 | def nn(self, channels, in_channels=None, out_channels=None): 227 | return MixtureModule([mod.nn(channels, in_channels=in_channels, out_channels=out_channels) for mod in self.mods], self.logit) 228 | def layers(self): 229 | return max(mod.layers() for mod in self.mods) 230 | 231 | class MixtureModule(nn.Module): 232 | def __init__(self, mods, logit_parameter): 233 | super().__init__() 234 | self.mods = mods 235 | self.logit = t.tensor(logit_parameter) 236 | for idx, mod in enumerate(mods): 237 | self.add_module(str(idx), mod) 238 | def forward(self, input): 239 | sqrt_proportions = F.softmax(self.logit, dim=0).sqrt() 240 | total = self.mods[0](input)*sqrt_proportions[0] 241 | for i in range(1, len(self.mods)): 242 | total = total + self.mods[i](input) # *sqrt_proportions[i] 243 | return total 244 | 245 | 246 | class Sum(NNGPKernel): 247 | def __init__(self, mods): 248 | super().__init__() 249 | self.mods = mods 250 | for idx, mod in enumerate(mods): 251 | self.add_module(str(idx), mod) 252 | def propagate(self, kp): 253 | # This adds 0 to the first kp, hopefully that's a noop 254 | return sum(m.propagate(kp) for m in self.mods) 255 | def nn(self, channels, in_channels=None, out_channels=None): 256 | return SumModule([ 257 | mod.nn(channels, in_channels=in_channels, out_channels=out_channels) 258 | for mod in self.mods]) 259 | def layers(self): 260 | return max(mod.layers() for mod in self.mods) 261 | 262 | 263 | class SumModule(nn.Module): 264 | def __init__(self, mods): 265 | super().__init__() 266 | self.mods = mods 267 | for idx, mod in enumerate(mods): 268 | self.add_module(str(idx), mod) 269 | def forward(self, input): 270 | # This adds 0 to the first value, hopefully that's a noop 271 | return sum(m(input) for m in self.mods) 272 | 273 | 274 | def resnet_block(stride=1, projection_shortcut=False, multiplier=1): 275 | if stride == 1 and not projection_shortcut: 276 | return Sum([ 277 | Sequential(), 278 | Sequential( 279 | ReLU(), 280 | Conv2d(3, stride=stride, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), 281 | ReLU(), 282 | Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), 283 | ) 284 | ]) 285 | else: 286 | return Sequential( 287 | ReLU(), 288 | Sum([ 289 | Conv2d(1, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier), 290 | Sequential( 291 | Conv2d(3, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier), 292 | ReLU(), 293 | Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier), 294 | ) 295 | ]), 296 | ) 297 | --------------------------------------------------------------------------------