├── exp_random_nn
├── mnist_as_tf
│ └── figure.png
├── mnist_paper_convnet_gp
│ └── figure.png
├── mnist_paper_residual_cnn_gp
│ └── figure.png
├── run.bash
├── random_comparison.py
└── random_plot.py
├── requirements.txt
├── cnn_gp
├── __init__.py
├── kernel_save_tools.py
├── kernel_patch.py
├── data.py
└── kernels.py
├── setup.py
├── configs
├── mnist_paper_convnet_gp.py
├── mnist_paper_residual_cnn_gp.py
├── mnist.py
├── cifar10.py
└── mnist_as_tf.py
├── exp_mnist_resnet
├── merge_h5_files.py
├── run.bash
├── save_kernel.py
└── classify_gp.py
├── LICENSE
├── .gitignore
└── README.md
/exp_random_nn/mnist_as_tf/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cambridge-mlg/cnn-gp/HEAD/exp_random_nn/mnist_as_tf/figure.png
--------------------------------------------------------------------------------
/exp_random_nn/mnist_paper_convnet_gp/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cambridge-mlg/cnn-gp/HEAD/exp_random_nn/mnist_paper_convnet_gp/figure.png
--------------------------------------------------------------------------------
/exp_random_nn/mnist_paper_residual_cnn_gp/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cambridge-mlg/cnn-gp/HEAD/exp_random_nn/mnist_paper_residual_cnn_gp/figure.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # for main module
2 | numpy>=1.10.0
3 | torch>=1.1.0
4 | torchvision>=0.2.0
5 | tqdm>=4.32
6 |
7 | # For experiments
8 | h5py>=2.9.0
9 | scikit-learn>=0.21.0
10 | pandas>=0.24.0
11 | matplotlib>=3.1.0
12 |
--------------------------------------------------------------------------------
/cnn_gp/__init__.py:
--------------------------------------------------------------------------------
1 | from . import kernels, data, kernel_save_tools
2 | from .kernels import *
3 | from .data import *
4 | from .kernel_save_tools import *
5 |
6 | __all__ = kernels.__all__ + data.__all__ + kernel_save_tools.__all__
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from __future__ import print_function
5 | from setuptools import setup
6 |
7 | setup(name='cnn_gp',
8 | version="0.1",
9 | author="Adrià Garriga-Alonso, Laurence Aitchison",
10 | author_email="adria.garriga@gmail.com",
11 | description="CNN-GPs in Pytorch",
12 | license="BSD License 2.0",
13 | url="http://github.com/cambridge-mlg/cnn-gp-pytorch",
14 | ext_modules=[],
15 | packages=["cnn_gp"],
16 | install_requires="""
17 | numpy>=1.10.0
18 | torch>=1.1.0
19 | torchvision>=0.2.0
20 | tqdm>=4.32
21 | """.split())
22 |
--------------------------------------------------------------------------------
/exp_random_nn/run.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | datasets_path="/scratch/ag919/datasets/"
4 | n_samples=10000
5 | seed=12
6 |
7 | # Edit which configs to do this for
8 | for config in cifar10 mnist_paper_convnet_gp mnist_paper_residual_cnn_gp; do
9 | out_path="./exp_random_nn/${config}/"
10 | mkdir "$out_path"
11 | for chans in 3 10 30 100; do
12 | echo "Running with ${chans} channels"
13 | python -m exp_random_nn.random_comparison --datasets_path="$datasets_path" \
14 | --config="$config" --seed=$seed --channels=$chans --n_samples=$n_samples \
15 | --out_path="$out_path"
16 | done
17 |
18 | python -m exp_random_nn.random_plot "$out_path/figure.pdf" \
19 | "$out_path"/*_samples.csv "$out_path"/*_cov.csv
20 | done
21 |
--------------------------------------------------------------------------------
/configs/mnist_paper_convnet_gp.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from cnn_gp import Conv2d, ReLU, Sequential
3 |
4 | train_range = range(5000, 55000)
5 | validation_range = list(range(55000, 60000)) + list(range(0, 5000))
6 | test_range = range(60000, 70000)
7 |
8 | dataset_name = "MNIST"
9 | model_name = "ResNet"
10 | dataset = torchvision.datasets.MNIST
11 | transforms = []
12 | epochs = 0
13 | in_channels = 1
14 | out_channels = 10
15 |
16 | var_bias = 7.86
17 | var_weight = 2.79
18 |
19 | layers = []
20 | for _ in range(7): # n_layers
21 | layers += [
22 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2,
23 | var_bias=var_bias),
24 | ReLU(),
25 | ]
26 | initial_model = Sequential(
27 | *layers,
28 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight,
29 | var_bias=var_bias),
30 | )
31 |
--------------------------------------------------------------------------------
/exp_mnist_resnet/merge_h5_files.py:
--------------------------------------------------------------------------------
1 | """
2 | Relatively inefficient way to merge the results of the workers
3 | """
4 | import h5py
5 | import sys
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | if len(sys.argv) < 3:
10 | print(f"Usage: {sys.argv[0]} dest_file [source_file1 source_file2 ...]")
11 | sys.exit(1)
12 |
13 | _, dest_file, *src_files = sys.argv
14 |
15 | with h5py.File(dest_file, "a") as dest_f:
16 | for path in tqdm(src_files):
17 | with h5py.File(path, "r") as src_f:
18 | valid_keys = [k
19 | for k in dest_f.keys()
20 | if k in src_f.keys()]
21 | for k in tqdm(valid_keys):
22 | dest_data = dest_f[k]
23 | src_data = src_f[k]
24 | for i in tqdm(range(len(dest_data))):
25 | src = src_data[i, ...]
26 | dest = dest_data[i, ...]
27 | dest_is_nan = np.isnan(dest)
28 | dest[dest_is_nan] = src[dest_is_nan]
29 |
30 | dest_data[i, ...] = dest
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2019, Adrià Garriga-Alonso and Laurence Aitchison
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/configs/mnist_paper_residual_cnn_gp.py:
--------------------------------------------------------------------------------
1 | """
2 | The best randomly-searched ResNet reported in the paper.
3 |
4 | In the original paper there is a bug. This network sums together layers after
5 | the ReLU nonlinearity, which are not Gaussian, and also do not have mean 0. As
6 | a result, the overall network does not converge to a Gaussian process. The
7 | defined kernel is still valid, even if it doesn't correspond to a NN.
8 |
9 | In the interest of making the results replicable, we have replicated this bug
10 | as well.
11 |
12 | The correct way to use ResNets is to sum things after a Conv2d layer, see for
13 | example the `resnet_block` in `cnn_gp/kernels.py`.
14 | """
15 | import torchvision
16 | from cnn_gp import Conv2d, ReLU, Sequential, Sum
17 |
18 | train_range = range(5000, 55000)
19 | validation_range = list(range(55000, 60000)) + list(range(0, 5000))
20 | test_range = range(60000, 70000)
21 |
22 | dataset_name = "MNIST"
23 | model_name = "ResNet"
24 | dataset = torchvision.datasets.MNIST
25 | transforms = []
26 | epochs = 0
27 | in_channels = 1
28 | out_channels = 10
29 |
30 | var_bias = 4.69
31 | var_weight = 7.27
32 | initial_model = Sequential(
33 | *(Sum([
34 | Sequential(),
35 | Sequential(
36 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2,
37 | var_bias=var_bias),
38 | ReLU(),
39 | )]) for _ in range(8)),
40 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2,
41 | var_bias=var_bias),
42 | ReLU(),
43 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight,
44 | var_bias=var_bias),
45 | )
46 |
--------------------------------------------------------------------------------
/exp_mnist_resnet/run.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_VISIBLE_DEVICES="0,1"
4 | datasets_path="/scratch/ag919/datasets/"
5 | out_path="/scratch/ag919/grams_pytorch/mnist_test3"
6 | config="mnist_as_tf"
7 | batch_size=200
8 |
9 | if [ -d "$out_path" ]; then
10 | echo "Careful: directory \"$out_path\" already exists"
11 | exit 1
12 | fi
13 |
14 | space_separated_cuda="${CUDA_VISIBLE_DEVICES//,/ }"
15 | n_workers=$(echo $space_separated_cuda | wc -w)
16 | if [ "$n_workers" == 0 ]; then
17 | echo "You must specify CUDA_VISIBLE_DEVICES"
18 | exit 1
19 | fi
20 |
21 | echo "Downloading dataset"
22 | python -c "import configs.$config as c; import cnn_gp; cnn_gp.DatasetFromConfig(\"$datasets_path\", c)"
23 |
24 | echo "Starting kernel computation workers in parallel"
25 |
26 | mkdir "$out_path"
27 | worker_rank=0
28 | for cuda_i in $space_separated_cuda; do
29 | this_worker="${out_path}/$(printf "%02d_nw%02d.h5" $worker_rank $n_workers)"
30 |
31 | CUDA_VISIBLE_DEVICES=$cuda_i python -m exp_mnist_resnet.save_kernel --n_workers=$n_workers \
32 | --worker_rank=$worker_rank --datasets_path="$datasets_path" --batch_size=$batch_size \
33 | --config="$config" --out_path="$this_worker" &
34 | pids[${i}]=$!
35 | worker_rank=$((worker_rank+1))
36 | done
37 | # Wait for all workers
38 | for pid in ${pids[*]}; do
39 | wait $pid
40 | done
41 |
42 | echo "combining all data sets in one"
43 | python -m exp_mnist_resnet.merge_h5_files "${out_path}"/*
44 |
45 |
46 | echo "Classify using the complete set"
47 | combined_file="${out_path}/$(printf "%02d_nw%02d.h5" 0 $n_workers)"
48 | python -m exp_mnist_resnet.classify_gp --datasets_path="$datasets_path" \
49 | --config="$config" --in_path="$combined_file"
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # Latex
107 | *.fls
108 | *.fdb_latexmk
109 |
110 | MNIST-data
111 | MNIST_data
112 |
113 | *.kp
114 | *.pt
115 | *.pyc
116 | *.csv
--------------------------------------------------------------------------------
/configs/mnist.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 |
3 | from cnn_gp import Conv2d, ReLU, Sequential, resnet_block
4 |
5 | train_range = range(50000)
6 | validation_range = range(50000, 60000)
7 | test_range = range(60000, 70000)
8 |
9 | dataset_name = "MNIST"
10 | model_name = "ResNet"
11 | dataset = torchvision.datasets.MNIST
12 | transforms = []
13 | epochs = 0
14 | in_channels = 1
15 | out_channels = 10
16 | initial_model = Sequential(
17 | Conv2d(kernel_size=3),
18 |
19 | # Big resnet block #1
20 | resnet_block(stride=1, projection_shortcut=True, multiplier=1),
21 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
22 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
23 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
24 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
25 |
26 | # Big resnet block #2
27 | resnet_block(stride=2, projection_shortcut=True, multiplier=2),
28 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
29 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
30 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
31 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
32 |
33 | # Big resnet block #3
34 | resnet_block(stride=2, projection_shortcut=True, multiplier=4),
35 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
36 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
37 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
38 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
39 |
40 | # No nonlinearity here, the next Conv2d substitutes the average pooling
41 | Conv2d(kernel_size=7, padding=0, in_channel_multiplier=4,
42 | out_channel_multiplier=4),
43 | ReLU(),
44 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4),
45 | )
46 |
--------------------------------------------------------------------------------
/configs/cifar10.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from cnn_gp import Conv2d, ReLU, Sequential, resnet_block
3 |
4 | train_range = range(40000)
5 | validation_range = range(40000, 50000)
6 | test_range = range(50000, 60000)
7 |
8 | kernel_batch_size = 350
9 |
10 | dataset_name = "CIFAR10"
11 | model_name = "ResNet"
12 | in_channels = 3
13 | dataset = torchvision.datasets.CIFAR10
14 | transforms = []
15 | epochs = 0
16 | initial_model = Sequential(
17 | Conv2d(kernel_size=3),
18 |
19 | # Big resnet block #1
20 | resnet_block(stride=1, projection_shortcut=True, multiplier=1),
21 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
22 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
23 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
24 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
25 |
26 | # Big resnet block #2
27 | resnet_block(stride=2, projection_shortcut=True, multiplier=2),
28 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
29 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
30 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
31 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
32 |
33 | # Big resnet block #3
34 | resnet_block(stride=2, projection_shortcut=True, multiplier=4),
35 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
36 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
37 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
38 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
39 |
40 | # No nonlinearity here, the next Conv2d substitutes the average pooling
41 | Conv2d(kernel_size=8, padding=0, in_channel_multiplier=4,
42 | out_channel_multiplier=4),
43 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4,
44 | out_channel_multiplier=4),
45 | ReLU(),
46 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4),
47 | )
48 |
--------------------------------------------------------------------------------
/configs/mnist_as_tf.py:
--------------------------------------------------------------------------------
1 | """
2 | Very similar to `./mnist.py`. In this case, however, the contents of the
3 | train/validation/test sets are the same as in the original paper's experiments,
4 | which were written in TensorFlow.
5 | """
6 | import torchvision
7 | from cnn_gp import Conv2d, ReLU, Sequential, resnet_block
8 |
9 | train_range = range(5000, 55000)
10 | validation_range = list(range(55000, 60000)) + list(range(0, 5000))
11 | test_range = range(60000, 70000)
12 |
13 | dataset_name = "MNIST"
14 | model_name = "ResNet"
15 | dataset = torchvision.datasets.MNIST
16 | transforms = []
17 | epochs = 0
18 | in_channels = 1
19 | out_channels = 10
20 | initial_model = Sequential(
21 | Conv2d(kernel_size=3),
22 |
23 | # Big resnet block #1
24 | resnet_block(stride=1, projection_shortcut=True, multiplier=1),
25 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
26 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
27 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
28 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
29 |
30 | # Big resnet block #2
31 | resnet_block(stride=2, projection_shortcut=True, multiplier=2),
32 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
33 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
34 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
35 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
36 |
37 | # Big resnet block #3
38 | resnet_block(stride=2, projection_shortcut=True, multiplier=4),
39 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
40 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
41 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
42 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
43 |
44 | # No nonlinearity here, the next Conv2d substitutes the average pooling
45 | Conv2d(kernel_size=7, padding=0, in_channel_multiplier=4,
46 | out_channel_multiplier=4),
47 | ReLU(),
48 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4),
49 | )
50 |
--------------------------------------------------------------------------------
/exp_mnist_resnet/save_kernel.py:
--------------------------------------------------------------------------------
1 | """
2 | Save a kernel matrix to disk
3 | """
4 | import absl.app
5 | import h5py
6 | import torch
7 | import importlib
8 | import os
9 |
10 | from cnn_gp import DatasetFromConfig, save_K
11 | FLAGS = absl.app.flags.FLAGS
12 |
13 |
14 | def main(_):
15 | print(f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
16 | n_workers, worker_rank = FLAGS.n_workers, FLAGS.worker_rank
17 | config = importlib.import_module(f"configs.{FLAGS.config}")
18 | dataset = DatasetFromConfig(FLAGS.datasets_path, config)
19 | model = config.initial_model.cuda()
20 |
21 | def kern(x, x2, same, diag):
22 | with torch.no_grad():
23 | return model(x.cuda(), x2.cuda(), same,
24 | diag).detach().cpu().numpy()
25 |
26 | with h5py.File(FLAGS.out_path, "w") as f:
27 | kwargs = dict(worker_rank=worker_rank, n_workers=n_workers,
28 | batch_size=FLAGS.batch_size, print_interval=2.)
29 | save_K(f, kern, name="Kxx", X=dataset.train, X2=None, diag=False, **kwargs)
30 | save_K(f, kern, name="Kxvx", X=dataset.validation, X2=dataset.train, diag=False, **kwargs)
31 | save_K(f, kern, name="Kxtx", X=dataset.test, X2=dataset.train, diag=False, **kwargs)
32 |
33 | if worker_rank == 0:
34 | with h5py.File(FLAGS.out_path, "a") as f:
35 | save_K(f, kern, name="Kv_diag", X=dataset.validation, X2=None, diag=True, **kwargs)
36 | save_K(f, kern, name="Kt_diag", X=dataset.test, X2=None, diag=True, **kwargs)
37 |
38 |
39 | if __name__ == '__main__':
40 | f = absl.app.flags
41 | f.DEFINE_string("datasets_path", "/scratch/ag919/datasets/",
42 | "where to save datasets")
43 | f.DEFINE_integer('batch_size', 200,
44 | "max number of examples to simultaneously compute "
45 | "the kernel of")
46 | f.DEFINE_string("config", "mnist", "which config to load from `configs`")
47 | f.DEFINE_integer("n_workers", 1, "num of workers")
48 | f.DEFINE_integer("worker_rank", 0, "rank of worker")
49 | f.DEFINE_string('out_path', None, "path of h5 file to save kernels in")
50 | absl.app.run(main)
51 |
--------------------------------------------------------------------------------
/cnn_gp/kernel_save_tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .data import ProductIterator, DiagIterator, print_timings
3 |
4 | __all__ = ('create_h5py_dataset', 'save_K')
5 |
6 |
7 | def create_h5py_dataset(f, batch_size, name, diag, N, N2):
8 | """
9 | Creates a dataset named `name` on `f`, with chunks of size `batch_size`.
10 | The chunks have leading dimension 1, so as to accommodate future resizing
11 | of the leading dimension of the dataset (which starts at 1).
12 | """
13 | if diag:
14 | chunk_shape = (1, min(batch_size, N))
15 | shape = (1, N)
16 | maxshape = (None, N)
17 | else:
18 | chunk_shape = (1, min(batch_size, N), min(batch_size, N2))
19 | shape = (1, N, N2)
20 | maxshape = (None, N, N2)
21 | return f.create_dataset(name, shape=shape, dtype=np.float32,
22 | fillvalue=np.nan, chunks=chunk_shape,
23 | maxshape=maxshape)
24 |
25 |
26 | def save_K(f, kern, name, X, X2, diag, batch_size, worker_rank=0, n_workers=1,
27 | print_interval=2.):
28 | """
29 | Saves a kernel to the h5py file `f`. Creates its dataset with name `name`
30 | if necessary.
31 | """
32 | if name in f.keys():
33 | print("Skipping {} (group exists)".format(name))
34 | return
35 | else:
36 | N = len(X)
37 | N2 = N if X2 is None else len(X2)
38 | out = create_h5py_dataset(f, batch_size, name, diag, N, N2)
39 |
40 | if diag:
41 | # Don't split the load for diagonals, they are cheap
42 | it = DiagIterator(batch_size, X, X2)
43 | else:
44 | it = ProductIterator(batch_size, X, X2, worker_rank=worker_rank,
45 | n_workers=n_workers)
46 | it = print_timings(it, desc=f"{name} (worker {worker_rank}/{n_workers})",
47 | print_interval=print_interval)
48 |
49 | for same, (i, (x, _y)), (j, (x2, _y2)) in it:
50 | k = kern(x, x2, same, diag)
51 | if np.any(np.isinf(k)) or np.any(np.isnan(k)):
52 | print(f"About to write a nan or inf for {i},{j}")
53 | import ipdb; ipdb.set_trace()
54 |
55 | if diag:
56 | out[0, i:i+len(x)] = k
57 | else:
58 | out[0, i:i+len(x), j:j+len(x2)] = k
59 |
--------------------------------------------------------------------------------
/exp_random_nn/random_comparison.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from tqdm import tqdm
4 | import importlib
5 | import numpy as np
6 | import pandas as pd
7 | import absl.app
8 |
9 | from cnn_gp import DatasetFromConfig
10 | FLAGS = absl.app.flags.FLAGS
11 |
12 |
13 | def main(_):
14 | torch.manual_seed(FLAGS.seed)
15 |
16 | config = importlib.import_module(f"configs.{FLAGS.config}")
17 | dataset = DatasetFromConfig(FLAGS.datasets_path, config)
18 | dl = torch.utils.data.DataLoader(dataset.train, batch_size=100)
19 | inputs, _ = next(iter(dl))
20 | inputs = inputs.cuda()
21 | model = config.initial_model.cuda()
22 |
23 | results = []
24 | r0 = []
25 |
26 | with torch.no_grad():
27 | true_cov = model(inputs).cpu().numpy()
28 |
29 | with torch.no_grad():
30 | for _ in tqdm(range(FLAGS.n_samples)):
31 | nn = model.nn(FLAGS.channels, in_channels=config.in_channels,
32 | out_channels=1).cuda()
33 | results.append(nn(inputs)[:, 0, 0, 0].cpu().numpy())
34 | r0.append(results[-1][0])
35 | del nn
36 |
37 | samples_output_filename = os.path.join(
38 | FLAGS.out_path,
39 | f"{FLAGS.channels:04d}_{FLAGS.seed:04d}_samples.csv")
40 | pd.DataFrame({
41 | 'r0': np.array(r0) / np.sqrt(true_cov[0, 0]),
42 | }).to_csv(samples_output_filename, index=False)
43 |
44 | Ni = inputs.shape[0]
45 | i = np.arange(Ni) * np.ones([Ni, 1])
46 | j = i.T
47 | R = np.vstack(results)
48 | est_cov = R.T @ R / FLAGS.n_samples
49 |
50 | cov_output_filename = os.path.join(
51 | FLAGS.out_path,
52 | f"{FLAGS.channels:04d}_{FLAGS.seed:04d}_cov.csv")
53 | pd.DataFrame({
54 | 'i': i.ravel(),
55 | 'j': j.ravel(),
56 | 'est': est_cov.ravel(),
57 | 'true': true_cov.ravel()
58 | }).to_csv(cov_output_filename, index=False)
59 |
60 | if __name__ == '__main__':
61 | f = absl.app.flags
62 | f.DEFINE_string("datasets_path", "/scratch/ag919/datasets/",
63 | "where to save datasets")
64 | f.DEFINE_string("out_path", None,
65 | "where to save the drawn outputs of the NN and kernel")
66 | f.DEFINE_string("config", "cifar10", "which config to load from `configs`")
67 | f.DEFINE_integer("seed", 1, "the random seed")
68 | f.DEFINE_integer("channels", 30, "the number of channels of the random finite NNs")
69 | f.DEFINE_integer("n_samples", 10000, "Number of samples to draw from the NN")
70 | absl.app.run(main)
71 |
--------------------------------------------------------------------------------
/cnn_gp/kernel_patch.py:
--------------------------------------------------------------------------------
1 | __all__ = ('ConvKP', 'NonlinKP')
2 |
3 |
4 | class KernelPatch:
5 | """
6 | Represents a block of the kernel matrix.
7 | Critically, we need the variances of the rows and columns, even if the
8 | diagonal isn't part of the block, and this introduces considerable
9 | complexity.
10 | In particular, we also need to know whether the
11 | rows and columns of the matrix correspond, in which case, we need to do
12 | something different when we add IID noise.
13 | """
14 | def __init__(self, same_or_kp, diag=False, xy=None, xx=None, yy=None):
15 | if isinstance(same_or_kp, KernelPatch):
16 | same = same_or_kp.same
17 | diag = same_or_kp.diag
18 | xy = same_or_kp.xy
19 | xx = same_or_kp.xx
20 | yy = same_or_kp.yy
21 | else:
22 | same = same_or_kp
23 |
24 | self.Nx = xx.size(0)
25 | self.Ny = yy.size(0)
26 | self.W = xy.size(-2)
27 | self.H = xy.size(-1)
28 |
29 | self.init(same, diag, xy, xx, yy)
30 |
31 | def __radd__(self, other):
32 | return self.__add__(other)
33 |
34 | def __rmul__(self, other):
35 | return self.__mul__(other)
36 |
37 | def __add__(self, other):
38 | return self._do_elementwise(other, '__add__')
39 |
40 | def __mul__(self, other):
41 | return self._do_elementwise(other, '__mul__')
42 |
43 | def _do_elementwise(self, other, op):
44 | KP = type(self)
45 | if isinstance(other, KernelPatch):
46 | other = KP(other)
47 | assert self.same == other.same
48 | assert self.diag == other.diag
49 | return KP(
50 | self.same,
51 | self.diag,
52 | getattr(self.xy, op)(other.xy),
53 | getattr(self.xx, op)(other.xx),
54 | getattr(self.yy, op)(other.yy)
55 | )
56 | else:
57 | return KP(
58 | self.same,
59 | self.diag,
60 | getattr(self.xy, op)(other),
61 | getattr(self.xx, op)(other),
62 | getattr(self.yy, op)(other)
63 | )
64 |
65 |
66 | class ConvKP(KernelPatch):
67 | def init(self, same, diag, xy, xx, yy):
68 | self.same = same
69 | self.diag = diag
70 | if diag:
71 | self.xy = xy.view(self.Nx, 1, self.W, self.H)
72 | else:
73 | self.xy = xy.view(self.Nx*self.Ny, 1, self.W, self.H)
74 | self.xx = xx.view(self.Nx, 1, self.W, self.H)
75 | self.yy = yy.view(self.Ny, 1, self.W, self.H)
76 |
77 |
78 | class NonlinKP(KernelPatch):
79 | def init(self, same, diag, xy, xx, yy):
80 | self.same = same
81 | self.diag = diag
82 | if diag:
83 | self.xy = xy.view(self.Nx, 1, self.W, self.H)
84 | self.xx = xx.view(self.Nx, 1, self.W, self.H)
85 | self.yy = yy.view(self.Ny, 1, self.W, self.H)
86 | else:
87 | self.xy = xy.view(self.Nx, self.Ny, self.W, self.H)
88 | self.xx = xx.view(self.Nx, 1, self.W, self.H)
89 | self.yy = yy.view( self.Ny, self.W, self.H)
90 |
--------------------------------------------------------------------------------
/exp_mnist_resnet/classify_gp.py:
--------------------------------------------------------------------------------
1 | """
2 | Given a pre-computed kernel and a data set, compute train/validation/test accuracy.
3 | """
4 | import absl.app
5 | import h5py
6 | import numpy as np
7 | import scipy.linalg
8 | import torch
9 | import sklearn.metrics
10 | import scipy
11 |
12 | import importlib
13 | from cnn_gp import DatasetFromConfig
14 | FLAGS = absl.app.flags.FLAGS
15 |
16 |
17 | def solve_system(Kxx, Y):
18 | print("Running scipy solve Kxx^-1 Y routine")
19 | assert Kxx.dtype == torch.float64 and Y.dtype == torch.float64, """
20 | It is important that `Kxx` and `Y` are `float64`s for the inversion,
21 | even if they were `float32` when being calculated. This makes the
22 | inversion much less likely to complain about the matrix being singular.
23 | """
24 | A = scipy.linalg.solve(
25 | Kxx.numpy(), Y.numpy(), overwrite_a=True, overwrite_b=False,
26 | check_finite=False, assume_a='pos', lower=False)
27 | return torch.from_numpy(A)
28 |
29 |
30 | def diag_add(K, diag):
31 | if isinstance(K, torch.Tensor):
32 | K.view(K.numel())[::K.shape[-1]+1] += diag
33 | elif isinstance(K, np.ndarray):
34 | K.flat[::K.shape[-1]+1] += diag
35 | else:
36 | raise TypeError("What do I do with a `{}`, K={}?".format(type(K), K))
37 |
38 |
39 | def print_accuracy(A, Kxvx, Y, key):
40 | Ypred = (Kxvx @ A).argmax(dim=1)
41 | acc = sklearn.metrics.accuracy_score(Y, Ypred)
42 | print(f"{key} accuracy: {acc*100}%")
43 |
44 |
45 | def load_kern(dset, i):
46 | A = np.empty(dset.shape[1:], dtype=np.float32)
47 | dset.read_direct(A, source_sel=np.s_[i, :, :])
48 | return torch.from_numpy(A).to(dtype=torch.float64)
49 |
50 |
51 | def main(_):
52 | config = importlib.import_module(f"configs.{FLAGS.config}")
53 | dataset = DatasetFromConfig(FLAGS.datasets_path, config)
54 |
55 | print("Reading training labels")
56 | _, Y = dataset.load_full(dataset.train)
57 | n_classes = Y.max() + 1
58 | Y_1hot = torch.ones((len(Y), n_classes), dtype=torch.float64).neg_() # all -1
59 | Y_1hot[torch.arange(len(Y)), Y] = 1.
60 |
61 | with h5py.File(FLAGS.in_path, "r") as f:
62 | print("Loading kernel")
63 | Kxx = load_kern(f["Kxx"], 0)
64 | diag_add(Kxx, FLAGS.jitter)
65 |
66 | print("Solving Kxx^{-1} Y")
67 | A = solve_system(Kxx, Y_1hot)
68 |
69 | _, Yv = dataset.load_full(dataset.validation)
70 | Kxvx = load_kern(f["Kxvx"], 0)
71 | print_accuracy(A, Kxvx, Yv, "validation")
72 | del Kxvx
73 | del Yv
74 |
75 | _, Yt = dataset.load_full(dataset.test)
76 | Kxtx = load_kern(f["Kxtx"], 0)
77 | print_accuracy(A, Kxtx, Yt, "test")
78 | del Kxtx
79 | del Yt
80 |
81 |
82 | # @(py36) ag919@ulam:~/Programacio/cnn-gp-pytorch$ python classify_gp.py --in_path=/scratch/ag919/grams_pytorch/mnist_as_tf/00_nwork07.h5 --config=mnist_as_tf
83 | # magma.py has some problem loading. Proceeding anyways using CPU.
84 | # Original error: ignoring magma shit
85 | # Reading training labels
86 | # Loading kernel
87 | # Solving Kxx^{-1} Y
88 | # Running scipy solve Kxx^-1 Y routine
89 | # train accuracy: 10.26%
90 | # validation accuracy: 99.31%
91 | # test accuracy: 99.11999999999999%
92 |
93 |
94 | if __name__ == '__main__':
95 | f = absl.app.flags
96 | f.DEFINE_string("datasets_path", "/scratch/ag919/datasets/",
97 | "where to save datasets")
98 | f.DEFINE_string("config", "mnist", "which config to load from `configs`")
99 | f.DEFINE_string('in_path', "/scratch/ag919/grams_pytorch/mnist/dest.h5",
100 | "path of h5 file to load kernels from")
101 | f.DEFINE_float("jitter", 0.0, "add to the diagonal")
102 | absl.app.run(main)
103 |
--------------------------------------------------------------------------------
/exp_random_nn/random_plot.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage: python random_plot.py ./figure.pdf \
3 | ./configs/0030_0001_samples.csv ./configs/0030_0001_samples.csv ./configs/0030_0001_samples.csv ./configs/0030_0001_samples.csv \
4 | ./configs/0030_0001_cov.csv ./configs/0030_0001_cov.csv ./configs/0030_0001_cov.csv ./configs/0030_0001_cov.csv
5 | """
6 | import sys
7 | import numpy as np
8 | import scipy
9 | import scipy.stats
10 | import pandas as pd
11 |
12 | output_filename = sys.argv[1]
13 | sample_filenames = sys.argv[2:6]
14 | cov_filenames = sys.argv[6:10]
15 |
16 | import matplotlib
17 | import matplotlib.colors as colors
18 | import matplotlib.pyplot as plt
19 | plt.switch_backend('PDF')
20 | #plt.switch_backend('Agg')
21 | #Use tex
22 | from matplotlib import rc
23 | matplotlib.rcParams['text.usetex'] = True
24 | matplotlib.rcParams['text.latex.preamble'] = [r'\usepackage[helvet]{sfmath}\usepackage{helvet}']
25 |
26 | #Basic measurements
27 | nrows = 3
28 | ncols = 4
29 | points = 10 # Font size
30 | fig_w_in = 5.5 # Plot width (inches)
31 | panel_wh_ratio = 0.9
32 | panel_lm_in = 0.55
33 | panel_rm_in = 0.05
34 | panel_tm_in = 0.2
35 | panel_bm_in = 0.45
36 | fig_lm_in = 0.
37 | fig_rm_in = 0.
38 | fig_tm_in = 0.
39 | fig_bm_in = 0.
40 | y_labelpad = 5
41 |
42 | panel_w_in = (fig_w_in - fig_lm_in - fig_rm_in)/ncols
43 | panel_h_in = panel_wh_ratio * panel_w_in
44 | fig_h_in = nrows*panel_h_in + fig_tm_in + fig_bm_in
45 |
46 | panel_w_s = panel_w_in / fig_w_in
47 | panel_h_s = panel_h_in / fig_h_in
48 |
49 | panel_lm_s = panel_lm_in / fig_w_in
50 | panel_rm_s = panel_rm_in / fig_w_in
51 | panel_tm_s = panel_tm_in / fig_h_in
52 | panel_bm_s = panel_bm_in / fig_h_in
53 |
54 | fig_lm_s = fig_lm_in / fig_w_in
55 | fig_rm_s = fig_rm_in / fig_w_in
56 | fig_tm_s = fig_tm_in / fig_h_in
57 | fig_bm_s = fig_bm_in / fig_h_in
58 |
59 | pt_w_s = 1/72/fig_w_in
60 | pt_h_s = 1/72/fig_w_in
61 | char_w_s = pt_w_s*points
62 | char_h_s = pt_h_s*points
63 |
64 | def bottom_margin(row):
65 | return (nrows - (row + 1))*panel_h_s + panel_bm_s + fig_bm_s
66 | def left_margin(col):
67 | return col*panel_w_s + panel_lm_s + fig_lm_s
68 | def rect(row, col):
69 | return [left_margin(col), bottom_margin(row), panel_w_s - panel_lm_s - panel_rm_s, panel_h_s - panel_tm_s - panel_bm_s]
70 | def label(ax, s):
71 | lmbm, rmtm = ax.get_position().get_points()
72 | lm, bm = lmbm
73 | rm, tm = rmtm
74 | w = rm - lm
75 | h = tm - bm
76 | ax.figure.text(lm-3.3*char_w_s, tm+char_h_s, r'\textbf{' + s + r'}')
77 |
78 | fig = plt.figure(figsize=(fig_w_in, fig_h_in))
79 |
80 | def set_ylabel_coords(ax, yshift=0):
81 | lmbm, rmtm = ax.get_position().get_points()
82 | lm, bm = lmbm
83 | rm, tm = rmtm
84 | w = rm - lm
85 | h = tm - bm
86 | ax.yaxis.set_label_coords(lm-2.5*char_w_s, bm+h/2+h*yshift, transform = ax.figure.transFigure)
87 |
88 | z = scipy.stats.norm(0, 1)
89 | lim = 4
90 | titles = ["C=3", "C=10", "C=30", "C=100"]
91 | for i in range(4):
92 | ax=fig.add_axes(rect(0, i))
93 | df = pd.read_csv(sample_filenames[i])
94 | ax.hist(np.array(df.r0), bins=50, range=(-lim, lim), density=True)
95 | xs = np.linspace(-lim, lim, 100)
96 | ax.plot(xs, z.pdf(xs), linewidth=1)
97 | ax.set_ylim(0, 0.7)
98 | ax.spines['right'].set_visible(False)
99 | ax.spines['top'].set_visible(False)
100 | ax.set_title(titles[i], pad=-5)
101 | ax.set_xlim(-lim, lim)
102 | ax.set_xticks([-lim, 0, lim])
103 | ax.set_xlabel('output')
104 | if i == 0:
105 | label(ax, 'A')
106 | ax.set_ylabel("pdf")
107 | set_ylabel_coords(ax)
108 |
109 |
110 | for i in range(4):
111 | ax=fig.add_axes(rect(1, i))
112 | df = pd.read_csv(sample_filenames[i])
113 |
114 | xs, ys = scipy.stats.probplot(np.array(df.r0), dist=z, fit=False)
115 |
116 | ax.plot(xs, ys, linewidth=1)
117 | ax.plot([-lim, lim], [-lim, lim], linewidth=1)
118 | ax.spines['right'].set_visible(False)
119 | ax.spines['top'].set_visible(False)
120 | ax.set_xlim(-lim, lim)
121 | ax.set_ylim(-lim, lim)
122 | ax.set_xticks([-lim, 0, lim])
123 | ax.set_yticks([-lim, 0, lim])
124 | ax.set_xlabel('limiting q.')
125 | if i == 0:
126 | label(ax, 'B')
127 | ax.set_ylabel('sampled q.')
128 | set_ylabel_coords(ax)
129 |
130 | for i in range(4):
131 | ax=fig.add_axes(rect(2, i))
132 | df = pd.read_csv(cov_filenames[i])
133 |
134 | est = np.array(df.est)
135 | true = np.array(df.true)
136 | hi_lim = int(1.1 * np.max([est, true]))
137 | order = 10**(len(str(hi_lim))-1)
138 | lims = (0, ((hi_lim+order-1)//order) * order)
139 |
140 | ax.plot(lims, lims, color='tab:orange', linewidth=1)
141 | ax.scatter(true, est, 0.3, color='tab:blue')
142 | ax.spines['right'].set_visible(False)
143 | ax.spines['top'].set_visible(False)
144 | ax.set_xlabel('limiting cov.')
145 | ax.set_xlim(*lims)
146 | ax.set_ylim(*lims)
147 | ax.set_xticks(np.linspace(*lims, 3))
148 | ax.set_yticks(np.linspace(*lims, 3))
149 | if i == 0:
150 | label(ax, 'C')
151 | ax.set_ylabel('sampled cov.')
152 | set_ylabel_coords(ax, yshift=-0.05)
153 |
154 | fig.savefig(output_filename, dpi=400)
155 |
--------------------------------------------------------------------------------
/cnn_gp/data.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torch.utils.data import ConcatDataset, DataLoader, Subset
3 | import os
4 | import numpy as np
5 | import itertools
6 |
7 | __all__ = ('DatasetFromConfig', 'ProductIterator', 'DiagIterator',
8 | 'print_timings')
9 |
10 |
11 | def _this_worker_batch(N_batches, worker_rank, n_workers):
12 | batches_per_worker = np.zeros([n_workers], dtype=np.int)
13 | batches_per_worker[:] = N_batches // n_workers
14 | batches_per_worker[:N_batches % n_workers] += 1
15 |
16 | start_batch = np.sum(batches_per_worker[:worker_rank])
17 | batches_this_worker = batches_per_worker[worker_rank]
18 |
19 | return int(start_batch), int(batches_this_worker)
20 |
21 |
22 | def _product_generator(N_batches_X, N_batches_X2, same):
23 | for i in range(N_batches_X):
24 | if same:
25 | # Yield only upper triangle
26 | yield (True, i, i)
27 | for j in range(i+1 if same else 0,
28 | N_batches_X2):
29 | yield (False, i, j)
30 |
31 |
32 | def _round_up_div(a, b):
33 | return (a+b-1)//b
34 |
35 |
36 | class ProductIterator(object):
37 | """
38 | Returns an iterator for loading data from both X and X2. It divides the
39 | load equally among `n_workers`, returning only the one that belongs to
40 | `worker_rank`.
41 | """
42 | def __init__(self, batch_size, X, X2=None, worker_rank=0, n_workers=1):
43 | N_batches_X = _round_up_div(len(X), batch_size)
44 | if X2 is None:
45 | same = True
46 | X2 = X
47 | N_batches_X2 = N_batches_X
48 | N_batches = max(1, N_batches_X * (N_batches_X+1) // 2)
49 | else:
50 | same = False
51 | N_batches_X2 = _round_up_div(len(X2), batch_size)
52 | N_batches = N_batches_X * N_batches_X2
53 |
54 | start_batch, self.batches_this_worker = _this_worker_batch(
55 | N_batches, worker_rank, n_workers)
56 |
57 | self.idx_iter = itertools.islice(
58 | _product_generator(N_batches_X, N_batches_X2, same),
59 | start_batch,
60 | start_batch + self.batches_this_worker)
61 |
62 | self.worker_rank = worker_rank
63 | self.prev_j = -2 # this + 1 = -1, which is not a valid j
64 | self.X_loader = None
65 | self.X2_loader = None
66 | self.x_batch = None
67 | self.X = X
68 | self.X2 = X2
69 | self.same = same
70 | self.batch_size = batch_size
71 |
72 | def __len__(self):
73 | return self.batches_this_worker
74 |
75 | def __iter__(self):
76 | return self
77 |
78 | def dataloader_beginning_at(self, i, dataset):
79 | return iter(DataLoader(
80 | Subset(dataset, range(i*self.batch_size, len(dataset))),
81 | batch_size=self.batch_size))
82 |
83 | def __next__(self):
84 | same, i, j = next(self.idx_iter)
85 |
86 | if self.X_loader is None:
87 | self.X_loader = self.dataloader_beginning_at(i, self.X)
88 |
89 | if j != self.prev_j+1:
90 | self.X2_loader = self.dataloader_beginning_at(j, self.X2)
91 | self.x_batch = next(self.X_loader)
92 | self.prev_j = j
93 |
94 | return (same,
95 | (i*self.batch_size, self.x_batch),
96 | (j*self.batch_size, next(self.X2_loader)))
97 |
98 |
99 | class DiagIterator(object):
100 | def __init__(self, batch_size, X, X2=None):
101 | self.batch_size = batch_size
102 | dl = DataLoader(X, batch_size=batch_size)
103 | if X2 is None:
104 | self.same = True
105 | self.it = iter(enumerate(dl))
106 | self.length = len(dl)
107 | else:
108 | dl2 = DataLoader(X2, batch_size=batch_size)
109 | self.same = False
110 | self.it = iter(enumerate(zip(dl, dl2)))
111 | self.length = min(len(dl), len(dl2))
112 |
113 | def __iter__(self):
114 | return self
115 |
116 | def __len__(self):
117 | return self.length
118 |
119 | def __next__(self):
120 | if self.same:
121 | i, xy = next(self.it)
122 | xy2 = xy
123 | else:
124 | i, xy, xy2 = next(self.it)
125 | ib = i*self.batch_size
126 | return (self.same, (ib, xy), (ib, xy2))
127 |
128 |
129 | class DatasetFromConfig(object):
130 | """
131 | A dataset that contains train, validation and test, and is created from a
132 | config file.
133 | """
134 | def __init__(self, datasets_path, config):
135 | """
136 | Requires:
137 | config.dataset_name (e.g. "MNIST")
138 | config.train_range
139 | config.test_range
140 | """
141 | self.config = config
142 |
143 | trans = torchvision.transforms.ToTensor()
144 | if len(config.transforms) > 0:
145 | trans = torchvision.transforms.Compose([trans] + config.transforms)
146 |
147 | # Full datasets
148 | datasets_path = os.path.join(datasets_path, config.dataset_name)
149 | train_full = config.dataset(datasets_path, train=True, download=True,
150 | transform=trans)
151 | test_full = config.dataset(datasets_path, train=False, transform=trans)
152 | self.data_full = ConcatDataset([train_full, test_full])
153 |
154 | # Our training/test split
155 | # (could omit some data, or include validation in test)
156 | self.train = Subset(self.data_full, config.train_range)
157 | self.validation = Subset(self.data_full, config.validation_range)
158 | self.test = Subset(self.data_full, config.test_range)
159 |
160 | @staticmethod
161 | def load_full(dataset):
162 | return next(iter(DataLoader(dataset, batch_size=len(dataset))))
163 |
164 |
165 | def _hhmmss(s):
166 | m, s = divmod(int(s), 60)
167 | h, m = divmod(m, 60)
168 | if h == 0.0:
169 | return f"{m:02d}:{s:02d}"
170 | else:
171 | return f"{h:02d}:{m:02d}:{s:02d}"
172 |
173 |
174 | def print_timings(iterator, desc="time", print_interval=2.):
175 | """
176 | Prints the current total number of iterations, speed of iteration, and
177 | elapsed time.
178 |
179 | Meant as a rudimentary replacement for `tqdm` that prints a new line at
180 | each iteration, and thus can be used in multiple parallel processes in the
181 | same terminal.
182 | """
183 | import time
184 | start_time = time.perf_counter()
185 | total = len(iterator)
186 | last_printed = -print_interval
187 | for i, value in enumerate(iterator):
188 | yield value
189 | cur_time = time.perf_counter()
190 | elapsed = cur_time - start_time
191 | it_s = (i+1)/elapsed
192 | total_s = total/it_s
193 | if elapsed > last_printed + print_interval:
194 | print(f"{desc}: {i+1}/{total} it, {it_s:.02f} it/s,"
195 | f"[{_hhmmss(elapsed)}<{_hhmmss(total_s)}]")
196 | last_printed = elapsed
197 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # "Deep CNNs as shallow GPs" in Pytorch
2 | Code for "Deep Convolutional Networks as shallow Gaussian Processes"
3 | ([arXiv](https://arxiv.org/abs/1808.05587),
4 | [other material](https://agarri.ga/publication/convnets-as-gps/)), by
5 | [Adrià Garriga-Alonso](https://agarri.ga/),
6 | [Laurence Aitchison](http://www.gatsby.ucl.ac.uk/~laurence/) and
7 | [Carl Edward Rasmussen](http://mlg.eng.cam.ac.uk/carl/).
8 |
9 | The most extensively used libraries
10 | are [PyTorch](https://pytorch.org/), [NumPy](https://www.numpy.org/) and
11 | [H5Py](http://www.h5py.org/), check `requirements.txt` for the rest.
12 |
13 | # The `cnn_gp` package
14 | This library allows you to very easily write down neural network architectures,
15 | and get the kernels corresponding to their equivalent GPs. We can easily build
16 | `Sequential` architectures. For example, a 3-layer convolutional network with a
17 | dense layer at the end is:
18 |
19 | ```python
20 | from cnn_gp import Sequential, Conv2d, ReLU
21 |
22 | model = Sequential(
23 | Conv2d(kernel_size=3),
24 | ReLU(),
25 | Conv2d(kernel_size=3, stride=2),
26 | ReLU(),
27 | Conv2d(kernel_size=14, padding=0), # equivalent to a dense layer
28 | )
29 | ```
30 | Optionally call `model = model.cuda()` to use the GPU.
31 |
32 | Then, we can compute the kernel between batches of input images:
33 | ```python
34 | import torch
35 | # X and Z have shape [N_images, N_channels, img_width, img_height]
36 | X = torch.randn(2, 3, 28, 28)
37 | Z = torch.randn(2, 3, 28, 28)
38 |
39 | Kxx = model(X)
40 | Kxx = model(X, X, same=True)
41 |
42 | Kxz = model(X, Z)
43 |
44 | # diagonal of Kxx matrix above
45 | Kxx_diag = model(X, diag=True)
46 | ```
47 |
48 | We can also instantiate randomly initialized neural networks that have the
49 | architecture corresponding to the kernel.
50 | ```python
51 | network = model.nn(channels=16, in_channels=3, out_channels=10)
52 | isinstance(network, torch.nn.Module) # evaluates to True
53 |
54 | f_X = network(X) # evaluates network at X
55 | ```
56 | Calling `model.nn` will give us an instance of the network above that can do 10-class
57 | classification. It accepts inputs that are RGB images (3 channels) of size
58 | 28x28. We can then train this neural network as we would any normal Pytorch
59 | model.
60 |
61 | ## Installation
62 | It is possible to install the `cnn_gp` package without any of the dependencies
63 | that are needed for the experiments. Just run
64 | ```sh
65 | pip install -e .
66 | ```
67 | from the root directory of this same repository.
68 |
69 | ## Current limitations
70 | Dense layers are not implemented. The way to simulate them is to have a
71 | convolutional layer with `padding=0`, and with `kernel_size` as large as the
72 | activations in the previous layer.
73 |
74 | # Replicating the experiments
75 |
76 | First install the packages in `requirements.txt`. To run each of the
77 | experiments, first take a look at the files `exp_mnist_resnet/run.bash` or
78 | `exp_random_nn/run.bash`. Edit the configuration variables near the top
79 | appropriately. Then, run one of the files from the root of the directory, for
80 | example:
81 |
82 | ```bash
83 | bash ./exp_mnist_resnet/run.bash
84 | ```
85 |
86 | ## Experiment 1: classify MNIST
87 |
88 | Here are the test errors for the best GPs corresponding to the NN architectures
89 | reported in the paper.
90 |
91 | Name in paper | Config file | Validation error | Test error
92 | --------------|-------------|------------------|----------
93 | ConvNet GP | `mnist_paper_convnet_gp` | 0.71% | 1.03%
94 | Residual CNN GP | `mnist_paper_residual_cnn_gp` | 0.72% | 0.96%
95 | ResNet GP | `mnist_as_tf` | 0.68% | 0.84%
96 |
97 |
98 | (click to expand) Architecture for ConvNet GP
99 |
100 | ```python
101 | var_bias = 7.86
102 | var_weight = 2.79
103 |
104 | initial_model = Sequential(
105 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias),
106 | ReLU(),
107 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias),
108 | ReLU(),
109 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias),
110 | ReLU(),
111 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias),
112 | ReLU(),
113 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias),
114 | ReLU(),
115 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias),
116 | ReLU(),
117 | Conv2d(kernel_size=7, padding="same", var_weight=var_weight * 7**2, var_bias=var_bias),
118 | ReLU(), # Total 7 layers before dense
119 |
120 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight, var_bias=var_bias),
121 | ```
122 |
123 |
124 | (click to expand) Architecture for Residual CNN GP
125 |
126 | ```python
127 | var_bias = 4.69
128 | var_weight = 7.27
129 | initial_model = Sequential(
130 | *(Sum([
131 | Sequential(),
132 | Sequential(
133 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2,
134 | var_bias=var_bias),
135 | ReLU(),
136 | )]) for _ in range(8)),
137 | Conv2d(kernel_size=4, padding="same", var_weight=var_weight * 4**2,
138 | var_bias=var_bias),
139 | ReLU(),
140 | Conv2d(kernel_size=28, padding=0, var_weight=var_weight,
141 | var_bias=var_bias),
142 | )
143 | ```
144 |
145 |
146 |
147 | (click to expand) Architecture for ResNet GP
148 |
149 | ```python
150 | initial_model = Sequential(
151 | Conv2d(kernel_size=3),
152 |
153 | # Big resnet block #1
154 | resnet_block(stride=1, projection_shortcut=True, multiplier=1),
155 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
156 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
157 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
158 | resnet_block(stride=1, projection_shortcut=False, multiplier=1),
159 |
160 | # Big resnet block #2
161 | resnet_block(stride=2, projection_shortcut=True, multiplier=2),
162 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
163 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
164 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
165 | resnet_block(stride=1, projection_shortcut=False, multiplier=2),
166 |
167 | # Big resnet block #3
168 | resnet_block(stride=2, projection_shortcut=True, multiplier=4),
169 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
170 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
171 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
172 | resnet_block(stride=1, projection_shortcut=False, multiplier=4),
173 |
174 | # No nonlinearity here, the next Conv2d substitutes the average pooling
175 | Conv2d(kernel_size=7, padding=0, in_channel_multiplier=4,
176 | out_channel_multiplier=4),
177 | ReLU(),
178 | Conv2d(kernel_size=1, padding=0, in_channel_multiplier=4),
179 | )
180 | ```
181 |
182 |
183 | ## Experiment 2: Check that networks converge to a Gaussian process
184 | In the paper, only ResNet-32 GP is presented. This is why an issue when
185 | constructing the Residual CNN GP was originally not caught. More details in the
186 | relevant subsection.
187 | ### ResNet-32 GP
188 | 
189 | ### ConvNet GP
190 | 
191 | ### Residual CNN GP
192 | The best randomly-searched ResNet reported in the paper.
193 |
194 | In the original paper there is slight issue with how the kernels relate to the
195 | underlying networks. The network sums together layers after the ReLU nonlinearity,
196 | which are not Gaussian, and also do not have mean 0. However, the kernel is valid
197 | and does correspond to a neural network. In particular, if we take an infinite
198 | 1x1 convolution, after each relu layer,
199 | we convert the output of the ReLU's into a zero-mean Gaussian,
200 | with the same kernel, which can be summed.
201 | In the interest of making the results replicable, we have replicated this issue
202 | as well.
203 |
204 | The correct way to use ResNets is to sum things after a Conv2d layer, see for
205 | example the `resnet_block` in [`cnn_gp/kernels.py`](/cnn_gp/kernels.py).
206 |
207 | # BibTex citation record
208 | Note: the version in arXiv is slightly newer and contains information about
209 | which hyperparameters turned out to be the most effective for each architecture.
210 |
211 | ```bibtex
212 | @inproceedings{aga2018cnngp,
213 | author = {{Garriga-Alonso}, Adri{\`a} and Aitchison, Laurence and Rasmussen, Carl Edward},
214 | title = {Deep Convolutional Networks as shallow {G}aussian Processes},
215 | booktitle = {International Conference on Learning Representations},
216 | year = {2019},
217 | url = {https://openreview.net/forum?id=Bklfsi0cKm}}
218 | ```
219 |
--------------------------------------------------------------------------------
/cnn_gp/kernels.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .kernel_patch import ConvKP, NonlinKP
6 | import math
7 |
8 |
9 | __all__ = ("NNGPKernel", "Conv2d", "ReLU", "Sequential", "Mixture",
10 | "MixtureModule", "Sum", "SumModule", "resnet_block")
11 |
12 |
13 | class NNGPKernel(nn.Module):
14 | """
15 | Transforms one kernel matrix into another.
16 | [N1, N2, W, H] -> [N1, N2, W, H]
17 | """
18 | def forward(self, x, y=None, same=None, diag=False):
19 | """
20 | Either takes one minibatch (x), or takes two minibatches (x and y), and
21 | a boolean indicating whether they're the same.
22 | """
23 | if y is None:
24 | assert same is None
25 | y = x
26 | same = True
27 |
28 | assert not diag or len(x) == len(y), (
29 | "diagonal kernels must operate with data of equal length")
30 |
31 | assert 4==len(x.size())
32 | assert 4==len(y.size())
33 | assert x.size(1) == y.size(1)
34 | assert x.size(2) == y.size(2)
35 | assert x.size(3) == y.size(3)
36 |
37 | N1 = x.size(0)
38 | N2 = y.size(0)
39 | C = x.size(1)
40 | W = x.size(2)
41 | H = x.size(3)
42 |
43 | # [N1, C, W, H], [N2, C, W, H] -> [N1 N2, 1, W, H]
44 | if diag:
45 | xy = (x*y).mean(1, keepdim=True)
46 | else:
47 | xy = (x.unsqueeze(1)*y).mean(2).view(N1*N2, 1, W, H)
48 | xx = (x**2).mean(1, keepdim=True)
49 | yy = (y**2).mean(1, keepdim=True)
50 |
51 | initial_kp = ConvKP(same, diag, xy, xx, yy)
52 | final_kp = self.propagate(initial_kp)
53 | r = NonlinKP(final_kp).xy
54 | if diag:
55 | return r.view(N1)
56 | else:
57 | return r.view(N1, N2)
58 |
59 |
60 | class Conv2d(NNGPKernel):
61 | def __init__(self, kernel_size, stride=1, padding="same", dilation=1,
62 | var_weight=1., var_bias=0., in_channel_multiplier=1,
63 | out_channel_multiplier=1):
64 | super().__init__()
65 | self.kernel_size = kernel_size
66 | self.stride = stride
67 | self.dilation = dilation
68 | self.var_weight = var_weight
69 | self.var_bias = var_bias
70 | self.kernel_has_row_of_zeros = False
71 | if padding == "same":
72 | self.padding = dilation*(kernel_size//2)
73 | if kernel_size % 2 == 0:
74 | self.kernel_has_row_of_zeros = True
75 | else:
76 | self.padding = padding
77 |
78 | if self.kernel_has_row_of_zeros:
79 | # We need to pad one side larger than the other. We just make a
80 | # kernel that is slightly too large and make its last column and
81 | # row zeros.
82 | kernel = t.ones(1, 1, self.kernel_size+1, self.kernel_size+1)
83 | kernel[:, :, 0, :] = 0.
84 | kernel[:, :, :, 0] = 0.
85 | else:
86 | kernel = t.ones(1, 1, self.kernel_size, self.kernel_size)
87 | self.register_buffer('kernel', kernel
88 | * (self.var_weight / self.kernel_size**2))
89 | self.in_channel_multiplier, self.out_channel_multiplier = (
90 | in_channel_multiplier, out_channel_multiplier)
91 |
92 | def propagate(self, kp):
93 | kp = ConvKP(kp)
94 | def f(patch):
95 | return (F.conv2d(patch, self.kernel, stride=self.stride,
96 | padding=self.padding, dilation=self.dilation)
97 | + self.var_bias)
98 | return ConvKP(kp.same, kp.diag, f(kp.xy), f(kp.xx), f(kp.yy))
99 |
100 | def nn(self, channels, in_channels=None, out_channels=None):
101 | if in_channels is None:
102 | in_channels = channels
103 | if out_channels is None:
104 | out_channels = channels
105 | conv2d = nn.Conv2d(
106 | in_channels=in_channels * self.in_channel_multiplier,
107 | out_channels=out_channels * self.out_channel_multiplier,
108 | kernel_size=self.kernel_size + (
109 | 1 if self.kernel_has_row_of_zeros else 0),
110 | stride=self.stride,
111 | padding=self.padding,
112 | dilation=self.dilation,
113 | bias=(self.var_bias > 0.),
114 | )
115 | conv2d.weight.data.normal_(0, math.sqrt(
116 | self.var_weight / conv2d.in_channels) / self.kernel_size)
117 | if self.kernel_has_row_of_zeros:
118 | conv2d.weight.data[:, :, 0, :] = 0
119 | conv2d.weight.data[:, :, :, 0] = 0
120 | if self.var_bias > 0.:
121 | conv2d.bias.data.normal_(0, math.sqrt(self.var_bias))
122 | return conv2d
123 |
124 | def layers(self):
125 | return 1
126 |
127 |
128 | class ReLU(NNGPKernel):
129 | """
130 | A ReLU nonlinearity, the covariance is numerically stabilised by clamping
131 | values.
132 | """
133 | f32_tiny = np.finfo(np.float32).tiny
134 | def propagate(self, kp):
135 | kp = NonlinKP(kp)
136 | """
137 | We need to calculate (xy, xx, yy == c, v₁, v₂):
138 | ⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤⏤
139 | √(v₁v₂) / 2π ⎷1 - c²/v₁v₂ + (π - θ)c / √(v₁v₂)
140 |
141 | which is equivalent to:
142 | 1/2π ( √(v₁v₂ - c²) + (π - θ)c )
143 |
144 | # NOTE we divide by 2 to avoid multiplying the ReLU by sqrt(2)
145 | """
146 | xx_yy = kp.xx * kp.yy + self.f32_tiny
147 |
148 | # Clamp these so the outputs are not NaN
149 | cos_theta = (kp.xy * xx_yy.rsqrt()).clamp(-1, 1)
150 | sin_theta = t.sqrt((xx_yy - kp.xy**2).clamp(min=0))
151 | theta = t.acos(cos_theta)
152 | xy = (sin_theta + (math.pi - theta)*kp.xy) / (2*math.pi)
153 |
154 | xx = kp.xx/2.
155 | if kp.same:
156 | yy = xx
157 | if kp.diag:
158 | xy = xx
159 | else:
160 | # Make sure the diagonal agrees with `xx`
161 | eye = t.eye(xy.size()[0]).unsqueeze(-1).unsqueeze(-1).to(kp.xy.device)
162 | xy = (1-eye)*xy + eye*xx
163 | else:
164 | yy = kp.yy/2.
165 | return NonlinKP(kp.same, kp.diag, xy, xx, yy)
166 |
167 | def nn(self, channels, in_channels=None, out_channels=None):
168 | assert in_channels is None
169 | assert out_channels is None
170 | return nn.ReLU()
171 |
172 | def layers(self):
173 | return 0
174 |
175 |
176 | #### Combination classes
177 |
178 | class Sequential(NNGPKernel):
179 | def __init__(self, *mods):
180 | super().__init__()
181 | self.mods = mods
182 | for idx, mod in enumerate(mods):
183 | self.add_module(str(idx), mod)
184 | def propagate(self, kp):
185 | for mod in self.mods:
186 | kp = mod.propagate(kp)
187 | return kp
188 | def nn(self, channels, in_channels=None, out_channels=None):
189 | if len(self.mods) == 0:
190 | return nn.Sequential()
191 | elif len(self.mods) == 1:
192 | return self.mods[0].nn(channels, in_channels=in_channels, out_channels=out_channels)
193 | else:
194 | return nn.Sequential(
195 | self.mods[0].nn(channels, in_channels=in_channels),
196 | *[mod.nn(channels) for mod in self.mods[1:-1]],
197 | self.mods[-1].nn(channels, out_channels=out_channels)
198 | )
199 | def layers(self):
200 | return sum(mod.layers() for mod in self.mods)
201 |
202 |
203 | class Mixture(NNGPKernel):
204 | """
205 | Applys multiple modules to the input, and sums the result
206 | (e.g. for the implementation of a ResNet).
207 |
208 | Parameterised by proportion of each module (proportions add
209 | up to one, such that, if each model has average variance 1,
210 | then the output will also have average variance 1.
211 | """
212 | def __init__(self, mods, logit_proportions=None):
213 | super().__init__()
214 | self.mods = mods
215 | for idx, mod in enumerate(mods):
216 | self.add_module(str(idx), mod)
217 | if logit_proportions is None:
218 | logit_proportions = t.zeros(len(mods))
219 | self.logit = nn.Parameter(logit_proportions)
220 | def propagate(self, kp):
221 | proportions = F.softmax(self.logit, dim=0)
222 | total = self.mods[0].propagate(kp) * proportions[0]
223 | for i in range(1, len(self.mods)):
224 | total = total + (self.mods[i].propagate(kp) * proportions[i])
225 | return total
226 | def nn(self, channels, in_channels=None, out_channels=None):
227 | return MixtureModule([mod.nn(channels, in_channels=in_channels, out_channels=out_channels) for mod in self.mods], self.logit)
228 | def layers(self):
229 | return max(mod.layers() for mod in self.mods)
230 |
231 | class MixtureModule(nn.Module):
232 | def __init__(self, mods, logit_parameter):
233 | super().__init__()
234 | self.mods = mods
235 | self.logit = t.tensor(logit_parameter)
236 | for idx, mod in enumerate(mods):
237 | self.add_module(str(idx), mod)
238 | def forward(self, input):
239 | sqrt_proportions = F.softmax(self.logit, dim=0).sqrt()
240 | total = self.mods[0](input)*sqrt_proportions[0]
241 | for i in range(1, len(self.mods)):
242 | total = total + self.mods[i](input) # *sqrt_proportions[i]
243 | return total
244 |
245 |
246 | class Sum(NNGPKernel):
247 | def __init__(self, mods):
248 | super().__init__()
249 | self.mods = mods
250 | for idx, mod in enumerate(mods):
251 | self.add_module(str(idx), mod)
252 | def propagate(self, kp):
253 | # This adds 0 to the first kp, hopefully that's a noop
254 | return sum(m.propagate(kp) for m in self.mods)
255 | def nn(self, channels, in_channels=None, out_channels=None):
256 | return SumModule([
257 | mod.nn(channels, in_channels=in_channels, out_channels=out_channels)
258 | for mod in self.mods])
259 | def layers(self):
260 | return max(mod.layers() for mod in self.mods)
261 |
262 |
263 | class SumModule(nn.Module):
264 | def __init__(self, mods):
265 | super().__init__()
266 | self.mods = mods
267 | for idx, mod in enumerate(mods):
268 | self.add_module(str(idx), mod)
269 | def forward(self, input):
270 | # This adds 0 to the first value, hopefully that's a noop
271 | return sum(m(input) for m in self.mods)
272 |
273 |
274 | def resnet_block(stride=1, projection_shortcut=False, multiplier=1):
275 | if stride == 1 and not projection_shortcut:
276 | return Sum([
277 | Sequential(),
278 | Sequential(
279 | ReLU(),
280 | Conv2d(3, stride=stride, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier),
281 | ReLU(),
282 | Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier),
283 | )
284 | ])
285 | else:
286 | return Sequential(
287 | ReLU(),
288 | Sum([
289 | Conv2d(1, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier),
290 | Sequential(
291 | Conv2d(3, stride=stride, in_channel_multiplier=multiplier//stride, out_channel_multiplier=multiplier),
292 | ReLU(),
293 | Conv2d(3, in_channel_multiplier=multiplier, out_channel_multiplier=multiplier),
294 | )
295 | ]),
296 | )
297 |
--------------------------------------------------------------------------------