├── .flake8
├── .gitignore
├── LICENCE
├── README.md
├── assets
├── neural_graphs_dark_bg.png
├── neural_graphs_dark_transparent_bg.png
├── neural_graphs_light_bg.png
└── neural_graphs_light_transparent_bg.png
├── experiments
├── __init__.py
├── cnn_generalization
│ ├── README.md
│ ├── configs
│ │ ├── base.yaml
│ │ ├── data
│ │ │ ├── cnn_park.yaml
│ │ │ └── zoo_cifar_nfn.yaml
│ │ └── model
│ │ │ ├── dynamic_stat.yaml
│ │ │ ├── head_cls
│ │ │ └── mlp.yaml
│ │ │ ├── nfn.yaml
│ │ │ ├── pna.yaml
│ │ │ ├── rtransformer.yaml
│ │ │ └── stat.yaml
│ ├── dataset
│ │ ├── cnn_park_splits.json
│ │ ├── cnn_sampler.py
│ │ ├── cnn_trainer.py
│ │ ├── fast_tensor_dataloader.py
│ │ ├── generate_cnn_park_config
│ │ │ ├── base.yaml
│ │ │ └── data
│ │ │ │ ├── cifar10.yaml
│ │ │ │ └── fmnist.yaml
│ │ ├── generate_cnn_park_splits.py
│ │ ├── nfn_cifar10_split.csv
│ │ ├── train_cnn_park.py
│ │ └── zoo_cifar_nfn_statistics.pth
│ ├── main.py
│ ├── scripts
│ │ ├── cnn_park_pna.sh
│ │ ├── cnn_park_pna_no_act.sh
│ │ ├── cnn_park_rt.sh
│ │ ├── cnn_park_rt_no_act.sh
│ │ ├── cnn_zoo_pna.sh
│ │ └── cnn_zoo_rt.sh
│ ├── sweep_configs
│ │ ├── sweep_cnn_park_gnn.yaml
│ │ ├── sweep_cnn_park_statnn.yaml
│ │ └── sweep_cnn_park_transformer.yaml
│ └── utils.py
├── data.py
├── data_generalization.py
├── data_nfn.py
├── inr_classification
│ ├── README.md
│ ├── configs
│ │ ├── base.yaml
│ │ ├── data
│ │ │ ├── dummy_inr.yaml
│ │ │ ├── fmnist.yaml
│ │ │ └── mnist.yaml
│ │ └── model
│ │ │ ├── dwsnet.yaml
│ │ │ ├── mlp.yaml
│ │ │ ├── nfn.yaml
│ │ │ ├── pna.yaml
│ │ │ └── rtransformer.yaml
│ ├── dataset
│ │ ├── compute_fmnist_statistics.py
│ │ ├── compute_mnist_statistics.py
│ │ ├── compute_nfn_mnist_statistics.py
│ │ ├── fmnist_splits.json
│ │ ├── fmnist_statistics.pth
│ │ ├── generate_mnist_data_splits.py
│ │ ├── mnist_splits.json
│ │ ├── mnist_statistics.pth
│ │ └── preprocess_fmnist.py
│ ├── main.py
│ └── scripts
│ │ ├── fmnist_cls_pna.sh
│ │ ├── fmnist_cls_pna_probe_64.sh
│ │ ├── fmnist_cls_rt.sh
│ │ ├── fmnist_cls_rt_probe_64.sh
│ │ ├── mnist_cls_pna.sh
│ │ ├── mnist_cls_pna_probe_64.sh
│ │ ├── mnist_cls_rt.sh
│ │ └── mnist_cls_rt_probe_64.sh
├── learning_to_optimize
│ └── README.md
├── lr_scheduler.py
├── style_editing
│ ├── README.md
│ ├── configs
│ │ ├── base.yaml
│ │ ├── data
│ │ │ ├── fmnist.yaml
│ │ │ ├── mnist.yaml
│ │ │ ├── nfn_cifar.yaml
│ │ │ └── nfn_mnist.yaml
│ │ ├── model
│ │ │ ├── dwsnet.yaml
│ │ │ ├── nfn.yaml
│ │ │ ├── pna.yaml
│ │ │ └── rtransformer.yaml
│ │ └── out_of_domain_data
│ │ │ ├── fmnist.yaml
│ │ │ └── mnist.yaml
│ ├── dataset
│ ├── image_processing.py
│ ├── main.py
│ └── scripts
│ │ ├── mnist_dilation_gnn.sh
│ │ └── mnist_dilation_rt.sh
└── utils.py
├── nn
├── __init__.py
├── activation_embedding.py
├── dense_gnn.py
├── dense_relational_transformer.py
├── dws
│ ├── __init__.py
│ ├── base.py
│ ├── bias_to_bias.py
│ ├── bias_to_weight.py
│ ├── layers.py
│ ├── models.py
│ ├── weight_to_bias.py
│ └── weight_to_weight.py
├── dynamic_gnn.py
├── dynamic_graph_constructor.py
├── dynamic_relational_transformer.py
├── dynamic_stat_net.py
├── gnn.py
├── graph_constructor.py
├── inr.py
├── nfn
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ └── data.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── encoding.py
│ │ ├── layer_utils.py
│ │ ├── layers.py
│ │ ├── misc_layers.py
│ │ └── regularize.py
│ └── models.py
├── original_nfn
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ └── data.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── encoding.py
│ │ ├── layer_utils.py
│ │ ├── layers.py
│ │ ├── misc_layers.py
│ │ └── regularize.py
│ └── models.py
├── pooling.py
├── probe_features.py
└── relational_transformer.py
├── notebooks
└── mnist-inr-classification.ipynb
├── pyproject.toml
├── requirements.txt
├── setup.py
└── tests
├── README.md
├── __init__.py
├── test_cnn_invariance.py
├── test_inr_equivariance.py
├── test_inr_invariance.py
└── utils.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | extend-ignore = E203
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .idea/
131 | output/
132 | wandb/
133 | #*.pth
134 | .DS_Store
135 | NFN_data
136 | experiments/cifar
137 | experiments/fashion-mnist
138 | experiments/data/MNIST
139 |
140 | outputs
141 | mnist_data
142 | *.ckpt
143 | experiments/inr-classification/dataset/mnist
144 | experiments/inr-classification/dataset/MNIST
145 | experiments/inr-classification/dataset/fmnist
146 | experiments/inr-classification/dataset/fashion-mnist
147 | experiments/inr-classification/dataset/FashionMNIST
148 | experiments/inr_classification/dataset/fmnist_inrs
149 | experiments/inr_classification/dataset/mnist-inrs
150 | experiments/mnist/logs
151 | experiments/cnn_generalization/dataset/small-zoo-cifar10
152 | experiments/cnn_generalization/dataset/cifar10_zooV2
153 | experiments/cnn_generalization/dataset/cifar10
154 | experiments/cnn_generalization/raw_dataset/
155 | experiments/transformer_generalization/dataset/transformer_park/
156 | dataset/**/*.pth
157 | *.gz
158 | *.pt
159 | *-ubyte
160 | .vscode
161 |
162 | siren_cifar_wts.tar
163 | siren_fashion_wts.tar
164 | siren_mnist_wts.tar
165 | cifar-10-batches-py
166 | nfn-cifar10-inrs
167 | nfn-cifar-inrs
168 | nfn-mnist-inrs
169 | nfn-fmnist-inrs
170 | states_cifar10_32x32_100trajectories_vit_inr_adamw_lr0.003_wd0.01_50ep.dat
171 | states_cifar10_32x32_100trajectories_vit_inr_adamw_lr0.003_wd0.01_50ep_meta_data.pkl
172 |
173 | __MACOSX/
174 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Miltiadis Kofinas & David W. Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Neural Networks for Learning Equivariant Representations of Neural Networks
2 |
3 | Official implementation for
4 |
5 | Graph Neural Networks for Learning Equivariant Representations of Neural Networks
6 | Miltiadis Kofinas*, Boris Knyazev, Yan Zhang, Yunlu Chen, Gertjan J. Burghouts, Efstratios Gavves, Cees G. M. Snoek, David W. Zhang*
7 | ICLR 2024
8 | https://arxiv.org/abs/2403.12143/
9 | *Joint first and last authors
10 |
11 |
12 | [](https://arxiv.org/abs/2403.12143)
13 | [](https://openreview.net/forum?id=oO6FsMyDBt)
14 |
15 | [](https://doi.org/10.5281/zenodo.12797219)
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | ## Setup environment
24 |
25 | To run the experiments, first create a clean virtual environment and install the requirements.
26 |
27 | ```bash
28 | conda create -n neural-graphs python=3.9
29 | conda activate neural-graphs
30 | conda install pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
31 | conda install pyg==2.3.0 pytorch-scatter -c pyg
32 | pip install hydra-core einops opencv-python
33 | ```
34 |
35 | Install the repo:
36 |
37 | ```bash
38 | git clone https://https://github.com/mkofinas/neural-graphs.git
39 | cd neural-graphs
40 | pip install -e .
41 | ```
42 |
43 | ## Introduction Notebook
44 |
45 | An introduction notebook for INR classification with **Neural Graphs**:
46 | [](https://colab.research.google.com/github/mkofinas/neural-graphs/blob/main/notebooks/mnist-inr-classification.ipynb)
47 | [](notebooks/mnist-inr-classification.ipynb)
48 |
49 | ## Run experiments
50 |
51 | To run a specific experiment, please follow the instructions in the README file within each experiment folder.
52 | It provides full instructions and details for downloading the data and reproducing the results reported in the paper.
53 |
54 | - INR classification: [`experiments/inr_classification`](experiments/inr_classification)
55 | - INR style editing: [`experiments/style_editing`](experiments/style_editing)
56 | - CNN generalization: [`experiments/cnn_generalization`](experiments/cnn_generalization)
57 | - Learning to optimize (coming soon): [`experiments/learning_to_optimize`](experiments/learning_to_optimize)
58 |
59 | ## Datasets
60 |
61 | ### INR classification and style editing
62 |
63 | For INR classification, we use MNIST and Fashion MNIST. **The datasets are available [here](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0).**
64 |
65 | - [MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=mnist-inrs.zip)
66 | - [Fashion MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=fmnist_inrs.zip)
67 |
68 | For INR style editing, we use MNIST. **The dataset is available [here](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0).**
69 |
70 | - [MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=mnist-inrs.zip)
71 |
72 | ### CNN generalization
73 |
74 | For CNN generalization, we use the grayscale CIFAR-10 (CIFAR10-GS) from the
75 | [_Small CNN Zoo_](https://github.com/google-research/google-research/tree/master/dnn_predict_accuracy)
76 | dataset.
77 | We also introduce *CNN Wild Park*, a dataset of CNNs with varying numbers of
78 | layers, kernel sizes, activation functions, and residual connections between
79 | arbitrary layers.
80 |
81 | - [CIFAR10-GS](https://storage.cloud.google.com/gresearch/smallcnnzoo-dataset/cifar10.tar.xz)
82 | - [CNN Wild Park](https://zenodo.org/records/12797219)
83 |
84 | ## Citation
85 |
86 | If you find our work or this code to be useful in your own research, please consider citing the following paper:
87 |
88 | ```bib
89 | @inproceedings{kofinas2024graph,
90 | title={{G}raph {N}eural {N}etworks for {L}earning {E}quivariant {R}epresentations of {N}eural {N}etworks},
91 | author={Kofinas, Miltiadis and Knyazev, Boris and Zhang, Yan and Chen, Yunlu and Burghouts,
92 | Gertjan J. and Gavves, Efstratios and Snoek, Cees G. M. and Zhang, David W.},
93 | booktitle = {12th International Conference on Learning Representations ({ICLR})},
94 | year={2024}
95 | }
96 | ```
97 |
98 | ```bib
99 | @inproceedings{zhang2023neural,
100 | title={{N}eural {N}etworks {A}re {G}raphs! {G}raph {N}eural {N}etworks for {E}quivariant {P}rocessing of {N}eural {N}etworks},
101 | author={Zhang, David W. and Kofinas, Miltiadis and Zhang, Yan and Chen, Yunlu and Burghouts, Gertjan J. and Snoek, Cees G. M.},
102 | booktitle = {Workshop on Topology, Algebra, and Geometry in Machine Learning (TAG-ML), ICML},
103 | year={2023}
104 | }
105 | ```
106 |
107 | ## Acknowledgments
108 |
109 | - This codebase started based on [github.com/AvivNavon/DWSNets](https://github.com/AvivNavon/DWSNets) and the DWSNet implementation is copied from there
110 | - The NFN implementation is copied and slightly adapted from [github.com/AllanYangZhou/nfn](https://github.com/AllanYangZhou/nfn)
111 | - We implemented the relational transformer in PyTorch following the JAX implementation at [github.com/CameronDiao/relational-transformer](https://github.com/CameronDiao/relational-transformer). Our implementation has some differences that we describe in the paper.
112 |
113 | ## Contributors
114 |
115 | - [David W. Zhang](https://davzha.netlify.app/)
116 | - [Miltiadis (Miltos) Kofinas](https://mkofinas.github.io/)
117 |
--------------------------------------------------------------------------------
/assets/neural_graphs_dark_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_dark_bg.png
--------------------------------------------------------------------------------
/assets/neural_graphs_dark_transparent_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_dark_transparent_bg.png
--------------------------------------------------------------------------------
/assets/neural_graphs_light_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_light_bg.png
--------------------------------------------------------------------------------
/assets/neural_graphs_light_transparent_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_light_transparent_bg.png
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/__init__.py
--------------------------------------------------------------------------------
/experiments/cnn_generalization/README.md:
--------------------------------------------------------------------------------
1 | ## CNN generalization
2 |
3 | ### NFN CNN Zoo data
4 |
5 | This experiment follows [NFN](https://arxiv.org/abs/2302.14040).
6 | Download the
7 | [CIFAR10](https://storage.cloud.google.com/gresearch/smallcnnzoo-dataset/cifar10.tar.xz)
8 | data (originally from [Unterthiner et al,
9 | 2020](https://github.com/google-research/google-research/tree/master/dnn_predict_accuracy))
10 | into `./dataset`, and extract them. Change `data_path` in
11 | `./configs/data/zoo_cifar_nfn.yaml` if you want to store the data somewhere else.
12 |
13 | Options for `data`:
14 | - `zoo_cifar_nfn`: NFN CNN Zoo (CIFAR) dataset
15 |
16 |
17 |
18 | #### Run experiments with scripts
19 |
20 | You can run the experiments using the scripts provided in the `scripts` directory.
21 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the CNN Zoo dataset, run the following command:
22 |
23 | ```sh
24 | ./scripts/cnn_zoo_rt.sh
25 | ```
26 | This script will run the experiment for 3 different seeds.
27 |
28 | ### CNN Wild Park
29 |
30 | [](https://doi.org/10.5281/zenodo.12797219)
31 |
32 | Download the dataset from [Zenodo](https://doi.org/10.5281/zenodo.12797219) and extract it into `./dataset`.
33 |
34 | #### Run experiments with scripts
35 |
36 | You can run the experiments using the scripts provided in the `scripts` directory.
37 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the CNN Wild Park dataset, run the following command:
38 |
39 | ```sh
40 | ./scripts/cnn_zoo_rt.sh
41 | ```
42 | This script will run the experiment for 3 different seeds.
43 |
44 | #### Hyperparameter Sweep
45 |
46 | We also provide sweep configs for NG-T, NG-GNN, and StatNN in the `sweep_configs` directory.
47 | In the following commands, change the `--project` and the `--entity` according to
48 | your WandB account, or change the corresponding `yaml` files.
49 |
50 | __NG-T__:
51 | ```sh
52 | wandb sweep --project cnn-generalization --entity neural-graphs sweep_configs/sweep_cnn_park_transformer.yaml
53 | ```
54 |
55 | __NG-GNN__:
56 | ```sh
57 | wandb sweep --project cnn-generalization --entity neural-graphs sweep_configs/sweep_cnn_park_gnn.yaml
58 | ```
59 |
60 | __StatNN__:
61 | ```sh
62 | wandb sweep --project cnn-generalization --entity neural-graphs sweep_configs/sweep_cnn_park_statnn.yaml
63 | ```
64 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: rtransformer
3 | - data: zoo_cifar_nfn
4 | - _self_
5 |
6 | n_epochs: 200
7 | batch_size: 256
8 |
9 | n_views: 1
10 | num_workers: 8
11 | eval_every: 100
12 | num_accum: 1
13 |
14 | compile: false
15 | compile_kwargs:
16 | # mode: reduce-overhead
17 | mode: null
18 | options:
19 | matmul-padding: True
20 |
21 | optim:
22 | _target_: torch.optim.AdamW
23 | lr: 1e-3
24 | weight_decay: 5e-4
25 | amsgrad: True
26 | fused: False
27 |
28 | loss:
29 | _target_: torch.nn.MSELoss # torch.nn.BCELoss
30 |
31 | scheduler:
32 | _target_: experiments.lr_scheduler.WarmupLRScheduler
33 | warmup_steps: 0
34 |
35 | distributed:
36 | world_size: 1
37 | rank: 0
38 | device_ids: null
39 |
40 | load_ckpt: null
41 |
42 | use_amp: False
43 | gradscaler:
44 | enabled: ${use_amp}
45 | autocast:
46 | device_type: cuda
47 | enabled: ${use_amp}
48 | dtype: float16
49 |
50 | clip_grad: True
51 | clip_grad_max_norm: 10.0
52 |
53 | seed: 42
54 | save_path: ./output
55 | wandb:
56 | project: cnn-generalization
57 | entity: null
58 | name: null
59 |
60 | matmul_precision: high
61 | cudnn_benchmark: False
62 |
63 | debug: False
64 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/data/cnn_park.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data_generalization.CNNParkCIFAR10
3 | normalize: False
4 | dataset_dir: dataset
5 | splits_path: cnn_park_splits.json
6 | statistics_path: null
7 | input_channels: 3
8 | num_classes: 10
9 | layer_layout: null
10 | img_shape: [32, 32]
11 | _max_kernel_height: 7
12 | _max_kernel_width: 7
13 | max_kernel_size:
14 | - ${data._max_kernel_height}
15 | - ${data._max_kernel_width}
16 | linear_as_conv: True
17 | flattening_method: null # repeat_nodes or extra_layer
18 | max_spatial_resolution: 49 # 7x7 feature map size
19 | deg: [118980, 0, 0, 607096, 394242, 98840, 0, 0, 390958,
20 | 87468, 0, 0, 0, 0, 0, 0, 381834, 59684,
21 | 0, 0, 0, 0, 0, 0, 0, 0, 0,
22 | 0, 0, 0, 0, 0, 351606, 81132]
23 | data_format: graph
24 | max_num_hidden_layers: 5
25 | inr_model: null
26 |
27 | stats:
28 | weights_mean: null
29 | weights_std: null
30 | biases_mean: null
31 | biases_std: null
32 |
33 | train:
34 | _target_: ${data.target}
35 | _recursive_: True
36 | dataset_dir: ${data.dataset_dir}
37 | splits_path: ${data.splits_path}
38 | split: train
39 | normalize: ${data.normalize}
40 | augmentation: False
41 | statistics_path: ${data.statistics_path}
42 | max_kernel_size: ${data.max_kernel_size}
43 | linear_as_conv: ${data.linear_as_conv}
44 | flattening_method: ${data.flattening_method}
45 | max_num_hidden_layers: ${data.max_num_hidden_layers}
46 | data_format: ${data.data_format}
47 | # num_classes: ${data.num_classes}
48 |
49 | val:
50 | _target_: ${data.target}
51 | dataset_dir: ${data.dataset_dir}
52 | splits_path: ${data.splits_path}
53 | split: val
54 | normalize: ${data.normalize}
55 | augmentation: False
56 | statistics_path: ${data.statistics_path}
57 | max_kernel_size: ${data.max_kernel_size}
58 | linear_as_conv: ${data.linear_as_conv}
59 | flattening_method: ${data.flattening_method}
60 | max_num_hidden_layers: ${data.max_num_hidden_layers}
61 | data_format: ${data.data_format}
62 | # num_classes: ${data.num_classes}
63 |
64 | test:
65 | _target_: ${data.target}
66 | dataset_dir: ${data.dataset_dir}
67 | splits_path: ${data.splits_path}
68 | split: test
69 | normalize: ${data.normalize}
70 | augmentation: False
71 | statistics_path: ${data.statistics_path}
72 | max_kernel_size: ${data.max_kernel_size}
73 | linear_as_conv: ${data.linear_as_conv}
74 | flattening_method: ${data.flattening_method}
75 | max_num_hidden_layers: ${data.max_num_hidden_layers}
76 | data_format: ${data.data_format}
77 | # num_classes: ${data.num_classes}
78 |
79 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/data/zoo_cifar_nfn.yaml:
--------------------------------------------------------------------------------
1 | target: experiments.data_generalization.NFNZooDataset
2 | data_path: dataset/small-zoo-cifar10
3 | idcs_file: dataset/nfn_cifar10_split.csv
4 | normalize: False
5 | statistics_path: dataset/zoo_cifar_nfn_statistics.pth
6 |
7 | input_channels: 1 # grayscale cifar10
8 | num_classes: 10
9 | layer_layout: [1, 16, 16, 16, 10]
10 | img_shape: [32, 32]
11 | _max_kernel_height: 3
12 | _max_kernel_width: 3
13 | max_kernel_size:
14 | - ${data._max_kernel_height}
15 | - ${data._max_kernel_width}
16 | linear_as_conv: True
17 | flattening_method: null # repeat_nodes or extra_layer or None
18 | max_spatial_resolution: 49 # 7x7 feature map size
19 | deg: [12000, 192000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 504000]
20 | data_format: graph
21 | max_num_hidden_layers: 3
22 |
23 | stats:
24 | # weights_mean: [-0.008810473, -0.019692749, -0.012631954, 0.018839896]
25 | # weights_std: [0.5502305, 0.4353398, 0.3642972, 0.36821017]
26 | # biases_mean: [-0.050750412, 0.006151379, 0.046173226, 0.04843864]
27 | # biases_std: [0.40749395, 0.9723978, 1.9454101, 0.5446171]
28 | weights_mean: null
29 | weights_std: null
30 | biases_mean: null
31 | biases_std: null
32 |
33 | train:
34 | _target_: ${data.target}
35 | _recursive_: True
36 | data_path: ${data.data_path}
37 | idcs_file: ${data.idcs_file}
38 | split: train
39 | augmentation: True
40 | noise_scale: 0.1
41 | drop_rate: 0.01
42 | normalize: ${data.normalize}
43 | max_kernel_size: ${data.max_kernel_size}
44 | linear_as_conv: ${data.linear_as_conv}
45 | flattening_method: ${data.flattening_method}
46 | max_num_hidden_layers: ${model.graph_constructor.max_num_hidden_layers}
47 | data_format: ${data.data_format}
48 |
49 | val:
50 | _target_: ${data.target}
51 | data_path: ${data.data_path}
52 | idcs_file: ${data.idcs_file}
53 | split: val
54 | augmentation: False
55 | normalize: ${data.normalize}
56 | max_kernel_size: ${data.max_kernel_size}
57 | linear_as_conv: ${data.linear_as_conv}
58 | flattening_method: ${data.flattening_method}
59 | max_num_hidden_layers: ${model.graph_constructor.max_num_hidden_layers}
60 | data_format: ${data.data_format}
61 |
62 | test:
63 | _target_: ${data.target}
64 | data_path: ${data.data_path}
65 | idcs_file: ${data.idcs_file}
66 | split: test
67 | augmentation: False
68 | normalize: ${data.normalize}
69 | max_kernel_size: ${data.max_kernel_size}
70 | linear_as_conv: ${data.linear_as_conv}
71 | flattening_method: ${data.flattening_method}
72 | max_num_hidden_layers: ${model.graph_constructor.max_num_hidden_layers}
73 | data_format: ${data.data_format}
74 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/model/dynamic_stat.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.dynamic_stat_net.DynamicStatNet
2 |
3 | h_size: 1000
4 | max_kernel_size: ${prod:${data._max_kernel_height}, ${data._max_kernel_width}}
5 | max_num_hidden_layers: ${data.max_num_hidden_layers}
6 | max_kernel_height: ${data._max_kernel_height}
7 | max_kernel_width: ${data._max_kernel_width}
8 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/model/head_cls/mlp.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.original_nfn.models.MlpHead
2 | _partial_: true
3 | sigmoid: false
4 | num_out: 1
5 | lnorm: false
6 | pool_mode: HNP
7 | dropout: 0
8 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/model/nfn.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - head_cls: mlp
3 | _target_: nn.original_nfn.models.InvariantNFN
4 | hchannels: [16, 16, 5]
5 | mode: HNP
6 | feature_dropout: 0.0
7 | normalize: false
8 | lnorm: null
9 | append_stats: false
10 | max_num_hidden_layers: ${data.max_num_hidden_layers}
11 |
12 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/model/pna.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.dynamic_gnn.GNNForGeneralization
2 | _recursive_: False
3 | d_out: 1
4 | d_hid: 32
5 | compile: False
6 | rev_edge_features: False
7 | pooling_method: cat
8 | pooling_layer_idx: last # all, last, or 0, 1, ...
9 | jit: False
10 | layer_layout: ${data.layer_layout}
11 |
12 | input_channels: ${data.input_channels}
13 | num_classes: ${data.num_classes}
14 |
15 | gnn_backbone:
16 | _target_: nn.gnn.PNA
17 | _convert_: all
18 | in_channels: ${model.d_hid}
19 | hidden_channels: ${model.d_hid}
20 | out_channels: ${model.d_hid}
21 | num_layers: 4
22 | aggregators: ['mean', 'min', 'max', 'std']
23 | scalers: ['identity', 'amplification']
24 | edge_dim: ${model.d_hid}
25 | dropout: 0.
26 | norm: layernorm
27 | act: silu
28 | deg: ${data.deg}
29 | update_edge_attr: True
30 | modulate_edges: True
31 | gating_edges: False
32 | final_edge_update: False
33 |
34 | graph_constructor:
35 | _target_: nn.dynamic_graph_constructor.GraphConstructor
36 | _recursive_: False
37 | _convert_: all
38 | d_in: 1
39 | d_edge_in: ${prod:${data._max_kernel_height}, ${data._max_kernel_width}}
40 | max_num_hidden_layers: ${data.max_num_hidden_layers}
41 | zero_out_bias: False
42 | zero_out_weights: False
43 | sin_emb: True
44 | sin_emb_dim: 128
45 | use_pos_embed: True
46 | input_layers: 1
47 | inp_factor: 3
48 | num_probe_features: 0
49 | # inr_model: ${data.inr_model}
50 | stats: ${data.stats}
51 | linear_as_conv: ${data.linear_as_conv}
52 | flattening_method: ${data.flattening_method}
53 | max_spatial_resolution: ${data.max_spatial_resolution}
54 | use_act_embed: True
55 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/model/rtransformer.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.dynamic_relational_transformer.DynamicRelationalTransformer
2 | _recursive_: False
3 | d_in: 1
4 | d_node: 64
5 | d_edge: 32
6 | d_attn_hid: 128
7 | d_node_hid: 128
8 | d_edge_hid: 64
9 | d_out_hid: 128
10 | d_out: 1
11 | n_layers: 4
12 | n_heads: 8
13 | node_update_type: rt
14 | disable_edge_updates: False
15 | use_cls_token: False
16 | pooling_method: cat
17 | pooling_layer_idx: last # all, last, or 0, 1, ...
18 | dropout: 0.0
19 | rev_edge_features: False
20 | use_ln: True
21 | tfixit_init: False
22 | modulate_v: True
23 | input_channels: ${data.input_channels}
24 | num_classes: ${data.num_classes}
25 | layer_layout: ${data.layer_layout}
26 |
27 | graph_constructor:
28 | _target_: nn.dynamic_graph_constructor.GraphConstructor
29 | _recursive_: False
30 | _convert_: all
31 | d_edge_in: ${prod:${data._max_kernel_height}, ${data._max_kernel_width}}
32 | zero_out_bias: False
33 | zero_out_weights: False
34 | sin_emb: True
35 | sin_emb_dim: 128
36 | use_pos_embed: True
37 | input_layers: 1
38 | inp_factor: 1
39 | num_probe_features: 0
40 | # inr_model: ${data.inr_model}
41 | stats: ${data.stats}
42 | max_num_hidden_layers: ${data.max_num_hidden_layers}
43 | linear_as_conv: ${data.linear_as_conv}
44 | flattening_method: ${data.flattening_method}
45 | max_spatial_resolution: ${data.max_spatial_resolution}
46 | use_act_embed: True
47 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/configs/model/stat.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.original_nfn.models.StatNet
2 |
3 | h_size: 1000
4 | dropout: 0.0
5 | sigmoid: true
6 | normalize: false
7 | max_num_hidden_layers: ${data.max_num_hidden_layers}
8 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/cnn_sampler.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import random
3 | from typing import Union
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | @dataclass(frozen=True)
11 | class CNNConfig:
12 | n_layers: int
13 | n_classes: int
14 | channels: list[int]
15 | kernel_size: list[int]
16 | stride: list[int]
17 | padding: list[int]
18 | residual: list[int]
19 | activation: list[str]
20 |
21 |
22 | DEFAULT_CONFIG_OPTIONS = {
23 | "n_layers": [2, 3, 4, 5],
24 | "n_classes": 10,
25 | "in_channels": 3,
26 | "channels": [4, 8, 16, 32],
27 | "kernel_size": [3, 5, 7],
28 | "stride": [1],
29 | "activation": ["relu", "gelu", "tanh", "sigmoid", "leaky_relu"],
30 | }
31 |
32 |
33 | ACTIVATION_FN = {
34 | "relu": F.relu,
35 | "gelu": F.gelu,
36 | "tanh": torch.tanh,
37 | "sigmoid": torch.sigmoid,
38 | "leaky_relu": F.leaky_relu,
39 | "none": lambda x: x,
40 | }
41 |
42 |
43 | class CNN(nn.Module):
44 | def __init__(self, cfg: CNNConfig) -> None:
45 | super().__init__()
46 | self._assert_cfg(cfg)
47 | self.cfg = cfg
48 |
49 | self.layers = nn.ModuleList()
50 | for i in range(len(cfg.channels) - 1):
51 | in_channels = cfg.channels[i]
52 | out_channels = cfg.channels[i + 1]
53 | kernel_size = cfg.kernel_size[i]
54 | stride = cfg.stride[i]
55 | padding = cfg.padding[i]
56 | conv = nn.Conv2d(
57 | in_channels,
58 | out_channels,
59 | kernel_size=kernel_size,
60 | padding=padding,
61 | stride=stride,
62 | )
63 | self.layers.append(conv)
64 | self.global_pool = nn.AdaptiveAvgPool2d(1)
65 | self.flatten = nn.Flatten()
66 | self.fc = nn.Linear(cfg.channels[-1], cfg.n_classes)
67 |
68 | def _assert_cfg(self, cfg: CNNConfig) -> None:
69 | assert (
70 | len(cfg.channels) - 1
71 | == len(cfg.kernel_size)
72 | == len(cfg.stride)
73 | == len(cfg.residual)
74 | == len(cfg.activation)
75 | == cfg.n_layers
76 | )
77 | assert len(cfg.channels) >= 2
78 | assert cfg.residual[0] == -1
79 | assert cfg.n_layers - 1 not in cfg.residual
80 | for i, r in enumerate(cfg.residual):
81 | assert r < i - 1 or r < 0
82 |
83 | def forward(self, x: torch.Tensor) -> torch.Tensor:
84 | residuals = dict()
85 | for i, layer in enumerate(self.layers):
86 | x = layer(x)
87 | if self.cfg.residual[i] > -1:
88 | # shared channels between residual and current layer
89 | ch_res = self.cfg.channels[self.cfg.residual[i] + 1]
90 | ch_out = self.cfg.channels[i + 1]
91 | ch = min(ch_res, ch_out)
92 | x = x.clone()
93 | x[:, :ch] += residuals[self.cfg.residual[i]][:, :ch]
94 | x = ACTIVATION_FN[self.cfg.activation[i]](x)
95 | # ------ x here is the node in the computation graph ------
96 | if i in self.cfg.residual:
97 | # save the residual for later use
98 | residuals[i] = x
99 |
100 | x = self.global_pool(x)
101 | x = self.flatten(x)
102 | x = self.fc(x)
103 | return x
104 |
105 |
106 | def sample_cnn_config(options: Union[dict, None] = None) -> CNNConfig:
107 | if options is None:
108 | options = DEFAULT_CONFIG_OPTIONS
109 |
110 | n_layers = random.choice(options["n_layers"])
111 | n_classes = options["n_classes"]
112 | channels = [options["in_channels"]] + [
113 | random.choice(options["channels"]) for _ in range(n_layers)
114 | ]
115 | kernel_size = [random.choice(options["kernel_size"]) for _ in range(n_layers)]
116 | stride = [random.choice(options["stride"]) for _ in range(n_layers)]
117 | # padding based on (odd) kernel size
118 | padding = [kernel_size[i] // 2 for i in range(n_layers)]
119 | # residuals can come from any previous layer, but not the preceding one
120 | residual = [-1] + [random.choice([-1, *range(i)]) for i in range(n_layers - 1)]
121 | activation = [random.choice(options["activation"]) for _ in range(n_layers)]
122 | return CNNConfig(
123 | n_layers=n_layers,
124 | n_classes=n_classes,
125 | channels=channels,
126 | kernel_size=kernel_size,
127 | stride=stride,
128 | padding=padding,
129 | residual=residual,
130 | activation=activation,
131 | )
132 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/cnn_trainer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import hydra
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from omegaconf import OmegaConf
8 | from ray.tune import Trainable
9 |
10 | from experiments.cnn_generalization.dataset.cnn_sampler import CNN
11 | from experiments.cnn_generalization.dataset.fast_tensor_dataloader import (
12 | FastTensorDataLoader,
13 | )
14 |
15 |
16 | class NN_tune_trainable(Trainable):
17 | def setup(self, cfg: dict):
18 | self.cfg = OmegaConf.create(cfg)
19 |
20 | dataset = torch.load(self.cfg.data.efficient_dataset_path)
21 | self.trainset = dataset["trainset"]
22 | self.testset = dataset["testset"]
23 | self.valset = dataset.get("valset", None)
24 |
25 | # instantiate Tensordatasets
26 | self.trainloader = FastTensorDataLoader(
27 | dataset=self.trainset,
28 | batch_size=self.cfg.batch_size,
29 | shuffle=True,
30 | # num_workers=self.cfg.num_workers,
31 | )
32 | self.testloader = FastTensorDataLoader(
33 | dataset=self.testset, batch_size=len(self.testset), shuffle=False
34 | )
35 |
36 | self.steps_per_epoch = len(self.trainloader)
37 |
38 | # init model
39 | self.model = CNN(self.cfg.model).to(self.cfg.device)
40 | self.optimizer = hydra.utils.instantiate(
41 | self.cfg.optimizer, params=self.model.parameters()
42 | )
43 |
44 | # run first test epoch and log results
45 | self._iteration = -1
46 |
47 | def step(self):
48 | # here, all manual writers are disabled. tune takes care of that
49 | # run one training epoch
50 | train(self.model, self.optimizer, self.trainloader, self.cfg.device, 1)
51 | # run one test epoch
52 | test_results = evaluate(self.model, self.testloader, self.cfg.device)
53 |
54 | result_dict = {
55 | **{"test/" + k: v for k, v in test_results.items()},
56 | }
57 | # if self.valset is not None:
58 | # pass
59 | self.stats = result_dict
60 |
61 | return result_dict
62 |
63 | def save_checkpoint(self, tmp_checkpoint_dir):
64 | # define checkpoint path
65 | path = Path(tmp_checkpoint_dir) / "checkpoint.pt"
66 | torch.save(
67 | {
68 | "model": self.model.state_dict(),
69 | "optimizer": self.optimizer.state_dict(),
70 | "config": self.cfg.model,
71 | **self.get_state(),
72 | },
73 | path,
74 | )
75 |
76 | # tune apparently expects to return the directory
77 | return tmp_checkpoint_dir
78 |
79 | def load_checkpoint(self, tmp_checkpoint_dir):
80 | # define checkpoint path
81 | path = Path(tmp_checkpoint_dir) / "checkpoint.pt"
82 | # save model state dict
83 | checkpoint = torch.load(path)
84 | self.model.load_state_dict(checkpoint["model"])
85 | # load optimizer
86 | try:
87 | # opt_dict = torch.load(path / "optimizer")
88 | self.optimizer.load_state_dict(checkpoint["optimizer"])
89 | except:
90 | print(f"Could not load optimizer state_dict. (not found at path {path})")
91 |
92 | def reset_config(self, new_config):
93 | success = False
94 | try:
95 | print(
96 | "### warning: reuse actors / reset_config only if the dataset remains exactly the same. \n ### only dataloader and model are reconfiugred"
97 | )
98 | self.cfg = new_config
99 |
100 | # init model
101 | self.NN = CNN(self.cfg.model).to(self.cfg.device)
102 |
103 | # instanciate Tensordatasets
104 | self.trainloader = FastTensorDataLoader(
105 | dataset=self.trainset,
106 | batch_size=self.cfg.batch_size,
107 | shuffle=True,
108 | )
109 | self.testloader = FastTensorDataLoader(
110 | dataset=self.testset, batch_size=len(self.testset), shuffle=False
111 | )
112 |
113 | # drop inital checkpoint
114 | self.save()
115 |
116 | # run first test epoch and log results
117 | self._iteration = -1
118 |
119 | # if we got to this point:
120 | success = True
121 |
122 | except Exception as e:
123 | print(e)
124 |
125 | return success
126 |
127 |
128 | def train(
129 | model: nn.Module,
130 | optimizer: torch.optim.Optimizer,
131 | train_loader: torch.utils.data.DataLoader,
132 | device: torch.device,
133 | epochs: int,
134 | ) -> None:
135 | model.train()
136 | model.to(device)
137 | for e in range(epochs):
138 | for batch_idx, (data, target) in enumerate(train_loader):
139 | data, target = data.to(device), target.to(device)
140 | optimizer.zero_grad()
141 | output = model(data)
142 | loss = F.cross_entropy(output, target)
143 | loss.backward()
144 | optimizer.step()
145 |
146 |
147 | def evaluate(
148 | model: nn.Module, test_loader: torch.utils.data.DataLoader, device: torch.device
149 | ) -> dict:
150 | model.eval()
151 | model.to(device)
152 | correct = 0
153 | loss = 0
154 | with torch.no_grad():
155 | for data, target in test_loader:
156 | data, target = data.to(device), target.to(device)
157 | output = model(data)
158 | pred = output.argmax(dim=1)
159 | correct += pred.eq(target).sum().item()
160 | loss += F.cross_entropy(output, target, reduction="sum").item()
161 |
162 | return {
163 | "acc": correct / len(test_loader.dataset),
164 | "loss": loss / len(test_loader.dataset),
165 | }
166 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/fast_tensor_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class FastTensorDataLoader:
5 | """
6 | A DataLoader-like object for a set of tensors that can be much faster than
7 | TensorDataset + DataLoader because dataloader grabs individual indices of
8 | the dataset and calls cat (slow).
9 | """
10 |
11 | def __init__(self, dataset, batch_size=32, shuffle=False):
12 | """
13 | Initialize a FastTensorDataLoader.
14 |
15 | :param *tensors: tensors to store. Must have the same length @ dim 0.
16 | :param batch_size: batch size to load.
17 | :param shuffle: if True, shuffle the data *in-place* whenever an
18 | iterator is created out of this object.
19 |
20 | :returns: A FastTensorDataLoader.
21 | """
22 | self.dataset = dataset
23 | assert all(
24 | t.shape[0] == self.dataset.tensors[0].shape[0] for t in self.dataset.tensors
25 | )
26 | self.tensors = self.dataset.tensors
27 |
28 | self.dataset_len = self.tensors[0].shape[0]
29 | self.device = self.tensors[0].device
30 | self.batch_size = batch_size
31 | self.shuffle = shuffle
32 |
33 | # Calculate # batches
34 | n_batches, remainder = divmod(self.dataset_len, self.batch_size)
35 | if remainder > 0:
36 | n_batches += 1
37 | self.n_batches = n_batches
38 |
39 | def __iter__(self):
40 | if self.shuffle:
41 | self.indices = torch.randperm(self.dataset_len, device=self.device)
42 | else:
43 | self.indices = None
44 | self.i = 0
45 | return self
46 |
47 | def __next__(self):
48 | if self.batch_size == self.dataset_len:
49 | # check if this is the first full batch
50 | if self.i == 0:
51 | # raise counter
52 | self.i = 1
53 | return self.tensors
54 | else:
55 | raise StopIteration
56 | else:
57 | if self.i >= self.dataset_len:
58 | raise StopIteration
59 | if self.indices is not None:
60 | indices = self.indices[self.i : self.i + self.batch_size]
61 | batch = tuple(torch.index_select(t, 0, indices) for t in self.tensors)
62 | else:
63 | batch = tuple(
64 | t[self.i : self.i + self.batch_size] for t in self.tensors
65 | )
66 | self.i += self.batch_size
67 | return batch
68 |
69 | def __len__(self):
70 | return self.n_batches
71 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/generate_cnn_park_config/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - data: cifar10
3 | - _self_
4 |
5 | name: cifar10_zooV3
6 |
7 | cpus: 32
8 | gpus: 1
9 | cpu_per_trial: 8
10 | device: cuda
11 |
12 | cudnn_benchmark: True
13 | matmul_precision: high
14 | num_workers: 8
15 |
16 | random_options:
17 | n_layers: [2, 3, 4, 5, 6, 7, 8]
18 | n_classes: ${data.n_classes}
19 | in_channels: ${data.in_channels}
20 | channels: [4, 8, 16, 32, 64]
21 | kernel_size: [3, 5, 7] # NOTE: We only use odd kernels for now
22 | stride: [1] # NOTE: We only use stide 1 for now
23 | activation: ['relu', 'gelu', 'tanh', 'sigmoid', 'leaky_relu', 'none']
24 |
25 | num_epochs: 200
26 | ckpt_freq: 10
27 |
28 | seed: 1
29 | num_models: 50_000
30 | model: null
31 |
32 | batch_size: 512
33 | optimizer:
34 | _target_: torch.optim.AdamW
35 | lr: 0.001
36 | weight_decay: 0.0
37 |
38 | wandb:
39 | project: cnn-park
40 | # entity: null
41 | # name: null
42 | # log_config: True
43 |
44 | hydra:
45 | output_subdir: null
46 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/generate_cnn_park_config/data/cifar10.yaml:
--------------------------------------------------------------------------------
1 | root: cifar10
2 | efficient_dataset_path: cifar10/dataset.pt
3 | n_classes: 10
4 | in_channels: 3
5 | dataset_seed: 0
6 | train:
7 | _target_: torchvision.datasets.CIFAR10
8 | root: ${data.root}
9 | train: True
10 | download: True
11 | transform:
12 | _target_: torchvision.transforms.Compose
13 | transforms:
14 | - _target_: torchvision.transforms.ToTensor
15 | - _target_: torchvision.transforms.Normalize
16 | mean: [0.49139968, 0.48215841, 0.44653091]
17 | std: [0.24703223, 0.24348513, 0.26158784]
18 | test:
19 | _target_: torchvision.datasets.CIFAR10
20 | root: ${data.root}
21 | train: False
22 | download: True
23 | transform:
24 | _target_: torchvision.transforms.Compose
25 | transforms:
26 | - _target_: torchvision.transforms.ToTensor
27 | - _target_: torchvision.transforms.Normalize
28 | mean: [0.49139968, 0.48215841, 0.44653091]
29 | std: [0.24703223, 0.24348513, 0.26158784]
30 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/generate_cnn_park_config/data/fmnist.yaml:
--------------------------------------------------------------------------------
1 | root: fashion_mnist
2 | efficient_dataset_path: fashion_mnist/dataset.pt
3 | n_classes: 10
4 | in_channels: 1
5 | dataset_seed: 0
6 | train:
7 | _target_: torchvision.datasets.FashionMNIST
8 | root: ${data.root}
9 | train: True
10 | download: True
11 | transform:
12 | _target_: torchvision.transforms.Compose
13 | transforms:
14 | - _target_: torchvision.transforms.ToTensor
15 | - _target_: torchvision.transforms.Normalize
16 | mean: [0.2860405969887955]
17 | std: [0.35302424451492237]
18 | test:
19 | _target_: torchvision.datasets.FashionMNIST
20 | root: ${data.root}
21 | train: False
22 | download: True
23 | transform:
24 | _target_: torchvision.transforms.Compose
25 | transforms:
26 | - _target_: torchvision.transforms.ToTensor
27 | - _target_: torchvision.transforms.Normalize
28 | mean: [0.2860405969887955]
29 | std: [0.35302424451492237]
30 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/generate_cnn_park_splits.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from argparse import ArgumentParser
4 | from pathlib import Path
5 | from collections import defaultdict
6 | from itertools import groupby
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from sklearn.model_selection import train_test_split
12 |
13 | from experiments.utils import common_parser, set_logger
14 |
15 |
16 | def generate_splits(
17 | data_path,
18 | save_path,
19 | name="cnn_park_splits.json",
20 | val_size=10000,
21 | test_size=10000,
22 | seed=42,
23 | ):
24 | script_dir = Path(__file__).parent
25 | data_path = script_dir / Path(data_path)
26 | # We have to sort the files to make sure that the order between checkpoints
27 | # and progresses is the same. We will randomize later.
28 | checkpoints = sorted(data_path.glob("cifar10_zooV2/*/*/checkpoint.pt"))
29 | checkpoint_parents = sorted(list(set([c.parent.parent for c in checkpoints])))
30 | progresses = {
31 | ckpt.as_posix(): torch.load(ckpt, map_location="cpu")["last_result"]["test/acc"]
32 | for ckpt in checkpoints
33 | }
34 |
35 | checkpoint_steps = {
36 | ckpt.as_posix(): torch.load(ckpt, map_location="cpu")["iteration"]
37 | for ckpt in checkpoints
38 | }
39 | print(
40 | len(checkpoint_steps),
41 | len(progresses),
42 | len(checkpoint_parents),
43 | len(checkpoints),
44 | )
45 |
46 | trainval_indices, test_indices = train_test_split(
47 | range(len(checkpoint_parents)), test_size=test_size, random_state=seed
48 | )
49 | train_indices, val_indices = train_test_split(
50 | trainval_indices, test_size=val_size, random_state=seed
51 | )
52 | grouped_checkpoints = [
53 | list(g) for _, g in groupby(checkpoints, lambda x: x.parent.parent)
54 | ]
55 |
56 | data_split = defaultdict(lambda: defaultdict(list))
57 | data_split["train"]["path"] = sum(
58 | [
59 | [
60 | ckpt.relative_to(script_dir).as_posix()
61 | for ckpt in grouped_checkpoints[idx]
62 | ]
63 | for idx in train_indices
64 | ],
65 | [],
66 | )
67 | data_split["train"]["score"] = sum(
68 | [
69 | [progresses[str(ckpt)] for ckpt in grouped_checkpoints[idx]]
70 | for idx in train_indices
71 | ],
72 | [],
73 | )
74 | data_split["train"]["step"] = sum(
75 | [
76 | [checkpoint_steps[str(ckpt)] for ckpt in grouped_checkpoints[idx]]
77 | for idx in train_indices
78 | ],
79 | [],
80 | )
81 | permutation = np.random.permutation(len(data_split["train"]["path"]))
82 | data_split["train"]["path"] = [
83 | data_split["train"]["path"][idx] for idx in permutation
84 | ]
85 | data_split["train"]["score"] = [
86 | data_split["train"]["score"][idx] for idx in permutation
87 | ]
88 | data_split["train"]["step"] = [
89 | data_split["train"]["step"][idx] for idx in permutation
90 | ]
91 |
92 | data_split["val"]["path"] = sum(
93 | [
94 | [
95 | ckpt.relative_to(script_dir).as_posix()
96 | for ckpt in grouped_checkpoints[idx]
97 | ]
98 | for idx in val_indices
99 | ],
100 | [],
101 | )
102 | data_split["val"]["score"] = sum(
103 | [
104 | [progresses[str(ckpt)] for ckpt in grouped_checkpoints[idx]]
105 | for idx in val_indices
106 | ],
107 | [],
108 | )
109 | data_split["val"]["step"] = sum(
110 | [
111 | [checkpoint_steps[str(ckpt)] for ckpt in grouped_checkpoints[idx]]
112 | for idx in val_indices
113 | ],
114 | [],
115 | )
116 | permutation = np.random.permutation(len(data_split["val"]["path"]))
117 | data_split["val"]["path"] = [data_split["val"]["path"][idx] for idx in permutation]
118 | data_split["val"]["score"] = [
119 | data_split["val"]["score"][idx] for idx in permutation
120 | ]
121 | data_split["val"]["step"] = [data_split["val"]["step"][idx] for idx in permutation]
122 |
123 | data_split["test"]["path"] = sum(
124 | [
125 | [
126 | ckpt.relative_to(script_dir).as_posix()
127 | for ckpt in grouped_checkpoints[idx]
128 | ]
129 | for idx in test_indices
130 | ],
131 | [],
132 | )
133 | data_split["test"]["score"] = sum(
134 | [
135 | [progresses[str(ckpt)] for ckpt in grouped_checkpoints[idx]]
136 | for idx in test_indices
137 | ],
138 | [],
139 | )
140 | data_split["test"]["step"] = sum(
141 | [
142 | [checkpoint_steps[str(ckpt)] for ckpt in grouped_checkpoints[idx]]
143 | for idx in test_indices
144 | ],
145 | [],
146 | )
147 | permutation = np.random.permutation(len(data_split["test"]["path"]))
148 | data_split["test"]["path"] = [
149 | data_split["test"]["path"][idx] for idx in permutation
150 | ]
151 | data_split["test"]["score"] = [
152 | data_split["test"]["score"][idx] for idx in permutation
153 | ]
154 | data_split["test"]["step"] = [
155 | data_split["test"]["step"][idx] for idx in permutation
156 | ]
157 |
158 | logging.info(
159 | f"train size: {len(data_split['train']['path'])}, "
160 | f"val size: {len(data_split['val']['path'])}, "
161 | f"test size: {len(data_split['test']['path'])}"
162 | f"train score size: {len(data_split['train']['score'])}, "
163 | f"val score size: {len(data_split['val']['score'])}, "
164 | f"test score size: {len(data_split['test']['score'])}"
165 | f"train step size: {len(data_split['train']['step'])}, "
166 | f"val step size: {len(data_split['val']['step'])}, "
167 | f"test step size: {len(data_split['test']['step'])}"
168 | )
169 |
170 | save_path = script_dir / Path(save_path) / name
171 | with open(save_path, "w") as file:
172 | json.dump(data_split, file)
173 |
174 |
175 | if __name__ == "__main__":
176 | parser = ArgumentParser(
177 | "CNN Generalization - generate data splits", parents=[common_parser]
178 | )
179 | parser.add_argument(
180 | "--name", type=str, default="cnn_park_splits.json", help="json file name"
181 | )
182 | parser.add_argument(
183 | "--val-size", type=int, default=25, help="number of validation examples"
184 | )
185 | parser.add_argument(
186 | "--test-size", type=int, default=50, help="number of test examples"
187 | )
188 | parser.set_defaults(
189 | save_path=".",
190 | data_path=".",
191 | )
192 | args = parser.parse_args()
193 |
194 | set_logger()
195 |
196 | generate_splits(
197 | args.data_path,
198 | args.save_path,
199 | name=args.name,
200 | val_size=args.val_size,
201 | test_size=args.test_size,
202 | )
203 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/train_cnn_park.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import hydra
4 | from omegaconf import DictConfig, OmegaConf
5 | import ray
6 | from ray import tune
7 | from ray import air
8 | from ray.air.integrations.wandb import WandbLoggerCallback
9 | import torch
10 |
11 | from experiments.cnn_generalization.dataset import cnn_sampler
12 | from experiments.cnn_generalization.dataset.cnn_trainer import NN_tune_trainable
13 |
14 |
15 | def prepare_dataset(cfg: DictConfig):
16 | """
17 | partially from https://github.com/ModelZoos/ModelZooDataset/blob/main/code/zoo_generators/train_zoo_f_mnist_uniform.py
18 | """
19 | data_path = Path(cfg.efficient_dataset_path).expanduser().resolve()
20 | if not data_path.exists():
21 | Path(cfg.root).expanduser().resolve().mkdir(parents=True, exist_ok=True)
22 | data_path.parent.mkdir(parents=True, exist_ok=True)
23 | val_and_trainset_raw = hydra.utils.instantiate(cfg.train)
24 | testset_raw = hydra.utils.instantiate(cfg.test)
25 | trainset_raw, valset_raw = torch.utils.data.random_split(
26 | val_and_trainset_raw,
27 | [len(val_and_trainset_raw) - 1, 1],
28 | generator=torch.Generator().manual_seed(cfg.dataset_seed),
29 | )
30 |
31 | # temp dataloaders
32 | trainloader_raw = torch.utils.data.DataLoader(
33 | dataset=trainset_raw, batch_size=len(trainset_raw), shuffle=True
34 | )
35 | valloader_raw = torch.utils.data.DataLoader(
36 | dataset=valset_raw, batch_size=len(valset_raw), shuffle=True
37 | )
38 | testloader_raw = torch.utils.data.DataLoader(
39 | dataset=testset_raw, batch_size=len(testset_raw), shuffle=True
40 | )
41 | # one forward pass
42 | assert (
43 | trainloader_raw.__len__() == 1
44 | ), "temp trainloader has more than one batch"
45 | for train_data, train_labels in trainloader_raw:
46 | pass
47 | assert valloader_raw.__len__() == 1, "temp valloader has more than one batch"
48 | for val_data, val_labels in valloader_raw:
49 | pass
50 | assert testloader_raw.__len__() == 1, "temp testloader has more than one batch"
51 | for test_data, test_labels in testloader_raw:
52 | pass
53 |
54 | trainset = torch.utils.data.TensorDataset(train_data, train_labels)
55 | valset = torch.utils.data.TensorDataset(val_data, val_labels)
56 | testset = torch.utils.data.TensorDataset(test_data, test_labels)
57 |
58 | # save dataset and seed in data directory
59 | dataset = {
60 | "trainset": trainset,
61 | "valset": valset,
62 | "testset": testset,
63 | "dataset_seed": cfg.dataset_seed,
64 | }
65 | torch.save(dataset, data_path)
66 |
67 |
68 | @hydra.main(
69 | config_path="generate_cnn_park_config", config_name="base", version_base=None
70 | )
71 | def main(cfg: DictConfig):
72 | torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
73 | torch.set_float32_matmul_precision(cfg.matmul_precision)
74 | torch.manual_seed(cfg.seed)
75 |
76 | # Resolve the relative path now
77 | cfg.data.efficient_dataset_path = (
78 | Path(cfg.data.efficient_dataset_path).expanduser().resolve()
79 | )
80 |
81 | prepare_dataset(cfg.data)
82 |
83 | ray.init(
84 | num_cpus=cfg.cpus,
85 | num_gpus=cfg.gpus,
86 | )
87 |
88 | gpu_fraction = ((cfg.gpus * 100) // (cfg.cpus / cfg.cpu_per_trial)) / 100
89 | resources_per_trial = {"cpu": cfg.cpu_per_trial, "gpu": gpu_fraction}
90 |
91 | assert ray.is_initialized() == True
92 |
93 | # create tune config
94 | tune_config = OmegaConf.to_container(cfg, resolve=True)
95 | model_configs = []
96 | for _ in range(cfg.num_models):
97 | model_configs.append(cnn_sampler.sample_cnn_config(cfg.random_options))
98 | tune_config["model"] = tune.grid_search(model_configs)
99 |
100 | # run tune trainable experiment
101 | analysis = tune.run(
102 | NN_tune_trainable,
103 | name=cfg.name,
104 | stop={
105 | "training_iteration": cfg.num_epochs,
106 | },
107 | checkpoint_config=air.CheckpointConfig(checkpoint_frequency=cfg.ckpt_freq),
108 | config=tune_config,
109 | local_dir=Path(cfg.data.root).expanduser().resolve().as_posix(),
110 | callbacks=[WandbLoggerCallback(**cfg.wandb)],
111 | reuse_actors=False,
112 | # resume="ERRORED_ONLY", # resumes from previous run. if run should be done all over, set resume=False
113 | # resume="LOCAL", # resumes from previous run. if run should be done all over, set resume=False
114 | resume=False, # resumes from previous run. if run should be done all over, set resume=False
115 | resources_per_trial=resources_per_trial,
116 | verbose=3,
117 | )
118 |
119 | ray.shutdown()
120 | assert ray.is_initialized() == False
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/dataset/zoo_cifar_nfn_statistics.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/cnn_generalization/dataset/zoo_cifar_nfn_statistics.pth
--------------------------------------------------------------------------------
/experiments/cnn_generalization/scripts/cnn_park_pna.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=cnn_park distributed=False \
9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \
10 | model.d_hid=64 model.pooling_method=cat model.pooling_layer_idx=last \
11 | wandb.name=cnn_generalization_cnn_park_pna_seed_${seed}_epoch_20_bce \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/scripts/cnn_park_pna_no_act.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=cnn_park distributed=False \
9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \
10 | model.d_hid=64 model.pooling_method=cat model.pooling_layer_idx=last model.use_act_embed=False\
11 | wandb.name=cnn_generalization_cnn_park_pna_seed_${seed}_epoch_20_no_act_bce \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/scripts/cnn_park_rt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=cnn_park distributed=False \
9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \
10 | model.d_node=32 model.d_node_hid=64 model.d_edge=16 model.d_edge_hid=64 model.d_attn_hid=32 \
11 | model.d_out_hid=64 model.n_heads=4 model.n_layers=3 model.dropout=0.2 \
12 | model.pooling_method=cat model.pooling_layer_idx=last \
13 | wandb.name=cnn_generalization_cnn_park_rt_seed_${seed}_epoch_20_drop_0.2_bce \
14 | $extra_args
15 | done
16 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/scripts/cnn_park_rt_no_act.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=cnn_park distributed=False \
9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \
10 | model.d_node=32 model.d_node_hid=64 model.d_edge=16 model.d_edge_hid=64 model.d_attn_hid=32 \
11 | model.d_out_hid=64 model.n_heads=4 model.n_layers=3 model.dropout=0.2 \
12 | model.pooling_method=cat model.pooling_layer_idx=last model.use_act_embed=False \
13 | wandb.name=cnn_generalization_cnn_park_rt_seed_${seed}_epoch_20_drop_0.2_no_act_bce \
14 | $extra_args
15 | done
16 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/scripts/cnn_zoo_pna.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=zoo_cifar_nfn distributed=False \
9 | loss._target_=torch.nn.MSELoss batch_size=128 n_epochs=300 \
10 | data.train.augmentation=True data.linear_as_conv=False model.d_hid=256 \
11 | wandb.name=cnn_generalization_cnn_zoo_pna_seed_${seed}_epoch_300_mse_b_128_d_256_no_lac_aug $extra_args
12 | done
13 |
14 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/scripts/cnn_zoo_rt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=zoo_cifar_nfn distributed=False \
9 | loss._target_=torch.nn.BCELoss batch_size=192 n_epochs=300 \
10 | model.d_node=128 model.d_edge=64 model.d_attn_hid=256 model.d_node_hid=256 \
11 | model.d_edge_hid=128 model.d_out_hid=256 \
12 | wandb.name=cnn_generalization_cnn_zoo_rt_seed_${seed}_epoch_300_bce_double_b_192 $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/sweep_configs/sweep_cnn_park_gnn.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: cnn-generalization
3 | entity: neural-graphs
4 | method: bayes
5 | metric:
6 | goal: maximize
7 | name: test/best_tau
8 | parameters:
9 | data:
10 | value: cnn_park
11 | model:
12 | value: pna
13 | eval_every:
14 | value: 1000
15 | n_epochs:
16 | value: 5
17 | loss._target_:
18 | values:
19 | - torch.nn.BCELoss
20 | - torch.nn.MSELoss
21 | distributed:
22 | value: False
23 | batch_size:
24 | values:
25 | - 8
26 | - 32
27 | - 128
28 | model.d_hid:
29 | values:
30 | - 8
31 | - 16
32 | - 32
33 | - 64
34 | - 128
35 | model.gnn_backbone.num_layers:
36 | values:
37 | - 2
38 | - 3
39 | - 4
40 | model.pooling_method:
41 | values:
42 | - mean
43 | - cat
44 | model.pooling_layer_idx:
45 | value: last
46 |
47 | command:
48 | - ${env}
49 | - ${interpreter}
50 | - ${program}
51 | - ${args_no_hyphens}
52 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/sweep_configs/sweep_cnn_park_statnn.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: cnn-generalization
3 | entity: neural-graphs
4 | method: bayes
5 | metric:
6 | goal: maximize
7 | name: test/best_tau
8 | parameters:
9 | data:
10 | value: cnn_park
11 | model:
12 | value: dynamic_stat
13 | data.data_format:
14 | value: stat
15 | eval_every:
16 | value: 1000
17 | n_epochs:
18 | value: 5
19 | loss._target_:
20 | values:
21 | - torch.nn.BCELoss
22 | - torch.nn.MSELoss
23 | distributed:
24 | value: False
25 | batch_size:
26 | values:
27 | - 8
28 | - 32
29 | - 128
30 | model.h_size:
31 | values:
32 | - 8
33 | - 16
34 | - 128
35 | - 512
36 | - 1000
37 |
38 | command:
39 | - ${env}
40 | - ${interpreter}
41 | - ${program}
42 | - ${args_no_hyphens}
43 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/sweep_configs/sweep_cnn_park_transformer.yaml:
--------------------------------------------------------------------------------
1 | program: main.py
2 | project: cnn-generalization
3 | entity: neural-graphs
4 | method: bayes
5 | metric:
6 | goal: maximize
7 | name: test/best_tau
8 | parameters:
9 | data:
10 | value: cnn_park
11 | model:
12 | value: rtransformer
13 | eval_every:
14 | value: 1000
15 | n_epochs:
16 | value: 5
17 | loss._target_:
18 | values:
19 | - torch.nn.BCELoss
20 | - torch.nn.MSELoss
21 | distributed:
22 | value: False
23 | batch_size:
24 | values:
25 | - 8
26 | - 32
27 | - 128
28 | model.dropout:
29 | values:
30 | - 0.0
31 | - 0.2
32 | model.d_node:
33 | values:
34 | - 8
35 | - 16
36 | - 64
37 | model.d_edge:
38 | values:
39 | - 8
40 | - 16
41 | - 32
42 | model.d_attn_hid:
43 | values:
44 | - 16
45 | - 32
46 | - 128
47 | model.d_node_hid:
48 | values:
49 | - 16
50 | - 32
51 | - 128
52 | model.d_edge_hid:
53 | values:
54 | - 8
55 | - 16
56 | - 64
57 | model.d_out_hid:
58 | values:
59 | - 8
60 | - 16
61 | - 128
62 | model.n_layers:
63 | values:
64 | - 2
65 | - 3
66 | - 4
67 | model.n_heads:
68 | values:
69 | - 1
70 | - 2
71 | - 4
72 | model.pooling_method:
73 | values:
74 | - mean
75 | - cat
76 | model.pooling_layer_idx:
77 | value: last
78 |
79 | command:
80 | - ${env}
81 | - ${interpreter}
82 | - ${program}
83 | - ${args_no_hyphens}
84 |
--------------------------------------------------------------------------------
/experiments/cnn_generalization/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch_geometric.data import Data
4 |
5 |
6 | def pad_and_flatten_kernel(kernel, max_kernel_size):
7 | full_padding = (
8 | max_kernel_size[0] - kernel.shape[2],
9 | max_kernel_size[1] - kernel.shape[3],
10 | )
11 | padding = (
12 | full_padding[0] // 2,
13 | full_padding[0] - full_padding[0] // 2,
14 | full_padding[1] // 2,
15 | full_padding[1] - full_padding[1] // 2,
16 | )
17 | return F.pad(kernel, padding).flatten(2, 3)
18 |
19 |
20 | def cnn_to_graph(
21 | weights,
22 | biases,
23 | weights_mean=None,
24 | weights_std=None,
25 | biases_mean=None,
26 | biases_std=None,
27 | ):
28 | weights_mean = weights_mean if weights_mean is not None else [0.0] * len(weights)
29 | weights_std = weights_std if weights_std is not None else [1.0] * len(weights)
30 | biases_mean = biases_mean if biases_mean is not None else [0.0] * len(biases)
31 | biases_std = biases_std if biases_std is not None else [1.0] * len(biases)
32 |
33 | # The graph will have as many nodes as the total number of channels in the
34 | # CNN, plus the number of output dimensions for each linear layer
35 | device = weights[0].device
36 | num_input_nodes = weights[0].shape[0]
37 | num_nodes = num_input_nodes + sum(b.shape[0] for b in biases)
38 |
39 | edge_features = torch.zeros(
40 | num_nodes, num_nodes, weights[0].shape[-1], device=device
41 | )
42 | edge_feature_masks = torch.zeros(
43 | num_nodes, num_nodes, device=device, dtype=torch.bool
44 | )
45 | adjacency_matrix = torch.zeros(
46 | num_nodes, num_nodes, device=device, dtype=torch.bool
47 | )
48 |
49 | row_offset = 0
50 | col_offset = num_input_nodes # no edge to input nodes
51 | for i, w in enumerate(weights):
52 | num_in, num_out = w.shape[:2]
53 | edge_features[
54 | row_offset : row_offset + num_in,
55 | col_offset : col_offset + num_out,
56 | : w.shape[-1],
57 | ] = (w - weights_mean[i]) / weights_std[i]
58 | edge_feature_masks[
59 | row_offset : row_offset + num_in, col_offset : col_offset + num_out
60 | ] = (w.shape[-1] == 1)
61 | adjacency_matrix[
62 | row_offset : row_offset + num_in, col_offset : col_offset + num_out
63 | ] = True
64 | row_offset += num_in
65 | col_offset += num_out
66 |
67 | node_features = torch.cat(
68 | [
69 | torch.zeros((num_input_nodes, 1), device=device, dtype=biases[0].dtype),
70 | *[(b - biases_mean[i]) / biases_std[i] for i, b in enumerate(biases)],
71 | ]
72 | )
73 |
74 | return node_features, edge_features, edge_feature_masks, adjacency_matrix
75 |
76 |
77 | def cnn_to_tg_data(
78 | weights,
79 | biases,
80 | conv_mask,
81 | weights_mean=None,
82 | weights_std=None,
83 | biases_mean=None,
84 | biases_std=None,
85 | **kwargs,
86 | ):
87 | node_features, edge_features, edge_feature_masks, adjacency_matrix = cnn_to_graph(
88 | weights, biases, weights_mean, weights_std, biases_mean, biases_std
89 | )
90 | edge_index = adjacency_matrix.nonzero().t()
91 |
92 | num_input_nodes = weights[0].shape[0]
93 | cnn_sizes = [w.shape[1] for i, w in enumerate(weights) if conv_mask[i]]
94 | num_cnn_nodes = num_input_nodes + sum(cnn_sizes)
95 | send_nodes = num_input_nodes + sum(cnn_sizes[:-1])
96 | spatial_embed_mask = torch.zeros_like(node_features[:, 0], dtype=torch.bool)
97 | spatial_embed_mask[send_nodes:num_cnn_nodes] = True
98 | node_types = torch.cat(
99 | [
100 | torch.zeros(num_cnn_nodes, dtype=torch.long),
101 | torch.ones(node_features.shape[0] - num_cnn_nodes, dtype=torch.long),
102 | ]
103 | )
104 | if "residual_connections" in kwargs and "layer_layout" in kwargs:
105 | residual_edge_index = get_residuals_graph(
106 | kwargs["residual_connections"],
107 | kwargs["layer_layout"],
108 | )
109 | edge_index = torch.cat([edge_index, residual_edge_index], dim=1)
110 | # TODO: Do this in a more general way, now it works for square kernels
111 | center_pixel_index = edge_features.shape[-1] // 2
112 | edge_features[
113 | residual_edge_index[0], residual_edge_index[1], center_pixel_index
114 | ] = 1.0
115 |
116 | data = Data(
117 | x=node_features,
118 | edge_attr=edge_features[edge_index[0], edge_index[1]],
119 | edge_index=edge_index,
120 | mlp_edge_masks=edge_feature_masks[edge_index[0], edge_index[1]],
121 | spatial_embed_mask=spatial_embed_mask,
122 | node_types=node_types,
123 | conv_mask=conv_mask,
124 | **kwargs,
125 | )
126 |
127 | return data
128 |
129 |
130 | def get_residuals_graph(residual_connections, layer_layout):
131 | residual_layer_index = torch.LongTensor(
132 | [(e, i) for i, e in enumerate(residual_connections) if e >= 0]
133 | )
134 | if residual_layer_index.numel() == 0:
135 | return torch.zeros((2, 0), dtype=torch.long)
136 |
137 | residual_layer_index = residual_layer_index.T
138 | layout = torch.tensor(layer_layout)
139 | hidden_layout = layout[1:-1]
140 | min_residuals = hidden_layout[residual_layer_index].min(0, keepdim=True).values
141 | starting_indices = torch.cumsum(layout, dim=0)[residual_layer_index]
142 |
143 | residual_edge_index = torch.cat(
144 | [
145 | torch.stack(
146 | [
147 | torch.arange(
148 | starting_indices[0, i],
149 | starting_indices[0, i] + min_residuals[0, i],
150 | dtype=torch.long,
151 | ),
152 | torch.arange(
153 | starting_indices[1, i],
154 | starting_indices[1, i] + min_residuals[0, i],
155 | dtype=torch.long,
156 | ),
157 | ],
158 | dim=0,
159 | )
160 | for i in range(starting_indices.shape[1])
161 | ],
162 | dim=1,
163 | )
164 | return residual_edge_index
165 |
--------------------------------------------------------------------------------
/experiments/inr_classification/README.md:
--------------------------------------------------------------------------------
1 | # INR classification
2 |
3 | The following commands assume that you are executing them from the current directory `experiments/inr_classification`.
4 | If you are in the root of the repository, please navigate to the `experiments/inr_classification` directory:
5 |
6 | ```sh
7 | cd experiments/inr_classification
8 | ```
9 |
10 | Activate the conda environment:
11 |
12 | ```sh
13 | conda activate neural-graphs
14 | ```
15 |
16 | ## Setup
17 |
18 | ### Download the data
19 |
20 | For INR classification, we use MNIST and Fashion MNIST.
21 | The datasets are available [here](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0).
22 |
23 | - [MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=mnist-inrs.zip)
24 | - [Fashion MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=fmnist_inrs.zip)
25 |
26 | Please download the data and place it in `dataset/mnist-inrs` and `dataset/fmnist_inrs`, respectively.
27 | If you want to use a different path, please change the following commands
28 | accordingly, or symlink your dataset path to the default ones.
29 |
30 | #### MNIST
31 |
32 | ```sh
33 | wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AABrctdu2U65jGYr2WQRzmMna/mnist-inrs.zip?dl=0" -O mnist-inrs.zip &&
34 | mkdir -p dataset/mnist-inrs &&
35 | unzip -q mnist-inrs.zip -d dataset &&
36 | rm mnist-inrs.zip
37 | ```
38 |
39 | #### Fashion MNIST
40 |
41 | ```sh
42 | wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AAAssoHq719OmSHSKKTiKKHGa/fmnist_inrs.zip?dl=0" -O fmnist_inrs.zip &&
43 | mkdir -p dataset/fmnist_inrs &&
44 | unzip -q fmnist_inrs.zip -d dataset &&
45 | rm fmnist_inrs.zip
46 | ```
47 |
48 | ### Data preprocessing
49 |
50 | We have already performed the data preprocessing required for MNIST and Fashion
51 | MNIST and provide the files within the repository. The preprocessing generates the
52 | data splits and the dataset statistics. These correspond to the files
53 | `dataset/mnist_splits.json` and `dataset/mnist_statistics.pth`
54 | for MNIST, and `dataset/fmnist_splits.json` and `dataset/fmnist_statistics.pth`
55 | for Fashion MNIST.
56 |
57 | However, if you want to use different directories for your experiments, you have
58 | to run the scripts that follow, or simply symlink your paths to the default ones.
59 |
60 | #### MNIST
61 |
62 | First, create the data split using:
63 |
64 | ```shell
65 | python dataset/generate_mnist_data_splits.py \
66 | --data-path mnist-inrs --save-path . --name mnist_splits.json
67 | ```
68 | This will create a json file `dataset/mnist_splits.json`.
69 | **Note** that the `--data-path` and `--save-path` arguments should be set relatively
70 | to the `dataset` directory.
71 |
72 | Next, compute the dataset (INRs) statistics:
73 | ```shell
74 | python dataset/compute_mnist_statistics.py \
75 | --data-path . --save-path . \
76 | --splits-path mnist_splits.json --statistics-path mnist_statistics.pth
77 | ```
78 | This will create `dataset/mnist_statistics.pth` object.
79 | Again, `--data-path` and `--save-path` should be set relatively to the `dataset`
80 | directory.
81 |
82 | #### Fashion MNIST
83 |
84 | Fashion MNIST requires a slightly different preprocessing.
85 | First, prepare the data splits using:
86 |
87 | ```shell
88 | python dataset/preprocess_fmnist.py \
89 | --data-path fmnist_inrs/splits.json --save-path . --name fmnist_splits.json
90 | ```
91 |
92 | Next, compute the dataset statistics:
93 | ```shell
94 | python dataset/compute_fmnist_statistics.py \
95 | --data-path . --save-path . \
96 | --splits-path fmnist_splits.json --statistics-path fmnist_statistics.pth
97 | ```
98 | This will create `dataset/fmnist_statistics.pth` object.
99 |
100 |
101 | ## Run the experiment
102 |
103 | Now for the fun part! :rocket:
104 | To train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command:
105 |
106 | ```shell
107 | python main.py model=rtransformer data=mnist
108 | ```
109 |
110 | Make sure to check the model configuration in `configs/model/rtransformer.yaml`
111 | and the data configuration in `configs/data/mnist.yaml`.
112 | If you used different paths for the data, you can either overwrite the default
113 | paths in `configs/data/mnist.yaml` or pass the paths as arguments to the command:
114 |
115 | ```shell
116 | python main.py model=rtransformer data=mnist \
117 | data.dataset_dir= data.splits_path= \
118 | data.statistics_path=
119 | ```
120 |
121 | Training a different model or using a different dataset is as simple as changing
122 | the `model` and `data` arguments!
123 | For example, you can train and evaluate a __Neural Graph Graph Neural Network__ (NG-GNN)
124 | on Fashion MNIST using the following command:
125 |
126 | ```shell
127 | python main.py model=pna data=fmnist
128 | ```
129 |
130 | ### Run experiments with scripts
131 |
132 | You can also run the experiments using the scripts provided in the `scripts` directory.
133 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command:
134 |
135 | ```sh
136 | ./scripts/mnist_cls_rt.sh
137 | ```
138 | This script will run the experiment for 3 different seeds.
139 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: rtransformer
3 | - data: mnist
4 | - _self_
5 |
6 | n_epochs: 300
7 | batch_size: 128
8 |
9 | n_views: 1
10 | num_workers: 8
11 | eval_every: 1000
12 | num_accum: 1
13 |
14 | compile: false
15 | compile_kwargs:
16 | # mode: reduce-overhead
17 | mode: null
18 | options:
19 | matmul-padding: True
20 |
21 | optim:
22 | _target_: torch.optim.AdamW
23 | lr: 1e-3
24 | weight_decay: 5e-4
25 | amsgrad: True
26 | fused: False
27 |
28 | scheduler:
29 | _target_: experiments.lr_scheduler.WarmupLRScheduler
30 | warmup_steps: 1000
31 |
32 | distributed:
33 | world_size: 1
34 | rank: 0
35 | device_ids: null
36 |
37 | load_ckpt: null
38 |
39 | use_amp: False
40 | gradscaler:
41 | enabled: ${use_amp}
42 | autocast:
43 | device_type: cuda
44 | enabled: ${use_amp}
45 | dtype: float16
46 |
47 | clip_grad: True
48 | clip_grad_max_norm: 10.0
49 |
50 | seed: 42
51 | save_path: ./output
52 | wandb:
53 | project: inr-classification
54 | entity: null
55 | name: null
56 |
57 | matmul_precision: high
58 | cudnn_benchmark: False
59 |
60 | debug: False
61 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/data/dummy_inr.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.INRDummyDataset
3 | normalize: False
4 | num_classes: 10
5 | img_shape: [28, 28]
6 | inr_model:
7 | _target_: nn.inr.INRPerLayer
8 | in_features: 2
9 | n_layers: 3
10 | hidden_features: 32
11 | out_features: 1
12 | layer_layout: [2, 2048, 2048, 3]
13 | stats: null
14 |
15 | train:
16 | _target_: ${data.target}
17 | _recursive_: True
18 | layer_layout: ${data.layer_layout}
19 |
20 | val:
21 | _target_: ${data.target}
22 | layer_layout: ${data.layer_layout}
23 |
24 | test:
25 | _target_: ${data.target}
26 | layer_layout: ${data.layer_layout}
27 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/data/fmnist.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.INRDataset
3 | normalize: False
4 | dataset_name: fmnist
5 | dataset_dir: dataset
6 | splits_path: fmnist_splits.json
7 | statistics_path: fmnist_statistics.pth
8 | num_classes: 10
9 | img_shape: [28, 28]
10 | inr_model:
11 | _target_: nn.inr.INRPerLayer
12 | in_features: 2
13 | n_layers: 3
14 | hidden_features: 32
15 | out_features: 1
16 |
17 | stats:
18 | # NOTE: Generated with `generate_fmnist_statistics.py`
19 | weights_mean: [6.370305982272839e-06, 6.88720547259436e-06, 1.0729863788583316e-05]
20 | weights_std: [0.07822809368371964, 0.03240188956260681, 0.13454964756965637]
21 | biases_mean: [1.6790845336345228e-07, -1.1566662578843534e-05, -0.020282816141843796]
22 | biases_std: [0.028561526909470558, 0.016700252890586853, 0.09595609456300735]
23 |
24 | train:
25 | _target_: ${data.target}
26 | _recursive_: True
27 | dataset_dir: ${data.dataset_dir}
28 | splits_path: ${data.splits_path}
29 | split: train
30 | normalize: ${data.normalize}
31 | augmentation: True
32 | permutation: False
33 | statistics_path: ${data.statistics_path}
34 | # num_classes: ${data.num_classes}
35 |
36 | val:
37 | _target_: ${data.target}
38 | dataset_dir: ${data.dataset_dir}
39 | splits_path: ${data.splits_path}
40 | split: val
41 | normalize: ${data.normalize}
42 | augmentation: False
43 | permutation: False
44 | statistics_path: ${data.statistics_path}
45 | # num_classes: ${data.num_classes}
46 |
47 | test:
48 | _target_: ${data.target}
49 | dataset_dir: ${data.dataset_dir}
50 | splits_path: ${data.splits_path}
51 | split: test
52 | normalize: ${data.normalize}
53 | augmentation: False
54 | permutation: False
55 | statistics_path: ${data.statistics_path}
56 | # num_classes: ${data.num_classes}
57 |
58 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/data/mnist.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.INRDataset
3 | normalize: False
4 | dataset_name: mnist
5 | dataset_dir: dataset
6 | splits_path: mnist_splits.json
7 | statistics_path: mnist_statistics.pth
8 | num_classes: 10
9 | img_shape: [28, 28]
10 | inr_model:
11 | _target_: nn.inr.INRPerLayer
12 | in_features: 2
13 | n_layers: 3
14 | hidden_features: 32
15 | out_features: 1
16 |
17 | stats:
18 | weights_mean: [-4.215954686515033e-05, -7.55547659991862e-07, 7.886120874900371e-05]
19 | weights_std: [0.06281130015850067, 0.018268151208758354, 0.11791174858808517]
20 | biases_mean: [5.419965418695938e-06, 3.7173406326473923e-06, -0.01239530649036169]
21 | biases_std: [0.021334609016776085, 0.011004417203366756, 0.09989194571971893]
22 |
23 | train:
24 | _target_: ${data.target}
25 | _recursive_: True
26 | dataset_dir: ${data.dataset_dir}
27 | splits_path: ${data.splits_path}
28 | split: train
29 | normalize: ${data.normalize}
30 | augmentation: True
31 | permutation: False
32 | statistics_path: ${data.statistics_path}
33 | # num_classes: ${data.num_classes}
34 |
35 | val:
36 | _target_: ${data.target}
37 | dataset_dir: ${data.dataset_dir}
38 | splits_path: ${data.splits_path}
39 | split: val
40 | normalize: ${data.normalize}
41 | augmentation: False
42 | permutation: False
43 | statistics_path: ${data.statistics_path}
44 | # num_classes: ${data.num_classes}
45 |
46 | test:
47 | _target_: ${data.target}
48 | dataset_dir: ${data.dataset_dir}
49 | splits_path: ${data.splits_path}
50 | split: test
51 | normalize: ${data.normalize}
52 | augmentation: False
53 | permutation: False
54 | statistics_path: ${data.statistics_path}
55 | # num_classes: ${data.num_classes}
56 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/model/dwsnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.dws.models.DWSModelForClassification
2 | _recursive_: False
3 | input_features: 1
4 | hidden_dim: 32
5 | n_hidden: 4
6 | reduction: max
7 | n_fc_layers: 1
8 | num_heads: 8
9 | set_layer: sab
10 | n_out_fc: 1
11 | dropout_rate: 0.0
12 | bn: true
13 | diagonal: false
14 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/model/mlp.yaml:
--------------------------------------------------------------------------------
1 | name: mlp
2 | kwargs:
3 | dim_hidden: 32
4 | n_hidden: 4
5 | add_bn: true
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/model/nfn.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.nfn.models.InvariantNFN
2 | hchannels: [512, 512, 512]
3 | mode: NP
4 | normalize: False
5 | in_channels: 1
6 | d_out: 10
7 | dropout: 0.5
8 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/model/pna.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.gnn.GNNForClassification
2 | _recursive_: False
3 | d_out: ${data.num_classes}
4 | d_hid: 32
5 | compile: False
6 | rev_edge_features: False
7 | pooling_method: cat
8 | pooling_layer_idx: last # all, last, or 0, 1, ...
9 | jit: False
10 |
11 | gnn_backbone:
12 | _target_: nn.gnn.PNA
13 | _convert_: all
14 | in_channels: ${model.d_hid}
15 | hidden_channels: ${model.d_hid}
16 | out_channels: ${model.d_hid}
17 | num_layers: 4
18 | aggregators: ['mean', 'min', 'max', 'std']
19 | scalers: ['identity', 'amplification']
20 | edge_dim: ${model.d_hid}
21 | dropout: 0.
22 | norm: layernorm
23 | act: silu
24 | deg: null
25 | update_edge_attr: True
26 | modulate_edges: True
27 | gating_edges: False
28 | final_edge_update: False
29 |
30 | graph_constructor:
31 | _target_: nn.graph_constructor.GraphConstructor
32 | _recursive_: False
33 | _convert_: all
34 | d_in: 1
35 | d_edge_in: 1
36 | zero_out_bias: False
37 | zero_out_weights: False
38 | sin_emb: True
39 | sin_emb_dim: 128
40 | use_pos_embed: True
41 | input_layers: 1
42 | inp_factor: 3
43 | num_probe_features: 0
44 | inr_model: ${data.inr_model}
45 | stats: ${data.stats}
46 |
--------------------------------------------------------------------------------
/experiments/inr_classification/configs/model/rtransformer.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.relational_transformer.RelationalTransformer
2 | _recursive_: False
3 | d_out: ${data.num_classes}
4 | d_node: 64
5 | d_edge: 32
6 | d_attn_hid: 128
7 | d_node_hid: 128
8 | d_edge_hid: 64
9 | d_out_hid: 128
10 | n_layers: 4
11 | n_heads: 8
12 | node_update_type: rt
13 | disable_edge_updates: False
14 | use_cls_token: False
15 | pooling_method: cat
16 | pooling_layer_idx: last # all, last, or 0, 1, ...
17 | dropout: 0.0
18 | rev_edge_features: False
19 |
20 | use_ln: True
21 | tfixit_init: False
22 | modulate_v: True
23 |
24 | graph_constructor:
25 | _target_: nn.graph_constructor.GraphConstructor
26 | _recursive_: False
27 | _convert_: all
28 | d_in: 1
29 | d_edge_in: 1
30 | zero_out_bias: False
31 | zero_out_weights: False
32 | sin_emb: True
33 | sin_emb_dim: 128
34 | use_pos_embed: True
35 | input_layers: 1
36 | inp_factor: 1
37 | num_probe_features: 0
38 | inr_model: ${data.inr_model}
39 | stats: ${data.stats}
40 |
--------------------------------------------------------------------------------
/experiments/inr_classification/dataset/compute_fmnist_statistics.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from experiments.inr_classification.dataset.compute_mnist_statistics import (
4 | compute_stats,
5 | )
6 | from experiments.utils import common_parser
7 |
8 | if __name__ == "__main__":
9 | parser = ArgumentParser(
10 | "Fashion MNIST - generate statistics", parents=[common_parser]
11 | )
12 | parser.add_argument(
13 | "--splits-path", type=str, default="fmnist_splits.json", help="json file name"
14 | )
15 | parser.add_argument(
16 | "--statistics-path",
17 | type=str,
18 | default="fmnist_statistics.pth",
19 | help="Pytorch statistics file name",
20 | )
21 | parser.set_defaults(
22 | save_path=".",
23 | data_path=".",
24 | )
25 | args = parser.parse_args()
26 |
27 | compute_stats(
28 | data_path=args.data_path,
29 | save_path=args.save_path,
30 | splits_path=args.splits_path,
31 | statistics_path=args.statistics_path,
32 | )
33 |
--------------------------------------------------------------------------------
/experiments/inr_classification/dataset/compute_mnist_statistics.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from pathlib import Path
3 |
4 | import torch
5 |
6 | from experiments.data import INRDataset
7 | from experiments.utils import common_parser
8 |
9 |
10 | def compute_stats(
11 | data_path: str, save_path: str, splits_path: str, statistics_path: str
12 | ):
13 | script_dir = Path(__file__).parent
14 | data_path = script_dir / Path(data_path)
15 |
16 | train_set = INRDataset(
17 | dataset_dir=data_path,
18 | splits_path=splits_path,
19 | split="train",
20 | statistics_path=None,
21 | normalize=False,
22 | )
23 |
24 | train_loader = torch.utils.data.DataLoader(
25 | train_set, batch_size=len(train_set), shuffle=False, num_workers=4
26 | )
27 |
28 | train_data = next(iter(train_loader))
29 |
30 | train_weights_mean = [w.mean().item() for w in train_data.weights]
31 | train_weights_std = [w.std().item() for w in train_data.weights]
32 | train_biases_mean = [w.mean().item() for w in train_data.biases]
33 | train_biases_std = [w.std().item() for w in train_data.biases]
34 |
35 | print(f"weights_mean: {train_weights_mean}")
36 | print(f"weights_std: {train_weights_std}")
37 | print(f"biases_mean: {train_biases_mean}")
38 | print(f"biases_std: {train_biases_std}")
39 |
40 | dws_weights_mean = [w.mean(0) for w in train_data.weights]
41 | dws_weights_std = [w.std(0) for w in train_data.weights]
42 | dws_biases_mean = [w.mean(0) for w in train_data.biases]
43 | dws_biases_std = [w.std(0) for w in train_data.biases]
44 |
45 | statistics = {
46 | "weights": {"mean": dws_weights_mean, "std": dws_weights_std},
47 | "biases": {"mean": dws_biases_mean, "std": dws_biases_std},
48 | }
49 |
50 | out_path = script_dir / Path(save_path)
51 | out_path.mkdir(exist_ok=True, parents=True)
52 | torch.save(statistics, out_path / statistics_path)
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = ArgumentParser("MNIST - generate statistics", parents=[common_parser])
57 | parser.add_argument(
58 | "--splits-path", type=str, default="mnist_splits.json", help="json file name"
59 | )
60 | parser.add_argument(
61 | "--statistics-path",
62 | type=str,
63 | default="mnist_statistics.pth",
64 | help="Pytorch statistics file name",
65 | )
66 | parser.set_defaults(
67 | save_path=".",
68 | data_path=".",
69 | )
70 | args = parser.parse_args()
71 |
72 | compute_stats(
73 | data_path=args.data_path,
74 | save_path=args.save_path,
75 | splits_path=args.splits_path,
76 | statistics_path=args.statistics_path,
77 | )
78 |
--------------------------------------------------------------------------------
/experiments/inr_classification/dataset/compute_nfn_mnist_statistics.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from pathlib import Path
3 |
4 | import torch
5 |
6 | from experiments.data_nfn import SirenDataset
7 | from experiments.utils import common_parser
8 |
9 |
10 | def compute_stats(data_path: str, save_path: str):
11 | script_dir = Path(__file__).parent
12 | data_path = script_dir / Path(data_path)
13 | dset = SirenDataset(data_path, "randinit_smaller")
14 |
15 | train_set = torch.utils.data.Subset(dset, range(45_000))
16 |
17 | all_weights = [d[0][0] for d in train_set]
18 | all_biases = [d[0][1] for d in train_set]
19 |
20 | weights_mean = []
21 | weights_std = []
22 | biases_mean = []
23 | biases_std = []
24 | for i in range(len(all_weights[0])):
25 | weights_mean.append(torch.stack([w[i] for w in all_weights]).mean().item())
26 | weights_std.append(torch.stack([w[i] for w in all_weights]).std().item())
27 | biases_mean.append(torch.stack([b[i] for b in all_biases]).mean().item())
28 | biases_std.append(torch.stack([b[i] for b in all_biases]).std().item())
29 | print(weights_mean)
30 | print(weights_std)
31 | print(biases_mean)
32 | print(biases_std)
33 |
34 | dws_weights_mean = []
35 | dws_weights_std = []
36 | dws_biases_mean = []
37 | dws_biases_std = []
38 | for i in range(len(all_weights[0])):
39 | dws_weights_mean.append(
40 | torch.stack([w[i] for w in all_weights]).mean(0).squeeze(0).unsqueeze(-1)
41 | )
42 | dws_weights_std.append(
43 | torch.stack([w[i] for w in all_weights]).std(0).squeeze(0).unsqueeze(-1)
44 | )
45 | dws_biases_mean.append(
46 | torch.stack([b[i] for b in all_biases]).mean(0).squeeze(0).unsqueeze(-1)
47 | )
48 | dws_biases_std.append(
49 | torch.stack([b[i] for b in all_biases]).std(0).squeeze(0).unsqueeze(-1)
50 | )
51 |
52 | statistics = {
53 | "weights": {"mean": dws_weights_mean, "std": dws_weights_std},
54 | "biases": {"mean": dws_biases_mean, "std": dws_biases_std},
55 | }
56 |
57 | out_path = script_dir / Path(save_path)
58 | out_path.mkdir(exist_ok=True, parents=True)
59 | torch.save(statistics, out_path / "nfn_mnist_statistics.pth")
60 |
61 |
62 | if __name__ == "__main__":
63 | parser = ArgumentParser("NFN MNIST - generate statistics", parents=[common_parser])
64 | parser.set_defaults(
65 | data_path="nfn-mnist-inrs",
66 | save_path=".",
67 | )
68 | args = parser.parse_args()
69 |
70 | compute_stats(data_path=args.data_path, save_path=args.save_path)
71 |
--------------------------------------------------------------------------------
/experiments/inr_classification/dataset/fmnist_statistics.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/inr_classification/dataset/fmnist_statistics.pth
--------------------------------------------------------------------------------
/experiments/inr_classification/dataset/generate_mnist_data_splits.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from argparse import ArgumentParser
4 | from collections import defaultdict
5 | from pathlib import Path
6 |
7 | from sklearn.model_selection import train_test_split
8 |
9 | from experiments.utils import common_parser, set_logger
10 |
11 |
12 | def generate_splits(
13 | data_path, save_path, name="mnist_splits.json", val_size=5000, random_state=None
14 | ):
15 | script_dir = Path(__file__).parent
16 | inr_path = script_dir / Path(data_path)
17 | data_split = defaultdict(lambda: defaultdict(list))
18 | for p in list(inr_path.glob("mnist_png_*/**/*.pth")):
19 | s = "train" if "train" in p.as_posix() else "test"
20 | data_split[s]["path"].append(p.relative_to(script_dir).as_posix())
21 | data_split[s]["label"].append(p.parent.parent.stem.split("_")[-2])
22 |
23 | # val split
24 | train_indices, val_indices = train_test_split(
25 | range(len(data_split["train"]["path"])),
26 | test_size=val_size,
27 | random_state=random_state,
28 | )
29 | data_split["val"]["path"] = [data_split["train"]["path"][v] for v in val_indices]
30 | data_split["val"]["label"] = [data_split["train"]["label"][v] for v in val_indices]
31 |
32 | data_split["train"]["path"] = [
33 | data_split["train"]["path"][v] for v in train_indices
34 | ]
35 | data_split["train"]["label"] = [
36 | data_split["train"]["label"][v] for v in train_indices
37 | ]
38 |
39 | logging.info(
40 | f"train size: {len(data_split['train']['path'])}, "
41 | f"val size: {len(data_split['val']['path'])}, "
42 | f"test size: {len(data_split['test']['path'])}"
43 | )
44 |
45 | save_path = script_dir / Path(save_path) / name
46 | with open(save_path, "w") as file:
47 | json.dump(data_split, file)
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = ArgumentParser("MNIST - generate data splits", parents=[common_parser])
52 | parser.add_argument(
53 | "--name", type=str, default="mnist_splits.json", help="json file name"
54 | )
55 | parser.add_argument(
56 | "--val-size", type=int, default=5000, help="number of validation examples"
57 | )
58 | parser.add_argument(
59 | "--random-state", type=int, default=None, help="random state for split"
60 | )
61 | parser.set_defaults(
62 | save_path=".",
63 | data_path="mnist-inrs",
64 | )
65 | args = parser.parse_args()
66 |
67 | set_logger()
68 |
69 | generate_splits(
70 | data_path=args.data_path,
71 | save_path=args.save_path,
72 | name=args.name,
73 | val_size=args.val_size,
74 | random_state=args.random_state,
75 | )
76 |
--------------------------------------------------------------------------------
/experiments/inr_classification/dataset/mnist_statistics.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/inr_classification/dataset/mnist_statistics.pth
--------------------------------------------------------------------------------
/experiments/inr_classification/dataset/preprocess_fmnist.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | from collections import defaultdict
4 | from pathlib import Path
5 |
6 | import torch
7 |
8 | from experiments.utils import common_parser, set_logger
9 |
10 |
11 | def generate_splits(
12 | data_path,
13 | save_path,
14 | name="fmnist_splits.json",
15 | ):
16 | script_dir = Path(__file__).parent
17 | inr_path = script_dir / Path(data_path)
18 | with open(inr_path, "r") as f:
19 | data = json.load(f)
20 |
21 | splits = ["train", "val", "test"]
22 | data_split = defaultdict(lambda: defaultdict(list))
23 | for split in splits:
24 | print(f"Processing {split} split")
25 |
26 | data_split[split]["path"] = [
27 | (Path(data_path).parent / Path(*Path(di).parts[-2:])).as_posix()
28 | for di in data[split]
29 | ]
30 |
31 | data_split[split]["label"] = [
32 | torch.load(p, map_location=lambda storage, loc: storage)["label"]
33 | for p in data_split[split]["path"]
34 | ]
35 |
36 | print(f"Finished processing {split} split")
37 |
38 | save_path = script_dir / Path(save_path) / name
39 | with open(save_path, "w") as file:
40 | json.dump(data_split, file)
41 |
42 |
43 | if __name__ == "__main__":
44 | parser = ArgumentParser(
45 | "INR Classification - Fashion MNIST - preprocess data", parents=[common_parser]
46 | )
47 | parser.add_argument(
48 | "--name", type=str, default="fmnist_splits.json", help="json file name"
49 | )
50 | parser.set_defaults(
51 | save_path=".",
52 | data_path="fmnist_inrs/splits.json",
53 | )
54 | args = parser.parse_args()
55 |
56 | set_logger()
57 |
58 | generate_splits(
59 | args.data_path,
60 | args.save_path,
61 | name=args.name,
62 | )
63 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/fmnist_cls_pna.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=fmnist n_epochs=200 \
9 | data.train.augmentation=True model.graph_constructor.num_probe_features=0 \
10 | model.gnn_backbone.dropout=0.2 model.graph_constructor.use_pos_embed=True \
11 | model.modulate_v=True model.rev_edge_features=True \
12 | wandb.name=inr_cls_fmnist_pna_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
13 | $extra_args
14 | done
15 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/fmnist_cls_pna_probe_64.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=fmnist n_epochs=200 \
9 | data.train.augmentation=True model.graph_constructor.num_probe_features=64 \
10 | model.gnn_backbone.dropout=0.2 model.graph_constructor.use_pos_embed=True \
11 | model.modulate_v=True model.rev_edge_features=True \
12 | wandb.name=inr_cls_fmnist_pna_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
13 | $extra_args
14 | done
15 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/fmnist_cls_rt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=fmnist n_epochs=200 \
9 | data.train.augmentation=True model.graph_constructor.num_probe_features=0 \
10 | model.dropout=0.2 model.graph_constructor.use_pos_embed=True model.modulate_v=True \
11 | wandb.name=inr_cls_fmnist_rt_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/fmnist_cls_rt_probe_64.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=fmnist n_epochs=200 \
9 | data.train.augmentation=True model.graph_constructor.num_probe_features=64 \
10 | model.dropout=0.2 model.graph_constructor.use_pos_embed=True model.modulate_v=True \
11 | wandb.name=inr_cls_fmnist_rt_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/mnist_cls_pna.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=mnist n_epochs=200 \
9 | model.graph_constructor.num_probe_features=0 model.gnn_backbone.dropout=0.2 \
10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True model.rev_edge_features=True \
11 | wandb.name=inr_cls_mnist_pna_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/mnist_cls_pna_probe_64.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=mnist n_epochs=200 \
9 | model.graph_constructor.num_probe_features=64 model.gnn_backbone.dropout=0.2 \
10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True model.rev_edge_features=True \
11 | wandb.name=inr_cls_mnist_pna_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/mnist_cls_rt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=mnist n_epochs=200 \
9 | model.graph_constructor.num_probe_features=0 model.dropout=0.2 \
10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True \
11 | wandb.name=inr_cls_mnist_rt_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/inr_classification/scripts/mnist_cls_rt_probe_64.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=mnist n_epochs=200 \
9 | model.graph_constructor.num_probe_features=64 model.dropout=0.2 \
10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True \
11 | wandb.name=inr_cls_mnist_rt_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/learning_to_optimize/README.md:
--------------------------------------------------------------------------------
1 | # Learning to Optimize
2 |
3 | Coming soon!
4 |
--------------------------------------------------------------------------------
/experiments/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from torch.optim.lr_scheduler import LRScheduler
4 |
5 |
6 | class WarmupLRScheduler(LRScheduler):
7 | def __init__(self, optimizer, warmup_steps=10000, last_epoch=-1, verbose=False):
8 | self.warmup_steps = warmup_steps
9 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
10 |
11 | def get_lr(self):
12 | if self._step_count < self.warmup_steps:
13 | return [
14 | base_lr * self._step_count / self.warmup_steps
15 | for base_lr in self.base_lrs
16 | ]
17 | else:
18 | return self.base_lrs
19 |
20 |
21 | class ExpLRScheduler(LRScheduler):
22 | def __init__(
23 | self,
24 | optimizer,
25 | warmup_steps=10000,
26 | decay_rate=0.5,
27 | decay_steps=100000,
28 | last_epoch=-1,
29 | verbose=False,
30 | ):
31 | self.warmup_steps = warmup_steps
32 | self.decay_rate = decay_rate
33 | self.decay_steps = decay_steps
34 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
35 |
36 | def get_lr(self):
37 | if self._step_count < self.warmup_steps:
38 | learning_rates = [
39 | base_lr * self._step_count / self.warmup_steps
40 | for base_lr in self.base_lrs
41 | ]
42 | else:
43 | learning_rates = self.base_lrs
44 | learning_rates = [
45 | lr * (self.decay_rate ** (self._step_count / self.decay_steps))
46 | for lr in learning_rates
47 | ]
48 | # print(self._step_count, learning_rates)
49 | return learning_rates
50 |
51 |
52 | class CosLRScheduler(LRScheduler):
53 | def __init__(
54 | self,
55 | optimizer,
56 | warmup_steps,
57 | decay_steps,
58 | last_epoch=-1,
59 | alpha=0.0,
60 | verbose=False,
61 | ):
62 | self.warmup_steps = warmup_steps
63 | self.decay_steps = decay_steps
64 | self.alpha = alpha
65 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
66 |
67 | def get_lr(self):
68 | if self._step_count < self.warmup_steps:
69 | learning_rates = [
70 | base_lr * self._step_count / self.warmup_steps
71 | for base_lr in self.base_lrs
72 | ]
73 | else:
74 | decay_steps = self.decay_steps - self.warmup_steps
75 | step = min(self._step_count - self.warmup_steps, decay_steps)
76 | cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.decay_steps))
77 | decayed = (1 - self.alpha) * cosine_decay + self.alpha
78 | learning_rates = [lr * decayed for lr in self.base_lrs]
79 | return learning_rates
80 |
--------------------------------------------------------------------------------
/experiments/style_editing/README.md:
--------------------------------------------------------------------------------
1 | # INR style editing
2 |
3 | The following commands assume that you are executing them from the current directory `experiments/style_editing`.
4 | If you are in the root of the repository, please navigate to the `experiments/style_editing` directory:
5 |
6 | ```sh
7 | cd experiments/style_editing
8 | ```
9 |
10 | Activate the conda environment:
11 |
12 | ```sh
13 | conda activate neural-graphs
14 | ```
15 |
16 | ## Setup
17 |
18 | Follow the directions from the [INR classification
19 | experiment](../inr_classification#setup) to download the data and
20 | preprocess it. The default
21 | dataset directory is `dataset` and it is shared with the `inr_classification` experiment.
22 |
23 | ## Run the experiment
24 |
25 | Now for the fun part! :rocket:
26 | To train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command:
27 |
28 | ```shell
29 | python main.py model=rtransformer data=mnist
30 | ```
31 |
32 | Make sure to check the model configuration in `configs/model/rtransformer.yaml`
33 | and the data configuration in `configs/data/mnist.yaml`.
34 | If you used different paths for the data, you can either overwrite the default
35 | paths in `configs/data/mnist.yaml` or pass the paths as arguments to the command:
36 |
37 | ```shell
38 | python main.py model=rtransformer data=mnist \
39 | data.dataset_dir= data.splits_path= \
40 | data.statistics_path=
41 | ```
42 |
43 | Training a different model is as simple as changing the `model` argument!
44 | For example, you can train and evaluate a __Neural Graph Graph Neural Network__ (NG-GNN)
45 | on MNIST using the following command:
46 |
47 | ```shell
48 | python main.py model=pna data=mnist
49 | ```
50 |
51 | ### Run experiments with scripts
52 |
53 | You can also run the experiments using the scripts provided in the `scripts` directory.
54 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command:
55 |
56 | ```sh
57 | ./scripts/mnist_dilation_rt.sh
58 | ```
59 | This script will run the experiment for 3 different seeds.
60 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: pna
3 | - data: mnist
4 | - out_of_domain_data: fmnist
5 | - _self_
6 |
7 | n_epochs: 200
8 | batch_size: 128
9 |
10 | num_workers: 8
11 | eval_every: 1000
12 | num_accum: 1
13 |
14 | compile: false
15 | compile_kwargs:
16 | # mode: reduce-overhead
17 | mode: null
18 | options:
19 | matmul-padding: True
20 |
21 | optim:
22 | _target_: torch.optim.AdamW
23 | lr: 1e-3
24 | weight_decay: 5e-4
25 | amsgrad: True
26 | fused: False
27 |
28 | scheduler:
29 | _target_: experiments.lr_scheduler.WarmupLRScheduler
30 | warmup_steps: 10000
31 |
32 | distributed:
33 | world_size: 1
34 | rank: 0
35 | device_ids: null
36 |
37 | load_ckpt: null
38 |
39 | use_amp: False
40 | gradscaler:
41 | enabled: ${use_amp}
42 | autocast:
43 | device_type: cuda
44 | enabled: ${use_amp}
45 | dtype: float16
46 |
47 | clip_grad: True
48 | clip_grad_max_norm: 1.0
49 |
50 | seed: 42
51 | save_path: ./output
52 | wandb:
53 | project: style-editing
54 | entity: null
55 | name: null
56 |
57 | log_n_imgs: 4
58 |
59 | matmul_precision: high
60 | cudnn_benchmark: False
61 |
62 | debug: False
63 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/data/fmnist.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.INRAndImageDataset
3 | data_format: dws_mnist
4 | style:
5 | _target_: experiments.style_editing.image_processing.Dilate
6 | normalize: False
7 | dataset_name: fmnist
8 | dataset_dir: dataset
9 | splits_path: fmnist_splits.json
10 | statistics_path: fmnist_statistics.pth
11 | img_shape: [28, 28]
12 | inr_model:
13 | _target_: nn.inr.INRPerLayer
14 | in_features: 2
15 | n_layers: 3
16 | hidden_features: 32
17 | out_features: 1
18 | img_ds_cls: torchvision.datasets.FashionMNIST
19 | img_path: dataset/fashion-mnist
20 | img_download: True
21 |
22 | batch_siren:
23 | _target_: experiments.data.BatchSiren
24 | in_features: ${data.inr_model.in_features}
25 | out_features: ${data.inr_model.out_features}
26 | n_layers: ${data.inr_model.n_layers}
27 | hidden_features: ${data.inr_model.hidden_features}
28 | img_shape: ${data.img_shape}
29 |
30 | stats:
31 | weights_mean: [6.370305982272839e-06, 6.88720547259436e-06, 1.0729863788583316e-05]
32 | weights_std: [0.07822809368371964, 0.03240188956260681, 0.13454964756965637]
33 | biases_mean: [1.6790845336345228e-07, -1.1566662578843534e-05, -0.020282816141843796]
34 | biases_std: [0.028561526909470558, 0.016700252890586853, 0.09595609456300735]
35 |
36 | train:
37 | _target_: ${data.target}
38 | _recursive_: True
39 | dataset_name: ${data.dataset_name}
40 | dataset_dir: ${data.dataset_dir}
41 | splits_path: ${data.splits_path}
42 | split: train
43 | normalize: ${data.normalize}
44 | augmentation: False
45 | permutation: False
46 | statistics_path: ${data.statistics_path}
47 | img_offset: 0
48 | # num_classes: ${data.num_classes}
49 | style_function: ${data.style}
50 | img_ds:
51 | _target_: ${data.img_ds_cls}
52 | train: True
53 | root: ${data.img_path}
54 | download: ${data.img_download}
55 |
56 | val:
57 | _target_: ${data.target}
58 | _recursive_: True
59 | dataset_name: ${data.dataset_name}
60 | dataset_dir: ${data.dataset_dir}
61 | splits_path: ${data.splits_path}
62 | split: val
63 | normalize: ${data.normalize}
64 | augmentation: False
65 | permutation: False
66 | statistics_path: ${data.statistics_path}
67 | img_offset: 45000
68 | # num_classes: ${data.num_classes}
69 | style_function: ${data.style}
70 | img_ds:
71 | _target_: ${data.img_ds_cls}
72 | train: True
73 | root: ${data.img_path}
74 | download: ${data.img_download}
75 |
76 | test:
77 | _target_: ${data.target}
78 | _recursive_: True
79 | dataset_name: ${data.dataset_name}
80 | dataset_dir: ${data.dataset_dir}
81 | splits_path: ${data.splits_path}
82 | split: test
83 | normalize: ${data.normalize}
84 | augmentation: False
85 | permutation: False
86 | statistics_path: ${data.statistics_path}
87 | img_offset: 0
88 | # num_classes: ${data.num_classes}
89 | style_function: ${data.style}
90 | img_ds:
91 | _target_: ${data.img_ds_cls}
92 | train: False
93 | root: ${data.img_path}
94 | download: ${data.img_download}
95 |
96 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/data/mnist.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.INRAndImageDataset
3 | data_format: dws_mnist
4 | style:
5 | _target_: experiments.style_editing.image_processing.Dilate
6 | normalize: False
7 | dataset_name: mnist
8 | dataset_dir: dataset
9 | splits_path: mnist_splits.json
10 | statistics_path: mnist_statistics.pth
11 | img_shape: [28, 28]
12 | inr_model:
13 | _target_: nn.inr.INRPerLayer
14 | in_features: 2
15 | n_layers: 3
16 | hidden_features: 32
17 | out_features: 1
18 | img_ds_cls: torchvision.datasets.MNIST
19 | img_path: dataset/mnist
20 | img_download: True
21 |
22 | batch_siren:
23 | _target_: experiments.data.BatchSiren
24 | in_features: ${data.inr_model.in_features}
25 | out_features: ${data.inr_model.out_features}
26 | n_layers: ${data.inr_model.n_layers}
27 | hidden_features: ${data.inr_model.hidden_features}
28 | img_shape: ${data.img_shape}
29 |
30 | stats:
31 | weights_mean: [-0.0001166215879493393, -3.2710825053072767e-06, 7.234242366394028e-05]
32 | weights_std: [0.06279338896274567, 0.01827024295926094, 0.11813738197088242]
33 | biases_mean: [4.912401891488116e-06, -3.210141949239187e-05, -0.012279038317501545]
34 | biases_std: [0.021347912028431892, 0.0109943225979805, 0.09998151659965515]
35 |
36 | train:
37 | _target_: ${data.target}
38 | _recursive_: True
39 | dataset_name: ${data.dataset_name}
40 | dataset_dir: ${data.dataset_dir}
41 | splits_path: ${data.splits_path}
42 | split: train
43 | normalize: ${data.normalize}
44 | augmentation: False
45 | permutation: False
46 | statistics_path: ${data.statistics_path}
47 | img_offset: 0
48 | # num_classes: ${data.num_classes}
49 | style_function: ${data.style}
50 | img_ds:
51 | _target_: ${data.img_ds_cls}
52 | train: True
53 | root: ${data.img_path}
54 | download: ${data.img_download}
55 |
56 | val:
57 | _target_: ${data.target}
58 | _recursive_: True
59 | dataset_name: ${data.dataset_name}
60 | dataset_dir: ${data.dataset_dir}
61 | splits_path: ${data.splits_path}
62 | split: val
63 | normalize: ${data.normalize}
64 | augmentation: False
65 | permutation: False
66 | statistics_path: ${data.statistics_path}
67 | img_offset: 45000
68 | # num_classes: ${data.num_classes}
69 | style_function: ${data.style}
70 | img_ds:
71 | _target_: ${data.img_ds_cls}
72 | train: True
73 | root: ${data.img_path}
74 | download: ${data.img_download}
75 |
76 | test:
77 | _target_: ${data.target}
78 | _recursive_: True
79 | dataset_name: ${data.dataset_name}
80 | dataset_dir: ${data.dataset_dir}
81 | splits_path: ${data.splits_path}
82 | split: test
83 | normalize: ${data.normalize}
84 | augmentation: False
85 | permutation: False
86 | statistics_path: ${data.statistics_path}
87 | img_offset: 0
88 | # num_classes: ${data.num_classes}
89 | style_function: ${data.style}
90 | img_ds:
91 | _target_: ${data.img_ds_cls}
92 | train: False
93 | root: ${data.img_path}
94 | download: ${data.img_download}
95 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/data/nfn_cifar.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.SirenAndOriginalDataset
3 | data_format: nfn_mnist
4 | dataset_name: cifar10
5 | style:
6 | _target_: experiments.style_editing.image_processing.IncreaseContrast
7 | normalize: False
8 | img_shape: [32, 32]
9 | inr_model:
10 | _target_: experiments.data_nfn.SirenPerLayer
11 | in_features: 2
12 | hidden_features: 32
13 | hidden_layers: 1
14 | out_features: 3
15 | outermost_linear: True
16 | first_omega_0: 30.0
17 | hidden_omega_0: 30.0
18 | img_ds_cls: torchvision.datasets.CIFAR10
19 | img_path: dataset/cifar10
20 | img_download: True
21 |
22 | siren_path: dataset/nfn-cifar10-inrs
23 |
24 | batch_siren:
25 | _target_: experiments.data_nfn.BatchSiren
26 | in_features: ${data.inr_model.in_features}
27 | hidden_features: ${data.inr_model.hidden_features}
28 | hidden_layers: ${data.inr_model.hidden_layers}
29 | out_features: ${data.inr_model.out_features}
30 | outermost_linear: ${data.inr_model.outermost_linear}
31 | first_omega_0: ${data.inr_model.first_omega_0}
32 | hidden_omega_0: ${data.inr_model.hidden_omega_0}
33 | img_shape: ${data.img_shape}
34 |
35 | stats:
36 | weights_mean: [0.00018394182552583516, -2.5748543066583807e-06, -4.988231376046315e-05]
37 | weights_std: [0.2802596390247345, 0.017659902572631836, 0.05460081994533539]
38 | biases_mean: [0.0005445665447041392, -2.380055775574874e-06, -0.0024678376503288746]
39 | biases_std: [0.40869608521461487, 0.10388434678316116, 0.08734994381666183]
40 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/data/nfn_mnist.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.SirenAndOriginalDataset
3 | data_format: nfn_mnist
4 | dataset_name: mnist
5 | style:
6 | _target_: experiments.style_editing.image_processing.Dilate
7 | normalize: False
8 | img_shape: [28, 28]
9 | inr_model:
10 | _target_: experiments.data_nfn.SirenPerLayer
11 | in_features: 2
12 | hidden_features: 32
13 | hidden_layers: 1
14 | out_features: 1
15 | outermost_linear: True
16 | first_omega_0: 30.0
17 | hidden_omega_0: 30.0
18 | img_ds_cls: torchvision.datasets.MNIST
19 | img_path: dataset/mnist
20 | img_download: True
21 |
22 | siren_path: dataset/nfn-mnist-inrs
23 |
24 | batch_siren:
25 | _target_: experiments.data_nfn.BatchSiren
26 | in_features: ${data.inr_model.in_features}
27 | hidden_features: ${data.inr_model.hidden_features}
28 | hidden_layers: ${data.inr_model.hidden_layers}
29 | out_features: ${data.inr_model.out_features}
30 | outermost_linear: ${data.inr_model.outermost_linear}
31 | first_omega_0: ${data.inr_model.first_omega_0}
32 | hidden_omega_0: ${data.inr_model.hidden_omega_0}
33 | img_shape: ${data.img_shape}
34 |
35 | stats:
36 | weights_mean: [0.00012268121645320207, -8.858834803504578e-07, 2.4448696422041394e-05]
37 | weights_std: [0.2868247926235199, 0.017109761014580727, 0.06391365826129913]
38 | biases_mean: [0.0006445261533372104, -3.312843546154909e-05, -0.03267413377761841]
39 | biases_std: [0.40904879570007324, 0.10408575087785721, 0.09695733338594437]
40 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/model/dwsnet.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.dws.models.DWSModel
2 | _recursive_: False
3 | input_features: 1
4 | hidden_dim: 32
5 | output_features: 1
6 | n_hidden: 4
7 | reduction: max
8 | n_fc_layers: 1
9 | num_heads: 8
10 | set_layer: sab
11 | dropout_rate: 0.0
12 | bn: true
13 | diagonal: false
14 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/model/nfn.yaml:
--------------------------------------------------------------------------------
1 | name: nn.nfn
2 | cls: nn.nfn.models.TransferNet
3 | kwargs:
4 | hidden_chan: 128
5 | hidden_layers: 3
6 | mode: NP
7 | out_scale: 0.01
8 | dropout: 0.0
9 | gfft:
10 | in_channels: 1
11 | mapping_size: 128
12 | scale: 3
13 | iosinemb:
14 | max_freq: 3
15 | num_bands: 3
16 | enc_layers: false
17 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/model/pna.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.dense_gnn.GNNParams
2 | _recursive_: False
3 | d_out: 1
4 | d_hid: 128
5 | compile: False
6 | rev_edge_features: False
7 | jit: False
8 | stats: ${data.stats}
9 | normalize: False
10 | out_scale: 0.01
11 |
12 | gnn_backbone:
13 | _target_: nn.gnn.PNA
14 | _convert_: all
15 | in_channels: ${model.d_hid}
16 | hidden_channels: ${model.d_hid}
17 | out_channels: ${model.d_hid}
18 | num_layers: 4
19 | aggregators: ['mean', 'min', 'max', 'std']
20 | scalers: ['identity', 'amplification']
21 | edge_dim: ${model.d_hid}
22 | dropout: 0.
23 | norm: layernorm
24 | act: silu
25 | deg: null
26 | update_edge_attr: True
27 | modulate_edges: True
28 | gating_edges: False
29 | final_edge_update: False
30 |
31 | graph_constructor:
32 | _target_: nn.graph_constructor.GraphConstructor
33 | _recursive_: False
34 | _convert_: all
35 | d_in: 1
36 | d_edge_in: 1
37 | zero_out_bias: False
38 | zero_out_weights: False
39 | sin_emb: False
40 | sin_emb_dim: 128
41 | use_pos_embed: True
42 | input_layers: 1
43 | inp_factor: 3
44 | num_probe_features: 0
45 | inr_model: ${data.inr_model}
46 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/model/rtransformer.yaml:
--------------------------------------------------------------------------------
1 | _target_: nn.dense_relational_transformer.RelationalTransformerParams
2 | _recursive_: False
3 | d_out: 1
4 | d_node: 64
5 | d_edge: 32
6 | d_attn_hid: 128
7 | d_node_hid: 128
8 | d_edge_hid: 64
9 | d_out_hid: 128
10 | n_layers: 4
11 | n_heads: 4
12 | node_update_type: rt
13 | disable_edge_updates: False
14 | dropout: 0.0
15 | rev_edge_features: False
16 | use_ln: True
17 | tfixit_init: False
18 | stats: ${data.stats}
19 | normalize: False
20 | modulate_v: True
21 |
22 | graph_constructor:
23 | _target_: nn.graph_constructor.GraphConstructor
24 | _recursive_: False
25 | _convert_: all
26 | d_in: 1
27 | d_edge_in: 1
28 | zero_out_bias: False
29 | zero_out_weights: False
30 | sin_emb: True
31 | sin_emb_dim: 128
32 | use_pos_embed: True
33 | input_layers: 1
34 | inp_factor: 3
35 | num_probe_features: 0
36 | inr_model: ${data.inr_model}
37 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/out_of_domain_data/fmnist.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.INRAndImageDataset
3 | data_format: dws_mnist
4 | style:
5 | _target_: experiments.style_editing.image_processing.Dilate
6 | normalize: False
7 | dataset_name: fmnist
8 | dataset_dir: dataset
9 | splits_path: fmnist_splits.json
10 | statistics_path: fmnist_statistics.pth
11 | img_shape: [28, 28]
12 | inr_model:
13 | _target_: nn.inr.INRPerLayer
14 | in_features: 2
15 | n_layers: 3
16 | hidden_features: 32
17 | out_features: 1
18 | img_ds_cls: torchvision.datasets.FashionMNIST
19 | img_path: dataset/fashion-mnist
20 | img_download: True
21 |
22 | batch_siren:
23 | _target_: experiments.data.BatchSiren
24 | in_features: ${out_of_domain_data.inr_model.in_features}
25 | out_features: ${out_of_domain_data.inr_model.out_features}
26 | n_layers: ${out_of_domain_data.inr_model.n_layers}
27 | hidden_features: ${out_of_domain_data.inr_model.hidden_features}
28 | img_shape: ${out_of_domain_data.img_shape}
29 |
30 | stats:
31 | weights_mean: [6.370305982272839e-06, 6.88720547259436e-06, 1.0729863788583316e-05]
32 | weights_std: [0.07822809368371964, 0.03240188956260681, 0.13454964756965637]
33 | biases_mean: [1.6790845336345228e-07, -1.1566662578843534e-05, -0.020282816141843796]
34 | biases_std: [0.028561526909470558, 0.016700252890586853, 0.09595609456300735]
35 |
36 | train:
37 | _target_: ${out_of_domain_data.target}
38 | _recursive_: True
39 | dataset_name: ${out_of_domain_data.dataset_name}
40 | dataset_dir: ${out_of_domain_data.dataset_dir}
41 | splits_path: ${out_of_domain_data.splits_path}
42 | split: train
43 | normalize: ${out_of_domain_data.normalize}
44 | augmentation: False
45 | permutation: False
46 | statistics_path: ${out_of_domain_data.statistics_path}
47 | img_offset: 0
48 | # num_classes: ${out_of_domain_data.num_classes}
49 | style_function: ${out_of_domain_data.style}
50 | img_ds:
51 | _target_: ${out_of_domain_data.img_ds_cls}
52 | train: True
53 | root: ${out_of_domain_data.img_path}
54 | download: ${out_of_domain_data.img_download}
55 |
56 | val:
57 | _target_: ${out_of_domain_data.target}
58 | _recursive_: True
59 | dataset_name: ${out_of_domain_data.dataset_name}
60 | dataset_dir: ${out_of_domain_data.dataset_dir}
61 | splits_path: ${out_of_domain_data.splits_path}
62 | split: val
63 | normalize: ${out_of_domain_data.normalize}
64 | augmentation: False
65 | permutation: False
66 | statistics_path: ${out_of_domain_data.statistics_path}
67 | img_offset: 45000
68 | # num_classes: ${out_of_domain_data.num_classes}
69 | style_function: ${out_of_domain_data.style}
70 | img_ds:
71 | _target_: ${out_of_domain_data.img_ds_cls}
72 | train: True
73 | root: ${out_of_domain_data.img_path}
74 | download: ${out_of_domain_data.img_download}
75 |
76 | test:
77 | _target_: ${out_of_domain_data.target}
78 | _recursive_: True
79 | dataset_name: ${out_of_domain_data.dataset_name}
80 | dataset_dir: ${out_of_domain_data.dataset_dir}
81 | splits_path: ${out_of_domain_data.splits_path}
82 | split: test
83 | normalize: ${out_of_domain_data.normalize}
84 | augmentation: False
85 | permutation: False
86 | statistics_path: ${out_of_domain_data.statistics_path}
87 | img_offset: 0
88 | # num_classes: ${out_of_domain_data.num_classes}
89 | style_function: ${out_of_domain_data.style}
90 | img_ds:
91 | _target_: ${out_of_domain_data.img_ds_cls}
92 | train: False
93 | root: ${out_of_domain_data.img_path}
94 | download: ${out_of_domain_data.img_download}
95 |
96 |
--------------------------------------------------------------------------------
/experiments/style_editing/configs/out_of_domain_data/mnist.yaml:
--------------------------------------------------------------------------------
1 | # shared
2 | target: experiments.data.INRAndImageDataset
3 | data_format: dws_mnist
4 | style:
5 | _target_: experiments.style_editing.image_processing.Dilate
6 | normalize: False
7 | dataset_name: mnist
8 | dataset_dir: dataset
9 | splits_path: mnist_splits.json
10 | statistics_path: mnist_statistics.pth
11 | img_shape: [28, 28]
12 | inr_model:
13 | _target_: nn.inr.INRPerLayer
14 | in_features: 2
15 | n_layers: 3
16 | hidden_features: 32
17 | out_features: 1
18 | img_ds_cls: torchvision.datasets.MNIST
19 | img_path: dataset/mnist
20 | img_download: True
21 |
22 | batch_siren:
23 | _target_: experiments.data.BatchSiren
24 | in_features: ${out_of_domain_data.inr_model.in_features}
25 | out_features: ${out_of_domain_data.inr_model.out_features}
26 | n_layers: ${out_of_domain_data.inr_model.n_layers}
27 | hidden_features: ${out_of_domain_data.inr_model.hidden_features}
28 | img_shape: ${out_of_domain_data.img_shape}
29 |
30 | stats:
31 | weights_mean: [-0.0001166215879493393, -3.2710825053072767e-06, 7.234242366394028e-05]
32 | weights_std: [0.06279338896274567, 0.01827024295926094, 0.11813738197088242]
33 | biases_mean: [4.912401891488116e-06, -3.210141949239187e-05, -0.012279038317501545]
34 | biases_std: [0.021347912028431892, 0.0109943225979805, 0.09998151659965515]
35 |
36 | train:
37 | _target_: ${out_of_domain_data.target}
38 | _recursive_: True
39 | dataset_name: ${out_of_domain_data.dataset_name}
40 | dataset_dir: ${out_of_domain_data.dataset_dir}
41 | splits_path: ${out_of_domain_data.splits_path}
42 | split: train
43 | normalize: ${out_of_domain_data.normalize}
44 | augmentation: False
45 | permutation: False
46 | statistics_path: ${out_of_domain_data.statistics_path}
47 | img_offset: 0
48 | # num_classes: ${out_of_domain_data.num_classes}
49 | style_function: ${out_of_domain_data.style}
50 | img_ds:
51 | _target_: ${out_of_domain_data.img_ds_cls}
52 | train: True
53 | root: ${out_of_domain_data.img_path}
54 | download: ${out_of_domain_data.img_download}
55 |
56 | val:
57 | _target_: ${out_of_domain_data.target}
58 | _recursive_: True
59 | dataset_name: ${out_of_domain_data.dataset_name}
60 | dataset_dir: ${out_of_domain_data.dataset_dir}
61 | splits_path: ${out_of_domain_data.splits_path}
62 | split: val
63 | normalize: ${out_of_domain_data.normalize}
64 | augmentation: False
65 | permutation: False
66 | statistics_path: ${out_of_domain_data.statistics_path}
67 | img_offset: 45000
68 | # num_classes: ${out_of_domain_data.num_classes}
69 | style_function: ${out_of_domain_data.style}
70 | img_ds:
71 | _target_: ${out_of_domain_data.img_ds_cls}
72 | train: True
73 | root: ${out_of_domain_data.img_path}
74 | download: ${out_of_domain_data.img_download}
75 |
76 | test:
77 | _target_: ${out_of_domain_data.target}
78 | _recursive_: True
79 | dataset_name: ${out_of_domain_data.dataset_name}
80 | dataset_dir: ${out_of_domain_data.dataset_dir}
81 | splits_path: ${out_of_domain_data.splits_path}
82 | split: test
83 | normalize: ${out_of_domain_data.normalize}
84 | augmentation: False
85 | permutation: False
86 | statistics_path: ${out_of_domain_data.statistics_path}
87 | img_offset: 0
88 | # num_classes: ${out_of_domain_data.num_classes}
89 | style_function: ${out_of_domain_data.style}
90 | img_ds:
91 | _target_: ${out_of_domain_data.img_ds_cls}
92 | train: False
93 | root: ${out_of_domain_data.img_path}
94 | download: ${out_of_domain_data.img_download}
95 |
--------------------------------------------------------------------------------
/experiments/style_editing/dataset:
--------------------------------------------------------------------------------
1 | ../inr_classification/dataset/
--------------------------------------------------------------------------------
/experiments/style_editing/image_processing.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def increase_contrast(img):
6 | # https://stackoverflow.com/questions/39308030/how-do-i-increase-the-contrast-of-an-image-in-python-opencv
7 | lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
8 | l_channel, a, b = cv2.split(lab)
9 | # Applying CLAHE to L-channel
10 | # feel free to try different values for the limit and grid size:
11 | clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(3, 3))
12 | cl = clahe.apply(l_channel)
13 | # merge the CLAHE enhanced L-channel with the a and b channel
14 | limg = cv2.merge((cl, a, b))
15 | # Converting image from LAB Color model to BGR color space
16 | enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
17 | return enhanced_img
18 |
19 |
20 | def dilate(img):
21 | kernel = np.ones((3, 3), np.uint8)
22 | return cv2.dilate(img, kernel, iterations=1)
23 |
24 |
25 | class Dilate(object):
26 | def __call__(self, img):
27 | return dilate(img)
28 |
29 |
30 | class IncreaseContrast(object):
31 | def __call__(self, img):
32 | return increase_contrast(img)
33 |
--------------------------------------------------------------------------------
/experiments/style_editing/scripts/mnist_dilation_gnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=pna data=mnist n_epochs=200 \
9 | model.graph_constructor.num_probe_features=0 model.gnn_backbone.dropout=0.2 \
10 | model.rev_edge_features=True \
11 | wandb.name=style_editing_mnist_dilation_pna_seed_${seed}_epoch_200_rev_edge_epoch_200_drop_0.2 \
12 | $extra_args
13 | done
14 |
--------------------------------------------------------------------------------
/experiments/style_editing/scripts/mnist_dilation_rt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | extra_args="$@"
4 | seeds=(0 1 2)
5 |
6 | for seed in "${seeds[@]}"
7 | do
8 | python -u main.py seed=$seed model=rtransformer data=mnist n_epochs=200 \
9 | model.graph_constructor.num_probe_features=0 model.d_node=64 model.d_edge=32 \
10 | model.dropout=0.3 model.graph_constructor.use_pos_embed=True \
11 | model.modulate_v=True model.rev_edge_features=True \
12 | wandb.name=style_editing_mnist_dilation_rt_seed_${seed}_hid_64_epoch_200_rev_edge_epoch_200_drop_0.3 \
13 | $extra_args
14 | done
15 |
--------------------------------------------------------------------------------
/experiments/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import random
5 | from typing import List, Tuple, Union
6 |
7 | import numpy as np
8 | import torch
9 | from torch.distributed import init_process_group
10 |
11 | common_parser = argparse.ArgumentParser(add_help=False, description="common parser")
12 | common_parser.add_argument("--data-path", type=str, help="path for dataset")
13 | common_parser.add_argument("--save-path", type=str, help="path for output file")
14 |
15 |
16 | def set_seed(seed):
17 | """for reproducibility
18 | :param seed:
19 | :return:
20 | """
21 | np.random.seed(seed)
22 | random.seed(seed)
23 |
24 | torch.manual_seed(seed)
25 | if torch.cuda.is_available():
26 | torch.cuda.manual_seed(seed)
27 | torch.cuda.manual_seed_all(seed)
28 |
29 | torch.backends.cudnn.enabled = True
30 | torch.backends.cudnn.benchmark = False
31 | torch.backends.cudnn.deterministic = True
32 |
33 |
34 | def set_logger():
35 | logging.basicConfig(
36 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
37 | level=logging.INFO,
38 | )
39 |
40 |
41 | def make_coordinates(
42 | shape: Union[Tuple[int], List[int]],
43 | bs: int,
44 | coord_range: Union[Tuple[int], List[int]] = (-1, 1),
45 | ) -> torch.Tensor:
46 | x_coordinates = np.linspace(coord_range[0], coord_range[1], shape[0])
47 | y_coordinates = np.linspace(coord_range[0], coord_range[1], shape[1])
48 | x_coordinates, y_coordinates = np.meshgrid(x_coordinates, y_coordinates)
49 | x_coordinates = x_coordinates.flatten()
50 | y_coordinates = y_coordinates.flatten()
51 | coordinates = np.stack([x_coordinates, y_coordinates]).T
52 | coordinates = np.repeat(coordinates[np.newaxis, ...], bs, axis=0)
53 | return torch.from_numpy(coordinates).type(torch.float)
54 |
55 |
56 | def count_parameters(model):
57 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
58 |
59 |
60 | def ddp_setup(rank: int, world_size: int):
61 | """
62 | Args:
63 | rank: Unique identifier of each process
64 | world_size: Total number of processes
65 | """
66 | os.environ["MASTER_ADDR"] = "localhost"
67 | os.environ["MASTER_PORT"] = "12355"
68 | init_process_group(backend="nccl", rank=rank, world_size=world_size)
69 | torch.cuda.set_device(rank)
70 |
--------------------------------------------------------------------------------
/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/nn/__init__.py
--------------------------------------------------------------------------------
/nn/activation_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class ActivationEmbedding(torch.nn.Module):
6 | ACTIVATION_FN = [
7 | "none",
8 | "relu",
9 | "gelu",
10 | "silu",
11 | "tanh",
12 | "sigmoid",
13 | "leaky_relu",
14 | ]
15 |
16 | def __init__(self, embedding_dim):
17 | super().__init__()
18 |
19 | self.activation_idx = {k: i for i, k in enumerate(self.ACTIVATION_FN)}
20 | self.idx_activation = {i: k for i, k in enumerate(self.ACTIVATION_FN)}
21 |
22 | self.embedding = torch.nn.Embedding(len(self.ACTIVATION_FN), embedding_dim)
23 |
24 | def forward(self, activations, layer_layout, device):
25 | indices = torch.tensor(
26 | [self.activation_idx[act] for act in activations],
27 | device=device,
28 | dtype=torch.long,
29 | )
30 | emb = self.embedding(indices)
31 | emb = emb.repeat_interleave(
32 | torch.tensor(layer_layout[1:-1], device=device), dim=0
33 | )
34 | emb = F.pad(emb, (0, 0, layer_layout[0], layer_layout[-1]))
35 | return emb
36 |
--------------------------------------------------------------------------------
/nn/dense_gnn.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.nn as nn
4 | import torch_geometric
5 |
6 | from nn.gnn import nn_to_edge_index, to_pyg_batch
7 |
8 |
9 | def graph_to_wb(
10 | edge_features,
11 | node_features,
12 | weights,
13 | biases,
14 | normalize=False,
15 | weights_mean=None,
16 | weights_std=None,
17 | biases_mean=None,
18 | biases_std=None,
19 | ):
20 | new_weights = []
21 | new_biases = []
22 |
23 | start = 0
24 | for i, w in enumerate(weights):
25 | size = torch.prod(torch.tensor(w.shape[1:]))
26 | w_mean = weights_mean[i] if normalize and weights_mean is not None else 0
27 | w_std = weights_std[i] if normalize and weights_std is not None else 1
28 | new_weights.append(
29 | edge_features[:, start : start + size].view(w.shape) * w_std + w_mean
30 | )
31 | start += size
32 |
33 | start = 0
34 | for i, b in enumerate(biases):
35 | size = torch.prod(torch.tensor(b.shape[1:]))
36 | b_mean = biases_mean[i] if normalize and biases_mean is not None else 0
37 | b_std = biases_std[i] if normalize and biases_std is not None else 1
38 | new_biases.append(
39 | node_features[:, start : start + size].view(b.shape) * b_std + b_mean
40 | )
41 | start += size
42 |
43 | return new_weights, new_biases
44 |
45 |
46 | class GNNParams(nn.Module):
47 | def __init__(
48 | self,
49 | d_hid,
50 | d_out,
51 | graph_constructor,
52 | gnn_backbone,
53 | layer_layout,
54 | rev_edge_features,
55 | stats=None,
56 | normalize=False,
57 | compile=False,
58 | jit=False,
59 | out_scale=0.01,
60 | ):
61 | super().__init__()
62 | self.nodes_per_layer = layer_layout
63 | self.layer_idx = torch.cumsum(torch.tensor([0] + layer_layout), dim=0)
64 |
65 | edge_index = nn_to_edge_index(self.nodes_per_layer, "cpu", dtype=torch.long)
66 | if rev_edge_features:
67 | edge_index = torch.cat([edge_index, edge_index.flip(dims=(0,))], dim=-1)
68 | self.register_buffer(
69 | "edge_index",
70 | edge_index,
71 | persistent=False,
72 | )
73 |
74 | self.construct_graph = hydra.utils.instantiate(
75 | graph_constructor,
76 | d_node=d_hid,
77 | d_edge=d_hid,
78 | layer_layout=layer_layout,
79 | rev_edge_features=rev_edge_features,
80 | stats=stats,
81 | )
82 |
83 | self.proj_edge = nn.Sequential(
84 | nn.Linear(d_hid, d_hid),
85 | nn.ReLU(),
86 | nn.Linear(d_hid, d_out),
87 | )
88 | self.proj_node = nn.Sequential(
89 | nn.Linear(d_hid, d_hid),
90 | nn.ReLU(),
91 | nn.Linear(d_hid, d_out),
92 | )
93 |
94 | gnn_kwargs = dict()
95 | if gnn_backbone.get("deg", False) is None:
96 | extended_layout = [0] + layer_layout
97 | deg = torch.zeros(max(extended_layout) + 1, dtype=torch.long)
98 | for li in range(len(extended_layout) - 1):
99 | deg[extended_layout[li]] += extended_layout[li + 1]
100 |
101 | gnn_kwargs["deg"] = deg
102 | self.gnn = hydra.utils.instantiate(gnn_backbone, **gnn_kwargs)
103 | if jit:
104 | self.gnn = torch.jit.script(self.gnn)
105 | if compile:
106 | self.gnn = torch_geometric.compile(self.gnn)
107 |
108 | self.weight_scale = nn.ParameterList(
109 | [
110 | nn.Parameter(torch.tensor(out_scale))
111 | for _ in range(len(layer_layout) - 1)
112 | ]
113 | )
114 | self.bias_scale = nn.ParameterList(
115 | [
116 | nn.Parameter(torch.tensor(out_scale))
117 | for _ in range(len(layer_layout) - 1)
118 | ]
119 | )
120 | self.stats = stats
121 | self.normalize = normalize
122 |
123 | def forward(self, inputs):
124 | node_features, edge_features, _ = self.construct_graph(inputs)
125 |
126 | batch = to_pyg_batch(node_features, edge_features, self.edge_index)
127 | node_out, edge_out = self.gnn(
128 | x=batch.x, edge_index=batch.edge_index, edge_attr=batch.edge_attr
129 | )
130 | edge_features = edge_out.reshape(edge_features.shape[0], -1, edge_out.shape[-1])
131 | node_features = node_out.reshape(node_features.shape[0], -1, node_out.shape[-1])
132 | edge_features = self.proj_edge(edge_features)
133 | node_features = self.proj_node(node_features)
134 |
135 | weights, biases = graph_to_wb(
136 | edge_features=edge_features,
137 | node_features=node_features,
138 | weights=inputs[0],
139 | biases=inputs[1],
140 | normalize=self.normalize,
141 | **self.stats,
142 | )
143 |
144 | weights = [w * s for w, s in zip(weights, self.weight_scale)]
145 | biases = [b * s for b, s in zip(biases, self.bias_scale)]
146 |
147 | return weights, biases
148 |
--------------------------------------------------------------------------------
/nn/dense_relational_transformer.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.nn as nn
4 |
5 | from nn.relational_transformer import RTLayer
6 |
7 |
8 | def graphs_to_batch(
9 | edge_features,
10 | node_features,
11 | weights,
12 | biases,
13 | normalize=False,
14 | weights_mean=None,
15 | weights_std=None,
16 | biases_mean=None,
17 | biases_std=None,
18 | ):
19 | new_weights = []
20 | new_biases = []
21 |
22 | row_offset = 0
23 | col_offset = weights[0].shape[1] # no edge to input nodes
24 | for i, w in enumerate(weights):
25 | _, num_in, num_out, _ = w.shape
26 | w_mean = weights_mean[i] if normalize and weights_mean is not None else 0
27 | w_std = weights_std[i] if normalize and weights_std is not None else 1
28 | new_weights.append(
29 | edge_features[
30 | :, row_offset : row_offset + num_in, col_offset : col_offset + num_out
31 | ]
32 | * w_std
33 | + w_mean
34 | )
35 | row_offset += num_in
36 | col_offset += num_out
37 |
38 | row_offset = weights[0].shape[1] # no bias in input nodes
39 | for i, b in enumerate(biases):
40 | _, num_out, _ = b.shape
41 | b_mean = biases_mean[i] if normalize and biases_mean is not None else 0
42 | b_std = biases_std[i] if normalize and biases_std is not None else 1
43 | new_biases.append(
44 | node_features[:, row_offset : row_offset + num_out] * b_std + b_mean
45 | )
46 | row_offset += num_out
47 |
48 | return new_weights, new_biases
49 |
50 |
51 | class RelationalTransformerParams(nn.Module):
52 | def __init__(
53 | self,
54 | d_node,
55 | d_edge,
56 | d_attn_hid,
57 | d_node_hid,
58 | d_edge_hid,
59 | d_out_hid,
60 | d_out,
61 | n_layers,
62 | n_heads,
63 | layer_layout,
64 | graph_constructor,
65 | dropout=0.0,
66 | node_update_type="rt",
67 | disable_edge_updates=False,
68 | rev_edge_features=False,
69 | modulate_v=True,
70 | use_ln=True,
71 | tfixit_init=False,
72 | stats=None,
73 | normalize=False,
74 | out_scale=0.01,
75 | ):
76 | super().__init__()
77 | self.rev_edge_features = rev_edge_features
78 | self.nodes_per_layer = layer_layout
79 | self.construct_graph = hydra.utils.instantiate(
80 | graph_constructor,
81 | d_node=d_node,
82 | d_edge=d_edge,
83 | layer_layout=layer_layout,
84 | rev_edge_features=rev_edge_features,
85 | stats=stats,
86 | )
87 |
88 | self.layers = nn.ModuleList(
89 | [
90 | torch.jit.script(
91 | RTLayer(
92 | d_node,
93 | d_edge,
94 | d_attn_hid,
95 | d_node_hid,
96 | d_edge_hid,
97 | n_heads,
98 | dropout,
99 | node_update_type=node_update_type,
100 | disable_edge_updates=disable_edge_updates,
101 | modulate_v=modulate_v,
102 | use_ln=use_ln,
103 | tfixit_init=tfixit_init,
104 | n_layers=n_layers,
105 | )
106 | )
107 | for _ in range(n_layers)
108 | ]
109 | )
110 |
111 | self.proj_edge = nn.Sequential(
112 | nn.Linear(d_edge, d_edge),
113 | nn.ReLU(),
114 | nn.Linear(d_edge, d_out),
115 | )
116 | self.proj_node = nn.Sequential(
117 | nn.Linear(d_node, d_node),
118 | nn.ReLU(),
119 | nn.Linear(d_node, d_out),
120 | )
121 |
122 | self.weight_scale = nn.ParameterList(
123 | [
124 | nn.Parameter(torch.tensor(out_scale))
125 | for _ in range(len(layer_layout) - 1)
126 | ]
127 | )
128 | self.bias_scale = nn.ParameterList(
129 | [
130 | nn.Parameter(torch.tensor(out_scale))
131 | for _ in range(len(layer_layout) - 1)
132 | ]
133 | )
134 | self.stats = stats
135 | self.normalize = normalize
136 |
137 | def forward(self, inputs):
138 | node_features, edge_features, mask = self.construct_graph(inputs)
139 |
140 | for layer in self.layers:
141 | node_features, edge_features = layer(node_features, edge_features, mask)
142 |
143 | node_features = self.proj_node(node_features)
144 | edge_features = self.proj_edge(edge_features)
145 |
146 | weights, biases = graphs_to_batch(
147 | edge_features,
148 | node_features,
149 | *inputs,
150 | normalize=self.normalize,
151 | **self.stats,
152 | )
153 | weights = [w * s for w, s in zip(weights, self.weight_scale)]
154 | biases = [b * s for b, s in zip(biases, self.bias_scale)]
155 | return weights, biases
156 |
--------------------------------------------------------------------------------
/nn/dws/__init__.py:
--------------------------------------------------------------------------------
1 | from nn.dws.layers import (
2 | BN,
3 | DownSampleDWSLayer,
4 | Dropout,
5 | DWSLayer,
6 | InvariantLayer,
7 | LeakyReLU,
8 | NaiveInvariantLayer,
9 | ReLU,
10 | )
11 |
--------------------------------------------------------------------------------
/nn/dynamic_gnn.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.nn as nn
4 | import torch_geometric
5 | from torch_geometric.utils import to_dense_batch
6 |
7 | from nn.pooling import HeterogeneousAggregator
8 |
9 |
10 | def to_pyg_batch(node_features, edge_features, edge_index, node_mask):
11 | data_list = [
12 | torch_geometric.data.Data(
13 | x=node_features[i][node_mask[i]],
14 | edge_index=edge_index[i],
15 | edge_attr=edge_features[i, edge_index[i][0], edge_index[i][1]],
16 | )
17 | for i in range(node_features.shape[0])
18 | ]
19 | return torch_geometric.data.Batch.from_data_list(data_list)
20 |
21 |
22 | class GNNForGeneralization(nn.Module):
23 | def __init__(
24 | self,
25 | d_hid,
26 | d_out,
27 | graph_constructor,
28 | gnn_backbone,
29 | rev_edge_features,
30 | pooling_method,
31 | pooling_layer_idx,
32 | compile=False,
33 | jit=False,
34 | input_channels=3,
35 | num_classes=10,
36 | layer_layout=None,
37 | ):
38 | super().__init__()
39 | self.pooling_method = pooling_method
40 | self.pooling_layer_idx = pooling_layer_idx
41 | self.out_features = d_out
42 | self.num_classes = num_classes
43 | self.rev_edge_features = rev_edge_features
44 |
45 | self.construct_graph = hydra.utils.instantiate(
46 | graph_constructor,
47 | d_node=d_hid,
48 | d_edge=d_hid,
49 | d_out=d_out,
50 | rev_edge_features=rev_edge_features,
51 | input_channels=input_channels,
52 | num_classes=num_classes,
53 | )
54 |
55 | num_graph_features = d_hid
56 | if pooling_method == "cat" and pooling_layer_idx == "last":
57 | num_graph_features = num_classes * d_hid
58 | elif pooling_method == "cat" and pooling_layer_idx == "all":
59 | # NOTE: Only allowed with datasets of fixed architectures
60 | num_graph_features = sum(layer_layout) * d_hid
61 |
62 | self.pool = HeterogeneousAggregator(
63 | d_hid,
64 | d_hid,
65 | d_hid,
66 | pooling_method,
67 | pooling_layer_idx,
68 | input_channels,
69 | num_classes,
70 | )
71 |
72 | self.proj_out = nn.Sequential(
73 | nn.Linear(num_graph_features, d_hid),
74 | nn.ReLU(),
75 | nn.Linear(d_hid, d_hid),
76 | nn.ReLU(),
77 | nn.Linear(d_hid, d_out),
78 | )
79 |
80 | gnn_kwargs = dict()
81 | gnn_kwargs["deg"] = torch.tensor(gnn_backbone["deg"], dtype=torch.long)
82 |
83 | self.gnn = hydra.utils.instantiate(gnn_backbone, **gnn_kwargs)
84 | if jit:
85 | self.gnn = torch.jit.script(self.gnn)
86 | if compile:
87 | self.gnn = torch_geometric.compile(self.gnn)
88 |
89 | def forward(self, batch):
90 | # self.register_buffer("edge_index", batch.edge_index, persistent=False)
91 | node_features, edge_features, _, node_mask = self.construct_graph(batch)
92 |
93 | if self.rev_edge_features:
94 | edge_index = [
95 | torch.cat(
96 | [batch[i].edge_index, batch[i].edge_index.flip(dims=(0,))], dim=-1
97 | )
98 | for i in range(len(batch))
99 | ]
100 | else:
101 | edge_index = [batch[i].edge_index for i in range(len(batch))]
102 |
103 | new_batch = to_pyg_batch(node_features, edge_features, edge_index, node_mask)
104 | out_node, out_edge = self.gnn(
105 | x=new_batch.x,
106 | edge_index=new_batch.edge_index,
107 | edge_attr=new_batch.edge_attr,
108 | )
109 | node_features = to_dense_batch(out_node, new_batch.batch)[0]
110 |
111 | graph_features = self.pool(
112 | node_features, batch.layer_layout, node_mask=node_mask
113 | )
114 |
115 | return self.proj_out(graph_features)
116 |
--------------------------------------------------------------------------------
/nn/dynamic_relational_transformer.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from nn.pooling import HeterogeneousAggregator
7 | from nn.relational_transformer import RTLayer
8 |
9 |
10 | class DynamicRelationalTransformer(nn.Module):
11 | def __init__(
12 | self,
13 | d_in,
14 | d_node,
15 | d_edge,
16 | d_attn_hid,
17 | d_node_hid,
18 | d_edge_hid,
19 | d_out_hid,
20 | d_out,
21 | n_layers,
22 | n_heads,
23 | graph_constructor,
24 | dropout=0.0,
25 | node_update_type="rt",
26 | disable_edge_updates=False,
27 | use_cls_token=True,
28 | pooling_method="cat",
29 | pooling_layer_idx="last",
30 | rev_edge_features=False,
31 | modulate_v=True,
32 | use_ln=True,
33 | tfixit_init=False,
34 | input_channels=3,
35 | num_classes=10,
36 | layer_layout=None,
37 | ):
38 | super().__init__()
39 | assert use_cls_token == (pooling_method == "cls_token")
40 | self.pooling_method = pooling_method
41 | self.pooling_layer_idx = pooling_layer_idx
42 |
43 | self.rev_edge_features = rev_edge_features
44 | self.out_features = d_out
45 | self.num_classes = num_classes
46 | self.construct_graph = hydra.utils.instantiate(
47 | graph_constructor,
48 | d_in=d_in,
49 | d_node=d_node,
50 | d_edge=d_edge,
51 | d_out=d_out,
52 | rev_edge_features=rev_edge_features,
53 | input_channels=input_channels,
54 | num_classes=num_classes,
55 | )
56 | self.use_cls_token = use_cls_token
57 | if use_cls_token:
58 | self.cls_token = nn.Parameter(torch.randn(d_node))
59 |
60 | self.layers = nn.ModuleList(
61 | [
62 | torch.jit.script(
63 | RTLayer(
64 | d_node,
65 | d_edge,
66 | d_attn_hid,
67 | d_node_hid,
68 | d_edge_hid,
69 | n_heads,
70 | float(dropout),
71 | node_update_type=node_update_type,
72 | disable_edge_updates=(
73 | (disable_edge_updates or (i == n_layers - 1))
74 | and pooling_method != "mean_edge"
75 | and pooling_layer_idx != "all"
76 | ),
77 | modulate_v=modulate_v,
78 | use_ln=use_ln,
79 | tfixit_init=tfixit_init,
80 | n_layers=n_layers,
81 | )
82 | )
83 | for i in range(n_layers)
84 | ]
85 | )
86 | num_graph_features = d_node
87 | if pooling_method == "cat" and pooling_layer_idx == "last":
88 | num_graph_features = num_classes * d_node
89 | elif pooling_method == "cat" and pooling_layer_idx == "all":
90 | # NOTE: Only allowed with datasets of fixed architectures
91 | num_graph_features = sum(layer_layout) * d_node
92 | elif pooling_method in ("mean_edge", "max_edge"):
93 | num_graph_features = d_edge
94 |
95 | if pooling_method in (
96 | "mean",
97 | "max",
98 | "cat",
99 | "attentional_aggregation",
100 | "set_transformer",
101 | "graph_multiset_transformer",
102 | ):
103 | self.pool = HeterogeneousAggregator(
104 | d_node,
105 | d_out_hid,
106 | d_node,
107 | pooling_method,
108 | pooling_layer_idx,
109 | input_channels,
110 | num_classes,
111 | )
112 |
113 | self.proj_out = nn.Sequential(
114 | nn.Linear(num_graph_features, d_out_hid),
115 | nn.ReLU(),
116 | nn.Linear(d_out_hid, d_out_hid),
117 | nn.ReLU(),
118 | nn.Linear(d_out_hid, d_out),
119 | )
120 |
121 | def forward(self, batch):
122 | node_features, edge_features, mask, node_mask = self.construct_graph(batch)
123 |
124 | if self.use_cls_token:
125 | node_features = torch.cat(
126 | [
127 | # repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)),
128 | self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1),
129 | node_features,
130 | ],
131 | dim=1,
132 | )
133 | edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0)
134 |
135 | for layer in self.layers:
136 | node_features, edge_features = layer(node_features, edge_features, mask)
137 |
138 | if self.pooling_method == "cls_token":
139 | graph_features = node_features[:, 0]
140 | elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "all":
141 | graph_features = edge_features.mean(dim=(1, 2))
142 | elif self.pooling_method == "max_edge" and self.pooling_layer_idx == "all":
143 | graph_features = edge_features.flatten(1, 2).max(dim=1).values
144 | elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "last":
145 | valid_layer_indices = (
146 | torch.arange(node_mask.shape[1], device=node_mask.device)[None, :]
147 | * node_mask
148 | )
149 | last_layer_indices = valid_layer_indices.topk(
150 | k=self.num_classes, dim=1
151 | ).values.fliplr()
152 | batch_range = torch.arange(node_mask.shape[0], device=node_mask.device)[
153 | :, None
154 | ]
155 | graph_features = edge_features[batch_range, last_layer_indices, :].mean(
156 | dim=(1, 2)
157 | )
158 | else:
159 | # FIXME: Node features are not masked, some contain garbage
160 | graph_features = self.pool(
161 | node_features, batch.layer_layout, node_mask=node_mask
162 | )
163 |
164 | return self.proj_out(graph_features)
165 |
--------------------------------------------------------------------------------
/nn/dynamic_stat_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class DynamicStatFeaturizer(nn.Module):
7 | def __init__(
8 | self,
9 | max_kernel_size,
10 | max_num_hidden_layers,
11 | max_kernel_height,
12 | max_kernel_width,
13 | ):
14 | super().__init__()
15 | self.max_kernel_size = max_kernel_size
16 | self.max_kernel_dimensions = (max_kernel_height, max_kernel_width)
17 | self.max_num_hidden_layers = max_num_hidden_layers
18 | self.max_size = 2 * 7 * (max_num_hidden_layers + 1)
19 |
20 | def forward(self, batch) -> torch.Tensor:
21 | out = []
22 | for i in range(len(batch)):
23 | elem = []
24 | biases = batch[i].x.split(batch[i].layer_layout)[1:]
25 | prods = [
26 | batch[i].layer_layout[j] * batch[i].layer_layout[j + 1]
27 | for j in range(len(batch[i].layer_layout) - 1)
28 | ]
29 | weights = batch[i].edge_attr.split(prods)
30 | for j, (weight, bias) in enumerate(zip(weights, biases)):
31 | if j < len(batch[i].initial_weight_shapes):
32 | weight = weight.unflatten(-1, self.max_kernel_dimensions)
33 | # TODO: Rewrite in a more efficient way
34 | start0 = (
35 | self.max_kernel_dimensions[0]
36 | - batch[i].initial_weight_shapes[j][0]
37 | ) // 2
38 | end0 = start0 + batch[i].initial_weight_shapes[j][0]
39 | start1 = (
40 | self.max_kernel_dimensions[1]
41 | - batch[i].initial_weight_shapes[j][1]
42 | ) // 2
43 | end1 = start1 + batch[i].initial_weight_shapes[j][1]
44 | weight = weight[:, start0:end0, start1:end1]
45 | elem.append(self.compute_stats(weight))
46 | elem.append(self.compute_stats(bias))
47 | elem = torch.cat(elem, dim=-1)
48 | elem = F.pad(elem, (0, self.max_size - elem.shape[0]))
49 | out.append(elem)
50 |
51 | return torch.stack(out, dim=0)
52 |
53 | def compute_stats(self, tensor: torch.Tensor) -> torch.Tensor:
54 | """Computes the statistics of the given tensor."""
55 | mean = tensor.mean() # (B, C)
56 | var = tensor.var() # (B, C)
57 | q = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]).to(tensor.device)
58 | quantiles = torch.quantile(tensor, q) # (5, B, C)
59 | return torch.stack([mean, var, *quantiles], dim=-1) # (B, C, 7)
60 |
61 |
62 | class DynamicStatNet(nn.Module):
63 | """Outputs a scalar."""
64 |
65 | def __init__(
66 | self,
67 | h_size,
68 | dropout=0.0,
69 | max_kernel_size=49,
70 | max_num_hidden_layers=5,
71 | max_kernel_height=7,
72 | max_kernel_width=7,
73 | ):
74 | super().__init__()
75 | num_features = 2 * 7 * (max_num_hidden_layers + 1)
76 | self.hypernetwork = nn.Sequential(
77 | # DynamicStatFeaturizer(max_kernel_size, max_num_hidden_layers, max_kernel_height, max_kernel_width),
78 | nn.Linear(num_features, h_size),
79 | nn.ReLU(),
80 | nn.Dropout(p=dropout),
81 | nn.Linear(h_size, h_size),
82 | nn.ReLU(),
83 | nn.Dropout(p=dropout),
84 | nn.Linear(h_size, 1),
85 | )
86 |
87 | def forward(self, batch):
88 | stacked_weights = torch.stack(batch[0], dim=0)
89 | stacked_biases = torch.stack(batch[1], dim=0)
90 | stacked_params = torch.cat([stacked_weights, stacked_biases], dim=-1)
91 | return self.hypernetwork(stacked_params)
92 |
--------------------------------------------------------------------------------
/nn/inr.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | from typing import Optional
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from rff.layers import GaussianEncoding, PositionalEncoding
8 | from torch import nn
9 |
10 |
11 | class Sine(nn.Module):
12 | def __init__(self, w0=1.0):
13 | super().__init__()
14 | self.w0 = w0
15 |
16 | def forward(self, x: torch.Tensor) -> torch.Tensor:
17 | return torch.sin(self.w0 * x)
18 |
19 |
20 | def params_to_tensor(params):
21 | return torch.cat([p.flatten() for p in params]), [p.shape for p in params]
22 |
23 |
24 | def tensor_to_params(tensor, shapes):
25 | params = []
26 | start = 0
27 | for shape in shapes:
28 | size = torch.prod(torch.tensor(shape)).item()
29 | param = tensor[start : start + size].reshape(shape)
30 | params.append(param)
31 | start += size
32 | return tuple(params)
33 |
34 |
35 | def wrap_func(func, shapes):
36 | def wrapped_func(params, *args, **kwargs):
37 | params = tensor_to_params(params, shapes)
38 | return func(params, *args, **kwargs)
39 |
40 | return wrapped_func
41 |
42 |
43 | class Siren(nn.Module):
44 | def __init__(
45 | self,
46 | dim_in,
47 | dim_out,
48 | w0=30.0,
49 | c=6.0,
50 | is_first=False,
51 | use_bias=True,
52 | activation=None,
53 | ):
54 | super().__init__()
55 | self.w0 = w0
56 | self.c = c
57 | self.dim_in = dim_in
58 | self.dim_out = dim_out
59 | self.is_first = is_first
60 |
61 | weight = torch.zeros(dim_out, dim_in)
62 | bias = torch.zeros(dim_out) if use_bias else None
63 | self.init_(weight, bias, c=c, w0=w0)
64 |
65 | self.weight = nn.Parameter(weight)
66 | self.bias = nn.Parameter(bias) if use_bias else None
67 | self.activation = Sine(w0) if activation is None else activation
68 |
69 | def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
70 | dim = self.dim_in
71 |
72 | w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
73 | weight.uniform_(-w_std, w_std)
74 |
75 | if bias is not None:
76 | # bias.uniform_(-w_std, w_std)
77 | bias.zero_()
78 |
79 | def forward(self, x: torch.Tensor) -> torch.Tensor:
80 | out = F.linear(x, self.weight, self.bias)
81 | out = self.activation(out)
82 | return out
83 |
84 |
85 | class INR(nn.Module):
86 | def __init__(
87 | self,
88 | in_features: int = 2,
89 | n_layers: int = 3,
90 | hidden_features: int = 32,
91 | out_features: int = 1,
92 | pe_features: Optional[int] = None,
93 | fix_pe=True,
94 | ):
95 | super().__init__()
96 |
97 | if pe_features is not None:
98 | if fix_pe:
99 | self.layers = [PositionalEncoding(sigma=10, m=pe_features)]
100 | encoded_dim = in_features * pe_features * 2
101 | else:
102 | self.layers = [
103 | GaussianEncoding(
104 | sigma=10, input_size=in_features, encoded_size=pe_features
105 | )
106 | ]
107 | encoded_dim = pe_features * 2
108 | self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features))
109 | else:
110 | self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)]
111 | for i in range(n_layers - 2):
112 | self.layers.append(Siren(hidden_features, hidden_features))
113 | self.layers.append(nn.Linear(hidden_features, out_features))
114 | self.seq = nn.Sequential(*self.layers)
115 | self.num_layers = len(self.layers)
116 |
117 | def forward(self, x: torch.Tensor) -> torch.Tensor:
118 | return self.seq(x) + 0.5
119 |
120 |
121 | class INRPerLayer(INR):
122 | def forward(self, x: torch.Tensor) -> torch.Tensor:
123 | nodes = [x]
124 | for layer in self.seq:
125 | nodes.append(layer(nodes[-1]))
126 | nodes[-1] = nodes[-1] + 0.5
127 | return nodes
128 |
129 |
130 | def make_functional(mod, disable_autograd_tracking=False):
131 | params_dict = dict(mod.named_parameters())
132 | params_names = params_dict.keys()
133 | params_values = tuple(params_dict.values())
134 |
135 | stateless_mod = copy.deepcopy(mod)
136 | stateless_mod.to("meta")
137 |
138 | def fmodel(new_params_values, *args, **kwargs):
139 | new_params_dict = {
140 | name: value for name, value in zip(params_names, new_params_values)
141 | }
142 | return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
143 |
144 | if disable_autograd_tracking:
145 | params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
146 | return fmodel, params_values
147 |
--------------------------------------------------------------------------------
/nn/nfn/__init__.py:
--------------------------------------------------------------------------------
1 | from nn.nfn import common, layers
2 |
--------------------------------------------------------------------------------
/nn/nfn/common/__init__.py:
--------------------------------------------------------------------------------
1 | from nn.nfn.common.data import (
2 | ArraySpec,
3 | NetworkSpec,
4 | WeightSpaceFeatures,
5 | network_spec_from_wsfeat,
6 | params_to_state_dicts,
7 | state_dict_to_tensors,
8 | )
9 |
--------------------------------------------------------------------------------
/nn/nfn/common/data.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from collections import OrderedDict
3 | from dataclasses import dataclass
4 | from typing import List, Tuple
5 |
6 |
7 | @dataclass(frozen=True)
8 | class ArraySpec:
9 | shape: Tuple[int, ...]
10 |
11 |
12 | @dataclass(frozen=True)
13 | class NetworkSpec:
14 | weight_spec: List[ArraySpec]
15 | bias_spec: List[ArraySpec]
16 |
17 | def get_io(self):
18 | # n_in, n_out
19 | return self.weight_spec[0].shape[1], self.weight_spec[-1].shape[0]
20 |
21 | def get_num_params(self):
22 | """Returns the number of parameters in the network."""
23 | num_params = 0
24 | for w, b in zip(self.weight_spec, self.bias_spec):
25 | num_weights = 1
26 | for dim in w.shape:
27 | assert dim != -1
28 | num_weights *= dim
29 | num_biases = 1
30 | for dim in b.shape:
31 | assert dim != -1
32 | num_biases *= dim
33 | num_params += num_weights + num_biases
34 | return num_params
35 |
36 | def __len__(self):
37 | return len(self.weight_spec)
38 |
39 |
40 | class WeightSpaceFeatures(collections.abc.Sequence):
41 | def __init__(self, weights, biases):
42 | # No mutability
43 | if isinstance(weights, list):
44 | weights = tuple(weights)
45 | if isinstance(biases, list):
46 | biases = tuple(biases)
47 | self.weights = weights
48 | self.biases = biases
49 |
50 | def __len__(self):
51 | return len(self.weights)
52 |
53 | def __iter__(self):
54 | return zip(self.weights, self.biases)
55 |
56 | def __getitem__(self, idx):
57 | return (self.weights[idx], self.biases[idx])
58 |
59 | def __add__(self, other):
60 | out_weights = tuple(w1 + w2 for w1, w2 in zip(self.weights, other.weights))
61 | out_biases = tuple(b1 + b2 for b1, b2 in zip(self.biases, other.biases))
62 | return WeightSpaceFeatures(out_weights, out_biases)
63 |
64 | def __mul__(self, other):
65 | if isinstance(other, WeightSpaceFeatures):
66 | weights = tuple(w1 * w2 for w1, w2 in zip(self.weights, other.weights))
67 | biases = tuple(b1 * b2 for b1, b2 in zip(self.biases, other.biases))
68 | return WeightSpaceFeatures(weights, biases)
69 | return self.map(lambda x: x * other)
70 |
71 | def detach(self):
72 | """Returns a copy with detached tensors."""
73 | return WeightSpaceFeatures(
74 | tuple(w.detach() for w in self.weights),
75 | tuple(b.detach() for b in self.biases),
76 | )
77 |
78 | def map(self, func):
79 | """Applies func to each weight and bias tensor."""
80 | return WeightSpaceFeatures(
81 | tuple(func(w) for w in self.weights), tuple(func(b) for b in self.biases)
82 | )
83 |
84 | def to(self, device):
85 | """Moves all tensors to device."""
86 | return WeightSpaceFeatures(
87 | tuple(w.to(device) for w in self.weights),
88 | tuple(b.to(device) for b in self.biases),
89 | )
90 |
91 | @classmethod
92 | def from_zipped(cls, weight_and_biases):
93 | """Converts a list of (weights, biases) into a WeightSpaceFeatures object."""
94 | weights, biases = zip(*weight_and_biases)
95 | return cls(weights, biases)
96 |
97 |
98 | def state_dict_to_tensors(state_dict):
99 | """Converts a state dict into two lists of equal length:
100 | 1. list of weight tensors
101 | 2. list of biases, or None if no bias
102 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...]
103 | """
104 | weights, biases = [], []
105 | keys = list(state_dict.keys())
106 | i = 0
107 | while i < len(keys):
108 | weights.append(state_dict[keys[i]][None])
109 | i += 1
110 | assert keys[i].endswith("bias")
111 | biases.append(state_dict[keys[i]][None])
112 | i += 1
113 | return weights, biases
114 |
115 |
116 | def params_to_state_dicts(keys, wsfeat: WeightSpaceFeatures) -> List[OrderedDict]:
117 | """Converts a list of weight tensors and a list of biases into a state dict.
118 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...]
119 | """
120 | batch_size = wsfeat.weights[0].shape[0]
121 | assert wsfeat.weights[0].shape[1] == 1
122 | state_dicts = [OrderedDict() for _ in range(batch_size)]
123 | layer_idx = 0
124 | while layer_idx < len(keys):
125 | for batch_idx in range(batch_size):
126 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.weights[layer_idx // 2][
127 | batch_idx
128 | ].squeeze(0)
129 | layer_idx += 1
130 | for batch_idx in range(batch_size):
131 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.biases[layer_idx // 2][
132 | batch_idx
133 | ].squeeze(0)
134 | layer_idx += 1
135 | return state_dicts
136 |
137 |
138 | def network_spec_from_wsfeat(
139 | wsfeat: WeightSpaceFeatures, set_all_dims=False
140 | ) -> NetworkSpec:
141 | assert len(wsfeat.weights) == len(wsfeat.biases)
142 | weight_specs = []
143 | bias_specs = []
144 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)):
145 | # -1 means the dimension has symmetry, hence summed/broadcast over.
146 | # Recall that weight has two leading dimension (BS, channels, ...)
147 | if weight.dim() == 4:
148 | weight_shape = [-1, -1]
149 | elif weight.dim() == 6:
150 | weight_shape = [-1, -1, weight.shape[-2], weight.shape[-1]]
151 | else:
152 | raise ValueError(f"Unsupported weight dim: {weight.dim()}")
153 | if i == 0 or set_all_dims:
154 | weight_shape[1] = weight.shape[3]
155 | if i == len(wsfeat) - 1 or set_all_dims:
156 | weight_shape[0] = weight.shape[2]
157 | weight_specs.append(ArraySpec(tuple(weight_shape)))
158 | bias_shape = (-1,)
159 | if i == len(wsfeat) - 1 or set_all_dims:
160 | bias_shape = (bias.shape[-1],)
161 | bias_specs.append(ArraySpec(bias_shape))
162 | return NetworkSpec(weight_specs, bias_specs)
163 |
--------------------------------------------------------------------------------
/nn/nfn/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from nn.nfn.layers.encoding import GaussianFourierFeatureTransform, IOSinusoidalEncoding
2 | from nn.nfn.layers.layers import HNPLinear, HNPPool, NPLinear, NPPool, Pointwise
3 | from nn.nfn.layers.misc_layers import (
4 | FlattenWeights,
5 | LearnedScale,
6 | ResBlock,
7 | StatFeaturizer,
8 | TupleOp,
9 | UnflattenWeights,
10 | )
11 | from nn.nfn.layers.regularize import ChannelDropout, ParamLayerNorm, SimpleLayerNorm
12 |
--------------------------------------------------------------------------------
/nn/nfn/layers/encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from nn.nfn.common import NetworkSpec, WeightSpaceFeatures
7 |
8 |
9 | class GaussianFourierFeatureTransform(nn.Module):
10 | """
11 | Given an input of size [batches, num_input_channels, ...],
12 | returns a tensor of size [batches, mapping_size*2, ...].
13 | """
14 |
15 | def __init__(self, in_channels, mapping_size=256, scale=10):
16 | super().__init__()
17 | # del network_spec
18 | self._num_input_channels = in_channels
19 | self._mapping_size = mapping_size
20 | self.out_channels = mapping_size * 2
21 | self.register_buffer("_B", torch.randn((in_channels, mapping_size)) * scale)
22 |
23 | def encode_tensor(self, x):
24 | assert len(x.shape) >= 3
25 | # Put channels dimension last.
26 | x = (x.transpose(1, -1) @ self._B).transpose(1, -1)
27 | x = 2 * math.pi * x
28 | return torch.cat([torch.sin(x), torch.cos(x)], dim=1)
29 |
30 | def forward(self, wsfeat):
31 | out_weights, out_biases = [], []
32 | for weight, bias in wsfeat:
33 | out_weights.append(self.encode_tensor(weight))
34 | out_biases.append(self.encode_tensor(bias))
35 | return WeightSpaceFeatures(out_weights, out_biases)
36 |
37 |
38 | def fourier_encode(x, max_freq, num_bands=4):
39 | x = x.unsqueeze(-1)
40 | device, dtype, orig_x = x.device, x.dtype, x
41 |
42 | scales = torch.linspace(1.0, max_freq / 2, num_bands, device=device, dtype=dtype)
43 | scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
44 |
45 | x = x * scales * math.pi
46 | x = torch.cat([x.sin(), x.cos()], dim=-1)
47 | x = torch.cat((x, orig_x), dim=-1)
48 | return x
49 |
50 |
51 | class IOSinusoidalEncoding(nn.Module):
52 | def __init__(self, layer_layout, max_freq=10, num_bands=6, enc_layers=True):
53 | super().__init__()
54 | self.layer_layout = layer_layout
55 | self.max_freq = max_freq
56 | self.num_bands = num_bands
57 | self.enc_layers = enc_layers
58 | self.n_in, self.n_out = layer_layout[0], layer_layout[-1]
59 |
60 | def forward(self, wsfeat: WeightSpaceFeatures):
61 | device, dtype = wsfeat.weights[0].device, wsfeat.weights[0].dtype
62 | L = len(self.layer_layout) - 1
63 | layernum = torch.linspace(-1.0, 1.0, steps=L, device=device, dtype=dtype)
64 | if self.enc_layers:
65 | layer_enc = fourier_encode(
66 | layernum, self.max_freq, self.num_bands
67 | ) # (L, 2 * num_bands + 1)
68 | else:
69 | layer_enc = torch.zeros(
70 | (L, 2 * self.num_bands + 1), device=device, dtype=dtype
71 | )
72 | inpnum = torch.linspace(-1.0, 1.0, steps=self.n_in, device=device, dtype=dtype)
73 | inp_enc = fourier_encode(
74 | inpnum, self.max_freq, self.num_bands
75 | ) # (n_in, 2 * num_bands + 1)
76 | outnum = torch.linspace(-1.0, 1.0, steps=self.n_out, device=device, dtype=dtype)
77 | out_enc = fourier_encode(
78 | outnum, self.max_freq, self.num_bands
79 | ) # (n_out, 2 * num_bands + 1)
80 |
81 | d = 2 * self.num_bands + 1
82 |
83 | out_weights, out_biases = [], []
84 | for i, (weight, bias) in enumerate(wsfeat):
85 | b, _, *axes = weight.shape
86 | enc_i = layer_enc[i].unsqueeze(0)[..., None, None]
87 | for _ in axes[2:]:
88 | enc_i = enc_i.unsqueeze(-1)
89 | enc_i = enc_i.expand(b, d, *axes) # (B, d, n_row, n_col, ...)
90 | bias_enc_i = layer_enc[i][None, :, None].expand(
91 | b, d, bias.shape[-1]
92 | ) # (B, d, n_row)
93 | if i == 0:
94 | # weight has shape (B, c_in, n_out, n_in)
95 | inp_enc_i = (
96 | inp_enc.transpose(0, 1).unsqueeze(0).unsqueeze(-2)
97 | ) # (1, d, 1, n_col)
98 | for _ in axes[2:]:
99 | inp_enc_i = inp_enc_i.unsqueeze(-1)
100 | enc_i = enc_i + inp_enc_i
101 | if i == len(wsfeat) - 1:
102 | out_enc_i = (
103 | out_enc.transpose(0, 1).unsqueeze(0).unsqueeze(-1)
104 | ) # (1, d, n_row, 1)
105 | for _ in axes[2:]:
106 | out_enc_i = inp_enc_i.unsqueeze(-1)
107 | enc_i = enc_i + out_enc_i
108 | bias_enc_i = bias_enc_i + out_enc.transpose(0, 1).unsqueeze(0)
109 | out_weights.append(torch.cat([weight, enc_i], dim=1))
110 | out_biases.append(torch.cat([bias, bias_enc_i], dim=1))
111 | return WeightSpaceFeatures(out_weights, out_biases)
112 |
113 | def num_out_chan(self, in_chan):
114 | return in_chan + (2 * self.num_bands + 1)
115 |
--------------------------------------------------------------------------------
/nn/nfn/layers/layer_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from einops import rearrange
4 | from torch import nn
5 |
6 | from nn.nfn.common import WeightSpaceFeatures
7 |
8 |
9 | def set_init_(*layers):
10 | in_chan = 0
11 | for layer in layers:
12 | if isinstance(layer, (nn.Conv2d, nn.Conv1d)):
13 | in_chan += layer.in_channels
14 | elif isinstance(layer, nn.Linear):
15 | in_chan += layer.in_features
16 | else:
17 | raise NotImplementedError(f"Unknown layer type {type(layer)}")
18 | bd = math.sqrt(1 / in_chan)
19 | for layer in layers:
20 | nn.init.uniform_(layer.weight, -bd, bd)
21 | if layer.bias is not None:
22 | nn.init.uniform_(layer.bias, -bd, bd)
23 |
24 |
25 | def shape_wsfeat_symmetry(params, weight_shapes):
26 | """Reshape so last 2 dims have symmetry, channel dims have all nonsymmetry.
27 | E.g., for conv weights we reshape (B, C, out, in, h, w) -> (B, C * h * w, out, in)
28 | """
29 | weights, bias = params.weights, params.biases
30 | reshaped_weights = []
31 | for weight, weight_spec in zip(weights, weight_shapes):
32 | if len(weight_spec) == 2: # mlp weight matrix:
33 | reshaped_weights.append(weight)
34 | else:
35 | reshaped_weights.append(rearrange(weight, "b c o i h w -> b (c h w) o i"))
36 | return WeightSpaceFeatures(reshaped_weights, bias)
37 |
38 |
39 | def unshape_wsfeat_symmetry(params, weight_shapes):
40 | """Reverse shape_params_symmetry"""
41 | weights, bias = params.weights, params.biases
42 | unreshaped_weights = []
43 | for weight, weight_spec in zip(weights, weight_shapes):
44 | if len(weight_spec) == 2: # mlp weight matrix:
45 | unreshaped_weights.append(weight)
46 | else:
47 | _, _, h, w = weight_spec
48 | unreshaped_weights.append(
49 | rearrange(weight, "b (c h w) o i -> b c o i h w", h=h, w=w)
50 | )
51 | return WeightSpaceFeatures(unreshaped_weights, bias)
52 |
--------------------------------------------------------------------------------
/nn/nfn/layers/misc_layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 |
5 | from nn.nfn.common import NetworkSpec, WeightSpaceFeatures
6 | from nn.nfn.layers.layer_utils import shape_wsfeat_symmetry
7 |
8 |
9 | class FlattenWeights(nn.Module):
10 | def __init__(self, network_spec):
11 | super().__init__()
12 | self.network_spec = network_spec
13 |
14 | def forward(self, wsfeat):
15 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec)
16 | outs = []
17 | for w, b in wsfeat:
18 | outs.append(torch.flatten(w, start_dim=2).transpose(1, 2))
19 | outs.append(b.transpose(1, 2))
20 | return torch.cat(outs, dim=1) # (B, N, C)
21 |
22 |
23 | class UnflattenWeights(nn.Module):
24 | def __init__(self, network_spec: NetworkSpec):
25 | super().__init__()
26 | self.network_spec = network_spec
27 |
28 | def forward(self, x: torch.Tensor) -> WeightSpaceFeatures:
29 | # x.shape == (bs, num weights and biases)
30 | out_weights, out_biases = [], []
31 | curr_idx = 0
32 | for weight_spec, bias_spec in zip(
33 | self.network_spec.weight_spec, self.network_spec.bias_spec
34 | ):
35 | num_wts = np.prod(weight_spec.shape)
36 | # reshape to (bs, 1, *weight_spec.shape) where 1 is channels.
37 | wt = (
38 | x[:, curr_idx : curr_idx + num_wts]
39 | .view(-1, *weight_spec.shape)
40 | .unsqueeze(1)
41 | )
42 | out_weights.append(wt)
43 | curr_idx += num_wts
44 | num_bs = np.prod(bias_spec.shape)
45 | bs = (
46 | x[:, curr_idx : curr_idx + num_bs]
47 | .view(-1, *bias_spec.shape)
48 | .unsqueeze(1)
49 | )
50 | out_biases.append(bs)
51 | curr_idx += num_bs
52 | return WeightSpaceFeatures(out_weights, out_biases)
53 |
54 |
55 | class LearnedScale(nn.Module):
56 | def __init__(self, layer_layout, init_scale):
57 | super().__init__()
58 | self.weight_scales = nn.ParameterList()
59 | self.bias_scales = nn.ParameterList()
60 | for _ in range(len(layer_layout) - 1):
61 | self.weight_scales.append(
62 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32))
63 | )
64 | self.bias_scales.append(
65 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32))
66 | )
67 |
68 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
69 | out_weights, out_biases = [], []
70 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)):
71 | out_weights.append(weight * self.weight_scales[i])
72 | out_biases.append(bias * self.bias_scales[i])
73 | return WeightSpaceFeatures(out_weights, out_biases)
74 |
75 |
76 | class ResBlock(nn.Module):
77 | def __init__(self, base_layer, activation, dropout, norm):
78 | super().__init__()
79 | self.base_layer = base_layer
80 | self.activation = activation
81 | self.dropout = None
82 | if dropout > 0:
83 | self.dropout = TupleOp(nn.Dropout(dropout))
84 | self.norm = norm
85 |
86 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures:
87 | res = self.activation(self.base_layer(self.norm(x)))
88 | if self.dropout is not None:
89 | res = self.dropout(res)
90 | return x + res
91 |
92 |
93 | class StatFeaturizer(nn.Module):
94 | def forward(self, wsfeat: WeightSpaceFeatures) -> torch.Tensor:
95 | out = []
96 | for weight, bias in wsfeat:
97 | out.append(self.compute_stats(weight))
98 | out.append(self.compute_stats(bias))
99 | return torch.cat(out, dim=-1)
100 |
101 | def compute_stats(self, tensor: torch.Tensor) -> torch.Tensor:
102 | """Computes the statistics of the given tensor."""
103 | tensor = torch.flatten(tensor, start_dim=2) # (B, C, H*W)
104 | mean = tensor.mean(-1) # (B, C)
105 | var = tensor.var(-1) # (B, C)
106 | q = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]).to(tensor.device)
107 | quantiles = torch.quantile(tensor, q, dim=-1) # (5, B, C)
108 | return torch.stack([mean, var, *quantiles], dim=-1) # (B, C, 7)
109 |
110 | @staticmethod
111 | def get_num_outs(network_spec):
112 | """Returns the number of outputs of the StatFeaturizer layer."""
113 | return 2 * len(network_spec) * 7
114 |
115 |
116 | class TupleOp(nn.Module):
117 | def __init__(self, op):
118 | super().__init__()
119 | self.op = op
120 |
121 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
122 | out_weights = [self.op(w) for w in wsfeat.weights]
123 | out_bias = [self.op(b) for b in wsfeat.biases]
124 | return WeightSpaceFeatures(out_weights, out_bias)
125 |
--------------------------------------------------------------------------------
/nn/nfn/layers/regularize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from einops import rearrange
3 | from torch import nn
4 |
5 | from nn.nfn.common import NetworkSpec, WeightSpaceFeatures
6 | from nn.nfn.layers.layer_utils import shape_wsfeat_symmetry, unshape_wsfeat_symmetry
7 |
8 |
9 | class ChannelDropout(nn.Module):
10 | def __init__(self, dropout):
11 | super().__init__()
12 | self.dropout = dropout
13 | self.matrix_dropout = nn.Dropout2d(dropout)
14 | self.bias_dropout = nn.Dropout(dropout)
15 |
16 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures:
17 | weights = [self.process_matrix(w) for w in x.weights]
18 | bias = [self.bias_dropout(b) for b in x.biases]
19 | return WeightSpaceFeatures(weights, bias)
20 |
21 | def process_matrix(self, mat):
22 | shape = mat.shape
23 | is_conv = len(shape) > 4
24 | if is_conv:
25 | _, _, _, _, h, w = shape
26 | mat = rearrange(mat, "b c o i h w -> b (c h w) o i")
27 | mat = self.matrix_dropout(mat)
28 | if is_conv:
29 | mat = rearrange(mat, "b (c h w) o i -> b c o i h w", h=h, w=w)
30 | return mat
31 |
32 |
33 | class SimpleLayerNorm(nn.Module):
34 | def __init__(self, network_spec, channels):
35 | super().__init__()
36 | self.network_spec = network_spec
37 | self.channels = channels
38 | for i in range(len(network_spec)):
39 | eff_channels = int(
40 | channels * np.prod(network_spec.weight_spec[i].shape[2:])
41 | )
42 | self.add_module(
43 | f"norm{i}_w", nn.LayerNorm(normalized_shape=(eff_channels,))
44 | )
45 | self.add_module(f"norm{i}_v", nn.LayerNorm(normalized_shape=(channels,)))
46 |
47 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
48 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec)
49 | in_weights, in_biases = [], []
50 | for weight, bias in wsfeat:
51 | in_weights.append(weight.transpose(-3, -1))
52 | in_biases.append(bias.transpose(-1, -2))
53 | out_weights, out_biases = [], []
54 | for i, (weight, bias) in enumerate(zip(in_weights, in_biases)):
55 | w_norm = getattr(self, f"norm{i}_w")
56 | v_norm = getattr(self, f"norm{i}_v")
57 | out_weights.append(w_norm(weight))
58 | out_biases.append(v_norm(bias))
59 | out_weights = [w.transpose(-3, -1) for w in out_weights]
60 | out_biases = [b.transpose(-1, -2) for b in out_biases]
61 | return unshape_wsfeat_symmetry(
62 | WeightSpaceFeatures(out_weights, out_biases), self.network_spec
63 | )
64 |
65 | def __repr__(self):
66 | return f"SimpleLayerNorm(channels={self.channels})"
67 |
68 |
69 | class ParamLayerNorm(nn.Module):
70 | def __init__(self, network_spec: NetworkSpec, channels):
71 | # TODO: This doesn't work for convs yet.
72 | super().__init__()
73 | self.n_in, self.n_out = network_spec.get_io()
74 | self.channels = channels
75 | for i in range(len(network_spec)):
76 | if i == 0:
77 | w_shape = (channels, self.n_in)
78 | v_shape = (channels,)
79 | elif i == len(network_spec) - 1:
80 | w_shape = (self.n_out, channels)
81 | v_shape = (channels, self.n_out)
82 | else:
83 | w_shape = (channels,)
84 | v_shape = (channels,)
85 | self.add_module(f"norm{i}_w", nn.LayerNorm(normalized_shape=w_shape))
86 | self.add_module(f"norm{i}_v", nn.LayerNorm(normalized_shape=v_shape))
87 |
88 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
89 | out_weights, out_biases = [], []
90 | for i, (weight, bias) in enumerate(wsfeat):
91 | w_norm = getattr(self, f"norm{i}_w")
92 | v_norm = getattr(self, f"norm{i}_v")
93 | if i == 0:
94 | out_weights.append(w_norm(weight.transpose(-3, -2)).transpose(-3, -2))
95 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2))
96 | elif i == len(wsfeat) - 1:
97 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1))
98 | out_biases.append(v_norm(bias))
99 | else:
100 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1))
101 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2))
102 | return WeightSpaceFeatures(out_weights, out_biases)
103 |
--------------------------------------------------------------------------------
/nn/original_nfn/__init__.py:
--------------------------------------------------------------------------------
1 | from nn.original_nfn import common, layers
2 |
--------------------------------------------------------------------------------
/nn/original_nfn/common/__init__.py:
--------------------------------------------------------------------------------
1 | from nn.original_nfn.common.data import (
2 | ArraySpec,
3 | NetworkSpec,
4 | WeightSpaceFeatures,
5 | network_spec_from_wsfeat,
6 | params_to_func_params,
7 | params_to_state_dicts,
8 | state_dict_to_tensors,
9 | )
10 |
--------------------------------------------------------------------------------
/nn/original_nfn/common/data.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from collections import OrderedDict
3 | from dataclasses import dataclass
4 | from typing import List, Tuple
5 |
6 |
7 | @dataclass(frozen=True)
8 | class ArraySpec:
9 | shape: Tuple[int, ...]
10 |
11 |
12 | @dataclass(frozen=True)
13 | class NetworkSpec:
14 | weight_spec: List[ArraySpec]
15 | bias_spec: List[ArraySpec]
16 |
17 | def get_io(self):
18 | # n_in, n_out
19 | return self.weight_spec[0].shape[1], self.weight_spec[-1].shape[0]
20 |
21 | def get_num_params(self):
22 | """Returns the number of parameters in the network."""
23 | num_params = 0
24 | for w, b in zip(self.weight_spec, self.bias_spec):
25 | num_weights = 1
26 | for dim in w.shape:
27 | assert dim != -1
28 | num_weights *= dim
29 | num_biases = 1
30 | for dim in b.shape:
31 | assert dim != -1
32 | num_biases *= dim
33 | num_params += num_weights + num_biases
34 | return num_params
35 |
36 | def __len__(self):
37 | return len(self.weight_spec)
38 |
39 |
40 | class WeightSpaceFeatures(collections.abc.Sequence):
41 | def __init__(self, weights, biases):
42 | # No mutability
43 | if isinstance(weights, list):
44 | weights = tuple(weights)
45 | if isinstance(biases, list):
46 | biases = tuple(biases)
47 | self.weights = weights
48 | self.biases = biases
49 |
50 | def __len__(self):
51 | return len(self.weights)
52 |
53 | def __iter__(self):
54 | return zip(self.weights, self.biases)
55 |
56 | def __getitem__(self, idx):
57 | return (self.weights[idx], self.biases[idx])
58 |
59 | def __add__(self, other):
60 | out_weights = tuple(w1 + w2 for w1, w2 in zip(self.weights, other.weights))
61 | out_biases = tuple(b1 + b2 for b1, b2 in zip(self.biases, other.biases))
62 | return WeightSpaceFeatures(out_weights, out_biases)
63 |
64 | def __mul__(self, other):
65 | if isinstance(other, WeightSpaceFeatures):
66 | weights = tuple(w1 * w2 for w1, w2 in zip(self.weights, other.weights))
67 | biases = tuple(b1 * b2 for b1, b2 in zip(self.biases, other.biases))
68 | return WeightSpaceFeatures(weights, biases)
69 | return self.map(lambda x: x * other)
70 |
71 | def detach(self):
72 | """Returns a copy with detached tensors."""
73 | return WeightSpaceFeatures(
74 | tuple(w.detach() for w in self.weights),
75 | tuple(b.detach() for b in self.biases),
76 | )
77 |
78 | def map(self, func):
79 | """Applies func to each weight and bias tensor."""
80 | return WeightSpaceFeatures(
81 | tuple(func(w) for w in self.weights), tuple(func(b) for b in self.biases)
82 | )
83 |
84 | def to(self, device):
85 | """Moves all tensors to device."""
86 | return WeightSpaceFeatures(
87 | tuple(w.to(device, non_blocking=True) for w in self.weights),
88 | tuple(b.to(device, non_blocking=True) for b in self.biases),
89 | )
90 |
91 | @classmethod
92 | def from_zipped(cls, weight_and_biases):
93 | """Converts a list of (weights, biases) into a WeightSpaceFeatures object."""
94 | weights, biases = zip(*weight_and_biases)
95 | return cls(weights, biases)
96 |
97 |
98 | def state_dict_to_tensors(state_dict):
99 | """Converts a state dict into two lists of equal length:
100 | 1. list of weight tensors
101 | 2. list of biases, or None if no bias
102 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...]
103 | """
104 | weights, biases = [], []
105 | keys = list(state_dict.keys())
106 | i = 0
107 | while i < len(keys):
108 | weights.append(state_dict[keys[i]][None])
109 | i += 1
110 | assert keys[i].endswith("bias")
111 | biases.append(state_dict[keys[i]][None])
112 | i += 1
113 | return weights, biases
114 |
115 |
116 | def params_to_state_dicts(keys, wsfeat: WeightSpaceFeatures) -> List[OrderedDict]:
117 | """Converts a list of weight tensors and a list of biases into a state dict.
118 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...]
119 | """
120 | batch_size = wsfeat.weights[0].shape[0]
121 | assert wsfeat.weights[0].shape[1] == 1
122 | state_dicts = [OrderedDict() for _ in range(batch_size)]
123 | layer_idx = 0
124 | while layer_idx < len(keys):
125 | for batch_idx in range(batch_size):
126 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.weights[layer_idx // 2][
127 | batch_idx
128 | ].squeeze(0)
129 | layer_idx += 1
130 | for batch_idx in range(batch_size):
131 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.biases[layer_idx // 2][
132 | batch_idx
133 | ].squeeze(0)
134 | layer_idx += 1
135 | return state_dicts
136 |
137 |
138 | def network_spec_from_wsfeat(
139 | wsfeat: WeightSpaceFeatures, set_all_dims=False
140 | ) -> NetworkSpec:
141 | assert len(wsfeat.weights) == len(wsfeat.biases)
142 | weight_specs = []
143 | bias_specs = []
144 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)):
145 | # -1 means the dimension has symmetry, hence summed/broadcast over.
146 | # Recall that weight has two leading dimension (BS, channels, ...)
147 | if weight.dim() == 4:
148 | weight_shape = [-1, -1]
149 | elif weight.dim() == 6:
150 | weight_shape = [-1, -1, weight.shape[-2], weight.shape[-1]]
151 | else:
152 | raise ValueError(f"Unsupported weight dim: {weight.dim()}")
153 | if i == 0 or set_all_dims:
154 | weight_shape[1] = weight.shape[3]
155 | if i == len(wsfeat) - 1 or set_all_dims:
156 | weight_shape[0] = weight.shape[2]
157 | weight_specs.append(ArraySpec(tuple(weight_shape)))
158 | bias_shape = (-1,)
159 | if i == len(wsfeat) - 1 or set_all_dims:
160 | bias_shape = (bias.shape[-1],)
161 | bias_specs.append(ArraySpec(bias_shape))
162 | return NetworkSpec(weight_specs, bias_specs)
163 |
164 |
165 | def params_to_func_params(params: WeightSpaceFeatures):
166 | """Convert our WeightSpaceFeatures object to a tuple of parameters for the functional model."""
167 | out_params = []
168 | for weight, bias in params:
169 | if weight.shape[1] == 1:
170 | weight, bias = weight.squeeze(1), bias.squeeze(1)
171 | out_params.append(weight)
172 | out_params.append(bias)
173 | return tuple(out_params)
174 |
--------------------------------------------------------------------------------
/nn/original_nfn/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from nn.original_nfn.layers.encoding import (
2 | GaussianFourierFeatureTransform,
3 | IOSinusoidalEncoding,
4 | LearnedPosEmbedding,
5 | )
6 | from nn.original_nfn.layers.layers import (
7 | ChannelLinear,
8 | HNPLinear,
9 | HNPPool,
10 | NPAttention,
11 | NPLinear,
12 | NPPool,
13 | Pointwise,
14 | )
15 | from nn.original_nfn.layers.misc_layers import (
16 | CrossAttnDecoder,
17 | CrossAttnEncoder,
18 | FlattenWeights,
19 | LearnedScale,
20 | ResBlock,
21 | StatFeaturizer,
22 | TupleOp,
23 | UnflattenWeights,
24 | )
25 | from nn.original_nfn.layers.regularize import (
26 | ChannelDropout,
27 | ChannelLayerNorm,
28 | ParamLayerNorm,
29 | SimpleLayerNorm,
30 | )
31 |
--------------------------------------------------------------------------------
/nn/original_nfn/layers/layer_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from einops import rearrange
4 | from torch import nn
5 |
6 | from nn.original_nfn.common import WeightSpaceFeatures
7 |
8 |
9 | def set_init_(*layers, init_type="pytorch_default"):
10 | in_chan = 0
11 | for layer in layers:
12 | if isinstance(layer, (nn.Conv2d, nn.Conv1d)):
13 | in_chan += layer.in_channels
14 | elif isinstance(layer, nn.Linear):
15 | in_chan += layer.in_features
16 | else:
17 | raise NotImplementedError(f"Unknown layer type {type(layer)}")
18 | if init_type == "pytorch_default":
19 | bd = math.sqrt(1 / in_chan)
20 | for layer in layers:
21 | nn.init.uniform_(layer.weight, -bd, bd)
22 | if layer.bias is not None:
23 | nn.init.uniform_(layer.bias, -bd, bd)
24 | elif init_type == "kaiming_normal":
25 | std = math.sqrt(2 / in_chan)
26 | for layer in layers:
27 | nn.init.normal_(layer.weight, 0, std)
28 | layer.bias.data.zero_()
29 | else:
30 | raise NotImplementedError(f"Unknown init type {init_type}.")
31 |
32 |
33 | def shape_wsfeat_symmetry(params, network_spec):
34 | """Reshape so last 2 dims have symmetry, channel dims have all nonsymmetry.
35 | E.g., for conv weights we reshape (B, C, out, in, h, w) -> (B, C * h * w, out, in)
36 | """
37 | weights, bias = params.weights, params.biases
38 | reshaped_weights = []
39 | for weight, weight_spec in zip(weights, network_spec.weight_spec):
40 | if len(weight_spec.shape) == 2: # mlp weight matrix:
41 | reshaped_weights.append(weight)
42 | else:
43 | reshaped_weights.append(rearrange(weight, "b c o i h w -> b (c h w) o i"))
44 | return WeightSpaceFeatures(reshaped_weights, bias)
45 |
46 |
47 | def unshape_wsfeat_symmetry(params, network_spec):
48 | """Reverse shape_params_symmetry"""
49 | weights, bias = params.weights, params.biases
50 | unreshaped_weights = []
51 | for weight, weight_spec in zip(weights, network_spec.weight_spec):
52 | if len(weight_spec.shape) == 2: # mlp weight matrix:
53 | unreshaped_weights.append(weight)
54 | else:
55 | _, _, h, w = weight_spec.shape
56 | unreshaped_weights.append(
57 | rearrange(weight, "b (c h w) o i -> b c o i h w", h=h, w=w)
58 | )
59 | return WeightSpaceFeatures(unreshaped_weights, bias)
60 |
--------------------------------------------------------------------------------
/nn/original_nfn/layers/misc_layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | from nn.original_nfn.common import NetworkSpec, WeightSpaceFeatures
7 | from nn.original_nfn.layers.layer_utils import shape_wsfeat_symmetry
8 |
9 |
10 | class FlattenWeights(nn.Module):
11 | def __init__(self, network_spec):
12 | super().__init__()
13 | self.network_spec = network_spec
14 |
15 | def forward(self, wsfeat):
16 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec)
17 | outs = []
18 | for i in range(len(self.network_spec)):
19 | w, b = wsfeat[i]
20 | outs.append(torch.flatten(w, start_dim=2).transpose(1, 2))
21 | outs.append(b.transpose(1, 2))
22 | return torch.cat(outs, dim=1) # (B, N, C)
23 |
24 |
25 | class UnflattenWeights(nn.Module):
26 | def __init__(self, network_spec: NetworkSpec):
27 | super().__init__()
28 | self.network_spec = network_spec
29 | self.num_wts, self.num_bs = [], []
30 | for weight_spec, bias_spec in zip(
31 | self.network_spec.weight_spec, self.network_spec.bias_spec
32 | ):
33 | self.num_wts.append(np.prod(weight_spec.shape))
34 | self.num_bs.append(np.prod(bias_spec.shape))
35 |
36 | def forward(self, x: torch.Tensor) -> WeightSpaceFeatures:
37 | # x.shape == (bs, num weights and biases, n_chan)
38 | n_chan = x.shape[2]
39 | out_weights, out_biases = [], []
40 | curr_idx = 0
41 | for i, (weight_spec, bias_spec) in enumerate(
42 | zip(self.network_spec.weight_spec, self.network_spec.bias_spec)
43 | ):
44 | num_wts, num_bs = self.num_wts[i], self.num_bs[i]
45 | # reshape to (bs, 1, *weight_spec.shape) where 1 is channels.
46 | wt = (
47 | x[:, curr_idx : curr_idx + num_wts]
48 | .transpose(1, 2)
49 | .reshape(-1, n_chan, *weight_spec.shape)
50 | )
51 | out_weights.append(wt)
52 | curr_idx += num_wts
53 | bs = (
54 | x[:, curr_idx : curr_idx + num_bs]
55 | .transpose(1, 2)
56 | .reshape(-1, n_chan, *bias_spec.shape)
57 | )
58 | out_biases.append(bs)
59 | curr_idx += num_bs
60 | return WeightSpaceFeatures(out_weights, out_biases)
61 |
62 |
63 | class LearnedScale(nn.Module):
64 | def __init__(self, network_spec: NetworkSpec, init_scale):
65 | super().__init__()
66 | self.weight_scales = nn.ParameterList()
67 | self.bias_scales = nn.ParameterList()
68 | for _ in range(len(network_spec)):
69 | self.weight_scales.append(
70 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32))
71 | )
72 | self.bias_scales.append(
73 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32))
74 | )
75 |
76 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
77 | out_weights, out_biases = [], []
78 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)):
79 | out_weights.append(weight * self.weight_scales[i])
80 | out_biases.append(bias * self.bias_scales[i])
81 | return WeightSpaceFeatures(out_weights, out_biases)
82 |
83 |
84 | class ResBlock(nn.Module):
85 | def __init__(self, base_layer, activation, dropout, norm):
86 | super().__init__()
87 | self.base_layer = base_layer
88 | self.activation = activation
89 | self.dropout = None
90 | if dropout > 0:
91 | self.dropout = TupleOp(nn.Dropout(dropout))
92 | self.norm = norm
93 |
94 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures:
95 | res = self.activation(self.base_layer(self.norm(x)))
96 | if self.dropout is not None:
97 | res = self.dropout(res)
98 | return x + res
99 |
100 |
101 | class StatFeaturizer(nn.Module):
102 | def forward(self, wsfeat: WeightSpaceFeatures) -> torch.Tensor:
103 | out = []
104 | for weight, bias in wsfeat:
105 | out.append(self.compute_stats(weight))
106 | out.append(self.compute_stats(bias))
107 | return torch.cat(out, dim=-1)
108 |
109 | def compute_stats(self, tensor: torch.Tensor) -> torch.Tensor:
110 | """Computes the statistics of the given tensor."""
111 | tensor = torch.flatten(tensor, start_dim=2) # (B, C, H*W)
112 | mean = tensor.mean(-1) # (B, C)
113 | var = tensor.var(-1) # (B, C)
114 | q = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]).to(tensor.device)
115 | quantiles = torch.quantile(tensor, q, dim=-1) # (5, B, C)
116 | return torch.stack([mean, var, *quantiles], dim=-1) # (B, C, 7)
117 |
118 | @staticmethod
119 | def get_num_outs(network_spec):
120 | """Returns the number of outputs of the StatFeaturizer layer."""
121 | return 2 * len(network_spec) * 7
122 |
123 |
124 | class TupleOp(nn.Module):
125 | def __init__(self, op):
126 | super().__init__()
127 | self.op = op
128 |
129 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
130 | out_weights = [self.op(w) for w in wsfeat.weights]
131 | out_bias = [self.op(b) for b in wsfeat.biases]
132 | return WeightSpaceFeatures(out_weights, out_bias)
133 |
134 | def __repr__(self):
135 | return f"TupleOp({self.op})"
136 |
137 |
138 | class CrossAttnEncoder(nn.Module):
139 | def __init__(self, network_spec, channels, num_latents):
140 | super().__init__()
141 | self.embeddings = nn.Parameter(torch.randn(num_latents, channels))
142 | self.flatten = FlattenWeights(network_spec)
143 |
144 | def forward(self, params):
145 | flat_params = self.flatten(params)
146 | # (B, num_latents, C)
147 | return F.scaled_dot_product_attention(self.embeddings, flat_params, flat_params)
148 |
149 |
150 | class CrossAttnDecoder(nn.Module):
151 | def __init__(self, network_spec, channels, num_params):
152 | super().__init__()
153 | self.embeddings = nn.Parameter(torch.randn(num_params, channels))
154 | self.unflatten = UnflattenWeights(network_spec)
155 |
156 | def forward(self, latents):
157 | # latents: (B, num_latents, C)
158 | return self.unflatten(
159 | F.scaled_dot_product_attention(self.embeddings, latents, latents)
160 | )
161 |
--------------------------------------------------------------------------------
/nn/original_nfn/layers/regularize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from einops import rearrange
4 | from einops.layers.torch import Rearrange
5 | from torch import nn
6 |
7 | from nn.original_nfn.common import NetworkSpec, WeightSpaceFeatures
8 | from nn.original_nfn.layers.layer_utils import (
9 | shape_wsfeat_symmetry,
10 | unshape_wsfeat_symmetry,
11 | )
12 |
13 |
14 | class ChannelDropout(nn.Module):
15 | def __init__(self, dropout):
16 | super().__init__()
17 | self.dropout = dropout
18 | self.matrix_dropout = nn.Dropout2d(dropout)
19 | self.bias_dropout = nn.Dropout(dropout)
20 |
21 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures:
22 | weights = [self.process_matrix(w) for w in x.weights]
23 | bias = [self.bias_dropout(b) for b in x.biases]
24 | return WeightSpaceFeatures(weights, bias)
25 |
26 | def process_matrix(self, mat):
27 | shape = mat.shape
28 | is_conv = len(shape) > 4
29 | if is_conv:
30 | _, _, _, _, h, w = shape
31 | mat = rearrange(mat, "b c o i h w -> b (c h w) o i")
32 | mat = self.matrix_dropout(mat)
33 | if is_conv:
34 | mat = rearrange(mat, "b (c h w) o i -> b c o i h w", h=h, w=w)
35 | return mat
36 |
37 |
38 | class SimpleLayerNorm(nn.Module):
39 | def __init__(self, network_spec, channels):
40 | super().__init__()
41 | self.network_spec = network_spec
42 | self.channels = channels
43 | self.w_norms, self.v_norms = nn.ModuleList(), nn.ModuleList()
44 | for i in range(len(network_spec)):
45 | eff_channels = int(
46 | channels * np.prod(network_spec.weight_spec[i].shape[2:])
47 | )
48 | self.w_norms.append(ChannelLayerNorm(eff_channels))
49 | self.v_norms.append(ChannelLayerNorm(channels))
50 |
51 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
52 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec)
53 | out_weights, out_biases = [], []
54 | for i in range(len(self.network_spec)):
55 | weight, bias = wsfeat[i]
56 | out_weights.append(self.w_norms[i](weight))
57 | out_biases.append(self.v_norms[i](bias))
58 | return unshape_wsfeat_symmetry(
59 | WeightSpaceFeatures(out_weights, out_biases), self.network_spec
60 | )
61 |
62 | def __repr__(self):
63 | return f"SimpleLayerNorm(channels={self.channels})"
64 |
65 |
66 | class ParamLayerNorm(nn.Module):
67 | def __init__(self, network_spec: NetworkSpec, channels):
68 | # TODO: This doesn't work for convs yet.
69 | super().__init__()
70 | self.n_in, self.n_out = network_spec.get_io()
71 | self.channels = channels
72 | for i in range(len(network_spec)):
73 | if i == 0:
74 | w_shape = (channels, self.n_in)
75 | v_shape = (channels,)
76 | elif i == len(network_spec) - 1:
77 | w_shape = (self.n_out, channels)
78 | v_shape = (channels, self.n_out)
79 | else:
80 | w_shape = (channels,)
81 | v_shape = (channels,)
82 | self.add_module(f"norm{i}_w", nn.LayerNorm(normalized_shape=w_shape))
83 | self.add_module(f"norm{i}_v", nn.LayerNorm(normalized_shape=v_shape))
84 |
85 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures:
86 | out_weights, out_biases = [], []
87 | for i, (weight, bias) in enumerate(wsfeat):
88 | w_norm = getattr(self, f"norm{i}_w")
89 | v_norm = getattr(self, f"norm{i}_v")
90 | if i == 0:
91 | out_weights.append(w_norm(weight.transpose(-3, -2)).transpose(-3, -2))
92 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2))
93 | elif i == len(wsfeat) - 1:
94 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1))
95 | out_biases.append(v_norm(bias))
96 | else:
97 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1))
98 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2))
99 | return WeightSpaceFeatures(out_weights, out_biases)
100 |
101 |
102 | class ChannelLayerNorm(nn.Module):
103 | def __init__(self, channels, eps=1e-5):
104 | super().__init__()
105 | self.eps = eps
106 | self.gamma = nn.Parameter(torch.ones(channels))
107 | self.beta = nn.Parameter(torch.zeros(channels))
108 | self.channels_last = Rearrange("b c ... -> b ... c")
109 | self.channels_first = Rearrange("b ... c -> b c ...")
110 |
111 | def forward(self, x):
112 | # x.shape = (b, c, ...)
113 | x = self.channels_last(x)
114 | mean = x.mean(dim=-1, keepdim=True)
115 | std = x.std(dim=-1, keepdim=True)
116 | x = (x - mean) / (std + self.eps)
117 | out = self.channels_first(x * self.gamma + self.beta)
118 | return out
119 |
--------------------------------------------------------------------------------
/nn/probe_features.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import torch
3 | import torch.nn as nn
4 | from einops.layers.torch import Rearrange
5 |
6 | from nn.inr import make_functional, params_to_tensor, wrap_func
7 |
8 |
9 | class GraphProbeFeatures(nn.Module):
10 | def __init__(self, d_in, num_inputs, inr_model, input_init=None, proj_dim=None):
11 | super().__init__()
12 | inr = hydra.utils.instantiate(inr_model)
13 | fmodel, params = make_functional(inr)
14 |
15 | vparams, vshapes = params_to_tensor(params)
16 | self.sirens = torch.vmap(wrap_func(fmodel, vshapes))
17 |
18 | inputs = (
19 | input_init
20 | if input_init is not None
21 | else 2 * torch.rand(1, num_inputs, d_in) - 1
22 | )
23 | self.inputs = nn.Parameter(inputs, requires_grad=input_init is None)
24 |
25 | self.reshape_weights = Rearrange("b i o 1 -> b (o i)")
26 | self.reshape_biases = Rearrange("b o 1 -> b o")
27 |
28 | self.proj_dim = proj_dim
29 | if proj_dim is not None:
30 | self.proj = nn.ModuleList(
31 | [
32 | nn.Sequential(
33 | nn.Linear(num_inputs, proj_dim),
34 | nn.LayerNorm(proj_dim),
35 | )
36 | for _ in range(inr.num_layers + 1)
37 | ]
38 | )
39 |
40 | def forward(self, weights, biases):
41 | weights = [self.reshape_weights(w) for w in weights]
42 | biases = [self.reshape_biases(b) for b in biases]
43 | params_flat = torch.cat(
44 | [w_or_b for p in zip(weights, biases) for w_or_b in p], dim=-1
45 | )
46 |
47 | out = self.sirens(params_flat, self.inputs.expand(params_flat.shape[0], -1, -1))
48 | if self.proj_dim is not None:
49 | out = [proj(out[i].permute(0, 2, 1)) for i, proj in enumerate(self.proj)]
50 | out = torch.cat(out, dim=1)
51 | return out
52 | else:
53 | out = torch.cat(out, dim=-1)
54 | return out.permute(0, 2, 1)
55 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | profile = "black"
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch>=1.12.1
3 | torchvision
4 | tqdm
5 | wandb
6 | random-fourier-features-pytorch
7 | scikit-learn
8 | matplotlib
9 | seaborn
10 | pytest
11 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from io import open
2 | from os import path
3 |
4 | from setuptools import find_packages, setup
5 |
6 | here = path.abspath(path.dirname(__file__))
7 |
8 | # get the long description from the README.md file
9 | with open(path.join(here, "README.md"), encoding="utf-8") as f:
10 | long_description = f.read()
11 |
12 |
13 | # get reqs
14 | def requirements():
15 | list_requirements = []
16 | with open("requirements.txt") as f:
17 | for line in f:
18 | list_requirements.append(line.rstrip())
19 | return list_requirements
20 |
21 |
22 | setup(
23 | name="neural-graphs",
24 | version="0.0.1", # Required
25 | description="Graph Neural Networks for Learning Equivariant Representations of Neural Networks", # Optional
26 | long_description="", # Optional
27 | long_description_content_type="text/markdown", # Optional (see note above)
28 | url="", # Optional
29 | author="", # Optional
30 | author_email="", # Optional
31 | packages=find_packages(exclude=["contrib", "docs", "tests"]),
32 | python_requires=">=3.9",
33 | install_requires=requirements(), # Optional
34 | )
35 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | ### Test invariance
2 |
3 | To test the model invariance/equivariance on INR tasks or CNN generalization, run the
4 | following from the repository root:
5 | ```py
6 | pytest tests/test_inr_invariance.py
7 | pytest tests/test_inr_equivariance.py
8 | pytest tests/test_cnn_invariance.py
9 | ```
10 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_cnn_invariance.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import pytest
3 | import torch
4 | from omegaconf import OmegaConf
5 |
6 | from experiments.utils import set_seed
7 | from tests.utils import wb_to_batch, permute_weights_biases
8 |
9 | set_seed(42)
10 |
11 | OmegaConf.register_new_resolver("prod", lambda x, y: x * y)
12 |
13 |
14 | @pytest.fixture
15 | def model():
16 | with hydra.initialize(
17 | version_base=None, config_path="../experiments/cnn_generalization/configs"
18 | ):
19 | cfg = hydra.compose(config_name="base")
20 | cfg.data.flattening_method = None
21 | cfg.data._max_kernel_height = 5
22 | cfg.data._max_kernel_width = 5
23 | model = hydra.utils.instantiate(cfg.model)
24 | return model
25 |
26 |
27 | def test_model_invariance(model):
28 | batch_size = 4
29 | layer_layout = [3, 16, 32, 10]
30 | dims = [25, 25, 1]
31 | pad = (12, 12)
32 | weights = tuple(
33 | torch.randn(batch_size, layer_layout[i], layer_layout[i + 1], dims[i])
34 | for i in range(len(layer_layout) - 1)
35 | )
36 | biases = tuple(
37 | torch.randn(batch_size, layer_layout[i], 1) for i in range(1, len(layer_layout))
38 | )
39 |
40 | batch = wb_to_batch(weights, biases, layer_layout, pad=pad)
41 | out = model(batch)
42 |
43 | # Generate random permutations
44 | permutations = [
45 | torch.randperm(layer_layout[i]) for i in range(1, len(layer_layout) - 1)
46 | ]
47 |
48 | perm_weights, perm_biases = permute_weights_biases(weights, biases, permutations)
49 | perm_batch = wb_to_batch(perm_weights, perm_biases, layer_layout, pad=pad)
50 |
51 | out_perm = model(perm_batch)
52 |
53 | assert torch.allclose(out, out_perm, atol=1e-5, rtol=0)
54 |
--------------------------------------------------------------------------------
/tests/test_inr_equivariance.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import pytest
3 | import torch
4 |
5 | from experiments.utils import set_seed
6 | from tests.utils import permute_weights_biases
7 |
8 | set_seed(42)
9 |
10 |
11 | @pytest.fixture
12 | def model():
13 | with hydra.initialize(
14 | version_base=None, config_path="../experiments/style_editing/configs"
15 | ):
16 | cfg = hydra.compose(config_name="base")
17 | cfg.data.stats = {
18 | "weights_mean": None,
19 | "weights_std": None,
20 | "biases_mean": None,
21 | "biases_std": None,
22 | }
23 | model = hydra.utils.instantiate(cfg.model, layer_layout=(2, 32, 32, 32, 32, 3))
24 | return model
25 |
26 |
27 | def test_model_equivariance(model):
28 | batch_size = 4
29 | layer_layout = [2, 32, 32, 32, 32, 3]
30 | weights = tuple(
31 | torch.randn(batch_size, layer_layout[i], layer_layout[i + 1], 1)
32 | for i in range(len(layer_layout) - 1)
33 | )
34 | biases = tuple(
35 | torch.randn(batch_size, layer_layout[i], 1) for i in range(1, len(layer_layout))
36 | )
37 | out_weights, out_biases = model((weights, biases))
38 |
39 | # Generate random permutations
40 | permutations = [
41 | torch.randperm(layer_layout[i]) for i in range(1, len(layer_layout) - 1)
42 | ]
43 |
44 | perm_weights, perm_biases = permute_weights_biases(weights, biases, permutations)
45 | out_perm_weights, out_perm_biases = model((perm_weights, perm_biases))
46 | perm_out_weights, perm_out_biases = permute_weights_biases(
47 | out_weights, out_biases, permutations
48 | )
49 |
50 | for i in range(len(out_weights)):
51 | assert torch.allclose(
52 | perm_out_weights[i], out_perm_weights[i], atol=1e-5, rtol=1e-8
53 | )
54 | assert torch.allclose(
55 | perm_out_biases[i], out_perm_biases[i], atol=1e-4, rtol=1e-1
56 | )
57 |
--------------------------------------------------------------------------------
/tests/test_inr_invariance.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import pytest
3 | import torch
4 |
5 | from experiments.utils import set_seed
6 | from tests.utils import permute_weights_biases
7 |
8 | set_seed(42)
9 |
10 |
11 | @pytest.fixture
12 | def model():
13 | with hydra.initialize(
14 | version_base=None, config_path="../experiments/inr_classification/configs"
15 | ):
16 | cfg = hydra.compose(config_name="base")
17 | cfg.data.stats = None
18 | model = hydra.utils.instantiate(cfg.model, layer_layout=(2, 32, 32, 32, 32, 3))
19 | return model
20 |
21 |
22 | def test_model_invariance(model):
23 | batch_size = 4
24 | layer_layout = [2, 32, 32, 32, 32, 3]
25 | weights = tuple(
26 | torch.randn(batch_size, layer_layout[i], layer_layout[i + 1], 1)
27 | for i in range(len(layer_layout) - 1)
28 | )
29 | biases = tuple(
30 | torch.randn(batch_size, layer_layout[i], 1) for i in range(1, len(layer_layout))
31 | )
32 | out = model((weights, biases))
33 |
34 | # Generate random permutations
35 | permutations = [
36 | torch.randperm(layer_layout[i]) for i in range(1, len(layer_layout) - 1)
37 | ]
38 |
39 | perm_weights, perm_biases = permute_weights_biases(weights, biases, permutations)
40 | out_perm = model((perm_weights, perm_biases))
41 |
42 | assert torch.allclose(out, out_perm, atol=1e-5, rtol=0)
43 | # return out, out_perm
44 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch_geometric
4 |
5 |
6 | def permute_weights_biases(weights, biases, permutations):
7 | perm_weights = tuple(
8 | (
9 | weights[i][:, :, permutations[i], :]
10 | if i == 0
11 | else (
12 | weights[i][:, permutations[i - 1], :, :][:, :, permutations[i], :]
13 | if i < len(weights) - 1
14 | else weights[i][:, permutations[i - 1], :, :]
15 | )
16 | )
17 | for i in range(len(weights))
18 | )
19 | perm_biases = tuple(
20 | biases[i][:, permutations[i], :] if i < len(biases) - 1 else biases[i]
21 | for i in range(len(biases))
22 | )
23 | return perm_weights, perm_biases
24 |
25 |
26 | def wb_to_batch(weights, biases, layer_layout, pad):
27 | batch_size = weights[0].shape[0]
28 | x = torch.cat(
29 | [
30 | torch.zeros(
31 | (biases[0].shape[0], layer_layout[0], 1),
32 | dtype=biases[0].dtype,
33 | device=biases[0].device,
34 | ),
35 | *biases,
36 | ],
37 | dim=1,
38 | )
39 | cumsum_layout = [0] + torch.tensor(layer_layout).cumsum(dim=0).tolist()
40 | edge_index = torch.cat(
41 | [
42 | torch.cartesian_prod(
43 | torch.arange(cumsum_layout[i], cumsum_layout[i + 1]),
44 | torch.arange(cumsum_layout[i + 1], cumsum_layout[i + 2]),
45 | ).T
46 | for i in range(len(cumsum_layout) - 2)
47 | ],
48 | dim=1,
49 | )
50 | edge_attr = torch.cat(
51 | [
52 | weights[0].flatten(1, 2),
53 | weights[1].flatten(1, 2),
54 | F.pad(weights[-1], pad=pad).flatten(1, 2),
55 | ],
56 | dim=1,
57 | )
58 | batch = torch_geometric.data.Batch.from_data_list(
59 | [
60 | torch_geometric.data.Data(
61 | x=x[i],
62 | edge_index=edge_index,
63 | edge_attr=edge_attr[i],
64 | layer_layout=layer_layout,
65 | conv_mask=[1 if w.shape[-1] > 1 else 0 for w in weights],
66 | fmap_size=1,
67 | )
68 | for i in range(batch_size)
69 | ]
70 | )
71 | return batch
72 |
--------------------------------------------------------------------------------